# SEIR Model

In [138]:
%load_ext lab_black

# Data manipulation
import numpy as np
import pandas as pd

pd.options.display.max_columns = 100
pd.options.display.max_colwidth = 500

# Data viz
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
import seaborn as sns

sns.set(style="darkgrid", rc={"figure.figsize": (11.7, 8.27)})

# Modeling
from scipy.integrate import solve_ivp
from scipy.optimize import minimize, least_squares

# Other
import copy
from datetime import date, timedelta
import os
import random
import sys

# Custom module
module_path = os.path.abspath(os.path.join("../"))
if module_path not in sys.path:
    sys.path.append(module_path)

from src.seir_model import SEIRModel
from src.plotting import plot_predictions
from src.utils import get_covid_data, get_all_covid_data

# Reload imported code
%reload_ext autoreload
%autoreload 2

# Print all output
from IPython.core.interactiveshell import InteractiveShell

InteractiveShell.ast_node_interactivity = "all"

# Set seeds for reproducibility
rand_seed = 2
np.random.seed(rand_seed)
random.seed(rand_seed)

The lab_black extension is already loaded. To reload it, use:
  %reload_ext lab_black


## Load data

In [139]:
# Read in data
mortality_start_date = date(day=8, month=3, year=2020)
mortality_provinces = get_all_covid_data(level="prov").query(
    "date >= @mortality_start_date"
)

provinces = ["Alberta", "BC", "Manitoba", "Ontario", "Quebec", "Saskatchewan"]

# Get first and last day of death reports
mortality_start_date = mortality_provinces["date"].min()
mortality_end_date = mortality_provinces["date"].max()

# Filter for Ontario
mortality_ontario = mortality_provinces.query('province == "Ontario"')

mortality_provinces.head(10)

Unnamed: 0,province,date,cumulative_cases,cumulative_recovered,cumulative_deaths,active_cases,active_cases_change,deaths,recovered,cases,population,removed,cumulative_removed,susceptible,percent_susceptible
43,Alberta,2020-03-08,4,0,0,4,2,0,0,2,4421876,0,0,4421872,0.999999
44,Alberta,2020-03-09,14,0,0,14,10,0,0,10,4421876,0,0,4421862,0.999997
45,Alberta,2020-03-10,14,0,0,14,0,0,0,0,4421876,0,0,4421862,0.999997
46,Alberta,2020-03-11,19,0,0,19,5,0,0,5,4421876,0,0,4421857,0.999996
47,Alberta,2020-03-12,23,0,0,23,4,0,0,4,4421876,0,0,4421853,0.999995
48,Alberta,2020-03-13,29,0,0,29,6,0,0,6,4421876,0,0,4421847,0.999993
49,Alberta,2020-03-14,39,0,0,39,10,0,0,10,4421876,0,0,4421837,0.999991
50,Alberta,2020-03-15,56,0,0,56,17,0,0,17,4421876,0,0,4421820,0.999987
51,Alberta,2020-03-16,74,0,0,74,18,0,0,18,4421876,0,0,4421802,0.999983
52,Alberta,2020-03-17,97,0,0,97,23,0,0,23,4421876,0,0,4421779,0.999978


## SEIR model parameter estimation

In [140]:
model = SEIRModel()
model.fit(mortality_ontario)
forecasts = model.forecast(h=21)

forecasts.tail()

In [142]:
model.optimal_params

array([20.95349961, 13.10989529, 13.12221521])

In [None]:
plot_predictions(
    forecasts,
    y="cumulative_deaths",
    y_label="Cumulative deaths",
    title="COVID-19 mortality in Ontario",
    height=600,
)

plot_predictions(
    forecasts,
    y="active_cases",
    y_label="Active cases",
    title="COVID-19 active cases in Ontario",
    height=600
)