In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from glob import glob
import pickle
import copy
import os
import yaml
from datetime import timedelta

import sys
sys.path.append('../../')

from data.processing.processing import get_data_from_tracker

from utils.fitting.loss import Loss_Calculator
from utils.generic.config import read_config
from utils.fitting.smooth_jump import smooth_big_jump

from viz.data import plot_data
from viz import plot_forecast
from utils.generic.enums.columns import *
from viz.utils import setup_plt, axis_formatter

# Smoothing Mumbai

In [None]:
config_filename = 'default.yaml'
config = read_config(config_filename)

In [None]:
df = get_data_from_tracker(state='Maharashtra', district='Mumbai', use_dataframe='data_all')['data_frame']
df = df[df['date'] <= '2020-07-31']

In [None]:
# Perform smoothing
print('smoothing params', config['fitting']['data']['smooth_jump_params'])
df_smooth, _ = smooth_big_jump(df, config['fitting']['data']['smooth_jump_params'])

In [None]:
import matplotlib as mpl
## for Palatino and other serif fonts use:
plt.rcParams.update({
    'text.usetex': True,
    'font.size': 20,
    'font.family': 'Palatino',
 })

In [None]:
fig, ax = plt.subplots(figsize=(12, 12))
for comp in ['active', 'total', 'recovered', 'deceased']:
    compartment = Columns.from_name(comp)
    ax.plot(df[compartments['date'][0].name].to_numpy(), df[compartment.name].to_numpy(),
            '-o', color=compartment.color, label='Simulated Data, Unspiked ({})'.format(compartment.label))
    ax.plot(df_smooth[compartments['date'][0].name].to_numpy(), df_smooth[compartment.name].to_numpy(),
            '-', color=compartment.color, label='Smoothed Data ({})'.format(compartment.label))
legend_elements = [
    Line2D([0], [0], ls='-', marker='o', ms=5, color='black', label='Original Data'),
    Line2D([0], [0], ls='-', color='black', label='Smoothed Data')
]
first_legend = ax.legend(handles=legend_elements, loc='upper left')
ax.add_artist(first_legend)
legend_elements = [
    Line2D([0], [0], ls='-', color='C0', label=f'Confirmed Cases'),
    Line2D([0], [0], ls='-', color='orange', label=f'Active Cases'),
    Line2D([0], [0], ls='-', color='green', label=f'Recovered'),
    Line2D([0], [0], ls='-', color='red', label=f'Deceased'),
]
ax.legend(handles=legend_elements, loc=[0.01, 0.7])
axis_formatter(ax)
ax.set_title('Smoothing Algorithm Illustration for Mumbai, India')
plt.tight_layout()
fig.savefig(f'../../../paper/plots/smoothing-mumbai.pdf', format='pdf', bbox_inches='tight', pad_inches=0)

# Mumbai Performance Plots

In [None]:
folder_name = 'start_period'
config_filename = glob(f'../../misc/reports/{folder_name}/*.yaml')[0]
config = read_config(config_filename)

with open(f'../../misc/reports/{folder_name}/predictions_dict.pkl', 'rb') as f:
    start_predictions_dict = pickle.load(f)
    
folder_name = 'mid_period'
config_filename = glob(f'../../misc/reports/{folder_name}/*.yaml')[0]
config = read_config(config_filename)

with open(f'../../misc/reports/{folder_name}/predictions_dict.pkl', 'rb') as f:
    mid_predictions_dict = pickle.load(f)
    
folder_name = 'end_period'
config_filename = glob(f'../../misc/reports/{folder_name}/*.yaml')[0]
config = read_config(config_filename)

with open(f'../../misc/reports/{folder_name}/predictions_dict.pkl', 'rb') as f:
    end_predictions_dict = pickle.load(f)

In [None]:
def return_test_error(predictions_dict):
    em_forecast = copy.deepcopy(predictions_dict['m2']['forecasts']['ensemble_mean'])
    df_train = copy.deepcopy(predictions_dict['m2']['df_train'])
    df_district = copy.deepcopy(predictions_dict['m2']['df_district'])
    df_district = df_district[df_district['date'] > predictions_dict['m2']['df_train'].iloc[-1]['date']].reset_index(drop=True)
    em_forecast = em_forecast[em_forecast['date'] > predictions_dict['m2']['df_train'].iloc[-1]['date']].reset_index(drop=True)
    df_district = df_district[df_district['date'] <= em_forecast.iloc[-1]['date']].reset_index(drop=True)

    lc = Loss_Calculator()
    loss_dict = lc.calc_loss_dict(em_forecast, df_district, method='mape')
    
    return loss_dict

