In [1]:
import msprime, stdpopsim, tskit, gzip, attr, demes, demesdraw

from demes import convert

from matplotlib import collections  as mc
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import time as t
import multiprocessing as mp
from intervaltree import Interval, IntervalTree
from itertools import chain
from collections import defaultdict

ModuleNotFoundError: No module named 'msprime'

In [None]:
import msprime, stdpopsim, tskit, gzip

from matplotlib import collections  as mc
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import time as t
import multiprocessing as mp
from intervaltree import Interval, IntervalTree
from itertools import chain
from collections import defaultdict

In [None]:
multi_graph = demes.load("../multi-way-demes.yaml")

In [None]:
fig_w, fig_h = plt.figaspect(12 / 16.0)
_, ax = plt.subplots(figsize=(fig_w, fig_h), dpi=300)
ax.set_yscale("log", base=10)
multi = demesdraw.tubes(multi_graph, 
                        num_lines_per_migration=3,
                        ax=ax,
                        inf_ratio=0.6,
                        seed=1111)

In [None]:
ms_model = convert.to_msprime(multi_graph)
popsim_model = convert.to_stdpopsim(multi_graph)

In [None]:
ms_demography = msprime.Demography.from_old_style(population_configurations=ms_model[0], 
                                      demographic_events=ms_model[1],
                                      migration_matrix=ms_model[2])

In [None]:
seed = 1111

species = stdpopsim.get_species("HomSap")
contig = species.get_contig("chr22")
nbp = contig.recombination_map.get_length()
mu = contig.mutation_rate
hapmap = species.get_genetic_map("HapMapII_GRCh37").get_chromosome_map("chr22").map

In [None]:
ts = msprime.sim_ancestry(samples={"Nama": 5000,
                                   "SAC": 5000,
                                   "GBR": 1000,
                                   "MSL": 1000,
                                   "EP": 1000,
                                   "EAS": 1000,
                                   "SAS": 1000},
                          demography=ms_demography, 
                          recombination_rate=hapmap,
                          gene_conversion_rate=2e-7,
                          gene_conversion_tract_length=125,
                          random_seed=seed,
                          record_migrations=True
)

In [None]:
ts.dump("/home/gerald/Documents/PhD/papers/paper4/multi_5000.tree")

In [None]:
ts_mu = msprime.sim_mutations(ts, 
                               rate=mu, 
                               random_seed=seed)

In [None]:
ts_mu.dump("/home/gerald/Documents/PhD/papers/paper4/multi_mu_5000.tree")

In [None]:
sample_nodes = list()
pop_nodes = list()

for node in ts_mu.nodes():
    if node.time == 0 :
        sample_nodes.append(node.id)
        if node.population == 4:
            pop_nodes.append("Nama")
        if node.population == 6:
            pop_nodes.append("MSL")
        if node.population == 7:
            pop_nodes.append("GBR")
        if node.population == 8:
            pop_nodes.append("EP")
        if node.population == 10:
            pop_nodes.append("EAS")
        if node.population == 11:
            pop_nodes.append("SAS")
        if node.population == 12:
            pop_nodes.append("SAC")

In [None]:
n_dip_indv = int(len(sample_nodes) / 2)
indv_names = [f"{pop_nodes[2*i]}{sample_nodes[2*i]}_{pop_nodes[2*i+1]}{sample_nodes[2*i+1]}" for i in range(n_dip_indv)]
with gzip.open("/home/gerald/Documents/PhD/papers/paper4/Multi_sim_5000.vcf.gz", "wt") as vcf_file:
    ts_mu.write_vcf(vcf_file, position_transform="legacy", contig_id=22, individual_names=indv_names)

In [None]:
def merge_intervals(intervals):
    """takes a list of (population-specific) ancestry intervals and combines them into contiguous intervals"""
    merged_intervals = []
    start = None
    stop = None
    for newstart, newstop in intervals:
        if start is None: # start first
            start = newstart
            stop = newstop
        elif np.allclose(newstart, stop): # extend
            stop = newstop
        else:  # end previous and start new 
            merged_intervals.append((start, stop))
            start = newstart
            stop = newstop
    # get terminal interval
    merged_intervals.append((start, newstop))
    return(np.array(merged_intervals))

def plot_merged_segments(merged_segments_from_pop):
    """rough plot of local ancestry, plots the output of find_local_ancestry()"""
    fig, ax = plt.subplots(figsize=(10,2))
    for anc_pop, ms in merged_segments_from_pop.items():
        lines = zip(zip(ms[:,0], itertools.repeat(anc_pop)), zip(ms[:,1], itertools.repeat(anc_pop)))
        lc = mc.LineCollection(lines, linewidths=8)
        ax.add_collection(lc)
    ax.margins(0.1)
    maxpop = max(merged_segments_from_pop.keys())
    plt.ylim(-.2, maxpop+.2)
    plt.yticks(range(maxpop+1), range(maxpop+1))
    plt.ylabel('source population')
    plt.xlabel('bp position')
    
