In [None]:
#!/usr/bin/python4
#author: @rayezh

import numpy as np
import pandas as pd
from itertools import combinations
import glob, random, re, os, sys, shutil
import argparse
import math
import dendropy
from Bio import Phylo
from nexus import NexusReader
from dendropy.calculate import treecompare, treemeasure
from tree_evaluations import evaluate_normalized_RF_distance,evaluate_triplet_distance, evaluate_triplet_distance_correlation
from Bio.Phylo.TreeConstruction import DistanceTreeConstructor,DistanceMatrix,_Matrix
from sklearn.ensemble import RandomForestRegressor
from tmc_wrapper.triplets_distance import triplets_score
from collections import Counter, defaultdict

In [None]:
def one_hot_embedding(lineage):
    """ 
    embedding: 
    symbol_set[loc][mutation] = weight
    """
    import string
    from collections import defaultdict
    import numpy as np
    all_symbol = []
    symbol_set =  {x:defaultdict(lambda:0) for x in range(200)}
    for _,cell in lineage.items():
        for i in range(len(cell)):
            symbol_set[i][cell[i]] +=1 
        all_symbol.extend(cell) 
    all_symbol = sorted(list(set(all_symbol)))
    for k1, v1 in symbol_set.items():
        for k2, v2 in v1.items():
            symbol_set[k1][k2] = np.log10(100/v2)+1  #log base =10 
    return symbol_set, all_symbol

In [None]:
def fix_gap(record):
    """
    Mark every gap with start and end:
        Gap can be uniquely markded by mutations at both ends: mutation type and location.
        We can also mark gaps resulted from three or more than three simultaneous mutations.
    input:
        mution recordings in list format.
    output:
        mutation recordings with gaps marked.
    """
    i = 0
    gap = False
    record_new = record.copy()
    while i < len(record):
        n = 0
        #print(record[i], gap)
        if(record[i] =='-')and(gap == False)and(record[i-1] in list('0'+'ABCDEFGHIJKLMNOPQRSTUVWXYZ'+'abcd')):
            start = record[i-1]
            start_idx = i-1
            gap = True
            #score = embed[start]
            sign = 'gap'+start+str(start_idx)
        if(record[i] !='-')and(gap==True)and(record[i] in list('0'+'ABCDEFGHIJKLMNOPQRSTUVWXYZ'+'abcd')):
            end = record[i]
            end_idx = i
            #score = score*embed[end] #5 is arbitury
            sign = sign+'_'+end+str(end_idx)
            if (i+1 == len(record))or (record[i+1] != '-'): # gap ends
                gap = False
                for j in range(start_idx, end_idx+1):
                    record_new[j]=sign
        i+=1
    return record_new

In [None]:
def find_nearest_pair(lineage, embed,  mutation_all):
    """
    input:
        lineage: a dictionary of taxa:record ['A', 'B', 'C', ...].
    output: 
        sister:the taxa of nearest pair.
    """
    from scipy.sparse import csr_matrix
    print('finding the nearest pair in',len(lineage),'cells ...')
    num_mutation_all =len(mutation_all)
    n = len(lineage[0]) #200
    mat = np.zeros((len(lineage), num_mutation_all*200))  #cell mutation  matrix for 100 cells; every cell has 200 barcodes
    dist_mat = np.zeros((len(lineage), num_mutation_all*200))
    i = 0
    for k, v in lineage.items():
        for idx,state in enumerate(v):
            if state =='0':
                mat[i, idx*num_mutation_all:(idx+1)*num_mutation_all] = 1
            else:
                mat[i, idx*num_mutation_all+mutation_all.index(state)] = num_mutation_all*embed[idx][state] #num_mutation_all to lift the weights

            if state.startswith('gap'):
                dist_mat[i, idx*num_mutation_all:(idx+1)*num_mutation_all] = 0
            else:
                dist_mat[i, idx*num_mutation_all:(idx+1)*num_mutation_all] = 1
        i+=1

    product = mat.dot(mat.T)
    dist_product = dist_mat.dot(dist_mat.T)
    product = np.divide(product, dist_product)
    np.fill_diagonal(product, 0)
    x = len(lineage.keys())
    idx = product.argmax()
    while (idx//x) == (idx % x):
        product[(idx//x), (idx//x)] -= 1
        idx = product.argmax()

    pair1 = list(lineage.keys())[(idx//x)]
    pair2 = list(lineage.keys())[(idx % x)]
    print(product[(idx//x), (idx % x)],(pair1, pair2))
    return[pair1, pair2]

In [None]:
def hierachical_clustering(X):
    '''
    input: recordings of cells; list of tuples-(cell_name, recording)
    output: clustered tree in newick format (with root)
    '''
    lineage = {}
    for taxa, record in X:
        record = fix_gap(record)
        lineage.update({taxa:record}) 
    while len(lineage) >= 2:
        embed, all_symbol = one_hot_embedding(lineage) #customed weights for each round of clustering
        #find the nearest two cells based on
        #compute distance matrix and find minimum
        sister = find_nearest_pair(lineage, embed, all_symbol)
        record_new = []
        for i in range(len(lineage[sister[0]])):
            if lineage[sister[0]][i]==lineage[sister[1]][i]:
                record_new.append(lineage[sister[0]][i])
            else:
                if lineage[sister[0]][i].startswith('gap') and (not lineage[sister[1]][i].startswith('gap')):
                    record_new.append(lineage[sister[1]][i])
                elif lineage[sister[1]][i].startswith('gap') and (not lineage[sister[0]][i].startswith('gap')):
                    record_new.append(lineage[sister[0]][i])
                else:
                    record_new.append('0')
        cluster_new = {'('+sister[0]+','+sister[1]+')':record_new}
        lineage.pop(sister[0])
        lineage.pop(sister[1])
        lineage.update(cluster_new)

    rec_tree = list(lineage.keys())[0]+'root;'
    return rec_tree

In [None]:
def main():
    for i in range(1,101):
        print(i)
        X = []
        input_file_path = '../SubC2_train_TXT/SubC2_train_'+format(i,'04d')+'.txt'   #path of input recordings
        output_file_path = 'SubC2_train_'+format(i,'04d')+'.nw'           #path of reconstructed lineage
        with  open(input_file_path, 'r') as Xfile:
            for line in Xfile:
                table = line.rstrip().split('\t')
                X.append((table[0], list(table[1])))
        rec_tree = hierachical_clustering(X)
        rec_file = open(output_file_path, 'w')
        rec_file.write(rec_tree)

In [None]:
if __name__ == '__main__':
    main()