In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

import sys
sys.path.append('../')

In [2]:
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ModelSummary
from torch.utils.data import DataLoader

from src.datasets import (CausalLMDataset, CausalLMPredictionDataset, MaskedLMDataset,
                          MaskedLMPredictionDataset, PaddingCollateFn)
from src.metrics import compute_metrics
from src.models import RNN, BERT4Rec, SASRec
from src.modules import SeqRec, SeqRecWithSampling
from src.postprocess import preds2recs
from src.preprocess import add_time_idx
from src.unbiased_metrics import get_metrics, hr, mrr, ndcg


libgomp: Invalid value for environment variable OMP_NUM_THREADS

libgomp: Invalid value for environment variable OMP_NUM_THREADS


## Load data

In [3]:
ITEM_COL = 'item_id'
RELEVANCE_THRESHOLD = 3.5
RELEVANCE_COL = 'rating'

In [4]:
train = pd.read_csv('../data/ml-1m/train.csv')
test = pd.read_csv('../data/ml-1m/test.csv')
val_1 = pd.read_csv('../data/ml-1m/val_1.csv')
val_2 = pd.read_csv('../data/ml-1m/val_2.csv')
test_users_history = pd.read_csv('../data/ml-1m/test_users_history.csv')
val_users_history_1 = pd.read_csv('../data/ml-1m/val_users_history_1.csv')
val_users_history_2 = pd.read_csv('../data/ml-1m/val_users_history_2.csv')

In [5]:
train.item_id = train.item_id * 2
val_users_history_1.item_id = val_users_history_1.item_id * 2
val_users_history_2.item_id = val_users_history_2.item_id * 2
test_users_history.item_id = test_users_history.item_id * 2

In [6]:
train.loc[train[RELEVANCE_COL] < RELEVANCE_THRESHOLD, 'item_id'] -= 1
val_users_history_1.loc[val_users_history_1[RELEVANCE_COL] < RELEVANCE_THRESHOLD, 'item_id'] -= 1
val_users_history_2.loc[val_users_history_2[RELEVANCE_COL] < RELEVANCE_THRESHOLD, 'item_id'] -= 1
test_users_history.loc[test_users_history[RELEVANCE_COL] < RELEVANCE_THRESHOLD, 'item_id'] -= 1

## Dataloaders

In [7]:
MAX_LENGTH = 200

VALIDATION_SIZE = 10000
# VALIDATION_SIZE = None

BATCH_SIZE = 64
TEST_BATCH_SIZE = 64
NUM_WORKERS = 8

In [8]:
def get_eval_dataset(validation, validation_size=None, max_length=200):
    validation_users = validation.user_id.unique()

    if validation_size and (validation_size < len(validation_users)):
        validation_users = np.random.choice(validation_users, size=validation_size, replace=False)

    eval_dataset = CausalLMPredictionDataset(
        validation[validation.user_id.isin(validation_users)],
        max_length=max_length,
        user_col='test_user_idx',
        validation_mode=True)

    return eval_dataset

In [9]:
%%time
train_dataset = CausalLMDataset(train, user_col='user_id', max_length=MAX_LENGTH, num_negatives=3000)
val_1_dataset = get_eval_dataset(val_users_history_1, max_length=MAX_LENGTH)
val_2_dataset = get_eval_dataset(val_users_history_2, max_length=MAX_LENGTH)
test_dataset = get_eval_dataset(test_users_history, max_length=MAX_LENGTH)

train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE,
    shuffle=True, num_workers=NUM_WORKERS,
    collate_fn=PaddingCollateFn()
)
val_1_loader = DataLoader(
    val_1_dataset, batch_size=TEST_BATCH_SIZE,
    shuffle=False, num_workers=NUM_WORKERS,
    collate_fn=PaddingCollateFn()
)
val_2_loader = DataLoader(
    val_2_dataset, batch_size=TEST_BATCH_SIZE,
    shuffle=False, num_workers=NUM_WORKERS,
    collate_fn=PaddingCollateFn()
)
test_loader = DataLoader(
    test_dataset, batch_size=TEST_BATCH_SIZE,
    shuffle=False, num_workers=NUM_WORKERS,
    collate_fn=PaddingCollateFn()
)

CPU times: user 10.6 s, sys: 1.49 s, total: 12.1 s
Wall time: 12.1 s


In [10]:
batch = next(iter(train_loader))
print(batch['input_ids'].shape)

torch.Size([64, 200])


## Model

In [11]:
SASREC_CONFIG = {
    'maxlen': 200,
    'hidden_units': 64,
    'num_blocks': 2,
    'num_heads': 1,
    'dropout_rate': 0.1,
}

In [12]:
item_count = train.item_id.max()
add_head = False

model = SASRec(item_num=item_count, add_head=add_head, **SASREC_CONFIG)

In [13]:
out = model(batch['input_ids'], batch['attention_mask'])
out.shape

torch.Size([64, 200, 64])

## Train

In [14]:
seqrec_module = SeqRecWithSampling(model, lr=0.001, predict_top_k=200, filter_seen=True)

early_stopping = EarlyStopping(monitor="val_ndcg", mode="max", patience=10, verbose=False)

