In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import matplotlib
import numpy as np
import pandas as pd
from tqdm import tqdm
import multiprocessing as mp
import matplotlib.pyplot as plt
import pickle
import multiprocessing as mp
import arviz as az

matplotlib.rcParams.update({'font.size': 18})
from joblib import delayed, Parallel
from collections import defaultdict, OrderedDict
import pymc3 as pm
from pymc3.ode import DifferentialEquation
from utils.generic import init_params
from main.seir.optimiser import Optimiser
from models.seir.seir_testing import SEIR_Testing
from data.processing import get_district_time_series
from data.dataloader import get_covid19india_api_data
from theano.ifelse import ifelse
from theano import tensor as T
from theano import tensor as T, function, printing
from theano import function
import theano
theano.config.compute_test_value='ignore'
theano.config.gcc.cxxflags = "-Wno-c++11-narrowing"



## Load covid19 data

In [None]:
dataframes = get_covid19india_api_data()

In [None]:
dataframes.keys()

In [None]:
regions = [('Delhi', ''), ('Karnataka', 'Bengaluru Urban'), ('Maharashtra', 'Mumbai'), ('Maharashtra', 'Pune'), ('Gujarat', 'Ahmedabad'), ('Rajasthan', 'Jaipur')]
state, district = regions[2]
df_district = get_district_time_series(dataframes, state=state, district=district, use_dataframe='districts_daily')

## Create train-val splits

In [None]:
df_train = df_district.iloc[:-5, :]
df_val = df_district.iloc[-5:, :]

In [None]:
df_train, df_val

In [None]:
df_train.to_csv('df_train.csv')
df_val.to_csv('df_val.csv')

## Loss Calculation Functions

In [None]:
def _calc_rmse(y_pred, y_true, log=True):
    if log:
        y_true = np.log(y_true)
        y_pred = np.log(y_pred)
    loss = np.sqrt(np.mean((y_true - y_pred)**2))
    return loss

def _calc_mape(y_pred, y_true):
    y_pred = y_pred[y_true > 0]
    y_true = y_true[y_true > 0]

    ape = np.abs((y_true - y_pred + 0) / y_true) *  100
    loss = np.mean(ape)
    return loss

def calc_loss_dict(states_time_matrix, df, method='rmse', rmse_log=False):
    pred_hospitalisations = states_time_matrix[6] + states_time_matrix[7] + states_time_matrix[8]
    pred_recoveries = states_time_matrix[9]
    pred_fatalities = states_time_matrix[10]
    pred_infectious_unknown = states_time_matrix[2] + states_time_matrix[4]
    pred_total_cases = pred_hospitalisations + pred_recoveries + pred_fatalities
    
    if method == 'rmse':
        if rmse_log:
            calculate = lambda x, y : _calc_rmse(x, y)
        else:
            calculate = lambda x, y : _calc_rmse(x, y, log=False)
    
    if method == 'mape':
            calculate = lambda x, y : _calc_mape(x, y)
    
    losses = {}
#     losses['hospitalised'] = calculate(pred_hospitalisations, df['Hospitalised'])
#     losses['recovered'] = calculate(pred_recoveries, df['Recovered'])
#     losses['fatalities'] = calculate(pred_fatalities, df['Fatalities'])
#     losses['active_infections'] = calculate(pred_infectious_unknown, df['Active Infections (Unknown)'])
    losses['total'] = calculate(pred_total_cases, df['total_infected'])
    
    return losses

def calc_loss(states_time_matrix, df, method='rmse', rmse_log=False):
    losses = calc_loss_dict(states_time_matrix, df, method, rmse_log)
#     loss = losses['hospitalised'] + losses['recovered'] + losses['total'] + losses['active_infections']
    loss = losses['total']
    return loss

## Prediction Interval calculator

In [None]:
def get_PI(pred_dfs, date, key, multiplier=1.96):
    pred_samples = list()
    for df in pred_dfs:
        pred_samples.append(df.loc[date, key])
        
    mu = np.array(pred_samples).mean()
    sigma = np.array(pred_samples).std()
    low = mu - multiplier*sigma
    high = mu + multiplier*sigma
    return mu, low, high

