In [9]:
import polars as pl
import torch
from datetime import date, timedelta

In [10]:
%load_ext autoreload
%autoreload 2

In [11]:
TEST_START = date(2024, 7, 1)

In [13]:
user_actions_full = pl.read_parquet('../data/user_actions_full')

In [14]:
train_orders = (
    user_actions_full
    .filter(pl.col('date') < TEST_START)
    .filter(pl.col('date') >= TEST_START - timedelta(days=3 * 30))
    .filter(pl.col('action_type') == 'order')
    .select('user_id', 'product_id', 'date')
)

In [54]:
del user_actions_full

In [22]:
id_mapping = (
    train_orders
    .select('product_id')
    .unique()
    .sort('product_id')
    .with_row_index('id')
    .with_columns([
        (pl.col('id') + 1).alias('id')
    ])
)

In [27]:
user_with_ids = (
    train_orders
    .join(id_mapping, on='product_id')
    .sort('user_id', 'date', 'id')
    .group_by('user_id')
    .agg(pl.col('id').alias('ids'))
)

In [38]:
(
    user_with_ids
    .with_columns([
        pl.col('ids').list.len().alias('ids_len')
    ])
    .select(
        pl.max('ids_len').alias('max_len'),
        pl.min('ids_len').alias('min_len'),
        pl.quantile('ids_len', 0.5).alias('median_len'),
    )
)

max_len,min_len,median_len
u32,u32,f64
1752,1,5.0


In [49]:
user_with_ids_filtered = (
    user_with_ids
    .with_columns([
        pl.col('ids').list.slice(-64, 64).alias('ids'),
    ])
    .filter(pl.col('ids').list.len() > 1)
)

In [50]:
(
    user_with_ids_filtered
    .with_columns([
        pl.col('ids').list.len().alias('ids_len')
    ])
    .select(
        pl.max('ids_len').alias('max_len'),
        pl.min('ids_len').alias('min_len'),
        pl.quantile('ids_len', 0.5).alias('median_len'),
    )
)

max_len,min_len,median_len
u32,u32,f64
64,2,8.0


In [52]:
user_with_ids_filtered.shape

(713166, 2)

In [53]:
id_mapping.shape

(53755, 2)

![title](/Users/alkrasnov/Documents/RS_IR_25/lecture6/simple_recommender/notebooks/pic/sasrec_bert4rec.png)

In [60]:
user_with_ids_filtered[0]['ids'].to_list()[0]

[30136, 39889]

In [129]:
id_mapping.shape

(53755, 2)

In [130]:
num_items = id_mapping.shape[0]

In [55]:
from torch.utils.data import DataLoader, Dataset

In [136]:
class Orders(Dataset):
    def __init__(self, orders_df: pl.DataFrame, max_len: int, pad_value: int):
        super().__init__()
        self.orders_df = orders_df
        self.max_len = max_len
        self.pad_value = pad_value
        
    def __len__(self):
        return len(self.orders_df)
    
    def __getitem__(self, idx):
        row = self.orders_df[idx]
        ids = row['ids'].to_list()[0]
        if len(ids) < self.max_len:
            ids = [self.pad_value] * (self.max_len - len(ids)) + ids
        return torch.tensor(ids, dtype=torch.int64)

In [137]:
orders = Orders(user_with_ids_filtered, max_len=64+1, pad_value=num_items+1)

In [138]:
def collate_train(input_batch, pad_value, num_negatives):
    batch_cat = torch.stack([input_batch[i] for i in range(len(input_batch))], dim=0)
    negatives = torch.randint(low=1, high=pad_value, size=(batch_cat.size(0), batch_cat.size(1), num_negatives))
    return [batch_cat, negatives]

In [139]:
def get_train_dataloader(
        orders_df: pl.DataFrame,
        pad_value: int,
        batch_size=32,
        max_len=64,
        train_neg_per_positive=256,
):
    train_dataset = Orders(orders_df, max_len=max_len + 1, pad_value=pad_value)
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=lambda x: collate_train(x, pad_value=pad_value, num_negatives=train_neg_per_positive)
    )
    return train_loader

