<a href="https://colab.research.google.com/github/Intelecy/covid-19-research/blob/areeh/COVID_19_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# COVID-19 death forecast with probabilistic inference of a logistic curve

Created by Intelecy AS, Norway

GitHub source: https://github.com/Intelecy/covid-19-research

In [0]:
#@title Install non-standard libraries
!pip install -q pycountry

In [0]:
#@title Imports
import datetime
import logging
import matplotlib.pyplot as plt
import numpy as np
import plotly
import plotly.graph_objects as go
import pycountry
import pymc3 as pm
import seaborn as sns
import time
from copy import deepcopy
from functools import partial
from google.colab import files
from IPython.utils import io as ipython_io
from numba import jit
from operator import attrgetter
from operator import itemgetter
from operator import methodcaller
from typing import Any
from typing import Dict
from typing import List
from typing import NamedTuple
from typing import Optional
from typing import Sequence
from typing import Tuple

## Import data

In [0]:
#@title Load the raw death data from the ecdc website if is not loaded yet
# Data sources:
# Full data (current): !wget https://www.ecdc.europa.eu/sites/default/files/documents/COVID-19-geographic-disbtribution-worldwide-2020-03-19.xlsx
# Incomplete data:
#     Deaths: https://github.com/CSSEGISandData/COVID-19/blob/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_deaths_global.csv
#     Confirmed: https://github.com/CSSEGISandData/COVID-19/blob/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_confirmed_global.csv
if 'daily_df' not in locals():
  yesterday_str = (datetime.date.today() - datetime.timedelta(days=1)).strftime(
      '%Y-%m-%d')
  today_str = datetime.date.today().strftime('%Y-%m-%d')
  # !wget https://www.ecdc.europa.eu/sites/default/files/documents/COVID-19-geographic-disbtribution-worldwide-2020-03-19.xlsx
  !wget https://www.ecdc.europa.eu/sites/default/files/documents/COVID-19-geographic-disbtribution-worldwide-{yesterday_str}.xlsx
  !wget https://www.ecdc.europa.eu/sites/default/files/documents/COVID-19-geographic-disbtribution-worldwide-{today_str}.xlsx
  try:
    daily_df = pd.read_excel(f"COVID-19-geographic-disbtribution-worldwide-{today_str}.xlsx")
    data_str = today_str
  except:
    daily_df = pd.read_excel(f"COVID-19-geographic-disbtribution-worldwide-{yesterday_str}.xlsx")
    data_str = yesterday_str
  # confirmed_df = pd.read_csv(io.BytesIO(uploaded['time_series_19-covid-Deaths.csv']))
  # deaths_df = pd.read_csv(io.BytesIO(uploaded['time_series_19-covid-Deaths.csv']))
# Population Data:
#     https://population.un.org/wpp/Download/Files/1_Indicators%20(Standard)/EXCEL_FILES/1_Population/WPP2019_POP_F01_1_TOTAL_POPULATION_BOTH_SEXES.xlsx

if 'pop_df' not in locals():
  !wget "https://population.un.org/wpp/Download/Files/1_Indicators%20(Standard)/EXCEL_FILES/1_Population/WPP2019_POP_F01_1_TOTAL_POPULATION_BOTH_SEXES.xlsx"
  pop_cols = ['Index', 'Region, subregion, country or area *', 'Country code', 'Type', '2020']
  pop_df = pd.read_excel("WPP2019_POP_F01_1_TOTAL_POPULATION_BOTH_SEXES.xlsx", header=16, index_col=0, usecols=pop_cols)
  pop_df = pop_df[pop_df['Type'] != 'Label/Separator']
  pop_df = pop_df.astype({'Country code': 'int32', '2020': 'int32'})

In [0]:
#@title Population lookup helper function
def get_pop(row) -> float:
  geo_id = row["GeoId"]
  if geo_id == "UK":
    geo_id = "GB"
  try:
    country_code = int(pycountry.countries.get(alpha_2=geo_id).numeric)
    pop = pop_df[pop_df["Country code"] == country_code].iloc[0]["2020"] * 1000
    return pop
  except AttributeError:
    return float("nan")

In [0]:
#@title Inspect the first rows to verify the data was loaded correctly
daily_df.head()

In [0]:
#@title map colum names
# This was added because ECDC changed first letter from uppercase to lowercase on March 27
def upper_first(word: str):
  """Uppercase first letter but leave rest alone (.capitalize() lowercases rest)"""
  out = f"{word[0].upper()}{word[1:]}"
  return out

daily_df.columns = map(upper_first, daily_df.columns)

In [0]:
#@title First rows after column mapping
daily_df.head()

## Preprocessing

1. Drop countries/regions with few deaths or few confirmed cases
1. Drop redundant columns
1. Optional: Drop regions that don't report the full period of Dec 31 up until now

In [0]:
min_confirmed = 1000 #@param {type:"integer"}
min_deaths = 2 #@param {type:"integer"}
drop_incomplete = False #@param ["False", "True"] {type:"raw"}

daily_df["Date"] = daily_df["DateRep"].copy().astype(str)
pd.to_datetime(daily_df["DateRep"], format='%Y-%m-%d')
min_date = daily_df.DateRep.values.min()
max_date = daily_df.DateRep.values.max()

total_counts = daily_df.groupby(["GeoId"]).sum()[["Cases", "Deaths"]]
current_total_confirmed = total_counts.Cases.values
current_total_deaths = total_counts.Deaths.values

# Get geo ids that contain sufficient data for modeling
analysis_geo_ids = total_counts.index.values[np.where(np.logical_and(
    current_total_confirmed >= min_confirmed,
    current_total_deaths >= min_deaths))[0]]

country_col_name = "CountriesAndTerritories"
analysis_df = daily_df[daily_df["GeoId"].isin(analysis_geo_ids)][
    ["Date", "DateRep", "Cases", "Deaths", "GeoId", country_col_name]]
if drop_incomplete:
  # Drop countries where the initial reported data is not the first date or where
  # the last reported data is not the last date
  not_first = analysis_geo_ids[analysis_df.groupby(
      ["GeoId"]).min().DateRep > min_date]
  not_last = analysis_geo_ids[analysis_df.groupby(
      ["GeoId"]).max().DateRep < max_date]
  print(f"Dropped not first date geo ids: {not_first}")
  print(f"Dropped not last date geo ids: {not_last}")
  analysis_geo_ids = np.setdiff1d(analysis_geo_ids, not_first)
  analysis_geo_ids = np.setdiff1d(analysis_geo_ids, not_last)

  analysis_df = analysis_df[analysis_df["GeoId"].isin(analysis_geo_ids)]

  # Drop countries with gaps in the reported data
  reporting_counts = analysis_df["GeoId"].value_counts()
  num_days = reporting_counts.max()
  analysis_geo_ids = reporting_counts.index.values[np.where(
      reporting_counts == num_days)[0]]
  analysis_df = analysis_df[analysis_df["GeoId"].isin(analysis_geo_ids)]
