In [None]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.1 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m67.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GATConv
from sklearn.metrics import (confusion_matrix, classification_report,
                            roc_auc_score, roc_curve, auc, accuracy_score)
from sklearn.utils.class_weight import compute_class_weight
import re
from sklearn.neighbors import NearestNeighbors
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.nn as nn
from sklearn.preprocessing import label_binarize
import glob
######################################
# Evaluation Metrics Functions
######################################
def calculate_metrics(y_true, y_pred, y_prob=None):
    """Calculate accuracy, sensitivity, specificity for each class and overall"""
    metrics = {}

    # Overall accuracy
    metrics['accuracy'] = accuracy_score(y_true, y_pred)

    # Per-class metrics
    cm = confusion_matrix(y_true, y_pred)

    # For each class
    n_classes = len(np.unique(y_true))
    for i in range(n_classes):
        # True positives, false positives, true negatives, false negatives
        tp = cm[i, i]
        fp = np.sum(cm[:, i]) - tp
        fn = np.sum(cm[i, :]) - tp
        tn = np.sum(cm) - tp - fp - fn

        # Sensitivity (recall) = TP / (TP + FN)
        sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
        metrics[f'sensitivity_class_{i}'] = sensitivity

        # Specificity = TN / (TN + FP)
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
        metrics[f'specificity_class_{i}'] = specificity

    # Calculate AUC-ROC if probabilities are provided
    if y_prob is not None:
        if n_classes > 2:
            # For multiclass, convert to one-hot encoding for ROC AUC calculation
            y_true_bin = label_binarize(y_true, classes=range(n_classes))

            # Calculate AUC for each class
            for i in range(n_classes):
                if len(np.unique(y_true_bin[:, i])) > 1:  # Only calculate if both classes are present
                    metrics[f'auc_class_{i}'] = roc_auc_score(y_true_bin[:, i], y_prob[:, i])

            # Calculate macro and weighted AUC
            metrics['auc_macro'] = roc_auc_score(y_true_bin, y_prob, average='macro', multi_class='ovr')
            metrics['auc_weighted'] = roc_auc_score(y_true_bin, y_prob, average='weighted', multi_class='ovr')
        else:
            # Binary case
            metrics['auc'] = roc_auc_score(y_true, y_prob[:, 1])

    return metrics

def plot_roc_curves(y_true, y_prob, n_classes, phase='train'):
    """Plot ROC curves for each class"""
    plt.figure(figsize=(10, 8))

    # One-hot encode the labels for ROC calculation
    y_true_bin = label_binarize(y_true, classes=range(n_classes))

    # Plot ROC curve for each class
    for i in range(n_classes):
        fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_prob[:, i])
        roc_auc = auc(fpr, tpr)
        plt.plot(fpr, tpr, lw=2, label=f'Class {i} (AUC = {roc_auc:.4f})')

    # Plot diagonal line
    plt.plot([0, 1], [0, 1], 'k--', lw=2)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(f'ROC Curves for {phase.capitalize()} Set')
    plt.legend(loc="lower right")
    plt.savefig(f'roc_curve_{phase}.png')
    plt.close()

######################################
# Connectivity Building Function (Time-based)
######################################
def build_time_based_connectivity(df, time_threshold=18, tau=10):
    visit_edge_index = [[], []]
    visit_edge_weight_list = []
    patient_to_visit_indices = {}
    for i, ptid in enumerate(df['PTID']):
        patient_to_visit_indices.setdefault(ptid, []).append(i)

    for indices in patient_to_visit_indices.values():
        sorted_indices = sorted(indices, key=lambda idx: df.iloc[idx]['time'])
        n = len(sorted_indices)
        for i in range(n):
            for j in range(i+1, n):
                time_i = df.iloc[sorted_indices[i]]['time']
                time_j = df.iloc[sorted_indices[j]]['time']
                gap = abs(time_j - time_i)
                if gap <= time_threshold:
                    weight = np.exp(-gap / tau)
                    # Bidirectional edges
                    visit_edge_index[0].append(sorted_indices[i])
                    visit_edge_index[1].append(sorted_indices[j])
                    visit_edge_weight_list.append(weight)
                    visit_edge_index[0].append(sorted_indices[j])
                    visit_edge_index[1].append(sorted_indices[i])
                    visit_edge_weight_list.append(weight)
    visit_edge_index = torch.tensor(visit_edge_index, dtype=torch.long)
    visit_edge_weight = torch.tensor(visit_edge_weight_list, dtype=torch.float)
    return visit_edge_index, visit_edge_weight

