# BERT4Rec-based recommendation system

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
cd /content/drive/MyDrive/Course/CS247 - Advanced Data Mining/Final Project/CS247-Project/

/content/drive/MyDrive/Course/CS247 - Advanced Data Mining/Final Project/CS247-Project


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import BertModel, BertConfig
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

# Ensure compatibility with Jupyter Notebook
%matplotlib inline

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [None]:
# Function to load the MovieLens dataset
def load_data(filepath="ml-1m/ratings.dat"):
    df = pd.read_csv(filepath, sep="::", engine="python",
                     names=["userId", "movieId", "rating", "timestamp"])
    df = df.sort_values(by=["userId", "timestamp"])  # Sort by user and timestamp
    user_movie_dict = df.groupby("userId")["movieId"].apply(list).to_dict()
    return user_movie_dict

# Load dataset
user_movie_dict = load_data()
print(f"Loaded {len(user_movie_dict)} users' movie interaction sequences")

# Print a sample of user-movie interactions for debugging
for user, movies in list(user_movie_dict.items())[:3]:
    print(f"User {user}: {movies}")

Loaded 6040 users' movie interaction sequences
User 1: [3186, 1270, 1721, 1022, 2340, 1836, 3408, 2804, 1207, 1193, 720, 260, 919, 608, 2692, 1961, 2028, 3105, 938, 1035, 1962, 2018, 150, 1028, 1097, 914, 1287, 2797, 2762, 1246, 661, 2918, 531, 3114, 2791, 2321, 1029, 1197, 594, 2398, 1545, 527, 595, 2687, 745, 588, 1, 2355, 2294, 783, 1566, 1907, 48]
User 2: [1198, 1210, 1217, 2717, 1293, 2943, 1225, 1193, 318, 3030, 2858, 1213, 1945, 1207, 593, 3095, 3468, 1873, 515, 1090, 2501, 3035, 110, 2067, 3147, 1247, 3105, 1357, 1196, 1957, 1953, 920, 1834, 1084, 1962, 3471, 3654, 3735, 1259, 1954, 1784, 2728, 1968, 1103, 902, 3451, 3578, 2852, 3334, 3068, 265, 2312, 590, 1253, 3071, 1244, 3699, 1955, 1245, 2236, 3678, 982, 2194, 2268, 1442, 3255, 647, 235, 1096, 1124, 498, 1246, 3893, 1537, 1188, 2396, 2359, 2321, 356, 3108, 1265, 3809, 589, 2028, 2571, 457, 2916, 1610, 480, 163, 380, 3418, 3256, 1408, 21, 349, 1527, 2353, 2006, 2278, 1370, 648, 2427, 1792, 1372, 1552, 2490, 1385, 780, 2881, 

In [None]:
# Function to load the movie lens dataset and check the data as pandas dataframe
def load_data(filepath="ml-1m/ratings.dat"):
    df = pd.read_csv(filepath, sep="::", engine="python",
                     names=["userId", "movieId", "rating", "timestamp"])
    df = df.sort_values(by=["userId", "timestamp"])  # Sort by user and timestamp
    return df

# Load dataset
ratings = load_data()
ratings.head()


Unnamed: 0,userId,movieId,rating,timestamp
31,1,3186,4,978300019
22,1,1270,5,978300055
27,1,1721,4,978300055
37,1,1022,5,978300055
24,1,2340,3,978300103


In [None]:
# Check null
ratings.info()


<class 'pandas.core.frame.DataFrame'>
Index: 1000209 entries, 31 to 1000042
Data columns (total 4 columns):
 #   Column     Non-Null Count    Dtype
---  ------     --------------    -----
 0   userId     1000209 non-null  int64
 1   movieId    1000209 non-null  int64
 2   rating     1000209 non-null  int64
 3   timestamp  1000209 non-null  int64
dtypes: int64(4)
memory usage: 38.2 MB


In [None]:
# Function to load the movie lens dataset and check the data as pandas dataframe
def load_data(filepath="ml-1m/users.dat"):
    df = pd.read_csv(filepath, sep="::", engine="python",
                     names=["userId","gender","age","occupation","zipCode"])
    df = df.sort_values(by=["userId"])  # Sort by user
    return df

# Load dataset
users = load_data()
users.head()

Unnamed: 0,userId,gender,age,occupation,zipCode
0,1,F,1,10,48067
1,2,M,56,16,70072
2,3,M,25,15,55117
3,4,M,45,7,2460
4,5,M,25,20,55455


In [None]:
# Replace gender with 0 and 1 for Female and Male
users.gender = users.gender.astype("category").cat.codes

users.head()

Unnamed: 0,userId,gender,age,occupation,zipCode
0,1,0,1,10,48067
1,2,1,56,16,70072
2,3,1,25,15,55117
3,4,1,45,7,2460
4,5,1,25,20,55455


In [None]:
# Check the columns
users.info()


<class 'pandas.core.frame.DataFrame'>
RangeIndex: 6040 entries, 0 to 6039
Data columns (total 5 columns):
 #   Column      Non-Null Count  Dtype 
---  ------      --------------  ----- 
 0   userId      6040 non-null   int64 
 1   gender      6040 non-null   int8  
 2   age         6040 non-null   int64 
 3   occupation  6040 non-null   int64 
 4   zipCode     6040 non-null   object
dtypes: int64(3), int8(1), object(1)
memory usage: 194.8+ KB


In [None]:
def load_data(filepath="ml-1m/movies.dat"):
    df = pd.read_csv(filepath, sep="::", engine="python",
                     names=["movieId","title","genres"], encoding="ISO-8859-1")
    df = df.sort_values(by=["movieId"])  # Sort by movieId
    return df

# Load dataset
movies = load_data()
movies.head()

Unnamed: 0,movieId,title,genres
0,1,Toy Story (1995),Animation|Children's|Comedy
1,2,Jumanji (1995),Adventure|Children's|Fantasy
2,3,Grumpier Old Men (1995),Comedy|Romance
3,4,Waiting to Exhale (1995),Comedy|Drama
4,5,Father of the Bride Part II (1995),Comedy


In [None]:
# Seperate the title column into title and year
def separate_title_year(title):
    # Extract the year from the title
    year = title[-5:-1]
    # Remove the year from the title
    title = title[:-7]
    return title, year


# Apply the function to the title column
movies["title"], movies["year"] = zip(*movies["title"].apply(separate_title_year))
movies.head()

Unnamed: 0,movieId,title,genres,year
0,1,Toy Story,Animation|Children's|Comedy,1995
1,2,Jumanji,Adventure|Children's|Fantasy,1995
2,3,Grumpier Old Men,Comedy|Romance,1995
3,4,Waiting to Exhale,Comedy|Drama,1995
4,5,Father of the Bride Part II,Comedy,1995


In [None]:
# Check the columns
movies.info()


<class 'pandas.core.frame.DataFrame'>
RangeIndex: 3883 entries, 0 to 3882
Data columns (total 4 columns):
 #   Column   Non-Null Count  Dtype 
---  ------   --------------  ----- 
 0   movieId  3883 non-null   int64 
 1   title    3883 non-null   object
 2   genres   3883 non-null   object
 3   year     3883 non-null   object
dtypes: int64(1), object(3)
memory usage: 121.5+ KB


In [None]:
# Replace | with a space
movies["genres"] = movies["genres"].str.replace("|", ", ")
movies.head()

Unnamed: 0,movieId,title,genres,year
0,1,Toy Story,"Animation, Children's, Comedy",1995
1,2,Jumanji,"Adventure, Children's, Fantasy",1995
2,3,Grumpier Old Men,"Comedy, Romance",1995
3,4,Waiting to Exhale,"Comedy, Drama",1995
4,5,Father of the Bride Part II,Comedy,1995


In [None]:
# Comcat title and genres
movies["title_genres"] = movies["title"] + ", " + movies["genres"]
movies.head()


Unnamed: 0,movieId,title,genres,year,title_genres
0,1,Toy Story,"Animation, Children's, Comedy",1995,"Toy Story, Animation, Children's, Comedy"
1,2,Jumanji,"Adventure, Children's, Fantasy",1995,"Jumanji, Adventure, Children's, Fantasy"
2,3,Grumpier Old Men,"Comedy, Romance",1995,"Grumpier Old Men, Comedy, Romance"
3,4,Waiting to Exhale,"Comedy, Drama",1995,"Waiting to Exhale, Comedy, Drama"
4,5,Father of the Bride Part II,Comedy,1995,"Father of the Bride Part II, Comedy"


In [None]:
# Merge the three dataframes
data = ratings.merge(users, on="userId").merge(movies, on="movieId")
data.head()

Unnamed: 0,userId,movieId,rating,timestamp,gender,age,occupation,zipCode,title,genres,year,title_genres
0,1,3186,4,978300019,0,1,10,48067,"Girl, Interrupted",Drama,1999,"Girl, Interrupted, Drama"
1,1,1270,5,978300055,0,1,10,48067,Back to the Future,"Comedy, Sci-Fi",1985,"Back to the Future, Comedy, Sci-Fi"
2,1,1721,4,978300055,0,1,10,48067,Titanic,"Drama, Romance",1997,"Titanic, Drama, Romance"
3,1,1022,5,978300055,0,1,10,48067,Cinderella,"Animation, Children's, Musical",1950,"Cinderella, Animation, Children's, Musical"
4,1,2340,3,978300103,0,1,10,48067,Meet Joe Black,Romance,1998,"Meet Joe Black, Romance"


In [None]:
# Copy all of the code above to create a dataloader function
def load_data(filepath="ml-1m"):
    '''
    This function loads the MovieLens 1M dataset and returns a pandas dataframe
    including the ratings, users, movies etc.
    '''


    # Load the data
    ratings = pd.read_csv(f"{filepath}/ratings.dat", sep="::", engine="python",
                          names=["userId", "movieId", "rating", "timestamp"])
    users = pd.read_csv(f"{filepath}/users.dat", sep="::", engine="python",
                        names=["userId","gender","age","occupation","zipCode"])
    movies = pd.read_csv(f"{filepath}/movies.dat", sep="::", engine="python",
                         names=["movieId","title","genres"], encoding="ISO-8859-1")

    # Sort ratings by user and timestamp
    ratings = ratings.sort_values(by=["userId", "timestamp"])

    # Replace gender with 0 and 1 for Female and Male
    users.gender = users.gender.astype("category").cat.codes

    # Clean the genres column, it separates the genres by "|"
    def clean_genres(genres):
        return genres.split("|")

    # Apply the function to the genres column
    movies["genres"] = movies["genres"].apply(clean_genres)

    # Seperate the title column into title and year
    def separate_title_year(title):
        # Extract the year from the title
        year = title[-5:-1]
        # Remove the year from the title
        title = title[:-7]
        return title, year


    # Apply the function to the title column
    movies["title"], movies["year"] = zip(*movies["title"].apply(separate_title_year))

    # Get all the unique genres
    unique_genres = set()

    for genres in movies["genres"]:
        unique_genres.update(genres)

    # Create a code for each genre and store it in a dictionary
    genre_to_code = {genre: code for code, genre in enumerate(unique_genres)}
    code_to_genre = {code: genre for genre, code in genre_to_code.items()}

    # Add a column for the genre codes
    movies["genre_codes"] = movies["genres"].apply(lambda x: [genre_to_code[genre] for genre in x])

    # Merge the three dataframes
    data = ratings.merge(users, on="userId").merge(movies, on="movieId")

    return data, genre_to_code, code_to_genre

In [None]:
# Load the data
data, genre_to_code, code_to_genre = load_data()
data.head()

Unnamed: 0,userId,movieId,rating,timestamp,gender,age,occupation,zipCode,title,genres,year,genre_codes
0,1,3186,4,978300019,0,1,10,48067,"Girl, Interrupted",[Drama],1999,[7]
1,1,1270,5,978300055,0,1,10,48067,Back to the Future,"[Comedy, Sci-Fi]",1985,"[0, 10]"
2,1,1721,4,978300055,0,1,10,48067,Titanic,"[Drama, Romance]",1997,"[7, 3]"
3,1,1022,5,978300055,0,1,10,48067,Cinderella,"[Animation, Children's, Musical]",1950,"[2, 14, 1]"
4,1,2340,3,978300103,0,1,10,48067,Meet Joe Black,[Romance],1998,[3]


In [None]:
from sentence_transformers import SentenceTransformer

# 1. Load a pretrained Sentence Transformer model
model = SentenceTransformer("all-MiniLM-L6-v2")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.7k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

In [None]:
# The sentences to encode
sentences = [
    "The weather is lovely today.",
    "It's so sunny outside!",
    "He drove to the stadium.",
]

# 2. Calculate embeddings by calling model.encode()
embeddings = model.encode(sentences)

embeddings.shape


(3, 384)

In [None]:
# Extract only userid, movieid, and title_genre column from data
data_prep = data[['userId', 'movieId', 'timestamp']]

data_prep.head()

Unnamed: 0,userId,movieId,timestamp
0,1,3186,978300019
1,1,1270,978300055
2,1,1721,978300055
3,1,1022,978300055
4,1,2340,978300103


In [None]:
# Use the title_genres column in the movies dataframe to embed the title_genres column in the data dataframe
movies["title_genres_embedding"] = movies["title_genres"].apply(lambda x: model.encode(x))

movies.head()

Unnamed: 0,movieId,title,genres,year,title_genres,title_genres_embedding
0,1,Toy Story,"Animation, Children's, Comedy",1995,"Toy Story, Animation, Children's, Comedy","[-0.042061232, -0.05036443, 0.05309303, 0.0070..."
1,2,Jumanji,"Adventure, Children's, Fantasy",1995,"Jumanji, Adventure, Children's, Fantasy","[0.0032210948, 0.076357044, 0.037274584, -0.01..."
2,3,Grumpier Old Men,"Comedy, Romance",1995,"Grumpier Old Men, Comedy, Romance","[-0.021038532, -0.076374575, 0.0022655362, -0...."
3,4,Waiting to Exhale,"Comedy, Drama",1995,"Waiting to Exhale, Comedy, Drama","[-0.012850017, -0.11153922, -0.011487518, 0.02..."
4,5,Father of the Bride Part II,Comedy,1995,"Father of the Bride Part II, Comedy","[-0.055391766, -0.006884501, -0.029818444, 0.0..."


In [None]:
# Take movieId and title_genres_embedding
movies_prep = movies[['movieId','title_genres_embedding']]

# Merge data_prep and movie_prep using movieId
data_merged = data_prep.merge(movies_prep, on='movieId')


data_merged.head()

Unnamed: 0,userId,movieId,timestamp,title_genres_embedding
0,1,3186,978300019,"[-0.0016497748, -0.04891429, 0.022234766, -0.0..."
1,1,1270,978300055,"[-0.06946387, -0.091070786, -0.0312587, 0.0067..."
2,1,1721,978300055,"[-0.014565197, -0.10387789, 0.047539853, 0.082..."
3,1,1022,978300055,"[0.016511308, -0.018827124, 0.049624704, 0.022..."
4,1,2340,978300103,"[-0.05233836, -0.031927433, -0.039634332, 0.09..."


In [None]:
# Sort data_merged by userId and timestamp
data_merged = data_merged.sort_values(['userId', 'timestamp'])

data_merged.head()


Unnamed: 0,userId,movieId,timestamp,title_genres_embedding
0,1,3186,978300019,"[-0.0016497748, -0.04891429, 0.022234766, -0.0..."
1,1,1270,978300055,"[-0.06946387, -0.091070786, -0.0312587, 0.0067..."
2,1,1721,978300055,"[-0.014565197, -0.10387789, 0.047539853, 0.082..."
3,1,1022,978300055,"[0.016511308, -0.018827124, 0.049624704, 0.022..."
4,1,2340,978300103,"[-0.05233836, -0.031927433, -0.039634332, 0.09..."


In [None]:
# Create a dictionary to store the user-movie interaction sequences
user_movie_embedding_dict = data_merged.groupby('userId')['title_genres_embedding'].apply(list).to_dict()

In [None]:
user_movie_dict = data_merged.groupby('userId')['movieId'].apply(list).to_dict()


In [None]:
# Function to split user interactions into train and test sets
def split_train_test(user_movie_dict, test_ratio=0.2, min_interactions=5):
    train_dict, test_dict = {}, {}

    for user, movies in user_movie_dict.items():
        if len(movies) >= min_interactions:  # Only split users with enough data
            split_idx = int(len(movies) * (1 - test_ratio))
            train_dict[user] = movies[:split_idx]
            test_dict[user] = movies[split_idx:]
        else:
            train_dict[user] = movies  # Assign all to train if only a few interactions

    return train_dict, test_dict

# Apply train-test split with filtering
train_movie_dict, test_movie_dict = split_train_test(user_movie_dict, test_ratio=0.2, min_interactions=5)

train_embedding_dict, test_embedding_dict = split_train_test(user_movie_embedding_dict, test_ratio=0.2, min_interactions=5)

# Print updated user counts
print(f"Train users: {len(train_movie_dict)}, Test users: {len(test_movie_dict)}")
print(f"Train users: {len(train_embedding_dict)}, Test users: {len(test_embedding_dict)}")


Train users: 6040, Test users: 6040
Train users: 6040, Test users: 6040


In [None]:
# Make sure lengths match
train_data = {}
for user in train_movie_dict:
    movie_seq = train_movie_dict[user]
    emb_seq = train_embedding_dict[user]
    assert len(movie_seq) == len(emb_seq)
    train_data[user] = (movie_seq, emb_seq)

In [None]:
test_data = {}
for user in test_movie_dict:
    movie_seq = test_movie_dict[user]
    emb_seq = test_embedding_dict[user]
    assert len(movie_seq) == len(emb_seq)
    test_data[user] = (movie_seq, emb_seq)

In [None]:
from torch.utils.data import Dataset, DataLoader

class BERT4RecDataset(Dataset):
    def __init__(self, user_data, max_seq_len=50):
        """
        user_data is a dict: { user_id: (movie_seq, emb_seq) }
            movie_seq: list of item IDs
            emb_seq: list of 384-dim embeddings
        max_seq_len: maximum sequence length
        """
        self.samples = []

        # Store the user_data as a list of tuples (movie_seq, emb_seq)
        for user, (movie_seq, emb_seq) in user_data.items():
            self.samples.append((movie_seq, emb_seq))


        self.max_seq_len = max_seq_len

    def __len__(self):
        return len(self.samples) # This line is to return the length of the dataset


    def __getitem__(self, idx): # This function is to get the item at a particular index
        movie_seq, emb_seq = self.samples[idx]

        # 1) Truncate or pad the movie sequence to the max_seq_len
        if len(movie_seq) > self.max_seq_len:
            movie_seq = movie_seq[-self.max_seq_len:] # to keep the last max_seq_len items
            emb_seq = emb_seq[-self.max_seq_len:]

        else:
            # pad the sequence
            pad_len = self.max_seq_len - len(movie_seq)
            movie_seq = list(movie_seq)
            movie_seq += [0]*pad_len

            # Embedding squence to be padded with zeros of shape (384,)
            emb_seq = emb_seq + [np.zeros(384)] * pad_len

        target_ids = movie_seq[1:] + [0]  # Next-movie prediction

        # attention_mask is 1 for real tokens and 0 for padding tokens
        attention_mask = [1 if m != 0 else 0 for m in movie_seq]

        # Convert all of the sequences to PyTorch tensors
        movie_seq = torch.LongTensor(movie_seq) #(max_seq_len,)
        emb_seq = [torch.tensor(emb, dtype=torch.float) for emb in emb_seq]
        emb_seq = torch.stack(emb_seq, dim=0) #(max_seq_len, 384)
        attention_mask = torch.tensor(attention_mask) #(max_seq_len,)


        return movie_seq, emb_seq, attention_mask

In [None]:
from torch.utils.data import DataLoader

train_dataset = BERT4RecDataset(train_data, max_seq_len=50)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

In [None]:
test_dataset = BERT4RecDataset(test_data, max_seq_len=50)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
# Check the test_loader data
for i, (movie_seq, emb_seq, attention_mask) in enumerate(test_loader):
    print(f"Movie sequence shape: {movie_seq.shape}")
    print(f"Embedding sequence shape: {emb_seq.shape}")
    print(f"Attention mask shape: {attention_mask.shape}")
    break

# Check the embedding sequence for the  sample
movie_seq[10]

Movie sequence shape: torch.Size([32, 50])
Embedding sequence shape: torch.Size([32, 50, 384])
Attention mask shape: torch.Size([32, 50])


tensor([ 104, 1777, 1732, 1753,  663,  788, 2539,  231, 2683, 2598, 1665,  586,
         333,  216,  784, 2806, 2325, 2431,   88, 1461,  435, 2507, 2335, 2306,
        2907, 3146, 2076, 3182,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0])

In [None]:
import torch.nn as nn

class BERT4RecWithText(nn.Module):
    def __init__(self, vocab_size,
                 hidden_dim=128,
                 text_emb_dim=384,
                 max_seq_len=50,
                 n_heads=4, n_layers=2, dropout=0.5):
        super().__init__()

        # Embedding layer for the movie IDs
        self.item_embedding = nn.Embedding(vocab_size,
                                           hidden_dim,
                                           padding_idx=0) #(vocab_size, hidden_dim)

        # Projection layer for text embeddings
        self.text_projection = nn.Linear(text_emb_dim,
                                         hidden_dim) #(text_emb_dim, hidden_dim)

        # Apply a fuse_linear to concatenate the embeddings
        self.fuse_linear = nn.Linear(hidden_dim * 2, hidden_dim)

        # Positional encoding
        self.positional_encoding = nn.Embedding(max_seq_len, hidden_dim) #(max_seq_len, hidden_dim)

        # Transformer layer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model = hidden_dim,
            nhead=n_heads,
            dim_feedforward=hidden_dim*4,
            dropout=dropout,
            activation='gelu',
            batch_first=True
        )

        # Transformer encoder to encode the sequence
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=n_layers
        )

        self.transformer_norm = nn.LayerNorm(hidden_dim)

        # Store themaximum sequence length
        self.max_seq_len = max_seq_len

        # MLP layers with BatchNorm
        self.mlp1 = nn.Linear(hidden_dim, 1024)
        self.bn1 = nn.BatchNorm1d(1024)  # Normalize over batch and sequence
        self.mlp2 = nn.Linear(1024, 512)
        self.bn2 = nn.BatchNorm1d(512)
        self.mlp3 = nn.Linear(512, vocab_size)
        self.leaky_relu = nn.LeakyReLU(negative_slope=0.01)
        self.dropout = nn.Dropout(dropout)


    def generate_casual_mask(self, seq_len):
        mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
        return mask

    def forward(self,
                item_ids, # (batch_size, max_seq_len) input integer movie IDs
                text_embeddings, # (batch_size, max_seq_len, 384) input text embeddings
                attention_mask # (batch_size, max_seq_len) input attention mask
    ):
        B, L = item_ids.shape

        # Create position indices
        positions = torch.arange(L, device=item_ids.device)
        positions = positions.unsqueeze(0).expand(B, -1) #(B, L)

        # Embed the item IDs to shape (B, L, hidden_dim)
        item_emb = self.item_embedding(item_ids)





        # Add positional embeddings
        pos_emb = self.positional_encoding(positions) #(B, L, hidden_dim)
        fused_emb = item_emb + pos_emb #(B, L, hidden_dim)

        #fused_emb = self.transformer_norm(fused_emb)  # Normalize over B and L

        # padding_mask is True for padding tokens and False for real tokens
        # So that the Transformer igmore padded positions
        padding_mask = (attention_mask == 0) #(B, L)

        casual_mask = self.generate_casual_mask(L).to(item_ids.device)


        # Pass it through the Transformer
        encoded = self.transformer_encoder(fused_emb,
                                           mask = casual_mask,
                                            src_key_padding_mask=padding_mask) #(B, L, hidden_dim)

        # Project text embeddings to shape (B, L, hidden_dim)
        text_emb = self.text_projection(text_embeddings)

        # Fuse item and text embeddings
        fused_emb = torch.cat([encoded, text_emb], dim=-1) #(B, L, hidden_dim*2)
        fused_emb = self.fuse_linear(fused_emb) #(B, L, hidden_dim)

        # MLP 1 with BatchNorm
        encoded = self.mlp1(fused_emb)  # (B, L, 1024)
        encoded = encoded.transpose(1, 2)  # (B, 1024, L) for BatchNorm1d
        encoded = self.bn1(encoded)  # Normalize over B and L
        encoded = encoded.transpose(1, 2)  # (B, L, 1024)
        encoded = self.leaky_relu(encoded)
        encoded = self.dropout(encoded)

        # MLP 2 with BatchNorm
        encoded = self.mlp2(encoded)  # (B, L, 512)
        encoded = encoded.transpose(1, 2)  # (B, 512, L)
        encoded = self.bn2(encoded)  # Normalize over B and L
        encoded = encoded.transpose(1, 2)  # (B, L, 512)
        encoded = self.leaky_relu(encoded)
        encoded = self.dropout(encoded)

        # MLP 3 (no BatchNorm before output logits)
        out = self.mlp3(encoded)  # (B, L, vocab_size)

        return out

