In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
import torch.nn.functional as F
import matplotlib.pyplot as plt
from collections import defaultdict
from sklearn.metrics import ndcg_score

  from .autonotebook import tqdm as notebook_tqdm


In [49]:
data_dir = './ml-1m'

In [3]:
movies = pd.read_csv(
    f'{data_dir}/movies.dat',sep='::',
    engine='python',  # Use 'python' engine to handle the separator correctly
    encoding='latin-1',  # Important for handling special characters
    header=None,
    names=['MovieID', 'Title', 'Genres']
)

In [4]:
movies.head()

Unnamed: 0,MovieID,Title,Genres
0,1,Toy Story (1995),Animation|Children's|Comedy
1,2,Jumanji (1995),Adventure|Children's|Fantasy
2,3,Grumpier Old Men (1995),Comedy|Romance
3,4,Waiting to Exhale (1995),Comedy|Drama
4,5,Father of the Bride Part II (1995),Comedy


In [5]:
users = pd.read_csv(
    f'{data_dir}/users.dat',sep='::',
    engine='python',  # Use 'python' engine to handle the separator correctly
    encoding='latin-1',  # Important for handling special characters
    header=None,
    names=['UserID', 'Gender', 'Age', 'Occupation', 'Zip-code']
)

In [6]:
users.head()

Unnamed: 0,UserID,Gender,Age,Occupation,Zip-code
0,1,F,1,10,48067
1,2,M,56,16,70072
2,3,M,25,15,55117
3,4,M,45,7,2460
4,5,M,25,20,55455


In [7]:
ratings = pd.read_csv(
    f'{data_dir}/ratings.dat',sep='::',
    engine='python',  # Use 'python' engine to handle the separator correctly
    encoding='latin-1',  # Important for handling special characters
    header=None,
    names=['UserID', 'MovieID', 'Rating', 'Timestamp']
)

In [8]:
ratings.head()

Unnamed: 0,UserID,MovieID,Rating,Timestamp
0,1,1193,5,978300760
1,1,661,3,978302109
2,1,914,3,978301968
3,1,3408,4,978300275
4,1,2355,5,978824291


In [9]:
ratings.sort_values(by='Rating').head()

Unnamed: 0,UserID,MovieID,Rating,Timestamp
155787,1004,2643,1,975043037
955328,5763,535,1,959535843
885948,5350,42,1,960660817
517012,3191,880,1,968653120
517011,3191,1556,1,968653144


In [10]:
ratings.shape

(1000209, 4)

In [11]:
group = list(ratings.groupby('UserID').groups)

In [12]:
len(group)

6040

In [13]:
ratings.UserID.nunique()

6040

In [14]:
ratings.groupby('UserID')['MovieID'].count().sort_values(ascending=False)

UserID
4169    2314
1680    1850
4277    1743
1941    1595
1181    1521
        ... 
5725      20
3407      20
1664      20
4419      20
3021      20
Name: MovieID, Length: 6040, dtype: int64

In [15]:
ratings.loc[ratings.UserID==4169]

Unnamed: 0,UserID,MovieID,Rating,Timestamp
695642,4169,3789,5,965333672
695643,4169,571,4,973310265
695644,4169,574,3,975805232
695645,4169,575,3,976589949
695646,4169,577,3,988324145
...,...,...,...,...
697951,4169,3784,2,965333606
697952,4169,3785,3,980476924
697953,4169,2047,3,971579815
697954,4169,3788,3,965333481


In [20]:
def generate_encoder_decoder(col) -> dict:
        """
        encoder, decoder

        Args:
            col (str): columns
        Returns:
            dict: encoder, decoder
        """

        encoder = {}
        decoder = {}
        ids = col.unique()

        for idx, _id in enumerate(ids):
            encoder[_id] = idx
            decoder[idx] = _id

        return encoder, decoder

In [34]:
def generate_sequence_data() -> dict:
        """
        sequence_data

        Returns:
            dict: train user sequence / valid user sequence
        """
        users = defaultdict(list)
        user_train = {}
        genres_seq = {}
        user_valid = {}
        group_df = df.groupby('user_idx')
        for user, item in group_df:
            users[user].extend(item['item_idx'].tolist())
            
        
        for user in users:
            user_train[user] = users[user][:-1]
            user_valid[user] = users[user][-1]

        return user_train, user_valid

