In [None]:
# ============================================================================
# MULTI-ENCODER TRAINING NOTEBOOK
# Trains GraphSAGE models with 5 different sentence encoders
# ============================================================================

# ============================================================================
# SECTION 1: Installs and Imports
# ============================================================================
!pip install --quiet ogb torch torchvision torchaudio torch-geometric sentence-transformers tqdm scikit-learn matplotlib
!pip install torch-geometric ogb -q

In [None]:
# Download dataset
!mkdir -p data/ogbn-arxiv/raw
!wget -O data/ogbn-arxiv/raw/titleabs.tsv https://snap.stanford.edu/ogb/data/misc/ogbn_arxiv/titleabs.tsv

In [None]:
import os
import json
import torch
import numpy as np
from tqdm.auto import tqdm
import torch.nn.functional as F
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import precision_score, recall_score, f1_score

# PyG / OGB
from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric.data.data import DataEdgeAttr, DataTensorAttr
from torch_geometric.data.storage import GlobalStorage
from torch_geometric.nn.models import GraphSAGE
from sentence_transformers import SentenceTransformer

print("‚úì All imports successful")

In [None]:
# ============================================================================
# SECTION 2: Dataset Preparation
# ============================================================================
print("\n" + "="*80)
print("SECTION 2: Dataset Preparation")
print("="*80)

try:
    with torch.serialization.safe_globals([DataEdgeAttr, DataTensorAttr, GlobalStorage]):
        dataset = PygNodePropPredDataset(name="ogbn-arxiv", root="data/ogbn-arxiv")
    num_nodes = dataset[0].num_nodes
except:
    dataset = PygNodePropPredDataset(name="ogbn-arxiv", root="data/ogbn-arxiv")
    num_nodes = dataset[0].num_nodes

TSV_PATH = 'data/ogbn-arxiv/raw/titleabs.tsv'
TEXTS_OUTPUT_PATH = 'data/arxiv_texts.txt'

# Load OGB ID -> MAG ID mapping
ogb_id_to_mag_id_file = 'data/ogbn-arxiv/ogbn_arxiv/mapping/nodeidx2paperid.csv.gz'
assert os.path.exists(ogb_id_to_mag_id_file), "Missing OGB mapping file"

ogb_id_map = pd.read_csv(ogb_id_to_mag_id_file)
ogb_id_map.columns = ['ogb_id', 'mag_id']

# Load raw text data
print("Loading raw text TSV...")
raw_texts_df = pd.read_csv(TSV_PATH, sep='\t', header=None,
                           names=['mag_id', 'title', 'abstract'],
                           on_bad_lines='skip')

# Merge and align
print("Aligning OGB Node IDs with MAG Texts...")
merged_df = pd.merge(ogb_id_map, raw_texts_df, on='mag_id', how='left')
merged_df = merged_df.sort_values(by='ogb_id')
merged_df['full_text'] = merged_df['title'].fillna('') + ' ' + merged_df['abstract'].fillna('')

texts_list = merged_df['full_text'].tolist()

with open(TEXTS_OUTPUT_PATH, 'w', encoding='utf-8') as f:
    for text in texts_list:
        f.write(text.strip() + '\n')

print(f"‚úì Created text file: {TEXTS_OUTPUT_PATH}")
print(f"‚úì Total texts: {len(texts_list)}")


In [None]:
# ============================================================================
# SECTION 3: Load Graph Structure
# ============================================================================
print("\n" + "="*80)
print("SECTION 3: Load Graph Structure")
print("="*80)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# Load dataset
with torch.serialization.safe_globals([DataEdgeAttr, DataTensorAttr, GlobalStorage]):
    dataset = PygNodePropPredDataset(name="ogbn-arxiv", root="data/ogbn-arxiv")
    data = dataset[0]

num_nodes = data.num_nodes
num_classes = int(dataset.num_classes)
split_idx = dataset.get_idx_split()
train_idx = split_idx["train"]
val_idx = split_idx["valid"]
test_idx = split_idx["test"]

# Create masks
train_mask = torch.zeros(num_nodes, dtype=torch.bool)
train_mask[train_idx] = True
val_mask = torch.zeros(num_nodes, dtype=torch.bool)
val_mask[val_idx] = True
test_mask = torch.zeros(num_nodes, dtype=torch.bool)
test_mask[test_idx] = True

train_mask = train_mask.to(device)
val_mask = val_mask.to(device)
test_mask = test_mask.to(device)
data.edge_index = data.edge_index.to(device)
data.y = data.y.squeeze().to(device)

