# Feasibility of controlling COVID-19 outbreaks by isolation of cases and contacts

In [561]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from tqdm import tqdm
from math import floor

"Feasibility of controlling COVID-19 outbreaks by isolation of cases and contacts" - https://www.thelancet.com/journals/langlo/article/PIIS2214-109X(20)30074-7/fulltext

# Helper Functions


In [562]:
import torch.distributions as tpd
import torch 

In [563]:
# TODO: Write this in pytorch and test same outputs 

# literal adaption from:
# http://stackoverflow.com/questions/4643285/how-to-generate-random-numbers-that-follow-skew-normal-distribution-in-matlab
# original at:
# http://www.ozgrid.com/forum/showthread.php?t=108175
def rand_skew_norm(fAlpha, fLocation, fScale):
    """Sample a random skew normal variable with skew = alpha, loc=xi and scale=omega"""

    sigma = fAlpha / np.sqrt(1.0 + fAlpha**2) 

    afRN = np.random.randn(2)
    u0 = afRN[0]
    v = afRN[1]
    u1 = sigma*u0 + np.sqrt(1.0 -sigma**2) * v 

    if u0 >= 0:
        return u1*fScale + fLocation 
    return (-u1)*fScale + fLocation 

def randn_skew(N, alpha=0.0, xi=0., omega=1.):
    """Sample N random skew normal variable with skew = alpha, loc=xi and scale=omega"""
    return [rand_skew_norm(alpha, xi, omega) for x in range(N)]

def inf_fn(num_samples=None, inc_samp = None, k = None):
    """
    Infection function to sample exposure time, using skewed normal
    TODO: Convert it into pytorch
    """
    # xi is locatiom, omega is scale, alpha is slant
    out = []
    for n,inc_samp in zip(num_samples, inc_samp):
        tmp = np.array(randn_skew(int(n), alpha=k, xi=inc_samp, omega=2.))
        tmp = (tmp<1)*1 + (tmp>1)*tmp # if out less than one set to one
        out.append(tmp)
    return np.concatenate(out)

In [564]:
def outbreak_setup(num_initial_cases = None,
                   incfn = None,
                   prop_asym = None,
                   delayfn = None,
                   k = None):
    """
    Reimplementation of https://github.com/cmmid/ringbp/blob/master/R/outbreak_setup.R
    """
    # Set up table of initial cases
    inc_samples = incfn.sample_n(num_initial_cases)
    
    case_data = pd.DataFrame({'exposure':[], 'asym':[], 'infector':[], 'missed':[], 'onset':[], 'new_cases':[]})

    case_data['exposure'] = torch.zeros(num_initial_cases)
    case_data['asym'] = tpd.bernoulli.Bernoulli(prop_asym).sample_n(num_initial_cases)
    case_data['caseid'] = torch.arange(0,num_initial_cases, 1)
    case_data['infector'] = torch.zeros(num_initial_cases)
    case_data['missed'] = torch.ones(num_initial_cases) # bool
    case_data['onset'] = inc_samples
    case_data['new_cases'] = torch.zeros(num_initial_cases) 
    # set isolation time for cluster to minimum time of onset of symptoms + draw from delay distribution
    case_data['isolated_time'] = inc_samples + delayfn.sample_n(num_initial_cases)
    case_data['isolated'] = torch.zeros(num_initial_cases)  # bool

    # if asympomatc than isolated time set to infinity
    case_data['isolated_time'] = np.maximum((case_data['asym']*np.inf).fillna(0), case_data['isolated_time'])


    return case_data

In [565]:
def torch_max(t):
    "max between N tensor"
    unsqueezed = [t_.unsqueeze(-1) for t_ in t]
    combined = torch.cat(unsqueezed, dim=-1)
    return torch.max(combined, dim=-1).values

def torch_min(t):
    "max between N tensor"
    unsqueezed = [t_.unsqueeze(-1) for t_ in t]
    combined = torch.cat(unsqueezed, dim=-1)
    return torch.min(combined, dim=-1).values