else:
  # Add zero rows for countries with incomplete data
  sorted_days = np.sort(np.unique(analysis_df.Date))
  num_days = sorted_days.size
  add_dfs = []
  for g in analysis_geo_ids:
    country_df = analysis_df[analysis_df["GeoId"] == g]
    missing_days = np.setdiff1d(sorted_days, country_df.Date)
    if missing_days.size:
      add_data = pd.DataFrame(data={
          'Date': missing_days,
          'DateRep': pd.to_datetime(missing_days, format='%Y-%m-%d'),
          'Cases': 0,
          'Deaths': 0,
          'GeoId': g,
          country_col_name: country_df[country_col_name].values[0],
      })
      add_dfs.append(add_data)
  
  analysis_df = pd.concat([analysis_df] + add_dfs)
  analysis_df = analysis_df.sort_values(['GeoId', 'DateRep'],
                                        ascending=[True, False])

# Reorder data such that it is chronologically increasing
analysis_df = analysis_df.iloc[::-1]

analysis_geo_ids = np.sort(analysis_geo_ids)
num_countries = analysis_geo_ids.size
print(f"Num analysed countries: {num_countries}")
print(f"Analysed geo ids: {analysis_geo_ids}")
analysis_countries = np.sort(np.unique(analysis_df[[country_col_name]].values))
print(f"Analysed countries: {analysis_countries}")
analysis_df.index = np.arange(analysis_df.shape[0])

# Add model features (days since start)
analysis_df["days_since_start"] = np.tile(np.arange(num_days), num_countries)
analysis_df["country_id"] = np.repeat(np.arange(num_countries), num_days)

# Add population data
analysis_df["Population"] = analysis_df.apply(lambda row: get_pop(row), axis=1)
analysis_df["Deaths_pc"] = analysis_df["Deaths"] / analysis_df["Population"]
analysis_df["Cases_pc"] = analysis_df["Cases"] / analysis_df["Population"]

In [0]:
#@title Possibly override the data with John Hopkins data
override_john_hopkins_data = True #@param ["False", "True"] {type:"raw"}

if override_john_hopkins_data:
  !rm time_series_covid19_confirmed_global.csv
  !rm time_series_covid19_deaths_global.csv
  !wget https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_confirmed_global.csv
  !wget https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_deaths_global.csv

  jh_confirmed = pd.read_csv("time_series_covid19_confirmed_global.csv")
  jh_deaths = pd.read_csv("time_series_covid19_deaths_global.csv")
  
  jh_confirmed = jh_confirmed.groupby(['Country/Region']).sum()
  jh_deaths = jh_deaths.groupby(['Country/Region']).sum()

  country_map = {
      "South Korea": "Korea, South",
      "United States Of America": "US",
      "Czech Republic": "Czechia",
      }
  analysis_df[country_col_name] = analysis_df[country_col_name].str.replace("_", " ").str.title()
  for country in np.sort(np.unique(analysis_df[country_col_name].values)):
    jh_country = country_map.get(country, country)
    if country == "China":
      print("Don't use JH for China")
    elif jh_country in jh_confirmed.index:
      jh_confirmed_c = jh_confirmed.iloc[np.where(jh_confirmed.index == jh_country)[0][0], 2:].values
      jh_deaths_c = jh_deaths.iloc[np.where(jh_confirmed.index == jh_country)[0][0], 2:].values
      jh_confirmed_day = np.diff(np.concatenate([np.array([0]), jh_confirmed_c]))
      jh_deaths_day = np.diff(np.concatenate([np.array([0]), jh_deaths_c]))
      country_index = np.where(analysis_df[country_col_name].values == country)[0][0]
      for i in np.arange(jh_confirmed_day.size):
        index = country_index + i + 22 # First day of reporting Jan 22 as Opposed to 31 Dec (22 days offset)
        analysis_df.loc[index, "Cases"] = jh_confirmed_day[i]
        analysis_df.loc[index, "Deaths"] = jh_deaths_day[i]
    else:
      print(f"Name non match: {country}")

  # Drop the data of today if the John Hopkins data is not available yet
  last_jh_day = jh_deaths.columns.values[-1]
  last_jh_datetime = datetime.datetime.strptime(last_jh_day, '%m/%d/%y')
  
  if last_jh_datetime != datetime.datetime.strptime(data_str, '%Y-%m-%d'):
    analysis_df = analysis_df.iloc[analysis_df.Date.values != analysis_df.Date.max()]
    num_days -= 1
    analysis_df.index = np.arange(analysis_df.shape[0])

# Drop data of the last day if all counts are zero
last_day = analysis_df.Date.max()
second_last_day = np.sort(np.unique(analysis_df.Date.values))[-2]
if analysis_df.Deaths.values[analysis_df.Date == last_day].sum() < (
    analysis_df.Deaths.values[analysis_df.Date == second_last_day].sum())/4:
  analysis_df = analysis_df.iloc[analysis_df.Date.values != last_day]
  num_days -= 1
  analysis_df.index = np.arange(analysis_df.shape[0])

analysis_df["Deaths_pc"] = analysis_df["Deaths"] / analysis_df["Population"]
analysis_df["Cases_pc"] = analysis_df["Cases"] / analysis_df["Population"]

## Helpers

In [0]:
title_font_size = 14 #@param {type:"integer"}
axis_label_font_size = 14 #@param {type:"integer"}
tick_font_size = 10 #@param {type:"integer"}
legend_font_size = 8 #@param {type:"integer"}
annotation_font_size = 8 #@param {type:"integer"}

In [0]:
#@title Find first greater than helper function
@jit(nopython=True)
def find_first_greater_than(value, vec) -> int:
    """return the index of the first occurrence greater than value in vec"""
    for i in range(len(vec)):
        if vec[i] > value:
            return i
    return -1

In [0]:
#@title Plot saving setup
def plotly_offline_nojs(data, filename):
  with open(filename, "w") as out:
        js = plotly.offline.plot(data, include_plotlyjs=False, output_type="div")

        out.write(
            f"""
<html>
<head><meta charset="utf-8" />
    <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
</head>
<body>
    <div>
        <script type="text/javascript">window.PlotlyConfig = {{MathJaxConfig: 'local'}};</script>
        {js}
    </div>
</body>
</html>"""
        )