In [143]:
train_dataloader = get_train_dataloader(
    user_with_ids_filtered,
    pad_value=id_mapping.shape[0]+1,
    batch_size=256,
    max_len=64,
    train_neg_per_positive=128,
)

In [144]:
for batch in train_dataloader:
    break

In [145]:
batch[0][0]

tensor([53756, 53756, 53756, 53756, 53756, 53756, 53756, 53756, 53756, 53756,
        53756, 53756, 53756, 53756, 53756, 53756, 53756, 53756, 53756, 53756,
        53756, 53756, 53756, 53756, 53756, 53756, 53756, 53756, 53756, 53756,
        53756, 53756, 53756, 53756, 53756, 53756, 53756, 53756, 53756, 53756,
        53756, 53756, 53756, 53756, 53756, 53756, 53756, 53756, 53756, 53756,
        53756, 53756, 53756, 53756, 53756, 53756, 53756, 53756,  2000,  6962,
         7210, 21426, 21708,  6256, 10238])

In [146]:
batch[1][0]

tensor([[38577, 18167, 16791,  ..., 10163, 28285, 42310],
        [17320, 27711,  7356,  ..., 43424, 44999, 37190],
        [26808, 24614, 47072,  ..., 21841, 45173,   912],
        ...,
        [39255, 32103, 29848,  ..., 33937,  9083, 15550],
        [30248, 21747, 15466,  ...,    83, 53131, 26752],
        [43666, 42889, 33373,  ..., 38804, 43915, 15656]])

In [147]:
import torch.nn as nn
import torch.nn.functional as F


class MultiHeadAttention(nn.Module):
    def __init__(self, dim, num_heads, dropout_rate=0.5):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.query_proj = nn.Linear(dim, dim)
        self.key_proj = nn.Linear(dim, dim)
        self.val_proj = nn.Linear(dim, dim)
        self.dropout = nn.Dropout(dropout_rate) # Change the dropout rate as needed

    def forward(self, queries, keys, causality=False):
        Q = self.query_proj(queries)
        K = self.key_proj(keys)
        V = self.val_proj(keys)

        # Split and concat
        Q_ = torch.cat(Q.chunk(self.num_heads, dim=2), dim=0)
        K_ = torch.cat(K.chunk(self.num_heads, dim=2), dim=0)
        V_ = torch.cat(V.chunk(self.num_heads, dim=2), dim=0)

        # Multiplication
        outputs = torch.matmul(Q_, K_.transpose(1, 2))

        # Scale
        outputs = outputs / (K_.size(-1) ** 0.5)

        # Key Masking
        key_masks = torch.sign(torch.sum(torch.abs(keys), dim=-1))
        key_masks = key_masks.repeat(self.num_heads, 1)
        key_masks = key_masks.unsqueeze(1).repeat(1, queries.size(1), 1)
        
        outputs = outputs.masked_fill(key_masks == 0, float('-inf'))

        # Causality
        if causality:
            diag_vals = torch.ones_like(outputs[0])
            tril = torch.tril(diag_vals)
            masks = tril[None, :, :].repeat(outputs.size(0), 1, 1)

            outputs = outputs.masked_fill(masks == 0, float('-inf'))

        # Activation
        outputs = F.softmax(outputs, dim=-1)
        outputs = torch.nan_to_num(outputs, nan=0.0, posinf=0.0, neginf=0.0)


        # Query Masking
        query_masks = torch.sign(torch.sum(torch.abs(queries), dim=-1))
        query_masks = query_masks.repeat(self.num_heads, 1)
        query_masks = query_masks.unsqueeze(-1).repeat(1, 1, keys.size(1))

        outputs *= query_masks

        attention_chunks = outputs.chunk(self.num_heads, dim=0)
        attention_weights = torch.stack(attention_chunks, dim=1)


        # Dropouts
        outputs = self.dropout(outputs)

        # Weighted sum
        outputs = torch.matmul(outputs, V_)

        # Restore shape
        outputs = torch.cat(outputs.chunk(self.num_heads, dim=0), dim=2)
        return outputs, attention_weights


