In [None]:
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
from time import time as timer
import Branching_Process as bp
from tqdm import tqdm
import importlib
import seaborn as sns
importlib.reload(bp)
sns.set_palette('Set2')
cmap = sns.color_palette('Set2')
cmap = sns.color_palette('Accent')

# Branching Process walkthrough
- The helper functions for evaluating things such as the PGF, mean, variance and other quantities from the branching process are in the package Branching_Process (imported as bp)
- To use funcitons from the package, use bp. and then call the function (e.g. bp.cmj creates a cmj object, which is a Crump-Mode-Jagers process)
- The main purpose of this notebook is to have a play around with swapping in and out various features of the branching process and seeing what effect it has on various outputs. 
- Example features could be: different generation time distributions, different time-varying infectivity, different dependence on time-since-infection, including importation, other, fancier extensions! 
- Example outputs could be: mean and variance of the process, number of cases over time, prevalence/cumulative incidence, probability of extinction, time to threshold/extinction.  

### Set some parameters for the epidemic - these can be changed as you wish! 


In [None]:
# Define time intervals
Tmax = 100
nsteps = Tmax +1 
time = np.linspace(0, Tmax, num=nsteps)
time_step = time[1] - time[0]

# Effective Reproduction Number
R_eff = 2
# Lockdown time, if lockdown.
t_lockdown = 30

# Parameters for the generation interval taken from COVID-19, assuming a gamma distribution

a = 6.6
b = 0.833
# Define distributions for the lifetime, infectivity profile and population-level transmission rate



### Define the generation time distribution - this is the random period of time for which an infectious individual is able to infect others. You need to give both the cdf and pdf for this. 


In [None]:
def lifetime_gamma(t, tau):
    return sp.stats.gamma.cdf(t, a = a, scale = b) 

def lifetime_gamma_pdf(t, tau):
    return sp.stats.gamma.pdf(t, a = a, scale = b) 

def lifetime_gamma_survival_function(t):
    return 1-lifetime_gamma(t, 0)

lifetime = lifetime_gamma
lifetime_pdf = lifetime_gamma_pdf
lifetime_survival_integral = sp.integrate.quad(lifetime_gamma_survival_function, 0, 100)[0]


### Define how infectiousness changes over an individual lifetime. This should be a deterministic function of time-since-infection, tau. You need both the function itself and its time-derivative.

### Also define the population-level time-varying changes in transmission. Start with this being constant. Later in the notebook there is an example where $\rho$ varies over time to model an intervention, such as a lockdown

In [None]:
# Choose some functions for time-varying infectivity

def infectiousness_gamma(t):
    return sp.stats.gamma.cdf(t, a = a, scale = b)*2

def infectiousness_gamma_deriv(t):
    return sp.stats.gamma.pdf(t, a = a, scale = b)*2

def infectiousness_new(t):
    return sp.stats.gamma.cdf(t, a = 5.5, scale = b)

def infectiousness_new_deriv(t):
    return sp.stats.gamma.pdf(t, a = 5.5, scale = b)

# Constant infectiousness
def infectiousness_const(t, const = 1):
    return const * t* np.ones_like(t) 

def infectiousness_const_deriv(t, const = 1):
    return const * np.ones_like(t) 

# Constant transmission parameter
def rho_const(t, R=1):
    return R_eff *np.ones_like(t)  

# Transmission parameter reduced by a lockdown
def rho_lockdown(t, R1 = R_eff, R2 = 0.8, t_lockdown = 30):
    return R1 * (t<=t_lockdown)  + R2 * (1-(t <= t_lockdown))



In [None]:
# Here we do some scaling of the chosen infectiousness curve - this is interpreted so that we can talk about a person infecting a proportion of their total infectivity over the course of their infection. 
# Make sure you change both the infectiousness and its derivative!!

# Choose here which infectiousness curve you want. 
infectiousness_not_scaled = infectiousness_const
infectiousness_deriv_not_scaled = infectiousness_const_deriv

def combined_integral(t):
    return infectiousness_deriv_not_scaled(t) * (1-lifetime(t, 0)) 

integral = np.round(sp.integrate.quad(combined_integral, 0, 100)[0], 5)

