# Rt profile multicategorical versus uniform

In [1]:
# Import libraries
import os
import numpy as np
import math
import branchpro
import scipy.stats
from branchpro.apps import ReproductionNumberPlot
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import pandas as pd
from cmdstanpy import CmdStanModel, cmdstan_path
import arviz as az
import nest_asyncio
import seaborn as sns
nest_asyncio.apply()

num_timepoints = 300 # number of days for incidence data
num_categories = 3

  from .autonotebook import tqdm as notebook_tqdm


## Parameterize example branching process model with disagregated data

In [2]:
# Build the serial interval w_s
serial_intervals = []

ws_mean_cat = [5.3, 7, 4.2]
ws_std_cat = [2.3, 2.3, 2.3]

for ws_mean, ws_std in zip(ws_mean_cat, ws_std_cat):
    theta = ws_std**2 / ws_mean
    k = ws_mean / theta
    w_dist = scipy.stats.gamma(k, scale=theta)
    disc_w = w_dist.pdf(np.arange(30))

    serial_intervals.append(disc_w)

serial_intervals = np.array(serial_intervals)

# Simulate incidence data
initial_r = 0.5

# contact_matrix = contact_matrix
path = os.path.join('../../data_library/polymod/final_contact_matrices/', 'BASE_Japan_3.csv')
contact_matrix = np.transpose(pd.read_csv(path, header=None).to_numpy())

transmissibility = [1, 0.3, 0.6]

m = branchpro.MultiCatPoissonBranchProModel(
        initial_r, serial_intervals, num_categories, contact_matrix, transmissibility, multipleSI=True)
m_overall = branchpro.MultiCatPoissonBranchProModel(
        initial_r, serial_intervals, num_categories, contact_matrix, transmissibility, multipleSI=True)
m_target = branchpro.MultiCatPoissonBranchProModel(
        initial_r, serial_intervals, num_categories, contact_matrix, transmissibility, multipleSI=True)
m_wrong_target = branchpro.MultiCatPoissonBranchProModel(
        initial_r, serial_intervals, num_categories, contact_matrix, transmissibility, multipleSI=True)

new_rs = [0.14]          # sequence of R_0 numbers
start_times = [0]      # days at which each R_0 period begins
m.set_r_profile(new_rs, start_times)
m_overall.set_r_profile(new_rs, start_times)
m_target.set_r_profile(new_rs, start_times)
m_wrong_target.set_r_profile(new_rs, start_times)
parameters = [2, 80, 13] # initial number of cases
times = np.arange(num_timepoints)

In [3]:
contact_matrix

array([[ 6.96142573,  5.22907914,  1.48784293],
       [ 1.86863597, 10.65987513,  2.38209249],
       [ 0.33377886,  1.29212976,  3.37356853]])

## Overall interventions

In [4]:
overall_reductions = [1, 0.75, 0.5]
times_reductions = [0, 20, 40]

## Simulate model

In [5]:
agg_cases = []
overall_agg_cases = []
target_agg_cases = []
wrong_target_agg_cases = []

for _ in range(100):
    desagg_cases = m.simulate(
        parameters, times)
    agg_cases.append(np.sum(desagg_cases, axis=1).tolist())

    overall_desagg_cases = m_overall.simulate(
        parameters, times,
        interventions=[red*np.identity(num_categories) for red in overall_reductions],
        time_interventions=times_reductions)
    overall_agg_cases.append(np.sum(overall_desagg_cases, axis=1).tolist())

    target_desagg_cases = m_target.simulate(
        parameters, times,
        interventions=[np.diag([red, 1, 1]) for red in overall_reductions],
        time_interventions=times_reductions)
    target_agg_cases.append(np.sum(target_desagg_cases, axis=1).tolist())

    wrong_target_desagg_cases = m_wrong_target.simulate(
        parameters, times,
        interventions=[np.diag([1, red, 1]) for red in overall_reductions],
        time_interventions=times_reductions)
    wrong_target_agg_cases.append(np.sum(wrong_target_desagg_cases, axis=1).tolist())

