In [None]:
"""
Immune-Enhanced Machine Learning Approach for Early Detection of Precancerous Colorectal Neoplasia: 
Insights from Biomarkers in Routine Health Checkups

Author      : Yohan Kim
Date        : 2025-08-08
Email       : biologyohan@gmail.com
Organization: MIH Lab, CHA University
"""

# ========== Library Imports ==========
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
from itertools import product

from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (
    confusion_matrix, precision_recall_curve, auc, 
    roc_auc_score, roc_curve
)
from sklearn.utils.class_weight import compute_class_weight

import torch
import torch.nn as nn
from pytorch_tabnet.tab_model import TabNetClassifier

import random

# ========== Reproducibility Setting ==========
seed = 417
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

# ========== File and Data Path Setting ==========
data_folder = '/data/yohan/nkcell/'
file_names = [
    'Biopsy_Yes_key_6146_feature27_threshold200.xlsx'
]
output_vars = ['nofadvneoplasm', 'computedhighrisk']
output_var = 'computedhighrisk'

# Output folder for saving evaluation results
base_output_path = '/data/yohan/nkcell/results'
timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
output_dir = os.path.join(base_output_path, timestamp)
os.makedirs(output_dir, exist_ok=True)

# ========== Device Setting ==========
device_name = 'cuda:3' if torch.cuda.is_available() else 'cpu'
print(f"Device in use: {device_name}")

# ========== Model Definition ==========
models = {'TabNet': TabNetClassifier}

# Define hyperparameter search space (manual grid search)
param_distributions = {
    'TabNet': {
        'n_d': [8], # 8 16 32 64
        'n_a': [8], # 8 16 32 64
        'n_steps': [7], # 3 5 7
        'optimizer_params': [{'lr': 0.02}] # 0.01, 0.001...
    }
}

