In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from torch.utils.data import DataLoader, Dataset

def load_data(file_path):
    user_dict = {}
    with open(file_path, 'r') as f:
        for line in f:
            user, item = map(int, line.strip().split())
            if user not in user_dict:
                user_dict[user] = []
            user_dict[user].append(item)
    return user_dict

class RecDataset(Dataset):
    def __init__(self, user_dict, num_items, max_seq_length, segment_length):
        self.data = []
        for user, items in user_dict.items():
            if len(items) < 2:
                continue
            
            if len(items)<max_seq_length:
                items = [0]*(max_seq_length-len(items)) + items
                
            
            seq = items[-max_seq_length:]
            for i in range(0, len(seq) - 1, segment_length):
                segment = seq[i:i+segment_length]
                target = seq[i+1:i+1+segment_length]
                if len(segment) == segment_length and len(target) == segment_length:
                    self.data.append((segment, target))
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return torch.tensor(self.data[idx][0], dtype=torch.long), torch.tensor(self.data[idx][1], dtype=torch.long)


# Example usage

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

embed_dim = 64
num_layers = 4
num_heads = 8
hidden_dim = 64
mem_length = 100
max_seq_length = 200
segment_length = 25
batch_size = 32
num_epochs = 10

# Load data

file_path = "data/ml-1m.txt"
user_dict = load_data(file_path)
num_items = max(max(items) for items in user_dict.values()) + 1
print("Number of users: ", len(user_dict))

train_dataset = RecDataset(user_dict, num_items, max_seq_length, segment_length)

print("Number of training samples: ", len(train_dataset))
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,drop_last=True)

for segment, target in train_loader:
    print(segment.shape)
    print(target.shape)
    break
    
# if the last batch is smaller than batch_size, the last batch will be dropped 


Number of users:  6040
Number of training samples:  42280
torch.Size([32, 25])
torch.Size([32, 25])


In [2]:
"""
---
title: Relative Multi-Headed Attention
summary: >
  Documented implementation with explanations of
  Relative Multi-Headed Attention from paper Transformer-XL.
---

# Relative Multi-Headed Attention

This is an implementation of relative multi-headed attention from paper
[Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860)
in [PyTorch](https://pytorch.org).
"""

import torch
from torch import nn

from labml.logger import inspect
from labml_nn.transformers.mha import MultiHeadAttention


from typing import List, Optional
def shift_right(x: torch.Tensor):
    """
    This method shifts $i^{th}$ row of a matrix by $i$ columns.

    If the input is `[[1, 2 ,3], [4, 5 ,6], [7, 8, 9]]`, the shifted
    result would be `[[1, 2 ,3], [0, 4, 5], [6, 0, 7]]`.
    *Ideally we should mask out the lower triangle but it's ok for our purpose*.
    """

    # Concatenate a column of zeros
    zero_pad = x.new_zeros(x.shape[0], 1, *x.shape[2:])
    x_padded = torch.cat([x, zero_pad], dim=1)

    # Reshape and remove excess elements from the end
    x_padded = x_padded.view(x.shape[1] + 1, x.shape[0], *x.shape[2:])
    x = x_padded[:-1].view_as(x)

    #
    return x