def outbreak_step(case_data = None, disp_iso = None, 
                   disp_com = None, r0isolated = None, 
                   r0community = None, prop_asym = None, 
                   incfn = None, delayfn = None, prop_ascertain = None, 
                   k = None, quarantine = None):
    """
    Reimplementation of R function https://github.com/cmmid/ringbp/blob/master/R/outbreak_step.R
    """
    # For each case in case_data, draw new_cases from a negative binomial distribution
    # with an R0 and dispersion dependent on if isolated=TRUE
    n_cases = case_data.shape[0]
    cases_if_not_isolated = (1-case_data["isolated"])*tpd.negative_binomial.NegativeBinomial(*convert_params(r0community, disp_com)).sample_n(n_cases).numpy()
    # Negative binomial didn't work with mean 0 and dispersion 1 so I just initalised tensor of zero
    #     cases_if_isolated = case_data["isolated"]*tpd.negative_binomial.NegativeBinomial(*convert_params(r0isolated, disp_iso)).sample_n(n_cases).numpy()
    cases_if_isolated = case_data["isolated"]*torch.zeros(n_cases).numpy()
    case_data['new_cases'] = cases_if_isolated + cases_if_not_isolated
    new_case_data = case_data[case_data['new_cases']>0]
    total_new_cases = int(case_data['new_cases'].sum())
    # If no new cases drawn, outbreak is over so return case_data
    if (total_new_cases == 0):
        print('No new cases')
        case_data.isolated = 1.0
        effective_r0 = 0 
        cases_in_gen = 0
        return {"case_data": case_data, "effective_r0" : effective_r0, "cases_in_gen": cases_in_gen}
    
    # Compile a data.table for all new cases, new_cases is the amount of people that each infector has infected    
    inc_samples = incfn.sample_n(total_new_cases)
    prob_samples = pd.DataFrame({'exposure':[], 'asym':[], 'infector':[], 'missed':[], 'onset':[], 'new_cases':[]})
    # time when new cases were exposed, a draw from serial interval based on infector's onset
    prob_samples['exposure'] = inf_fn(new_case_data.new_cases, new_case_data.onset, k)
    # records the infector of each new person    
    prob_samples['infector'] = np.array([id_ for i, (n,id_) in enumerate(zip(new_case_data.new_cases, new_case_data.caseid)) for _ in range(int(n))])
    # records when infector was isolated
    prob_samples['infector_iso_time'] = np.array([isolatedt for i, (n,isolatedt) in enumerate(zip(new_case_data.new_cases, new_case_data.isolated_time)) for _ in range(int(n))])
    # records if infector asymptomatic
    prob_samples['infector_asym'] = np.array([asym for i, (n,asym) in enumerate(zip(new_case_data.new_cases, new_case_data.asym)) for _ in range(int(n))])
    # draws a sample to see if this person is asymptomatic    
    prob_samples['asym'] = tpd.bernoulli.Bernoulli(prop_asym).sample_n(total_new_cases)
    # draws a sample to see if this person is traced
    prob_samples['missed'] = tpd.bernoulli.Bernoulli(1-prop_ascertain).sample_n(total_new_cases)
    prob_samples['incubfn_sample'] = inc_samples
    # set isolation time for cluster to minimum time of onset of symptoms + draw from delay distribution
    prob_samples['new_cases'] = torch.zeros(total_new_cases) 
    prob_samples["isolated"] = torch.zeros(total_new_cases)

    # filter out new cases prevented by isolation
    prob_samples = prob_samples[prob_samples.exposure < prob_samples.infector_iso_time]
    total_new_cases_revised = prob_samples.shape[0]
    
    # If no new cases drawn due to isolation, outbreak is over so return case_data
    if (total_new_cases_revised == 0):
        print('No total_new_cases_revised')
        case_data.isolated = 1.0
        effective_r0 = 0 
        cases_in_gen = 0
        return {"case_data": case_data, "effective_r0" : effective_r0, "cases_in_gen": cases_in_gen}
    
    # onset of new case is exposure + incubation period sample
    prob_samples['onset']= prob_samples.exposure + prob_samples.incubfn_sample
    
    # cases whose parents are asymptomatic are automatically missed, or missed from random process
    prob_samples['missed']= np.maximum((prob_samples.infector_asym>=1.)*1, prob_samples['missed'])

    # If you are asymptomatic, your isolation time is Inf  
    iso_time_asym = torch.tensor(((prob_samples['asym']*np.inf).fillna(0)).to_numpy())

    # If you are not asymptomatic, but you are missed,
    # you are isolated at your symptom onset
    iso_time_missed = torch.tensor(((1-prob_samples['asym'])*
                                    prob_samples['missed']).to_numpy())*(
                                    torch.tensor(prob_samples['onset'].to_numpy())
                                     + delayfn.sample_n(total_new_cases_revised))
                                  

    # If you are not asymptomatic and you are traced,
    # you are isolated at max(onset,infector isolation time) # max(onset,infector_iso_time)
    # Not sure this is the bet logic for this
    ons_iso_max = torch_max([torch.tensor(prob_samples['infector_iso_time'].to_numpy()), torch.tensor(prob_samples['onset'].to_numpy())+ delayfn.sample_n(total_new_cases_revised)])
    nasym_traced = torch.tensor((1-prob_samples['asym']).to_numpy())*torch.tensor((1-prob_samples['missed']).to_numpy())
    # 0 * inf gives nan, so set back to zero
    iso_time_traced = (ons_iso_max*nasym_traced)*(1-quarantine) + quarantine*torch.tensor(prob_samples['infector_iso_time'].to_numpy())
    iso_time_traced[torch.isnan(iso_time_traced)] = 0
    
    # take max isolation time 
    prob_samples['isolated_time'] = torch_max([iso_time_asym.float(),iso_time_missed.float(),iso_time_traced.float() ])

    # Chop out unneeded sample columns
    prob_samples = prob_samples.drop(columns=['incubfn_sample', 'infector_iso_time', 'infector_asym'])
    
    # Set new case ids for new people
    prob_samples['caseid'] = torch.arange(n_cases, n_cases+total_new_cases_revised, 1)

    ## Number of new cases
    cases_in_gen = prob_samples.shape[0]
    
    ## Estimate the effective r0
    n_new_cases_non_isolated = prob_samples[prob_samples.isolated!=1].shape[0]
    effective_r0 = cases_in_gen / n_new_cases_non_isolated  if n_new_cases_non_isolated>0 else 0 

    # Everyone in case_data so far has had their chance to infect and are therefore considered isolated
    case_data.isolated = 1.0

    # bind original cases + new secondary cases
    case_data = pd.concat([case_data, prob_samples])
    
    # neaten up indexes in table
    case_data.reset_index(drop=True, inplace=True)

    return {"case_data": case_data, "effective_r0" : effective_r0, "cases_in_gen": cases_in_gen}

