# Model Comparison on Test Set

Loads trained models and evaluates on **test set** (2022-10 to 2023-09).

In [1]:
import sys
import pickle
import json
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch_geometric.nn import SAGEConv, HeteroConv
from pathlib import Path
from typing import Dict, List
from datetime import datetime
from dateutil.relativedelta import relativedelta
from tqdm.auto import tqdm
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    GRAPH_DIR = Path('/content/drive/MyDrive/CS224W/graphs')
    PROJECT_DIR = Path('/content/drive/MyDrive/CS224W')
else:
    GRAPH_DIR = Path('../data/processed/graphs')
    PROJECT_DIR = Path('..')

RESULTS_DIR = PROJECT_DIR / 'results'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

Device: cpu


In [2]:
# ============================================================================
# MODEL CONFIGURATION - Add your models here
# ============================================================================
MODELS = {
    "Baseline (Linear + RF)": "../results/baseline/baseline_models.pkl",
    "Temporal GNN": "../results/temporal_gnn.pt",
    # "Baseline GNN": "../results/baseline_gnn.pt",
}

for name, path in MODELS.items():
    exists = Path(path).exists()
    print(f"{name}: {path} {'✓' if exists else '✗ NOT FOUND'}")

Baseline (Linear + RF): ../results/baseline/baseline_models.pkl ✓
Temporal GNN: ../results/temporal_gnn.pt ✓


## 1. GNN Model Definition

In [3]:
class TemporalCommunityGNN(nn.Module):
    def __init__(self, user_feat_dim: int, tag_feat_dim: int, hidden_dim: int = 128,
                 num_conv_layers: int = 2, num_transformer_layers: int = 2,
                 num_attention_heads: int = 4, dropout: float = 0.1,
                 transformer_ffn_dim: int = 256):
        super().__init__()
        self.hidden_dim = hidden_dim
        
        self.user_norm = nn.LayerNorm(user_feat_dim)
        self.tag_norm = nn.LayerNorm(tag_feat_dim)
        self.user_proj = nn.Linear(user_feat_dim, hidden_dim)
        self.tag_proj = nn.Linear(tag_feat_dim, hidden_dim)
        
        self.convs = nn.ModuleList()
        for _ in range(num_conv_layers):
            conv = HeteroConv({
                ("tag", "cooccurs", "tag"): SAGEConv(hidden_dim, hidden_dim, aggr="mean"),
                ("user", "contributes", "tag"): SAGEConv((hidden_dim, hidden_dim), hidden_dim, aggr="mean"),
                ("tag", "contributed_to_by", "user"): SAGEConv((hidden_dim, hidden_dim), hidden_dim, aggr="mean"),
            }, aggr="mean")
            self.convs.append(conv)
        
        self.conv_dropout = nn.Dropout(dropout)
        
        community_emb_dim = 2 * hidden_dim
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=community_emb_dim, nhead=num_attention_heads,
            dim_feedforward=transformer_ffn_dim, dropout=dropout, batch_first=True
        )
        self.temporal_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_transformer_layers)
        
        self.qpd_head = nn.Linear(community_emb_dim, 1)
        self.ansrate_head = nn.Linear(community_emb_dim, 1)
        self.retention_head = nn.Linear(community_emb_dim, 1)

    def forward_single_graph(self, x_dict, edge_index_dict):
        projected = {}
        if "user" in x_dict:
            projected["user"] = self.user_proj(self.user_norm(x_dict["user"]))
        if "tag" in x_dict:
            projected["tag"] = self.tag_proj(self.tag_norm(x_dict["tag"]))
        
        x = projected
        for conv in self.convs:
            x = conv(x, edge_index_dict)
            x = {k: torch.relu(v) for k, v in x.items()}
            x = {k: self.conv_dropout(v) for k, v in x.items()}
        
        user_pooled = x["user"].mean(dim=0)
        tag_pooled = x["tag"].mean(dim=0)
        return torch.cat([user_pooled, tag_pooled])

    def forward(self, batch_monthly_graphs):
        batch_embeddings = []
        for community_graphs in batch_monthly_graphs:
            monthly_embs = []
            for graph in community_graphs:
                if isinstance(graph, tuple):
                    x_dict, edge_index_dict = graph
                else:
                    x_dict = graph.x_dict
                    edge_index_dict = graph.edge_index_dict
                emb = self.forward_single_graph(x_dict, edge_index_dict)
                monthly_embs.append(emb)
            batch_embeddings.append(torch.stack(monthly_embs))
        
        batch_embeddings = torch.stack(batch_embeddings)
        temporal_out = self.temporal_encoder(batch_embeddings)
        final_repr = temporal_out[:, -1, :]
        
        return {
            "qpd": self.qpd_head(final_repr).squeeze(-1),
            "answer_rate": self.ansrate_head(final_repr).squeeze(-1),
            "retention": self.retention_head(final_repr).squeeze(-1),
        }