class RelativeMultiHeadAttention(MultiHeadAttention):
    """
    ## Relative Multi-Head Attention Module

    We override [Multi-Head Attention](mha.html) module so we only need to 
    write the `get_scores` method.
    """

    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
        # The linear transformations do not need a bias since we
        # explicitly include it when calculating scores.
        # However having a bias for `value` might make sense.
        # print(heads, d_model, dropout_prob)
        super().__init__(heads, d_model, dropout_prob, bias=False)
        
        # Number of relative positions
        self.P = 2 ** 12

        # Relative positional embeddings for key relative to the query.
        # We need $2P$ embeddings because the keys can be before or after the query.
        self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P * 2, heads, self.d_k)), requires_grad=True)
        # Relative positional embedding bias for key relative to the query.
        self.key_pos_bias = nn.Parameter(torch.zeros((self.P * 2, heads)), requires_grad=True)
        # Positional embeddings for the query is independent of the position of the query
        self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)

    def get_scores(self, query: torch.Tensor, key: torch.Tensor):
        r"""
        ### Get relative attention scores

        With absolute attention

        \begin{align}
        A^{abs}_{j} &= lin_q(X^q_i + P_i)^\top lin_k(X^k_j + P_j) \\
                      &= \underset{\textcolor{lightgreen}{A}}{Q_i^\top K_j} +
                         \underset{\textcolor{lightgreen}{B}}{Q_i^\top U^K_j} +
                         \underset{\textcolor{lightgreen}{C}}{{U^Q_i}^\top K_j} +
                         \underset{\textcolor{lightgreen}{D}}{{U^Q_i}^\top U^K_j}
        \end{align}

        where $Q_i, K_j$, are linear transformations of
         original embeddings $X^q_i, X^k_j$
         and $U^Q_i, U^K_j$ are linear transformations of
         absolute positional encodings $P_i, P_j$.

        They reason out that the attention to a given key should be the same regardless of
        the position of query.
        Hence replace $\underset{\textcolor{lightgreen}{C}}{{U^Q_i}^\top K_j}$
        with a constant $\underset{\textcolor{lightgreen}{C}}{\textcolor{orange}{v^\top} K_j}$.

        For the second and third terms relative positional encodings are introduced.
        So $\underset{\textcolor{lightgreen}{B}}{Q_i^\top U^K_j}$ is
        replaced with $\underset{\textcolor{lightgreen}{B}}{Q_i^\top \textcolor{orange}{R_{i - j}}}$
        and $\underset{\textcolor{lightgreen}{D}}{{U^Q_i}^\top U^K_j}$
        with $\underset{\textcolor{lightgreen}{D}}{\textcolor{orange}{S_{i-j}}}$.

        \begin{align}
        A^{rel}_{i,j} &= \underset{\mathbf{\textcolor{lightgreen}{A}}}{Q_i^\top K_j} +
                         \underset{\mathbf{\textcolor{lightgreen}{B}}}{Q_i^\top \textcolor{orange}{R_{i - j}}} +
                         \underset{\mathbf{\textcolor{lightgreen}{C}}}{\textcolor{orange}{v^\top} K_j} +
                         \underset{\mathbf{\textcolor{lightgreen}{D}}}{\textcolor{orange}{S_{i-j}}}
        \end{align}
        """

        # $\textcolor{orange}{R_k}$
        key_pos_emb = self.key_pos_embeddings[self.P - key.shape[0]:self.P + query.shape[0]]
        # $\textcolor{orange}{S_k}$
        key_pos_bias = self.key_pos_bias[self.P - key.shape[0]:self.P + query.shape[0]]
        # $\textcolor{orange}{v^\top}$
        query_pos_bias = self.query_pos_bias[None, None, :, :]

        # ${(\textcolor{lightgreen}{\mathbf{A + C}})}_{i,j} =
        # Q_i^\top K_j +
        # \textcolor{orange}{v^\top} K_j$
        ac = torch.einsum('ibhd,jbhd->ijbh', query + query_pos_bias, key)
        # $\textcolor{lightgreen}{\mathbf{B'}_{i,k}} = Q_i^\top \textcolor{orange}{R_k}$
        b = torch.einsum('ibhd,jhd->ijbh', query, key_pos_emb)
        # $\textcolor{lightgreen}{\mathbf{D'}_{i,k}} = \textcolor{orange}{S_k}$
        d = key_pos_bias[None, :, None, :]
        # Shift the rows of $\textcolor{lightgreen}{\mathbf{(B' + D')}_{i,k}}$
        # to get $$\textcolor{lightgreen}{\mathbf{(B + D)}_{i,j} = \mathbf{(B' + D')}_{i,i - j}}$$
        bd = shift_right(b + d)
        # Remove extra positions
        bd = bd[:, -key.shape[0]:]

        # Return the sum $$
        # \underset{\mathbf{\textcolor{lightgreen}{A}}}{Q_i^\top K_j} +
        # \underset{\mathbf{\textcolor{lightgreen}{B}}}{Q_i^\top \textcolor{orange}{R_{i - j}}} +
        # \underset{\mathbf{\textcolor{lightgreen}{C}}}{\textcolor{orange}{v^\top} K_j} +
        # \underset{\mathbf{\textcolor{lightgreen}{D}}}{\textcolor{orange}{S_{i-j}}}
        # $$
        return ac + bd




