# SASRec on MovieLens
This notebook demonstrates the use of the sequenial recommendation algorithm, **SASRec**, to predict the next movie for a particular user.

## Imports

In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from collections import defaultdict
import random

## Load and prepare data
***NOTE***  
It is assumed that that MovieLens-1M dataset has already been downloaded and placed next to this notebook in a folder named `ml-1m`.

This is what'll happen below:
- After loading the data, sort each user's sequence of movie ratings chronologically.
- If a user's sequence is less than 3 movie ratings long, use all movie ratings for training. Otherwise, use all but the last two ratings for training, second-to-last rating for validation and the last rating for testing.
- After that, we'll sample "negative" examples to use for training alongside the actual movies that were selected/rated by the user.

In [2]:
# ====================================================
# 1. Preprocess MovieLens 1M
# ====================================================
ratings = pd.read_csv(
    "ml-1m/ratings.dat",
    sep="::",
    engine="python",
    names=["user", "item", "rating", "timestamp"]
)

# Keep ratings >= 4
ratings = ratings[ratings["rating"] >= 4]

# Map to consecutive IDs
user2id = {u: i+1 for i, u in enumerate(ratings["user"].unique())}
item2id = {m: i+1 for i, m in enumerate(ratings["item"].unique())}
ratings["user"] = ratings["user"].map(user2id)
ratings["item"] = ratings["item"].map(item2id)

n_users = len(user2id)
n_items = len(item2id)

# Build user sequences
user_sequences = defaultdict(list)
for row in ratings.itertuples(index=False):
    user_sequences[row.user].append((row.item, row.timestamp))

# Sort by time
for u in user_sequences:
    user_sequences[u] = [x[0] for x in sorted(user_sequences[u], key=lambda x: x[1])]

# Leave-one-out split
train_seqs, valid_seqs, test_seqs = {}, {}, {}
for u, items in user_sequences.items():
    if len(items) < 3:
        train_seqs[u] = items
        valid_seqs[u], test_seqs[u] = [], []
    else:
        train_seqs[u] = items[:-2]
        valid_seqs[u] = [items[-2]]
        test_seqs[u] = [items[-1]]

In [3]:
# ====================================================
# 2. Dataset + Negative Sampling
# ====================================================
MAX_SEQ_LEN = 50  # max sequence length

class SASRecDataset(Dataset):
    def __init__(self, user_train, n_items, num_negatives=1):
        self.user_train = user_train
        self.users = list(user_train.keys())
        self.n_items = n_items
        self.num_negatives = num_negatives

    def __len__(self):
        return len(self.users)

    def __getitem__(self, idx):
        user = self.users[idx]
        seq = self.user_train[user]

        if len(seq) < 2:
            return torch.zeros(MAX_SEQ_LEN, dtype=torch.long), torch.tensor(0, dtype=torch.long), torch.tensor(0, dtype=torch.long)

        cut = random.randint(1, len(seq) - 1)
        prefix, target = seq[:cut], seq[cut]

        # Pad prefix
        seq_padded = [0] * (MAX_SEQ_LEN - len(prefix)) + prefix[-MAX_SEQ_LEN:]
        seq_tensor = torch.tensor(seq_padded, dtype=torch.long)

        # Negative sample
        neg = random.randint(1, self.n_items)
        while neg in seq:
            neg = random.randint(1, self.n_items)

        return seq_tensor, torch.tensor(target, dtype=torch.long), torch.tensor(neg, dtype=torch.long)


## SASRec model

In [4]:
# ====================================================
# 3. SASRec Model
# ====================================================
class SASRec(nn.Module):
    def __init__(self, n_items, hidden_dim=64, max_len=50, n_heads=2, n_layers=2, dropout=0.2):
        super().__init__()
        self.item_emb = nn.Embedding(n_items+1, hidden_dim, padding_idx=0)
        self.pos_emb = nn.Embedding(max_len, hidden_dim)
        self.dropout = nn.Dropout(dropout)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=n_heads,
            dim_feedforward=hidden_dim*4,
            dropout=dropout,
            activation="gelu",
            batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.hidden_dim = hidden_dim
        self.max_len = max_len

    def forward(self, seq):
        device = seq.device
        positions = torch.arange(self.max_len, device=device).unsqueeze(0)
        x = self.item_emb(seq) + self.pos_emb(positions)
        x = self.dropout(x)

        mask = (seq == 0)
        x = self.encoder(x, src_key_padding_mask=mask)

        out = x[:, -1, :]  # last position representation
        return out

    def predict(self, seq, candidates):
        seq_repr = self.forward(seq)  # [B, H]
        item_repr = self.item_emb(candidates)  # [B, H]
        return (seq_repr * item_repr).sum(dim=-1)

