# Cross-Domain Recommendation System Development
This notebook is an experiment in building a cross-domain recommendation system using the Amazon Reviews dataset. It uses the best model from the single-domain experiments and extends it to handle multiple domains. The dataset is the same as in the single-domain experiments, but now will combine data from two different domains.

In [1]:
import os
import random
import numpy as np
import pandas as pd
import time
import gc
import matplotlib.pyplot as plt
from collections import defaultdict

os.environ["HF_HOME"] = "D:/Python Projects/recommendation_system"
os.environ["HF_DATASETS_CACHE"] = "D:/Python Projects/recommendation_system/recsys/data"
os.environ["TRANSFORMERS_CACHE"] = "D:/Python Projects/recommendation_system/recsys/models"

# os.environ["HF_HOME"] = "E:/Python Scripts/recsys"
# os.environ['HF_DATASETS_CACHE'] = "E:/Python Scripts/recsys/data"
# os.environ['TRANSFORMERS_CACHE'] = "E:/Python Scripts/recsys/models"

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder
from datasets import load_dataset, Features, Value
from tqdm import tqdm
from tensorboardX import SummaryWriter

In [2]:
SEED = 42
def set_seed(seed=SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(SEED)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("DEVICE:", DEVICE)

DEVICE: cuda


## Single-domain development on best model (SASRec)

In [3]:
HF_DATASET = "McAuley-Lab/Amazon-Reviews-2023"

def load_amazon_reviews(domain, save_dir="data", max_items=None, seed=SEED):
    os.makedirs(save_dir, exist_ok=True)
    filepath = f"{save_dir}/amazon_reviews_{domain}.csv"

    if not os.path.exists(filepath):
        print(f"File {filepath} not found. Downloading dataset for domain '{domain}'...")
        ds = load_dataset(
            "McAuley-Lab/Amazon-Reviews-2023",
            f"raw_review_{domain}",
            split="full",
            trust_remote_code=True,
        )

        # Keep only needed columns
        ds = ds.select_columns(["user_id", "parent_asin", "rating", "timestamp"])
        ds = ds.rename_columns({"user_id": "user", "parent_asin": "item"})
        ds = ds.cast(Features({
            "user": Value("string"),
            "item": Value("string"),
            "rating": Value("float32"),
            "timestamp": Value("int64"),
        }))

        # Convert to pandas (Arrow zero-copy where possible)
        df = ds.to_pandas()
        df.insert(3, "domain", domain)
        df.to_csv(f"{save_dir}/amazon_reviews_{domain}.csv", index=False)
        print(f"Saved amazon_reviews_{domain}.csv to {save_dir}/")

    final_df = pd.read_csv(filepath)
    # Random subset if max_items is set
    if max_items is not None:
        k = min(max_items, len(final_df))
        final_df = final_df.sample(n=k, random_state=seed).reset_index(drop=True)
    print(f"Loaded {filepath} with {len(final_df)} rows.")
    return final_df

def preprocess_dataset(df, min_user_interactions=5, min_item_interactions=5):
    # Make it implicit
    df["label"] = 1.0
    user_counts = df.groupby("user").size()
    valid_users = user_counts[user_counts >= min_user_interactions].index
    item_counts = df.groupby("item").size()
    valid_items = item_counts[item_counts >= min_item_interactions].index
    df_filtered = df[df["user"].isin(valid_users) & df["item"].isin(valid_items)]
    print("After interactions filtering:", len(df_filtered), "rows,", df_filtered["user"].nunique(), "users,", df_filtered["item"].nunique(), "items")
    return df_filtered

def label_encoder(df, shift_item_id=False):
    user_enc = LabelEncoder()
    item_enc = LabelEncoder()
    domain_enc = LabelEncoder()
    df["user_id"] = user_enc.fit_transform(df["user"])
    df["item_id"] = item_enc.fit_transform(df["item"])
    if shift_item_id:
        df["item_id"] = df["item_id"] + 1  # Shift item IDs by 1 to reserve 0 for padding if needed
    df["domain_id"] = domain_enc.fit_transform(df["domain"])
    return df, user_enc, item_enc, domain_enc

### Dataset preparation

In [4]:
# New input
SOURCE_DOMAIN = "Books"

# Loading data from multiple domains
df = load_amazon_reviews(SOURCE_DOMAIN, max_items=10_000_000, seed=SEED)
print(f"Total rows in {SOURCE_DOMAIN}: {len(df)}")

# Preprocess the dataset
filtered_df = preprocess_dataset(df, min_user_interactions=20, min_item_interactions=20)
df_encoded, user_encoder, item_encoder, domain_encoder = label_encoder(filtered_df, shift_item_id=True)

NUM_USERS = df_encoded["user_id"].max() + 1
NUM_ITEMS = df_encoded["item_id"].max() + 1
NUM_DOMAINS = df_encoded["domain_id"].max() + 1
print(f"Number of users: {NUM_USERS}, Number of items: {NUM_ITEMS}, Number of domains: {NUM_DOMAINS}")

Loaded data/amazon_reviews_Books.csv with 10000000 rows.
Total rows in Books: 10000000
After interactions filtering: 318168 rows, 24942 users, 53240 items


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df["user_id"] = user_enc.fit_transform(df["user"])


Number of users: 24942, Number of items: 53241, Number of domains: 1


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df["item_id"] = item_enc.fit_transform(df["item"])
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df["item_id"] = df["item_id"] + 1  # Shift item IDs by 1 to reserve 0 for padding if needed
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df["domain_id"] = domain_enc.fit_transform(df["domain"])


In [5]:
def create_user_sequences(df):
    df_sorted = df.sort_values(["user_id", "timestamp"])
    user_sequences = {}
    for uid, group in df_sorted.groupby("user_id"):
        items = group["item_id"].tolist()
        user_sequences[uid] = items

    print(f"Number of users: {len(user_sequences)}")
    print(f"Max sequence length: {max(len(seq) for seq in user_sequences.values())}")
    print(f"Min sequence length: {min(len(seq) for seq in user_sequences.values())}")

    return user_sequences

# Create sequences
user_sequences = create_user_sequences(df_encoded)
pos_items_by_user = {u: set(seq) for u, seq in user_sequences.items()}

Number of users: 24942
Max sequence length: 415
Min sequence length: 1


In [6]:
def sequences_loo_split(user_sequences):
    train_seqs = {}
    val_data = {}
    test_data = {}

    for user, seq in user_sequences.items():
        if len(seq) < 3:  # Need at least 3 items for train/val/test
            continue

        train_seqs[user] = seq[:-2]  # All but last two
        val_data[user] = (seq[:-2], seq[-2])  # Train on all but last 2, predict second-to-last
        test_data[user] = (seq[:-1], seq[-1])  # Train on all but last, predict last

    print(f"Training sequences: {len(train_seqs)}")
    print(f"Validation users: {len(val_data)}")
    print(f"Test users: {len(test_data)}")

    return train_seqs, val_data, test_data

train_sequences, val_sequences, test_sequences = sequences_loo_split(user_sequences)
print(f"Sequences - Train: {len(train_sequences)}, Val: {len(val_sequences)}, Test: {len(test_sequences)}")

Training sequences: 22235
Validation users: 22235
Test users: 22235
Sequences - Train: 22235, Val: 22235, Test: 22235


### Dataset and DataLoader
SASRec uses sequences of user interactions to predict the next item. So for sequence `[i1, i2, i3, i4]`, the training samples are:
- Input: `[i1]` -> Target: `i2`
- Input: `[i1, i2]` -> Target: `i3`
- Input: `[i1, i2, i3]` -> Target: `i4`

In [7]:
class SASRecDataset(Dataset):
    def __init__(self, data, num_items, max_seq_len=50, pos_items_by_user=None, mode="train", neg_samples=1):
        self.num_items = num_items
        self.max_seq_len = max_seq_len
        self.mode = mode
        self.neg_samples = neg_samples
        self.all_pos = pos_items_by_user

        self.samples = []
        if mode == "train":
            for user, seq in data.items():
                for i in range(1, len(seq)):
                    self.samples.append({
                        "user": user,
                        "input_seq": seq[:i],
                        "target": seq[i],
                        "full_seq": seq # For negative sampling
                    })
        else:
            for user, (seq, target) in data.items():
                self.samples.append({
                    "user": user,
                    "input_seq": seq,
                    "target": target,
                    "full_seq": seq + [target]
                })

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        user = sample["user"]
        seq = sample["input_seq"]
        target = sample["target"]

        # Truncate sequence if > max length
        if len(seq) > self.max_seq_len:
            seq = seq[-self.max_seq_len:]

        # Left-pad sequence with zeros
        pad_len = self.max_seq_len - len(seq)
        padded_seq = [0] * pad_len + seq

        # Negative sampling
        forbid = self.all_pos[user] if self.all_pos is not None else set(sample["full_seq"])
        neg_items = set()

        while len(neg_items) < self.neg_samples:
            neg = random.randint(1, self.num_items - 1)
            if neg not in forbid:
                neg_items.add(neg)

        return {
            "user": sample["user"],
            "input_seq": torch.tensor(padded_seq, dtype=torch.long),
            "target": torch.tensor(target, dtype=torch.long),
            "neg_items": torch.tensor(list(neg_items), dtype=torch.long)
        }

# Create datasets
train_dataset = SASRecDataset(train_sequences, NUM_ITEMS, pos_items_by_user=pos_items_by_user, max_seq_len=50, mode="train", neg_samples=1)
val_dataset = SASRecDataset(val_sequences, NUM_ITEMS, pos_items_by_user=pos_items_by_user, max_seq_len=50, mode="val", neg_samples=99)
test_dataset = SASRecDataset(test_sequences, NUM_ITEMS, pos_items_by_user=pos_items_by_user, max_seq_len=50, mode="test", neg_samples=99)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

Training samples: 247292
Validation samples: 22235
Test samples: 22235


In [8]:
# Create data loaders
BATCH_SIZE = 4096
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [9]:
first = next(iter(train_loader))
print("Sample batch from loader:")
print("Input sequence shape:", first["input_seq"].shape)
print("Target shape:", first["target"].shape)
print("Negative items shape:", first["neg_items"].shape)

print("\nSample input sequence:")
random_index = []
for _ in range(5):
    random_index.append(random.randint(0, len(train_loader) - 1))

for i in random_index:
    print(first["input_seq"][i])

Sample batch from loader:
Input sequence shape: torch.Size([4096, 50])
Target shape: torch.Size([4096])
Negative items shape: torch.Size([4096, 1])

Sample input sequence:
tensor([    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0, 50395, 50653, 49777, 50954, 50808, 50151])
tensor([    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,  4552,  4315,
        30645, 15719,  7750, 15394,   614, 13048, 14547, 38603,   372,  8043])
tensor([    0,     0,     0,     0,     0,    

### Create SASRec model

In [10]:
# Building SASRec model
class PointWiseFeedForward(nn.Module):
    def __init__(self, hidden_dim, dropout=0.2):
        super().__init__()
        self.w1 = nn.Linear(hidden_dim, hidden_dim)
        self.w2 = nn.Linear(hidden_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w2(self.dropout(self.relu(self.w1(x))))

class AttentionBlock(nn.Module):
    def __init__(self, hidden_dim, num_heads, dropout=0.2):
        super().__init__()

        # Multi-head attention
        self.attn = nn.MultiheadAttention(hidden_dim, num_heads, dropout=dropout, batch_first=True)

        # Layer norms
        self.ln1 = nn.LayerNorm(hidden_dim)
        self.ln2 = nn.LayerNorm(hidden_dim)

        # Feed-forward network
        self.ffn = PointWiseFeedForward(hidden_dim, dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, attn_mask=None):
        # Self-attention with residual connection
        attn_out, _ = self.attn(x, x, x, attn_mask=attn_mask)
        x = self.ln1(x + self.dropout(attn_out))

        # Feed-forward network with residual connection
        ffn_out = self.ffn(x)
        x = self.ln2(x + self.dropout(ffn_out))

        return x

class SASRec(nn.Module):
    def __init__(self,
                 num_items,
                 hidden_dim=64,
                 max_seq_len=50,
                 num_blocks=2,
                 num_heads=2,
                 dropout=0.2):
        super().__init__()

        self.num_items = num_items
        self.hidden_dim = hidden_dim
        self.max_seq_len = max_seq_len

        # Embedding layers
        self.item_embed = nn.Embedding(num_items, hidden_dim, padding_idx=0)
        self.positional_embed = nn.Embedding(max_seq_len, hidden_dim)
        self.dropout = nn.Dropout(dropout)

        # Stack of SASRec blocks
        self.blocks = nn.ModuleList([
            AttentionBlock(hidden_dim, num_heads, dropout) for _ in range(num_blocks)
        ])

        # Final layer norm
        self.ln = nn.LayerNorm(hidden_dim)

        # Initialize weights
        self._reset_parameters()

    def _reset_parameters(self):
        nn.init.xavier_normal_(self.item_embed.weight[1:])  # Skip padding idx
        nn.init.xavier_normal_(self.positional_embed.weight)

    def forward(self, input_seq, candidate_items=None):
        batch_size, seq_len = input_seq.shape

        # Get item embeddings
        item_embeds = self.item_embed(input_seq)  # [B, L, D]

        # Add positional embeddings
        positions = torch.arange(seq_len, device=input_seq.device).unsqueeze(0)
        pos_embeds = self.positional_embed(positions)  # [1, L, D]
        x = self.dropout(item_embeds + pos_embeds)

        # Create causal attention mask
        attn_mask = self._create_causal_mask(seq_len, input_seq.device)
        pad_mask = input_seq.eq(0)

        # Pass through transformer blocks
        for block in self.blocks:
            x = block(x, attn_mask=attn_mask)

        # Final layer norm
        x = self.ln(x)  # [B, L, D]
        x = x.masked_fill(pad_mask.unsqueeze(-1), 0.0)

        # If candidate_items provided, score them
        if candidate_items is not None:
            # Get embeddings for candidate items
            cand_emb = self.item_embed(candidate_items) # [B, N, D]

            # Use last position's representation for scoring
            last_hidden = x[:, -1, :].unsqueeze(1)  # [B, 1, D]

            # Compute scores via dot product
            scores = torch.matmul(last_hidden, cand_emb.transpose(1, 2)).squeeze(1) # [B, N]
            return scores

        return x

    def _create_causal_mask(self, seq_len, device):
        mask = torch.triu(torch.ones(seq_len, seq_len, device=device, dtype=torch.bool), diagonal=1)
        mask = mask.masked_fill(mask == 1, float("-inf"))
        return mask

    def predict_next(self, input_seq):
        # Get sequence representations
        seq_repr = self.forward(input_seq)  # [B, L, D]

        # Use last position for prediction
        last_hidden = seq_repr[:, -1, :]  # [B, D]

        # Score against all item embeddings
        all_item_embeds = self.item_embed.weight  # [num_items, D]
        scores = torch.matmul(last_hidden, all_item_embeds.T)  # [B, num_items]
        return scores

### Training and evaluation functions

In [11]:
def train_sasrec_epoch(model, train_loader, loss_fn, optimizer, device="cpu"):
    model.train()
    total_loss = 0.0
    n_batches = 0

    for batch in tqdm(train_loader, desc="Training"):
        input_seq = batch["input_seq"].to(device)
        pos_items = batch["target"].to(device)
        neg_items = batch["neg_items"].to(device)

        # Get predictions for last position
        seq_output = model(input_seq)  # [B, L, D]
        last_hidden = seq_output[:, -1, :]  # [B, D]

        # Get embeddings for positive and negative items
        pos_embeds = model.item_embed(pos_items)
        neg_embeds = model.item_embed(neg_items)

        # Compute logits
        pos_logits = (last_hidden * pos_embeds).sum(dim=1)
        neg_logits = torch.bmm(neg_embeds, last_hidden.unsqueeze(-1)).squeeze(-1)

        # Binary cross-entropy loss with logits
        pos_labels = torch.ones_like(pos_logits)
        neg_labels = torch.zeros_like(neg_logits)

        # Concatenate logits and labels
        all_logits = torch.cat([pos_logits.unsqueeze(1), neg_logits], dim=1)
        all_labels = torch.cat([pos_labels.unsqueeze(1), neg_labels], dim=1)

        loss = loss_fn(all_logits, all_labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        n_batches += 1

    return total_loss / n_batches

In [12]:
# Validation loss and ranking metrics
@torch.no_grad()
def evaluate_sasrec(model, eval_loader, loss_fn, k=10, device="cpu"):
    model.eval()
    total = 0
    sum_hr = 0.0
    sum_ndcg = 0.0
    sum_prec = 0.0
    sum_ap = 0.0

    sum_val_loss = 0.0
    n_loss_batches = 0

    for batch in tqdm(eval_loader, desc="Evaluating"):
        input_seq = batch["input_seq"].to(device)
        target = batch["target"].to(device)
        neg_items = batch["neg_items"].to(device)

        batch_size = input_seq.size(0)

        # Create candidate set: 1 positive + negatives
        seq_output = model(input_seq)  # [B, L, D]
        last_hidden = seq_output[:, -1, :]  # [B, D]
        candidates = torch.cat([
            target.unsqueeze(1),  # [B, 1]
            neg_items  # [B, neg_samples]
        ], dim=1)  # [B, 1 + neg_samples]

        # Get embeddings for all candidates
        cand_emb = model.item_embed(candidates)  # [B, 1+neg_samples, D]
        scores = torch.bmm(cand_emb, last_hidden.unsqueeze(-1)).squeeze(-1)  # [B, 1+neg_samples]

        # sanity: positive not in negatives
        if torch.any((candidates[:, 1:] == target.unsqueeze(1)).any(dim=1)):
            raise RuntimeError("Positive item appeared in negatives for some samples.")

        # Loss calculation
        pos_scores = scores[:, 0]
        neg_scores = scores[:, 1:]
        pos_labels = torch.ones_like(scores[:, 0])
        neg_labels = torch.zeros_like(scores[:, 1:])
        all_scores = torch.cat([pos_scores.unsqueeze(1), neg_scores], dim=1)
        all_labels = torch.cat([pos_labels.unsqueeze(1), neg_labels], dim=1)
        batch_loss = loss_fn(all_scores.reshape(-1), all_labels.reshape(-1))
        sum_val_loss += batch_loss.item()
        n_loss_batches += 1

        # Calculate metrics
        _, full_idx = torch.sort(scores, dim=1, descending=True)
        rank = (full_idx == 0).nonzero(as_tuple=True)[1] + 1  # Rank of the positive item (1-based)

        hit = (rank <= k).float()
        ndcg = torch.where(rank <= k, 1.0 / torch.log2(rank.float() + 1), torch.zeros_like(hit))
        precision = hit / float(k)
        ap = torch.where(rank <= k, 1.0 / rank.float(), torch.zeros_like(hit))

        sum_hr += hit.sum().item()
        sum_ndcg += ndcg.sum().item()
        sum_prec += precision.sum().item()
        sum_ap += ap.sum().item()
        total += batch_size

    metrics = {
        "HR@K": sum_hr / total if total else 0.0,
        "NDCG@K": sum_ndcg / total if total else 0.0,
        "Precision@K": sum_prec / total if total else 0.0,
        "MAP@K": sum_ap / total if total else 0.0,
        "Val loss": sum_val_loss / max(n_loss_batches, 1)
    }

    return metrics

In [13]:
def sasrec_trainer(
        model,
        train_loader,
        eval_loader,
        epochs,
        loss_fn,
        optimizer,
        k=10,
        device="cpu",
        save_dir="model"
    ):
    os.makedirs(save_dir, exist_ok=True)
    model.to(device)
    writer = SummaryWriter()

    train_losses, val_losses, val_metrics_log = [], [], []
    best_ndcg, best_epoch = 0.0, 0

    for epoch in range(epochs):
        t0 = time.time()

        # Train (batched)
        train_loss = train_sasrec_epoch(model, train_loader, loss_fn, optimizer, device=device)
        train_losses.append(train_loss)

        # Eval (batched)
        m = evaluate_sasrec(model, eval_loader, loss_fn, k=k, device=device)
        val_losses.append(m["Val loss"])
        val_metrics_log.append({k_: m[k_] for k_ in ["HR@K", "NDCG@K", "Precision@K", "MAP@K"]})

        # Checkpointing by NDCG
        if m["NDCG@K"] > best_ndcg:
            best_ndcg = m["NDCG@K"]
            best_epoch = epoch + 1
            torch.save(model.state_dict(), os.path.join(save_dir, "best_model.pth"))
        torch.save(model.state_dict(), os.path.join(save_dir, "last_model.pth"))

        # TB logs
        writer.add_scalar("Loss/Train", train_loss, epoch)
        writer.add_scalar("Loss/Validation", m["Val loss"], epoch)
        writer.add_scalar(f"Metrics/Val_HR@{k}", m["HR@K"], epoch)
        writer.add_scalar(f"Metrics/Val_NDCG@{k}", m["NDCG@K"], epoch)
        writer.add_scalar(f"Metrics/Val_Precision@{k}", m["Precision@K"], epoch)
        writer.add_scalar(f"Metrics/Val_MAP@{k}", m["MAP@K"], epoch)

        print(
            f"Epoch {epoch+1}/{epochs}  "
            f"Train loss {train_loss:.4f}  "
            f"Val loss {m['Val loss']:.4f}  "
            f"HR@{k} {m['HR@K']:.4f}  "
            f"NDCG@{k} {m['NDCG@K']:.4f}  "
            f"Precision@{k} {m['Precision@K']:.4f}  "
            f"MAP@{k} {m['MAP@K']:.4f}  "
            f"{'(new best)' if m['NDCG@K'] == best_ndcg and best_epoch==epoch+1 else ''}  "
            f"Time {time.time()-t0:.2f}s"
        )

    print("\nTraining Complete.")
    print(f"Best epoch: {best_epoch} with NDCG@{k}: {best_ndcg:.4f}\n")

    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    writer.close()
    return train_losses, val_losses, val_metrics_log, best_ndcg

### Training the model

In [14]:
# Hyperparameters from the original paper, except higher hidden_dim
sasrec = SASRec(
    num_items=NUM_ITEMS,
    hidden_dim=64,
    max_seq_len=50,
    num_blocks=2,
    num_heads=2,
    dropout=0.2
)

loss_fn_sasrec = nn.BCEWithLogitsLoss()
optimizer_sasrec = torch.optim.Adam(sasrec.parameters(), lr=1e-3, weight_decay=1e-6)

train_losses_sasrec, val_losses_sasrec, val_metrics_sasrec, best_ndcg_sasrec = sasrec_trainer(
    model=sasrec,
    train_loader=train_loader,
    eval_loader=val_loader,
    loss_fn=loss_fn_sasrec,
    optimizer=optimizer_sasrec,
    epochs=20,
    k=10,
    device=DEVICE,
    save_dir="model_sasrec"
)

Training: 100%|██████████| 61/61 [00:15<00:00,  3.95it/s]
Evaluating: 100%|██████████| 6/6 [00:02<00:00,  2.03it/s]


Epoch 1/20  Train loss 0.6532  Val loss 0.6672  HR@10 0.2449  NDCG@10 0.1296  Precision@10 0.0245  MAP@10 0.0947  (new best)  Time 18.44s


Training: 100%|██████████| 61/61 [00:15<00:00,  4.05it/s]
Evaluating: 100%|██████████| 6/6 [00:02<00:00,  2.57it/s]


Epoch 2/20  Train loss 0.5963  Val loss 0.6330  HR@10 0.2468  NDCG@10 0.1299  Precision@10 0.0247  MAP@10 0.0946  (new best)  Time 17.46s


Training: 100%|██████████| 61/61 [00:14<00:00,  4.18it/s]
Evaluating: 100%|██████████| 6/6 [00:02<00:00,  2.60it/s]


Epoch 3/20  Train loss 0.5821  Val loss 0.6206  HR@10 0.2503  NDCG@10 0.1310  Precision@10 0.0250  MAP@10 0.0951  (new best)  Time 16.94s


Training: 100%|██████████| 61/61 [00:14<00:00,  4.26it/s]
Evaluating: 100%|██████████| 6/6 [00:02<00:00,  2.29it/s]


Epoch 4/20  Train loss 0.5661  Val loss 0.6202  HR@10 0.2708  NDCG@10 0.1425  Precision@10 0.0271  MAP@10 0.1039  (new best)  Time 16.99s


Training: 100%|██████████| 61/61 [00:14<00:00,  4.31it/s]
Evaluating: 100%|██████████| 6/6 [00:02<00:00,  2.58it/s]


Epoch 5/20  Train loss 0.5463  Val loss 0.6167  HR@10 0.2918  NDCG@10 0.1547  Precision@10 0.0292  MAP@10 0.1134  (new best)  Time 16.51s


Training: 100%|██████████| 61/61 [00:14<00:00,  4.15it/s]
Evaluating: 100%|██████████| 6/6 [00:02<00:00,  2.55it/s]


Epoch 6/20  Train loss 0.5250  Val loss 0.6141  HR@10 0.3095  NDCG@10 0.1655  Precision@10 0.0310  MAP@10 0.1220  (new best)  Time 17.11s


Training: 100%|██████████| 61/61 [00:14<00:00,  4.29it/s]
Evaluating: 100%|██████████| 6/6 [00:02<00:00,  2.26it/s]


Epoch 7/20  Train loss 0.5040  Val loss 0.6228  HR@10 0.3295  NDCG@10 0.1761  Precision@10 0.0329  MAP@10 0.1297  (new best)  Time 16.92s


Training: 100%|██████████| 61/61 [00:14<00:00,  4.33it/s]
Evaluating: 100%|██████████| 6/6 [00:02<00:00,  2.55it/s]


Epoch 8/20  Train loss 0.4799  Val loss 0.6014  HR@10 0.3414  NDCG@10 0.1845  Precision@10 0.0341  MAP@10 0.1370  (new best)  Time 16.50s


Training: 100%|██████████| 61/61 [00:14<00:00,  4.22it/s]
Evaluating: 100%|██████████| 6/6 [00:02<00:00,  2.52it/s]


Epoch 9/20  Train loss 0.4539  Val loss 0.5934  HR@10 0.3647  NDCG@10 0.1968  Precision@10 0.0365  MAP@10 0.1460  (new best)  Time 16.90s


Training: 100%|██████████| 61/61 [00:14<00:00,  4.22it/s]
Evaluating: 100%|██████████| 6/6 [00:02<00:00,  2.59it/s]


Epoch 10/20  Train loss 0.4268  Val loss 0.5658  HR@10 0.3789  NDCG@10 0.2055  Precision@10 0.0379  MAP@10 0.1530  (new best)  Time 16.80s


Training: 100%|██████████| 61/61 [00:14<00:00,  4.32it/s]
Evaluating: 100%|██████████| 6/6 [00:02<00:00,  2.30it/s]


Epoch 11/20  Train loss 0.4001  Val loss 0.5340  HR@10 0.3943  NDCG@10 0.2143  Precision@10 0.0394  MAP@10 0.1598  (new best)  Time 16.76s


Training: 100%|██████████| 61/61 [00:14<00:00,  4.31it/s]
Evaluating: 100%|██████████| 6/6 [00:02<00:00,  2.58it/s]


Epoch 12/20  Train loss 0.3736  Val loss 0.4992  HR@10 0.4063  NDCG@10 0.2248  Precision@10 0.0406  MAP@10 0.1697  (new best)  Time 16.51s


Training: 100%|██████████| 61/61 [00:14<00:00,  4.25it/s]
Evaluating: 100%|██████████| 6/6 [00:02<00:00,  2.57it/s]


Epoch 13/20  Train loss 0.3509  Val loss 0.4764  HR@10 0.4192  NDCG@10 0.2345  Precision@10 0.0419  MAP@10 0.1784  (new best)  Time 16.74s


Training: 100%|██████████| 61/61 [00:14<00:00,  4.24it/s]
Evaluating: 100%|██████████| 6/6 [00:02<00:00,  2.21it/s]


Epoch 14/20  Train loss 0.3264  Val loss 0.4420  HR@10 0.4334  NDCG@10 0.2434  Precision@10 0.0433  MAP@10 0.1854  (new best)  Time 17.13s


Training: 100%|██████████| 61/61 [00:14<00:00,  4.31it/s]
Evaluating: 100%|██████████| 6/6 [00:02<00:00,  2.57it/s]


Epoch 15/20  Train loss 0.3046  Val loss 0.4137  HR@10 0.4448  NDCG@10 0.2524  Precision@10 0.0445  MAP@10 0.1938  (new best)  Time 16.53s


Training: 100%|██████████| 61/61 [00:14<00:00,  4.10it/s]
Evaluating: 100%|██████████| 6/6 [00:02<00:00,  2.44it/s]


Epoch 16/20  Train loss 0.2843  Val loss 0.3919  HR@10 0.4545  NDCG@10 0.2622  Precision@10 0.0454  MAP@10 0.2035  (new best)  Time 17.40s


Training: 100%|██████████| 61/61 [00:15<00:00,  3.93it/s]
Evaluating: 100%|██████████| 6/6 [00:02<00:00,  2.50it/s]


Epoch 17/20  Train loss 0.2652  Val loss 0.3607  HR@10 0.4621  NDCG@10 0.2701  Precision@10 0.0462  MAP@10 0.2113  (new best)  Time 17.95s


Training: 100%|██████████| 61/61 [00:14<00:00,  4.22it/s]
Evaluating: 100%|██████████| 6/6 [00:02<00:00,  2.53it/s]


Epoch 18/20  Train loss 0.2468  Val loss 0.3422  HR@10 0.4682  NDCG@10 0.2771  Precision@10 0.0468  MAP@10 0.2185  (new best)  Time 16.87s


Training: 100%|██████████| 61/61 [00:14<00:00,  4.16it/s]
Evaluating: 100%|██████████| 6/6 [00:02<00:00,  2.52it/s]


Epoch 19/20  Train loss 0.2307  Val loss 0.3182  HR@10 0.4748  NDCG@10 0.2830  Precision@10 0.0475  MAP@10 0.2240  (new best)  Time 17.10s


Training: 100%|██████████| 61/61 [00:14<00:00,  4.21it/s]
Evaluating: 100%|██████████| 6/6 [00:02<00:00,  2.54it/s]


Epoch 20/20  Train loss 0.2152  Val loss 0.2969  HR@10 0.4830  NDCG@10 0.2906  Precision@10 0.0483  MAP@10 0.2315  (new best)  Time 16.89s

Training Complete.
Best epoch: 20 with NDCG@10: 0.2906



## Cross-domain development

In [16]:
# Load trained model on source domain
def load_best_weights(model, ckpt_path="model/best_model.pth", device=None):
    if device is None:
        device = next(model.parameters()).device
    if not os.path.exists(ckpt_path):
        raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
    state = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(state)
    model.to(device)
    model.eval()
    return model

best_model = load_best_weights(sasrec, ckpt_path="model_sasrec/best_model.pth", device=DEVICE)

### Align users across domains

In [28]:
@torch.no_grad()
def compute_user_reprs_from_sequences(model_src, train_seqs_src, user_encoder_src, max_seq_len=50, device=DEVICE):
    model_src.eval().to(device)
    user_vecs = {}

    for user_id, seq in train_seqs_src.items():
        if len(seq) < 1:
            continue

        # Pad-left to max_seq_len
        seq = seq[-max_seq_len:]
        pad_len = max_seq_len - len(seq)
        input_seq = torch.tensor([([0] * pad_len + seq)], dtype=torch.long, device=device)
        hidden = model_src(input_seq)
        last_hidden = hidden[0, -1, :].squeeze(0)
        raw_user = user_encoder_src.inverse_transform([user_id])[0]
        user_vecs[raw_user] = last_hidden.detach().cpu().numpy()

    print(f"\nComputed user representations for {len(user_vecs)} users.")
    return user_vecs

In [29]:
# Cross-domain evaluation on target domain
SOURCE_DOMAIN = "Books"
TARGET_DOMAIN = "Movies_and_TV"
ALL_DOMAIN = [SOURCE_DOMAIN, TARGET_DOMAIN]

# Load data from target domain
df_target = load_amazon_reviews(TARGET_DOMAIN, max_items=10_000_000, seed=SEED)
filtered_df_target = preprocess_dataset(df_target, min_user_interactions=20, min_item_interactions=20)
df_target_encoded, user_encoder_tgt, item_encoder_tgt, domain_encoder_tgt = label_encoder(filtered_df_target, shift_item_id=True)

NUM_USERS_TGT = df_target_encoded["user_id"].max() + 1
NUM_ITEMS_TGT = df_target_encoded["item_id"].max() + 1

# Rebuild sequences for target domain and split
user_sequences_tgt = create_user_sequences(df_target_encoded)
pos_items_by_user_tgt = {u: set(seq) for u, seq in user_sequences_tgt.items()}
train_sequences_tgt, val_sequences_tgt, test_sequences_tgt = sequences_loo_split(user_sequences_tgt)

# Build source user vectors from the trained source model
user_vecs_src = compute_user_reprs_from_sequences(
    model_src=sasrec,
    train_seqs_src=train_sequences,
    user_encoder_src=user_encoder,
    max_seq_len=50,
    device=DEVICE
)

# Create an aligned matrix of source vectors in target's user_id space
embed_dim = 64
transfer_src_mat = np.zeros((NUM_USERS_TGT, embed_dim), dtype=np.float32)
for raw_user, vec in user_vecs_src.items():
    if raw_user in user_encoder_tgt.classes_:
        uid_target = user_encoder_tgt.transform([raw_user])[0]
        transfer_src_mat[uid_target] = vec # give source user vector to target user_id (shared users)

transfer_src_mat = torch.tensor(transfer_src_mat)  # [U_T, D]
print("\n")
print(transfer_src_mat)

Loaded data/amazon_reviews_Movies_and_TV.csv with 10000000 rows.
After interactions filtering: 838690 rows, 29948 users, 73928 items


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df["user_id"] = user_enc.fit_transform(df["user"])
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df["item_id"] = item_enc.fit_transform(df["item"])
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df["item_id"] = df["item_id"] + 1  # Shift item IDs by 1 to reserve 0 for padding if needed
A value is 

Number of users: 29948
Max sequence length: 1123
Min sequence length: 1
Training sequences: 29819
Validation users: 29819
Test users: 29819

Computed user representations for 22235 users.


tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 2.7717, -7.0042, -1.5810,  ..., -5.7022, -1.2490, -2.1327]])


### Dataset and DataLoader for cross-domain

In [30]:
# Additional dataset changes for cross-domain
class SASRecDatasetCD(SASRecDataset):
    def __init__(self, data, num_items, transfer_src_mat, max_seq_len=50, mode="train", neg_samples=1):
        super().__init__(data, num_items, max_seq_len=max_seq_len, mode=mode, neg_samples=neg_samples)
        self.transfer_src_mat = transfer_src_mat

    def __getitem__(self, idx):
        out = super().__getitem__(idx)
        user_id = out["user"]
        out["transfer_src"] = self.transfer_src_mat[user_id].float()
        return out

In [35]:
## Target datasets & loaders (reuse your batch size)
train_dataset_tgt = SASRecDatasetCD(train_sequences_tgt, NUM_ITEMS_TGT, transfer_src_mat, max_seq_len=50, mode="train", neg_samples=1)
val_dataset_tgt = SASRecDatasetCD(val_sequences_tgt, NUM_ITEMS_TGT, transfer_src_mat, max_seq_len=50, mode="val", neg_samples=99)
test_dataset_tgt = SASRecDatasetCD(test_sequences_tgt, NUM_ITEMS_TGT, transfer_src_mat, max_seq_len=50, mode="test", neg_samples=99)

train_loader_tgt = DataLoader(train_dataset_tgt, batch_size=BATCH_SIZE, shuffle=True)
val_loader_tgt   = DataLoader(val_dataset_tgt,   batch_size=BATCH_SIZE, shuffle=False)
test_loader_tgt  = DataLoader(test_dataset_tgt,  batch_size=BATCH_SIZE, shuffle=False)

### Cross-domain SASRec model
This technique is inspired by the paper [Personalized Transfer of User Preferences for Cross-domain Recommendation (2021)](https://arxiv.org/abs/2110.11154).

In [36]:
class SASRecCD(nn.Module):
    def __init__(self, base_sasrec, hidden_dim=64, bridge_hidden=128, dropout=0.1):
        super().__init__()
        self.base = base_sasrec
        self.bridge = nn.Sequential(
            nn.Linear(hidden_dim, bridge_hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(bridge_hidden, bridge_hidden),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        self.gate = nn.Linear(hidden_dim * 2, hidden_dim)

    def forward(self, input_seq, transfer_src=None, candidate_items=None):
        seq_output = self.base(input_seq)
        last_hidden = seq_output[:, -1, :]

        if transfer_src is not None:
            bridge_out = self.bridge(transfer_src)
            combined = torch.cat([last_hidden, bridge_out], dim=-1)
            gate = torch.sigmoid(self.gate(combined))
            fused = gate * last_hidden + (1.0 - gate) * bridge_out
        else:
            fused = last_hidden

        if candidate_items is not None:
            cand_emb = self.base.item_embed(candidate_items)
            scores = torch.bmm(cand_emb, fused.unsqueeze(-1)).squeeze(-1)
            return scores

        return fused

### Training and evaluation functions for cross-domain

In [37]:
def train_epoch_transfer(model, loader, loss_fn, optimizer, device="cpu"):
    model.train()
    total, n = 0.0, 0
    for batch in tqdm(loader, desc="Training"):
        inp = batch["input_seq"].to(device)
        pos = batch["target"].to(device)
        neg = batch["neg_items"].to(device)
        transfer = batch["transfer_src"].to(device)

        # fused representation
        fused = model(inp, transfer_src=transfer)
        pos_emb = model.base.item_embed(pos)
        neg_emb = model.base.item_embed(neg)

        pos_logits = (fused * pos_emb).sum(dim=1)
        neg_logits = torch.bmm(neg_emb, fused.unsqueeze(-1)).squeeze(-1)

        all_logits = torch.cat([pos_logits.unsqueeze(1), neg_logits], 1)
        all_labels = torch.cat([torch.ones_like(pos_logits).unsqueeze(1),
                                torch.zeros_like(neg_logits)], 1)

        loss = loss_fn(all_logits.reshape(-1), all_labels.reshape(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total += loss.item(); n += 1
    return total / n

In [38]:
@torch.no_grad()
def evaluate_transfer(model, loader, loss_fn, k=10, device="cpu"):
    model.eval()
    total = hits = ndcgs = precs = mrrs = 0.0
    loss_sum, nb = 0.0, 0

    for batch in tqdm(loader, desc="Evaluating"):
        inp = batch["input_seq"].to(device)
        tgt = batch["target"].to(device)
        neg = batch["neg_items"].to(device)
        transfer = batch["transfer_src"].to(device)

        fused = model(inp, transfer_src=transfer)
        cand = torch.cat([tgt.unsqueeze(1), neg], dim=1)
        cand_emb = model.base.item_embed(cand)
        scores = torch.bmm(cand_emb, fused.unsqueeze(-1)).squeeze(-1)

        # loss (same as train for parity)
        labels = torch.cat([torch.ones_like(scores[:, :1]),
                            torch.zeros_like(scores[:, 1:])], dim=1)
        batch_loss = loss_fn(scores.reshape(-1), labels.reshape(-1))
        loss_sum += batch_loss.item(); nb += 1

        # ranks & metrics
        _, idx = torch.sort(scores, dim=1, descending=True)
        rank = (idx == 0).nonzero(as_tuple=True)[1] + 1  # 1-based
        hit = (rank <= k).float()
        ndcg = torch.where(rank <= k, 1.0 / torch.log2(rank.float() + 1), torch.zeros_like(hit))
        precision = hit / float(k)
        mrr = 1.0 / rank.float()

        B = inp.size(0)
        hits += hit.sum().item()
        ndcgs += ndcg.sum().item()
        precs += precision.sum().item()
        mrrs += mrr.sum().item()
        total += B

    return {
        "HR@K": hits / total,
        "NDCG@K": ndcgs / total,
        "Precision@K": precs / total,
        "MRR": mrrs / total,
        "Val loss": loss_sum / max(nb, 1)
    }

In [42]:
# Trainer (target domain)
def train_target_with_transfer(model, train_loader, val_loader, epochs, lr=1e-3, wd=1e-6, k=10, device="cpu"):
    model.to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    loss_fn = nn.BCEWithLogitsLoss()
    best_ndcg, best_epoch = 0.0, 0

    for epoch in range(epochs):
        train = train_epoch_transfer(model, train_loader, loss_fn, opt, device=device)
        eval = evaluate_transfer(model, val_loader, loss_fn, k=k, device=device)

        if eval["NDCG@K"] > best_ndcg:
            best_ndcg, best_epoch = eval["NDCG@K"], epoch+1
            torch.save(model.state_dict(), "xfer_best.pth")

        print(f"Epoch {epoch+1}/{epochs}  "
              f"Train {train:.4f}  "
              f"Val {eval['Val loss']:.4f}  "
              f"HR@{k} {eval['HR@K']:.4f}  "
              f"NDCG@{k} {eval['NDCG@K']:.4f}  "
              f"Prec@{k} {eval['Precision@K']:.4f}  "
              f"MRR {eval['MRR']:.4f}  "
              f"{'(new best)' if eval['NDCG@K']==best_ndcg and best_epoch==epoch+1 else ''}")

    print(f"\nBest epoch {best_epoch} NDCG@{k}={best_ndcg:.4f}")
    return best_ndcg

In [43]:
sasrec_base = load_best_weights(sasrec, ckpt_path="model_sasrec/best_model.pth", device=DEVICE)
xfer_model = SASRecCD(sasrec_base, hidden_dim=64, bridge_hidden=128, dropout=0.2)

best_ndcg_tgt = train_target_with_transfer(
    SASRecCD, train_loader_tgt, val_loader_tgt, epochs=20, lr=1e-3, wd=1e-6, k=10, device=DEVICE
)

AttributeError: 'str' object has no attribute '_apply'