# Importing required libraries

In [34]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import os
from tqdm.notebook import tqdm
from scipy.integrate import solve_ivp
from scipy.optimize import minimize
from sklearn.metrics import mean_squared_log_error, mean_squared_error
import datetime

In [35]:
os.chdir('/home/arkadeep/Desktop/COVID')

## Data processing

In [36]:
state_data = pd.read_csv('https://api.covid19india.org/csv/latest/state_wise_daily.csv')
state_data['Date'] = pd.to_datetime(state_data['Date'], infer_datetime_format = True, format = '%d-%m-%y')

In [37]:
pop_data = pd.read_csv('State_population.csv')
abb_dict = dict(zip(pop_data['Abbreviation'], pop_data['State or union territory']))
abb_dict['TT'] = 'India'
abb_pop_dict = dict(zip(pop_data['Abbreviation'], pop_data['Population']))
abb_pop_dict['TT'] = sum(abb_pop_dict.values())

In [38]:
df_conf = state_data[state_data['Status'] == 'Confirmed']
df_recov = state_data[state_data['Status'] == 'Recovered']
df_dead = state_data[state_data['Status'] == 'Deceased']

In [39]:
for df in [df_conf, df_recov, df_dead]:
    for name in df.columns[2:-1]:
        df[name + '_cumsum'] = df[name].cumsum()

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  This is separate from the ipykernel package so we can avoid doing imports until


In [45]:
data = pd.DataFrame(columns= ['Area', 'Date', 'Id', 'Province_State', 'Country_Region','ConfirmedCases','Fatalities','Recovered'])
#Preparing the dataframe for existing data
for st_name in list(state_data.columns)[2:-1]:
    temp_data = pd.DataFrame(columns= ['Area', 'Date', 'Id', 'Province_State', 'Country_Region','ConfirmedCases','Fatalities','Recovered'],  
                         index = list(range(data.shape[0] + 0 , data.shape[0] + df_conf[st_name].shape[0])))
    temp_data['Area'] = st_name
    temp_data['Date'] = list(df_conf['Date'])
    temp_data['Id'] = list(temp_data.index)
    temp_data['Province_State'] = abb_dict[st_name]
    temp_data['Country_Region'] = 'India'
    temp_data['ConfirmedCases'] = list(df_conf[st_name + '_cumsum'])
    temp_data['Fatalities'] = list(df_dead[st_name + '_cumsum'])
    temp_data['Recovered'] = list(df_recov[st_name + '_cumsum'])
    data = pd.concat([data, temp_data])

valid_days = 7    
TEST_MIN_DATE = pd.Timestamp(data['Date'].max() - datetime.timedelta(days= valid_days))
DATE_BORDER = data['Date'].max()
train_full = data[data['Date'] < DATE_BORDER]
train = data[data['Date'] < TEST_MIN_DATE]
valid = data[(data['Date'] >= TEST_MIN_DATE) & (data['Date'] <= DATE_BORDER)]

forecast_days = 10
base = data['Date'].max()
date_list = [pd.Timestamp(base + datetime.timedelta(days=x+1)) for x in range(forecast_days)]

#Preparing the dataframe for existing data
for st_name in list(state_data.columns)[2:-1]:
    temp_data = pd.DataFrame(columns= ['Area', 'Date', 'Id', 'Province_State', 'Country_Region','ConfirmedCases','Fatalities','Recovered'],  
                         index = list(range(data.shape[0] + 0 , data.shape[0] + forecast_days)))
    temp_data['Area'] = st_name
    temp_data['Date'] = date_list
    temp_data['Id'] = list(temp_data.index)
    temp_data['Province_State'] = abb_dict[st_name]
    temp_data['Country_Region'] = 'India'
    #temp_data['ConfirmedCases'] = list(df_conf[st_name + '_cumsum'])
    #temp_data['Fatalities'] = list(df_dead[st_name + '_cumsum'])
    #temp_data['Recovered'] = list(df_recov[st_name + '_cumsum'])
    data = pd.concat([data, temp_data])

test = data[data['Date'] >= TEST_MIN_DATE]
test = test.rename(columns={'Id': 'ForecastId'})
    