In [35]:
user_encoder, user_decoder = generate_encoder_decoder(ratings['UserID'])
item_encoder, item_decoder = generate_encoder_decoder(movies['MovieID'])

In [37]:
df=ratings.copy()
df['item_idx'] = df['MovieID'].apply(lambda x : item_encoder[x] + 1)
df['user_idx'] = df['UserID'].apply(lambda x : user_encoder[x])
        
df = df.sort_values(['user_idx', 'Timestamp']) 
        
print('Generate sequence data...')
user_train, user_valid = generate_sequence_data()

Generate sequence data...


In [40]:
user_train[0], user_valid[0]

([3118,
  1251,
  1673,
  1010,
  2272,
  1769,
  3340,
  2736,
  1190,
  1177,
  712,
  258,
  908,
  605,
  2624,
  1893,
  1960,
  3037,
  927,
  1023,
  1894,
  1950,
  149,
  1016,
  1082,
  903,
  1268,
  2729,
  2694,
  1227,
  656,
  2850,
  528,
  3046,
  2723,
  2253,
  1017,
  1180,
  591,
  2330,
  1507,
  524,
  592,
  2619,
  736,
  585,
  1,
  2287,
  2226,
  774,
  1527,
  1839],
 48)

In [43]:
df.loc[df.UserID==1]

Unnamed: 0,UserID,MovieID,Rating,Timestamp,item_idx,user_idx
31,1,3186,4,978300019,3118,0
22,1,1270,5,978300055,1251,0
27,1,1721,4,978300055,1673,0
37,1,1022,5,978300055,1010,0
24,1,2340,3,978300103,2272,0
36,1,1836,5,978300172,1769,0
3,1,3408,4,978300275,3340,0
7,1,2804,5,978300719,2736,0
47,1,1207,4,978300719,1190,0
0,1,1193,5,978300760,1177,0


In [233]:
ratings.MovieID.nunique()

3706

In [277]:
_B, _T, _D, _K = 4, 8, 16, 3
_logits = torch.rand(size=(_B,_T,_D))
_labels = torch.randint(1,10,size=(_B, _T))
print (_labels.shape)
print (_labels)
_mask_pos = torch.rand(size=(_T,))<0.2
# _mask_pos = torch.tensor([False]*(_T-1)+[True])
print (_mask_pos)
_mask_pos = torch.where(_mask_pos)[0]
print (_mask_pos)
_mask_pos = _mask_pos.expand(_B, -1)
print (_mask_pos)
print (_mask_pos.shape)
_logits = _logits.gather(1, _mask_pos.unsqueeze(-1).expand(-1, -1, _D))
print (_logits.shape)
print (_logits.topk(5, dim=-1).indices.cpu().tolist())
_labels = _labels.gather(1, _mask_pos)
print (_labels)
_neg_sample = torch.randint(1,10,size=(_B, 2))
print  (_neg_sample)
_candidates = torch.cat((_labels, _neg_sample), dim=1)
print (_candidates)

torch.Size([4, 8])
tensor([[1, 8, 7, 1, 1, 4, 5, 6],
        [8, 8, 5, 2, 6, 1, 9, 8],
        [9, 4, 2, 6, 2, 7, 7, 1],
        [5, 5, 5, 6, 2, 7, 1, 2]])
tensor([False, False, False, False,  True,  True, False, False])
tensor([4, 5])
tensor([[4, 5],
        [4, 5],
        [4, 5],
        [4, 5]])
torch.Size([4, 2])
torch.Size([4, 2, 16])
[[[5, 6, 8, 0, 9], [5, 1, 7, 3, 0]], [[12, 2, 15, 7, 4], [0, 2, 1, 8, 7]], [[10, 14, 13, 9, 15], [11, 9, 4, 14, 6]], [[0, 12, 4, 11, 1], [10, 1, 15, 0, 13]]]
tensor([[1, 4],
        [6, 1],
        [2, 7],
        [2, 7]])
tensor([[1, 1],
        [3, 2],
        [2, 9],
        [3, 6]])
tensor([[1, 4, 1, 1],
        [6, 1, 3, 2],
        [2, 7, 2, 9],
        [2, 7, 3, 6]])


In [278]:
_mask = torch.tensor([
    [False, True, False, True],
    [True, False, False, True]
])
print (_mask)
print (torch.where(_mask))

