In [1]:
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from typing import List
from tqdm import tqdm

from RandomSplit import RandomSplit
from metrics import ndcg_metric, dcg, recall_metric, evaluate_recommender, get_metrics

%matplotlib inline

import catalyst 
import recbole

from typing import Dict, List, Tuple

import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.nn.init import constant_, xavier_normal_
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset

from catalyst import dl, metrics
from catalyst.contrib.datasets import MovieLens
from catalyst.utils import get_device, set_global_seed
from torch.nn.utils.rnn import pad_sequence 

import random

from recbole.model.abstract_recommender import SequentialRecommender
from recbole.model.layers import *

from BERT4Rec import BERT4Rec
from runner import RecSysRunner
from ContextBERT4Rec import ContextBERT4Rec

set_global_seed(100)
device = torch.device("mps")
print(device)

mps


In [2]:
rnames = ['user_id', 'movie_id', 'rating', 'timestamp']
df = pd.read_table('data/ratings.dat', sep='::',header=None, names=rnames, engine='python')
df = df.rename(columns={'userId': 'user_id', 'movie_id': 'item_id'})
df['timestamp'] = pd.to_datetime(df['timestamp'],unit='s')
df['weekday'] = pd.to_datetime(df.timestamp).dt.weekday
df['hour'] = pd.to_datetime(df.timestamp).dt.hour
df.head()

Unnamed: 0,user_id,item_id,rating,timestamp,weekday,hour
0,1,1193,5,2000-12-31 22:12:40,6,22
1,1,661,3,2000-12-31 22:35:09,6,22
2,1,914,3,2000-12-31 22:32:48,6,22
3,1,3408,4,2000-12-31 22:04:35,6,22
4,1,2355,5,2001-01-06 23:38:11,5,23


In [3]:
splitter = RandomSplit(test_fraction=0.2)
train_df, valid_df, test_df = splitter(df)

In [4]:
train_grouped = train_df.groupby('user_id').apply(
    lambda x: [(t1, t2, t3, t4) for t1, t2, t3, t4 in sorted(zip(x.item_id, 
                                                                 x.timestamp,
                                                                 x.weekday,
                                                                 x.hour), key=lambda x: x[1])]
).reset_index()
train_grouped.rename({0:'train_interactions'}, axis=1, inplace=True)

valid_grouped = valid_df.groupby('user_id').apply(
    lambda x: [(t1, t2, t3, t4) for t1, t2, t3, t4 in sorted(zip(x.item_id,
                                                         x.timestamp,
                                                         x.weekday,
                                                         x.hour), key=lambda x: x[1])]
).reset_index()
valid_grouped.rename({0:'valid_interactions'}, axis=1, inplace=True)

test_grouped = test_df.groupby('user_id').apply(
    lambda x: [(t1, t2, t3, t4) for t1, t2, t3, t4 in sorted(zip(x.item_id,
                                                         x.timestamp,
                                                         x.weekday,
                                                         x.hour), key=lambda x: x[1])]
).reset_index()
test_grouped.rename({0:'test_interactions'}, axis=1, inplace=True)


train_grouped.head()

Unnamed: 0,user_id,train_interactions
0,1,"[(3186, 2000-12-31 22:00:19, 6, 22), (1270, 20..."
1,2,"[(1198, 2000-12-31 21:28:44, 6, 21), (1210, 20..."
2,3,"[(593, 2000-12-31 21:10:18, 6, 21), (2858, 200..."
3,4,"[(1210, 2000-12-31 20:18:44, 6, 20), (1097, 20..."
4,5,"[(2717, 2000-12-31 05:37:52, 6, 5), (908, 2000..."


In [5]:
joined = train_grouped.merge(valid_grouped).merge(test_grouped)
joined.head()

Unnamed: 0,user_id,train_interactions,valid_interactions,test_interactions
0,1,"[(3186, 2000-12-31 22:00:19, 6, 22), (1270, 20...","[(2791, 2000-12-31 22:36:28, 6, 22), (2321, 20...","[(2687, 2001-01-06 23:37:48, 5, 23), (745, 200..."
1,2,"[(1198, 2000-12-31 21:28:44, 6, 21), (1210, 20...","[(2028, 2000-12-31 21:56:13, 6, 21), (2571, 20...","[(1372, 2000-12-31 21:59:01, 6, 21), (1552, 20..."
2,3,"[(593, 2000-12-31 21:10:18, 6, 21), (2858, 200...","[(648, 2000-12-31 21:24:27, 6, 21), (2735, 200...","[(1270, 2000-12-31 21:30:31, 6, 21), (1079, 20..."
3,4,"[(1210, 2000-12-31 20:18:44, 6, 20), (1097, 20...","[(2947, 2000-12-31 20:23:50, 6, 20), (1214, 20...","[(1240, 2000-12-31 20:24:20, 6, 20), (2951, 20..."
4,5,"[(2717, 2000-12-31 05:37:52, 6, 5), (908, 2000...","[(2323, 2000-12-31 06:50:45, 6, 6), (272, 2000...","[(1715, 2000-12-31 06:58:11, 6, 6), (1653, 200..."


In [6]:
our_items = set()
for idx, row in tqdm(joined.iterrows()):
    for el in row.train_interactions:
        our_items.add(el[0])
        
len(our_items)

6040it [00:00, 40585.64it/s]


3636

In [7]:
item2idx = {k: i for i, k in enumerate(our_items)}
idx2item = {i: k for k, i in item2idx.items()}

In [8]:
class MyDataset(Dataset):
    
    def __init__(self, ds, num_items, item2idx, phase='valid', N=200):
        super().__init__()
        self.ds = ds
        self.phase = phase
        self.n_items = num_items
        self.item2idx = item2idx
        self.N = N 
        
    def __len__(self):
        return len(self.ds)
    
    def __getitem__(self, idx):
        
        row = self.ds.iloc[idx]
        
        x_input = np.zeros(self.n_items+1)
        x_input[[self.item2idx[x[0]]+1 for x in row['train_interactions'] if x[0] in self.item2idx]] = 1
        
        days_of_weeks = [x[2] for x in row['train_interactions'] if x[0] in self.item2idx][-self.N+1:]
        
        hours = [x[3] for x in row['train_interactions'] if x[0] in self.item2idx][-self.N+1:]
        
        seq_input = [self.item2idx[x[0]]+1 for x in row['train_interactions'] if x[0] in self.item2idx][-self.N+1:]
        
        targets = np.zeros(self.n_items+1)
        
        dow_valid = row['valid_interactions'][0][2]
        dow_test = row['test_interactions'][0][2]
        
        hours_valid = row['valid_interactions'][0][3]
        hours_test = row['test_interactions'][0][3]
        
        if self.phase == 'train':
            return (seq_input, days_of_weeks, hours, dow_valid, hours_valid)
        elif self.phase == 'valid':
            targets[[self.item2idx[x[0]]+1 for x in row['valid_interactions'] if x[0] in self.item2idx]] = 1
        else:
            return (seq_input, days_of_weeks, hours, dow_test, hours_test)
            
        return (targets, seq_input, days_of_weeks, hours, dow_valid, hours_valid)
     
n_items = len(item2idx)

train = MyDataset(ds=joined,
                  num_items=n_items, 
                  item2idx=item2idx,
                  phase='train')

valid = MyDataset(ds=joined,
                  num_items=n_items,
                  item2idx=item2idx,
                  phase='valid')

print(len(train),len(valid))

6040 6040


In [9]:
def collate_fn_train(batch: List[Tuple[torch.Tensor]]) -> Dict[str, torch.Tensor]: 
    
    seq_i,days_of_weeks,hours,dow_valid,hours_valid = zip(*batch)
    seq_len = torch.Tensor([len(x) for x in seq_i])
    dow_valid = torch.Tensor([x for x in dow_valid])
    hours_valid = torch.Tensor([x for x in hours_valid])
    seq_i = pad_sequence([torch.Tensor(t) for t in seq_i]).T    
    days_of_weeks = pad_sequence([torch.Tensor(t) for t in days_of_weeks]).T
    hours = pad_sequence([torch.Tensor(t) for t in hours]).T
    
    return {'seq_i': seq_i, 
            'seq_len':seq_len,
            'dow': days_of_weeks,
            'hours': hours,
            'dow_valid': dow_valid,
            'hours_valid': hours_valid}


def collate_fn_valid(batch: List[Tuple[torch.Tensor]]) -> Dict[str, torch.Tensor]:
    
    y, seq_i, days_of_weeks, hours, dow_valid, hours_valid = zip(*batch)
    
    seq_len = torch.Tensor([len(x) for x in seq_i]).long()
    seq_i = pad_sequence([torch.Tensor(t) for t in seq_i]).T.long()
    days_of_weeks = pad_sequence([torch.Tensor(t) for t in days_of_weeks]).T.long()
    hours = pad_sequence([torch.Tensor(t) for t in hours]).T.long()
    dow_valid = torch.Tensor([x for x in dow_valid])
    hours_valid = torch.Tensor([x for x in hours_valid])
            
    targets = pad_sequence([torch.Tensor(t) for t in y]).T

    return {"targets": targets,
            'seq_i': seq_i,
            'seq_len':seq_len,
            'dow': days_of_weeks,
            'hours': hours,
            'dow_valid': dow_valid,
            'hours_valid': hours_valid}

In [10]:
loaders = {
        "train": DataLoader(train, batch_size=256, collate_fn=collate_fn_train),
        "valid": DataLoader(valid, batch_size=256, collate_fn=collate_fn_valid),
}

In [12]:
class MultiHeadAttention(nn.Module):
    """
    Multi-head Self-attention layers, a attention score dropout layer is introduced.
    Args:
        input_tensor (torch.Tensor): the input of the multi-head self-attention layer
        attention_mask (torch.Tensor): the attention mask for input tensor
    Returns:
        hidden_states (torch.Tensor): the output of the multi-head self-attention layer
    """

    def __init__(
        self,
        n_heads,
        hidden_size,
        hidden_dropout_prob,
        attn_dropout_prob,
        layer_norm_eps,
    ):
        super(MultiHeadAttention, self).__init__()
        if hidden_size % n_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (hidden_size, n_heads)
            )

        self.num_attention_heads = n_heads
        self.attention_head_size = int(hidden_size / n_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        self.sqrt_attention_head_size = math.sqrt(self.attention_head_size)

        self.query = nn.Linear(hidden_size, self.all_head_size)
        self.key = nn.Linear(hidden_size, self.all_head_size)
        self.value = nn.Linear(hidden_size, self.all_head_size)

        self.softmax = nn.Softmax(dim=-1)
        self.attn_dropout = nn.Dropout(attn_dropout_prob)

        self.dense = nn.Linear(hidden_size, hidden_size)
        self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
        self.out_dropout = nn.Dropout(hidden_dropout_prob)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (
            self.num_attention_heads,
            self.attention_head_size,
        )
        x = x.view(*new_x_shape)
        return x

    def forward(self, input_tensor, attention_mask, return_explanations=False):
        mixed_query_layer = self.query(input_tensor)
        mixed_key_layer = self.key(input_tensor)
        mixed_value_layer = self.value(input_tensor)

        query_layer = self.transpose_for_scores(mixed_query_layer).permute(0, 2, 1, 3)
        key_layer = self.transpose_for_scores(mixed_key_layer).permute(0, 2, 3, 1)
        value_layer = self.transpose_for_scores(mixed_value_layer).permute(0, 2, 1, 3)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer)

        attention_scores = attention_scores / self.sqrt_attention_head_size
        # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
        # [batch_size heads seq_len seq_len] scores
        # [batch_size 1 1 seq_len]
        attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = self.softmax(attention_scores)
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.

        attention_probs = self.attn_dropout(attention_probs)
        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        hidden_states = self.dense(context_layer)
        hidden_states = self.out_dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        
        if return_explanations:
            return hidden_states, attention_probs
        else:
            return hidden_states

        
        

