# Baseline GNN Training

- Loads model from `baseline_gnn.py` file
- Starts training from 2014 (skips noisy early years)
- Uses 2K stratified samples for fast iteration
- Lazy loading with LRU cache
- Works both locally and on Colab

## 1. Setup Environment

In [1]:
# Check if running on Colab
import sys
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    print("Running on Google Colab")
    !pip install torch-geometric
    !pip install python-dateutil
else:
    print("Running locally")

Running on Google Colab


In [2]:
# Check GPU availability
import torch

# Default values
USE_BF16 = False
USE_AMP = torch.cuda.is_available()  # AMP works only on CUDA

if torch.cuda.is_available():
    device = torch.device('cuda')
    gpu_name = torch.cuda.get_device_name(0).lower()
    gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9

    print(f"GPU available: {gpu_name}")
    print(f"GPU memory: {gpu_mem:.2f} GB")

    # Precision selection based on GPU family
    if any(key in gpu_name for key in ['b200', 'b100', 'blackwell']):
        print("Detected Blackwell GPU → using BF16 for optimal performance")
        USE_BF16 = True

    elif any(key in gpu_name for key in ['h100', 'a100']):
        print("Detected Hopper/Ampere GPU → using BF16")
        USE_BF16 = True

    else:
        print("Standard CUDA GPU detected → using FP16 (AMP)")
        USE_BF16 = False

else:
    device = torch.device('cpu')
    print("No GPU available → using CPU")
    USE_BF16 = False
    USE_AMP = False

print("\nFinal Precision Settings")
print(f"Using device: {device}")
print(f"AMP enabled:  {USE_AMP}")
print(f"BF16 enabled: {USE_BF16}")

GPU available: nvidia a100-sxm4-80gb
GPU memory: 85.17 GB
Detected Hopper/Ampere GPU → using BF16

Final Precision Settings
Using device: cuda
AMP enabled:  True
BF16 enabled: True


In [3]:
# Setup paths based on environment
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')

    PROJECT_DIR = '/content/drive/MyDrive/dont-thread-on-me'
    GRAPH_DIR = f'{PROJECT_DIR}/data/processed/graphs'

    # Add project directory to path for imports
    if PROJECT_DIR not in sys.path:
        sys.path.append(PROJECT_DIR)

    print(f"Google Drive mounted")
    print(f"Project directory: {PROJECT_DIR}")

else:
    PROJECT_DIR = '..'
    GRAPH_DIR = '../data/processed/graphs'

    if PROJECT_DIR not in sys.path:
        sys.path.append(PROJECT_DIR)

    print(f"Running locally")
    print(f"Project directory: {PROJECT_DIR}")

print(f"Graph directory: {GRAPH_DIR}")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Google Drive mounted
Project directory: /content/drive/MyDrive/dont-thread-on-me
Graph directory: /content/drive/MyDrive/dont-thread-on-me/data/processed/graphs


In [4]:
# Verify data directory exists
from pathlib import Path

graph_dir = Path(GRAPH_DIR)

if not graph_dir.exists():
    print(f"ERROR: Graph directory not found: {graph_dir}")
    print(f"\nPlease upload your processed graphs to:")
    print(f"  {GRAPH_DIR}")
else:
    communities = [d.name for d in graph_dir.iterdir() if d.is_dir()]
    print(f"Found {len(communities)} communities")

    # Show first few
    for comm in sorted(communities)[:5]:
        comm_path = graph_dir / comm
        n_graphs = len(list(comm_path.glob('*.pt')))
        print(f"  - {comm}: {n_graphs} monthly graphs")

    if len(communities) > 5:
        print(f"  ... and {len(communities) - 5} more")

Found 177 communities
  - 3dprinting.stackexchange.com: 98 monthly graphs
  - academia.stackexchange.com: 147 monthly graphs
  - ai.stackexchange.com: 91 monthly graphs
  - android.stackexchange.com: 174 monthly graphs
  - anime.stackexchange.com: 135 monthly graphs
  ... and 172 more


## 2. Import Model from baseline_gnn.py

**Note:** Make sure `baseline_gnn.py` is uploaded to your project directory:
- Colab: `/content/drive/MyDrive/dont-thread-on-me/src/models/baseline_gnn.py`
- Local: `../src/models/baseline_gnn.py`

In [5]:
# Try to import the model
try:
    from src.models.baseline_gnn import create_model
    print("Successfully imported model from baseline_gnn.py")
except ImportError as e:
    print("Failed to import model")
    print(f"\nError: {e}")
    print(f"\nPlease make sure baseline_gnn.py exists at:")
    if IN_COLAB:
        print(f"  /content/drive/MyDrive/dont-thread-on-me/src/models/baseline_gnn.py")
    else:
        print(f"  ../src/models/baseline_gnn.py")
    raise

Successfully imported model from baseline_gnn.py


## 3. Define Optimized Dataset

- **No pre-caching** - starts training immediately
- **LRU cache** - automatically caches frequently accessed graphs
- **Stratified sampling** - representative samples across communities
- **2014 start date** - skips noisy early Stack Exchange years

In [6]:
# Dataset imports
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from typing import List, Dict
from datetime import datetime
from dateutil.relativedelta import relativedelta
from functools import lru_cache
from collections import defaultdict
import random
import numpy as np
from tqdm.auto import tqdm