tensor([[False,  True, False,  True],
        [ True, False, False,  True]])
(tensor([0, 0, 1, 1]), tensor([1, 3, 0, 3]))


In [275]:
_mask = torch.rand(size=(_B, _T))<0.3
print (_mask)
print (torch.where(_mask))
_labels = torch.randint(1,10,size=(_B, _T))
print (_labels)
print (_labels.gather(1, _mask))

tensor([[False,  True, False, False, False,  True, False,  True],
        [False, False, False, False, False,  True,  True, False],
        [ True, False, False, False, False, False, False, False],
        [False,  True, False, False, False, False, False, False]])
(tensor([0, 0, 0, 1, 1, 2, 3]), tensor([1, 5, 7, 5, 6, 0, 1]))
tensor([[5, 4, 6, 2, 6, 6, 9, 7],
        [6, 1, 4, 1, 3, 7, 6, 2],
        [5, 3, 2, 9, 9, 1, 1, 4],
        [6, 6, 3, 6, 7, 7, 4, 1]])


RuntimeError: gather(): Expected dtype int64 for index

In [292]:
_arr = torch.randint(1,10,size=(4,8))
print (_arr)
_arr = _arr[:,-1].unsqueeze(-1)
print (_arr)
_neg_sample = torch.randint(1,10,size=(4, 2))
print (_neg_sample)
print (torch.cat((_arr, _neg_sample), dim=1))

tensor([[6, 5, 2, 8, 5, 7, 9, 2],
        [8, 6, 8, 3, 2, 4, 4, 8],
        [7, 6, 5, 3, 5, 5, 3, 2],
        [6, 2, 5, 1, 2, 1, 8, 3]])
tensor([[2],
        [8],
        [2],
        [3]])
tensor([[8, 4],
        [4, 2],
        [1, 5],
        [3, 7]])
tensor([[2, 8, 4],
        [8, 4, 2],
        [2, 1, 5],
        [3, 3, 7]])


## Bert4rec

In [1]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
import torch.nn.functional as F
import matplotlib.pyplot as plt
from collections import defaultdict
from sklearn.metrics import ndcg_score
import evaluate

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cpu


## Data

In [3]:
data_dir = './ml-1m'

In [4]:
ratings = pd.read_csv(
    f'{data_dir}/ratings.dat',sep='::',
    engine='python',  # Use 'python' engine to handle the separator correctly
    encoding='latin-1',  # Important for handling special characters
    header=None,
    names=['UserID', 'MovieID', 'Rating', 'Timestamp']
)

In [5]:
display (ratings.head())
ratings.shape

Unnamed: 0,UserID,MovieID,Rating,Timestamp
0,1,1193,5,978300760
1,1,661,3,978302109
2,1,914,3,978301968
3,1,3408,4,978300275
4,1,2355,5,978824291


(1000209, 4)

In [6]:
len(ratings.MovieID.unique())

3706

In [7]:
movie_idx = {}
for idx, movie in enumerate(ratings.MovieID.unique()):
    movie_idx[movie] = idx+1
ratings['movie_idx'] = ratings.MovieID.apply(lambda x: movie_idx[x])
ratings.head()

Unnamed: 0,UserID,MovieID,Rating,Timestamp,movie_idx
0,1,1193,5,978300760,1
1,1,661,3,978302109,2
2,1,914,3,978301968,3
3,1,3408,4,978300275,4
4,1,2355,5,978824291,5


In [8]:
data = ratings.groupby('UserID')[['movie_idx','Timestamp']].apply(lambda x: x.sort_values('Timestamp').movie_idx.agg(list)).to_dict()

In [9]:
len(data[1]), data[1][:10], data[1][-10:]

(53,
 [32, 28, 38, 23, 25, 37, 4, 48, 8, 22],
 [34, 17, 30, 11, 31, 36, 5, 35, 33, 26])

In [10]:
low_rated_movies = np.where(ratings.groupby('movie_idx').Rating.max()<=2)[0]
len(low_rated_movies)

78

In [12]:
popular_movies = np.where(ratings.groupby('movie_idx').Rating.max()>=4)[0]
len(popular_movies)

3533

In [13]:
vocab_size = max(movie_idx.values())+1
print (vocab_size)

3707


