In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))
!pip install torch-geometric ucimlrepo imbalanced-learn
# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
from sklearn.model_selection import StratifiedKFold
from ucimlrepo import fetch_ucirepo
from sklearn.metrics import roc_auc_score,roc_curve, precision_recall_curve,average_precision_score, balanced_accuracy_score, recall_score, f1_score, confusion_matrix, accuracy_score
from sklearn.preprocessing import StandardScaler
from imblearn.over_sampling import SMOTE
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_geometric.nn import TransformerConv
from torch_geometric.utils import to_undirected, coalesce, remove_self_loops
from sklearn.metrics.pairwise import cosine_similarity
import warnings
import os
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
import json
from scipy.stats import linregress
import copy
from sklearn.metrics import roc_auc_score, average_precision_score, balanced_accuracy_score, recall_score, f1_score, confusion_matrix, accuracy_score

In [None]:
class TemporalFusion3(nn.Module):
    """
    [Model Code Withheld for Peer Review]

    The implementation of the TemporalFusion3 model is currently withheld
    for the double-blind peer review process. The full source code will be made 
    publicly available upon the acceptance of the associated research paper.

    This placeholder class is provided to ensure that the surrounding training 
    and evaluation framework remains executable. It accepts the same parameters
    as the original model but contains a simplified forward pass that does not 
    reflect the model's true architecture or performance.
    """
    def __init__(self, num_total_features, num_classes, **kwargs):
        # All original model parameters are accepted via **kwargs to maintain 
        # compatibility with the training script, but they are not used here.
        super().__init__()
        
        # A simple linear layer is used as a placeholder for the complex
        # architecture of the original model. This ensures the output shape
        # is correct for the loss function.
        self.placeholder_fc = nn.Linear(num_total_features, num_classes)
        print("NOTE: Using a placeholder for TemporalFusion3. Model code is withheld for peer review.")

    def forward(self, data):
        # The original model performed dynamic graph construction, temporal feature
        # fusion, and graph transformer layers.
        # This placeholder performs a minimal operation to allow the script to run.
        x = data.x
        
        # The original forward pass included dynamic graph generation like this:
        # x_for_graph = x[:, self.indices_for_graph_construction]
        # edge_index = self.generate_optimized_knn_graph(x_for_graph)
        # This placeholder omits graph-based operations.

        # The forward pass returns a tensor of the correct shape [batch_size, num_classes]
        return self.placeholder_fc(x)

In [None]:
# --- Define device ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# --- Loss Function ---
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.5, gamma=3.0, reduction='mean', pos_weight=None):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        self.pos_weight = pos_weight

    def forward(self, inputs, targets):
        current_pos_weight = self.pos_weight
        if current_pos_weight is not None and current_pos_weight.device != inputs.device:
            current_pos_weight = current_pos_weight.to(inputs.device)

        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none', pos_weight=current_pos_weight)
        probs = torch.sigmoid(inputs)
        pt = torch.where(targets == 1, probs, 1 - probs)
        alpha_t = torch.where(targets == 1, self.alpha, 1 - self.alpha)
        focal_weight = alpha_t * torch.pow(1 - pt, self.gamma)
        loss = focal_weight * BCE_loss
        
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

# --- Data Preprocessing ---
myocardial_infarction_complications = fetch_ucirepo(id=579)
X_df_orig = myocardial_infarction_complications.data.features.copy()
y_df_orig = myocardial_infarction_complications.data.targets.copy()
target_names = list(y_df_orig.columns)

