In [2]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/recsys-hse/user_features.csv
/kaggle/input/recsys-hse/item_features.csv
/kaggle/input/recsys-hse/SASRec.yaml
/kaggle/input/recsys-hse/submission_sample.csv
/kaggle/input/recsys-hse/BERT4Rec.yaml
/kaggle/input/recsys-hse/events.csv
/kaggle/input/lkllklklk/submission_best.csv


In [None]:
pip install -q hydra-core --upgrade pytorch_lightning recommenders

In [3]:
import os
import sys
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
sys.path.append('../')

In [4]:
"""
Torch datasets and collate function.
"""

import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset


class LMDataset(Dataset):

    def __init__(self, df, max_length=128, num_negatives=None, full_negative_sampling=True,
                 user_col='user_id', item_col='item_id', time_col='time_idx'):

        self.max_length = max_length
        self.num_negatives = num_negatives
        self.full_negative_sampling = full_negative_sampling
        self.user_col = user_col
        self.item_col = item_col
        self.time_col = time_col

        self.data = df.sort_values(time_col).groupby(user_col)[item_col].agg(list).to_dict()
        self.user_ids = list(self.data.keys())

        if num_negatives:
            self.all_items = df[item_col].unique()

    def __len__(self):

        return len(self.data)

    def sample_negatives(self, item_sequence):

        negatives = self.all_items[~np.isin(self.all_items, item_sequence)]
        if self.full_negative_sampling:
            negatives = np.random.choice(
                negatives, size=self.num_negatives * (len(item_sequence) - 1), replace=True)
            negatives = negatives.reshape(len(item_sequence) - 1, self.num_negatives)
        else:
            negatives = np.random.choice(negatives, size=self.num_negatives, replace=False)

        return negatives


class CausalLMDataset(LMDataset):

    def __init__(self, df, max_length=128, num_negatives=None, full_negative_sampling=True,
                 user_col='user_id', item_col='item_id', time_col='time_idx',
                 label_masking_probability=0):

        super().__init__(df, max_length, num_negatives, full_negative_sampling,
                         user_col, item_col, time_col)

        self.label_masking_probability = label_masking_probability

    def __getitem__(self, idx):

        item_sequence = self.data[self.user_ids[idx]]

        if len(item_sequence) > self.max_length + 1:
            item_sequence = item_sequence[-self.max_length - 1:]

        input_ids = np.array(item_sequence[:-1])
        labels = np.array(item_sequence[1:])

        # for testing how masking labels influence performance
        if self.label_masking_probability > 0:
            mask = np.random.rand(len(labels)) < self.label_masking_probability
            labels[mask] = -100

        if self.num_negatives:
            negatives = self.sample_negatives(item_sequence)
            return {'input_ids': input_ids, 'labels': labels, 'negatives': negatives}

        return {'input_ids': input_ids, 'labels': labels}


class CausalLMPredictionDataset(LMDataset):

    def __init__(self, df, max_length=128, validation_mode=False,
                 user_col='user_id', item_col='item_id',
                 time_col='time_idx'):

        super().__init__(df, max_length=max_length, num_negatives=None,
                         user_col=user_col, item_col=item_col, time_col=time_col)

        self.validation_mode = validation_mode

    def __getitem__(self, idx):

        user_id = self.user_ids[idx]
        item_sequence = self.data[user_id]

        if self.validation_mode:
            target = item_sequence[-1]
            input_ids = item_sequence[-self.max_length-1:-1]
            item_sequence = item_sequence[:-1]

            return {'input_ids': input_ids, 'user_id': user_id,
                    'full_history': item_sequence, 'target': target}
        else:
            input_ids = item_sequence[-self.max_length:]

            return {'input_ids': input_ids, 'user_id': user_id,
                    'full_history': item_sequence}


class MaskedLMDataset(LMDataset):

    def __init__(self, df, max_length=128,
                 num_negatives=None, full_negative_sampling=True,
                 mlm_probability=0.2,
                 masking_value=1, ignore_value=-100,
                 force_last_item_masking_prob=0,
                 user_col='user_id', item_col='item_id',
                 time_col='time_idx'):

        super().__init__(df, max_length, num_negatives, full_negative_sampling,
                         user_col, item_col, time_col)

        self.mlm_probability = mlm_probability
        self.masking_value = masking_value
        self.ignore_value = ignore_value
        self.force_last_item_masking_prob = force_last_item_masking_prob

    def __getitem__(self, idx):

        item_sequence = self.data[self.user_ids[idx]]

        if len(item_sequence) > self.max_length:
            item_sequence = item_sequence[-self.max_length:]

        input_ids = np.array(item_sequence)
        mask = np.random.rand(len(item_sequence)) < self.mlm_probability
        input_ids[mask] = self.masking_value
        if self.force_last_item_masking_prob > 0:
            if np.random.rand() < self.force_last_item_masking_prob:
                input_ids[-1] = self.masking_value

        labels = np.array(item_sequence)
        labels[input_ids != self.masking_value] = self.ignore_value

        if self.num_negatives:
            negatives = self.sample_negatives(item_sequence)
            return {'input_ids': input_ids, 'labels': labels, 'negatives': negatives}

        return {'input_ids': input_ids, 'labels': labels}