class TransformerLayer(nn.Module):
    """
    One transformer layer consists of a multi-head self-attention layer and a point-wise feed-forward layer.
    Args:
        hidden_states (torch.Tensor): the input of the multi-head self-attention sublayer
        attention_mask (torch.Tensor): the attention mask for the multi-head self-attention sublayer
    Returns:
        feedforward_output (torch.Tensor): The output of the point-wise feed-forward sublayer,
                                           is the output of the transformer layer.
    """

    def __init__(
        self,
        n_heads,
        hidden_size,
        intermediate_size,
        hidden_dropout_prob,
        attn_dropout_prob,
        hidden_act,
        layer_norm_eps,
    ):
        super(TransformerLayer, self).__init__()
        self.multi_head_attention = MultiHeadAttention(
            n_heads, hidden_size, hidden_dropout_prob, attn_dropout_prob, layer_norm_eps
        )
        self.feed_forward = FeedForward(
            hidden_size,
            intermediate_size,
            hidden_dropout_prob,
            hidden_act,
            layer_norm_eps,
        )

    def forward(self, hidden_states, attention_mask,return_explanations=False):
        
        if return_explanations:
            attention_output, expl = self.multi_head_attention(hidden_states, attention_mask,
                                                         return_explanations=return_explanations)
            
        else:
            attention_output = self.multi_head_attention(hidden_states, attention_mask,
                                                         return_explanations=return_explanations)
        feedforward_output = self.feed_forward(attention_output)
        
        if return_explanations:
            return feedforward_output, expl
        else:
            return feedforward_output
    
    
    
class TransformerEncoder(nn.Module):
    r"""One TransformerEncoder consists of several TransformerLayers.
    Args:
        n_layers(num): num of transformer layers in transformer encoder. Default: 2
        n_heads(num): num of attention heads for multi-head attention layer. Default: 2
        hidden_size(num): the input and output hidden size. Default: 64
        inner_size(num): the dimensionality in feed-forward layer. Default: 256
        hidden_dropout_prob(float): probability of an element to be zeroed. Default: 0.5
        attn_dropout_prob(float): probability of an attention score to be zeroed. Default: 0.5
        hidden_act(str): activation function in feed-forward layer. Default: 'gelu'
                      candidates: 'gelu', 'relu', 'swish', 'tanh', 'sigmoid'
        layer_norm_eps(float): a value added to the denominator for numerical stability. Default: 1e-12
    """

    def __init__(
        self,
        n_layers=2,
        n_heads=2,
        hidden_size=64,
        inner_size=256,
        hidden_dropout_prob=0.5,
        attn_dropout_prob=0.5,
        hidden_act="gelu",
        layer_norm_eps=1e-12,
    ):

        super(TransformerEncoder, self).__init__()
        layer = TransformerLayer(
            n_heads,
            hidden_size,
            inner_size,
            hidden_dropout_prob,
            attn_dropout_prob,
            hidden_act,
            layer_norm_eps,
        )
        self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(n_layers)])

    def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True,
                return_explanations=False):
        """
        Args:
            hidden_states (torch.Tensor): the input of the TransformerEncoder
            attention_mask (torch.Tensor): the attention mask for the input hidden_states
            output_all_encoded_layers (Bool): whether output all transformer layers' output
        Returns:
            all_encoder_layers (list): if output_all_encoded_layers is True, return a list consists of all transformer
            layers' output, otherwise return a list only consists of the output of last transformer layer.
        """
        all_encoder_layers = []
        for idx, layer_module in enumerate(self.layer):
            
            if return_explanations:
                hidden_states, expl = layer_module(hidden_states, attention_mask, 
                                         return_explanations=return_explanations)
            else:            
                hidden_states = layer_module(hidden_states, attention_mask, 
                                             return_explanations=return_explanations)
            if output_all_encoded_layers:
                all_encoder_layers.append(hidden_states)
        if not output_all_encoded_layers:
            all_encoder_layers.append(hidden_states)
            
        if return_explanations:
            return all_encoder_layers, expl
        else:
            return all_encoder_layers

    

    
class BERT4Rec(torch.nn.Module):

    def __init__(self, n_items, hidden_size, mask_ratio):
        super(BERT4Rec, self).__init__()

        self.n_layers = 2
        self.n_heads = 2
        self.hidden_size = hidden_size  
        self.inner_size = 128 
        self.hidden_dropout_prob = 0.2
        self.attn_dropout_prob = 0.2
        self.hidden_act = 'sigmoid'
        self.layer_norm_eps = 1e-5
        self.ITEM_SEQ = 'seq_i'
        self.ITEM_SEQ_LEN = 'seq_len'
        self.max_seq_length = 200
        

        self.mask_ratio = mask_ratio

        self.loss_type =  'CE'
        self.initializer_range = 1e-2

        self.n_items = n_items
        self.mask_token = self.n_items 
        self.mask_item_length = int(self.mask_ratio * self.max_seq_length)

        self.item_embedding = nn.Embedding(self.n_items + 1, self.hidden_size, padding_idx=0) 
        self.position_embedding = nn.Embedding(self.max_seq_length + 1, self.hidden_size)  
        self.trm_encoder = TransformerEncoder(
            n_layers=self.n_layers,
            n_heads=self.n_heads,
            hidden_size=self.hidden_size,
            inner_size=self.inner_size,
            hidden_dropout_prob=self.hidden_dropout_prob,
            attn_dropout_prob=self.attn_dropout_prob,
            hidden_act=self.hidden_act,
            layer_norm_eps=self.layer_norm_eps
        )

        self.LayerNorm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
        self.dropout = nn.Dropout(self.hidden_dropout_prob)

        try:
            assert self.loss_type in ['BPR', 'CE']
        except AssertionError:
            raise AssertionError("Make sure 'loss_type' in ['BPR', 'CE']!")

        self.apply(self._init_weights)
        
    def gather_indexes(self, output, gather_index):
        """Gathers the vectors at the specific positions over a minibatch"""
        gather_index = gather_index.view(-1, 1, 1).expand(-1, -1, output.shape[-1])
        output_tensor = output.gather(dim=1, index=gather_index)
        return output_tensor.squeeze(1)

    def _init_weights(self, module):
        """ Initialize the weights """
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=self.initializer_range)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    def get_attention_mask(self, item_seq):
        """Generate bidirectional attention mask for multi-head attention."""
        attention_mask = (item_seq > 0).long()
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)  
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) 
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        return extended_attention_mask

    def _neg_sample(self, item_set):
        item = random.randint(1, self.n_items - 1)
        while item in item_set:
            item = random.randint(1, self.n_items - 1)
        return item

    def _padding_sequence(self, sequence, max_length):
        pad_len = max_length - len(sequence)
        sequence = [0] * pad_len + sequence
        sequence = sequence[-max_length:]  
        return sequence

    def reconstruct_train_data(self, item_seq):
        """
        Mask item sequence for training.
        """
        device = item_seq.device
        batch_size = item_seq.size(0)

        sequence_instances = item_seq.cpu().numpy().tolist()

        masked_item_sequence = []
        pos_items = []
        neg_items = []
        masked_index = []
        for instance in sequence_instances:
            masked_sequence = instance.copy()
            pos_item = []
            neg_item = []
            index_ids = []
            for index_id, item in enumerate(instance):
                if item == 0:
                    break
                prob = random.random()
                if prob < self.mask_ratio:
                    pos_item.append(item)
                    neg_item.append(self._neg_sample(instance))
                    masked_sequence[index_id] = self.mask_token
                    index_ids.append(index_id)

            masked_item_sequence.append(masked_sequence)
            pos_items.append(self._padding_sequence(pos_item, self.mask_item_length))
            neg_items.append(self._padding_sequence(neg_item, self.mask_item_length))
            masked_index.append(self._padding_sequence(index_ids, self.mask_item_length))

        masked_item_sequence = torch.tensor(masked_item_sequence, dtype=torch.long, device=device).view(batch_size, -1)
        pos_items = torch.tensor(pos_items, dtype=torch.long, device=device).view(batch_size, -1)
        neg_items = torch.tensor(neg_items, dtype=torch.long, device=device).view(batch_size, -1)
        masked_index = torch.tensor(masked_index, dtype=torch.long, device=device).view(batch_size, -1)
        return masked_item_sequence, pos_items, neg_items, masked_index

    def reconstruct_test_data(self, item_seq, item_seq_len):
        """
        Add mask token at the last position according to the lengths of item_seq
        """
        padding = torch.zeros(item_seq.size(0), dtype=torch.long, device=item_seq.device)  
        item_seq = torch.cat((item_seq, padding.unsqueeze(-1)), dim=-1)
        for batch_id, last_position in enumerate(item_seq_len):
            item_seq[batch_id][last_position] = self.mask_token
        return item_seq

    def forward(self, item_seq, return_explanations=False):
        
        position_ids = torch.arange(item_seq.size(1), dtype=torch.long, device=item_seq.device)
        position_ids = position_ids.unsqueeze(0).expand_as(item_seq)
        position_embedding = self.position_embedding(position_ids)
        
        item_emb = self.item_embedding(item_seq)
        input_emb = item_emb + position_embedding
        input_emb = self.LayerNorm(input_emb)
        input_emb = self.dropout(input_emb)
        extended_attention_mask = self.get_attention_mask(item_seq)
        if return_explanations:
            trm_output, explanations = self.trm_encoder(input_emb, extended_attention_mask, output_all_encoded_layers=True,
                                         return_explanations=return_explanations)
        else:
            trm_output = self.trm_encoder(input_emb, extended_attention_mask, output_all_encoded_layers=True,
                                         return_explanations=return_explanations)
            
        output = trm_output[-1]
        
        if return_explanations:
            return output, explanations
        else:
            return output

    def multi_hot_embed(self, masked_index, max_length):
        """
        For memory, we only need calculate loss for masked position.
        Generate a multi-hot vector to indicate the masked position for masked sequence, and then is used for
        gathering the masked position hidden representation.
        Examples:
            sequence: [1 2 3 4 5]
            masked_sequence: [1 mask 3 mask 5]
            masked_index: [1, 3]
            max_length: 5
            multi_hot_embed: [[0 1 0 0 0], [0 0 0 1 0]]
        """
        masked_index = masked_index.view(-1)
        multi_hot = torch.zeros(masked_index.size(0), max_length, device=masked_index.device)
        multi_hot[torch.arange(masked_index.size(0)), masked_index] = 1
        return multi_hot

    def calculate_loss(self, interaction):
        item_seq = interaction[self.ITEM_SEQ].long()
        masked_item_seq, pos_items, neg_items, masked_index = self.reconstruct_train_data(item_seq)

        seq_output = self.forward(masked_item_seq)
        pred_index_map = self.multi_hot_embed(masked_index, masked_item_seq.size(-1))
        pred_index_map = pred_index_map.view(masked_index.size(0), masked_index.size(1), -1)  
        seq_output = torch.bmm(pred_index_map, seq_output) 

        if self.loss_type == 'BPR':
            pos_items_emb = self.item_embedding(pos_items)  
            neg_items_emb = self.item_embedding(neg_items)  
            pos_score = torch.sum(seq_output * pos_items_emb, dim=-1)  
            neg_score = torch.sum(seq_output * neg_items_emb, dim=-1)  
            targets = (masked_index > 0).float()
            loss = - torch.sum(torch.log(1e-14 + torch.sigmoid(pos_score - neg_score)) * targets) \
                   / torch.sum(targets)
            return loss

        elif self.loss_type == 'CE':
            loss_fct = nn.CrossEntropyLoss(reduction='none')
            test_item_emb = self.item_embedding.weight[:self.n_items] 
            logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1)) 
            targets = (masked_index > 0).float().view(-1) 

            loss = torch.sum(loss_fct(logits.view(-1, test_item_emb.size(0)), pos_items.view(-1)) * targets) \
                   / torch.sum(targets)
            return loss
        else:
            raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!")


    def full_sort_predict(self, interaction, return_explanations=False):
        item_seq = interaction[self.ITEM_SEQ].long()
        item_seq_len = interaction[self.ITEM_SEQ_LEN].long()
        item_seq = self.reconstruct_test_data(item_seq, item_seq_len)
        
        if return_explanations:
            seq_output, expl = self.forward(item_seq, return_explanations=return_explanations)
        else:
            seq_output = self.forward(item_seq, return_explanations=return_explanations)
            
        
        seq_output = self.gather_indexes(seq_output, item_seq_len - 1)  
        test_items_emb = self.item_embedding.weight[:self.n_items]  
        scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1))  
                
        idxs = item_seq.nonzero()
        item_seq[item_seq==self.n_items] = 0
        scores[idxs[:,0], item_seq[idxs[:,0],idxs[:,1]].long()] = -1000

        if return_explanations:
            return scores, expl
        else:
            return scores

