In [2]:
#!/usr/bin/env python
import string
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from num2words import num2words
from word2number import w2n
import pydotplus
from six import StringIO
from IPython.display import Image
from sklearn import tree, metrics
from sklearn.model_selection import (
    train_test_split,
    RandomizedSearchCV,
    cross_val_score,
    KFold,
    RepeatedStratifiedKFold,
)
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.metrics import (
    roc_auc_score,
    matthews_corrcoef,
    accuracy_score,
    confusion_matrix,
    f1_score,
    roc_curve,
    auc,
    RocCurveDisplay,
)
import pickle



# Now you can access the job number as args.job_number
job_number = 1



######## FUNCTIONS ##########
def add_labels_to_subplots(axs, hfont, height, fontsize):
    labels_subplots = list(string.ascii_uppercase)
    for i, ax in enumerate(axs):
        ax.text(
            ax.get_xlim()[0],
            ax.get_ylim()[1] * height,
            labels_subplots[i],
            fontsize=fontsize,
            **hfont,
        )
    return labels_subplots


### this code it's exactly in  x weeks
def merge_and_rename_data(data1, data2, on_column, suffix1, suffix2):
    merged_data = pd.merge(
        data1, data2, on=on_column, suffixes=("_" + suffix1, "_" + suffix2)
    )

    new_column_names = [
        col.replace(f"_{on_column}_{suffix1}", f"_{suffix1}").replace(
            f"_{on_column}_{suffix2}", f"_{suffix2}"
        )
        for col in merged_data.columns
    ]
    merged_data.rename(
        columns=dict(zip(merged_data.columns, new_column_names)), inplace=True
    )

    return merged_data


def pivot_data_by_HSA(data, index_column, columns_column, values_column):
    data_by_HSA = data[[index_column, columns_column, values_column]]
    pivot_table = data_by_HSA.pivot_table(
        index=index_column, columns=columns_column, values=values_column
    )
    return pivot_table


def create_column_names(categories_for_subsetting, num_of_weeks):
    column_names = ["HSA_ID"]

    for week in range(1, num_of_weeks + 1):
        week = num2words(week)
        for category in categories_for_subsetting:
            column_name = f"week_{week}_{category}"
            column_names.append(column_name)

    return column_names


def create_collated_weekly_data(
    pivoted_table, original_data, categories_for_subsetting, geography, column_names
):
    collated_data = pd.DataFrame(index=range(51), columns=column_names)

    x = 0
    for geo in original_data[geography].unique():
        # matching_indices = [i for i, geo_col in enumerate(pivoted_table) if geo_col == geo]
        collated_data.loc[x, geography] = geo
        columns_to_subset = [
            f"{geo}_{category}" for category in categories_for_subsetting
        ]
        j = 1
        try:
            for row in range(len(pivoted_table.loc[:, columns_to_subset])):
                collated_data.iloc[
                    x, j : j + len(categories_for_subsetting)
                ] = pivoted_table.loc[row, columns_to_subset]
                j += len(categories_for_subsetting)
        except:
            pass
        x += 1

    return collated_data


def add_changes_by_week(weekly_data_frame, outcome_column):
    for column in weekly_data_frame.columns[1:]:
        # Calculate the difference between each row and the previous row
        if outcome_column not in column.lower():  # want to leave out the outcome column
            diff = weekly_data_frame[column].diff()

            # Create a new column with the original column name and "delta"
            new_column_name = column + "_delta"

            column_index = weekly_data_frame.columns.get_loc(column)

            # Insert the new column just after the original column
            weekly_data_frame.insert(column_index + 1, new_column_name, diff)
            weekly_data_frame[new_column_name] = diff
    return weekly_data_frame