class MaskedLMPredictionDataset(LMDataset):

    def __init__(self, df, max_length=128, masking_value=1,
                 validation_mode=False,
                 user_col='user_id', item_col='item_id',
                 time_col='time_idx'):

        super().__init__(df, max_length=max_length, num_negatives=None,
                         user_col=user_col, item_col=item_col, time_col=time_col)

        self.masking_value = masking_value
        self.validation_mode = validation_mode

    def __getitem__(self, idx):

        user_id = self.user_ids[idx]
        item_sequence = self.data[user_id]

        if self.validation_mode:
            target = item_sequence[-1]
            input_ids = item_sequence[-self.max_length:-1]
            item_sequence = item_sequence[:-1]
        else:
            input_ids = item_sequence[-self.max_length + 1:]

        input_ids += [self.masking_value]

        if self.validation_mode:
            return {'input_ids': input_ids, 'user_id': user_id,
                    'full_history': item_sequence, 'target': target}
        else:
            return {'input_ids': input_ids, 'user_id': user_id,
                    'full_history': item_sequence}


class PaddingCollateFn:

    def __init__(self, padding_value=0, labels_padding_value=-100):

        self.padding_value = padding_value
        self.labels_padding_value = labels_padding_value

    def __call__(self, batch):

        collated_batch = {}

        for key in batch[0].keys():

            if np.isscalar(batch[0][key]):
                collated_batch[key] = torch.tensor([example[key] for example in batch])
                continue

            if key == 'labels':
                padding_value = self.labels_padding_value
            else:
                padding_value = self.padding_value
            values = [torch.tensor(example[key]) for example in batch]
            collated_batch[key] = pad_sequence(values, batch_first=True,
                                               padding_value=padding_value)

        if 'input_ids' in collated_batch:
            attention_mask = collated_batch['input_ids'] != self.padding_value
            collated_batch['attention_mask'] = attention_mask.to(dtype=torch.float32)  

        return collated_batch

In [5]:
"""
Metrics.
"""

import numpy as np
import torch
from recommenders.evaluation.python_evaluation import map_at_k, ndcg_at_k, recall_at_k
from tqdm.auto import tqdm


def compute_metrics(ground_truth, preds, k=10):

    if not hasattr(ground_truth, 'rating'):
        ground_truth = ground_truth.assign(rating=1)

    # when we have 1 true positive, HitRate == Recall and MRR == MAP
    metrics = {
        f'ndcg@{k}': ndcg_at_k(ground_truth, preds, col_user='user_id', col_item='item_id',
                               col_prediction='prediction', col_rating='rating', k=k),
        f'hit_rate@{k}': recall_at_k(ground_truth, preds, col_user='user_id', col_item='item_id',
                                     col_prediction='prediction', col_rating='rating', k=k),
        f'mrr@{k}': map_at_k(ground_truth, preds, col_user='user_id', col_item='item_id',
                             col_prediction='prediction', col_rating='rating', k=k),
        f'rec@{k}': recall_at_k(ground_truth, preds, col_user='user_id', col_item='item_id',
                             col_prediction='prediction', col_rating='rating', k=k)
    }

    return metrics


def compute_sampled_metrics(seqrec_module, predict_dataset, test, item_counts,
                            popularity_sampling=True, num_negatives=100, k=10,
                            device='cuda'):

    test = test.set_index('user_id')['item_id'].to_dict()
    all_items = item_counts.index.values
    item_weights = item_counts.values
    # probabilities = item_weights/item_weights.sum()

    seqrec_module = seqrec_module.eval().to(device)

    ndcg, hit_rate, mrr = 0.0, 0.0, 0.0
    user_count = 0

    for user in tqdm(predict_dataset):

        if user['user_id'] not in test:
            continue

        positive = test[user['user_id']]
        indices = ~np.isin(all_items, user['full_history'])
        negatives = all_items[indices]
        if popularity_sampling:
            probabilities = item_weights[indices]
            probabilities = probabilities/probabilities.sum()
        else:
            probabilities = None
        negatives = np.random.choice(negatives, size=num_negatives,
                                     replace=False, p=probabilities)
        items = np.concatenate([np.array([positive]), negatives])

        batch = {'input_ids': torch.tensor(user['input_ids']).unsqueeze(0).to(device),
                 'attention_mask': torch.tensor([1] * len(user['input_ids'])).unsqueeze(0).to(device)}
        pred = seqrec_module.prediction_output(batch)
        pred = pred[0, -1, items]

        rank = (-pred).argsort().argsort()[0].item() + 1
        if rank <= k:
            ndcg += 1 / np.log2(rank + 1)
            hit_rate += 1
            mrr += 1 / rank
        user_count += 1

    ndcg = ndcg / user_count
    hit_rate = hit_rate / user_count
    mrr = mrr / user_count

    return {'ndcg': ndcg, 'hit_rate': hit_rate, 'mrr': mrr}

