# Here we provided the code to calculate the correction probability, namely the functions alpha(tree, N) and beta(tree, N), where tree should be in tskit tree format and N is the effective population size.

In [32]:
import math
import random
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib as mpl
import matplotlib.pyplot as plt
import tskit
import msprime
from scipy import stats
import statistics

In [33]:
def find_no_lineages(time, T):
    time_check = [time >= x for x in T]
    return(len(T)+1-sum(time_check))

def create_T(tree, N):
    T = []
    for x in tree.nodes():
        if tree.time(x) != 0:
            T.append(tree.time(x)/(2*N))
    T.sort()
    T = [0] + T
    T.reverse()
    T = [-1] + T
    return(T)

def create_t(T):
    t = []
    for i in range(2, len(T)):
        t.append(T[i-1]-T[i])
    t = [-1, -1] + t
    return(t)

def calculate_A_ij(i, j):
    if (i == j):
        A_ij = -1/i*math.exp(-i*T[i-1])
    else:
        exponent = -i*T[i-1]
        for k in range(j+1, i):
            exponent = exponent - k*t[k]
        A_ij = math.exp(exponent)*1/j*(1-math.exp(-j*t[j]))
    return(A_ij)
        
def create_A(tree):
    n = len(T) - 1
    A = np.zeros((n+1, n+1))
    for i in range(2, n+1):
        for j in range(2, i+1):
            A[i, j] = calculate_A_ij(i, j)
    return(A)

def create_G(tree, N):
    G = {}
    root = tree.root
    for x in tree.nodes():
        if x != root:
            G[x] = [tree.time(x)/(2*N), tree.time(tree.parent(x))/(2*N)]
    return(G)

In [34]:
def sub_prob(branch, i):
    p = 0
    lb = find_no_lineages(G[branch][1], T)+1
    for j in range(lb, i+1):
        p = p + A[i, j]
    p = p*(math.exp(i*T[i-1]) - math.exp(i*T[i]))
    p = p + t[i]
    p = p*1/i
    return(p)

def branch_prob(branch):
    lb = find_no_lineages(G[branch][1], T)+1
    ub = find_no_lineages(G[branch][0], T)
    p = 0
    for i in range(lb, ub+1):
        p = p + sub_prob(branch, i)
    return(p)

In [35]:
def full_prob(tree, N):
    global T
    global t
    global G
    global A
    T = create_T(tree, N)
    t = create_t(T)
    G = create_G(tree, N)
    A = create_A(tree)
    p = 0
    root = tree.root
    for x in tree.nodes():
        if x != root:
            p = p + branch_prob(x)
    return(p/tree.total_branch_length*2*N)

In [36]:
# From msprime output to readable format by the above program (the dictionary).
def TupleToSet(T):
    S = set()
    for element in T:
        S.add(element)
    return(S)
    
def SetToElement(S):
    e = None
    for element in S:
        e = element
    return(e)
    
def FindRoot(T):
    u = 0
    v = 0
    while u != -1:
        v = u
        u = T.parent(u)  
    return(v)
    
def TreeNodes(T):
    S = set()
    for element in T.nodes():
        S.add(element)
    return(S)
    
def FindPair(b, T):
    c = T.parent(b)
    S = TupleToSet(T.children(c))
    b_1 = S - {b}
    b_1 = SetToElement(b_1)
    return(b_1)

def getLongBranch(G):
    l = []
    for key in G:
        if (G[key][4])==True:
            l.append(key)
    return(l)

def BuildTreeDict(T, N):
    G = {}
    for element in TreeNodes(T):
        S = [None, None, None, None, None, None]
        P = [None, None, None, None, None, None]
        if (element == FindRoot(T)):
            S[2] = T.time(element)/(2*N)
            S[3] = S[2]+ 100
            S[4] = False
            S[5] = T.children(element)
            G[element] = S
        else:
            S[0] = FindPair(element, T)
            S[1] = T.parent(element)
            S[2] = T.time(element)/(2*N)
            S[3] = T.time(T.parent(element))/(2*N)
            S[5] = T.children(element)
            P[0] = element
            P[1] = S[1]
            P[2] = T.time(S[0])/(2*N)
            P[3] = S[3]
            P[5] = T.children(FindPair(element, T))
            if (T.branch_length(element)>=T.branch_length(S[0])):
                S[4] = True
                P[4] = False
            else:
                S[4] = False
                P[4] = True
            G[element] = S
            G[FindPair(element, T)] = P
    return(G)
    