# Initialize Model
vocab_size = max(max(seq) for seq in user_movie_dict.values()) + 1  # Get max movie ID as vocab size
model = BERT4RecWithText(vocab_size).to(device)

print(f"Initialized BERT4Rec model with vocab size {vocab_size}")

Initialized BERT4Rec model with vocab size 3953


In [None]:
print(f"Using device: {device}")

Using device: cuda


In [None]:
import torch
import torch.optim as optim
import torch.nn as nn

def adjust_padding(item_ids, target_padding_ratio=0.5):
    """
    Adjusts the padding ratio in item_ids by randomly masking real items to match target_padding_ratio.
    """
    real_items = item_ids != 0  # (B, L)
    num_real = real_items.sum(dim=-1).float()  # (B,) as float for division
    seq_len = item_ids.shape[-1]  # Scalar (e.g., 50)
    current_padding_ratio = 1 - num_real / seq_len  # (B,)
    mask_prob = (target_padding_ratio - current_padding_ratio) / (1 - current_padding_ratio + 1e-6)  # (B,), add epsilon to avoid division by zero
    mask_prob = torch.clamp(mask_prob, 0, 1)  # (B,)
    mask_prob = mask_prob.unsqueeze(-1)  # (B, 1) for broadcasting
    mask = (torch.rand_like(item_ids, dtype=torch.float) < mask_prob) & real_items  # (B, L)
    item_ids = item_ids.clone()
    item_ids[mask] = 0  # Replace real items with padding
    return item_ids

