# PHE SEIR Model First Example

In this notebook we present how to use the `epimodels` module to set up an instantiation of the model built by Public Health England in collaboration with University of Cambridge, using some toy data. We assess the differences in the qualtiy of simulation using two different ODE solvers:
 - the bespoke one as presented in the PHE paper (`my-solver`) for different sizes of time steps;
 - the `scipy` solver using the ``RK45`` method.

*The PHE model is built by Public Health England in collaboration with University of Cambridge.*

In [1]:
# Load necessary libraries
import numpy as np
import epimodels as em
import matplotlib
import plotly.graph_objects as go
from matplotlib import pyplot as plt
from iteration_utilities import deepflatten

## Model Setup
### Define setup matrices for the PHE Model

In [2]:
# Populate the model
regions = ['London', 'Cornwall']
age_groups = ['0-10', '10-25']

# Initial state of the system
contact_data_matrix_0 = np.array([[1, 0], [0, 3]])
contact_data_matrix_1 = np.array([[10, 5.2], [0, 3]])

region_data_matrix_0_0 = np.array([[0.5, 0], [0, 6]])
region_data_matrix_0_1 = np.array([[1, 10], [1, 0]])
region_data_matrix_1_0 = np.array([[0.5, 1.2], [0.29, 6]])
region_data_matrix_1_1 = np.array([[0.85, 1], [0.9, 6]])

contacts_0 = em.ContactMatrix(age_groups, contact_data_matrix_0)
contacts_1 = em.ContactMatrix(age_groups, contact_data_matrix_1)
regional_0_0 = em.RegionMatrix(
    regions[0], age_groups, region_data_matrix_0_0)
regional_0_1 = em.RegionMatrix(
    regions[1], age_groups, region_data_matrix_0_1)
regional_1_0 = em.RegionMatrix(
    regions[0], age_groups, region_data_matrix_1_0)
regional_1_1 = em.RegionMatrix(
    regions[1], age_groups, region_data_matrix_1_1)

# Matrices contact
matrices_contact = [contacts_0, contacts_1]
time_changes_contact = [1, 3]
matrices_region = [
    [regional_0_0, regional_0_1],
    [regional_1_0, regional_1_1]]
time_changes_region = [1, 2]

### Set the parameters and initial conditions of the model and bundle everything together

In [3]:
# Instantiate model
model = em.PheSEIRModel()

# Set the region names, contact and regional data of the model
model.set_regions(regions)
model.set_age_groups(age_groups)
model.read_contact_data(matrices_contact, time_changes_contact)
model.read_regional_data(matrices_region, time_changes_region)

# Set regional and time dependent parameters
regional_parameters = em.PheRegParameters(
    model=model,
    initial_r=[0.5, 1],
    region_index=1,
    betas=[[1] * np.arange(1, 20.5, 0.5).shape[0], [1] * np.arange(1, 20.5, 0.5).shape[0]],
    times=np.arange(1, 20.5, 0.5).tolist()
)

# Set ICs parameters
ICs = em.PheICs(
    model=model,
    susceptibles_IC=[[500, 600], [700, 800]],
    exposed1_IC=[[0, 0], [0, 0]],
    exposed2_IC=[[0, 0], [0, 0]],
    infectives1_IC=[[5, 20], [50, 32]],
    infectives2_IC=[[40, 20], [10, 0]],
    recovered_IC=[[0, 0], [0, 0]]
)

# Set disease-specific parameters
disease_parameters = em.PheDiseaseParameters(
    model=model,
    dL=4,
    dI=4
)

# Set other simulation parameters
simulation_parameters = em.PheSimParameters(
    model=model,
    delta_t=0.5,
    method='RK45'
)

# Set all parameters in the controller
parameters = em.PheParametersController(
    model=model,
    regional_parameters=regional_parameters,
    ICs=ICs,
    disease_parameters=disease_parameters,
    simulation_parameters=simulation_parameters
)

### Simulate for one of the regions: **London**

In [4]:
# Simulate using the ODE solver from scipy
scipy_method = 'RK45'

parameters.simulation_parameters.method = scipy_method
output_scipy_solver = model.simulate(parameters)

# Use different time steps for personalised solver
outputs_my_solver = []
time_steps = [0.5, 0.25, 0.05, 10**(-3)]