In [6]:
"""
Models.
"""

import numpy as np
import torch
from torch import nn
from transformers import BertConfig, BertModel, GPT2Config, GPT2Model


class BERT4Rec(nn.Module):

    def __init__(self, vocab_size, bert_config, add_head=True,
                 tie_weights=True, padding_idx=0, init_std=0.02):

        super().__init__()

        self.vocab_size = vocab_size
        self.bert_config = bert_config
        self.add_head = add_head
        self.tie_weights = tie_weights
        self.padding_idx = padding_idx
        self.init_std = init_std

        self.embed_layer = nn.Embedding(num_embeddings=vocab_size,
                                        embedding_dim=bert_config['hidden_size'],
                                        padding_idx=padding_idx)
        self.transformer_model = BertModel(BertConfig(**bert_config))

        if self.add_head:
            self.head = nn.Linear(bert_config['hidden_size'], vocab_size, bias=False)
            if self.tie_weights:
                self.head.weight = self.embed_layer.weight

        self.init_weights()

    def init_weights(self):

        self.embed_layer.weight.data.normal_(mean=0.0, std=self.init_std)
        if self.padding_idx is not None:
            self.embed_layer.weight.data[self.padding_idx].zero_()

    def forward(self, input_ids, attention_mask):

        embeds = self.embed_layer(input_ids)
        transformer_outputs = self.transformer_model(
            inputs_embeds=embeds, attention_mask=attention_mask)
        outputs = transformer_outputs.last_hidden_state

        if self.add_head:
            outputs = self.head(outputs)

        return outputs