agg_cases_mean = np.mean(agg_cases, axis=0).tolist()
overall_agg_cases_mean = np.mean(overall_agg_cases, axis=0).tolist()
target_agg_cases_mean = np.mean(target_agg_cases, axis=0).tolist()
wrong_target_agg_cases_mean = np.mean(wrong_target_agg_cases, axis=0).tolist()


## Plot agreggated local incidence numbers

In [6]:
# Plot (bar chart cases each day)
from plotly.subplots import make_subplots
fig = go.Figure()
fig = make_subplots(rows=2, cols=2)

# Plot of incidences
fig.add_trace(
    go.Scatter(
        x=times[1:],
        y=agg_cases_mean[1:],
        name='Cases'
    ),
    row=1,col=1
)

fig.add_trace(
    go.Scatter(
        x=times[1:],
        y=wrong_target_agg_cases_mean[1:],
        name='Cases with Poorly Targeted NPIs'
    ),row=1,col=2
)

fig.add_trace(
    go.Scatter(
        x=times[1:],
        y=overall_agg_cases_mean[1:],
        name='Cases with Overall NPIs'
    ),row=2,col=1
)

fig.add_trace(
    go.Scatter(
        x=times[1:],
        y=target_agg_cases_mean[1:],
        name='Cases with Targeted NPIs'
    ),row=2,col=2
)

# Add axis labels
fig.update_layout(
    boxmode='group',
    width=700,
    height=600,
    plot_bgcolor='white',
    xaxis=dict(
        title='Time (days)',
        linecolor='black'),
    yaxis=dict(
        title='New cases',
        linecolor='black'),
    xaxis2=dict(
        title='Time (days)',
        linecolor='black'),
    yaxis2=dict(
        title='New cases',
        linecolor='black'),
    xaxis3=dict(
        title='Time (days)',
        linecolor='black'),
    yaxis3=dict(
        title='New cases',
        linecolor='black'),
    xaxis4=dict(
        title='Time (days)',
        linecolor='black'),
    yaxis4=dict(
        title='New cases',
        linecolor='black'),
    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=1.05,
        xanchor="right",
        x=1
    ))

fig.write_image('images/Different_intervention_effects.pdf')
fig.show()

## More realistic reductions
### Reduction of contacts in one group leads to a shfit of those contacts with other population groups
For example, if an intervention is applied to school-age children, we reduce the number of contacts in that group, and increase their contacts in adults instead.

### Simple shift of contacts

In [7]:
def simple_overall_interv(red, cm):
    final_matrix = np.array([
        [red * cm[0, 0], 0.95 * (1-red) * cm[0, 0] + cm[0, 1], 0.05 * (1-red) * cm[0, 0] + cm[0, 2]],
        [0.7 * (1-red) * cm[1, 1] + cm[1, 0], red * cm[1, 1], 0.3 * (1-red) * cm[1, 1] +red * cm[1, 2]],
        [0.3 * (1-red) * cm[2, 2] + cm[2, 0], 0.7 * (1-red) * cm[2, 2] + cm[2, 1], red * cm[2, 2]]
    ])
    return np.matmul(final_matrix, np.linalg.inv(cm))

def simple_target_interv(red, cm):
    final_matrix = np.array([
        [red * cm[0, 0], 0.95 * (1-red) * cm[0, 0] + cm[0, 1], 0.05 * (1-red) * cm[0, 0] + cm[0, 2]],
        [cm[1, 0], cm[1, 1], cm[1, 2]],
        [cm[2, 0], cm[2, 1], cm[2, 2]]
    ])
    return np.matmul(final_matrix, np.linalg.inv(cm))