def find_local_ancestry(sample, time, ts, mig_int_tree, tree_idx):
    """returns a dict.
    keys are the ancestral populations.
    values are the contiguous tracks of ancestry for the sample inherited from that pop.  
    Pops are defined at time """
    t0 = t.time()
    #Make a dictionaty of migrations with nodes as keys
    # mig_int_tree = dict()
    # for migration in ts.migrations():
    #     if migration.time < time:
    #         if migration.node not in mig_int_tree:
    #             mig_int_tree[migration.node] = [migration]
    #         else:
    #             mig_int_tree[migration.node].append(migration)
    # for each tree, find the oldest node prior to [time]
    ancestor_before_timex_of_tree = dict()
    tree = ts.first()
    for i in tree_idx:
        tree.seek_index(i)
        #if(tree.num_sites > 0):
        target = sample
        node_time = tree.time(target)
        parent_node = tree.parent(target)
        if parent_node != tskit.NULL:
            parent_time = tree.time(tree.parent(target))
        else:
            parent_time = time+1 
        while parent_time < time:
            node_time = parent_time
            target = tree.parent(target)
            parent_node = tree.parent(target)
            if parent_node != tskit.NULL:
                parent_time = tree.time(tree.parent(target))
            else:
                parent_time = time+1 
        ancestor_before_timex_of_tree[tree.index] = target
    # loop over trees and their relevant ancestors, find the pop that contributes to this sample 
    pop_at_time_of_parent = dict()
    pop_at_time_of_tree = dict()
    intervals_of_tree = dict()
    tree.first()
    for i in tree_idx:
        tree.seek_index(i)
        #if(tree.num_sites > 0):
        parent_node = ancestor_before_timex_of_tree[tree.index]
        if parent_node in mig_int_tree:
            overlapping_migrations = list(filter(lambda x: x.left <= tree.interval[0] and x.right >= tree.interval[0], mig_int_tree[parent_node]))
        else:
            overlapping_migrations = []
        # the original population of the node
        pop_at_time_of_parent[parent_node] = tree.population(parent_node)
        if len(overlapping_migrations) > 0:
            #print len(overlapping_migrations)
            overlapping_migrations = sorted(overlapping_migrations, 
                                            key = lambda x : x.time)
            #last_mig = overlapping_migrations[-1]#.pop()
            last_mig = overlapping_migrations.pop()
            #assert (pop_at_time_of_parent[parent_node] == mig.source), (parent_node, mig.time, pop_at_time_of_parent[parent_node], mig.source, mig.dest)
            pop_at_time_of_parent[parent_node] = last_mig.dest
        pop_at_time_of_tree[tree.index] = pop_at_time_of_parent[parent_node]
        intervals_of_tree[tree.index] = tree.interval
    # for each tree, record the interval and pop
    # intervals_of_tree = dict()
    # for tree in ts.trees():
    #     intervals_of_tree[tree.index] = tree.interval
    segments_from_pop = defaultdict(list)
    for ti, anc_pop in pop_at_time_of_tree.items():
        segments_from_pop[anc_pop].append(intervals_of_tree[ti])
    # merge adjacent intervals from the same population
    merged_segments_from_pop = dict()
    for anc_pop in segments_from_pop:
        merged_segments_from_pop[anc_pop] = merge_intervals(segments_from_pop[anc_pop])  
    t5 = t.time()
    print(sample, "done in ", t5-t0, "seconds.")
    return({sample: merged_segments_from_pop})

In [None]:
ts_mu = tskit.load("../multi_sim_5000_22_mu.tree")

In [None]:
ex = list(pd.read_csv("../ids/ids_22.txt", sep="\t", header=None)[1])

In [None]:
ex_ids = []
for i in ts_mu.sites():
    if(int(i.position) not in ex):
        ex_ids.append(i.id)

In [None]:
ts_trim = ts_mu.delete_sites(ex_ids)

In [None]:
sample_nodes = list()
pop_nodes = list()

for node in ts_mu.nodes():
    if node.time == 0 :
        sample_nodes.append(node.id)
        if node.population == 4:
            pop_nodes.append("Nama")
        if node.population == 6:
            pop_nodes.append("MSL")
        if node.population == 7:
            pop_nodes.append("GBR")
        if node.population == 8:
            pop_nodes.append("EP")
        if node.population == 10:
            pop_nodes.append("EAS")
        if node.population == 11:
            pop_nodes.append("SAS")
        if node.population == 12:
            pop_nodes.append("SAC")

In [None]:
tree_sites = []
for tree in ts_trim.trees():
    tree_sites.append(tree.num_sites)

In [None]:
tree_idx = []
for tree in ts_trim.trees():
    if(tree.num_sites > 0):
        tree_idx.append(tree.index)

In [None]:
len(tree_idx)

In [None]:
time = 100
#Make a dictionaty of migrations with nodes as keys
mig_int_tree = dict()
for migration in ts_mu.migrations():
    if migration.time < time:
        if migration.node not in mig_int_tree:
            mig_int_tree[migration.node] = [migration]
        else:
            mig_int_tree[migration.node].append(migration)