### exactly 
def prep_training_test_data(
    data, no_weeks, weeks_in_future, geography, weight_col, keep_output
):
    ## Get the weeks for the x and y datasets
    x_weeks = []
    y_weeks = []
    for week in no_weeks:
        test_week = int(week) + weeks_in_future
        x_weeks.append("_" + num2words(week) + "_")
        y_weeks.append("_" + num2words(test_week) + "_")

    X_data = pd.DataFrame()
    y_data = pd.DataFrame()
    weights_all = pd.DataFrame()
    missing_data = []
    ## Now get the training data
    k = 0
    for x_week in x_weeks:
        y_week = y_weeks[k]
        k += 1
        weeks_x = [col for col in data.columns if x_week in col]
        columns_x = [geography] + weeks_x + [weight_col]
        data_x = data[columns_x]

        weeks_y = [col for col in data.columns if y_week in col]
        columns_y = [geography] + weeks_y
        data_y = data[columns_y]
        # ensure they have the same amount of data
        # remove rows in test_data1 with NA in test_data2
        data_x = data_x.dropna()
        data_x = data_x[data_x[geography].isin(data_y[geography])]
        # remove rows in test_data2 with NA in test_data1
        data_y = data_y.dropna()
        data_y = data_y[data_y[geography].isin(data_x[geography])]
        data_x = data_x[data_x[geography].isin(data_y[geography])]
        data_x_no_HSA = len(data_x[geography].unique())

        missing_data.append(
            (
                (len(data[geography].unique()) - data_x_no_HSA)
                / len(data[geography].unique())
            )
            * 100
        )
        # get weights
        # weights = weight_data[weight_data[geography].isin(data_x[geography])][[geography, weight_col]]

        X_week = data_x.iloc[:, 1 : len(columns_x)]  # take away y, leave weights for mo
        y_week = data_y.iloc[:, -1]

        y_week = y_week.astype(int)
        weights = X_week.iloc[:, -1]
        if keep_output:
            X_week = X_week.iloc[
                :, : len(X_week.columns) - 1
            ]  # remove the weights and leave "target" for that week

            # rename columns for concatenation
            X_week.columns = range(1, len(data_x.columns) - 1)
        else:
            X_week = X_week.iloc[
                :, : len(X_week.columns) - 2
            ]  # remove the weights and  "target" for that week

            X_week.columns = range(
                1, len(data_x.columns) - 2
            )  # remove the weights and  "target" for that week

            # rename columns for concatenation
        y_week.columns = range(1, len(data_y.columns) - 1)
        X_data = pd.concat([X_data, X_week])
        y_data = pd.concat([y_data, y_week])

        weights_all = pd.concat([weights_all, weights])

    X_data.reset_index(drop=True, inplace=True)
    y_data.reset_index(drop=True, inplace=True)
    weights_all.reset_index(drop=True, inplace=True)

    return (X_data, y_data, weights_all, missing_data)

### this code it's ANY in the x week period
def prep_training_test_data_period(
    data, no_weeks, weeks_in_future, geography, weight_col, keep_output
):
    ## Get the weeks for the x and y datasets
    x_weeks = []
    y_weeks = []
    y_weeks_to_check = []  # check these weeks to see if any of them are equal to 1
    for week in no_weeks:
        test_week = int(week) + weeks_in_future
        x_weeks.append("_" + num2words(week) + "_")
        for week_y in range(week + 1, test_week + 1):
            y_weeks_to_check.append("_" + num2words(week_y) + "_")
        y_weeks.append("_" + num2words(test_week) + "_")

    ## Divide up the test/train split
    # if is_geographic:
    # Calculate the index to start slicing from
    #    start_index = len(data['county']) // proportion[0] * proportion[1]
    # Divide up the dataset based on this proportion
    #    first_two_thirds = data['county'][:start_index]
    #    last_third = data['county'][start_index:]
    X_data = pd.DataFrame()
    y_data = pd.DataFrame()
    weights_all = pd.DataFrame()
    missing_data = []
    ## Now get the training data
    k = 0
    for x_week in x_weeks:
        y_week = y_weeks[k]
        k += 1

        weeks_x = [col for col in data.columns if x_week in col]
        columns_x = [geography] + weeks_x + [weight_col]
        data_x = data[columns_x]

        weeks_y = [col for col in data.columns if y_week in col]
        columns_y = [geography] + weeks_y
        data_y = data[columns_y]
        ### now add the final column to the y data that has it so that it's if any week in the trhee week perdiod exceeded 15
        train_week = w2n.word_to_num(x_week.replace("_", ""))
        target_week = w2n.word_to_num(y_week.replace("_", ""))
        y_weeks_to_check = []
        for week_to_check in range(train_week + 1, target_week + 1):
            y_weeks_to_check.append("_" + num2words(week_to_check) + "_")

        y_weeks_to_check = [week + "beds_over_15_100k" for week in y_weeks_to_check]
        columns_to_check = [
            col for col in data.columns if any(week in col for week in y_weeks_to_check)
        ]
        y_over_in_period = data[columns_to_check].apply(max, axis=1)
        data_y = pd.concat([data_y, y_over_in_period], axis=1)
        # ensure they have the same amount of data
        # remove rows in test_data1 with NA in test_data2
        data_x = data_x.dropna()
        data_x = data_x[data_x[geography].isin(data_y[geography])]
        # remove rows in test_data2 with NA in test_data1
        data_y = data_y.dropna()
        data_y = data_y[data_y[geography].isin(data_x[geography])]
        data_x = data_x[data_x[geography].isin(data_y[geography])]
        data_x_no_HSA = len(data_x[geography].unique())

        missing_data.append(
            (
                (len(data[geography].unique()) - data_x_no_HSA)
                / len(data[geography].unique())
            )
            * 100
        )
        # get weights
        # weights = weight_data[weight_data[geography].isin(data_x[geography])][[geography, weight_col]]

        X_week = data_x.iloc[:, 1 : len(columns_x)]  # take away y, leave weights for mo
        y_week = data_y.iloc[:, -1]

        y_week = y_week.astype(int)
        weights = X_week.iloc[:, -1]
        if keep_output:
            X_week = X_week.iloc[
                :, : len(X_week.columns) - 1
            ]  # remove the weights and leave "target" for that week

            # rename columns for concatenation
            X_week.columns = range(1, len(data_x.columns) - 1)
        else:
            X_week = X_week.iloc[
                :, : len(X_week.columns) - 2
            ]  # remove the weights and  "target" for that week

            X_week.columns = range(
                1, len(data_x.columns) - 2
            )  # remove the weights and  "target" for that week

        y_week.columns = range(1, len(data_y.columns) - 2)
        X_data = pd.concat([X_data, X_week])
        y_data = pd.concat([y_data, y_week])

        weights_all = pd.concat([weights_all, weights])

    X_data.reset_index(drop=True, inplace=True)
    y_data.reset_index(drop=True, inplace=True)
    weights_all.reset_index(drop=True, inplace=True)

    return (X_data, y_data, weights_all, missing_data)


