# WarwickLanc SEIR Model First Example

In this notebook we present how to use the `warwickmodel` module to set up an instantiation of the model built by Universities of Warwick and Lancaster, using some toy data.

*The Warwick model is built by Universities of Warwick and Lancaster.*

In [1]:
# Load necessary libraries
import os
import numpy as np
import pandas as pd
import scipy
import epimodels as em
import warwickmodel as wm
import matplotlib
import plotly.graph_objects as go
from matplotlib import pyplot as plt
from iteration_utilities import deepflatten

## Model Setup
### Define setup matrices for the WarwickLanc Model

In [2]:
# Populate the model
total_days =  100
regions = ['United Kingdom', 'France']
age_groups = ['0-4', '5-9', '10-14', '15-19', '20-24', '25-29', '30-34', '35-39',
              '40-44', '45-49', '50-54', '55-59', '60-64', '65-69', '70-74', '75+']

extended_age_groups = ['0-4', '5-9', '10-14', '15-19', '20-24', '25-29', '30-34', '35-39',
                       '40-44', '45-49', '50-54', '55-59', '60-64', '65-69', '70-74', '75-79',
                       '80-84', '85-89', '90-94', '95-99', '100+']

matrices_region = []

# Initial state of the system
weeks_matrices_region = []
for r in regions:
    path = os.path.join('../data/Contacts_{}.csv'.format(r))
    region_data_matrix = pd.read_csv(path, header=None, dtype=np.float64)
    regional = em.RegionMatrix(r, age_groups, region_data_matrix)
    weeks_matrices_region.append(regional)

matrices_region.append(weeks_matrices_region)

contacts = em.ContactMatrix(age_groups, np.ones((len(age_groups), len(age_groups))))
matrices_contact = [contacts]

# Matrices contact
time_changes_contact = [1]
time_changes_region = [1]

In [3]:
# Add folder path to data file
path = os.path.join('../data/')

# Risk Factors
RF_df = pd.read_csv(os.path.join(path, 'Risks_France.csv'), dtype=np.float64)

extended_d = RF_df['symptom_risk'].tolist()
extended_beta = RF_df['susceptibility'].tolist()

# Vaccine effects
eff_df = pd.read_csv(os.path.join(path, 'efficacies.csv'),
                     usecols=range(1,5), dtype=np.float64)

VE_i = eff_df['Infection_eff']
VE_s = eff_df['Symptom_eff']
VE_h = eff_df['Hosp_eff']
VE_d = eff_df['Death_eff']

VE_d = np.divide(VE_d-VE_h, 1-VE_h)
VE_h = np.divide(VE_h-VE_i, 1-VE_i)
VE_s = np.divide(VE_s-VE_i, 1-VE_i)

nu_tra = [1] * 6
nu_symp = np.nan_to_num(1 - VE_s).tolist()
nu_inf = np.nan_to_num(1 - VE_i).tolist()
nu_sev_h = np.nan_to_num(1 - VE_h).tolist()
nu_sev_d = np.nan_to_num(1 - VE_d).tolist()

# Parameters
param_df = pd.read_csv(os.path.join(path, 'parameters.csv'), dtype=np.float64)

omega = param_df['transmission'].tolist()[0]
alpha = param_df['e_progression'].tolist()[0]
gamma = param_df['recovery'].tolist()[0]
tau = param_df['asymptomatic_transmission'].tolist()[0]
we = [param_df['waning_rate'].tolist()[0], 0]

# Initial conditions
susceptibles_IC = []
exposed1_IC = []
exposed2_IC = []
exposed3_IC = []
exposed4_IC = []
exposed5_IC = []
infectives_sym_IC = []
infectives_asym_IC = []
recovered_IC = []

# Susceptible
for r in regions:
        IC_df = pd.read_csv(
            os.path.join(path, 'Start_pop_{}.csv'.format(r)),
            usecols=range(0, 5),
            header=None, dtype=np.float64)

        extended_S = np.asarray(IC_df)
        under_75_S = extended_S[:15, :]
        over_75_S = extended_S[15:, :]
        reduced_S = np.vstack((under_75_S, np.sum(over_75_S, axis=0)))
        susceptibles_IC.append(
                reduced_S.flatten('F').tolist() + [0] * len(age_groups))

