# Model Training on TCGA + CGGA Data for shipping in pyGSLModel

### Importing Libraries

In [17]:
import sys
from pathlib import Path

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

import json
import joblib

import torch
import torch.nn as nn

from scipy.special import expit

from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import StratifiedKFold

from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.feature_selection import VarianceThreshold
from sklearn.base import clone

from xgboost import XGBClassifier

from skopt import BayesSearchCV
from skopt.space import Real, Integer

from skorch import NeuralNetClassifier
from skorch.callbacks import EarlyStopping

from imblearn.over_sampling import SMOTE
from imblearn.pipeline import Pipeline

from lifelines import CoxPHFitter
from lifelines.statistics import logrank_test

### Setting up the ANN class

In [18]:
#######################################
### Defining Helper Classes for ANN ###
#######################################

# Focal Loss class for enabling training focused on the difficult to predict class
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=5, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
    
    def forward(self, logits, targets):
        targets = targets.view(-1,1).type_as(logits)
        bce_loss = nn.functional.binary_cross_entropy_with_logits(logits, targets, reduction='none')
        pt = torch.exp(-bce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * bce_loss
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss
        
# Class for the ANN model (binary classification)
class DeepBinary(nn.Module):
    def __init__(self, hidden_dim=64, num_layers=4, dropout_rate=0.25):
        super().__init__()
        layers = []
        layers.append(nn.LazyLinear(hidden_dim))
        layers.append(nn.LayerNorm(hidden_dim))
        layers.append(nn.ReLU())
        layers.append(nn.Dropout(dropout_rate))

        for _ in range(num_layers - 1):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.LayerNorm(hidden_dim))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout_rate))

        layers.append(nn.Linear(hidden_dim, 1))  # final logit
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

# Defining NeuralNet class which will be necessary for use with skorch and skopt
class NeuralNetBinaryClassifier(NeuralNetClassifier):
    def predict_proba(self, X):
        logits = self.forward(X).detach().cpu().numpy()
        probs = expit(logits)
        return np.hstack((1 - probs, probs))

############################################
### Defining Deep Learning Skorch Set up ###
############################################

# Defining the base ANN model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = NeuralNetBinaryClassifier(
    module = DeepBinary,
    criterion = FocalLoss,
    criterion__alpha = 0.25,
    criterion__gamma = 2.0,
    max_epochs = 500,
    lr = 1e-3,
    optimizer = torch.optim.Adam,
    optimizer__weight_decay=1e-4,
    batch_size = 128,
    device = device,
    verbose = 0,
    callbacks=[
        EarlyStopping(
            monitor='valid_loss',
            threshold=0.01,
            patience=100,
            lower_is_better=True,
            load_best=False
        )
    ]
)

### Defining the hyperparameter search space

In [19]:
# Defining the ANN search space
deep_search_space = {
    "lr": Real(0.000001, 0.01, prior="log-uniform"),
    "module__hidden_dim": Integer(4, 256),
    "module__num_layers": Integer(1, 4),
    "module__dropout_rate": Real(0.0, 0.5),
    "criterion__alpha": Real(0.0, 0.5),
    "criterion__gamma": Real(1.0, 7.0)
}

############################################
### Defining non ANN model search spaces ###
############################################

search_spaces = {
    "SVM": (
        SVC(probability=True, class_weight="balanced", kernel = 'rbf', gamma='scale', random_state=42),
        {
            "C": Real(0.001, 1.0, prior="log-uniform")
        }
    ),
    "RandomForest": (
        RandomForestClassifier(class_weight="balanced", random_state=42),
        {
            "n_estimators": Integer(50, 500),
            "max_depth": Integer(2, 20),
            "min_samples_split": Integer(2, 20),
            "min_samples_leaf": Integer(1, 10),
        }
    ),
    "XGBoost": (
        XGBClassifier(eval_metric="logloss", random_state=42),
        {
            "n_estimators": Integer(50, 500),
            "max_depth": Integer(2, 20),
            "learning_rate": Real(0.001, 0.1, prior="log-uniform"),
            "subsample": Real(0.5, 1.0),
            "colsample_bytree": Real(0.5, 1.0),
            "scale_pos_weight": Real(1.0, 10.0)
        }
    ),
    "LogisticRegression": (
        LogisticRegression(max_iter=5000, class_weight="balanced", solver="lbfgs"),
        {
            "C": Real(0.001, 1.0, prior="log-uniform")
        }
    )
}