def calculate_metrics(confusion_matrix):
    # Extract values from the confusion matrix
    TP = confusion_matrix[1, 1]
    FP = confusion_matrix[0, 1]
    TN = confusion_matrix[0, 0]
    FN = confusion_matrix[1, 0]

    # Calculate Sensitivity (True Positive Rate), Specificity (True Negative Rate),
    # PPV (Precision), and NPV
    sensitivity = TP / (TP + FN) if (TP + FN) > 0 else 0.0
    specificity = TN / (TN + FP) if (TN + FP) > 0 else 0.0
    ppv = TP / (TP + FP) if (TP + FP) > 0 else 0.0
    npv = TN / (TN + FN) if (TN + FN) > 0 else 0.0

    return sensitivity, specificity, ppv, npv


def merge_and_rename_data(data1, data2, on_column, suffix1, suffix2):
    merged_data = pd.merge(
        data1, data2, on=on_column, suffixes=("_" + suffix1, "_" + suffix2)
    )

    new_column_names = [
        col.replace(f"_{on_column}_{suffix1}", f"_{suffix1}").replace(
            f"_{on_column}_{suffix2}", f"_{suffix2}"
        )
        for col in merged_data.columns
    ]
    merged_data.rename(
        columns=dict(zip(merged_data.columns, new_column_names)), inplace=True
    )

    return merged_data


def pivot_data_by_HSA(data, index_column, columns_column, values_column):
    data_by_HSA = data[[index_column, columns_column, values_column]]
    pivot_table = data_by_HSA.pivot_table(
        index=index_column, columns=columns_column, values=values_column
    )
    return pivot_table


def add_changes_by_week(weekly_data_frame, outcome_column):
    for column in weekly_data_frame.columns[1:]:
        # Calculate the difference between each row and the previous row
        if outcome_column not in column.lower():  # want to leave out the outcome column
            diff = weekly_data_frame[column].diff()

            # Create a new column with the original column name and "delta"
            new_column_name = column + "_delta"

            column_index = weekly_data_frame.columns.get_loc(column)

            # Insert the new column just after the original column
            weekly_data_frame.insert(column_index + 1, new_column_name, diff)
            weekly_data_frame[new_column_name] = diff
    return weekly_data_frame


def determine_covid_outcome_indicator(
    new_cases_per_100k, new_admits_per_100k, percent_beds_100k
):
    if new_cases_per_100k < 200:
        if (new_admits_per_100k >= 10) | (
            percent_beds_100k > 0.10
        ):  # Changed .10 to 0.10
            if (new_admits_per_100k >= 20) | (percent_beds_100k >= 15):
                return "High"
            else:
                return "Medium"
        else:
            return "Low"
    elif new_cases_per_100k >= 200:
        if (new_admits_per_100k >= 10) | (
            percent_beds_100k >= 0.10
        ):  # Changed .10 to 0.10
            return "High"
        elif (new_admits_per_100k < 10) | (percent_beds_100k < 10):
            return "Medium"