In [3]:
class TransformerXLLayer(nn.Module):
    def __init__(self, d_model: int, self_attn: RelativeMultiHeadAttention, dropout_prob: float):
        super().__init__()
        self.size = d_model
        self.self_attn = self_attn
        self.linear = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout_prob)
        self.norm_self_attn = nn.LayerNorm(d_model)
        self.norm_linear = nn.LayerNorm(d_model)

    def forward(self, x: torch.Tensor, mem: Optional[torch.Tensor], mask: torch.Tensor):
        z = self.norm_self_attn(x)
        if mem is not None:
            mem = self.norm_self_attn(mem)
            
            m_z = torch.cat((mem, z), dim=0)
        else:
            m_z = z
        self_attn = self.self_attn(query=z, key=m_z, value=m_z, mask=mask)
        x = x + self.dropout(self_attn)
        z = self.norm_linear(x)
        linear_out = self.linear(z)
        x = x + self.dropout(linear_out)
        return x

class TransformerXL(nn.Module):
    def __init__(self, layer: TransformerXLLayer, n_layers: int):
        super().__init__()
        self.layers = nn.ModuleList([layer for _ in range(n_layers)])
        self.norm = nn.LayerNorm(layer.size)

    def forward(self, x: torch.Tensor, mem: List[torch.Tensor], mask: torch.Tensor):
        new_mem = []
        for i, layer in enumerate(self.layers):
            new_mem.append(x.detach())
            m = mem[i] if mem else None
            x = layer(x=x, mem=m, mask=mask)
        return self.norm(x), new_mem

class TransformerXLEncoder(nn.Module):
    def __init__(self, num_items, embed_dim, num_layers, num_heads, hidden_dim, mem_length, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(num_items, embed_dim)
        self.mem_length = mem_length
        # print(embed_dim)
        self.transformer = TransformerXL(
            TransformerXLLayer(
                d_model=embed_dim,
                self_attn=RelativeMultiHeadAttention(num_heads,embed_dim, dropout),
                dropout_prob=dropout
            ),
            n_layers=num_layers
        )
        self.linear = nn.Linear(embed_dim, num_items)
    
    def forward(self, x, memory=None):
        x = self.embedding(x)  # Shape: (B, S, D)
        x = x.permute(1, 0, 2)  # Shape: (S, B, D)
        mask = None  # Define mask if needed
        output, new_memory = self.transformer(x, memory, mask)
        logits = self.linear(output)  # Shape: (S, B, num_items)
        return logits.permute(1, 2, 0), new_memory  # Shape: (B, num_items, S)
    
    
#example usage
# x = torch.randint(0, num_items, (batch_size, segment_length), dtype=torch.long).to(device)
# model = TransformerXLEncoder(num_items, embed_dim, num_layers, num_heads, hidden_dim, mem_length).to(device)
# logits, memory = model(x)
# print(logits.shape)


In [4]:
def train_model(model, train_loader, optimizer, criterion, device, num_epochs=10):
    model.train()
    memory = None  # Initialize memory

    for epoch in range(num_epochs):
        total_loss = 0

        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()

            logits, memory = model(inputs, memory)  # Pass both sequence & memory
            
            # print(logits.shape, targets.shape)
            loss = criterion(logits, targets)  # Shape: (B, num_items, S)

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader):.4f}")