class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, hidden_dim, dropout_rate=0.5, causality=True):
        super(TransformerBlock, self).__init__()
        
        self.first_norm = nn.LayerNorm(dim)
        self.second_norm = nn.LayerNorm(dim)
        
        self.multihead_attention = MultiHeadAttention(dim, num_heads, dropout_rate)
        
        self.dense1 = nn.Linear(dim, hidden_dim)
        self.dense2 = nn.Linear(hidden_dim, dim)
        
        self.dropout = nn.Dropout(dropout_rate)
        self.causality = causality
        
    def forward(self, seq, mask=None):
        x = self.first_norm(seq)
        queries = x
        keys = seq
        x, attentions = self.multihead_attention(queries, keys, self.causality)
        
        # Add & Norm
        x = x + queries
        x = self.second_norm(x)
        
        # Feed Forward
        residual = x
        x = self.dense1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.dense2(x)
        x = self.dropout(x)
        
        # Add & Norm
        x = x + residual
        
        # Apply mask if provided
        if mask is not None:
            x *= mask
            
        return x, attentions

In [148]:
import torch 

class SASRec(torch.nn.Module):
    def __init__ (
            self,
            num_items,
            sequence_length=64,
            embedding_dim=256,
            num_heads=4,
            num_blocks=3,
            dropout_rate=0.5,
            reuse_item_embeddings=False
    ):
        super(SASRec, self).__init__()
        self.num_items = num_items
        self.sequence_length = sequence_length
        self.embedding_dim = embedding_dim
        self.embeddings_dropout = torch.nn.Dropout(dropout_rate)

        self.num_heads = num_heads

        self.item_embedding = torch.nn.Embedding(self.num_items + 2, self.embedding_dim) 
        self.position_embedding = torch.nn.Embedding(self.sequence_length, self.embedding_dim)
    
        self.transformer_blocks = torch.nn.ModuleList([
            TransformerBlock(self.embedding_dim, self.num_heads, self.embedding_dim, dropout_rate)
            for _ in range(num_blocks)
        ])
        self.seq_norm = torch.nn.LayerNorm(self.embedding_dim)
        self.reuse_item_embeddings = reuse_item_embeddings
        if not self.reuse_item_embeddings:
            self.output_embedding = torch.nn.Embedding(self.num_items + 2, self.embedding_dim)

    def get_output_embeddings(self) -> torch.nn.Embedding:
        if self.reuse_item_embeddings:
            return self.item_embedding
        else:
            return self.output_embedding

    #returns last hidden state and the attention weights
    def forward(self, input: torch.Tensor):
        seq = self.item_embedding(input.long())
        mask = (input != self.num_items + 1).float().unsqueeze(-1)
        
        bs = seq.size(0)
        positions = torch.arange(seq.shape[1]).unsqueeze(0).repeat(bs, 1).to(input.device)
        pos_embeddings = self.position_embedding(positions)[:input.size(0)]
        seq = seq + pos_embeddings
        seq = self.embeddings_dropout(seq)
        seq *= mask
        
        attentions = []
        for i, block in enumerate(self.transformer_blocks):
            seq, attention = block(seq, mask)
            attentions.append(attention)
        
        seq_emb = self.seq_norm(seq)
        return seq_emb, attentions
    
    def get_predictions(self, input, limit, rated=None):
        with torch.no_grad():
            model_out, _ = self.forward(input)
            seq_emb = model_out[:,-1,:] 
            output_embeddings = self.get_output_embeddings()
            scores = torch.einsum('bd,nd->bn', seq_emb, output_embeddings.weight)
            scores[:,0] = float("-inf")
            scores[:,self.num_items+1:] = float("-inf")
            if rated is not None:
                for i in range(len(input)):
                    for j in rated[i]:
                        scores[i, j] = float("-inf")
            result = torch.topk(scores, limit, dim=1)
            return result.indices, result.values

In [149]:
train_dataloader

<torch.utils.data.dataloader.DataLoader at 0x3184b2680>

In [154]:
model = SASRec(
    num_items,
    sequence_length=64,
    embedding_dim=128,
    num_heads=4,
    num_blocks=3,
    dropout_rate=0.0,
    reuse_item_embeddings=True,
)
    