def simplify_labels_graphviz(graph):
    for node in graph.get_node_list():
        if node.get_attributes().get("label") is None:
            continue
        else:
            split_label = node.get_attributes().get("label").split("<br/>")
            if len(split_label) == 4:
                split_label[3] = split_label[3].split("=")[1].strip()

                del split_label[1]  # number of samples
                del split_label[1]  # split of sample
            elif len(split_label) == 3:  # for a terminating node, no rule is provided
                split_label[2] = split_label[2].split("=")[1].strip()

                del split_label[0]  # number of samples
                del split_label[0]  # split of samples
                split_label[0] = "<" + split_label[0]
            node.set("label", "<br/>".join(split_label))


def generate_decision_tree_graph(classifier, class_names, feature_names):
    dot_data = StringIO()
    tree.export_graphviz(
        classifier,
        out_file=dot_data,
        class_names=class_names,
        feature_names=feature_names,
        filled=True,
        rounded=True,
        special_characters=True,
        proportion=False,
        precision=0,
        impurity=False,
    )

    graph = pydotplus.graph_from_dot_data(dot_data.getvalue())

    return graph


def cross_validation_leave_geo_out(
    data,
    geography_column,
    geo_split,
    no_iterations,
    cv,
    classifier,
    param_grid,
    no_iterations_param,
    no_weeks_train,
    no_weeks_test,
    weeks_in_future,
    weight_col,
    keep_output,
    time_period,
):
    best_hyperparameters_per_iter = []
    auROC_per_iter = []

    for i in range(no_iterations):
        print(i)
        # Subset the HSAs from the full dataset
        geo_names = data[geography_column].unique()
        num_names_to_select = int(geo_split * len(geo_names))
        geos_for_sample = random.sample(list(geo_names), num_names_to_select)
        subset_HSAs_for_train = data[data[geography_column].isin(geos_for_sample)]
        subset_HSAs_for_test = data[~data[geography_column].isin(geos_for_sample)]

        # Create training and test data
        if time_period == "period":
            (
                X_sample_train,
                y_sample_train,
                weights_train,
                missing_data_train_HSA,
            ) = prep_training_test_data_period(
                subset_HSAs_for_train,
                no_weeks=no_weeks_train,
                weeks_in_future=weeks_in_future,
                geography=geography_column,
                weight_col=weight_col,
                keep_output=keep_output,
            )
            (
                X_sample_test,
                y_sample_test,
                weights_test,
                missing_data_train_HSA,
            ) = prep_training_test_data_period(
                subset_HSAs_for_test,
                no_weeks=no_weeks_test,
                weeks_in_future=weeks_in_future,
                geography=geography_column,
                weight_col=weight_col,
                keep_output=keep_output,
            )
            weights_train = weights_train[0]
        elif time_period == "exact":
            (
                X_sample_train,
                y_sample_train,
                weights_train,
                missing_data_train_HSA,
            ) = prep_training_test_data(
                subset_HSAs_for_train,
                no_weeks=no_weeks_train,
                weeks_in_future=weeks_in_future,
                geography=geography_column,
                weight_col=weight_col,
                keep_output=keep_output,
            )
            (
                X_sample_test,
                y_sample_test,
                weights_test,
                missing_data_train_HSA,
            ) = prep_training_test_data(
                subset_HSAs_for_test,
                no_weeks=no_weeks_test,
                weeks_in_future=weeks_in_future,
                geography=geography_column,
                weight_col=weight_col,
                keep_output=keep_output,
            )
            weights_train = weights_train[0]
        elif time_period == "shifted":
            (
                X_sample_train,
                y_sample_train,
                weights_train,
                missing_data_train_HSA,
            ) = prep_training_test_data_shifted(
                subset_HSAs_for_train,
                no_weeks=no_weeks_train,
                weeks_in_future=weeks_in_future,
                geography=geography_column,
                weight_col=weight_col,
                keep_output=keep_output,
            )
            (
                X_sample_test,
                y_sample_test,
                weights_test,
                missing_data_train_HSA,
            ) = prep_training_test_data_shifted(
                subset_HSAs_for_test,
                no_weeks=no_weeks_test,
                weeks_in_future=weeks_in_future,
                geography=geography_column,
                weight_col=weight_col,
                keep_output=keep_output,
            )
            weights_train = weights_train[0]

        # Check if y_sample_test contains only 1's
        while (int(y_sample_test.sum().iloc[0]) / len(y_sample_test)) == 1:
            print("All 1")
            # Subset the HSAs from the full dataset
            geo_names = data[geography_column].unique()
            num_names_to_select = int(geo_split * len(geo_names))
            geos_for_sample = random.sample(list(geo_names), num_names_to_select)
            subset_HSAs_for_train = data[data[geography_column].isin(geos_for_sample)]
            subset_HSAs_for_test = data[~data[geography_column].isin(geos_for_sample)]

            # Create training and test data
            if time_period == "period":
                (
                    X_sample_train,
                    y_sample_train,
                    weights_train,
                    missing_data_train_HSA,
                ) = prep_training_test_data_period(
                    subset_HSAs_for_train,
                    no_weeks=no_weeks_train,
                    weeks_in_future=weeks_in_future,
                    geography=geography_column,
                    weight_col=weight_col,
                    keep_output=keep_output,
                )
                (
                    X_sample_test,
                    y_sample_test,
                    weights_test,
                    missing_data_train_HSA,
                ) = prep_training_test_data_period(
                    subset_HSAs_for_test,
                    no_weeks=no_weeks_test,
                    weeks_in_future=weeks_in_future,
                    geography=geography_column,
                    weight_col=weight_col,
                    keep_output=keep_output,
                )
                weights_train = weights_train[0]
            elif time_period == "exact":
                (
                    X_sample_train,
                    y_sample_train,
                    weights_train,
                    missing_data_train_HSA,
                ) = prep_training_test_data(
                    subset_HSAs_for_train,
                    no_weeks=no_weeks_train,
                    weeks_in_future=weeks_in_future,
                    geography=geography_column,
                    weight_col=weight_col,
                    keep_output=keep_output,
                )
                (
                    X_sample_test,
                    y_sample_test,
                    weights_test,
                    missing_data_train_HSA,
                ) = prep_training_test_data(
                    subset_HSAs_for_test,
                    no_weeks=no_weeks_test,
                    weeks_in_future=weeks_in_future,
                    geography=geography_column,
                    weight_col=weight_col,
                    keep_output=keep_output,
                )
                weights_train = weights_train[0]
            elif time_period == "shifted":
                (
                    X_sample_train,
                    y_sample_train,
                    weights_train,
                    missing_data_train_HSA,
                ) = prep_training_test_data_shifted(
                    subset_HSAs_for_train,
                    no_weeks=no_weeks_train,
                    weeks_in_future=weeks_in_future,
                    geography=geography_column,
                    weight_col=weight_col,
                    keep_output=keep_output,
                )
                (
                    X_sample_test,
                    y_sample_test,
                    weights_test,
                    missing_data_train_HSA,
                ) = prep_training_test_data_shifted(
                    subset_HSAs_for_test,
                    no_weeks=no_weeks_test,
                    weeks_in_future=weeks_in_future,
                    geography=geography_column,
                    weight_col=weight_col,
                    keep_output=keep_output,
                )
                weights_train = weights_train[0]

        random_search = RandomizedSearchCV(
            classifier, param_grid, n_iter=no_iterations_param, cv=cv, random_state=10
        )
        random_search.fit(X_sample_train, y_sample_train, sample_weight=weights_train)
        best_params = random_search.best_params_

        # Create the Decision Tree classifier with the best hyperparameters
        model = DecisionTreeClassifier(
            **best_params, random_state=10, class_weight="balanced"
        )
        model_fit = model.fit(
            X_sample_train, y_sample_train, sample_weight=weights_train
        )
        y_pred = model_fit.predict_proba(X_sample_test)

        # Evaluate the accuracy of the model
        best_hyperparameters_per_iter.append(best_params)
        auROC_per_iter.append(roc_auc_score(y_sample_test, y_pred[:, 1]))

    return best_hyperparameters_per_iter[np.argmax(np.array(auROC_per_iter))]


