# Train the decision trees based on the latest data available from: https://healthdata.gov/Hospital/COVID-19-Reported-Patient-Impact-and-Hospital-Capa/anag-cw7u

NB - there is no longer regular updating of case and death data. Must omit from model training. 

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn import tree, metrics
from sklearn.tree import DecisionTreeClassifier, plot_tree, export_graphviz
from sklearn.metrics import accuracy_score, confusion_matrix, matthews_corrcoef, f1_score, roc_auc_score, roc_curve, auc, RocCurveDisplay
from sklearn.model_selection import RandomizedSearchCV, cross_val_score, KFold, RepeatedStratifiedKFold
from num2words import num2words
import word2number
from word2number import w2n
import pickle
import pydotplus

import random
from matplotlib.patches import Polygon
import graphviz
import sklearn.tree as tree
from six import StringIO
from IPython.display import Image
import string
from PIL import Image
from subprocess import call
from Functions import prep_training_test_data_period, prep_training_test_data, calculate_metrics,cross_validation_leave_geo_out, prep_training_test_data_shifted, add_labels_to_subplots, LOOCV_by_HSA_dataset, save_in_HSA_dictionary, pivot_data_by_HSA, merge_and_rename_data, add_changes_by_week, create_column_names, create_collated_weekly_data, simplify_labels_graphviz
hfont = {'fontname':'Helvetica'}
palette = ['#66c2a5', '#fc8d62', '#8da0cb', '#e78ac3', '#a6d854', '#e5c494']

In [2]:
def calculate_pred_by_hsa(hsa_weekly_data_all, directory_path, weeks_to_predict, prep_training_test_data, time_period, train_weeks_for_initial_model, size_of_test_dataset, keep_output, weeks_in_future, weight_col, geography):
    pred_by_hsa_full = {}
    outcome_by_hsa_full = {}
    pred_proba_by_hsa_full = {}


    for hsa in hsa_weekly_data_all[geography].unique():
        print(hsa)
        #prediction_for_hsa_proba = []
        #outcome_for_hsa = []
        for prediction_week in weeks_to_predict:
            training_dataframe, testing_dataframe = LOOCV_by_HSA_dataset(hsa_weekly_data_all, hsa, 'HSA_ID') #may need to add time period somewhere here
            if not testing_dataframe['weight'].isna().any():
                model_name_to_load = directory_path   + str(prediction_week) + ".sav"
                clf = pickle.load(open(model_name_to_load, 'rb'))

                x_test, y_test, weights_test, missing_data_test_hsa = prep_training_test_data(testing_dataframe, no_weeks=range(int(prediction_week + train_weeks_for_initial_model) + 1, int(prediction_week + train_weeks_for_initial_model + size_of_test_dataset) + 1), weeks_in_future=weeks_in_future, geography=geography, weight_col=weight_col, keep_output=keep_output)
                if len(x_test) >= 1:
                    #outcome_for_hsa.append(y_test.at[0,0])

                    y_pred = clf.predict(x_test)
                    y_pred_proba = clf.predict_proba(x_test)
                    pred_by_hsa_full[hsa] = y_pred[0]
                    outcome_by_hsa_full[hsa] = y_test.at[0,0]
                    pred_proba_by_hsa_full[hsa] = y_pred_proba[:, 1]

                else: 

                    pred_by_hsa_full[hsa] = np.nan
                    outcome_by_hsa_full[hsa] = np.nan
                    pred_proba_by_hsa_full[hsa] = np.nan
    return pred_by_hsa_full, outcome_by_hsa_full, pred_proba_by_hsa_full

In [3]:
def convert_state_to_code(dataframe, column_name):
    # List of state names in alphabetical order, including Washington, D.C.
    state_names = [
        'Alabama', 'Alaska', 'Arizona', 'Arkansas', 'California', 'Colorado',
        'Connecticut', 'Delaware', 'Florida', 'Georgia', 'Hawaii', 'Idaho',
        'Illinois', 'Indiana', 'Iowa', 'Kansas', 'Kentucky', 'Louisiana',
        'Maine', 'Maryland', 'Massachusetts', 'Michigan', 'Minnesota',
        'Mississippi', 'Missouri', 'Montana', 'Nebraska', 'Nevada', 'New Hampshire',
        'New Jersey', 'New Mexico', 'New York', 'North Carolina', 'North Dakota',
        'Ohio', 'Oklahoma', 'Oregon', 'Pennsylvania', 'Rhode Island', 'South Carolina',
        'South Dakota', 'Tennessee', 'Texas', 'Utah', 'Vermont', 'Virginia', 'Washington',
        'West Virginia', 'Wisconsin', 'Wyoming', 'Washington D.C.'
    ]

    # Create a dictionary to map state names to numerical codes
    state_to_code = {state: f'{index + 1:02}' for index, state in enumerate(state_names)}

    # Create a new column "state_code" based on the mapping
    dataframe['state_code'] = dataframe[column_name].map(state_to_code)
    
    return dataframe