######################################
# Helper function for viscode conversion
######################################
def convert_viscode(viscode):
    if isinstance(viscode, str):
        if viscode.lower() == "sc":
            return 0
        m = re.match(r'm(\d+)', viscode.lower())
        if m:
            return int(m.group(1))
    return None

######################################
# Load and preprocess training data
######################################
train_file = "/content/train_data.csv"
train_df = pd.read_csv(train_file)
train_df = train_df[train_df['DIAGNOSIS'].isin([1, 2, 3])]
train_df['DIAGNOSIS'] = train_df['DIAGNOSIS'] - 1
train_df = train_df.fillna(0)

# Define feature columns - exclude genotype
non_feature_cols = ['PTID', 'VISCODE2', 'DIAGNOSIS', 'GENOTYPE']  # Added GENOTYPE to non-features
feature_columns = [col for col in train_df.columns if col not in non_feature_cols]

# Remove any genotype dummy columns if they exist
genotype_cols = [col for col in train_df.columns if col.startswith('GENOTYPE_')]
for col in genotype_cols:
    if col in feature_columns:
        feature_columns.remove(col)

# Calculate mean and std for standardization
feature_means = train_df[feature_columns].mean()
feature_stds = train_df[feature_columns].std()
train_df[feature_columns] = (train_df[feature_columns] - feature_means) / feature_stds

# Process time
train_df['time'] = train_df['VISCODE2'].apply(convert_viscode)
train_df = train_df[train_df['time'].notnull()]

# Extract labels and features
train_labels = train_df['DIAGNOSIS'].values
train_features = train_df[feature_columns].values.astype(np.float32)

# Compute class weights
classes = np.unique(train_labels)
class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=train_labels)
class_weights = torch.tensor(class_weights, dtype=torch.float)

# Build visit-level connectivity
train_visit_edge_index, train_visit_edge_weight = build_time_based_connectivity(train_df, time_threshold=18, tau=10)
train_visit_x = torch.tensor(train_features, dtype=torch.float)
train_y_tensor = torch.tensor(train_labels, dtype=torch.long)

# Create train graph data
train_visit_data = Data(x=train_visit_x, edge_index=train_visit_edge_index, y=train_y_tensor)
train_visit_data.edge_weight = train_visit_edge_weight

# Build patient-level graph for training
train_patient_df = train_df.groupby('PTID')[feature_columns].mean().reset_index()
train_patient_features = train_patient_df[feature_columns].values.astype(np.float32)
train_patient_x = torch.tensor(train_patient_features, dtype=torch.float)

n_neighbors = min(4, len(train_patient_features) - 1)  # Ensure we don't exceed available samples
nbrs = NearestNeighbors(n_neighbors=n_neighbors)
nbrs.fit(train_patient_features)
_, knn_indices = nbrs.kneighbors(train_patient_features)

train_patient_edge_index = [[], []]
num_patients = train_patient_features.shape[0]
for i in range(num_patients):
    for j in knn_indices[i][1:]:  # Skip the first neighbor (self)
        train_patient_edge_index[0].append(i)
        train_patient_edge_index[1].append(j)
        train_patient_edge_index[0].append(j)
        train_patient_edge_index[1].append(i)
train_patient_edge_index = torch.tensor(train_patient_edge_index, dtype=torch.long)

train_ptid_to_patient_index = {ptid: idx for idx, ptid in enumerate(train_patient_df['PTID'])}
train_visit_to_patient_mapping = train_df['PTID'].apply(lambda ptid: train_ptid_to_patient_index[ptid]).values
train_visit_to_patient_mapping = torch.tensor(train_visit_to_patient_mapping, dtype=torch.long)

# Create a mask for all training data
train_mask = torch.ones(len(train_labels), dtype=torch.bool)
train_visit_data.train_mask = train_mask