### Hours only

In [13]:
class ContextBERT4Rec(BERT4Rec):

    def __init__(self, n_items, hidden_size, mask_ratio):
        super(BERT4Rec, self).__init__()
        
        self.n_layers = 2
        self.n_heads = 2
        self.hidden_size = hidden_size  
        self.inner_size = 128 
        self.hidden_dropout_prob = 0.2
        self.attn_dropout_prob = 0.2
        self.hidden_act = 'sigmoid'
        self.layer_norm_eps = 1e-5
        self.ITEM_SEQ = 'seq_i'
        self.ITEM_SEQ_LEN = 'seq_len'
        self.max_seq_length = 200
        

        self.mask_ratio = mask_ratio

        self.loss_type =  'CE'
        self.initializer_range = 1e-2

        # load dataset info
        self.n_items = n_items
        self.mask_token = self.n_items
        self.mask_item_length = int(self.mask_ratio * self.max_seq_length)

        # define layers and loss
        self.hours_embedding = nn.Embedding(24, self.hidden_size)
        self.item_embedding = nn.Embedding(self.n_items + 1, self.hidden_size, padding_idx=0)  # mask token add 1
        self.position_embedding = nn.Embedding(self.max_seq_length + 1, self.hidden_size)  # add mask_token at the last
        self.trm_encoder = TransformerEncoder(
            n_layers=self.n_layers,
            n_heads=self.n_heads,
            hidden_size=self.hidden_size,
            inner_size=self.inner_size,
            hidden_dropout_prob=self.hidden_dropout_prob,
            attn_dropout_prob=self.attn_dropout_prob,
            hidden_act=self.hidden_act,
            layer_norm_eps=self.layer_norm_eps
        )

        self.LayerNorm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
        self.dropout = nn.Dropout(self.hidden_dropout_prob)

        try:
            assert self.loss_type in ['BPR', 'CE']
        except AssertionError:
            raise AssertionError("Make sure 'loss_type' in ['BPR', 'CE']!")

        self.apply(self._init_weights)



    def reconstruct_test_data(self,
                              item_seq,
                              item_seq_len,
                              hours,
                              hours_valid,
                              particular_day=-1,
                              ):
        """
        Add mask token at the last position according to the lengths of item_seq
        """
        padding = torch.zeros(item_seq.size(0), dtype=torch.long, device=item_seq.device)  # [B]
        item_seq = torch.cat((item_seq, padding.unsqueeze(-1)), dim=-1)  # [B max_len+1]
        hours = torch.cat((hours, padding.unsqueeze(-1)), dim=-1)
        for batch_id, last_position in enumerate(item_seq_len):
            item_seq[batch_id][last_position] = self.mask_token
            if particular_day == -1:
                hours[batch_id][last_position] = hours_valid[batch_id]
            else:
                hours[batch_id][last_position] = particular_day
        return item_seq, hours

    def forward(self, item_seq, hours, return_explanations=False):
        
        
        hours_embeddings = self.hours_embedding(hours.long())
        
        position_ids = torch.arange(item_seq.size(1), dtype=torch.long, device=item_seq.device)
        position_ids = position_ids.unsqueeze(0).expand_as(item_seq)
        position_embedding = self.position_embedding(position_ids)
        item_emb = self.item_embedding(item_seq)
        input_emb = item_emb + position_embedding + hours_embeddings
        input_emb = self.LayerNorm(input_emb)
        input_emb = self.dropout(input_emb)
        extended_attention_mask = self.get_attention_mask(item_seq)
        if return_explanations:
            trm_output, explanations = self.trm_encoder(input_emb, extended_attention_mask, output_all_encoded_layers=True,
                                         return_explanations=return_explanations)
        else:
            trm_output = self.trm_encoder(input_emb, extended_attention_mask, output_all_encoded_layers=True,
                                         return_explanations=return_explanations)
            
        output = trm_output[-1]
        
        if return_explanations:
            return output, explanations
        else:
            return output


    def calculate_loss(self, interaction):
        item_seq = interaction[self.ITEM_SEQ].long()
        masked_item_seq, pos_items, neg_items, masked_index = self.reconstruct_train_data(item_seq)

        seq_output = self.forward(masked_item_seq, hours=interaction['hours'])
        pred_index_map = self.multi_hot_embed(masked_index, masked_item_seq.size(-1))  
        pred_index_map = pred_index_map.view(masked_index.size(0), masked_index.size(1), -1) 
        seq_output = torch.bmm(pred_index_map, seq_output) 

        if self.loss_type == 'BPR':
            pos_items_emb = self.item_embedding(pos_items)  
            neg_items_emb = self.item_embedding(neg_items)  
            pos_score = torch.sum(seq_output * pos_items_emb, dim=-1)  
            neg_score = torch.sum(seq_output * neg_items_emb, dim=-1)  
            targets = (masked_index > 0).float()
            loss = - torch.sum(torch.log(1e-14 + torch.sigmoid(pos_score - neg_score)) * targets) \
                   / torch.sum(targets)
            return loss

        elif self.loss_type == 'CE':
            loss_fct = nn.CrossEntropyLoss(reduction='none')
            test_item_emb = self.item_embedding.weight[:self.n_items]  
            logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1))  
            targets = (masked_index > 0).float().view(-1)  

            loss = torch.sum(loss_fct(logits.view(-1, test_item_emb.size(0)), pos_items.view(-1)) * targets) \
                   / torch.sum(targets)
            return loss
        else:
            raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!")


    def full_sort_predict(self, 
                          interaction,
                          return_explanations=False,
                          particular_day=-1):
        
        item_seq = interaction[self.ITEM_SEQ].long()
        item_seq_len = interaction[self.ITEM_SEQ_LEN].long()
        item_seq, hours = self.reconstruct_test_data(item_seq,
                                              item_seq_len,
                                              hours=interaction['hours'],
                                              hours_valid=interaction['hours_valid'].long(),
                                              particular_day=particular_day)
        
        
        if return_explanations:
            seq_output, expl = self.forward(item_seq,
                                            hours=hours,
                                            return_explanations=return_explanations)
        else:
            seq_output = self.forward(item_seq,
                                      hours=hours,
                                      return_explanations=return_explanations)
            
        
        seq_output = self.gather_indexes(seq_output, item_seq_len - 1)  
        test_items_emb = self.item_embedding.weight[:self.n_items]  
        scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1)) 
                
        idxs = item_seq.nonzero()
        item_seq[item_seq==self.n_items] = 0
        scores[idxs[:,0], item_seq[idxs[:,0],idxs[:,1]].long()] = -1000

        if return_explanations:
            return scores, expl
        else:
            return scores


In [20]:
model = ContextBERT4Rec(n_items=len(item2idx)+1, mask_ratio=0.2, hidden_size=128)

optimizer = optim.Adam(model.parameters(), lr=0.01)
lr_scheduler = StepLR(optimizer, step_size=20, gamma=0.1)
engine = dl.DeviceEngine('cpu')
hparams = {
    "anneal_cap": 0.2,
    "total_anneal_steps": 6000,
}


callbacks = [
    dl.NDCGCallback("logits", "targets", [20]),
    dl.MAPCallback("logits", "targets", [10]),
    dl.OptimizerCallback("loss", accumulation_steps=1),
    dl.EarlyStoppingCallback(
        patience=5, loader_key="valid", metric_key="map10", minimize=False
    )
]


runner = RecSysRunner()
runner.train(
    model=model,
    optimizer=optimizer,
    engine=engine,
    hparams=hparams,
    scheduler=lr_scheduler,
    loaders=loaders,
    num_epochs=100,
    verbose=True,
    timeit=True,
    callbacks=callbacks,
)


1/100 * Epoch (train): 100%|██████████| 24/24 [00:28<00:00,  1.20s/it, _timer/_fps=186.389, _timer/batch_time=0.815, _timer/data_time=0.453, _timer/model_time=0.363, loss=7.394, lr=0.010, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]


train (1/100) loss: 7.5486999454877255 | lr: 0.01 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


1/100 * Epoch (valid): 100%|██████████| 24/24 [00:13<00:00,  1.72it/s, _timer/_fps=429.411, _timer/batch_time=0.354, _timer/data_time=0.020, _timer/model_time=0.334, loss=7.326, lr=0.010, map10=0.103, momentum=0.900, ndcg20=0.060]


valid (1/100) loss: 7.4432442595627135 | lr: 0.01 | map10: 0.12408531623960332 | map10/std: 0.016438783373056175 | momentum: 0.9 | ndcg20: 0.06770167051265572 | ndcg20/std: 0.006908426776674754
* Epoch (1/100) 


2/100 * Epoch (train): 100%|██████████| 24/24 [00:28<00:00,  1.17s/it, _timer/_fps=178.674, _timer/batch_time=0.851, _timer/data_time=0.492, _timer/model_time=0.359, loss=7.353, lr=0.010, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]


train (2/100) loss: 7.453495850468314 | lr: 0.01 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


2/100 * Epoch (valid): 100%|██████████| 24/24 [00:13<00:00,  1.84it/s, _timer/_fps=444.299, _timer/batch_time=0.342, _timer/data_time=0.020, _timer/model_time=0.322, loss=7.320, lr=0.010, map10=0.109, momentum=0.900, ndcg20=0.067]


valid (2/100) loss: 7.41002306717121 | lr: 0.01 | map10: 0.13553662106690814 | map10/std: 0.016741590569495545 | momentum: 0.9 | ndcg20: 0.0768625318609326 | ndcg20/std: 0.007651711078039143
* Epoch (2/100) 


3/100 * Epoch (train): 100%|██████████| 24/24 [00:28<00:00,  1.17s/it, _timer/_fps=167.814, _timer/batch_time=0.906, _timer/data_time=0.499, _timer/model_time=0.406, loss=7.354, lr=0.010, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]


train (3/100) loss: 7.42241380514688 | lr: 0.01 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


3/100 * Epoch (valid): 100%|██████████| 24/24 [00:13<00:00,  1.79it/s, _timer/_fps=424.846, _timer/batch_time=0.358, _timer/data_time=0.019, _timer/model_time=0.338, loss=7.316, lr=0.010, map10=0.108, momentum=0.900, ndcg20=0.067]


valid (3/100) loss: 7.383084136760787 | lr: 0.01 | map10: 0.1363276656100292 | map10/std: 0.017371114096854018 | momentum: 0.9 | ndcg20: 0.07431822904687843 | ndcg20/std: 0.007529737098835713
* Epoch (3/100) 


4/100 * Epoch (train): 100%|██████████| 24/24 [00:30<00:00,  1.28s/it, _timer/_fps=171.024, _timer/batch_time=0.889, _timer/data_time=0.506, _timer/model_time=0.382, loss=7.348, lr=0.010, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]


train (4/100) loss: 7.417258884101513 | lr: 0.01 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


4/100 * Epoch (valid): 100%|██████████| 24/24 [00:15<00:00,  1.59it/s, _timer/_fps=453.082, _timer/batch_time=0.335, _timer/data_time=0.021, _timer/model_time=0.315, loss=7.317, lr=0.010, map10=0.117, momentum=0.900, ndcg20=0.066]


