# PHE SEIR Model

In this notebook we present how to use the `epimodels` module to set up an instantiation of the model built by Public Health England in collaboration with University of Cambridge, using some toy data.

In [1]:
# Load necessary libraries
import numpy as np
import epimodels as em
import matplotlib
import plotly.graph_objects as go
from matplotlib import pyplot as plt
from iteration_utilities import deepflatten

### Define setup matrices for the PHE Model

In [2]:
# Populate the model
regions = ['London', 'Cornwall']
age_groups = ['0-10', '10-25']

# Initial state of the system
contact_data_matrix_0 = np.array([[1, 0], [0, 3]])
contact_data_matrix_1 = np.array([[10, 5.2], [0, 3]])

region_data_matrix_0_0 = np.array([[0.5, 0], [0, 6]])
region_data_matrix_0_1 = np.array([[1, 10], [1, 0]])
region_data_matrix_1_0 = np.array([[0.5, 1.2], [0.29, 6]])
region_data_matrix_1_1 = np.array([[0.85, 1], [0.9, 6]])

contacts_0 = em.ContactMatrix(age_groups, contact_data_matrix_0)
contacts_1 = em.ContactMatrix(age_groups, contact_data_matrix_1)
regional_0_0 = em.RegionMatrix(
    regions[0], age_groups, region_data_matrix_0_0)
regional_0_1 = em.RegionMatrix(
    regions[1], age_groups, region_data_matrix_0_1)
regional_1_0 = em.RegionMatrix(
    regions[0], age_groups, region_data_matrix_1_0)
regional_1_1 = em.RegionMatrix(
    regions[1], age_groups, region_data_matrix_1_1)

# Matrices contact
matrices_contact = [contacts_0, contacts_1]
time_changes_contact = [1, 3]
matrices_region = [
    [regional_0_0, regional_0_1],
    [regional_1_0, regional_1_1]]
time_changes_region = [1, 2]

### Set the parameters and initial conditions of the model and bundle everything together to simulate

In [3]:
# Instantiate model
model = em.PheSEIRModel()

# Set the region names, contact and regional data of the model
model.set_regions(regions)
model.read_contact_data(matrices_contact, time_changes_contact)
model.read_regional_data(matrices_region, time_changes_region)

# Initial number of susceptibles
susceptibles = [[5, 6], [7, 8]]
dI = 4
initial_r = [0.5, 1]

# List of times at which we wish to evaluate the states of the compartments of the model
times = np.arange(1, 20.5, 0.5).tolist()

# List of initial conditions and parameters that characterise the model
parameters = [
    initial_r, 2, susceptibles, [[0.4, 0.2], [0.1, 0]], [[0.05, 0.2], [0.5, 0.32]],
    [[0, 0], [0, 0]], [[0, 0], [0, 0]], [[0, 0], [0, 0]],
    [[1]*len(times), [1]*len(times)], 4, dI, 0.5]

# Simulate using the ODE solver from scipy
scipy_method = 'RK45'
parameters.append(scipy_method)

output_scipy_solver = model.simulate(list(deepflatten(parameters, ignore=str)), times)

# Use different time steps for personalised solver
outputs_my_solver = []
time_steps = [0.5, 0.25, 0.05, 10**(-3)]

parameters[-1] = 'my-solver'

for ts in time_steps:
    # Update value of time step in parameters vector
    parameters[-2] = ts

    # Simulate using the 'homemade' discretised version of the ODE solver
    outputs_my_solver.append(model.simulate(list(deepflatten(parameters, ignore=str)), times))

### Plot the comparments of the two methods against each other

In [4]:
from plotly.subplots import make_subplots

colours = ['blue', 'red', 'green', 'purple', 'orange', 'black', 'gray', 'pink']

# Group outputs together
outputs = outputs_my_solver
outputs.append(output_scipy_solver)

In [5]:
# Trace names - represent the solver used for the simulation
trace_name = ['my-solver with delta_t = {}'.format(ts) for ts in time_steps]
trace_name.append('scipy-solver {}'.format(scipy_method))

# Compartment list - type and age
comparments = []
for n in model.output_names():
    comparments.append('{} in region {}'.format(n, regions[parameters[len(regions)-1]-1]))

