In [None]:
from baseline_model import load_data_by_year
from baseline_model import drop_high_correlation_variables
from baseline_model import split_data_by_category
from baseline_model import drop_columns_from_data
from baseline_model import drop_columns_by_category_and_class

from SPU import load_aggregated_data
from SPU import tune_model_with_random_search
from SPU import evaluate_model_performance_with_cv
from SPU import evaluate_model_with_checks

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import pickle
import joblib
import json
import os
import random

import sklearn.metrics
from sklearn.model_selection import train_test_split, KFold, RandomizedSearchCV, ParameterGrid
from sklearn.metrics import roc_auc_score, precision_recall_curve, f1_score, precision_score, recall_score,make_scorer, average_precision_score
from sklearn.preprocessing import StandardScaler
from sklearn.feature_selection import VarianceThreshold, mutual_info_classif
from sklearn.utils.class_weight import compute_class_weight
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import StratifiedKFold

from xgboost import XGBClassifier
from sklearn.utils import resample
from scipy.stats import randint, uniform, loguniform

import torch
import torch.nn as nn
import torch.optim as optim
from imblearn.over_sampling import SMOTE
from torch.utils.data import DataLoader, TensorDataset


# Step1 load data

In [None]:
save_path = "./"
data_by_year_training, data_by_year_test = load_data_by_year(save_path)


In [None]:
data_by_year_training, data_by_year_test = drop_high_correlation_variables(
    "high_correlation_variables.txt", data_by_year_training, data_by_year_test
)

In [None]:
categories = ['CKD', 'CKD_2', 'DIABETES', 'DIABETES_2', 'HYPERTENSION_ESSENTIAL', 
              'HYPERTENSION_ALL', 'HYPERTENSIVE_DISEASE', 'LIVER_DISEASE', 'CHRONIC_LIVER_DISEASE', 
              'CIRCULATORY_SYSTEM', 'ISCHEMIC_HEART_DISEASE', 'CONGESTIVE_HEART_FAILURE', 'STROKE', 
              'CARDIOVASCULAR_DISEASE', 'CANCER', 'COPD', 'PERIPHERAL_VASCULAR', 
              'NUTRITIONAL_DEFICIENCIES', 'METABOLIC_SYNDROME', 'SUBSTANCE_ABUSE']

split_data_by_category(categories, data_by_year_test, data_type="test")
split_data_by_category(categories, data_by_year_training, data_type="training")


In [None]:
columns_to_drop = ['CKD', 'CKD_2', 'DIABETES', 'DIABETES_2', 'HYPERTENSION_ESSENTIAL', 
              'HYPERTENSION_ALL', 'HYPERTENSIVE_DISEASE', 'LIVER_DISEASE', 'CHRONIC_LIVER_DISEASE', 
              'CIRCULATORY_SYSTEM', 'ISCHEMIC_HEART_DISEASE', 'CONGESTIVE_HEART_FAILURE', 'STROKE', 
              'CARDIOVASCULAR_DISEASE', 'CANCER', 'COPD', 'PERIPHERAL_VASCULAR', 
              'NUTRITIONAL_DEFICIENCIES', 'METABOLIC_SYNDROME', 'SUBSTANCE_ABUSE']
data_by_year_training = drop_columns_from_data(columns_to_drop, data_by_year_training)
data_by_year_test = drop_columns_from_data(columns_to_drop, data_by_year_test)

In [None]:
categories = ['CKD', 'CKD_2', 'DIABETES', 'DIABETES_2', 'HYPERTENSION_ESSENTIAL', 
              'HYPERTENSION_ALL', 'HYPERTENSIVE_DISEASE', 'LIVER_DISEASE', 'CHRONIC_LIVER_DISEASE', 
              'CIRCULATORY_SYSTEM', 'ISCHEMIC_HEART_DISEASE', 'CONGESTIVE_HEART_FAILURE', 'STROKE', 
              'CARDIOVASCULAR_DISEASE', 'CANCER', 'COPD', 'PERIPHERAL_VASCULAR', 
              'NUTRITIONAL_DEFICIENCIES', 'METABOLIC_SYNDROME', 'SUBSTANCE_ABUSE']