def infectiousness(t):
    return infectiousness_not_scaled(t) /  integral

def infectiousness_deriv(t):
    return infectiousness_deriv_not_scaled(t) / integral

# Choose here which global transmission parameter/curve you want.
rho = rho_const

In [None]:
# Plot the various distributions below:

fig, axs = plt.subplots(1, 3, figsize = (10, 6))
axs[0].grid(alpha = 0.5)
axs[0].plot(time, lifetime_pdf(time, 0), color = cmap[0])
axs[0].set_xlim([0, 15])
axs[0].set_xlabel('Time Since Infection')
axs[0].set_title('Inf. Period, ' + r'$g(\tau)$')

axs[1].grid(alpha = 0.5)
axs[1].plot(time, infectiousness_deriv(time), color = cmap[1])
axs[1].set_xlim([0, 15])
axs[1].set_xlabel('Time Since Infection')
axs[1].set_title('Infectiousness, ' + r'$k(\tau)$')

axs[2].grid(alpha = 0.5)
axs[2].plot(time, rho(time), color = cmap[2])
# axs[2].plot(time, rho_lockdown(time, 1.5, 0.5, 7.5), color = cmap[2])
axs[2].set_xlim([0, 15])
axs[2].set_xlabel('Time Since Infection')
axs[2].set_title('Transmission Rate, ' + r'$\rho(\tau)$')

fig.tight_layout()


### Define the offspring distribution for how many cases are produced at each offspring event. 

### Let's stick with either Poisson or Logarithmic (which means that the resulting offspring for an individual will be Negative Binomial distributed - this incorporates superspreading!). 

### The PGF for both of these are given below. Start with Poisson and later try logarithmic to compare. Varying the value of phi will give different levels of superspreading

In [None]:
# Define individual offspring distribution

def logarithmic_offspring(z, phi = 0.1):
    return - phi * (np.log(phi + 1 - z) - np.log(phi))

def poisson_offspring(z, lamb = 1):
    return lamb*(z-1)


### Once all of this is defined, you can very easily create a cmj object below!

In [None]:

cmj = bp.cmj(poisson_offspring, time=time)
cmj.set_lifetime_distribution(lifetime)
cmj.set_infectiousness_profile(infectiousness)
cmj.set_transmission_rate(rho)

# Let's not include immigration for now. But definie it here so we can come back to it later... 
def immigration(tau, eta = 1):
    return eta*np.ones_like(tau)




# Define the intensity measure for the Poisson process that governs the offspring process of a single individual - 
# Integrating this gives is an individual case reproduction number! 
def intensity_measure(t):
    return rho_const(t + 0) * infectiousness_deriv(t) * (1-lifetime(t, 0)) 


case_reproduction_number = sp.integrate.quad(intensity_measure, 0, 50)[0]
print('Case Reproduction Number is: ' + str(np.round(case_reproduction_number, 5)))

### Calculate the PGF of the process, and then in one go conver into a PMF for the process. This gives us, at each time point, the distribution of the number of cases at time $t$

- This can take quite a long time to calculate, depending on the parameter max_pgf, so set this to be low (say, 1000-5000) if wanting to run multiple times. 
- This can be sped up using parallelisation - if you do not have multiple cores, set the parameter parallel = False
- On my computer, max_pgf = 10000 takes ~73 seconds if parallel = True and n_cores = 6. If parallel = False, it takes ~310 seconds... 
- Change these parameters according to your computer!!


In [None]:
# Need to set an upper limit for the number of cases - this is where the distribution will be cut off. 
max_pgf = 10000

n_cores = 6
# Set parallelise = False if this package is not installed, or if you cannot parallelise on your computer. 
characteristic = 'prev' # Switch between 'prev' for prevalence or 'ci' for cumulative incidence

# Calculate pmf - parallelise for faster results 
start = timer()


pmf_test = bp.pmfft(cmj.pgf_vec, 
                    max_pgf, 
                    immigration = None,
                    count = 'prev',
                    parallel = True,
                    n_cores = n_cores)



stop = timer()

print('Runs in ' + str(stop - start) + ' seconds on ' + str(n_cores) + ' cores')


mn_test = bp.mean_pmf(pmf_test)[:-1]