def save_plot_html(fig, plot_name):
  # plotly.offline.plot(fig, filename='plots/' + plot_name)
  # plotly_offline_nojs(fig, filename='plots/' + "no_js-" + plot_name)
  plotly_offline_nojs(fig, filename='plots/' + plot_name)

! rm -rf plots/
! mkdir plots
# ! rm -rf forecast_plots/
# ! mkdir forecast_plots

In [0]:
#@title Sorted by key helper
class NameGeoIDTuple(NamedTuple):
  country_name: str
  geo_id: str

def create_sorted_geoid_tuples(
    geo_ids: np.ndarray,
    df: pd.DataFrame,
    sort_key: str = "country_name",
) -> List[NameGeoIDTuple]:
  plot_tuple_lst = []
  for geo_id in geo_ids:
    country_df = df[df.GeoId.values == geo_id]
    country = df[country_col_name].values[df.GeoId.values == geo_id][0]
    plot_tuple = NameGeoIDTuple(
        geo_id=geo_id,
        country_name=country
      )
    plot_tuple_lst.append(plot_tuple)

  plot_tuple_lst = sorted(plot_tuple_lst, key=attrgetter(sort_key))
  return plot_tuple_lst

In [0]:
#@title Sorted by key with shifted start helper
class PlotTuple(NamedTuple):
  geo_id: str
  start_idx: int
  start_ts: pd.Timestamp
  country_name: str

def create_sorted_plot_tuples(
    geo_ids: np.ndarray,
    df: pd.DataFrame,
    start_value: int,
    start_col: str,
    per_100k: bool = False,
    sort_key: str = "country_name",
) -> List[PlotTuple]:
  plot_tuple_lst = []
  per_100k_str = " per 100k" if per_100k else ""
  for i, geo_id in enumerate(geo_ids):
    country_df = df[df.GeoId.values == geo_id]
    start_col_values = df[start_col].values[df.GeoId.values == geo_id]
    start_col_values = np.cumsum(start_col_values)

    if per_100k:
      start_col_values *= 100000

    country = df[country_col_name].values[df.GeoId.values == geo_id][0]
    start_idx = find_first_greater_than(start_value, start_col_values)

    if start_idx == -1:
      print(
        f"{country} did not have more than {start_value} cumulative {start_col}{per_100k_str}. "
          "Dropping from plot."
      )
    else:
      start_date = country_df.iloc[start_idx].Date
      start_date = pd.Timestamp(start_date)
    
      plot_tuple = PlotTuple(
          geo_id=geo_id,
          start_idx=start_idx,
          start_ts=start_date,
          country_name=country
        )
      plot_tuple_lst.append(plot_tuple)

  plot_tuple_lst = sorted(plot_tuple_lst, key=attrgetter(sort_key))
  return plot_tuple_lst

In [0]:
#@title trace color helper
def get_color(i: int) -> str:
  cat_col = sns.color_palette()[i % 10]
  rgb_cols = [str(int(v*256)) for v in list(cat_col)]
  marker_color = 'rgb(' + ', '.join(rgb_cols) + ')'
  return marker_color

In [0]:
#@title create hoverdata helper
def create_hoverdata(x: Sequence, y: Sequence) -> List[Tuple[Any, int]]:
  hoverdata = list(zip(x, y))
  return hoverdata

In [0]:
#@tile hover template formatters
HOVER_TEMPLATE_DATE = "%{text|%b %d}"
hover_template = partial("({x}, {y})".format, x=HOVER_TEMPLATE_DATE)
hover_template_int = hover_template(y="%{y:.d}")
hover_template_float = hover_template(y="%{y:.4f}")

In [0]:
intelecy_caption = f'Generated by <b>Intelecy</b> on {datetime.datetime.utcnow().date()}.<br />See <a href="https://github.com/Intelecy/covid-19-research">GitHub</a> or <a href="https://intelecy.com">https://intelecy.com</a> for more information. <br />License: <a href="https://github.com/Intelecy/covid-19-research/blob/master/LICENSE">CC BY 4.0</a>.'
intelecy_annotation = dict(
    x=0.0,
    y=1.0,
    align="left",
    text=intelecy_caption,
    showarrow=False,
    xref="paper",
    yref="paper",
    xanchor="left",
    yanchor="top",
    xshift=0,
    yshift=0,
    font=dict(size=annotation_font_size),
)

def add_intelecy_annotation(fig) -> None:
  annotation = deepcopy(intelecy_annotation)
  fig.add_annotation(**annotation)

In [0]:
#@title Legend string
name_map = {
    "United States Of America": "US",
    "United Kingdom": "UK",
    "Czech Republic": "Czechia"
}
def legend_string(country, start_ts = None):
  display_strings = []
  display_strings.append(f"{name_map.get(country, country)}")
  if start_ts:
    display_strings.append(f"start: {start_ts.date()}")
  out = " ".join(display_strings)
  return out

In [0]:
#@title default layout
def legend_layout(y: int = 1.0, yanchor: str = 'bottom',) -> Dict:
  legend = dict(
      font=dict(size=legend_font_size),
      orientation='h',
      y=y,
      yanchor=yanchor,
      borderwidth=20,
      bordercolor='hsla(0,0,0,0)',
    )
  return legend

def add_layout(
  fig,
  xtext: str,
  ytext: str,
  legend=legend_layout(),
  plot_title: Optional[str] = None,
  row: Optional[int] = None,
  col: Optional[int] = None,
) -> None:
  margin = dict(l=0, r=0, b=0)
  if plot_title:
    fig.update_layout(title=plot_title, title_x=0.5, title_font=dict(size=title_font_size))
  else:
    margin["t"] = 15  # keep the plotly buttons above the plot/legend
  fig.update_layout(
    margin=margin,
    font=dict(size=tick_font_size),
    legend=legend,
  )
  fig.update_xaxes(automargin=True, title=dict(text=xtext, font=dict(size=axis_label_font_size)), row=row, col=col)
  fig.update_yaxes(automargin=True, title=dict(text=ytext, font=dict(size=axis_label_font_size)), row=row, col=col)
  add_intelecy_annotation(fig)

## EDA


In [0]:
#@title Confirmed daily cases plot
fig = go.Figure()

