# Stability Criterion

In [1]:
# Import libraries
import numpy as np
import math
import branchpro
import scipy.stats
from branchpro.apps import ReproductionNumberPlot
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import pandas as pd
from cmdstanpy import CmdStanModel, cmdstan_path
import arviz as az
import nest_asyncio
import seaborn as sns
nest_asyncio.apply()

num_timepoints = 1000 # number of days for incidence data
num_categories = 3

trace_names = ['0-20 years old', '20-65 years old', '65+ years old']
colour_names = ['#EC9F0A', '#D41159', '#1A85FF']

  from .autonotebook import tqdm as notebook_tqdm


## Parametrisation of the renewal model

In [2]:
# Build the serial interval w_s
serial_intervals_range = []

ws_mean_range = [
    [15.3, 3.5, 7],
    [15.3, 3.5, 7],
    [15.3, 3.5, 7],
    [15.3, 3.5, 7],
    [15.3, 3.5, 7],
    [15.3, 3.5, 7],
    [15.3, 3.5, 7],
]
ws_std_range = [
    [5.3, 2.5, 3],
    [5.3, 2.5, 3],
    [5.3, 2.5, 3],
    [5.3, 2.5, 3],
    [5.3, 2.5, 3],
    [5.3, 2.5, 3],
    [5.3, 2.5, 3],
]

for ws_mean_cat, ws_std_cat in zip(ws_mean_range, ws_std_range):
    serial_intervals = []
    for ws_mean, ws_std in zip(ws_mean_cat, ws_std_cat):
        theta = ws_std**2 / ws_mean
        k = ws_mean / theta
        w_dist = scipy.stats.gamma(k, scale=theta)
        disc_w = w_dist.pdf(np.arange(250))

        serial_intervals.append(disc_w)
    serial_intervals_range.append(np.array(serial_intervals))

# Simulate incidence data
initial_r_range = [0.113, 0.10456, 0.09, 0.044, 0.07, 1/9.7262, 0.122]

# Select a range of different contact matrices
# contact_matrix_range = [
#     np.array([[6., 3., 2.], [3., 3., 2.], [2., 1., 6.]]),
#     np.array([[6., 3., 2.], [3., 3., 2.], [2., 1., 6.]]),
#     np.array([[6., 3., 2.], [3., 3., 2.], [2., 1., 6.]]),
#     # 1.1*np.array([[0.5, 1., 1], [1., 1.2, 1], [0.2, 1.2, 0.5]]),
#     # 3*np.array([[2.3, 0.1, 3], [0.2, 1.2, 0.5], [1., 1.2, 1]]),
#     np.array([[12, 0.6, 1], [6, 4.8, 5], [1, 5, 10]]),
#     np.array([[0, 6, 0], [0, 0, 3], [1.5, 0, 0]]),
#     np.array([[0, 6, 0], [0, 0, 3], [2, 0, 0]]),
#     np.array([[0, 6, 0], [0, 0, 3], [2, 0, 0]]),
#     # np.array([[0, 6, 0], [0, 0, 0.1], [2, 0, 0]])
# ]

path = os.path.join('../../data_library/polymod/final_contact_matrices/', 'BASE.csv')
contact_matrix = pd.read_csv(path, header=None)

contact_matrix_range = [contact_matrix]*7

# Choose unique transmissibilityt vector for ease so
# effective contact matrix = contact matrix
transmissibility = [1, 1, 1]

# Select a range of different growth rates
new_rs_range = [
    [0.113, 0.113],
    [0.10456, 0.10456],
    [0.09, 0.09],
    [0.044, 0.044],
    [0.082, 0.082],
    [1/9.7262, 1/9.7262],
    [0.122, 0.122],
]

# Choose no differences in the r_t profile after day 20
start_times = [0, 20]

# Initial number of cases
parameters = [50, 50, 50] 

## Compute Stability Criterion

In [3]:
stability_criterion = []

for (contact_matrix, new_rs) in zip(contact_matrix_range, new_rs_range):   
    # Compute the eigenvalues of the maximum effective contact matrix
    eff_contact_matrix = np.matmul(contact_matrix, np.diag(transmissibility))

    spec_radius = np.max(np.absolute(np.linalg.eigvals(eff_contact_matrix)))

    # Compute stability criterion using last change in growth rate
    stability_criterion.append(new_rs[-1] * spec_radius)

    new_rs_range[-2] = [1/spec_radius]*2

print(stability_criterion)

[1.0990605133144844, 1.0169713917890486, 0.875357930958439, 0.42795276624634787, 0.7975483370954666, 1.0, 1.1865963064103282]


## Deterministic Model

In [4]:
# Define Deterministic Model