columns_to_drop = ['CKD', 'CKD_2', 'DIABETES', 'DIABETES_2', 'HYPERTENSION_ESSENTIAL', 
              'HYPERTENSION_ALL', 'HYPERTENSIVE_DISEASE', 'LIVER_DISEASE', 'CHRONIC_LIVER_DISEASE', 
              'CIRCULATORY_SYSTEM', 'ISCHEMIC_HEART_DISEASE', 'CONGESTIVE_HEART_FAILURE', 'STROKE', 
              'CARDIOVASCULAR_DISEASE', 'CANCER', 'COPD', 'PERIPHERAL_VASCULAR', 
              'NUTRITIONAL_DEFICIENCIES', 'METABOLIC_SYNDROME', 'SUBSTANCE_ABUSE']

drop_columns_by_category_and_class(categories, columns_to_drop, data_type="test")

drop_columns_by_category_and_class(categories, columns_to_drop, data_type="training")

# Step2 SPU:train and save model

In [None]:
diseases = [
    'HYPERTENSIVE_DISEASE', 'STROKE', 'CARDIOVASCULAR_DISEASE', 'CANCER', 
    'COPD', 'NUTRITIONAL_DEFICIENCIES', 'METABOLIC_SYNDROME', 'SUBSTANCE_ABUSE', 'DIABETES'
]