def BuildTreeTimeline(T, N):
    S = []
    for u in T.nodes():
        S.append(T.time(u)/(2*N))
    S = list(set(S))
    S.sort()
    S.append(S[-1]+100)
    S.reverse()
    return(S)
    
def BuildTimeInterval(T, N):
    S = BuildTreeTimeline(T, N)
    t = []
    for i in range(1, len(S)):
        t.append(S[i-1] - S[i])
    return(t)

In [37]:
def lambda_dict(T):
    lambda_dict = {}
    for i in range(len(T)):
        lambda_dict[T[i]] = i
    return(lambda_dict)

def cumulative_timeline(t):
    cumulative_timeline = [0]
    for i in range(len(t)):
        cumulative_timeline.append(cumulative_timeline[-1]+(i+1)*t[i])
    return(cumulative_timeline)

def new_cal_A_prob_matrix(n):
    global new_A_prob_sum_matrix
    new_A_prob_sum_matrix = np.zeros([100, 100])
    for i in range(1, n+1):
        for j in range(1, i):
            new_A_prob_sum_matrix[i-1, j-1] = new_A_prob_sum_matrix[i-1, j-2] + (1/j)*(1-math.exp(-j*(t_tmp[j-1])))*c_T[i-1]*math.exp(-(c_timeline[i-1]-c_timeline[j]))
        new_A_prob_sum_matrix[i-1, i-1] = new_A_prob_sum_matrix[i-1, i-2] + (-1/i)*c_T[i-1]

def cal_A_prob_matrix(n):
    global new_A_prob_sum_matrix
    new_A_prob_sum_matrix = np.zeros([100, 100])
    for i in range(1, n+1):
        for j in range(1, i):
            #new_A_prob_sum_matrix[i-1, j-1] = new_A_prob_sum_matrix[i-1, j-2] + 0.01
            new_A_prob_sum_matrix[i-1, j-1] = new_A_prob_sum_matrix[i-1, j-2] + (1/j)*(1-math.exp(-j*(t_tmp[j-1])))*math.exp(-i*T_tmp[i-1]-(c_timeline[i-1]-c_timeline[j]))
        #new_A_prob_sum_matrix[i-1, i-1] = new_A_prob_sum_matrix[i-1, i-2] + 0.01
        new_A_prob_sum_matrix[i-1, i-1] = new_A_prob_sum_matrix[i-1, i-2] + (-1/i)*math.exp(-i*T_tmp[i-1])
            
def new_Mem_P_type1(i, b):
    P_1 = t_tmp[i-1]
    P_2 = new_A_prob_sum_matrix[i-1, i-1] - new_A_prob_sum_matrix[i-1, Mem_lambda[G_tmp[G_tmp[b][1]][3]]-1]
    P_2 = P_2 + new_A_prob_sum_matrix[i-1, Mem_lambda[G_tmp[G_tmp[b][0]][2]]-1] - new_A_prob_sum_matrix[i-1, Mem_lambda[G_tmp[G_tmp[b][0]][3]]-1]
    P = P_1+P_2*(math.exp(i*T_tmp[i-1])-math.exp(i*T_tmp[i]))
    P = (1/i)*P
    return(P)
    
