In [None]:
# Clone tree construction

pair_cluster = squareform(pdist(model.encode(tmodel.fc2.weight.T)[0].detach().numpy()))
pair_cluster[pair_cluster == 0] = np.nan

included = np.zeros(cluster_no).astype(bool)
edge = []
edge_distance = []

connection = np.unravel_index(np.nanargmin(pair_cluster), pair_cluster.shape)
edge.append((connection[0] + 1, connection[1] + 1))
edge_distance.append(pair_cluster[connection])
included[connection[0]] = True
included[connection[1]] = True
subset = pair_cluster[included, :]
subset[:, included] = np.nan
pair_cluster[included, :] = subset

while ~ included.all():
    
    possible = np.empty(pair_cluster.shape)
    possible[:] = np.nan
    possible[included, :] = pair_cluster[included, :]
    connection = np.unravel_index(np.nanargmin(possible), pair_cluster.shape)
    edge.append((connection[0] + 1, connection[1] + 1))
    edge_distance.append(pair_cluster[connection])
    included[connection[0]] = True
    included[connection[1]] = True
    subset = pair_cluster[included, :]
    subset[:, included] = np.nan
    pair_cluster[included, :] = subset

In [None]:
# 2 proportion Z-tests

import scipy.stats

edge_weight = 1 / (np.array(edge_distance) ** 2)
node_weight = []

for w in range(len(cluster)):
    
    node_weight.append(len(cluster[w]))

SNP_cluster = tmodel.fc2.weight.T.detach().numpy()
Zmatrix = np.empty((len(edge), SNP.shape[1]))

for u in range(len(edge)):
    
    p1 = SNP_cluster[edge[u][0] - 1]
    p2 = SNP_cluster[edge[u][1] - 1]
    n1 = np.sum(np.asarray(DP.todense())[allgene_var > threshold_var, :].T[cluster[edge[u][0] - 1], :], 0)
    n2 = np.sum(np.asarray(DP.todense())[allgene_var > threshold_var, :].T[cluster[edge[u][1] - 1], :], 0)
    p0 = np.clip((n1 * p1 + n2 * p2) / (n1 + n2), 0.01, None)
    Zmatrix[u, :] = (p1 - p2) / np.sqrt(p0 * (1 - p0) * (1 / n1 + 1 / n2))
    
p_val = np.clip(scipy.stats.norm.sf(abs(Zmatrix)) * 2, None, 1)
node_weight = np.array(node_weight)

correction = False
SNP_diff = []
SNP_diff_pval = []

for w in range(len(edge)):
    
    if correction:
        
        significant = p_val[w] < 0.05 / allgene_var.shape[0]
        
    else:
        
        significant = p_val[w] < 0.05 / 10e33
        
    SNP_diff.append(np.arange(1, p_val.shape[1] + 1)[significant][np.argsort(p_val[w][significant])])
    SNP_diff_pval.append(p_val[w][significant][np.argsort(p_val[w][significant])])

In [None]:
# Clone tree visualization

import networkx as nx

graph = nx.Graph()

for u in range(len(edge)):
    
    graph.add_edge(edge[u][0], edge[u][1], weight = edge_weight[u], SNP = str(SNP_diff[u].shape[0]))
    
pos = nx.spring_layout(graph, iterations = 3000, weight = 'weight', k = 10 / np.sqrt(graph.order()))

plt.figure(figsize = (13, 7)) 
nx.draw(graph, pos, with_labels = True, font_weight = 'bold', node_size = node_weight, node_color = colors)
nx.draw_networkx_edge_labels(graph, pos, nx.get_edge_attributes(graph, 'SNP'), rotate = False, alpha = 0.75)
plt.title("Phylogenetic Tree", fontsize = 30)
plt.show()

In [None]:
# Heatmap sorted by clone tree

import seaborn as sns
import pandas as pd

root = np.argmin(np.sum(model.encode(tmodel.fc2.weight.T)[0].detach().numpy() ** 2, 1)) + 1
move = 0
edge_map = edge.copy()
current_position = root
cluster_order = [root - 1]
history = []

while move < len(edge):
    
    current_move = move
    
    for w in range(len(edge_map)):
        
        if current_position in edge_map[w]:
            
            history.append(current_position)
            
            if edge_map[w][0] == current_position:
                
                current_position = edge_map[w][1]
                
            elif edge_map[w][1] == current_position:
                
                current_position = edge_map[w][0]
            
            cluster_order.append(current_position - 1)
            move = current_move + 1            
            edge_map.remove(edge_map[w])     
            break
        
    if current_move == move:
        
        current_position = history[-1]
        del history[-1] 

def moving_average(a, n):
    
    ret = np.cumsum(a, dtype = float)
    ret[n: ] = ret[n: ] - ret[ :-n]
    
    return ret[n - 1: ] / n

var_gene = torch.var(model.fc3(model.encode(SNP)[0].detach()).detach(), 0)
rank_gene = torch.flipud(torch.argsort(var_gene))

cluster_sorted = np.empty(0)

for w in cluster_order:
    
    cluster_sorted = np.concatenate((cluster_sorted, cluster[w]), axis = None).astype(int)
    
SNP_sorted = SNP.detach().numpy()[cluster_sorted, :]
SNP_sorted = SNP_sorted[:, rank_gene][:, 0 : 30].T
cluster_colors = pd.Series(label[cluster_no - 2][0][cluster_sorted]).map(dict(zip(np.arange(0, cluster_no), colors)))
fig = sns.clustermap(pd.DataFrame(SNP_sorted, index = information[allgene_var > threshold_var]["POS"].to_numpy()[rank_gene][0 : 30], columns = np.arange(1, SNP_sorted.shape[1] + 1)), row_cluster = False, col_cluster = False, col_colors = cluster_colors, figsize = (20, 15))
fig.ax_col_colors.set_xticks(moving_average(np.cumsum([0] + list(node_weight[cluster_order])), 2))
fig.ax_col_colors.set_xticklabels(np.arange(1, cluster_no + 1)[cluster_order])
fig.ax_col_colors.xaxis.set_tick_params(size = 0)
fig.ax_col_colors.xaxis.tick_top()
plt.show()