# GNN Encoder Training -- Affective-RAG

Trains the Graph Transformer encoder with three loss objectives:
- **MSE**: Emotion prediction from graph-smoothed representations
- **Contrastive**: Structural k-differentiation across hop depths
- **Alignment**: Cosine alignment between GNN output and sentence-BERT space

Sweeps `alignment_weight` over {0.05, 0.1, 0.3, 0.5} with `contrastive_weight` fixed at 0.2.
Downloads the best encoder checkpoint and full sweep results.

### Prerequisites
| # | Requirement |
|---|-------------|
| 1 | Runtime set to **GPU** |
| 2 | GCP project with GCS bucket containing the dataset |
| 3 | Fill in `PROJECT_ID` in the auth cell |

## Install dependencies & clone repo

In [None]:
import subprocess, sys

def _pip(*a):
    subprocess.check_call([sys.executable, "-m", "pip", "install"] + list(a) + ["-q"])

_pip('torch-geometric')

REPO_URL = "https://github.com/Prashant002-1/Affective-RAG-Recommender-Systems.git"
!git clone {REPO_URL} ARAG

_pip("-r", "ARAG/requirements.txt")
print("Done. Restart runtime.")

### Restart the runtime
Go to **Runtime > Restart runtime**, then continue from the next cell.


## Authenticate to GCS

Replace `YOUR_PROJECT_ID` with your GCP project ID.


In [None]:
from google.colab import auth
auth.authenticate_user()

import os
PROJECT_ID = "YOUR_PROJECT_ID"          # <-- set this
os.environ["GOOGLE_CLOUD_PROJECT"] = PROJECT_ID

from google.cloud import storage
storage.Client(project=PROJECT_ID)      # quick validation
print(f"Authenticated: {PROJECT_ID}")


## Import from repo

In [None]:
import sys, torch
sys.path.insert(0, "ARAG/src")

from krag.core.knowledge_graph  import KRAGEncoder
from krag.training.gnn_trainer  import GNNTrainer, TrainingConfig, prepare_emotion_ground_truth
from krag.data.adapters         import get_adapter, DatasetPath
from krag.data.ingestion        import MovieDataLoader, KnowledgeGraphBuilder

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Imports OK | device = {DEVICE}")


## Load data from GCS

In [None]:
adapter = get_adapter()          # uses default bucket / base_path

# --- 1. Movies (content items) ---
print("[1/4] Loading movies …")
content_items = MovieDataLoader(adapter).load_all_movies()
print(f"      {len(content_items)} movies")

# --- 2. Knowledge graph (from Neo4j CSVs) ---
print("[2/4] Building knowledge graph …")
kg = KnowledgeGraphBuilder(adapter).build_from_neo4j_exports()

# --- 3. Precomputed node embeddings (768-dim) ---
print("[3/4] Loading node embeddings …")
node_embeddings = adapter.load_pickle(DatasetPath.NODE_EMBEDDINGS)
print(f"      {len(node_embeddings)} embeddings")

# --- 4. Emotion ground truth ---
print("[4/4] Loading emotion ground truth …")
movies_df  = adapter.load_movies(vector_ready=True)
emotion_gt = prepare_emotion_ground_truth(movies_df)
print(f"      {len(emotion_gt)} movies with labels")


## Initialise & prepare datasets

In [None]:
config = TrainingConfig(
    embedding_dim      = 768,
    num_epochs         = 50,
    batch_size         = 32,
    learning_rate      = 1e-4,
    patience           = 10,
    contrastive_weight = 0.2,
    alignment_weight   = 0.3,
    device             = DEVICE,
)

encoder = KRAGEncoder(embedding_dim=768, num_layers=3, num_heads=4, dropout=0.1)
trainer = GNNTrainer(
    gnn_encoder=encoder,
    config=config,
    knowledge_graph=kg,
    node_embeddings=node_embeddings,
)

print("Preparing multi-k training data (one-time) ...")
train_ds, val_ds = trainer.prepare_multi_k_training_data(
    content_items        = content_items,
    node_embeddings      = node_embeddings,
    emotion_ground_truth = emotion_gt,
)
print(f"Train: {len(train_ds)}  |  Val: {len(val_ds)}")

## Alignment weight sweep

In [None]:
import copy, json

ALIGNMENT_WEIGHTS = [0.05, 0.1, 0.3, 0.5]
FIXED_CONTRASTIVE = 0.2
sweep_results = {}

for aw in ALIGNMENT_WEIGHTS:
    print(f"\n{'='*60}")
    print(f"  alignment_weight = {aw}")
    print(f"{'='*60}")

    cfg = TrainingConfig(
        embedding_dim      = 768,
        num_epochs         = 50,
        batch_size         = 32,
        learning_rate      = 1e-4,
        patience           = 10,
        contrastive_weight = FIXED_CONTRASTIVE,
        alignment_weight   = aw,
        device             = DEVICE,
    )

    enc = KRAGEncoder(embedding_dim=768, num_layers=3, num_heads=4, dropout=0.1)
    trn = GNNTrainer(
        gnn_encoder=enc,
        config=cfg,
        knowledge_graph=kg,
        node_embeddings=node_embeddings,
    )

    res = trn.train_multi_k(train_ds, val_ds)
    sweep_results[aw] = {
        "history": res["history"],
        "best_val_loss": res["best_val_loss"],
        "final_epoch": res["final_epoch"],
        "encoder_state": copy.deepcopy(enc.state_dict()),
    }