In [7]:
class FastTemporalDataset(Dataset):
    """
    Optimized dataset with lazy loading + LRU cache + target normalization.

    - No slow pre-caching step (starts training immediately)
    - LRU cache automatically keeps hot graphs in memory
    - Memory efficient (only caches ~2000 graphs, not all)
    - Stratified sampling for representative coverage
    - Starts from 2014 (skips noisy early years)
    - Target normalization: Standardizes all targets to mean=0, std=1
    """

    def __init__(
        self,
        graph_dir: Path,
        split: str = 'train',
        sequence_length: int = 12,
        prediction_horizon: int = 6,
        max_samples: int = None,
        cache_size: int = 2000,
        stratified_sample: bool = True
    ):
        self.graph_dir = Path(graph_dir)
        self.split = split
        self.sequence_length = sequence_length
        self.prediction_horizon = prediction_horizon

        # Temporal splits
        self.split_ranges = {
            'train': ('2014-01', '2020-06'),  # Predict through 2020-12
            'val':   ('2020-07', '2022-09'),  # Predict through 2023-03
            'test':  ('2022-10', '2023-09')   # Predict through 2024-03
        }

        # Build sample index
        print(f"Building {split} sample index...")
        self.samples = self._build_sample_index()
        print(f"  Found {len(self.samples)} potential samples")

        # Apply sampling strategy
        if max_samples and len(self.samples) > max_samples:
            if stratified_sample:
                print(f"  Applying stratified sampling to select {max_samples} samples...")
                self.samples = self._stratified_sample(max_samples)
            else:
                print(f"  Randomly sampling {max_samples} samples...")
                random.shuffle(self.samples)
                self.samples = self.samples[:max_samples]

        print(f"{split.upper()} Dataset: {len(self.samples)} samples")

        # Initialize normalization stats (will be computed for train, loaded for val/test)
        self.norm_stats = None

        # Create LRU cached loader for individual monthly graphs
        @lru_cache(maxsize=cache_size)
        def _cached_load(community: str, month: str):
            path = self.graph_dir / community / f"{month}.pt"
            return torch.load(path, weights_only=False, map_location='cpu')

        self._load_graph = _cached_load
        self._access_count = 0

    def compute_normalization_stats(self):
        """
        Compute mean and std for each target metric from training data.

        IMPORTANT: Only call this on the training dataset.
        Validation and test datasets will use these same statistics.
        """
        print("\n" + "="*70)
        print("Computing target normalization statistics")
        print("="*70)
        print("Sampling up to 1000 examples to compute statistics...")

        all_targets = {
            'qpd': [],
            'answer_rate': [],
            'retention': []
        }

        # Sample up to 1000 examples for statistics
        sample_size = min(1000, len(self.samples))
        sample_indices = np.random.choice(len(self.samples), sample_size, replace=False)

        for idx in tqdm(sample_indices, desc='Sampling targets'):
            sample = self.samples[idx]
            target_graph = self._load_graph(sample['community'], sample['target_month'])
            targets = target_graph.y

            for key in all_targets:
                if key in targets:
                    all_targets[key].append(float(targets[key]))

        # Compute mean and std for each metric
        self.norm_stats = {}
        for key in all_targets:
            if len(all_targets[key]) > 0:
                values = np.array(all_targets[key])
                self.norm_stats[key] = {
                    'mean': float(values.mean()),
                    'std': float(values.std() + 1e-8)  # Add small constant to avoid division by zero
                }

        print("\nNormalization statistics (computed from training data):")
        print(f"{'Metric':<15} {'Mean':<12} {'Std':<12} {'Min':<12} {'Max':<12}")
        print("-" * 63)

        for key in all_targets:
            if key in self.norm_stats:
                values = np.array(all_targets[key])
                print(f"{key:<15} {self.norm_stats[key]['mean']:<12.4f} "
                      f"{self.norm_stats[key]['std']:<12.4f} "
                      f"{values.min():<12.4f} {values.max():<12.4f}")

        print("="*70)
        print("  Normalization statistics computed")
        print("  These will be used to standardize all targets during training")
        print("  Val/test datasets will use these SAME statistics (no data leakage)")
        print("="*70 + "\n")

        return self.norm_stats

    def _build_sample_index(self) -> List[Dict]:
        """Build index of all valid temporal sequences."""
        samples = []
        start_month, end_month = self.split_ranges[self.split]
        min_graphs = self.sequence_length + self.prediction_horizon

        for community_dir in sorted(self.graph_dir.iterdir()):
            if not community_dir.is_dir():
                continue

            available_months = sorted([f.stem for f in community_dir.glob('*.pt')])
            if len(available_months) < min_graphs:
                continue

            for i, month_t in enumerate(available_months):
                # Check if month is in split range
                if not (start_month <= month_t <= end_month):
                    continue
                if i < self.sequence_length - 1:
                    continue

                target_idx = i + self.prediction_horizon
                if target_idx >= len(available_months):
                    continue

                seq_start = i - self.sequence_length + 1
                seq_months = available_months[seq_start:i+1]
                target_month = available_months[target_idx]

                # Check temporal consistency
                if self._is_consecutive(seq_months, target_month):
                    samples.append({
                        'community': community_dir.name,
                        'sequence_months': seq_months,
                        'target_month': target_month
                    })

        return samples

    def _is_consecutive(self, seq_months, target_month):
        """Check if months form consecutive sequence."""
        try:
            dates = [datetime.strptime(m, '%Y-%m') for m in seq_months]
            for i in range(1, len(dates)):
                if dates[i] != dates[i-1] + relativedelta(months=1):
                    return False
            target_date = datetime.strptime(target_month, '%Y-%m')
            expected = dates[-1] + relativedelta(months=self.prediction_horizon)
            return target_date == expected
        except:
            return False

    def _stratified_sample(self, n_samples: int) -> List[Dict]:
        """
        Create stratified sample across communities.
        Ensures representative distribution across small/medium/large communities.
        """
        # Group by community
        by_community = defaultdict(list)
        for sample in self.samples:
            by_community[sample['community']].append(sample)

        # Sample proportionally from each community
        samples_per_community = max(1, n_samples // len(by_community))
        selected = []

        for community, comm_samples in by_community.items():
            n = min(samples_per_community, len(comm_samples))
            selected.extend(random.sample(comm_samples, n))

        # Trim to exact size and shuffle
        random.shuffle(selected)
        return selected[:n_samples]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        community = sample['community']

        # Load sequence using cached loader
        graphs = [
            self._load_graph(community, month)
            for month in sample['sequence_months']
        ]

        target_graph = self._load_graph(community, sample['target_month'])
        targets = target_graph.y

        # Apply standardization if statistics are available
        if self.norm_stats is not None:
            targets_standardized = {}
            for key in targets:
                if key in self.norm_stats:
                    # Z-score normalization: (x - mean) / std
                    targets_standardized[key] = (
                        (targets[key] - self.norm_stats[key]['mean']) /
                        self.norm_stats[key]['std']
                    )
                else:
                    targets_standardized[key] = targets[key]
            targets = targets_standardized

        # Periodic cache statistics (every 1000 accesses)
        self._access_count += 1
        if self._access_count % 1000 == 0:
            cache_info = self._load_graph.cache_info()
            if cache_info.hits + cache_info.misses > 0:
                hit_rate = cache_info.hits / (cache_info.hits + cache_info.misses)
                print(f"  Cache: {cache_info.currsize}/{cache_info.maxsize} graphs | Hit rate: {hit_rate:.1%}")

        return graphs, targets

    def get_cache_stats(self):
        """Get current cache statistics."""
        return self._load_graph.cache_info()


In [8]:
def collate_fn(batch):
    """Collate function for batching."""
    batch_graphs = []
    batch_targets = {'qpd': [], 'answer_rate': [], 'retention': []}

    for graphs, targets in batch:
        batch_graphs.append(graphs)
        for key in batch_targets:
            batch_targets[key].append(targets[key])

    for key in batch_targets:
        batch_targets[key] = torch.tensor(batch_targets[key], dtype=torch.float32)

    return batch_graphs, batch_targets

## 4. Load Saved Dataset Configuration

**Configuration:**
- 2,000 training samples (stratified across communities)
- 500 validation samples
- Starts from 2014 (skips 2008-2013 noisy data)

In [9]:
# Dataset imports
import pickle

config_path = f"{PROJECT_DIR}/results/baseline_config.pkl"
with open(config_path, 'rb') as f:
    config = pickle.load(f)

print("="*70)
print("Loaded Temporal GNN Training Configuration")
print("="*70)
print(f"Training samples:   {len(config['train_samples'])}")
print(f"Validation samples: {len(config['val_samples'])}")
print(f"Normalization stats: {list(config['norm_stats'].keys())}")
print("="*70)

Loaded Temporal GNN Training Configuration
Training samples:   1834
Validation samples: 338
Normalization stats: ['qpd', 'answer_rate', 'retention']


In [10]:
# Create datasets
train_dataset = FastTemporalDataset(
    graph_dir=GRAPH_DIR,
    split='train',
    max_samples=None,
    stratified_sample=False
)

val_dataset = FastTemporalDataset(
    graph_dir=GRAPH_DIR,
    split='val',
    max_samples=None,
    stratified_sample=False
)

# Use same data
train_dataset.samples = config['train_samples']
val_dataset.samples = config['val_samples']
train_dataset.norm_stats = config['norm_stats']
val_dataset.norm_stats = config['norm_stats']

print("\nDatasets configured with identical samples and normalization")
print(f"Train dataset size: {len(train_dataset.samples)}")
print(f"Val dataset size: {len(val_dataset.samples)}")

Building train sample index...
  Found 11020 potential samples
TRAIN Dataset: 11020 samples
Building val sample index...
  Found 4471 potential samples
VAL Dataset: 4471 samples

Datasets configured with identical samples and normalization
Train dataset size: 1834
Val dataset size: 338


In [11]:
# Inspect a sample to verify data structure
if len(train_dataset) > 0:
    print("\nInspecting first sample...")
    sample_graphs, sample_targets = train_dataset[0]
    print(f"  Sequence length: {len(sample_graphs)} monthly graphs")

    g = sample_graphs[0]
    print(f"\n  Graph structure:")
    print(f"    Users: {g['user'].x.shape}")
    print(f"    Tags: {g['tag'].x.shape}")

    print(f"\n  Target metrics:")
    for key, value in sample_targets.items():
        print(f"    {key}: {value:.4f}")


Inspecting first sample...
  Sequence length: 12 monthly graphs

  Graph structure:
    Users: torch.Size([228, 5])
    Tags: torch.Size([283, 7])

  Target metrics:
    qpd: -0.2391
    answer_rate: -1.8552
    retention: -1.6198
    growth: -0.3421


In [12]:
# Create DataLoaders with optimized settings
BATCH_SIZE = 64

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=1,  # Parallel loading
    pin_memory=True,  # Faster GPU transfer
    persistent_workers=True  # Keep workers alive
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=1,
    pin_memory=True,
    persistent_workers=True
)