def train_model_with_text(model, train_loader, test_loader, epochs=20, lr=0.001, device="cpu"):
    '''
    This function trains the model with text embeddings
    model: BERT4RecWithText model
    data_loader: DataLoader object that provides
    item_ids: (batch_size, max_seq_len) input integer movie IDs
    text_embeddings: (batch_size, max_seq_len, 384) input text embeddings
    attention_mask: (batch_size, max_seq_len) input attention mask

    epochs: number of epochs to train
    lr: learning rate
    device: device is CUDA if GPU is available, otherwise CPU
    '''

    # Initialize the optimizer
    optimizer = optim.Adam(model.parameters(), weight_decay=1e-4, lr=lr)
    criterion = nn.CrossEntropyLoss(ignore_index=0)  # Ignore padding index 0

    model.to(device)



    for epoch in range(1, epochs+1):
        total_loss = 0
        model.train()
        num_train_batches = 0
        total_train_loss = 0

        for batch in train_loader:
            # batch: (item_ids, text_emb, attention_mask)
            item_ids, text_emb, attn_mask = batch

            item_ids = adjust_padding(item_ids, target_padding_ratio=0.5)
            attn_mask = (item_ids != 0).float()

            item_ids = item_ids.to(device)     # shape (B, L)
            text_emb = text_emb.to(device)     # shape (B, L, 384)
            attn_mask = attn_mask.to(device)   # shape (B, L)

            # Prepare the input and target
            input_ids   = item_ids[:, :-1]
            input_emb   = text_emb[:, :-1, :]
            input_mask  = attn_mask[:, :-1]

            target_ids  = item_ids[:, 1:]

            # Forward pass
            outputs = model(input_ids,
                            input_emb,
                            input_mask)  # (B, L-1, vocab_size)

            # Reshape outputs and targets for CrossEntropyLoss
            B, seq_len_minus_1, num_items = outputs.shape
            outputs_2d = outputs.view(B * seq_len_minus_1, num_items)
            targets_2d = target_ids[:, :seq_len_minus_1].reshape(-1)

            # Compute Loss
            loss = criterion(outputs_2d, targets_2d)

            # Backprop and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_train_loss = total_loss / len(train_loader)

        # Evaluation phase (test_loader only)
        model.eval()
        total_test_loss = total_test_recall = 0
        with torch.no_grad():
            for batch in test_loader:
                item_ids, text_emb, attn_mask = [x.to(device) for x in batch]
                input_ids = item_ids[:, :-1]
                input_emb = text_emb[:, :-1, :]
                input_mask = attn_mask[:, :-1]
                target_ids = item_ids[:, 1:]

                outputs = model(input_ids, input_emb, input_mask)  # (B, L-1, vocab_size)

                # Test loss
                B, seq_len_minus_1, num_items = outputs.shape
                outputs_2d = outputs.view(B * seq_len_minus_1, num_items)
                targets_2d = target_ids.reshape(-1)
                loss = criterion(outputs_2d, targets_2d)
                total_test_loss += loss.item()

                # Recall@10
                _, top_k = torch.topk(outputs, k=10, dim=-1)
                hits = (top_k == target_ids.unsqueeze(-1)).any(dim=-1)
                mask = (target_ids != 0)
                recall = hits.float()[mask].mean() if mask.sum() > 0 else 0.0
                total_test_recall += recall.item()

        avg_test_loss = total_test_loss / len(test_loader)
        avg_test_recall = total_test_recall / len(test_loader)

        print(f"Epoch {epoch}/{epochs}, Train Loss: {avg_train_loss:.4f}, Test Loss: {avg_test_loss:.4f}, Test Recall@10: {avg_test_recall:.4f}")