######################################
# Load and preprocess test data
######################################
test_file = "/content/test_data.csv"
test_df = pd.read_csv(test_file)
test_df = test_df[test_df['DIAGNOSIS'].isin([1, 2, 3])]
test_df['DIAGNOSIS'] = test_df['DIAGNOSIS'] - 1
test_df = test_df.fillna(0)

# Use the same feature columns as training (already excludes GENOTYPE)
test_feature_columns = [col for col in feature_columns if col in test_df.columns]

# Standardize using training means and stds
for col in test_feature_columns:
    if col in feature_means and col in feature_stds:
        test_df[col] = (test_df[col] - feature_means[col]) / feature_stds[col]

# Process time
test_df['time'] = test_df['VISCODE2'].apply(convert_viscode)
test_df = test_df[test_df['time'].notnull()]

# Extract labels and features
test_labels = test_df['DIAGNOSIS'].values
test_features = test_df[test_feature_columns].values.astype(np.float32)

# Build visit-level connectivity
test_visit_edge_index, test_visit_edge_weight = build_time_based_connectivity(test_df, time_threshold=18, tau=10)
test_visit_x = torch.tensor(test_features, dtype=torch.float)
test_y_tensor = torch.tensor(test_labels, dtype=torch.long)

# Create test graph data
test_visit_data = Data(x=test_visit_x, edge_index=test_visit_edge_index, y=test_y_tensor)
test_visit_data.edge_weight = test_visit_edge_weight

# Build patient-level graph for testing
test_patient_df = test_df.groupby('PTID')[test_feature_columns].mean().reset_index()
test_patient_features = test_patient_df[test_feature_columns].values.astype(np.float32)
test_patient_x = torch.tensor(test_patient_features, dtype=torch.float)

n_neighbors = min(4, len(test_patient_features) - 1)  # Ensure we don't exceed available samples
nbrs = NearestNeighbors(n_neighbors=n_neighbors)
nbrs.fit(test_patient_features)
_, knn_indices = nbrs.kneighbors(test_patient_features)

test_patient_edge_index = [[], []]
num_patients = test_patient_features.shape[0]
for i in range(num_patients):
    for j in knn_indices[i][1:]:  # Skip the first neighbor (self)
        test_patient_edge_index[0].append(i)
        test_patient_edge_index[1].append(j)
        test_patient_edge_index[0].append(j)
        test_patient_edge_index[1].append(i)
test_patient_edge_index = torch.tensor(test_patient_edge_index, dtype=torch.long)

test_ptid_to_patient_index = {ptid: idx for idx, ptid in enumerate(test_patient_df['PTID'])}
test_visit_to_patient_mapping = test_df['PTID'].apply(lambda ptid: test_ptid_to_patient_index[ptid]).values
test_visit_to_patient_mapping = torch.tensor(test_visit_to_patient_mapping, dtype=torch.long)

# Create a mask for all test data
test_mask = torch.ones(len(test_labels), dtype=torch.bool)
test_visit_data.test_mask = test_mask