### Compare approximate mean with the analytic mean 


In [None]:
quick_mn = cmj.analytic_mean()[:-1]

In [None]:

showplot = 200

plt.grid(alpha = 0.5)

#plt.plot(xaxis[:showplot], mn_lockdown[:showplot], "g", label = "Lockdown")

#plt.plot(grid[:], (mean_pmf(pmf_immigration))[:], label = "Immigration", color = 'orange')
plt.plot(time, quick_mn, label = "Analytic mean")
plt.plot(time, mn_test, '--', label = "Mean from PGF")

# plt.plot(xaxis[:showplot], quick_mn[:showplot], "b--", label = "Lockdown mean")

plt.xlabel("Time (Days)")
plt.ylabel("Prevalence")
plt.xlim([0, Tmax])
plt.ylim([0, 5000])
plt.title("Mean Prevalence, M = " + str(max_pgf))
plt.legend()


In [None]:
extinction = (cmj.extinction_probability(cmj.pgf_vec))[:-1]
tstar_idx = np.argwhere(np.gradient(extinction, time_step)<= 0.0001)[1][0]
tstar = time[tstar_idx]
zstar = quick_mn[tstar_idx]
plt.grid(alpha = 0.5)
plt.plot(time, extinction)
plt.axvline(tstar, color = 'tab:purple', linestyle='--', label = "T*")
plt.legend()
plt.xlim([0, Tmax])
plt.ylim([0, 1])
plt.title("Probability of Extinction")
plt.xlabel("Time (Days)")
plt.ylabel("Probability")
print('Ultimate extinction probability is: ' + str(extinction[-1]))

In [None]:
Z1 = zstar
FPT_cdf_small = bp.FirstPassageTime(pmf_test, Z1, time, False)
FPT_pdf_small = bp.FirstPassageTime(pmf_test, Z1, time, True)

print('Probability sums to ' + str(np.round(FPT_cdf_small[-1], 4)))
plt.grid(alpha = 0.5)
plt.plot(time, FPT_pdf_small, color = 'tab:red')
plt.title('Time to hit ' + str(int(np.round(Z1))) + ' cases')
plt.xlabel('Time (Days)')
plt.ylabel("Probability")
#plt.savefig("FPT_162.png")

### Try now to go back to the beginning and change a few of the parameters:
- What if we chose a different lifetime distribution? 
- What if instead of constant infectiousness, we had a different infectiousness over time curve? 
- What if we used instead a logarithmic offsprind distribution, which includes overdispersion? And what happens if you change the parameter k?

### What about if we wanted to look at the time to extinction instead?? 

In [None]:
def time_to_extinction(extinction_prob):
    return np.gradient(((extinction) / (extinction[-1])))
plt.plot(time, time_to_extinction(extinction))
plt.grid(alpha = 0.5)
plt.title('Time to Extinction (conditioned on extinction)')

In [None]:
max_pgf = 10000
start = timer()


pmf_immigration = bp.pmfft(cmj.pgf_immigration, 
                    max_pgf, 
                    immigration = immigration,
                    count = 'prev',
                    parallel = True,
                    n_cores = n_cores)



stop = timer()

print('Runs in ' + str(stop - start) + ' seconds on ' + str(n_cores) + ' cores')

mn_immigration= bp.mean_pmf(pmf_immigration)[:-1]


In [None]:

Z1 = zstar
FPT_pdf_noim = bp.FirstPassageTimeImmigration(pmf_test, Z1, time, True)
FPT_pdf_im = bp.FirstPassageTimeImmigration(pmf_immigration, Z1, time, True)

plt.grid(alpha = 0.5)
plt.plot(time, FPT_pdf_noim, color = 'tab:red', label = 'No Immigration')

plt.plot(time, FPT_pdf_im, color = 'tab:orange', label = 'Immigration')
plt.title('Time to hit ' + str(int(np.round(Z1))) + ' cases')
plt.xlabel('Time (Days)')
plt.ylabel("Probability")
plt.legend()

### Try running a few different scenarios - this may take some time if not running in parallel, so if you cannot do this, maybe either reduce the number of scenarios, or reduce max_pgf