#DATE_BORDER_2 = data['Date'].max()

# Split the test into public & private
test_public = test[test['Date'] <= DATE_BORDER]
test_private = test[test['Date'] > DATE_BORDER]

submission = pd.DataFrame(columns = ['ForecastId', 'ConfirmedCases', 'Fatalities'])
submission['ForecastId'] = list(test['ForecastId'])
submission['ConfirmedCases'] = 0
submission['Fatalities'] = 0
submission = submission.set_index(['ForecastId'])
#print(submission)

# Use a multi-index for easier slicing
train_full.set_index(['Area', 'Date'], inplace=True)
train.set_index(['Area', 'Date'], inplace=True)
valid.set_index(['Area', 'Date'], inplace=True)
test_public.set_index(['Area', 'Date'], inplace=True)
test_private.set_index(['Area', 'Date'], inplace=True)



In [48]:
test_public.loc['WB'] 

Unnamed: 0_level_0,ForecastId,Province_State,Country_Region,ConfirmedCases,Fatalities,Recovered
Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
2020-05-18,2766,West Bengal,India,2824,244,1006
2020-05-19,2767,West Bengal,India,2961,250,1074
2020-05-20,2768,West Bengal,India,3103,253,1136
2020-05-21,2769,West Bengal,India,3197,259,1193
2020-05-22,2770,West Bengal,India,3332,265,1221
2020-05-23,2771,West Bengal,India,3459,269,1281
2020-05-24,2772,West Bengal,India,3667,272,1339
2020-05-25,2773,West Bengal,India,3816,278,1414


In [49]:
test_private.loc['WB']

Unnamed: 0_level_0,ForecastId,Province_State,Country_Region,ConfirmedCases,Fatalities,Recovered
Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
2020-05-26,3144,West Bengal,India,,,
2020-05-27,3145,West Bengal,India,,,
2020-05-28,3146,West Bengal,India,,,
2020-05-29,3147,West Bengal,India,,,
2020-05-30,3148,West Bengal,India,,,
2020-05-31,3149,West Bengal,India,,,
2020-06-01,3150,West Bengal,India,,,
2020-06-02,3151,West Bengal,India,,,
2020-06-03,3152,West Bengal,India,,,
2020-06-04,3153,West Bengal,India,,,


# ANALYSIS STARTS: 

In [50]:
# Susceptible equation
def dS_dt(S, I, R_t, t_inf):
    return -(R_t / t_inf) * I * S


# Exposed equation
def dE_dt(S, E, I, R_t, t_inf, t_inc):
    return (R_t / t_inf) * I * S - (E / t_inc)


# Infected equation
def dI_dt(I, E, t_inc, t_inf):
    return (E / t_inc) - (I / t_inf)


# Hospialized equation
def dH_dt(I, C, H, t_inf, t_hosp, t_crit, m_a, f_a):
    return ((1 - m_a) * (I / t_inf)) + ((1 - f_a) * C / t_crit) - (H / t_hosp)


# Critical equation
def dC_dt(H, C, t_hosp, t_crit, c_a):
    return (c_a * H / t_hosp) - (C / t_crit)


# Recovered equation
def dR_dt(I, H, t_inf, t_hosp, m_a, c_a):
    return (m_a * I / t_inf) + (1 - c_a) * (H / t_hosp)


# Deaths equation
def dD_dt(C, t_crit, f_a):
    return f_a * C / t_crit


