## 1. Setup Environment

In [15]:
# Check if running on Colab
import sys
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    print("Running on Google Colab")
    !pip install torch-geometric
    !pip install python-dateutil
else:
    print("Running locally")

Running locally


In [16]:
# Check GPU availability
import torch

if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"GPU available: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    
    # Check for B200/Blackwell or other high-end GPUs
    gpu_name = torch.cuda.get_device_name(0).lower()
    if 'b200' in gpu_name or 'b100' in gpu_name or 'blackwell' in gpu_name:
        print("Detected Blackwell GPU - using bfloat16 for optimal performance")
        USE_BF16 = True
    elif 'h100' in gpu_name or 'a100' in gpu_name:
        print("Detected Hopper/Ampere GPU - using bfloat16")
        USE_BF16 = True
    else:
        print("Using float16 for mixed precision")
        USE_BF16 = False
else:
    device = torch.device('cpu')
    USE_BF16 = False
    print("No GPU available, using CPU")

print(f"\nUsing device: {device}")

No GPU available, using CPU

Using device: cpu


In [17]:
# Mount Google Drive (Colab only)
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    GRAPH_DIR = '/content/drive/MyDrive/CS224W/graphs'
else:
    GRAPH_DIR = '../data/processed/graphs'

In [18]:
# Verify data directory exists
from pathlib import Path
import os

graph_dir = Path(GRAPH_DIR)

if not graph_dir.exists():
    print(f"ERROR: Graph directory not found: {graph_dir}")
else:
    communities = [d.name for d in graph_dir.iterdir() if d.is_dir()]
    print(f"Found {len(communities)} communities")
    # Show first few
    for comm in sorted(communities)[:5]:
        comm_path = graph_dir / comm
        n_graphs = len(list(comm_path.glob('*.pt')))
        print(f"  - {comm}: {n_graphs} monthly graphs")
    if len(communities) > 5:
        print(f"  ... and {len(communities) - 5} more")

Found 177 communities
  - 3dprinting.stackexchange.com: 99 monthly graphs
  - academia.stackexchange.com: 148 monthly graphs
  - ai.stackexchange.com: 92 monthly graphs
  - android.stackexchange.com: 175 monthly graphs
  - anime.stackexchange.com: 136 monthly graphs
  ... and 172 more


In [19]:
graph = torch.load("2023-05.pt", weights_only=False)

In [20]:
graph.y

{'qpd': 9.903225806451612,
 'answer_rate': 0.1758957654723127,
 'retention': 0.1589041095890411,
 'growth': 0.2644628099173554}

## 2. Define Model and Dataset Classes

In [7]:
# Imports
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch_geometric.nn import SAGEConv, HeteroConv
from pathlib import Path
from typing import List, Dict, Tuple, Optional
from datetime import datetime
from dateutil.relativedelta import relativedelta
import numpy as np
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