feature_collection_time = {
    'AGE': 'admission', 'SEX': 'admission', 'INF_ANAM': 'admission', 'STENOK_AN': 'admission',
    'FK_STENOK': 'admission', 'IBS_POST': 'admission', 'IBS_NASL': 'admission', 'GB': 'admission',
    'SIM_GIPERT': 'admission', 'DLIT_AG': 'admission', 'ZSN_A': 'admission', 'nr_11': 'admission',
    'nr_01': 'admission', 'nr_02': 'admission', 'nr_03': 'admission', 'nr_04': 'admission',
    'nr_07': 'admission', 'nr_08': 'admission', 'np_01': 'admission', 'np_04': 'admission',
    'np_05': 'admission', 'np_07': 'admission', 'np_08': 'admission', 'np_09': 'admission',
    'np_10': 'admission', 'endocr_01': 'admission', 'endocr_02': 'admission', 'endocr_03': 'admission',
    'zab_leg_01': 'admission', 'zab_leg_02': 'admission', 'zab_leg_03': 'admission',
    'zab_leg_04': 'admission', 'zab_leg_06': 'admission', 'S_AD_KBRIG': 'admission',
    'D_AD_KBRIG': 'admission', 'S_AD_ORIT': 'admission', 'D_AD_ORIT': 'admission',
    'O_L_POST': 'admission', 'K_SH_POST': 'admission', 'MP_TP_POST': 'admission',
    'SVT_POST': 'admission', 'GT_POST': 'admission', 'FIB_G_POST': 'admission',
    'ant_im': 'admission', 'lat_im': 'admission', 'inf_im': 'admission', 'post_im': 'admission',
    'IM_PG_P': 'admission', 'ritm_ecg_p_01': 'admission', 'ritm_ecg_p_02': 'admission',
    'ritm_ecg_p_04': 'admission', 'ritm_ecg_p_06': 'admission', 'ritm_ecg_p_07': 'admission',
    'ritm_ecg_p_08': 'admission', 'n_r_ecg_p_01': 'admission', 'n_r_ecg_p_02': 'admission',
    'n_r_ecg_p_03': 'admission', 'n_r_ecg_p_04': 'admission', 'n_r_ecg_p_05': 'admission',
    'n_r_ecg_p_06': 'admission', 'n_r_ecg_p_08': 'admission', 'n_r_ecg_p_09': 'admission',
    'n_r_ecg_p_10': 'admission', 'n_p_ecg_p_01': 'admission', 'n_p_ecg_p_03': 'admission',
    'n_p_ecg_p_04': 'admission', 'n_p_ecg_p_05': 'admission', 'n_p_ecg_p_06': 'admission',
    'n_p_ecg_p_07': 'admission', 'n_p_ecg_p_08': 'admission', 'n_p_ecg_p_09': 'admission',
    'n_p_ecg_p_10': 'admission', 'n_p_ecg_p_11': 'admission', 'n_p_ecg_p_12': 'admission',
    'fibr_ter_01': 'admission', 'fibr_ter_02': 'admission', 'fibr_ter_03': 'admission',
    'fibr_ter_05': 'admission', 'fibr_ter_06': 'admission', 'fibr_ter_07': 'admission',
    'fibr_ter_08': 'admission', 'GIPO_K': 'admission', 'K_BLOOD': 'admission',
    'GIPER_NA': 'admission', 'NA_BLOOD': 'admission', 'ALT_BLOOD': 'admission',
    'AST_BLOOD': 'admission', 'KFK_BLOOD': 'admission', 'L_BLOOD': 'admission',
    'ROE': 'admission', 'TIME_B_S': 'admission', 'NA_KB': 'admission', 'NOT_NA_KB': 'admission',
    'LID_KB': 'admission', 'NITR_S': 'admission', 'LID_S_n': 'admission',
    'B_BLOK_S_n': 'admission', 'ANT_CA_S_n': 'admission', 'GEPAR_S_n': 'admission',
    'ASP_S_n': 'admission', 'TIKL_S_n': 'admission', 'TRENT_S_n': 'admission',
    'R_AB_1_n': 'day1', 'R_AB_2_n': 'day2', 'R_AB_3_n': 'day3',
    'NA_R_1_n': 'day1', 'NA_R_2_n': 'day2', 'NA_R_3_n': 'day3',
    'NOT_NA_1_n': 'day1', 'NOT_NA_2_n': 'day2', 'NOT_NA_3_n': 'day3',
}