def SEIR_HCD_model(t, y, R_t, t_inc=2.9, t_inf=5.2, t_hosp=4, t_crit=14, m_a=0.8, c_a=0.1, f_a=0.3):
    """

    :param t: Time step for solve_ivp
    :param y: Previous solution or initial values
    :param R_t: Reproduction number
    :param t_inc: Average incubation period. Default 5.2 days
    :param t_inf: Average infectious period. Default 2.9 days
    :param t_hosp: Average time a patient is in hospital before either recovering or becoming critical. Default 4 days
    :param t_crit: Average time a patient is in a critical state (either recover or die). Default 14 days
    :param m_a: Fraction of infections that are asymptomatic or mild. Default 0.8
    :param c_a: Fraction of severe cases that turn critical. Default 0.1
    :param f_a: Fraction of critical cases that are fatal. Default 0.3
    :return:
    """
    if callable(R_t):
        reprod = R_t(t)
    else:
        reprod = R_t
        
    S, E, I, R, H, C, D = y
    
    S_out = dS_dt(S, I, reprod, t_inf)
    E_out = dE_dt(S, E, I, reprod, t_inf, t_inc)
    I_out = dI_dt(I, E, t_inc, t_inf)
    R_out = dR_dt(I, H, t_inf, t_hosp, m_a, c_a)
    H_out = dH_dt(I, C, H, t_inf, t_hosp, t_crit, m_a, f_a)
    C_out = dC_dt(H, C, t_hosp, t_crit, c_a)
    D_out = dD_dt(C, t_crit, f_a)
    return [S_out, E_out, I_out, R_out, H_out, C_out, D_out]

In [51]:
def plot_model(solution, title='SEIR+HCD model'):
    sus, exp, inf, rec, hosp, crit, death = solution.y
    
    cases = inf + rec + hosp + crit + death

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16,5))
    fig.suptitle(title)
    
    ax1.plot(sus, 'tab:blue', label='Susceptible');
    ax1.plot(exp, 'tab:orange', label='Exposed');
    ax1.plot(inf, 'tab:red', label='Infected');
    ax1.plot(rec, 'tab:green', label='Recovered');
    ax1.plot(hosp, 'tab:purple', label='Hospitalised');
    ax1.plot(crit, 'tab:brown', label='Critical');
    ax1.plot(death, 'tab:cyan', label='Deceased');
    
    ax1.set_xlabel("Days", fontsize=10);
    ax1.set_ylabel("Fraction of population", fontsize=10);
    ax1.legend(loc='best');
    
    ax2.plot(cases, 'tab:red', label='Cases');    
    ax2.set_xlabel("Days", fontsize=10);
    ax2.set_ylabel("Fraction of population (Cases)", fontsize=10, color='tab:red');
    
    ax3 = ax2.twinx()
    ax3.plot(death, 'tab:cyan', label='Deceased');    
    ax3.set_xlabel("Days", fontsize=10);
    ax3.set_ylabel("Fraction of population (Fatalities)", fontsize=10, color='tab:cyan');


In [52]:
OPTIM_DAYS = 21  # Number of days to use for the optimisation evaluation

In [53]:
# Use a constant reproduction number
def eval_model_const(params, data, population, return_solution=False, forecast_days=0):
    R_0, t_hosp, t_crit, m, c, f = params
    N = population
    n_infected = data['ConfirmedCases'].iloc[0]
    max_days = len(data) + forecast_days
    initial_state = [(N - n_infected)/ N, 0, n_infected / N, 0, 0, 0, 0]
    args = (R_0, 5.6, 2.9, t_hosp, t_crit, m, c, f)
               
    sol = solve_ivp(SEIR_HCD_model, [0, max_days], initial_state, args=args, t_eval=np.arange(0, max_days))
    
    sus, exp, inf, rec, hosp, crit, deaths = sol.y
    
    y_pred_cases = np.clip(inf + rec + hosp + crit + deaths, 0, np.inf) * population
    y_true_cases = data['ConfirmedCases'].values
    y_pred_fat = np.clip(deaths, 0, np.inf) * population
    y_true_fat = data['Fatalities'].values
    
    optim_days = min(OPTIM_DAYS, len(data))  # Days to optimise for
    weights = 1 / np.arange(1, optim_days+1)[::-1]  # Recent data is more heavily weighted
    msle_cases = mean_squared_log_error(y_true_cases[-optim_days:], y_pred_cases[-optim_days:], weights)
    msle_fat = mean_squared_log_error(y_true_fat[-optim_days:], y_pred_fat[-optim_days:], weights)
    
    msle_final = np.mean([msle_cases, msle_fat])
    
    if return_solution:
        return msle_final, sol
    else:
        return msle_final