id_name_lst = create_sorted_geoid_tuples(
    geo_ids=analysis_geo_ids,
    df=analysis_df,
)

for i, id_name_tup in enumerate(id_name_lst):
  geo_id = id_name_tup.geo_id
  country = id_name_tup.country_name
  marker_color = get_color(i)
  subset_ids = analysis_df.GeoId.values == geo_id
  fig.add_trace(go.Scatter(
      x=analysis_df.Date.values[subset_ids],
      y=analysis_df.Cases.values[subset_ids],
      name=legend_string(country),
      mode="lines",
      marker_color=marker_color
  ))

add_layout(fig, xtext="", ytext="Daily confirmed cases")
fig.show()
save_plot_html(fig, 'confirmed_daily_raw.html')

In [0]:
#@title Daily deaths plot
fig = go.Figure()

id_name_lst = create_sorted_geoid_tuples(
    geo_ids=analysis_geo_ids,
    df=analysis_df,
)

for i, id_name_tup in enumerate(id_name_lst):
  geo_id = id_name_tup.geo_id
  country = id_name_tup.country_name
  marker_color = get_color(i)
  country = analysis_df[country_col_name].values[(
      analysis_df.GeoId.values == geo_id)][0]
  subset_ids = analysis_df.GeoId.values == geo_id
  fig.add_trace(go.Scatter(
      x=analysis_df.Date.values[subset_ids],
      y=analysis_df.Deaths.values[subset_ids],
      name=legend_string(country),
      mode='lines',
      marker_color=marker_color
  ))

add_layout(fig, xtext="", ytext='Daily deaths')
fig.show()

save_plot_html(fig, 'deaths_daily_raw.html')

In [0]:
#@title Confirmed daily cases since first deaths plot
start_deaths =  5 #@param {type:"number"}
fig = go.Figure()

plot_tuple_lst = create_sorted_plot_tuples(
    geo_ids=analysis_geo_ids,
    df=analysis_df,
    start_value=start_deaths,
    start_col="Deaths",    
)

for i, plt_tup in enumerate(plot_tuple_lst):
  geo_id = plt_tup.geo_id
  start_idx = plt_tup.start_idx
  start_ts = plt_tup.start_ts
  marker_color = get_color(i)
  current_df = analysis_df[analysis_df.GeoId.values == geo_id]
  country = current_df[country_col_name].iloc[0]
  datetimes = current_df.Date.values
  datetimes = [pd.Timestamp(date) for date in datetimes]
  cases = current_df.Cases.values

  fig.add_trace(go.Scatter(
      x=np.arange(num_days),
      y=cases[start_idx:],
      name=legend_string(country),
      text=datetimes[start_idx:],
      hovertemplate=hover_template_int,
      mode='lines',
      marker_color=marker_color
  ))

add_layout(fig, xtext=f"Days since first {start_deaths} death(s)", ytext="Daily confirmed cases")
fig.show()

save_plot_html(fig, 'confirmed_daily_since_first_deaths.html')

In [0]:
#@title Daily deaths since first deaths plot
start_deaths = 5 #@param {type:"number"}
fig = go.Figure()

plot_tuple_lst = create_sorted_plot_tuples(
    geo_ids=analysis_geo_ids,
    df=analysis_df,
    start_value=start_deaths,
    start_col="Deaths",    
)

for i, plt_tup in enumerate(plot_tuple_lst):
  geo_id = plt_tup.geo_id
  start_idx = plt_tup.start_idx
  start_ts = plt_tup.start_ts
  marker_color = get_color(i)
  current_df = analysis_df[analysis_df.GeoId.values == geo_id]
  country = current_df[country_col_name].iloc[0]
  datetimes = current_df.Date.values
  datetimes = [pd.Timestamp(date) for date in datetimes]
  deaths = current_df.Deaths.values

  fig.add_trace(go.Scatter(
      x=np.arange(num_days),
      y=deaths[start_idx:],
      name=legend_string(country),
      text=datetimes[start_idx:],
      hovertemplate=hover_template_int,
      mode='lines',
      marker_color=marker_color
  ))

add_layout(fig, xtext=f"Days since first {start_deaths} death(s)", ytext="Daily deaths")
add_intelecy_annotation(fig)
fig.show()

save_plot_html(fig, 'deaths_daily_since_first_deaths.html')

In [0]:
#@title Cumulative cases since first deaths plot
start_deaths =  5 #@param {type:"number"}
log_plot = False #@param {type:"boolean"}
fig = go.Figure()

plot_tuple_lst = create_sorted_plot_tuples(
    geo_ids=analysis_geo_ids,
    df=analysis_df,
    start_value=start_deaths,
    start_col="Deaths",    
)

for i, plt_tup in enumerate(plot_tuple_lst):
  geo_id = plt_tup.geo_id
  start_idx = plt_tup.start_idx
  start_ts = plt_tup.start_ts
  marker_color = get_color(i)
  current_df = analysis_df[analysis_df.GeoId.values == geo_id]
  country = current_df[country_col_name].iloc[0]
  datetimes = current_df.Date.values
  datetimes = [pd.Timestamp(date) for date in datetimes]
  cases = current_df.Cases.values

  cases = np.cumsum(cases)

  fig.add_trace(go.Scatter(
      x=np.arange(num_days),
      y=cases[start_idx:],
      name=legend_string(country),
      text=datetimes[start_idx:],
      hovertemplate=hover_template_int,
      mode='lines',
      marker_color=marker_color
  ))


add_layout(
    fig, 
    xtext=f"Days since first {start_deaths} death(s)", 
    ytext="Cumulative confirmed cases"
)

if log_plot:
  fig.update_layout(yaxis_type="log")

fig.show()

save_plot_html(fig, 'confirmed_cumulative_since_first_deaths.html')

In [0]:
#@title Cumulative deaths since first deaths plot
start_deaths = 5 #@param {type:"number"}
log_plot = False #@param {type:"boolean"}
fig = go.Figure()

plot_tuple_lst = create_sorted_plot_tuples(
    geo_ids=analysis_geo_ids,
    df=analysis_df,
    start_value=start_deaths,
    start_col="Deaths",    
)

for i, plt_tup in enumerate(plot_tuple_lst):
  geo_id = plt_tup.geo_id
  start_idx = plt_tup.start_idx
  start_ts = plt_tup.start_ts
  marker_color = get_color(i)
  current_df = analysis_df[analysis_df.GeoId.values == geo_id]
  country = current_df[country_col_name].iloc[0]
  deaths = current_df.Deaths.values
  datetimes = current_df.Date.values
  datetimes = [pd.Timestamp(date) for date in datetimes]

  deaths = np.cumsum(deaths)

  fig.add_trace(go.Scatter(
      x=np.arange(num_days),
      y=deaths[start_idx:],
      name=legend_string(country),
      text=datetimes[start_idx:],
      hovertemplate=hover_template_int,
      mode='lines',
      marker_color=marker_color
  ))


