In [None]:
# ============================================================================
# MULTI-ENCODER TRAINING NOTEBOOK
# Trains GraphSAGE models with 5 different sentence encoders
# ============================================================================

# ============================================================================
# SECTION 1: Installs and Imports
# ============================================================================
!pip install --quiet ogb torch torchvision torchaudio torch-geometric sentence-transformers tqdm scikit-learn matplotlib
!pip install torch-geometric ogb -q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.8/78.8 kB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m68.0 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
# Download dataset
!mkdir -p data/ogbn-arxiv/raw
!wget -O data/ogbn-arxiv/raw/titleabs.tsv https://snap.stanford.edu/ogb/data/misc/ogbn_arxiv/titleabs.tsv

--2025-12-12 19:06:12--  https://snap.stanford.edu/ogb/data/misc/ogbn_arxiv/titleabs.tsv
Resolving snap.stanford.edu (snap.stanford.edu)... 171.64.75.80
Connecting to snap.stanford.edu (snap.stanford.edu)|171.64.75.80|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 209701602 (200M) [text/tab-separated-values]
Saving to: ‘data/ogbn-arxiv/raw/titleabs.tsv’


2025-12-12 19:06:28 (12.8 MB/s) - ‘data/ogbn-arxiv/raw/titleabs.tsv’ saved [209701602/209701602]



In [None]:
import os
import json
import torch
import numpy as np
from tqdm.auto import tqdm
import torch.nn.functional as F
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import precision_score, recall_score, f1_score

# PyG / OGB
from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric.data.data import DataEdgeAttr, DataTensorAttr
from torch_geometric.data.storage import GlobalStorage
from torch_geometric.nn.models import GraphSAGE
from sentence_transformers import SentenceTransformer

print("✓ All imports successful")

✓ All imports successful


In [None]:
# ============================================================================
# SECTION 2: Dataset Preparation
# ============================================================================
print("\n" + "="*80)
print("SECTION 2: Dataset Preparation")
print("="*80)

try:
    with torch.serialization.safe_globals([DataEdgeAttr, DataTensorAttr, GlobalStorage]):
        dataset = PygNodePropPredDataset(name="ogbn-arxiv", root="data/ogbn-arxiv")
    num_nodes = dataset[0].num_nodes
except:
    dataset = PygNodePropPredDataset(name="ogbn-arxiv", root="data/ogbn-arxiv")
    num_nodes = dataset[0].num_nodes

TSV_PATH = 'data/ogbn-arxiv/raw/titleabs.tsv'
TEXTS_OUTPUT_PATH = 'data/arxiv_texts.txt'

# Load OGB ID -> MAG ID mapping
ogb_id_to_mag_id_file = 'data/ogbn-arxiv/ogbn_arxiv/mapping/nodeidx2paperid.csv.gz'
assert os.path.exists(ogb_id_to_mag_id_file), "Missing OGB mapping file"

ogb_id_map = pd.read_csv(ogb_id_to_mag_id_file)
ogb_id_map.columns = ['ogb_id', 'mag_id']

# Load raw text data
print("Loading raw text TSV...")
raw_texts_df = pd.read_csv(TSV_PATH, sep='\t', header=None,
                           names=['mag_id', 'title', 'abstract'],
                           on_bad_lines='skip')

# Merge and align
print("Aligning OGB Node IDs with MAG Texts...")
merged_df = pd.merge(ogb_id_map, raw_texts_df, on='mag_id', how='left')
merged_df = merged_df.sort_values(by='ogb_id')
merged_df['full_text'] = merged_df['title'].fillna('') + ' ' + merged_df['abstract'].fillna('')

texts_list = merged_df['full_text'].tolist()

with open(TEXTS_OUTPUT_PATH, 'w', encoding='utf-8') as f:
    for text in texts_list:
        f.write(text.strip() + '\n')

print(f"✓ Created text file: {TEXTS_OUTPUT_PATH}")
print(f"✓ Total texts: {len(texts_list)}")



SECTION 2: Dataset Preparation
Downloading http://snap.stanford.edu/ogb/data/nodeproppred/arxiv.zip