In [54]:
# Use a Hill decayed reproduction number
def eval_model_decay(params, data, population, return_solution=False, forecast_days=0):
    R_0, t_hosp, t_crit, m, c, f, k, L = params  
    N = population
    n_infected = data['ConfirmedCases'].iloc[0]
    max_days = len(data) + forecast_days
    
    # https://github.com/SwissTPH/openmalaria/wiki/ModelDecayFunctions   
    # Hill decay. Initial values: R_0=2.2, k=2, L=50
    def time_varying_reproduction(t): 
        return R_0 / (1 + (t/L)**k)
    
    initial_state = [(N - n_infected)/ N, 0, n_infected / N, 0, 0, 0, 0]
    args = (time_varying_reproduction, 5.6, 2.9, t_hosp, t_crit, m, c, f)
            
    sol = solve_ivp(SEIR_HCD_model, [0, max_days], initial_state, args=args, t_eval=np.arange(0, max_days))
    
    sus, exp, inf, rec, hosp, crit, deaths = sol.y
    
    y_pred_cases = np.clip(inf + rec + hosp + crit + deaths, 0, np.inf) * population
    y_true_cases = data['ConfirmedCases'].values
    y_pred_fat = np.clip(deaths, 0, np.inf) * population
    y_true_fat = data['Fatalities'].values
    
    optim_days = min(OPTIM_DAYS, len(data))  # Days to optimise for
    weights = 1 / np.arange(1, optim_days+1)[::-1]  # Recent data is more heavily weighted
    
    msle_cases = mean_squared_log_error(y_true_cases[-optim_days:], y_pred_cases[-optim_days:], weights)
    msle_fat = mean_squared_log_error(y_true_fat[-optim_days:], y_pred_fat[-optim_days:], weights)
    msle_final = np.mean([msle_cases, msle_fat])
    
    if return_solution:
        return msle_final, sol
    else:
        return msle_final

In [55]:
def use_last_value(train_data, valid_data, test_data):
    lv = train_data[['ConfirmedCases', 'Fatalities']].iloc[-1].values
    
    forecast_ids = test_data['ForecastId']
    submission.loc[forecast_ids, ['ConfirmedCases', 'Fatalities']] = lv
    
    if valid_data is not None:
        y_pred_valid = np.ones((len(valid_data), 2)) * lv.reshape(1, 2)
        y_true_valid = valid_data[['ConfirmedCases', 'Fatalities']]

        msle_cases = mean_squared_log_error(y_true_valid['ConfirmedCases'], y_pred_valid[:, 0])
        msle_fat = mean_squared_log_error(y_true_valid['Fatalities'], y_pred_valid[:, 1])
        msle_final = np.mean([msle_cases, msle_fat])

        return msle_final

