In [91]:
import numpy as np

def birth_death_clause(Nt, lamb_t, mniu, di, df, birth=True):
    p = lamb_t if birth else mniu
    return np.exp(-Nt*(lamb_t + mniu)*(df-di))*p

def end_exponential_clause(Nt, lamb_t, mniu, di, df):
    return np.exp(-Nt*(lamb_t + mniu)*(df-di))

def get_lamb_t(Nt, lamb_0=0.8, Bn=-0.075):
    return lamb_0 - Bn*Nt

In [92]:
get_lamb_t(1)

0.875

In [93]:
get_lamb_t(2)

0.9500000000000001

# 1.

In [94]:
events = [(5,True), (10,True), (20,None)]

tree_prob = 1
number_species = 1
mniu = 0.1

for i in range(len(events)):
    if i==0:
        di,df,event = 0,events[0][0],events[0][1]
    else:
        di,df,event = events[i-1][0], events[i][0], events[i][1]

    if event is None:
        tree_prob *= end_exponential_clause(number_species, get_lamb_t(number_species), mniu, di,df)
    else:
        tree_prob *= birth_death_clause(number_species, get_lamb_t(number_species), mniu, di,df, event)
        number_species += 1 if event else -1

print(tree_prob)

3.8460527355226085e-22


# 2.

In [95]:
from  scipy.integrate import dblquad, quad
from functools import partial

integrand1 = lambda x,y: birth_death_clause(1, get_lamb_t(1), mniu, 0, x, birth=True) * birth_death_clause(2, get_lamb_t(2), mniu, x, y, birth=False) * birth_death_clause(1, get_lamb_t(1), mniu, y, 5, birth=True)
integrand21 = lambda x: birth_death_clause(1, get_lamb_t(1), mniu, 0, x, birth=True) * birth_death_clause(2, get_lamb_t(2), mniu, x, 5, birth=True)
integrand22 = lambda y: birth_death_clause(3, get_lamb_t(3), mniu, 5, y, birth=False) * birth_death_clause(2, get_lamb_t(2), mniu, y, 10, birth=True)
integrand31 = integrand21
integrand32 = lambda y: birth_death_clause(4, get_lamb_t(4), mniu, 10, y, birth=False) * end_exponential_clause(3, get_lamb_t(3), mniu, y, 20)
integrand4 = lambda x,y: birth_death_clause(2, get_lamb_t(2), mniu, 5, x, birth=True) * birth_death_clause(3, get_lamb_t(3), mniu, x, y, birth=False) * birth_death_clause(2, get_lamb_t(2), mniu, y, 10, birth=True)
integrand51 = lambda x: birth_death_clause(2, get_lamb_t(2), mniu, 5, x, birth=True) * birth_death_clause(3, get_lamb_t(3), mniu, x, 10, birth=True)
integrand52 = lambda y: birth_death_clause(4, get_lamb_t(4), mniu, 10, y, birth=False) * end_exponential_clause(3, get_lamb_t(3), mniu, y, 20)
integrand6 = lambda x,y: birth_death_clause(3, get_lamb_t(3), mniu, 10, x, birth=True) * birth_death_clause(4, get_lamb_t(4), mniu, x, y, birth=False) * end_exponential_clause(3, get_lamb_t(3), mniu, y, 20)

In [96]:
result1 = dblquad(integrand1, 0, 5, lambda x: x, lambda x: 5)[0] * birth_death_clause(2, get_lamb_t(2), mniu, 5, 10, birth=True) * end_exponential_clause(3, get_lamb_t(3), mniu, 10, 20)
result2 = quad(integrand21, 0, 5)[0] * quad(integrand22, 5, 10)[0] * end_exponential_clause(3, get_lamb_t(3), mniu, 10, 20)
result3 = quad(integrand31, 0, 5)[0] * birth_death_clause(3, get_lamb_t(3), mniu, 5, 10, birth=True) * quad(integrand32, 10, 20)[0]
result4 = birth_death_clause(1, get_lamb_t(1), mniu, 0, 5, birth=True) * dblquad(integrand4, 5, 10, lambda x: x, lambda x: 10)[0] * end_exponential_clause(3, get_lamb_t(3), mniu, 10, 20)
result5 = birth_death_clause(1, get_lamb_t(1), mniu, 0, 5, birth=True) * quad(integrand51, 5, 10)[0] * quad(integrand52, 10, 20)[0]
result6 = birth_death_clause(1, get_lamb_t(1), mniu, 0, 5, birth=True) * birth_death_clause(2, get_lamb_t(2), mniu, 5, 10, birth=True) * dblquad(integrand6, 10, 20, lambda x: x, lambda x: 20)[0]

In [97]:
result = result1 + result2 + result3 + result4 + result5 + result6
print(result)

2.999818763407265e-17


In [99]:
case1 = lambda x,y: integrand1(x,y) * birth_death_clause(2, get_lamb_t(2), mniu, 5, 10, birth=True) * end_exponential_clause(3, get_lamb_t(3), mniu, 10, 20)
case2 = lambda x,y: integrand21(x) * integrand22(y) * end_exponential_clause(3, get_lamb_t(3), mniu, 10, 20)
case3 = lambda x,y: integrand31(x) * birth_death_clause(3, get_lamb_t(3), mniu, 5, 10, birth=True) * integrand32(y)
case4 = lambda x,y: birth_death_clause(1, get_lamb_t(1), mniu, 0, 5, birth=True) * integrand4(x,y) * end_exponential_clause(3, get_lamb_t(3), mniu, 10, 20)
case5 = lambda x,y: birth_death_clause(1, get_lamb_t(1), mniu, 0, 5, birth=True) * integrand51(x) * integrand52(y)
case6 = lambda x,y: birth_death_clause(1, get_lamb_t(1), mniu, 0, 5, birth=True) * birth_death_clause(2, get_lamb_t(2), mniu, 5, 10, birth=True) * integrand6(x,y)

def get_case_prob_Nt(x,y):
    if x < 5 and y < 5:
        return case1(x,y), 1
    elif x < 5 and y < 10:
        return case2(x,y), 1
    elif x < 5 and y < 20:
        return case3(x,y), 1
    elif x < 10 and y < 10:
        return case4(x,y), 2
    elif x < 10 and y < 20:
        return case5(x,y), 2
    else:
        return case6(x,y), 3


# 3.

In [126]:
probabilities = []
num_trials = 100000

T = 20
for _ in range(num_trials):
    start = np.random.uniform(0,T)
    finnish = np.random.uniform(start,T)
    case_prob, Nt = get_case_prob_Nt(start, finnish)
    normalization = (1/T) * (1/(T-start)) * (1/Nt)
    probabilities.append(case_prob / normalization)
    
mean = np.mean(probabilities)
std = np.std(probabilities)
print(mean, std)

1.1956572077393737e-21 2.616847154746734e-21


In [127]:
probabilities = []
num_trials = 100000

T = 20
for _ in range(num_trials):
    l = [np.random.uniform(0,T) for _ in range(2)]
    start, finnish = min(l), max(l)
    case_prob, Nt = get_case_prob_Nt(start, finnish)
    normalization = (1/T) * (1/T) * (1/Nt)
    probabilities.append(case_prob / normalization)

mean = np.mean(probabilities)
std = np.std(probabilities)
print(mean, std)

2.3985796881650967e-21 6.4912965975179474e-21