In [574]:
def outbreak_model(num_initial_cases = None, prop_ascertain = None,
                           cap_max_days = None, cap_cases = None,
                           r0isolated = None, r0community = None,
                           disp_iso = None, disp_com = None,
                           k = None, delay_shape = None,
                           delay_scale = None, prop_asym = None,
                           quarantine = None):
    """
    Reimplementation of https://github.com/cmmid/ringbp/blob/master/R/outbreak_model.R
    """
    # Incubation period sampling function
    incfn = tpd.weibull.Weibull(concentration = 2.322737,
                          scale = 6.492272)
    # Onset to isolation delay sampling function
    delayfn = tpd.weibull.Weibull(concentration = delay_shape, # 1.651524 
                            scale = delay_scale) # 4.287786 
    # Set initial values for loop indices
    total_cases= num_initial_cases # 5 
    latest_onset = 0
    extinct = False
    
    # Initial setup
    case_data = outbreak_setup(num_initial_cases = total_cases,
                               incfn = incfn,
                               prop_asym = prop_asym,
                               delayfn = delayfn,
                               k = k)
    
    # Preallocate
    effective_r0_vect = []
    cases_in_gen_vect = [ ]
    
#     pbar = tqdm(total = cap_max_days+1)
    # Model loop
    while (latest_onset < cap_max_days) and (total_cases < cap_cases) and (not extinct):
        out = outbreak_step(case_data = case_data,
                             disp_iso = disp_iso,
                             disp_com = disp_com,
                             r0isolated = r0isolated,
                             r0community = r0community,
                             incfn = incfn,
                             delayfn = delayfn,
                             prop_ascertain = prop_ascertain,
                             k = k,
                             quarantine = quarantine,
                             prop_asym = prop_asym)
        
        case_data = out['case_data']
        effective_r0_vect.append(out['effective_r0'])
        cases_in_gen_vect.append(out['cases_in_gen'])
        total_cases = case_data.shape[0]
        latest_onset = np.max(case_data.onset)
        extinct = torch.all(torch.tensor(case_data.isolated.to_numpy()).bool())