In [8]:
class TemporalCommunityGNN(nn.Module):
    """
    Temporal GNN for predicting community health trajectories.
    Optimized for GPU execution - no torch.compile, uses AMP-friendly operations.
    """
    
    def __init__(
        self,
        user_feat_dim: int,
        tag_feat_dim: int,
        hidden_dim: int = 128,
        num_conv_layers: int = 2,
        num_transformer_layers: int = 2,
        num_attention_heads: int = 4,
        dropout: float = 0.1,
        transformer_ffn_dim: int = 256
    ):
        super().__init__()
        
        self.hidden_dim = hidden_dim
        self.num_conv_layers = num_conv_layers
        self.dropout = dropout
        
        # Feature projection (LayerNorm is AMP-friendly)
        self.user_norm = nn.LayerNorm(user_feat_dim)
        self.tag_norm = nn.LayerNorm(tag_feat_dim)
        self.user_proj = nn.Linear(user_feat_dim, hidden_dim)
        self.tag_proj = nn.Linear(tag_feat_dim, hidden_dim)
        
        # Graph conv layers
        self.convs = nn.ModuleList()
        for _ in range(num_conv_layers):
            conv = HeteroConv({
                ("tag", "cooccurs", "tag"): SAGEConv(hidden_dim, hidden_dim, aggr="mean"),
                ("user", "contributes", "tag"): SAGEConv((hidden_dim, hidden_dim), hidden_dim, aggr="mean"),
                ("tag", "contributed_to_by", "user"): SAGEConv((hidden_dim, hidden_dim), hidden_dim, aggr="mean"),
            }, aggr="mean")
            self.convs.append(conv)
        
        self.conv_dropout = nn.Dropout(dropout)
        
        # Temporal transformer
        community_emb_dim = 2 * hidden_dim
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=community_emb_dim,
            nhead=num_attention_heads,
            dim_feedforward=transformer_ffn_dim,
            dropout=dropout,
            batch_first=True
        )
        self.temporal_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_transformer_layers)
        
        # Prediction heads
        self.qpd_head = nn.Linear(community_emb_dim, 1)
        self.ansrate_head = nn.Linear(community_emb_dim, 1)
        self.retention_head = nn.Linear(community_emb_dim, 1)

    def forward_single_graph(self, x_dict, edge_index_dict):
        """Process a single graph to get community embedding."""
        # Project features
        projected = {}
        if "user" in x_dict:
            projected["user"] = self.user_proj(self.user_norm(x_dict["user"]))
        if "tag" in x_dict:
            projected["tag"] = self.tag_proj(self.tag_norm(x_dict["tag"]))
        
        # Apply conv layers
        x = projected
        for conv in self.convs:
            x = conv(x, edge_index_dict)
            x = {k: torch.relu(v) for k, v in x.items()}
            x = {k: self.conv_dropout(v) for k, v in x.items()}
        
        # Pool to community embedding
        user_pooled = x["user"].mean(dim=0)
        tag_pooled = x["tag"].mean(dim=0)
        return torch.cat([user_pooled, tag_pooled])

    def forward(self, batch_monthly_graphs):
        """Forward pass for batched temporal sequences."""
        batch_embeddings = []
        
        for community_graphs in batch_monthly_graphs:
            monthly_embs = []
            for graph in community_graphs:
                if isinstance(graph, tuple):
                    x_dict, edge_index_dict = graph
                else:
                    x_dict = graph.x_dict
                    edge_index_dict = graph.edge_index_dict
                emb = self.forward_single_graph(x_dict, edge_index_dict)
                monthly_embs.append(emb)
            batch_embeddings.append(torch.stack(monthly_embs))
        
        # [batch, 12, emb_dim]
        batch_embeddings = torch.stack(batch_embeddings)
        
        # Temporal encoding
        temporal_out = self.temporal_encoder(batch_embeddings)
        final_repr = temporal_out[:, -1, :]
        
        return {
            "qpd": self.qpd_head(final_repr).squeeze(-1),
            "answer_rate": self.ansrate_head(final_repr).squeeze(-1),
            "retention": self.retention_head(final_repr).squeeze(-1),
        }

In [9]:
class CachedTemporalDataset(Dataset):
    """
    Dataset with in-memory caching for fast GPU training.
    Loads all data into memory once, then serves from cache.
    """
    
    def __init__(
        self,
        graph_dir: Path,
        split: str = 'train',
        sequence_length: int = 12,
        prediction_horizon: int = 6,
        max_samples: int = None,  # Limit samples for faster testing
        cache_in_memory: bool = True
    ):
        self.graph_dir = Path(graph_dir)
        self.split = split
        self.sequence_length = sequence_length
        self.prediction_horizon = prediction_horizon
        self.cache_in_memory = cache_in_memory
        self.cache = {}
        
        # Temporal splits
        self.split_ranges = {
            'train': ('2008-01', '2020-06'),
            'val': ('2020-07', '2022-09'),
            'test': ('2022-10', '2023-09')
        }
        
        self.samples = self._build_sample_index()
        if max_samples:
            self.samples = self.samples[:max_samples]
        
        print(f"{split.upper()} Dataset: {len(self.samples)} samples")
        
        # Pre-cache if requested
        if cache_in_memory and len(self.samples) > 0:
            print(f"Pre-caching {split} data...")
            self._precache_all()
    
    def _build_sample_index(self) -> List[Dict]:
        samples = []
        start_month, end_month = self.split_ranges[self.split]
        min_graphs = self.sequence_length + self.prediction_horizon
        
        for community_dir in sorted(self.graph_dir.iterdir()):
            if not community_dir.is_dir():
                continue
            
            available_months = sorted([f.stem for f in community_dir.glob('*.pt')])
            if len(available_months) < min_graphs:
                continue
            
            for i, month_t in enumerate(available_months):
                if not (start_month <= month_t <= end_month):
                    continue
                if i < self.sequence_length - 1:
                    continue
                
                target_idx = i + self.prediction_horizon
                if target_idx >= len(available_months):
                    continue
                
                seq_start = i - self.sequence_length + 1
                seq_months = available_months[seq_start:i+1]
                target_month = available_months[target_idx]
                
                # Simple consecutive check
                if self._is_consecutive(seq_months, target_month):
                    samples.append({
                        'community': community_dir.name,
                        'sequence_months': seq_months,
                        'target_month': target_month
                    })
        return samples
    
    def _is_consecutive(self, seq_months, target_month):
        """Check if months form consecutive sequence."""
        try:
            dates = [datetime.strptime(m, '%Y-%m') for m in seq_months]
            for i in range(1, len(dates)):
                if dates[i] != dates[i-1] + relativedelta(months=1):
                    return False
            target_date = datetime.strptime(target_month, '%Y-%m')
            expected = dates[-1] + relativedelta(months=self.prediction_horizon)
            return target_date == expected
        except:
            return False
    
    def _precache_all(self):
        """Load all graphs into memory."""
        for idx in tqdm(range(len(self.samples)), desc='Caching'):
            if idx not in self.cache:
                self.cache[idx] = self._load_sample(idx)
    
    def _load_sample(self, idx):
        sample = self.samples[idx]
        community_dir = self.graph_dir / sample['community']
        
        graphs = []
        for month in sample['sequence_months']:
            graph = torch.load(community_dir / f"{month}.pt", weights_only=False)
            graphs.append(graph)
        
        target_graph = torch.load(community_dir / f"{sample['target_month']}.pt", weights_only=False)
        targets = target_graph.y
        
        return graphs, targets
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        if self.cache_in_memory and idx in self.cache:
            return self.cache[idx]
        return self._load_sample(idx)