######################################
# Define the Multi-Scale GAT Model
######################################
class MultiScaleGAT(nn.Module):
    def __init__(self, visit_in_channels, patient_in_channels, hidden_channels, num_classes, dropout=0.3, num_heads=4):
        super(MultiScaleGAT, self).__init__()
        # Visit-level branch (4 layers)
        self.visit_gat1 = GATConv(visit_in_channels, hidden_channels, heads=num_heads, concat=False, dropout=dropout)
        self.visit_bn1 = nn.BatchNorm1d(hidden_channels)
        self.visit_gat2 = GATConv(hidden_channels, hidden_channels, heads=num_heads, concat=False, dropout=dropout)
        self.visit_bn2 = nn.BatchNorm1d(hidden_channels)
        self.visit_gat3 = GATConv(hidden_channels, hidden_channels, heads=num_heads, concat=False, dropout=dropout)
        self.visit_bn3 = nn.BatchNorm1d(hidden_channels)
        self.visit_gat4 = GATConv(hidden_channels, hidden_channels, heads=num_heads, concat=False, dropout=dropout)
        self.visit_bn4 = nn.BatchNorm1d(hidden_channels)
        self.visit_dropout = nn.Dropout(dropout)

        # Patient-level branch (3 layers)
        self.patient_gat1 = GATConv(patient_in_channels, hidden_channels, heads=num_heads, concat=False, dropout=dropout)
        self.patient_bn1 = nn.BatchNorm1d(hidden_channels)
        self.patient_gat2 = GATConv(hidden_channels, hidden_channels, heads=num_heads, concat=False, dropout=dropout)
        self.patient_bn2 = nn.BatchNorm1d(hidden_channels)
        self.patient_gat3 = GATConv(hidden_channels, hidden_channels, heads=num_heads, concat=False, dropout=dropout)
        self.patient_bn3 = nn.BatchNorm1d(hidden_channels)
        self.patient_dropout = nn.Dropout(dropout)

        self.classifier = nn.Linear(hidden_channels * 2, num_classes)

    def forward(self, visit_x, visit_edge_index, patient_x, patient_edge_index, visit_to_patient_mapping):
        # Visit-level branch
        v = self.visit_gat1(visit_x, visit_edge_index)
        v = self.visit_bn1(v)
        v = F.relu(v)
        v = self.visit_dropout(v)
        v_res = v
        v = self.visit_gat2(v, visit_edge_index)
        v = self.visit_bn2(v)
        v = F.relu(v)
        v = v + v_res
        v = self.visit_dropout(v)
        v_res = v
        v = self.visit_gat3(v, visit_edge_index)
        v = self.visit_bn3(v)
        v = F.relu(v)
        v = v + v_res
        v = self.visit_dropout(v)
        v_res = v
        v = self.visit_gat4(v, visit_edge_index)
        v = self.visit_bn4(v)
        v = F.relu(v)
        v = v + v_res
        v = self.visit_dropout(v)

        # Patient-level branch
        p = self.patient_gat1(patient_x, patient_edge_index)
        p = self.patient_bn1(p)
        p = F.relu(p)
        p = self.patient_dropout(p)
        p_res = p
        p = self.patient_gat2(p, patient_edge_index)
        p = self.patient_bn2(p)
        p = F.relu(p)
        p = p + p_res
        p = self.patient_dropout(p)
        p_res = p
        p = self.patient_gat3(p, patient_edge_index)
        p = self.patient_bn3(p)
        p = F.relu(p)
        p = p + p_res
        p = self.patient_dropout(p)

        # Map patient embedding to each visit
        p_for_visit = p[visit_to_patient_mapping]
        combined = torch.cat([v, p_for_visit], dim=1)
        logits = self.classifier(combined)
        return logits, F.log_softmax(logits, dim=1)

######################################
# Training Setup and Training Loop
######################################
# Best hyperparameters from grid search
num_classes = len(torch.unique(train_y_tensor))
model = MultiScaleGAT(train_visit_x.shape[1], train_patient_x.shape[1],
                      hidden_channels=128, num_classes=num_classes,
                      dropout=0.3, num_heads=4)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0005)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)

num_epochs = 500

def compute_metrics_for_batch(model, data, patient_x, patient_edge_index, visit_to_patient_mapping, mask=None):
    """Compute all metrics for a batch of data"""
    model.eval()
    with torch.no_grad():
        # Forward pass
        logits, log_softmax = model(data.x, data.edge_index, patient_x, patient_edge_index, visit_to_patient_mapping)
        preds = log_softmax.argmax(dim=1)

        # Apply mask if provided
        if mask is None:
            mask = torch.ones(len(data.y), dtype=torch.bool)

        # Get predictions, ground truth, and probabilities for masked entries
        y_true = data.y[mask].cpu().numpy()
        y_pred = preds[mask].cpu().numpy()
        y_prob = F.softmax(logits[mask], dim=1).cpu().numpy()

        # Calculate metrics
        metrics = calculate_metrics(y_true, y_pred, y_prob)
        return metrics, y_true, y_pred, y_prob