valid (4/100) loss: 7.369235244018354 | lr: 0.01 | map10: 0.13686595531488888 | map10/std: 0.017472520268054098 | momentum: 0.9 | ndcg20: 0.07650183581358551 | ndcg20/std: 0.00782102264951083
* Epoch (4/100) 


5/100 * Epoch (train): 100%|██████████| 24/24 [00:27<00:00,  1.16s/it, _timer/_fps=174.569, _timer/batch_time=0.871, _timer/data_time=0.513, _timer/model_time=0.358, loss=7.367, lr=0.010, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]


train (5/100) loss: 7.3922384843131566 | lr: 0.01 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


5/100 * Epoch (valid): 100%|██████████| 24/24 [00:13<00:00,  1.76it/s, _timer/_fps=349.308, _timer/batch_time=0.435, _timer/data_time=0.023, _timer/model_time=0.412, loss=7.341, lr=0.010, map10=0.125, momentum=0.900, ndcg20=0.066]


valid (5/100) loss: 7.354217925292768 | lr: 0.01 | map10: 0.14151358155422655 | map10/std: 0.016127745767677365 | momentum: 0.9 | ndcg20: 0.07895260493684289 | ndcg20/std: 0.007867150151322277
* Epoch (5/100) 


6/100 * Epoch (train): 100%|██████████| 24/24 [00:29<00:00,  1.23s/it, _timer/_fps=183.078, _timer/batch_time=0.830, _timer/data_time=0.471, _timer/model_time=0.359, loss=7.357, lr=0.010, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]


train (6/100) loss: 7.395082389124181 | lr: 0.01 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


6/100 * Epoch (valid): 100%|██████████| 24/24 [00:12<00:00,  1.85it/s, _timer/_fps=460.874, _timer/batch_time=0.330, _timer/data_time=0.020, _timer/model_time=0.310, loss=7.333, lr=0.010, map10=0.114, momentum=0.900, ndcg20=0.067]


valid (6/100) loss: 7.359322107706638 | lr: 0.01 | map10: 0.1329105152298283 | map10/std: 0.016020637968934704 | momentum: 0.9 | ndcg20: 0.07776578152613925 | ndcg20/std: 0.007403739742993672
* Epoch (6/100) 


7/100 * Epoch (train): 100%|██████████| 24/24 [00:28<00:00,  1.19s/it, _timer/_fps=181.845, _timer/batch_time=0.836, _timer/data_time=0.475, _timer/model_time=0.361, loss=7.381, lr=0.010, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]


train (7/100) loss: 7.399077822830503 | lr: 0.01 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


7/100 * Epoch (valid): 100%|██████████| 24/24 [00:12<00:00,  1.86it/s, _timer/_fps=453.320, _timer/batch_time=0.335, _timer/data_time=0.019, _timer/model_time=0.316, loss=7.359, lr=0.010, map10=0.116, momentum=0.900, ndcg20=0.070]


valid (7/100) loss: 7.361966505113816 | lr: 0.01 | map10: 0.1315164936101989 | map10/std: 0.01879966374652992 | momentum: 0.9 | ndcg20: 0.07702410432281874 | ndcg20/std: 0.008275508973209202
* Epoch (7/100) 


8/100 * Epoch (train): 100%|██████████| 24/24 [00:26<00:00,  1.12s/it, _timer/_fps=183.215, _timer/batch_time=0.830, _timer/data_time=0.467, _timer/model_time=0.363, loss=7.363, lr=0.010, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]


train (8/100) loss: 7.3877011090714415 | lr: 0.01 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


8/100 * Epoch (valid): 100%|██████████| 24/24 [00:12<00:00,  1.88it/s, _timer/_fps=395.976, _timer/batch_time=0.384, _timer/data_time=0.021, _timer/model_time=0.363, loss=7.343, lr=0.010, map10=0.111, momentum=0.900, ndcg20=0.070]


valid (8/100) loss: 7.354840564727785 | lr: 0.01 | map10: 0.138475516339801 | map10/std: 0.0185681068924728 | momentum: 0.9 | ndcg20: 0.07955095825211102 | ndcg20/std: 0.008095980176571925
* Epoch (8/100) 


9/100 * Epoch (train): 100%|██████████| 24/24 [00:28<00:00,  1.18s/it, _timer/_fps=164.803, _timer/batch_time=0.922, _timer/data_time=0.557, _timer/model_time=0.365, loss=7.344, lr=0.010, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]


train (9/100) loss: 7.392812946774312 | lr: 0.01 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


9/100 * Epoch (valid): 100%|██████████| 24/24 [00:12<00:00,  1.89it/s, _timer/_fps=468.594, _timer/batch_time=0.324, _timer/data_time=0.019, _timer/model_time=0.305, loss=7.325, lr=0.010, map10=0.116, momentum=0.900, ndcg20=0.066]


valid (9/100) loss: 7.359860982326483 | lr: 0.01 | map10: 0.13974313579055647 | map10/std: 0.01689711623110146 | momentum: 0.9 | ndcg20: 0.08036795160825679 | ndcg20/std: 0.008032918983786844
* Epoch (9/100) 


10/100 * Epoch (train): 100%|██████████| 24/24 [00:26<00:00,  1.12s/it, _timer/_fps=166.698, _timer/batch_time=0.912, _timer/data_time=0.546, _timer/model_time=0.366, loss=7.337, lr=0.010, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]


train (10/100) loss: 7.389346378528519 | lr: 0.01 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


10/100 * Epoch (valid): 100%|██████████| 24/24 [00:12<00:00,  1.85it/s, _timer/_fps=454.407, _timer/batch_time=0.335, _timer/data_time=0.020, _timer/model_time=0.314, loss=7.313, lr=0.010, map10=0.124, momentum=0.900, ndcg20=0.071]

valid (10/100) loss: 7.355596533516385 | lr: 0.01 | map10: 0.1414352862250726 | map10/std: 0.017744710421119393 | momentum: 0.9 | ndcg20: 0.07977344540768111 | ndcg20/std: 0.007730465737714401
* Epoch (10/100) 





In [21]:
test_dataset = MyDataset(ds=joined, num_items=n_items, phase='test',item2idx=item2idx)


inference_loader = DataLoader(test_dataset, 
                              batch_size=joined.shape[0]//100, 
                              collate_fn=collate_fn_train,)

preds = []

for prediction in tqdm(runner.predict_loader(loader=inference_loader)):
    preds.extend(prediction.detach().cpu().numpy().tolist())
    
print(len(preds))
assert len(preds) == joined.shape[0]

joined['preds_contextbert4rec'] = preds
joined['recs_contextbert4rec_10'] = joined['preds_contextbert4rec'].apply(lambda x: np.argsort(-np.array(x))[:10])
joined['recs_contextbert4rec_10'] = joined['recs_contextbert4rec_10'].apply(lambda x: [idx2item[t-1] for t in x])
joined['recs_contextbert4rec_5'] = joined['preds_contextbert4rec'].apply(lambda x: np.argsort(-np.array(x))[:5])
joined['recs_contextbert4rec_5'] = joined['recs_contextbert4rec_5'].apply(lambda x: [idx2item[t-1] for t in x])
joined.drop(['preds_contextbert4rec'],axis=1, inplace=True)

101it [00:06, 15.92it/s]


6040


In [22]:
evaluate_recommender(joined, model_preds='recs_contextbert4rec_10')

{'ndcg': 0.15241325770658687, 'recall': 0.03240332694747826}

In [23]:
evaluate_recommender(joined, model_preds='recs_contextbert4rec_5')

{'ndcg': 0.09340676542190061, 'recall': 0.01696575075374235}

### Hours + day of the week

In [18]:
class ContextBERT4Rec(BERT4Rec):

    def __init__(self, n_items, hidden_size, mask_ratio):
        super(BERT4Rec, self).__init__()
        
        self.n_layers = 2
        self.n_heads = 2
        self.hidden_size = hidden_size
        self.inner_size = 128 
        self.hidden_dropout_prob = 0.2
        self.attn_dropout_prob = 0.2
        self.hidden_act = 'sigmoid'
        self.layer_norm_eps = 1e-5
        self.ITEM_SEQ = 'seq_i'
        self.ITEM_SEQ_LEN = 'seq_len'
        self.max_seq_length = 200
        

        self.mask_ratio = mask_ratio

        self.loss_type =  'CE'
        self.initializer_range = 1e-2

        # load dataset info
        self.n_items = n_items
        self.mask_token = self.n_items
        self.mask_item_length = int(self.mask_ratio * self.max_seq_length)

        # define layers and loss
        self.weekday_embedding = nn.Embedding(7, self.hidden_size)
        self.hours_embedding = nn.Embedding(24, self.hidden_size)
        self.item_embedding = nn.Embedding(self.n_items + 1, self.hidden_size, padding_idx=0)  
        self.position_embedding = nn.Embedding(self.max_seq_length + 1, self.hidden_size) 
        self.trm_encoder = TransformerEncoder(
            n_layers=self.n_layers,
            n_heads=self.n_heads,
            hidden_size=self.hidden_size,
            inner_size=self.inner_size,
            hidden_dropout_prob=self.hidden_dropout_prob,
            attn_dropout_prob=self.attn_dropout_prob,
            hidden_act=self.hidden_act,
            layer_norm_eps=self.layer_norm_eps
        )

        self.LayerNorm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
        self.dropout = nn.Dropout(self.hidden_dropout_prob)

        try:
            assert self.loss_type in ['BPR', 'CE']
        except AssertionError:
            raise AssertionError("Make sure 'loss_type' in ['BPR', 'CE']!")

        self.apply(self._init_weights)



    def reconstruct_test_data(self,
                              item_seq,
                              item_seq_len,
                              dow,
                              hours,
                              dow_valid,
                              hours_valid,
                              particular_day=-1,
                              ):
        """
        Add mask token at the last position according to the lengths of item_seq
        """
        padding = torch.zeros(item_seq.size(0), dtype=torch.long, device=item_seq.device)  
        item_seq = torch.cat((item_seq, padding.unsqueeze(-1)), dim=-1)  
        dow = torch.cat((dow, padding.unsqueeze(-1)), dim=-1)
        hours = torch.cat((hours, padding.unsqueeze(-1)), dim=-1)
        for batch_id, last_position in enumerate(item_seq_len):
            item_seq[batch_id][last_position] = self.mask_token
            if particular_day == -1:
                dow[batch_id][last_position] = dow_valid[batch_id]
                hours[batch_id][last_position] = hours_valid[batch_id]
            else:
                dow[batch_id][last_position] = particular_day
                hours[batch_id][last_position] = particular_day
        return item_seq, dow, hours

    def forward(self, item_seq, dow, hours, return_explanations=False):
        
        
        dow_embeddings = self.weekday_embedding(dow.long())
        hours_embeddings = self.hours_embedding(hours.long())
        
        position_ids = torch.arange(item_seq.size(1), dtype=torch.long, device=item_seq.device)
        position_ids = position_ids.unsqueeze(0).expand_as(item_seq)
        position_embedding = self.position_embedding(position_ids)
        item_emb = self.item_embedding(item_seq)
        input_emb = item_emb + position_embedding + dow_embeddings + hours_embeddings
        input_emb = self.LayerNorm(input_emb)
        input_emb = self.dropout(input_emb)
        extended_attention_mask = self.get_attention_mask(item_seq)
        if return_explanations:
            trm_output, explanations = self.trm_encoder(input_emb, extended_attention_mask, output_all_encoded_layers=True,
                                         return_explanations=return_explanations)
        else:
            trm_output = self.trm_encoder(input_emb, extended_attention_mask, output_all_encoded_layers=True,
                                         return_explanations=return_explanations)
            
        output = trm_output[-1]
        
        if return_explanations:
            return output, explanations
        else:
            return output


    def calculate_loss(self, interaction):
        item_seq = interaction[self.ITEM_SEQ].long()
        masked_item_seq, pos_items, neg_items, masked_index = self.reconstruct_train_data(item_seq)

        seq_output = self.forward(masked_item_seq, dow=interaction['dow'], hours=interaction['hours'])
        pred_index_map = self.multi_hot_embed(masked_index, masked_item_seq.size(-1)) 
        pred_index_map = pred_index_map.view(masked_index.size(0), masked_index.size(1), -1) 
        seq_output = torch.bmm(pred_index_map, seq_output)  

        if self.loss_type == 'BPR':
            pos_items_emb = self.item_embedding(pos_items) 
            neg_items_emb = self.item_embedding(neg_items)  
            pos_score = torch.sum(seq_output * pos_items_emb, dim=-1)  
            neg_score = torch.sum(seq_output * neg_items_emb, dim=-1) 
            targets = (masked_index > 0).float()
            loss = - torch.sum(torch.log(1e-14 + torch.sigmoid(pos_score - neg_score)) * targets) \
                   / torch.sum(targets)
            return loss

        elif self.loss_type == 'CE':
            loss_fct = nn.CrossEntropyLoss(reduction='none')
            test_item_emb = self.item_embedding.weight[:self.n_items]  
            logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1))  
            targets = (masked_index > 0).float().view(-1)  

            loss = torch.sum(loss_fct(logits.view(-1, test_item_emb.size(0)), pos_items.view(-1)) * targets) \
                   / torch.sum(targets)
            return loss
        else:
            raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!")


    def full_sort_predict(self, 
                          interaction,
                          return_explanations=False,
                          particular_day=-1):
        
        item_seq = interaction[self.ITEM_SEQ].long()
        item_seq_len = interaction[self.ITEM_SEQ_LEN].long()
        item_seq, dow, hours = self.reconstruct_test_data(item_seq,
                                              item_seq_len,
                                              dow=interaction['dow'],
                                              hours=interaction['hours'],
                                              dow_valid=interaction['dow_valid'].long(),
                                              hours_valid=interaction['hours_valid'].long(),
                                              particular_day=particular_day)
        
        
        if return_explanations:
            seq_output, expl = self.forward(item_seq,
                                            dow=dow,
                                            hours=hours,
                                            return_explanations=return_explanations)
        else:
            seq_output = self.forward(item_seq,
                                      dow=dow,
                                      hours=hours,
                                      return_explanations=return_explanations)
            
        
        seq_output = self.gather_indexes(seq_output, item_seq_len - 1)  
        test_items_emb = self.item_embedding.weight[:self.n_items]  
        scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1))  
                
        idxs = item_seq.nonzero()
        item_seq[item_seq==self.n_items] = 0
        scores[idxs[:,0], item_seq[idxs[:,0],idxs[:,1]].long()] = -1000

        if return_explanations:
            return scores, expl
        else:
            return scores


