In [None]:
# Cell tree construction

no_cell = pair_latent.shape[0]

for r in range(no_cell):
    
    pair_latent[r, r] = np.nan

included = np.zeros(no_cell).astype(bool)
edge = []
edge_distance = []

connection = np.unravel_index(np.nanargmin(pair_latent), pair_latent.shape)
edge.append((connection[0] + 1, connection[1] + 1))
edge_distance.append(pair_latent[connection])
included[connection[0]] = True
included[connection[1]] = True
subset = pair_latent[included, :]
subset[:, included] = np.nan
pair_latent[included, :] = subset

while ~ included.all():
    
    possible = np.empty(pair_latent.shape)
    possible[:] = np.nan
    possible[included, :] = pair_latent[included, :]
    connection = np.unravel_index(np.nanargmin(possible), pair_latent.shape)
    edge.append((connection[0] + 1, connection[1] + 1))
    edge_distance.append(pair_latent[connection])
    included[connection[0]] = True
    included[connection[1]] = True
    subset = pair_latent[included, :]
    subset[:, included] = np.nan
    pair_latent[included, :] = subset

In [None]:
# Heatmap sorted by cell tree

import seaborn as sns
import pandas as pd

root = np.argmin(np.sum(latent.numpy() ** 2, 1)) + 1
move = 0
edge_map = edge.copy()
current_position = root
cell_order = [root - 1]
history = []

while move < len(edge):
    
    current_move = move
    
    for w in range(len(edge_map)):
        
        if current_position in edge_map[w]:
            
            history.append(current_position)
            
            if edge_map[w][0] == current_position:
                
                current_position = edge_map[w][1]
                
            elif edge_map[w][1] == current_position:
                
                current_position = edge_map[w][0]
            
            cell_order.append(current_position - 1)
            move = current_move + 1            
            edge_map.remove(edge_map[w])     
            break
        
    if current_move == move:
        
        current_position = history[-1]
        del history[-1]
        
var_gene = torch.var(model.fc3(model.encode(SNP)[0].detach()).detach(), 0)
rank_gene = torch.flipud(torch.argsort(var_gene))

cell_sorted = cell_order

SNP_sorted = SNP.detach().numpy()[cell_sorted, :]
SNP_sorted = SNP_sorted[:, rank_gene][:, 0 : 30].T
cell_colors = pd.Series(np.hstack((np.array(0), label[cluster_no - 2][0][cell_sorted]))).map(dict(zip(np.arange(0, cluster_no), colors)))
fig = sns.clustermap(pd.DataFrame(SNP_sorted, index = information[allgene_var > threshold_var]["POS"].to_numpy()[rank_gene][0 : 30], columns = np.arange(1, SNP_sorted.shape[1] + 1)), row_cluster = False, col_cluster = False, col_colors = cell_colors, figsize = (20, 15))
plt.show()

In [None]:
# Cell tree in principal components 1 and 2

from sklearn.decomposition import PCA

pca = PCA()
pca.fit(latent.numpy())
pca_latent = pca.fit_transform(latent.numpy())

fig, ax1 = plt.subplots(1, 1)
fig.set_size_inches(10, 10)
fig.suptitle("Phylogenetic Tree", fontsize = 30)
ax1.set_xlabel("PCA 1", fontsize = 15)
ax1.set_ylabel("PCA 2", fontsize = 15)

for r in range(no_cell - 1):
    
    ax1.plot(pca_latent[[edge[r][0] - 1, edge[r][1] - 1] , 0], pca_latent[[edge[r][0] - 1, edge[r][1] - 1], 1], color = "k", linewidth = 0.2)

for m in range(cluster_no):
    
    ax1.scatter(pca_latent[cluster[m], 0], pca_latent[cluster[m], 1], s = 20, color = colors[m])

plt.show()

In [None]:
# Cell tree in UMAP embeddings

import umap

reducer = umap.UMAP()
embedding = reducer.fit_transform(latent.numpy())

fig, ax1 = plt.subplots(1, 1)
fig.set_size_inches(10, 10)
fig.suptitle("Phylogenetic Tree", fontsize = 30)

for r in range(no_cell - 1):
    
    ax1.plot(embedding[[edge[r][0] - 1, edge[r][1] - 1] , 0], embedding[[edge[r][0] - 1, edge[r][1] - 1], 1], color = "k", linewidth = 0.2)

for m in range(cluster_no):
    
    ax1.scatter(embedding[cluster[m], 0], embedding[cluster[m], 1], s = 7, color = colors[m])

plt.show()

In [None]:
# Cell tree in polar coordinate

radii = np.sum(pca_latent ** 2, 1)
rad_rank = np.argsort(radii)
angle = np.empty(no_cell)
angle[rad_rank[0]] = 0


for h in rad_rank:
    
    for g in range(len(edge)):
    
        if edge[g][0] == h + 1:
            
            related_node = edge[g][1] - 1
            
            if radii[related_node] >= radii[h]:
                
                cos_angle = (radii[h] + radii[related_node] - pair_latent[h, related_node]) / (2 * ((radii[h] * radii[related_node]) ** (0.5)))
                angle[related_node] = angle[h] + np.arccos(cos_angle)
            
        elif edge[g][1] == h + 1:
            
            related_node = edge[g][0] - 1
            
            if radii[related_node] >= radii[h]:
                
                cos_angle = (radii[h] + radii[related_node] - pair_latent[h, related_node]) / (2 * ((radii[h] * radii[related_node]) ** (0.5)))
                angle[related_node] = angle[h] + np.arccos(cos_angle)
                
X = radii * np.cos(angle)
Y = radii * np.sin(angle)

fig, ax1 = plt.subplots(1, 1)
fig.set_size_inches(10, 10)
fig.suptitle("Phylogenetic Tree", fontsize = 30)

for r in range(no_cell - 1):
    
    ax1.plot(X[[edge[r][0] - 1, edge[r][1] - 1]], Y[[edge[r][0] - 1, edge[r][1] - 1]], color = "k", linewidth = 0.2)

for m in range(cluster_no):
    
    ax1.scatter(X[cluster[m]], Y[cluster[m]], s = 7, color = colors[m])

plt.show()