# Tutorial 3 - Tanglegrams

`baltic` will take pretty long with a tree that has thousands of tips. Mitigate this with the `artist collections` objects in `matplotlib`. 

In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.collections import PatchCollection, LineCollection
import matplotlib.path as mpath
import matplotlib.lines as mlines
from matplotlib import gridspec
import matplotlib as mpl
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42

font = {'family' : 'sans-serif',
        'sans-serif': 'Ariel'}
mpl.rc('font', **font)

import numpy as np
import pandas as pd
import os
import sys
from functools import reduce
import subprocess
import time

# Set random state
my_randomstate = np.random.RandomState(12345)

import xio
import patho_tools as pt

import baltic3 as bt
import baltic3_utils as btu
#from Bio import Phylo

my_random_state = np.random.RandomState(12345)
seg_ls = ["H3", "N2", "PB1", "PB2", "PA", "MP", "NP", "NS"]


def get_leaf_coords(tipname, tre):
    """Searches for part of, or all of, a tipname in the list of tipnames in a tree. 
    Returns the first instance if multiple matches exist, but this fails silently."""
    
    found = 0
    tip_x = 0
    tip_y = 0
    for lf in tre.leaves:
        if tipname in lf.name:
            tip_x = lf.height
            tip_y = lf.y
            found = 1
    if found == 0:
        print("ERROR: %s not found!" % tipname)
    return tip_x, tip_y

In [2]:
# Load trees
# The order in which the trees are loaded will be the order that they're drawn in
# This will populate next_seg_dict{} and x_offset_dict{}
t_dict = {}
for seg in seg_ls:
    print(seg)
    fn = "tango_"+seg+"_mls.nex"
    t_dict[seg] = btu.austechia_read_tree(fn, date_delim="|")


H3
Number of objects found in tree string: 1199

Tree height: 0.026984
Tree length: 1.004646
strictly bifurcating tree

Numbers of objects in tree: 1199 (599 nodes and 600 leaves)

Highest tip date: 2017.0082
N2
Number of objects found in tree string: 1199

Tree height: 0.025195
Tree length: 0.949045
strictly bifurcating tree

Numbers of objects in tree: 1199 (599 nodes and 600 leaves)

Highest tip date: 2017.0082
PB1
Number of objects found in tree string: 1199

Tree height: 0.022644
Tree length: 0.730114
strictly bifurcating tree

Numbers of objects in tree: 1199 (599 nodes and 600 leaves)

Highest tip date: 2017.0082
PB2
Number of objects found in tree string: 1199

Tree height: 0.030902
Tree length: 0.713739
strictly bifurcating tree

Numbers of objects in tree: 1199 (599 nodes and 600 leaves)

Highest tip date: 2017.0082
PA
Number of objects found in tree string: 1199

Tree height: 0.021206
Tree length: 0.675799
strictly bifurcating tree