In [68]:
def plot_model_results(y_pred, train_data, state, valid_data=None, res = None, valid_mlse = None):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16,8))
       
    train_data['ConfirmedCases'].plot(label='Confirmed Cases (train)', color='g', ax=ax1)
    y_pred.loc[train_data.index, 'ConfirmedCases'].plot(label='Modeled Cases', color='r', ax=ax1)
    ax3 = y_pred['R'].plot(label='Reproduction number', color='c', linestyle='-', secondary_y=True, ax=ax1)
    ax3.set_ylabel("Reproduction number", fontsize=10, color='c');
        
    train_data['Fatalities'].plot(label='Fatalities (train)', color='g', ax=ax2)
    y_pred.loc[train_data.index, 'Fatalities'].plot(label='Modeled Fatalities', color='r', ax=ax2)
    
    ax1.minorticks_on()
    ax1.yaxis.grid(True, which = 'both')
    ax2.minorticks_on()
    ax2.yaxis.grid(True, which = 'both')
    
    if res is not None:
        fig.suptitle('For ' + str(abb_dict[state]) + ' ' + str(f'R: {res.x[0]:0.3f}, t_hosp: {res.x[1]:0.3f}, t_crit: {res.x[2]:0.3f}, '
              f'm: {res.x[3]:0.3f}, c: {res.x[4]:0.3f}, f: {res.x[5]:0.3f}'), fontsize = 18)
    else:
        fig.suptitle('For ' + str(abb_dict[state]), fontsize = 18)
        
    if valid_data is not None:
        fig.suptitle('For ' + str(abb_dict[state]) + ' ' + str(f'R: {res.x[0]:0.3f}, t_hosp: {res.x[1]:0.3f}, t_crit: {res.x[2]:0.3f}, '
              f'm: {res.x[3]:0.3f}, c: {res.x[4]:0.3f}, f: {res.x[5]:0.3f}, Val_mlse: {valid_mlse:0.3f}'), fontsize = 18)
        ax1.set_title('Confirmed Cases for ' + str(abb_dict[state]) + '(Train and Val)')
        ax2.set_title('Fatalities for ' + str(abb_dict[state]) + '(Train and Val)')
        
        valid_data['ConfirmedCases'].plot(label='Confirmed Cases (valid)', color='g', linestyle=':', ax=ax1)
        valid_data['Fatalities'].plot(label='Fatalities (valid)', color='g', linestyle=':', ax=ax2)
        y_pred.loc[valid_data.index, 'ConfirmedCases'].plot(label='Modeled Cases (forecast)', color='r', linestyle=':', ax=ax1)
        y_pred.loc[valid_data.index, 'Fatalities'].plot(label='Modeled Fatalities (forecast)', color='r', linestyle=':', ax=ax2)
        
        ax1.legend(loc='best')
        fig.savefig('/home/arkadeep/Desktop/COVID/SIER_plots/' + 'Plot for ' + str(abb_dict[state]) + '(Train and Val).png', dpi = 200)
    else:
        ax1.set_title('Confirmed Cases for ' + str(abb_dict[state]) + '(With forecast)')
        ax2.set_title('Fatalities for ' + str(abb_dict[state]) + '(With forecast)')
        
        y_pred.loc[:, 'ConfirmedCases'].plot(label='Modeled Cases (forecast)', color='r', linestyle=':', ax=ax1)
        y_pred.loc[:, 'Fatalities'].plot(label='Modeled Fatalities (forecast)', color='r', linestyle=':', ax=ax2)
        ax1.yaxis.grid(True, which = 'both')
        ax2.yaxis.grid(True, which = 'both')
        ax1.legend(loc='best')
        fig.savefig('/home/arkadeep/Desktop/COVID/SIER_plots/' + 'Plot for ' + str(abb_dict[state]) + '(With forecast).png', dpi = 200) 
    plt.close()   