class SASRec(nn.Module):
    """Adaptation of code from
    https://github.com/pmixer/SASRec.pytorch.
    """

    def __init__(self, item_num, maxlen=128, hidden_units=64, num_blocks=1,
                 num_heads=1, dropout_rate=0.1, initializer_range=0.02,
                 add_head=True):

        super(SASRec, self).__init__()

        self.item_num = item_num
        self.maxlen = maxlen
        self.hidden_units = hidden_units
        self.num_blocks = num_blocks
        self.num_heads = num_heads
        self.dropout_rate = dropout_rate
        self.initializer_range = initializer_range
        self.add_head = add_head

        self.item_emb = nn.Embedding(item_num + 1, hidden_units, padding_idx=0)
        self.pos_emb = nn.Embedding(maxlen, hidden_units)
        self.emb_dropout = nn.Dropout(dropout_rate)

        self.attention_layernorms = nn.ModuleList() # to be Q for self-attention
        self.attention_layers = nn.ModuleList()
        self.forward_layernorms = nn.ModuleList()
        self.forward_layers = nn.ModuleList()

        self.last_layernorm = nn.LayerNorm(hidden_units, eps=1e-8)

        for _ in range(num_blocks):
            new_attn_layernorm = nn.LayerNorm(hidden_units, eps=1e-8)
            self.attention_layernorms.append(new_attn_layernorm)

            new_attn_layer = nn.MultiheadAttention(hidden_units,
                                                   num_heads,
                                                   dropout_rate)
            self.attention_layers.append(new_attn_layer)

            new_fwd_layernorm = nn.LayerNorm(hidden_units, eps=1e-8)
            self.forward_layernorms.append(new_fwd_layernorm)

            new_fwd_layer = PointWiseFeedForward(hidden_units, dropout_rate)
            self.forward_layers.append(new_fwd_layer)

        # parameters initialization
        self.apply(self._init_weights)

    def _init_weights(self, module):
        """Initialize weights.

        Examples:
        https://github.com/huggingface/transformers/blob/v4.25.1/src/transformers/models/gpt2/modeling_gpt2.py#L454
        https://recbole.io/docs/_modules/recbole/model/sequential_recommender/sasrec.html#SASRec
        """

        if isinstance(module, (nn.Linear, nn.Conv1d)):
            module.weight.data.normal_(mean=0.0, std=self.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    # parameter attention mask added for compatibility with Lightning module, not used
    def forward(self, input_ids, attention_mask):

        seqs = self.item_emb(input_ids)
        seqs *= self.item_emb.embedding_dim ** 0.5
        positions = np.tile(np.array(range(input_ids.shape[1])), [input_ids.shape[0], 1])
        # need to be on the same device
        seqs += self.pos_emb(torch.LongTensor(positions).to(seqs.device))
        seqs = self.emb_dropout(seqs)

        timeline_mask = torch.Tensor(input_ids == 0)
        seqs *= ~timeline_mask.unsqueeze(-1) # broadcast in last dim

        tl = seqs.shape[1] 
        attention_mask = ~torch.tril(torch.ones((tl, tl), dtype=torch.bool).to(seqs.device))

        for i in range(len(self.attention_layers)):
            seqs = torch.transpose(seqs, 0, 1)
            Q = self.attention_layernorms[i](seqs)
            mha_outputs, _ = self.attention_layers[i](Q, seqs, seqs, 
                                            attn_mask=attention_mask)
                                            # key_padding_mask=timeline_mask
                                            # need_weights=False) this arg do not work?
            seqs = Q + mha_outputs
            seqs = torch.transpose(seqs, 0, 1)

            seqs = self.forward_layernorms[i](seqs)
            seqs = self.forward_layers[i](seqs)
            seqs *=  ~timeline_mask.unsqueeze(-1)

        outputs = self.last_layernorm(seqs) # (U, T, C) -> (U, -1, C)
        if self.add_head:
            outputs = torch.matmul(outputs, self.item_emb.weight.transpose(0, 1))

        return outputs


class PointWiseFeedForward(nn.Module):

    def __init__(self, hidden_units, dropout_rate):

        super(PointWiseFeedForward, self).__init__()

        self.conv1 = nn.Conv1d(hidden_units, hidden_units, kernel_size=1)
        self.dropout1 = nn.Dropout(p=dropout_rate)
        self.relu = nn.SiLU()
        self.conv2 = nn.Conv1d(hidden_units, hidden_units, kernel_size=1)
        self.dropout2 = nn.Dropout(p=dropout_rate)

    def forward(self, inputs):
        outputs = self.dropout2(
            self.conv2(self.relu(self.dropout1(self.conv1(inputs.transpose(-1, -2))))))
        outputs = outputs.transpose(-1, -2) # as Conv1D requires (N, C, Length)
        outputs += inputs
        return outputs

In [7]:
"""
Filter interactions.
"""


def add_time_idx(df, user_col='user_id', timestamp_col='timestamp', sort=True):
    """Add time index to interactions dataframe."""

    if sort:
        df = df.sort_values([user_col, timestamp_col])

    df['time_idx'] = df.groupby(user_col).cumcount()
    df['time_idx_reversed'] = df.groupby(user_col).cumcount(ascending=False)

    return df


def filter_items(df, item_min_count, item_col='item_id'):

    print('Filtering items..')

    item_count = df.groupby(item_col).user_id.nunique()

    item_ids = item_count[item_count >= item_min_count].index
    print(f'Number of items before {len(item_count)}')
    print(f'Number of items after {len(item_ids)}')

    print(f'Interactions length before: {len(df)}')
    df = df[df.item_id.isin(item_ids)]
    print(f'Interactions length after: {len(df)}')

    return df


def filter_users(df, user_min_count, user_col='user_id'):

    print('Filtering users..')

    user_count = df.groupby(user_col).item_id.nunique()

    user_ids = user_count[user_count >= user_min_count].index
    print(f'Number of users before {len(user_count)}')
    print(f'Number of users after {len(user_ids)}')

    print(f'Interactions length before: {len(df)}')
    df = df[df.user_id.isin(user_ids)]
    print(f'Interactions length after: {len(df)}')

    return df

In [8]:
"""
Postprocessing.
"""

import numpy as np
import pandas as pd


def preds2recs(preds, item_mapping=None):

    user_ids = np.hstack([pred['user_ids'] for pred in preds])
    scores = np.vstack([pred['scores'] for pred in preds])
    preds = np.vstack([pred['preds'] for pred in preds])

    user_ids = np.repeat(user_ids[:, None], repeats=scores.shape[1], axis=1)

    recs = pd.DataFrame({'user_id': user_ids.flatten(),
                         'item_id': preds.flatten(),
                         'prediction': scores.flatten()})

    if item_mapping is not None:
        recs.item_id = recs.item_id.map(item_mapping)

    return recs

In [9]:
"""
Pytorch Lightning Modules.
"""

import numpy as np
import pytorch_lightning as pl
import torch
from torch import nn


class SeqRecBase(pl.LightningModule):

    def __init__(self, model, lr=1e-3, padding_idx=0,
                 predict_top_k=10, filter_seen=True):

        super().__init__()

        self.model = model
        self.lr = lr
        self.padding_idx = padding_idx
        self.predict_top_k = predict_top_k
        self.filter_seen = filter_seen

    def configure_optimizers(self):

        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

    def predict_step(self, batch, batch_idx):

        preds, scores = self.make_prediction(batch)

        scores = scores.detach().cpu().numpy()
        preds = preds.detach().cpu().numpy()
        user_ids = batch['user_id'].detach().cpu().numpy()

        return {'preds': preds, 'scores': scores, 'user_ids': user_ids}

    def validation_step(self, batch, batch_idx):

        preds, scores = self.make_prediction(batch)
        metrics = self.compute_val_metrics(batch['target'], preds)

        self.log("val_ndcg", metrics['ndcg'], prog_bar=True)
        self.log("val_hit_rate", metrics['hit_rate'], prog_bar=True)
        self.log("val_mrr", metrics['mrr'], prog_bar=True)

    def make_prediction(self, batch):

        outputs = self.prediction_output(batch)

        input_ids = batch['input_ids']
        rows_ids = torch.arange(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
        last_item_idx = (input_ids != self.padding_idx).sum(axis=1) - 1

        preds = outputs[rows_ids, last_item_idx, :]

        scores, preds = torch.sort(preds, descending=True)

        if self.filter_seen:
            seen_items = batch['full_history']
            preds, scores = self.filter_seen_items(preds, scores, seen_items)
        else:
            scores = scores[:, :self.predict_top_k]
            preds = preds[:, :self.predict_top_k]

        return preds, scores

    def filter_seen_items(self, preds, scores, seen_items):

        max_len = seen_items.size(1)
        scores = scores[:, :self.predict_top_k + max_len]
        preds = preds[:, :self.predict_top_k + max_len]

        final_preds, final_scores = [], []
        for i in range(preds.size(0)):
            not_seen_indexes = torch.isin(preds[i], seen_items[i], invert=True)
            pred = preds[i, not_seen_indexes][:self.predict_top_k]
            score = scores[i, not_seen_indexes][:self.predict_top_k]
            final_preds.append(pred)
            final_scores.append(score)

        final_preds = torch.vstack(final_preds)
        final_scores = torch.vstack(final_scores)

        return final_preds, final_scores

    def compute_val_metrics(self, targets, preds):

        ndcg, hit_rate, mrr, precision = 0, 0, 0, 0

        for i, pred in enumerate(preds):
            if torch.isin(targets[i], pred).item():
                hit_rate += 1
                rank = torch.where(pred == targets[i])[0].item() + 1
                ndcg += 1 / np.log2(rank + 1)
                mrr += 1 / rank

        hit_rate = hit_rate / len(targets)
        ndcg = ndcg / len(targets)
        mrr = mrr / len(targets)

        return {'ndcg': ndcg, 'hit_rate': hit_rate, 'mrr': mrr}


class SeqRec(SeqRecBase):

    def training_step(self, batch, batch_idx):

        outputs = self.model(batch['input_ids'], batch['attention_mask'])
        loss = self.compute_loss(outputs, batch)

        return loss

    def compute_loss(self, outputs, batch):

        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(outputs.view(-1, outputs.size(-1)), batch['labels'].view(-1))

        return loss

    def prediction_output(self, batch):

        return self.model(batch['input_ids'], batch['attention_mask'])


class SeqRecWithSampling(SeqRec):

    def __init__(self, model, lr=1e-3, loss='cross_entropy',
                 padding_idx=0, predict_top_k=10, filter_seen=True):

        super().__init__(model, lr, padding_idx, predict_top_k, filter_seen)

        self.loss = loss

        if hasattr(self.model, 'item_emb'):  # for SASRec
            self.embed_layer = self.model.item_emb
        elif hasattr(self.model, 'embed_layer'):
            self.embed_layer = self.model.embed_layer


    def compute_loss(self, outputs, batch):

        # embed  and compute logits for negatives
        if batch['negatives'].ndim == 2:  # for full_negative_sampling=False
            # [N, M, D]
            embeds_negatives = self.embed_layer(batch['negatives'].to(torch.int32))
            # [N, T, D] * [N, D, M] -> [N, T, M]
            logits_negatives = torch.matmul(outputs, embeds_negatives.transpose(1, 2))
        elif batch['negatives'].ndim == 3:  # for full_negative_sampling=True
            # [N, T, M, D]
            embeds_negatives = self.embed_layer(batch['negatives'].to(torch.int32))
            # [N, T, 1, D] * [N, T, D, M] -> [N, T, 1, M] -> -> [N, T, M]
            logits_negatives = torch.matmul(
                outputs.unsqueeze(2), embeds_negatives.transpose(2, 3)).squeeze()
            if logits_negatives.ndim == 2:
                logits_negatives = logits_negatives.unsqueeze(2)

        # embed  and compute logits for positives
        # [N, T]
        labels = batch['labels'].clone()
        labels[labels == -100] = self.padding_idx
        # [N, T, D]
        embeds_labels = self.embed_layer(labels)
        # [N, T, 1, D] * [N, T, D, 1] -> [N, T, 1, 1] -> [N, T]
        logits_labels = torch.matmul(outputs.unsqueeze(2), embeds_labels.unsqueeze(3)).squeeze()

        # concat positives and negatives
        # [N, T, M + 1]
        logits = torch.cat([logits_labels.unsqueeze(2), logits_negatives], dim=-1)

        # prepare targets for loss
        if self.loss == 'cross_entropy':
            # [N, T]
            targets = batch['labels'].clone()
            targets[targets != -100] = 0
        elif self.loss == 'bce':
            # [N, T, M + 1]
            targets = torch.zeros_like(logits)
            targets[:, :, 0] = 1

        if self.loss == 'cross_entropy':
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, logits.size(-1)), targets.view(-1))
        elif self.loss == 'bce':
            # loss_fct = nn.BCEWithLogitsLoss()
            # loss = loss_fct(logits, targets)
            loss_fct = nn.BCEWithLogitsLoss(reduction='none')
            loss = loss_fct(logits, targets)
            loss = loss[batch['labels'] != -100]
            loss = loss.mean()

        return loss

    def prediction_output(self, batch):

        outputs = self.model(batch['input_ids'], batch['attention_mask'])
        outputs = torch.matmul(outputs, self.embed_layer.weight.T)

        return outputs