def LOOCV_by_HSA_dataset(dataframe, geo_ID, geo_ID_col):
    training_dataframe = dataframe[dataframe[geo_ID_col] != geo_ID]
    testing_dataframe = dataframe[dataframe[geo_ID_col] == geo_ID]
    return training_dataframe, testing_dataframe


def save_in_HSA_dictionary(
    prediction_week,
    ROC_by_week,
    accuracy_by_week,
    sensitivity_by_week,
    specificity_by_week,
    ppv_by_week,
    npv_by_week,
    ROC_by_HSA,
    accuracy_by_HSA,
    sensitivity_by_HSA,
    specificity_by_HSA,
    ppv_by_HSA,
    npv_by_HSA,
):
    ROC_by_HSA[prediction_week] = ROC_by_week
    accuracy_by_HSA[prediction_week] = accuracy_by_week
    sensitivity_by_HSA[prediction_week] = sensitivity_by_week
    specificity_by_HSA[prediction_week] = specificity_by_week
    ppv_by_HSA[prediction_week] = ppv_by_week
    npv_by_HSA[prediction_week] = npv_by_week


def prep_training_test_data_shifted(
    data, no_weeks, weeks_in_future, geography, weight_col, keep_output
):
    ## Get the weeks for the x and y datasets
    x_weeks = []
    y_weeks = []
    y_weeks_to_check = []  # check these weeks to see if any of them are equal to 1
    for week in no_weeks:
        test_week = int(week) + weeks_in_future
        x_weeks.append("_" + num2words(week) + "_")
        for week_y in range(week + 2, test_week + 2):
            y_weeks_to_check.append("_" + num2words(week_y) + "_")
        y_weeks.append("_" + num2words(test_week) + "_")
    ## Divide up the test/train split
    # if is_geographic:
    # Calculate the index to start slicing from
    #    start_index = len(data['county']) // proportion[0] * proportion[1]
    # Divide up the dataset based on this proportion
    #    first_two_thirds = data['county'][:start_index]
    #    last_third = data['county'][start_index:]
    X_data = pd.DataFrame()
    y_data = pd.DataFrame()
    weights_all = pd.DataFrame()
    missing_data = []
    ## Now get the training data
    k = 0
    for x_week in x_weeks:
        y_week = y_weeks[k]
        k += 1

        weeks_x = [col for col in data.columns if x_week in col]
        columns_x = [geography] + weeks_x + [weight_col]
        data_x = data[columns_x]

        weeks_y = [col for col in data.columns if y_week in col]
        columns_y = [geography] + weeks_y
        data_y = data[columns_y]
        ### now add the final column to the y data that has it so that it's if any week in the trhee week perdiod exceeded 15
        train_week = w2n.word_to_num(x_week.replace("_", ""))
        target_week = w2n.word_to_num(y_week.replace("_", ""))
        y_weeks_to_check = []
        for week_to_check in range(
            train_week + 2, target_week + 2
        ):  # have to ensure you skip the next week for getting the excess
            y_weeks_to_check.append("_" + num2words(week_to_check) + "_")
            print(y_weeks_to_check)
        y_weeks_to_check = [week + "beds_over_15_100k" for week in y_weeks_to_check]
        columns_to_check = [
            col for col in data.columns if any(week in col for week in y_weeks_to_check)
        ]
        y_over_in_period = data[columns_to_check].apply(max, axis=1)
        data_y = pd.concat([data_y, y_over_in_period], axis=1)
        # ensure they have the same amount of data
        # remove rows in test_data1 with NA in test_data2
        data_x = data_x.dropna()
        data_x = data_x[data_x[geography].isin(data_y[geography])]
        # remove rows in test_data2 with NA in test_data1
        data_y = data_y.dropna()
        data_y = data_y[data_y[geography].isin(data_x[geography])]
        data_x = data_x[data_x[geography].isin(data_y[geography])]
        data_x_no_HSA = len(data_x[geography].unique())

        missing_data.append(
            (
                (len(data[geography].unique()) - data_x_no_HSA)
                / len(data[geography].unique())
            )
            * 100
        )
        # get weights
        # weights = weight_data[weight_data[geography].isin(data_x[geography])][[geography, weight_col]]

        X_week = data_x.iloc[:, 1 : len(columns_x)]  # take away y, leave weights for mo
        y_week = data_y.iloc[:, -1]

        y_week = y_week.astype(int)

        weights = X_week.iloc[:, -1]
        if keep_output:
            X_week = X_week.iloc[
                :, : len(X_week.columns) - 1
            ]  # remove the weights and leave "target" for that week

            # rename columns for concatenation
            X_week.columns = range(1, len(data_x.columns) - 1)
        else:
            X_week = X_week.iloc[
                :, : len(X_week.columns) - 2
            ]  # remove the weights and  "target" for that week

            X_week.columns = range(
                1, len(data_x.columns) - 2
            )  # remove the weights and  "target" for that week

        y_week.columns = range(1, len(data_y.columns) - 2)
        X_data = pd.concat([X_data, X_week])
        y_data = pd.concat([y_data, y_week])

        weights_all = pd.concat([weights_all, weights])

    X_data.reset_index(drop=True, inplace=True)
    y_data.reset_index(drop=True, inplace=True)
    weights_all.reset_index(drop=True, inplace=True)

    return (X_data, y_data, weights_all, missing_data)