add_layout(
    fig,
    xtext=f"Days since first {start_deaths} death(s)", 
    ytext="Cumulative deaths"
)
if log_plot:
  fig.update_layout(yaxis_type="log")
fig.show()

save_plot_html(fig, 'deaths_cumulative_since_first_deaths.html')

In [0]:
#@title Cumulative death/confirmed case ratio since first deaths plot
start_deaths =  5 #@param {type:"number"}
log_plot = False #@param {type:"boolean"}
fig = go.Figure()

plot_tuple_lst = create_sorted_plot_tuples(
    geo_ids=analysis_geo_ids,
    df=analysis_df,
    start_value=start_deaths,
    start_col="Deaths",    
)

for i, plt_tup in enumerate(plot_tuple_lst):
  geo_id = plt_tup.geo_id
  start_idx = plt_tup.start_idx
  start_ts = plt_tup.start_ts
  marker_color = get_color(i)
  current_df = analysis_df[analysis_df.GeoId.values == geo_id]
  country = current_df[country_col_name].iloc[0]
  datetimes = current_df.Date.values
  datetimes = [pd.Timestamp(date) for date in datetimes]
  deaths = current_df.Deaths.values
  cases = current_df.Cases.values

  cum_cases = np.cumsum(cases)
  cum_deaths = np.cumsum(deaths)

  ratio = cum_deaths / cum_cases

  fig.add_trace(go.Scatter(
      x=np.arange(num_days),
      y=ratio[start_idx:],
      name=legend_string(country),
      text=datetimes[start_idx:],
      hovertemplate=hover_template_float,
      mode='lines',
      marker_color=marker_color
  ))

add_layout(
    fig, 
    xtext=f"Days since first {start_deaths} death(s)", 
    ytext="Cumulative death/confirmed case ratio"
)
if log_plot:
  fig.update_layout(yaxis_type="log")
add_intelecy_annotation(fig)
fig.show()

save_plot_html(fig, 'deaths_case_ratio_cumulative_since_first_deaths.html')

## EDA (population relative)


In [0]:
#@title Confirmed cumulative cases per capita since critical mass
start_cases = 1 #@param {type:"number"}
log_plot = False #@param {type:"boolean"}
fig = go.Figure()

plot_tuple_lst = create_sorted_plot_tuples(
    geo_ids=analysis_geo_ids,
    df=analysis_df,
    start_value=start_cases,
    start_col="Cases_pc",
    per_100k=True,
)

for i, plt_tup in enumerate(plot_tuple_lst):
  geo_id = plt_tup.geo_id
  start_idx = plt_tup.start_idx
  start_ts = plt_tup.start_ts
  marker_color = get_color(i)
  current_df = analysis_df[analysis_df.GeoId.values == geo_id]
  country = current_df[country_col_name].iloc[0]
  datetimes = current_df.Date.values
  cases_per_100k = current_df.Cases_pc.values * 100000

  datetimes = [pd.Timestamp(date) for date in datetimes]
  cum_cases_per_100k = np.cumsum(cases_per_100k)

  fig.add_trace(go.Scatter(
      x=np.arange(num_days),
      y=cum_cases_per_100k[start_idx:],
      name=legend_string(country),
      text=datetimes[start_idx:],
      hovertemplate=hover_template_float,
      mode='lines',
      marker_color=marker_color
  ))

add_layout(
    fig, 
    xtext=f"Days since reaching > {start_cases} confirmed case(s) per 100K people",
    ytext="Confirmed cases per 100K people",
)
if log_plot:
  fig.update_layout(yaxis_type="log")
fig.show()

save_plot_html(fig, 'confirmed_normalized_cumulative_since_first_confirmed.html')

In [0]:
#@title Cumulative deaths per capita since critical mass
start_deaths = 0.1 #@param {type:"number"}
log_plot = False #@param {type:"boolean"}
fig = go.Figure()

plot_tuple_lst = create_sorted_plot_tuples(
    geo_ids=analysis_geo_ids,
    df=analysis_df,
    start_value=start_deaths,
    start_col="Deaths_pc",
    per_100k=True,
)

for i, plt_tup in enumerate(plot_tuple_lst):
  geo_id = plt_tup.geo_id
  start_idx = plt_tup.start_idx
  start_ts = plt_tup.start_ts
  marker_color = get_color(i)
  current_df = analysis_df[analysis_df.GeoId.values == geo_id]
  country = current_df[country_col_name].iloc[0]
  datetimes = current_df.Date.values
  deaths_per_100k = current_df.Deaths_pc.values * 100000

  datetimes = [pd.Timestamp(date) for date in datetimes]
  cum_deaths_per_100k = np.cumsum(deaths_per_100k)

  fig.add_trace(go.Scatter(
      x=np.arange(num_days),
      y=cum_deaths_per_100k[start_idx:],
      name=legend_string(country),
      text=datetimes[start_idx:],
      hovertemplate=hover_template_float,
      mode='lines',
      marker_color=marker_color
  ))

add_layout(
    fig, 
    xtext=f"Days since reaching > {start_deaths} death(s) per 100K people",
    ytext="Cumulative deaths per 100K people",
)
if log_plot:
  fig.update_layout(yaxis_type="log")
fig.show()

save_plot_html(fig, 'deaths_normalized_cumulative_since_first_confirmed.html')

In [0]:
#@title New confirmed cases per capita since critical mass
start_cases = 1 #@param {type:"number"}
fig = go.Figure()

plot_tuple_lst = create_sorted_plot_tuples(
    geo_ids=analysis_geo_ids,
    df=analysis_df,
    start_value=start_cases,
    start_col="Cases_pc",
    per_100k=True,
)

for i, plt_tup in enumerate(plot_tuple_lst):
  geo_id = plt_tup.geo_id
  start_idx = plt_tup.start_idx
  start_ts = plt_tup.start_ts
  marker_color = get_color(i)
  current_df = analysis_df[analysis_df.GeoId.values == geo_id]
  country = current_df[country_col_name].iloc[0]
  datetimes = current_df.Date.values
  cases_per_100k = current_df.Cases_pc.values * 100000

  datetimes = [pd.Timestamp(date) for date in datetimes]

  fig.add_trace(go.Scatter(
      x=np.arange(num_days),
      y=cases_per_100k[start_idx:],
      name=legend_string(country),
      text=datetimes[start_idx:],
      hovertemplate=hover_template_float,
      mode='lines',
      marker_color=marker_color
  ))

