In [1]:
TOTAL_ITEM_COUNT = 1495777 # 1495778
VOCAB_SIZE = TOTAL_ITEM_COUNT + 2  # including PAD and MASK
MASK_ID = VOCAB_SIZE - 1  # Add special token for [MASK]
PAD_ID = 0
MAX_SEQUENCE_LEN = 8

NUM_EPOCHS = 1 
import random
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence

import pickle
import torch
import tqdm
from torch import optim
from torch.utils.data import DataLoader, random_split

In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Working on {DEVICE=}')

Working on DEVICE=device(type='cuda')


In [3]:
import torch
from torch import nn
from dataclasses import dataclass


# GPT said this is a professional way to do it
@dataclass
class BERT4RecConfig:
    # EmbeddingLayer
    vocab_size: int
    embedding_dim: int
    max_seq_len: int
    embedding_dropout: float

    # Encoder
    num_layers: int
    num_heads: int
    hidden_dim: int
    encoder_dropout: float

    # ProjectionHead
    projection_dim: int


class EmbeddingLayer(nn.Module):
    """Item + positional embeddings with layer normalization and dropout"""

    def __init__(self, vocab_size: int, embedding_dim: int, max_seq_len: int, dropout: float = 0.1):
        super().__init__()
        self.item_embeddings = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.position_embeddings = nn.Embedding(max_seq_len, embedding_dim)
        self.layer_norm = nn.LayerNorm(embedding_dim)
        self.dropout = nn.Dropout(dropout)
        self.vocab_size = vocab_size

    def forward(self, x):
        positions = torch.arange(x.size(1), device=x.device).unsqueeze(0).expand_as(x)
        # print(x.max(), self.vocab_size, x.min())
        embeddings = self.item_embeddings(x) + self.position_embeddings(positions)
        embeddings = self.layer_norm(embeddings)
        return self.dropout(embeddings)


class Encoder(nn.Module):
    """Transformer encoder. Wrapper for `torch.TransformerEncoderLayer`"""

    def __init__(self, embedding_dim: int, num_layers: int, num_heads: int, hidden_dim: int,
                 dropout: float = 0.1) -> None:
        """

        Args:
            embedding_dim:
            num_layers:
            num_heads:
            hidden_dim:
            dropout:

        Returns:
            None:
        """
        super().__init__()
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embedding_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim,
            dropout=dropout,
            activation="gelu",
            batch_first=True,
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

    def forward(self, x, src_key_padding_mask=None):
        # x: [batch_size, seq_len, embedding_dim]
        return self.transformer(x, src_key_padding_mask=src_key_padding_mask)


class ProjectionHead(nn.Module):
    """Projection head"""

    def __init__(self, embedding_dim: int, projection_dim: int, vocab_size: int):
        super().__init__()
        # transform + layer normalization
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.activation = nn.GELU()
        self.layer_norm = nn.LayerNorm(projection_dim)

        # map to final logits
        self.decoder = nn.Linear(projection_dim, vocab_size, bias=False)
        self.bias = nn.Parameter(torch.zeros(vocab_size))

    def forward(self, x):
        # x: [batch_size, seq_len, embedding_dim]
        proj = self.projection(x)
        proj = self.activation(proj)
        proj = self.layer_norm(proj)
        return self.decoder(proj) + self.bias


class BERT4Rec(nn.Module):
    """BERT4Rec model from `https://arxiv.org/pdf/1904.06690` paper"""

    def __init__(self, config: BERT4RecConfig):
        super().__init__()
        self.embedding = EmbeddingLayer(
            vocab_size=config.vocab_size,
            embedding_dim=config.embedding_dim,
            max_seq_len=config.max_seq_len,
            dropout=config.embedding_dropout,
        )
        self.encoder = Encoder(
            embedding_dim=config.embedding_dim,
            num_layers=config.num_layers,
            num_heads=config.num_heads,
            hidden_dim=config.hidden_dim,
            dropout=config.encoder_dropout,
        )
        self.projection = ProjectionHead(
            embedding_dim=config.embedding_dim,
            projection_dim=config.projection_dim,
            vocab_size=config.vocab_size,
        )

        # weights sharing
        self.projection.decoder.weight = self.embedding.item_embeddings.weight

    def forward(self, x, mask=None):
        pad_mask = x.eq(0)
        x = self.embedding(x)
        x = self.encoder(x, src_key_padding_mask=pad_mask)
        x = self.projection(x)
        return x

