# PHE SEIR Model Optimisation

In this notebook we present how to use the `epimodels` module to perform parameter optimisation for the initial R number and the transmission time-varying beta parameters specified by the PHE model, using time-dependent region-specific contact matrices.

The analysis is run for:
 - Dates: **15 Feb 2020** - **25 June 2020**;
 - PHE s of interest: **London**.

We use realistic serology and mortality data extracted from the REACT survey and GOV.UK data.

*The PHE model is built by Public Health England in collaboration with University of Cambridge.*

In [1]:
# Load necessary libraries
import os
import numpy as np
import pandas as pd
from scipy.stats import gamma, nbinom, norm
import epimodels as em
import matplotlib
import plotly.graph_objects as go
import plotly.express as px
from matplotlib import pyplot as plt
from iteration_utilities import deepflatten

# Set random seed
np.random.seed(27)

## Model Setup
### Define setup matrices for the PHE Model

In [2]:
# Populate the model
total_days =  132
regions = ['London']
age_groups = ['0-1', '1-5', '5-15', '15-25', '25-45', '45-65', '65-75', '75+']

weeks = list(range(1,int(np.ceil(total_days/7))+1))
matrices_region = []

### Variable
# Initial state of the system
for w in weeks:
    weeks_matrices_region = []
    for r in regions:
        path = os.path.join('../../data/final_contact_matrices/{}_W{}.csv'.format(r, w))
        region_data_matrix = pd.read_csv(path, header=None, dtype=np.float64)
        regional = em.RegionMatrix(r, age_groups, region_data_matrix)
        weeks_matrices_region.append(regional)

    matrices_region.append(weeks_matrices_region)

contacts = em.ContactMatrix(age_groups, np.ones((len(age_groups), len(age_groups))))
matrices_contact = [contacts]

# Matrices contact
time_changes_contact = [1]
time_changes_region = np.arange(1, total_days+1, 7).tolist()

### Fixed
## Initial state of the system
# weeks_matrices_region = []
# for r in regions:
#     path = os.path.join('../../data/final_contact_matrices/BASE.csv')
#     region_data_matrix = pd.read_csv(path, header=None, dtype=np.float64)
#     regional = em.RegionMatrix(r, age_groups, region_data_matrix)
#     weeks_matrices_region.append(regional)

# matrices_region.append(weeks_matrices_region)

# contacts = em.ContactMatrix(age_groups, np.ones((len(age_groups), len(age_groups))))
# matrices_contact = [contacts]

# # Matrices contact
# time_changes_contact = [1]
# time_changes_region = [1]

### Set the parameters and initial conditions of the model and bundle everything together

In [3]:
# Instantiate model
model = em.PheSEIRModel()

# Set the region names, age groups, contact and regional data of the model
model.set_regions(regions)
model.set_age_groups(age_groups)
model.read_contact_data(matrices_contact, time_changes_contact)
model.read_regional_data(matrices_region, time_changes_region)

# Initial number of susceptibles
path = os.path.join('../../data/england_population/England_population.csv')
total_susceptibles = np.loadtxt(path, dtype=int, delimiter=',').tolist()
susceptibles = total_susceptibles[1]

# Initial number of infectives
ICs_multiplier = 50
infectives1 = (ICs_multiplier * np.ones((len(regions), len(age_groups)))).tolist()

infectives2 = np.zeros((len(regions), len(age_groups))).tolist()

# List of times at which we wish to evaluate the states of the compartments of the model
times = np.arange(1, total_days+1, 1).tolist()

In [4]:
# Set regional and time dependent parameters
phe_model_var_regional_parameters = em.PheRegParameters(
    model=model,
    initial_r=[2.35],
    region_index=1,
    betas=np.ones((len(regions), len(times))).tolist(),
    times=times
)

# Set ICs parameters
phe_model_var_ICs = em.PheICs(
    model=model,
    susceptibles_IC=[susceptibles],
    exposed1_IC=np.zeros((len(regions), len(age_groups))).tolist(),
    exposed2_IC=np.zeros((len(regions), len(age_groups))).tolist(),
    infectives1_IC=infectives1,
    infectives2_IC=infectives2,
    recovered_IC=np.zeros((len(regions), len(age_groups))).tolist()
)

# Set disease-specific parameters
phe_model_var_disease_parameters = em.PheDiseaseParameters(
    model=model,
    dL=4,
    dI=4
)

# Set other simulation parameters
phe_model_var_simulation_parameters = em.PheSimParameters(
    model=model,
    delta_t=0.5,
    method='RK45'
)