## Train & Predict

In [5]:
# ====================================================
# 4. Training
# ====================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_dataset = SASRecDataset(train_seqs, n_items)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

model = SASRec(n_items, hidden_dim=64, max_len=MAX_SEQ_LEN).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(20):  # small demo run
    model.train()
    total_loss = 0
    for seq, pos, neg in train_loader:
        seq, pos, neg = seq.to(device), pos.to(device), neg.to(device)

        seq_repr = model(seq)
        pos_emb = model.item_emb(pos)
        neg_emb = model.item_emb(neg)

        pos_score = (seq_repr * pos_emb).sum(dim=-1)
        neg_score = (seq_repr * neg_emb).sum(dim=-1)

        loss = -torch.mean(torch.log(torch.sigmoid(pos_score - neg_score)))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}")

Epoch 1, Loss: 4.2063
Epoch 2, Loss: 3.4697
Epoch 3, Loss: 2.9608
Epoch 4, Loss: 2.6624
Epoch 5, Loss: 2.3778
Epoch 6, Loss: 2.2103
Epoch 7, Loss: 1.9235
Epoch 8, Loss: 1.7298
Epoch 9, Loss: 1.5111
Epoch 10, Loss: 1.5157
Epoch 11, Loss: 1.3694
Epoch 12, Loss: 1.2453
Epoch 13, Loss: 1.1094
Epoch 14, Loss: 1.0494
Epoch 15, Loss: 0.9634
Epoch 16, Loss: 0.8558
Epoch 17, Loss: 0.8164
Epoch 18, Loss: 0.7509
Epoch 19, Loss: 0.7098
Epoch 20, Loss: 0.6692


In [6]:
# Build reverse maps
id2item = {v: k for k, v in item2id.items()}

# Load movies.dat for movie names
movies = pd.read_csv(
    "ml-1m/movies.dat",
    sep="::",
    engine="python",
    names=["item", "title", "genres"], encoding='latin-1'
)

# Map movie IDs to our new indexing system
movies["item"] = movies["item"].map(item2id)
itemid2name = dict(zip(movies["item"], movies["title"]))

In [7]:
# ====================================================
# 5. Prediction Example
# ====================================================
model.eval()
with torch.no_grad():
    # Pick a random user
    # u = list(test_seqs.keys())[0]
    u = random.choice(list(test_seqs.keys()))
    seq = train_seqs[u][-MAX_SEQ_LEN:]
    seq_padded = [0]*(MAX_SEQ_LEN - len(seq)) + seq
    seq_tensor = torch.tensor([seq_padded], dtype=torch.long).to(device)

    # Candidate set: true item + 99 random negatives
    true_item = test_seqs[u][0]
    candidates = [true_item] + random.sample(range(1, n_items+1), 99)
    candidates_tensor = torch.tensor(candidates, dtype=torch.long).to(device)

    # Predict scores
    scores = model.predict(seq_tensor.repeat(len(candidates), 1),
                           candidates_tensor).cpu().numpy().flatten()

    ranked = np.argsort(-scores)
    rank_of_true = list(ranked).index(0) + 1

    # Get names
    true_item_name = itemid2name[true_item]
    candidates_names = [itemid2name[c] for c in candidates]
    
    print("True item:", true_item, "->", true_item_name)
    print("Candidates:", candidates_names[:5], "...")  # show a few
    print("True item rank:", rank_of_true, "out of", len(candidates))