# Exposed 1
for r in regions:
        IC_df = pd.read_csv(
            os.path.join(path, 'Start_pop_{}.csv'.format(r)),
            usecols=range(5, 10),
            header=None, dtype=np.float64)

        extended_E1 = np.asarray(IC_df)
        under_75_E1 = extended_E1[:15, :]
        over_75_E1 = extended_E1[15:, :]
        reduced_E1 = np.vstack((under_75_E1, np.sum(over_75_E1, axis=0)))
        exposed1_IC.append(
                reduced_E1.flatten('F').tolist() + [0] * len(age_groups))

# Exposed 2
for r in regions:
        IC_df = pd.read_csv(
            os.path.join(path, 'Start_pop_{}.csv'.format(r)),
            usecols=range(10, 15),
            header=None, dtype=np.float64)

        extended_E2 = np.asarray(IC_df)
        under_75_E2 = extended_E2[:15, :]
        over_75_E2 = extended_E1[15:, :]
        reduced_E2 = np.vstack((under_75_E2, np.sum(over_75_E2, axis=0)))
        exposed2_IC.append(
                reduced_E2.flatten('F').tolist() + [0] * len(age_groups))

# Exposed 3
for r in regions:
        IC_df = pd.read_csv(
            os.path.join(path, 'Start_pop_{}.csv'.format(r)),
            usecols=range(15, 20),
            header=None, dtype=np.float64)

        extended_E3 = np.asarray(IC_df)
        under_75_E3 = extended_E3[:15, :]
        over_75_E3 = extended_E3[15:, :]
        reduced_E3 = np.vstack((under_75_E3, np.sum(over_75_E3, axis=0)))
        exposed3_IC.append(
                reduced_E3.flatten('F').tolist() + [0] * len(age_groups))

# Exposed 4
for r in regions:
        IC_df = pd.read_csv(
            os.path.join(path, 'Start_pop_{}.csv'.format(r)),
            usecols=range(20, 25),
            header=None, dtype=np.float64)

        extended_E4 = np.asarray(IC_df)
        under_75_E4 = extended_E4[:15, :]
        over_75_E4 = extended_E4[15:, :]
        reduced_E4 = np.vstack((under_75_E4, np.sum(over_75_E4, axis=0)))
        exposed4_IC.append(
                reduced_E4.flatten('F').tolist() + [0] * len(age_groups))

# Exposed 5
for r in regions:
        IC_df = pd.read_csv(
            os.path.join(path, 'Start_pop_{}.csv'.format(r)),
            usecols=range(25, 30),
            header=None, dtype=np.float64)

        extended_E5 = np.asarray(IC_df)
        under_75_E5 = extended_E5[:15, :]
        over_75_E5 = extended_E5[15:, :]
        reduced_E5 = np.vstack((under_75_E5, np.sum(over_75_E5, axis=0)))
        exposed5_IC.append(
                reduced_E5.flatten('F').tolist() + [0] * len(age_groups))

# Symptomatic & Asymptomatic Infectious
for r in regions:
        IC_df = pd.read_csv(
            os.path.join(path, 'Start_pop_{}.csv'.format(r)),
            usecols=range(30, 35),
            header=None, dtype=np.float64)

        extended_I = np.matmul(np.diag(extended_d), np.asarray(IC_df))
        under_75_I = extended_I[:15, :]
        over_75_I = extended_I[15:, :]
        reduced_I = np.vstack((under_75_I, np.sum(over_75_I, axis=0)))
        infectives_sym_IC.append(
                reduced_I.flatten('F').tolist() + [0] * len(age_groups))

        extended_A = np.matmul(np.diag((1 - np.array(extended_d))), np.asarray(IC_df))
        under_75_A = extended_A[:15, :]
        over_75_A = extended_A[15:, :]
        reduced_A = np.vstack((under_75_A, np.sum(over_75_A, axis=0)))
        infectives_asym_IC.append(
                reduced_A.flatten('F').tolist() + [0] * len(age_groups))

# Recovered
for r in regions:
        IC_df = pd.read_csv(
            os.path.join(path, 'Start_pop_{}.csv'.format(r)),
            usecols=[35],
            header=None, dtype=np.float64)

        extended_R = np.asarray(IC_df)
        under_75_R = extended_R[:15, :]
        over_75_R = extended_R[15:, :]
        reduced_R = np.vstack((under_75_R, np.sum(over_75_R, axis=0)))
        recovered_IC.append(
                reduced_R.flatten('F').tolist())