In [4]:
def masked_cross_entropy_loss(logits, labels, mask):
    """
    Args:
        logits: [batch_size, seq_len, vocab_size] - output from model
        labels: [batch_size, seq_len] - ground truth item IDs (masked positions are real, others are 0 or ignored)
        mask:   [batch_size, seq_len] - binary mask indicating which positions are masked (1 = predict, 0 = ignore)
    Returns:
        scalar loss
    """
    logits = logits.view(-1, logits.size(-1))  # [B*L, V]
    labels = labels.view(-1)  # [B*L]
    mask = mask.view(-1).float()  # [B*L]

    # Compute per-position loss, but ignore non-masked tokens (mask == 0)
    loss = F.cross_entropy(logits, labels, reduction='none')  # [B*L]
    masked_loss = loss * mask
    return masked_loss.sum() / mask.sum()


def train_step(model, optimizer, batch, device=DEVICE):
    model.train()
    input_ids, labels, masked_pos = [x.to(device) for x in batch]  # all [B, L]

    optimizer.zero_grad()
    logits = model(input_ids)  # [B, L, V]
    loss = masked_cross_entropy_loss(logits, labels, masked_pos)
    loss.backward()
    optimizer.step()

    return loss.item()


def validate_step(model, batch, top_k=10, device=DEVICE):
    model.eval()
    input_ids, labels, masked_pos = [x.to(device) for x in batch]

    with torch.no_grad():
        logits = model(input_ids)  # [B, L, V]

    # Only consider predictions at masked positions
    masked_logits = logits[masked_pos.bool()]  # [N_masked, V]
    masked_labels = labels[masked_pos.bool()]  # [N_masked]

    loss = F.cross_entropy(masked_logits, masked_labels, reduction='mean').item()

    # HR@10 and NDCG@10
    _, topk = masked_logits.topk(top_k, dim=-1)  # [N, top_k]
    hits = (topk == masked_labels.unsqueeze(1)).float()  # [N, top_k]
    hr = hits.any(dim=1).float().mean().item()
    ndcg = (hits / torch.log2(torch.arange(2, top_k + 2, device=hits.device).float())).sum(dim=1).mean().item()

    return loss, hr, ndcg


def mask_sequence(seq, mask_token_id, vocab_size, mask_prob=0.15, pad_token_id=0):
    """
    Function takes a user sequence and randomly selects items to mask.
    Returns:
        - input_ids: modified sequence with [MASK] and others.
        - labels: target items only in masked positions (0 elsewhere).
        - masked_pos: binary mask of masked locations.
    """
    input_ids = seq.copy()
    labels = [0] * len(seq)
    masked_pos = [0] * len(seq)
    # mask_prob = 0.5

    for i in range(len(seq)):
        if seq[i] == pad_token_id:
            continue
        if random.random() < mask_prob:
            masked_pos[i] = 1
            labels[i] = seq[i]
            rand = random.random()
            if rand < 0.8:
                input_ids[i] = mask_token_id
            elif rand < 0.9:
                input_ids[i] = random.randint(1, vocab_size - 1)  # avoid pad token
            else:
                pass  # leave unchanged
    return input_ids, labels, masked_pos


