# VitRec (CUDA Edition) — Vision Transformer for Recommendation

Datasets (He et al. 2017 split):
- `data/ml-1m.train.rating` — implicit positives (user, item)
- `data/ml-1m.test.rating` — held-out positive per user
- `data/ml-1m.test.negative` — for each user: 1 positive + 99 negatives

Evaluation: per-user **HR@K** and **NDCG@K** on those 100 candidates (no leakage).


In [None]:
# ====================================================
#  Cell 1 — GPU & Environment Diagnostics
# ====================================================
import os, platform, torch

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"   # show real CUDA error lines

print("Python:", platform.python_version())
print("PyTorch:", torch.__version__)
print("Built with CUDA:", torch.version.cuda)
print("CUDA available:", torch.cuda.is_available())

if not torch.cuda.is_available():
    raise SystemExit(" CUDA not detected — please install a CUDA-enabled PyTorch build.")

DEVICE = torch.device("cuda:0")
print("Using device:", torch.cuda.get_device_name(0))

# Quick CUDA sanity check
try:
    x = torch.zeros(1, device=DEVICE)
    torch.cuda.synchronize()
    print("✅ CUDA sanity check passed.")
    print("cuDNN version:", torch.backends.cudnn.version())
    mem = torch.cuda.get_device_properties(DEVICE).total_memory / 1024**3
    print(f"Total GPU memory: {mem:.2f} GB")
except Exception as e:
    raise SystemExit(" CUDA failed to initialize:", e)


In [None]:
# ============================================================
#  Cell 2 — Imports, Configurations, and Random Seeds
# ============================================================
import os, os.path as osp, glob, time, random, math
from typing import List, Tuple, Dict, Optional
import numpy as np
import pandas as pd
import scipy.sparse as sp

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from torch.nn.utils import clip_grad_norm_
from torch.amp import GradScaler, autocast

# cuDNN tuning
cudnn.benchmark = True
cudnn.deterministic = False

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# ======================================
#  Master configuration (safe defaults)
# ======================================
CFG = dict(
    # Data
    data_path     = "data",
    data_set      = "ml-1m",
    model_path    = "model_ckpts",

    # Training
    objective     = "BPR",        # or "BCE"
    epochs        = 50,
    batch_size    = 256,          # <=512 for 6 GB GPU
    lr            = 3e-4,
    weight_decay  = 1e-4,
    warmup_epochs = 3,
    num_ng        = 4,            # negatives per positive
    top_k         = 10,
    amp           = True,

    # Model
    embed_dim     = 32,
    vit_depth     = 2,
    num_heads     = 4,
    patch_size    = 8,
    dropout       = 0.1,

    # Loader
    num_workers   = 2,
    pin_memory    = True,
    seed          = SEED,
)

assert CFG["embed_dim"] % CFG["num_heads"]  == 0
assert CFG["embed_dim"] % CFG["patch_size"] == 0

origin    = os.getcwd()
DATA_DIR  = osp.join(origin, CFG["data_path"])
DATA_FILE = osp.join(DATA_DIR, CFG["data_set"])

expected = [DATA_FILE + ".train.rating", DATA_FILE + ".test.rating", DATA_FILE + ".test.negative"]
print(" Checking dataset files:")
for p in expected:
    print("  -", p, "[OK]" if osp.exists(p) else "[MISSING]")