In [10]:
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
from hydra import compose, initialize
from omegaconf import OmegaConf
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ModelSummary
from torch.utils.data import DataLoader

In [14]:
# for SASRec
with initialize(version_base=None, config_path="./configs"):
    config = compose(config_name="SASRec")

In [15]:
OmegaConf.set_struct(config, False)

config.dataset.max_length = 300
config.model_params.maxlen = 300

config.dataset.num_negatives = 3000

In [16]:
print(OmegaConf.to_yaml(config))

cuda_visible_devices: 0
data_path: ../data/ml-1m.txt
dataset:
  max_length: 300
  full_negative_sampling: false
  num_negatives: 3000
dataloader:
  batch_size: 128
  test_batch_size: 256
  num_workers: 8
  validation_size: 10000
model: SASRec
model_params:
  maxlen: 300
  hidden_units: 64
  num_blocks: 2
  num_heads: 1
  dropout_rate: 0.1
seqrec_module:
  lr: 0.001
  predict_top_k: 10
  filter_seen: true
trainer_params:
  max_epochs: 100
patience: 10
sampled_metrics: false
top_k_metrics:
- 10
- 100



In [17]:
data = pd.read_csv('/kaggle/input/recsys-hse/events.csv')
item_counts = data['item_id'].value_counts()