# Over 75 population fractions
IC_df = pd.read_csv(os.path.join(path, 'Start_pop_United Kingdom.csv'),
            skiprows=15,
            header=None, dtype=np.float64)

frac_pop_over75_UK = (1/np.sum(np.asarray(IC_df))) * np.sum(np.asarray(IC_df),axis=1)

# Over 75 population fractions
IC_df = pd.read_csv(os.path.join(path, 'Start_pop_France.csv'),
            skiprows=15,
            header=None, dtype=np.float64)

frac_pop_over75_FR = (1/np.sum(np.asarray(IC_df))) * np.sum(np.asarray(IC_df),axis=1) 

# Other parameters
vac=0
vacb=0

# Compress age-dependent parameters
d = [
        extended_d[:15] + [np.sum(np.multiply(extended_d[15:], frac_pop_over75_UK))],
        extended_d[:15] + [np.sum(np.multiply(extended_d[15:], frac_pop_over75_FR))],
        ]
beta = [
        extended_beta[:15] + [np.sum(np.multiply(extended_beta[15:], frac_pop_over75_UK))],
        extended_d[:15] + [np.sum(np.multiply(extended_d[15:], frac_pop_over75_FR))],
        ]


### Set the parameters and initial conditions of the model and bundle everything together

In [4]:
# Instantiate model
model = wm.WarwickLancSEIRModel()

# Set the region names, contact and regional data of the model
model.set_regions(regions)
model.set_age_groups(age_groups)
model.read_contact_data(matrices_contact, time_changes_contact)
model.read_regional_data(matrices_region, time_changes_region)

# List of times at which we wish to evaluate the states of the compartments of the model
times = np.arange(1, total_days+1, 1).tolist()

# Set regional and time dependent parameters
regional_parameters = wm.RegParameters(
    model=model,
    region_index=1
)

# Set ICs parameters
ICs_parameters = wm.ICs(
    model=model,
    susceptibles_IC=susceptibles_IC,
    exposed1_IC=exposed1_IC,
    exposed2_IC=exposed2_IC,
    exposed3_IC=exposed3_IC,
    exposed4_IC=exposed4_IC,
    exposed5_IC=exposed5_IC,
    infectives_sym_IC=infectives_sym_IC,
    infectives_asym_IC=infectives_asym_IC,
    recovered_IC=recovered_IC
)

# Set disease-specific parameters
disease_parameters = wm.DiseaseParameters(
    model=model,
    d=d[0],
    tau=tau,
    we=we,
    omega=omega
)

# Set transmission parameters
transmission_parameters = wm.Transmission(
    model=model,
    beta=beta[0],
    alpha=alpha,
    gamma=gamma
)

# Set other simulation parameters
simulation_parameters = wm.SimParameters(
    model=model,
    method='Radau',
    times=times,
    eps=False
)

# Set vaccination parameters
vaccine_parameters = wm.VaccineParameters(
    model=model,
    vac=vac,
    vacb=vacb,
    nu_tra=nu_tra,
    nu_symp=nu_symp,
    nu_inf=nu_inf,
    nu_sev_h=nu_sev_h,
    nu_sev_d=nu_sev_d,
)

# Set social distancing parameters
soc_dist_parameters = wm.SocDistParameters(
    model=model,
    phi=1
)

# Set all parameters in the controller
parameters = wm.ParametersController(
    model=model,
    regional_parameters=regional_parameters,
    ICs_parameters=ICs_parameters,
    disease_parameters=disease_parameters,
    transmission_parameters=transmission_parameters,
    simulation_parameters=simulation_parameters,
    vaccine_parameters=vaccine_parameters,
    soc_dist_parameters=soc_dist_parameters
)

### Simulate for the regions

In [5]:
# Simulate for all the regions
outputs = []

for r, reg in enumerate(regions):
    # List of initial conditions and parameters that characterise the model
    parameters.regional_parameters.region_index = r + 1

    parameters.disease_parameters.d =d[r]
    parameters.transmission_parameters.beta = beta[r]

    # Simulate using the ODE solver
    outputs.append(model.simulate(parameters))

## Plot the comparments of the two methods against each other
### Setup ``plotly`` and default settings for plotting

In [6]:
from plotly.subplots import make_subplots

colours = ['blue', 'red', 'green', 'purple', 'orange', 'black', 'gray', 'pink', 'blue', 'red', 'green', 'purple', 'orange', 'black', 'gray', 'pink']

### Plot the comparments of the two methods against each other