In [4]:
directory_path = "/Users/rem76/Documents/COVID_projections/Exact_analysis_smaller_hyperparameters/Latest_data/"
os.chdir(directory_path)

# Import and prepare data

In [26]:
HSA_weekly_data_updated = pd.read_csv("/Users/rem76/Documents/COVID_projections/hsa_time_data_all_dates_with_state_fips_latest_data.csv")


  HSA_weekly_data_updated = pd.read_csv("/Users/rem76/Documents/COVID_projections/hsa_time_data_all_dates_with_state_fips_latest_data.csv")


In [27]:
HSA_weekly_data_updated.rename(columns={'health_service_area_number': 'HSA_ID'}, inplace=True)
HSA_weekly_data_updated['beds_over_15_100k'] = (HSA_weekly_data_updated['beds_weekly'] > 15)*1
HSA_weekly_data_updated_features = HSA_weekly_data_updated.dropna(subset=['admits_weekly', 'icu_weekly', 'beds_weekly', 'perc_covid'])
for i, week in enumerate(HSA_weekly_data_updated_features['date'].unique()):
    HSA_weekly_data_updated_features.loc[HSA_weekly_data_updated_features['date'] == week, 'week'] = i

Merge dataframes

In [114]:
## pivot 
data_by_HSA_admissions = pivot_data_by_HSA(HSA_weekly_data_updated_features, 'week', 'HSA_ID', 'admits_weekly')
data_by_HSA_icu = pivot_data_by_HSA(HSA_weekly_data_updated_features, 'week', 'HSA_ID', 'icu_weekly')
data_by_HSA_beds = pivot_data_by_HSA(HSA_weekly_data_updated_features, 'week', 'HSA_ID', 'beds_weekly')
data_by_HSA_percent_beds = pivot_data_by_HSA(HSA_weekly_data_updated_features, 'week', 'HSA_ID', 'perc_covid')
data_by_HSA_over_15_100k = pivot_data_by_HSA(HSA_weekly_data_updated_features, 'week', 'HSA_ID', 'beds_over_15_100k')

## merge 
data_by_HSA_admits_icu_weekly = merge_and_rename_data(data_by_HSA_admissions, data_by_HSA_icu,'week','admits', 'icu')
data_by_HSA_beds_perc_weekly = merge_and_rename_data(data_by_HSA_beds, data_by_HSA_percent_beds,'week','beds', 'perc_covid')
data_by_HSA_cases_beds_perc_admits_icu = pd.merge(data_by_HSA_beds_perc_weekly, data_by_HSA_admits_icu_weekly, on='week')

## add outcome variable 

old_column_names = data_by_HSA_over_15_100k.columns
new_column_names = [str(col) + '_beds_over_15_100k' for col in old_column_names]
new_column_names = dict(zip(old_column_names, new_column_names))
data_by_HSA_over_15_100k.rename(columns=new_column_names, inplace=True)
data_by_HSA_cases_admits_icu_beds = pd.merge(data_by_HSA_cases_beds_perc_admits_icu, data_by_HSA_over_15_100k, on='week')

data_by_HSA_cases_admits_icu_beds = data_by_HSA_cases_admits_icu_beds.reset_index()
data_by_HSA_cases_admits_icu_beds.columns = data_by_HSA_cases_admits_icu_beds.columns.str.replace(',', '')

In [29]:
maximum_week = HSA_weekly_data_updated_features['week'].max()

Get weekly changes