# Train the model
train_model_with_text(model, train_loader, test_loader, epochs=20, lr=1e-5, device=device)


Epoch 1/20, Train Loss: 5.0956, Test Loss: 6.2481, Test Recall@10: 0.1819
Epoch 2/20, Train Loss: 5.0845, Test Loss: 6.2387, Test Recall@10: 0.1822
Epoch 3/20, Train Loss: 5.0943, Test Loss: 6.2347, Test Recall@10: 0.1823
Epoch 4/20, Train Loss: 5.0872, Test Loss: 6.2440, Test Recall@10: 0.1825
Epoch 5/20, Train Loss: 5.0939, Test Loss: 6.2337, Test Recall@10: 0.1829
Epoch 6/20, Train Loss: 5.0869, Test Loss: 6.2376, Test Recall@10: 0.1822
Epoch 7/20, Train Loss: 5.0857, Test Loss: 6.2410, Test Recall@10: 0.1818
Epoch 8/20, Train Loss: 5.0873, Test Loss: 6.2387, Test Recall@10: 0.1824
Epoch 9/20, Train Loss: 5.0819, Test Loss: 6.2421, Test Recall@10: 0.1823
Epoch 10/20, Train Loss: 5.0893, Test Loss: 6.2458, Test Recall@10: 0.1818
Epoch 11/20, Train Loss: 5.0908, Test Loss: 6.2395, Test Recall@10: 0.1821
Epoch 12/20, Train Loss: 5.0891, Test Loss: 6.2420, Test Recall@10: 0.1819
Epoch 13/20, Train Loss: 5.0899, Test Loss: 6.2355, Test Recall@10: 0.1820
Epoch 14/20, Train Loss: 5.0902, T