print(f"  DataLoaders created")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")

  DataLoaders created
  Batch size: 64
  Train batches: 29
  Val batches: 6


## 5. Create Model

Using the model imported from `baseline_gnn.py`

In [13]:
# Get feature dimensions from data
if len(train_dataset) > 0:
    sample_graphs, _ = train_dataset[0]
    USER_FEAT_DIM = sample_graphs[0]['user'].x.shape[1]
    TAG_FEAT_DIM = sample_graphs[0]['tag'].x.shape[1]
else:
    USER_FEAT_DIM = 5
    TAG_FEAT_DIM = 7

print(f"Feature dimensions:")
print(f"  User features: {USER_FEAT_DIM}")
print(f"  Tag features: {TAG_FEAT_DIM}")

Feature dimensions:
  User features: 5
  Tag features: 7


In [14]:
# Model configuration (optimized for 2K samples)
model = create_model(
    user_feat_dim=USER_FEAT_DIM,
    tag_feat_dim=TAG_FEAT_DIM,
    hidden_dim=64,           # Reduced for initial training
    num_conv_layers=2,       # Reduced to prevent overfitting
    dropout=0.3,             # Higher for regularization with limited data
    batched=True
)

model = model.to(device)

# Count parameters
n_params = sum(p.numel() for p in model.parameters())
n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel created and moved to {device}")
print(f"\nModel Architecture:")
print(f"  Hidden dimension: 64")
print(f"  Graph conv layers: 2")
print(f"  Transformer layers: 2")
print(f"  Dropout: 0.3")
print(f"\nParameters:")
print(f"  Total: {n_params:,}")
print(f"  Trainable: {n_trainable:,}")


Model created and moved to cuda

Model Architecture:
  Hidden dimension: 64
  Graph conv layers: 2
  Transformer layers: 2
  Dropout: 0.3

