# XGB Validation

This code is for using our XGB models on hold-out validation dataset

## 1. Model Import

In [1]:
import pandas as pd
import datetime
import numpy as np # linear algebra
import seaborn as sns
import matplotlib.pyplot as plt
import xgboost as xgb
from xgboost import plot_importance, plot_tree
from sklearn.metrics import mean_squared_error, mean_absolute_error
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error
import matplotlib as plt
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

In [33]:
# import packages for plot
from __future__ import print_function
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
import plotly.io as pio
import plotly.express as px
import plotly.offline as py
import plotly.graph_objects as go

In [8]:
# import models

model_list_case = []
model_list_death = []

for i in range(7):
    model_case = xgb.Booster()
    model_case.load_model("XGB_model_"+str(i+1)+"_case.model")
    model_death = xgb.Booster()
    model_death.load_model("XGB_model_"+str(i+1)+"_death.model")
    
    model_list_case.append(model_case)
    model_list_death.append(model_death)



## 2. Read in data & Preprocessing

In [3]:
final_data = pd.read_csv("final_data.csv")

In [4]:
# change the date format:
format = '%Y-%m-%d'
Da = []
for index, row in final_data.iterrows():
    #print(row['c1'], row['c2'])
    Da.append(datetime.datetime.strptime(row["Date"], format).date())
final_data["Date"] = Da

In [5]:
# Create lag varaibles
num = range(7)
final_data_list = {}
for i in num:
    final_data_test = final_data.copy(deep=True)
    j = i + 1
    final_data_test["lag"+str(j+0)+"c"] = final_data_test.groupby("State").Daily_Case.shift(j+0)
    final_data_test["lag"+str(j+1)+"c"] = final_data_test.groupby("State").Daily_Case.shift(j+1)
    final_data_test["lag"+str(j+2)+"c"] = final_data_test.groupby("State").Daily_Case.shift(j+2)
    final_data_test["lag"+str(j+3)+"c"] = final_data_test.groupby("State").Daily_Case.shift(j+3)
    final_data_test["lag"+str(j+4)+"c"] = final_data_test.groupby("State").Daily_Case.shift(j+4)
    final_data_test["lag"+str(j+5)+"c"] = final_data_test.groupby("State").Daily_Case.shift(j+5)
    final_data_test["lag"+str(j+6)+"c"] = final_data_test.groupby("State").Daily_Case.shift(j+6)
    
    final_data_test["lag"+str(j+0)+"d"] = final_data_test.groupby("State").Daily_Death.shift(j+0)
    final_data_test["lag"+str(j+1)+"d"] = final_data_test.groupby("State").Daily_Death.shift(j+1)
    final_data_test["lag"+str(j+2)+"d"] = final_data_test.groupby("State").Daily_Death.shift(j+2)
    final_data_test["lag"+str(j+3)+"d"] = final_data_test.groupby("State").Daily_Death.shift(j+3)
    final_data_test["lag"+str(j+4)+"d"] = final_data_test.groupby("State").Daily_Death.shift(j+4)
    final_data_test["lag"+str(j+5)+"d"] = final_data_test.groupby("State").Daily_Death.shift(j+5)
    final_data_test["lag"+str(j+6)+"d"] = final_data_test.groupby("State").Daily_Death.shift(j+6)
    
    final_data_test["lag"+str(j+0)+"t"] = final_data_test.groupby("State").tests_combined_total.shift(j+0)
    final_data_test["lag"+str(j+1)+"t"] = final_data_test.groupby("State").tests_combined_total.shift(j+1)
    final_data_test["lag"+str(j+2)+"t"] = final_data_test.groupby("State").tests_combined_total.shift(j+2)
    final_data_test["lag"+str(j+3)+"t"] = final_data_test.groupby("State").tests_combined_total.shift(j+3)
    final_data_test["lag"+str(j+4)+"t"] = final_data_test.groupby("State").tests_combined_total.shift(j+4)
    final_data_test["lag"+str(j+5)+"t"] = final_data_test.groupby("State").tests_combined_total.shift(j+5)
    final_data_test["lag"+str(j+6)+"t"] = final_data_test.groupby("State").tests_combined_total.shift(j+6)
    
    final_data_test["lag"+str(j+0)+"v"] = final_data_test.groupby("State").People_Fully_Vaccinated.shift(j+0)
    final_data_test["lag"+str(j+1)+"v"] = final_data_test.groupby("State").People_Fully_Vaccinated.shift(j+1)
    final_data_test["lag"+str(j+2)+"v"] = final_data_test.groupby("State").People_Fully_Vaccinated.shift(j+2)
    final_data_test["lag"+str(j+3)+"v"] = final_data_test.groupby("State").People_Fully_Vaccinated.shift(j+3)
    final_data_test["lag"+str(j+4)+"v"] = final_data_test.groupby("State").People_Fully_Vaccinated.shift(j+4)
    final_data_test["lag"+str(j+5)+"v"] = final_data_test.groupby("State").People_Fully_Vaccinated.shift(j+5)
    final_data_test["lag"+str(j+6)+"v"] = final_data_test.groupby("State").People_Fully_Vaccinated.shift(j+6)
    
    
    final_data_test = final_data_test.dropna().copy(deep = True)
    
    final_data_test['Fully_reopen']        = final_data_test['Fully_reopen'].astype(object)
    final_data_test['Mask_Mandate']        = final_data_test['Mask_Mandate'].astype(object)
    final_data_test['Vaccination_or_test'] = final_data_test['Vaccination_or_test'].astype(object)
    final_data_test['State']               = final_data_test['State'].astype(object)
    final_data_test['Region']              = final_data_test['Region'].astype(object)
    final_data_test['Division']            = final_data_test['Division'].astype(object)
    
    final_data_test = final_data_test.drop(columns=['People_Fully_Vaccinated', 'tests_combined_total']).copy(deep = True)

    final_data_list["final_data_"+str(j)] = final_data_test