In [None]:
eta = 0.5
eta2 = 1
# Constant immigration
def scenario_1_im(t): 
    return eta*np.ones_like(t)

# Constant immigration (lower rate)
def scenario_2_im(t):
    return eta2 * np.ones_like(t)

# Exp increasing immigration
def scenario_3_im(t):
    return eta * np.exp(0.02*t)

# Exp decreasing immigration
def scenario_4_im(t):
    return eta * np.exp(-0.02*t)


# Total shutdown scenario
def scenario_5_im(t, t_stop = 15):
    return eta * (t < t_stop)


# Two exogenous sources - one growing epidemic, tighter controls + one decreasing epidemic, looser controls
def scenario_6_im(t, eta_country1 = 0.05, eta_country2 = -0.02, control_country1=0.01, control_country2 = 0.2):
    return control_country1*np.exp(eta_country1*t) + control_country2*np.exp(eta_country2*t)


In [None]:
immigration_scenarios = [scenario_1_im, scenario_2_im, scenario_3_im, scenario_4_im, scenario_5_im, scenario_6_im]
immigration_scenario_labels = ["scenario_1", "scenario_2", "scenario_3", "scenario_4", "scenario_5", "scenario_6"]


max_pgf = 20000

n_cores = 6
characteristic = 'prev'
scenarios = []
pmf_scenarios = []

for i, immigration in enumerate(immigration_scenarios):
    # if i != 5:
    #     continue
    cmj_immigration = bp.cmj(poisson_offspring, time=time)
    cmj_immigration.set_lifetime_distribution(lifetime)
    cmj_immigration.set_infectiousness_profile(infectiousness)
    cmj_immigration.set_transmission_rate(rho_const)
    cmj_immigration.set_immigration(immigration)
    scenarios += [cmj_immigration]
    start = timer()
    pmf_immigration = bp.pmfft(cmj_immigration.pgf_immigration, 
                    max_pgf,
                    immigration = immigration,
                    parallel = True,
                    n_cores = n_cores)
    stop = timer()
    
    print('Scenario ' + str(i) + ' finished in ' + str(np.round(stop - start, 2)) + ' seconds!')

    pmf_scenarios += [pmf_immigration]

In [None]:
fpt_scenarios = []
n_scenarios = len(pmf_scenarios)
Zstars = np.zeros(len(pmf_scenarios))
tstars = np.zeros(len(pmf_scenarios))
for i in range(n_scenarios):
    pmf = pmf_scenarios[i]
    mean = bp.mean_pmf(pmf)
    extinction = pmf[0, :]
    Zstar = 100
    try:
        tstar_idx = np.argwhere(mean >= Zstar)[1][0]#np.argwhere(np.gradient(extinction, time_step)<= 0.0001)[1][0]
    except:
        tstar_idx = -1
    tstar = time[tstar_idx]
    #int(mean[tstar_idx]) + 1
    Zstars[i] = Zstar
    tstars[i] = tstar
    FPT_cdf = bp.FirstPassageTime(pmf, Zstar, time, False)
    fpt_scenarios += [np.gradient(FPT_cdf, time_step)]

In [None]:
n_scenarios = len(pmf_scenarios)
fig, axs = plt.subplots(nrows = n_scenarios, ncols = 1)
immigration_legend_labels = ['Scenario ' + str(int(i+1)) for i in range(n_scenarios)]
#fig.set_figwidth(6)

for i in range(n_scenarios):
    axs[i].label_outer()
    
    axs[i].plot(time[5:], fpt_scenarios[i][5:], color = cmap[0], label = immigration_legend_labels[i])
    axs[i].fill_between(time[5:], fpt_scenarios[i][5:], color = cmap[0], alpha = 0.8)
    axs[i].axvline(np.sum(time[5:] * fpt_scenarios[i][5:])*time_step, color = cmap[1], linestyle = '--')
    axs[i].grid(alpha = 0.2)
    axs[i].set_xlim([0, 100])
    axs[i].set_ylim([0, 0.2])
    axs[i].legend(handlelength=0, handletextpad=0, fancybox=True)
    #axs[i].set_xlim([0, 100])

plt.xlabel('Time (Days)')
fig.supylabel('Density')
plt.suptitle('Time to hit ' + str(Zstar) + ' cases')
plt.tight_layout() 


