In [10]:
import pandas as pd
from random import randint
import numpy as np

mutation_rate = pd.read_table("./mutation_rate.tsv")
mutation_rate = dict(zip(zip(mutation_rate["state"],mutation_rate["pos"]),mutation_rate["rate"]))
mutation_rate

{(0, 0): 0.09024896265560166,
 (0, 1): 0.020713463751438437,
 (0, 2): 0.09248956884561893,
 (0, 3): 0.12156448202959833,
 (0, 4): 0.1042319749216301,
 (0, 5): 0.08647140864714087,
 (0, 6): 0.03686934023285899,
 (0, 7): 0.16951219512195126,
 (0, 8): 0.19682151589242053,
 (0, 9): 0.06385542168674699,
 (1, 0): 0.7510373443983402,
 (1, 1): 0.9626006904487916,
 (1, 2): 0.8630041724617524,
 (1, 3): 0.7473572938689218,
 (1, 4): 0.8244514106583072,
 (1, 5): 0.910041841004184,
 (1, 6): 0.9301423027166882,
 (1, 7): 0.6621951219512195,
 (1, 8): 0.6381418092909535,
 (1, 9): 0.9319277108433734,
 (2, 0): 0.15871369294605808,
 (2, 1): 0.01668584579976985,
 (2, 2): 0.04450625869262865,
 (2, 3): 0.13107822410147993,
 (2, 4): 0.0713166144200627,
 (2, 5): 0.0034867503486750353,
 (2, 6): 0.032988357050452784,
 (2, 7): 0.16829268292682928,
 (2, 8): 0.16503667481662593,
 (2, 9): 0.004216867469879518}

In [11]:
class TreeNode(object):
    def __init__(self,label=None,seq=None,parent=None,leaves=[]):
        self.label = label
        self.seq = seq
        self.left = None
        self.right = None
        self.parent = parent
        self.leaves = leaves

In [12]:
class hcluster(object):
    def __init__(self,mutation_rate,df):
        self.mutation_rate = mutation_rate
        self.df = df
        self.all_tree_nodes = [TreeNode(label='_'.join([i[1].cell,i[1].state]),seq=i[1].state,\
                                       leaves=[i[0]]) for i in df.iterrows()]
        self.root = TreeNode(leaves=[i for i in range(len(self.all_tree_nodes))])
    
    def initialize_dist_matrix(self,df):
        n = len(df)
        m = len(df["state"][0])
        
        dist_matrix = np.zeros((n,n))
        dist_matrix2 = np.zeros((n,n))
        for i in range(n):
            for j in range(i):
                dist_matrix[i,j] = self.sequence_dist(df["state"][i],df["state"][j])
                dist_matrix2[i,j] = self.sequence_dist2(df["state"][i],df["state"][j])
                
        dist_matrix = dist_matrix + dist_matrix.T
        dist_matrix2 = dist_matrix2 + dist_matrix2.T
        return dist_matrix + self.lambda1 * dist_matrix2
#         return dist_matrix
    
    def hamming_dist(self,s1,s2):
        return sum([s1[i]!=s2[i] for i in range(len(s1))])
    
    def hamming1_dist(self,s1,s2):
        return sum([s1[i]!=s2[i] or s1[i]!="1" for i in range(len(s1))])
    
    def sequence_dist(self,s1,s2):
        def match_prob(a,b,i):
#             if a != "1" and a == b:
            if a == b:
                return math.log(self.mutation_rate[(int(a),i)])
            return 0
            return 1

        return sum([match_prob(s1[i],s2[i],i) for i in range(len(s1))])
  
    def sequence_dist2(self,s1,s2):
        def mismatch_prob(a,b,i):
#             if a != "1" and b != "1" and a == b:
            if a == b:
                return 0   
            return -math.log(self.mutation_rate[int(a),i]*self.mutation_rate[int(b),i])
        
        return sum([mismatch_prob(s1[i],s2[i],i) for i in range(len(s1))])

    
    def merge_node(self,D):
        np.fill_diagonal(D,float("inf"))
        i = np.argmin(D)
        i,j = i//len(D),i%len(D)
#         tmp = np.where(D==np.min(D))
#         s = randint(0,len(tmp[0])-1)
#         i,j = tmp[0][s],tmp[1][s]
        return i,j
    
    def update_dist_matrix(self,clusters):
        def cluster_dist(c1,c2):
            d = 0
            for cc1 in c1:
                for cc2 in c2:
                    d += self.D[cc1,cc2]
            return d/(len(c1)*len(c2))
        
        n = len(clusters)
        D = np.zeros((n,n))
        
        for i in range(n):
            for j in range(i):
                D[i,j] = cluster_dist(clusters[i].leaves,clusters[j].leaves)
                
        D = D + D.T
        return D
    
    def h_clustering(self,lambda1=0.1):
        self.lambda1 = lambda1
        clusters = self.all_tree_nodes