In [152]:
from tqdm import tqdm

device = "cpu"
optimiser = torch.optim.Adam(model.parameters())


In [155]:
model

SASRec(
  (embeddings_dropout): Dropout(p=0.0, inplace=False)
  (item_embedding): Embedding(53757, 128)
  (position_embedding): Embedding(64, 128)
  (transformer_blocks): ModuleList(
    (0-2): 3 x TransformerBlock(
      (first_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (second_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (multihead_attention): MultiHeadAttention(
        (query_proj): Linear(in_features=128, out_features=128, bias=True)
        (key_proj): Linear(in_features=128, out_features=128, bias=True)
        (val_proj): Linear(in_features=128, out_features=128, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (dense1): Linear(in_features=128, out_features=128, bias=True)
      (dense2): Linear(in_features=128, out_features=128, bias=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
  )
  (seq_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
)

In [156]:
for batch in train_dataloader:
    break

In [157]:
positives, negatives = batch

In [158]:
model_input = positives[:, :-1]

In [160]:
last_hidden_state, attentions = model(model_input)
labels = positives[:, 1:]

In [161]:
last_hidden_state.shape

torch.Size([256, 64, 128])

In [166]:
negatives = negatives[:, 1:, :]

In [164]:
labels.unsqueeze(-1).shape

torch.Size([256, 64, 1])

In [165]:
negatives.shape

torch.Size([256, 65, 128])

In [167]:
pos_neg_concat = torch.cat([labels.unsqueeze(-1), negatives], dim=-1)
output_embeddings = model.get_output_embeddings()

In [168]:
pos_neg_embeddings = output_embeddings(pos_neg_concat)

In [169]:
num_items

53755

In [170]:
mask = (model_input != num_items + 1).float()

In [171]:
mask

tensor([[0., 0., 0.,  ..., 0., 1., 1.],
        [0., 0., 0.,  ..., 1., 1., 1.],
        [0., 0., 0.,  ..., 1., 1., 1.],
        ...,
        [0., 0., 0.,  ..., 1., 1., 1.],
        [0., 0., 0.,  ..., 1., 1., 1.],
        [0., 0., 0.,  ..., 1., 1., 1.]])

In [172]:
logits = torch.einsum('bse, bsne -> bsn', last_hidden_state, pos_neg_embeddings)

In [174]:
last_hidden_state.shape, pos_neg_embeddings.shape

(torch.Size([256, 64, 128]), torch.Size([256, 64, 129, 128]))

In [173]:
logits.shape

torch.Size([256, 64, 129])

In [175]:
gt = torch.zeros_like(logits)
gt[:, :, 0] = 1

In [176]:
alpha = 128 / (num_items - 1)
t = 0.75
beta = alpha * ((1 - 1/alpha)*t + 1/alpha)

In [177]:
positive_logits = logits[:, :, 0:1].to(torch.float64) #use float64 to increase numerical stability
negative_logits = logits[:,:,1:].to(torch.float64)
eps = 1e-10
positive_probs = torch.clamp(torch.sigmoid(positive_logits), eps, 1-eps)
positive_probs_adjusted = torch.clamp(positive_probs.pow(-beta), 1+eps, torch.finfo(torch.float64).max)
to_log = torch.clamp(torch.div(1.0, (positive_probs_adjusted  - 1)), eps, torch.finfo(torch.float64).max)
positive_logits_transformed = to_log.log()
logits = torch.cat([positive_logits_transformed, negative_logits], -1)
loss_per_element = torch.nn.functional.binary_cross_entropy_with_logits(logits, gt, reduction='none').mean(-1)*mask
loss = loss_per_element.sum() / mask.sum()

In [178]:
loss

tensor(4.5457, grad_fn=<DivBackward0>)

In [None]:
loss_sum = 0

In [None]:
loss.backward()
optimiser.step()
optimiser.zero_grad()
loss_sum += loss.item()

https://github.com/MobileTeleSystems/RecTools/blob/main/examples/tutorials/transformers_tutorial.ipynb
https://github.com/asash/gSASRec-pytorch/blob/main/README.md