class MultiCatDeterministicBranchProModel(branchpro.BranchProModel):
    r"""MultiCatDeterministicBranchProModel Class:
    Class for the models following a Branching Processes behaviour with
    no distribution noise and multiple population categories.
    It inherits from the ``BranchProModel`` class.

    In the branching process model, we track the number of cases
    registered for each category and each day, I_{t,i}, also known as the
    "incidence" at time t.

    The incidence at time t depends on previous number of cases, according to
    the following formula:

    .. math::
        I_{t, i}^{\text(local)}|I_0, I_1, \dots I_{t-1}, w_{s}, R_{t}) =
            \sum_{j}R_{t}C{i,j}T_{j}\sum_{s=1}^{t}I_{t-s,j}w_{s}

    Always apply method :meth:`set_r_profile` before calling
    :meth:`NegBinBranchProModel.simulate` for a change of R_t profile!

    Parameters
    ----------
    initial_r
        (list) List of reproduction numbers per category at the beginning
        of the epidemic.
    serial_interval
        (list) Unnormalised probability distribution of that the recipient
        first displays symptoms s days after the infector first displays
        symptoms for each category.
    num_cat
        (int) Number of categories in which the population is split.
    contact_matrix
        (array) Matrix of contacts between the different categories in which
        the population is split.
    transm
        (list) List of overall reductions in transmissibility per category.
    multipleSI
        (boolean) Different serial intervals used for categories.

    """

    def __init__(self, initial_r, serial_interval, num_cat, contact_matrix,
                 transm, multipleSI=False):
        if not isinstance(initial_r, (int, float)):
            raise TypeError('Value of R must be integer or float.')

        if not isinstance(num_cat, int):
            raise TypeError('Number of population categories must be integer.')
        if num_cat <= 0:
            raise ValueError('Number of population categories must be > 0.')

        if np.asarray(contact_matrix).ndim != 2:
            raise ValueError(
                'Contact matrix values storage format must be 2-dimensional')
        if np.asarray(contact_matrix).shape[0] != num_cat:
            raise ValueError(
                'Wrong number of rows in contact matrix values storage')
        if np.asarray(contact_matrix).shape[1] != num_cat:
            raise ValueError(
                'Wrong number of columns in contact matrix values storage')
        for c in np.asarray(contact_matrix):
            for _ in c:
                if _ < 0:
                    raise ValueError('Contact matrix values must be >= 0.')
                if not isinstance(_, (int, float)):
                    raise TypeError(
                        'Contact matrix values must be integer or float.')

        if np.asarray(transm).ndim != 1:
            raise ValueError(
                'Transmissiblity storage format must be 1-dimensional')
        if np.asarray(transm).shape[0] != num_cat:
            raise ValueError(
                'Wrong number of categories in transmissibility storage')
        for _ in transm:
            if _ < 0:
                raise ValueError('Transmissiblity values must be >= 0.')
            if not isinstance(_, (int, float)):
                raise TypeError(
                    'Transmissiblity values must be integer or float.')

        # Invert order of serial intervals for ease in _normalised_daily_mean
        self._num_cat = num_cat
        self._contact_matrix = np.asarray(contact_matrix)
        self._transm = np.asarray(transm)

        if multipleSI is False:
            if np.asarray(serial_interval).ndim != 1:
                raise ValueError(
                    'Serial interval values storage format must be\
                    1-dimensional')
            if np.sum(serial_interval) < 0:
                raise ValueError('Sum of serial interval values must be >= 0.')
            self._serial_interval = np.tile(
                np.asarray(serial_interval)[::-1], (num_cat, 1))
        else:
            if np.asarray(serial_interval).ndim != 2:
                raise ValueError(
                    'Serial interval values storage format must be\
                    2-dimensional')
            if np.asarray(serial_interval).shape[0] != num_cat:
                raise ValueError(
                    'Serial interval values storage format must match\
                    number of categories')
            for _ in range(num_cat):
                if np.sum(serial_interval[_, :]) < 0:
                    raise ValueError(
                        'Sum of serial interval values must be >= 0.')
            self._serial_interval = np.asarray(serial_interval)[:, ::-1]

        self._r_profile = np.array([initial_r])
        self._normalizing_const = np.sum(self._serial_interval, axis=1)

    def set_transmissibility(self, contact_matrix):
        """
        Updates contact matrix for the model.

        Parameters
        ----------
        contact_matrix
            New matrix of contacts between the different categories in which
            the population is split.

        """
        if np.asarray(contact_matrix).ndim != 2:
            raise ValueError(
                'Contact matrix values storage format must be 2-dimensional')
        if np.asarray(contact_matrix).shape[0] != self._num_cat:
            raise ValueError(
                'Wrong number of rows in contact matrix values storage')
        if np.asarray(contact_matrix).shape[1] != self._num_cat:
            raise ValueError(
                'Wrong number of columns in contact matrix values storage')
        for c in np.asarray(contact_matrix):
            for _ in c:
                if _ < 0:
                    raise ValueError('Contact matrix values must be >= 0.')
                if not isinstance(_, (int, float)):
                    raise TypeError(
                        'Contact matrix values must be integer or float.')

        self._contact_matrix = np.asarray(contact_matrix)

    def get_transmissibility(self):
        """
        Returns transmissibility vector for the model.

        """
        return self._transm

    def get_contact_matrix(self):
        """
        Returns contact matrix for the model.

        """
        return self._contact_matrix

    def get_serial_intervals(self):
        """
        Returns serial intervals for the model.

        """
        # Reverse inverting of order of serial intervals
        return self._serial_interval[:, ::-1]

    def set_serial_intervals(self, serial_intervals, multipleSI=False):
        """
        Updates serial intervals for the model.

        Parameters
        ----------
        serial_intervals
            New unnormalised probability distribution of that the recipient
            first displays symptoms s days after the infector first displays
            symptoms for each category.
        multipleSI
            (boolean) Different serial intervals used for categories.

        """
        # Invert order of serial intervals for ease in _effective_no_infectives
        if multipleSI is False:
            if np.asarray(serial_intervals).ndim != 1:
                raise ValueError(
                    'Serial interval values storage format must be\
                    1-dimensional')
            if np.sum(serial_intervals) < 0:
                raise ValueError('Sum of serial interval values must be >= 0.')
            self._serial_interval = np.tile(
                np.asarray(serial_intervals)[::-1], (self._num_cat, 1))
        else:
            if np.asarray(serial_intervals).ndim != 2:
                raise ValueError(
                    'Serial interval values storage format must be\
                    2-dimensional')
            if np.asarray(serial_intervals).shape[0] != self._num_cat:
                raise ValueError(
                    'Serial interval values storage format must match\
                    number of categories')
            for _ in range(self._num_cat):
                if np.sum(serial_intervals[_, :]) < 0:
                    raise ValueError(
                        'Sum of serial interval values must be >= 0.')
            self._serial_interval = np.asarray(serial_intervals)[:, ::-1]

        self._normalizing_const = np.sum(self._serial_interval, axis=1)

    def _effective_no_infectives(self, t, incidences, contact_matrix):
        """
        Computes expected number of new cases at time t, using previous
        incidences and serial intervals at a rate of 1:1 reproduction.

        Parameters
        ----------
        t
            evaluation time
        incidences
            sequence of incidence numbers
        contact_matrix
            matrix of contacts between categories
        """
        mean = np.zeros(self._num_cat)

        for i in range(self._num_cat):
            for j in range(self._num_cat):
                if t > self._serial_interval.shape[1]:
                    start_date = t - self._serial_interval.shape[1]
                    sub_sum = math.fsum(np.multiply(
                            self._serial_interval[j, :],
                            incidences[start_date:t, j]
                            )) / self._normalizing_const[j]
                else:
                    sub_sum = math.fsum(np.multiply(
                            self._serial_interval[j, -t:],
                            incidences[:t, j]
                            )) / self._normalizing_const[j]
                sub_sum *= contact_matrix[i, j] * self._transm[j]
                mean[i] += sub_sum

        return mean

    def simulate(
            self, parameters, times, var_contacts=False, neg_binom=False,
            niu=0.1):
        """
        Runs a forward simulation with the given ``parameters`` and returns a
        time-series with incidence numbers per population category
        corresponding to the given ``times``.

        Parameters
        ----------
        parameters
            Initial number of cases per population category.
        times
            The times at which to evaluate. Must be an ordered sequence,
            without duplicates, and without negative values.
            All simulations are started at time 0, regardless of whether this
            value appears in ``times``.
        var_contacts
            (boolean) Wheteher there exists noise in number of contacts.
        neg_binom
            (boolean) Wheteher the noise in number of contacts is Negative
            Binomial distributed.
        niu
            (float) Accepance probability.

        """
        initial_cond = parameters
        last_time_point = np.max(times)

        # Repeat final r if necessary
        # (r_1, r_2, ..., r_t)
        if len(self._r_profile) < last_time_point:
            missing_days = last_time_point - len(self._r_profile)
            last_r = self._r_profile[-1]
            repeated_r = np.full(shape=missing_days, fill_value=last_r)
            self._r_profile = np.append(self._r_profile, repeated_r)

        incidences = np.empty(shape=(last_time_point + 1, self._num_cat))
        incidences[0, :] = initial_cond

        # Construct simulation times in steps of 1 unit time each
        simulation_times = np.arange(start=1, stop=last_time_point+1, step=1)

        # Compute normalised daily means for full timespan
        # and draw samples for the incidences
        self.exact_contact_matrix = [np.random.poisson(self._contact_matrix)]

        for t in simulation_times:
            if var_contacts is False:
                contact_matrix = self._contact_matrix
            else:
                if neg_binom is False:
                    contact_matrix = np.random.poisson(self._contact_matrix)
                else:
                    contact_matrix = np.random.negative_binomial(
                        self._contact_matrix, niu)
                self.exact_contact_matrix.append(contact_matrix)
            norm_daily_mean = self._r_profile[t-1] * \
                self._effective_no_infectives(t, incidences, contact_matrix)
            incidences[t, :] = norm_daily_mean

        mask = np.in1d(np.append(np.asarray(0), simulation_times), times)
        return incidences[mask]