Downloaded 0.08 GB: 100%|██████████| 81/81 [00:08<00:00,  9.08it/s]
Processing...


Extracting data/ogbn-arxiv/arxiv.zip
Loading necessary files...
This might take a while.
Processing graphs...


100%|██████████| 1/1 [00:00<00:00, 11459.85it/s]


Converting graphs into PyG objects...


100%|██████████| 1/1 [00:00<00:00, 1293.74it/s]

Saving...



Done!


Loading raw text TSV...
Aligning OGB Node IDs with MAG Texts...
✓ Created text file: data/arxiv_texts.txt
✓ Total texts: 169343


In [None]:
# ============================================================================
# SECTION 3: Load Graph Structure
# ============================================================================
print("\n" + "="*80)
print("SECTION 3: Load Graph Structure")
print("="*80)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# Load dataset
with torch.serialization.safe_globals([DataEdgeAttr, DataTensorAttr, GlobalStorage]):
    dataset = PygNodePropPredDataset(name="ogbn-arxiv", root="data/ogbn-arxiv")
    data = dataset[0]

num_nodes = data.num_nodes
num_classes = int(dataset.num_classes)
split_idx = dataset.get_idx_split()
train_idx = split_idx["train"]
val_idx = split_idx["valid"]
test_idx = split_idx["test"]

# Create masks
train_mask = torch.zeros(num_nodes, dtype=torch.bool)
train_mask[train_idx] = True
val_mask = torch.zeros(num_nodes, dtype=torch.bool)
val_mask[val_idx] = True
test_mask = torch.zeros(num_nodes, dtype=torch.bool)
test_mask[test_idx] = True

train_mask = train_mask.to(device)
val_mask = val_mask.to(device)
test_mask = test_mask.to(device)
data.edge_index = data.edge_index.to(device)
data.y = data.y.squeeze().to(device)

print(f"✓ Nodes: {num_nodes}, Classes: {num_classes}")
print(f"✓ Train: {train_idx.shape[0]}, Val: {val_idx.shape[0]}, Test: {test_idx.shape[0]}")

# Load texts
with open(TEXTS_OUTPUT_PATH, "r", encoding="utf8") as f:
    texts = [line.strip() for line in f]

print(f"✓ Loaded {len(texts)} texts")


SECTION 3: Load Graph Structure
Device: cuda
✓ Nodes: 169343, Classes: 40
✓ Train: 90941, Val: 29799, Test: 48603
✓ Loaded 169343 texts


In [None]:
# ============================================================================
# SECTION 4: Encoder Configurations
# ============================================================================
print("\n" + "="*80)
print("SECTION 4: Encoder Configurations")
print("="*80)

ENCODER_CONFIGS = {
    'minilm': {
        'model_name': 'sentence-transformers/all-MiniLM-L6-v2',
        'dim': 384,
        'description': 'Baseline (fast, lightweight)'
    },
    'mpnet': {
        'model_name': 'sentence-transformers/all-mpnet-base-v2',
        'dim': 768,
        'description': 'Industry standard'
    },
    'e5-base': {
        'model_name': 'intfloat/e5-base-v2',
        'dim': 768,
        'description': 'E5: LLM-based SOTA'
    },
    'e5-large-multilingual': {
        'model_name': 'intfloat/multilingual-e5-large',
        'dim': 1024,
        'description': 'E5: Multilingual large'
    },
    'paraphrase-multilingual': {
        'model_name': 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2',
        'dim': 768,
        'description': 'SBERT multilingual'
    }
}

print("Available encoders:")
for key, config in ENCODER_CONFIGS.items():
    print(f"  - {key}: {config['model_name']} (dim={config['dim']}) - {config['description']}")


SECTION 4: Encoder Configurations
Available encoders:
  - minilm: sentence-transformers/all-MiniLM-L6-v2 (dim=384) - Baseline (fast, lightweight)
  - mpnet: sentence-transformers/all-mpnet-base-v2 (dim=768) - Industry standard
  - e5-base: intfloat/e5-base-v2 (dim=768) - E5: LLM-based SOTA
  - e5-large-multilingual: intfloat/multilingual-e5-large (dim=1024) - E5: Multilingual large
  - paraphrase-multilingual: sentence-transformers/paraphrase-multilingual-mpnet-base-v2 (dim=768) - SBERT multilingual