In [10]:
def collate_fn(batch):
    """Collate function for batching."""
    batch_graphs = []
    batch_targets = {'qpd': [], 'answer_rate': [], 'retention': []}
    
    for graphs, targets in batch:
        batch_graphs.append(graphs)
        for key in batch_targets:
            batch_targets[key].append(targets[key])
    
    for key in batch_targets:
        batch_targets[key] = torch.tensor(batch_targets[key], dtype=torch.float32)
    
    return batch_graphs, batch_targets

## 3. Create Datasets

In [9]:
# Create datasets - limit samples for fast proof of concept
# Set max_samples=None to use all data once you confirm it works
MAX_SAMPLES = None  # Limit for fast testing

train_dataset = CachedTemporalDataset(
    graph_dir=GRAPH_DIR,
    split='train',
    max_samples=MAX_SAMPLES,
    cache_in_memory=True
)

val_dataset = CachedTemporalDataset(
    graph_dir=GRAPH_DIR,
    split='val',
    max_samples=MAX_SAMPLES,
    cache_in_memory=True
)

print(f"\nDatasets created: {len(train_dataset)} train, {len(val_dataset)} val")

TRAIN Dataset: 13359 samples
Pre-caching train data...


Caching:   0%|          | 0/13359 [00:00<?, ?it/s]

VAL Dataset: 4484 samples
Pre-caching val data...


Caching:   0%|          | 0/4484 [00:00<?, ?it/s]


Datasets created: 13359 train, 4484 val


In [11]:
# Inspect a sample
if len(train_dataset) > 0:
    sample_graphs, sample_targets = train_dataset[0]
    print(f"Sample has {len(sample_graphs)} monthly graphs")
    g = sample_graphs[0]
    print(f"Graph structure:")
    print(f"  Users: {g['user'].x.shape}")
    print(f"  Tags: {g['tag'].x.shape}")
    print(f"\nTargets: {sample_targets}")

Sample has 12 monthly graphs
Graph structure:
  Users: torch.Size([75, 5])
  Tags: torch.Size([109, 7])

Targets: {'qpd': 1.5, 'answer_rate': 0.5777777777777777, 'retention': 0.29508196721311475, 'growth': 0.0}


In [20]:
# Create DataLoaders - use larger batch size for B200
# B200 has 192GB memory, so we can use much larger batches
BATCH_SIZE = 2048  # Increase for B200

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=0,
    pin_memory=False  # Data already in memory
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=0,
    pin_memory=False
)

print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")

Train batches: 7, Val batches: 3


## 4. Create Model

In [21]:
# Get feature dimensions from data
if len(train_dataset) > 0:
    sample_graphs, _ = train_dataset[0]
    USER_FEAT_DIM = sample_graphs[0]['user'].x.shape[1]
    TAG_FEAT_DIM = sample_graphs[0]['tag'].x.shape[1]
else:
    USER_FEAT_DIM = 5
    TAG_FEAT_DIM = 7