def simple_wrong_target_interv(red, cm):
    final_matrix = np.array([
        [cm[0, 0], cm[0, 1], cm[0, 2]],
        [0.7 * (1-red) * cm[1, 1] + cm[1, 0], red * cm[1, 1], 0.3 * (1-red) * cm[1, 1] +red * cm[1, 2]],
        [cm[2, 0], (1 - 0.25 * red) * cm[2, 1], cm[2, 2]]
    ])
    return np.matmul(final_matrix, np.linalg.inv(cm))


In [8]:
simple_realistic_agg_cases = []
simple_realistic_overall_agg_cases = []
simple_realistic_target_agg_cases = []
simple_realistic_wrong_target_agg_cases = []

for _ in range(100):
    simple_realistic_desagg_cases = m.simulate(
        parameters, times)
    simple_realistic_agg_cases.append(np.sum(simple_realistic_desagg_cases, axis=1).tolist())

    simple_realistic_overall_desagg_cases = m_overall.simulate(
        parameters, times,
        interventions=[simple_overall_interv(red, contact_matrix) for red in overall_reductions],
        time_interventions=times_reductions)
    simple_realistic_overall_agg_cases.append(np.sum(simple_realistic_overall_desagg_cases, axis=1).tolist())

    simple_realistic_target_desagg_cases = m_target.simulate(
        parameters, times,
        interventions=[simple_target_interv(red, contact_matrix)for red in overall_reductions],
        time_interventions=times_reductions)
    simple_realistic_target_agg_cases.append(np.sum(simple_realistic_target_desagg_cases, axis=1).tolist())

    simple_realistic_wrong_target_desagg_cases = m_wrong_target.simulate(
        parameters, times,
        interventions=[simple_wrong_target_interv(red, contact_matrix) for red in overall_reductions],
        time_interventions=times_reductions)
    simple_realistic_wrong_target_agg_cases.append(np.sum(simple_realistic_wrong_target_desagg_cases, axis=1).tolist())

simple_realistic_agg_cases_mean = np.mean(simple_realistic_agg_cases, axis=0).tolist()
simple_realistic_overall_agg_cases_mean = np.mean(simple_realistic_overall_agg_cases, axis=0).tolist()
simple_realistic_target_agg_cases_mean = np.mean(simple_realistic_target_agg_cases, axis=0).tolist()
simple_realistic_wrong_target_agg_cases_mean = np.mean(simple_realistic_wrong_target_agg_cases, axis=0).tolist()


## Plot agreggated local incidence numbers

In [9]:
# Plot (bar chart cases each day)
from plotly.subplots import make_subplots
fig = go.Figure()
fig = make_subplots(rows=2, cols=2)

# Plot of incidences
fig.add_trace(
    go.Scatter(
        x=times[1:],
        y=simple_realistic_agg_cases_mean[1:],
        name='Cases'
    ),
    row=1,col=1
)

fig.add_trace(
    go.Scatter(
        x=times[1:],
        y=simple_realistic_wrong_target_agg_cases_mean[1:],
        name='Cases with Poorly Targeted NPIs'
    ),row=1,col=2
)

fig.add_trace(
    go.Scatter(
        x=times[1:],
        y=simple_realistic_overall_agg_cases_mean[1:],
        name='Cases with Overall NPIs'
    ),row=2,col=1
)

fig.add_trace(
    go.Scatter(
        x=times[1:],
        y=simple_realistic_target_agg_cases_mean[1:],
        name='Cases with Targeted NPIs'
    ),row=2,col=2
)

# Add axis labels
fig.update_layout(
    boxmode='group',
    width=700,
    height=600,
    plot_bgcolor='white',
    xaxis=dict(
        title='Time (days)',
        linecolor='black'),
    yaxis=dict(
        title='New cases',
        linecolor='black'),
    xaxis2=dict(
        title='Time (days)',
        linecolor='black'),
    yaxis2=dict(
        title='New cases',
        linecolor='black'),
    xaxis3=dict(
        title='Time (days)',
        linecolor='black'),
    yaxis3=dict(
        title='New cases',
        linecolor='black'),
    xaxis4=dict(
        title='Time (days)',
        linecolor='black'),
    yaxis4=dict(
        title='New cases',
        linecolor='black'),
    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=1.05,
        xanchor="right",
        x=1
    ))