admission_features_names = [f for f, t in feature_collection_time.items() if t == 'admission' and f in X_df_orig.columns]
day1_features_names = [f for f, t in feature_collection_time.items() if t == 'day1' and f in X_df_orig.columns]
day2_features_names = [f for f, t in feature_collection_time.items() if t == 'day2' and f in X_df_orig.columns]
day3_features_names = [f for f, t in feature_collection_time.items() if t == 'day3' and f in X_df_orig.columns]

if not (len(day1_features_names) == len(day2_features_names) == len(day3_features_names)):
    print("Warning: Temporal features per day are not consistent. Ensure this is intended.")

# Generate derived temporal features
temporal_base_features = ['R_AB', 'NA_R', 'NOT_NA']
derived_features = []
for base in temporal_base_features:
    day1_col, day2_col, day3_col = f'{base}_1_n', f'{base}_2_n', f'{base}_3_n'
    X_df_orig[f'{base}_diff_2_1'] = X_df_orig[day2_col] - X_df_orig[day1_col]
    X_df_orig[f'{base}_diff_3_2'] = X_df_orig[day3_col] - X_df_orig[day2_col]
    derived_features.extend([f'{base}_diff_2_1', f'{base}_diff_3_2'])
    X_df_orig[f'{base}_mean'] = X_df_orig[[day1_col, day2_col, day3_col]].mean(axis=1)
    derived_features.append(f'{base}_mean')
    def compute_trend(row):
        x, y = np.array([1, 2, 3]), row[[day1_col, day2_col, day3_col]].values
        return linregress(x, y).slope if not np.any(np.isnan(y)) else 0.0
    X_df_orig[f'{base}_trend'] = X_df_orig.apply(compute_trend, axis=1)
    derived_features.append(f'{base}_trend')

ordered_features_all = admission_features_names + day1_features_names + day2_features_names + day3_features_names + derived_features
X_df_orig = X_df_orig[ordered_features_all]

num_admission_features_model = len(admission_features_names) + len(derived_features)
num_temporal_features_per_step_model = len(day1_features_names)
num_time_steps_model = 3

for col in ordered_features_all:
    if X_df_orig[col].isnull().any():
        X_df_orig[col].fillna(X_df_orig[col].median(), inplace=True)

print("Processing target labels:")
let_is_idx = target_names.index('LET_IS') if 'LET_IS' in target_names else 11
num_binary_classes = len(target_names)
y_binary = np.zeros((y_df_orig.shape[0], num_binary_classes), dtype=int)

for col_idx, col_name in enumerate(y_df_orig.columns):
    y_df_orig[col_name].fillna(y_df_orig[col_name].median(), inplace=True)
    y_df_orig[col_name] = y_df_orig[col_name].astype(int)
    y_binary[:, col_idx] = (y_df_orig[col_name] > 0).astype(int)
    if not np.array_equal(np.unique(y_df_orig[col_name]), np.unique(y_binary[:, col_idx])):
        print(f"  Target '{col_name}': Original unique values {np.unique(y_df_orig[col_name])} -> Binarized unique values {np.unique(y_binary[:, col_idx])}")

X_np_all, y_np_all = X_df_orig.values, y_binary

graph_construction_feature_names = ['AGE', 'SEX', 'K_BLOOD']
indices_for_graph_construction_model = [ordered_features_all.index(name) for name in graph_construction_feature_names if name in ordered_features_all]
if not indices_for_graph_construction_model:
    print("Warning: No graph construction features found by name. Using first 3 features as fallback.")
    indices_for_graph_construction_model = list(range(min(3, X_np_all.shape[1])))
print(f"Indices used for graph construction (in data.x): {indices_for_graph_construction_model}")

def create_patient_nodes(X_np_fold_scaled, y_np_fold_labels):
    return [Data(x=torch.tensor(X_np_fold_scaled[i, :], dtype=torch.float).unsqueeze(0), 
                 y=torch.tensor(y_np_fold_labels[i, :], dtype=torch.float).unsqueeze(0)) 
            for i in range(X_np_fold_scaled.shape[0])]