print(f"User features: {USER_FEAT_DIM}, Tag features: {TAG_FEAT_DIM}")

User features: 5, Tag features: 7


In [22]:
model = TemporalCommunityGNN(
    user_feat_dim=USER_FEAT_DIM,
    tag_feat_dim=TAG_FEAT_DIM,
    hidden_dim=64,
    num_conv_layers=2,
    num_transformer_layers=1,
    num_attention_heads=4,
    dropout=0.2,
    transformer_ffn_dim=64
)

model = model.to(device)

# Count parameters
n_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {n_params:,}")

Model parameters: 133,979


## 5. Training with AMP (Automatic Mixed Precision)

In [23]:
def move_to_device(batch_graphs, device):
    """Move graphs to device efficiently."""
    return [[g.to(device, non_blocking=True) for g in graphs] for graphs in batch_graphs]


def train_epoch(model, loader, optimizer, criterion, device, scaler, use_amp):
    """Train for one epoch with AMP."""
    model.train()
    total_loss = 0.0
    n_batches = 0
    
    # Determine dtype for autocast
    amp_dtype = torch.bfloat16 if USE_BF16 else torch.float16
    
    pbar = tqdm(loader, desc='Training')
    for batch_graphs, batch_targets in pbar:
        batch_graphs = move_to_device(batch_graphs, device)
        batch_targets = {k: v.to(device, non_blocking=True) for k, v in batch_targets.items()}
        
        optimizer.zero_grad(set_to_none=True)  # Faster than zero_grad()
        
        # Forward with AMP
        with torch.amp.autocast('cuda', enabled=use_amp, dtype=amp_dtype):
            predictions = model(batch_graphs)
            
            loss = (
                criterion(predictions['qpd'], batch_targets['qpd']) +
                criterion(predictions['answer_rate'], batch_targets['answer_rate']) +
                criterion(predictions['retention'], batch_targets['retention'])
            ) / 3.0
        
        # Backward with scaler
        if use_amp and not USE_BF16:  # scaler only needed for fp16
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
        
        total_loss += loss.item()
        n_batches += 1
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    return total_loss / max(n_batches, 1)


@torch.no_grad()
def evaluate(model, loader, criterion, device, use_amp):
    """Evaluate model."""
    model.eval()
    total_loss = 0.0
    n_batches = 0
    
    all_preds = {'qpd': [], 'answer_rate': [], 'retention': []}
    all_targets = {'qpd': [], 'answer_rate': [], 'retention': []}
    
    amp_dtype = torch.bfloat16 if USE_BF16 else torch.float16
    
    for batch_graphs, batch_targets in tqdm(loader, desc='Evaluating', leave=False):
        batch_graphs = move_to_device(batch_graphs, device)
        batch_targets = {k: v.to(device, non_blocking=True) for k, v in batch_targets.items()}
        
        with torch.amp.autocast('cuda', enabled=use_amp, dtype=amp_dtype):
            predictions = model(batch_graphs)
            
            loss = (
                criterion(predictions['qpd'], batch_targets['qpd']) +
                criterion(predictions['answer_rate'], batch_targets['answer_rate']) +
                criterion(predictions['retention'], batch_targets['retention'])
            ) / 3.0
        
        total_loss += loss.item()
        n_batches += 1
        
        for key in all_preds:
            all_preds[key].extend(predictions[key].float().cpu().numpy())
            all_targets[key].extend(batch_targets[key].float().cpu().numpy())
    
    # Compute R² for each metric
    r2_scores = {}
    for key in all_preds:
        preds = np.array(all_preds[key])
        targets = np.array(all_targets[key])
        ss_res = np.sum((targets - preds) ** 2)
        ss_tot = np.sum((targets - targets.mean()) ** 2)
        r2_scores[key] = 1 - ss_res / (ss_tot + 1e-8)
    
    return total_loss / max(n_batches, 1), r2_scores

In [24]:
# Training setup
LEARNING_RATE = 1e-4
NUM_EPOCHS = 5
USE_AMP = torch.cuda.is_available()  # Enable AMP on GPU

optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
criterion = nn.MSELoss()
scaler = torch.amp.GradScaler('cuda', enabled=USE_AMP and not USE_BF16)

print(f"Training config:")
print(f"  Device: {device}")
print(f"  AMP: {USE_AMP}")
print(f"  Dtype: {'bfloat16' if USE_BF16 else 'float16' if USE_AMP else 'float32'}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Epochs: {NUM_EPOCHS}")

Training config:
  Device: cuda
  AMP: True
  Dtype: bfloat16
  Batch size: 2048
  Epochs: 5