In [20]:
class Bert4RecDataset(Dataset):
    def __init__(self, data, low_rated_movies, context_len, mask_prob, split):
        super(Bert4RecDataset, self).__init__()
        self.context_len = context_len
        self.mask_prob = mask_prob
        self.data = data
        self.split = split
    def __len__(self):
        return len(self.data)
    def get_negative_sample(self, user):
        movies_watched = self.data.get(user+1, [])
        negatives = popular_movies[~np.isin(popular_movies, movies_watched)]
        neg_sample = torch.tensor(np.random.choice(negatives, size=50, replace=False), dtype=torch.long)
        return neg_sample
        
    def __getitem__(self, user):
        seq = self.data.get(user+1, [])
        seq = seq[-(self.context_len+1):]
        pad_len = self.context_len+1 - len(seq)
        seq = [0] * pad_len + seq
        if self.split == 'train':
            tokens = torch.tensor(seq[-(self.context_len+1):-1], dtype=torch.long)
            labels = torch.full_like(tokens, fill_value=-100)
            mask_pos = np.random.rand(self.context_len) < self.mask_prob
            mask_pos[tokens==0] = False # don't mask padded tokens
            # print (tokens.shape, labels.shape, mask_pos.shape)
            labels[mask_pos] = tokens[mask_pos]
            tokens[mask_pos]=0
            neg_sample = torch.tensor([])
        else:
            tokens = torch.tensor(seq[-self.context_len:], dtype=torch.long)
            labels = torch.full_like(tokens, fill_value=-100)
            mask_pos = torch.tensor([False]*(self.context_len-1)+[True])
            mask_pos[tokens==0] = False # don't mask padded tokens
            # print (tokens.shape, labels.shape, mask_pos.shape)
            labels[mask_pos] = tokens[mask_pos]
            tokens[mask_pos]=0
            # neg_sample = torch.tensor(np.random.choice(low_rated_movies, size=50, replace=False), dtype=torch.long)
            neg_sample = self.get_negative_sample(user)
        return {
            'user': user,
            'tokens': tokens,
            # 'mask_pos': torch.tensor(mask_pos, dtype=torch.bool),
            'labels': labels,
            'neg_sample': neg_sample
        }

In [21]:
train_data = Bert4RecDataset(data, low_rated_movies, context_len=32, mask_prob=0.2, split='train')
eval_data = Bert4RecDataset(data, low_rated_movies, context_len=32, mask_prob=0.2, split='eval')

from torch.utils.data import DataLoader

train_dataloader = DataLoader(dataset=train_data, batch_size=32, shuffle=True)
eval_dataloader = DataLoader(dataset=eval_data, batch_size=32, shuffle=True)

# Let's see no of batches that we have now with the current batch-size
print (len(train_dataloader), len(eval_dataloader))

189 189


## Model

In [41]:
class BERT4Rec(nn.Module):
    def __init__(self, vocab_size, context_len=32, d_model=256, num_heads=4, num_layers=2, dropout=0.1):
        super(BERT4Rec, self).__init__()

        self.d_model = d_model
        self.context_len = context_len
        self.item_embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
        self.position_embedding = nn.Embedding(context_len, d_model)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=num_heads,
            dim_feedforward=4*d_model,
            dropout=dropout,
            batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.layer_norm = nn.LayerNorm(d_model)
        self.output_layer = nn.Linear(d_model, vocab_size)

    def forward(self, input_seq):
        """
        input_seq: [B, T] - sequence of item IDs (0 = PAD and MASK)
        masked_pos: [B, K] - positions in the sequence we want to predict
        """
        B, T = input_seq.size()
        positions = torch.arange(T, device=input_seq.device).unsqueeze(0).expand(B, T)

        x = self.item_embedding(input_seq) + self.position_embedding(positions)

        attention_mask = (input_seq != 0) # avoid attending padded and masked tokens

        x = self.encoder(x, src_key_padding_mask=~attention_mask)
        # x = self.layer_norm(x)

        logits = self.output_layer(x)  # [B, T, vocab_size]
        return logits

bert4rec_model = BERT4Rec(vocab_size, context_len=32, d_model=256, num_heads=4, num_layers=2, dropout=0.1)
bert4rec_model.to(device)