Parameters:
  Total: 50,843
  Trainable: 50,843


## 6. Training Setup

In [15]:
def move_to_device(batch_graphs, device):
    """Move graphs to device efficiently."""
    return [[g.to(device, non_blocking=True) for g in graphs] for graphs in batch_graphs]

def train_epoch(model, loader, optimizer, criterion, device, scaler, use_amp, scheduler=None, max_grad_norm=1.0):
    """Train for one epoch."""
    model.train()
    total_loss = 0.0
    n_batches = 0

    amp_dtype = torch.bfloat16 if USE_BF16 else torch.float16

    pbar = tqdm(loader, desc='Training')
    for batch_graphs, batch_targets in pbar:
        batch_graphs = move_to_device(batch_graphs, device)
        batch_targets = {k: v.to(device, non_blocking=True) for k, v in batch_targets.items()}

        optimizer.zero_grad(set_to_none=True)

        with torch.amp.autocast(device_type='cuda', enabled=use_amp, dtype=amp_dtype):
            predictions = model(batch_graphs)

            # Simple loss (targets are already standardized!)
            loss = (
                criterion(predictions['qpd'], batch_targets['qpd']) +
                criterion(predictions['answer_rate'], batch_targets['answer_rate']) +
                criterion(predictions['retention'], batch_targets['retention'])
            ) / 3.0

        if use_amp and not USE_BF16:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            optimizer.step()

        if scheduler is not None:
            scheduler.step()

        total_loss += loss.item()
        n_batches += 1
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})

    return total_loss / max(n_batches, 1)


@torch.no_grad()
def evaluate(model, loader, criterion, device, use_amp, norm_stats):
    """
    Evaluate model with INVERSE TRANSFORMATION.

    IMPORTANT: Metrics are computed on ORIGINAL scale (after denormalization).
    This is standard practice - allows interpretation and comparison to baselines.
    """
    model.eval()
    total_loss = 0.0  # Loss on standardized scale
    n_batches = 0

    all_preds = {'qpd': [], 'answer_rate': [], 'retention': []}
    all_targets = {'qpd': [], 'answer_rate': [], 'retention': []}

    amp_dtype = torch.bfloat16 if USE_BF16 else torch.float16

    for batch_graphs, batch_targets in tqdm(loader, desc='Evaluating', leave=False):
        batch_graphs = move_to_device(batch_graphs, device)
        batch_targets = {k: v.to(device, non_blocking=True) for k, v in batch_targets.items()}

        with torch.amp.autocast('cuda', enabled=use_amp, dtype=amp_dtype):
            predictions = model(batch_graphs)

            # Compute loss on standardized scale
            loss = (
                criterion(predictions['qpd'], batch_targets['qpd']) +
                criterion(predictions['answer_rate'], batch_targets['answer_rate']) +
                criterion(predictions['retention'], batch_targets['retention'])
            ) / 3.0

        total_loss += loss.item()
        n_batches += 1

        # INVERSE TRANSFORM back to original scale for metrics
        for key in predictions:
            if key in norm_stats:
                # Inverse z-score: x = z * std + mean
                preds_original = (
                    predictions[key].float().cpu().numpy() * norm_stats[key]['std'] +
                    norm_stats[key]['mean']
                )
                targets_original = (
                    batch_targets[key].float().cpu().numpy() * norm_stats[key]['std'] +
                    norm_stats[key]['mean']
                )
            else:
                preds_original = predictions[key].float().cpu().numpy()
                targets_original = batch_targets[key].float().cpu().numpy()

            all_preds[key].extend(preds_original)
            all_targets[key].extend(targets_original)

    # Compute metrics on ORIGINAL scale (after inverse transform)
    metrics = {}
    for key in all_preds:
        preds = np.array(all_preds[key])
        targets = np.array(all_targets[key])

        # MAE, R², RMSE on original scale
        mae = np.mean(np.abs(preds - targets))
        ss_res = np.sum((targets - preds) ** 2)
        ss_tot = np.sum((targets - targets.mean()) ** 2)
        r2 = 1 - ss_res / (ss_tot + 1e-8)
        rmse = np.sqrt(np.mean((targets - preds) ** 2))

        metrics[key] = {'mae': mae, 'r2': r2, 'rmse': rmse}

    return total_loss / max(n_batches, 1), metrics

print("Training functions defined")

Training functions defined


In [16]:
# Training configuration
LEARNING_RATE = 1e-3
NUM_EPOCHS = 20  # Should be sufficient with 2K samples
PATIENCE = 5     # Early stopping patience
USE_AMP = torch.cuda.is_available()

optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=5e-4)
criterion = nn.MSELoss()
scaler = torch.amp.GradScaler(enabled=USE_AMP and not USE_BF16)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

print("Training Configuration:")
print(f"  Device: {device}")
print(f"  AMP: {USE_AMP}")
print(f"  Dtype: {'bfloat16' if USE_BF16 else 'float16' if USE_AMP else 'float32'}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Early stopping patience: {PATIENCE}")

Training Configuration:
  Device: cuda
  AMP: True
  Dtype: bfloat16
  Batch size: 64
  Learning rate: 0.001
  Epochs: 20
  Early stopping patience: 5


## 7. Quick Test (Single Batch)

Verify everything works before full training

In [17]:
# ============================================================
# PERFORMANCE BENCHMARK - A100 GPU
# ============================================================

import time
import numpy as np

print("="*70)
print("PERFORMANCE BENCHMARK - Testing on A100 GPU")
print("="*70)

if len(train_loader) == 0:
    print("No data in train_loader")