In [None]:
backbone_configs = {
    'TemporalFusion3_v1': {
        'model': TemporalFusion3,
        'params': {
            'num_encoder_layers': 3,
            'hidden_channels_per_head': 16,
            'heads_per_layer': 4,
            'dropout_rate': 0.3,
            'short_term_cnn_out_channels': 64,
            'short_term_cnn_layers': 4,
            'long_term_gru_hidden_dim': 16,
            'long_term_gru_input_dim': 12,
            'long_term_gru_num_layers': 1,
            'temporal_embedding_dim': 32,
            'k_min': 5,
            'k_max': 15,
            'sim_threshold': 0.5,
        }
    },
}

In [None]:
import shutil

# --- Training Parameters ---
k_folds = 5
training_epochs = 300
early_stopping_patience = 50
learning_rate = 1e-3
weight_decay = 2e-3
batch_size = 64

# --- Define Fold Splits (fixed for reproducibility) ---
stratify_on_labels = (y_np_all.sum(axis=1) > 0).astype(int)
skf = StratifiedKFold(n_splits=k_folds, shuffle=True, random_state=42)
global_fold_indices = list(skf.split(X_np_all, stratify_on_labels))

# --- Main Training ---
for backbone_name, config in backbone_configs.items():
    print(f"\n--- Starting {k_folds}-Fold CV for Backbone: {backbone_name} ---")
    model_class = config['model']
    model_params = config['params']
    
    MODEL_NAME_FOR_PATHS = f"TemporalGraphTransformer_Binary_{backbone_name}"
    output_dir = f'/kaggle/working/{MODEL_NAME_FOR_PATHS}'
    checkpoints_dir = os.path.join(output_dir, 'checkpoints')
    os.makedirs(checkpoints_dir, exist_ok=True)

    all_folds_metrics = []
    all_folds_train_logs = []
    all_folds_val_logs = []
    max_epochs_across_folds = 0

    for fold_idx, (train_idx, val_idx) in enumerate(global_fold_indices):
        print(f"\n--- Backbone: {backbone_name}, Fold {fold_idx + 1}/{k_folds} ---")
        X_train, X_val = X_np_all[train_idx], X_np_all[val_idx]
        y_train, y_val = y_np_all[train_idx], y_np_all[val_idx]

        scaler = StandardScaler().fit(X_train)
        X_train_scaled = np.nan_to_num(scaler.transform(X_train), nan=0.0)
        X_val_scaled = np.nan_to_num(scaler.transform(X_val), nan=0.0)

        train_nodes = create_patient_nodes(X_train_scaled, y_train)
        val_nodes = create_patient_nodes(X_val_scaled, y_val)

        train_loader = DataLoader(train_nodes, batch_size=batch_size, shuffle=True, drop_last=True)
        val_loader = DataLoader(val_nodes, batch_size=len(val_nodes) if val_nodes else 1, shuffle=False)

        model = model_class(
            num_total_features=X_train_scaled.shape[1],
            num_admission_features=num_admission_features_model,
            num_temporal_features_per_step=num_temporal_features_per_step_model,
            num_time_steps=num_time_steps_model,
            indices_for_graph_construction=indices_for_graph_construction_model,
            num_classes=num_binary_classes,
            **model_params
        ).to(device)

        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
        pos_weight = torch.tensor( (len(y_train) - y_train.sum(axis=0)) / (y_train.sum(axis=0) + 1e-6), dtype=torch.float).to(device)
        criterion = FocalLoss(alpha=0.5, gamma=2.0, pos_weight=pos_weight)
        scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=15)

        def train_epoch(loader, model, crit, opt):
            model.train()
            total_loss = 0
            for data in loader:
                data = data.to(device)
                opt.zero_grad()
                out = model(data)
                loss = crit(out, data.y)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                opt.step()
                total_loss += loss.item() * data.num_graphs
            return total_loss / len(loader.dataset)

        def test_epoch(loader, model, crit):
            model.eval()
            all_preds, all_labels = [], []
            with torch.no_grad():
                for data in loader:
                    data = data.to(device)
                    out = model(data)
                    preds = torch.sigmoid(out)
                    all_preds.append(preds.cpu().numpy())
                    all_labels.append(data.y.cpu().numpy())
            return np.concatenate(all_preds), np.concatenate(all_labels)

        best_val_auc = -1.0
        patience_counter = 0
        best_model_state = None
        fold_train_log, fold_val_log = [], []
        
        for epoch in range(1, training_epochs + 1):
            train_loss = train_epoch(train_loader, model, criterion, optimizer)
            val_preds_prob, val_labels = test_epoch(val_loader, model, criterion)
            val_preds_binary = (val_preds_prob > 0.5).astype(int)
            
            val_metrics = calculate_fold_metrics(val_labels, val_preds_prob, val_preds_binary, num_binary_classes, target_names, let_is_idx)
            current_val_auc = val_metrics.get('mean_roc_auc', -1.0)
            
            fold_train_log.append({'avg_loss': train_loss})
            fold_val_log.append(val_metrics)

            if epoch % 10 == 0:
                print(f'Fold {fold_idx+1}, Epoch {epoch:03d}: Train Loss: {train_loss:.4f}, Val AUC: {current_val_auc:.4f}, Val Bal. Acc: {val_metrics.get("mean_balanced_acc", -1):.4f}')
            
            scheduler.step(current_val_auc)
            
            if current_val_auc > best_val_auc:
                best_val_auc = current_val_auc
                patience_counter = 0
                best_model_state = copy.deepcopy(model.state_dict())
                all_folds_metrics.append(val_metrics)
            else:
                patience_counter += 1
            
            if patience_counter >= early_stopping_patience:
                print(f'Early stopping at epoch {epoch}. Best Val AUC: {best_val_auc:.4f}')
                break
        
        max_epochs_across_folds = max(max_epochs_across_folds, epoch)
        all_folds_train_logs.append(fold_train_log)
        all_folds_val_logs.append(fold_val_log)
        
        if best_model_state:
            checkpoint_path = os.path.join(checkpoints_dir, f'fold_{fold_idx}_best_model.pth')
            torch.save(best_model_state, checkpoint_path)
            print(f"Fold {fold_idx+1} best model saved. AUC: {best_val_auc:.4f}")

    # --- Post-Training Summary ---
    print(f"\n--- Averaged K-Fold Performance for {backbone_name} ---")
    avg_metrics = {key: np.nanmean([m.get(key, np.nan) for m in all_folds_metrics]) for key in all_folds_metrics[0] if isinstance(all_folds_metrics[0][key], (int, float))}
    for key, value in avg_metrics.items():
        print(f"  Avg {key.replace('_', ' ').title()}: {value:.4f}")
    
    # --- Save Best Models to Final Directory ---
    BEST_MODELS_FINAL_DIR = '/kaggle/working/best_models'
    os.makedirs(BEST_MODELS_FINAL_DIR, exist_ok=True)
    for i in range(k_folds):
        src_path = os.path.join(checkpoints_dir, f'fold_{i}_best_model.pth')
        dest_path = os.path.join(BEST_MODELS_FINAL_DIR, f'fold_{i}.pth')
        if os.path.exists(src_path):
            shutil.copy(src_path, dest_path)
    print(f"\nBest models from each fold copied to {BEST_MODELS_FINAL_DIR}")
    
    # --- Plotting Training Curves ---
    if all_folds_train_logs:
        plot_kfold_training_curves(MODEL_NAME_FOR_PATHS, all_folds_train_logs, all_folds_val_logs, k_folds, max_epochs_across_folds)
        print(f"Training curves saved in /kaggle/working/{MODEL_NAME_FOR_PATHS}/")