In [6]:
import sys 
import networkx as nx
import pandas as pd
import numpy as np
import pickle as pic
import random

import cassiopeia.TreeSolver.simulation_tools.simulation_utils as sim_utils
import cassiopeia.TreeSolver.simulation_tools.dataset_generation as data_gen
from cassiopeia.TreeSolver.Node import Node
from cassiopeia.TreeSolver.Cassiopeia_Tree import Cassiopeia_Tree

from tqdm import tqdm_notebook
import matplotlib.pyplot as plt

import subprocess

#import seaborn as sns
import os

In [7]:
def get_character_matrix(nodes):
    
    char_arrays = []
    for n in nodes:
        chars = n.char_string.split("_")[0].split("|")
        char_arrays.append(chars)
        
    return pd.DataFrame(char_arrays)

def compute_priors(C, S, p, mean=0.01, disp=0.1, skew_factor = 0.05, num_skew=1, empirical = np.array([]), mixture = 0):
    
    sp = {}
    prior_probabilities = {}
    for i in range(0, C):
        if len(empirical) > 0:
            sampled_probabilities = sorted(empirical)
        else:
            sampled_probabilities = sorted([np.random.negative_binomial(mean,disp) for _ in range(1,S+1)])
        s = C % num_skew
        mut_rate = p * (1 + num_skew * skew_factor)
        prior_probabilities[i] = {'0': (1-mut_rate)}
        total = np.sum(sampled_probabilities)

        sampled_probabilities = list(map(lambda x: x / (1.0 * total), sampled_probabilities))
        
        if mixture > 0: 
            for s in range(len(sampled_probabilities)):
                if np.random.uniform() <= mixture:
                    sampled_probabilities[s] = np.random.uniform()
            
            sp[i] = sampled_probabilities 
            total = np.sum(sampled_probabilities)
            sampled_probabilities = list(map(lambda x: x / (1.0 * total), sampled_probabilities))
            
            
        for j in range(1, S+1):
            prior_probabilities[i][str(j)] = (mut_rate)*sampled_probabilities[j-1]

    return prior_probabilities, sp

def count_all_dropouts_leaves(leaves):
    count = 0
    for node in leaves:
        sample = node.get_character_string().split('|')
        for i in sample:
            if (i == '-' or i == '*'):
                count += 1
    return count

In [45]:
def overlay_mutation(network, mutation_prob_map, basal_rate, cassette_size):
    root = [n for n in network if network.in_degree(n) == 0][0]
    mutation_cache = {}
    
    for i in mutation_prob_map: #edit the mutation map to only include the probs 
                                #of mutating to each state, given that character is chosen to mutate
        sum = 0
        mutation_prob_map[i].pop('0', None)
        for j in mutation_prob_map[i]:
            sum += mutation_prob_map[i][j]
        new_probs = {}
        for j in mutation_prob_map[i]:
            new_probs[j] = mutation_prob_map[i][j]/sum
        mutation_prob_map[i] = new_probs
    
    mutation_prob_map['basal_mut_rate'] = basal_rate
    
    mutation_helper(network, root, mutation_prob_map, mutation_cache, root.char_vec, [], cassette_size)