parameters.simulation_parameters.method = 'my-solver'

for ts in time_steps:
    # Update value of time step in parameters vector
    parameters.simulation_parameters.delta_t = ts

    # Simulate using the 'homemade' discretised version of the ODE solver
    outputs_my_solver.append(model.simulate(parameters))


## Plot the comparments of the two methods against each other
### Setup ``plotly`` and default settings for plotting

In [5]:
from plotly.subplots import make_subplots

colours = ['blue', 'red', 'green', 'purple', 'orange', 'black', 'gray', 'pink']

# Group outputs together
outputs = outputs_my_solver
outputs.append(output_scipy_solver)

### Plot the comparments of the two methods against each other

In [6]:
# Trace names - represent the solver used for the simulation
trace_name = ['my-solver with delta_t = {}'.format(ts) for ts in time_steps]
trace_name.append('scipy-solver {}'.format(scipy_method))

# Compartment list - type and age
comparments = []
for n in model.output_names():
    comparments.append('{} in region {}'.format(n, regions[parameters.regional_parameters.region_index-1]))

# Plot for each comparment
for c, comparment in enumerate(comparments):
    fig = go.Figure()
    fig = make_subplots(rows=int(np.ceil(len(age_groups)/2)), cols=2, subplot_titles=tuple('ages {}'.format(a) for a in age_groups))
    # Plot (line plot for each solver method for each age)
    for a, age in enumerate(age_groups):
        if a != 0:
            for o, out in enumerate(outputs):
                fig.add_trace(
                    go.Scatter(
                        y=out[:, c*len(age_groups)+a],
                        x=parameters.regional_parameters.times,
                        mode='lines',
                        name=trace_name[o],
                        line_color=colours[o],
                        showlegend=False
                    ),
                    row= int(np.floor(a / 2)) + 1,
                    col= a % 2 + 1
                )
        
        else:
            for o, out in enumerate(outputs):
                fig.add_trace(
                    go.Scatter(
                        y=out[:, c*len(age_groups)+a],
                        x=parameters.regional_parameters.times,
                        mode='lines',
                        name=trace_name[o],
                        line_color=colours[o]
                    ),
                    row= int(np.floor(a / 2)) + 1,
                    col= a % 2 + 1
                )

    # Add axis labels
    fig.update_layout(
        boxmode='group',
        title=comparment, 
        width=800,
        plot_bgcolor='white',
        xaxis=dict(linecolor='black'),
        yaxis=dict(linecolor='black'),
        xaxis2=dict(linecolor='black'),
        yaxis2=dict(linecolor='black'),
        xaxis3=dict(linecolor='black'),
        yaxis3=dict(linecolor='black'),
        xaxis4=dict(linecolor='black'),
        yaxis4=dict(linecolor='black'),
        xaxis5=dict(linecolor='black'),
        yaxis5=dict(linecolor='black'),
        xaxis6=dict(linecolor='black'),
        yaxis6=dict(linecolor='black'),
        xaxis7=dict(linecolor='black'),
        yaxis7=dict(linecolor='black'),
        xaxis8=dict(linecolor='black'),
        yaxis8=dict(linecolor='black'),
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=1
        ))

    fig.show()

### Plot the first infectious comparment for one region (London) and one age group (0-10)

In [7]:
# Trace names - represent the solver used for the simulation
trace_name = ['my-solver with delta_t = {}'.format(ts) for ts in time_steps]
trace_name.append('scipy-solver {}'.format(scipy_method))

# Compartment list - type and age
fig = go.Figure()

# Plot (line plot for each method of simulation)
for o, out in enumerate(outputs):
    fig.add_trace(
        go.Scatter(
            y=out[:, 3*len(age_groups)],
            x=parameters.regional_parameters.times,
            mode='lines',
            name=trace_name[o],
            line_color=colours[o]
        )
    )

# Add axis labels
fig.update_layout(
    boxmode='group',
    title='I1 in region {} for ages {}'. format(regions[0], age_groups[0]),
    plot_bgcolor='white',
    xaxis=dict(linecolor='black'),
    yaxis=dict(linecolor='black'),
    )

fig.write_image('images/Toy.pdf')
fig.show()