BERT4Rec(
  (item_embedding): Embedding(3707, 256, padding_idx=0)
  (position_embedding): Embedding(32, 256)
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-1): 2 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
        )
        (linear1): Linear(in_features=256, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=1024, out_features=256, bias=True)
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (output_layer): Linear(in_features=256, out_features=3707, bias=True)
)

## Train and Eval functions

In [43]:
def train(model, optimizer, loss_fn, train_dataloader, device):
    print ('training')
    model.train()
    train_loss = 0
    for step, batch in enumerate(train_dataloader):
        tokens = batch['tokens'].to(device)
        labels = batch['labels'].to(device)
        # mask_pos = batch['mask_pos'].to(device)
        # print (mask_pos.shape) #[B, T]
        # mask_pos = torch.where(mask_pos)[0]
        # print (mask_pos.shape) #[B, K]
        # no need to get neg_sample training stage

        optimizer.zero_grad()

        logits = model(tokens) # [B, T, vocab_size]
        B, T, d_model = logits.shape
        # print (logits.shape)
        # logits = logits.gather(1, mask_pos.unsqueeze(-1).expand(-1, -1, d_model)) # [B, K, d_model]
        # targets = labels.gather(1, mask_pos) # [B, K]
        # loss = loss_fn(logits.view(-1, logits.size(-1)), targets.view(-1))
        loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
        train_loss += loss

        loss.backward()
        optimizer.step()

        # if step%100==0:
        #      print (f'training loss at {step} is {train_loss}')
    return train_loss/len(train_dataloader)

@torch.no_grad()
def eval(model, loss_fn, eval_dataloader, device):
    model.eval()
    eval_loss = 0
    metric_accuracy = evaluate.load("accuracy")
    ndcg_scores = []
    for step, batch in enumerate(eval_dataloader):
        tokens = batch['tokens'].to(device)
        labels = batch['labels'].to(device)
        # mask_pos = batch['mask_pos'].to(device)
        neg_sample = batch['neg_sample']

        # mask_pos = torch.where(mask_pos)[0]
        B, N = neg_sample.shape

        logits = model(tokens) # [B, T, vocab_size]
        B, T, V = logits.shape
        # print (tokens.shape, labels.shape, logits.shape)
        # logits = logits.gather(1, mask_pos.unsqueeze(-1).expand(-1, -1, d_model))
        # targets = labels.gather(1, mask_pos) # [B, K] -- K=1 in eval
        # B, K = targets.shape
        loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
        eval_loss += loss

        targets = labels[:,-1].unsqueeze(-1)
        # print (targets.shape) # B, K
        B, K = targets.shape
        candidates = torch.cat((targets, neg_sample), dim=1)
        # print (candidates.shape) # B, K+N

        logits = logits[:, -1, :]  # last position
        metric_accuracy.add_batch(predictions=logits.argmax(dim=1), references=targets)
        _ndcg10 = ndcg_score(np.broadcast_to(np.array([1]*K+[0]*N),(B,K+N)), logits.gather(1, candidates).cpu())
        ndcg_scores.append(_ndcg10)
        # print (logits.shape) # B, V
        # scores = logits.gather(1, candidates).cpu()
        # # print (scores.shape)
        # for i in range(scores.shape[0]):
        #     y_true = np.array([1]*K+[0]*N)
        #     ndcg_scores.append(ndcg_score(y_true.reshape(1, K+N), scores[i].reshape(1, K+N), k=10))
    # ndcg_scores = np.array(ndcg_scores)
    accuracy = metric_accuracy.compute()
    print (f'Eval loss: {eval_loss/len(eval_dataloader)}')
    print (f'NDCG: {np.mean(ndcg_scores)}')
    print (f'Accuracy: {accuracy}')
    return eval_loss/len(eval_dataloader), np.mean(ndcg_scores), accuracy

## Model training

In [46]:
from torch.optim import AdamW
from transformers import get_scheduler

optimizer = AdamW(bert4rec_model.parameters(), lr=1e-3)
num_epochs = 10
num_training_steps = num_epochs * len(train_dataloader)