In [None]:
extinction_list = []
R_lockdowns = np.array((1.2, 1.4, 1.6, 1.8, 2.)) / 2
for j, Rl in enumerate(R_lockdowns):
    cmj = bp.cmj(poisson_offspring, time=time)
    cmj.set_lifetime_distribution(lifetime)
    cmj.set_infectiousness_profile(infectiousness)
    def R_lockdown(t, R=R_eff, Rl = Rl, t_lockdown = 30):
        return R * (t<=t_lockdown)  + Rl * (1-(t <= t_lockdown))
    cmj.set_transmission_rate(R_lockdown)
    extinction = cmj.extinction_probability(cmj.pgf_vec)[:-1]
    extinction_list += [extinction]

### We can also model things such as the impact of a lockdown, by changing the global transmission parameter $\rho$ 
- Have a go at doing this, try differente parameters and days of implementing the lockdown. 
- You could also try changing the function rho_lockdown to model different changes in $\rho$, such as seasonal effects


In [None]:
max_pgf = 10000
time = np.arange(200)
n_cores = 6
characteristic = 'prev'
R = 1.8
R_l = 0.7
t_lockdowns = np.array((40, 50, 60, 70))
lockdown_scenarios = []
extinction_probs = []

for j, tl in enumerate(t_lockdowns):
    cmj = bp.cmj(poisson_offspring, time=time)
    cmj.set_lifetime_distribution(lifetime)
    cmj.set_infectiousness_profile(infectiousness)
    def rho_lockdown(t, R=R, Rl = R_l, t_lockdown = tl):
        return R * (t<=t_lockdown)  + Rl * (1-(t <= t_lockdown))
    cmj.set_transmission_rate(rho_lockdown)
    extinction_probs += [cmj.pgf_vec(0)]
    start = timer()
    pmf = bp.pmfft(cmj.pgf_vec, 
                        max_pgf, 
                        count = 'prev',
                        parallel = True,
                        n_cores = n_cores)
    stop = timer()
    

    print('Scenario ' + str(j) + ' finished in ' + str(np.round(stop - start, 2)) + ' seconds!')

    lockdown_scenarios += [pmf]
    

#### Investigate effects on time to extinction 

- Try different parameters, timings, and functional forms for the intervention and investigate the impact on time-to-extinction

In [None]:
mean_lockdowns = [bp.mean_pmf(ls) for ls in lockdown_scenarios]
l_idx = 2
ls =lockdown_scenarios[l_idx] 
ml = mean_lockdowns[l_idx]
t_idx = np.argwhere(time == t_lockdowns[l_idx])[0][0]
extinction = np.real_if_close(extinction_probs[l_idx][:-1])#ls[0, :-1]
extinction_lockdown = extinction[(t_idx+1):] / (1-extinction[t_idx])
extinction_pdf = np.gradient(extinction_lockdown, time_step)
mean_extinction_time = time[np.where(ml<1)[0][0]]


fig, ax = plt.subplots()
ax.grid(alpha = 0.5)
ln2 = ax.plot(time, ml[:-1], label = 'Prevalence (intervention)', color = cmap[0])
ax.set_xlim([0, 150])
ax.set_ylim([0, 3000])
ax.set_xlabel('Time (Days)')
ax.set_ylabel('Prevalence')
ln4 = ax.axvline(t_lockdowns[l_idx], label = 'Day of Intervention', linestyle = '--', alpha = 0.7, color = cmap[2])
ln5 = ax.axvline(mean_extinction_time, label = 'Mean Extinction Time', linestyle = '--', alpha = 0.7, color = cmap[7])
print('Mean at ' + str(mean_extinction_time))

ax2 = ax.twinx()
ln3 =  ax2.plot(time[(t_idx+1):], extinction_pdf, color = cmap[7], label = 'Probability of Zero')
ax2.fill_between(time[(t_idx+1):], 0, extinction_pdf, color = cmap[7], alpha = 0.2)
ax2.set_xlim([0, 200])
ax2.set_ylim([0, 0.1])
ax2.set_ylabel('Probability')
ax.set_title('Time to Zero Prevalence After an Intervention')