def print_metrics(metrics, phase='train'):
    """Print metrics in a formatted way"""
    print(f"\n{phase.capitalize()} Metrics:")
    print(f"Accuracy: {metrics['accuracy']:.4f}")

    # Print per-class sensitivity and specificity
    for key in sorted([k for k in metrics.keys() if 'sensitivity_class' in k]):
        class_idx = key.split('_')[-1]
        print(f"Class {class_idx} Sensitivity: {metrics[key]:.4f}, "
              f"Specificity: {metrics[f'specificity_class_{class_idx}']:.4f}")

    # Print AUC metrics if available
    if 'auc_macro' in metrics:
        print(f"AUC (macro): {metrics['auc_macro']:.4f}")
        print(f"AUC (weighted): {metrics['auc_weighted']:.4f}")

    # Print per-class AUC if available
    for key in sorted([k for k in metrics.keys() if 'auc_class' in k]):
        class_idx = key.split('_')[-1]
        print(f"Class {class_idx} AUC: {metrics[key]:.4f}")
def load_models(model_paths, model_class, model_params):
    """Load multiple trained models from saved checkpoints"""
    models = []
    for path in model_paths:
        model = model_class(**model_params)
        model.load_state_dict(torch.load(path))
        model.eval()
        models.append(model)
    return models

def get_ensemble_predictions(models, visit_data, patient_x, patient_edge_index, visit_to_patient_mapping):
    """Get predictions from multiple models and combine them"""
    all_visit_probs = []

    with torch.no_grad():
        for model in models:
            logits, _ = model(visit_data.x, visit_data.edge_index, patient_x,
                             patient_edge_index, visit_to_patient_mapping)
            probs = F.softmax(logits, dim=1).cpu().numpy()
            all_visit_probs.append(probs)

    # Average probabilities from all models
    ensemble_probs = np.mean(all_visit_probs, axis=0)
    ensemble_preds = np.argmax(ensemble_probs, axis=1)

    return ensemble_preds, ensemble_probs

def create_visit_level_csv(visit_df, predictions, probabilities, output_file="visit_predictions.csv"):
    """Create a CSV with visit-level predictions"""
    result_df = visit_df[['PTID', 'VISCODE2', 'DIAGNOSIS']].copy()

    # Add predictions and diagnosis labels (0-indexed in model, 1-indexed in data)
    result_df['DIAGNOSIS'] = result_df['DIAGNOSIS'] + 1  # Convert back to 1-indexed
    result_df['PREDICTED'] = predictions + 1  # Convert back to 1-indexed

    # Add probabilities for each class
    for i in range(probabilities.shape[1]):
        result_df[f'PROB_CLASS_{i+1}'] = probabilities[:, i]

    result_df.to_csv(output_file, index=False)
    return result_df

def create_patient_level_csv(visit_df, visit_predictions, visit_probabilities, output_file="patient_predictions.csv"):
    """Aggregate visit predictions to patient level and create a CSV"""
    # Group by patient to get last diagnosis for each patient
    patient_df = visit_df.groupby('PTID')['DIAGNOSIS'].last().reset_index()

    # Create a DataFrame with visit predictions
    visit_pred_df = pd.DataFrame({
        'PTID': visit_df['PTID'],
        'VISCODE2': visit_df['VISCODE2'],
        'PREDICTED': visit_predictions + 1,  # Convert back to 1-indexed
    })

    # For each probability class
    num_classes = visit_probabilities.shape[1]
    for i in range(num_classes):
        visit_pred_df[f'PROB_CLASS_{i+1}'] = visit_probabilities[:, i]

    # Group by patient: for each patient get most frequent prediction
    patient_predictions = []
    patient_probabilities = []

    for ptid in patient_df['PTID']:
        # Get all visits for this patient
        patient_visits = visit_pred_df[visit_pred_df['PTID'] == ptid]

        # Get the most frequent prediction (majority voting)
        pred_counts = patient_visits['PREDICTED'].value_counts()
        most_frequent_pred = pred_counts.index[0]

        # Average the probabilities across all visits
        avg_probs = []
        for i in range(1, num_classes+1):
            avg_probs.append(patient_visits[f'PROB_CLASS_{i}'].mean())

        patient_predictions.append(most_frequent_pred)
        patient_probabilities.append(avg_probs)

    # Add predictions to patient dataframe
    patient_df['PREDICTED'] = patient_predictions

    # Add averaged probabilities
    patient_probabilities = np.array(patient_probabilities)
    for i in range(num_classes):
        patient_df[f'PROB_CLASS_{i+1}'] = patient_probabilities[:, i]

    patient_df.to_csv(output_file, index=False)
    return patient_df