else:
    # Get test batch
    batch_graphs, batch_targets = next(iter(train_loader))
    batch_graphs = move_to_device(batch_graphs, device)
    batch_targets = {k: v.to(device, non_blocking=True) for k, v in batch_targets.items()}

    amp_dtype = torch.bfloat16 if USE_BF16 else torch.float16

    # --------------------------------------------------------
    # 1. WARMUP (important for accurate timing!)
    # --------------------------------------------------------
    print("\n1. Warming up GPU (JIT compilation, caching)...")
    for _ in range(3):
        with torch.amp.autocast('cuda', enabled=USE_AMP, dtype=amp_dtype):
            preds = model(batch_graphs)
            loss = (
                criterion(preds['qpd'], batch_targets['qpd']) +
                criterion(preds['answer_rate'], batch_targets['answer_rate']) +
                criterion(preds['retention'], batch_targets['retention'])
            ) / 3.0
        loss.backward()
        optimizer.zero_grad()

    torch.cuda.synchronize()
    print("Warmup complete")

    # --------------------------------------------------------
    # 2. FORWARD PASS ONLY (inference speed)
    # --------------------------------------------------------
    print("\n2. Testing forward pass (inference)...")
    forward_times = []

    for i in range(5):
        torch.cuda.synchronize()
        start = time.time()

        with torch.amp.autocast('cuda', enabled=USE_AMP, dtype=amp_dtype):
            preds = model(batch_graphs)
            loss = (
                criterion(preds['qpd'], batch_targets['qpd']) +
                criterion(preds['answer_rate'], batch_targets['answer_rate']) +
                criterion(preds['retention'], batch_targets['retention'])
            ) / 3.0

        torch.cuda.synchronize()
        elapsed = time.time() - start
        forward_times.append(elapsed)
        print(f"   Run {i+1}: {elapsed:.3f}s")

    avg_forward = np.mean(forward_times)
    std_forward = np.std(forward_times)

    # --------------------------------------------------------
    # 3. FORWARD + BACKWARD PASS (training speed)
    # --------------------------------------------------------
    print("\n3. Testing forward + backward pass (training)...")
    train_times = []

    for i in range(5):
        optimizer.zero_grad()

        torch.cuda.synchronize()
        start = time.time()

        with torch.amp.autocast('cuda', enabled=USE_AMP, dtype=amp_dtype):
            preds = model(batch_graphs)
            loss = (
                criterion(preds['qpd'], batch_targets['qpd']) +
                criterion(preds['answer_rate'], batch_targets['answer_rate']) +
                criterion(preds['retention'], batch_targets['retention'])
            ) / 3.0

        if USE_AMP:
            scaler.scale(loss).backward()
        else:
            loss.backward()

        torch.cuda.synchronize()
        elapsed = time.time() - start
        train_times.append(elapsed)
        print(f"   Run {i+1}: {elapsed:.3f}s (loss: {loss.item():.4f})")

    avg_train = np.mean(train_times)
    std_train = np.std(train_times)

    # --------------------------------------------------------
    # 4. RESULTS SUMMARY
    # --------------------------------------------------------
    print("\n" + "="*70)
    print("RESULTS")
    print("="*70)

    print(f"\nForward pass only:")
    print(f"  Average: {avg_forward:.3f}s ± {std_forward:.3f}s")

    print(f"\nForward + Backward (training):")
    print(f"  Average: {avg_train:.3f}s ± {std_train:.3f}s")

    print(f"\nBatch details:")
    print(f"  Batch size: {batch_targets['qpd'].shape[0]}")
    print(f"  Samples per batch: {batch_targets['qpd'].shape[0]}")
    print(f"  Using bfloat16: {USE_BF16}")
    print(f"  Using AMP: {USE_AMP}")

    # --------------------------------------------------------
    # 5. TIME PROJECTIONS
    # --------------------------------------------------------
    batches_per_epoch = len(train_loader)
    epoch_time_sec = avg_train * batches_per_epoch
    epoch_time_min = epoch_time_sec / 60
    total_time_min = epoch_time_min * NUM_EPOCHS
    total_time_hr = total_time_min / 60

    print(f"\n" + "="*70)
    print("TIME PROJECTIONS")
    print("="*70)
    print(f"\nConfiguration:")
    print(f"  Training samples: {len(train_dataset)}")
    print(f"  Batch size: {batch_targets['qpd'].shape[0]}")
    print(f"  Batches per epoch: {batches_per_epoch}")
    print(f"  Number of epochs: {NUM_EPOCHS}")

    print(f"\nEstimated times:")
    print(f"  Per epoch: {epoch_time_min:.1f} minutes ({epoch_time_sec:.0f} seconds)")
    print(f"  Total ({NUM_EPOCHS} epochs): {total_time_min:.1f} minutes ({total_time_hr:.2f} hours)")

    # Add validation time estimate (roughly 25% of training time)
    val_time_min = (total_time_min * 0.25)
    total_with_val = total_time_min + val_time_min

    print(f"  + Validation time: ~{val_time_min:.1f} minutes")
    print(f"  = Total training time: ~{total_with_val:.1f} minutes ({total_with_val/60:.2f} hours)")

    # --------------------------------------------------------
    # 6. PERFORMANCE ASSESSMENT
    # --------------------------------------------------------
    print(f"\n" + "="*70)
    print("ASSESSMENT")
    print("="*70)

    print(f"  Expected total time: {total_with_val:.0f} minutes ({total_with_val/60:.1f} hours)")

    # --------------------------------------------------------
    # 7. GPU MEMORY USAGE
    # --------------------------------------------------------
    print(f"\n" + "="*70)
    print("GPU MEMORY")
    print("="*70)

    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1e9
        reserved = torch.cuda.memory_reserved() / 1e9
        max_allocated = torch.cuda.max_memory_allocated() / 1e9

        print(f"\nCurrent usage:")
        print(f"  Allocated: {allocated:.2f} GB")
        print(f"  Reserved: {reserved:.2f} GB")
        print(f"  Peak: {max_allocated:.2f} GB")

        # A100 has 40GB or 80GB
        gpu_name = torch.cuda.get_device_name(0)
        if '80GB' in gpu_name:
            total_memory = 80
        else:
            total_memory = 40

        usage_pct = (max_allocated / total_memory) * 100
        print(f"  Total GPU memory: {total_memory} GB")
        print(f"  Usage: {usage_pct:.1f}%")

        torch.cuda.reset_peak_memory_stats()

    print(f"\n" + "="*70)
    print("Benchmark complete! Ready to start training.")
    print("="*70)

    # Clean up
    optimizer.zero_grad()
    del batch_graphs, batch_targets, preds, loss
    torch.cuda.empty_cache()