In [7]:
# Trace names - represent the solver used for the simulation
trace_name = ['region {}'.format(r) for r in regions]

# Compartment list - type and age
comparments = []
for n in model.output_names():
    comparments.append('{}'.format(n))

# Plot for each comparment
for c, comparment in enumerate(comparments):
    fig = go.Figure()
    fig = make_subplots(rows=int(np.ceil(len(age_groups)/2)), cols=2, subplot_titles=tuple('ages {}'.format(a) for a in age_groups))
    # Plot (line plot for each solver method for each age)
    for a, age in enumerate(age_groups):
        if a != 0:
            for o, out in enumerate(outputs):
                fig.add_trace(
                    go.Scatter(
                        y=out[10:, c*len(age_groups)+a],
                        x=parameters.simulation_parameters.times[10:],
                        mode='lines',
                        name=trace_name[o],
                        line_color=colours[o],
                        showlegend=False
                    ),
                    row= int(np.floor(a / 2)) + 1,
                    col= a % 2 + 1
                )
        
        else:
            for o, out in enumerate(outputs):
                fig.add_trace(
                    go.Scatter(
                        y=out[10:, c*len(age_groups)+a],
                        x=parameters.simulation_parameters.times[10:],
                        mode='lines',
                        name=trace_name[o],
                        line_color=colours[o]
                    ),
                    row= int(np.floor(a / 2)) + 1,
                    col= a % 2 + 1
                )

    # Add axis labels
    fig.update_layout(
        boxmode='group',
        title=comparment, 
        width=800,
        height=1950,
        plot_bgcolor='white',
        xaxis=dict(linecolor='black'),
        yaxis=dict(linecolor='black'),
        xaxis2=dict(linecolor='black'),
        yaxis2=dict(linecolor='black'),
        xaxis3=dict(linecolor='black'),
        yaxis3=dict(linecolor='black'),
        xaxis4=dict(linecolor='black'),
        yaxis4=dict(linecolor='black'),
        xaxis5=dict(linecolor='black'),
        yaxis5=dict(linecolor='black'),
        xaxis6=dict(linecolor='black'),
        yaxis6=dict(linecolor='black'),
        xaxis7=dict(linecolor='black'),
        yaxis7=dict(linecolor='black'),
        xaxis8=dict(linecolor='black'),
        yaxis8=dict(linecolor='black'),
        xaxis9=dict(linecolor='black'),
        yaxis9=dict(linecolor='black'),
        xaxis10=dict(linecolor='black'),
        yaxis10=dict(linecolor='black'),
        xaxis11=dict(linecolor='black'),
        yaxis11=dict(linecolor='black'),
        xaxis12=dict(linecolor='black'),
        yaxis12=dict(linecolor='black'),
        xaxis13=dict(linecolor='black'),
        yaxis13=dict(linecolor='black'),
        xaxis14=dict(linecolor='black'),
        yaxis14=dict(linecolor='black'),
        xaxis15=dict(linecolor='black'),
        yaxis15=dict(linecolor='black'),
        xaxis16=dict(linecolor='black'),
        yaxis16=dict(linecolor='black'),
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=1
        ))

    fig.write_image('images/{}.pdf'.format(comparment))
    fig.show()

### Plot the boosted symptomatic infectious comparment for one region (UK) and one age group (0-10)

In [8]:
# Trace names - represent the solver used for the simulation
trace_name = ['ages {}'.format(a) for a in age_groups]

# Compartment list - type and age
fig = go.Figure()

# Plot (line plot for each method of simulation)
for a in range(len(age_groups)):
    fig.add_trace(
        go.Scatter(
            y=outputs[0][:, 38 * len(age_groups) + a],
            x=parameters.simulation_parameters.times,
            mode='lines',
            name=trace_name[a],
            line_color=colours[a]
        )
    )

# Add axis labels
fig.update_layout(
    boxmode='group',
    title='I in region {}'. format(regions[0]),
    plot_bgcolor='white',
    xaxis=dict(linecolor='black'),
    yaxis=dict(linecolor='black'),
    )

fig.write_image('images/Infection UK.pdf')
fig.show()

## Number of New Infections, Hospitalisations & Deaths

In [9]:
# Set time-to-hospitalisation using a Gamma distribution using the mean and standard deviation 
th_mean = param_df['hosp_lag'].tolist()[0]+0.00001
th_var = 12.1**2
theta = th_var / th_mean
k = th_mean / theta
time_to_hosp = scipy.stats.gamma(k, scale=theta).pdf(np.arange(1, 31)).tolist()