In [5]:
def evaluate(model, user_dict, num_items, max_seq_length, segment_length, device):
    model.eval()
    NDCG, HR, valid_users = 0.0, 0.0, 0
    
    for user, items in user_dict.items():
        if len(items) < 2:
            continue
        
        seq = items[-max_seq_length:]
        input_seq = torch.tensor(seq[:-1], dtype=torch.long).unsqueeze(0).to(device)
        target = seq[-1]
        candidates = [target] + random.sample(set(range(1, num_items)) - set(items), 99)
        
        memory = None
        for i in range(0, len(input_seq[0]), segment_length):
            segment = input_seq[:, i:i+segment_length]
            logits, memory = model(segment, memory)
            
        # logits shape is (1, num_items, segment_length)
        scores = logits[0, :, -1]
        scores = scores[candidates]
        # print(scores)
        ranked = torch.argsort(scores, descending=True).cpu().numpy()
        rank = np.where(ranked == 0)[0][0] + 1
        # print(rank)
        valid_users += 1
        
        
        HR += int(rank <= 10)
        NDCG += 1 / np.log2(rank + 1) if rank <= 10 else 0
        
        if valid_users % 10 == 0:
            print(f"Validated users: {valid_users}, HR@10: {HR / valid_users:.4f}, NDCG@10: {NDCG / valid_users:.4f}")
    print(f"HR@10: {HR / valid_users:.4f}, NDCG@10: {NDCG / valid_users:.4f}")
    
    

In [None]:
model = TransformerXLEncoder(num_items, embed_dim, num_layers, num_heads, hidden_dim, mem_length).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
print(model)
num_epochs=100
train_model(model, train_loader, optimizer, criterion, device, num_epochs)


TransformerXLEncoder(
  (embedding): Embedding(3417, 64)
  (transformer): TransformerXL(
    (layers): ModuleList(
      (0-3): 4 x TransformerXLLayer(
        (self_attn): RelativeMultiHeadAttention(
          (query): PrepareForMultiHeadAttention(
            (linear): Linear(in_features=64, out_features=64, bias=False)
          )
          (key): PrepareForMultiHeadAttention(
            (linear): Linear(in_features=64, out_features=64, bias=False)
          )
          (value): PrepareForMultiHeadAttention(
            (linear): Linear(in_features=64, out_features=64, bias=True)
          )
          (softmax): Softmax(dim=1)
          (output): Linear(in_features=64, out_features=64, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (linear): Linear(in_features=64, out_features=64, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (norm_self_attn): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (norm_linear): LayerNorm

In [None]:
evaluate(model, user_dict, num_items, max_seq_length, segment_length, device)

Validated users: 10, HR@10: 0.1000, NDCG@10: 0.0631
Validated users: 20, HR@10: 0.4500, NDCG@10: 0.3122
Validated users: 30, HR@10: 0.5000, NDCG@10: 0.3479
Validated users: 40, HR@10: 0.5500, NDCG@10: 0.3638
Validated users: 50, HR@10: 0.5400, NDCG@10: 0.3911
Validated users: 60, HR@10: 0.5167, NDCG@10: 0.3831
Validated users: 70, HR@10: 0.5143, NDCG@10: 0.3676
Validated users: 80, HR@10: 0.5125, NDCG@10: 0.3646
Validated users: 90, HR@10: 0.4889, NDCG@10: 0.3432
Validated users: 100, HR@10: 0.4900, NDCG@10: 0.3453
Validated users: 110, HR@10: 0.5000, NDCG@10: 0.3371
Validated users: 120, HR@10: 0.5167, NDCG@10: 0.3320
Validated users: 130, HR@10: 0.5077, NDCG@10: 0.3245
Validated users: 140, HR@10: 0.5000, NDCG@10: 0.3229
Validated users: 150, HR@10: 0.5000, NDCG@10: 0.3309
Validated users: 160, HR@10: 0.5062, NDCG@10: 0.3393
Validated users: 170, HR@10: 0.4941, NDCG@10: 0.3319
Validated users: 180, HR@10: 0.4833, NDCG@10: 0.3265
Validated users: 190, HR@10: 0.4895, NDCG@10: 0.3314
Va