PERFORMANCE BENCHMARK - Testing on A100 GPU

1. Warming up GPU (JIT compilation, caching)...
Warmup complete

2. Testing forward pass (inference)...
   Run 1: 2.660s
   Run 2: 2.859s
   Run 3: 2.725s
   Run 4: 2.734s
   Run 5: 2.922s

3. Testing forward + backward pass (training)...
   Run 1: 5.025s (loss: 1.1614)
   Run 2: 4.898s (loss: 1.1623)
   Run 3: 4.936s (loss: 1.1623)
   Run 4: 4.906s (loss: 1.1617)
   Run 5: 5.208s (loss: 1.1623)

RESULTS

Forward pass only:
  Average: 2.780s ± 0.096s

Forward + Backward (training):
  Average: 4.995s ± 0.116s

Batch details:
  Batch size: 64
  Samples per batch: 64
  Using bfloat16: True
  Using AMP: True

TIME PROJECTIONS

Configuration:
  Training samples: 1834
  Batch size: 64
  Batches per epoch: 29
  Number of epochs: 20

Estimated times:
  Per epoch: 2.4 minutes (145 seconds)
  Total (20 epochs): 48.3 minutes (0.80 hours)
  + Validation time: ~12.1 minutes
  = Total training time: ~60.4 minutes (1.01 hours)

ASSESSMENT
  Expected total 

## 8. Training Loop

**Expected training time:** ~2 hours for 20 epochs with 2K samples

In [None]:
# Training loop with early stopping - TRACKS ALL METRICS
best_val_loss = float('inf')
patience_counter = 0
train_losses = []
val_losses = []

# Track all metrics for each target variable
val_metrics_history = {
    'qpd': {'mae': [], 'r2': [], 'rmse': []},
    'answer_rate': {'mae': [], 'r2': [], 'rmse': []},
    'retention': {'mae': [], 'r2': [], 'rmse': []}
}

print("="*70)
print("Starting Training")
print("="*70)

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    print("-" * 70)

    # Train
    train_loss = train_epoch(
        model, train_loader, optimizer, criterion,
        device, scaler, USE_AMP, scheduler
    )
    train_losses.append(train_loss)

    # Use the training dataset's normalization stats for evaluation
    norm_stats = train_dataset.norm_stats

    # Validate - returns metrics dict now
    val_loss, metrics = evaluate(
        model, val_loader, criterion, device, USE_AMP, norm_stats
    )
    val_losses.append(val_loss)

    # Store all metrics
    for key in metrics:
        for metric_name in ['mae', 'r2', 'rmse']:
            val_metrics_history[key][metric_name].append(metrics[key][metric_name])

    # Learning rate step
    scheduler.step()

    # Print results
    print(f"\nResults:")
    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Val Loss:   {val_loss:.4f}")
    print(f"  LR: {optimizer.param_groups[0]['lr']:.6f}")
    print(f"\n  Validation Metrics:")
    print(f"    {'Metric':<15} {'MAE':<10} {'R²':<10} {'RMSE':<10}")
    print(f"    {'-'*45}")
    for key in ['qpd', 'answer_rate', 'retention']:
        if key in metrics:
            print(f"    {key.replace('_', ' ').title():<15} "
                  f"{metrics[key]['mae']:<10.3f} "
                  f"{metrics[key]['r2']:<10.3f} "
                  f"{metrics[key]['rmse']:<10.3f}")

    # Compute mean metrics across all targets
    mean_mae = np.mean([metrics[k]['mae'] for k in metrics])
    mean_r2 = np.mean([metrics[k]['r2'] for k in metrics])
    mean_rmse = np.mean([metrics[k]['rmse'] for k in metrics])
    print(f"    {'Mean':<15} {mean_mae:<10.3f} {mean_r2:<10.3f} {mean_rmse:<10.3f}")

    # Early stopping check (based on validation loss)
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0

        # Save best model
        if IN_COLAB:
            model_path = f"{PROJECT_DIR}/results/baseline_gnn/best_model.pt"
        else:
            model_path = "../results/baseline_gnn/best_model.pt"

        # Create results directory if it doesn't exist
        import os
        os.makedirs(os.path.dirname(model_path), exist_ok=True)

        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'metrics': metrics,
            'train_losses': train_losses,
            'val_losses': val_losses,
            'val_metrics_history': val_metrics_history
        }, model_path)

        print(f"    New best model saved (val_loss: {val_loss:.4f})")
    else:
        patience_counter += 1
        print(f"  No improvement ({patience_counter}/{PATIENCE})")

        if patience_counter >= PATIENCE:
            print(f"\nEarly stopping triggered after {epoch+1} epochs")
            break

    # Show cache statistics every 5 epochs
    if (epoch + 1) % 5 == 0:
        cache_info = train_dataset.get_cache_stats()
        if cache_info.hits + cache_info.misses > 0:
            hit_rate = cache_info.hits / (cache_info.hits + cache_info.misses)
            print(f"\n  Dataset Cache: {cache_info.currsize}/{cache_info.maxsize} | Hit rate: {hit_rate:.1%}")

print("\n" + "="*70)
print("Training Complete!")
print("="*70)
print(f"Best validation loss: {best_val_loss:.4f}")

Starting Training

Epoch 1/20
----------------------------------------------------------------------


Training:   0%|          | 0/29 [12:19<?, ?it/s]

## 9. Visualize Results

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import os

# Set style
sns.set_style("whitegrid")
sns.set_context("paper", font_scale=1.2)
plt.rcParams['figure.dpi'] = 500
plt.rcParams['savefig.dpi'] = 500
plt.rcParams['font.family'] = 'sans-serif'

# Create results directory
results_dir = f"{PROJECT_DIR}/results/baseline_gnn" if IN_COLAB else "../results/baseline_gnn"
os.makedirs(results_dir, exist_ok=True)

# Color palette
colors = sns.color_palette("husl", 3)
metric_colors = {'qpd': '#2E86AB', 'answer_rate': '#A23B72', 'retention': '#F18F01'}