In [None]:
# ====================================================
#  Cell 3 — Data Loading Utilities (fixed for ID gaps)
# ====================================================
def load_all(path_prefix: str):
    """Load train/test splits and reindex user/item IDs to 0-based contiguous range."""
    train_file = path_prefix + ".train.rating"
    test_file  = path_prefix + ".test.rating"

    # --- Load train ---
    df = pd.read_csv(train_file, sep='\t', names=['user', 'item', 'rating', 'timestamp'])
    unique_users = sorted(df.user.unique())
    unique_items = sorted(df.item.unique())

    # Build mappings: original_id → new_id
    user_map = {u: i for i, u in enumerate(unique_users)}
    item_map = {i: j for j, i in enumerate(unique_items)}

    # Remap to contiguous IDs
    df['user'] = df['user'].map(user_map)
    df['item'] = df['item'].map(item_map)

    user_num = len(user_map)
    item_num = len(item_map)
    train_data = df[['user', 'item']].values.tolist()

    # Build sparse train matrix
    mat = sp.dok_matrix((user_num, item_num), dtype=np.float32)
    for u, i in train_data:
        mat[u, i] = 1.0

    # --- Load test and remap ---
    test_df = pd.read_csv(test_file, sep='\t', names=['user', 'item', 'rating', 'timestamp'])
    test_df = test_df[test_df.user.isin(user_map)]   # ensure overlap
    test_df = test_df[test_df.item.isin(item_map)]
    test_df['user'] = test_df['user'].map(user_map)
    test_df['item'] = test_df['item'].map(item_map)
    test_data = test_df[['user', 'item']].values.tolist()

    print(f"Reindexed users: {user_num}, items: {item_num}")
    return train_data, test_data, user_num, item_num, mat


In [None]:
# ====================================================
#  Cell 4 — BPR Dataset (fixed unpacking issue)
# ====================================================
class BPRDataset(Dataset):
    """Bayesian Personalized Ranking with negative sampling."""
    def __init__(self, train_data, item_num, train_mat, num_ng=4):
        self.train_data = train_data
        self.item_num = item_num
        self.train_mat = train_mat
        self.num_ng = num_ng
        self.samples = self._ng_sample()

    def _ng_sample(self):
        samples = []
        for u, i in self.train_data:
            for _ in range(self.num_ng):
                j = np.random.randint(self.item_num)
                while (u, j) in self.train_mat:
                    j = np.random.randint(self.item_num)
                samples.append([u, i, j])
        return samples

    def ng_sample(self, seed=None):
        if seed is not None:
            np.random.seed(seed)
        self.samples = self._ng_sample()

    def __len__(self): return len(self.samples)

    def __getitem__(self, idx):
         u, i, j = self.samples[idx]
         return (
           torch.tensor(u, dtype=torch.long),
           torch.tensor(i, dtype=torch.long),
           torch.tensor(j, dtype=torch.long),
        )



In [None]:
# ====================================================
# Cell 5 — ViT-based Collaborative Filtering Model (Paper Version)
# ====================================================
import torch
import torch.nn as nn

class SimpleVisionTransformer(nn.Module):
    """
    Minimal Vision Transformer backbone for recommender use.
    Each input is a 3×d×d "interaction map".
    """
    def __init__(self, in_chans=3, embed_dim=64, patch_size=8, depth=2,
                 num_heads=4, mlp_ratio=4.0, dropout=0.1):
        super().__init__()

        # Flatten patches (like ViT patch embedding)
        self.proj = nn.Conv2d(in_chans, embed_dim,
                              kernel_size=patch_size,
                              stride=patch_size)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = None  # optional positional embedding
        self.pos_drop = nn.Dropout(dropout)

        # Transformer encoder stack
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=int(embed_dim * mlp_ratio),
            dropout=dropout,
            activation="gelu",
            batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        self.norm = nn.LayerNorm(embed_dim)

        nn.init.trunc_normal_(self.cls_token, std=0.02)

    def forward(self, x):  # x: [B, 3, d, d]
        B = x.size(0)
        x = self.proj(x)                      # [B, embed_dim, H', W']
        x = x.flatten(2).transpose(1, 2)      # [B, N, embed_dim]
        cls = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls, x), dim=1)
        x = self.pos_drop(x)
        x = self.encoder(x)
        x = self.norm(x)
        return x[:, 0]                        # CLS token as global feature