def mutation_helper(network, node, mutation_prob_map, mutation_cache, curr_mutations, dropout_indices, cassette_size):
    new_sample = curr_mutations.copy()
    new_dropout_indices = dropout_indices.copy()
    t = network.nodes[node]['lifespan']
    if t == 0:
        node.char_vec = new_sample
        node.char_string = '|'.join([str(char) for char in new_sample])
        network.nodes[node]['dropout'] = new_dropout_indices
        return
    mut_rate = mutation_prob_map['basal_mut_rate']
    p = 0
    
    if len(mutation_cache) == 0:
        mutation_cache[1] = mut_rate
    
    if t in mutation_cache:
        p = mutation_cache[t]
    else:
        t_p = max(mutation_cache.keys())
        p = mutation_cache[t_p]
        for t_temp in range(t_p + 1, t + 1):
            p += mut_rate * (1 - mut_rate) ** (t_temp - 1)
            mutation_cache[t_temp] = p
            
    base_chars = []
    for i in range(0, len(new_sample)):
        if new_sample[i] == '0' and i not in new_dropout_indices:
            base_chars.append(i)
    
    times = {}
    draws = np.random.binomial(len(base_chars), p)
    chosen_ind = np.random.choice(base_chars, draws)
    for i in chosen_ind:
        values, probabilities = zip(*mutation_prob_map[i].items())
        new_character = np.random.choice(values, p=probabilities)
        new_sample[i] = new_character
        time = np.random.choice(range(1, t + 1))
        for ti in range(time - 2, time + 3):
            if ti >= 1 and ti <= t:
                if ti in times:
                    times[ti].append(i)
                else:
                    times[ti] = [i]
    
    for time in sorted(times.keys()):
        if len(times[time]) > 1:
            not_dropped = []
            for i in times[time]:
                if i not in new_dropout_indices:
                    not_dropped.append(i)
            for c in range(0, (len(new_sample)//cassette_size)):
                cass_indices = []
                for i in not_dropped:
                    if (i >= c * cassette_size and i < (c + 1) * cassette_size):
                        cass_indices.append(i)
                if len(cass_indices) > 1:
                    for e in range(min(cass_indices), max(cass_indices) + 1):
                        new_dropout_indices.append(e)
             
    node.char_vec = new_sample
    node.char_string = '|'.join([str(char) for char in new_sample])
    network.nodes[node]['dropout'] = new_dropout_indices

    if network.out_degree(node) > 0:
        for i in network.successors(node):
            mutation_helper(network, i, mutation_prob_map, mutation_cache, new_sample, new_dropout_indices, cassette_size)

def overlay_heritable_dropout(network):
    root = [n for n in network if network.in_degree(n) == 0][0]
    h_dropout_helper(network, root)

def h_dropout_helper(network, node):
    new_sample = node.char_vec.copy()
    for i in network.nodes[node]['dropout']:
        new_sample[i] = '-'
    node.char_vec = new_sample
    node.char_string = '|'.join([str(char) for char in new_sample])

    if network.out_degree(node) > 0:
        for i in network.successors(node):
            h_dropout_helper(network, i)
            
def add_stochastic_leaves(leaves, dropout_prob, cassette_size):
    for node in leaves:
        sample = node.char_vec.copy()
        for i in range(0, len(sample)//cassette_size):
            if random.uniform(0, 1) <= dropout_prob:
                for j in range(i * cassette_size, (i + 1) * cassette_size):
                    sample[j] = '-'
        node.char_vec = sample
        node.char_string = '|'.join([str(char) for char in sample])
        
# def mutation_helper_old(network, node, mutation_prob_map, mutation_cache, curr_mutations):
#     new_sample = curr_mutations
#     t = network.nodes[node]['lifespan']
#     for i in range(0, len(new_sample)):
#         if new_sample[i] == '0':
#             values, probabilities = zip(*mutation_prob_map[i].items())
#             if t in mutation_cache[i]:
#                 new_probs = mutation_cache[i][t]
#             else:
#                 new_probs = []
#                 t_p = 0
#                 if len(mutation_cache) == 0:
#                     t_p = 1
#                     new_probs = probabilities
#                 else:
#                     t_p = max(mutation_cache[i])
#                     new_probs = mutation_cache[i][t_p]
#                 for t_temp in range(t_p, t + 1):
#                     new_probs[0] *= probabilities[0]
#                     for p in range(1, len(new_probs)):
#                         new_probs += probabilities[0] ** (t_temp - 1) * probabilities[p]
#                     mutation_cache[i][t_temp] = new_probs
#             new_character = np.random.choice(values, p=new_probs)
#             new_sample[i] = new_character
#     node.char_vec = new_sample
#     node.char_string = '|'.join([str(char) for char in new_sample])
    
#     if network.out_degree(node) > 0:
#         for i in network.successors(node):
#             mutation_helper(network, i, mutation_prob_map, mutation_cache, new_sample)

In [46]:
def phylo_forward_pass(
    cassette_size = 3,
    cassette_number = 10,
    timesteps = 100, 
    min_division_rate = 0.076,
    #U = lambda: np.random.exponential(1, 1),
    fitness_rate = 0.000,
    epsilon = 0.001,
    cell_death = 0.001
):
    
    characters = cassette_size * cassette_number
    
#     division_rate = min_division_rate + np.random.exponential(1, 1) * (1 - min_division_rate) # probability that cell will double per time-step
    division_rate = min_division_rate
    
    network = nx.DiGraph()
    current_cells = [[['0' for _ in range(0, characters)], '0']]
    
    network.add_node(sim_utils.node_to_string(current_cells[0]))
    network.nodes[sim_utils.node_to_string(current_cells[0])]['fitness'] = division_rate
    network.nodes[sim_utils.node_to_string(current_cells[0])]['lifespan'] = 0
    uniq = 1
    
    for t in range(0, timesteps + 1):
        temp_current_cells = []
#         if len(current_cells) == 0:
#             print("all cells dead, terminating")
#             break
#         current_fitnesses = [network.nodes[sim_utils.node_to_string(n)]['fitness'] for n in current_cells]
#         norm = np.max(current_fitnesses)
        
        for node in current_cells:
            fitness = network.nodes[sim_utils.node_to_string(node)]['fitness']
            network.nodes[sim_utils.node_to_string(node)]['lifespan'] += 1
            
            if np.random.random() >= cell_death:
                
                if np.random.random() <= fitness_rate: #cell gains a fitness mutation
#                     if t == (timesteps - 1):
#                         network.node[node]['fitness'] = fitness
#                     else:
#                     s = max(1e-20, U()[0])
                    if np.random.random() <= 0.5:
                        fitness = fitness + epsilon
                    else:
                        fitness = fitness - epsilon
                    network.nodes[sim_utils.node_to_string(node)]['fitness'] = fitness
                
                if np.random.random() <= fitness: # t != (timesteps - 1): #cell divides
                    for _ in range(0,2):
                        parent_fitness = network.nodes[sim_utils.node_to_string(node)]['fitness']
                        temp_current_cells.append([node[0], str(uniq)])
                        network.add_edge(sim_utils.node_to_string(node), sim_utils.node_to_string([node[0], str(uniq)]))
                        network.nodes[sim_utils.node_to_string([node[0], str(uniq)])]['fitness'] = parent_fitness
                        network.nodes[sim_utils.node_to_string([node[0], str(uniq)])]['lifespan'] = 0
                        uniq += 1
                else: #cell does not divide
                    temp_current_cells.append(node)
                                    
            else: #if cell dies
                curr_parent = sim_utils.node_to_string(node)
                while network.out_degree(curr_parent) < 1 and network.in_degree(curr_parent) > 0:
                    next_parent = list(network.predecessors(curr_parent))[0]
                    network.remove_node(curr_parent)
                    curr_parent = next_parent
                
        current_cells = temp_current_cells
#         print("timestep:" + str(t))
#         print("size:" + str(len(current_cells)))
        
    rdict = {}
    i = 0
    for n in network.nodes:
        nn = Node("StateNode" + str(i), n.split("_")[0].split("|"), pid = n.split("_")[1], is_target=False)
        i += 1
        rdict[n] = nn

    network = nx.relabel_nodes(network, rdict)
    
#     source = [x for x in network.nodes() if network.in_degree(x)==0][0]

#     max_depth = max(nx.shortest_path_length(network,source,node) for node in network.nodes())
#     shortest_paths = nx.shortest_path_length(network,source)

#     leaves = [x for x in network.nodes() if network.out_degree(x)==0 and network.in_degree(x) == 1 and shortest_paths[x] == max_depth]

    leaves = [n for n in network if network.out_degree(n) == 0 and network.in_degree(n) == 1] 
    
    state_tree = Cassiopeia_Tree('simulated', network = network)
    return state_tree, leaves

In [81]:
def states_per_char(cm):
    unique_chars = [0 for n in range(0, cm.shape[1])]
    seen = [[] for n in range(0, cm.shape[1])]
    for j in range(0, cm.shape[1]):
        for i in range(0,cm.shape[0]):
            val = cm.iloc[i, j]
            if val != '0' and val != '-' and val not in seen[j]:
                unique_chars[j] += 1
                seen[j].append(val)
    return unique_chars

In [128]:
path = "/data/yosef2/users/richardz/projects/Yule/benchmarking/t500"
if os.path.exists(path) == False:
    os.mkdir(path)  

timesteps = 500
division_rate = 0.0166
fitness_rate = 0.0025
epsilon = 0.001
cell_death = 0.006
mutation_rate = 0.00026
cassette_size = 3
dropout_rate = 0.20

# counts = []
size = []
drop_perc = []
avg_spc = []

# avg_unique_chars = []
for i in range(1, 51):
    out, leaves = phylo_forward_pass(timesteps = timesteps, 
                                          min_division_rate = division_rate, 
                                          fitness_rate = fitness_rate, 
                                          epsilon = epsilon, 
                                          cell_death = cell_death)
    
    while len(leaves) < 300 or len(leaves) > 500:
        out, leaves = phylo_forward_pass(timesteps = timesteps, 
                                              min_division_rate = division_rate, 
                                              fitness_rate = fitness_rate, 
                                              epsilon = epsilon,
                                              cell_death = cell_death)

    prior_probabilities = compute_priors(30, 100, mutation_rate, 5, 0.5, skew_factor=0.0, num_skew=1)[0]
    
    pic.dump(prior_probabilities, open(path + '/priors' + str(i) + '.pkl', 'wb'))
    
    overlay_mutation(out.network, prior_probabilities.copy(), mutation_rate, 3)
    ground_cm = get_character_matrix(leaves)
    ground_cm.to_csv(path + '/ground_truth_cm' + str(i) + '.txt', sep = '\t')
    
    overlay_heritable_dropout(out.network)
    add_stochastic_leaves(leaves, dropout_rate, cassette_size)
    dropout_cm = get_character_matrix(leaves)
    dropout_cm = dropout_cm.astype(str)
    row_names = ['c' + str(i) for i in range(dropout_cm.shape[0])]
    dropout_cm.index = row_names
    dropout_cm.to_csv(path + '/dropout_cm' + str(i) + '.txt', sep = '\t')
    pic.dump(out, open(path + '/dropout_net' + str(i) + '.pkl', 'wb'))
    
    num_dropped = 0
    for k in range(dropout_cm.shape[0]):
        for j in range(dropout_cm.shape[1]):
            if dropout_cm.iloc[k,j] == "-" or dropout_cm.iloc[k,j] == "*":
                num_dropped += 1
    drop_perc.append(num_dropped/(dropout_cm.shape[0] * dropout_cm.shape[1]))
    
    num_leaves = len(leaves)
    size.append(num_leaves)
    
    spc = states_per_char(dropout_cm)
    avg_spc.append(sum(spc)/len(spc))

#     count = count_all_dropouts_leaves(leaves)/(num_leaves*number_of_states)
#     counts.append(count)
    
    print(i, num_leaves, sum(spc)/len(spc))

1 336 4.866666666666666
2 347 4.633333333333334
3 334 5.5
4 489 7.366666666666666
5 481 7.0
6 344 5.266666666666667
7 355 5.566666666666666
8 396 6.233333333333333
9 348 4.366666666666666
10 445 6.9
11 339 5.4
12 487 6.2
13 309 4.7
14 442 6.8
15 328 5.0
16 348 5.333333333333333
17 358 5.2
18 416 6.933333333333334
19 362 4.466666666666667
20 410 6.833333333333333
21 434 7.266666666666667
22 395 7.1
23 417 6.233333333333333
24 331 5.6
25 341 5.433333333333334
26 399 5.4
27 384 5.333333333333333
28 346 5.333333333333333
29 491 6.5
30 306 4.366666666666666
31 495 6.9
32 403 6.033333333333333
33 436 7.133333333333334
34 377 5.166666666666667
35 383 6.1
36 366 5.933333333333334
37 316 4.166666666666667
38 370 5.433333333333334
39 404 4.966666666666667
40 479 7.7
41 349 5.033333333333333
42 468 7.966666666666667
43 476 7.933333333333334
44 466 5.966666666666667
45 325 3.8666666666666667
46 361 5.433333333333334
47 419 6.633333333333334
48 413 5.566666666666666
49 418 6.833333333333333
50 320 

In [129]:
import statistics as stat
print(stat.mean(drop_perc))
print(stat.median(drop_perc))
print(stat.mean(avg_spc))
print(stat.median(avg_spc))
print(stat.mean(size))
print(stat.median(size))

0.2012195847616055
0.20070501394030804
5.841333333333333
5.566666666666666
391.24
383.5


In [None]:
df = pd.DataFrame([size, heights, widths, avg_deg, avg_leaf_deg])
df.index = ["size", "heights", "widths", "avg_deg", "avg_leaf_deg"]
df

In [None]:
df.to_csv(path + '/tree_stats.txt', sep = '\t')

In [None]:
overlay_mutation(out.network, prior_probabilities.copy(), (1 - .9999), 30)
overlay_heritable_dropout(out.network)

In [None]:
for i in out.network.nodes():
    print(i.get_character_string())
    print(out.network.nodes[i]['lifespan'])

In [None]:
def post_process_tree(network):
    root = [n for n in network if network.in_degree(n) == 0][0]
    if network.out_degree(root) > 0:
        for node in network.successors(root):
            post_process_helper(network, node, root)
    
def post_process_helper(network, node, parent):
    if parent.char_vec == node.char_vec:
        succs = network.successors(node)
        network.remove_node(node)
        for i in succs:
            network.add_edge(parent, i)
            post_process_helper(network, i, parent)
    else:
        for i in network.successors(node):
            post_process_helper(network, i, node)