In [None]:
import torch
import torch.nn as nn

def evaluate_model_with_text(model, data_loader, k=10, device='cuda'):
    """
    Evaluate the model using next-item prediction.
    We shift inputs and targets by one position:
        input: item_ids[:, :-1], text_emb[:, :-1], attention_mask[:, :-1]
        target: item_ids[:, 1:]
    Then compute loss, Recall@K, and NDCG@K.

    Args:
        model: BERT4RecWithText model
        data_loader: yields (item_ids, text_emb, attention_mask)
                     shapes:
                       item_ids:   [B, L]
                       text_emb:   [B, L, 384]
                       attn_mask:  [B, L]
        k: top-K for metrics
        device: 'cuda' or 'cpu'
    """
    model.eval()
    criterion = nn.CrossEntropyLoss(ignore_index=0)  # ignore padding label
    total_loss = 0
    total_recall = 0
    total_ndcg = 0

    with torch.no_grad():
        for item_ids, text_emb, attn_mask in data_loader:
            item_ids = item_ids.to(device)
            text_emb = text_emb.to(device)
            attn_mask = attn_mask.to(device)

            # Shift for targets
            input_ids = item_ids[:, :-1]
            input_emb = text_emb[:, :-1, :]
            input_mask = attn_mask[:, :-1]

            targets = item_ids[:, 1:]  # (Batch, L-1)

            # Forward pass
            outputs = model(input_ids, input_emb, input_mask)

            # Compute loss
            # Flatten for cross-entropy
            B, seq_len_minus_1, vocab_size = outputs.shape
            outputs_2d = outputs.view(B * seq_len_minus_1, vocab_size)
            targets_2d = targets[:, :seq_len_minus_1].reshape(-1)

            loss = criterion(outputs_2d, targets_2d)
            total_loss += loss.item()

            _, top_k_predictions = torch.topk(outputs, k, dim=-1)
            recall = recall_at_k(top_k_predictions, targets[:, :seq_len_minus_1], k, input_mask)
            ndcg = ndcg_at_k(top_k_predictions, targets[:, :seq_len_minus_1], k, input_mask)

            total_recall += recall
            total_ndcg += ndcg

    num_batches = len(data_loader)
    avg_loss = total_loss / num_batches
    avg_recall = total_recall / num_batches
    avg_ndcg = total_ndcg / num_batches

    print(f"test - Loss: {avg_loss:.4f}, Recall@{k}: {avg_recall:.4f}, NDCG@{k}: {avg_ndcg:.4f}")
    return avg_loss, avg_recall, avg_ndcg