In [62]:
def fit_model_public(area_name, 
                     initial_guess=[3.6, 4, 14, 0.8, 0.1, 0.3, 2, 50],
                     bounds=((1, 20), # R bounds
                             (0.5, 10), (2, 20), # transition time param bounds
                             (0.5, 1), (0, 1), (0, 1), (1, 5), (1, 100)), # fraction time param bounds
                     make_plot=True):
        
    train_data = train.loc[area_name].query('ConfirmedCases > 0')
    #print("train_data:", train_data)
    valid_data = valid.loc[area_name]
    #print("valid_data:", valid_data)
    test_data = test_public.loc[area_name]  
    #print("test_data:", test_data)
    
    try:
        population = abb_pop_dict[area_name]
    except KeyError:
        print("Key not found")
        return
        
    cases_per_million = train_data['ConfirmedCases'].max() * 10**6 / population
    n_infected = train_data['ConfirmedCases'].iloc[0]
        
    if cases_per_million < 1:
        return use_last_value(train_data, valid_data, test_data)
                
    res_const = minimize(eval_model_const, initial_guess[:-2], bounds=bounds[:-2],
                         args=(train_data, population, False),
                         method='L-BFGS-B')
    
    res_decay = minimize(eval_model_decay, initial_guess, bounds=bounds,
                         args=(train_data, population, False),
                         method='L-BFGS-B')
    
    dates_all = train_data.index.append(test_data.index)
    dates_val = train_data.index.append(valid_data.index)
    
    
    # If using a constant R number is better, use that model
    if res_const.fun < res_decay.fun:
        msle, sol = eval_model_const(res_const.x, train_data, population, True, len(test_data))
        res = res_const
        R_t = pd.Series([res_const.x[0]] * len(dates_val), dates_val)
    else:
        msle, sol = eval_model_decay(res_decay.x, train_data, population, True, len(test_data))
        res = res_decay
        
        # Calculate the R_t values
        t = np.arange(len(dates_val))
        R_0, t_hosp, t_crit, m, c, f, k, L = res.x  
        R_t = pd.Series(R_0 / (1 + (t/L)**k), dates_val)
        
    sus, exp, inf, rec, hosp, crit, deaths = sol.y
    
    y_pred = pd.DataFrame({
        'ConfirmedCases': np.clip(inf + rec + hosp + crit + deaths, 0, np.inf) * population,
        'Fatalities': np.clip(deaths, 0, np.inf) * population,
        'R': R_t,
    }, index=dates_all)
    
    y_pred_valid = y_pred.iloc[len(train_data): len(train_data)+len(valid_data)]
    y_pred_test = y_pred.iloc[len(train_data):]
    y_true_valid = valid_data[['ConfirmedCases', 'Fatalities']]
        
    valid_msle_cases = mean_squared_log_error(y_true_valid['ConfirmedCases'], y_pred_valid['ConfirmedCases'])
    valid_msle_fat = mean_squared_log_error(y_true_valid['Fatalities'], y_pred_valid['Fatalities'])
    valid_msle = np.mean([valid_msle_cases, valid_msle_fat])
    
    if make_plot:
        print(f'State: {abb_dict[area_name]}')
        print(f'Validation MSLE: {valid_msle:0.5f}')
        print(f'R: {res.x[0]:0.3f}, t_hosp: {res.x[1]:0.3f}, t_crit: {res.x[2]:0.3f}, '
              f'm: {res.x[3]:0.3f}, c: {res.x[4]:0.3f}, f: {res.x[5]:0.3f}')
        plot_model_results(y_pred, train_data, area_name, valid_data, res , valid_msle)
        
    # Put the forecast in the submission
    forecast_ids = test_data['ForecastId']
    submission.loc[forecast_ids, ['ConfirmedCases', 'Fatalities']] = y_pred_test[['ConfirmedCases', 'Fatalities']].values
        
    return valid_msle
            

In [63]:
# Fit a model on the full dataset (i.e. no validation)
def fit_model_private(area_name, 
                      initial_guess=[3.6, 4, 14, 0.8, 0.1, 0.3, 2, 50],
                      bounds=((1, 20), # R bounds
                              (0.5, 10), (2, 20), # transition time param bounds
                              (0.5, 1), (0, 1), (0, 1), (1, 5), (1, 100)), # fraction time param bounds
                      make_plot=True):
        
    train_data = train_full.loc[area_name].query('ConfirmedCases > 0')
    test_data = test_private.loc[area_name]
    
    try:
        population = abb_pop_dict[area_name]
    except KeyError:
        print("Key not found")
        return
        
    cases_per_million = train_data['ConfirmedCases'].max() * 10**6 / population
    n_infected = train_data['ConfirmedCases'].iloc[0]
        
    if cases_per_million < 1:
        return use_last_value(train_data, None, test_data)
                
    res_const = minimize(eval_model_const, initial_guess[:-2], bounds=bounds[:-2],
                         args=(train_data, population, False),
                         method='L-BFGS-B')
    
    res_decay = minimize(eval_model_decay, initial_guess, bounds=bounds,
                         args=(train_data, population, False),
                         method='L-BFGS-B')
    
    dates_all = train_data.index.append(test_data.index)
    
    
    # If using a constant R number is better, use that model
    if res_const.fun < res_decay.fun:
        msle, sol = eval_model_const(res_const.x, train_data, population, True, len(test_data))
        res = res_const
        R_t = pd.Series([res_const.x[0]] * len(dates_all), dates_all)
    else:
        msle, sol = eval_model_decay(res_decay.x, train_data, population, True, len(test_data))
        res = res_decay
        
        # Calculate the R_t values
        t = np.arange(len(dates_all))
        R_0, t_hosp, t_crit, m, c, f, k, L = res.x  
        R_t = pd.Series(R_0 / (1 + (t/L)**k), dates_all)
        
    sus, exp, inf, rec, hosp, crit, deaths = sol.y
    
    y_pred = pd.DataFrame({
        'ConfirmedCases': np.clip(inf + rec + hosp + crit + deaths, 0, np.inf) * population,
        'Fatalities': np.clip(deaths, 0, np.inf) * population,
        'R': R_t,
    }, index=dates_all)
    
    y_pred_test = y_pred.iloc[len(train_data):]
    
    if make_plot:
        print(f'R: {res.x[0]:0.3f}, t_hosp: {res.x[1]:0.3f}, t_crit: {res.x[2]:0.3f}, '
              f'm: {res.x[3]:0.3f}, c: {res.x[4]:0.3f}, f: {res.x[5]:0.3f}')
        plot_model_results(y_pred, train_data, area_name, res = res)
        
    # Put the forecast in the submission
    forecast_ids = test_data['ForecastId']
    submission.loc[forecast_ids, ['ConfirmedCases', 'Fatalities']] = y_pred_test[['ConfirmedCases', 'Fatalities']].values
    