### Setting up utility functions

In [20]:
##################################
### Defining Utility Functions ###
##################################

# Alternative Function for tuning on log-rank z
def tune_threshold_by_logrank(
    probs_train: np.ndarray,
    time_train: np.ndarray,
    event_train: np.ndarray,
) -> tuple[float, float]:
    
    probs_train = np.asarray(probs_train, float).ravel()
    time_train  = np.asarray(time_train,  float).ravel()
    event_train = np.asarray(event_train, bool).ravel()

    # Candidate thresholds from fixed quantiles
    qs = np.linspace(0.3, 0.7, 41)
    cands = np.unique(np.quantile(probs_train, qs))
    best_thr  = float(np.median(probs_train))
    best_stat = -np.inf
    found     = False

    for thr in cands:
        hi = probs_train >= thr
        lo = ~hi
        if hi.sum() == 0 or lo.sum() == 0:
            continue
        try:
            lr = logrank_test(
                time_train[hi], time_train[lo],
                event_observed_A=event_train[hi],
                event_observed_B=event_train[lo],
            )
            chi2 = float(lr.test_statistic)
            if np.isfinite(chi2) and chi2 > best_stat:
                best_stat = chi2
                best_thr  = float(thr)
                found     = True
        except Exception:
            continue

    if not found:
        # fallback: median threshold, 0 separation
        return float(np.median(probs_train)), 0.0

    return best_thr, best_stat

### Setting up the training function

In [21]:
# Downloading TCGA Data (Training)
TCGA_URL = "./iMAT_integrated_data/TCGA_iMAT_integrated_df_21.csv"
TCGA = pd.read_csv(TCGA_URL).drop(columns=["sample"])
TCGA = TCGA.dropna()

# Downloading CGGA Data (Validation)
CGGA_URL = "./CGGA_Data/CGGA_Tidied_Integrated.csv"
CGGA = pd.read_csv(CGGA_URL).drop(columns=["CGGA_ID"])
CGGA = CGGA.dropna()

# Checking Columns Match
print(f"TCGA Columns == CGGA columns: {TCGA.columns.to_list()==CGGA.columns.to_list()}")

# Combining the dataframes
df_LGG = pd.concat([TCGA,CGGA],axis=0)

# Setting up X and y
LGG_OS = df_LGG[["OS", "OS.time"]]
le = LabelEncoder().fit(df_LGG["OS"])

X_LGG = df_LGG.drop(columns = ["OS", "OS.time"])
y_LGG = le.transform(df_LGG["OS"])

# Setting up survival data
LGG_event = LGG_OS['OS'].values.astype(bool)
LGG_time  = LGG_OS['OS.time'].values.astype(float)

TCGA Columns == CGGA columns: True


In [22]:
sum(y_LGG)/len(y_LGG)

0.3475026567481403

### Defining the Training Function