class ViTRecModel(nn.Module):
    """
    ViT-based Collaborative Filtering model (outer product → ViT → score).
    """
    def __init__(self, user_count, item_count, embed_dim=64,
                 patch_size=8, depth=2, num_heads=4, dropout=0.1):
        super().__init__()
        self.user_count = user_count
        self.item_count = item_count
        self.embed_dim  = embed_dim
        self.spatial_shape = (embed_dim, embed_dim)

        # ----- ID Embeddings -----
        self.user_emb = nn.Embedding(user_count, embed_dim)
        self.item_emb = nn.Embedding(item_count, embed_dim)

        # ----- ViT Backbone -----
        self.vit = SimpleVisionTransformer(
            in_chans=3,
            embed_dim=embed_dim,
            patch_size=patch_size,
            depth=depth,
            num_heads=num_heads,
            dropout=dropout
        )

        # ----- Prediction Head -----
        self.fc1 = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(0.2)
        self.out = nn.Linear(embed_dim, 1)

    def forward(self, user_ids, item_ids):
        #  User and item embeddings
        u = self.user_emb(user_ids)    # [B, d]
        v = self.item_emb(item_ids)    # [B, d]

        #  Outer product → interaction “image”
        inter = torch.bmm(u.unsqueeze(2), v.unsqueeze(1))  # [B, d, d]
        inter = inter.unsqueeze(1)                         # [B, 1, d, d]
        img = inter.repeat(1, 3, 1, 1)                     # [B, 3, d, d]

        # Vision Transformer backbone
        feat = self.vit(img)                               # [B, d]
        feat = self.dropout(torch.relu(feat))

        # 4️Prediction head
        x = self.fc1(feat)
        x = self.out(x).squeeze(-1)
        return x


In [None]:
# ====================================================
# Cell 6 — Evaluation Metrics (HR, NDCG)
# ====================================================
def hit(gt_item, pred_items):
    return int(gt_item in pred_items)

def ndcg(gt_item, pred_items):
    if gt_item in pred_items:
        index = pred_items.index(gt_item)
        return np.reciprocal(np.log2(index + 2))
    return 0


In [None]:
# ====================================================
#  Cell 6.5 — TestUserDataset (robust version)
# ====================================================
class TestUserDataset(Dataset):
    """
    Robust loader for ML-1M test.negative file
    Handles both '42\t100 200 300...' and '(42,100)\t...' formats.
    """
    def __init__(self, path_prefix: str):
        test_file = path_prefix + ".test.negative"
        self.users, self.items = [], []

        with open(test_file, "r") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue  # skip blank lines
                
                # Split on whitespace or tab
                parts = line.replace('(', ' ').replace(')', ' ').replace(',', ' ').split()
                if len(parts) < 2:
                    continue  # skip malformed lines

                try:
                    user = int(parts[0])
                except ValueError:
                    # Try extracting the user from "42:" or "userID"
                    cleaned = ''.join([c for c in parts[0] if c.isdigit()])
                    if cleaned:
                        user = int(cleaned)
                    else:
                        continue  # skip if still bad

                # Convert rest to items
                items = []
                for token in parts[1:]:
                    try:
                        items.append(int(token))
                    except ValueError:
                        continue

                if items:
                    self.users.append(user)
                    self.items.append(items)

        print(f"Loaded {len(self.users)} users for testing.")

    def __len__(self):
        return len(self.users)

    def __getitem__(self, idx):
        user = torch.tensor(self.users[idx], dtype=torch.long)
        items = torch.tensor(self.items[idx], dtype=torch.long)
        return user, items


In [None]:
# ====================================================
# Cell 7 — DataLoader Setup
# ====================================================
train_data, test_data_flat, user_num, item_num, train_mat = load_all(DATA_FILE)
print(f"Users: {user_num}, Items: {item_num}, Train pairs: {len(train_data)}")

train_ds = BPRDataset(train_data, item_num, train_mat, num_ng=CFG['num_ng'])
test_ds  = TestUserDataset(DATA_FILE)

