# Clade Labelling

Development workbook for clade-labelling functions. 

### Input
 - a tree file, with some reference names embedded in there.
 - list of reference names

### Output
A dataframe with a list of tipnames, and their clade labels. 

It's possible to not have a clade label if that tipname is too far away from even the nearest reference to be meaningfully determined to "belong" to that reference.

In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.collections import PatchCollection, LineCollection
import matplotlib.path as mpath
import matplotlib.lines as mlines
from matplotlib import gridspec

import numpy as np
import pandas as pd
import os
os.chdir("/users/dten0001/Google Drive/baltic3/examples")
import sys
from functools import reduce
import subprocess
import time

# Set random state
my_randomstate = np.random.RandomState(12345)

import xio
import my_utils as xu
import biophylo_utils as bpu
import baltic3 as bt
import baltic3_utils as btu

from Bio import Phylo


def tip_to_tip_distance(my_tree, tip1, tip2):
    """Computes the tip-to-mrca-to-tip distance between tip1 and tip2.

    PARAMS
    ------
    my_tree: biopython tree object.
    tip1, tip2: biopython Clade objects

    RETURNS
    -------
    tt_dist: float. tip-to-mrca-to-tip distance between tip1 and tip2.
    """

    mrca = my_tree.common_ancestor(tip1, tip2)
    t1_mrca_dist = my_tree.distance(tip1, mrca)
    t2_mrca_dist = my_tree.distance(tip2, mrca)

    tt_dist = t1_mrca_dist + t2_mrca_dist
    return tt_dist


def get_tt_dist_from_tree(my_tree, ref_names_ls, verbose=True):
    
    t0 = time.time()
    
    # Get all tipnames
    names_ls = [x.name for x in tree.get_terminals()]
    
    # Check
    for ref_nm in ref_names_ls:
        if ref_nm not in names_ls:
            print("WARNING: %s not found in input tree!" % ref_nm)

    # Get the non-reference names
    non_ref_names_ls = list(set.difference(set(names_ls), set(ref_names_ls)))
    
    if verbose:
        print("No. of reference tip names = %s" % len(ref_names_ls))
        print("No. of non-reference tip names = %s" % len(non_ref_names_ls))

    # Compute all possible distances between all tips and all references
    contents = []
    for nm in non_ref_names_ls:
        line = [nm]
        for ref_nm in ref_names_ls:
            tt_dist = tip_to_tip_distance(my_tree, nm, ref_nm)
            line.append(tt_dist)
            
        contents.append(line)

    col_names = ["tip_name"]
    for ref_nm in ref_names_ls:
        col_names.append("dist_to_"+ref_nm)

    df = pd.DataFrame(data=contents, 
                      columns=col_names)
    
    if verbose:
        print("Done in %.2fs" % (time.time() - t0))
        
    return df


def get_clade_labels(df):
    """Given a dataframe that's the output from get_tt_dist_from_tree(),
    Compute the minimum distance as an additional col, and the
    corresponding nearest reference.
    """
    colnames = df.columns
    #df['min_dist'] = df.loc[:, colnames].min(axis=1)
    print("WIP")
    

In [4]:
tree = Phylo.read("H3_test_tree.nex", 'nexus')
tree.root_at_midpoint()
tree.ladderize()

names_ls = [x.name for x in tree.get_terminals()]
ref_names_ls = []
for nm in names_ls:
    if nm.split("|")[0] == "'A/H3N2":
        ref_names_ls.append(nm)

df = get_tt_dist_from_tree(tree, ref_names_ls)

No. of reference tip names = 7
No. of non-reference tip names = 593
Done in 7.70s


In [5]:
df_cols = list(df.columns)[1:]
df["min_col"] = df[df_cols].idxmin(axis=1)
df["min_dist"] = df.loc[:, df_cols].min(axis=1)

In [6]:
df