In [23]:
def train_evaluate_model(random_state=42,inner_folds=3,inner_iterations=25,ANN_iterations=25,savepath=None):

    # Stetting variables
    RANDOM_STATE = random_state
    INNER_FOLDS = inner_folds
    N_ITER_INNER = inner_iterations
    N_ITER_ANN = ANN_iterations
    N_JOBS = -1

    # Preparing storage dictionaries
    models_info = {name: {'estimator': m, 'space': s} for name, (m, s) in search_spaces.items()}
    models_info['ANN'] = {'estimator': net, 'space': deep_search_space}
    probs_train_store = {}
    thr_values = {}

    for model_name, info in models_info.items():
        with open("./Complete_Model/training_log.txt", "a") as file:
            print(f'\nTuning and Fitting: {model_name}',file=file)
        base = clone(info['estimator'])
        space = info['space']

        # Defining the pipeline for the model
        pipe = Pipeline([
            ('low_var', VarianceThreshold()),
            ('scaler', StandardScaler()),
            ('smote', SMOTE(random_state=RANDOM_STATE,sampling_strategy='auto')),
            ('clf', base)
        ])


        # prefix search space
        space_prefixed = {f'clf__{k}': v for k, v in space.items()}


        #Selecting iterations
        n_iter = N_ITER_ANN if model_name == 'ANN' else N_ITER_INNER
        n_jobs = 1 if model_name == 'ANN' else N_JOBS

        opt=BayesSearchCV(
            estimator=pipe,
            search_spaces=space_prefixed,
            n_iter=n_iter,
            scoring='average_precision',
            cv=StratifiedKFold(n_splits=INNER_FOLDS, shuffle=True, random_state=RANDOM_STATE),
            random_state=RANDOM_STATE,
            n_jobs=n_jobs,
            refit=True,
        )

        # Fitting
        fit_X = X_LGG.astype(np.float32) if model_name == 'ANN' else X_LGG
        opt.fit(fit_X, y_LGG)

        # Plotting Loss Curves
        if model_name == "ANN":
            best = opt.best_estimator_
            models_info[model_name]["trained_estimator"] = best
            # skorch Net is inside the pipeline as step 'clf'
            net_trained = best.named_steps['clf']

            # Access the training history
            history = net_trained.history_

            # Extract values
            train_losses = history[:, 'train_loss']
            valid_losses = history[:, 'valid_loss']
            plt.figure(figsize=(6,4))
            plt.plot(train_losses, label="Train Loss")
            if valid_losses is not None:
                plt.plot(valid_losses, label="Valid Loss")
            plt.xlabel("Epoch")
            plt.ylabel("Loss")
            plt.title(f"ANN Loss Curve")
            plt.legend()
            plt.grid(alpha=0.3)
            plt.savefig(savepath)
            plt.close()

        with open("./Complete_Model/training_log.txt", "a") as file:
            print(f"Best params for {model_name}: {opt.best_params_}",file=file)

        # In this case fit base estimator inside the pipeline without search
        best = opt.best_estimator_
        models_info[model_name]["trained_estimator"] = best
        
        # Tuning for best threshold:
        train_preds_input = fit_X.astype(np.float32) if model_name == 'ANN' else fit_X
        probs_train = best.predict_proba(train_preds_input)[:, 1].ravel()

        # Store training probs
        probs_train_store[model_name] = probs_train

        # Tuning threshold on training data
        thr, thr_best = tune_threshold_by_logrank(probs_train=probs_train, time_train=LGG_time,event_train=LGG_event)
        thr_values[model_name] = thr
        with open("./Complete_Model/training_log.txt", "a") as file:
            print(f"Tuned threshold for {model_name}: {thr:.2f} (Log-Rank Chi2={thr_best:.3f})",file=file)

    # Computing Ensemble predictions for this fold against the train set for threshold tuning
    model_list = ["SVM","RandomForest","XGBoost","LogisticRegression","ANN"]
    L_train = np.vstack([probs_train_store[m]for m in model_list]).T
    std_scaler = StandardScaler().fit(L_train)
    Z_train = std_scaler.transform(L_train)

    probs_train_df = pd.DataFrame(Z_train, columns=model_list)
    probs_train_df["OS"] = LGG_event
    probs_train_df["OS.time"] = LGG_time

    cph = CoxPHFitter(penalizer=0.05, l1_ratio=0.0)
    cph.fit(probs_train_df, duration_col="OS.time", event_col="OS", robust=True)

    beta_vec = cph.params_[model_list].values
    eta_train = Z_train @ beta_vec               

    # Threshold tuning on TRAIN
    thr_ens, thr_ens_best = tune_threshold_by_logrank(eta_train, LGG_time, LGG_event)
    thr_values["Ensemble"] = thr_ens
    probs_train_store['Ensemble'] = eta_train
    with open("./Complete_Model/training_log.txt", "a") as file:
        print(f"Tuned threshold for Ensemble: {thr_ens:.2f} (Log-Rank Chi2={thr_ens_best:.3f})",file=file)


    save_root = Path("./Complete_Model")
    (save_root / "models").mkdir(parents=True, exist_ok=True)
    (save_root / "thresholds").mkdir(exist_ok=True)
    (save_root / "ensemble").mkdir(exist_ok=True)
    
    #############################
    # 1. Save classical models
    #############################
    for model_name in model_list:  # SVM, RF, XGB, LR, ANN
        if model_name != "ANN":
            best_model = models_info[model_name]["trained_estimator"]
            joblib.dump(best_model, save_root / f"models/{model_name}.pkl")

    #############################
    # 2. Save ANN pipeline (.pkl)
    #############################
    ann_pipeline = models_info["ANN"]["trained_estimator"]
    joblib.dump(ann_pipeline, save_root / "models/ANN_pipeline.pkl")

    #############################
    # 3. Save ANN weights (.pt)
    #############################
    ann_clf = ann_pipeline.named_steps["clf"]
    ann_clf.save_params(f_params=str(save_root / "models/ANN_params.pt"))

    #############################
    # 4. Create HF-ready ANN config
    #############################
    # Extract ANN module (PyTorch model)
    # good â€” use a new name that doesn't clash
    ann_module = ann_clf.module_


    ann_config = {
        "model_name": "DeepBinary",
        "architecture": {
            "hidden_dim": getattr(ann_module, "hidden_dim", None),
            "num_layers": getattr(ann_module, "num_layers", None),
            "dropout_rate": getattr(ann_module, "dropout_rate", None),
            "output_dim": 2,
        },
        "training": {
            "optimizer": "Adam",
            "learning_rate": ann_clf.lr,
            "batch_size": ann_clf.batch_size,
            "max_epochs": ann_clf.max_epochs,
            "criterion": str(ann_clf.criterion),
        },
        "random_state": random_state,
        "versions": {
            "python": sys.version,
            "pytorch": torch.__version__,
            "skorch": ann_clf.__class__.__module__,
        }
    }

    with open(save_root / "models/ANN_config.json", "w") as f:
        json.dump(ann_config, f, indent=4)

    #############################
    # 5. Save per-model thresholds
    #############################
    json.dump(
        {m: float(thr_values.get(m)) for m in model_list},
        open(save_root / "thresholds/per_model_thresholds.json", "w"),
        indent=4
    )

    #############################
    # 6. Save ensemble components
    #############################
    joblib.dump(std_scaler, save_root / "ensemble/scaler.pkl")
    np.save(save_root / "ensemble/beta_vec.npy", beta_vec)

    json.dump(
        {"ensemble_threshold": float(thr_ens)},
        open(save_root / "ensemble/ens_threshold.json", "w"),
        indent=4
    )

    #############################
    # 7. Save global metadata
    #############################
    metadata = {
        "random_state": random_state,
        "inner_folds": inner_folds,
        "inner_iterations": inner_iterations,
        "ann_iterations": ANN_iterations,
        "model_list": model_list,
        "python_version": sys.version,
    }
    json.dump(metadata, open(save_root / "metadata.json", "w"), indent=4)

    print("\n All models and config files successfully saved")

    return 

In [24]:
train_evaluate_model(random_state=0,inner_folds=5,inner_iterations=50,ANN_iterations=50, savepath=f"./Complete_Model/ANN_Figures.png")


 All models and config files successfully saved