def collate_fn(batch, mask_token_id, vocab_size, pad_token_id=0, max_len=None):
    """
    Pads sequences to same length, converts everything to tensors.
    batch: list of sequences (list of item IDs)
    """
    input_ids, labels, masked_pos = [], [], []

    for seq in batch:
        if max_len is not None:
            seq = seq[-max_len:]  # truncate if needed

        inp, lab, msk = mask_sequence(seq, mask_token_id, vocab_size, pad_token_id=pad_token_id)
        input_ids.append(torch.tensor(inp))
        labels.append(torch.tensor(lab))
        masked_pos.append(torch.tensor(msk))

    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=pad_token_id)
    labels = pad_sequence(labels, batch_first=True, padding_value=0)
    masked_pos = pad_sequence(masked_pos, batch_first=True, padding_value=0)

    return input_ids, labels, masked_pos

In [5]:
TOTAL_ITEM_COUNT = 1495777 # 1495778
VOCAB_SIZE = TOTAL_ITEM_COUNT + 2  # including PAD and MASK
MASK_ID = VOCAB_SIZE - 1  # Add special token for [MASK]
PAD_ID = 0
MAX_SEQUENCE_LEN = 8

NUM_EPOCHS = 1

In [6]:
import pickle 

with open("/kaggle/input/agh-sp2/bert_train_corrected.pkl", "rb") as fh:
    train_sequences = pickle.load(fh)

In [7]:
train_size = int(0.8 * len(train_sequences))
val_size = len(train_sequences) - train_size

train_dataset, val_dataset = random_split(train_sequences, [train_size, val_size])

In [8]:
len(train_dataset) // 256

301

In [9]:
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=256,
    shuffle=True,
    collate_fn=lambda batch: collate_fn(batch, mask_token_id=MASK_ID, vocab_size=VOCAB_SIZE, max_len=MAX_SEQUENCE_LEN),
)

val_loader = DataLoader(
    dataset=val_dataset,
    batch_size=256,
    shuffle=False,
    collate_fn=lambda batch: collate_fn(batch, mask_token_id=MASK_ID, vocab_size=VOCAB_SIZE),
)

In [10]:
import gc

def cleanup_cuda():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()

In [11]:
cleanup_cuda()
cleanup_cuda()
cleanup_cuda()

In [12]:
config = BERT4RecConfig(
    vocab_size=VOCAB_SIZE + 2,
    embedding_dim=64,
    max_seq_len=MAX_SEQUENCE_LEN,
    embedding_dropout=0.2,
    num_layers=8,
    num_heads=8,
    hidden_dim=128,
    encoder_dropout=0.2,
    projection_dim=64,
)

model = BERT4Rec(config)
model = model.to(DEVICE)
optimizer = optim.AdamW(model.parameters(), lr=1e-3)

In [13]:
train_history = []
valid_history = []

NUM_EPOCHS = 5
for epoch in range(NUM_EPOCHS):
    for batch in tqdm.tqdm(train_loader, desc="Training"):
        loss = train_step(model, optimizer, batch)
        train_history.append(loss)
    
    val_loss, hr10, ndcg10 = validate_step(model, next(iter(val_loader)))
    valid_history.append((val_loss, hr10, ndcg10))
    print(f"Epoch {epoch + 1} | Val Loss: {val_loss:.4f} | HR@10: {hr10:.4f} | NDCG@10: {ndcg10:.4f}")

Training:   0%|          | 0/302 [00:00<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 11.41 GiB. GPU 0 has a total capacity of 15.89 GiB of which 3.74 GiB is free. Process 5924 has 12.15 GiB memory in use. Of the allocated memory 11.84 GiB is allocated by PyTorch, and 15.49 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def moving_average(data, window_size):
    return np.convolve(data, np.ones(window_size)/window_size, mode='valid')

window_size = 100
y = moving_average(train_history, window_size)
x = np.arange(len(y)) + window_size

plt.plot(x, y, color='blue', alpha=0.5, label='train loss')
plt.axhline(val_loss, color='black', linestyle='--', label='val loss')
plt.legend()
plt.grid(True)
plt.show()