In [5]:
deterministic_cases = []

for (initial_r, contact_matrix, new_rs, serial_intervals) in zip(initial_r_range, contact_matrix_range, new_rs_range, serial_intervals_range):    
    # Initialise the determinsitc model
    m = MultiCatDeterministicBranchProModel(
        initial_r, serial_intervals, num_categories, contact_matrix, transmissibility, multipleSI=True)

    # Simulate the incidence
    m.set_r_profile(new_rs, start_times)
    times = np.arange(num_timepoints)

    desagg_cases = m.simulate(parameters, times, var_contacts=False)
    deterministic_cases.append(desagg_cases)

### Plot disagreggated local incidence numbers

In [6]:
for _, desagg_cases in enumerate(deterministic_cases):
    # Identify fate of epidemic
    if stability_criterion[_] < 1:
        fate = '(Epidemic Decays)'
    else:
        fate = '(Epidemic Grows)'

    # Plot (bar chart cases each day)
    fig = go.Figure()

    for cat in range(num_categories):
        # Plot of incidences for category
        fig.add_trace(
            go.Scatter(
                x=times,
                y=desagg_cases[:, cat],
                name=trace_names[cat],
                line_color=colour_names[cat]
            )
        )

    # Add axis labels
    fig.update_layout(
        width=600, 
        height=400,
        plot_bgcolor='white',
        title='Stability criterion: {:.2f} '.format(stability_criterion[_])  + fate,
        xaxis=dict(linecolor='black'),
        xaxis_title='Time (days)',
        yaxis=dict(linecolor='black'),
        yaxis_title='New infections'
    )

    fig.write_image('images/Det_SC_{}.pdf'.format(_+1))
    fig.show()

## Stochastic Model

In [7]:
stochastic_cases = []