add_layout(
    fig, 
    xtext=f"Days since since reaching > {start_cases} confirmed case(s) per 100K people",
    ytext="New confirmed cases per 100K people",
)
fig.show()

save_plot_html(fig, 'confirmed_normalized_daily_since_first_confirmed.html')

In [0]:
#@title New deaths per capita since critical mass
start_deaths = 0.1 #@param {type:"number"}
fig = go.Figure()

plot_tuple_lst = create_sorted_plot_tuples(
    geo_ids=analysis_geo_ids,
    df=analysis_df,
    start_value=start_deaths,
    start_col="Deaths_pc",
    per_100k=True,
)

for i, plt_tup in enumerate(plot_tuple_lst):
  geo_id = plt_tup.geo_id
  start_idx = plt_tup.start_idx
  start_ts = plt_tup.start_ts
  marker_color = get_color(i)
  current_df = analysis_df[analysis_df.GeoId.values == geo_id]
  country = current_df[country_col_name].iloc[0]
  datetimes = current_df.Date.values
  deaths_per_100k = current_df.Deaths_pc.values * 100000

  datetimes = [pd.Timestamp(date) for date in datetimes]

  fig.add_trace(go.Scatter(
      x=np.arange(num_days),
      y=deaths_per_100k[start_idx:],
      name=legend_string(country),
      text=datetimes[start_idx:],
      hovertemplate=hover_template_float,
      mode='lines',
      marker_color=marker_color
  ))

add_layout(
    fig, 
    xtext=f"Days since reaching > {start_deaths} death(s) per 100K people",
    ytext="New confirmed deaths per 100K people",
)
fig.show()

save_plot_html(fig, 'deaths_normalized_daily_since_first_confirmed.html')

In [0]:
#@title Optionally download html EDA plots
download_data_plots = False #@param ["False", "True"] {type:"raw"}

zip_generated = False
if download_data_plots:
  if zip_generated:
    !rm plots.zip
  !zip -r plots.zip plots/
  zip_generated = True
  files.download("plots.zip")

## Modeling


In [0]:
np.unique(analysis_df[[country_col_name]].values)

In [0]:
#@title Single country model
studied_country = "United States Of America" #@param ["Norway", "Sweden", "Belgium", "United_States_of_America", "Italy", "France", "China", "Spain", "United Kingdom", "Germany", "Netherlands", "United States Of America"]
normalize_by_population = True #@param ["False", "True"] {type:"raw"}
show_model_logs = True #@param ["False", "True"] {type:"raw"}
count_distribution = "Poisson" #@param ["Normal", "Poisson"] {type:"string"}
num_samples = 2000 #@param {type:"integer"}

logger = logging.getLogger("pymc3")
logger.propagate = show_model_logs
if not show_model_logs:
  logger.setLevel(logging.ERROR)

def fit_pm_model(simple_model_data, count_distribution, normalize_by_population,
                 target="Deaths"):
  with pm.Model() as single_country_model:
    # At what day is it the peak of the epidemic
    center_time = pm.Uniform("center_time", lower=0, upper=num_days*2)
    # Total number of casualties
    if normalize_by_population:
      log_amplitude = pm.Uniform("10_log_amplitude", lower=-9, upper=-4)
    else:
      log_amplitude = pm.Uniform("10_log_amplitude", lower=0, upper=3)
    # Duration of the epidemic
    width = pm.Uniform("width", lower=3, upper=15)
    relative_center_distance = np.abs(
        simple_model_data["days_since_start"].values-center_time)/width
    
    # https://towardsdatascience.com/derivative-of-the-sigmoid-function-536880cf918e
    relative_pred = 4*np.exp(-relative_center_distance)/(
        (1+np.exp(-relative_center_distance))**2)
    # abs_error = pm.Uniform('abs_error', lower=0, upper=0.3)

    start_id = np.where(simple_model_data[target].values > 0)[0][0]
    start_id = max(0, start_id-1000) # Fit data after X days before first death 
    valid_ids = np.arange(start_id, num_days)

    if count_distribution == "Normal":
      min_sigma = 1e-5
      relative_error = pm.Uniform("relative_error", lower=0, upper=0.1)
      d = pm.Normal("d", mu=relative_pred[valid_ids],
                    sigma=min_sigma+relative_error*relative_pred[valid_ids],
                    # sigma=relative_error,#*relative_pred[valid_ids],
                    obsFalseerved=simple_model_data[target].values[valid_ids] / (
                        10**log_amplitude))
    elif count_distribution == "Poisson":
      mean_mult = simple_model_data.Population.values[0] if (
          normalize_by_population) else 1
      d = pm.Poisson("d", mu=relative_pred[valid_ids] * mean_mult * (
          10**log_amplitude),
          observed=simple_model_data["Deaths"].values[valid_ids])
  
    trace = pm.sample(num_samples, tune=num_samples//2, chains=3, cores=2,
                      progressbar=show_model_logs)

  return pm.summary(trace), single_country_model, trace

target = "Deaths_pc" if normalize_by_population else "Deaths"
simple_model_data = analysis_df.iloc[(
    analysis_df[[country_col_name]].values[:, 0] == studied_country)]
param_summary, single_country_model, country_trace = fit_pm_model(
    simple_model_data, count_distribution, normalize_by_population, target)
param_summary

In [0]:
#@title MAP estimate
map_estimate = pm.find_MAP(model=single_country_model)
map_estimate

In [0]:
#@title Plot the model parameters posterior distribution
param_summary.index = ['center_time', '10_log_amplitude', 'width',
                      #  'abs_error',
                       'relative_error'][:len(param_summary.index)]
param_summary_plot = param_summary.copy() 
param_summary_plot.reset_index(drop=False, inplace=True)

err_vals = ((param_summary['hpd_2.5'] - param_summary['mean']).values,
            (param_summary['hpd_97.5'] - param_summary['mean']).values)
ax = param_summary_plot.plot(x='index', y='mean', kind='bar', figsize=(14, 7),
                 title='Posterior Distribution of Model Parameters',
                 yerr=err_vals, color='lightgrey',
                 legend=False, grid=True,
                 capsize=5)