def new_Mem_P_type2(i, b):
    P_1 = t_tmp[i-1]
    P_2 = 2*(new_A_prob_sum_matrix[i-1, i-1] - new_A_prob_sum_matrix[i-1, Mem_lambda[G_tmp[G_tmp[b][1]][3]]-1])
    P_2 = P_2 - (new_A_prob_sum_matrix[i-1, Mem_lambda[G_tmp[G_tmp[b][1]][2]]-1] - new_A_prob_sum_matrix[i-1, Mem_lambda[G_tmp[G_tmp[b][1]][3]]-1])
    P = 2*P_1+P_2*(math.exp(i*T_tmp[i-1])-math.exp(i*T_tmp[i]))
    P = (1/i)*P
    return(P)

def new_Mem_full_prob(G, T, t):
    global G_tmp
    G_tmp = G
    global T_tmp
    T_tmp = T
    global t_tmp
    t_tmp = t
    global c_timeline
    c_timeline = cumulative_timeline(t)
    global Mem_lambda
    Mem_lambda = lambda_dict(T)
    L = c_timeline[-1] - c_timeline[1]
    n = len(t)
    cal_A_prob_matrix(n)
    long_branch = getLongBranch(G)
    p = 0
    for b in long_branch:
        range_l = Mem_lambda[G[G[b][0]][3]]+1
        range_u = Mem_lambda[G[G[b][0]][2]]+1
        for k in range(range_l, range_u):
            p = p + 2*new_Mem_P_type2(k, b)
        range_l = Mem_lambda[G[G[b][0]][2]]+1
        range_u = Mem_lambda[G[b][2]]+1
        for k in range(range_l, range_u):
            p = p + new_Mem_P_type1(k, b)
    p = p/L
    return(p)

def beta(tree, N):
    G = BuildTreeDict(tree, 10000)
    T = BuildTreeTimeline(tree, 10000)
    t = BuildTimeInterval(tree, 10000)
    return(1 - new_Mem_full_prob(G, T, t))

In [38]:
def find_no_lineages(time, T):
    time_check = [time >= x for x in T]
    return(len(T)+1-sum(time_check))

def create_T(tree, N):
    T = []
    for x in tree.nodes():
        if tree.time(x) != 0:
            T.append(tree.time(x)/(2*N))
    T.sort()
    T = [0] + T
    T.reverse()
    T = [-1] + T
    return(T)

def create_t(T):
    t = []
    for i in range(2, len(T)):
        t.append(T[i-1]-T[i])
    t = [-1, -1] + t
    return(t)

def calculate_A_ij(i, j):
    if (i == j):
        A_ij = -1/i*math.exp(-i*T[i-1])
    else:
        exponent = -i*T[i-1]
        for k in range(j+1, i):
            exponent = exponent - k*t[k]
        A_ij = math.exp(exponent)*1/j*(1-math.exp(-j*t[j]))
    return(A_ij)
        
def create_A(tree):
    n = len(T) - 1
    A = np.zeros((n+1, n+1))
    for i in range(2, n+1):
        for j in range(2, i+1):
            A[i, j] = calculate_A_ij(i, j)
    return(A)

def create_G(tree, N):
    G = {}
    root = tree.root
    for x in tree.nodes():
        if x != root:
            G[x] = [tree.time(x)/(2*N), tree.time(tree.parent(x))/(2*N)]
    return(G)

In [39]:
def sub_prob(branch, i):
    p = 0
    lb = find_no_lineages(G[branch][1], T)+1
    for j in range(lb, i+1):
        p = p + A[i, j]
    p = p*(math.exp(i*T[i-1]) - math.exp(i*T[i]))
    p = p + t[i]
    p = p*1/i
    return(p)

def branch_prob(branch):
    lb = find_no_lineages(G[branch][1], T)+1
    ub = find_no_lineages(G[branch][0], T)
    p = 0
    for i in range(lb, ub+1):
        p = p + sub_prob(branch, i)
    return(p)

def alpha(tree, N):
    global T
    global t
    global G
    global A
    T = create_T(tree, N)
    t = create_t(T)
    G = create_G(tree, N)
    A = create_A(tree)
    p = 0
    root = tree.root
    for x in tree.nodes():
        if x != root:
            p = p + branch_prob(x)
    return(1 - p/tree.total_branch_length*2*N)