#         print([i.leaves for i in clusters])
        self.D = self.initialize_dist_matrix(self.df)
        
        D = self.D
        nn = len(self.D)
        
        for c in range(nn-1):
#             print("Dist matrix")
#             print(np.around(D,decimals=1))
            n = len(D)
            i,j = self.merge_node(D)
            parent = TreeNode()
            parent.left = clusters[i]
            parent.right = clusters[j]
#             print("joining nodes: %s and %s" \
#               %(' '.join([str(c) for c in clusters[i].leaves]),\
#                 ' '.join([str(c) for c in clusters[j].leaves])))
            parent.leaves = clusters[i].leaves + clusters[j].leaves
            clusters = [clusters[idx] for idx in range(len(clusters)) if idx!=i and idx!=j] + [parent]
#             print([i.leaves for i in clusters])
            D = self.update_dist_matrix(clusters)
        self.root = parent
    
    def write_newick(self):
        def helper(root):
            if not root:
                return

            if not root.left and not root.right:
                return "%s" % (root.label)

            return "(%s,%s)" % (helper(root.left),helper(root.right))

        return helper(self.root) + "root;"

In [14]:
import glob
import numpy as np
import math

def solver(args):
    df = args[0]
    sample = args[1]
    res = hcluster(mutation_rate,df)
    res.h_clustering(lambda1=0)
    outfile = "../../Output/Subchallenge1/HC/%s.nwk" %sample
    f = open(outfile,"w")
    f.write(res.write_newick())
    f.close()

samples = [i.split("/")[-1].split(".")[0] for i in glob.glob("../../Data/Subchallenge1/*.txt")]
for sample in samples:
#     print(sample)
    if sample:
        df = pd.read_table("../../Data/Subchallenge1/%s.txt"%sample,dtype="str")
        solver([df,sample])

In [78]:
import glob
import numpy as np
import math

def solver(args):
    df = args[0]
    sample = args[1]
    res = hcluster(mutation_rate,df)
    res.h_clustering(lambda1=0)
    outfile = "../../Output/Subchallenge1/HC_hamming/%s.nwk" %sample
    f = open(outfile,"w")
    f.write(res.write_newick())
    f.close()

samples = [i.split("/")[-1].split(".")[0] for i in glob.glob("../../Data/Subchallenge1/*.txt")]
for sample in samples:
    if sample:
        df = pd.read_table("../../Data/Subchallenge1/%s.txt"%sample,dtype="str")
        solver([df,sample])

In [15]:
ID = []
trees = []
samples = [i.split("/")[-1].split(".")[0] for i in glob.glob("../../Data/Subchallenge1/*test*.txt")]
for i in range(30):
    ID.append(i+1)
    sample = "sub1_test_%d"%(i+1)
    df = pd.read_table("../../Data/Subchallenge1/%s.txt"%sample,dtype="str")
    print(sample,len(df))
    res = hcluster(mutation_rate,df)
    res.h_clustering(lambda1=0)
    trees.append(res.write_newick())
    
    outfile = "../../Output/Subchallenge1/HC/%s.nwk" %sample
    f = open(outfile,"w")
    f.write(res.write_newick())
    f.close()

sub1_test_1 4
sub1_test_2 19
sub1_test_3 20
sub1_test_4 6
sub1_test_5 4
sub1_test_6 20
sub1_test_7 9
sub1_test_8 19
sub1_test_9 13
sub1_test_10 14
sub1_test_11 29
sub1_test_12 20
sub1_test_13 4
sub1_test_14 21
sub1_test_15 7
sub1_test_16 12
sub1_test_17 13
sub1_test_18 10
sub1_test_19 11
sub1_test_20 9
sub1_test_21 15
sub1_test_22 9
sub1_test_23 14
sub1_test_24 18
sub1_test_25 27
sub1_test_26 4
sub1_test_27 6
sub1_test_28 23
sub1_test_29 29
sub1_test_30 15


In [33]:
df = pd.DataFrame(data={"dreamID":ID,"nw":trees})
df.to_csv(r'../../Output/Subchallenge1/Jasper_Hu_sub1_submission.txt',index=None, sep='\t')

In [34]:
df

Unnamed: 0,dreamID,nw
0,1,"(4_2212210001,(2_2112210001,(1_2012210001,3_20..."
1,2,"((8_1111111111,17_1111111111),(18_1111011001,(..."
2,3,"(5_1111111111,((15_0011011221,(11_1111011221,(..."
3,4,"((3_2221012200,6_2221012221),((1_2202000001,5_..."
4,5,"(3_2102111000,(2_2100111000,(1_2100101000,4_21..."
5,6,"((3_2111111111,20_2111111211),((((1_0122210001..."
6,7,"(((1_2002110021,9_2002111000),(5_2102111021,8_..."
7,8,"((8_2211100221,(15_2110101221,(2_2111101221,18..."
8,9,"((7_1212210021,(10_2212110021,12_2212110021)),..."
9,10,"((11_2102111000,((7_2001120012,12_2001120012),..."