# Training loop with evaluation metrics
print("Starting training...")
best_train_metrics = None
best_epoch = 0
best_val_acc = 0

for epoch in range(1, num_epochs + 1):
    model.train()
    optimizer.zero_grad()
    _, out = model(train_visit_data.x, train_visit_data.edge_index,
                train_patient_x, train_patient_edge_index, train_visit_to_patient_mapping)
    loss = F.nll_loss(out[train_visit_data.train_mask],
                     train_visit_data.y[train_visit_data.train_mask],
                     weight=class_weights)
    loss.backward()
    optimizer.step()
    scheduler.step(loss.item())

    # Evaluate metrics every 50 epochs or at the end
    if epoch % 50 == 0 or epoch == num_epochs:
        # Calculate and print training metrics
        train_metrics, y_true_train, y_pred_train, y_prob_train = compute_metrics_for_batch(
            model, train_visit_data, train_patient_x, train_patient_edge_index,
            train_visit_to_patient_mapping, train_visit_data.train_mask
        )

        print(f"\nEpoch {epoch:03d}, Loss: {loss:.4f}")
        print_metrics(train_metrics, 'train')

        # Save best model based on training accuracy
        if best_train_metrics is None or train_metrics['accuracy'] > best_train_metrics['accuracy']:
            best_train_metrics = train_metrics
            best_epoch = epoch
            # Save model state
            torch.save(model.state_dict(), 'best_model.pt')

            # Plot ROC curves for best epoch
            plot_roc_curves(y_true_train, y_prob_train, num_classes, 'train')

# Load the best model for final evaluation
print(f"\nLoading best model from epoch {best_epoch}...")
model.load_state_dict(torch.load('best_model.pt'))

######################################
# Final Evaluation on Test Set
######################################
print("\nEvaluating on test set...")
test_metrics, y_true_test, y_pred_test, y_prob_test = compute_metrics_for_batch(
    model, test_visit_data, test_patient_x, test_patient_edge_index,
    test_visit_to_patient_mapping, test_visit_data.test_mask
)

# Print detailed metrics
print_metrics(test_metrics, 'test')

# Plot ROC curves for test set
plot_roc_curves(y_true_test, y_prob_test, num_classes, 'test')

# Print confusion matrix
print("\nTest Confusion Matrix:")
cm = confusion_matrix(y_true_test, y_pred_test)
print(cm)

# Print classification report
print("\nTest Classification Report:")
print(classification_report(y_true_test, y_pred_test, digits=4))

# Create a confusion matrix plot
plt.figure(figsize=(8, 6))
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix')
plt.colorbar()
tick_marks = np.arange(num_classes)
plt.xticks(tick_marks, range(num_classes))
plt.yticks(tick_marks, range(num_classes))
plt.xlabel('Predicted Label')
plt.ylabel('True Label')