print(f"‚úì Nodes: {num_nodes}, Classes: {num_classes}")
print(f"‚úì Train: {train_idx.shape[0]}, Val: {val_idx.shape[0]}, Test: {test_idx.shape[0]}")

# Load texts
with open(TEXTS_OUTPUT_PATH, "r", encoding="utf8") as f:
    texts = [line.strip() for line in f]

print(f"‚úì Loaded {len(texts)} texts")

In [None]:
# ============================================================================
# SECTION 4: Encoder Configurations
# ============================================================================
print("\n" + "="*80)
print("SECTION 4: Encoder Configurations")
print("="*80)

ENCODER_CONFIGS = {
    'minilm': {
        'model_name': 'sentence-transformers/all-MiniLM-L6-v2',
        'dim': 384,
        'description': 'Baseline (fast, lightweight)'
    },
    'mpnet': {
        'model_name': 'sentence-transformers/all-mpnet-base-v2',
        'dim': 768,
        'description': 'Industry standard'
    },
    'e5-base': {
        'model_name': 'intfloat/e5-base-v2',
        'dim': 768,
        'description': 'E5: LLM-based SOTA'
    },
    'e5-large-multilingual': {
        'model_name': 'intfloat/multilingual-e5-large',
        'dim': 1024,
        'description': 'E5: Multilingual large'
    },
    'paraphrase-multilingual': {
        'model_name': 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2',
        'dim': 768,
        'description': 'SBERT multilingual'
    }
}

print("Available encoders:")
for key, config in ENCODER_CONFIGS.items():
    print(f"  - {key}: {config['model_name']} (dim={config['dim']}) - {config['description']}")

In [None]:
# ============================================================================
# SECTION 5: Model Definition
# ============================================================================
print("\n" + "="*80)
print("SECTION 5: Model Definition")
print("="*80)

class SAGEModel(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, num_classes):
        super().__init__()
        self.sage_model = GraphSAGE(
            in_channels=in_channels,
            hidden_channels=hidden_channels,
            num_layers=2,
            out_channels=num_classes,
            dropout=0.5,
            act='relu'
        )

    def forward(self, x, edge_index):
        return self.sage_model(x, edge_index)

print("‚úì Model class defined")


In [None]:
# ============================================================================
# SECTION 6: Training Functions
# ============================================================================
print("\n" + "="*80)
print("SECTION 6: Training Functions")
print("="*80)