fig.write_image('images/Different_simple_realistic_intervention_effects.pdf')
fig.show()

### More complex shift of contacts

In [10]:
def overall_interv(red, cm):
    final_matrix = np.array([
        [red * cm[0, 0], (1 + 0.25 * (1-red)) * cm[0, 1], red * cm[0, 2]],
        [(1 + 0.25 * (1-red)) * cm[1, 0], red * cm[1, 1], red * cm[1, 2]],
        [red * cm[2, 0], red * cm[2, 1], red * cm[2, 2]]
    ])
    return np.matmul(final_matrix, np.linalg.inv(cm))

def target_interv(red, cm):
    final_matrix = np.array([
        [red * cm[0, 0], (1 + 0.25 * (1-red)) * cm[0, 1], (1 - 0.25 * red) * cm[0, 2]],
        [(1 + 0.25 * (1-red)) * cm[1, 0], cm[1, 1], cm[1, 2]],
        [(1 - 0.25 * red) *  cm[2, 0], cm[2, 1], cm[2, 2]]
    ])
    return np.matmul(final_matrix, np.linalg.inv(cm))


def wrong_target_interv(red, cm):
    final_matrix = np.array([
        [cm[0, 0], (1 + 0.25 * (1-red)) * cm[0, 1], cm[0, 2]],
        [(1 + 0.25 * (1-red)) * cm[1, 0], red * cm[1, 1], (1 - 0.25 * red) *cm[1, 2]],
        [cm[2, 0], (1 - 0.25 * red) * cm[2, 1], cm[2, 2]]
    ])
    return np.matmul(final_matrix, np.linalg.inv(cm))


In [11]:
realistic_agg_cases = []
realistic_overall_agg_cases = []
realistic_target_agg_cases = []
realistic_wrong_target_agg_cases = []

for _ in range(100):
    realistic_desagg_cases = m.simulate(
        parameters, times)
    realistic_agg_cases.append(np.sum(realistic_desagg_cases, axis=1).tolist())

    realistic_overall_desagg_cases = m_overall.simulate(
        parameters, times,
        interventions=[overall_interv(red, contact_matrix) for red in overall_reductions],
        time_interventions=times_reductions)
    realistic_overall_agg_cases.append(np.sum(realistic_overall_desagg_cases, axis=1).tolist())

    realistic_target_desagg_cases = m_target.simulate(
        parameters, times,
        interventions=[target_interv(red, contact_matrix)for red in overall_reductions],
        time_interventions=times_reductions)
    realistic_target_agg_cases.append(np.sum(realistic_target_desagg_cases, axis=1).tolist())

    realistic_wrong_target_desagg_cases = m_wrong_target.simulate(
        parameters, times,
        interventions=[wrong_target_interv(red, contact_matrix) for red in overall_reductions],
        time_interventions=times_reductions)
    realistic_wrong_target_agg_cases.append(np.sum(realistic_wrong_target_desagg_cases, axis=1).tolist())

realistic_agg_cases_mean = np.mean(realistic_agg_cases, axis=0).tolist()
realistic_overall_agg_cases_mean = np.mean(realistic_overall_agg_cases, axis=0).tolist()
realistic_target_agg_cases_mean = np.mean(realistic_target_agg_cases, axis=0).tolist()
realistic_wrong_target_agg_cases_mean = np.mean(realistic_wrong_target_agg_cases, axis=0).tolist()


## Plot agreggated local incidence numbers

In [12]:
# Plot (bar chart cases each day)
from plotly.subplots import make_subplots
fig = go.Figure()
fig = make_subplots(rows=2, cols=2)

# Plot of incidences
fig.add_trace(
    go.Scatter(
        x=times[1:],
        y=realistic_agg_cases_mean[1:],
        name='Cases'
    ),
    row=1,col=1
)