param_summary_plot.plot(x='index', y='mean', color='k', marker='o',
                        linestyle='None', ax=ax, grid=True, legend=False,
                        xlim=plt.gca().get_xlim())

In [0]:
#@title Plot multiple traces, drawn from the posterior distribution
num_posterior_traces = 100 #@param {type:"integer"}
posterior_step_interval = num_samples // num_posterior_traces

def generate_country_posterior(
  trace,
  simple_model_data,
  num_traces=num_posterior_traces,
):
  predict_data = simple_model_data.copy()[["DateRep", "Deaths", target,
                                            "Population"]]
  next_period_predict = predict_data.copy()
  next_period_predict["DateRep"] = predict_data["DateRep"] + datetime.timedelta(
      days=num_days)
  next_period_predict["Deaths"] = np.nan
  next_period_predict[target] = np.nan
  predict_data = pd.concat([predict_data, next_period_predict])
  days_since_start = np.arange(num_days*2)
  predict_data["Date"] = predict_data["DateRep"].astype(str)

  for i in range(num_traces):
    param_center = trace[-(i*posterior_step_interval+1)]["center_time"]
    param_width = trace[-(i*posterior_step_interval+1)]["width"]
    param_log_amplitude = trace[-(i*posterior_step_interval+1)]["10_log_amplitude"]

    rel_distance = np.abs(days_since_start-param_center)/param_width
    rel_pred = 4*np.exp(-rel_distance)/((1+np.exp(-rel_distance))**2)
    
    predict_data[f"Posterior{i}"] = rel_pred * (10**param_log_amplitude)
    if normalize_by_population:
      predict_data[f"Posterior{i}"] *= predict_data["Population"]
  
  return predict_data

def plot_country_posterior(
  plot_title,
  predict_data,
  pred_vals=None,
  num_traces=num_posterior_traces,
  save_plot=False,
  cumulative=False,
  show_plot=True,
):
  traces = []
  fig = go.Figure()

  for i in range(num_traces):
    post_pred_vals = predict_data[f"Posterior{i}"].values
    if cumulative:
      post_pred_vals = np.cumsum(post_pred_vals)

    name = "Posterior samples" if i==0 else "Posterior " + str(i)
    traces.append(go.Scatter(
        x=predict_data.Date.values,
        y=post_pred_vals,
        marker_color="grey",
        name=name,
        mode="lines",
        legendgroup="Posterior sample",
        opacity=min(1, 15/num_traces),
        showlegend=(i==0),
    ))

  if pred_vals is not None:
    if cumulative:
      pred_vals = np.cumsum(pred_vals)
    traces.append(go.Scatter(
      x=predict_data.Date.values,
      y=pred_vals,
      marker_color="blue",
      name="Predicted deaths",
      mode="lines",
      ))

  y = np.cumsum(predict_data.Deaths.values) if cumulative else (
      predict_data.Deaths.values)
  traces.append(go.Scatter(
      x=predict_data.Date.values,
      y=y,
      marker_color="red",
      name="Actual deaths",
      mode="lines",
  ))

  fig = go.Figure()
  for t in traces:
    fig.add_trace(t)

  y_title = "Cumulative deaths" if cumulative else "Daily deaths"
  add_layout(
      fig,
      xtext="",
      ytext=y_title,
      legend=legend_layout(y=-0.1, yanchor="top"),
      plot_title=plot_title,
    )

  if show_plot:
    fig.show()

  if save_plot:
    plot_extension = '_cumulative' if cumulative else ''
    plotly_offline_nojs(fig, filename='forecast_plots/' + plot_title + (
        '-posterior_samples') + plot_extension + '.html')

  return traces

posteriors = generate_country_posterior(country_trace, simple_model_data)
plot_country_posterior(studied_country, posteriors);

In [0]:
#@title Plot the single country predicted deaths with the actual death count
MAP_params = False #@param ["False", "True"] {type:"raw"}


def generate_country_forecast(
  MAP_params: bool,
  map_estimate,
  param_summary,
  simple_model_data,
):
  if MAP_params:
    param_center = map_estimate['center_time']
    param_width = map_estimate['width']
    param_log_amplitude = map_estimate['10_log_amplitude']
  else:
    param_center = param_summary['mean'][param_summary.index == 'center_time'].values[0]
    param_width = param_summary['mean'][param_summary.index == 'width'].values[0]
    param_log_amplitude = param_summary['mean'][param_summary.index == '10_log_amplitude'].values[0]

  predict_data = simple_model_data.copy()[['DateRep', 'Deaths', target, 'Population']]
  next_period_predict = predict_data.copy()
  next_period_predict['DateRep'] = predict_data['DateRep'] + datetime.timedelta(
      days=num_days)
  next_period_predict['Deaths'] = np.nan
  next_period_predict[target] = np.nan
  predict_data = pd.concat([predict_data, next_period_predict])
                          
  days_since_start = np.arange(num_days*2)
  rel_distance = np.abs(days_since_start-param_center)/param_width
  rel_pred = 4*np.exp(-rel_distance)/((1+np.exp(-rel_distance))**2)
  predict_data['Date'] = predict_data['DateRep'].astype(str)
  predict_data['Prediction'] = rel_pred * (10**param_log_amplitude)
  if normalize_by_population:
    predict_data['Prediction'] *= predict_data['Population']

  return predict_data

def plot_country_forecast(
  plot_title,
  predict_data,
  save_plot=False,
  cumulative=False,
  show_plot=True,
):
  if cumulative:
    pred_vals = np.cumsum(predict_data.Prediction.values)
    actual_vals = np.cumsum(predict_data.Deaths.values)
  else:
    pred_vals = predict_data.Prediction.values
    actual_vals = predict_data.Deaths.values

  traces = [
  go.Scatter(
      x=predict_data.Date.values,
      y=pred_vals,
      name="Predicted deaths",
      mode='lines',
      ),
  go.Scatter(
      x=predict_data.Date.values,
      y=actual_vals,
      name="Actual deaths",
      mode='lines',
      )
  ]

  fig = go.Figure()
  for t in traces:
    fig.add_trace(t)

  y_title = "Cumulative deaths" if cumulative else "Daily deaths"
  add_layout(
      fig,
      xtext="",
      ytext=y_title,
      legend=legend_layout(y=-0.1, yanchor="top"),
      plot_title=plot_title,
    )


  if show_plot:
    fig.show()

  if save_plot:
    plot_extension = '_cumulative' if cumulative else ''
    plotly_offline_nojs(fig, filename=f'forecast_plots/{plot_title}{plot_extension}.html')
  
  return traces