def LOOCV_by_HSA_dataset(dataframe, geo_ID, geo_ID_col):
    training_dataframe = dataframe[dataframe[geo_ID_col] != geo_ID]
    testing_dataframe = dataframe[dataframe[geo_ID_col] == geo_ID]
    return training_dataframe, testing_dataframe


def save_in_HSA_dictionary(
    prediction_week,
    ROC_by_week,
    accuracy_by_week,
    sensitivity_by_week,
    specificity_by_week,
    ppv_by_week,
    npv_by_week,
    ROC_by_HSA,
    accuracy_by_HSA,
    sensitivity_by_HSA,
    specificity_by_HSA,
    ppv_by_HSA,
    npv_by_HSA,
):
    ROC_by_HSA[prediction_week] = ROC_by_week
    accuracy_by_HSA[prediction_week] = accuracy_by_week
    sensitivity_by_HSA[prediction_week] = sensitivity_by_week
    specificity_by_HSA[prediction_week] = specificity_by_week
    ppv_by_HSA[prediction_week] = ppv_by_week
    npv_by_HSA[prediction_week] = npv_by_week


######### IMPORT DATA ##############
#HSA_weekly_data_all = pd.read_csv(
#    "/Users/rem76/Documents/COVID_projections/Exact_analysis_smaller_hyperparameters/Expanding_models_15_per_100k/hsa_time_data_all_dates_CDC_features_only_incl_NA.csv"
#)