True item: 2619 -> American Buffalo (1996)
Candidates: ['American Buffalo (1996)', 'Life Less Ordinary, A (1997)', 'Shaggy D.A., The (1976)', 'Serpico (1973)', 'Steel (1997)'] ...
True item rank: 66 out of 100


  output = torch._nested_tensor_from_mask(


In [8]:
# ====================================================
# 6. Model Evaluation
# ====================================================
def evaluate_model(model, train_seqs, test_seqs, n_items, itemid2name, K=10, num_neg=100):
    """
    Evaluate SASRec with Hit@K and NDCG@K.

    Args:
        model: trained SASRec model
        train_seqs: dict {user: [train items]}
        test_seqs: dict {user: [test item]}
        n_items: total number of items
        itemid2name: dict mapping item_id -> movie name
        K: cutoff for metrics
        num_neg: number of negative samples per test

    Returns:
        (hit_rate, ndcg)
    """
    model.eval()
    hits, ndcgs = [], []
    example_outputs = []

    with torch.no_grad():
        for u in test_seqs:
            if len(test_seqs[u]) == 0:
                continue

            true_item = test_seqs[u][0]
            seq = train_seqs[u][-MAX_SEQ_LEN:]
            seq_padded = [0]*(MAX_SEQ_LEN - len(seq)) + seq
            seq_tensor = torch.tensor([seq_padded], dtype=torch.long).to(device)

            # Candidate set: true item + random negatives
            candidates = [true_item] + random.sample(range(1, n_items+1), num_neg)
            candidates_tensor = torch.tensor(candidates, dtype=torch.long).to(device)

            scores = model.predict(seq_tensor.repeat(len(candidates), 1),
                                   candidates_tensor).cpu().numpy().flatten()

            ranked = np.argsort(-scores)
            rank_of_true = list(ranked).index(0)  # 0 = index of true item in candidates

            # Metrics
            if rank_of_true < K:
                hits.append(1)
                ndcgs.append(1 / np.log2(rank_of_true + 2))  # rank starts from 0
            else:
                hits.append(0)
                ndcgs.append(0)

            # Save some qualitative examples
            if len(example_outputs) < 5:  # just show 5 users
                topk_idx = ranked[:K]
                topk_items = [candidates[i] for i in topk_idx]
                topk_names = [itemid2name[it] for it in topk_items]

                example_outputs.append({
                    "true_item": itemid2name[true_item],
                    "topk": topk_names
                })

    hit_rate = np.mean(hits)
    ndcg = np.mean(ndcgs)

    print(f"Hit@{K}: {hit_rate:.4f}, NDCG@{K}: {ndcg:.4f}\n")

    print("Sample recommendations:")
    for ex in example_outputs:
        print("True item:", ex["true_item"])
        print("Top-K predictions:", ex["topk"])
        print("---")

    return hit_rate, ndcg

In [9]:
hit, ndcg = evaluate_model(model, train_seqs, test_seqs, n_items, itemid2name, K=10, num_neg=100)

Hit@10: 0.2471, NDCG@10: 0.1168

Sample recommendations:
True item: Pocahontas (1995)
Top-K predictions: ['In Search of the Castaways (1962)', 'Soylent Green (1973)', 'Double Indemnity (1944)', 'American President, The (1995)', 'Dead Again (1991)', 'Above the Rim (1994)', 'Thing, The (1982)', 'Only Angels Have Wings (1939)', 'Farewell My Concubine (1993)', 'Topaz (1969)']
---
True item: Lost World: Jurassic Park, The (1997)
Top-K predictions: ['Bowfinger (1999)', 'Fight Club (1999)', 'Father of the Bride Part II (1995)', 'Austin Powers: The Spy Who Shagged Me (1999)', 'Married to the Mob (1988)', 'Stigmata (1999)', 'Lost World: Jurassic Park, The (1997)', 'Replacement Killers, The (1998)', 'Long Goodbye, The (1973)', 'Dog of Flanders, A (1999)']
---
True item: Little Mermaid, The (1989)
Top-K predictions: ['Little Mermaid, The (1989)', 'Rocky (1976)', 'Inventing the Abbotts (1997)', 'Cell, The (2000)', 'Rear Window (1954)', 'Alvarez Kelly (1966)', 'Bandit Queen (1994)', 'Wisdom (1986)'