In [24]:
model = ContextBERT4Rec(n_items=len(item2idx)+1, mask_ratio=0.2, hidden_size=128)

optimizer = optim.Adam(model.parameters(), lr=0.01)
lr_scheduler = StepLR(optimizer, step_size=20, gamma=0.1)
engine = dl.DeviceEngine('cpu')
hparams = {
    "anneal_cap": 0.2,
    "total_anneal_steps": 6000,
}


callbacks = [
    dl.NDCGCallback("logits", "targets", [20]),
    dl.MAPCallback("logits", "targets", [10]),
    dl.OptimizerCallback("loss", accumulation_steps=1),
    dl.EarlyStoppingCallback(
        patience=5, loader_key="valid", metric_key="map10", minimize=False
    )
]


runner = RecSysRunner()
runner.train(
    model=model,
    optimizer=optimizer,
    engine=engine,
    hparams=hparams,
    scheduler=lr_scheduler,
    loaders=loaders,
    num_epochs=100,
    verbose=True,
    timeit=True,
    callbacks=callbacks,
)


1/100 * Epoch (train): 100%|██████████| 24/24 [00:26<00:00,  1.10s/it, _timer/_fps=184.197, _timer/batch_time=0.825, _timer/data_time=0.465, _timer/model_time=0.360, loss=7.418, lr=0.010, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]


train (1/100) loss: 7.555545104576263 | lr: 0.01 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


1/100 * Epoch (valid): 100%|██████████| 24/24 [00:12<00:00,  1.88it/s, _timer/_fps=462.453, _timer/batch_time=0.329, _timer/data_time=0.020, _timer/model_time=0.309, loss=7.341, lr=0.010, map10=0.111, momentum=0.900, ndcg20=0.064]


valid (1/100) loss: 7.445513191602088 | lr: 0.01 | map10: 0.12790470888085714 | map10/std: 0.01621691434394232 | momentum: 0.9 | ndcg20: 0.06918315407850885 | ndcg20/std: 0.007077874368164376
* Epoch (1/100) 


2/100 * Epoch (train): 100%|██████████| 24/24 [00:28<00:00,  1.21s/it, _timer/_fps=179.797, _timer/batch_time=0.845, _timer/data_time=0.455, _timer/model_time=0.390, loss=7.367, lr=0.010, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]


train (2/100) loss: 7.44987665580598 | lr: 0.01 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


2/100 * Epoch (valid): 100%|██████████| 24/24 [00:12<00:00,  1.89it/s, _timer/_fps=459.865, _timer/batch_time=0.331, _timer/data_time=0.020, _timer/model_time=0.311, loss=7.329, lr=0.010, map10=0.111, momentum=0.900, ndcg20=0.064]


valid (2/100) loss: 7.406028172827715 | lr: 0.01 | map10: 0.13488951142852668 | map10/std: 0.01732436265849217 | momentum: 0.9 | ndcg20: 0.0752465847509586 | ndcg20/std: 0.0075593063303510755
* Epoch (2/100) 


3/100 * Epoch (train): 100%|██████████| 24/24 [00:27<00:00,  1.14s/it, _timer/_fps=186.947, _timer/batch_time=0.813, _timer/data_time=0.446, _timer/model_time=0.367, loss=7.348, lr=0.010, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]


train (3/100) loss: 7.420824758895975 | lr: 0.01 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


3/100 * Epoch (valid): 100%|██████████| 24/24 [00:13<00:00,  1.82it/s, _timer/_fps=434.267, _timer/batch_time=0.350, _timer/data_time=0.022, _timer/model_time=0.328, loss=7.318, lr=0.010, map10=0.118, momentum=0.900, ndcg20=0.067]


valid (3/100) loss: 7.394668850046122 | lr: 0.01 | map10: 0.13780646224487694 | map10/std: 0.017070203664227363 | momentum: 0.9 | ndcg20: 0.07514734006678031 | ndcg20/std: 0.007966531944900112
* Epoch (3/100) 


4/100 * Epoch (train): 100%|██████████| 24/24 [00:27<00:00,  1.16s/it, _timer/_fps=169.459, _timer/batch_time=0.897, _timer/data_time=0.524, _timer/model_time=0.373, loss=7.342, lr=0.010, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]


train (4/100) loss: 7.420100878406045 | lr: 0.01 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


4/100 * Epoch (valid): 100%|██████████| 24/24 [00:13<00:00,  1.78it/s, _timer/_fps=334.189, _timer/batch_time=0.455, _timer/data_time=0.024, _timer/model_time=0.431, loss=7.309, lr=0.010, map10=0.132, momentum=0.900, ndcg20=0.072]


valid (4/100) loss: 7.371081373075775 | lr: 0.01 | map10: 0.14047056955612258 | map10/std: 0.01900748433859079 | momentum: 0.9 | ndcg20: 0.07832785569476766 | ndcg20/std: 0.008021938772896852
* Epoch (4/100) 


5/100 * Epoch (train): 100%|██████████| 24/24 [00:29<00:00,  1.24s/it, _timer/_fps=173.845, _timer/batch_time=0.874, _timer/data_time=0.496, _timer/model_time=0.378, loss=7.368, lr=0.010, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]


train (5/100) loss: 7.394936804739846 | lr: 0.01 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


5/100 * Epoch (valid): 100%|██████████| 24/24 [00:13<00:00,  1.75it/s, _timer/_fps=442.243, _timer/batch_time=0.344, _timer/data_time=0.020, _timer/model_time=0.323, loss=7.348, lr=0.010, map10=0.131, momentum=0.900, ndcg20=0.068]


valid (5/100) loss: 7.356114329407547 | lr: 0.01 | map10: 0.1418781919392529 | map10/std: 0.01483220975141328 | momentum: 0.9 | ndcg20: 0.07928699738537237 | ndcg20/std: 0.007633160690774158
* Epoch (5/100) 


6/100 * Epoch (train): 100%|██████████| 24/24 [00:26<00:00,  1.10s/it, _timer/_fps=192.517, _timer/batch_time=0.790, _timer/data_time=0.420, _timer/model_time=0.369, loss=7.359, lr=0.010, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]


train (6/100) loss: 7.395024589513311 | lr: 0.01 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


6/100 * Epoch (valid): 100%|██████████| 24/24 [00:13<00:00,  1.83it/s, _timer/_fps=440.737, _timer/batch_time=0.345, _timer/data_time=0.023, _timer/model_time=0.322, loss=7.342, lr=0.010, map10=0.106, momentum=0.900, ndcg20=0.067]


valid (6/100) loss: 7.366093260405079 | lr: 0.01 | map10: 0.13294146644161237 | map10/std: 0.018435071658037913 | momentum: 0.9 | ndcg20: 0.07597670783072905 | ndcg20/std: 0.007776523998895446
* Epoch (6/100) 


7/100 * Epoch (train): 100%|██████████| 24/24 [00:27<00:00,  1.15s/it, _timer/_fps=170.666, _timer/batch_time=0.891, _timer/data_time=0.520, _timer/model_time=0.371, loss=7.372, lr=0.010, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]


train (7/100) loss: 7.3994960949120925 | lr: 0.01 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


7/100 * Epoch (valid): 100%|██████████| 24/24 [00:14<00:00,  1.66it/s, _timer/_fps=418.413, _timer/batch_time=0.363, _timer/data_time=0.023, _timer/model_time=0.340, loss=7.358, lr=0.010, map10=0.117, momentum=0.900, ndcg20=0.072]


valid (7/100) loss: 7.36685700321829 | lr: 0.01 | map10: 0.1315575515000236 | map10/std: 0.01896158695151225 | momentum: 0.9 | ndcg20: 0.0753728217340463 | ndcg20/std: 0.00830497352161689
* Epoch (7/100) 


8/100 * Epoch (train): 100%|██████████| 24/24 [00:26<00:00,  1.12s/it, _timer/_fps=198.655, _timer/batch_time=0.765, _timer/data_time=0.396, _timer/model_time=0.369, loss=7.363, lr=0.010, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]


train (8/100) loss: 7.390039569652633 | lr: 0.01 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


8/100 * Epoch (valid): 100%|██████████| 24/24 [00:13<00:00,  1.80it/s, _timer/_fps=423.844, _timer/batch_time=0.359, _timer/data_time=0.021, _timer/model_time=0.338, loss=7.347, lr=0.010, map10=0.114, momentum=0.900, ndcg20=0.070]


valid (8/100) loss: 7.357703402184493 | lr: 0.01 | map10: 0.13947743097480556 | map10/std: 0.017390441198534386 | momentum: 0.9 | ndcg20: 0.07939761872125779 | ndcg20/std: 0.007953575870584607
* Epoch (8/100) 


9/100 * Epoch (train): 100%|██████████| 24/24 [00:26<00:00,  1.11s/it, _timer/_fps=189.733, _timer/batch_time=0.801, _timer/data_time=0.437, _timer/model_time=0.364, loss=7.349, lr=0.010, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]


train (9/100) loss: 7.393237271845735 | lr: 0.01 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


9/100 * Epoch (valid): 100%|██████████| 24/24 [00:12<00:00,  1.86it/s, _timer/_fps=467.186, _timer/batch_time=0.325, _timer/data_time=0.021, _timer/model_time=0.304, loss=7.330, lr=0.010, map10=0.124, momentum=0.900, ndcg20=0.066]