Unnamed: 0,tip_name,dist_to_'A/H3N2|A/Victoria/361/2011|Australia|2011-10-24',dist_to_'A/H3N2|A/SouthAustralia/30/2012|Australia|2012-05-10',dist_to_'A/H3N2|A/Stockholm/6/2014|Sweden|2014-02-06',dist_to_'A/H3N2|A/Switzerland/9715293/2013|Switzerland|2013-12-06',dist_to_'A/H3N2|A/SouthAustralia/55/2014|Australia|2014-06-29',dist_to_'A/H3N2|A/HongKong/4801/2014|HongKong|2014',dist_to_'A/H3N2|A/Singapore/INFIMH-16-0019/2016|Singapore|2016-06-14',min_col
0,'H3N2|A/Tasmania/3/2016|70811-HA|H3|TAS|2016-0...,0.020387,0.030282,0.021011,0.022247,0.022247,0.009877,0.006795,dist_to_'A/H3N2|A/Singapore/INFIMH-16-0019/201...
1,'H3N2|A/Sydney/224/2015|70795-HA|H3|NSW/StMary...,0.019144,0.029039,0.019768,0.021004,0.021004,0.008634,0.016676,dist_to_'A/H3N2|A/HongKong/4801/2014|HongKong|...
2,'H3N2|A/Victoria/21/2015|70701-HA|H3|VIC/Kew|2...,0.017905,0.027800,0.018529,0.019765,0.019765,0.007395,0.015437,dist_to_'A/H3N2|A/HongKong/4801/2014|HongKong|...
3,'H3N2|A/Sydney/1005/2014|70521-HA|H3|NSW/Yagoo...,0.013582,0.023477,0.010516,0.011752,0.011752,0.015458,0.027192,dist_to_'A/H3N2|A/Stockholm/6/2014|Sweden|2014...
4,'H3N2|A/SouthAustralia/108/2016|70908-HA|H3|SA...,0.020348,0.030243,0.009864,0.009860,0.009860,0.022224,0.033958,dist_to_'A/H3N2|A/Switzerland/9715293/2013|Swi...
5,'H3N2|A/Sydney/53/2015|70681-HA|H3|NSW/Sydney|...,0.017901,0.027796,0.018525,0.019761,0.019761,0.007391,0.012965,dist_to_'A/H3N2|A/HongKong/4801/2014|HongKong|...
6,'H3N2|A/Newcastle/190/2016|70925-HA|H3|NSW|201...,0.021012,0.030907,0.021636,0.022872,0.022872,0.010502,0.013602,dist_to_'A/H3N2|A/HongKong/4801/2014|HongKong|...
7,'H3N2|A/Brisbane/299/2016|70957-HA|H3|QLD/More...,0.029646,0.039541,0.030270,0.031506,0.031506,0.019136,0.012360,dist_to_'A/H3N2|A/Singapore/INFIMH-16-0019/201...
8,'H3N2|A/Canberra/127/2013|70480-HA|H3|ACT/Canb...,0.009235,0.019130,0.006169,0.007405,0.007405,0.011111,0.022845,dist_to_'A/H3N2|A/Stockholm/6/2014|Sweden|2014...
9,'H3N2|A/Sydney/233/2012|70434-HA|H3|NSW/Hallid...,0.008016,0.010511,0.017272,0.018508,0.018508,0.018524,0.030258,dist_to_'A/H3N2|A/Victoria/361/2011|Australia|...


In [None]:
df_dict = {}
for seg in seg_ls:
    print(seg)
    
    
# Rename columns
df_dict2 = {}
for seg in seg_ls:
    d_temp = df_dict[seg][["v_name", "clade_label", "tt_dist_min"]]
    d_temp.columns = ["v_name", "clade_label_"+seg, "tt_dist_min_"+seg]
    df_dict2[seg] = d_temp
    
# Merge the required columns    
df_ls = [df_dict2[seg] for seg in seg_ls]
dz = reduce(lambda left, right: pd.merge(left, right, on="v_name"), df_ls)