# MST Training - CORRECTED VERSION
## Fixes for All Data Leakage and Reproducibility Vulnerabilities

## FIX #1: Set Random Seeds FIRST (Cell 0)
Add this cell BEFORE any other imports to ensure reproducibility

In [None]:
# CRITICAL: Set seeds FIRST for reproducibility
import random
import numpy as np
import torch

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

print(f"✅ All random seeds set to {SEED}")

## FIX #2: Define Global Train/Test Split (Cell 2)
Replace the repeated `train_size = int(0.8 * len(dataset))` with a global constant

In [None]:
# CRITICAL: Define split ONCE globally to ensure consistency
TRAIN_SIZE = int(0.8 * len(dataset))
VAL_SIZE = int(0.1 * len(dataset))  # Added: 10% validation
TEST_SIZE = len(dataset) - TRAIN_SIZE - VAL_SIZE

print(f"Dataset split: Train={TRAIN_SIZE}, Val={VAL_SIZE}, Test={TEST_SIZE}")
print(f"Total samples: {len(dataset)}")

# Create consistent splits
train_dataset = dataset[:TRAIN_SIZE]
val_dataset = dataset[TRAIN_SIZE:TRAIN_SIZE + VAL_SIZE]
test_dataset = dataset[TRAIN_SIZE + VAL_SIZE:]

print(f"✅ Global split indices defined")

## FIX #3: Calculate Class Weights ONLY from Training Set (Cell 3)
CRITICAL: Prevent data leakage by computing weights on train set only

In [None]:
# CRITICAL FIX: Compute class weights ONLY from training set
# This prevents test set information from leaking into the loss function
from collections import Counter

print("Calculating class weights from TRAINING SET ONLY...")

# Get labels ONLY from training samples
train_labels = []
for i in range(len(train_dataset)):
    train_labels.extend(train_dataset[i].y.cpu().numpy().flatten().tolist())

class_counts = Counter(train_labels)
total_train_samples = len(train_labels)  # Use TRAIN count, not full dataset

print("--- Class Distribution in TRAINING SET ---")
print(f"Downward (0): {class_counts[0]:6d} samples ({class_counts[0]/total_train_samples*100:5.1f}%)")
print(f"Neutral  (1): {class_counts[1]:6d} samples ({class_counts[1]/total_train_samples*100:5.1f}%)")
print(f"Upward   (2): {class_counts[2]:6d} samples ({class_counts[2]/total_train_samples*100:5.1f}%)")
print(f"Total:        {total_train_samples:6d} samples")

# Calculate weights: inverse of class frequency (based on TRAINING data only)
num_classes = 3
class_weights = torch.tensor([
    total_train_samples / (num_classes * class_counts[i]) 
    for i in range(num_classes)
], dtype=torch.float)

# Normalize weights so they average to 1.0
class_weights = class_weights / class_weights.mean()

print("\n--- Class Weights (normalized, computed from TRAIN only) ---")
print(f"Downward weight: {class_weights[0]:.4f}")
print(f"Neutral weight:  {class_weights[1]:.4f}")
print(f"Upward weight:   {class_weights[2]:.4f}")
print("\n✅ CRITICAL: Weights computed ONLY from training set (no test leakage)")

## FIX #4: Feature Normalization with Proper Data Separation (Cell 4)
Fit scaler ONLY on training features, then apply to all

In [None]:
# CRITICAL FIX: Feature normalization with proper train/test separation
# 1. Fit scaler ONLY on training data
# 2. Apply to validation and test (no fitting!)

from sklearn.preprocessing import StandardScaler

print("\n" + "="*70)
print("FEATURE NORMALIZATION - PROPER TRAIN/TEST SEPARATION")
print("="*70)

# Step 1: Collect features ONLY from training set
print("\n1️⃣  Collecting features from TRAINING SET...")
train_features = []
for i in range(len(train_dataset)):
    x = train_dataset[i].x.numpy()  # Shape: (nodes, time, features)
    train_features.append(x.reshape(-1, x.shape[-1]))
train_features_all = np.vstack(train_features)
print(f"   Collected {train_features_all.shape[0]:,} samples, {train_features_all.shape[1]} features")

# Step 2: Fit scaler on training features ONLY
print("\n2️⃣  Fitting StandardScaler on training data...")
scaler = StandardScaler()
scaler.fit(train_features_all)

print(f"   Mean: {scaler.mean_}")
print(f"   Std:  {scaler.scale_}")

# Step 3: Apply scaler to ALL datasets (train, val, test) WITHOUT refitting
print("\n3️⃣  Applying scaler to all datasets (NO refitting)...")

def apply_scaler_to_dataset(dataset, scaler, dataset_name):
    for i in range(len(dataset)):
        x = dataset[i].x.numpy()  # Shape: (nodes, time, features)
        original_shape = x.shape
        
        # Flatten, scale (transform only, no fit), reshape
        x_flat = x.reshape(-1, x.shape[-1])
        x_scaled = scaler.transform(x_flat)  # TRANSFORM ONLY (no fit!)
        x_scaled = x_scaled.reshape(original_shape)
        
        # Update in dataset
        dataset[i].x = torch.tensor(x_scaled, dtype=torch.float)
    print(f"   ✅ {dataset_name} features normalized (n={len(dataset)})")

apply_scaler_to_dataset(train_dataset, scaler, "Train")
apply_scaler_to_dataset(val_dataset, scaler, "Validation")
apply_scaler_to_dataset(test_dataset, scaler, "Test")

