In [94]:
import msprime
import tskit
import hgt_simulation
import hgt_sim_args
import numpy as np
import random
import torch
import os
import time
import re
import secrets
import copy
import json
import h5py
import uuid
import glob
import shutil

from random import randint
from collections import namedtuple
from collections import defaultdict
from collections import deque
from typing import List, Union
from sbi.utils import BoxUniform
from sklearn.decomposition import PCA
from scipy.spatial.distance import pdist, squareform

from concurrent.futures import ProcessPoolExecutor


num_samples = 1000
theta = 1
rho = 1
hgt_rate = 20
num_genes = 1
nucleotide_mutation_rate = 0.1
gene_length = 1000
pca_dimensions = 10
ce_from_nwk = None
seed = None
distance_matrix = None
multidimensional_scaling_dimensions = 100

start_time = time.time()

    
if ce_from_nwk is None and num_samples is None:
    raise ValueError(
        "Neither a core tree or parameters for simulation were provided. Choose either."
    )

if ce_from_nwk is None:
    core_tree = msprime.sim_ancestry(
            samples=num_samples,
            sequence_length=1,
            ploidy=1,
            recombination_rate=0,
            gene_conversion_rate=0,
            gene_conversion_tract_length=1,  # One gene
            random_seed=seed,
        )

    ce_from_nwk = core_tree.first().newick()

if seed is None:
    seed = secrets.randbelow(2**32 - 4) + 2

if num_samples < multidimensional_scaling_dimensions:
    multidimensional_scaling_dimensions = num_samples

random.seed(seed)
np.random.seed(seed)
#print("Seed: ", seed)

### Calculate hgt events:

args = hgt_sim_args.Args(
    sample_size=num_samples,
    num_sites=num_genes,
    gene_conversion_rate=0,
    recombination_rate=0,
    hgt_rate=hgt_rate,
    ce_from_ts=None,
    ce_from_nwk=ce_from_nwk,
    random_seed=seed,
    #random_seed=84,
)

ts, hgt_edges = hgt_simulation.run_simulate(args)
        
### Place mutations

alleles = ["absent", "present"]

# Place one mutation per site, e.g. genome position

gains_model = msprime.MatrixMutationModel(
    alleles = alleles,
    root_distribution=[1, 0],
    transition_matrix=[
        [0, 1],
        [0, 1],
    ],
)

ts_gains = msprime.sim_mutations(ts, rate=1, model = gains_model, keep = True, random_seed=seed)

k = 1
while (ts_gains.num_sites < ts_gains.sequence_length):
    ts_gains = msprime.sim_mutations(ts_gains, rate=1, model = gains_model, keep = True, random_seed=seed+k)
    k = k+1

# Remove superfluous mutations

tables = ts_gains.dump_tables()

mutations_by_site = {}
for mut in tables.mutations:
    if mut.site not in mutations_by_site:
        mutations_by_site[mut.site] = []
    mutations_by_site[mut.site].append(mut)

tables.mutations.clear()

for site, mutations in mutations_by_site.items():
    selected_mutation = random.choice(mutations)

    tables.mutations.add_row(
        site=selected_mutation.site,
        node=selected_mutation.node,
        derived_state=selected_mutation.derived_state,
        parent=-1,
        metadata=None,
        time=selected_mutation.time,
    )

    # Add sentinel mutations at the leafs:

    for leaf_position in range(num_samples):
        tables.mutations.add_row(
            site = selected_mutation.site,
            node = leaf_position,
            derived_state = "absent",
            time = 0.00000000001,
        )

ts_gains = tables.tree_sequence()


# Place losses:

losses_model = msprime.MatrixMutationModel(
    alleles = alleles,
    root_distribution=[1, 0],
    transition_matrix=[
        [1, 0],
        [1, 0],
    ],
)

ts_gains_losses = msprime.sim_mutations(ts_gains, rate = rho, model = losses_model, keep = True, random_seed=seed-1)

tables = ts_gains_losses.dump_tables()

# Find the node of the root of the clonal tree:
clonal_node = ts_gains.first().mrca(*list(range(num_samples)))

print("Clonal node:", clonal_node)

for site, mutations in mutations_by_site.items():
    # Mutation at clonal root:
    tables.mutations.add_row(
        site=site,
        node=clonal_node,
        derived_state="absent",
        parent=-1,
        metadata=None,
        time=tables.nodes.time[clonal_node],
    )

tables.sort()
ts_gains_losses = tables.tree_sequence()

### Calculate the gene absence presence matrix:

MutationRecord = namedtuple('MutationRecord', ['site_id', 'mutation_id', 'node', 'is_hgt'])

tables = ts_gains_losses.dump_tables()

tables.mutations.clear() # SIMPLE VERSION!