# Add text annotations to the confusion matrix
thresh = cm.max() / 2
for i in range(cm.shape[0]):
    for j in range(cm.shape[1]):
        plt.text(j, i, format(cm[i, j], 'd'),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

plt.tight_layout()
plt.savefig('confusion_matrix.png')
plt.close()

# Save final metrics to CSV
train_metrics_df = pd.DataFrame({k: [v] for k, v in best_train_metrics.items()})
test_metrics_df = pd.DataFrame({k: [v] for k, v in test_metrics.items()})

metrics_df = pd.concat([
    train_metrics_df.add_prefix('train_'),
    test_metrics_df.add_prefix('test_')
], axis=1)

metrics_df.to_csv('model_metrics.csv', index=False)
print("\nMetrics saved to model_metrics.csv")

# Print summary of key metrics
print("\n===== SUMMARY =====")
print(f"Best epoch: {best_epoch}")
print(f"Train accuracy: {best_train_metrics['accuracy']:.4f}")
print(f"Test accuracy: {test_metrics['accuracy']:.4f}")

if 'auc_macro' in best_train_metrics and 'auc_macro' in test_metrics:
    print(f"Train AUC (macro): {best_train_metrics['auc_macro']:.4f}")
    print(f"Test AUC (macro): {test_metrics['auc_macro']:.4f}")

print("\n===== GENERATING PREDICTION CSVs =====")

# Model parameters (using the same as defined in training)
model_params = {
    'visit_in_channels': test_visit_x.shape[1],
    'patient_in_channels': test_patient_x.shape[1],
    'hidden_channels': 128,
    'num_classes': num_classes,
    'dropout': 0.3,
    'num_heads': 4
}

# Find model checkpoints
model_paths = glob.glob('model_*.pt')
if not model_paths:
    model_paths = ['best_model.pt']  # Fallback to single best model

print(f"Using {len(model_paths)} models for ensemble predictions")

# Load ensemble models
models = load_models(model_paths, MultiScaleGAT, model_params)

# === GENERATE TEST PREDICTIONS ===
# Prepare test data (convert diagnoses to 1-indexed for results)
original_test_diagnoses = test_df['DIAGNOSIS'].copy()  # Save current diagnoses (0-indexed)
test_df['DIAGNOSIS'] = original_test_diagnoses + 1  # Convert back to 1-indexed for results

# Get ensemble predictions for test data
test_ensemble_preds, test_ensemble_probs = get_ensemble_predictions(
    models, test_visit_data, test_patient_x, test_patient_edge_index, test_visit_to_patient_mapping
)

# Create test prediction CSVs
test_visit_pred_df = create_visit_level_csv(test_df, test_ensemble_preds, test_ensemble_probs, "test_visit_predictions.csv")
test_patient_pred_df = create_patient_level_csv(test_df, test_ensemble_preds, test_ensemble_probs, "test_patient_predictions.csv")

print(f"Created test visit-level predictions: test_visit_predictions.csv")
print(f"Created test patient-level predictions: test_patient_predictions.csv")

# === GENERATE TRAIN PREDICTIONS ===
# Prepare train data (convert diagnoses to 1-indexed for results)
original_train_diagnoses = train_df['DIAGNOSIS'].copy()  # Save current diagnoses (0-indexed)
train_df['DIAGNOSIS'] = original_train_diagnoses + 1  # Convert back to 1-indexed for results

# Get ensemble predictions for train data
train_ensemble_preds, train_ensemble_probs = get_ensemble_predictions(
    models, train_visit_data, train_patient_x, train_patient_edge_index, train_visit_to_patient_mapping
)

# Create train prediction CSVs
train_visit_pred_df = create_visit_level_csv(train_df, train_ensemble_preds, train_ensemble_probs, "train_visit_predictions.csv")
train_patient_pred_df = create_patient_level_csv(train_df, train_ensemble_preds, train_ensemble_probs, "train_patient_predictions.csv")

print(f"Created train visit-level predictions: train_visit_predictions.csv")
print(f"Created train patient-level predictions: train_patient_predictions.csv")

Starting training...

Epoch 050, Loss: 0.4058

Train Metrics:
Accuracy: 0.8355
Class 0 Sensitivity: 0.8430, Specificity: 0.9516
Class 1 Sensitivity: 0.8239, Specificity: 0.8455
Class 2 Sensitivity: 0.8466, Specificity: 0.9416
AUC (macro): 0.9517
AUC (weighted): 0.9457
Class 0 AUC: 0.9722
Class 1 AUC: 0.9119
Class 2 AUC: 0.9711

Epoch 100, Loss: 0.3701

Train Metrics:
Accuracy: 0.8480
Class 0 Sensitivity: 0.9083, Specificity: 0.9356
Class 1 Sensitivity: 0.7866, Specificity: 0.8970
Class 2 Sensitivity: 0.8748, Specificity: 0.9339
AUC (macro): 0.9587
AUC (weighted): 0.9539
Class 0 AUC: 0.9771
Class 1 AUC: 0.9258
Class 2 AUC: 0.9732

Epoch 150, Loss: 0.3659

Train Metrics:
Accuracy: 0.8513
Class 0 Sensitivity: 0.9019, Specificity: 0.9418
Class 1 Sensitivity: 0.8016, Specificity: 0.8912
Class 2 Sensitivity: 0.8700, Specificity: 0.9370
AUC (macro): 0.9599
AUC (weighted): 0.9553
Class 0 AUC: 0.9781
Class 1 AUC: 0.9279
Class 2 AUC: 0.9738

Epoch 200, Loss: 0.3656

Train Metrics:
Accuracy: 0.85