In [None]:
test = find_local_ancestry(0, time, ts_trim, mig_int_tree)

In [None]:
def find_local_ancestry(samples, time, ts, mig_int_tree):
    t0 = t.time()
    ancestor_before_timex_of_tree = dict()
    pop_at_time_of_parent = dict()
    pop_at_time_of_tree = dict()
    intervals_of_tree = dict()
    merged_segments_from_pop = dict()
    for sample in samples:
        ancestor_before_timex_of_tree[sample] = dict()
        pop_at_time_of_tree[sample] = dict()
        intervals_of_tree[sample] = dict()
        merged_segments_from_pop[sample] = dict()
    for tree in ts.trees():
        if(tree.num_sites > 0):
            for sample in samples:
                target = sample
                node_time = tree.time(target)
                parent_node = tree.parent(target)
                if parent_node != tskit.NULL:
                    parent_time = tree.time(tree.parent(target))
                else:
                    parent_time = time+1 
                while parent_time < time:
                    node_time = parent_time
                    target = tree.parent(target)
                    parent_node = tree.parent(target)
                    if parent_node != tskit.NULL:
                        parent_time = tree.time(tree.parent(target))
                    else:
                        parent_time = time+1 
                parent_node = target
                if parent_node in mig_int_tree:
                    overlapping_migrations = list(filter(lambda x: x.left <= tree.interval[0] and x.right >= tree.interval[0], mig_int_tree[parent_node]))
                else:
                    overlapping_migrations = []
                pop_at_time_of_parent[parent_node] = tree.population(parent_node)
                if len(overlapping_migrations) > 0:
                    overlapping_migrations = sorted(overlapping_migrations, 
                                                    key = lambda x : x.time)
                    last_mig = overlapping_migrations.pop()
                    pop_at_time_of_parent[parent_node] = last_mig.dest
                pop_at_time_of_tree[sample][tree.index] = pop_at_time_of_parent[parent_node]
                intervals_of_tree[sample][tree.index] = tree.interval
    for sample in samples:
        segments_from_pop = defaultdict(list)
        for ti, anc_pop in pop_at_time_of_tree[sample].items():
            segments_from_pop[anc_pop].append(intervals_of_tree[sample][ti]) 
        for anc_pop in segments_from_pop:
            merged_segments_from_pop[sample][anc_pop] = merge_intervals(segments_from_pop[anc_pop])
    t5 = t.time()
    print("done in ", t5-t0, "seconds.")
    return(merged_segments_from_pop)

In [None]:
a = dict()
a[0] = {1: "A"}
a[0][1]

In [None]:
test3 = find_local_ancestry([0,1,2,3,4], time, ts_trim, mig_int_tree)

In [None]:
pop_nodes[19999]

In [None]:
test2[4]

In [None]:
t0 = t.time()
ancestor_before_timex_of_tree = dict()
for sample in [0,1]:
    ancestor_before_timex_of_tree[sample] = dict()
for tree in ts_trim.trees():
    if(tree.num_sites > 0):
        for sample in [0,1]:
            target = sample
            node_time = tree.time(target)
            parent_node = tree.parent(target)
            if parent_node != tskit.NULL:
                parent_time = tree.time(tree.parent(target))
            else:
                parent_time = time+1 
            while parent_time < time:
                node_time = parent_time
                target = tree.parent(target)
                parent_node = tree.parent(target)
                if parent_node != tskit.NULL:
                    parent_time = tree.time(tree.parent(target))
                else:
                    parent_time = time+1 
            ancestor_before_timex_of_tree[sample][tree.index] = target

In [None]:
ancestor_before_timex_of_tree

In [None]:
ADM_la = dict()
samples = sample_nodes[0:1]

if __name__ == '__main__':
    with mp.Pool(2) as p:
        results = [p.apply_async(find_local_ancestry, args=(sample, time, ts_trim, mig_int_tree)) for sample in samples]                     
        for r in results:
            key = next(iter(r.get()))
            ADM_la[key] = r.get()[key]

In [None]:
ts_trim.at_index(tree_idx[0])

In [None]:
ADM_la[0]

In [None]:
ADM_la[1]

In [None]:
pos_raw = sorted(set(list(chain.from_iterable(chain.from_iterable(pd.DataFrame(ADM_sub).iloc[0])))))
pos = np.array(list(map(lambda x: int(x), pos_raw)))

base = pd.DataFrame({"chm":[19]*(len(pos)-1), "spos":pos[:-1], "epos":pos[1:]})

inds = dict()
for i in range(15192,15384):
    ind = ADM_la.get(i)
    #anc = len(ind)
    info = list([0] * len(pos[1:]))
    #for a in range(anc):
    if np.array(ind.get(1) != None).any():
        for p in ind.get(1):
            start = list(pos[:-1]).index(int(p[0]))
            stop = list(pos[1:]).index(int(p[1])) + 1
            r = len(range(start,stop))
            info[start:stop] = [int(a)] * r
    inds[i] = np.array(info)
    #print(len(info))

inds = pd.DataFrame(inds)
sp = pd.concat([base, inds], axis=1)