|
| 1 | +#!/usr/bin/env python |
| 2 | +"""Arrow drawing example for the new fancy_arrow facilities. |
| 3 | +
|
| 4 | +Code contributed by: Rob Knight <rob@spot.colorado.edu> |
| 5 | +
|
| 6 | +usage: |
| 7 | +
|
| 8 | + python arrow_demo.py realistic|full|sample|extreme |
| 9 | +
|
| 10 | +
|
| 11 | +""" |
| 12 | +from pylab import * |
| 13 | + |
| 14 | +rc('text', usetex=True) |
| 15 | +rates_to_bases={'r1':'AT', 'r2':'TA', 'r3':'GA','r4':'AG','r5':'CA','r6':'AC', \ |
| 16 | + 'r7':'GT', 'r8':'TG', 'r9':'CT','r10':'TC','r11':'GC','r12':'CG'} |
| 17 | +numbered_bases_to_rates = dict([(v,k) for k, v in rates_to_bases.items()]) |
| 18 | +lettered_bases_to_rates = dict([(v, 'r'+v) for k, v in rates_to_bases.items()]) |
| 19 | +def add_dicts(d1, d2): |
| 20 | + """Adds two dicts and returns the result.""" |
| 21 | + result = d1.copy() |
| 22 | + result.update(d2) |
| 23 | + return result |
| 24 | + |
| 25 | +def make_arrow_plot(data, size=4, display='length', shape='right', \ |
| 26 | + max_arrow_width=0.03, arrow_sep = 0.02, alpha=0.5, \ |
| 27 | + normalize_data=False, ec=None, labelcolor=None, \ |
| 28 | + head_starts_at_zero=True, rate_labels=lettered_bases_to_rates,\ |
| 29 | + **kwargs): |
| 30 | + """Makes an arrow plot. |
| 31 | +
|
| 32 | + Parameters: |
| 33 | + |
| 34 | + data: dict with probabilities for the bases and pair transitions. |
| 35 | + size: size of the graph in inches. |
| 36 | + display: 'length', 'width', or 'alpha' for arrow property to change. |
| 37 | + shape: 'full', 'left', or 'right' for full or half arrows. |
| 38 | + max_arrow_width: maximum width of an arrow, data coordinates. |
| 39 | + arrow_sep: separation between arrows in a pair, data coordinates. |
| 40 | + alpha: maximum opacity of arrows, default 0.8. |
| 41 | + |
| 42 | + **kwargs can be anything allowed by a Arrow object, e.g. |
| 43 | + linewidth and edgecolor. |
| 44 | + """ |
| 45 | + |
| 46 | + xlim(-0.5,1.5) |
| 47 | + ylim(-0.5,1.5) |
| 48 | + gcf().set_figsize_inches(size,size) |
| 49 | + xticks([]) |
| 50 | + yticks([]) |
| 51 | + max_text_size = size*12 |
| 52 | + min_text_size = size |
| 53 | + label_text_size = size*2.5 |
| 54 | + text_params={'ha':'center', 'va':'center', 'family':'sans-serif',\ |
| 55 | + 'fontweight':'bold'} |
| 56 | + r2 = sqrt(2) |
| 57 | + |
| 58 | + deltas = {\ |
| 59 | + 'AT':(1,0), |
| 60 | + 'TA':(-1,0), |
| 61 | + 'GA':(0,1), |
| 62 | + 'AG':(0,-1), |
| 63 | + 'CA':(-1/r2, 1/r2), |
| 64 | + 'AC':(1/r2, -1/r2), |
| 65 | + 'GT':(1/r2, 1/r2), |
| 66 | + 'TG':(-1/r2,-1/r2), |
| 67 | + 'CT':(0,1), |
| 68 | + 'TC':(0,-1), |
| 69 | + 'GC':(1,0), |
| 70 | + 'CG':(-1,0) |
| 71 | + } |
| 72 | + |
| 73 | + colors = {\ |
| 74 | + 'AT':'r', |
| 75 | + 'TA':'k', |
| 76 | + 'GA':'g', |
| 77 | + 'AG':'r', |
| 78 | + 'CA':'b', |
| 79 | + 'AC':'r', |
| 80 | + 'GT':'g', |
| 81 | + 'TG':'k', |
| 82 | + 'CT':'b', |
| 83 | + 'TC':'k', |
| 84 | + 'GC':'g', |
| 85 | + 'CG':'b' |
| 86 | + } |
| 87 | + |
| 88 | + label_positions = {\ |
| 89 | + 'AT':'center', |
| 90 | + 'TA':'center', |
| 91 | + 'GA':'center', |
| 92 | + 'AG':'center', |
| 93 | + 'CA':'left', |
| 94 | + 'AC':'left', |
| 95 | + 'GT':'left', |
| 96 | + 'TG':'left', |
| 97 | + 'CT':'center', |
| 98 | + 'TC':'center', |
| 99 | + 'GC':'center', |
| 100 | + 'CG':'center' |
| 101 | + } |
| 102 | + |
| 103 | + |
| 104 | + def do_fontsize(k): |
| 105 | + return float(clip(max_text_size*sqrt(data[k]),\ |
| 106 | + min_text_size,max_text_size)) |
| 107 | + |
| 108 | + A = text(0,1, '$A_3$', color='r', size=do_fontsize('A'), **text_params) |
| 109 | + T = text(1,1, '$T_3$', color='k', size=do_fontsize('T'), **text_params) |
| 110 | + G = text(0,0, '$G_3$', color='g', size=do_fontsize('G'), **text_params) |
| 111 | + C = text(1,0, '$C_3$', color='b', size=do_fontsize('C'), **text_params) |
| 112 | + |
| 113 | + arrow_h_offset = 0.25 #data coordinates, empirically determined |
| 114 | + max_arrow_length = 1 - 2*arrow_h_offset |
| 115 | + |
| 116 | + max_arrow_width = max_arrow_width |
| 117 | + max_head_width = 2.5*max_arrow_width |
| 118 | + max_head_length = 2*max_arrow_width |
| 119 | + arrow_params={'length_includes_head':True, 'shape':shape, \ |
| 120 | + 'head_starts_at_zero':head_starts_at_zero} |
| 121 | + ax = gca() |
| 122 | + sf = 0.6 #max arrow size represents this in data coords |
| 123 | + |
| 124 | + d = (r2/2 + arrow_h_offset - 0.5)/r2 #distance for diags |
| 125 | + r2v = arrow_sep/r2 #offset for diags |
| 126 | + |
| 127 | + #tuple of x, y for start position |
| 128 | + positions = {\ |
| 129 | + 'AT': (arrow_h_offset, 1+arrow_sep), |
| 130 | + 'TA': (1-arrow_h_offset, 1-arrow_sep), |
| 131 | + 'GA': (-arrow_sep, arrow_h_offset), |
| 132 | + 'AG': (arrow_sep, 1-arrow_h_offset), |
| 133 | + 'CA': (1-d-r2v, d-r2v), |
| 134 | + 'AC': (d+r2v, 1-d+r2v), |
| 135 | + 'GT': (d-r2v, d+r2v), |
| 136 | + 'TG': (1-d+r2v, 1-d-r2v), |
| 137 | + 'CT': (1-arrow_sep, arrow_h_offset), |
| 138 | + 'TC': (1+arrow_sep, 1-arrow_h_offset), |
| 139 | + 'GC': (arrow_h_offset, arrow_sep), |
| 140 | + 'CG': (1-arrow_h_offset, -arrow_sep), |
| 141 | + } |
| 142 | + |
| 143 | + if normalize_data: |
| 144 | + #find maximum value for rates, i.e. where keys are 2 chars long |
| 145 | + max_val = 0 |
| 146 | + for k, v in data.items(): |
| 147 | + if len(k) == 2: |
| 148 | + max_val = max(max_val, v) |
| 149 | + #divide rates by max val, multiply by arrow scale factor |
| 150 | + for k, v in data.items(): |
| 151 | + data[k] = v/max_val*sf |
| 152 | + |
| 153 | + def draw_arrow(pair, alpha=alpha, ec=ec, labelcolor=labelcolor): |
| 154 | + #set the length of the arrow |
| 155 | + if display == 'length': |
| 156 | + length = max_head_length+(max_arrow_length-max_head_length)*\ |
| 157 | + data[pair]/sf |
| 158 | + else: |
| 159 | + length = max_arrow_length |
| 160 | + #set the transparency of the arrow |
| 161 | + if display == 'alph': |
| 162 | + alpha = min(data[pair]/sf, alpha) |
| 163 | + else: |
| 164 | + alpha=alpha |
| 165 | + #set the width of the arrow |
| 166 | + if display == 'width': |
| 167 | + scale = data[pair]/sf |
| 168 | + width = max_arrow_width*scale |
| 169 | + head_width = max_head_width*scale |
| 170 | + head_length = max_head_length*scale |
| 171 | + else: |
| 172 | + width = max_arrow_width |
| 173 | + head_width = max_head_width |
| 174 | + head_length = max_head_length |
| 175 | + |
| 176 | + fc = colors[pair] |
| 177 | + ec = ec or fc |
| 178 | + |
| 179 | + x_scale, y_scale = deltas[pair] |
| 180 | + x_pos, y_pos = positions[pair] |
| 181 | + arrow(x_pos, y_pos, x_scale*length, y_scale*length, \ |
| 182 | + fc=fc, ec=ec, alpha=alpha, width=width, head_width=head_width, \ |
| 183 | + head_length=head_length, **arrow_params) |
| 184 | + |
| 185 | + #figure out coordinates for text |
| 186 | + #if drawing relative to base: x and y are same as for arrow |
| 187 | + #dx and dy are one arrow width left and up |
| 188 | + #need to rotate based on direction of arrow, use x_scale and y_scale |
| 189 | + #as sin x and cos x? |
| 190 | + sx, cx = y_scale, x_scale |
| 191 | + |
| 192 | + where = label_positions[pair] |
| 193 | + if where == 'left': |
| 194 | + orig_position = 3*array([[max_arrow_width, max_arrow_width]]) |
| 195 | + elif where == 'absolute': |
| 196 | + orig_position = array([[max_arrow_length/2.0, 3*max_arrow_width]]) |
| 197 | + elif where == 'right': |
| 198 | + orig_position = array([[length-3*max_arrow_width,\ |
| 199 | + 3*max_arrow_width]]) |
| 200 | + elif where == 'center': |
| 201 | + orig_position = array([[length/2.0, 3*max_arrow_width]]) |
| 202 | + else: |
| 203 | + raise ValueError, "Got unknown position parameter %s" % where |
| 204 | + |
| 205 | + |
| 206 | + |
| 207 | + M = array([[cx, sx],[-sx,cx]]) |
| 208 | + coords = matrixmultiply(orig_position, M) + [[x_pos, y_pos]] |
| 209 | + x, y = ravel(coords) |
| 210 | + orig_label = rate_labels[pair] |
| 211 | + label = '$%s_{_{\mathrm{%s}}}$' % (orig_label[0], orig_label[1:]) |
| 212 | + |
| 213 | + text(x, y, label, size=label_text_size, ha='center', va='center', \ |
| 214 | + color=labelcolor or fc) |
| 215 | + |
| 216 | + for p in positions.keys(): |
| 217 | + draw_arrow(p) |
| 218 | + |
| 219 | + #test data |
| 220 | +all_on_max = dict([(i, 1) for i in 'TCAG'] + \ |
| 221 | + [(i+j, 0.6) for i in 'TCAG' for j in 'TCAG']) |
| 222 | + |
| 223 | +realistic_data = { |
| 224 | + 'A':0.4, |
| 225 | + 'T':0.3, |
| 226 | + 'G':0.5, |
| 227 | + 'C':0.2, |
| 228 | + 'AT':0.4, |
| 229 | + 'AC':0.3, |
| 230 | + 'AG':0.2, |
| 231 | + 'TA':0.2, |
| 232 | + 'TC':0.3, |
| 233 | + 'TG':0.4, |
| 234 | + 'CT':0.2, |
| 235 | + 'CG':0.3, |
| 236 | + 'CA':0.2, |
| 237 | + 'GA':0.1, |
| 238 | + 'GT':0.4, |
| 239 | + 'GC':0.1, |
| 240 | + } |
| 241 | + |
| 242 | +extreme_data = { |
| 243 | + 'A':0.75, |
| 244 | + 'T':0.10, |
| 245 | + 'G':0.10, |
| 246 | + 'C':0.05, |
| 247 | + 'AT':0.6, |
| 248 | + 'AC':0.3, |
| 249 | + 'AG':0.1, |
| 250 | + 'TA':0.02, |
| 251 | + 'TC':0.3, |
| 252 | + 'TG':0.01, |
| 253 | + 'CT':0.2, |
| 254 | + 'CG':0.5, |
| 255 | + 'CA':0.2, |
| 256 | + 'GA':0.1, |
| 257 | + 'GT':0.4, |
| 258 | + 'GC':0.2, |
| 259 | + } |
| 260 | + |
| 261 | +sample_data = { |
| 262 | + 'A':0.2137, |
| 263 | + 'T':0.3541, |
| 264 | + 'G':0.1946, |
| 265 | + 'C':0.2376, |
| 266 | + 'AT':0.0228, |
| 267 | + 'AC':0.0684, |
| 268 | + 'AG':0.2056, |
| 269 | + 'TA':0.0315, |
| 270 | + 'TC':0.0629, |
| 271 | + 'TG':0.0315, |
| 272 | + 'CT':0.1355, |
| 273 | + 'CG':0.0401, |
| 274 | + 'CA':0.0703, |
| 275 | + 'GA':0.1824, |
| 276 | + 'GT':0.0387, |
| 277 | + 'GC':0.1106, |
| 278 | + } |
| 279 | + |
| 280 | + |
| 281 | +if __name__ == '__main__': |
| 282 | + from sys import argv |
| 283 | + if len(argv) > 1: |
| 284 | + if argv[1] == 'full': |
| 285 | + d = all_on_max |
| 286 | + scaled = False |
| 287 | + elif argv[1] == 'extreme': |
| 288 | + d = extreme_data |
| 289 | + scaled = False |
| 290 | + elif argv[1] == 'realistic': |
| 291 | + d = realistic_data |
| 292 | + scaled = False |
| 293 | + elif argv[1] == 'sample': |
| 294 | + d = sample_data |
| 295 | + scaled = True |
| 296 | + else: |
| 297 | + d = all_on_max |
| 298 | + scaled=False |
| 299 | + if len(argv) > 2: |
| 300 | + display = argv[2] |
| 301 | + else: |
| 302 | + display = 'length' |
| 303 | + |
| 304 | + size = 4 |
| 305 | + figure(figsize=(size,size)) |
| 306 | + |
| 307 | + make_arrow_plot(d, display=display, linewidth=0.001, edgecolor=None, |
| 308 | + normalize_data=scaled, head_starts_at_zero=True, size=size) |
| 309 | + |
| 310 | + draw() |
| 311 | + savefig('arrows.png') |
| 312 | + print 'Example saved to file "arrows.png"' |
| 313 | + show() |
0 commit comments