def train_step(model, optimizer, x, edge_index, y, train_mask):
    """Single training step."""
    model.train()
    optimizer.zero_grad()
    out = model(x, edge_index)
    loss = F.cross_entropy(out[train_mask], y[train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

@torch.no_grad()
def eval_epoch(model, x, edge_index, y, train_mask, val_mask, test_mask):
    """Evaluate on all splits."""
    model.eval()
    logits = model(x, edge_index)
    preds = logits.argmax(dim=1)

    train_acc = (preds[train_mask] == y[train_mask]).float().mean().item()
    val_acc = (preds[val_mask] == y[val_mask]).float().mean().item()
    test_acc = (preds[test_mask] == y[test_mask]).float().mean().item()

    return train_acc, val_acc, test_acc

@torch.no_grad()
def compute_detailed_metrics(model, x, edge_index, y, test_mask):
    """Compute precision, recall, F1 on test set."""
    model.eval()
    logits = model(x, edge_index)
    preds = logits.argmax(dim=1)

    y_true = y[test_mask].cpu().numpy()
    y_pred = preds[test_mask].cpu().numpy()

    precision = precision_score(y_true, y_pred, average='weighted', zero_division=0)
    recall = recall_score(y_true, y_pred, average='weighted', zero_division=0)
    f1 = f1_score(y_true, y_pred, average='weighted', zero_division=0)
    accuracy = (preds[test_mask] == y[test_mask]).float().mean().item()

    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1
    }

print("‚úì Training functions defined")

In [None]:
# ============================================================================
# SECTION 7: Main Training Loop for All Encoders
# ============================================================================
print("\n" + "="*80)
print("SECTION 7: Training All Encoders")
print("="*80)

# Create directories
os.makedirs('embeddings', exist_ok=True)
os.makedirs('models', exist_ok=True)
os.makedirs('results', exist_ok=True)
os.makedirs('plots', exist_ok=True)

all_training_results = {}

for encoder_key, encoder_config in ENCODER_CONFIGS.items():
    print("\n" + "="*80)
    print(f"TRAINING WITH: {encoder_key}")
    print(f"Model: {encoder_config['model_name']}")
    print(f"Description: {encoder_config['description']}")
    print("="*80)

    # File paths
    emb_path = f"embeddings/{encoder_key}_clean.pt"
    model_path = f"models/{encoder_key}_model.pt"
    results_path = f"results/{encoder_key}_training_results.json"

    # Step 1: Create embeddings
    if os.path.exists(emb_path):
        print(f"‚úì Loading cached embeddings from {emb_path}")
        node_embeddings = torch.load(emb_path, map_location=device)
    else:
        print(f"Creating embeddings with {encoder_key}...")
        encoder = SentenceTransformer(encoder_config['model_name'], device=device)
        encoder.eval()

        # Handle E5 models (require "query: " prefix)
        texts_to_encode = texts
        if 'e5' in encoder_config['model_name'].lower():
            print("  (Adding 'query: ' prefix for E5 model)")
            texts_to_encode = [f"query: {text}" for text in texts]

        batch_size = 64
        all_embs = []
        for start in tqdm(range(0, num_nodes, batch_size), desc="Encoding"):
            batch_texts = texts_to_encode[start:start+batch_size]
            with torch.no_grad():
                embs = encoder.encode(
                    batch_texts,
                    convert_to_tensor=True,
                    batch_size=batch_size,
                    show_progress_bar=False
                )
            all_embs.append(embs)

        node_embeddings = torch.cat(all_embs, dim=0)
        torch.save(node_embeddings.cpu(), emb_path)
        print(f"‚úì Saved embeddings to {emb_path}")

    node_embeddings = node_embeddings.to(device)
    print(f"Embeddings shape: {node_embeddings.shape}")

    # Step 2: Initialize model
    model = SAGEModel(
        in_channels=node_embeddings.size(1),
        hidden_channels=128,
        num_classes=num_classes
    ).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

    # Step 3: Training loop
    print(f"\nTraining model for {encoder_key}...")

    loss_history = []
    train_acc_history = []
    val_acc_history = []
    test_acc_history = []

    best_val = 0.0
    best_test_for_best_val = 0.0
    patience = 0
    max_patience = 20

    for epoch in range(1, 201):
        loss = train_step(model, optimizer, node_embeddings, data.edge_index,
                         data.y, train_mask)
        train_acc, val_acc, test_acc = eval_epoch(model, node_embeddings,
                                                   data.edge_index, data.y,
                                                   train_mask, val_mask, test_mask)

        loss_history.append(loss)
        train_acc_history.append(train_acc)
        val_acc_history.append(val_acc)
        test_acc_history.append(test_acc)

        if val_acc > best_val:
            best_val = val_acc
            best_test_for_best_val = test_acc
            torch.save(model.state_dict(), model_path)
            patience = 0
        else:
            patience += 1

        if patience > max_patience:
            print(f"  Early stopping at epoch {epoch}")
            break

        if epoch % 10 == 0:
            print(f"  Epoch {epoch:03d} | Loss {loss:.4f} | Train {train_acc:.4f} | "
                  f"Val {val_acc:.4f} | Test {test_acc:.4f}")

    # Step 4: Load best model and compute final metrics
    print(f"\nLoading best model for {encoder_key}...")
    model.load_state_dict(torch.load(model_path, map_location=device))

    final_metrics = compute_detailed_metrics(model, node_embeddings, data.edge_index,
                                            data.y, test_mask)

    print(f"\n{'='*50}")
    print(f"FINAL RESULTS FOR {encoder_key.upper()}")
    print(f"{'='*50}")
    print(f"Best Val Accuracy:  {best_val:.4f}")
    print(f"Test Accuracy:      {final_metrics['accuracy']:.4f}")
    print(f"Test Precision:     {final_metrics['precision']:.4f}")
    print(f"Test Recall:        {final_metrics['recall']:.4f}")
    print(f"Test F1 Score:      {final_metrics['f1']:.4f}")
    print(f"{'='*50}\n")

    # Step 5: Save results
    training_results = {
        'encoder': encoder_key,
        'model_name': encoder_config['model_name'],
        'embedding_dim': encoder_config['dim'],
        'best_val_accuracy': best_val,
        'test_accuracy': final_metrics['accuracy'],
        'test_precision': final_metrics['precision'],
        'test_recall': final_metrics['recall'],
        'test_f1': final_metrics['f1'],
        'training_epochs': len(loss_history),
        'final_train_accuracy': train_acc_history[-1],
        'final_val_accuracy': val_acc_history[-1]
    }

    with open(results_path, 'w') as f:
        json.dump(training_results, f, indent=2)

    all_training_results[encoder_key] = training_results

    # Step 6: Create plots
    print(f"Creating plots for {encoder_key}...")

    # Loss curve
    plt.figure(figsize=(10, 5))
    plt.plot(loss_history, label='Training Loss', color='red', linewidth=2)
    plt.title(f'Training Loss - {encoder_key}', fontsize=14, fontweight='bold')
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('Loss', fontsize=12)
    plt.legend(fontsize=10)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(f'plots/{encoder_key}_training_loss.png', dpi=300)
    plt.close()

    # Accuracy curves
    plt.figure(figsize=(10, 5))
    plt.plot(train_acc_history, label='Train Acc', color='blue', linewidth=2)
    plt.plot(val_acc_history, label='Val Acc', color='orange', linewidth=2)
    plt.plot(test_acc_history, label='Test Acc', color='green', linewidth=2)
    plt.title(f'Accuracy Curves - {encoder_key}', fontsize=14, fontweight='bold')
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('Accuracy', fontsize=12)
    plt.legend(fontsize=10)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(f'plots/{encoder_key}_accuracy_curves.png', dpi=300)
    plt.close()

    print(f"‚úì Plots saved to plots/{encoder_key}_*.png")

In [None]:
# ============================================================================
# SECTION 8: Comparison Summary
# ============================================================================
print("\n" + "="*80)
print("SECTION 8: Training Summary Across All Encoders")
print("="*80)

# Create comparison DataFrame
comparison_df = pd.DataFrame(all_training_results).T
comparison_df = comparison_df[['model_name', 'embedding_dim', 'best_val_accuracy',
                                'test_accuracy', 'test_precision', 'test_recall', 'test_f1']]

print("\nüìä PERFORMANCE COMPARISON:")
print(comparison_df.to_string())

# Save comparison
comparison_df.to_csv('results/encoder_comparison.csv')
print("\n‚úì Comparison saved to results/encoder_comparison.csv")

# Create comparison bar chart
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

metrics = ['test_accuracy', 'test_precision', 'test_recall', 'test_f1']
titles = ['Test Accuracy', 'Test Precision', 'Test Recall', 'Test F1 Score']
colors = ['#3498db', '#e74c3c', '#2ecc71', '#f39c12']

for idx, (metric, title, color) in enumerate(zip(metrics, titles, colors)):
    ax = axes[idx // 2, idx % 2]
    values = [all_training_results[key][metric] for key in ENCODER_CONFIGS.keys()]
    bars = ax.bar(range(len(ENCODER_CONFIGS)), values, color=color, alpha=0.7, edgecolor='black')
    ax.set_xticks(range(len(ENCODER_CONFIGS)))
    ax.set_xticklabels(ENCODER_CONFIGS.keys(), rotation=45, ha='right')
    ax.set_ylabel(title, fontsize=11, fontweight='bold')
    ax.set_ylim([min(values) - 0.05, max(values) + 0.05])
    ax.grid(True, alpha=0.3, axis='y')

    # Add value labels on bars
    for bar, val in zip(bars, values):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{val:.4f}', ha='center', va='bottom', fontsize=9)

plt.suptitle('Encoder Performance Comparison', fontsize=16, fontweight='bold', y=0.995)
plt.tight_layout()
plt.savefig('plots/encoder_comparison.png', dpi=300, bbox_inches='tight')
plt.close()

print("‚úì Comparison plot saved to plots/encoder_comparison.png")

# Find best encoder
best_encoder = max(all_training_results.items(),
                   key=lambda x: x[1]['test_accuracy'])

print(f"\nüèÜ BEST PERFORMING ENCODER:")
print(f"   {best_encoder[0].upper()}")
print(f"   Test Accuracy: {best_encoder[1]['test_accuracy']:.4f}")
print(f"   Test F1:       {best_encoder[1]['test_f1']:.4f}")

print("\n" + "="*80)
print("‚úÖ TRAINING COMPLETE FOR ALL ENCODERS!")
print("="*80)
print("\nGenerated files:")
print(f"  üìÅ embeddings/  - {len(ENCODER_CONFIGS)} embedding files")
print(f"  üìÅ models/      - {len(ENCODER_CONFIGS)} model checkpoints")
print(f"  üìÅ results/     - {len(ENCODER_CONFIGS)} result JSON files + comparison CSV")
print(f"  üìÅ plots/       - {len(ENCODER_CONFIGS)*2 + 1} plot images")

In [None]:
!zip -r /content/models.zip /content/models


In [None]:
!zip -r /content/plots.zip /content/plots

In [None]:
!zip -r /content/results.zip /content/results

In [None]:
!zip -r /content/embeddings.zip /content/embeddings