fig.add_trace(
    go.Scatter(
        x=times[1:],
        y=realistic_wrong_target_agg_cases_mean[1:],
        name='Cases with Poorly Targeted NPIs'
    ),row=1,col=2
)

fig.add_trace(
    go.Scatter(
        x=times[1:],
        y=realistic_overall_agg_cases_mean[1:],
        name='Cases with Overall NPIs'
    ),row=2,col=1
)

fig.add_trace(
    go.Scatter(
        x=times[1:],
        y=realistic_target_agg_cases_mean[1:],
        name='Cases with Targeted NPIs'
    ),row=2,col=2
)

# Add axis labels
fig.update_layout(
    boxmode='group',
    width=700,
    height=600,
    plot_bgcolor='white',
    xaxis=dict(
        title='Time (days)',
        linecolor='black'),
    yaxis=dict(
        title='New cases',
        linecolor='black'),
    xaxis2=dict(
        title='Time (days)',
        linecolor='black'),
    yaxis2=dict(
        title='New cases',
        linecolor='black'),
    xaxis3=dict(
        title='Time (days)',
        linecolor='black'),
    yaxis3=dict(
        title='New cases',
        linecolor='black'),
    xaxis4=dict(
        title='Time (days)',
        linecolor='black'),
    yaxis4=dict(
        title='New cases',
        linecolor='black'),
    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=1.05,
        xanchor="right",
        x=1
    ))

fig.write_image('images/Different_realistic_intervention_effects.pdf')
fig.show()

## R_t inference using STAN

In [13]:
# Same inference, but using the PoissonBranchProPosterior
tau = 6
R_t_start = tau+1
a = 1
b = 1/5
sigma = 1.5

# Correction for reproduction number per category
correc_e = np.multiply(np.sum(m.exact_contact_matrix, axis=0), transmissibility)
correc_e = np.asarray(correc_e.tolist()*num_timepoints)

In [14]:
contact_matrix

array([[ 6.96142573,  5.22907914,  1.48784293],
       [ 1.86863597, 10.65987513,  2.38209249],
       [ 0.33377886,  1.29212976,  3.37356853]])

In [15]:
eff_contact_matrix = np.matmul(contact_matrix, np.diag(transmissibility))
eff_contact_matrix

array([[6.96142573, 1.56872374, 0.89270576],
       [1.86863597, 3.19796254, 1.4292555 ],
       [0.33377886, 0.38763893, 2.02414112]])

In [16]:
new_rs[0] * np.sum(eff_contact_matrix, axis=0)

array([1.28293768, 0.72160553, 0.60845433])

In [17]:
# Transform our incidence data into pandas dataframes
multicat_inc_data_matrix = {'Time': np.arange(num_timepoints)}

for _ in range(num_categories):
    multicat_inc_data_matrix['Incidence Number Cat {}'.format(_+1)] = \
        desagg_cases[:, _]

multicat_inc_data = pd.DataFrame(multicat_inc_data_matrix)

# Transform our incidence data into pandas dataframes
inc_data = pd.DataFrame(
    {
        'Time': np.arange(num_timepoints),
        'Incidence Number': agg_cases
    }
)

L1 = len(np.arange(R_t_start, num_timepoints))

ground_truth = []
for j in range(num_categories):
    ground_truth.append(pd.DataFrame({
        'Time Points': np.arange(R_t_start, num_timepoints),
        'R_t': (0.2*correc_e[R_t_start:num_timepoints, j]).tolist()
    }))

ValueError: All arrays must be of the same length

In [None]:
ground_truth[0].values[:,1]