# Set all parameters in the controller
parameters = em.PheParametersController(
    model=model,
    regional_parameters=phe_model_var_regional_parameters,
    ICs=phe_model_var_ICs,
    disease_parameters=phe_model_var_disease_parameters,
    simulation_parameters=phe_model_var_simulation_parameters
)

## Read Death and Serology data

In [5]:
# Read in death and positive data from external files
deaths_data = []
positives_data = []
tests = []

for region in regions:
    deaths_data.append(np.loadtxt('../../data/death_data/{}_deaths.csv'.format(region), dtype=int, delimiter=','))
    positives_data.append(np.loadtxt('../../data/serology_data/{}_positives_nhs.csv'.format(region), dtype=int, delimiter=','))
    tests.append(np.loadtxt('../../data/serology_data/{}_tests_nhs.csv'.format(region), dtype=int, delimiter=','))

In [6]:
# Select the time points for which the death and serology data is known
deaths_times = np.arange(27, total_days+1, 1).tolist()
serology_times = np.arange(80, total_days+1, 7).tolist()

In [7]:
# Set time-to-death using a Gamma distribution using the mean and standard deviation from the PHE paper
td_mean = 15.0
td_var = 12.1**2
theta = td_var / td_mean
k = td_mean / theta
time_to_death = gamma(k, scale=theta).pdf(np.arange(1, 31)).tolist()

# Set information
fatality_ratio = pd.read_csv('../../data/fatality_ratio_data/IFR.csv', usecols=['ifr'], dtype=np.float64)['ifr'].values.tolist()
time_to_death.extend([0.0] * (len(times)-30))
niu = float(gamma.rvs(1, scale=1/0.2, size=1))

sens = 0.7
spec = 0.95

## Optimisation Procedure

In [8]:
## Initialise optimisation for the model
phe_optimisation = em.inference.PheSEIRInfer(model)

# Add model, death and tests data to the optimisation structure
phe_optimisation.read_model_data(susceptibles, infectives1)
phe_optimisation.read_deaths_data(deaths_data, deaths_times, time_to_death, fatality_ratio)
phe_optimisation.read_serology_data(tests, positives_data, serology_times, sens, spec)

# Run optimisation structure
found, log_post_value = phe_optimisation.optimisation_problem_setup(times, wd=1, wp=0)

Maximising LogPDF
Using Covariance Matrix Adaptation Evolution Strategy (CMA-ES)
Running in sequential mode.
Population size: 12
Iter. Eval.  Best      Current   Time m:s
0     12     -7380.576 -7380.576   0:01.2
1     24     -7140.13  -7140.13    0:02.2
2     36     -6797.405 -6797.405   0:03.3
3     48     -6356.834 -6356.834   0:04.4


  np.log(x[r*LEN+d+1]),
  loc=np.log(x[r*LEN+d]),


20    252    -2996.268 -3022.637   0:25.3
40    492    -2894.237 -2993.773   0:50.0
60    732    -2889.036 -2889.036   1:17.0
80    972    -2857.743 -2857.743   1:43.3
100   1212   -2811.529 -2811.529   2:07.5
120   1452   -2794.775 -2794.775   2:27.9
140   1692   -2791.745 -2791.914   2:48.0
160   1932   -2791.062 -2791.062   3:08.0
180   2172   -2790.474 -2790.504   3:27.3
200   2412   -2790.114 -2790.129   3:46.3
220   2652   -2789.728 -2789.728   4:04.5
240   2892   -2789.48  -2789.535   4:22.4
260   3132   -2789.343 -2789.343   4:40.0
280   3372   -2789.304 -2789.325   4:57.4
298   3576   -2789.3   -2789.3     5:12.2
Halting: No significant change for 100 iterations.
[3.25640229 0.56494642 0.34635661 0.32542759 0.40219124 0.47082501
 0.49157231 0.48788908 0.47856502 0.47179792 0.46109053 0.45708857
 0.46125917 0.46456999 0.15101424] -2789.300030379355
Optimisation phase is finished.


### Run the model with optimised parameter choices to produce predicted timelines for the:
 - number of new infections
 - number of deaths
 - R number

In [9]:
n_parameters = model.n_parameters()

predicted_new_infec = []
predicted_deaths = []
predicted_reprod_num = []