## Define new class

In [None]:


class SEIR_Test_pymc3(SEIR_Testing):
    def __init__(self,  *args, **kwargs):
        super().__init__( *args, **kwargs)
    def get_derivative(self, y, t, p):
        # Init state variables
        #for i, _ in enumerate(y):
        #for i in range(11):
        #    y[i] = ifelse(T.lt(y[i], 0), y[i], np.float64(0))
        #    y[i] = max(y[i], 0)
        zero = T.cast(0.0, 'float64')
        for i in range(11):
            T.set_subtensor(y[i], ifelse(T.gt(y[i], zero), y[i], zero))
        # Init time parameters and probabilities
        for key in self.vanilla_params:
            setattr(self, key, self.vanilla_params[key])
        for key in self.testing_params:
            suffix = '_D' if key in self.vanilla_params else ''
            setattr(self, key + suffix, self.testing_params[key])
            
        
        ## Set up variables using `y` and `p`
        
        S = y[0]
        E = y[1]
        I = y[2]
        D_E = y[3]
        D_I = y[4]
        R_mild = y[5]
        R_severe_home = y[6]
        R_severe_hosp = y[7]
        R_fatal = y[8]
        C = y[9]
        D = y[10]
        
        # p
    
        self.R0 = p[0]
        self.T_inc = p[1]
        self.T_inf = p[2]
        self.T_recov_severe = p[3]
        self.P_severe = p[4]
        self.P_fatal = p[5]
        self.intervention_amount = p[6]
        
        #Define variables  
        #if self.post_lockdown_R0 == None:
        #    self.post_lockdown_R0 = self.lockdown_R0

        self.P_mild = 1 - self.P_severe - self.P_fatal

        # define testing related parameters
        self.T_inf_detected = self.T_inf
        self.T_inc_detected = self.T_inc

        self.P_mild_detected = self.P_mild
        self.P_severe_detected = self.P_severe
        self.P_fatal_detected = self.P_fatal
        #self.T_trans_D = self.T_trans
  
        self.theta_E = self.testing_rate_for_exposed
        self.psi_E = self.positive_test_rate_for_exposed
        self.theta_I = self.testing_rate_for_infected
        self.psi_I = self.positive_test_rate_for_infected
        #TODO incorporate lockdown R0 code
        #T.set_subtensor(self.R0, ifelse(T.gt(t, self.lockdown_removal_day), self.R0 , self.post_lockdown_R0))
        # Modelling the behaviour lockdown
        #elif t >= self.lockdown_day:
        #    self.R0 = self.lockdown_R0
        #T.set_subtensor(self.R0, ifelse(T.gt(t, self.lockdown_day), self.R0, self.lockdown_R0))
        # Modelling the behaviour pre-lockdown
        #else:
        #    self.R0 = self.pre_lockdown_R0
        #T.set_subtensor(self.R0, ifelse(T.gt(y[i], zero), self.R0, self.pre_lockdown_R0))
        self.T_trans = self.T_inf/self.R0
        self.T_trans_D = self.T_inf_D/self.R0
        
       
        # Write differential equations
        dS = - I * S / (self.T_trans) - (self.q / self.T_trans_D) * (S * D_I) # # S
        #dS = - y[2] * y[0]*p[0]/p[2]  - self.q*p[2] * (y[0] * y[4])
        dE = I * S / (self.T_trans) + (self.q / self.T_trans_D) * (S * D_I) - (E/ self.T_inc) - (self.theta_E * self.psi_E * E) # E
        dI = E / self.T_inc - I / self.T_inf - (self.theta_I * self.psi_I * I) # I
        dD_E = (self.theta_E * self.psi_E * E) - (1 / self.T_inc_D) * D_E# D_E
        dD_I = (self.theta_I * self.psi_I * I) + (1 / self.T_inc_D) * D_E - (1 / self.T_inf_D) * D_I # D_I 
        dR_mild = (1/self.T_inf)*(self.P_mild*I) + (1/self.T_inf_D)*(self.P_mild_D*D_I) - R_mild/self.T_recov_mild  # R_mild
        dR_severe_home = (1/self.T_inf)*(self.P_severe*I) + (1/self.T_inf_D)*(self.P_severe_D*D_I) - R_severe_home/self.T_hosp  # R_severe_home
        dR_severe_hosp = R_severe_home/self.T_hosp - R_severe_hosp/self.T_recov_severe# R_severe_hosp
        dR_fatal = (1/self.T_inf)*(self.P_fatal*I) + (1/self.T_inf_D)*(self.P_fatal_D*D_I) - R_fatal/self.T_death # R_fatal
        dC = R_mild/self.T_recov_mild + R_severe_hosp/self.T_recov_severe # C
        dD = R_fatal/self.T_death # D

        return [dS, dE, dI, dD_E, dD_I, dR_mild, dR_severe_home, dR_severe_hosp, dR_fatal, dC, dD]
    
    def init_intermediate(self, variable_params, default_params, df_true, start_date=None, end_date=None, 
              state_init_values=None, initialisation='starting', loss_indices=[-20, -10]):
        params_dict = {**variable_params, **default_params}
        if initialisation == 'intermediate':
            row = df_true.iloc[loss_indices[0], :]
            
            state_init_values = OrderedDict()
            key_order = ['S', 'E', 'I', 'D_E', 'D_I', 
                'R_mild', 'R_severe_home', 'R_severe_hosp', 'R_fatal', 'C', 'D']
            for key in key_order:
                state_init_values[key] = 0

            state_init_values['R_severe_hosp'] = params_dict['P_severe'] / (params_dict['P_severe'] + params_dict['P_fatal']) * row['hospitalised']
            state_init_values['R_fatal'] = params_dict['P_fatal'] / (params_dict['P_severe'] + params_dict['P_fatal']) * row['hospitalised']
            state_init_values['C'] = row['recovered']
            state_init_values['D'] = row['deceased']

            state_init_values['E'] = params_dict['E_hosp_ratio'] * row['hospitalised']
            state_init_values['I'] = params_dict['I_hosp_ratio'] * row['hospitalised']
            
            nonSsum = sum(state_init_values.values())
            state_init_values['S'] = (params_dict['N'] - nonSsum)
            for key in state_init_values.keys():
                state_init_values[key] = state_init_values[key]/params_dict['N']

            params_dict['state_init_values'] = state_init_values
        
   


