# SEIRD Model First Example

In this notebook we present how to use the `epimodels` module to set up an instantiation of the SEIRD model, using some toy data.

In [1]:
# Load necessary libraries
import numpy as np
import epimodels as em
import matplotlib
import plotly.graph_objects as go
from matplotlib import pyplot as plt
from iteration_utilities import deepflatten

## Model Setup
### Define setup matrices for the SEIRD Model

In [2]:
# Populate the model
regions = ['London', 'Cornwall']
age_groups = ['0-10', '10-25']

### Set the parameters and initial conditions of the model and bundle everything together

In [3]:
# Instantiate model
model = em.SEIRDModel()

# Set the region names, contact and regional data of the model
model.set_regions(regions)
model.set_age_groups(age_groups)

# Set ICs parameters
ICs = em.SEIRDICs(
    model=model,
    susceptibles_IC=[[1500, 600], [700, 400]],
    exposed_IC=[[0, 0], [0, 0]],
    infectives_IC=[[40, 20], [50, 32]],
    recovered_IC=[[0, 0], [0, 0]],
    dead_IC=[[0, 0], [0, 0]]
)

# Set transmission parameters
transmission_parameters = em.SEIRDTransmission(
    model=model,
    beta=0.2,
    kappa=0.2,
    gamma=0.1,
    Pd=0.05 * np.ones(len(age_groups))
)

# Set other simulation parameters
simulation_parameters = em.SEIRDSimParameters(
    model=model,
    region_index=2,
    method='RK45',
    times=np.arange(1, 20.5, 0.5).tolist()
)

# Set all parameters in the controller
parameters = em.SEIRDParametersController(
    model=model,
    ICs=ICs,
    transmission_parameters=transmission_parameters,
    simulation_parameters=simulation_parameters
)

### Simulate for one of the regions: **London**

In [4]:
# Simulate for all the regions
outputs = []

for r, reg in enumerate(regions):
    # List of initial conditions and parameters that characterise the model
    parameters.simulation_parameters.region_index = r + 1

    # Simulate using the ODE solver
    outputs.append(model.simulate(parameters))

## Plot the comparments of the two methods against each other
### Setup ``plotly`` and default settings for plotting

In [5]:
from plotly.subplots import make_subplots

colours = ['blue', 'red', 'green', 'purple', 'orange', 'black', 'gray', 'pink']

### Plot the comparments of the two methods against each other

In [6]:
# Trace names - represent the solver used for the simulation
trace_name = ['region {}'.format(r) for r in regions]

# Compartment list - type and age
comparments = []
for n in model.output_names():
    comparments.append('{}'.format(n))

# Plot for each comparment
for c, comparment in enumerate(comparments):
    fig = go.Figure()
    fig = make_subplots(rows=int(np.ceil(len(age_groups)/2)), cols=2, subplot_titles=tuple('ages {}'.format(a) for a in age_groups))
    # Plot (line plot for each solver method for each age)
    for a, age in enumerate(age_groups):
        if a != 0:
            for o, out in enumerate(outputs):
                fig.add_trace(
                    go.Scatter(
                        y=out[:, c*len(age_groups)+a],
                        x=parameters.simulation_parameters.times,
                        mode='lines',
                        name=trace_name[o],
                        line_color=colours[o],
                        showlegend=False
                    ),
                    row= int(np.floor(a / 2)) + 1,
                    col= a % 2 + 1
                )
        
        else:
            for o, out in enumerate(outputs):
                fig.add_trace(
                    go.Scatter(
                        y=out[:, c*len(age_groups)+a],
                        x=parameters.simulation_parameters.times,
                        mode='lines',
                        name=trace_name[o],
                        line_color=colours[o]
                    ),
                    row= int(np.floor(a / 2)) + 1,
                    col= a % 2 + 1
                )

    # Add axis labels
    fig.update_layout(
        boxmode='group',
        title=comparment, 
        width=800,
        plot_bgcolor='white',
        xaxis=dict(linecolor='black'),
        yaxis=dict(linecolor='black'),
        xaxis2=dict(linecolor='black'),
        yaxis2=dict(linecolor='black'),
        xaxis3=dict(linecolor='black'),
        yaxis3=dict(linecolor='black'),
        xaxis4=dict(linecolor='black'),
        yaxis4=dict(linecolor='black'),
        xaxis5=dict(linecolor='black'),
        yaxis5=dict(linecolor='black'),
        xaxis6=dict(linecolor='black'),
        yaxis6=dict(linecolor='black'),
        xaxis7=dict(linecolor='black'),
        yaxis7=dict(linecolor='black'),
        xaxis8=dict(linecolor='black'),
        yaxis8=dict(linecolor='black'),
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=1
        ))

    fig.show()

### Plot the infectious comparment for one region (London)

In [7]:
# Trace names - represent the solver used for the simulation
trace_name = ['ages {}'.format(a) for a in age_groups]

# Compartment list - type and age
fig = go.Figure()

# Plot (line plot for each method of simulation)
for a in range(len(age_groups)):
    fig.add_trace(
        go.Scatter(
            y=outputs[0][:, 2 * len(age_groups) + a],
            x=parameters.simulation_parameters.times,
            mode='lines',
            name=trace_name[a],
            line_color=colours[a]
        )
    )

# Add axis labels
fig.update_layout(
    boxmode='group',
    title='I in region {}'. format(regions[0]),
    plot_bgcolor='white',
    xaxis=dict(linecolor='black'),
    yaxis=dict(linecolor='black'),
    )

fig.write_image('images/Toy.pdf')
fig.show()