print("\n\nSweep complete.")
for aw, r in sweep_results.items():
    print(f"  aw={aw:.2f}  best_val={r['best_val_loss']:.4f}  epochs={r['final_epoch']}")

## Alignment sweep results

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 3, figsize=(18, 10))
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']

for idx, aw in enumerate(ALIGNMENT_WEIGHTS):
    h = sweep_results[aw]["history"]
    ep = [e["epoch"] for e in h]
    c = colors[idx]
    lbl = f"aw={aw}"

    axes[0, 0].plot(ep, [e["mse_loss"]         for e in h], color=c, label=lbl)
    axes[0, 1].plot(ep, [e["contrastive_loss"]  for e in h], color=c, label=lbl)
    axes[0, 2].plot(ep, [e["alignment_loss"]    for e in h], color=c, label=lbl)
    axes[1, 0].plot(ep, [e["val_loss"]          for e in h], color=c, label=lbl)
    axes[1, 1].plot(ep, [e["val_cosine"]        for e in h], color=c, label=lbl)
    axes[1, 2].plot(ep, [e["val_alignment"]     for e in h], color=c, label=lbl)

axes[0, 0].set_title("Train MSE");        axes[0, 0].set_ylabel("Loss")
axes[0, 1].set_title("Train Contrastive"); axes[0, 1].set_ylabel("Loss")
axes[0, 2].set_title("Train Alignment");   axes[0, 2].set_ylabel("Loss")
axes[1, 0].set_title("Val MSE");           axes[1, 0].set_ylabel("Loss")
axes[1, 1].set_title("Val Cosine Sim");    axes[1, 1].set_ylabel("Cosine")
axes[1, 2].set_title("Val Alignment");     axes[1, 2].set_ylabel("Loss")

for ax in axes.flat:
    ax.set_xlabel("Epoch")
    ax.legend(fontsize=8)
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig("/tmp/alignment_sweep.png", dpi=150, bbox_inches="tight")
plt.show()

print(f"\n{'aw':<8} {'best_val':<12} {'epochs':<8} {'mse':<10} {'contr':<10} {'align':<10} {'cos':<10}")
print("-" * 68)
for aw in ALIGNMENT_WEIGHTS:
    r = sweep_results[aw]
    last = r["history"][-1]
    print(f"{aw:<8.2f} {r['best_val_loss']:<12.4f} {r['final_epoch']:<8} "
          f"{last['mse_loss']:<10.4f} {last['contrastive_loss']:<10.4f} "
          f"{last['alignment_loss']:<10.4f} {last['val_cosine']:<10.4f}")

## Save best encoder

In [None]:
best_aw = min(sweep_results, key=lambda aw: sweep_results[aw]["best_val_loss"])
print(f"Best alignment_weight: {best_aw}  (val_loss={sweep_results[best_aw]['best_val_loss']:.4f})")

output = {
    "metadata": {
        "experiment": "alignment_weight_sweep",
        "contrastive_weight_fixed": FIXED_CONTRASTIVE,
        "alignment_weights_swept": ALIGNMENT_WEIGHTS,
        "best_alignment_weight": best_aw,
        "architecture": {
            "embedding_dim": 768,
            "num_layers": 3,
            "num_heads": 4,
            "dropout": 0.1,
        },
        "training": {
            "num_epochs": 50,
            "batch_size": 32,
            "learning_rate": 1e-4,
            "patience": 10,
        },
        "dataset": {
            "train_size": len(train_ds),
            "val_size": len(val_ds),
        },
        "device": DEVICE,
    },
    "summary": {},
    "runs": {},
}

for aw in ALIGNMENT_WEIGHTS:
    r = sweep_results[aw]
    last = r["history"][-1]
    output["summary"][str(aw)] = {
        "best_val_loss": r["best_val_loss"],
        "final_epoch": r["final_epoch"],
        "final_mse": last["mse_loss"],
        "final_contrastive": last["contrastive_loss"],
        "final_alignment": last["alignment_loss"],
        "final_val_cosine": last["val_cosine"],
        "final_val_alignment": last["val_alignment"],
    }
    output["runs"][str(aw)] = r["history"]

with open("/tmp/alignment_sweep_results.json", "w") as f:
    json.dump(output, f, indent=2)

from google.colab.files import download

for aw, r in sweep_results.items():
    path = f"/tmp/krag_encoder_aw{aw}.pt"
    torch.save(r["encoder_state"], path)
    download(path)

download("/tmp/alignment_sweep_results.json")
download("/tmp/alignment_sweep.png")
print("Downloads triggered.")