In [None]:
x = T.scalar('x')
z = T.scalar('z')
xplus = ifelse(T.lt(x, z), x, z)
xplus.eval({x:1,z:0})

# Set up model parameters

In [None]:
SEIR_Test_obj = SEIR_Test_pymc3()
num_patients = SEIR_Test_obj.__dict__['vanilla_params']['N']
init_vals = list(SEIR_Test_obj.__dict__['state_init_values'].values())
num_states = 11
num_params = 7
num_steps = 40
num_train_steps = 7


burn_in = 100
mcmc_steps = 400

observed = df_train['total_infected'][-num_train_steps:]
num_train = len(df_train)

In [None]:
init_vals

In [None]:
observed 

# Run model

In [None]:
sir_model = DifferentialEquation(
    func=SEIR_Test_obj.get_derivative,
    times=np.arange(0, num_steps, 1),
    n_states= num_states,
    n_theta= num_params,
    t0 = 0
)

In [None]:
with pm.Model() as model:
    R0 = pm.Uniform("R0", lower = 1, upper = 3.5)#(1.6, 3)
    T_inc = pm.Uniform("T_inc", lower = 1, upper = 5)#(3, 4)
    T_inf = pm.Uniform("T_inf", lower = 1, upper = 4)#(3, 4)
    T_recov_severe = pm.Uniform("T_recov_severe ", lower = 9, upper = 20)
    P_severe = pm.Uniform("P_severe", lower = 0.3, upper = 0.99)
    P_fatal = pm.Uniform("P_fatal", lower = 1e-4, upper = 0.3)
    intervention_amount = pm.Uniform("intervention_amount", lower = 0.3, upper = 1)
    
    ode_solution = sir_model(y0=init_vals , theta=[R0, T_inc, T_inf, T_recov_severe, P_severe,
                                                   P_fatal, intervention_amount])
    # The ode_solution has a shape of (n_times, n_states)
    
    predictions = ode_solution[num_train-num_train_steps-1:num_train-1]
    hospitalised = predictions[:,6] + predictions[:,7] + predictions[:,8]
    recovered = predictions[:,9]
    deceased = predictions[:,10]
    total_infected = hospitalised + recovered + deceased
    total_infected = total_infected * num_patients 
    #sigma = pm.HalfNormal('sigma',
    #                      sigma=observed.std(),
    #                      shape=num_params)
    Y = pm.Normal('Y', mu = total_infected, observed=observed)
    
    prior = pm.sample_prior_predictive()
    trace = pm.sample(mcmc_steps, tune=burn_in , target_accept=0.9, cores=4)
    posterior_predictive = pm.sample_posterior_predictive(trace)
    
    