# Run model and number of new infections for all regions
for r, _ in enumerate(model.regions):
    parameters.regional_parameters.region_index = r+1

    parameters.regional_parameters.initial_r = [found[0]] * len(model.regions)
    LEN = len(np.arange(44, len(times), 7))

    betas = np.array(parameters.regional_parameters.betas)
    for r in range(len(model.regions)):
        for d, day in enumerate(np.arange(44, len(times), 7)):
            betas[r, day:(day+7)] = found[r*LEN+d+1]

    parameters.regional_parameters.betas = betas.tolist()
    
    r_fix = np.empty(len(times))
    model_reg_deaths_data = np.empty(len(times))

    m_fix = em.MultiTimesInfectivity(
        matrices_contact, time_changes_contact, regions, matrices_region, time_changes_region,
        parameters.regional_parameters.initial_r, parameters.disease_parameters.dI,
        parameters.ICs.susceptibles)

    # Run model and number of new infections for all age groups
    model_output = model.simulate(parameters)
    age_model_reg_new_infections = model.new_infections(model_output)
    model_reg_new_infections = age_model_reg_new_infections.sum(axis=1)

    for t, time in enumerate(times):
        r_fix[t] = m_fix.compute_reproduction_number(
            r+1, time, model_output[t, :len(age_groups)],
            temp_variation=parameters.regional_parameters.betas[r][t])
        model_reg_deaths_data[t] = np.sum(model.mean_deaths(
            fatality_ratio, time_to_death, t, age_model_reg_new_infections))
    
    predicted_new_infec.append(np.array(model_reg_new_infections))
    predicted_deaths.append(model_reg_deaths_data)
    predicted_reprod_num.append(r_fix)