## 2. Test Dataset

In [4]:
class TestDataset:
    SPLIT_RANGES = {
        'train': ('2014-01', '2020-06'),
        'val':   ('2020-07', '2022-09'),
        'test':  ('2022-10', '2023-09')
    }
    
    def __init__(self, graph_dir: Path, split: str = 'test', seq_len: int = 12, horizon: int = 6):
        self.graph_dir = Path(graph_dir)
        self.split = split
        self.seq_len = seq_len
        self.horizon = horizon
        self.cache = {}
        self.samples = self._build_index()
        print(f"{split.upper()}: {len(self.samples)} samples")
    
    def _load(self, community, month):
        key = (community, month)
        if key not in self.cache:
            self.cache[key] = torch.load(self.graph_dir / community / f"{month}.pt", 
                                         weights_only=False, map_location='cpu')
        return self.cache[key]
    
    def _build_index(self):
        samples = []
        start, end = self.SPLIT_RANGES[self.split]
        
        for comm_dir in sorted(self.graph_dir.iterdir()):
            if not comm_dir.is_dir(): continue
            months = sorted([f.stem for f in comm_dir.glob('*.pt')])
            if len(months) < self.seq_len + self.horizon: continue
            
            for i, m in enumerate(months):
                if not (start <= m <= end): continue
                if i < self.seq_len - 1: continue
                target_idx = i + self.horizon
                if target_idx >= len(months): continue
                
                seq = months[i - self.seq_len + 1:i + 1]
                target = months[target_idx]
                
                if self._consecutive(seq, target):
                    samples.append({'comm': comm_dir.name, 'seq': seq, 'target': target})
        return samples
    
    def _consecutive(self, seq, target):
        try:
            dates = [datetime.strptime(m, '%Y-%m') for m in seq]
            for i in range(1, len(dates)):
                if dates[i] != dates[i-1] + relativedelta(months=1): return False
            t = datetime.strptime(target, '%Y-%m')
            return t == dates[-1] + relativedelta(months=self.horizon)
        except: return False
    
    def __len__(self): return len(self.samples)
    
    def __getitem__(self, idx):
        s = self.samples[idx]
        graphs = [self._load(s['comm'], m) for m in s['seq']]
        target = self._load(s['comm'], s['target'])
        return graphs, target.y, graphs[-1].y

test_dataset = TestDataset(GRAPH_DIR, 'test')

TEST: 2002 samples


## 3. Feature Extraction (for baselines)

In [5]:
def extract_graph_features(g):
    feats = []
    user_x = g['user'].x.numpy()
    if len(user_x) > 0:
        feats.extend(user_x.mean(0)); feats.extend(user_x.std(0)); feats.append(len(user_x))
    else:
        feats.extend([0]*11)
    
    tag_x = g['tag'].x.numpy()
    if len(tag_x) > 0:
        feats.extend(tag_x.mean(0)); feats.extend(tag_x.std(0)); feats.append(len(tag_x))
    else:
        feats.extend([0]*15)
    
    for et in [('user', 'posts_in', 'tag'), ('user', 'answers', 'user')]:
        feats.append(g[et].edge_index.shape[1] if et in g.edge_types else 0)
    
    for k in ['qpd', 'answer_rate', 'retention']:
        feats.append(float(g.y.get(k, 0)) if hasattr(g, 'y') else 0)
    
    return np.array(feats, dtype=np.float32)