In [None]:
 theano.printing.Print("Predictions")(ode_solution[:,-1])

In [None]:
theano.printing.Print("R0")(R0)

In [None]:
trace

# Analyse runs

In [None]:
with model:
    data = az.from_pymc3(trace=trace, prior=prior, posterior_predictive=posterior_predictive)
    az.plot_posterior(data,round_to=2, credible_interval=0.95)

In [None]:
#pm.forestplot(trace)

In [None]:
pm.summary(trace)

In [None]:
pm.plots.traceplot(trace)

In [None]:
pm.plots.autocorrplot(trace)

In [None]:
len(trace[500:])

In [None]:
default_params

In [None]:
final_runs = trace#[burn_in:]

## Visualize the samples and intervals

In [None]:
#def visualize(): 
data_split = df_district.copy()
optimiser = Optimiser()
default_params = optimiser.init_default_params(data_split)

#combined_acc = list()
#for k, run in enumerate(mcmc):
#    burn_in = int(len(run) / 2)
#    combined_acc += run[0][burn_in:]

n_samples = 1000
sample_indices = np.random.uniform(0, len(final_runs), n_samples)

pred_dfs = list()
for i in tqdm(sample_indices):
    pred_dfs.append(optimiser.solve(final_runs[int(i)], 
                default_params, data_split, 
                initialisation = 'intermediate', 
                start_date = data_split.iloc[-num_train_steps, :].date,
                end_date= data_split.iloc[-1, :].date,
                hardcode_ratios = True, loss_indices = [-num_train_steps,0]))

for df in pred_dfs:
    df.set_index('date', inplace=True)

result = pred_dfs[0].copy()
for col in result.columns:
    result["{}_low".format(col)] = ''
    result["{}_high".format(col)] = ''

for date in tqdm(pred_dfs[0].index):
    for key in pred_dfs[0]:
        result.loc[date, key], result.loc[date, "{}_low".format(key)], result.loc[date, "{}_high".format(key)] = get_PI(pred_dfs, date, key)

data_split.set_index("date", inplace=True)


In [None]:
final_runs[int(i)],

In [None]:
data_split['total_infected']

In [None]:
result['hospitalised']

In [None]:
pred_dfs[-1]

# Plot graphs

In [None]:
#result['total_infected'], data_split['total_infected']

In [None]:
#x_plot = range(len(df_train) - len(observed)-7, len(df_train) - len(observed) + num_steps)
plt.figure(figsize=(15, 10))
plt.plot(data_split['total_infected'].tolist(), c='g', label='Actual')
plt.plot( result['total_infected'].tolist(), c='r', label='Estimated')
plt.plot( result['total_infected_low'].tolist(), c='r', linestyle='dashdot')
plt.plot( result['total_infected_high'].tolist(), c='r', linestyle='dashdot')
plt.axvline(x=len(df_train), c='b', linestyle='dashed')
plt.xlabel("Day")
plt.ylabel("Total infected")
plt.legend()
plt.title("95% confidence intervals for {}, {}".format(district, state))

plt.savefig('./mcmc_confidence_intervals_{}_{}.png'.format(district, state))
plt.show()

In [None]:
visualize()