In [1]:
# Import libraries
import numpy as np
import math
import branchpro
import scipy.stats
import matplotlib
import plotly.graph_objects as go
from matplotlib import pyplot as plt
import os
import pandas as pd


In [3]:
# Read australia region data
region = 'WA'
path = os.path.join("../data_library/covid_australia/", "{}.csv".format(region))
data = pd.read_csv(path)

In [9]:
time_key = 'Time'
num_timepoints = max(data[time_key])
data_times = data[time_key]

# Pad with zeros the time points where we have no information on
# the number of incidences
padded_inc_data = data.set_index(time_key).reindex(
    range(
        1, max(data_times)+1)
        ).fillna(0).reset_index()
locally_infected_cases = padded_inc_data['Incidence Number']
imported_cases = padded_inc_data['Imported Cases']

start_times = np.arange(1, num_timepoints+1, dtype=int)
times = np.arange(num_timepoints+1)

In [11]:
# Read serial interval
si_file = 'serial_interval'
path = os.path.join("../data_library/covid_australia/", "{}.txt".format(si_file))
serial_interval = pd.read_csv(path, header=None)
serial_interval = serial_interval.fillna(0)

Unnamed: 0,0,1
0,0,0.0
1,2,2.0
2,4,3.0
3,2,0.0
4,0,0.0
5,5,0.0
6,3,0.0
7,2,0.0
8,10,0.0


In [None]:
# Plot (bar chart cases each day)
fig = go.Figure()

# Plot of incidences
fig.add_trace(
    go.Bar(
        x=times,
        y=locally_infected_cases,
        name='Local Incidences'
    )
)

fig.add_trace(
    go.Bar(
        x=times,
        y=imported_cases,
        name='Imported Incidences'
    )
)

# Add axis labels
fig.update_layout(
    xaxis_title='Time (days)',
    yaxis_title='New cases'
)

fig.show()

In [None]:
# Same inference, but using the LocImpBranchProPosterior
tau = 2
R_t_start = tau+1
a = 1
b = 0.8

# Transform our incidence data into pandas dataframes
inc_data = pd.DataFrame(
    {
        'Time': times,
        'Incidence Number': locally_infected_cases
    }
)

imported_inc_data = pd.DataFrame(
    {
        'Time': start_times,
        'Incidence Number': imported_cases
    }
)

inference = branchpro.LocImpBranchProPosteriorMultSI(
    inc_data=inc_data,
    imported_inc_data=imported_inc_data,
    epsilon=epsilon,
    daily_serial_intervals=serial_interval,
    alpha=a,
    beta=1/b)

inference.run_inference(tau=tau)
intervals = inference.get_intervals(central_prob=.95)

In [None]:
fig = plt.figure()
plt.plot(start_times[R_t_start-1:], intervals['Mean'], 'k')
plt.plot(start_times[R_t_start-1:], intervals['Lower bound CI'], 'gray')
plt.plot(start_times[R_t_start-1:], intervals['Upper bound CI'], 'gray')
plt.plot(start_times, np.ones(start_times), 'red')
plt.xlabel('Time (Day)')
plt.ylabel('R_t')
plt.title('Inferred effective reproduction number vs. time')