In [None]:
plt.rcParams.update({
    'text.usetex': True,
    'font.size': 15,
    'font.family': 'Palatino',
})

In [None]:
fig, axs = plt.subplots(figsize=(30, 12), nrows=2, ncols=6)

plot_forecast(start_predictions_dict, 
              config['fitting']['data']['dataloading_params']['location_description'],
              which_compartments=config['fitting']['loss']['loss_compartments'],
              fits_to_plot=['ensemble_mean'], error_bars=False, smoothed_gt=True,
              plotting_config=config['plotting'], figsize=(13, 12), axs=axs[0:, 0:2])

plot_forecast(mid_predictions_dict, 
              config['fitting']['data']['dataloading_params']['location_description'],
              which_compartments=config['fitting']['loss']['loss_compartments'],
              fits_to_plot=['ensemble_mean'], error_bars=False, smoothed_gt=True,
              plotting_config=config['plotting'], figsize=(13, 12), axs=axs[0:, 2:4])

plot_forecast(end_predictions_dict, 
              config['fitting']['data']['dataloading_params']['location_description'],
              which_compartments=config['fitting']['loss']['loss_compartments'],
              fits_to_plot=['ensemble_mean'], error_bars=False, smoothed_gt=True,
              plotting_config=config['plotting'], figsize=(13, 12), axs=axs[0:, 4:6])

plt.figtext(0.25,0.92, "(a) Starting Phase", va="center", ha="center", size=30)
plt.figtext(0.52,0.92, "(b) Middle Phase", va="center", ha="center", size=30)
plt.figtext(0.78,0.92, "(c) End Phase", va="center", ha="center", size=30)

legend_elements = [
    Line2D([0], [0], ls='-', marker='o', ms=5, color='black', label='Ground Truth'),
    Line2D([0], [0], ls='-', color='black', label='EM Forecast'),
    Line2D([0], [0], ls='--', color='black', label='Training Range')
]
axs[0, 0].legend(handles=legend_elements)

In [None]:
config['plotting']['separate_compartments_separate_ax'] = False

In [None]:
fig, axs = plt.subplots(figsize=(30, 8), nrows=1, ncols=3)

plot_forecast(start_predictions_dict, 
              config['fitting']['data']['dataloading_params']['location_description'],
              which_compartments=config['fitting']['loss']['loss_compartments'], twin_axes=True,
              fits_to_plot=['ensemble_mean'], smoothed_gt=True, log_scale=False,
              plotting_config=config['plotting'], figsize=(13, 12), axs=axs.flat[0])

plot_forecast(mid_predictions_dict, 
              config['fitting']['data']['dataloading_params']['location_description'],
              which_compartments=config['fitting']['loss']['loss_compartments'], twin_axes=True,
              fits_to_plot=['ensemble_mean'], smoothed_gt=True, log_scale=False,
              plotting_config=config['plotting'], figsize=(13, 12), axs=axs.flat[1])

plot_forecast(end_predictions_dict, 
              config['fitting']['data']['dataloading_params']['location_description'],
              which_compartments=config['fitting']['loss']['loss_compartments'], twin_axes=True,
              fits_to_plot=['ensemble_mean'], smoothed_gt=True, log_scale=False,
              plotting_config=config['plotting'], figsize=(13, 12), axs=axs.flat[2])

plt.figtext(0.25,0.92, "(a) Starting Phase", va="center", ha="center", size=30)
plt.figtext(0.52,0.92, "(b) Middle Phase", va="center", ha="center", size=30)
plt.figtext(0.78,0.92, "(c) End Phase", va="center", ha="center", size=30)

legend_elements = [
    Line2D([0], [0], ls='-', marker='o', ms=5, color='black', label='Ground Truth'),
    Line2D([0], [0], ls='-', color='black', label='EM Forecast'),
    Line2D([0], [0], ls='--', color='black', label='Training Range')
]
first_legend = axs[0].legend(handles=legend_elements, loc='upper left')
axs[0].add_artist(first_legend) 

legend_elements = [
    Line2D([0], [0], ls='-', color='C0', label=f'Confirmed'),
    Line2D([0], [0], ls='-', color='orange', label=f'Active'),
    Line2D([0], [0], ls='-', color='green', label=f'Recovered'),
    Line2D([0], [0], ls='-', color='red', label=f'Deceased'),
]
axs[0].legend(handles=legend_elements, loc=[0.35, 0.77])

In [None]:
fig.savefig('../../../paper/plots/bombay-plots-twin-axes.pdf', format='pdf', bbox_inches='tight', pad_inches=0)

In [None]:
return_test_error(start_predictions_dict)