In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import seaborn as sns
from hyperopt import hp, tpe, fmin, Trials
from tqdm import tqdm

from collections import OrderedDict, defaultdict
import itertools
from functools import partial
import datetime
from joblib import Parallel, delayed
import copy
import sys
sys.path.append('../../')

from data.dataloader import Covid19IndiaLoader
from models.seir import SEIR_Testing
from main.seir.optimiser import Optimiser
from utils.fitting.loss import Loss_Calculator
from utils.generic.enums import Columns
from utils.fitting.smooth_jump import smooth_big_jump
from viz import plot_smoothing, plot_fit

from main.ihme.fitting import single_cycle
from main.seir.main import data_setup
from utils.fitting.data import cities


In [None]:

def get_regional_data(dataframes, state, district, data_from_tracker, data_format, filename, granular_data=False, 
                      smooth_jump=False, smoothing_length=28, smoothing_method='uniform', t_recov=14, #Smoothing params
                      return_extra=False, **ihme):
    if granular_data:
        df_district = granular.get_data(filename=filename)
    else:
        if data_from_tracker:
            df_district = get_data(dataframes, state=state, district=district, use_dataframe='districts_daily')
        else:
            df_district = get_data(state=state, district=district, disable_tracker=True, filename=filename, 
                                data_format=data_format)
    
    df_district_raw_data = get_data(dataframes, state=state, district=district, use_dataframe='raw_data')
    ax = None
    orig_df_district = copy.copy(df_district)

    # RUN IHME and MERGE DATA
    ihme_res = single_cycle(district, state, smooth_jump=smooth_jump, smoothing_length=smoothing_length, smoothing_method=smoothing_method, **ihme)
    new_dates = pd.date_range(start=df_district['date'].max(), end=df_district['date'].max()+datetime.timedelta(days=val_period))
    pd.concat(df_district.set_index, ihme_res['df_district'].set_index('date').loc[new_dates, df_district['date'].columns], axis=1).reset_index()

    if smooth_jump:
        df_district = smooth_big_jump(
            df_district, smoothing_length=smoothing_length, 
            method=smoothing_method, data_from_tracker=data_from_tracker, t_recov=t_recov)
        ax = plot_smoothing(orig_df_district, df_district, state, district, description=f'Smoothing: {smoothing_method}')

    if return_extra:
        extra = {
            'ax': ax,
            'df_district_unsmoothed': orig_df_district
            }
        return df_district, df_district_raw_data, extra 
    return df_district, df_district_raw_data 


In [None]:

def single_fitting_cycle(dataframes, state, district, model=SEIR_Testing, variable_param_ranges=None, #Main 
                         data_from_tracker=True, granular_data=False, filename=None, data_format='new', #Data
                         train_period=7, val_period=7, num_evals=1500, N=1e7, initialisation='starting', #Misc
                         which_compartments=['active', 'total'], #Compartments
                         smooth_jump=False, smoothing_length=28, smoothing_method='uniform', **ihme): #Smoothing
    # record parameters for reproducability
    run_params = locals()
    del run_params['dataframes']
    run_params['model'] = model.__name__
    
    print('Performing {} fit ..'.format('m2' if val_period == 0 else 'm1'))

    # Get data
    df_district, df_district_raw_data, extra = get_regional_data(
        dataframes, state, district, 
        data_from_tracker, data_format, filename, granular_data, 
        smooth_jump=smooth_jump, smoothing_method=smoothing_method, 
        smoothing_length=smoothing_length, return_extra=True, **ihme
    )
    smoothed_plot = extra['ax']
    orig_df_district = extra['df_district_unsmoothed']

    # Process the data to get rolling averages and other stuff
    observed_dataframes = data_setup(df_district, df_district_raw_data, val_period)

    print('train\n', observed_dataframes['df_train'].tail())
    print('val\n', observed_dataframes['df_val'])
    
    predictions_dict = run_cycle(
        state, district, observed_dataframes, 
        model=model, variable_param_ranges=variable_param_ranges,
        data_from_tracker=data_from_tracker, train_period=train_period, 
        which_compartments=which_compartments, N=N,
        num_evals=num_evals, initialisation=initialisation
    )

    if smoothed_plot != None:
        predictions_dict['smoothing_plot'] = smoothed_plot
    predictions_dict['df_district_unsmoothed'] = orig_df_district

    # record parameters for reproducibility
    predictions_dict['run_params'] = run_params

    return predictions_dict


In [None]:
dist, st, an = cities['mumbai']
from utils.generic.config import read_config
config, model_params = read_config('../../scripts/ihme/config/mumbai.yaml')
print(config)
dlobj = Covid19IndiaLoader()
dataframes = dlobj.pull_dataframes_cached()
get_regional_data(dataframes, 'Maharashtra', 'Mumbai', False, None, None, granular_data=False, 
                  return_extra=False, **{'area_names': an,  'model_params': model_params}, **config)