# Set time-to-death using a Gamma distribution using the mean and standard deviation
td_mean = param_df['death_lag'].tolist()[0]
td_var = 12.1**2
theta = td_var / td_mean
k = td_mean / theta
time_to_death = scipy.stats.gamma(k, scale=theta).pdf(np.arange(1, 31)).tolist()

# Probabilities of proceeding to severe outcomes
# Infected -> Hospital
extended_pItoH = RF_df['hospitalisation_risk'].tolist()
pItoH = [
    extended_pItoH[:15] + [np.sum(np.multiply(extended_pItoH[15:], frac_pop_over75_UK))],
    extended_pItoH[:15] + [np.sum(np.multiply(extended_pItoH[15:], frac_pop_over75_FR))]]

# Hospital -> Death
extended_pHtoD = RF_df['death_risk'].tolist()
pHtoD = [
    extended_pHtoD[:15] + [np.sum(np.multiply(extended_pHtoD[15:], frac_pop_over75_UK))],
    extended_pHtoD[:15] + [np.sum(np.multiply(extended_pHtoD[15:], frac_pop_over75_FR))]]

# Distribution of delays before proceeding to severe outcomes
# Infected -> Hospital
dItoH = time_to_hosp
# Hospital -> Death
dHtoD = time_to_death

In [10]:
# Simulate for all the regions
new_infections = []
new_hospitalisation = []
new_deaths = []

for r, reg in enumerate(regions):
    # Compute reginal matrix of new infections for all timepoints simulated
    reg_new_infections = model.new_infections(outputs[r])

    # Compute reginal matrix of new hospitalisation for all timepoints simulated
    reg_new_hospitalisation = model.new_hospitalisations(reg_new_infections, pItoH[r], dItoH)

    # Compute reginal matrix of new deaths for all timepoints simulated
    reg_new_deaths = model.new_deaths(reg_new_hospitalisation, pHtoD[r], dHtoD)

    new_infections.append(reg_new_infections)
    new_hospitalisation.append(reg_new_hospitalisation)
    new_deaths.append(reg_new_deaths)


### Plot the compartments of the two regions against each other

Hospitalisations

In [11]:
# Trace names - represent the solver used for the simulation
trace_name = ['region {}'.format(r) for r in regions]

# Compartment list - type and age
comparments = []
for n in ['', 'f', 'b', 'w1', 'w2', 'w3']:
    comparments.append('H{}'.format(n))

# Plot for each comparment
for c, comparment in enumerate(comparments):
    fig = go.Figure()
    fig = make_subplots(rows=int(np.ceil(len(age_groups)/2)), cols=2, subplot_titles=tuple('ages {}'.format(a) for a in age_groups))
    # Plot (line plot for each solver method for each age)
    for a, age in enumerate(age_groups):
        if a != 0:
            for o, out in enumerate(outputs):
                fig.add_trace(
                    go.Scatter(
                        y=new_hospitalisation[o][c][10:, a],
                        x=parameters.simulation_parameters.times[10:],
                        mode='lines',
                        name=trace_name[o],
                        line_color=colours[o],
                        showlegend=False
                    ),
                    row= int(np.floor(a / 2)) + 1,
                    col= a % 2 + 1
                )
        
        else:
            for o, out in enumerate(outputs):
                fig.add_trace(
                    go.Scatter(
                        y=new_hospitalisation[o][c][10:, a],
                        x=parameters.simulation_parameters.times[10:],
                        mode='lines',
                        name=trace_name[o],
                        line_color=colours[o]
                    ),
                    row= int(np.floor(a / 2)) + 1,
                    col= a % 2 + 1
                )

    # Add axis labels
    fig.update_layout(
        boxmode='group',
        title=comparment, 
        width=800,
        height=1950,
        plot_bgcolor='white',
        xaxis=dict(linecolor='black'),
        yaxis=dict(linecolor='black'),
        xaxis2=dict(linecolor='black'),
        yaxis2=dict(linecolor='black'),
        xaxis3=dict(linecolor='black'),
        yaxis3=dict(linecolor='black'),
        xaxis4=dict(linecolor='black'),
        yaxis4=dict(linecolor='black'),
        xaxis5=dict(linecolor='black'),
        yaxis5=dict(linecolor='black'),
        xaxis6=dict(linecolor='black'),
        yaxis6=dict(linecolor='black'),
        xaxis7=dict(linecolor='black'),
        yaxis7=dict(linecolor='black'),
        xaxis8=dict(linecolor='black'),
        yaxis8=dict(linecolor='black'),
        xaxis9=dict(linecolor='black'),
        yaxis9=dict(linecolor='black'),
        xaxis10=dict(linecolor='black'),
        yaxis10=dict(linecolor='black'),
        xaxis11=dict(linecolor='black'),
        yaxis11=dict(linecolor='black'),
        xaxis12=dict(linecolor='black'),
        yaxis12=dict(linecolor='black'),
        xaxis13=dict(linecolor='black'),
        yaxis13=dict(linecolor='black'),
        xaxis14=dict(linecolor='black'),
        yaxis14=dict(linecolor='black'),
        xaxis15=dict(linecolor='black'),
        yaxis15=dict(linecolor='black'),
        xaxis16=dict(linecolor='black'),
        yaxis16=dict(linecolor='black'),
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=1
        ))

    fig.write_image('images/{}.pdf'.format(comparment))
    fig.show()