#         pbar.update(1)
        
#     pbar.close()
    if (latest_onset > cap_max_days):
        print('Terminating from max days')
    elif (total_cases > cap_cases):
        print('Total cases surpassed cap')
    else: 
        print('Terminating due to stochastic extinction')

    # Prepare output, group into weeks
    case_data['week'] = np.floor((case_data['onset']/7).to_numpy())
    weekly_cases = case_data[['week', 'onset']].groupby(['week']).count()
    weekly_cases.reset_index(drop=False, inplace=True)


    # maximum outbreak week
    max_week = floor(cap_max_days / 7)
    
    # add in missing weeks
    new_df = pd.DataFrame({'week': torch.arange(0,max_week,), 'onset': torch.zeros(max_week)})
    weekly_cases = pd.merge(new_df, weekly_cases, on='week', how='left').fillna(0)
    weekly_cases = weekly_cases[['week', 'onset_y']]
    weekly_cases['cumulative'] = np.cumsum(weekly_cases['onset_y'].to_numpy())
    # Add effective R0
    weekly_cases['effective_r0'] = np.array(effective_r0_vect).mean()
    return weekly_cases


In [575]:
def convert_params(mu, alpha):
    """ 
    Convert mean/dispersion parameterization of a negative binomial to the ones scipy supports

    Parameters
    ----------
    mu : float 
       Mean of NB distribution.
    alpha : float
       Overdispersion parameter used for variance calculation.

    See https://en.wikipedia.org/wiki/Negative_binomial_distribution#Alternative_formulations
    """
    var = mu + alpha * mu ** 2
    p = (var - mu) / var
    r = mu ** 2 / (var - mu)
    return r, p

# Simulation

In [576]:
## Put parameters that are grouped by disease into here
delay = "Wuhan" #("SARS", "Wuhan")
delay_shape = 1.651524 #(1.651524, 2.305172)
delay_scale = 4.287786 #(4.287786, 9.483875)

# TODO: not super familiar with the original R notation below
# k_group = list(tibble::tibble(
#     theta = c("<1%", "15%", "30%"),
#     k = c(30, 1.95, 0.7)
#   )),

k =  1.95 #(30, 1.95, 0.7) # k


index_R0 =  2.5 #(1.5, 2.5, 3.5)
prop_asym = 0.1 #(0, 0.1)
control_effectiveness = 0.4 #torch.arange(0, 1.2, 0.2)
num_initial_cases = 40 #(5, 20, 40)

# Fixed params
cap_max_days = 365
cap_cases = 10000 # notice they have capped the number of cases to a very small value, can increase this using sparse matrices
r0isolated = 0
disp_iso = 1
disp_com = 0.16
quarantine = 0 #bool

# Samples per experiment 
n_samples = 10

In [577]:
case_data_list = [outbreak_model(num_initial_cases = num_initial_cases, prop_ascertain = control_effectiveness,
               cap_max_days = cap_max_days, cap_cases = cap_cases,
               r0isolated = r0isolated, r0community = index_R0,
               disp_iso = disp_iso, disp_com = disp_com,
               k = k, delay_shape = delay_shape,
               delay_scale = delay_scale, prop_asym = prop_asym,
               quarantine = quarantine) for _ in range(n_samples)]



Total cases surpassed cap
Total cases surpassed cap
Total cases surpassed cap
Total cases surpassed cap
Total cases surpassed cap
Total cases surpassed cap
Total cases surpassed cap
Total cases surpassed cap
Total cases surpassed cap
Total cases surpassed cap


# Plotting Function

In [578]:
cumulative_list = list(map(lambda p: p['cumulative'].to_numpy(), case_data_list))
onset_y_list = list(map(lambda p: p['onset_y'].to_numpy(), case_data_list))

In [579]:
cumulative_list = np.vstack(cumulative_list)
onset_y_list = np.vstack(onset_y_list)

In [580]:
cumulative_list = pd.DataFrame(cumulative_list).melt()
onset_y_list = pd.DataFrame(onset_y_list).melt()

In [None]:
sns.lineplot(data=cumulative_list, x='variable', y='value', label='Cumulative Weekly Cases')
sns.lineplot(data=onset_y_list, x='variable', y='value', label='Weekly New Cases')
plt.xlabel('Week')
plt.ylabel('Value')
plt.title('Weekly Cases')
plt.show()