In [18]:
#data = filter_items(data, item_counts.quantile(0.25))

In [19]:
data = add_time_idx(data, sort=False)

# index 1 is used for masking value
if config.model == 'BERT4Rec':
    data.item_id += 1

print(data.shape)
data.head()

(894149, 6)


Unnamed: 0,user_id,item_id,rating,timestamp,time_idx,time_idx_reversed
0,0,1505,4,0,0,286
1,0,3669,3,1,1,285
2,0,584,4,2,2,284
3,0,3390,3,3,3,283
4,0,2885,4,4,4,282


In [20]:
data

Unnamed: 0,user_id,item_id,rating,timestamp,time_idx,time_idx_reversed
0,0,1505,4,0,0,286
1,0,3669,3,1,1,285
2,0,584,4,2,2,284
3,0,3390,3,3,3,283
4,0,2885,4,4,4,282
...,...,...,...,...,...,...
894144,6039,1453,4,223,197,4
894145,6039,613,5,224,198,3
894146,6039,1548,4,225,199,2
894147,6039,241,4,226,200,1


In [21]:
data = data[['user_id','item_id','time_idx','time_idx_reversed']]
data

Unnamed: 0,user_id,item_id,time_idx,time_idx_reversed
0,0,1505,0,286
1,0,3669,1,285
2,0,584,2,284
3,0,3390,3,283
4,0,2885,4,282
...,...,...,...,...
894144,6039,1453,197,4
894145,6039,613,198,3
894146,6039,1548,199,2
894147,6039,241,200,1