Deaths

In [12]:
# Trace names - represent the solver used for the simulation
trace_name = ['region {}'.format(r) for r in regions]

# Compartment list - type and age
comparments = []
for n in ['', 'f', 'b', 'w1', 'w2', 'w3']:
    comparments.append('D{}'.format(n))

# Plot for each comparment
for c, comparment in enumerate(comparments):
    fig = go.Figure()
    fig = make_subplots(rows=int(np.ceil(len(age_groups)/2)), cols=2, subplot_titles=tuple('ages {}'.format(a) for a in age_groups))
    # Plot (line plot for each solver method for each age)
    for a, age in enumerate(age_groups):
        if a != 0:
            for o, out in enumerate(outputs):
                fig.add_trace(
                    go.Scatter(
                        y=new_deaths[o][c][:, a],
                        x=parameters.simulation_parameters.times,
                        mode='lines',
                        name=trace_name[o],
                        line_color=colours[o],
                        showlegend=False
                    ),
                    row= int(np.floor(a / 2)) + 1,
                    col= a % 2 + 1
                )
        
        else:
            for o, out in enumerate(outputs):
                fig.add_trace(
                    go.Scatter(
                        y=new_deaths[o][c][:, a],
                        x=parameters.simulation_parameters.times,
                        mode='lines',
                        name=trace_name[o],
                        line_color=colours[o]
                    ),
                    row= int(np.floor(a / 2)) + 1,
                    col= a % 2 + 1
                )

    # Add axis labels
    fig.update_layout(
        boxmode='group',
        title=comparment,
        width=800,
        height=1950,
        plot_bgcolor='white',
        xaxis=dict(linecolor='black'),
        yaxis=dict(linecolor='black'),
        xaxis2=dict(linecolor='black'),
        yaxis2=dict(linecolor='black'),
        xaxis3=dict(linecolor='black'),
        yaxis3=dict(linecolor='black'),
        xaxis4=dict(linecolor='black'),
        yaxis4=dict(linecolor='black'),
        xaxis5=dict(linecolor='black'),
        yaxis5=dict(linecolor='black'),
        xaxis6=dict(linecolor='black'),
        yaxis6=dict(linecolor='black'),
        xaxis7=dict(linecolor='black'),
        yaxis7=dict(linecolor='black'),
        xaxis8=dict(linecolor='black'),
        yaxis8=dict(linecolor='black'),
        xaxis9=dict(linecolor='black'),
        yaxis9=dict(linecolor='black'),
        xaxis10=dict(linecolor='black'),
        yaxis10=dict(linecolor='black'),
        xaxis11=dict(linecolor='black'),
        yaxis11=dict(linecolor='black'),
        xaxis12=dict(linecolor='black'),
        yaxis12=dict(linecolor='black'),
        xaxis13=dict(linecolor='black'),
        yaxis13=dict(linecolor='black'),
        xaxis14=dict(linecolor='black'),
        yaxis14=dict(linecolor='black'),
        xaxis15=dict(linecolor='black'),
        yaxis15=dict(linecolor='black'),
        xaxis16=dict(linecolor='black'),
        yaxis16=dict(linecolor='black'),
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=1
        ))

    fig.write_image('images/{}.pdf'.format(comparment))
    fig.show()