for (initial_r, contact_matrix, new_rs, serial_intervals) in zip(initial_r_range, contact_matrix_range, new_rs_range, serial_intervals_range):    
    # Initialise the stochastic model with Poisson Noise
    m = branchpro.MultiCatPoissonBranchProModel(
        initial_r, serial_intervals, num_categories, contact_matrix, transmissibility, multipleSI=True)

    # Simulate the incidence
    m.set_r_profile(new_rs, start_times)
    times = np.arange(num_timepoints)

    desagg_cases = m.simulate(parameters, times, var_contacts=False)
    stochastic_cases.append(desagg_cases)

### Plot disagreggated local incidence numbers

In [8]:
for _, desagg_cases in enumerate(stochastic_cases):
    # Identify fate of epidemic
    if stability_criterion[_] < 1:
        fate = '(Epidemic Decays)'
    else:
        fate = '(Epidemic Grows)'

    # Plot (bar chart cases each day)
    fig = go.Figure()

    for cat in range(num_categories):
        # Plot of incidences for category
        fig.add_trace(
            go.Scatter(
                x=times,
                y=desagg_cases[:, cat],
                name=trace_names[cat],
                line_color=colour_names[cat]
            )
        )

    # Add axis labels
    fig.update_layout(
        width=600, 
        height=400,
        plot_bgcolor='white',
        title='Stability criterion: {:.2f} '.format(stability_criterion[_])  + fate,
        xaxis=dict(linecolor='black'),
        xaxis_title='Time (days)',
        yaxis=dict(linecolor='black'),
        yaxis_title='New infections'
    )

    fig.write_image('images/Stoc_SC_{}.pdf'.format(_+1))

    fig.show()

In [9]:
from plotly.subplots import make_subplots

titles = ['R<sub>t</sub> = {:.1f}'.format(s) for s in stability_criterion]
titles = titles[-3:]