def extract_seq_features(graphs):
    gf = np.array([extract_graph_features(g) for g in graphs])
    feats = list(gf[-1]) + list(gf.mean(0)) + list(gf.std(0))
    x = np.arange(len(graphs))
    for i in range(gf.shape[1]):
        feats.append(np.polyfit(x, gf[:, i], 1)[0] if np.std(gf[:, i]) > 1e-8 else 0)
    return np.array(feats, dtype=np.float32)

In [6]:
print("Extracting test features...")
X_test, y_test, y_curr = [], {'qpd': [], 'answer_rate': [], 'retention': []}, {'qpd': [], 'answer_rate': [], 'retention': []}
test_graphs_for_gnn = []

for i in tqdm(range(len(test_dataset)), desc='Test'):
    graphs, targets, current = test_dataset[i]
    X_test.append(extract_seq_features(graphs))
    test_graphs_for_gnn.append(graphs)
    for k in y_test:
        y_test[k].append(float(targets.get(k, 0)))
        y_curr[k].append(float(current.get(k, 0)))

X_test = np.array(X_test)
for k in y_test:
    y_test[k] = np.array(y_test[k])
    y_curr[k] = np.array(y_curr[k])

print(f"Test: {X_test.shape[0]} samples, {X_test.shape[1]} features")

Extracting test features...


Test:   0%|          | 0/2002 [00:00<?, ?it/s]

Test: 2002 samples, 124 features


## 4. Evaluation Functions

In [7]:
TARGETS = ['qpd', 'answer_rate', 'retention']

def evaluate(y_true, y_pred):
    return {
        'mae': mean_absolute_error(y_true, y_pred),
        'rmse': np.sqrt(mean_squared_error(y_true, y_pred)),
        'r2': r2_score(y_true, y_pred)
    }

def print_results(name, results):
    print(f"\n{'='*60}")
    print(f"{name}")
    print(f"{'='*60}")
    for t in TARGETS:
        r = results[t]
        print(f"{t:15} MAE={r['mae']:.4f}  RMSE={r['rmse']:.4f}  R²={r['r2']:.4f}")
    mean_r2 = np.mean([results[t]['r2'] for t in TARGETS])
    print(f"{'MEAN R²':15} {mean_r2:.4f}")

## 5. Load and Evaluate Models

In [None]:
all_results = {}

# Naive baseline
naive_results = {t: evaluate(y_test[t], y_curr[t]) for t in TARGETS}
all_results['Naive'] = naive_results
print_results('Naive (predict current)', naive_results)