hgt_parent_nodes = [edge.parent-1 for edge in hgt_edges]
hgt_children_nodes = [edge.child for edge in hgt_edges]
hgt_parent_children = defaultdict(list)
#hgt_children_parent = defaultdict(list)

for parent in hgt_parent_nodes:
    hgt_parent_children[parent].append(parent-1)


#for child in hgt_children_nodes:
#    hgt_children_parent[child].append(child+1)

# Initialize tables for the diversity of each gene
tables_gene = tskit.TableCollection(sequence_length=gene_length)
tables_gene.nodes.replace_with(ts_gains_losses.tables.nodes)
tables_gene.populations.replace_with(ts_gains_losses.tables.populations)


tables_gene_list = [copy.deepcopy(tables_gene) for _ in range(num_genes)]
gene_number_hgt_events_passed = [0 for _ in range(num_genes)]
gene_trees_list = []


for tree in ts_gains_losses.trees():

    for site in tree.sites():
        
        hgt_parent_children_passed = [False] * ts_gains_losses.num_nodes
        mutations = site.mutations
        #site = list(tree.sites())[0]
        
        present_mutation = [m for m in mutations if m.derived_state == "present"][0]
        absent_mutation_nodes = {m.node for m in mutations if m.derived_state == "absent"}

        absent_mutations = defaultdict(list)
        for m in mutations:
            if m.derived_state == "absent" and m.time < present_mutation.time:
                absent_mutations[m.node].append(m)

        branching_nodes_reached_before = defaultdict(list)
        for node_id in range(ts_gains_losses.num_nodes):
            branching_nodes_reached_before[node_id] = False
    
        print("Present_mutation at node: ", present_mutation.node)
    
        branching_nodes_to_process = deque([(present_mutation.node, False, False, 0)])
        # The second variable describes if a hgt edge was passed the whole way down to the actual node. 
        # The third describes if a hgt edge was passed in the last step.
        # The fourth is the number of hgt edges that were passed.
    
        child_mutations = []
    
        if present_mutation.node < num_samples: # Gain directly above leaf:
            if present_mutation.id == sorted([mut for mut in mutations if mut.node == present_mutation.node], key=lambda m: m.time)[1].id:
                sentinel_mutation = min([mut for mut in absent_mutations[present_mutation.node]], key=lambda m: m.time)
                child_mutations.append(MutationRecord(
                    site_id=site.id,
                    mutation_id=sentinel_mutation.id,
                    node=sentinel_mutation.node,
                    is_hgt=False
                ))
                tables_gene_list[site.id].edges.add_row(
                        left = 0, right = gene_length, parent = tree.parent(present_mutation.node), child = sentinel_mutation.node
                )
    
        else: 
            
            # To see, which HGT edges are passed, we have to go through the tree two times. 
            # First, we detect all passed HGT edges, then we calculate the presence of mutations in the leaves.
    
            # First time going through:
            while branching_nodes_to_process:
                
                last_branching_node = branching_nodes_to_process.popleft()
                selected_branch_nodes_to_process = deque([last_branching_node])
                branching_nodes_reached_before[tree.parent(last_branching_node[0])] = True
                
                while selected_branch_nodes_to_process:
                
                    child_node = selected_branch_nodes_to_process.popleft()    
        
                    # If there is a mutation on the edge, find the earliest one.
                    if not child_node[2] and child_node[0] in absent_mutation_nodes:
                        
                        absent_mutation_after_gain_at_node = absent_mutations[child_node[0]]
                        
                        if not absent_mutation_after_gain_at_node: # empty
                            children = tree.children(child_node[0])
                            if len(children) > 1:
                                for child in reversed(children):
                                    if not branching_nodes_reached_before[child_node[0]]:
                                        branching_nodes_to_process.extendleft([(child, child_node[1], False, 0)])
                                    #else:
                                    #    print("Child passed before", " Child: ", child)
                            else:
                                for child in reversed(children):
                                    selected_branch_nodes_to_process.extendleft([(child, child_node[1], False, child_node[3])])
                            if hgt_parent_children[child_node[0]]:
                                selected_branch_nodes_to_process.extendleft([(hgt_parent_children[child_node[0]][0], True, True, child_node[3] + 1)])
                                hgt_parent_children_passed[hgt_parent_children[child_node[0]][0]] = True # The child of the hgt_edge is marked
                                #print("HGT at Node: ", hgt_parent_children[child_node[0]][0])
                            
            
                    # If there is no mutation, add child nodes.
                    else:
                        children = tree.children(child_node[0])
                        if len(children) > 1:
                            for child in reversed(children):
                                #print("Child: ", child)
                                if not branching_nodes_reached_before[child_node[0]]:
                                    branching_nodes_to_process.extendleft([(child, child_node[1], False, 0)])
                                #else:
                                #    print("Child passed before", " Child: ", child)
                        else:
                            for child in reversed(children):
                                #print("Child: ", child)
                                selected_branch_nodes_to_process.extendleft([(child, child_node[1], False, child_node[3])])
                        if hgt_parent_children[child_node[0]]:
                            selected_branch_nodes_to_process.extendleft([(hgt_parent_children[child_node[0]][0], True, True, child_node[3] + 1)])
                            hgt_parent_children_passed[hgt_parent_children[child_node[0]][0]] = True # The child of the hgt_edge is marked
                            #print("HGT at Node: ", hgt_parent_children[child_node[0]][0])
    
            # Second time going through:
            
            branching_nodes_to_process = deque([(present_mutation.node, False, False, 0)])

            branching_nodes_reached_before = defaultdict(list)
            for node_id in range(ts_gains_losses.num_nodes):
                branching_nodes_reached_before[node_id] = False
            
            while branching_nodes_to_process:
                
                last_branching_node = branching_nodes_to_process.popleft()
                selected_branch_nodes_to_process = deque([last_branching_node])
                branching_nodes_reached_before[tree.parent(last_branching_node[0])] = True
                
                while selected_branch_nodes_to_process:
                
                    child_node = selected_branch_nodes_to_process.popleft()
    
                    if not child_node[2] and hgt_parent_children_passed[child_node[0]]:
                        #print("Incoming HGT edge registered at node: ", child_node[0])
                        continue
        
                    # If there is a mutation on the edge, find the earliest one.
                    if not child_node[2] and child_node[0] in absent_mutation_nodes:
                        
                        absent_mutation_after_gain_at_node = absent_mutations[child_node[0]]
                        
                        if absent_mutation_after_gain_at_node: # not empty
                            earliest_mutation = max(
                                absent_mutations[child_node[0]], 
                                key=lambda m: m.time
                            )
            
                            if earliest_mutation.time == 0.00000000001:
                                earliest_mutation.derived_state = "present"
                                child_mutations.append(MutationRecord(
                                    site_id=site.id,
                                    mutation_id=earliest_mutation.id,
                                    node=earliest_mutation.node,
                                    is_hgt=child_node[1]
                                ))
                                tables_gene_list[site.id].edges.add_row(
                                    left = 0, right = gene_length, parent = tree.parent(last_branching_node[0]), child = earliest_mutation.node
                                )
                                gene_number_hgt_events_passed[site.id] += child_node[3]
                        else:
                            children = tree.children(child_node[0])
                            if len(children) > 1:
                                for child in reversed(children):
                                    if not branching_nodes_reached_before[child_node[0]]:
                                        branching_nodes_to_process.extendleft([(child, child_node[1], False, 0)])
                                tables_gene_list[site.id].edges.add_row(
                                        left = 0, right = gene_length, parent = tree.parent(last_branching_node[0]), child = child_node[0]
                                )
                                gene_number_hgt_events_passed[site.id] += child_node[3]
                            else:
                                for child in reversed(children):
                                    selected_branch_nodes_to_process.extendleft([(child, child_node[1], False, child_node[3])])
                            if hgt_parent_children[child_node[0]]:
                                selected_branch_nodes_to_process.extendleft([(hgt_parent_children[child_node[0]][0], True, True, child_node[3] + 1)])  
                            
            
                    # If there is no mutation, add child nodes.
                    else:
                        children = tree.children(child_node[0])
                        if len(children) > 1:
                            for child in reversed(children):
                                if not branching_nodes_reached_before[child_node[0]]:
                                    branching_nodes_to_process.extendleft([(child, child_node[1], False, 0)])
                            tables_gene_list[site.id].edges.add_row(
                                    left = 0, right = gene_length, parent = tree.parent(last_branching_node[0]), child = child_node[0]
                            )
                            gene_number_hgt_events_passed[site.id] += child_node[3]
                        else:
                            for child in reversed(children):
                                selected_branch_nodes_to_process.extendleft([(child, child_node[1], False, child_node[3])])
                        if hgt_parent_children[child_node[0]]:
                            selected_branch_nodes_to_process.extendleft([(hgt_parent_children[child_node[0]][0], True, True, child_node[3] + 1)])


        child_mutations.sort(key=lambda mut: not mut.is_hgt) # Will set is_hgt to False later if there are paths without hgt events to the leaf.
    
        # We have to adress multiple paths to the same destiny, some with hgt and other without it:
        unique_mutations = {}
    
        for mut in child_mutations:
            if mut.node not in unique_mutations:
                unique_mutations[mut.node] = mut
            else:
                existing_mut = unique_mutations[mut.node]
                if not existing_mut.is_hgt or not mut.is_hgt:
                    unique_mutations[mut.node] = mut._replace(is_hgt=False)
        
        child_mutations_filtered = list(unique_mutations.values())

        for mutation in mutations:
            if mutation.time > 0.00000000001:
                if mutation.derived_state == "absent":
                    metadata_value = bytes([3]) 
                elif mutation.derived_state == "present":
                    metadata_value = bytes([7])
                tables.mutations.add_row(
                    site=site.id,
                    node=mutation.node,
                    derived_state=mutation.derived_state,
                    parent=-1,
                    metadata=metadata_value,
                    time=mutation.time,
                )
        
        for mutation in child_mutations_filtered:
    
            tables.mutations.add_row(
                site=site.id,
                node=mutation.node,
                derived_state="present",
                parent=-1,
                metadata=bytes([mutation.is_hgt]),
                time=0.00000000001,
            )
    