for disease in diseases:
    data_by_year_training_sample = {}  
    results_storage = {}
    data_var_name = f"data_by_year_training_{disease}_1"
    if data_var_name in globals():  
        data_dict = globals()[data_var_name]

        for year, df in data_dict.items():
            if len(df) > 30000:
                sampled_df = df.sample(n=30000, random_state=42)  
            else:
                sampled_df = df  

            data_by_year_training_sample[year] = sampled_df  

        print(f"Sampling results for {disease}:")
        for year, df in data_by_year_training_sample.items():
            print(f"{year}: {len(df)} rows sampled.")
    
    base_dir = f"./"
    os.makedirs(base_dir, exist_ok=True)

    param_distributions = {
        'n_estimators': [1000],  
        'max_depth': [5],  
        'learning_rate': [0.01],  
        'subsample': [0.9],  
        'colsample_bytree': [0.9]
    }



    rolling_aggregated_metrics_results = {}
    source_models_by_year = {}
    categories = ['CKD', 'CKD_2', 'DIABETES', 'DIABETES_2', 'HYPERTENSION_ESSENTIAL', 
                  'HYPERTENSION_ALL', 'HYPERTENSIVE_DISEASE', 'LIVER_DISEASE', 'CHRONIC_LIVER_DISEASE', 
                  'CIRCULATORY_SYSTEM', 'ISCHEMIC_HEART_DISEASE', 'CONGESTIVE_HEART_FAILURE', 'STROKE', 
                  'CARDIOVASCULAR_DISEASE', 'CANCER', 'COPD', 'PERIPHERAL_VASCULAR', 
                  'NUTRITIONAL_DEFICIENCIES', 'METABOLIC_SYNDROME', 'SUBSTANCE_ABUSE']

    for category in categories:
        for class_type in ['1', '0']: 
            variable_name = f"rolling_aggregated_metrics_results_{category}_{class_type}"

            globals()[variable_name] = {}

    cutoff_year = 2011
    data_by_year_totall = pd.concat(
        [data_by_year_training[2009], data_by_year_training[2010], data_by_year_training[2011]],
        axis=0
    )

    all_years = sorted(data_by_year_training.keys())

    for update_year in range(1, 2, 1): 
        print(f"Running with update_year={update_year}")

        for i, cutoff_year in enumerate(all_years[2:]):  
            print(f"Rolling aggregation up to year {cutoff_year} with update_year={update_year}")
            X_train, y_train = load_aggregated_data(
                data_by_year_totall=data_by_year_totall,
                data_by_year_training=data_by_year_training_sample,
                cutoff_year=cutoff_year,
                update_year=update_year
            )

            best_model, best_params = tune_model_with_random_search(
                param_distributions=param_distributions,
                use_scale_pos_weight=True, 
                X_train=X_train,
                y_train=y_train
            )
            #print(f"Best parameters for aggregated data up to year {cutoff_year} with update_year={update_year}: {random_search.best_params_}")

            model_path = os.path.join(base_dir, f"update_{update_year}_best_model_{cutoff_year}.pkl")
            joblib.dump(best_model, model_path)
            print(f"✅ Saved newly trained model: {model_path}")
                
            source_models_by_year[f"{cutoff_year}_update_{update_year}"] = model_path


            rolling_aggregated_metrics_results[f"update_{update_year}_{cutoff_year}"] = {}
            for category in categories:
                for class_type in ['1', '0']:
                    variable_name = f"rolling_aggregated_metrics_results_{category}_{class_type}"
                    results_dict = globals().get(variable_name, {})

                    if results_dict is None:
                        print(f"Variable {variable_name} not initialized.")
                        continue

                    key = f"update_{update_year}_{cutoff_year}"
                    results_dict[key] = {}  

                    print(f"  Updated {variable_name} with key={key}")

                    globals()[variable_name] = results_dict



            for test_year in range(cutoff_year+1, cutoff_year + update_year+1):  # 不包含 cutoff_year + update_year
                if test_year in data_by_year_test: 
                    metrics_mean = evaluate_model_performance_with_cv(
                        data_by_year_test=data_by_year_test,
                        test_year=test_year,
                        best_model=best_model,
                    )
                    rolling_aggregated_metrics_results[f"update_{update_year}_{cutoff_year}"][test_year] = {
                        'AUC': metrics_mean['AUC'],
                        'PRAUC': metrics_mean['PRAUC'],
                        'F1-Score': metrics_mean['F1'],
                        'Precision': metrics_mean['Precision'],
                        'Recall': metrics_mean['Recall']
                    }

                    print(f"Metrics for update_year={update_year}, cutoff_year={cutoff_year}, tested on year {test_year}:")

                for category in categories:
                    for class_type in ['1', '0']:
                        test_data_variable = f"data_by_year_test_{category}_{class_type}"
                        test_data_dict = globals().get(test_data_variable, {})

                        if not test_data_dict or test_year not in test_data_dict:
                            print(f"Missing test data for {category}_{class_type} in year {test_year}.")
                            continue

                        metrics_mean = evaluate_model_with_checks(
                            test_data_dict=test_data_dict,
                            test_year=test_year,
                            best_model=best_model,
                            category=category,
                            class_type=class_type,
                        )
                        #print(f"Category: {category}, Class: {class_type}, Year: {test_year}, Metrics: {metrics_mean}")

                        if category not in results_storage:
                            results_storage[category] = {}
                        if class_type not in results_storage[category]:
                            results_storage[category][class_type] = {}
                        if test_year not in results_storage[category][class_type]:
                            results_storage[category][class_type][test_year] = {}

                        if metrics_mean: 
                            results_storage[category][class_type][test_year] = {
                                'AUC': metrics_mean.get('AUC', None),
                                'PRAUC': metrics_mean.get('PRAUC', None),
                                'F1-Score': metrics_mean.get('F1', None),
                                'Precision': metrics_mean.get('Precision', None),
                                'Recall': metrics_mean.get('Recall', None)
                            }
                        else:
                            print(f"Invalid metrics for {category}_{class_type} in year {test_year}!")
    
    class_types = ['1', '0']
    results_base_dir = f"./"
    os.makedirs(results_base_dir, exist_ok=True)  

    categories_class_types_path = os.path.join(results_base_dir, 'categories_and_class_types.json')
    with open(categories_class_types_path, 'w') as f:
        json.dump({'categories': categories, 'class_types': class_types}, f)
    print(f"Saved categories and class types to {categories_class_types_path}")

    results_storage_path = os.path.join(results_base_dir, 'results_storage.json')
    with open(results_storage_path, 'w') as f:
        json.dump(results_storage, f)
    print(f"Saved results_storage to {results_storage_path}")

    rolling_results_path = os.path.join(results_base_dir, 'rolling_aggregated_metrics_results.json')
    with open(rolling_results_path, 'w') as f:
        json.dump(rolling_aggregated_metrics_results, f)
    print(f"Saved rolling_aggregated_metrics_results to {rolling_results_path}")