In [22]:
train = data[data.time_idx_reversed >= 2]
validation = data[data.time_idx_reversed == 1]
validation_full = data[data.time_idx_reversed >= 1]
test = data[data.time_idx_reversed == 0]
#test = data[(data.time_idx_reversed == 1) | (data.time_idx_reversed == 0)]

In [23]:
validation_size = config.dataloader.validation_size
validation_users = validation_full.user_id.unique()
if validation_size and (validation_size < len(validation_users)):
    validation_users = np.random.choice(validation_users, size=validation_size, replace=False)

if config.model == 'SASRec':
    train_dataset = CausalLMDataset(train, **config['dataset'])
    eval_dataset = CausalLMPredictionDataset(
        validation_full[validation_full.user_id.isin(validation_users)],
        max_length=config.dataset.max_length, validation_mode=True)
elif config.model == 'BERT4Rec':
    train_dataset = MaskedLMDataset(train, **config['dataset'])
    eval_dataset = MaskedLMPredictionDataset(
        validation_full[validation_full.user_id.isin(validation_users)],
        max_length=config.dataset.max_length, validation_mode=True)

train_loader = DataLoader(
    train_dataset, shuffle=True,
    collate_fn=PaddingCollateFn(),
    batch_size=config.dataloader.batch_size,
    num_workers=config.dataloader.num_workers)
eval_loader = DataLoader(
    eval_dataset, shuffle=False,
    collate_fn=PaddingCollateFn(),
    batch_size=config.dataloader.test_batch_size,
    num_workers=config.dataloader.num_workers)



In [24]:
batch = next(iter(train_loader))
print(batch['input_ids'].shape)

torch.Size([128, 300])


In [25]:
item_count = data.item_id.max()

if hasattr(config.dataset, 'num_negatives') and config.dataset.num_negatives:
    add_head = False
else:
    add_head = True

if config.model == 'SASRec':
    model = SASRec(item_num=item_count, add_head=add_head, **config.model_params)
if config.model == 'BERT4Rec':
    model = BERT4Rec(vocab_size=item_count + 1, add_head=add_head,
                     bert_config=config.model_params)

In [26]:
out = model(batch['input_ids'], batch['attention_mask'])
out.shape

torch.Size([128, 300, 64])

In [27]:
if hasattr(config.dataset, 'num_negatives') and config.dataset.num_negatives:
    seqrec_module = SeqRecWithSampling(model, **config['seqrec_module'])
else:
    seqrec_module = SeqRec(model, **config['seqrec_module'])
    
early_stopping = EarlyStopping(monitor="val_hit_rate", mode="max",
                               patience=config.patience, verbose=False)
model_summary = ModelSummary(max_depth=2)
checkpoint = ModelCheckpoint(save_top_k=1, monitor="val_hit_rate",
                             mode="max", save_weights_only=True)
callbacks=[early_stopping, model_summary, checkpoint]

trainer = pl.Trainer(callbacks=callbacks, enable_checkpointing=True, log_every_n_steps=48,
                      **config['trainer_params'])

trainer.fit(model=seqrec_module,
            train_dataloaders=train_loader,
            val_dataloaders=eval_loader)

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()


Training: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [28]:
seqrec_module.load_state_dict(torch.load(checkpoint.best_model_path)['state_dict'])

  seqrec_module.load_state_dict(torch.load(checkpoint.best_model_path)['state_dict'])


<All keys matched successfully>

In [29]:
if config.model == 'SASRec':
    predict_dataset = CausalLMPredictionDataset(train, max_length=config.dataset.max_length)
elif config.model  == 'BERT4Rec':
    predict_dataset = MaskedLMPredictionDataset(train, max_length=config.dataset.max_length)

predict_loader = DataLoader(
        predict_dataset, shuffle=False,
        collate_fn=PaddingCollateFn(),
        batch_size=config.dataloader.test_batch_size,
        num_workers=config.dataloader.num_workers)

seqrec_module.predict_top_k = max(config.top_k_metrics)
preds = trainer.predict(model=seqrec_module, dataloaders=predict_loader)

recs = preds2recs(preds)
print(recs.shape)
recs.head()

Predicting: |          | 0/? [00:00<?, ?it/s]

(604000, 3)


Unnamed: 0,user_id,item_id,prediction
0,0,2397,5.370092
1,0,1052,5.327501
2,0,1149,4.887436
3,0,3660,4.646952
4,0,3657,4.585503