def recall_at_k(top_k_predictions, targets, k, mask):
    """
    Compute Recall@K for next-item prediction.

    Args:
        top_k_predictions: [B, L-1, k] - Top-K predicted item IDs
        targets: [B, L-1] - True next-item IDs
        k: Number of top predictions to consider
        mask: [B, L-1] - Attention mask (1 for valid, 0 for padded)

    Returns:
        float: Average Recall@K across all valid positions
    """
    # Check if true target is in top-K predictions
    hits = (top_k_predictions == targets.unsqueeze(-1)).float()  # [B, L-1, k]
    valid_hits = hits * mask.unsqueeze(-1)  # [B, L-1, k], zero out padded positions

    # Sum hits along top-K dimension: 1 if target is in top-K, 0 otherwise
    hits_per_pos = valid_hits.sum(dim=-1)  # [B, L-1], max 1 per position

    # Number of valid positions per sequence
    valid_positions = mask.sum(dim=-1)  # [B], total non-padded positions per sequence

    # Recall per sequence: fraction of valid positions where target was in top-K
    recall_per_seq = hits_per_pos.sum(dim=-1) / valid_positions.clamp(min=1)  # [B]

    # Average across batch, avoid NaN if all sequences are fully padded
    recall = recall_per_seq.mean().item() if valid_positions.sum() > 0 else 0.0
    return recall

def ndcg_at_k(top_k_predictions, targets, k, mask):
    """
    Compute NDCG@K for next-item prediction with one relevant item per position.

    Args:
        top_k_predictions: [B, L-1, k] - Top-K predicted item IDs
        targets: [B, L-1] - True next-item IDs
        k: Number of top predictions to consider
        mask: [B, L-1] - Attention mask (1 for valid, 0 for padded)

    Returns:
        float: Average NDCG@K across all valid positions
    """
    # Check if true target is in top-K predictions
    hits = (top_k_predictions == targets.unsqueeze(-1)).float()  # [B, L-1, k]
    valid_hits = hits * mask.unsqueeze(-1)  # [B, L-1, k], zero out padded positions

    # Discounted weights: 1/log2(rank + 2), where rank starts at 1
    log_positions = 1 / torch.log2(torch.arange(2, k+2, device=targets.device).float())  # [k]

    # DCG: Sum discounted weights where hits occur
    dcg = (valid_hits * log_positions).sum(dim=-1)  # [B, L-1]

    # Ideal DCG (IDCG): Assume one relevant item per position, ranked first
    idcg = log_positions[0]  # 1/log2(2) = 1, highest possible score per position

    # NDCG per position: DCG / IDCG
    ndcg_per_pos = dcg / idcg  # [B, L-1], scales between 0 and 1

    # Sum NDCG for each sequence and divide by valid positions in that sequence
    ndcg_sum_per_seq = ndcg_per_pos.sum(dim=-1)  # [B]
    valid_positions_per_seq = mask.sum(dim=-1)  # [B]
    ndcg_per_seq = ndcg_sum_per_seq / valid_positions_per_seq.clamp(min=1)  # [B]

    # Average across sequences
    ndcg = ndcg_per_seq.mean().item() if valid_positions_per_seq.sum() > 0 else 0.0
    return ndcg

In [None]:
evaluate_model_with_text(model, test_loader, k=10, device=device)

test - Loss: 6.2451, Recall@10: 0.1935, NDCG@10: 0.1065


(6.24514752847177, 0.1934580838002225, 0.10646433469952729)

In [None]:
# Save the model state_dict and optimizer state_dict
save_path = "Transformer_TitleEmbedding_After_Mar5_18%Recall.pth"
torch.save({
    'model_state_dict': model.state_dict()
}, save_path)

print(f"Model saved to {save_path}")


Model saved to Transformer_TitleEmbedding_After_Mar5_18%Recall.pth


In [None]:
checkpoint = torch.load("Transformer_TitleEmbedding_After_Mar5_18%Recall.pth")
model.load_state_dict(checkpoint['model_state_dict'])



  checkpoint = torch.load("Transformer_TitleEmbedding_After_Mar5_18%Recall.pth")


<All keys matched successfully>

In [None]:
def inspect_predictions(model, data_loader, k=10, num_samples=5, device='cuda'):
    """Inspect individual predictions and metrics"""
    model.eval()

    with torch.no_grad():
        for item_ids, text_emb, attn_mask in data_loader:
            item_ids = item_ids.to(device)
            text_emb = text_emb.to(device)
            attn_mask = attn_mask.to(device)

            # Forward pass
            input_ids = item_ids[:, :-1]
            input_emb = text_emb[:, :-1, :]
            input_mask = attn_mask[:, :-1]
            targets = item_ids[:, 1:]

            outputs = model(input_ids, input_emb, input_mask)

            # Get top-K predictions
            _, top_k_predictions = torch.topk(outputs, k, dim=-1)

            # Print details for the first few samples
            for i in range(min(num_samples, input_ids.size(0))):
                seq_len = input_mask[i].sum().item()
                print(f"\nSample {i+1}, Sequence length: {seq_len}")

                for pos in range(seq_len):
                    true_item = targets[i, pos].item()
                    pred_items = top_k_predictions[i, pos].tolist()
                    hit = true_item in pred_items
                    rank = pred_items.index(true_item) + 1 if hit else "N/A"

                    print(f"  Pos {pos}: True={true_item}, Top-{k}={pred_items}, Hit={hit}, Rank={rank}")

            break  # Just process one batch

inspect_predictions(model, test_loader, k=20, num_samples=3, device=device)


Sample 1, Sequence length: 11
  Pos 0: True=2687, Top-20=[364, 595, 2081, 1721, 1210, 2087, 1193, 1907, 1097, 1617, 900, 1198, 531, 2018, 260, 2857, 920, 2096, 945, 912], Hit=False, Rank=N/A
  Pos 1: True=745, Top-20=[364, 596, 2096, 1097, 1566, 2087, 1907, 2857, 2018, 1033, 1025, 531, 2081, 588, 1029, 1210, 1073, 2687, 2137, 2083], Hit=False, Rank=N/A
  Pos 2: True=588, Top-20=[1148, 2078, 34, 3114, 2080, 1223, 3429, 595, 1, 364, 2761, 1617, 2396, 2687, 594, 1641, 2700, 50, 2987, 3396], Hit=False, Rank=N/A
  Pos 3: True=1, Top-20=[595, 364, 596, 2087, 2081, 2080, 588, 2018, 2089, 2078, 594, 2687, 2083, 1907, 783, 1064, 2085, 1566, 34, 2096], Hit=False, Rank=N/A
  Pos 4: True=2355, Top-20=[364, 3114, 594, 595, 34, 2355, 3429, 588, 2085, 596, 1, 2018, 2700, 2087, 1907, 1223, 1265, 2080, 1028, 2102], Hit=True, Rank=6
  Pos 5: True=2294, Top-20=[588, 364, 2355, 2085, 2687, 34, 2761, 2700, 3751, 1028, 2018, 2384, 3396, 596, 2087, 2078, 595, 1907, 1641, 594], Hit=False, Rank=N/A
  Pos 6: T