In [None]:
# ============================================================================
# SECTION 5: Model Definition
# ============================================================================
print("\n" + "="*80)
print("SECTION 5: Model Definition")
print("="*80)

class SAGEModel(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, num_classes):
        super().__init__()
        self.sage_model = GraphSAGE(
            in_channels=in_channels,
            hidden_channels=hidden_channels,
            num_layers=2,
            out_channels=num_classes,
            dropout=0.5,
            act='relu'
        )

    def forward(self, x, edge_index):
        return self.sage_model(x, edge_index)

print("✓ Model class defined")



SECTION 5: Model Definition
✓ Model class defined


In [None]:
# ============================================================================
# SECTION 6: Training Functions
# ============================================================================
print("\n" + "="*80)
print("SECTION 6: Training Functions")
print("="*80)

def train_step(model, optimizer, x, edge_index, y, train_mask):
    """Single training step."""
    model.train()
    optimizer.zero_grad()
    out = model(x, edge_index)
    loss = F.cross_entropy(out[train_mask], y[train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

@torch.no_grad()
def eval_epoch(model, x, edge_index, y, train_mask, val_mask, test_mask):
    """Evaluate on all splits."""
    model.eval()
    logits = model(x, edge_index)
    preds = logits.argmax(dim=1)

    train_acc = (preds[train_mask] == y[train_mask]).float().mean().item()
    val_acc = (preds[val_mask] == y[val_mask]).float().mean().item()
    test_acc = (preds[test_mask] == y[test_mask]).float().mean().item()

    return train_acc, val_acc, test_acc

@torch.no_grad()
def compute_detailed_metrics(model, x, edge_index, y, test_mask):
    """Compute precision, recall, F1 on test set."""
    model.eval()
    logits = model(x, edge_index)
    preds = logits.argmax(dim=1)

    y_true = y[test_mask].cpu().numpy()
    y_pred = preds[test_mask].cpu().numpy()

    precision = precision_score(y_true, y_pred, average='weighted', zero_division=0)
    recall = recall_score(y_true, y_pred, average='weighted', zero_division=0)
    f1 = f1_score(y_true, y_pred, average='weighted', zero_division=0)
    accuracy = (preds[test_mask] == y[test_mask]).float().mean().item()

    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1
    }

print("✓ Training functions defined")


SECTION 6: Training Functions
✓ Training functions defined


In [None]:
# ============================================================================
# SECTION 7: Main Training Loop for All Encoders
# ============================================================================
print("\n" + "="*80)
print("SECTION 7: Training All Encoders")
print("="*80)

# Create directories
os.makedirs('embeddings', exist_ok=True)
os.makedirs('models', exist_ok=True)
os.makedirs('results', exist_ok=True)
os.makedirs('plots', exist_ok=True)

all_training_results = {}

for encoder_key, encoder_config in ENCODER_CONFIGS.items():
    print("\n" + "="*80)
    print(f"TRAINING WITH: {encoder_key}")
    print(f"Model: {encoder_config['model_name']}")
    print(f"Description: {encoder_config['description']}")
    print("="*80)

    # File paths
    emb_path = f"embeddings/{encoder_key}_clean.pt"
    model_path = f"models/{encoder_key}_model.pt"
    results_path = f"results/{encoder_key}_training_results.json"

    # Step 1: Create embeddings
    if os.path.exists(emb_path):
        print(f"✓ Loading cached embeddings from {emb_path}")
        node_embeddings = torch.load(emb_path, map_location=device)
    else:
        print(f"Creating embeddings with {encoder_key}...")
        encoder = SentenceTransformer(encoder_config['model_name'], device=device)
        encoder.eval()

        # Handle E5 models (require "query: " prefix)
        texts_to_encode = texts
        if 'e5' in encoder_config['model_name'].lower():
            print("  (Adding 'query: ' prefix for E5 model)")
            texts_to_encode = [f"query: {text}" for text in texts]

        batch_size = 64
        all_embs = []
        for start in tqdm(range(0, num_nodes, batch_size), desc="Encoding"):
            batch_texts = texts_to_encode[start:start+batch_size]
            with torch.no_grad():
                embs = encoder.encode(
                    batch_texts,
                    convert_to_tensor=True,
                    batch_size=batch_size,
                    show_progress_bar=False
                )
            all_embs.append(embs)

        node_embeddings = torch.cat(all_embs, dim=0)
        torch.save(node_embeddings.cpu(), emb_path)
        print(f"✓ Saved embeddings to {emb_path}")

    node_embeddings = node_embeddings.to(device)
    print(f"Embeddings shape: {node_embeddings.shape}")

    # Step 2: Initialize model
    model = SAGEModel(
        in_channels=node_embeddings.size(1),
        hidden_channels=128,
        num_classes=num_classes
    ).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

    # Step 3: Training loop
    print(f"\nTraining model for {encoder_key}...")

    loss_history = []
    train_acc_history = []
    val_acc_history = []
    test_acc_history = []

    best_val = 0.0
    best_test_for_best_val = 0.0
    patience = 0
    max_patience = 20

    for epoch in range(1, 201):
        loss = train_step(model, optimizer, node_embeddings, data.edge_index,
                         data.y, train_mask)
        train_acc, val_acc, test_acc = eval_epoch(model, node_embeddings,
                                                   data.edge_index, data.y,
                                                   train_mask, val_mask, test_mask)

        loss_history.append(loss)
        train_acc_history.append(train_acc)
        val_acc_history.append(val_acc)
        test_acc_history.append(test_acc)

        if val_acc > best_val:
            best_val = val_acc
            best_test_for_best_val = test_acc
            torch.save(model.state_dict(), model_path)
            patience = 0
        else:
            patience += 1

        if patience > max_patience:
            print(f"  Early stopping at epoch {epoch}")
            break

        if epoch % 10 == 0:
            print(f"  Epoch {epoch:03d} | Loss {loss:.4f} | Train {train_acc:.4f} | "
                  f"Val {val_acc:.4f} | Test {test_acc:.4f}")

    # Step 4: Load best model and compute final metrics
    print(f"\nLoading best model for {encoder_key}...")
    model.load_state_dict(torch.load(model_path, map_location=device))

    final_metrics = compute_detailed_metrics(model, node_embeddings, data.edge_index,
                                            data.y, test_mask)

    print(f"\n{'='*50}")
    print(f"FINAL RESULTS FOR {encoder_key.upper()}")
    print(f"{'='*50}")
    print(f"Best Val Accuracy:  {best_val:.4f}")
    print(f"Test Accuracy:      {final_metrics['accuracy']:.4f}")
    print(f"Test Precision:     {final_metrics['precision']:.4f}")
    print(f"Test Recall:        {final_metrics['recall']:.4f}")
    print(f"Test F1 Score:      {final_metrics['f1']:.4f}")
    print(f"{'='*50}\n")

    # Step 5: Save results
    training_results = {
        'encoder': encoder_key,
        'model_name': encoder_config['model_name'],
        'embedding_dim': encoder_config['dim'],
        'best_val_accuracy': best_val,
        'test_accuracy': final_metrics['accuracy'],
        'test_precision': final_metrics['precision'],
        'test_recall': final_metrics['recall'],
        'test_f1': final_metrics['f1'],
        'training_epochs': len(loss_history),
        'final_train_accuracy': train_acc_history[-1],
        'final_val_accuracy': val_acc_history[-1]
    }

    with open(results_path, 'w') as f:
        json.dump(training_results, f, indent=2)

    all_training_results[encoder_key] = training_results

    # Step 6: Create plots
    print(f"Creating plots for {encoder_key}...")

    # Loss curve
    plt.figure(figsize=(10, 5))
    plt.plot(loss_history, label='Training Loss', color='red', linewidth=2)
    plt.title(f'Training Loss - {encoder_key}', fontsize=14, fontweight='bold')
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('Loss', fontsize=12)
    plt.legend(fontsize=10)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(f'plots/{encoder_key}_training_loss.png', dpi=300)
    plt.close()

    # Accuracy curves
    plt.figure(figsize=(10, 5))
    plt.plot(train_acc_history, label='Train Acc', color='blue', linewidth=2)
    plt.plot(val_acc_history, label='Val Acc', color='orange', linewidth=2)
    plt.plot(test_acc_history, label='Test Acc', color='green', linewidth=2)
    plt.title(f'Accuracy Curves - {encoder_key}', fontsize=14, fontweight='bold')
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('Accuracy', fontsize=12)
    plt.legend(fontsize=10)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(f'plots/{encoder_key}_accuracy_curves.png', dpi=300)
    plt.close()

    print(f"✓ Plots saved to plots/{encoder_key}_*.png")


SECTION 7: Training All Encoders

TRAINING WITH: minilm
Model: sentence-transformers/all-MiniLM-L6-v2
Description: Baseline (fast, lightweight)
Creating embeddings with minilm...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Encoding:   0%|          | 0/2646 [00:00<?, ?it/s]

✓ Saved embeddings to embeddings/minilm_clean.pt
Embeddings shape: torch.Size([169343, 384])

Training model for minilm...
  Epoch 010 | Loss 1.9511 | Train 0.5832 | Val 0.6039 | Test 0.5806
  Epoch 020 | Loss 1.3430 | Train 0.6482 | Val 0.6543 | Test 0.6258
  Epoch 030 | Loss 1.1510 | Train 0.6828 | Val 0.6813 | Test 0.6493
  Epoch 040 | Loss 1.0626 | Train 0.7084 | Val 0.7075 | Test 0.6764
  Epoch 050 | Loss 1.0245 | Train 0.7181 | Val 0.7151 | Test 0.6877
  Epoch 060 | Loss 0.9980 | Train 0.7250 | Val 0.7185 | Test 0.6926
  Epoch 070 | Loss 0.9779 | Train 0.7290 | Val 0.7221 | Test 0.6954
  Epoch 080 | Loss 0.9637 | Train 0.7317 | Val 0.7230 | Test 0.6982
  Epoch 090 | Loss 0.9558 | Train 0.7343 | Val 0.7244 | Test 0.7009
  Epoch 100 | Loss 0.9446 | Train 0.7353 | Val 0.7258 | Test 0.7019
  Epoch 110 | Loss 0.9418 | Train 0.7364 | Val 0.7271 | Test 0.7022
  Epoch 120 | Loss 0.9355 | Train 0.7374 | Val 0.7273 | Test 0.7033
  Epoch 130 | Loss 0.9341 | Train 0.7385 | Val 0.7271 | Test 

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Encoding:   0%|          | 0/2646 [00:00<?, ?it/s]

✓ Saved embeddings to embeddings/mpnet_clean.pt
Embeddings shape: torch.Size([169343, 768])

Training model for mpnet...
  Epoch 010 | Loss 1.7703 | Train 0.6119 | Val 0.6291 | Test 0.6108
  Epoch 020 | Loss 1.2130 | Train 0.6777 | Val 0.6758 | Test 0.6543
  Epoch 030 | Loss 1.0479 | Train 0.7133 | Val 0.7101 | Test 0.6820
  Epoch 040 | Loss 0.9851 | Train 0.7295 | Val 0.7242 | Test 0.7039
  Epoch 050 | Loss 0.9521 | Train 0.7378 | Val 0.7312 | Test 0.7088
  Epoch 060 | Loss 0.9242 | Train 0.7435 | Val 0.7350 | Test 0.7144
  Epoch 070 | Loss 0.9094 | Train 0.7470 | Val 0.7382 | Test 0.7161
  Epoch 080 | Loss 0.8964 | Train 0.7489 | Val 0.7399 | Test 0.7177
  Epoch 090 | Loss 0.8891 | Train 0.7504 | Val 0.7403 | Test 0.7178
  Epoch 100 | Loss 0.8832 | Train 0.7520 | Val 0.7405 | Test 0.7190
  Epoch 110 | Loss 0.8744 | Train 0.7526 | Val 0.7421 | Test 0.7200
  Epoch 120 | Loss 0.8715 | Train 0.7533 | Val 0.7422 | Test 0.7200
  Epoch 130 | Loss 0.8670 | Train 0.7531 | Val 0.7427 | Test 0.

modules.json:   0%|          | 0.00/387 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/57.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/650 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/314 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/200 [00:00<?, ?B/s]

  (Adding 'query: ' prefix for E5 model)


Encoding:   0%|          | 0/2646 [00:00<?, ?it/s]

✓ Saved embeddings to embeddings/e5-base_clean.pt
Embeddings shape: torch.Size([169343, 768])

Training model for e5-base...
  Epoch 010 | Loss 2.8849 | Train 0.2653 | Val 0.2420 | Test 0.1431
  Epoch 020 | Loss 2.3283 | Train 0.4020 | Val 0.3816 | Test 0.3341
  Epoch 030 | Loss 1.8342 | Train 0.5445 | Val 0.5521 | Test 0.4855
  Epoch 040 | Loss 1.5302 | Train 0.6086 | Val 0.6130 | Test 0.5704
  Epoch 050 | Loss 1.3666 | Train 0.6458 | Val 0.6465 | Test 0.6108
  Epoch 060 | Loss 1.2995 | Train 0.6611 | Val 0.6543 | Test 0.6116
  Epoch 070 | Loss 1.2474 | Train 0.6756 | Val 0.6746 | Test 0.6427
  Epoch 080 | Loss 1.2082 | Train 0.6833 | Val 0.6815 | Test 0.6531
  Epoch 090 | Loss 1.1748 | Train 0.6931 | Val 0.6889 | Test 0.6592
  Epoch 100 | Loss 1.1608 | Train 0.6921 | Val 0.6869 | Test 0.6572
  Epoch 110 | Loss 1.1411 | Train 0.7003 | Val 0.6919 | Test 0.6604
  Epoch 120 | Loss 1.1259 | Train 0.7021 | Val 0.6938 | Test 0.6670
  Epoch 130 | Loss 1.1252 | Train 0.7051 | Val 0.6978 | Tes

modules.json:   0%|          | 0.00/387 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/57.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/690 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.24G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/418 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.1M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/280 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/201 [00:00<?, ?B/s]

  (Adding 'query: ' prefix for E5 model)


Encoding:   0%|          | 0/2646 [00:00<?, ?it/s]

✓ Saved embeddings to embeddings/e5-large-multilingual_clean.pt
Embeddings shape: torch.Size([169343, 1024])

Training model for e5-large-multilingual...
  Epoch 010 | Loss 3.0132 | Train 0.1955 | Val 0.1002 | Test 0.0697
  Epoch 020 | Loss 2.5735 | Train 0.2976 | Val 0.2917 | Test 0.2315
  Epoch 030 | Loss 2.0320 | Train 0.4759 | Val 0.4952 | Test 0.4495
  Epoch 040 | Loss 1.6757 | Train 0.5776 | Val 0.5956 | Test 0.5531
  Epoch 050 | Loss 1.4956 | Train 0.6205 | Val 0.6274 | Test 0.5877
  Epoch 060 | Loss 1.4011 | Train 0.6376 | Val 0.6480 | Test 0.6189
  Epoch 070 | Loss 1.3517 | Train 0.6536 | Val 0.6550 | Test 0.6235
  Epoch 080 | Loss 1.2978 | Train 0.6659 | Val 0.6685 | Test 0.6409
  Epoch 090 | Loss 1.2630 | Train 0.6720 | Val 0.6721 | Test 0.6472
  Epoch 100 | Loss 1.2436 | Train 0.6781 | Val 0.6790 | Test 0.6544
  Epoch 110 | Loss 1.2290 | Train 0.6853 | Val 0.6800 | Test 0.6486
  Epoch 120 | Loss 1.2156 | Train 0.6895 | Val 0.6874 | Test 0.6622
  Epoch 130 | Loss 1.1965 | Tr

modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/122 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/723 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.11G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/402 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Encoding:   0%|          | 0/2646 [00:00<?, ?it/s]

✓ Saved embeddings to embeddings/paraphrase-multilingual_clean.pt
Embeddings shape: torch.Size([169343, 768])

Training model for paraphrase-multilingual...
  Epoch 010 | Loss 1.8634 | Train 0.5501 | Val 0.5775 | Test 0.5359
  Epoch 020 | Loss 1.4376 | Train 0.6234 | Val 0.6279 | Test 0.5875
  Epoch 030 | Loss 1.2868 | Train 0.6531 | Val 0.6527 | Test 0.6133
  Epoch 040 | Loss 1.2140 | Train 0.6726 | Val 0.6657 | Test 0.6263
  Epoch 050 | Loss 1.1748 | Train 0.6834 | Val 0.6735 | Test 0.6302
  Epoch 060 | Loss 1.1476 | Train 0.6899 | Val 0.6778 | Test 0.6374
  Epoch 070 | Loss 1.1216 | Train 0.6950 | Val 0.6797 | Test 0.6404
  Epoch 080 | Loss 1.1080 | Train 0.6990 | Val 0.6836 | Test 0.6406
  Epoch 090 | Loss 1.0941 | Train 0.7010 | Val 0.6856 | Test 0.6452
  Epoch 100 | Loss 1.0805 | Train 0.7043 | Val 0.6848 | Test 0.6435
  Epoch 110 | Loss 1.0751 | Train 0.7065 | Val 0.6884 | Test 0.6487
  Epoch 120 | Loss 1.0665 | Train 0.7081 | Val 0.6879 | Test 0.6496
  Epoch 130 | Loss 1.0612 |

In [None]:
# ============================================================================
# SECTION 8: Comparison Summary
# ============================================================================
print("\n" + "="*80)
print("SECTION 8: Training Summary Across All Encoders")
print("="*80)

# Create comparison DataFrame
comparison_df = pd.DataFrame(all_training_results).T
comparison_df = comparison_df[['model_name', 'embedding_dim', 'best_val_accuracy',
                                'test_accuracy', 'test_precision', 'test_recall', 'test_f1']]

print("\n📊 PERFORMANCE COMPARISON:")
print(comparison_df.to_string())

# Save comparison
comparison_df.to_csv('results/encoder_comparison.csv')
print("\n✓ Comparison saved to results/encoder_comparison.csv")

# Create comparison bar chart
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

metrics = ['test_accuracy', 'test_precision', 'test_recall', 'test_f1']
titles = ['Test Accuracy', 'Test Precision', 'Test Recall', 'Test F1 Score']
colors = ['#3498db', '#e74c3c', '#2ecc71', '#f39c12']

for idx, (metric, title, color) in enumerate(zip(metrics, titles, colors)):
    ax = axes[idx // 2, idx % 2]
    values = [all_training_results[key][metric] for key in ENCODER_CONFIGS.keys()]
    bars = ax.bar(range(len(ENCODER_CONFIGS)), values, color=color, alpha=0.7, edgecolor='black')
    ax.set_xticks(range(len(ENCODER_CONFIGS)))
    ax.set_xticklabels(ENCODER_CONFIGS.keys(), rotation=45, ha='right')
    ax.set_ylabel(title, fontsize=11, fontweight='bold')
    ax.set_ylim([min(values) - 0.05, max(values) + 0.05])
    ax.grid(True, alpha=0.3, axis='y')

    # Add value labels on bars
    for bar, val in zip(bars, values):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{val:.4f}', ha='center', va='bottom', fontsize=9)

plt.suptitle('Encoder Performance Comparison', fontsize=16, fontweight='bold', y=0.995)
plt.tight_layout()
plt.savefig('plots/encoder_comparison.png', dpi=300, bbox_inches='tight')
plt.close()

print("✓ Comparison plot saved to plots/encoder_comparison.png")

# Find best encoder
best_encoder = max(all_training_results.items(),
                   key=lambda x: x[1]['test_accuracy'])

print(f"\n🏆 BEST PERFORMING ENCODER:")
print(f"   {best_encoder[0].upper()}")
print(f"   Test Accuracy: {best_encoder[1]['test_accuracy']:.4f}")
print(f"   Test F1:       {best_encoder[1]['test_f1']:.4f}")

print("\n" + "="*80)
print("✅ TRAINING COMPLETE FOR ALL ENCODERS!")
print("="*80)
print("\nGenerated files:")
print(f"  📁 embeddings/  - {len(ENCODER_CONFIGS)} embedding files")
print(f"  📁 models/      - {len(ENCODER_CONFIGS)} model checkpoints")
print(f"  📁 results/     - {len(ENCODER_CONFIGS)} result JSON files + comparison CSV")
print(f"  📁 plots/       - {len(ENCODER_CONFIGS)*2 + 1} plot images")


SECTION 8: Training Summary Across All Encoders

📊 PERFORMANCE COMPARISON:
                                                                          model_name embedding_dim best_val_accuracy test_accuracy test_precision test_recall   test_f1
minilm                                        sentence-transformers/all-MiniLM-L6-v2           384          0.728145      0.704751       0.708253    0.704751  0.694013
mpnet                                        sentence-transformers/all-mpnet-base-v2           768          0.745897      0.723268       0.723915    0.723268  0.713467
e5-base                                                          intfloat/e5-base-v2           768          0.706232      0.674773       0.666642    0.674773  0.651817
e5-large-multilingual                                 intfloat/multilingual-e5-large          1024          0.699654      0.674135       0.676609    0.674135  0.649582
paraphrase-multilingual  sentence-transformers/paraphrase-multilingual-mpnet-base-v2

In [None]:
!zip -r /content/models.zip /content/models


  adding: content/models/ (stored 0%)
  adding: content/models/minilm_model.pt (deflated 7%)
  adding: content/models/paraphrase-multilingual_model.pt (deflated 6%)
  adding: content/models/e5-large-multilingual_model.pt (deflated 5%)
  adding: content/models/mpnet_model.pt (deflated 7%)
  adding: content/models/e5-base_model.pt (deflated 6%)


In [None]:
!zip -r /content/plots.zip /content/plots

  adding: content/plots/ (stored 0%)
  adding: content/plots/mpnet_training_loss.png (deflated 24%)
  adding: content/plots/minilm_accuracy_curves.png (deflated 15%)
  adding: content/plots/e5-large-multilingual_accuracy_curves.png (deflated 10%)
  adding: content/plots/encoder_comparison.png (deflated 25%)
  adding: content/plots/e5-base_training_loss.png (deflated 21%)
  adding: content/plots/minilm_training_loss.png (deflated 23%)
  adding: content/plots/mpnet_accuracy_curves.png (deflated 15%)
  adding: content/plots/paraphrase-multilingual_accuracy_curves.png (deflated 16%)
  adding: content/plots/e5-large-multilingual_training_loss.png (deflated 20%)
  adding: content/plots/e5-base_accuracy_curves.png (deflated 11%)
  adding: content/plots/paraphrase-multilingual_training_loss.png (deflated 22%)


In [None]:
!zip -r /content/results.zip /content/results

  adding: content/results/ (stored 0%)
  adding: content/results/e5-large-multilingual_training_results.json (deflated 45%)
  adding: content/results/encoder_comparison.csv (deflated 50%)
  adding: content/results/mpnet_training_results.json (deflated 44%)
  adding: content/results/e5-base_training_results.json (deflated 44%)
  adding: content/results/minilm_training_results.json (deflated 42%)
  adding: content/results/paraphrase-multilingual_training_results.json (deflated 47%)


In [None]:
!zip -r /content/embeddings.zip /content/embeddings

  adding: content/embeddings/ (stored 0%)
  adding: content/embeddings/e5-large-multilingual_clean.pt (deflated 7%)
  adding: content/embeddings/mpnet_clean.pt (deflated 7%)
  adding: content/embeddings/minilm_clean.pt (deflated 7%)
  adding: content/embeddings/paraphrase-multilingual_clean.pt (deflated 7%)
  adding: content/embeddings/e5-base_clean.pt (deflated 8%)