In [6]:
# Hold out data after 2021-11-15 for validation
sep = [i for i in final_data_list]
thresh = datetime.date(2021, 11, 15)
predict_data_list = {}
for i in sep:
    predict_data_list[i] = final_data_list[i][final_data_list[i]['Date'] >= thresh]
    final_data_list[i] = final_data_list[i][final_data_list[i]['Date'] < thresh]

## 3. Model Validation

In [10]:
# Use our case model on hold-out validation dataset
Statename = sorted(list(set(predict_data_list["final_data_1"]["State"])))
Vali_results_case = []

for i in range(7):
    thresh = datetime.date(2021, 11, 15+i)
    data = predict_data_list["final_data_"+str(i+1)]
    data = data[data["Date"] == thresh].drop(["Date","State","state"], 1)
    data = pd.get_dummies(data)

    # y and X
    y = data["Daily_Case"]
    X = data.drop(["Daily_Case","Daily_Death"],1)

    # xgb.DMatrix
    dpred = xgb.DMatrix(X, label=y) 
    
    # predict
    y_pred = model_list_case[i].predict(dpred)
    y_pred[y_pred<0] = 0
    
    # append
    df = {
        "State": Statename,
        "Date": thresh,
        "True Cases": y,
        "Predict Cases": y_pred
    }
    
    df = pd.DataFrame(df)
    
    Vali_results_case.append(df)
    
Vali_results_case = pd.concat(Vali_results_case).sort_values(["State","Date"])

In [12]:
# Use our death model on hold-out validation dataset
Statename = sorted(list(set(predict_data_list["final_data_1"]["State"])))
Vali_results_death = []

for i in range(7):
    thresh = datetime.date(2021, 11, 15+i)
    data = predict_data_list["final_data_"+str(i+1)]
    data = data[data["Date"] == thresh].drop(["Date","State","state"], 1)
    data = pd.get_dummies(data)

    # y and X
    y = data["Daily_Death"]
    X = data.drop(["Daily_Case","Daily_Death"],1)

    # xgb.DMatrix
    dpred = xgb.DMatrix(X, label=y) 
    
    # predict
    y_pred = model_list_death[i].predict(dpred)
    y_pred[y_pred<0] = 0
    
    # append
    df = {
        "State": Statename,
        "Date": thresh,
        "True Deaths": y,
        "Predict Deaths": y_pred
    }
    
    df = pd.DataFrame(df)
    
    Vali_results_death.append(df)
    
Vali_results_death = pd.concat(Vali_results_death).sort_values(["State","Date"])

In [30]:
Vali_results = pd.concat([Vali_results_case, Vali_results_death[["True Deaths", "Predict Deaths"]]], axis = 1)
US = Vali_results.groupby("Date").sum([["True Cases", "Predict Cases", "True Deaths", "Predict Deaths"]]).reset_index()

US["State"] = "United States"

US = US[["State", "Date", "True Cases", "Predict Cases", "True Deaths", "Predict Deaths"]]

Vali_results = pd.concat([US, Vali_results]).reset_index(drop = True)

In [31]:
Vali_results

Unnamed: 0,State,Date,True Cases,Predict Cases,True Deaths,Predict Deaths
0,United States,2021-11-15,144771,126152.828125,1058,1168.595459
1,United States,2021-11-16,95092,101961.632812,1419,1444.539917
2,United States,2021-11-17,103479,87326.929688,1225,1116.613647
3,United States,2021-11-18,111823,73257.117188,1097,903.298706
4,United States,2021-11-19,113589,109797.867188,1677,1368.132812
5,United States,2021-11-20,32250,49456.425781,269,473.537384
6,United States,2021-11-21,24995,42788.019531,85,331.188873
7,Alabama,2021-11-15,652,871.107666,2,9.481331
8,Alabama,2021-11-16,642,713.607727,0,4.718861
9,Alabama,2021-11-17,622,922.985962,1,6.061668


In [32]:
# Save it
Vali_results.to_csv(r'XGB_Validation.csv')

## 4. plot validation results

In [41]:
def plot_vali(State):
    ## select State
    State_data = Vali_results[Vali_results["State"] == State]
    

    fig = px.line(State_data, 
                      ## set x, y and color
                          x = "Date", 
                          y = ["True Cases","Predict Cases"],
                      
                    ## set title and rename labels
                     title = f"Number of COVID-19 Cases from 11/15 to 11/21 in {State}")
            
    fig.show()
        
## use widgets interaction       
widgets.interact(plot_vali, State = sorted(set(Vali_results["State"])))

interactive(children=(Dropdown(description='State', options=('Alabama', 'Alaska', 'Arizona', 'Arkansas', 'Cali…

<function __main__.plot_vali(State)>