mts = tables.tree_sequence()

# Simulate the tree for each gene:


nucleotide_mutation = msprime.MatrixMutationModel(
    alleles=["C", "M"],
    root_distribution=[1.0, 0.0],  # nur C als Wurzel
    transition_matrix=[
        # C     M1     
        [ 0.0,  1 ], 
        [ 1/3,  2/3],  
    ]
)

for i in range(num_genes): 
            
    tables_gene_list[i].sort()
    
    gene_trees_list.append(tables_gene_list[i].tree_sequence())
    
    gene_trees_list[i] = msprime.sim_mutations(gene_trees_list[i], 
                                               rate = nucleotide_mutation_rate, model = nucleotide_mutation, keep = True, random_seed=seed-1)

# Calculate the different alleles:

alleles_list = []

for i in range(num_genes):
    alleles_list.append([])
    for var in gene_trees_list[i].variants():
        alleles_list[i].append(var.genotypes)
    alleles_list[i] = np.array(alleles_list[i])


### Compute the gene presence and absence:

gene_absence_presence_matrix = []

for var in mts.variants():
    gene_absence_presence_matrix.append(var.genotypes)
gene_absence_presence_matrix = np.array(gene_absence_presence_matrix)

print("Number of present genes: ", sum(gene_absence_presence_matrix[0]))
print("gene_number_hgt_events_passed: ", gene_number_hgt_events_passed)