# Plot (bar chart cases each day)
fig = go.Figure()
fig = make_subplots(
    rows=len(titles),
    cols=2,
    subplot_titles=tuple(titles[i//2] for i in range(len(titles)*2))
    )

for cat in range(num_categories):
    # Plot of incidences for category
    fig.add_trace(
        go.Scatter(
            x=times,
            y=deterministic_cases[4][:500, cat],
            name=trace_names[cat],
            line_color=colour_names[cat],
            showlegend=False
        ),
        row= 1,
        col= 1
    )

    fig.add_trace(
        go.Scatter(
            x=times[:500],
            y=stochastic_cases[4][:500, cat],
            name=trace_names[cat],
            line_color=colour_names[cat],
            showlegend=False
        ),
        row= 1,
        col= 2
    )

    # Plot of incidences for category
    fig.add_trace(
        go.Scatter(
            x=times,
            y=deterministic_cases[5][:500, cat],
            name=trace_names[cat],
            line_color=colour_names[cat]
        ),
        row= 2,
        col= 1
    )


    fig.add_trace(
        go.Scatter(
            x=times,
            y=stochastic_cases[5][:500, cat],
            name=trace_names[cat],
            line_color=colour_names[cat],
            showlegend=False
        ),
        row= 2,
        col= 2
    )

    # Plot of incidences for category
    fig.add_trace(
        go.Scatter(
            x=times,
            y=deterministic_cases[6][:250, cat],
            name=trace_names[cat],
            line_color=colour_names[cat],
            showlegend=False
        ),
        row= 3,
        col= 1
    )


    fig.add_trace(
        go.Scatter(
            x=times,
            y=stochastic_cases[6][:250, cat],
            name=trace_names[cat],
            line_color=colour_names[cat],
            showlegend=False
        ),
        row= 3,
        col= 2
    )

# Add axis labels
fig.update_layout(
    width=900, 
    height=800,
    plot_bgcolor='white',
    xaxis=dict(
        linecolor='black',
        title='Time (days)'),
    yaxis=dict(
        linecolor='black',
        title='New infections'),
    xaxis2=dict(
        linecolor='black',
        title='Time (days)'),
    yaxis2=dict(
        linecolor='black',
        title='New infections'),
    xaxis3=dict(
        linecolor='black',
        title='Time (days)'),
    yaxis3=dict(
        linecolor='black',
        title='New infections'),
    xaxis4=dict(
        linecolor='black',
        title='Time (days)'),
    yaxis4=dict(
        linecolor='black',
        title='New infections'),
    xaxis5=dict(
        linecolor='black',
        title='Time (days)'),
    yaxis5=dict(
        linecolor='black',
        title='New infections'),
    xaxis6=dict(
        linecolor='black',
        title='Time (days)'),
    yaxis6=dict(
        linecolor='black',
        title=dict(
            text='New infections',
            standoff=0.9))
)

fig.write_image('images/Stability.pdf')

fig.show()

In [10]:
# Build the serial interval w_s
serial_intervals_range = []

ws_mean_range = [
    [5.6, 15.3, 25],
    [15.3, 15.3, 15.3],
    [25, 15.3, 15],
]
ws_std_range = [
    [3.3, 3.3, 3.3],
    [3.3, 3.3, 3.3],
    [3.3, 3.3, 3.3],
]

for ws_mean_cat, ws_std_cat in zip(ws_mean_range, ws_std_range):
    serial_intervals = []
    for ws_mean, ws_std in zip(ws_mean_cat, ws_std_cat):
        theta = ws_std**2 / ws_mean
        k = ws_mean / theta
        w_dist = scipy.stats.gamma(k, scale=theta)
        
        serial_intervals.append(w_dist)
    serial_intervals_range.append(serial_intervals)

# Simulate incidence data
initial_r_range = [0.113, 0.113, 0.113]

# Select a range of different contact matrices
path = os.path.join('../../data_library/polymod/final_contact_matrices/', 'BASE.csv')
contact_matrix = pd.read_csv(path, header=None)

contact_matrix_range = [contact_matrix]*4

# Choose unique transmissibilityt vector for ease so
# effective contact matrix = contact matrix
transmissibility = [1, 1, 1]

# Select a range of different growth rates
new_rs_range = [
    [0.113],
    [0.113],
    [0.113],
]

start_times = [0]

# Initial number of cases
parameters = [50, 50, 50] 

In [11]:
def compute_Rt_from_growth_rate(r, serial_intervals, contact_matrix, transmissibility):
    eff_contact_matrix = np.matmul(contact_matrix, np.diag(transmissibility))
    integral = np.diagflat([scipy.integrate.quad(lambda x: si.pdf(x) * np.exp(-r * x), 0, 1000)[0] for si in serial_intervals])

    spec_radius_1 = np.max(np.absolute(np.linalg.eigvals(eff_contact_matrix)))
    spec_radius_2 = np.max(np.absolute(np.linalg.eigvals(np.matmul(eff_contact_matrix, integral))))
    return spec_radius_1 / spec_radius_2

In [12]:
trajectories = []

for _ in range(len(serial_intervals_range)):
    trajectories.append(
        [compute_Rt_from_growth_rate(r, serial_intervals_range[_], contact_matrix_range[_], transmissibility) for r in np.arange(-0.08, 0.08, 0.001)])

In [13]:
# Plot (bar chart cases each day)
fig = go.Figure()

colour_names += ['green']
trace_names = ['Shorter generation interval for 0-20 year-olds', 'Same generation intervals', 'Shorter generation interval for 65+ year-olds']

for _ in range(len(serial_intervals_range)):
    # Plot of incidences for category
    fig.add_trace(
        go.Scatter(
            x=trajectories[_],
            y=np.arange(-0.08, 0.08, 0.001),
            name=trace_names[_],
            line_color=colour_names[_]
        )
    )

fig.add_vline(
    x=1, line_dash='dot',
    annotation_text='R<sub>t</sub>=1', fillcolor='black',
    annotation_position='top right')

fig.add_hline(
    y=0, line_dash='dot',
    annotation_text='r=0', fillcolor='black',
    annotation_position='top right')

# Add axis labels
fig.update_layout(
    width=800, 
    height=400,
    plot_bgcolor='white',
    yaxis=dict(linecolor='black'),
    yaxis_title='Growth rate r (per day)',
    xaxis=dict(linecolor='black'),
    xaxis_title='Reproduction number R<sub>t</sub>',
    legend=dict(
        yanchor="bottom",
        y=0.13,
        xanchor="right",
        x=1.01,
        bordercolor="Black", 
        borderwidth=1)
)

fig.write_image('images/R_vs_growth_rate.pdf')

fig.show()

In [14]:
# Build the serial interval w_s
disc_serial_intervals_range = []
serial_intervals_range = []

ws_mean_range = [
    [3.3, 5, 7],
    [15.3, 15.3, 15.3],
    [30, 25.3, 15],
]
ws_std_range = [
    [3.3, 3.3, 3.3],
    [3.3, 3.3, 3.3],
    [3.3, 3.3, 3.3],
]

for ws_mean_cat, ws_std_cat in zip(ws_mean_range, ws_std_range):
    disc_serial_intervals = []
    serial_intervals = []
    for ws_mean, ws_std in zip(ws_mean_cat, ws_std_cat):
        theta = ws_std**2 / ws_mean
        k = ws_mean / theta
        w_dist = scipy.stats.gamma(k, scale=theta)
        disc_w = w_dist.pdf(np.arange(250))

        disc_serial_intervals.append(disc_w)
        serial_intervals.append(w_dist)
    serial_intervals_range.append(serial_intervals)
    disc_serial_intervals_range.append(np.array(disc_serial_intervals))

# Simulate incidence data
initial_r_range = [0.0825, 1/9.7262, 0.123]

# Select a range of different contact matrices
path = os.path.join('../../data_library/polymod/final_contact_matrices/', 'BASE.csv')
contact_matrix = pd.read_csv(path, header=None)

contact_matrix_range = [contact_matrix]*4

# Choose unique transmissibilityt vector for ease so
# effective contact matrix = contact matrix
transmissibility = [1, 1, 1]

# Select a range of different growth rates
new_rs_range = [
    [0.0825],
    [1/9.7262],
    [0.123],
]

start_times = [0]

# Initial number of cases
parameters = [20, 20, 20]

In [15]:
deterministic_cases = []
times = np.arange(1000)

for (initial_r, contact_matrix, new_rs) in zip(initial_r_range, contact_matrix_range, new_rs_range):
    deterministic_cases_per_regime = []
    for serial_intervals in disc_serial_intervals_range:    
        # Initialise the determinsitc model
        m = MultiCatDeterministicBranchProModel(
            initial_r, serial_intervals, num_categories, contact_matrix, transmissibility, multipleSI=True)

        # Simulate the incidence
        m.set_r_profile(new_rs, start_times)

        desagg_cases = m.simulate(parameters, times, var_contacts=False)
        deterministic_cases_per_regime.append(np.sum(desagg_cases, axis=1))
    
    deterministic_cases.append(deterministic_cases_per_regime)

In [16]:
def compute_Rt(gamma, contact_matrix, transmissibility):
    eff_contact_matrix = np.matmul(contact_matrix, np.diag(transmissibility))

    spec_radius_1 = np.max(np.absolute(np.linalg.eigvals(eff_contact_matrix)))
    return gamma * spec_radius_1

In [17]:
stability_criterion = []

for _ in range(len(initial_r_range)):
    stability_criterion.append(compute_Rt(new_rs_range[_][-1], contact_matrix_range[_], transmissibility))

In [18]:
from plotly.subplots import make_subplots

titles = ['R<sub>t</sub> = {:.1f}'.format(sc) for sc in stability_criterion]

trace_names = ['{:.1f}'.format(si_mean) for si_mean in np.mean(ws_mean_range, axis=1)]

# Plot (bar chart cases each day)
fig = go.Figure()
fig = make_subplots(
    rows=2,
    cols=2,
    subplot_titles=tuple(titles + ['']))


for si in range(len(disc_serial_intervals_range)):
    # Plot of incidences for category
    fig.add_trace(
        go.Scatter(
            x=times[1:500],
            y=deterministic_cases[0][si][1:500],
            name=trace_names[si],
            line_color=colour_names[si],
            showlegend=False
        ),
        row= 1,
        col= 1
    )


    fig.add_trace(
        go.Scatter(
            x=times[1:500],
            y=deterministic_cases[1][si][1:500],
            name=trace_names[si],
            line_color=colour_names[si],
            showlegend=False
        ),
        row= 1,
        col= 2
    )

    fig.add_trace(
        go.Scatter(
            x=times[10:500],
            y=deterministic_cases[2][si][10:500],
            name=trace_names[si],
            line_color=colour_names[si]
        ),
        row= 2,
        col= 1
    )

fig.update_yaxes(type='log', row=2, col=1)

# Add axis labels
fig.update_layout(
    width=800, 
    height=600,
    plot_bgcolor='white',
    xaxis=dict(
        linecolor='black',
        title='Time (days)'),
    yaxis=dict(
        linecolor='black',
        title='New infections'),
    xaxis2=dict(
        linecolor='black',
        title='Time (days)'),
    yaxis2=dict(
        linecolor='black',
        title='New infections'),
    xaxis3=dict(
        linecolor='black',
        title='Time (days)'),
    yaxis3=dict(
        linecolor='black',
        title='New infections',
        exponentformat = 'power'),
    legend=dict(
            title='Generation interval mean',
            yanchor="bottom",
            y=0.17,
            xanchor="right",
            x=0.9,
            bordercolor="Black", 
            borderwidth=1)
)

fig.write_image('images/Different-SI-regimes.pdf')

fig.show()

In [19]:
growth_rates = []

for _ in range(len(initial_r_range)):
    growth_rate_per_regime = []
    
    for si in range(len(serial_intervals_range)):
        regression = scipy.optimize.curve_fit(f=lambda t, r: np.log(1+r) * t, xdata=times[500:]-times[499], ydata=np.log(deterministic_cases[_][si][500:])-np.log(deterministic_cases[_][si][499]), maxfev=1000)
        growth_rate_per_regime.append(regression[0][0])

    growth_rates.append(growth_rate_per_regime)


Covariance of the parameters could not be estimated



In [30]:
theoretical_growth_rates = []

for _ in range(len(initial_r_range)):
    roots_per_Rt = []
    for ws_mean in np.linspace(2, 35, 200):
        ws_std = 3.3
        theoretical_serial_intervals = []
        for n in range(num_categories):
            theta = ws_std**2 / ws_mean
            k = ws_mean / theta
            w_dist = scipy.stats.gamma(k, scale=theta)

            theoretical_serial_intervals.append(w_dist)
        
        root = scipy.optimize.minimize(
            fun=lambda r: np.abs(compute_Rt_from_growth_rate(r, theoretical_serial_intervals, contact_matrix_range[_], transmissibility) - stability_criterion[_]),
            x0=0, bounds=[(-0.1, 0.1)],
            tol=0.0001)

        # print(root)

        roots_per_Rt.append(root.x)

    theoretical_growth_rates.append(roots_per_Rt)

In [32]:
from plotly.subplots import make_subplots

titles = ['{:.1f}'.format(sc) for sc in stability_criterion]

# Plot (bar chart cases each day)
fig = go.Figure()

for _ in range(len(titles)):
    # Plot of incidences for category
    fig.add_trace(
        go.Scatter(
            x=np.asarray(ws_mean_range)[:, 0].tolist(),
            y=growth_rates[_],
            mode='markers',
            name=titles[_],
            line_color=colour_names[_],
            showlegend=False
        ),
    )

    fig.add_trace(
        go.Scatter(
            x=np.linspace(2, 35, 200)[1:].tolist(),
            y=np.asarray(theoretical_growth_rates[_]).reshape(200)[1:].tolist(),
            name=titles[_],
            mode='lines',
            line_color=colour_names[_]
        ),
    )

# Add axis labels
fig.update_layout(
    width=600, 
    height=400,
    plot_bgcolor='white',
    yaxis=dict(
        linecolor='black',
        title='Growth rate r (per day)'),
    xaxis=dict(
        linecolor='black',
        title='Generation interval mean (days)'),
    legend=dict(
            title='R<sub>t</sub>',
            yanchor="bottom",
            y=0.72,
            xanchor="right",
            x=1.01,
            bordercolor="Black", 
            borderwidth=1)
    
)

fig.write_image('images/R-curves.pdf')

fig.show()

In [22]:
young_theoretical_growth_rates = []

for _ in range(len(initial_r_range)):
    roots_per_Rt = []
    for ws_mean in np.linspace(2, 28.6, 200):
        ws_mean_ = [ws_mean, 15.3, 15.3]
        ws_std = 3.3
        theoretical_serial_intervals = []
        for n in range(num_categories):
            theta = ws_std**2 / ws_mean_[n]
            k = ws_mean_[n] / theta
            w_dist = scipy.stats.gamma(k, scale=theta)

            theoretical_serial_intervals.append(w_dist)
        
        root = scipy.optimize.minimize(
            fun=lambda r: np.abs(compute_Rt_from_growth_rate(r, theoretical_serial_intervals, contact_matrix_range[_], transmissibility) - stability_criterion[_]),
            x0=0, bounds=[(-0.1, 0.1)],
            tol=0.0001)

        # print(root)

        roots_per_Rt.append(root.x)

    young_theoretical_growth_rates.append(roots_per_Rt)

In [23]:
old_theoretical_growth_rates = []

for _ in range(len(initial_r_range)):
    roots_per_Rt = []
    for ws_mean in np.linspace(2, 28.6, 200):
        ws_mean_ = [15.3, 15.3, ws_mean]
        ws_std = 3.3
        theoretical_serial_intervals = []
        for n in range(num_categories):
            theta = ws_std**2 / ws_mean_[n]
            k = ws_mean_[n] / theta
            w_dist = scipy.stats.gamma(k, scale=theta)

            theoretical_serial_intervals.append(w_dist)
        
        root = scipy.optimize.minimize(
            fun=lambda r: np.abs(compute_Rt_from_growth_rate(r, theoretical_serial_intervals, contact_matrix_range[_], transmissibility) - stability_criterion[_]),
            x0=0, bounds=[(-0.1, 0.1)],
            tol=0.0001)

        # print(root)

        roots_per_Rt.append(root.x)

    old_theoretical_growth_rates.append(roots_per_Rt)

In [24]:
titles = ['{:.2f}'.format(sc) for sc in stability_criterion]

# Plot (bar chart cases each day)
fig = go.Figure()
fig = make_subplots(
    rows=1,
    cols=2,
    shared_yaxes=True)

for _ in range(len(titles)):
    # Plot of incidences for category
    fig.add_trace(
        go.Scatter(
            x=np.linspace(2, 28.6, 200).tolist(),
            y=np.asarray(young_theoretical_growth_rates[_]).reshape(200).tolist(),
            name=titles[_],
            mode='lines',
            line_color=colour_names[_],
        ),
        row=1, col=1
    )

    fig.add_trace(
        go.Scatter(
            x=np.linspace(2, 28.6, 200).tolist(),
            y=np.asarray(old_theoretical_growth_rates[_]).reshape(200).tolist(),
            name=titles[_],
            mode='lines',
            line_color=colour_names[_],
            showlegend=False
        ),
        row=1, col=2
    )

# Add axis labels
fig.update_layout(
    width=800, 
    height=400,
    plot_bgcolor='white',
    xaxis=dict(
        linecolor='black',
        title='Generation interval mean<br>in the 0-20 year-olds (days)'),
    yaxis=dict(
        linecolor='black',
        title='Growth rate r (per day)'),
    xaxis2=dict(
        linecolor='black',
        title='Generation interval mean<br>in the 65+ year-olds (days)'),
    yaxis2=dict(
        linecolor='black',
        showticklabels=True),
    margin=dict(
        l=20,
        r=20,
        b=20,
        t=20,
        pad=4),
    legend=dict(
            title='R<sub>t</sub>',
            orientation="h",
            yanchor="bottom",
            y=0.95,
            xanchor="right",
            x=0.96,
            bordercolor="Black", 
            borderwidth=1)
)

fig.write_image('images/Short-SI-effect.pdf')

fig.show()

## Same growth rate in all groups

In [25]:
# Build the serial interval w_s
short_serial_intervals_range = []

short_ws_mean_range = [
    [3.3, 5, 7],
    [5.1, 5.1, 5.1],
    [5.6, 15.3, 25],
    [15.3, 15.3, 15.3]
]
short_ws_std_range = [
    [3.3, 3.3, 3.3],
    [3.3, 3.3, 3.3],
    [3.3, 3.3, 3.3],
    [3.3, 3.3, 3.3]
]

for ws_mean_cat, ws_std_cat in zip(short_ws_mean_range, short_ws_std_range):
    disc_serial_intervals = []
    for ws_mean, ws_std in zip(ws_mean_cat, ws_std_cat):
        theta = ws_std**2 / ws_mean
        k = ws_mean / theta
        w_dist = scipy.stats.gamma(k, scale=theta)
        disc_w = w_dist.pdf(np.arange(250))

        disc_serial_intervals.append(disc_w)
    short_serial_intervals_range.append(np.array(disc_serial_intervals))

# Simulate incidence data
initial_r_range = [0.123, 0.1028, 0.082]

# Select a range of different contact matrices
path = os.path.join('../../data_library/polymod/final_contact_matrices/', 'BASE.csv')
contact_matrix = pd.read_csv(path, header=None)

contact_matrix_range = [contact_matrix]*6

# Choose unique transmissibilityt vector for ease so
# effective contact matrix = contact matrix
transmissibility = [1, 1, 1]

# Select a range of different growth rates
new_rs_range = [
    [0.123],
    [0.1028],
    [0.082]
]

start_times = [0]

# Initial number of cases
parameters = [50, 50, 50] 

In [26]:
short_deterministic_cases = []
times = np.arange(1000)

for (initial_r, contact_matrix, new_rs) in zip(initial_r_range, contact_matrix_range, new_rs_range):
    short_deterministic_cases_per_regime = []
    for serial_intervals in short_serial_intervals_range:    
        # Initialise the determinsitc model
        m = MultiCatDeterministicBranchProModel(
            initial_r, serial_intervals, num_categories, contact_matrix, transmissibility, multipleSI=True)

        # Simulate the incidence
        m.set_r_profile(new_rs, start_times)

        desagg_cases = m.simulate(parameters, times, var_contacts=False)
        short_deterministic_cases_per_regime.append(desagg_cases)
    
    short_deterministic_cases.append(short_deterministic_cases_per_regime)

In [27]:
short_stability_criterion = []

for (contact_matrix, new_rs) in zip(contact_matrix_range, new_rs_range):   
    # Compute the eigenvalues of the maximum effective contact matrix
    eff_contact_matrix = np.matmul(contact_matrix, np.diag(transmissibility))

    spec_radius = np.max(np.absolute(np.linalg.eigvals(eff_contact_matrix)))

    # Compute stability criterion using last change in growth rate
    short_stability_criterion.append(new_rs[-1] * spec_radius)

print(short_stability_criterion)

[1.1963225056431999, 0.9998532811391947, 0.7975483370954666]


In [28]:
row_name = ['R<sub>t</sub> = {:.1f}'.format(r) for r in short_stability_criterion]
col_name = ['Short generation<br>time interval', 'Long generation<br>time interval']
trace_names = ['Equal', 'Varied']

# Plot (bar chart cases each day)
fig = go.Figure()
fig = make_subplots(
    rows=len(row_name),
    cols=2,
    start_cell="bottom-left",
    column_titles=col_name,
    row_titles=row_name
    )

for r in range(len(row_name)):
    for c in range(len(col_name)):
        for _ in range(len(trace_names)):
            if (c != 0) or (r != 0):
                # Plot of incidences for category
                fig.add_trace(
                    go.Scatter(
                        x=times[1:100],
                        y=np.sum(short_deterministic_cases[r][2*c+1-_][1:100, :], axis=1).tolist(),
                        name=trace_names[_],
                        line_color=colour_names[_],
                        showlegend=False
                    ),
                    row= r+1,
                    col= c+1
                )

            else:
                # Plot of incidences for category
                fig.add_trace(
                    go.Scatter(
                        x=times[1:100],
                        y=np.sum(short_deterministic_cases[r][2*c+1-_][1:100, :], axis=1).tolist(),
                        name=trace_names[_],
                        line_color=colour_names[_]
                    ),
                    row= r+1,
                    col= c+1
                )

# Add axis labels
fig.update_layout(
    width=900, 
    height=800,
    plot_bgcolor='white',
    xaxis=dict(
        linecolor='black',
        title='Time (days)'),
    yaxis=dict(
        linecolor='black',
        title='New infections'),
    xaxis2=dict(
        linecolor='black',
        title='Time (days)'),
    yaxis2=dict(
        linecolor='black',
        title='New infections'),
    xaxis3=dict(
        linecolor='black',
        title='Time (days)'),
    yaxis3=dict(
        linecolor='black',
        title='New infections'),
    xaxis4=dict(
        linecolor='black',
        title='Time (days)'),
    yaxis4=dict(
        linecolor='black',
        title='New infections'),
    xaxis5=dict(
        linecolor='black',
        title='Time (days)'),
    yaxis5=dict(
        linecolor='black',
        title='New infections'),
    xaxis6=dict(
        linecolor='black',
        title='Time (days)'),
    yaxis6=dict(
        linecolor='black',
        title='New infections'),
    legend=dict(
            title='Generation time<br>distribution across<br>categories',
            yanchor="bottom",
            y=0.4,
            xanchor="right",
            x=0.35,
            bordercolor="Black", 
            borderwidth=1)
)

fig.write_image('images/Different-SI-regimes-NEW.pdf')

fig.show()

In [29]:
trace_names = ['0-20 years old', '20-65 years old', '65+ years old']

# Plot (bar chart cases each day)
fig = go.Figure()
fig = make_subplots(
    rows=1,
    cols=2,
    subplot_titles=('Same generation time<br>interval in each group', 'Different generation time<br>interval in each group'),
    shared_yaxes=True)

for _ in range(len(trace_names)):
    # Plot of incidences for category
    fig.add_trace(
        go.Scatter(
            x=times[1:500],
            y=short_deterministic_cases[0][-2][1:500, _],
            name=trace_names[_],
            mode='lines',
            line_color=colour_names[_]
        ),
        row=1, col=2
    )

    fig.add_trace(
        go.Scatter(
            x=times[1:500],
            y=short_deterministic_cases[0][-1][1:500, _],
            name=trace_names[_],
            mode='lines',
            line_color=colour_names[_],
            showlegend=False
        ),
        row=1, col=1
    )

fig.update_yaxes(type='log', row=1, col=1)
fig.update_yaxes(type='log', row=1, col=2)

# Add axis labels
fig.update_layout(
    width=800, 
    height=400,
    plot_bgcolor='white',
    xaxis=dict(
        linecolor='black',
        title='Time (days)'),
    yaxis=dict(
        linecolor='black',
        title='New infections',
        exponentformat = 'power',
        range=[-5,5]),
    xaxis2=dict(
        linecolor='black',
        title='Time (days)'),
    yaxis2=dict(
        linecolor='black',
        exponentformat = 'power',
        range=[-5,5]),
    legend=dict(
            yanchor="bottom",
            y=0.2,
            xanchor="right",
            x=0.96,
            bordercolor="Black", 
            borderwidth=1)
)

fig.write_image('images/Formula_proof.pdf')

fig.show()