for model_name, model_path in MODELS.items():
    model_path = Path(model_path)
    if not model_path.exists():
        print(f"\n⚠️  {model_name}: NOT FOUND at {model_path}")
        continue
    
    print(f"\n>>> Loading {model_name} from {model_path}")
    
    # ===== BASELINE MODELS (.pkl) =====
    if model_path.suffix == '.pkl':
        with open(model_path, 'rb') as f:
            saved = pickle.load(f)
        
        scaler = saved.get('scaler')
        norm_stats = saved.get('norm_stats', {})
        X_scaled = scaler.transform(X_test) if scaler else X_test
        
        # Linear Regression
        if 'linear_models' in saved:
            results = {}
            for t in TARGETS:
                if t in saved['linear_models']:
                    pred = saved['linear_models'][t].predict(X_scaled)
                    if t in norm_stats:
                        pred = pred * norm_stats[t]['std'] + norm_stats[t]['mean']
                    results[t] = evaluate(y_test[t], pred)
            if results:
                all_results['Linear Regression'] = results
                print_results('Linear Regression', results)
        
        # Random Forest
        if 'rf_models' in saved:
            results = {}
            for t in TARGETS:
                if t in saved['rf_models']:
                    pred = saved['rf_models'][t].predict(X_scaled)
                    if t in norm_stats:
                        pred = pred * norm_stats[t]['std'] + norm_stats[t]['mean']
                    results[t] = evaluate(y_test[t], pred)
            if results:
                all_results['Random Forest'] = results
                print_results('Random Forest', results)
    
    # ===== GNN MODELS (.pt) =====
    elif model_path.suffix == '.pt':
        checkpoint = torch.load(model_path, weights_only=False, map_location=device)
        state_dict = checkpoint['model_state_dict']
        
        # Infer architecture from state dict
        hidden_dim = state_dict['user_proj.weight'].shape[0]
        num_conv_layers = max(int(k.split('.')[1]) for k in state_dict if k.startswith('convs.')) + 1
        num_transformer_layers = max(int(k.split('.')[2]) for k in state_dict if 'temporal_encoder.layers' in k) + 1
        transformer_ffn_dim = state_dict['temporal_encoder.layers.0.linear1.weight'].shape[0]
        
        print(f"  Arch: hidden={hidden_dim}, conv={num_conv_layers}, transformer={num_transformer_layers}, ffn={transformer_ffn_dim}")
        
        model = TemporalCommunityGNN(
            user_feat_dim=5, tag_feat_dim=7,
            hidden_dim=hidden_dim, num_conv_layers=num_conv_layers,
            num_transformer_layers=num_transformer_layers, num_attention_heads=4,
            dropout=0.1, transformer_ffn_dim=transformer_ffn_dim
        ).to(device)
        
        model.load_state_dict(state_dict)
        model.eval()
        
        # Evaluate on test set
        preds = {t: [] for t in TARGETS}
        batch_size = 32
        
        with torch.no_grad():
            for i in tqdm(range(0, len(test_graphs_for_gnn), batch_size), desc=f'Eval {model_name}'):
                batch = test_graphs_for_gnn[i:i+batch_size]
                batch = [[g.to(device) for g in seq] for seq in batch]
                out = model(batch)
                for t in TARGETS:
                    preds[t].extend(out[t].cpu().numpy())
        
        results = {t: evaluate(y_test[t], np.array(preds[t])) for t in TARGETS}
        all_results[model_name] = results
        print_results(model_name, results)


Naive (predict current)
qpd             MAE=1.0394  RMSE=2.3570  R²=0.9859
answer_rate     MAE=0.1104  RMSE=0.1658  R²=-0.0680
retention       MAE=0.0903  RMSE=0.1407  R²=-0.0091
MEAN R²         0.3029

>>> Loading Baseline (Linear + RF) from ../results/baseline/baseline_models.pkl

Linear Regression
qpd             MAE=1.6566  RMSE=2.6704  R²=0.9819
answer_rate     MAE=0.1087  RMSE=0.1547  R²=0.0702
retention       MAE=0.0797  RMSE=0.1226  R²=0.2336
MEAN R²         0.4286

Random Forest
qpd             MAE=1.6494  RMSE=9.5714  R²=0.7671
answer_rate     MAE=0.1092  RMSE=0.1508  R²=0.1168
retention       MAE=0.0777  RMSE=0.1215  R²=0.2482
MEAN R²         0.3774

>>> Loading Temporal GNN from ../results/temporal_gnn.pt
  Arch: hidden=128, conv=3, transformer=3, ffn=256


Eval Temporal GNN:   0%|          | 0/63 [00:00<?, ?it/s]

## 6. Summary

In [None]:
# Build comparison table
rows = []
for model_name, results in all_results.items():
    for t in TARGETS:
        if t in results:
            rows.append({'Model': model_name, 'Target': t, **results[t]})

df = pd.DataFrame(rows)
print("\n" + "="*80)
print("TEST SET RESULTS")
print("="*80)
print(df.to_string(index=False))

print("\n" + "="*80)
print("MEAN R² BY MODEL")
print("="*80)
mean_df = df.groupby('Model')['r2'].mean().sort_values(ascending=False)
print(mean_df.to_string())

In [None]:
# Save results
output = {'test_samples': len(test_dataset), 'models': {}}
for name, results in all_results.items():
    output['models'][name] = {
        'metrics': {t: {k: float(v) for k, v in results[t].items()} for t in TARGETS if t in results},
        'mean_r2': float(np.mean([results[t]['r2'] for t in TARGETS if t in results]))
    }

with open(RESULTS_DIR / 'test_set_results.json', 'w') as f:
    json.dump(output, f, indent=2)
print(f"Saved to {RESULTS_DIR / 'test_set_results.json'}")