In [30]:
for k in config.top_k_metrics:
    metrics_val = compute_metrics(validation, recs, k=k)
    print('k = ', k)
    print(metrics_val)

  df_hit.groupby(col_user, as_index=False)[col_user].agg({"hit": "count"}),
  rating_true_common.groupby(col_user, as_index=False)[col_user].agg(


k =  10
{'ndcg@10': 0.1645274776603075, 'hit_rate@10': 0.30596026490066225, 'mrr@10': 0.12155517449805528, 'rec@10': 0.30596026490066225}


  df_hit.groupby(col_user, as_index=False)[col_user].agg({"hit": "count"}),
  rating_true_common.groupby(col_user, as_index=False)[col_user].agg(


k =  100
{'ndcg@100': 0.24897708019162168, 'hit_rate@100': 0.7168874172185431, 'mrr@100': 0.1374267546473959, 'rec@100': 0.7168874172185431}


In [31]:
if config.model == 'SASRec':
    test_predict_dataset = CausalLMPredictionDataset(validation_full, max_length=config.dataset.max_length)
elif config.model  == 'BERT4Rec':
    test_predict_dataset = MaskedLMPredictionDataset(validation_full, max_length=config.dataset.max_length)
    
test_predict_loader = DataLoader(
        test_predict_dataset, shuffle=False,
        collate_fn=PaddingCollateFn(),
        batch_size=config.dataloader.test_batch_size,
        num_workers=config.dataloader.num_workers)

seqrec_module.predict_top_k = max(config.top_k_metrics)
preds_test = trainer.predict(model=seqrec_module, dataloaders=test_predict_loader)

recs_test = preds2recs(preds_test)
print(recs_test.shape)
recs_test.head()

  self.pid = os.fork()


Predicting: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()


(604000, 3)


Unnamed: 0,user_id,item_id,prediction
0,0,1461,4.502969
1,0,3421,4.169186
2,0,1332,4.122826
3,0,2397,3.698641
4,0,2480,3.600317


In [32]:
for k in config.top_k_metrics:
    metrics_test = compute_metrics(test, recs_test, k=k)
    print('k = ', k)
    print(metrics_test)

  df_hit.groupby(col_user, as_index=False)[col_user].agg({"hit": "count"}),
  rating_true_common.groupby(col_user, as_index=False)[col_user].agg(


k =  10
{'ndcg@10': 0.16467296460923264, 'hit_rate@10': 0.2993377483443709, 'mrr@10': 0.12366853253442656, 'rec@10': 0.2993377483443709}


  df_hit.groupby(col_user, as_index=False)[col_user].agg({"hit": "count"}),
  rating_true_common.groupby(col_user, as_index=False)[col_user].agg(


k =  100
{'ndcg@100': 0.24654160480381876, 'hit_rate@100': 0.6991721854304636, 'mrr@100': 0.1389707304112343, 'rec@100': 0.6991721854304636}


In [33]:
top_10_recs = recs_test.groupby('user_id').head(10)
subs = top_10_recs.groupby('user_id')['item_id'].apply(list).reset_index()
subs

Unnamed: 0,user_id,item_id
0,0,"[1461, 3421, 1332, 2397, 2480, 2606, 3271, 150..."
1,1,"[232, 1246, 1686, 3656, 452, 1822, 3101, 933, ..."
2,2,"[2354, 2774, 1687, 1560, 1781, 452, 382, 724, ..."
3,3,"[1560, 989, 605, 1316, 2564, 3272, 461, 67, 10..."
4,4,"[802, 2185, 3035, 2814, 1868, 1337, 270, 702, ..."
...,...,...
6035,6035,"[2054, 1811, 2256, 2646, 2366, 2606, 3529, 305..."
6036,6036,"[3692, 3142, 2054, 2732, 401, 3013, 772, 3529,..."
6037,6037,"[2256, 3059, 2732, 1102, 1379, 2833, 2502, 196..."
6038,6038,"[450, 1893, 3309, 2420, 2305, 3605, 405, 79, 8..."


In [34]:
def array_to_str(arr):
    return ' '.join(map(str, arr))

subs['item_id'] = subs['item_id'].apply(array_to_str)
subs

Unnamed: 0,user_id,item_id
0,0,1461 3421 1332 2397 2480 2606 3271 1501 331 213
1,1,232 1246 1686 3656 452 1822 3101 933 350 1884
2,2,2354 2774 1687 1560 1781 452 382 724 221 3035
3,3,1560 989 605 1316 2564 3272 461 67 106 810
4,4,802 2185 3035 2814 1868 1337 270 702 3404 1687
...,...,...
6035,6035,2054 1811 2256 2646 2366 2606 3529 3059 3382 463
6036,6036,3692 3142 2054 2732 401 3013 772 3529 2664 1039
6037,6037,2256 3059 2732 1102 1379 2833 2502 1968 2664 2646
6038,6038,450 1893 3309 2420 2305 3605 405 79 84 1337


In [35]:
subs.to_csv('submission.csv', index=False)