In [None]:
from google.colab import drive
drive.mount('/content/drive')

MessageError: Error: credential propagation was unsuccessful

In [None]:
# Create data directory
!mkdir -p data

# Download MovieLens-1M dataset from the GitHub repository
!wget -O data/ml-1m.txt https://raw.githubusercontent.com/pmixer/SASRec.pytorch/main/python/data/ml-1m.txt

print("Dataset downloaded successfully!")

# Verify the dataset
if os.path.exists('data/ml-1m.txt'):
    with open('data/ml-1m.txt', 'r') as f:
        lines = f.readlines()
    print(f"Dataset loaded with {len(lines)} interactions")
    print("First few lines:")
    for i in range(min(5, len(lines))):
        print(lines[i].strip())
else:
    print("Error: Dataset not found!")

--2025-06-06 03:21:27--  https://raw.githubusercontent.com/pmixer/SASRec.pytorch/main/python/data/ml-1m.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.111.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 9053831 (8.6M) [text/plain]
Saving to: ‘data/ml-1m.txt’


2025-06-06 03:21:28 (216 MB/s) - ‘data/ml-1m.txt’ saved [9053831/9053831]

Dataset downloaded successfully!
Dataset loaded with 999611 interactions
First few lines:
1 1
1 2
1 3
1 4
1 5


In [None]:
!pip install torch-geometric
!pip install wandb
!pip install pyg-lib -f https://data.pyg.org/whl/torch-2.6.0+cu124.html

Collecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m38.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-geometric
Successfully installed torch-geometric-2.6.1
Looking in links: https://data.pyg.org/whl/torch-2.6.0+cu124.html
Collecting pyg-lib
  Downloading https://data.pyg.org/whl/torch-2.6.0%2Bcu124/pyg_lib-0.4.0%2Bpt26cu124-cp311-cp311-linux_x86_64.whl (4.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.7/4.7 MB[0m [31m35.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pyg-lib
Successfully installed

In [None]:
!pip install optuna

Collecting optuna
  Downloading optuna-4.3.0-py3-none-any.whl.metadata (17 kB)
Collecting alembic>=1.5.0 (from optuna)
  Downloading alembic-1.16.1-py3-none-any.whl.metadata (7.3 kB)
Collecting colorlog (from optuna)
  Downloading colorlog-6.9.0-py3-none-any.whl.metadata (10 kB)