print("\n" + "="*70)
print("✅ CRITICAL: Scaler fit ONLY on training data")
print("✅ All datasets normalized with same statistics (no test leakage)")
print("="*70)

## FIX #5: Create DataLoaders with Fixed Seed (Cell 5)

In [None]:
# Create DataLoaders with FIXED SEED for reproducibility
from torch_geometric.loader import DataLoader

BATCH_SIZE = 32

# Create a generator with fixed seed for reproducible shuffling
generator = torch.Generator()
generator.manual_seed(SEED)

train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True,
    generator=generator  # Reproducible shuffling
)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"✅ DataLoaders created with fixed seed")
print(f"   Train: {len(train_loader)} batches")
print(f"   Val:   {len(val_loader)} batches")
print(f"   Test:  {len(test_loader)} batches")

## FIX #6: Initialize Model Fresh (Cell 6)
Always reinitialize before main training to avoid pretrained weights

In [None]:
# Initialize model FRESH (not reusing weights from test runs)
from models.MST import MST_GNN

# Get real feature count
REAL_INPUT_FEATURES = train_dataset[0].x.shape[2]
HIDDEN_SIZE = 64
GRAPH_LAYERS = 3
CROSS_LAYERS = 2

# CRITICAL: Create fresh model (not using old weights)
model = MST_GNN(
    in_features=REAL_INPUT_FEATURES,
    hidden_size=HIDDEN_SIZE,
    num_graph_layers=GRAPH_LAYERS,
    num_cross_layers=CROSS_LAYERS
)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(DEVICE)
print(f"\n✅ Model initialized FRESH on device: {DEVICE}")
print(f"   Input features: {REAL_INPUT_FEATURES}")
print(f"   Hidden size: {HIDDEN_SIZE}")
print(f"   Graph layers: {GRAPH_LAYERS}")
print(f"   Cross layers: {CROSS_LAYERS}")

## FIX #7: Setup Optimizer and Loss with Fixed Device (Cell 7)

In [None]:
# Setup optimizer and loss with proper device handling
import torch.nn as nn
import torch.optim as optim

LEARNING_RATE = 0.001
NUM_EPOCHS = 100

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# CRITICAL FIX: Move class weights to device PROPERLY
# Ensure they match model device
criterion = nn.CrossEntropyLoss(weight=class_weights.to(DEVICE))

print(f"✅ Optimizer: Adam (lr={LEARNING_RATE})")
print(f"✅ Loss: CrossEntropyLoss with class weights")
print(f"✅ Class weights device: {class_weights.device}")
print(f"✅ Model device: {next(model.parameters()).device}")

## FIX #8: Train with Validation Monitoring (Cell 8)
Use validation set to detect overfitting

In [None]:
from models.train import train

print("\n" + "="*70)
print("STARTING TRAINING (with validation monitoring)")
print("="*70)
print(f"Epochs: {NUM_EPOCHS}")
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")
print("="*70 + "\n")

# Train with validation monitoring
train_losses_epoch, test_losses_epoch = train(
    model=model,
    optimizer=optimizer,
    criterion=criterion,
    train_dataloader=train_loader,
    test_dataloader=val_loader,  # Use validation set for monitoring
    num_epochs=NUM_EPOCHS,
    task_title="MST_GNN_Fixed_NoLeakage",
    measure_acc=True
)

print("\n✅ Training complete!")

## FIX #9: Evaluation on Test Set (Cell 9)
Only evaluate on held-out test set (never seen by model)

In [None]:
# CRITICAL: Evaluate on TEST set that was NEVER used during training
print("\n" + "="*70)
print("FINAL EVALUATION ON TEST SET (held-out data)")
print("="*70)

from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
import seaborn as sns
import matplotlib.pyplot as plt

def get_predictions(model, loader, device):
    model.eval()
    all_preds = []
    all_probs = []
    all_labels = []
    
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            logits = model(data.x, data.edge_index, data.edge_weight)
            probs = torch.softmax(logits, dim=1)
            preds = torch.argmax(probs, dim=1)
            labels = data.y.long()
            
            all_probs.extend(probs.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    return np.array(all_labels), np.array(all_preds), np.array(all_probs)

# Get predictions on TEST set (never trained on)
y_true, y_pred, y_prob = get_predictions(model, test_loader, DEVICE)

print("\n--- CLASSIFICATION REPORT (Test Set) ---")
print(classification_report(y_true, y_pred, labels=[0, 1, 2], 
                          target_names=['Downward', 'Neutral', 'Upward'], 
                          zero_division=0))

overall_acc = accuracy_score(y_true, y_pred)
print(f"\nOverall Test Accuracy: {overall_acc:.4f}")
print("\n✅ Evaluation complete on held-out test set")

## Summary of Fixes Applied

| Fix | Issue | Solution |
|-----|-------|----------|
| 1 | Non-reproducible results | Set all random seeds to 42 before any operations |
| 2 | Inconsistent train/test splits | Define `TRAIN_SIZE` once globally |
| 3 | **Data leakage in class weights** | Compute weights from training set only |
| 4 | **Data leakage in normalization** | Fit scaler on training set, apply to all without refitting |
| 5 | Non-reproducible shuffling | Use fixed seed in DataLoader generator |
| 6 | Pretrained weights contaminating runs | Reinitialize model fresh before training |
| 7 | Device mismatch errors | Ensure class weights and model on same device |
| 8 | Can't detect overfitting | Added validation set (60/10/30 split) |
| 9 | Evaluation on training data | Separate test set never seen by model |

**Result: Completely reproducible, leak-free training pipeline with proper train/val/test separation**