valid (9/100) loss: 7.363185141102368 | lr: 0.01 | map10: 0.13569314905152413 | map10/std: 0.0166163718338287 | momentum: 0.9 | ndcg20: 0.07895752016874338 | ndcg20/std: 0.007762595069919511
* Epoch (9/100) 


10/100 * Epoch (train): 100%|██████████| 24/24 [00:26<00:00,  1.09s/it, _timer/_fps=179.271, _timer/batch_time=0.848, _timer/data_time=0.488, _timer/model_time=0.360, loss=7.339, lr=0.010, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]


train (10/100) loss: 7.390272671655314 | lr: 0.01 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


10/100 * Epoch (valid): 100%|██████████| 24/24 [00:12<00:00,  1.86it/s, _timer/_fps=456.896, _timer/batch_time=0.333, _timer/data_time=0.019, _timer/model_time=0.314, loss=7.316, lr=0.010, map10=0.130, momentum=0.900, ndcg20=0.072]

valid (10/100) loss: 7.359320659511135 | lr: 0.01 | map10: 0.14164599831530592 | map10/std: 0.017763887164391423 | momentum: 0.9 | ndcg20: 0.0795116835872069 | ndcg20/std: 0.007699242698728413
* Epoch (10/100) 





In [25]:
test_dataset = MyDataset(ds=joined, num_items=n_items, phase='test',item2idx=item2idx)


inference_loader = DataLoader(test_dataset, 
                              batch_size=joined.shape[0]//100, 
                              collate_fn=collate_fn_train,)

preds = []

for prediction in tqdm(runner.predict_loader(loader=inference_loader)):
    preds.extend(prediction.detach().cpu().numpy().tolist())
    
print(len(preds))
assert len(preds) == joined.shape[0]

joined['preds_contextbert4rec'] = preds
joined['recs_contextbert4rec_10'] = joined['preds_contextbert4rec'].apply(lambda x: np.argsort(-np.array(x))[:10])
joined['recs_contextbert4rec_10'] = joined['recs_contextbert4rec_10'].apply(lambda x: [idx2item[t-1] for t in x])
joined['recs_contextbert4rec_5'] = joined['preds_contextbert4rec'].apply(lambda x: np.argsort(-np.array(x))[:5])
joined['recs_contextbert4rec_5'] = joined['recs_contextbert4rec_5'].apply(lambda x: [idx2item[t-1] for t in x])
joined.drop(['preds_contextbert4rec'],axis=1, inplace=True)

101it [00:06, 15.94it/s]


6040


In [26]:
evaluate_recommender(joined, model_preds='recs_contextbert4rec_10')

{'ndcg': 0.15325097636829188, 'recall': 0.03236868958740314}

In [27]:
evaluate_recommender(joined, model_preds='recs_contextbert4rec_5')

{'ndcg': 0.0935407587039434, 'recall': 0.017191631300404735}

### Full context (hours + day of the week + date)

In [45]:
df['date'] = pd.to_datetime(df['timestamp'].dt.date).astype('int64')
map_dates = {k: i for i, k in enumerate(df.date.unique())}
df['date'] = df['date'].map(map_dates)
df.head()

Unnamed: 0,user_id,item_id,rating,timestamp,weekday,hour,date
0,1,1193,5,2000-12-31 22:12:40,6,22,0
1,1,661,3,2000-12-31 22:35:09,6,22,0
2,1,914,3,2000-12-31 22:32:48,6,22,0
3,1,3408,4,2000-12-31 22:04:35,6,22,0
4,1,2355,5,2001-01-06 23:38:11,5,23,1


In [46]:
n_dates = df.date.nunique()

In [47]:
splitter = RandomSplit(test_fraction=0.2)
train_df, valid_df, test_df = splitter(df)

In [49]:
train_grouped = train_df.groupby('user_id').apply(
    lambda x: [(t1, t2, t3, t4, t5) for t1, t2, t3, t4, t5 in sorted(zip(x.item_id, 
                                                                 x.timestamp,
                                                                 x.date,
                                                                 x.weekday,
                                                                 x.hour), key=lambda x: x[1])]
).reset_index()
train_grouped.rename({0:'train_interactions'}, axis=1, inplace=True)

valid_grouped = valid_df.groupby('user_id').apply(
    lambda x: [(t1, t2, t3, t4, t5) for t1, t2, t3, t4, t5 in sorted(zip(x.item_id,
                                                         x.timestamp,
                                                         x.date,
                                                         x.weekday,
                                                         x.hour), key=lambda x: x[1])]
).reset_index()
valid_grouped.rename({0:'valid_interactions'}, axis=1, inplace=True)

test_grouped = test_df.groupby('user_id').apply(
    lambda x: [(t1, t2, t3, t4, t5) for t1, t2, t3, t4, t5 in sorted(zip(x.item_id,
                                                         x.timestamp,
                                                         x.date,
                                                         x.weekday,
                                                         x.hour), key=lambda x: x[1])]
).reset_index()
test_grouped.rename({0:'test_interactions'}, axis=1, inplace=True)

In [50]:
joined = train_grouped.merge(valid_grouped).merge(test_grouped)
joined.head()

Unnamed: 0,user_id,train_interactions,valid_interactions,test_interactions
0,1,"[(3186, 2000-12-31 22:00:19, 0, 6, 22), (1270,...","[(2791, 2000-12-31 22:36:28, 0, 6, 22), (2321,...","[(2687, 2001-01-06 23:37:48, 1, 5, 23), (745, ..."
1,2,"[(1198, 2000-12-31 21:28:44, 0, 6, 21), (1210,...","[(2028, 2000-12-31 21:56:13, 0, 6, 21), (2571,...","[(1372, 2000-12-31 21:59:01, 0, 6, 21), (1552,..."
2,3,"[(593, 2000-12-31 21:10:18, 0, 6, 21), (2858, ...","[(648, 2000-12-31 21:24:27, 0, 6, 21), (2735, ...","[(1270, 2000-12-31 21:30:31, 0, 6, 21), (1079,..."
3,4,"[(1210, 2000-12-31 20:18:44, 0, 6, 20), (1097,...","[(2947, 2000-12-31 20:23:50, 0, 6, 20), (1214,...","[(1240, 2000-12-31 20:24:20, 0, 6, 20), (2951,..."
4,5,"[(2717, 2000-12-31 05:37:52, 0, 6, 5), (908, 2...","[(2323, 2000-12-31 06:50:45, 0, 6, 6), (272, 2...","[(1715, 2000-12-31 06:58:11, 0, 6, 6), (1653, ..."


In [56]:
class MyDataset(Dataset):
    
    def __init__(self, ds, num_items, item2idx, phase='valid', N=200):
        super().__init__()
        self.ds = ds
        self.phase = phase
        self.n_items = num_items
        self.item2idx = item2idx
        self.N = N 
        
    def __len__(self):
        return len(self.ds)
    
    def __getitem__(self, idx):
        
        row = self.ds.iloc[idx]
        
        x_input = np.zeros(self.n_items+1)
        x_input[[self.item2idx[x[0]]+1 for x in row['train_interactions'] if x[0] in self.item2idx]] = 1
        
        date = [x[2] for x in row['train_interactions'] if x[0] in self.item2idx][-self.N+1:]
        days_of_weeks = [x[3] for x in row['train_interactions'] if x[0] in self.item2idx][-self.N+1:]
        hours = [x[4] for x in row['train_interactions'] if x[0] in self.item2idx][-self.N+1:]
        
        seq_input = [self.item2idx[x[0]]+1 for x in row['train_interactions'] if x[0] in self.item2idx][-self.N+1:]
        
        targets = np.zeros(self.n_items+1)
        
        date_valid = row['valid_interactions'][0][2]
        date_test = row['test_interactions'][0][2]
        
        dow_valid = row['valid_interactions'][0][3]
        dow_test = row['test_interactions'][0][3]
        
        hours_valid = row['valid_interactions'][0][4]
        hours_test = row['test_interactions'][0][4]
        
        if self.phase == 'train':
            return (seq_input, date, days_of_weeks, hours, date_valid, dow_valid, hours_valid)
        elif self.phase == 'valid':
            targets[[self.item2idx[x[0]]+1 for x in row['valid_interactions'] if x[0] in self.item2idx]] = 1
        else:
            return (seq_input, date, days_of_weeks, hours, date_test, dow_test, hours_test)
            
        return (targets, seq_input, date, days_of_weeks, hours, date_valid, dow_valid, hours_valid)

In [57]:
n_items = len(item2idx)

train = MyDataset(ds=joined,
                  num_items=n_items, 
                  item2idx=item2idx,
                  phase='train')

valid = MyDataset(ds=joined,
                  num_items=n_items,
                  item2idx=item2idx,
                  phase='valid')

print(len(train),len(valid))

6040 6040


In [58]:
def collate_fn_train(batch: List[Tuple[torch.Tensor]]) -> Dict[str, torch.Tensor]: 
    
    seq_i,date, days_of_weeks,hours,date_valid,dow_valid,hours_valid = zip(*batch)
    
    seq_len = torch.Tensor([len(x) for x in seq_i])
    date_valid = torch.Tensor([x for x in date_valid])
    dow_valid = torch.Tensor([x for x in dow_valid])
    hours_valid = torch.Tensor([x for x in hours_valid])
    seq_i = pad_sequence([torch.Tensor(t) for t in seq_i]).T    
    days_of_weeks = pad_sequence([torch.Tensor(t) for t in days_of_weeks]).T
    hours = pad_sequence([torch.Tensor(t) for t in hours]).T
    date = pad_sequence([torch.Tensor(t) for t in date]).T
    
    return {'seq_i': seq_i, 
            'seq_len':seq_len,
            'date': date,
            'dow': days_of_weeks,
            'hours': hours,
            'date_valid': date_valid,
            'dow_valid': dow_valid,
            'hours_valid': hours_valid}


def collate_fn_valid(batch: List[Tuple[torch.Tensor]]) -> Dict[str, torch.Tensor]:
    
    y, seq_i, date, days_of_weeks, hours, date_valid, dow_valid, hours_valid = zip(*batch)
    
    seq_len = torch.Tensor([len(x) for x in seq_i]).long()
    seq_i = pad_sequence([torch.Tensor(t) for t in seq_i]).T.long()
    date = pad_sequence([torch.Tensor(t) for t in date]).T.long()
    days_of_weeks = pad_sequence([torch.Tensor(t) for t in days_of_weeks]).T.long()
    hours = pad_sequence([torch.Tensor(t) for t in hours]).T.long()
    date_valid = torch.Tensor([x for x in date_valid])
    dow_valid = torch.Tensor([x for x in dow_valid])
    hours_valid = torch.Tensor([x for x in hours_valid])
            
    targets = pad_sequence([torch.Tensor(t) for t in y]).T

    return {"targets": targets,
            'seq_i': seq_i,
            'seq_len':seq_len,
            'date': date,
            'dow': days_of_weeks,
            'hours': hours,
            'date_valid': date_valid,
            'dow_valid': dow_valid,
            'hours_valid': hours_valid}

In [59]:

