# GIN Link Prediction - Multi-Dataset Testing

Test GIN model on multiple SemanticGraph datasets from Science4Cast competition.


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pickle
import random
import time
import os
from datetime import date
import matplotlib.pyplot as plt
import pandas as pd

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")


In [None]:
# ============================================
# CONFIGURATION
# ============================================

# Kaggle dataset path
DATASETS_FOLDER = "/kaggle/input/science4cast/Science4Cast_18datasets"

# Datasets to test (first 5)
DATASET_FILES = [
    "SemanticGraph_delta_1_cutoff_25_minedge_1.pkl",
    "SemanticGraph_delta_1_cutoff_25_minedge_3.pkl", 
    "SemanticGraph_delta_1_cutoff_5_minedge_1.pkl",
    "SemanticGraph_delta_3_cutoff_25_minedge_1.pkl",
    "SemanticGraph_delta_3_cutoff_25_minedge_3.pkl",
]

# Model hyperparameters
HIDDEN_DIM = 32
NUM_LAYERS = 2
EDGE_HIDDEN = 32
DROPOUT = 0.5

# Training hyperparameters
EPOCHS = 10
BATCH_SIZE = 1024
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-5

# Other
TRAIN_RATIO = 0.9
SEED = 42
NUM_OF_VERTICES = 64719

# Device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {DEVICE}")
print(f"Datasets: {len(DATASET_FILES)}")


In [None]:
# ============================================
# MODEL DEFINITION
# ============================================

class GINConvLayer(nn.Module):
    def __init__(self, in_dim, out_dim, eps=0.0, train_eps=True):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, out_dim), nn.BatchNorm1d(out_dim), nn.ReLU(),
            nn.Linear(out_dim, out_dim), nn.BatchNorm1d(out_dim))
        self.eps = nn.Parameter(torch.tensor([eps])) if train_eps else eps
    
    def forward(self, x, edge_index):
        row, col = edge_index
        agg = torch.zeros(x.size(0), x.size(1), device=x.device)
        agg.index_add_(0, row, x[col])
        return self.mlp((1 + self.eps) * x + agg)

class GINEncoder(nn.Module):
    def __init__(self, in_dim, hidden_dim=64, num_layers=3, dropout=0.5):
        super().__init__()
        self.convs = nn.ModuleList([GINConvLayer(in_dim if i==0 else hidden_dim, hidden_dim) 
                                     for i in range(num_layers)])
        self.dropout = dropout
        self.num_layers = num_layers
    
    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i < self.num_layers - 1:
                x = F.relu(F.dropout(x, self.dropout, self.training))
        return x

class GINLinkPredictor(nn.Module):
    def __init__(self, in_dim, hidden_dim=64, num_layers=3, edge_hidden=64, dropout=0.5):
        super().__init__()
        self.encoder = GINEncoder(in_dim, hidden_dim, num_layers, dropout)
        self.decoder = nn.Sequential(nn.Linear(2*hidden_dim, edge_hidden), nn.ReLU(), nn.Linear(edge_hidden, 1))
    
    def forward(self, x, edge_index, pairs):
        z = self.encoder(x, edge_index)
        return self.decoder(torch.cat([z[pairs[:,0]], z[pairs[:,1]]], dim=1)).squeeze(-1)

print("Model defined!")


In [None]:
# ============================================
# UTILITY FUNCTIONS
# ============================================

def set_seed(seed):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)

def load_data(path):
    with open(path, "rb") as f:
        return pickle.load(f)

def build_edge_index(edges, year_start):
    cutoff = (date(year_start, 12, 31) - date(1990, 1, 1)).days
    e = edges[edges[:, 2] < cutoff][:, :2]
    idx = [[int(a), int(b)] for a, b in e] + [[int(b), int(a)] for a, b in e]
    return torch.tensor(idx, dtype=torch.long).t().contiguous()

def split_data(pairs, labels, ratio=0.9, seed=42):
    np.random.seed(seed)
    pos, neg = np.where(labels==1)[0], np.where(labels==0)[0]
    np.random.shuffle(pos); np.random.shuffle(neg)
    p_tr, n_tr = int(len(pos)*ratio), int(len(neg)*ratio)
    tr_idx = np.concatenate([pos[:p_tr], neg[:n_tr]])
    te_idx = np.concatenate([pos[p_tr:], neg[n_tr:]])
    np.random.shuffle(tr_idx); np.random.shuffle(te_idx)
    return pairs[tr_idx], labels[tr_idx], pairs[te_idx], labels[te_idx]

def compute_auc(pred, labels):
    idx = np.argsort(-pred)
    labs = labels[idx]
    n_pos, n_neg = labs.sum(), len(labs) - labs.sum()
    if n_pos == 0 or n_neg == 0: return 0.5
    fp, rank_sum = 0, 0.0
    for l in labs:
        if l == 0: fp += 1
        else: rank_sum += fp
    return 1.0 - rank_sum / (n_pos * n_neg)