Clonal node: 6147
Present_mutation at node:  9282
Number of present genes:  843
gene_number_hgt_events_passed:  [914]


In [95]:
clonal_root_node = clonal_node
gene_number_hgt_events_passed = [0 for _ in range(num_genes)]

start_time = time.time()

for tree in mts.trees():
    for site in tree.sites():
        
        clonal_nodes = defaultdict(list)
        for node_id in range(mts.num_nodes):
            clonal_nodes[node_id] = False
        reached_nodes_from_leaves = copy.deepcopy(clonal_nodes)
            
        stack = [clonal_root_node]
        clonal_nodes[clonal_root_node] = True
        while stack:
            node = stack.pop()
            children = tree.children(node)
            clonal_nodes[node] = True
            stack.extend(children)

        stack = []
        for node_id in range(mts.num_samples):
            if gene_absence_presence_matrix[site.id][node_id]:
                #reached_nodes_from_leaves[node_id] = True
                stack.append(node_id)
                
        while stack:
            node = stack.pop()
            #print(node)
            parent = tree.parent(node)
            if not hgt_parent_children_passed[node] and not reached_nodes_from_leaves[parent] and node < clonal_root_node:
                stack.append(parent)
                reached_nodes_from_leaves[node] = True
            elif hgt_parent_children_passed[node] and not reached_nodes_from_leaves[parent] and node < clonal_root_node:
                #print("HGT event spotted at node: ", node)
                gene_number_hgt_events_passed[site.id] += 1
                reached_nodes_from_leaves[node] = True
        

end_time = time.time()
print(end_time-start_time)

print("Number of genes in leaves", sum(gene_absence_presence_matrix[0]))
print("Number of HGT:" , gene_number_hgt_events_passed[0])

#mts.draw_svg()

0.019999027252197266
Number of genes in leaves 843
Number of HGT: 78


In [65]:
hgt_parent_children_passed

[False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False]

In [60]:
reached_nodes_from_leaves

defaultdict(list,
            {0: True,
             1: True,
             2: True,
             3: True,
             4: True,
             5: True,
             6: False,
             7: False,
             8: True,
             9: True,
             10: True,
             11: True,
             12: False,
             13: False,
             14: False,
             15: False,
             16: False,
             17: False,
             18: False,
             19: False,
             20: False,
             21: False,
             22: False,
             23: False,
             24: False,
             25: True,
             26: False,
             27: False,
             28: False,
             29: False,
             30: False,
             31: True,
             32: False})