class ContextBERT4Rec(BERT4Rec):

    def __init__(self, n_items, n_dates, hidden_size, mask_ratio):
        super(BERT4Rec, self).__init__()
        
        self.n_layers = 2
        self.n_heads = 2
        self.hidden_size = hidden_size  
        self.inner_size = 128 
        self.hidden_dropout_prob = 0.2
        self.attn_dropout_prob = 0.2
        self.hidden_act = 'sigmoid'
        self.layer_norm_eps = 1e-5
        self.ITEM_SEQ = 'seq_i'
        self.ITEM_SEQ_LEN = 'seq_len'
        self.max_seq_length = 200
        

        self.mask_ratio = mask_ratio

        self.loss_type =  'CE'
        self.initializer_range = 1e-2

        # load dataset info
        self.n_items = n_items
        self.n_dates = n_dates
        self.mask_token = self.n_items
        self.mask_item_length = int(self.mask_ratio * self.max_seq_length)

        # define layers and loss
        self.weekday_embedding = nn.Embedding(7, self.hidden_size)
        self.hours_embedding = nn.Embedding(24, self.hidden_size)
        self.item_embedding = nn.Embedding(self.n_items + 1, self.hidden_size, padding_idx=0)  
        self.date_embedding = nn.Embedding(self.n_dates + 1, self.hidden_size)
        self.position_embedding = nn.Embedding(self.max_seq_length + 1, self.hidden_size)  
        self.trm_encoder = TransformerEncoder(
            n_layers=self.n_layers,
            n_heads=self.n_heads,
            hidden_size=self.hidden_size,
            inner_size=self.inner_size,
            hidden_dropout_prob=self.hidden_dropout_prob,
            attn_dropout_prob=self.attn_dropout_prob,
            hidden_act=self.hidden_act,
            layer_norm_eps=self.layer_norm_eps
        )

        self.LayerNorm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
        self.dropout = nn.Dropout(self.hidden_dropout_prob)

        try:
            assert self.loss_type in ['BPR', 'CE']
        except AssertionError:
            raise AssertionError("Make sure 'loss_type' in ['BPR', 'CE']!")

        self.apply(self._init_weights)



    def reconstruct_test_data(self,
                              item_seq,
                              item_seq_len,
                              date,
                              dow,
                              hours,
                              date_valid,
                              dow_valid,
                              hours_valid,
                              particular_day=-1,
                              ):
        """
        Add mask token at the last position according to the lengths of item_seq
        """
        padding = torch.zeros(item_seq.size(0), dtype=torch.long, device=item_seq.device)  # [B]
        item_seq = torch.cat((item_seq, padding.unsqueeze(-1)), dim=-1)  # [B max_len+1]
        date = torch.cat((date, padding.unsqueeze(-1)), dim=-1)
        dow = torch.cat((dow, padding.unsqueeze(-1)), dim=-1)
        hours = torch.cat((hours, padding.unsqueeze(-1)), dim=-1)
        for batch_id, last_position in enumerate(item_seq_len):
            item_seq[batch_id][last_position] = self.mask_token
            if particular_day == -1:
                date[batch_id][last_position] = date_valid[batch_id]
                dow[batch_id][last_position] = dow_valid[batch_id]
                hours[batch_id][last_position] = hours_valid[batch_id]
            else:
                date[batch_id][last_position] = particular_day
                dow[batch_id][last_position] = particular_day
                hours[batch_id][last_position] = particular_day
        return item_seq, date, dow, hours

    def forward(self, item_seq, date, dow, hours, return_explanations=False):
        
        
        dow_embeddings = self.weekday_embedding(dow.long())
        hours_embeddings = self.hours_embedding(hours.long())
        date_embeddings = self.date_embedding(date.long())
        
        position_ids = torch.arange(item_seq.size(1), dtype=torch.long, device=item_seq.device)
        position_ids = position_ids.unsqueeze(0).expand_as(item_seq)
        position_embedding = self.position_embedding(position_ids)
        item_emb = self.item_embedding(item_seq)
        input_emb = item_emb + position_embedding + dow_embeddings + hours_embeddings + date_embeddings
        input_emb = self.LayerNorm(input_emb)
        input_emb = self.dropout(input_emb)
        extended_attention_mask = self.get_attention_mask(item_seq)
        if return_explanations:
            trm_output, explanations = self.trm_encoder(input_emb, extended_attention_mask, output_all_encoded_layers=True,
                                         return_explanations=return_explanations)
        else:
            trm_output = self.trm_encoder(input_emb, extended_attention_mask, output_all_encoded_layers=True,
                                         return_explanations=return_explanations)
            
        output = trm_output[-1]
        
        if return_explanations:
            return output, explanations
        else:
            return output


    def calculate_loss(self, interaction):
        item_seq = interaction[self.ITEM_SEQ].long()
        masked_item_seq, pos_items, neg_items, masked_index = self.reconstruct_train_data(item_seq)

        seq_output = self.forward(masked_item_seq, date=interaction['date'], dow=interaction['dow'], hours=interaction['hours'])
        pred_index_map = self.multi_hot_embed(masked_index, masked_item_seq.size(-1))  
        pred_index_map = pred_index_map.view(masked_index.size(0), masked_index.size(1), -1)  
        seq_output = torch.bmm(pred_index_map, seq_output)  

        if self.loss_type == 'BPR':
            pos_items_emb = self.item_embedding(pos_items)  
            neg_items_emb = self.item_embedding(neg_items)  
            pos_score = torch.sum(seq_output * pos_items_emb, dim=-1) 
            neg_score = torch.sum(seq_output * neg_items_emb, dim=-1)  
            targets = (masked_index > 0).float()
            loss = - torch.sum(torch.log(1e-14 + torch.sigmoid(pos_score - neg_score)) * targets) \
                   / torch.sum(targets)
            return loss

        elif self.loss_type == 'CE':
            loss_fct = nn.CrossEntropyLoss(reduction='none')
            test_item_emb = self.item_embedding.weight[:self.n_items]  
            logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1))  
            targets = (masked_index > 0).float().view(-1) 

            loss = torch.sum(loss_fct(logits.view(-1, test_item_emb.size(0)), pos_items.view(-1)) * targets) \
                   / torch.sum(targets)
            return loss
        else:
            raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!")


    def full_sort_predict(self, 
                          interaction,
                          return_explanations=False,
                          particular_day=-1):
        
        item_seq = interaction[self.ITEM_SEQ].long()
        item_seq_len = interaction[self.ITEM_SEQ_LEN].long()
        item_seq, date, dow, hours = self.reconstruct_test_data(item_seq,
                                              item_seq_len,
                                              date=interaction['date'],
                                              dow=interaction['dow'],
                                              hours=interaction['hours'],
                                              date_valid=interaction['date_valid'].long(),
                                              dow_valid=interaction['dow_valid'].long(),
                                              hours_valid=interaction['hours_valid'].long(),
                                              particular_day=particular_day)
        
        
        if return_explanations:
            seq_output, expl = self.forward(item_seq,
                                            date=date,
                                            dow=dow,
                                            hours=hours,
                                            return_explanations=return_explanations)
        else:
            seq_output = self.forward(item_seq,
                                      date=date,
                                      dow=dow,
                                      hours=hours,
                                      return_explanations=return_explanations)
            
        
        seq_output = self.gather_indexes(seq_output, item_seq_len - 1) 
        test_items_emb = self.item_embedding.weight[:self.n_items]  
        scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1))  
                
        idxs = item_seq.nonzero()
        item_seq[item_seq==self.n_items] = 0
        scores[idxs[:,0], item_seq[idxs[:,0],idxs[:,1]].long()] = -1000

        if return_explanations:
            return scores, expl
        else:
            return scores

In [60]:
loaders = {
        "train": DataLoader(train, batch_size=256, collate_fn=collate_fn_train),
        "valid": DataLoader(valid, batch_size=256, collate_fn=collate_fn_valid),
}

In [63]:
model = ContextBERT4Rec(n_items=len(item2idx)+1, n_dates=n_dates, mask_ratio=0.2, hidden_size=128)

optimizer = optim.Adam(model.parameters(), lr=0.01)
lr_scheduler = StepLR(optimizer, step_size=20, gamma=0.1)
engine = dl.DeviceEngine('cpu')
hparams = {
    "anneal_cap": 0.2,
    "total_anneal_steps": 6000,
}


callbacks = [
    dl.NDCGCallback("logits", "targets", [20]),
    dl.MAPCallback("logits", "targets", [10]),
    dl.OptimizerCallback("loss", accumulation_steps=1),
    dl.EarlyStoppingCallback(
        patience=5, loader_key="valid", metric_key="map10", minimize=False
    )
]


runner = RecSysRunner()
runner.train(
    model=model,
    optimizer=optimizer,
    engine=engine,
    hparams=hparams,
    scheduler=lr_scheduler,
    loaders=loaders,
    num_epochs=100,
    verbose=True,
    timeit=True,
    callbacks=callbacks,
)


1/100 * Epoch (train): 100%|██████████| 24/24 [00:27<00:00,  1.16s/it, _timer/_fps=173.094, _timer/batch_time=0.878, _timer/data_time=0.519, _timer/model_time=0.359, loss=7.394, lr=0.010, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]


train (1/100) loss: 7.536861446835347 | lr: 0.01 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


1/100 * Epoch (valid): 100%|██████████| 24/24 [00:13<00:00,  1.78it/s, _timer/_fps=354.476, _timer/batch_time=0.429, _timer/data_time=0.026, _timer/model_time=0.403, loss=7.323, lr=0.010, map10=0.102, momentum=0.900, ndcg20=0.059]


valid (1/100) loss: 7.427616329066801 | lr: 0.01 | map10: 0.12805838611544365 | map10/std: 0.018024661106085162 | momentum: 0.9 | ndcg20: 0.06925669961812482 | ndcg20/std: 0.007266366313612276
* Epoch (1/100) 


2/100 * Epoch (train): 100%|██████████| 24/24 [00:31<00:00,  1.32s/it, _timer/_fps=178.983, _timer/batch_time=0.849, _timer/data_time=0.486, _timer/model_time=0.363, loss=7.345, lr=0.010, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]


train (2/100) loss: 7.444541586945389 | lr: 0.01 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


2/100 * Epoch (valid): 100%|██████████| 24/24 [00:12<00:00,  1.87it/s, _timer/_fps=449.461, _timer/batch_time=0.338, _timer/data_time=0.021, _timer/model_time=0.318, loss=7.319, lr=0.010, map10=0.109, momentum=0.900, ndcg20=0.062]


valid (2/100) loss: 7.41803659919082 | lr: 0.01 | map10: 0.13242110419549685 | map10/std: 0.016910505083480264 | momentum: 0.9 | ndcg20: 0.07515916410364853 | ndcg20/std: 0.007424306428190924
* Epoch (2/100) 


3/100 * Epoch (train): 100%|██████████| 24/24 [00:26<00:00,  1.10s/it, _timer/_fps=168.175, _timer/batch_time=0.904, _timer/data_time=0.532, _timer/model_time=0.372, loss=7.364, lr=0.010, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]


train (3/100) loss: 7.4214393716774225 | lr: 0.01 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


3/100 * Epoch (valid): 100%|██████████| 24/24 [00:12<00:00,  1.88it/s, _timer/_fps=445.437, _timer/batch_time=0.341, _timer/data_time=0.021, _timer/model_time=0.320, loss=7.335, lr=0.010, map10=0.101, momentum=0.900, ndcg20=0.064]


valid (3/100) loss: 7.388891737982138 | lr: 0.01 | map10: 0.13336395080515878 | map10/std: 0.017215231857943103 | momentum: 0.9 | ndcg20: 0.07408513371518118 | ndcg20/std: 0.007605365468550447
* Epoch (3/100) 


4/100 * Epoch (train): 100%|██████████| 24/24 [00:25<00:00,  1.06s/it, _timer/_fps=185.563, _timer/batch_time=0.819, _timer/data_time=0.458, _timer/model_time=0.361, loss=7.334, lr=0.010, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]


train (4/100) loss: 7.4067662971698685 | lr: 0.01 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


4/100 * Epoch (valid): 100%|██████████| 24/24 [00:12<00:00,  1.90it/s, _timer/_fps=427.828, _timer/batch_time=0.355, _timer/data_time=0.021, _timer/model_time=0.334, loss=7.346, lr=0.010, map10=0.117, momentum=0.900, ndcg20=0.067]


valid (4/100) loss: 7.400157096685952 | lr: 0.01 | map10: 0.1368013177585128 | map10/std: 0.016712762275400046 | momentum: 0.9 | ndcg20: 0.07734376441761358 | ndcg20/std: 0.00803253663960003
* Epoch (4/100) 


5/100 * Epoch (train): 100%|██████████| 24/24 [00:26<00:00,  1.09s/it, _timer/_fps=174.452, _timer/batch_time=0.871, _timer/data_time=0.501, _timer/model_time=0.370, loss=7.339, lr=0.010, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]