# ============================================================
# Figure 1: Training & Validation Loss
# ============================================================
fig, ax = plt.subplots(figsize=(8, 5))

epochs = np.arange(1, len(train_losses) + 1)
ax.plot(epochs, train_losses, marker='o', linewidth=2.5,
        label='Training Loss', color=colors[0], markersize=6)
ax.plot(epochs, val_losses, marker='s', linewidth=2.5,
        label='Validation Loss', color=colors[1], markersize=6)

ax.set_xlabel('Epoch', fontsize=12, fontweight='bold')
ax.set_ylabel('Loss (MSE)', fontsize=12, fontweight='bold')
ax.set_title('Training Progress', fontsize=14, fontweight='bold', pad=15)
ax.legend(frameon=True, loc='best', fontsize=11)
ax.grid(True, alpha=0.3, linestyle='--')

plt.tight_layout()
plt.savefig(f"{results_dir}/loss_curves.png", bbox_inches='tight')
print(f"Saved: {results_dir}/loss_curves.png")
plt.show()

# ============================================================
# Figure 2: R² Performance Over Time
# ============================================================
fig, ax = plt.subplots(figsize=(8, 5))

targets = ['qpd', 'answer_rate', 'retention']
labels = {'qpd': 'Questions Per Day', 'answer_rate': 'Answer Rate', 'retention': 'User Retention'}

for target in targets:
    r2_values = val_metrics_history[target]['r2']
    ax.plot(epochs, r2_values, marker='o', linewidth=2.5,
            label=labels[target], color=metric_colors[target], markersize=6)

ax.axhline(y=0.4, color='gray', linestyle='--', alpha=0.5, linewidth=2, label='Target (0.4)')
ax.set_xlabel('Epoch', fontsize=12, fontweight='bold')
ax.set_ylabel('R² Score', fontsize=12, fontweight='bold')
ax.set_title('R² Performance by Metric', fontsize=14, fontweight='bold', pad=15)
ax.legend(frameon=True, loc='best', fontsize=10)
ax.grid(True, alpha=0.3, linestyle='--')
ax.set_ylim(bottom=0)

plt.tight_layout()
plt.savefig(f"{results_dir}/r2_over_time.png", bbox_inches='tight')
print(f"Saved: {results_dir}/r2_over_time.png")
plt.show()

# ============================================================
# Figure 3: Final Performance Comparison
# ============================================================
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Prepare data
metric_names = [labels[t] for t in targets]
final_r2 = [val_metrics_history[t]['r2'][-1] for t in targets]
final_mae = [val_metrics_history[t]['mae'][-1] for t in targets]

# R² comparison
ax1 = axes[0]
bars = ax1.barh(metric_names, final_r2, color=[metric_colors[t] for t in targets], alpha=0.8)
ax1.axvline(x=0.4, color='gray', linestyle='--', linewidth=2, alpha=0.5, label='Target')
ax1.set_xlabel('R² Score', fontsize=12, fontweight='bold')
ax1.set_title('R² Performance', fontsize=13, fontweight='bold', pad=10)
ax1.set_xlim(0, max(final_r2) * 1.2)
ax1.grid(True, alpha=0.3, linestyle='--', axis='x')

# Add value labels
for i, (bar, val) in enumerate(zip(bars, final_r2)):
    ax1.text(val + 0.01, bar.get_y() + bar.get_height()/2,
             f'{val:.3f}', va='center', fontsize=11, fontweight='bold')

# MAE comparison
ax2 = axes[1]
bars = ax2.barh(metric_names, final_mae, color=[metric_colors[t] for t in targets], alpha=0.8)
ax2.set_xlabel('Mean Absolute Error', fontsize=12, fontweight='bold')
ax2.set_title('MAE Performance', fontsize=13, fontweight='bold', pad=10)
ax2.grid(True, alpha=0.3, linestyle='--', axis='x')
ax2.invert_xaxis()  # Lower is better

# Add value labels
for i, (bar, val) in enumerate(zip(bars, final_mae)):
    ax2.text(val - 0.01, bar.get_y() + bar.get_height()/2,
             f'{val:.3f}', va='center', ha='right', fontsize=11, fontweight='bold')

plt.tight_layout()
plt.savefig(f"{results_dir}/final_performance.png", bbox_inches='tight')
print(f"Saved: {results_dir}/final_performance.png")
plt.show()

# ============================================================
# Figure 4: Detailed Metric Evolution (3 subplots)
# ============================================================
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for idx, (target, ax) in enumerate(zip(targets, axes)):
    # Plot MAE and RMSE
    mae_values = val_metrics_history[target]['mae']
    rmse_values = val_metrics_history[target]['rmse']
    r2_values = val_metrics_history[target]['r2']

    color = metric_colors[target]

    # Primary axis: MAE and RMSE
    ax.plot(epochs, mae_values, marker='o', linewidth=2,
            label='MAE', color=color, markersize=5)
    ax.plot(epochs, rmse_values, marker='s', linewidth=2,
            label='RMSE', color=color, alpha=0.6, markersize=5, linestyle='--')

    ax.set_xlabel('Epoch', fontsize=11, fontweight='bold')
    ax.set_ylabel('Error', fontsize=11, fontweight='bold')
    ax.set_title(labels[target], fontsize=12, fontweight='bold', pad=10)
    ax.grid(True, alpha=0.3, linestyle='--')
    ax.legend(loc='upper left', fontsize=9)

    # Secondary axis: R²
    ax2 = ax.twinx()
    ax2.plot(epochs, r2_values, marker='^', linewidth=2,
             label='R²', color='green', alpha=0.7, markersize=5)
    ax2.axhline(y=0.4, color='gray', linestyle=':', alpha=0.5)
    ax2.set_ylabel('R²', fontsize=11, fontweight='bold', color='green')
    ax2.tick_params(axis='y', labelcolor='green')
    ax2.legend(loc='upper right', fontsize=9)

plt.tight_layout()
plt.savefig(f"{results_dir}/metric_evolution.png", bbox_inches='tight')
print(f"Saved: {results_dir}/metric_evolution.png")
plt.show()