lr_scheduler = get_scheduler(
    name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

loss_fn = nn.CrossEntropyLoss(ignore_index=-100)

In [48]:
for epoch in range(num_epochs):
  train_loss = train(bert4rec_model, optimizer, loss_fn, train_dataloader, device)
  print (f'train loss at epoch {epoch} is {train_loss}')
  eval_loss, eval_ndcg10, accuracy = eval(bert4rec_model, loss_fn, eval_dataloader, device)
  print (f'eval loss at epoch {epoch} is {eval_loss}')
  print (f'eval NDCG@10 at epoch {epoch} is {eval_ndcg10}')

training
train loss at epoch 0 is 6.545655727386475
Eval loss: 6.9343438148498535
NDCG: 0.5735814456561721
Accuracy: {'accuracy': 0.011589403973509934}
eval loss at epoch 0 is 6.9343438148498535
eval NDCG@10 at epoch 0 is 0.5735814456561721
training
train loss at epoch 1 is 6.477384567260742
Eval loss: 6.815270900726318
NDCG: 0.5888803394498596
Accuracy: {'accuracy': 0.011920529801324504}
eval loss at epoch 1 is 6.815270900726318
eval NDCG@10 at epoch 1 is 0.5888803394498596
training
train loss at epoch 2 is 6.413047790527344
Eval loss: 6.8117780685424805
NDCG: 0.591340426774445
Accuracy: {'accuracy': 0.012913907284768211}
eval loss at epoch 2 is 6.8117780685424805
eval NDCG@10 at epoch 2 is 0.591340426774445
training
train loss at epoch 3 is 6.342720985412598
Eval loss: 6.7819390296936035
NDCG: 0.5907602979565334
Accuracy: {'accuracy': 0.014403973509933774}
eval loss at epoch 3 is 6.7819390296936035
eval NDCG@10 at epoch 3 is 0.5907602979565334
training
train loss at epoch 4 is 6.2970

## Test and checks

In [120]:
test_batch = next(iter(eval_dataloader))

In [121]:
test_batch

{'user': tensor([4700, 4906, 1231, 2089, 5915,  217,  141, 1220, 1266, 5393,  393, 3963,
         3820,  386, 4295, 1127, 3146, 5674, 1722, 3843, 3175, 3939, 4732,  584,
         5166,  595, 2450, 3029, 2082, 4812, 1389, 1027]),
 'tokens': tensor([[ 129,  501,  110,  ...,  874,  121,    0],
         [1645, 1005, 1728,  ..., 2638, 1366,    0],
         [ 204,  393,  615,  ..., 1201,  919,    0],
         ...,
         [1798, 1380, 1428,  ...,  518, 2065,    0],
         [1381, 1191, 1067,  ..., 1151, 1435,    0],
         [2033, 2448, 1903,  ...,  293,  247,    0]]),
 'labels': tensor([[-100, -100, -100,  ..., -100, -100, 1518],
         [-100, -100, -100,  ..., -100, -100, 1107],
         [-100, -100, -100,  ..., -100, -100, 1281],
         ...,
         [-100, -100, -100,  ..., -100, -100, 1158],
         [-100, -100, -100,  ..., -100, -100,  499],
         [-100, -100, -100,  ..., -100, -100,  259]]),
 'neg_sample': tensor([[3512, 3149, 3645,  ..., 3504, 3503, 3533],
         [3399, 

In [122]:
_user=4700
_context_len=32
_seq = data.get(_user+1, [])
_seq = _seq[-(_context_len+1):]
_pad_len = _context_len+1 - len(_seq)
_seq = [0] * _pad_len + _seq
_seq

[105,
 129,
 501,
 110,
 281,
 739,
 195,
 124,
 1383,
 157,
 44,
 337,
 515,
 1201,
 39,
 377,
 1616,
 165,
 41,
 51,
 31,
 970,
 1312,
 135,
 553,
 438,
 383,
 284,
 1397,
 265,
 874,
 121,
 1518]

In [123]:
tokens = test_batch['tokens'].to(device)
labels = test_batch['labels'].to(device)
neg_sample = test_batch['neg_sample']

B, N = neg_sample.shape
print (tokens.shape, B, N)

torch.Size([32, 32]) 32 50


In [124]:
logits = bert4rec_model(tokens) # [B, T, vocab_size]
B, T, V = logits.shape
print (B, T, V)

32 32 3707


In [125]:
loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
print (loss)

tensor(7.7274, grad_fn=<NllLossBackward0>)


In [126]:
targets = labels[:,-1].unsqueeze(-1)
print (targets.shape) # B, K
B, K = targets.shape
candidates = torch.cat((targets, neg_sample), dim=1)
print (candidates.shape)

torch.Size([32, 1])
torch.Size([32, 51])


In [127]:
logits = logits[:, -1, :]  # last position
print (logits.shape)

torch.Size([32, 3707])


In [128]:
predictions=logits.argmax(dim=1)
references=targets
print (predictions, targets)

tensor([ 438,  420,  190, 1076, 1076,   69,  247,  134,   10,  930,  433,  129,
         832,  211,  533,  518,   61,   39,  238,  838,   69,   69, 1076,   69,
          87,  534,  353,  129,   61,  534,  532,  247]) tensor([[1518],
        [1107],
        [1281],
        [1364],
        [ 895],
        [1041],
        [ 525],
        [1376],
        [ 886],
        [2376],
        [1174],
        [ 890],
        [2016],
        [ 548],
        [1580],
        [ 262],
        [1189],
        [2516],
        [ 219],
        [2386],
        [2241],
        [1267],
        [1015],
        [ 500],
        [ 880],
        [1160],
        [ 154],
        [ 436],
        [ 110],
        [1158],
        [ 499],
        [ 259]])


In [129]:
logits.gather(1, candidates).cpu()

tensor([[ -0.6900,  -4.8754,  -6.1538,  ..., -11.6188, -11.5547,  -6.7060],
        [  0.9483,  -9.4182,  -2.4566,  ...,  -9.0036,  -3.3372,  -2.6410],
        [ -0.1588,  -2.7801,  -2.2601,  ...,  -6.7586,  -6.0960,  -8.5420],
        ...,
        [  1.5060,  -4.6516, -11.2825,  ..., -11.0691, -11.9302,  -5.1259],
        [ -0.8759,  -1.3401,  -9.7555,  ...,  -4.3106,  -2.6279,  -9.4801],
        [  1.1888,  -9.7429,  -4.6852,  ...,  -3.7372, -10.5331,  -3.5214]],
       grad_fn=<GatherBackward0>)

In [130]:
logits.gather(1, targets).cpu()

tensor([[-0.6900],
        [ 0.9483],
        [-0.1588],
        [ 1.7015],
        [ 0.7795],
        [ 0.1989],
        [ 2.0634],
        [-0.2410],
        [ 0.2375],
        [-2.1970],
        [ 0.4046],
        [-1.4611],
        [-0.2847],
        [ 1.6789],
        [ 0.0110],
        [ 0.4923],
        [-0.7186],
        [-9.2214],
        [ 2.4940],
        [ 0.7613],
        [ 1.0710],
        [-1.9045],
        [ 1.0546],
        [ 3.0804],
        [ 1.2476],
        [ 0.5921],
        [-0.4447],
        [ 1.4848],
        [-0.0661],
        [ 1.5060],
        [-0.8759],
        [ 1.1888]], grad_fn=<GatherBackward0>)

In [26]:
for name, param in bert4rec_model.named_parameters():
    if param.grad is not None:
        print(f"{name} grad norm: {param.grad.norm()}")

item_embedding.weight grad norm: 0.010649275965988636
position_embedding.weight grad norm: 0.023239050060510635
encoder.layers.0.self_attn.in_proj_weight grad norm: 0.20482070744037628
encoder.layers.0.self_attn.in_proj_bias grad norm: 0.015206803567707539
encoder.layers.0.self_attn.out_proj.weight grad norm: 0.14409254491329193
encoder.layers.0.self_attn.out_proj.bias grad norm: 0.021366922184824944
encoder.layers.0.linear1.weight grad norm: 0.15505120158195496
encoder.layers.0.linear1.bias grad norm: 0.009657294489443302
encoder.layers.0.linear2.weight grad norm: 0.3421947956085205
encoder.layers.0.linear2.bias grad norm: 0.021902913227677345
encoder.layers.0.norm1.weight grad norm: 0.022491611540317535
encoder.layers.0.norm1.bias grad norm: 0.022413820028305054
encoder.layers.0.norm2.weight grad norm: 0.025037651881575584
encoder.layers.0.norm2.bias grad norm: 0.024799855425953865
encoder.layers.1.self_attn.in_proj_weight grad norm: 0.1483602225780487
encoder.layers.1.self_attn.in_p