In [None]:
import pickle
from glob import glob

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.lines import Line2D

from utils.generic.config import read_config

from viz import plot_forecast
from utils.generic.enums.columns import *
from viz.utils import axis_formatter

# Mumbai Performance Plots

In [None]:
ROOT_DIR = '../../misc/reports'

folder_name = 'start_period'
with open(f'{ROOT_DIR}/{folder_name}/predictions_dict.pkl', 'rb') as f:
    start_predictions_dict = pickle.load(f)
    
folder_name = 'mid_period'
with open(f'{ROOT_DIR}/{folder_name}/predictions_dict.pkl', 'rb') as f:
    mid_predictions_dict = pickle.load(f)
    
folder_name = 'end_period'
config_filename = glob(f'{ROOT_DIR}/{folder_name}/*.yaml')[0]
config = read_config(config_filename)

with open(f'{ROOT_DIR}/{folder_name}/predictions_dict.pkl', 'rb') as f:
    end_predictions_dict = pickle.load(f)

In [None]:
def return_test_error(predictions_dict):
    em_forecast = copy.deepcopy(predictions_dict['m2']['forecasts']['ensemble_mean'])
    df_train = copy.deepcopy(predictions_dict['m2']['df_train'])
    df_district = copy.deepcopy(predictions_dict['m2']['df_district'])
    df_district = df_district[df_district['date'] > predictions_dict['m2']['df_train'].iloc[-1]['date']].reset_index(drop=True)
    em_forecast = em_forecast[em_forecast['date'] > predictions_dict['m2']['df_train'].iloc[-1]['date']].reset_index(drop=True)
    df_district = df_district[df_district['date'] <= em_forecast.iloc[-1]['date']].reset_index(drop=True)

    lc = Loss_Calculator()
    loss_dict = lc.calc_loss_dict(em_forecast, df_district, method='mape')
    
    return loss_dict

## RC Params

In [None]:
plt.rcParams.update({
    'text.usetex': True,
    'font.size': 15,
    'font.family': 'Palatino',
})

# 2•6 grid

In [None]:
config['plotting']['separate_compartments_separate_ax'] = True

fig, axs = plt.subplots(figsize=(30, 10), nrows=2, ncols=6)

plot_forecast(start_predictions_dict, 
              config['fitting']['data']['dataloading_params']['location_description'],
              which_compartments=config['fitting']['loss']['loss_compartments'],
              fits_to_plot=['ensemble_mean'], smoothed_gt=True,
              plotting_config=config['plotting'], figsize=(13, 12), axs=axs[0:, 0:2])

plot_forecast(mid_predictions_dict, 
              config['fitting']['data']['dataloading_params']['location_description'],
              which_compartments=config['fitting']['loss']['loss_compartments'],
              fits_to_plot=['ensemble_mean'], smoothed_gt=True,
              plotting_config=config['plotting'], figsize=(13, 12), axs=axs[0:, 2:4])

plot_forecast(end_predictions_dict, 
              config['fitting']['data']['dataloading_params']['location_description'],
              which_compartments=config['fitting']['loss']['loss_compartments'],
              fits_to_plot=['ensemble_mean'], smoothed_gt=True,
              plotting_config=config['plotting'], figsize=(13, 12), axs=axs[0:, 4:6])

plt.figtext(0.25,0.92, "(a) Starting Phase", va="center", ha="center", size=30)
plt.figtext(0.52,0.92, "(b) Middle Phase", va="center", ha="center", size=30)
plt.figtext(0.78,0.92, "(c) End Phase", va="center", ha="center", size=30)

legend_elements = [
    Line2D([0], [0], ls='-', marker='o', ms=5, color='black', label='Ground Truth'),
    Line2D([0], [0], ls='-', color='black', label='EM Forecast'),
    Line2D([0], [0], ls='--', color='black', label='Training Range')
]
axs[0, 0].legend(handles=legend_elements)

# 1•3 grid


### Toggle `twin_axes`, `log_scale` in the `plot_forecast` forecast function

In [None]:
config['plotting']['separate_compartments_separate_ax'] = False

fig, axs = plt.subplots(figsize=(30, 8), nrows=1, ncols=3)

plot_forecast(start_predictions_dict, 
              config['fitting']['data']['dataloading_params']['location_description'],
              which_compartments=config['fitting']['loss']['loss_compartments'], twin_axes=False,
              fits_to_plot=['ensemble_mean'], smoothed_gt=True, log_scale=False,
              plotting_config=config['plotting'], axs=axs.flat[0])

plot_forecast(mid_predictions_dict, 
              config['fitting']['data']['dataloading_params']['location_description'],
              which_compartments=config['fitting']['loss']['loss_compartments'], twin_axes=False,
              fits_to_plot=['ensemble_mean'], smoothed_gt=True, log_scale=False,
              plotting_config=config['plotting'], axs=axs.flat[1])

plot_forecast(end_predictions_dict, 
              config['fitting']['data']['dataloading_params']['location_description'],
              which_compartments=config['fitting']['loss']['loss_compartments'], twin_axes=False,
              fits_to_plot=['ensemble_mean'], smoothed_gt=True, log_scale=False,
              plotting_config=config['plotting'], axs=axs.flat[2])

plt.figtext(0.25,0.92, "(a) Starting Phase", va="center", ha="center", size=30)
plt.figtext(0.52,0.92, "(b) Middle Phase", va="center", ha="center", size=30)
plt.figtext(0.78,0.92, "(c) End Phase", va="center", ha="center", size=30)

legend_elements = [
    Line2D([0], [0], ls='-', marker='o', ms=5, color='black', label='Ground Truth'),
    Line2D([0], [0], ls='-', color='black', label='EM Forecast'),
    Line2D([0], [0], ls='--', color='black', label='Training Range')
]
first_legend = axs[0].legend(handles=legend_elements, loc='upper left')
axs[0].add_artist(first_legend) 

legend_elements = [
    Line2D([0], [0], ls='-', color='C0', label=f'Confirmed'),
    Line2D([0], [0], ls='-', color='orange', label=f'Active'),
    Line2D([0], [0], ls='-', color='green', label=f'Recovered'),
    Line2D([0], [0], ls='-', color='red', label=f'Deceased'),
]
axs[0].legend(handles=legend_elements, loc=[0.35, 0.77])