Numbers of objects in tree: 1199 (599 node

In [3]:
# =================================== PARAMS ===================================
# TREE PARAMS
branchWidth=0.5 # line thickness of branches

# compute x_offset_dict from tree heights
x_offset_dict = {seg_ls[0]:0}
interval_dist_multiplier = 1.1
for i in range(1, len(seg_ls)):
    prev_seg = seg_ls[i-1]
    x_offset_dict[seg_ls[i]] = x_offset_dict[prev_seg] + t_dict[prev_seg].treeHeight*interval_dist_multiplier

# Prep next seg dict
next_seg_dict = {}
keys_ls = list(x_offset_dict.keys())
for i in range(len(keys_ls)-1):
    next_seg_dict[keys_ls[i]] = keys_ls[i+1]


clade_cdict = {"A/Singapore/INFIMH-16-0019/2016":"red", 
               'A/SouthAustralia/30/2012':"orange", 
              'A/HongKong/4801/2014':"blue", 
              'A/SouthAustralia/55/2014':"green"}


In [5]:
# Gytis' disentangling block
tip_positions={x:{} for x in seg_ls} ## remember the position of each tip in each tree

for t,tr in enumerate(t_dict.keys()): ## iterate over trees
    cur_tree=t_dict[tr] ## fetch tree object
    for k in cur_tree.Objects:
        if k.branchType=='leaf':
            tip_positions[tr][k.name]=(k.height,k.y) ## remember XY position of tip

cmap=mpl.cm.Spectral



In [10]:
for X in range(10): ## 10 untangling iterations
    print('iteration %d'%(X+1))
    for t,tr in enumerate(seg_ls): ## iterate over each tree
        print(tr)
        ptr=seg_ls[t-1] ## previous tree
        ntr=seg_ls[t] ## next tree
        seg=t_dict[ptr] ## fetch appropriate tree
        nex_seg=t_dict[ntr]
        for k in sorted(nex_seg.Objects,key=lambda q:q.height): ## iterate over branches from most recent to oldest
            if k.branchType=='node': ## can only sort nodes
                #leaves=[[seg.tipMap[tip] for tip in w.leaves] if w.branchType=='node' else [w.name] for w in k.children] ## descendent tips in current order
                leaves = k.leaves
                
                for c in range(len(leaves)):
                    leaves[c]=sorted(leaves[c],key=lambda x:tip_positions[ntr][x][1])
                
                ys=[sorted([tip_positions[ntr][w][1] for w in cl]) for cl in leaves] ## extract y positions of descendents
                merge_ys=[i for s in ys for i in s] ## flatten list of tip y coordinates
                ypos=range(min(merge_ys),max(merge_ys)+1) ## get y positions of tips in current order
                
                order={i:x for i,x in enumerate(leaves)} ## dict of tip order: tip name
                new_order=sorted(order.keys(),key=lambda x:-np.mean([(tip_positions[ptr][order[x][w]][1]-ypos[w]) for w in range(len(order[x]))])) 
                ## get new order by sorting existing order based on y position differences
                
                if new_order!=range(len(leaves)): ## if new order is not current order
                    k.children=[k.children[i] for i in new_order] ## assign new order of child branches
                    nex_seg.drawTree() ## update y positions

                    for w in nex_seg.Objects: ## iterate over objects in next tree
                        if w.branchType=='leaf':
                            tip_positions[ntr][w.name]=(w.height,w.y) ## remember new positions
                
        if t==0: ## if first tree
            trees[segments[t]].drawTree() ## update positions
            lvs=sorted([w for w in trees[segments[t]].Objects if w.branchType=='leaf'],key=lambda x:x.y) ## get leaves in y position order
            
            norm=mpl.colors.Normalize(0,len(lvs))
            pos_colours={w.name:cmap(norm(w.y)) for w in lvs} ## assign colour

iteration 1
H3
H3


KeyError: 'A'

In [12]:
tip_positions["H3"]

{'A/H3N2|A/HongKong/4801/2014|HongKong|2014': (0.010303999999999919, 216),
 'A/H3N2|A/Singapore/INFIMH-16-0019/2016|Singapore|2016-06-14': (0.022037999999999953,
  517),
 'A/H3N2|A/SouthAustralia/30/2012|Australia|2012-05-10': (0.012556999999999914,
  27),
 'A/H3N2|A/SouthAustralia/55/2014|Australia|2014-06-29': (0.01028799999999992,
  95),
 'A/H3N2|A/Stockholm/6/2014|Sweden|2014-02-06': (0.009051999999999916, 78),
 'A/H3N2|A/Switzerland/9715293/2013|Switzerland|2013-12-06': (0.01028799999999992,
  93),
 'A/H3N2|A/Victoria/361/2011|Australia|2011-10-24': (0.002661999999999905, 1),
 'H3N2|A/Brisbane/1/2012|70382-HA|H3|QLD/GoldCoast|2012-01-06': (0.004115999999999915,
  54),
 'H3N2|A/Brisbane/1/2013|70459-HA|H3|QLD/Brisbane|2013-01-03': (0.007192999999999918,
  200),
 'H3N2|A/Brisbane/1/2017|70973-HA|H3|QLD/SunshineCoast|2017-01-02': (0.01778299999999992,
  91),
 'H3N2|A/Brisbane/100/2014|70499-HA|H3|QLD/Brisbane|2014-04-07': (0.01092899999999992,
  225),
 'H3N2|A/Brisbane/1000/2015|7062

In [4]:
fig,ax = plt.subplots(figsize=(20, 9.3),facecolor='w')
# ==================== Draw just the HA tree ====================
preview = 0
patches_ls = []
lines_ls = []

# to store connecting lines between non-reference tips
dotted_lines_ls = []
# to store connecting lines between reference tips
dotted_lines_ref_ls = []

for i in range(len(seg_ls)):
    seg = seg_ls[i]
    print("Drawing " + seg + "...")
    
    # Segment labels
    x_offset = x_offset_dict[seg]
    ax.text(x_offset*1.05, 610, seg, size=20)
    
    for k in t_dict[seg].Objects:
        c = 'k'
        x=k.height
        y=k.y

        xp = k.parent.height
        if x is None: # matplotlib won't plot Nones, like root
            x = x_offset
        if xp==None:
            xp = x + x_offset

        if isinstance(k,bt.leaf) or k.branchType=='leaf':
            # Plot the tips of the references
            if k.name.split("|")[0] == "A/H3N2":
                if k.name.split("|")[1] in list(clade_cdict.keys()):
                    ax.scatter(x+x_offset, y, s=50, alpha=0.25, c=clade_cdict[k.name.split("|")[1]])
                    ax.text(x+x_offset, y, k.name.split("|")[1], size=10, color="k")
                
                # Draw connecting lines between references
                if seg != seg_ls[-1]:
                    next_x_offset = x_offset_dict[next_seg_dict[seg]]
                    x1, y1 = get_leaf_coords(k.name.split("|")[1], t_dict[next_seg_dict[seg]])
                    x1=x1+next_x_offset

                    # Draw connecting lines
                    line = np.array([[x0, y0],[x1, y1]])
                    dotted_lines_ref_ls.append(line)
                
            # ========== Draw connecting lines ==========
            if seg != seg_ls[-1]:
                # Coords of the current tip
                x0 = x+x_offset
                y0 = y
                # Get coords of the tip in the next tree
                next_x_offset = x_offset_dict[next_seg_dict[seg]]
                x1, y1 = get_leaf_coords(k.name.split("|")[1], t_dict[next_seg_dict[seg]])
                x1=x1+next_x_offset

                # Draw connecting lines
                #d_temp = dm2.loc[dm2["tip_name_lookup"]==k.name.split("|")[1]]
                #if len(d_temp) > 0:
                #    clade_label = d_temp["clade_label_iso_name"].values[0]
                #    ax.plot([x0, x1], [y0, y1], c=clade_cdict[clade_label], lw=branchWidth/2, alpha=0.4)


        elif isinstance(k,bt.node) or k.branchType=='node':
            line = np.array([[x+x_offset, k.children[0].y], [x+x_offset, k.children[-1].y]])
            lines_ls.append(line)

        line = np.array([[xp+x_offset, y], [x+x_offset, y]])
        lines_ls.append(line)
        
patch_collection = PatchCollection(patches_ls, color="k", zorder=10)
line_collection = LineCollection(lines_ls, lw=branchWidth,color='k', zorder=10)
ax.add_collection(patch_collection)
ax.add_collection(line_collection)

ax.add_collection(line_collection)
line_collection = LineCollection(dotted_lines_ref_ls, 
                                 lw=branchWidth/1.5,
                                 linestyle="dashed",
                                 color='red', 
                                 alpha=0.5,
                                 zorder=10)
ax.add_collection(line_collection)


ax.set_ylim(-0.1*t_dict["H3"].ySpan,t_dict["H3"].ySpan*1.1)
x_right_edge = (x_offset_dict[seg_ls[-1]]+t_dict[seg_ls[-1]].treeHeight)*1.05
ax.set_xlim(-0.003,x_right_edge)

plt.axis("off")
plt.tight_layout()
plt.savefig("aus_h3n2_tango.pdf", bbox="inches")

plt.show()

Drawing H3...


NameError: name 'get_leaf_coords' is not defined

  (prop.get_family(), self.defaultFamily[fontext]))


ValueError: Image size of 1161x308430 pixels is too large. It must be less than 2^16 in each direction.

<matplotlib.figure.Figure at 0x111193f98>

In [None]:
t_dict["H3"].ySpan