# ============================================================
# Figure 5: Results Summary Table (as image)
# ============================================================
fig, ax = plt.subplots(figsize=(10, 4))
ax.axis('tight')
ax.axis('off')

# Prepare table data
table_data = []
table_data.append(['Metric', 'MAE ↓', 'RMSE ↓', 'R² ↑', 'Normalized MAE'])

norm_stats = train_dataset.norm_stats
for target in targets:
    final_mae = val_metrics_history[target]['mae'][-1]
    final_rmse = val_metrics_history[target]['rmse'][-1]
    final_r2 = val_metrics_history[target]['r2'][-1]
    norm_mae = final_mae / norm_stats[target]['std']

    table_data.append([
        labels[target],
        f'{final_mae:.4f}',
        f'{final_rmse:.4f}',
        f'{final_r2:.4f}',
        f'{norm_mae:.4f}'
    ])

# Add mean row
mean_mae = np.mean([val_metrics_history[t]['mae'][-1] for t in targets])
mean_rmse = np.mean([val_metrics_history[t]['rmse'][-1] for t in targets])
mean_r2 = np.mean([val_metrics_history[t]['r2'][-1] for t in targets])
mean_norm_mae = np.mean([val_metrics_history[t]['mae'][-1] / norm_stats[t]['std'] for t in targets])

table_data.append([
    'MEAN',
    f'{mean_mae:.4f}',
    f'{mean_rmse:.4f}',
    f'{mean_r2:.4f}',
    f'{mean_norm_mae:.4f}'
])

# Create table
table = ax.table(cellText=table_data, cellLoc='center', loc='center',
                colWidths=[0.25, 0.15, 0.15, 0.15, 0.2])
table.auto_set_font_size(False)
table.set_fontsize(11)
table.scale(1, 2)

# Style header row
for i in range(5):
    cell = table[(0, i)]
    cell.set_facecolor('#2E86AB')
    cell.set_text_props(weight='bold', color='white')

# Style mean row
for i in range(5):
    cell = table[(len(table_data)-1, i)]
    cell.set_facecolor('#E8E8E8')
    cell.set_text_props(weight='bold')

# Alternate row colors
for i in range(1, len(table_data)-1):
    for j in range(5):
        if i % 2 == 0:
            table[(i, j)].set_facecolor('#F5F5F5')

plt.title('Final Validation Results', fontsize=14, fontweight='bold', pad=20)
plt.savefig(f"{results_dir}/results_table.png", bbox_inches='tight')
print(f"Saved: {results_dir}/results_table.png")
plt.show()

# ============================================================
# Print Summary
# ============================================================
print("\n" + "="*70)
print("VISUALIZATION SUMMARY")
print("="*70)
print(f"\nAll figures saved to: {results_dir}/")
print("\nFiles created:")
print("  1. loss_curves.png       - Training/validation loss")
print("  2. r2_over_time.png      - R² evolution by metric")
print("  3. final_performance.png - Final R² and MAE comparison")
print("  4. metric_evolution.png  - Detailed metric tracking")
print("  5. results_table.png     - Summary table")
print("\n" + "="*70)

# ============================================================
# Detailed Text Summary
# ============================================================
print("\nFINAL RESULTS SUMMARY")
print("="*70)
print(f"\nBest Validation Loss: {best_val_loss:.4f}")
print(f"Achieved at Epoch: {checkpoint['epoch'] if 'checkpoint' in globals() else 'N/A'}\n")

print("Per-Metric Performance:")
print(f"{'Metric':<20} {'MAE ↓':<12} {'RMSE ↓':<12} {'R² ↑':<12} {'Norm MAE ↓':<12}")
print("-" * 68)

for target in targets:
    final_mae = val_metrics_history[target]['mae'][-1]
    final_rmse = val_metrics_history[target]['rmse'][-1]
    final_r2 = val_metrics_history[target]['r2'][-1]
    norm_mae = final_mae / norm_stats[target]['std']

    print(f"{labels[target]:<20} {final_mae:<12.4f} {final_rmse:<12.4f} "
          f"{final_r2:<12.4f} {norm_mae:<12.4f}")

print("-" * 68)
print(f"{'MEAN':<20} {mean_mae:<12.4f} {mean_rmse:<12.4f} "
      f"{mean_r2:<12.4f} {mean_norm_mae:<12.4f}")

print("\n" + "="*70)
print("INTERPRETATION")
print("="*70)

if mean_r2 > 0.5:
    print("✓ Strong predictive performance (R² > 0.5)")
elif mean_r2 > 0.4:
    print("✓ Good predictive performance (R² > 0.4)")
elif mean_r2 > 0.3:
    print("⚠ Moderate performance (R² > 0.3)")
    print("  Consider: Temporal distribution shift (COVID-19 period in validation)")
else:
    print("⚠ Weak performance (R² < 0.3)")

print("\nPractical Interpretation:")
print(f"  • Questions/Day:  ±{val_metrics_history['qpd']['mae'][-1]:.2f} questions")
print(f"  • Answer Rate:    ±{val_metrics_history['answer_rate']['mae'][-1]:.1%}")
print(f"  • User Retention: ±{val_metrics_history['retention']['mae'][-1]:.1%}")

print("\nNormalized MAE (lower is better):")
for target in targets:
    norm_mae = val_metrics_history[target]['mae'][-1] / norm_stats[target]['std']
    if norm_mae < 0.5:
        status = "✓ Excellent"
    elif norm_mae < 1.0:
        status = "✓ Good"
    else:
        status = "⚠ Needs improvement"
    print(f"  • {labels[target]:<17}: {norm_mae:.3f} {status}")

print("\n" + "="*70)

## 10. Save GNN Results

In [None]:
# Save GNN results
gnn_results = {
    'val_loss': best_val_loss,
    'metrics': metrics,  # Final validation metrics
    'train_losses': train_losses,
    'val_losses': val_losses,
    'val_metrics_history': val_metrics_history
}

with open(f"{results_dir}/baseline_gnn/gnn_results.pkl", 'wb') as f:
    pickle.dump(gnn_results, f)