In [69]:
#Testing the functions:

fit_model_public('TT')

State: India
Validation MSLE: 0.02375
R: 2.838, t_hosp: 3.347, t_crit: 13.786, m: 0.506, c: 0.187, f: 1.000


0.023750610159423455

In [65]:
# Public Leaderboard
validation_scores = []

for c in tqdm(test_public.index.levels[0].values):
    try:
        score = fit_model_public(c, make_plot=False)
        validation_scores.append({'State': abb_dict[c], 'MSLE': score})
        print(f'{score:0.5f} {abb_dict[c]}')
    except IndexError as e:
        print(abb_dict[c], 'has no cases in train')
    except ValueError as e:
        print(abb_dict[c], e)

validation_scores = pd.DataFrame(validation_scores)
print(f'Mean validation score: {np.sqrt(validation_scores["MSLE"].mean()):0.5f}')

HBox(children=(FloatProgress(value=0.0, max=38.0), HTML(value='')))

0.00000 Andaman and Nicobar Islands
0.00953 Andhra Pradesh
0.02055 Arunachal Pradesh
1.31568 Assam
0.06013 Bihar
1.04685 Chandigarh
0.31391 Chhattisgarh
Daman and Diu has no cases in train
0.00242 Delhi
0.01908 Dadra and Nagar Haveli 
0.65360 Goa
0.00135 Gujarat
0.22987 Himachal Pradesh
0.01696 Haryana
0.11851 Jharkhand
0.01632 Jammu and Kashmir
0.07531 Karnataka
0.04137 Kerala
0.00647 Ladakh
Lakshadweep has no cases in train
0.03301 Maharashtra
0.00966 Meghalaya
1.33780 Manipur[c]
0.16527 Madhya Pradesh
0.00000 Mizoram
Nagaland has no cases in train
0.03369 Odisha
0.03238 Punjab
0.11914 Puducherry
0.04021 Rajasthan
Sikkim has no cases in train
0.05163 Telangana
0.04617 Tamil Nadu
0.15460 Tripura
0.02375 India
0.01012 Uttar Pradesh
0.69698 Uttarakhand
0.04741 West Bengal

Mean validation score: 0.44556


In [66]:
# Find which areas are not being predicted well
validation_scores.sort_values(by=['MSLE'], ascending=False).head(30)

Unnamed: 0,State,MSLE
20,Manipur[c],1.337805
3,Assam,1.315677
5,Chandigarh,1.046848
32,Uttarakhand,0.696978
9,Goa,0.653601
6,Chhattisgarh,0.313906
11,Himachal Pradesh,0.229872
21,Madhya Pradesh,0.165268
29,Tripura,0.154605
25,Puducherry,0.119144


In [67]:
# Private Leaderboard
abb_dict_2 = abb_dict.copy()
for c in tqdm(test_private.index.levels[0].values):
    try:
        score = fit_model_private(c, make_plot=False)
    except IndexError as e:
        print(abb_dict[c], 'has no cases in train')
        del abb_dict_2[c]