In [18]:
# Quick test - single batch forward/backward
print("Testing single batch...")

if len(train_loader) > 0:
    batch_graphs, batch_targets = next(iter(train_loader))
    batch_graphs = move_to_device(batch_graphs, device)
    batch_targets = {k: v.to(device) for k, v in batch_targets.items()}
    
    amp_dtype = torch.bfloat16 if USE_BF16 else torch.float16
    
    # Time forward pass
    import time
    torch.cuda.synchronize() if torch.cuda.is_available() else None
    start = time.time()
    
    with torch.amp.autocast('cuda', enabled=USE_AMP, dtype=amp_dtype):
        preds = model(batch_graphs)
        loss = criterion(preds['qpd'], batch_targets['qpd'])
    
    loss.backward()
    
    torch.cuda.synchronize() if torch.cuda.is_available() else None
    elapsed = time.time() - start
    
    print(f"✓ Single batch completed in {elapsed:.3f}s")
    print(f"  Loss: {loss.item():.4f}")
    if torch.cuda.is_available():
        print(f"  GPU memory: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")
else:
    print("No data in train_loader")

Testing single batch...
✓ Single batch completed in 11.431s
  Loss: 815.5268
  GPU memory: 2.34 GB


In [25]:
# Training loop
print("\n" + "="*50)
print("Starting training...")
print("="*50 + "\n")

history = {'train_loss': [], 'val_loss': [], 'val_r2': []}
best_val_loss = float('inf')

for epoch in range(NUM_EPOCHS):
    # Train
    train_loss = train_epoch(model, train_loader, optimizer, criterion, device, scaler, USE_AMP)
    history['train_loss'].append(train_loss)
    
    # Validate
    val_loss, val_r2 = evaluate(model, val_loader, criterion, device, USE_AMP)
    history['val_loss'].append(val_loss)
    history['val_r2'].append(val_r2)
    
    # Track best
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_marker = " ★"
    else:
        best_marker = ""
    
    print(f"Epoch {epoch+1:2d}/{NUM_EPOCHS} | "
          f"Train: {train_loss:.4f} | "
          f"Val: {val_loss:.4f} | "
          f"R²: qpd={val_r2['qpd']:.3f}, ans={val_r2['answer_rate']:.3f}, ret={val_r2['retention']:.3f}"
          f"{best_marker}")

print("\nTraining complete!")


Starting training...



Training:   0%|          | 0/7 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch  1/5 | Train: 593.7313 | Val: 181.3873 | R²: qpd=-0.087, ans=-1.158, ret=-74.413 ★


Training:   0%|          | 0/7 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch  2/5 | Train: 615.9460 | Val: 175.6763 | R²: qpd=-0.058, ans=-0.111, ret=-36.174 ★


Training:   0%|          | 0/7 [00:00<?, ?it/s]

KeyboardInterrupt: 

## 6. Results

In [None]:
# Plot training curves
if len(history['train_loss']) > 0:
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
    # Loss
    ax = axes[0]
    ax.plot(history['train_loss'], label='Train')
    ax.plot(history['val_loss'], label='Val')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.set_title('Training Loss')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # R² scores
    ax = axes[1]
    for key in ['qpd', 'answer_rate', 'retention']:
        r2_values = [h[key] for h in history['val_r2']]
        ax.plot(r2_values, label=key)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('R²')
    ax.set_title('Validation R² Scores')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
else:
    print("No training history to plot")

In [None]:
# Final summary
print("="*50)
print("TRAINING SUMMARY")
print("="*50)
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Peak GPU memory: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")
print(f"Batch size: {BATCH_SIZE}")
print(f"Mixed precision: {'bfloat16' if USE_BF16 else 'float16' if USE_AMP else 'float32'}")
print(f"Total samples: {len(train_dataset)} train, {len(val_dataset)} val")
print(f"\nFinal validation loss: {history['val_loss'][-1]:.4f}" if history['val_loss'] else "")
if history['val_r2']:
    final_r2 = history['val_r2'][-1]
    print(f"Final R² scores:")
    for k, v in final_r2.items():
        print(f"  {k}: {v:.4f}")

## 7. Next Steps

Once this proof of concept runs successfully:

1. **Increase data**: Set `MAX_SAMPLES = None` to use all data
2. **Increase model size**: Try `hidden_dim=128` or `256`
3. **More epochs**: Train for 20-50 epochs
4. **Add test evaluation**: Create test_dataset and evaluate final model