train (5/100) loss: 7.378267670940881 | lr: 0.01 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


5/100 * Epoch (valid): 100%|██████████| 24/24 [00:12<00:00,  1.88it/s, _timer/_fps=460.579, _timer/batch_time=0.330, _timer/data_time=0.020, _timer/model_time=0.310, loss=7.372, lr=0.010, map10=0.108, momentum=0.900, ndcg20=0.061]


valid (5/100) loss: 7.362623564612787 | lr: 0.01 | map10: 0.12863219465246267 | map10/std: 0.016430418009910335 | momentum: 0.9 | ndcg20: 0.07577643458120871 | ndcg20/std: 0.007428773133592136
* Epoch (5/100) 


6/100 * Epoch (train): 100%|██████████| 24/24 [00:26<00:00,  1.10s/it, _timer/_fps=183.210, _timer/batch_time=0.830, _timer/data_time=0.466, _timer/model_time=0.364, loss=7.324, lr=0.010, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]


train (6/100) loss: 7.371885016106611 | lr: 0.01 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


6/100 * Epoch (valid): 100%|██████████| 24/24 [00:12<00:00,  1.89it/s, _timer/_fps=455.050, _timer/batch_time=0.334, _timer/data_time=0.021, _timer/model_time=0.313, loss=7.353, lr=0.010, map10=0.129, momentum=0.900, ndcg20=0.069]


valid (6/100) loss: 7.398236951764845 | lr: 0.01 | map10: 0.13696408808626087 | map10/std: 0.015109525812340294 | momentum: 0.9 | ndcg20: 0.07902482970068786 | ndcg20/std: 0.007289206484509295
* Epoch (6/100) 


7/100 * Epoch (train): 100%|██████████| 24/24 [00:26<00:00,  1.09s/it, _timer/_fps=178.209, _timer/batch_time=0.853, _timer/data_time=0.492, _timer/model_time=0.361, loss=7.330, lr=0.010, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]


train (7/100) loss: 7.364537566229207 | lr: 0.01 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


7/100 * Epoch (valid): 100%|██████████| 24/24 [00:13<00:00,  1.72it/s, _timer/_fps=349.563, _timer/batch_time=0.435, _timer/data_time=0.023, _timer/model_time=0.411, loss=7.382, lr=0.010, map10=0.119, momentum=0.900, ndcg20=0.071]


valid (7/100) loss: 7.377548524086049 | lr: 0.01 | map10: 0.13614802711846813 | map10/std: 0.01717392031770509 | momentum: 0.9 | ndcg20: 0.07831891462305524 | ndcg20/std: 0.008011540049322263
* Epoch (7/100) 


8/100 * Epoch (train): 100%|██████████| 24/24 [00:27<00:00,  1.16s/it, _timer/_fps=182.632, _timer/batch_time=0.832, _timer/data_time=0.472, _timer/model_time=0.361, loss=7.309, lr=0.010, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]


train (8/100) loss: 7.353062136441666 | lr: 0.01 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


8/100 * Epoch (valid): 100%|██████████| 24/24 [00:13<00:00,  1.78it/s, _timer/_fps=403.384, _timer/batch_time=0.377, _timer/data_time=0.024, _timer/model_time=0.352, loss=7.309, lr=0.010, map10=0.130, momentum=0.900, ndcg20=0.071]


valid (8/100) loss: 7.3192568778991705 | lr: 0.01 | map10: 0.14380098845785028 | map10/std: 0.015081447412352855 | momentum: 0.9 | ndcg20: 0.08099902511037739 | ndcg20/std: 0.007553835326096513
* Epoch (8/100) 


9/100 * Epoch (train): 100%|██████████| 24/24 [00:28<00:00,  1.18s/it, _timer/_fps=180.187, _timer/batch_time=0.844, _timer/data_time=0.477, _timer/model_time=0.366, loss=7.297, lr=0.010, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]


train (9/100) loss: 7.35384136983101 | lr: 0.01 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


9/100 * Epoch (valid): 100%|██████████| 24/24 [00:13<00:00,  1.79it/s, _timer/_fps=418.940, _timer/batch_time=0.363, _timer/data_time=0.024, _timer/model_time=0.339, loss=7.288, lr=0.010, map10=0.117, momentum=0.900, ndcg20=0.066]


valid (9/100) loss: 7.3198788282887035 | lr: 0.01 | map10: 0.13838831229715157 | map10/std: 0.017002189835316494 | momentum: 0.9 | ndcg20: 0.07992180320403434 | ndcg20/std: 0.0076521321793722495
* Epoch (9/100) 


10/100 * Epoch (train): 100%|██████████| 24/24 [00:29<00:00,  1.22s/it, _timer/_fps=140.758, _timer/batch_time=1.080, _timer/data_time=0.673, _timer/model_time=0.407, loss=7.281, lr=0.010, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]


train (10/100) loss: 7.34513983694923 | lr: 0.01 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


10/100 * Epoch (valid): 100%|██████████| 24/24 [00:13<00:00,  1.75it/s, _timer/_fps=451.729, _timer/batch_time=0.336, _timer/data_time=0.023, _timer/model_time=0.313, loss=7.265, lr=0.010, map10=0.128, momentum=0.900, ndcg20=0.068]


valid (10/100) loss: 7.30663888943906 | lr: 0.01 | map10: 0.14827104905583208 | map10/std: 0.014963807276094919 | momentum: 0.9 | ndcg20: 0.08204439229128377 | ndcg20/std: 0.007086015865302367
* Epoch (10/100) 


11/100 * Epoch (train): 100%|██████████| 24/24 [00:28<00:00,  1.19s/it, _timer/_fps=179.093, _timer/batch_time=0.849, _timer/data_time=0.460, _timer/model_time=0.389, loss=7.268, lr=0.010, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]


train (11/100) loss: 7.340847535796513 | lr: 0.01 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


11/100 * Epoch (valid): 100%|██████████| 24/24 [00:13<00:00,  1.81it/s, _timer/_fps=392.955, _timer/batch_time=0.387, _timer/data_time=0.026, _timer/model_time=0.361, loss=7.261, lr=0.010, map10=0.107, momentum=0.900, ndcg20=0.066]


valid (11/100) loss: 7.318968758362018 | lr: 0.01 | map10: 0.13772945119845162 | map10/std: 0.016152898229630758 | momentum: 0.9 | ndcg20: 0.07973823644072806 | ndcg20/std: 0.007371894774589045
* Epoch (11/100) 


12/100 * Epoch (train): 100%|██████████| 24/24 [00:27<00:00,  1.15s/it, _timer/_fps=186.073, _timer/batch_time=0.817, _timer/data_time=0.438, _timer/model_time=0.378, loss=7.264, lr=0.010, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]


train (12/100) loss: 7.338669200922481 | lr: 0.01 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


12/100 * Epoch (valid): 100%|██████████| 24/24 [00:12<00:00,  1.85it/s, _timer/_fps=457.571, _timer/batch_time=0.332, _timer/data_time=0.022, _timer/model_time=0.310, loss=7.252, lr=0.010, map10=0.123, momentum=0.900, ndcg20=0.070]


valid (12/100) loss: 7.317331356402264 | lr: 0.01 | map10: 0.14219064494434572 | map10/std: 0.01487801719848606 | momentum: 0.9 | ndcg20: 0.07935756627494928 | ndcg20/std: 0.006998321477483123
* Epoch (12/100) 


13/100 * Epoch (train): 100%|██████████| 24/24 [00:26<00:00,  1.11s/it, _timer/_fps=182.742, _timer/batch_time=0.832, _timer/data_time=0.473, _timer/model_time=0.359, loss=7.292, lr=0.010, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]


train (13/100) loss: 7.338112532381979 | lr: 0.01 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


13/100 * Epoch (valid): 100%|██████████| 24/24 [00:13<00:00,  1.84it/s, _timer/_fps=419.391, _timer/batch_time=0.362, _timer/data_time=0.022, _timer/model_time=0.340, loss=7.281, lr=0.010, map10=0.119, momentum=0.900, ndcg20=0.073]


valid (13/100) loss: 7.309316902918532 | lr: 0.01 | map10: 0.147451419151382 | map10/std: 0.01805267610169488 | momentum: 0.9 | ndcg20: 0.08181477647940844 | ndcg20/std: 0.008235468722713152
* Epoch (13/100) 


14/100 * Epoch (train): 100%|██████████| 24/24 [00:26<00:00,  1.11s/it, _timer/_fps=177.162, _timer/batch_time=0.858, _timer/data_time=0.494, _timer/model_time=0.364, loss=7.294, lr=0.010, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]


train (14/100) loss: 7.3367607628272875 | lr: 0.01 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


14/100 * Epoch (valid): 100%|██████████| 24/24 [00:12<00:00,  1.86it/s, _timer/_fps=436.956, _timer/batch_time=0.348, _timer/data_time=0.021, _timer/model_time=0.327, loss=7.279, lr=0.010, map10=0.120, momentum=0.900, ndcg20=0.067]


valid (14/100) loss: 7.3171570784208795 | lr: 0.01 | map10: 0.14290125070028745 | map10/std: 0.016033606415020107 | momentum: 0.9 | ndcg20: 0.07931029861731245 | ndcg20/std: 0.00751089527925184
* Epoch (14/100) 


15/100 * Epoch (train): 100%|██████████| 24/24 [00:26<00:00,  1.09s/it, _timer/_fps=190.844, _timer/batch_time=0.796, _timer/data_time=0.437, _timer/model_time=0.360, loss=7.321, lr=0.010, map10=0.000e+00, momentum=0.900, ndcg20=0.000e+00]


train (15/100) loss: 7.354373474626351 | lr: 0.01 | map10: 0.0 | map10/std: 0.0 | momentum: 0.9 | ndcg20: 0.0 | ndcg20/std: 0.0


15/100 * Epoch (valid): 100%|██████████| 24/24 [00:12<00:00,  1.90it/s, _timer/_fps=412.768, _timer/batch_time=0.368, _timer/data_time=0.024, _timer/model_time=0.345, loss=7.317, lr=0.010, map10=0.131, momentum=0.900, ndcg20=0.070]

valid (15/100) loss: 7.334956389547184 | lr: 0.01 | map10: 0.14315529793303533 | map10/std: 0.017343349805413046 | momentum: 0.9 | ndcg20: 0.08156047558152912 | ndcg20/std: 0.008565876107133798
* Epoch (15/100) 





In [64]:
test_dataset = MyDataset(ds=joined, num_items=n_items, phase='test',item2idx=item2idx)


inference_loader = DataLoader(test_dataset, 
                              batch_size=joined.shape[0]//100, 
                              collate_fn=collate_fn_train,)

preds = []

for prediction in tqdm(runner.predict_loader(loader=inference_loader)):
    preds.extend(prediction.detach().cpu().numpy().tolist())
    
print(len(preds))
assert len(preds) == joined.shape[0]

joined['preds_contextbert4rec'] = preds
joined['recs_contextbert4rec_10'] = joined['preds_contextbert4rec'].apply(lambda x: np.argsort(-np.array(x))[:10])
joined['recs_contextbert4rec_10'] = joined['recs_contextbert4rec_10'].apply(lambda x: [idx2item[t-1] for t in x])
joined['recs_contextbert4rec_5'] = joined['preds_contextbert4rec'].apply(lambda x: np.argsort(-np.array(x))[:5])
joined['recs_contextbert4rec_5'] = joined['recs_contextbert4rec_5'].apply(lambda x: [idx2item[t-1] for t in x])
joined.drop(['preds_contextbert4rec'],axis=1, inplace=True)

101it [00:06, 16.11it/s]


6040


In [65]:
evaluate_recommender(joined, model_preds='recs_contextbert4rec_10')

{'ndcg': 0.14897195618042602, 'recall': 0.031896361990104}

In [66]:
evaluate_recommender(joined, model_preds='recs_contextbert4rec_5')

{'ndcg': 0.0911531299562371, 'recall': 0.01706607886901304}