In [None]:
def print_stats(loader, name):
    item_counts = torch.bincount(torch.cat([batch[0].flatten() for batch in loader]))
    seq_lengths = [batch[0].shape[1] for batch in loader]
    print(f"{name} - Unique items: {len(item_counts)}, Padding %: {item_counts[0]/item_counts.sum():.4f}, Avg Seq Len: {sum(seq_lengths)/len(seq_lengths):.2f}")
print_stats(train_loader, "Train")
print_stats(test_loader, "Test")

Train - Unique items: 3953, Padding %: 0.1496, Avg Seq Len: 50.00
Test - Unique items: 3953, Padding %: 0.5101, Avg Seq Len: 50.00


# BASELINE model

In [None]:
import torch.nn as nn

class BERT4RecNoText(nn.Module):
    def __init__(self, vocab_size,
                 hidden_dim=128,
                 max_seq_len=50,
                 n_heads=4, n_layers=2, dropout=0.5):
        super().__init__()

        # Embedding layer for the movie IDs
        self.item_embedding = nn.Embedding(vocab_size,
                                           hidden_dim,
                                           padding_idx=0) #(vocab_size, hidden_dim)

        # Apply a fuse_linear to concatenate the embeddings
        self.fuse_linear = nn.Linear(hidden_dim, hidden_dim)

        # Positional encoding
        self.positional_encoding = nn.Embedding(max_seq_len, hidden_dim) #(max_seq_len, hidden_dim)

        # Transformer layer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model = hidden_dim,
            nhead=n_heads,
            dim_feedforward=hidden_dim*4,
            dropout=dropout,
            activation='gelu',
            batch_first=True
        )

        # Transformer encoder to encode the sequence
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=n_layers
        )

        self.transformer_norm = nn.LayerNorm(hidden_dim)

        self.output_layer = nn.Linear(hidden_dim, vocab_size)

        # Store themaximum sequence length
        self.max_seq_len = max_seq_len

    def generate_casual_mask(self, seq_len):
        mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
        return mask

    def forward(self,
                item_ids, # (batch_size, max_seq_len) input integer movie IDs
                attention_mask # (batch_size, max_seq_len) input attention mask
    ):
        B, L = item_ids.shape

        # Create position indices
        positions = torch.arange(L, device=item_ids.device)
        positions = positions.unsqueeze(0).expand(B, -1) #(B, L)

        # Embed the item IDs to shape (B, L, hidden_dim)
        item_emb = self.item_embedding(item_ids)


        fused_emb = self.fuse_linear(item_emb) #(B, L, hidden_dim)

        # Add positional embeddings
        pos_emb = self.positional_encoding(positions)
        fused_emb = fused_emb + pos_emb

        fused_emb = self.transformer_norm(fused_emb)  # Normalize over B and L

        # padding_mask is True for padding tokens and False for real tokens
        # So that the Transformer igmore padded positions
        padding_mask = (attention_mask == 0) #(B, L)

        casual_mask = self.generate_casual_mask(L).to(item_ids.device)


        # Pass it through the Transformer
        encoded = self.transformer_encoder(fused_emb,
                                           mask = casual_mask,
                                            src_key_padding_mask=padding_mask) #(B, L, hidden_dim)

        # Output logits
        out = self.output_layer(encoded) #(B, L, vocab_size)

        return out

# Initialize Model
vocab_size = max(max(seq) for seq in user_movie_dict.values()) + 1  # Get max movie ID as vocab size
model = BERT4RecNoText(vocab_size).to(device)

print(f"Initialized BERT4Rec model with vocab size {vocab_size}")

Initialized BERT4Rec model with vocab size 3953


In [None]:
import torch
import torch.optim as optim
import torch.nn as nn

def adjust_padding(item_ids, target_padding_ratio=0.5):
    """
    Adjusts the padding ratio in item_ids by randomly masking real items to match target_padding_ratio.
    """
    real_items = item_ids != 0  # (B, L)
    num_real = real_items.sum(dim=-1).float()  # (B,) as float for division
    seq_len = item_ids.shape[-1]  # Scalar (e.g., 50)
    current_padding_ratio = 1 - num_real / seq_len  # (B,)
    mask_prob = (target_padding_ratio - current_padding_ratio) / (1 - current_padding_ratio + 1e-6)  # (B,), add epsilon to avoid division by zero
    mask_prob = torch.clamp(mask_prob, 0, 1)  # (B,)
    mask_prob = mask_prob.unsqueeze(-1)  # (B, 1) for broadcasting
    mask = (torch.rand_like(item_ids, dtype=torch.float) < mask_prob) & real_items  # (B, L)
    item_ids = item_ids.clone()
    item_ids[mask] = 0  # Replace real items with padding
    return item_ids

def train_model_no_text(model, data_loader, epochs=150,lr=0.01, device="cpu"):
    '''
    This function trains the model with text embeddings
    model: BERT4RecWithText model
    data_loader: DataLoader object that provides
    item_ids: (batch_size, max_seq_len) input integer movie IDs
    text_embeddings: (batch_size, max_seq_len, 384) input text embeddings
    attention_mask: (batch_size, max_seq_len) input attention mask

    epochs: number of epochs to train
    lr: learning rate
    device: device is CUDA if GPU is available, otherwise CPU
    '''

    # Initialize the optimizer
    optimizer = optim.Adam(model.parameters(), weight_decay=1e-4, lr=lr)
    criterion = nn.CrossEntropyLoss(ignore_index=0)  # Ignore padding index 0

    model.to(device)



    for epoch in range(1, epochs+1):
        total_loss = 0
        model.train()
        num_train_batches = 0
        total_train_loss = 0

        for batch in data_loader:
            # batch: (item_ids, text_emb, attention_mask)
            item_ids, text_emb, attn_mask = batch

            item_ids = adjust_padding(item_ids, target_padding_ratio=0.5)
            attn_mask = (item_ids != 0).float()

            item_ids = item_ids.to(device)     # shape (B, L)
            #text_emb = text_emb.to(device)     # shape (B, L, 384)
            attn_mask = attn_mask.to(device)   # shape (B, L)

            # Prepare the input and target
            input_ids   = item_ids[:, :-1]
            #input_emb   = text_emb[:, :-1, :]
            input_mask  = attn_mask[:, :-1]

            target_ids  = item_ids[:, 1:]

            # Forward pass
            outputs = model(input_ids,
                            #input_emb,
                            input_mask)  # (B, L-1, vocab_size)

            # Reshape outputs and targets for CrossEntropyLoss
            B, seq_len_minus_1, num_items = outputs.shape
            outputs_2d = outputs.view(B * seq_len_minus_1, num_items)
            targets_2d = target_ids[:, :seq_len_minus_1].reshape(-1)

            # Compute Loss
            loss = criterion(outputs_2d, targets_2d)

            # Backprop and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(data_loader)
        print(f"Epoch {epoch}/{epochs}, Loss: {avg_loss:.4f}")