forecast_data = generate_country_forecast(
    MAP_params, map_estimate, param_summary, simple_model_data
)
plot_country_forecast(studied_country, forecast_data);

In [0]:
#@title Optional: generate forecasts and store the plots for all countries
generate_forecast_plots = True #@param ["False", "True"] {type:"raw"}
download_forecast_plots = False #@param ["False", "True"] {type:"raw"}


class ForecastData(NamedTuple):
  param_summary: pd.DataFrame
  single_country_model: pm.Model
  country_trace: pm.backends.base.MultiTrace
  map_estimate: Dict
  pred_vals: pd.Series
  posterior_samples: pd.DataFrame

countries_plot_data: Dict[str, ForecastData] = {}

geo_name_tup_lst = create_sorted_geoid_tuples(geo_ids=analysis_geo_ids, df=analysis_df)

if generate_forecast_plots:
  ! rm -R forecast_plots/
  ! mkdir forecast_plots

  num_countries = len(geo_name_tup_lst)
  for i, geo_name_tup in enumerate(geo_name_tup_lst):
    geo_id = geo_name_tup.geo_id
    studied_country = geo_name_tup.country_name
    print("Creating forecasts for country {} of {}: {}".format(
        i+1, num_countries, studied_country))
    simple_model_data = analysis_df[analysis_df.GeoId.values == geo_id]
    with ipython_io.capture_output() as captured:
      param_summary, single_country_model, country_trace = fit_pm_model(
          simple_model_data, count_distribution="Poisson",
          normalize_by_population=True,
        )
      map_estimate = pm.find_MAP(model=single_country_model)
    forecast_data = generate_country_forecast(
      MAP_params=False,
      map_estimate=map_estimate,
      param_summary=param_summary,
      simple_model_data=simple_model_data,
    )
    pred_vals = forecast_data["Prediction"]
    for cumulative in [False, True]:
      plot_country_forecast(
        studied_country,
        forecast_data,
        save_plot=True,
        cumulative=cumulative,
      )
    
    posteriors = generate_country_posterior(country_trace, simple_model_data)
    for cumulative in [False, True]:
      plot_country_posterior(
        studied_country,
        posteriors,
        save_plot=True,
        pred_vals=pred_vals,
        cumulative=cumulative,
      )
    country_plot_data = ForecastData(
      param_summary=param_summary,
      single_country_model=single_country_model,
      country_trace=country_trace,
      map_estimate=map_estimate,
      pred_vals=pred_vals,
      posterior_samples=posteriors,
    )
    countries_plot_data[studied_country] = country_plot_data

In [0]:
#@title Single plotly combined plot of the best guess forecasts
if generate_forecast_plots:
  for cumulative in [False, True]:
    plot_countries = countries_plot_data.keys()
    num_plot_countries = len(plot_countries)
    fig = go.Figure()

    for i, country in enumerate(plot_countries):
      print(f"Adding traces for country {i+1} of {num_plot_countries}")
      country_plot_data = countries_plot_data[country]

      traces = plot_country_posterior(
        country,
        country_plot_data.posterior_samples,
        save_plot=False,
        pred_vals=country_plot_data.pred_vals,
        cumulative=cumulative,
        show_plot=False,
      )

      # Add all traces to the correct subplot
      marker_color = get_color(i)
      for j, t in enumerate(traces):
        name = legend_string(country)
        if j == len(traces)-2:
          name += " Forecast"
          line_dash = "dot"
        elif j == len(traces)-1:
          name += " Actual"
          line_dash = None
        else:
          continue
        t.legendgroup = name
        t.name = name
        t.showlegend = True
        t.marker.color = marker_color
        t.line.dash = line_dash
        fig.add_trace(t)

    y_title = "Cumulative deaths" if cumulative else "Daily deaths"
    add_layout(fig, xtext="", ytext=y_title)

    fig.show()

    if generate_forecast_plots:
      plot_extension = "_cumulative" if cumulative else ""
      plotly_offline_nojs(
        fig,
        filename=f"forecast_plots/single_plot_combined{plot_extension}.html",
      )

In [0]:
#@title Combined grid plot
max_combined_posterior_traces = 20 #@param {type:"integer"}
if generate_forecast_plots:
  combined_posterior_plot_mod_step = int(max(
      1, num_posterior_traces/max_combined_posterior_traces))
  for cumulative in [False, True]:
    plot_countries = countries_plot_data.keys()
    num_plot_countries = len(plot_countries)
    num_cols = 2
    num_rows = int(np.ceil(num_plot_countries / num_cols))
    fig = plotly.subplots.make_subplots(rows=num_rows, cols=num_cols,
                                        subplot_titles=tuple(plot_countries))

    for i, country in enumerate(plot_countries):
      print(f"Adding traces for country {i+1} of {num_plot_countries}")
      country_plot_data = countries_plot_data[country]

      traces = plot_country_posterior(
        country,
        country_plot_data.posterior_samples,
        save_plot=False,
        pred_vals=country_plot_data.pred_vals,
        cumulative=cumulative,
        show_plot=False,
      )

      # Add all traces to the correct subplot
      row = 1 + i // num_cols
      col = 1 + i % num_cols
      for j, t in enumerate(traces):
        if j == len(traces)-2:
          name = "Forecast"
        elif j == len(traces)-1:
          name = "Actual deaths"
        else:
          # Potentially skip the posterior samples
          # Only relevant if max_combined_posterior_traces < num_posterior_traces
          if j % combined_posterior_plot_mod_step != 0:
            continue
          name = "Posterior samples"


        t.legendgroup = name
        t.name = name
        t.showlegend = (i==0 and j >= (len(traces)-3))
        fig.add_trace(t, row=row, col=col)
    
      y_title = "Cumulative deaths" if cumulative else "Daily deaths"
      # Drop the y legend for non first columns
      if col != 1:
        y_title=""
      add_layout(fig, xtext="", ytext=y_title, row=row, col=col)
    fig.update_layout(height=300*num_rows)
    
    fig.show()

    if generate_forecast_plots:
      plot_extension = "_cumulative" if cumulative else ""
      plotly_offline_nojs(
        fig,
        filename=f"forecast_plots/combined{plot_extension}.html"
      )

In [0]:
if download_forecast_plots:
  !rm forecast_plots-{today_str}.zip
  !zip -r forecast_plots-{today_str}.zip forecast_plots/
  for _ in range(10):
    try:
      files.download(f"forecast_plots-{today_str}.zip")
      break
    except:
      time.sleep(10)