train_loader = DataLoader(train_ds, batch_size=CFG['batch_size'],
                          shuffle=True, num_workers=CFG['num_workers'],
                          pin_memory=CFG['pin_memory'], drop_last=True)

test_loader  = DataLoader(test_ds, batch_size=1, shuffle=False)


In [None]:
# ====================================================
# Cell 8 — Model, Optimizer, AMP Scaler & Scheduler Setup
# ====================================================
model = ViTRecModel(
    user_num,
    item_num,
    embed_dim=CFG['embed_dim'],
    patch_size=CFG['patch_size'],
    depth=CFG['vit_depth'],
    num_heads=CFG['num_heads'],
    dropout=CFG['dropout']
).to(DEVICE)

optimizer = AdamW(model.parameters(), lr=CFG['lr'], weight_decay=CFG['weight_decay'])
scaler = GradScaler('cuda', enabled=CFG['amp'])

# ----------------------------
# Scheduler with cosine warmup
# ----------------------------
total_steps  = CFG['epochs'] * len(train_loader)
warmup_steps = CFG['warmup_epochs'] * len(train_loader)

def lr_lambda(step):
    if step < warmup_steps: 
        return step / max(1, warmup_steps)
    progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
    return 0.5 * (1.0 + math.cos(math.pi * progress))

scheduler = LambdaLR(optimizer, lr_lambda)


In [None]:
# ====================================================
#  Cell 9 Training + Inline Evaluation (HR@K, NDCG@K)
# ====================================================
from tqdm import tqdm
import numpy as np

def evaluate(model, test_loader, top_k=10):
    model.eval()
    HRs, NDCGs = [], []
    with torch.no_grad():
        for user, items in test_loader:
            user, items = user.to(DEVICE), items.to(DEVICE)
            logits = model(user.repeat_interleave(items.size(1)), items)
            _, topk = torch.topk(logits, top_k)
            pred = items[0][topk.cpu().numpy()].tolist()
            gt = items[0][0].item()
            HRs.append(int(gt in pred))
            if gt in pred:
                index = pred.index(gt)
                NDCGs.append(1 / np.log2(index + 2))
            else:
                NDCGs.append(0)
    return np.mean(HRs), np.mean(NDCGs)


for epoch in range(CFG['epochs']):
    model.train()
    train_loader.dataset.ng_sample(seed=CFG['seed'] + epoch)
    total_loss = 0.0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{CFG['epochs']}", leave=True)
    for batch in pbar:
        optimizer.zero_grad(set_to_none=True)
        user, pos, neg = [x.to(DEVICE) for x in batch]
        with autocast(device_type='cuda', enabled=CFG['amp']):
            pos_logits = model(user, pos)
            neg_logits = model(user, neg)
            loss = -torch.log(torch.sigmoid(pos_logits - neg_logits) + 1e-8).mean()

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        total_loss += loss.item()
        pbar.set_postfix({"Loss": f"{loss.item():.4f}"})

    avg_loss = total_loss / len(train_loader)

    # ---- Evaluate each epoch ----
    hr, ndcg = evaluate(model, test_loader, top_k=CFG['top_k'])
    print(f"\n Epoch {epoch+1:02d}/{CFG['epochs']} "
          f"| Avg Loss: {avg_loss:.4f} "
          f"| HR@{CFG['top_k']}: {hr:.4f} "
          f"| NDCG@{CFG['top_k']}: {ndcg:.4f}\n")

    torch.cuda.empty_cache()


In [None]:
# ====================================================
#  Cell 11 — Save Model Checkpoint
# ====================================================
os.makedirs(CFG['model_path'], exist_ok=True)
ckpt_path = osp.join(CFG['model_path'], f"vitrec_{CFG['data_set']}.pt")
torch.save(model.state_dict(), ckpt_path)
print(" Model saved to:", ckpt_path)