# Train the model
train_model_no_text(model, train_loader, epochs=20, lr=0.001, device=device)

Epoch 1/20, Loss: 7.5201
Epoch 2/20, Loss: 7.3618
Epoch 3/20, Loss: 7.2041
Epoch 4/20, Loss: 6.9266
Epoch 5/20, Loss: 6.7086
Epoch 6/20, Loss: 6.5517
Epoch 7/20, Loss: 6.4288
Epoch 8/20, Loss: 6.3325
Epoch 9/20, Loss: 6.2459
Epoch 10/20, Loss: 6.1739
Epoch 11/20, Loss: 6.1211
Epoch 12/20, Loss: 6.0665
Epoch 13/20, Loss: 6.0229
Epoch 14/20, Loss: 5.9816
Epoch 15/20, Loss: 5.9525
Epoch 16/20, Loss: 5.9213
Epoch 17/20, Loss: 5.8962
Epoch 18/20, Loss: 5.8677
Epoch 19/20, Loss: 5.8384
Epoch 20/20, Loss: 5.8152


In [None]:
import torch
import torch.nn as nn

def evaluate_model_no_text(model, data_loader, k=10, device='cuda'):
    """
    Evaluate the model using next-item prediction.
    We shift inputs and targets by one position:
        input: item_ids[:, :-1], text_emb[:, :-1], attention_mask[:, :-1]
        target: item_ids[:, 1:]
    Then compute loss, Recall@K, and NDCG@K.

    Args:
        model: BERT4RecWithText model
        data_loader: yields (item_ids, text_emb, attention_mask)
                     shapes:
                       item_ids:   [B, L]
                       text_emb:   [B, L, 384]
                       attn_mask:  [B, L]
        k: top-K for metrics
        device: 'cuda' or 'cpu'
    """
    model.eval()
    criterion = nn.CrossEntropyLoss(ignore_index=0)  # ignore padding label
    total_loss = 0
    total_recall = 0
    total_ndcg = 0

    with torch.no_grad():
        for item_ids, text_emb, attn_mask in data_loader:
            item_ids = item_ids.to(device)
            text_emb = text_emb.to(device)
            attn_mask = attn_mask.to(device)

            # Shift for targets
            input_ids = item_ids[:, :-1]
            input_mask = attn_mask[:, :-1]

            targets = item_ids[:, 1:]  # (Batch, L-1)

            # Forward pass
            outputs = model(input_ids, input_mask)

            # Compute loss
            # Flatten for cross-entropy
            B, seq_len_minus_1, vocab_size = outputs.shape
            outputs_2d = outputs.view(B * seq_len_minus_1, vocab_size)
            targets_2d = targets[:, :seq_len_minus_1].reshape(-1)

            loss = criterion(outputs_2d, targets_2d)
            total_loss += loss.item()

            _, top_k_predictions = torch.topk(outputs, k, dim=-1)
            recall = recall_at_k(top_k_predictions, targets[:, :seq_len_minus_1], k, input_mask)
            ndcg = ndcg_at_k(top_k_predictions, targets[:, :seq_len_minus_1], k, input_mask)

            total_recall += recall
            total_ndcg += ndcg

    num_batches = len(data_loader)
    avg_loss = total_loss / num_batches
    avg_recall = total_recall / num_batches
    avg_ndcg = total_ndcg / num_batches

    print(f"test - Loss: {avg_loss:.4f}, Recall@{k}: {avg_recall:.4f}, NDCG@{k}: {avg_ndcg:.4f}")
    return avg_loss, avg_recall, avg_ndcg

def recall_at_k(top_k_predictions, targets, k, mask):
    """
    Compute Recall@K for next-item prediction.

    Args:
        top_k_predictions: [B, L-1, k] - Top-K predicted item IDs
        targets: [B, L-1] - True next-item IDs
        k: Number of top predictions to consider
        mask: [B, L-1] - Attention mask (1 for valid, 0 for padded)

    Returns:
        float: Average Recall@K across all valid positions
    """
    # Check if true target is in top-K predictions
    hits = (top_k_predictions == targets.unsqueeze(-1)).float()  # [B, L-1, k]
    valid_hits = hits * mask.unsqueeze(-1)  # [B, L-1, k], zero out padded positions

    # Sum hits along top-K dimension: 1 if target is in top-K, 0 otherwise
    hits_per_pos = valid_hits.sum(dim=-1)  # [B, L-1], max 1 per position

    # Number of valid positions per sequence
    valid_positions = mask.sum(dim=-1)  # [B], total non-padded positions per sequence

    # Recall per sequence: fraction of valid positions where target was in top-K
    recall_per_seq = hits_per_pos.sum(dim=-1) / valid_positions.clamp(min=1)  # [B]

    # Average across batch, avoid NaN if all sequences are fully padded
    recall = recall_per_seq.mean().item() if valid_positions.sum() > 0 else 0.0
    return recall

def ndcg_at_k(top_k_predictions, targets, k, mask):
    """
    Compute NDCG@K for next-item prediction with one relevant item per position.

    Args:
        top_k_predictions: [B, L-1, k] - Top-K predicted item IDs
        targets: [B, L-1] - True next-item IDs
        k: Number of top predictions to consider
        mask: [B, L-1] - Attention mask (1 for valid, 0 for padded)

    Returns:
        float: Average NDCG@K across all valid positions
    """
    # Check if true target is in top-K predictions
    hits = (top_k_predictions == targets.unsqueeze(-1)).float()  # [B, L-1, k]
    valid_hits = hits * mask.unsqueeze(-1)  # [B, L-1, k], zero out padded positions

    # Discounted weights: 1/log2(rank + 2), where rank starts at 1
    log_positions = 1 / torch.log2(torch.arange(2, k+2, device=targets.device).float())  # [k]

    # DCG: Sum discounted weights where hits occur
    dcg = (valid_hits * log_positions).sum(dim=-1)  # [B, L-1]

    # Ideal DCG (IDCG): Assume one relevant item per position, ranked first
    idcg = log_positions[0]  # 1/log2(2) = 1, highest possible score per position

    # NDCG per position: DCG / IDCG
    ndcg_per_pos = dcg / idcg  # [B, L-1], scales between 0 and 1

    # Sum NDCG for each sequence and divide by valid positions in that sequence
    ndcg_sum_per_seq = ndcg_per_pos.sum(dim=-1)  # [B]
    valid_positions_per_seq = mask.sum(dim=-1)  # [B]
    ndcg_per_seq = ndcg_sum_per_seq / valid_positions_per_seq.clamp(min=1)  # [B]

    # Average across sequences
    ndcg = ndcg_per_seq.mean().item() if valid_positions_per_seq.sum() > 0 else 0.0
    return ndcg

In [None]:
evaluate_model_no_text(model, test_loader, k=10, device=device)

test - Loss: 6.5008, Recall@10: 0.1458, NDCG@10: 0.0769


(6.500831664554656, 0.14576226693612557, 0.07692421365667272)