# GNN Training Dataset

- Create a 2K stratified sample dataset for training
- Lazy loading with LRU cache
- Works both locally and on Colab

## 1. Setup Environment

In [None]:
# Check if running on Colab
import sys
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    print("Running on Google Colab")
    !pip install torch-geometric
    !pip install python-dateutil
else:
    print("Running locally")

Running on Google Colab
Collecting torch-geometric
  Downloading torch_geometric-2.7.0-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.7.0-py3-none-any.whl (1.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m43.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-geometric
Successfully installed torch-geometric-2.7.0


In [None]:
# Check GPU availability
import torch

# Default values
USE_BF16 = False
USE_AMP = torch.cuda.is_available()  # AMP works only on CUDA

if torch.cuda.is_available():
    device = torch.device('cuda')
    gpu_name = torch.cuda.get_device_name(0).lower()
    gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9

    print(f"GPU available: {gpu_name}")
    print(f"GPU memory: {gpu_mem:.2f} GB")

    # Precision selection based on GPU family
    if any(key in gpu_name for key in ['b200', 'b100', 'blackwell']):
        print("Detected Blackwell GPU → using BF16 for optimal performance")
        USE_BF16 = True

    elif any(key in gpu_name for key in ['h100', 'a100']):
        print("Detected Hopper/Ampere GPU → using BF16")
        USE_BF16 = True

    else:
        print("Standard CUDA GPU detected → using FP16 (AMP)")
        USE_BF16 = False

else:
    device = torch.device('cpu')
    print("No GPU available → using CPU")
    USE_BF16 = False
    USE_AMP = False

print("\nFinal Precision Settings")
print(f"Using device: {device}")
print(f"AMP enabled:  {USE_AMP}")
print(f"BF16 enabled: {USE_BF16}")

GPU available: nvidia a100-sxm4-80gb
GPU memory: 85.17 GB
Detected Hopper/Ampere GPU → using BF16

Final Precision Settings
Using device: cuda
AMP enabled:  True
BF16 enabled: True


In [None]:
# Setup paths based on environment
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')

    PROJECT_DIR = '/content/drive/MyDrive/dont-thread-on-me'
    GRAPH_DIR = f'{PROJECT_DIR}/data/processed/graphs'

    # Add project directory to path for imports
    if PROJECT_DIR not in sys.path:
        sys.path.append(PROJECT_DIR)

    print(f"Google Drive mounted")
    print(f"Project directory: {PROJECT_DIR}")

else:
    PROJECT_DIR = '..'
    GRAPH_DIR = '../data/processed/graphs'

    if PROJECT_DIR not in sys.path:
        sys.path.append(PROJECT_DIR)

    print(f"Running locally")
    print(f"Project directory: {PROJECT_DIR}")

print(f"Graph directory: {GRAPH_DIR}")

Mounted at /content/drive
Google Drive mounted
Project directory: /content/drive/MyDrive/dont-thread-on-me
Graph directory: /content/drive/MyDrive/dont-thread-on-me/data/processed/graphs


In [None]:
# Verify data directory exists
from pathlib import Path

graph_dir = Path(GRAPH_DIR)

if not graph_dir.exists():
    print(f"ERROR: Graph directory not found: {graph_dir}")
    print(f"\nPlease upload your processed graphs to:")
    print(f"  {GRAPH_DIR}")
else:
    communities = [d.name for d in graph_dir.iterdir() if d.is_dir()]
    print(f"Found {len(communities)} communities")

    # Show first few
    for comm in sorted(communities)[:5]:
        comm_path = graph_dir / comm
        n_graphs = len(list(comm_path.glob('*.pt')))
        print(f"  - {comm}: {n_graphs} monthly graphs")

    if len(communities) > 5:
        print(f"  ... and {len(communities) - 5} more")

Found 177 communities
  - 3dprinting.stackexchange.com: 98 monthly graphs
  - academia.stackexchange.com: 147 monthly graphs
  - ai.stackexchange.com: 91 monthly graphs
  - android.stackexchange.com: 174 monthly graphs
  - anime.stackexchange.com: 135 monthly graphs
  ... and 172 more


## 2. Define Optimized Dataset

- **No pre-caching** - starts training immediately
- **LRU cache** - automatically caches frequently accessed graphs
- **Stratified sampling** - representative samples across communities
- **2014 start date** - skips noisy early Stack Exchange years

In [None]:
# Dataset imports
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from typing import List, Dict
from datetime import datetime
from dateutil.relativedelta import relativedelta
from functools import lru_cache
from collections import defaultdict
import random
import numpy as np
from tqdm.auto import tqdm

In [None]:
class FastTemporalDataset(Dataset):
    """
    Optimized dataset with lazy loading + LRU cache + target normalization.

    - No slow pre-caching step (starts training immediately)
    - LRU cache automatically keeps hot graphs in memory
    - Memory efficient (only caches ~2000 graphs, not all)
    - Stratified sampling for representative coverage
    - Starts from 2014 (skips noisy early years)
    - Target normalization: Standardizes all targets to mean=0, std=1
    """

    def __init__(
        self,
        graph_dir: Path,
        split: str = 'train',
        sequence_length: int = 12,
        prediction_horizon: int = 6,
        max_samples: int = None,
        cache_size: int = 2000,
        stratified_sample: bool = True
    ):
        self.graph_dir = Path(graph_dir)
        self.split = split
        self.sequence_length = sequence_length
        self.prediction_horizon = prediction_horizon

        # Temporal splits
        self.split_ranges = {
            'train': ('2014-01', '2020-06'),  # Predict through 2020-12
            'val':   ('2020-07', '2022-09'),  # Predict through 2023-03
            'test':  ('2022-10', '2023-09')   # Predict through 2024-03
        }

        # Build sample index
        print(f"Building {split} sample index...")
        self.samples = self._build_sample_index()
        print(f"  Found {len(self.samples)} potential samples")

        # Apply sampling strategy
        if max_samples and len(self.samples) > max_samples:
            if stratified_sample:
                print(f"  Applying stratified sampling to select {max_samples} samples...")
                self.samples = self._stratified_sample(max_samples)
            else:
                print(f"  Randomly sampling {max_samples} samples...")
                random.shuffle(self.samples)
                self.samples = self.samples[:max_samples]

        print(f"{split.upper()} Dataset: {len(self.samples)} samples")

        # Initialize normalization stats (will be computed for train, loaded for val/test)
        self.norm_stats = None

        # Create LRU cached loader for individual monthly graphs
        @lru_cache(maxsize=cache_size)
        def _cached_load(community: str, month: str):
            path = self.graph_dir / community / f"{month}.pt"
            return torch.load(path, weights_only=False, map_location='cpu')

        self._load_graph = _cached_load
        self._access_count = 0

    def compute_normalization_stats(self):
        """
        Compute mean and std for each target metric from training data.

        IMPORTANT: Only call this on the training dataset.
        Validation and test datasets will use these same statistics.
        """
        print("\n" + "="*70)
        print("Computing target normalization statistics")
        print("="*70)
        print("Sampling up to 1000 examples to compute statistics...")

        all_targets = {
            'qpd': [],
            'answer_rate': [],
            'retention': []
        }

        # Sample up to 1000 examples for statistics
        sample_size = min(1000, len(self.samples))
        sample_indices = np.random.choice(len(self.samples), sample_size, replace=False)

        for idx in tqdm(sample_indices, desc='Sampling targets'):
            sample = self.samples[idx]
            target_graph = self._load_graph(sample['community'], sample['target_month'])
            targets = target_graph.y

            for key in all_targets:
                if key in targets:
                    all_targets[key].append(float(targets[key]))

        # Compute mean and std for each metric
        self.norm_stats = {}
        for key in all_targets:
            if len(all_targets[key]) > 0:
                values = np.array(all_targets[key])
                self.norm_stats[key] = {
                    'mean': float(values.mean()),
                    'std': float(values.std() + 1e-8)  # Add small constant to avoid division by zero
                }

        print("\nNormalization statistics (computed from training data):")
        print(f"{'Metric':<15} {'Mean':<12} {'Std':<12} {'Min':<12} {'Max':<12}")
        print("-" * 63)

        for key in all_targets:
            if key in self.norm_stats:
                values = np.array(all_targets[key])
                print(f"{key:<15} {self.norm_stats[key]['mean']:<12.4f} "
                      f"{self.norm_stats[key]['std']:<12.4f} "
                      f"{values.min():<12.4f} {values.max():<12.4f}")

        print("="*70)
        print("  Normalization statistics computed")
        print("  These will be used to standardize all targets during training")
        print("  Val/test datasets will use these SAME statistics (no data leakage)")
        print("="*70 + "\n")

        return self.norm_stats

    def _build_sample_index(self) -> List[Dict]:
        """Build index of all valid temporal sequences."""
        samples = []
        start_month, end_month = self.split_ranges[self.split]
        min_graphs = self.sequence_length + self.prediction_horizon

        for community_dir in sorted(self.graph_dir.iterdir()):
            if not community_dir.is_dir():
                continue

            available_months = sorted([f.stem for f in community_dir.glob('*.pt')])
            if len(available_months) < min_graphs:
                continue

            for i, month_t in enumerate(available_months):
                # Check if month is in split range
                if not (start_month <= month_t <= end_month):
                    continue
                if i < self.sequence_length - 1:
                    continue

                target_idx = i + self.prediction_horizon
                if target_idx >= len(available_months):
                    continue

                seq_start = i - self.sequence_length + 1
                seq_months = available_months[seq_start:i+1]
                target_month = available_months[target_idx]

                # Check temporal consistency
                if self._is_consecutive(seq_months, target_month):
                    samples.append({
                        'community': community_dir.name,
                        'sequence_months': seq_months,
                        'target_month': target_month
                    })

        return samples

    def _is_consecutive(self, seq_months, target_month):
        """Check if months form consecutive sequence."""
        try:
            dates = [datetime.strptime(m, '%Y-%m') for m in seq_months]
            for i in range(1, len(dates)):
                if dates[i] != dates[i-1] + relativedelta(months=1):
                    return False
            target_date = datetime.strptime(target_month, '%Y-%m')
            expected = dates[-1] + relativedelta(months=self.prediction_horizon)
            return target_date == expected
        except:
            return False

    def _stratified_sample(self, n_samples: int) -> List[Dict]:
        """
        Create stratified sample across communities.
        Ensures representative distribution across small/medium/large communities.
        """
        # Group by community
        by_community = defaultdict(list)
        for sample in self.samples:
            by_community[sample['community']].append(sample)

        # Sample proportionally from each community
        samples_per_community = max(1, n_samples // len(by_community))
        selected = []

        for community, comm_samples in by_community.items():
            n = min(samples_per_community, len(comm_samples))
            selected.extend(random.sample(comm_samples, n))

        # Trim to exact size and shuffle
        random.shuffle(selected)
        return selected[:n_samples]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        community = sample['community']

        # Load sequence using cached loader
        graphs = [
            self._load_graph(community, month)
            for month in sample['sequence_months']
        ]

        target_graph = self._load_graph(community, sample['target_month'])
        targets = target_graph.y

        # Apply standardization if statistics are available
        if self.norm_stats is not None:
            targets_standardized = {}
            for key in targets:
                if key in self.norm_stats:
                    # Z-score normalization: (x - mean) / std
                    targets_standardized[key] = (
                        (targets[key] - self.norm_stats[key]['mean']) /
                        self.norm_stats[key]['std']
                    )
                else:
                    targets_standardized[key] = targets[key]
            targets = targets_standardized

        # Periodic cache statistics (every 1000 accesses)
        self._access_count += 1
        if self._access_count % 1000 == 0:
            cache_info = self._load_graph.cache_info()
            if cache_info.hits + cache_info.misses > 0:
                hit_rate = cache_info.hits / (cache_info.hits + cache_info.misses)
                print(f"  Cache: {cache_info.currsize}/{cache_info.maxsize} graphs | Hit rate: {hit_rate:.1%}")

        return graphs, targets

    def get_cache_stats(self):
        """Get current cache statistics."""
        return self._load_graph.cache_info()


In [None]:
def collate_fn(batch):
    """Collate function for batching."""
    batch_graphs = []
    batch_targets = {'qpd': [], 'answer_rate': [], 'retention': []}

    for graphs, targets in batch:
        batch_graphs.append(graphs)
        for key in batch_targets:
            batch_targets[key].append(targets[key])

    for key in batch_targets:
        batch_targets[key] = torch.tensor(batch_targets[key], dtype=torch.float32)

    return batch_graphs, batch_targets

## 3. Create Datasets

**Configuration:**
- 2,000 training samples (stratified across communities)
- 500 validation samples
- Starts from 2014 (skips 2008-2013 noisy data)
- No pre-caching. Training starts immediately

In [None]:
# Configuration
TRAIN_SAMPLES = 2000  # Stratified sample
VAL_SAMPLES = 500

# Create datasets
print("Creating datasets...\n")

train_dataset = FastTemporalDataset(
    graph_dir=GRAPH_DIR,
    split='train',
    max_samples=TRAIN_SAMPLES,
    cache_size=2000,
    stratified_sample=True  # Ensures representative coverage
)

val_dataset = FastTemporalDataset(
    graph_dir=GRAPH_DIR,
    split='val',
    max_samples=VAL_SAMPLES,
    cache_size=1000,
    stratified_sample=True
)

print(f"\nDatasets created: {len(train_dataset)} train, {len(val_dataset)} val")

Creating datasets...

Building train sample index...
  Found 11020 potential samples
  Applying stratified sampling to select 2000 samples...
TRAIN Dataset: 1834 samples
Building val sample index...
  Found 4471 potential samples
  Applying stratified sampling to select 500 samples...
VAL Dataset: 338 samples

Datasets created: 1834 train, 338 val


In [None]:
# Compute normalization statistics from training data only
train_dataset.compute_normalization_stats()

# Share normalization statistics with validation dataset
# This ensures no data leakage (val/test use train statistics)
val_dataset.norm_stats = train_dataset.norm_stats

print("\n  Validation dataset will use training normalization statistics")
print("  This is standard practice in ML to prevent data leakage.")


Computing target normalization statistics
Sampling up to 1000 examples to compute statistics...


Sampling targets:   0%|          | 0/1000 [00:00<?, ?it/s]


Normalization statistics (computed from training data):
Metric          Mean         Std          Min          Max         
---------------------------------------------------------------
qpd             13.3609      38.0344      0.0323       519.9677    
answer_rate     0.4403       0.1489       0.0000       1.0000      
retention       0.2937       0.0974       0.0000       0.6364      
  Normalization statistics computed
  These will be used to standardize all targets during training
  Val/test datasets will use these SAME statistics (no data leakage)


  Validation dataset will use training normalization statistics
  This is standard practice in ML to prevent data leakage.


In [None]:
# Inspect a sample to verify data structure
if len(train_dataset) > 0:
    print("\nInspecting first sample...")
    sample_graphs, sample_targets = train_dataset[0]
    print(f"  Sequence length: {len(sample_graphs)} monthly graphs")

    g = sample_graphs[0]
    print(f"\n  Graph structure:")
    print(f"    Users: {g['user'].x.shape}")
    print(f"    Tags: {g['tag'].x.shape}")

    print(f"\n  Target metrics:")
    for key, value in sample_targets.items():
        print(f"    {key}: {value:.4f}")


Inspecting first sample...
  Sequence length: 12 monthly graphs

  Graph structure:
    Users: torch.Size([228, 5])
    Tags: torch.Size([283, 7])

  Target metrics:
    qpd: -0.2391
    answer_rate: -1.8552
    retention: -1.6198
    growth: -0.3421


## 4. Save Dataset Configuration

In [None]:
# ============================================================
# Save Dataset Configuration for Baseline Comparison
# ============================================================

import pickle
import os

# Create results directory
results_dir = f"{PROJECT_DIR}/results" if IN_COLAB else "../results"
os.makedirs(results_dir, exist_ok=True)

# Save everything needed for fair comparison
baseline_config = {
    'train_samples': train_dataset.samples,      # Exact sample indices
    'val_samples': val_dataset.samples,          # Exact val samples
    'norm_stats': train_dataset.norm_stats,      # Normalization statistics
}

save_path = f"{results_dir}/baseline_config.pkl"
with open(save_path, 'wb') as f:
    pickle.dump(baseline_config, f)

print(f"Saved baseline configuration to: {save_path}")
print(f"\nSaved:")
print(f"  - {len(baseline_config['train_samples'])} training samples")
print(f"  - {len(baseline_config['val_samples'])} validation samples")
print(f"  - Normalization stats for {len(baseline_config['norm_stats'])} metrics")
print(f"\nUse this file to ensure fair baseline comparison!")

Saved baseline configuration to: /content/drive/MyDrive/dont-thread-on-me/results/baseline_config.pkl

Saved:
  - 1834 training samples
  - 338 validation samples
  - Normalization stats for 3 metrics

Use this file to ensure fair baseline comparison!
