Skip to content

Commit de93f2b

Browse files
committed
added rob knights fancy arrow
svn path=/trunk/matplotlib/; revision=2188
1 parent dd525e4 commit de93f2b

File tree

6 files changed

+440
-31
lines changed

6 files changed

+440
-31
lines changed

CHANGELOG

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
2006-03-21 Added Rob Knight's arrow code; see examples/arrow_demo.py - JDH
2+
13
2006-03-20 Added support for masking values with nan's, using ADS's
24
isnan module and the new API. Works for *Agg backends and
35
the postscript backend- DSD

boilerplate.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def %(func)s(*args, **kwargs):
4747
# these methods are all simple wrappers of Axes methods by the same
4848
# name.
4949
_plotcommands = (
50+
'arrow',
5051
'axhline',
5152
'axhspan',
5253
'axvline',

examples/arrow_demo.py

Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
#!/usr/bin/env python
2+
"""Arrow drawing example for the new fancy_arrow facilities.
3+
4+
Code contributed by: Rob Knight <rob@spot.colorado.edu>
5+
6+
usage:
7+
8+
python arrow_demo.py realistic|full|sample|extreme
9+
10+
11+
"""
12+
from pylab import *
13+
14+
rc('text', usetex=True)
15+
rates_to_bases={'r1':'AT', 'r2':'TA', 'r3':'GA','r4':'AG','r5':'CA','r6':'AC', \
16+
'r7':'GT', 'r8':'TG', 'r9':'CT','r10':'TC','r11':'GC','r12':'CG'}
17+
numbered_bases_to_rates = dict([(v,k) for k, v in rates_to_bases.items()])
18+
lettered_bases_to_rates = dict([(v, 'r'+v) for k, v in rates_to_bases.items()])
19+
def add_dicts(d1, d2):
20+
"""Adds two dicts and returns the result."""
21+
result = d1.copy()
22+
result.update(d2)
23+
return result
24+
25+
def make_arrow_plot(data, size=4, display='length', shape='right', \
26+
max_arrow_width=0.03, arrow_sep = 0.02, alpha=0.5, \
27+
normalize_data=False, ec=None, labelcolor=None, \
28+
head_starts_at_zero=True, rate_labels=lettered_bases_to_rates,\
29+
**kwargs):
30+
"""Makes an arrow plot.
31+
32+
Parameters:
33+
34+
data: dict with probabilities for the bases and pair transitions.
35+
size: size of the graph in inches.
36+
display: 'length', 'width', or 'alpha' for arrow property to change.
37+
shape: 'full', 'left', or 'right' for full or half arrows.
38+
max_arrow_width: maximum width of an arrow, data coordinates.
39+
arrow_sep: separation between arrows in a pair, data coordinates.
40+
alpha: maximum opacity of arrows, default 0.8.
41+
42+
**kwargs can be anything allowed by a Arrow object, e.g.
43+
linewidth and edgecolor.
44+
"""
45+
46+
xlim(-0.5,1.5)
47+
ylim(-0.5,1.5)
48+
gcf().set_figsize_inches(size,size)
49+
xticks([])
50+
yticks([])
51+
max_text_size = size*12
52+
min_text_size = size
53+
label_text_size = size*2.5
54+
text_params={'ha':'center', 'va':'center', 'family':'sans-serif',\
55+
'fontweight':'bold'}
56+
r2 = sqrt(2)
57+
58+
deltas = {\
59+
'AT':(1,0),
60+
'TA':(-1,0),
61+
'GA':(0,1),
62+
'AG':(0,-1),
63+
'CA':(-1/r2, 1/r2),
64+
'AC':(1/r2, -1/r2),
65+
'GT':(1/r2, 1/r2),
66+
'TG':(-1/r2,-1/r2),
67+
'CT':(0,1),
68+
'TC':(0,-1),
69+
'GC':(1,0),
70+
'CG':(-1,0)
71+
}
72+
73+
colors = {\
74+
'AT':'r',
75+
'TA':'k',
76+
'GA':'g',
77+
'AG':'r',
78+
'CA':'b',
79+
'AC':'r',
80+
'GT':'g',
81+
'TG':'k',
82+
'CT':'b',
83+
'TC':'k',
84+
'GC':'g',
85+
'CG':'b'
86+
}
87+
88+
label_positions = {\
89+
'AT':'center',
90+
'TA':'center',
91+
'GA':'center',
92+
'AG':'center',
93+
'CA':'left',
94+
'AC':'left',
95+
'GT':'left',
96+
'TG':'left',
97+
'CT':'center',
98+
'TC':'center',
99+
'GC':'center',
100+
'CG':'center'
101+
}
102+
103+
104+
def do_fontsize(k):
105+
return float(clip(max_text_size*sqrt(data[k]),\
106+
min_text_size,max_text_size))
107+
108+
A = text(0,1, '$A_3$', color='r', size=do_fontsize('A'), **text_params)
109+
T = text(1,1, '$T_3$', color='k', size=do_fontsize('T'), **text_params)
110+
G = text(0,0, '$G_3$', color='g', size=do_fontsize('G'), **text_params)
111+
C = text(1,0, '$C_3$', color='b', size=do_fontsize('C'), **text_params)
112+
113+
arrow_h_offset = 0.25 #data coordinates, empirically determined
114+
max_arrow_length = 1 - 2*arrow_h_offset
115+
116+
max_arrow_width = max_arrow_width
117+
max_head_width = 2.5*max_arrow_width
118+
max_head_length = 2*max_arrow_width
119+
arrow_params={'length_includes_head':True, 'shape':shape, \
120+
'head_starts_at_zero':head_starts_at_zero}
121+
ax = gca()
122+
sf = 0.6 #max arrow size represents this in data coords
123+
124+
d = (r2/2 + arrow_h_offset - 0.5)/r2 #distance for diags
125+
r2v = arrow_sep/r2 #offset for diags
126+
127+
#tuple of x, y for start position
128+
positions = {\
129+
'AT': (arrow_h_offset, 1+arrow_sep),
130+
'TA': (1-arrow_h_offset, 1-arrow_sep),
131+
'GA': (-arrow_sep, arrow_h_offset),
132+
'AG': (arrow_sep, 1-arrow_h_offset),
133+
'CA': (1-d-r2v, d-r2v),
134+
'AC': (d+r2v, 1-d+r2v),
135+
'GT': (d-r2v, d+r2v),
136+
'TG': (1-d+r2v, 1-d-r2v),
137+
'CT': (1-arrow_sep, arrow_h_offset),
138+
'TC': (1+arrow_sep, 1-arrow_h_offset),
139+
'GC': (arrow_h_offset, arrow_sep),
140+
'CG': (1-arrow_h_offset, -arrow_sep),
141+
}
142+
143+
if normalize_data:
144+
#find maximum value for rates, i.e. where keys are 2 chars long
145+
max_val = 0
146+
for k, v in data.items():
147+
if len(k) == 2:
148+
max_val = max(max_val, v)
149+
#divide rates by max val, multiply by arrow scale factor
150+
for k, v in data.items():
151+
data[k] = v/max_val*sf
152+
153+
def draw_arrow(pair, alpha=alpha, ec=ec, labelcolor=labelcolor):
154+
#set the length of the arrow
155+
if display == 'length':
156+
length = max_head_length+(max_arrow_length-max_head_length)*\
157+
data[pair]/sf
158+
else:
159+
length = max_arrow_length
160+
#set the transparency of the arrow
161+
if display == 'alph':
162+
alpha = min(data[pair]/sf, alpha)
163+
else:
164+
alpha=alpha
165+
#set the width of the arrow
166+
if display == 'width':
167+
scale = data[pair]/sf
168+
width = max_arrow_width*scale
169+
head_width = max_head_width*scale
170+
head_length = max_head_length*scale
171+
else:
172+
width = max_arrow_width
173+
head_width = max_head_width
174+
head_length = max_head_length
175+
176+
fc = colors[pair]
177+
ec = ec or fc
178+
179+
x_scale, y_scale = deltas[pair]
180+
x_pos, y_pos = positions[pair]
181+
arrow(x_pos, y_pos, x_scale*length, y_scale*length, \
182+
fc=fc, ec=ec, alpha=alpha, width=width, head_width=head_width, \
183+
head_length=head_length, **arrow_params)
184+
185+
#figure out coordinates for text
186+
#if drawing relative to base: x and y are same as for arrow
187+
#dx and dy are one arrow width left and up
188+
#need to rotate based on direction of arrow, use x_scale and y_scale
189+
#as sin x and cos x?
190+
sx, cx = y_scale, x_scale
191+
192+
where = label_positions[pair]
193+
if where == 'left':
194+
orig_position = 3*array([[max_arrow_width, max_arrow_width]])
195+
elif where == 'absolute':
196+
orig_position = array([[max_arrow_length/2.0, 3*max_arrow_width]])
197+
elif where == 'right':
198+
orig_position = array([[length-3*max_arrow_width,\
199+
3*max_arrow_width]])
200+
elif where == 'center':
201+
orig_position = array([[length/2.0, 3*max_arrow_width]])
202+
else:
203+
raise ValueError, "Got unknown position parameter %s" % where
204+
205+
206+
207+
M = array([[cx, sx],[-sx,cx]])
208+
coords = matrixmultiply(orig_position, M) + [[x_pos, y_pos]]
209+
x, y = ravel(coords)
210+
orig_label = rate_labels[pair]
211+
label = '$%s_{_{\mathrm{%s}}}$' % (orig_label[0], orig_label[1:])
212+
213+
text(x, y, label, size=label_text_size, ha='center', va='center', \
214+
color=labelcolor or fc)
215+
216+
for p in positions.keys():
217+
draw_arrow(p)
218+
219+
#test data
220+
all_on_max = dict([(i, 1) for i in 'TCAG'] + \
221+
[(i+j, 0.6) for i in 'TCAG' for j in 'TCAG'])
222+
223+
realistic_data = {
224+
'A':0.4,
225+
'T':0.3,
226+
'G':0.5,
227+
'C':0.2,
228+
'AT':0.4,
229+
'AC':0.3,
230+
'AG':0.2,
231+
'TA':0.2,
232+
'TC':0.3,
233+
'TG':0.4,
234+
'CT':0.2,
235+
'CG':0.3,
236+
'CA':0.2,
237+
'GA':0.1,
238+
'GT':0.4,
239+
'GC':0.1,
240+
}
241+
242+
extreme_data = {
243+
'A':0.75,
244+
'T':0.10,
245+
'G':0.10,
246+
'C':0.05,
247+
'AT':0.6,
248+
'AC':0.3,
249+
'AG':0.1,
250+
'TA':0.02,
251+
'TC':0.3,
252+
'TG':0.01,
253+
'CT':0.2,
254+
'CG':0.5,
255+
'CA':0.2,
256+
'GA':0.1,
257+
'GT':0.4,
258+
'GC':0.2,
259+
}
260+
261+
sample_data = {
262+
'A':0.2137,
263+
'T':0.3541,
264+
'G':0.1946,
265+
'C':0.2376,
266+
'AT':0.0228,
267+
'AC':0.0684,
268+
'AG':0.2056,
269+
'TA':0.0315,
270+
'TC':0.0629,
271+
'TG':0.0315,
272+
'CT':0.1355,
273+
'CG':0.0401,
274+
'CA':0.0703,
275+
'GA':0.1824,
276+
'GT':0.0387,
277+
'GC':0.1106,
278+
}
279+
280+
281+
if __name__ == '__main__':
282+
from sys import argv
283+
if len(argv) > 1:
284+
if argv[1] == 'full':
285+
d = all_on_max
286+
scaled = False
287+
elif argv[1] == 'extreme':
288+
d = extreme_data
289+
scaled = False
290+
elif argv[1] == 'realistic':
291+
d = realistic_data
292+
scaled = False
293+
elif argv[1] == 'sample':
294+
d = sample_data
295+
scaled = True
296+
else:
297+
d = all_on_max
298+
scaled=False
299+
if len(argv) > 2:
300+
display = argv[2]
301+
else:
302+
display = 'length'
303+
304+
size = 4
305+
figure(figsize=(size,size))
306+
307+
make_arrow_plot(d, display=display, linewidth=0.001, edgecolor=None,
308+
normalize_data=scaled, head_starts_at_zero=True, size=size)
309+
310+
draw()
311+
savefig('arrows.png')
312+
print 'Example saved to file "arrows.png"'
313+
show()

lib/matplotlib/axes.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from matplotlib.numerix.mlab import flipud, amin, amax
3434

3535
from matplotlib import rcParams
36-
from patches import Patch, Rectangle, Circle, Polygon, Arrow, Wedge, Shadow, bbox_artist
36+
from patches import Patch, Rectangle, Circle, Polygon, Arrow, Wedge, Shadow, FancyArrow, bbox_artist
3737
from table import Table
3838
from text import Text, TextWithDash, _process_text_args
3939
from transforms import Bbox, Point, Value, Affine, NonseparableTransformation
@@ -463,6 +463,12 @@ def _set_lim_and_transforms(self):
463463
if self._sharey:
464464
self.transData.set_funcy(self._sharey.transData.get_funcy())
465465

466+
def arrow(self, x, y, dx, dy, **kwargs):
467+
"""Draws arrow on specified axis from (x,y) to (x+dx,y+dy)."""
468+
a = FancyArrow(x, y, dx, dy, **kwargs)
469+
self.add_artist(a)
470+
return a
471+
466472
def axhline(self, y=0, xmin=0, xmax=1, **kwargs):
467473
"""
468474
AXHLINE(y=0, xmin=0, xmax=1, **kwargs)

0 commit comments

Comments
 (0)