In [1]:
import torch
import torch.nn.functional as F
import pandas as pd
import numpy as np
from tqdm import tqdm
from src.data import get_dataset
from dt4rec_utils import make_rsa
from cql_dqn import DQNCQL
from metrics import Evaluator

DEVICE = torch.device('cuda:1')

In [2]:
trainset, data_description, _, testset, _, holdout = get_dataset(
    validation_size=1024, test_size=5000, data_path='./data/ml-1m.zip', splitting='temporal_full', q=0.8)

inference_sequences = testset.groupby('userid', sort=False)['itemid'].apply(list)
inference_sequences

Filtered 115 invalid observations.
Filtered 11 invalid observations.
Filtered 4 invalid observations.


userid
1       [2935, 1160, 1552, 941, 2117, 1633, 3136, 2566...
2       [1090, 1102, 1109, 2479, 1183, 2702, 1117, 108...
3       [573, 2618, 3260, 1762, 1307, 1755, 1156, 1259...
4       [1102, 1008, 3194, 463, 3253, 253, 1088, 1090,...
5       [2479, 832, 843, 1140, 346, 2618, 1033, 1981, ...
                              ...                        
5998    [1268, 1269, 1270, 3136, 3300, 3313, 3034, 323...
6001    [3464, 3207, 1562, 1544, 2635, 1139, 1242, 889...
6002    [1736, 440, 1740, 2049, 2849, 265, 961, 2898, ...
6016    [3570, 3463, 3507, 1344, 2880, 3601, 27, 2466,...
6040    [2932, 2348, 1104, 3092, 3148, 1148, 1160, 186...
Name: itemid, Length: 1739, dtype: object

In [3]:
@torch.no_grad()
def create_sasrec_predictions(
    sequences,
    sasrec_path: str,
    device: torch.device
):
    sasrec = torch.load(sasrec_path).to(device)
    sasrec.eval()

    scores = []

    for u, seq in tqdm(sequences.items(), total=len(sequences)):
        s = torch.LongTensor(seq).to(device)
        logits = sasrec.score_with_state(s)[0].flatten().detach().cpu().numpy()[:-1]
        scores.append(logits)

    return np.stack(scores)

@torch.no_grad()
def create_cqlsasrec_predictions(
    sequences,
    cql_path: str,
    device: torch.device
):
    trainer = torch.load(cql_path)
    trainer.q_1 = trainer.q_1.to(device)
    trainer.q_2 = trainer.q_2.to(device)
    trainer.body = trainer.body.to(device)
    trainer.q_1.eval()
    trainer.q_2.eval()
    trainer.body.eval()

    scores = []

    for u, seq in tqdm(sequences.items(), total=len(sequences)):
        s = torch.LongTensor(seq).to(device)
        body_out = trainer.body.score_with_state(s)[-1]
        body_out = body_out.reshape(-1, body_out.shape[-1])
        out = (trainer.q_1(body_out) + trainer.q_2(body_out)) / 2.0
        scores.append(out.flatten()[:-1].cpu().numpy())

    return np.stack(scores)


@torch.no_grad()
def create_dt4rec_predictions(
    sequences,
    dt4rec_path: str,
    device: torch.device
):
    dt4rec = torch.load(dt4rec_path).to(device)
    dt4rec.eval()

    item_num = dt4rec.config.vocab_size
    seq_len = 100

    scores = []

    for u, seq in tqdm(sequences.items(), total=len(sequences)):
        s = torch.LongTensor(seq).to(device)
        s = F.pad(s, (seq_len - 1 - len(s), 0), value=item_num)
        rsa = {
            key: value[None, ...].to(device)
            for key, value in make_rsa(s, 3, item_num).items()
        }
        state = dt4rec(**rsa)
        # [:-1] to fix a bug
        scores.append(state[:, -1, :].flatten()[:-1].cpu().numpy())

    return np.stack(scores)


@torch.no_grad()
def create_ssknn_predictions(
    sequences,
    ssknn_path: str
):
    ssknn = torch.load(ssknn_path)

    scores = []

    for u, seq in tqdm(sequences.items(), total=len(sequences)):
        s = seq
        d = pd.DataFrame({'itemid' : s, 'timestamp' : np.arange(len(s))})
        d['userid'] = u
        sc = ssknn.recommend(d, data_description).ravel()
        # sc[seq[-1]] = 0.0
        scores.append(sc)

    return np.stack(scores)

In [4]:
models = {
    'sasrec_2' : './models/sasrec_2.pt',
    'sasrec_3' : './models/sasrec_3.pt',
    'sasrec_4' : './models/sasrec_4.pt',
    'dt4rec' : './models/dt4rec.pt',
    'cql_sasrec' : './models/cql_sasrec.pt',
    'ssknn' : './models/ssknn.pt',
}

In [5]:
evaluator = Evaluator(top_k=[10])

scores = {}

for k,v in models.items():
    if k in ['sasrec_2', 'sasrec_3', 'sasrec_4']:
        scores[k] = create_sasrec_predictions(inference_sequences, v, DEVICE)
    elif k == 'dt4rec':
        scores[k] = create_dt4rec_predictions(inference_sequences, v, DEVICE)
    elif k == 'cql_sasrec':
        scores[k] = create_cqlsasrec_predictions(inference_sequences, v, DEVICE)
    elif k == 'ssknn':
        scores[k] = create_ssknn_predictions(inference_sequences, v)

metrics_all = {}
for k, v in scores.items():
    s = evaluator.downvote_seen_items(v, testset)
    recs = evaluator.topk_recommendations(s)
    metrics = evaluator.compute_metrics(holdout, recs)
    metrics_all[k] = metrics

100%|██████████| 1739/1739 [00:02<00:00, 737.91it/s]
100%|██████████| 1739/1739 [00:02<00:00, 836.39it/s]
100%|██████████| 1739/1739 [00:02<00:00, 822.75it/s]
100%|██████████| 1739/1739 [00:03<00:00, 578.09it/s]
100%|██████████| 1739/1739 [00:02<00:00, 742.64it/s]
100%|██████████| 1739/1739 [00:15<00:00, 113.70it/s]


In [6]:
pd.DataFrame(metrics_all).T.sort_values(by='ndcg@10', ascending=False)

Unnamed: 0,ndcg@10,hr@10
sasrec_2,0.094667,0.182864
cql_sasrec,0.090978,0.170213
sasrec_3,0.070142,0.133985
ssknn,0.062707,0.123059
dt4rec,0.03441,0.063255
sasrec_4,0.017741,0.037378