for c, comparment in enumerate(comparments):
    # Plot (scatter plot for each comparment)
    fig = go.Figure()
    fig = make_subplots(rows=len(age_groups), cols=1, subplot_titles=tuple('ages {}'.format(a) for a in age_groups))

    for a, age in enumerate(age_groups):
        if a != 0:
            for o, out in enumerate(outputs):
                fig.add_trace(
                    go.Scatter(
                        y=out[:, c*len(age_groups)+a],
                        x=times,
                        mode='lines',
                        name=trace_name[o],
                        line_color=colours[o],
                        showlegend=False
                    ),
                    row=a+1, col=1
                )
        
        else:
            for o, out in enumerate(outputs):
                fig.add_trace(
                    go.Scatter(
                        y=out[:, c*len(age_groups)+a],
                        x=times,
                        mode='lines',
                        name=trace_name[o],
                        line_color=colours[o]
                    ),
                    row=a+1, col=1
                )

    # Add axis labels
    fig.update_layout(
        boxmode='group',
        title=comparment, 
        height=800,
        plot_bgcolor='white',
        xaxis=dict(linecolor='black'),
        yaxis=dict(linecolor='black'),
        xaxis2=dict(linecolor='black'),
        yaxis2=dict(linecolor='black'),
        xaxis3=dict(linecolor='black'),
        yaxis3=dict(linecolor='black'),
        xaxis4=dict(linecolor='black'),
        yaxis4=dict(linecolor='black'),
        xaxis5=dict(linecolor='black'),
        yaxis5=dict(linecolor='black'),
        xaxis6=dict(linecolor='black'),
        yaxis6=dict(linecolor='black'),
        xaxis7=dict(linecolor='black'),
        yaxis7=dict(linecolor='black'),
        xaxis8=dict(linecolor='black'),
        yaxis8=dict(linecolor='black'))

    fig.show()

### Number of new infections over all time for chosen region

In [6]:
model.new_infections(output=output_scipy_solver)

array([[0.00000000e+00, 0.00000000e+00],
       [6.22228631e-08, 6.19597046e-06],
       [1.57262042e-07, 1.18014257e-05],
       [2.62760690e-07, 1.65028393e-05],
       [9.46073228e-06, 2.01938274e-05],
       [1.11190598e-05, 2.28972375e-05],
       [1.23480017e-05, 2.47139684e-05],
       [1.31733288e-05, 2.57769599e-05],
       [1.36403282e-05, 2.62238639e-05],
       [1.38126981e-05, 2.62013093e-05],
       [1.37446371e-05, 2.58230219e-05],
       [1.34886697e-05, 2.51877360e-05],
       [1.30989879e-05, 2.43855340e-05],
       [1.26160982e-05, 2.34803435e-05],
       [1.20692227e-05, 2.25157753e-05],
       [1.14877408e-05, 2.15307830e-05],
       [1.08931701e-05, 2.05502411e-05],
       [1.03003704e-05, 1.95911343e-05],
       [9.71913144e-06, 1.86644106e-05],
       [9.15805853e-06, 1.77767976e-05],
       [8.62112176e-06, 1.69285650e-05],
       [8.11109493e-06, 1.61202888e-05],
       [7.62985790e-06, 1.53532541e-05],
       [7.17855584e-06, 1.46280745e-05],
       [6.755981

### Log-likelihood of observed number of deaths 

In [7]:
model.new_infections(output=output_scipy_solver)

obs_death = [10, 12]
fatality_ratio = [0.1, 0.5]
time_to_death = [0.5] * len(times)

model.loglik_deaths(
    obs_death, output_scipy_solver, fatality_ratio,
    time_to_death, 0.5, 1)

array([-26.63865896, -21.4215394 ])

### Log-likelihood of observed number of positive results 

In [8]:
obs_pos = [10, 12]
tests = [20, 30]
sens = 0.7
spec = 0.95

model.loglik_positive_tests(obs_pos, output_scipy_solver, tests, sens, spec, 10)

array([-18.34346436, -18.59648994])

In [9]:
# Trace names - represent the solver used for the simulation
trace_name = ['my-solver with delta_t = {}'.format(ts) for ts in time_steps]
trace_name.append('scipy-solver {}'.format(scipy_method))

# Compartment list - type and age
fig = go.Figure()

for o, out in enumerate(outputs):
    fig.add_trace(
        go.Scatter(
            y=out[:, 3*len(age_groups)],
            x=times,
            mode='lines',
            name=trace_name[o],
            line_color=colours[o]
        )
    )

# Add axis labels
fig.update_layout(
    boxmode='group',
    title='I1 in region {} for ages {}'. format(regions[parameters[len(regions)-1]-1], age_groups[0]),
    plot_bgcolor='white',
    xaxis=dict(linecolor='black'),
    yaxis=dict(linecolor='black'))

fig.write_image('images/Toy.pdf')
fig.show()