HSA_weekly_data_all = pd.read_csv("/Users/rem76/Documents/COVID_projections/hsa_time_data_all_dates_weekly_incl_NA.csv")
columns_to_remove = [col for col in HSA_weekly_data_all.columns if 'cases' in col]
HSA_weekly_data_all = HSA_weekly_data_all.drop(columns=columns_to_remove)

columns_to_remove = [col for col in HSA_weekly_data_all.columns if 'deaths' in col]
HSA_weekly_data_all = HSA_weekly_data_all.drop(columns=columns_to_remove)



########### SET UP FOR EXPANDING MODELS
clf = DecisionTreeClassifier(random_state=10, class_weight="balanced")


no_iterations = 10
geography_column = "HSA_ID"
geo_split = 0.9
time_period = "exact"  # Choose 'period', 'exact', or 'shifted'
size_of_test_dataset = 1
train_weeks_for_initial_model = 1

weeks_in_future = 3
weight_col = "weight"
keep_output = True

no_iterations_param = 6 # Replace with the number of iterations for RandomizedSearchCV
param_grid = {
    "criterion": ["gini", "entropy"],
    "max_depth": np.arange(2, 5, 1)
} 
# Create the Decision Tree classifier
cv = RepeatedStratifiedKFold(
    n_splits=10, n_repeats=10, random_state=1
)  ## 10-fold cross validations


