In [None]:
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
#import seaborn as sns

In [None]:
from random import random
from tqdm.notebook import tqdm

In [None]:
import time
import multiprocess as mp

### Data analysis

In [None]:
def heterozygosity_map(chromosome, fname = None):
    ref, alt = read_data('./Data/glioblastoma_BT_S2/ref.csv', './Data/glioblastoma_BT_S2/alt.csv')
    
    ref_proportion = (ref + 1) / (ref + alt + 2) # add a dummy count to both ref and alt to avoid division by 0
    alpha = 2 * np.arctan(ref + alt) / np.pi # hide loci without enough counts
    
    plt.figure(figsize=(24,16))
    plt.imshow(ref_proportion.T, cmap = 'viridis', vmin = 0., vmax = 1., alpha = alpha.T) 
    # "viridis": yellow for 1, purple for 0, green/blue for 0.5 (https://matplotlib.org/3.5.1/tutorials/colors/colormaps.html)
    plt.title(chromosome, fontsize = 17)
    plt.xlabel('locus index', fontsize = 17)
    plt.ylabel('cell index', fontsize = 17)
    if fname is None: 
        fname = 'map_chr' + str(chromosome) + '.png'
    plt.savefig('./figures/' + fname)

### Data generator

In [None]:
from data_generator import *
from utilities import *

In [None]:
dg = DataGenerator(50, 400, coverage_sampler=coverage_sampler())
dg.random_tree()
dg.random_mutations(mut_prop = 0.5, genotype_freq = [1., 0., 0.])

In [None]:
ref_raw, alt_raw = dg.generate_reads()

In [None]:
from mutation_detection import mut_type_posteriors

In [None]:
posteriors = mut_type_posteriors(ref_raw, alt_raw, genotype_freq = {'R': 1/3, 'H': 1/3, 'A': 1/3})
selected = np.where(np.sum(posteriors[:,3:], axis = 1) > 0.1)[0]
ref, alt = ref_raw[:,selected], alt_raw[:,selected]

In [None]:
mut_type = np.argmax(posteriors[selected, 3:], axis = 1)
gt1 = np.choose(mut_type, choices = ['R', 'H', 'H', 'A'])
gt2 = np.choose(mut_type, choices = ['H', 'A', 'R', 'H'])

In [None]:
len(gt2)

In [None]:
# correct mutation and correct direction
np.sum(np.logical_and(gt1 == dg.gt1[selected], gt2 == dg.gt2[selected]))

In [None]:
# correct mutation, either direction
np.sum(np.logical_or(np.logical_and(gt1 == dg.gt1[selected], gt2 == dg.gt2[selected]), np.logical_and(gt2 == dg.gt1[selected], gt1 == dg.gt2[selected])))

### Tree inference with generated data

In [None]:
from tree_inference import *
from mutation_detection import likelihood_matrices
from utilities import path_len_dist

In [None]:
ref, alt, gt1, gt2 = filter_mutations(ref_raw, alt_raw, method = 'threshold', t = 0.5)

In [None]:
likelihoods1, likelihoods2 = likelihood_matrices(ref, alt, gt1, gt2)

#### True tree

In [None]:
optz = TreeOptimizer()
optz.fit(likelihoods1, likelihoods2, reversible = True)
optz.ct = dg.tree.copy()
optz.ct.n_mut = optz.n_mut
optz.update_ct()

In [None]:
optz.ct_joint / likelihoods1.size

In [None]:
print('Distance matrix MSE to real tree:', path_len_dist(optz.ct, dg.tree))

In [None]:
optz.mt.fit_structure(optz.ct)
optz.mt_L[:,optz.mt.root.ID] = np.sum(optz.likelihoods1, axis = 1)
optz.update_mt()

In [None]:
optz.mt_joint / likelihoods1.size

#### All mutations reversible

In [None]:
import matplotlib.pyplot as plt

In [None]:
optz = TreeOptimizer()
optz.fit(likelihoods2, likelihoods1, reversible = True)
print('Distance matrix MSE to real tree:', path_len_dist(optz.ct, dg.tree))

In [None]:
optz.optimize(spaces = ['c', 'm'])

In [None]:
mean_likelihoods = np.array(optz.likelihood_history) / likelihoods1.size
plt.plot(mean_likelihoods)

In [None]:
print('MSE of distance matrix:', path_len_dist(optz.ct, dg.tree))
print('Cell tree mean loglikelihood:', optz.ct_joint / likelihoods1.size)
print('Mutation tree mean loglikelihood:', optz.mt_joint / likelihoods1.size)

#### No mutation reversible

In [None]:
optz = TreeOptimizer()
optz.fit(likelihoods1, likelihoods2, reversible = False)
print('Distance matrix MSE to real tree:', path_len_dist(optz.ct, dg.tree))

In [None]:
optz.optimize()

In [None]:
mean_likelihoods = np.array(optz.likelihood_history) / likelihoods1.size
plt.plot(mean_likelihoods)

In [None]:
print('MSE of distance matrix:', path_len_dist(optz.ct, dg.tree))
print('Cell tree mean loglikelihood:', optz.ct_joint / likelihoods1.size)
print('Mutation tree mean loglikelihood:', optz.mt_joint / likelihoods1.size)

### Other tests

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from utilities import *

In [7]:
import numpy as np

In [10]:
dist_rev = np.load('./test_results/reversibility_50c_100m/dist_rev.npy')
dist_irr = np.load('./test_results/reversibility_50c_100m/dist_irr.npy')