HBox(children=(FloatProgress(value=0.0, max=38.0), HTML(value='')))

Daman and Diu has no cases in train
Lakshadweep has no cases in train
Nagaland has no cases in train



In [83]:
abb_dict_2.pop('SK')

'Sikkim'

In [84]:
for name in abb_dict_2.keys():
    print("For the state: ", abb_dict[name])
    fit_model_public(name)
    fit_model_private(name)

For the state:  Uttar Pradesh
State: Uttar Pradesh
Validation MSLE: 0.01012
R: 2.310, t_hosp: 3.840, t_crit: 13.955, m: 0.570, c: 0.173, f: 0.724
R: 2.640, t_hosp: 3.933, t_crit: 13.988, m: 0.511, c: 0.163, f: 0.698
For the state:  Maharashtra
State: Maharashtra
Validation MSLE: 0.03301
R: 2.953, t_hosp: 3.902, t_crit: 13.962, m: 0.502, c: 0.318, f: 0.755
R: 3.010, t_hosp: 3.792, t_crit: 13.926, m: 0.500, c: 0.195, f: 0.997
For the state:  Bihar
State: Bihar
Validation MSLE: 0.06013
R: 2.886, t_hosp: 3.669, t_crit: 13.909, m: 0.501, c: 0.056, f: 0.810
R: 2.861, t_hosp: 3.967, t_crit: 3.778, m: 0.989, c: 1.000, f: 1.000
For the state:  West Bengal
State: West Bengal
Validation MSLE: 0.04741
R: 2.762, t_hosp: 0.500, t_crit: 13.183, m: 0.500, c: 0.441, f: 0.976
R: 2.897, t_hosp: 1.956, t_crit: 13.132, m: 0.790, c: 0.947, f: 1.000
For the state:  Madhya Pradesh
State: Madhya Pradesh
Validation MSLE: 0.16527
R: 2.877, t_hosp: 3.897, t_crit: 13.951, m: 0.531, c: 0.493, f: 0.864
R: 4.457, t_h

In [85]:
submission.round().to_csv('submission.csv')

In [86]:
state_data.columns


Index(['Date', 'Status', 'TT', 'AN', 'AP', 'AR', 'AS', 'BR', 'CH', 'CT', 'DN',
       'DD', 'DL', 'GA', 'GJ', 'HR', 'HP', 'JK', 'JH', 'KA', 'KL', 'LA', 'LD',
       'MP', 'MH', 'MN', 'ML', 'MZ', 'NL', 'OR', 'PY', 'PB', 'RJ', 'SK', 'TN',
       'TG', 'TR', 'UP', 'UT', 'WB', 'UN'],
      dtype='object')

In [88]:
train_full.loc['SK']

Unnamed: 0_level_0,Id,Province_State,Country_Region,ConfirmedCases,Fatalities,Recovered
Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
2020-03-14,2263,Sikkim,India,0,0,0
2020-03-15,2264,Sikkim,India,0,0,0
2020-03-16,2265,Sikkim,India,0,0,0
2020-03-17,2266,Sikkim,India,0,0,0
2020-03-18,2267,Sikkim,India,0,0,0
...,...,...,...,...,...,...
2020-05-20,2330,Sikkim,India,0,0,0
2020-05-21,2331,Sikkim,India,0,0,0
2020-05-22,2332,Sikkim,India,0,0,0
2020-05-23,2333,Sikkim,India,1,0,0


In [89]:
abb_dict_2.keys()

dict_keys(['UP', 'MH', 'BR', 'WB', 'MP', 'TN', 'RJ', 'KA', 'GJ', 'AP', 'OR', 'TG', 'KL', 'JH', 'AS', 'PB', 'CT', 'HR', 'UT', 'HP', 'TR', 'ML', 'MN', 'GA', 'AR', 'MZ', 'DL', 'JK', 'PY', 'CH', 'AN', 'DN', 'LA', 'TT'])

In [92]:
train.loc['SK'].query('ConfirmedCases > 0')

Unnamed: 0_level_0,Id,Province_State,Country_Region,ConfirmedCases,Fatalities,Recovered
Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