model_summary = ModelSummary(max_depth=2)
checkpoint = ModelCheckpoint(save_top_k=1, monitor="val_ndcg",
                             mode="max", save_weights_only=True)
callbacks=[early_stopping, model_summary, checkpoint]

trainer = pl.Trainer(callbacks=callbacks, enable_checkpointing=True,
                     gpus=1, max_epochs=100)

trainer.fit(model=seqrec_module,
            train_dataloaders=train_loader,
            val_dataloaders=val_1_loader)

Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [3]

  | Name                       | Type       | Params
----------------------------------------------------------
0 | model                      | SASRec     | 569 K 
1 | model.item_emb             | Embedding  | 505 K 
2 | model.pos_emb              | Embedding  | 12.8 K
3 | model.emb_dropout          | Dropout    | 0     
4 | model.attention_layernorms | ModuleList | 256   
5 | model.attention_layers     | ModuleList | 33.3 K
6 | model.forward_layernorms   | ModuleList | 256   
7 | model.forward_layers       | ModuleList | 16.6 K
8 | model.last_layernorm       | LayerNorm  | 128   
----------------------------------------------

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [14]:
seqrec_module.load_state_dict(torch.load(checkpoint.best_model_path)['state_dict'])

<All keys matched successfully>

## Evaluation

In [67]:
seqrec_module.predict_top_k = test.item_id.nunique()
preds = trainer.predict(model=seqrec_module, dataloaders=val_2_loader)

preds_val = preds2recs(preds)
print(preds_val.shape)
preds_val.head()

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [3]


Predicting: 94it [00:00, ?it/s]

(76387752, 3)


Unnamed: 0,user_id,item_id,prediction
0,0,220,6.707412
1,0,1816,6.644127
2,0,5582,6.205075
3,0,2386,6.200454
4,0,1216,6.111491


In [68]:
seqrec_module.predict_top_k = test.item_id.nunique()
preds = trainer.predict(model=seqrec_module, dataloaders=test_loader)

preds_test = preds2recs(preds)
print(preds_test.shape)
preds_test.head()

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [3]


Predicting: 94it [00:00, ?it/s]

(152775504, 3)


Unnamed: 0,user_id,item_id,prediction
0,0,7597,9.083544
1,0,6962,7.462196
2,0,7502,7.30528
3,0,6300,7.125731
4,0,7156,6.893892


In [69]:
preds_test = preds_test[preds_test.item_id % 2 == 0]
preds_test.loc[:, 'item_id'] /= 2
preds_test = preds_test.groupby('user_id').head(10)

  preds_test.loc[:, 'item_id'] /= 2


In [70]:
preds_val = preds_val[preds_val.item_id % 2 == 0]
preds_val.loc[:, 'item_id'] /= 2
preds_val = preds_val.groupby('user_id').head(10)

  preds_val.loc[:, 'item_id'] /= 2


In [72]:
sum(preds_test.groupby('user_id').count().item_id < 10)

0

In [73]:
preds_val = (preds_val
              .rename(columns={'user_id': 'test_user_idx', 'item_id': 'pred_items'})
              .groupby('test_user_idx')['pred_items']
              .apply(list).reset_index()
              .merge(val_2, on='test_user_idx', how='left'))

In [74]:
preds_test = (preds_test
              .rename(columns={'user_id': 'test_user_idx', 'item_id': 'pred_items'})
              .groupby('test_user_idx')['pred_items']
              .apply(list).reset_index()
              .merge(test, on='test_user_idx', how='left'))

In [75]:
metrics_df, beta = get_metrics(preds_test, preds_val)

In [76]:
metrics_df

Unnamed: 0,type,HR,MRR,nDCG
0,Biased,0.155943,0.056163,0.079221
1,Unbiased,0.243924,0.144194,0.142461
2,Unbiased_feedback_sampling,0.719237,0.277838,0.100862


In [78]:
beta

0.10423539206237155

In [88]:
from src.unbiased_metrics import confusion_matrix_metrics

In [83]:
relevance_col = 'rating'
relevance_threshold=3.5

In [85]:
preds_test_pos = preds_test[preds_test[relevance_col] >= relevance_threshold]
preds_val_neg = preds_val[preds_val[relevance_col] < relevance_threshold]
preds_val_pos = preds_val[preds_val[relevance_col] >= relevance_threshold]

tp, fn = confusion_matrix_metrics(preds_val_pos, user_col, item_col)
fp, tn = confusion_matrix_metrics(preds_val_neg, user_col, item_col)

In [86]:
tp / (tp + fp)

0.6731609002351361

In [87]:
tp / (tp + fn)

0.1548446917014372

In [79]:
preds_test_pos = preds_test[preds_test['rating'] >= 3.5]

In [80]:
hr(preds_test_pos, beta=beta, sample_feedback=True, return_confidence_interval=True)

(0.719237, 0.0013179760836089516)

In [81]:
mrr(preds_test_pos, beta=beta, sample_feedback=True, return_confidence_interval=True)

(0.277838, 0.0010509016208834487)

In [82]:
ndcg(preds_test_pos, beta=beta, sample_feedback=True, return_confidence_interval=True)

(0.100862, 7.160066666789255e-05)