Downloading optuna-4.3.0-py3-none-any.whl (386 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m386.6/386.6 kB[0m [31m18.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading alembic-1.16.1-py3-none-any.whl (242 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m242.5/242.5 kB[0m [31m24.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading colorlog-6.9.0-py3-none-any.whl (11 kB)
Installing collected packages: colorlog, alembic, optuna
Successfully installed alembic-1.16.1 colorlog-6.9.0 optuna-4.3.0


In [None]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [None]:
import os
import math
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data
from torch_geometric.nn import GATConv, MessagePassing
from torch_geometric.utils import degree
from collections import defaultdict

#===============================================================================
# 1) Utilities
#===============================================================================
def set_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark    = False

def create_train_val_test_split(ratings_df,
                                val_ratio=0.1,
                                test_ratio=0.1,
                                min_interactions=5,
                                positive_threshold=None,
                                seed=42):
    """
    Chronological per-user train/val/test split.
    Returns train_df, val_df, test_df, num_users, num_items
    """
    np.random.seed(seed)
    df = ratings_df.copy()
    if positive_threshold is not None:
        df = df[df.rating >= positive_threshold]

    counts = df.userId.value_counts()
    valid  = counts[counts >= min_interactions].index
    df = df[df.userId.isin(valid)]
    df = df.sort_values(['userId','timestamp'])

    trains, vals, tests = [], [], []
    for _, udf in df.groupby('userId'):
        n = len(udf)
        n_test  = max(1, int(n*test_ratio))
        n_val   = max(1, int(n*val_ratio))
        n_train = n - n_val - n_test
        if n_train < 1:
            continue
        trains.append( udf.iloc[:n_train] )
        vals.append(   udf.iloc[n_train:n_train+n_val] )
        tests.append(  udf.iloc[n_train+n_val:] )

    train_df = pd.concat(trains, ignore_index=True)
    val_df   = pd.concat(vals,   ignore_index=True)
    test_df  = pd.concat(tests,  ignore_index=True)

    # remap to contiguous 0..N-1
    u2idx = {u:i for i,u in enumerate(train_df.userId.unique())}
    i2idx = {m:i for i,m in enumerate(train_df.movieId.unique())}

    for d in (train_df, val_df, test_df):
        d['user_idx']  = d.userId .map(u2idx)
        d['movie_idx'] = d.movieId.map(i2idx)

    num_users = len(u2idx)
    num_items = len(i2idx)
    print(f"Split sizes: train={len(train_df)}, val={len(val_df)}, test={len(test_df)}")
    print(f"#users={num_users}, #items={num_items}")
    return train_df, val_df, test_df, num_users, num_items

def create_pyg_data(df, num_users, num_items):
    """
    Build a PyG bipartite graph.
    Users: 0..num_users-1, items: num_users..num_users+num_items-1
    """
    us  = torch.LongTensor(df.user_idx.values)
    is_ = torch.LongTensor(df.movie_idx.values) + num_users
    edge_index = torch.stack([torch.cat([us,is_]), torch.cat([is_,us])], dim=0)
    data = Data(edge_index=edge_index, num_nodes=num_users+num_items)
    data.num_users = num_users
    data.num_items = num_items
    data.orig_interactions = torch.stack([us,is_], dim=1)
    return data

#===============================================================================
# 2) Graph eval + train
#===============================================================================
def evaluate_graph_model(model, train_data, eval_data, device, k=10):
    """
    HR@k, NDCG@k, precision, recall, f1, mrr for graph-based models
    """
    model.eval()
    metrics = defaultdict(list)

    # 1) compute embeddings on the correct device
    with torch.no_grad():
        edge_index = train_data.edge_index.to(device)
        embs = model(edge_index)
        U = model.num_users
        user_embs = embs[:U]
        item_embs = embs[U:]

    # 2) build train‐map and eval‐map on CPU
    train_map = defaultdict(set)
    for u,i in train_data.orig_interactions.cpu().tolist():
        train_map[u].add(i - U)

    eval_map = defaultdict(list)
    for u,i in eval_data.orig_interactions.cpu().tolist():
        eval_map[u].append(i - U)

    # 3) per‐user ranking
    for u, true_items in eval_map.items():
        if not true_items:
            continue

        # Perform multiplication on the GPU, then move result to CPU
        scores = (item_embs @ user_embs[u]).cpu().numpy()

        # mask training items
        for ti in train_map[u]:
            scores[ti] = -1e9

        # top‐k
        topk = np.argpartition(-scores, k)[:k]
        hits = [1 if t in true_items else 0 for t in topk]

        # HR
        hr = 1.0 if any(hits) else 0.0
        # NDCG
        dcg  = sum(h / math.log2(idx+2) for idx,h in enumerate(hits))
        idcg = sum(1.0/math.log2(i+2) for i in range(min(len(true_items), k)))
        ndcg = dcg/idcg if idcg>0 else 0.0
        # prec/rec/f1
        prec = sum(hits)/k
        rec  = sum(hits)/len(true_items)
        f1   = 2*prec*rec/(prec+rec) if (prec+rec)>0 else 0.0
        # MRR
        mrr  = next((1.0/(idx+1) for idx,h in enumerate(hits) if h), 0.0)

        metrics['hr'].append(hr)
        metrics['ndcg'].append(ndcg)
        metrics['precision'].append(prec)
        metrics['recall'].append(rec)
        metrics['f1'].append(f1)
        metrics['mrr'].append(mrr)

    return {m: np.mean(v) for m,v in metrics.items()}

def train_graph_epoch_regularized(model,
                                  train_data,
                                  optimizer,
                                  device,
                                  batch_size=8192,
                                  l2_reg=1e-4):
    """
    One epoch of BPR + L2 training for graph models.
    """
    model.train()
    edge_index = train_data.edge_index.to(device)
    pos_edges  = train_data.orig_interactions.to(device)
    E = pos_edges.size(0)
    perm = torch.randperm(E, device=device)
    n_batches = (E + batch_size - 1) // batch_size
    total_loss = 0.0

    for b in range(n_batches):
        idx   = perm[b*batch_size:(b+1)*batch_size]
        users = pos_edges[idx,0]
        items = pos_edges[idx,1]
        negs  = torch.randint(model.num_users,
                              model.num_users+model.num_items,
                              size=users.size(), device=device)

        optimizer.zero_grad()
        pos_scores = model.predict(edge_index, users, items)
        neg_scores = model.predict(edge_index, users, negs)
        bpr_loss = -torch.log(torch.sigmoid(pos_scores - neg_scores)+1e-8).mean()
        l2_loss  = l2_reg * (
                model.user_emb(users).pow(2).mean()
              + model.item_emb(items-model.num_users).pow(2).mean()
              + model.item_emb(negs-model.num_users).pow(2).mean()
        )
        loss = bpr_loss + l2_loss
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(),1.0)
        optimizer.step()
        total_loss += loss.item()

    return total_loss / n_batches

#===============================================================================
# 3) Graph Models
#===============================================================================
class GATRec(nn.Module):
    def __init__(self, num_users, num_items,
                 embedding_dim=64, hidden_dim=64,
                 heads=4, n_layers=3):
        super().__init__()
        self.num_users = num_users
        self.num_items = num_items
        self.user_emb  = nn.Embedding(num_users, embedding_dim)
        self.item_emb  = nn.Embedding(num_items, embedding_dim)
        self.convs     = nn.ModuleList()
        # first layer
        self.convs.append(GATConv(embedding_dim, hidden_dim, heads=heads, dropout=0.6))
        # middle
        for _ in range(n_layers-2):
            self.convs.append(GATConv(hidden_dim*heads, hidden_dim, heads=heads, dropout=0.6))
        # final
        self.convs.append(GATConv(hidden_dim*heads, hidden_dim, heads=1, concat=False, dropout=0.6))
        self.dropout = nn.Dropout(0.6)
        # init
        nn.init.xavier_uniform_(self.user_emb.weight)
        nn.init.xavier_uniform_(self.item_emb.weight)

    def forward(self, edge_index):
        x = torch.cat([self.user_emb.weight, self.item_emb.weight], dim=0)
        for conv in self.convs[:-1]:
            x = conv(x, edge_index)
            x = F.relu(x); x = self.dropout(x)
        x = self.convs[-1](x, edge_index)
        return x

    def predict(self, edge_index, u, i):
        embs = self(edge_index)
        return (embs[u]*embs[i]).sum(-1)

class LightGCNConv(MessagePassing):
    def __init__(self): super().__init__(aggr='add')
    def forward(self, x, edge_index):
        row,col = edge_index
        deg = degree(torch.cat([row,col]), x.size(0), dtype=x.dtype)
        d   = deg.pow(-0.5); d[d==float('inf')] = 0
        norm= d[row]*d[col]
        return self.propagate(edge_index, x=x, norm=norm)
    def message(self,x_j,norm):
        return norm.view(-1,1)*x_j

class LightGCN(nn.Module):
    def __init__(self, num_users, num_items,
                 embedding_dim=64, n_layers=3):
        super().__init__()
        self.num_users = num_users
        self.num_items = num_items
        self.user_emb  = nn.Embedding(num_users, embedding_dim)
        self.item_emb  = nn.Embedding(num_items, embedding_dim)
        self.conv      = LightGCNConv()
        self.n_layers  = n_layers
        nn.init.xavier_uniform_(self.user_emb.weight)
        nn.init.xavier_uniform_(self.item_emb.weight)

    def forward(self, edge_index):
        x0 = torch.cat([self.user_emb.weight, self.item_emb.weight], dim=0)
        embs = [x0]; x=x0
        for _ in range(self.n_layers):
            x = self.conv(x, edge_index)
            embs.append(x)
        return torch.stack(embs,0).mean(0)

    def predict(self, edge_index, u, i):
        embs = self(edge_index)
        return (embs[u]*embs[i]).sum(-1)

#===============================================================================
# 4) SASRec + train/eval
#===============================================================================
class SASRecBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout):
        super().__init__()
        self.ln1  = nn.LayerNorm(d_model, eps=1e-8)
        self.attn = nn.MultiheadAttention(d_model,n_heads,
                                          dropout=dropout,
                                          batch_first=True)
        self.drop1= nn.Dropout(dropout)
        self.ln2  = nn.LayerNorm(d_model, eps=1e-8)
        self.ffn  = nn.Sequential(
            nn.Linear(d_model,d_ff), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(d_ff,d_model), nn.Dropout(dropout)
        )
    def forward(self, x, causal_mask, pad_mask):
        z = x; x_n = self.ln1(z)
        a,_ = self.attn(x_n, x_n, x_n,
                       attn_mask=causal_mask,
                       key_padding_mask=pad_mask)
        x = z + self.drop1(a)
        z2 = x; x2_n = self.ln2(z2)
        f  = self.ffn(x2_n)
        return z2 + f

class SASRec(nn.Module):
    PAD = 0
    def __init__(self, num_users, num_items,
                 maxlen=50, hidden=50,
                 blocks=2, heads=1, drop=0.2):
        super().__init__()
        self.maxlen   = maxlen
        self.item_emb = nn.Embedding(num_items+1, hidden, padding_idx=self.PAD)
        self.pos_emb  = nn.Embedding(maxlen+1, hidden, padding_idx=0)
        self.drop     = nn.Dropout(drop)
        self.blocks   = nn.ModuleList([
            SASRecBlock(hidden, heads, hidden, drop) for _ in range(blocks)
        ])
        self.ln_final = nn.LayerNorm(hidden, eps=1e-8)
        # init
        for m in self.modules():
            if isinstance(m, (nn.Embedding, nn.Linear)):
                nn.init.xavier_uniform_(m.weight)
                if hasattr(m,'bias') and m.bias is not None:
                    nn.init.zeros_(m.bias)

    def log2feats(self, logs):
        B,L = logs.shape
        x   = self.item_emb(logs) * math.sqrt(self.item_emb.embedding_dim)
        pos = torch.arange(1, L+1, device=logs.device).unsqueeze(0).expand(B,-1)
        pos = pos * (logs != self.PAD)
        x   = x + self.pos_emb(pos)
        x   = self.drop(x)
        pad_mask = logs.eq(self.PAD)
        causal   = torch.triu(
            torch.ones(L,L,device=logs.device,dtype=torch.bool), diagonal=1
        )
        for blk in self.blocks:
            x = blk(x, causal_mask=causal, pad_mask=pad_mask)
        return self.ln_final(x)

    def forward(self, _, log_seqs, pos_seqs, neg_seqs):
        device = next(self.parameters()).device
        logs = torch.tensor(log_seqs, dtype=torch.long, device=device)
        pos  = torch.tensor(pos_seqs, dtype=torch.long, device=device)
        neg  = torch.tensor(neg_seqs, dtype=torch.long, device=device)
        feats= self.log2feats(logs)        # [B,L,H]
        final= feats[:,-1,:]               # [B,H]
        pe   = self.item_emb(pos)          # [B,H]
        ne   = self.item_emb(neg)          # [B,H]
        return (final*pe).sum(-1), (final*ne).sum(-1)

    def predict(self, _, log_seqs, item_idx):
        device = next(self.parameters()).device
        logs = torch.tensor(log_seqs, dtype=torch.long, device=device)
        items= torch.tensor(item_idx, dtype=torch.long, device=device)
        feats= self.log2feats(logs)        # [B,L,H]
        final= feats[:,-1,:]               # [B,H]
        ie   = self.item_emb(items)        # [I,H]
        return (ie * final).sum(-1)

class SASRecDataset(Dataset):
    def __init__(self, logs, pos, neg):
        self.logs = logs
        self.pos  = pos
        self.neg  = neg
    def __len__(self):
        return len(self.logs)
    def __getitem__(self, idx):
        return self.logs[idx], self.pos[idx], self.neg[idx]

def create_sasrec_sequences(df, num_items):
    """
    user -> [shifted+1 item IDs clipped to 1..num_items]
    """
    user_seq = {}
    for u, grp in df.sort_values(['user_idx','timestamp']).groupby('user_idx'):
        # The clip method here is a safe way to handle potential index errors
        seq = (grp.movie_idx.values + 1).clip(1, num_items).astype(int).tolist()
        user_seq[u] = seq
    return user_seq

def create_sasrec_training_data(user_seqs, num_items, maxlen):
    """
    Precompute (log_seq, pos, neg) for each subsequence.
    """
    logs, pos, neg = [], [], []
    all_items = set(range(1, num_items+1))
    for u, seq in user_seqs.items():
        for i in range(1, len(seq)):
            inp = seq[:i]
            tgt = seq[i]
            negs = list(all_items - set(inp))
            nid  = np.random.choice(negs)
            padded = [0]*(maxlen - len(inp)) + inp[-maxlen:]
            logs.append(padded); pos.append(tgt); neg.append(nid)
    return (np.array(logs, dtype=np.int64),
            np.array(pos,  dtype=np.int64),
            np.array(neg,  dtype=np.int64))

def train_sasrec(model, dataset, optimizer, device, bs=128):
    """
    One epoch over the SASRecDataset.
    """
    loader = DataLoader(dataset, batch_size=bs, shuffle=True)
    model.train()
    for logs, pos, neg in loader:
        optimizer.zero_grad()
        pl, nl = model(None,
                       logs.cpu().numpy(),
                       pos.cpu().numpy(),
                       neg.cpu().numpy())
        loss = F.binary_cross_entropy_with_logits(pl, torch.ones_like(pl)) \
             + F.binary_cross_entropy_with_logits(nl, torch.zeros_like(nl))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

def evaluate_sasrec_model(model,
                          train_seqs,
                          eval_seqs,
                          num_items,
                          device,
                          k=10,
                          maxlen=50,
                          val_seqs=None):
    """
    Vectorized eval for SASRec: HR/ndcg/prec/rec/f1/mrr
    """
    model.eval()
    users = sorted(set(train_seqs) & set(eval_seqs))
    U = len(users)
    seqs = torch.zeros((U, maxlen), dtype=torch.long, device=device)

    # build prefix matrix
    for i,u in enumerate(users):
        hist = train_seqs[u].copy()
        if val_seqs and u in val_seqs:
            hist += val_seqs[u]
        tail = hist[-maxlen:]
        tail = [min(max(1,int(x)), num_items) for x in tail]
        L = len(tail)
        if L>0:
            seqs[i, maxlen-L:] = torch.tensor(tail, device=device)

    with torch.no_grad():
        feats = model.log2feats(seqs)   # [U,L,H]
        final= feats[:,-1,:]            # [U,H]
        items= torch.arange(1, num_items+1, device=device)
        ie   = model.item_emb(items)    # [I,H]
        scores = final @ ie.T           # [U,I]

    # mask seen
    for i,u in enumerate(users):
        seen = set(train_seqs[u])
        if val_seqs and u in val_seqs:
            seen |= set(val_seqs[u])
        mask = [s-1 for s in seen if 1<=s<=num_items]
        if mask:
            scores[i, mask] = -1e9

    topk = torch.topk(scores, k, dim=1).indices.cpu().tolist()
    hr, ndcg, prec, rec, f1, mrr = [],[],[],[],[],[]
    for i,u in enumerate(users):
        truth = set(eval_seqs[u])
        hits  = [1 if (t+1) in truth else 0 for t in topk[i]]
        hr.append(1.0 if any(hits) else 0.0)
        dcg  = sum(h/math.log2(j+2) for j,h in enumerate(hits))
        idcg = sum(1.0/math.log2(j+2) for j in range(min(len(truth),k)))
        ndcg.append(dcg/idcg if idcg>0 else 0.0)
        p = sum(hits)/k
        r = sum(hits)/len(truth)
        prec.append(p); rec.append(r)
        f1.append(2*p*r/(p+r) if (p+r)>0 else 0.0)
        mrr.append(next((1.0/(j+1) for j,h in enumerate(hits) if h), 0.0))

    return {
      'hr':        np.mean(hr),
      'ndcg':      np.mean(ndcg),
      'precision': np.mean(prec),
      'recall':    np.mean(rec),
      'f1':        np.mean(f1),
      'mrr':       np.mean(mrr)
    }

#===============================================================================
# 5) Cross‐Validation
#===============================================================================
def robust_cross_validation_graph(model_cls, params, ratings_df,
                                  n_folds=3, device='cuda'):
    print(f"\n=== Graph CV ({n_folds} folds) ===")
    scores=[]
    for fold in range(n_folds):
        print(f"\n-- fold {fold+1}")
        set_seed(42+fold)
        tr,va,te,nu,ni = create_train_val_test_split(
            ratings_df, min_interactions=3, seed=42+fold
        )
        train_data = create_pyg_data(tr, nu, ni).to(device)
        val_data   = create_pyg_data(va, nu, ni).to(device)
        test_data  = create_pyg_data(te, nu, ni).to(device)

        model = model_cls(nu, ni, **params).to(device)
        opt   = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

        best_hr, pat = 0.0, 0
        ckpt = f'gat_fold_{fold}.pt'
        for ep in range(50):
            train_graph_epoch_regularized(model, train_data, opt, device)
            if (ep+1)%5==0:
                m = evaluate_graph_model(model, train_data, val_data, device)
                print(f"  ep {ep+1:02d} val HR@10 = {m['hr']:.4f}")
                if m['hr']>best_hr:
                    best_hr, pat = m['hr'], 0
                    torch.save(model.state_dict(), ckpt)
                else:
                    pat += 1
                if pat>=5:
                    print("  early stopping"); break

        model.load_state_dict(torch.load(ckpt))
        tm = evaluate_graph_model(model, train_data, test_data, device)
        print(f"  TEST HR@10 = {tm['hr']:.4f}")
        scores.append(tm['hr'])
        os.remove(ckpt)

    m, s = np.mean(scores), np.std(scores)
    print(f"\nGraph CV HR@10 = {m:.4f} ± {s:.4f}")
    return m, s

def robust_cross_validation_sasrec(params, ratings_df,
                                   n_folds=3, device='cuda'):
    print(f"\n=== SASRec CV ({n_folds} folds) ===")
    scores=[]
    for fold in range(n_folds):
        print(f"\n-- fold {fold+1}")
        set_seed(42+fold)
        tr,va,te,nu,ni = create_train_val_test_split(
            ratings_df, min_interactions=3, seed=42+fold
        )
        train_seqs = create_sasrec_sequences(tr, num_items=ni)
        val_seqs   = create_sasrec_sequences(va, num_items=ni)
        test_seqs  = create_sasrec_sequences(te, num_items=ni)

        logs,pos,neg = create_sasrec_training_data(train_seqs, ni, params['maxlen'])
        ds = SASRecDataset(logs,pos,neg)

        model = SASRec(nu, ni,
                       maxlen=params['maxlen'],
                       hidden=params['hidden'],
                       blocks=params['blocks'],
                       heads=params['heads'],
                       drop=params['drop']).to(device)
        opt   = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

        best_hr, pat = 0.0, 0
        ckpt = f'sasrec_fold_{fold}.pt'
        for ep in range(50):
            train_sasrec(model, ds, opt, device)
            if (ep+1)%5==0:
                m = evaluate_sasrec_model(model,
                                          train_seqs, val_seqs,
                                          ni, device,
                                          k=10, maxlen=params['maxlen'])
                print(f"  ep {ep+1:02d} val HR@10 = {m['hr']:.4f}")
                if m['hr']>best_hr:
                    best_hr, pat = m['hr'], 0
                    torch.save(model.state_dict(), ckpt)
                else:
                    pat += 1
                if pat>=5:
                    print("  early stopping"); break

        model.load_state_dict(torch.load(ckpt))
        tm = evaluate_sasrec_model(model,
                                   train_seqs, test_seqs,
                                   ni, device,
                                   k=10, maxlen=params['maxlen'],
                                   val_seqs=val_seqs)
        print(f"  TEST HR@10 = {tm['hr']:.4f}")
        scores.append(tm['hr'])
        os.remove(ckpt)

    m,s = np.mean(scores), np.std(scores)
    print(f"\nSASRec CV HR@10 = {m:.4f} ± {s:.4f}")
    return m, s

#===============================================================================
# 6) Main
#===============================================================================
if __name__=="__main__":
    set_seed(42)
    # load ratings.csv
    possible = ['/content/drive/MyDrive/movielens/ratings.csv','ratings.csv','data/ratings.csv','./ratings.csv']
    path = next((p for p in possible if os.path.exists(p)), None)
    if path is None:
        raise FileNotFoundError("ratings.csv not found")
    ratings_df = pd.read_csv(path)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Device:", device)

    # Graph CV
    gat_params  = {'embedding_dim':128,'hidden_dim':128,'heads':4,'n_layers':2}
    lgcn_params = {'embedding_dim':128,'n_layers':2}

    gm,gs = robust_cross_validation_graph(GATRec,   gat_params,  ratings_df,
                                         n_folds=3, device=device)
    lm,ls = robust_cross_validation_graph(LightGCN, lgcn_params, ratings_df,
                                         n_folds=3, device=device)

    # SASRec CV
    sasrec_params = {'maxlen':50,'hidden':50,'blocks':2,'heads':1,'drop':0.2}
    sm,ss = robust_cross_validation_sasrec(sasrec_params, ratings_df,
                                           n_folds=3, device=device)

    print("\n=== SUMMARY ===")
    print(f"GATRec   HR@10 = {gm:.4f} ± {gs:.4f}")
    print(f"LightGCN HR@10 = {lm:.4f} ± {ls:.4f}")
    print(f"SASRec   HR@10 = {sm:.4f} ± {ss:.4f}")


Device: cuda

=== Graph CV (3 folds) ===

-- fold 1
Split sizes: train=81200, val=9818, test=9818
#users=610, #items=8255
  ep 05 val HR@10 = 0.1098
  ep 10 val HR@10 = 0.1738
  ep 15 val HR@10 = 0.2262
  ep 20 val HR@10 = 0.2246
  ep 25 val HR@10 = 0.2393
  ep 30 val HR@10 = 0.2508
  ep 35 val HR@10 = 0.2410
  ep 40 val HR@10 = 0.2508
  ep 45 val HR@10 = 0.2557
  ep 50 val HR@10 = 0.2656
  TEST HR@10 = 0.2197

-- fold 2
Split sizes: train=81200, val=9818, test=9818
#users=610, #items=8255
  ep 05 val HR@10 = 0.1098
  ep 10 val HR@10 = 0.1787
  ep 15 val HR@10 = 0.2197
  ep 20 val HR@10 = 0.2361
  ep 25 val HR@10 = 0.2557
  ep 30 val HR@10 = 0.2738
  ep 35 val HR@10 = 0.2672
  ep 40 val HR@10 = 0.2574
  ep 45 val HR@10 = 0.2508
  ep 50 val HR@10 = 0.2541
  TEST HR@10 = 0.2033

-- fold 3
Split sizes: train=81200, val=9818, test=9818
#users=610, #items=8255
  ep 05 val HR@10 = 0.1098
  ep 10 val HR@10 = 0.1885
  ep 15 val HR@10 = 0.2279
  ep 20 val HR@10 = 0.2361
  ep 25 val HR@10 = 0.24

  seq = (grp.movie_idx.values + 1).clip(1, num_items).astype(int).tolist()


  ep 05 val HR@10 = 0.0820
  ep 10 val HR@10 = 0.0820
  ep 15 val HR@10 = 0.0820
  ep 20 val HR@10 = 0.0820
  ep 25 val HR@10 = 0.0820
  ep 30 val HR@10 = 0.0820
  early stopping
  TEST HR@10 = 0.0918

-- fold 2
Split sizes: train=81200, val=9818, test=9818
#users=610, #items=8255


  seq = (grp.movie_idx.values + 1).clip(1, num_items).astype(int).tolist()


  ep 05 val HR@10 = 0.0820
  ep 10 val HR@10 = 0.0820
  ep 15 val HR@10 = 0.0820
  ep 20 val HR@10 = 0.0820
  ep 25 val HR@10 = 0.0820
  ep 30 val HR@10 = 0.0820
  early stopping
  TEST HR@10 = 0.0918

-- fold 3
Split sizes: train=81200, val=9818, test=9818
#users=610, #items=8255


  seq = (grp.movie_idx.values + 1).clip(1, num_items).astype(int).tolist()


  ep 05 val HR@10 = 0.0820
  ep 10 val HR@10 = 0.0820
  ep 15 val HR@10 = 0.0820
  ep 20 val HR@10 = 0.0820
  ep 25 val HR@10 = 0.0820
  ep 30 val HR@10 = 0.0820
  early stopping
  TEST HR@10 = 0.0918

SASRec CV HR@10 = 0.0918 ± 0.0000

=== SUMMARY ===
GATRec   HR@10 = 0.2169 ± 0.0102
LightGCN HR@10 = 0.2022 ± 0.0085
SASRec   HR@10 = 0.0918 ± 0.0000


In [None]:
import os
import math
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data
from torch_geometric.nn import GATConv, MessagePassing
from torch_geometric.utils import degree
from collections import defaultdict

#===============================================================================
# 1) Utilities
#===============================================================================
def set_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark    = False

def create_train_val_test_split(ratings_df,
                                val_ratio=0.1,
                                test_ratio=0.1,
                                min_interactions=1,
                                positive_threshold=None,
                                seed=42):
    """
    Chronological per-user train/val/test split.
    Returns train_df, val_df, test_df, num_users, num_items
    """
    np.random.seed(seed)
    df = ratings_df.copy()
    if positive_threshold is not None:
        df = df[df.rating >= positive_threshold]

    counts = df.user_id.value_counts()  # Changed from userId to user_id
    valid  = counts[counts >= min_interactions].index
    df = df[df.user_id.isin(valid)]     # Changed from userId to user_id
    df = df.sort_values(['user_id','timestamp'])  # Changed from userId to user_id

    trains, vals, tests = [], [], []
    for _, udf in df.groupby('user_id'):  # Changed from userId to user_id
        n = len(udf)
        n_test  = max(1, int(n*test_ratio))
        n_val   = max(1, int(n*val_ratio))
        n_train = n - n_val - n_test
        if n_train < 1:
            continue
        trains.append( udf.iloc[:n_train] )
        vals.append(   udf.iloc[n_train:n_train+n_val] )
        tests.append(  udf.iloc[n_train+n_val:] )

    train_df = pd.concat(trains, ignore_index=True)
    val_df   = pd.concat(vals,   ignore_index=True)
    test_df  = pd.concat(tests,  ignore_index=True)

    # remap to contiguous 0..N-1
    u2idx = {u:i for i,u in enumerate(train_df.user_id.unique())}    # Changed from userId
    i2idx = {m:i for i,m in enumerate(train_df.item_id.unique())}    # Changed from movieId to item_id

    for d in (train_df, val_df, test_df):
        d['user_idx']  = d.user_id .map(u2idx)   # Changed from userId
        d['movie_idx'] = d.item_id.map(i2idx)    # Changed from movieId to item_id

    num_users = len(u2idx)
    num_items = len(i2idx)
    print(f"Split sizes: train={len(train_df)}, val={len(val_df)}, test={len(test_df)}")
    print(f"#users={num_users}, #items={num_items}")
    return train_df, val_df, test_df, num_users, num_items

def create_pyg_data(df, num_users, num_items):
    """
    Build a PyG bipartite graph.
    Users: 0..num_users-1, items: num_users..num_users+num_items-1
    """
    us  = torch.LongTensor(df.user_idx.values)
    is_ = torch.LongTensor(df.movie_idx.values) + num_users
    edge_index = torch.stack([torch.cat([us,is_]), torch.cat([is_,us])], dim=0)
    data = Data(edge_index=edge_index, num_nodes=num_users+num_items)
    data.num_users = num_users
    data.num_items = num_items
    data.orig_interactions = torch.stack([us,is_], dim=1)
    return data

#===============================================================================
# 2) Graph eval + train
#===============================================================================
def evaluate_graph_model(model, train_data, eval_data, device, k=10):
    """
    HR@k, NDCG@k, precision, recall, f1, mrr for graph-based models
    """
    model.eval()
    metrics = defaultdict(list)

    # 1) compute embeddings on the correct device
    with torch.no_grad():
        edge_index = train_data.edge_index.to(device)
        embs = model(edge_index)
        U = model.num_users
        user_embs = embs[:U]
        item_embs = embs[U:]

    # 2) build train‐map and eval‐map on CPU
    train_map = defaultdict(set)
    for u,i in train_data.orig_interactions.cpu().tolist():
        train_map[u].add(i - U)

    eval_map = defaultdict(list)
    for u,i in eval_data.orig_interactions.cpu().tolist():
        eval_map[u].append(i - U)

    # 3) per‐user ranking
    for u, true_items in eval_map.items():
        if not true_items:
            continue

        # Perform multiplication on the GPU, then move result to CPU
        scores = (item_embs @ user_embs[u]).cpu().numpy()

        # mask training items
        for ti in train_map[u]:
            scores[ti] = -1e9

        # top‐k
        topk = np.argpartition(-scores, k)[:k]
        hits = [1 if t in true_items else 0 for t in topk]

        # HR
        hr = 1.0 if any(hits) else 0.0
        # NDCG
        dcg  = sum(h / math.log2(idx+2) for idx,h in enumerate(hits))
        idcg = sum(1.0/math.log2(i+2) for i in range(min(len(true_items), k)))
        ndcg = dcg/idcg if idcg>0 else 0.0
        # prec/rec/f1
        prec = sum(hits)/k
        rec  = sum(hits)/len(true_items)
        f1   = 2*prec*rec/(prec+rec) if (prec+rec)>0 else 0.0
        # MRR
        mrr  = next((1.0/(idx+1) for idx,h in enumerate(hits) if h), 0.0)

        metrics['hr'].append(hr)
        metrics['ndcg'].append(ndcg)
        metrics['precision'].append(prec)
        metrics['recall'].append(rec)
        metrics['f1'].append(f1)
        metrics['mrr'].append(mrr)

    return {m: np.mean(v) for m,v in metrics.items()}

def train_graph_epoch_regularized(model,
                                  train_data,
                                  optimizer,
                                  device,
                                  batch_size=16384,
                                  l2_reg=1e-4):
    """
    One epoch of BPR + L2 training for graph models.
    """
    model.train()
    edge_index = train_data.edge_index.to(device)
    pos_edges  = train_data.orig_interactions.to(device)
    E = pos_edges.size(0)
    perm = torch.randperm(E, device=device)
    n_batches = (E + batch_size - 1) // batch_size
    total_loss = 0.0

    for b in range(n_batches):
        idx   = perm[b*batch_size:(b+1)*batch_size]
        users = pos_edges[idx,0]
        items = pos_edges[idx,1]
        negs  = torch.randint(model.num_users,
                              model.num_users+model.num_items,
                              size=users.size(), device=device)

        optimizer.zero_grad()
        pos_scores = model.predict(edge_index, users, items)
        neg_scores = model.predict(edge_index, users, negs)
        bpr_loss = -torch.log(torch.sigmoid(pos_scores - neg_scores)+1e-8).mean()
        l2_loss  = l2_reg * (
                model.user_emb(users).pow(2).mean()
              + model.item_emb(items-model.num_users).pow(2).mean()
              + model.item_emb(negs-model.num_users).pow(2).mean()
        )
        loss = bpr_loss + l2_loss
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(),1.0)
        optimizer.step()
        total_loss += loss.item()

    return total_loss / n_batches

#===============================================================================
# 3) Graph Models
#===============================================================================
class GATRec(nn.Module):
    def __init__(self, num_users, num_items,
                 embedding_dim=64, hidden_dim=64,
                 heads=4, n_layers=3):
        super().__init__()
        self.num_users = num_users
        self.num_items = num_items
        self.user_emb  = nn.Embedding(num_users, embedding_dim)
        self.item_emb  = nn.Embedding(num_items, embedding_dim)
        self.convs     = nn.ModuleList()
        # first layer
        self.convs.append(GATConv(embedding_dim, hidden_dim, heads=heads, dropout=0.6))
        # middle
        for _ in range(n_layers-2):
            self.convs.append(GATConv(hidden_dim*heads, hidden_dim, heads=heads, dropout=0.6))
        # final
        self.convs.append(GATConv(hidden_dim*heads, hidden_dim, heads=1, concat=False, dropout=0.6))
        self.dropout = nn.Dropout(0.6)
        # init
        nn.init.xavier_uniform_(self.user_emb.weight)
        nn.init.xavier_uniform_(self.item_emb.weight)

    def forward(self, edge_index):
        x = torch.cat([self.user_emb.weight, self.item_emb.weight], dim=0)
        for conv in self.convs[:-1]:
            x = conv(x, edge_index)
            x = F.relu(x); x = self.dropout(x)
        x = self.convs[-1](x, edge_index)
        return x

    def predict(self, edge_index, u, i):
        embs = self(edge_index)
        return (embs[u]*embs[i]).sum(-1)

class LightGCNConv(MessagePassing):
    def __init__(self): super().__init__(aggr='add')
    def forward(self, x, edge_index):
        row,col = edge_index
        deg = degree(torch.cat([row,col]), x.size(0), dtype=x.dtype)
        d   = deg.pow(-0.5); d[d==float('inf')] = 0
        norm= d[row]*d[col]
        return self.propagate(edge_index, x=x, norm=norm)
    def message(self,x_j,norm):
        return norm.view(-1,1)*x_j

class LightGCN(nn.Module):
    def __init__(self, num_users, num_items,
                 embedding_dim=64, n_layers=3):
        super().__init__()
        self.num_users = num_users
        self.num_items = num_items
        self.user_emb  = nn.Embedding(num_users, embedding_dim)
        self.item_emb  = nn.Embedding(num_items, embedding_dim)
        self.conv      = LightGCNConv()
        self.n_layers  = n_layers
        nn.init.xavier_uniform_(self.user_emb.weight)
        nn.init.xavier_uniform_(self.item_emb.weight)

    def forward(self, edge_index):
        x0 = torch.cat([self.user_emb.weight, self.item_emb.weight], dim=0)
        embs = [x0]; x=x0
        for _ in range(self.n_layers):
            x = self.conv(x, edge_index)
            embs.append(x)
        return torch.stack(embs,0).mean(0)

    def predict(self, edge_index, u, i):
        embs = self(edge_index)
        return (embs[u]*embs[i]).sum(-1)

#===============================================================================
# 4) SASRec + train/eval
#===============================================================================
class SASRecBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout):
        super().__init__()
        self.ln1  = nn.LayerNorm(d_model, eps=1e-8)
        self.attn = nn.MultiheadAttention(d_model,n_heads,
                                          dropout=dropout,
                                          batch_first=True)
        self.drop1= nn.Dropout(dropout)
        self.ln2  = nn.LayerNorm(d_model, eps=1e-8)
        self.ffn  = nn.Sequential(
            nn.Linear(d_model,d_ff), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(d_ff,d_model), nn.Dropout(dropout)
        )
    def forward(self, x, causal_mask, pad_mask):
        z = x; x_n = self.ln1(z)
        a,_ = self.attn(x_n, x_n, x_n,
                       attn_mask=causal_mask,
                       key_padding_mask=pad_mask)
        x = z + self.drop1(a)
        z2 = x; x2_n = self.ln2(z2)
        f  = self.ffn(x2_n)
        return z2 + f

class SASRec(nn.Module):
    PAD = 0
    def __init__(self, num_users, num_items,
                 maxlen=50, hidden=50,
                 blocks=2, heads=1, drop=0.2):
        super().__init__()
        self.maxlen   = maxlen
        self.item_emb = nn.Embedding(num_items+1, hidden, padding_idx=self.PAD)
        self.pos_emb  = nn.Embedding(maxlen+1, hidden, padding_idx=0)
        self.drop     = nn.Dropout(drop)
        self.blocks   = nn.ModuleList([
            SASRecBlock(hidden, heads, hidden, drop) for _ in range(blocks)
        ])
        self.ln_final = nn.LayerNorm(hidden, eps=1e-8)
        # init
        for m in self.modules():
            if isinstance(m, (nn.Embedding, nn.Linear)):
                nn.init.xavier_uniform_(m.weight)
                if hasattr(m,'bias') and m.bias is not None:
                    nn.init.zeros_(m.bias)

    def log2feats(self, logs):
        B,L = logs.shape
        x   = self.item_emb(logs) * math.sqrt(self.item_emb.embedding_dim)
        pos = torch.arange(1, L+1, device=logs.device).unsqueeze(0).expand(B,-1)
        pos = pos * (logs != self.PAD)
        x   = x + self.pos_emb(pos)
        x   = self.drop(x)
        pad_mask = logs.eq(self.PAD)
        causal   = torch.triu(
            torch.ones(L,L,device=logs.device,dtype=torch.bool), diagonal=1
        )
        for blk in self.blocks:
            x = blk(x, causal_mask=causal, pad_mask=pad_mask)
        return self.ln_final(x)

    def forward(self, _, log_seqs, pos_seqs, neg_seqs):
        device = next(self.parameters()).device
        logs = torch.tensor(log_seqs, dtype=torch.long, device=device)
        pos  = torch.tensor(pos_seqs, dtype=torch.long, device=device)
        neg  = torch.tensor(neg_seqs, dtype=torch.long, device=device)
        feats= self.log2feats(logs)        # [B,L,H]
        final= feats[:,-1,:]               # [B,H]
        pe   = self.item_emb(pos)          # [B,H]
        ne   = self.item_emb(neg)          # [B,H]
        return (final*pe).sum(-1), (final*ne).sum(-1)

    def predict(self, _, log_seqs, item_idx):
        device = next(self.parameters()).device
        logs = torch.tensor(log_seqs, dtype=torch.long, device=device)
        items= torch.tensor(item_idx, dtype=torch.long, device=device)
        feats= self.log2feats(logs)        # [B,L,H]
        final= feats[:,-1,:]               # [B,H]
        ie   = self.item_emb(items)        # [I,H]
        return (ie * final).sum(-1)

class SASRecDataset(Dataset):
    def __init__(self, logs, pos, neg):
        self.logs = logs
        self.pos  = pos
        self.neg  = neg
    def __len__(self):
        return len(self.logs)
    def __getitem__(self, idx):
        return self.logs[idx], self.pos[idx], self.neg[idx]

def create_sasrec_sequences(df, num_items):
    """
    user -> [shifted+1 item IDs clipped to 1..num_items]
    """
    user_seq = {}
    for u, grp in df.sort_values(['user_idx','timestamp']).groupby('user_idx'):
        # The clip method here is a safe way to handle potential index errors
        seq = (grp.movie_idx.values + 1).clip(1, num_items).astype(int).tolist()
        user_seq[u] = seq
    return user_seq

def create_sasrec_training_data(user_seqs, num_items, maxlen):
    """
    Precompute (log_seq, pos, neg) for each subsequence.
    """
    logs, pos, neg = [], [], []
    all_items = set(range(1, num_items+1))
    for u, seq in user_seqs.items():
        for i in range(1, len(seq)):
            inp = seq[:i]
            tgt = seq[i]
            negs = list(all_items - set(inp))
            nid  = np.random.choice(negs)
            padded = [0]*(maxlen - len(inp)) + inp[-maxlen:]
            logs.append(padded); pos.append(tgt); neg.append(nid)
    return (np.array(logs, dtype=np.int64),
            np.array(pos,  dtype=np.int64),
            np.array(neg,  dtype=np.int64))

def train_sasrec(model, dataset, optimizer, device, bs=512):  # INCREASED BATCH SIZE
    """
    One epoch over the SASRecDataset.
    """
    loader = DataLoader(dataset, batch_size=bs, shuffle=True)
    model.train()
    for logs, pos, neg in loader:
        optimizer.zero_grad()
        pl, nl = model(None,
                       logs.cpu().numpy(),
                       pos.cpu().numpy(),
                       neg.cpu().numpy())
        loss = F.binary_cross_entropy_with_logits(pl, torch.ones_like(pl)) \
             + F.binary_cross_entropy_with_logits(nl, torch.zeros_like(nl))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

def evaluate_sasrec_model(model,
                          train_seqs,
                          eval_seqs,
                          num_items,
                          device,
                          k=10,
                          maxlen=50,
                          val_seqs=None):
    """
    Vectorized eval for SASRec: HR/ndcg/prec/rec/f1/mrr
    """
    model.eval()
    users = sorted(set(train_seqs) & set(eval_seqs))
    U = len(users)
    seqs = torch.zeros((U, maxlen), dtype=torch.long, device=device)

    # build prefix matrix
    for i,u in enumerate(users):
        hist = train_seqs[u].copy()
        if val_seqs and u in val_seqs:
            hist += val_seqs[u]
        tail = hist[-maxlen:]
        tail = [min(max(1,int(x)), num_items) for x in tail]
        L = len(tail)
        if L>0:
            seqs[i, maxlen-L:] = torch.tensor(tail, device=device)

    with torch.no_grad():
        feats = model.log2feats(seqs)   # [U,L,H]
        final= feats[:,-1,:]            # [U,H]
        items= torch.arange(1, num_items+1, device=device)
        ie   = model.item_emb(items)    # [I,H]
        scores = final @ ie.T           # [U,I]

    # mask seen
    for i,u in enumerate(users):
        seen = set(train_seqs[u])
        if val_seqs and u in val_seqs:
            seen |= set(val_seqs[u])
        mask = [s-1 for s in seen if 1<=s<=num_items]
        if mask:
            scores[i, mask] = -1e9

    topk = torch.topk(scores, k, dim=1).indices.cpu().tolist()
    hr, ndcg, prec, rec, f1, mrr = [],[],[],[],[],[]
    for i,u in enumerate(users):
        truth = set(eval_seqs[u])
        hits  = [1 if (t+1) in truth else 0 for t in topk[i]]
        hr.append(1.0 if any(hits) else 0.0)
        dcg  = sum(h/math.log2(j+2) for j,h in enumerate(hits))
        idcg = sum(1.0/math.log2(j+2) for j in range(min(len(truth),k)))
        ndcg.append(dcg/idcg if idcg>0 else 0.0)
        p = sum(hits)/k
        r = sum(hits)/len(truth)
        prec.append(p); rec.append(r)
        f1.append(2*p*r/(p+r) if (p+r)>0 else 0.0)
        mrr.append(next((1.0/(j+1) for j,h in enumerate(hits) if h), 0.0))

    return {
      'hr':        np.mean(hr),
      'ndcg':      np.mean(ndcg),
      'precision': np.mean(prec),
      'recall':    np.mean(rec),
      'f1':        np.mean(f1),
      'mrr':       np.mean(mrr)
    }

#===============================================================================
# 5) Cross‐Validation
#===============================================================================
def robust_cross_validation_graph(model_cls, params, ratings_df,
                                  n_folds=3, device='cuda'):
    print(f"\n=== Graph CV ({n_folds} folds) ===")
    scores=[]
    for fold in range(n_folds):
        print(f"\n-- fold {fold+1}")
        set_seed(42+fold)
        tr,va,te,nu,ni = create_train_val_test_split(
            ratings_df, min_interactions=5, seed=42+fold
        )
        train_data = create_pyg_data(tr, nu, ni).to(device)
        val_data   = create_pyg_data(va, nu, ni).to(device)
        test_data  = create_pyg_data(te, nu, ni).to(device)

        model = model_cls(nu, ni, **params).to(device)
        opt   = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

        best_hr, pat = 0.0, 0
        ckpt = f'gat_fold_{fold}.pt'
        for ep in range(201):  # CHANGED TO 201 EPOCHS
            train_graph_epoch_regularized(model, train_data, opt, device)
            if (ep+1)%20==0:  # EVALUATE EVERY 20 EPOCHS
                m = evaluate_graph_model(model, train_data, val_data, device)
                print(f"  ep {ep+1:02d} val HR@10 = {m['hr']:.4f}")
                if m['hr']>best_hr:
                    best_hr, pat = m['hr'], 0
                    torch.save(model.state_dict(), ckpt)
                else:
                    pat += 1
                if pat>=5:
                    print("  early stopping"); break

        model.load_state_dict(torch.load(ckpt))
        tm = evaluate_graph_model(model, train_data, test_data, device)
        print(f"  TEST HR@10 = {tm['hr']:.4f}")
        scores.append(tm['hr'])
        os.remove(ckpt)

    m, s = np.mean(scores), np.std(scores)
    print(f"\nGraph CV HR@10 = {m:.4f} ± {s:.4f}")
    return m, s

def robust_cross_validation_sasrec(params, ratings_df,
                                   n_folds=3, device='cuda'):
    print(f"\n=== SASRec CV ({n_folds} folds) ===")
    scores=[]
    for fold in range(n_folds):
        print(f"\n-- fold {fold+1}")
        set_seed(42+fold)
        tr,va,te,nu,ni = create_train_val_test_split(
            ratings_df, min_interactions=5, seed=42+fold
        )
        train_seqs = create_sasrec_sequences(tr, num_items=ni)
        val_seqs   = create_sasrec_sequences(va, num_items=ni)
        test_seqs  = create_sasrec_sequences(te, num_items=ni)

        logs,pos,neg = create_sasrec_training_data(train_seqs, ni, params['maxlen'])
        ds = SASRecDataset(logs,pos,neg)

        model = SASRec(nu, ni,
                       maxlen=params['maxlen'],
                       hidden=params['hidden'],
                       blocks=params['blocks'],
                       heads=params['heads'],
                       drop=params['drop']).to(device)
        opt   = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

        best_hr, pat = 0.0, 0
        ckpt = f'sasrec_fold_{fold}.pt'
        for ep in range(201):  # CHANGED TO 201 EPOCHS
            train_sasrec(model, ds, opt, device)
            if (ep+1)%20==0:  # EVALUATE EVERY 20 EPOCHS
                m = evaluate_sasrec_model(model,
                                          train_seqs, val_seqs,
                                          ni, device,
                                          k=10, maxlen=params['maxlen'])
                print(f"  ep {ep+1:02d} val HR@10 = {m['hr']:.4f}")
                if m['hr']>best_hr:
                    best_hr, pat = m['hr'], 0
                    torch.save(model.state_dict(), ckpt)
                else:
                    pat += 1
                if pat>=5:
                    print("  early stopping"); break

        model.load_state_dict(torch.load(ckpt))
        tm = evaluate_sasrec_model(model,
                                   train_seqs, test_seqs,
                                   ni, device,
                                   k=10, maxlen=params['maxlen'],
                                   val_seqs=val_seqs)
        print(f"  TEST HR@10 = {tm['hr']:.4f}")
        scores.append(tm['hr'])
        os.remove(ckpt)

    m,s = np.mean(scores), np.std(scores)
    print(f"\nSASRec CV HR@10 = {m:.4f} ± {s:.4f}")
    return m, s

#===============================================================================
# 6) Main - MODIFIED FOR ML-1M DATASET
#===============================================================================
if __name__=="__main__":
    set_seed(42)

    # Load ml-1m.txt dataset (space-separated: user_id item_id)
    possible = ['/content/data/ml-1m.txt', 'data/ml-1m.txt', './ml-1m.txt', 'ml-1m.txt']
    path = next((p for p in possible if os.path.exists(p)), None)
    if path is None:
        raise FileNotFoundError("ml-1m.txt not found")

    # Read the ml-1m.txt file (format: user_id item_id per line)
    data = []
    with open(path, 'r') as f:
        for line_num, line in enumerate(f):
            parts = line.strip().split()
            if len(parts) >= 2:
                try:
                    user_id = int(parts[0])
                    item_id = int(parts[1])
                    data.append([user_id, item_id, line_num])  # Use line number as timestamp
                except ValueError:
                    continue

    # Create DataFrame with the correct column names
    ratings_df = pd.DataFrame(data, columns=['user_id', 'item_id', 'timestamp'])
    print(f"Loaded {len(ratings_df)} interactions from ml-1m dataset")
    print(f"Users: {ratings_df.user_id.nunique()}, Items: {ratings_df.item_id.nunique()}")

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Device:", device)

    # Graph CV
    gat_params  = {'embedding_dim':128,'hidden_dim':128,'heads':4,'n_layers':2}
    lgcn_params = {'embedding_dim':128,'n_layers':2}

    gm,gs = robust_cross_validation_graph(GATRec,   gat_params,  ratings_df,
                                         n_folds=3, device=device)
    lm,ls = robust_cross_validation_graph(LightGCN, lgcn_params, ratings_df,
                                         n_folds=3, device=device)

    # SASRec CV
    sasrec_params = {'maxlen':50,'hidden':50,'blocks':2,'heads':1,'drop':0.2}
    sm,ss = robust_cross_validation_sasrec(sasrec_params, ratings_df,
                                           n_folds=3, device=device)

    print("\n=== SUMMARY ===")
    print(f"GATRec   HR@10 = {gm:.4f} ± {gs:.4f}")
    print(f"LightGCN HR@10 = {lm:.4f} ± {ls:.4f}")
    print(f"SASRec   HR@10 = {sm:.4f} ± {ss:.4f}")

Loaded 999611 interactions from ml-1m dataset
Users: 6040, Items: 3416
Device: cuda

=== Graph CV (3 folds) ===

-- fold 1
Split sizes: train=804987, val=97312, test=97312
#users=6040, #items=3415
  ep 20 val HR@10 = 0.2791
  ep 40 val HR@10 = 0.2854
  ep 60 val HR@10 = 0.3263
  ep 80 val HR@10 = 0.3243
  ep 100 val HR@10 = 0.3310
  ep 120 val HR@10 = 0.3379
  ep 140 val HR@10 = 0.3389
  ep 160 val HR@10 = 0.3411
  ep 180 val HR@10 = 0.3417
  ep 200 val HR@10 = 0.3483
  TEST HR@10 = 0.3005

-- fold 2
Split sizes: train=804987, val=97312, test=97312
#users=6040, #items=3415
  ep 20 val HR@10 = 0.3003
  ep 40 val HR@10 = 0.2972
  ep 60 val HR@10 = 0.2907
  ep 80 val HR@10 = 0.3288
  ep 100 val HR@10 = 0.3250
  ep 120 val HR@10 = 0.3315
  ep 140 val HR@10 = 0.3275
  ep 160 val HR@10 = 0.3341
  ep 180 val HR@10 = 0.3353
  ep 200 val HR@10 = 0.3348
  TEST HR@10 = 0.2978

-- fold 3
Split sizes: train=804987, val=97312, test=97312
#users=6040, #items=3415
  ep 20 val HR@10 = 0.2760
  ep 40 va

  seq = (grp.movie_idx.values + 1).clip(1, num_items).astype(int).tolist()


  ep 20 val HR@10 = 0.2002
  ep 40 val HR@10 = 0.2002
  ep 60 val HR@10 = 0.2002
  ep 80 val HR@10 = 0.2002
  ep 100 val HR@10 = 0.2002
  ep 120 val HR@10 = 0.2002
  early stopping
  TEST HR@10 = 0.1907

-- fold 2
Split sizes: train=804987, val=97312, test=97312
#users=6040, #items=3415


  seq = (grp.movie_idx.values + 1).clip(1, num_items).astype(int).tolist()


  ep 20 val HR@10 = 0.2002
  ep 40 val HR@10 = 0.2002
  ep 60 val HR@10 = 0.2002
  ep 80 val HR@10 = 0.2002
  ep 100 val HR@10 = 0.2002
  ep 120 val HR@10 = 0.2002
  early stopping
  TEST HR@10 = 0.1907

-- fold 3
Split sizes: train=804987, val=97312, test=97312
#users=6040, #items=3415


  seq = (grp.movie_idx.values + 1).clip(1, num_items).astype(int).tolist()


  ep 20 val HR@10 = 0.2002
  ep 40 val HR@10 = 0.2002
  ep 60 val HR@10 = 0.2002
  ep 80 val HR@10 = 0.2002
  ep 100 val HR@10 = 0.2002
  ep 120 val HR@10 = 0.2002
  early stopping
  TEST HR@10 = 0.1907

SASRec CV HR@10 = 0.1907 ± 0.0000

=== SUMMARY ===
GATRec   HR@10 = 0.3015 ± 0.0035
LightGCN HR@10 = 0.0323 ± 0.0000
SASRec   HR@10 = 0.1907 ± 0.0000