array([0.37372719, 0.06675577, 1.39228515, 0.37372719, 0.06675577,
       1.39228515, 0.37372719, 0.06675577, 1.39228515, 0.37372719,
       0.06675577, 1.39228515, 0.37372719, 0.06675577, 1.39228515,
       0.37372719, 0.06675577, 1.39228515, 0.37372719, 0.06675577,
       1.39228515, 0.37372719, 0.06675577, 1.39228515, 0.37372719,
       0.06675577, 1.39228515, 0.37372719, 0.06675577, 1.39228515,
       0.37372719, 0.06675577, 1.39228515, 0.37372719, 0.06675577,
       1.39228515, 0.37372719, 0.06675577, 1.39228515, 0.37372719,
       0.06675577, 1.39228515, 0.37372719, 0.06675577, 1.39228515,
       0.37372719, 0.06675577, 1.39228515, 0.37372719, 0.06675577,
       1.39228515, 0.37372719, 0.06675577, 1.39228515, 0.37372719,
       0.06675577, 1.39228515, 0.37372719, 0.06675577, 1.39228515,
       0.37372719, 0.06675577, 1.39228515, 0.37372719, 0.06675577,
       1.39228515, 0.37372719, 0.06675577, 1.39228515, 0.37372719,
       0.06675577, 1.39228515, 0.37372719, 0.06675577, 1.39228

In [None]:
ground_truth[1].values[:,1]

array([0.63959251, 0.07752779, 0.31374475, 0.63959251, 0.07752779,
       0.31374475, 0.63959251, 0.07752779, 0.31374475, 0.63959251,
       0.07752779, 0.31374475, 0.63959251, 0.07752779, 0.31374475,
       0.63959251, 0.07752779, 0.31374475, 0.63959251, 0.07752779,
       0.31374475, 0.63959251, 0.07752779, 0.31374475, 0.63959251,
       0.07752779, 0.31374475, 0.63959251, 0.07752779, 0.31374475,
       0.63959251, 0.07752779, 0.31374475, 0.63959251, 0.07752779,
       0.31374475, 0.63959251, 0.07752779, 0.31374475, 0.63959251,
       0.07752779, 0.31374475, 0.63959251, 0.07752779, 0.31374475,
       0.63959251, 0.07752779, 0.31374475, 0.63959251, 0.07752779,
       0.31374475, 0.63959251, 0.07752779, 0.31374475, 0.63959251,
       0.07752779, 0.31374475, 0.63959251, 0.07752779, 0.31374475,
       0.63959251, 0.07752779, 0.31374475, 0.63959251, 0.07752779,
       0.31374475, 0.63959251, 0.07752779, 0.31374475, 0.63959251,
       0.07752779, 0.31374475, 0.63959251, 0.07752779, 0.31374

In [None]:
ground_truth[2].values[:,1]

array([0.2858511 , 0.40482822, 0.17854115, 0.2858511 , 0.40482822,
       0.17854115, 0.2858511 , 0.40482822, 0.17854115, 0.2858511 ,
       0.40482822, 0.17854115, 0.2858511 , 0.40482822, 0.17854115,
       0.2858511 , 0.40482822, 0.17854115, 0.2858511 , 0.40482822,
       0.17854115, 0.2858511 , 0.40482822, 0.17854115, 0.2858511 ,
       0.40482822, 0.17854115, 0.2858511 , 0.40482822, 0.17854115,
       0.2858511 , 0.40482822, 0.17854115, 0.2858511 , 0.40482822,
       0.17854115, 0.2858511 , 0.40482822, 0.17854115, 0.2858511 ,
       0.40482822, 0.17854115, 0.2858511 , 0.40482822, 0.17854115,
       0.2858511 , 0.40482822, 0.17854115, 0.2858511 , 0.40482822,
       0.17854115, 0.2858511 , 0.40482822, 0.17854115, 0.2858511 ,
       0.40482822, 0.17854115, 0.2858511 , 0.40482822, 0.17854115,
       0.2858511 , 0.40482822, 0.17854115, 0.2858511 , 0.40482822,
       0.17854115, 0.2858511 , 0.40482822, 0.17854115, 0.2858511 ,
       0.40482822, 0.17854115, 0.2858511 , 0.40482822, 0.17854