def train_epoch(model, x, edge_idx, pairs, labels, opt, crit, bs, dev):
    model.train()
    perm = torch.randperm(len(pairs), device=dev)
    pairs, labels = pairs[perm], labels[perm]
    loss_sum, n = 0, 0
    with torch.no_grad():
        z = model.encoder(x, edge_idx)
    for i in range(0, len(pairs), bs):
        bp, bl = pairs[i:i+bs], labels[i:i+bs].float()
        opt.zero_grad()
        logits = model.decoder(torch.cat([z[bp[:,0]], z[bp[:,1]]], dim=1)).squeeze(-1)
        loss = crit(logits, bl)
        loss.backward()
        opt.step()
        loss_sum += loss.item(); n += 1
    return loss_sum / n

@torch.no_grad()
def evaluate(model, x, edge_idx, pairs, labels, bs, dev):
    model.eval()
    z = model.encoder(x, edge_idx)
    logits = []
    for i in range(0, len(pairs), bs):
        bp = pairs[i:i+bs]
        logits.append(model.decoder(torch.cat([z[bp[:,0]], z[bp[:,1]]], dim=1)).squeeze(-1).cpu())
    return compute_auc(torch.cat(logits).numpy(), labels.cpu().numpy())

print("Functions defined!")


In [None]:
# ============================================
# TRAIN ON MULTIPLE DATASETS
# ============================================

set_seed(SEED)
results = []

for dataset_file in DATASET_FILES:
    print("\n" + "="*70)
    print(f"DATASET: {dataset_file}")
    print("="*70)
    
    path = os.path.join(DATASETS_FOLDER, dataset_file)
    
    try:
        # Load
        data = load_data(path)
        edges, pairs, labels, year, delta, cutoff, min_e = data
        print(f"  delta={delta}, cutoff={cutoff}, min_edges={min_e}")
        print(f"  pairs={len(pairs)}, positive={labels.sum()}")
        
        # Build graph
        edge_idx = build_edge_index(edges, year).to(DEVICE)
        print(f"  edges={edge_idx.size(1)}")
        
        # Split
        tr_p, tr_l, te_p, te_l = split_data(pairs, labels, TRAIN_RATIO, SEED)
        tr_p = torch.tensor(tr_p, dtype=torch.long, device=DEVICE)
        tr_l = torch.tensor(tr_l, dtype=torch.long, device=DEVICE)
        te_p = torch.tensor(te_p, dtype=torch.long, device=DEVICE)
        te_l = torch.tensor(te_l, dtype=torch.long, device=DEVICE)
        
        # Features
        x = torch.ones(NUM_OF_VERTICES, 1, device=DEVICE)
        
        # Model
        model = GINLinkPredictor(1, HIDDEN_DIM, NUM_LAYERS, EDGE_HIDDEN, DROPOUT).to(DEVICE)
        opt = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
        crit = nn.BCEWithLogitsLoss()
        
        # Train
        best_auc = 0
        t0 = time.time()
        for ep in range(1, EPOCHS+1):
            loss = train_epoch(model, x, edge_idx, tr_p, tr_l, opt, crit, BATCH_SIZE, DEVICE)
            auc = evaluate(model, x, edge_idx, te_p, te_l, BATCH_SIZE, DEVICE)
            if auc > best_auc: best_auc = auc
            if ep % 5 == 0 or ep == 1:
                print(f"  Epoch {ep:2d}: loss={loss:.4f}, AUC={auc:.4f}")
        
        final_auc = evaluate(model, x, edge_idx, te_p, te_l, BATCH_SIZE, DEVICE)
        elapsed = time.time() - t0
        
        results.append({
            'dataset': dataset_file, 'delta': delta, 'cutoff': cutoff,
            'min_edges': min_e, 'best_auc': best_auc, 'final_auc': final_auc, 'time': elapsed
        })
        print(f"\n  ✓ FINAL AUC: {final_auc:.4f} (best: {best_auc:.4f}) in {elapsed:.1f}s")
        
        # Cleanup
        del model, opt, tr_p, tr_l, te_p, te_l, edge_idx, x
        torch.cuda.empty_cache()
        
    except Exception as e:
        print(f"  ✗ ERROR: {e}")
        results.append({'dataset': dataset_file, 'delta': None, 'cutoff': None,
                       'min_edges': None, 'best_auc': None, 'final_auc': None, 'time': None})

print("\n" + "="*70)
print("DONE!")
print("="*70)


In [None]:
# ============================================
# RESULTS SUMMARY
# ============================================

df = pd.DataFrame(results)
print("\n" + "="*70)
print("SUMMARY RESULTS")
print("="*70)
print(df.to_string(index=False))

# Stats
valid = [r['final_auc'] for r in results if r['final_auc']]
if valid:
    print(f"\nAverage AUC: {np.mean(valid):.4f}")
    print(f"Min AUC: {np.min(valid):.4f}")
    print(f"Max AUC: {np.max(valid):.4f}")

# Plot
if len(valid) > 1:
    fig, ax = plt.subplots(figsize=(10, 5))
    names = [r['dataset'].replace('SemanticGraph_', '').replace('.pkl', '') for r in results if r['final_auc']]
    ax.bar(range(len(valid)), valid, color='steelblue')
    ax.set_xticks(range(len(valid)))
    ax.set_xticklabels(names, rotation=45, ha='right')
    ax.set_ylabel('AUC')
    ax.set_title('GIN Link Prediction Results')
    ax.axhline(0.5, color='red', linestyle='--', label='Random')
    ax.set_ylim(0.4, 1.0)
    ax.legend()
    plt.tight_layout()
    plt.show()