# ========== Main Loop for Each File ==========
for file_name in file_names:
    print(f"\n=== Processing File: {file_name} ===")
    base_file_name = os.path.splitext(file_name)[0]  
    file_path = os.path.join(data_folder, file_name)
    data = pd.read_excel(file_path)

    # Select input features and target
    input_vars = ['age', 'sex', 'totalcholesterol', 'diabetes', 'circum', 'smokingscore', 'drinkingscore', 'wbc', 'plt', 'nkatertiary']
    X_raw = data[input_vars].copy()
    Y = data[output_var].copy()
    X = pd.DataFrame(X_raw, columns=[c.replace(' ', '_') for c in input_vars])

    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=seed)
    thresholds = np.linspace(0, 1, 101)  

    best_params = {}
    model_fold_results = {}  

    # ========== Hyperparameter Tuning ==========
    for model_name, model_cls in models.items():
        print(f"\n[Hyperparameter Tuning: {model_name}]")
        param_grid = param_distributions.get(model_name, {})
        param_keys = list(param_grid.keys())

        best_mean_auc = -1.0
        best_param_dict = {}
        best_fold_predictions = None

        all_param_values = [param_grid[k] for k in param_keys]
        for param_combination in product(*all_param_values):
            current_params = dict(zip(param_keys, param_combination))

            fold_aucs = []
            fold_pred_results = []  

            for train_idx, valid_idx in skf.split(X, Y):
                # Normalization
                scaler = MinMaxScaler()
                X_train_fold, X_valid_fold = scaler.fit_transform(X.iloc[train_idx]), scaler.transform(X.iloc[valid_idx])
                Y_train_fold, Y_valid_fold = Y.iloc[train_idx], Y.iloc[valid_idx]

                # Initialize model with learning rate scheduler
                model_temp = model_cls(
                    device_name=device_name,  
                    scheduler_params={"step_size":10, "gamma":0.9}, 
                    scheduler_fn=torch.optim.lr_scheduler.StepLR,
                    verbose=0,
                    **current_params
                )

                # Cost-sensitive learning: class weights (inverse frequency)
                classes = np.unique(Y_train_fold)
                class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=Y_train_fold)
                class_weights_tensor = torch.tensor(class_weights, dtype=torch.float).to(device_name)  # GPU에 올리기

                # Train model
                model_temp.fit(
                    X_train_fold, Y_train_fold.values,
                    eval_set=[(X_valid_fold, Y_valid_fold.values)],
                    patience=20,  
                    max_epochs=1000,  
                    eval_metric=['auc'],
                    loss_fn=nn.CrossEntropyLoss(weight=class_weights_tensor) 
                )

                Y_proba_valid = model_temp.predict_proba(X_valid_fold)[:, 1]
                fold_aucs.append(roc_auc_score(Y_valid_fold, Y_proba_valid))
                fold_pred_results.append((Y_valid_fold.values, Y_proba_valid))

            # Select best hyperparameter set based on mean AUC
            mean_auc = np.mean(fold_aucs)
            if mean_auc > best_mean_auc:
                best_mean_auc = mean_auc
                best_param_dict = current_params
                best_fold_predictions = fold_pred_results

        best_params[model_name] = best_param_dict
        model_fold_results[model_name] = best_fold_predictions

        print(f"   -> Best Parameters: {best_param_dict}")
        print(f"   -> Mean AUC (5-fold): {best_mean_auc:.3f}")

    # ========== Evaluation with Best Parameters ==========
    overall_results = []

    for model_name, fold_predictions in model_fold_results.items():
        print(f"\n=== Final Evaluation for Model: {model_name} ===")

        all_y_true = np.concatenate([fp[0] for fp in fold_predictions])
        all_y_proba = np.concatenate([fp[1] for fp in fold_predictions])

        threshold_metrics = []
        for thr in thresholds:
            fold_measures = []
            for (y_true_fold, y_proba_fold) in fold_predictions:
                y_pred_fold = (y_proba_fold >= thr).astype(int)

                cm = confusion_matrix(y_true_fold, y_pred_fold)
                tp = cm[1, 1]
                tn = cm[0, 0]
                fp = cm[0, 1]
                fn = cm[1, 0]

                sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
                specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
                ppv         = tp / (tp + fp) if (tp + fp) > 0 else 0
                npv         = tn / (tn + fn) if (tn + fn) > 0 else 0

                roc_val_fold = roc_auc_score(y_true_fold, y_proba_fold)
                prec_fold, rec_fold, _ = precision_recall_curve(y_true_fold, y_proba_fold)
                pr_val_fold = auc(rec_fold, prec_fold)

                youden_j = sensitivity + specificity - 1
                fold_measures.append([
                    sensitivity, specificity, ppv, npv,
                    roc_val_fold, pr_val_fold, youden_j
                ])

            avg_vals = np.mean(fold_measures, axis=0)
            std_vals = np.std(fold_measures, axis=0)
            threshold_metrics.append([thr, *avg_vals, *std_vals])

        columns = [
            'Threshold',
            'Sensitivity', 'Specificity', 'PPV', 'NPV', 'ROC-AUC', 'PR-AUC', 'Youden',
            'Sensitivity_STD', 'Specificity_STD', 'PPV_STD', 'NPV_STD', 'ROC-AUC_STD', 'PR-AUC_STD', 'Youden_STD'
        ]
        threshold_df = pd.DataFrame(threshold_metrics, columns=columns)

        # Select optimal threshold using Youden's index
        best_idx = threshold_df['Youden'].idxmax()
        optimal_threshold = threshold_df.loc[best_idx, 'Threshold']
        print(f"   -> Optimal Threshold (Youden's index): {optimal_threshold:.3f}")

        # Plot and save Confusion Matrix per fold
        for fold_idx, (y_true_fold, y_proba_fold) in enumerate(fold_predictions):
            y_pred_fold = (y_proba_fold >= optimal_threshold).astype(int)

            cm = confusion_matrix(y_true_fold, y_pred_fold)
            plt.figure(figsize=(5,4))
            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
            plt.title(f"[{model_name}] CM - Fold {fold_idx+1}\n"
                      f"Threshold={optimal_threshold:.2f}, {base_file_name}")
            plt.xlabel("Predicted"); 
            plt.ylabel("Actual")

            cm_img = os.path.join(
                output_dir, 
                f"ConfMat_{model_name}_Fold{fold_idx+1}_{base_file_name}.png"
            )
            plt.savefig(cm_img)
            plt.close()

        # Save summary metrics
        row_opt = threshold_df.loc[best_idx]
        overall_results.append([
            file_name,
            model_name,
            f"{optimal_threshold:.3f}",
            f"{row_opt['Sensitivity']:.3f} ± {row_opt['Sensitivity_STD']:.3f}",
            f"{row_opt['Specificity']:.3f} ± {row_opt['Specificity_STD']:.3f}",
            f"{row_opt['PPV']:.3f} ± {row_opt['PPV_STD']:.3f}",
            f"{row_opt['NPV']:.3f} ± {row_opt['NPV_STD']:.3f}",
            f"{row_opt['ROC-AUC']:.3f} ± {row_opt['ROC-AUC_STD']:.3f}",
            f"{row_opt['PR-AUC']:.3f} ± {row_opt['PR-AUC_STD']:.3f}",
            f"{row_opt['Youden']:.3f} ± {row_opt['Youden_STD']:.3f}"
        ])

        # Plot ROC curve
        fpr, tpr, _ = roc_curve(all_y_true, all_y_proba)
        plt.figure(figsize=(6,5))
        plt.plot(fpr, tpr, label=f"ROC (AUC={row_opt['ROC-AUC']:.3f})")
        plt.plot([0,1], [0,1], '--', color='gray')
        plt.title(f"[{model_name}] ROC - {base_file_name}")
        plt.xlabel("False Positive Rate"); plt.ylabel("True Positive Rate")
        plt.legend(loc='lower right')
        plt.savefig(os.path.join(output_dir, f"ROC_Curve_{model_name}_{base_file_name}.png"))
        plt.show()

        # Plot PR curve
        prec_all, rec_all, _ = precision_recall_curve(all_y_true, all_y_proba)
        plt.figure(figsize=(6,5))
        plt.plot(rec_all, prec_all, label=f"PR (AUC={row_opt['PR-AUC']:.3f})")
        plt.title(f"[{model_name}] PR - {base_file_name}")
        plt.xlabel("Recall"); plt.ylabel("Precision")
        plt.legend(loc='upper right')
        plt.savefig(os.path.join(output_dir, f"PR_Curve_{model_name}_{base_file_name}.png"))
        plt.show()

    # Save final summary table
    results_df = pd.DataFrame(overall_results, columns=[
        'File','Model','OptimalThreshold',
        'Sensitivity','Specificity','PPV','NPV',
        'ROC-AUC','PR-AUC','Youden'
    ])
    csv_path = os.path.join(output_dir, f"final_results_{base_file_name}.csv")
    results_df.to_csv(csv_path, index=False, encoding='utf-8')
    print(f"   -> Final results saved to: {csv_path}")

print("\nAll processes completed successfully.")