predicted_new_infec = np.array(predicted_new_infec)
predicted_deaths = np.array(predicted_deaths)
predicted_reprod_num = np.array(predicted_reprod_num)

  r_fix[t] = m_fix.compute_reproduction_number(


## Plot data vs predicted

### Setup ``plotly`` and default settings for plotting

In [10]:
from plotly.subplots import make_subplots

colours = ['blue', 'red', 'green', 'purple', 'orange', 'black', 'gray', 'pink']

# Group outputs together
outputs = [deaths_data, positives_data]

# Number of regions
n_reg = len(regions)

### Select predicted quantities to plot

In [11]:
# Set up traces to plot
new_infec_pred = []
deaths_pred = []
reprod_num_pred = []

for r, _ in enumerate(model.regions):
    # Compute the prediction 
    new_infec_pred.append(predicted_new_infec[r,:])
    deaths_pred.append(predicted_deaths[r,:])
    reprod_num_pred.append(predicted_reprod_num[r,:])

### Plot observed versus predicted using model with optimised parameters

In [12]:
# Trace names - represent the solver used for the simulation
trace_name = regions
titles = ['Infections', 'Deaths', 'Reproduction Number']

fig = go.Figure()
fig = make_subplots(rows=len(titles), cols=1, subplot_titles=tuple(titles), horizontal_spacing = 0.15)

# Plot (continuous predicted timeline and pointwise observed numbers each day)
for r, region in enumerate(regions):
    # Plot of infections
    fig.add_trace(
        go.Scatter(
            x=times,
            y=new_infec_pred[r].tolist(),
            mode='lines',
            name=trace_name[r],
            line_color=colours[r]
        ),
        row= 1,
        col= 1
    )

    fig.add_trace(
        go.Scatter(
            x=serology_times,
            y=np.sum(np.multiply(np.nan_to_num(np.divide(positives_data[r], tests[r])), susceptibles[r]), axis=1).tolist(),
            mode='markers',
            name=trace_name[r],
            showlegend=False,
            line_color=colours[r]
        ),
        row= 1,
        col= 1
    )

    # Plot deaths
    fig.add_trace(
        go.Scatter(
            x=times,
            y=deaths_pred[r].tolist(),
            mode='lines',
            name=trace_name[r],
            showlegend=False,
            line_color=colours[r]
        ),
        row= 2,
        col= 1
    )

    fig.add_trace(
        go.Scatter(
            x=deaths_times,
            y=np.sum(deaths_data[r], axis=1).tolist(),
            mode='markers',
            name=trace_name[r],
            showlegend=False,
            line_color=colours[r]
        ),
        row= 2,
        col= 1
    )

    # Plot reproduction number
    fig.add_trace(
        go.Scatter(
            x=times,
            y=reprod_num_pred[r].tolist(),
            mode='lines',
            name=trace_name[r],
            showlegend=False,
            line_color=colours[r]
        ),
        row= 3,
        col= 1
    )

# Add axis labels
fig.update_layout(
    width=600, 
    height=900,
    plot_bgcolor='white',
    xaxis=dict(
        linecolor='black',
        tickvals=np.arange(1, total_days, 10).tolist(),
        ticktext=['Feb 15', 'Feb 25', 'Mar 06', 'Mar 16', 'Mar 26', 'Apr 05', 'Apr 15', 'Apr 25', 'May 05', 'May 15', 'May 25', 'Jun 04', 'Jun 14', 'Jun 24']),
    yaxis=dict(linecolor='black'),
    xaxis2=dict(
        linecolor='black',
        tickvals=np.arange(1, total_days, 10).tolist(),
        ticktext=['Feb 15', 'Feb 25', 'Mar 06', 'Mar 16', 'Mar 26', 'Apr 05', 'Apr 15', 'Apr 25', 'May 05', 'May 15', 'May 25', 'Jun 04', 'Jun 14', 'Jun 24']),
    yaxis2=dict(linecolor='black'),
    xaxis3=dict(
        linecolor='black',
        tickvals=np.arange(1, total_days, 10).tolist(),
        ticktext=['Feb 15', 'Feb 25', 'Mar 06', 'Mar 16', 'Mar 26', 'Apr 05', 'Apr 15', 'Apr 25', 'May 05', 'May 15', 'May 25', 'Jun 04', 'Jun 14', 'Jun 24']),
    yaxis3=dict(linecolor='black'),
    legend=dict(
        orientation='h',
        yanchor="bottom",
        y=1.075,
        xanchor="right",
        x=1)
    )

fig.write_image('images/Figure-3-optimisation.pdf')
fig.show()


invalid value encountered in true_divide



In [13]:
# NPIs data
max_levels_npi = [3, 3, 2, 4, 2, 3, 2, 4, 2]
targeted_npi = [True, True, True, True, True, True, True, False, True]
path = os.path.join('../../data/npi_data/')
general_npi = np.loadtxt(os.path.join(path, 'uk_flags.csv'), dtype=bool, delimiter=',').tolist()
time_changes_flag = np.loadtxt(os.path.join(path, 'times_flags.csv'), dtype=int, delimiter=',').tolist()

reg_levels_npi = [np.loadtxt(os.path.join(path, 'uk_npis.csv'), dtype=int, delimiter=',').tolist()]
time_changes_npi = np.loadtxt(os.path.join(path, 'times_npis.csv'), dtype=int, delimiter=',').tolist()

# Instantiate model
roche_model = em.RocheSEIRModel()

# Set the region names, contact and regional data of the model
roche_model.set_regions(regions)
roche_model.set_age_groups(age_groups)
roche_model.read_contact_data(matrices_contact, time_changes_contact)
roche_model.read_regional_data(matrices_region, time_changes_region)
roche_model.read_npis_data(max_levels_npi, targeted_npi, general_npi, reg_levels_npi, time_changes_npi, time_changes_flag)

SIs = []

for t in times:
    SIs.append(roche_model._compute_SI(1, t))

In [14]:
titles = ['Reproduction number', 'Stringency Index']

fig = go.Figure()
fig = make_subplots(rows=len(titles), cols=1, subplot_titles=tuple(titles), horizontal_spacing = 0.15)

# Plot reproduction number
fig.add_trace(
    go.Scatter(
        x=times,
        y=reprod_num_pred[0].tolist(),
        mode='lines',
        name=titles[0],
        showlegend=False,
        line_color=colours[0]
    ),
    row= 1,
    col= 1
)

# Plot reproduction number
fig.add_trace(
    go.Scatter(
        x=times,
        y=SIs,
        mode='lines',
        name=titles[1],
        showlegend=False,
        line_color=colours[1]
    ),
    row= 2,
    col= 1
)

fig.update_layout(
    width=600, 
    height=450,
    plot_bgcolor='white',
    xaxis=dict(
        linecolor='black',
        tickvals=np.arange(1, total_days, 10).tolist(),
        ticktext=['Feb 15', 'Feb 25', 'Mar 06', 'Mar 16', 'Mar 26', 'Apr 05', 'Apr 15', 'Apr 25', 'May 05', 'May 15', 'May 25', 'Jun 04', 'Jun 14', 'Jun 24']),
    yaxis=dict(linecolor='black'),
    xaxis2=dict(
        linecolor='black',
        tickvals=np.arange(1, total_days, 10).tolist(),
        ticktext=['Feb 15', 'Feb 25', 'Mar 06', 'Mar 16', 'Mar 26', 'Apr 05', 'Apr 15', 'Apr 25', 'May 05', 'May 15', 'May 25', 'Jun 04', 'Jun 14', 'Jun 24']),
    yaxis2=dict(linecolor='black'),
    legend=dict(
        orientation='h',
        yanchor="bottom",
        y=1.075,
        xanchor="right",
        x=1)
    )

fig.write_image('images/SIvsR.pdf')
fig.show()