######### ACTUAL RUNS ############
weeks_to_predict = [job_number] 
ROC_by_week_full_period = []
sensitivity_by_week_full_period = []
specificity_by_week_full_period = []
ppv_by_week_full_period = []
npv_by_week_full_period = []
accuracy_by_week_full_period = []
norm_MCC_by_week_full_period = []
weeks_to_predict = range(1, 123,1)
for prediction_week in weeks_to_predict:
    print(prediction_week)
    no_weeks_train = range(1, int(prediction_week + train_weeks_for_initial_model) + 1)
    no_weeks_test = range(
        int(prediction_week + train_weeks_for_initial_model) + 1,
        int(prediction_week + train_weeks_for_initial_model + size_of_test_dataset) + 1,
    )
    (
        X_train_full_period,
        y_train_full_period,
        weights_full_period,
        missing_data_train_HSA,
    ) = prep_training_test_data(
        HSA_weekly_data_all,
        no_weeks=no_weeks_train,
        weeks_in_future=weeks_in_future,
        geography=geography_column,
        weight_col=weight_col,
        keep_output=keep_output,
    )

    (
        X_test_full_period,
        y_test_full_period,
        weights_test_full_period,
        missing_data_test_HSA,
    ) = prep_training_test_data(
        HSA_weekly_data_all,
        no_weeks=no_weeks_test,
        weeks_in_future=weeks_in_future,
        geography=geography_column,
        weight_col=weight_col,
        keep_output=keep_output,
    )
    weights_full_period = weights_full_period[0].to_numpy()
    best_params = cross_validation_leave_geo_out(
        HSA_weekly_data_all,
        geography_column=geography_column,
        geo_split=geo_split,
        no_iterations=no_iterations,
        cv=cv,
        classifier=clf,
        param_grid=param_grid,
        no_iterations_param=no_iterations_param,
        no_weeks_train=no_weeks_train,
        no_weeks_test=no_weeks_test,
        weeks_in_future=weeks_in_future,
        weight_col=weight_col,
        keep_output=keep_output,
        time_period=time_period,
    )
    clf = DecisionTreeClassifier(
        **best_params, random_state=10, class_weight="balanced"
    )
    clf.fit(
        X_train_full_period, y_train_full_period, sample_weight=weights_full_period
    )

    # Make predictions on the test set
    y_pred = clf.predict(X_test_full_period)
    y_pred_proba = clf.predict_proba(X_test_full_period)

    # Evaluate the accuracy of the model
    accuracy_by_week_full_period.append(accuracy_score(y_test_full_period, y_pred))
    ROC_by_week_full_period.append(
        roc_auc_score(y_test_full_period, y_pred_proba[:, 1])
    )
    conf_matrix = confusion_matrix(y_test_full_period, y_pred)

    model_name_to_save = (
        "/Users/rem76/Documents/COVID_projections/Exact_analysis_smaller_hyperparameters/No_cases_no_deaths/No_cases_no_deaths" + time_period + "_" + str(prediction_week) + ".sav"
    )

    pickle.dump(clf, open(model_name_to_save, "wb"))
    sensitvity, specificity, ppv, npv = calculate_metrics(conf_matrix)
    specificity_by_week_full_period.append(specificity)
    # Calculate sensitivity (true positive rate)
    sensitivity_by_week_full_period.append(sensitvity)
    norm_MCC_by_week_full_period.append(
        (matthews_corrcoef(y_test_full_period, y_pred) + 1) / 2
    )

    ppv_by_week_full_period.append(ppv)
    npv_by_week_full_period.append(npv)

1
0
1
2
3
4
5
6
7
8
9
2
0
1
2
3
4
5
6
7
8
9
3
0
1
2
3
4
5
6
7
8
9
4
0
1
2
3
4
5
6
7
8
9
5
0
1
2
3
4
5
6
7
8
9
6
0
1
2
3
4
5
6
7
8
9
7
0
1
2
3
4
5
6
7
8
9
8
0
1
2
3
4
5
6
7
8
9
9
0
1
2
3
4
5
6
7
8
9
10
0
1
2
3
4
5
6
7
8
9
11
0
1
2
3
4
5
6
7
8
9
12
0
1
2
3
4
5
6
7
8
9
13
0
1
2
3
4
5
6
7
8
9
14
0
1
2
3
4
5
6
7
8
9
15
0
1
2
3
4
5
6
7
8
9
16
0
1
2
3
All 1
4
5
6
7
8
9
17
0
1
2
3
All 1
4
5
6
7
8
9
18
0
1
2
All 1
3
4
5
6
7
8
9
19
0
1
2
3
4
5
6
7
8
9
20
0
1
2
3
4
5
6
7
8
9
21
0
1
2
3
4
5
6
7
8
9
22
0
1
2
3
4
5
6
7
8
9
23
0
1
2
3
4
5
6
7
8
9
24
0
1
2
3
4
5
6
7
8
9
25
0
1
2
3
4
5
6
7
8
9
26
0
1
2
3
4
5
6
7
8
9
27
0
1
2
3
4
5
6
7
8
9
28
0
1
2
3
4
5
6
7
8
9
29
0
1
2
3
4
5
6
7
8
9
30
0
1
2
3
4
5
6
7
8
9
31
0
1
2
3
4
5
6
7
8
9
32
0
1
2
3
4
5
6
7
8
9
33
0
1
2
3
4
5
6
7
8
9
34
0
1
2
3
4
5
6
7
8
9
35
0
1
2
3
4
5
6
7
8
9
36
0
1
2
3
4
5
6
7
8
9
37
0
1
2
3
4
5
6
7
8
9
38
0
1
2
3
4
5
6
7
8
9
39
0
1
2
3
4
5
6
7
8
9
40
0
1
2
3
4
5
6
7
8
9
41
0
1
2
3
4
5
6
7
8
9
42
0
1
2
3
4
5
6
7
8
9
43
0
1
2
3
4
5
6
7
8
9
44

ValueError: multi_class must be in ('ovo', 'ovr')