In [115]:
all_HSA_ID_weekly_data = add_changes_by_week(data_by_HSA_cases_admits_icu_beds, "beds_over_15_100k")

  weekly_data_frame.insert(column_index + 1, new_column_name, diff)
  weekly_data_frame.insert(column_index + 1, new_column_name, diff)
  weekly_data_frame.insert(column_index + 1, new_column_name, diff)
  weekly_data_frame.insert(column_index + 1, new_column_name, diff)
  weekly_data_frame.insert(column_index + 1, new_column_name, diff)
  weekly_data_frame.insert(column_index + 1, new_column_name, diff)
  weekly_data_frame.insert(column_index + 1, new_column_name, diff)
  weekly_data_frame.insert(column_index + 1, new_column_name, diff)
  weekly_data_frame.insert(column_index + 1, new_column_name, diff)
  weekly_data_frame.insert(column_index + 1, new_column_name, diff)
  weekly_data_frame.insert(column_index + 1, new_column_name, diff)
  weekly_data_frame.insert(column_index + 1, new_column_name, diff)
  weekly_data_frame.insert(column_index + 1, new_column_name, diff)
  weekly_data_frame.insert(column_index + 1, new_column_name, diff)
  weekly_data_frame.insert(column_index + 1, new

In [116]:
categories_for_subsetting = [ 'admits', 'icu', 'beds', 'perc_covid',  'admits_delta', 'icu_delta', 'beds_delta', 'perc_covid_delta','beds_over_15_100k']
num_of_weeks = len(all_HSA_ID_weekly_data)
column_names = create_column_names(categories_for_subsetting, num_of_weeks)

In [117]:
all_HSA_ID_weekly_data = create_collated_weekly_data(all_HSA_ID_weekly_data, HSA_weekly_data_updated, categories_for_subsetting, 'HSA_ID', column_names)

Add weights 

In [118]:
weights_df = HSA_weekly_data_updated[HSA_weekly_data_updated['HSA_ID'].isin(all_HSA_ID_weekly_data['HSA_ID'])][['HSA_ID','weight']]
weights_df = weights_df.rename(columns = {'HSA_ID': 'HSA_ID', 'weight':'weight'})
weights_df['weight'].unique()
all_HSA_ID_weekly_data = all_HSA_ID_weekly_data.join(weights_df['weight'])

Save file

In [120]:
# write a csv file with all the data
all_HSA_ID_weekly_data.to_csv("/Users/rem76/Documents/COVID_projections/Exact_analysis_smaller_hyperparameters/Latest_data/hsa_time_data_all_dates_weekly_latest_data.csv", index=False)

# Get GeoData

In [5]:
data_by_HSA = pd.read_csv('/Users/rem76/Documents/COVID_projections/hsa_time_data_all_dates_with_state_fips.csv')
data_by_HSA.rename(columns={'health_service_area_number': 'HSA_ID'}, inplace=True)

# Load the json file with county coordinates
geoData = gpd.read_file('https://raw.githubusercontent.com/holtzy/The-Python-Graph-Gallery/master/static/data/US-counties.geojson')

# Make sure the "id" column is an integer
geoData.id = geoData.id.astype(str).astype(int)

# Try and match to states 
census_data = pd.read_csv('https://raw.githubusercontent.com/holtzy/The-Python-Graph-Gallery/master/static/data/unemployment-x.csv')

geoData = geoData.merge(census_data, left_on=['id'], right_on=['id'])

data_by_HSA = convert_state_to_code(data_by_HSA, 'state')

  data_by_HSA = pd.read_csv('/Users/rem76/Documents/COVID_projections/hsa_time_data_all_dates_with_state_fips.csv')


NameError: name 'gpd' is not defined

In [None]:
geoData['HSA_ID'] = None
for i, name in enumerate(geoData['id']):
    state = geoData.loc[i, 'state']
    filtered_rows = data_by_HSA[data_by_HSA['fips'] == name ]

    if(len(filtered_rows) > 1): #ensures that only HSAs that are represented in the geoData are used
        geoData.loc[i, 'HSA_ID'] = filtered_rows['HSA_ID'].reset_index(drop=True)[0]


for i, HSA in enumerate(data_by_HSA['HSA_ID']):
    geoData.loc[geoData['HSA_ID'] == HSA, 'Prediction'] = predictions_with_HSA.loc[predictions_with_HSA['HSA'] == int(HSA), 'y_pred']

# Predictions, outcomes, and probabilities per HSA from trained datasets

In [None]:
weeks_to_predict_all = range(1, maximum_week)
for week in weeks_to_predict_all: 
    directory_path = "/Users/rem76/Documents/GitHub/Viz-COVID19/Classifiers/Outcome_week_"
    weeks_to_predict = [week]
    pred_by_hsa, outcome_for_hsa, pred_proba_by_hsa = calculate_pred_by_hsa(HSA_latest_data, directory_path, weeks_to_predict, prep_training_test_data, time_period, train_weeks_for_initial_model, size_of_test_dataset, keep_output, weeks_in_future, weight_col, geography)
    #print(pred_by_hsa)
    #print(outcome_for_hsa)
    print("sum", sum(pred == outcome for pred, outcome in zip(pred_by_hsa, outcome_for_hsa)))
    for i, HSA in enumerate(pred_by_hsa.keys()):
        geoData.loc[geoData['HSA_ID'] == HSA, 'Prediction'] = pred_by_hsa[HSA]
        geoData.loc[geoData['HSA_ID'] == HSA, 'Outcome'] = outcome_for_hsa[HSA]
        if np.isnan(pred_proba_by_hsa.get(HSA)): 
            geoData.loc[geoData['HSA_ID'] == HSA, 'Predict_proba'] = pred_proba_by_hsa[HSA]
        else: 
            geoData.loc[geoData['HSA_ID'] == HSA, 'Predict_proba'] = pred_proba_by_hsa[HSA][0]
    file_path = "/Users/rem76/Documents/GitHub/Viz-COVID19/Weekly_data/Geodata_predictions_outcomes_week_" + str(week) + ".csv"
    geoData.to_csv(file_path)

# Plot decision trees

In [6]:
weeks_to_predict_all = range(1, 175)

In [24]:
feature_names=[ 'COVID-19  admissions',  'COVID-19 ICU beds', 'COVID-19 hospital beds', 'Perc. of beds with \nCOVID-19 patients', '\u0394 COVID-19 admissions', '\u0394 COVID-19 ICU beds', '\u0394 COVID-19 hospital beds',  '\u0394 Perc. beds with \nCOVID-19 patients', '> 15 per 100,000 COVID-19 \npatients in hospital beds']

def simplify_labels_graphviz(graph):
    for node in graph.get_node_list():
        if node.get_attributes().get("label") is None:
            continue
        else:
            split_label = node.get_attributes().get("label").split("<br/>")
            if len(split_label) == 4:
                split_label[3] = split_label[3].split("=")[1].strip()

                del split_label[1]  # number of samples
                del split_label[1]  # split of sample
            elif len(split_label) == 3:  # for a terminating node, no rule is provided
                split_label[2] = split_label[2].split("=")[1].strip()

                del split_label[0]  # number of samples
                del split_label[0]  # split of samples
                split_label[0] = "<" + split_label[0]
            node.set("label", "<br/>".join(split_label))
def enhance_graph(graph):
    graph.set_dpi(400)  # Set DPI for higher resolution (adjust as needed)
    graph.set_margin(.1)  # Adjust margins to make the graph layout tighter
    graph.set_rankdir('TB')  # Set the direction of the graph (e.g., top to bottom)

# Create StringIO objects to store dot data

base_directory_path = "/Users/rem76/Documents/GitHub/Viz-COVID19/Classifiers/Outcome_week_"

for week in weeks_to_predict_all: 
    directory_path = base_directory_path + str(week) + ".sav" 
    classifier = pickle.load(open(directory_path, 'rb'))
    dot_data = tree.export_graphviz(classifier, 
    out_file= None,
    class_names=['Over', 'Under'], # the target names.
    feature_names= feature_names, # the feature names.
    filled=True, # Whether to fill in the boxes with colours.
    rounded=True, # Whether to round the corners of the boxes.
    special_characters=True, 
    proportion=False,
    precision = 0, 
    impurity=False)

    graph = pydotplus.graph_from_dot_data(dot_data) 
    simplify_labels_graphviz(graph)
    graph.del_node('"\\n"') ## empty node at end, delete it 
    output_file_path = "/Users/rem76/Documents/GitHub/Viz-COVID19/Classifiers_visualized/Classifier_week_" + str(week) + ".png"
    graph.write_png(output_file_path) 




In [22]:
for n in graph.get_nodes():
    print(n)

<pydotplus.graphviz.Node object at 0x28082bf10>
<pydotplus.graphviz.Node object at 0x282ddea90>
<pydotplus.graphviz.Node object at 0x282b70bd0>
<pydotplus.graphviz.Node object at 0x282e94e10>
<pydotplus.graphviz.Node object at 0x282e96ed0>
<pydotplus.graphviz.Node object at 0x1181a4f50>
<pydotplus.graphviz.Node object at 0x1181a6110>
<pydotplus.graphviz.Node object at 0x11826bc50>
<pydotplus.graphviz.Node object at 0x11816e790>
<pydotplus.graphviz.Node object at 0x282be5150>
<pydotplus.graphviz.Node object at 0x11818e7d0>
<pydotplus.graphviz.Node object at 0x1182879d0>
<pydotplus.graphviz.Node object at 0x1182b03d0>
<pydotplus.graphviz.Node object at 0x1182b0e10>
<pydotplus.graphviz.Node object at 0x1181ba350>
<pydotplus.graphviz.Node object at 0x1182b1890>
<pydotplus.graphviz.Node object at 0x1182b2c10>
<pydotplus.graphviz.Node object at 0x1182b3750>
<pydotplus.graphviz.Node object at 0x1182d82d0>
<pydotplus.graphviz.Node object at 0x1182d97d0>
<pydotplus.graphviz.Node object at 0x118