In [1]:
import torch
import random
import datetime
import pandas as pd
import numpy as np
import os

from torch.utils.data import Dataset
from src.datasets import RL4RS, ContentWise, DummyData
from src.utils import train, get_dummy_data, get_train_val_test_tmatrix_tnumitems, fit_treshold
from src.embeddings import RecsysEmbedding
from sklearn.linear_model import LogisticRegression

experiment_name = 'SlatewiseGRU'
device = 'cuda:2'
seed = 7331
pkl_path = '../pkl/'


random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7f6e5b6eaad0>

In [2]:
def flatten(true, pred, mask, to_cpu=True):
    mask = mask.flatten()
    nnz_idx = mask.nonzero()[:, 0]
    true, pred = [x.flatten()[nnz_idx] for x in [true, pred]]
    if to_cpu:
        true, pred = [x.cpu().numpy() for x in [true, pred]]
    return true, pred

# Модель

In [3]:
import torch.nn as nn
import torch.nn.functional as F

class SlatewiseGRU(nn.Module):
    def __init__(self, embedding, readout=False):
        super().__init__()
        
        self.embedding = embedding
        
        self.emb_dim = embedding.embedding_dim
        self.rnn_layer = nn.GRU(
            input_size=self.emb_dim * 2, 
            hidden_size=self.emb_dim, 
            batch_first=True
        )
        self.out_layer = nn.Linear(self.emb_dim, 1)
        
        self.thr = -1.5
        self.readout = readout
        self.readout_mode = 'threshold' # ['soft' ,'threshold', 'sample', 'diff_sample']
        
        self.calibration = False
        self.w = 1
        self.b = 0

    def forward(self, batch):
        item_embs, user_embs = self.embedding(batch)
        shp = item_embs.shape
        max_sequence = item_embs.size(1)
        # ilya format:
        # 'items': (batch, slate, 2*embedding_dim ) 2, нужно для readout, по умолчанию ноли на половине эмбеддинга
        # 'clicks': (batch, slate)
        # 'users': (1, batch, embedding_dim), 
        # 'mask': (batch, slate)
        
        x = {}
        x['items'] = torch.cat(
            [
                item_embs.flatten(0,1),
                torch.zeros_like(item_embs.flatten(0,1)),
            ],
            dim = -1
        )
                
        if self.training:
            indices = (batch['length'] - 1)
        else:
            indices = (batch['in_length'] - 1)
        indices[indices<0] = 0
        indices = indices[:, None, None].repeat(1, 1, user_embs.size(-1))
        user_embs = user_embs.gather(1, indices).squeeze(-2).unsqueeze(0)
        x['users'] = user_embs.repeat_interleave(max_sequence, 1)
        x['clicks'] = (batch['responses'].flatten(0,1) > 0 ).int().clone()
        x['mask'] = batch['slates_mask'].flatten(0,1).clone()
        
        items = x['items']
        h = x['users']
        
        if self.readout:
            res = []
            seq_len = items.shape[1]
            for i in range(seq_len):
#                 print(items[:,[i],:])
                output, h = self.rnn_layer(items[:,[i],:], h)
                y = self.out_layer(output)[:, :, 0]
                
                # readout
                if i + 1 < seq_len:
                    if self.readout_mode == 'threshold':
                        items[:, [i+1], self.emb_dim:] *= (y.detach()[:, :, None] > self.thr).to(torch.float32)
                    elif self.readout_mode == 'soft':
                        items[:, [i+1], self.emb_dim:] *= torch.sigmoid(y)[:, :, None]
                    elif self.readout_mode == 'diff_sample' or self.readout_mode == 'sample':
                        eps = 1e-8
                        gumbel_sample = -( (torch.rand_like(y) + eps).log() / (torch.rand_like(y) + eps).log() + eps).log()
                        T = 0.3
                        bernoulli_sample = torch.sigmoid( (nn.LogSigmoid()(self.w * y + self.b) + gumbel_sample) / T )
                        if self.readout_mode == 'sample':
                            bernoulli_sample = bernoulli_sample.detach()
                        items[:, [i+1], self.emb_dim:] *= bernoulli_sample[:, :, None]
                    else:
                        raise
                    
                res.append(y)
        
            y = torch.cat(res, axis=1)
            
        else:
            items[:, 1:, self.emb_dim:] *= x['clicks'][:, :-1, None]
            rnn_out, _ = self.rnn_layer(items, h)
            y = self.out_layer(rnn_out)[:, :, 0]
        
        
        if self.calibration and self.training:
            clicks_flat, logits_flat = flatten(x['clicks'], y.detach(), x['mask'])
            logreg = LogisticRegression()
            logreg.fit(logits_flat[:, None], clicks_flat)
            γ = 0.3
            self.w = (1 - γ) * self.w + γ * logreg.coef_[0, 0]
            self.b = (1 - γ) * self.b + γ * logreg.intercept_[0]
            y = self.w * y + self.b
        else:
            y = self.w * y + self.b
            
        return y.reshape(shp[:-1])

# Игрушечный датасет: проверим, что сходится к идеальным метрикам

In [4]:
d = DummyData()
dummy_loader, dummy_matrix = get_dummy_data(d)

model = SlatewiseGRU(
    RecsysEmbedding(d.n_items, dummy_matrix, embeddings='svd', embedding_dim = 2),
    readout=False
).to('cpu')

train(
    model, 
    dummy_loader, dummy_loader, dummy_loader,
    device=device, lr=1e-3, num_epochs=5000, dummy=True,
    silent=True,
)


biulding affinity matrix...


3it [00:00, 3458.74it/s]


Test before learning: {'f1': 0.0, 'roc-auc': 0.3333333134651184, 'accuracy': 0.75}


train... loss:0.7181857228279114:   0%|                                                                                                    | 1/5000 [00:00<26:43,  3.12it/s]

Val update: epoch: 0 |accuracy: 0.25 | f1: 0.4000000059604645 | auc: 0.3333333134651184 | treshold: 0.01
Test: accuracy: 0.25 | f1: 0.4000000059604645 | auc: 0.3333333134651184 | 


train... loss:0.7011198401451111:   1%|▊                                                                                                  | 42/5000 [00:09<21:07,  3.91it/s]

Val update: epoch: 41 |accuracy: 0.25 | f1: 0.4000000059604645 | auc: 0.6666666269302368 | treshold: 0.01
Test: accuracy: 0.25 | f1: 0.4000000059604645 | auc: 0.6666666269302368 | 


train... loss:0.6998206377029419:   1%|▉                                                                                                  | 45/5000 [00:10<21:54,  3.77it/s]

Val update: epoch: 44 |accuracy: 0.25 | f1: 0.4000000059604645 | auc: 1.0 | treshold: 0.01
Test: accuracy: 0.25 | f1: 0.4000000059604645 | auc: 1.0 | 


train... loss:0.6820520758628845:   2%|█▋                                                                                                 | 85/5000 [00:19<20:37,  3.97it/s]

Val update: epoch: 84 |accuracy: 0.75 | f1: 0.6666666865348816 | auc: 1.0 | treshold: 0.36000000000000004
Test: accuracy: 0.75 | f1: 0.6666666865348816 | auc: 1.0 | 


train... loss:0.6675306558609009:   2%|██▎                                                                                               | 117/5000 [00:26<18:25,  4.42it/s]

Val update: epoch: 117 |accuracy: 1.0 | f1: 1.0 | auc: 1.0 | treshold: 0.38
Test: accuracy: 1.0 | f1: 1.0 | auc: 1.0 | 





(SlatewiseGRU(
   (embedding): RecsysEmbedding()
   (rnn_layer): GRU(4, 2, batch_first=True)
   (out_layer): Linear(in_features=2, out_features=1, bias=True)
 ),
 {'f1': 1.0, 'roc-auc': 1.0, 'accuracy': 1.0})

# ContentWise

In [5]:
content_wise_results = []
dataset = ContentWise.load(os.path.join(pkl_path, 'cw.pkl'))
(
    train_loader, 
    val_loader, 
    test_loader, 
    train_user_item_matrix, 
    train_num_items 
) = get_train_val_test_tmatrix_tnumitems(dataset, batch_size=150)

print(f"{len(dataset)} data points among {len(train_loader)} batches")

20216 data points among 108 batches


In [6]:
for embeddings in ['svd', 'neural']:
    print(f"\nEvaluating {experiment_name} with {embeddings} embeddings")
    model = SlatewiseGRU(
        RecsysEmbedding(train_num_items, train_user_item_matrix, embeddings=embeddings).to('cpu'),
    ).to(device)

    _, metrics = train(
        model, 
        train_loader, val_loader, test_loader, 
        device=device, lr=1e-3, num_epochs=5000, early_stopping=7,
       silent=True, 
    )
    
    metrics['embeddings'] = embeddings
    content_wise_results.append(metrics)


Evaluating SlatewiseGRU with svd embeddings
Test before learning: {'f1': 0.17239901423454285, 'roc-auc': 0.4641438126564026, 'accuracy': 0.11004862934350967}


train... loss:49.18179851770401:   0%|                                                                                                  | 1/5000 [00:39<55:26:15, 39.92s/it]

Val update: epoch: 0 |accuracy: 0.1531217247247696 | f1: 0.18792562186717987 | auc: 0.6076698899269104 | treshold: 0.09
Test: accuracy: 0.17369838058948517 | f1: 0.1841088980436325 | auc: 0.6171619892120361 | 


train... loss:34.75427806377411:   0%|                                                                                                  | 2/5000 [00:47<29:18:44, 21.11s/it]

Val update: epoch: 1 |accuracy: 0.7042703628540039 | f1: 0.23013544082641602 | auc: 0.608708381652832 | treshold: 0.13
Test: accuracy: 0.7140371203422546 | f1: 0.2277820110321045 | auc: 0.6147950887680054 | 


train... loss:33.50057601928711:   0%|                                                                                                  | 3/5000 [00:55<20:54:11, 15.06s/it]

Val update: epoch: 2 |accuracy: 0.8128049373626709 | f1: 0.23322582244873047 | auc: 0.6253132820129395 | treshold: 0.13
Test: accuracy: 0.8206504583358765 | f1: 0.23027826845645905 | auc: 0.6287914514541626 | 


train... loss:32.839709997177124:   0%|                                                                                                 | 4/5000 [01:03<16:57:21, 12.22s/it]

Val update: epoch: 3 |accuracy: 0.8089298009872437 | f1: 0.241201713681221 | auc: 0.6412267088890076 | treshold: 0.13
Test: accuracy: 0.8171557188034058 | f1: 0.23751938343048096 | auc: 0.6442759037017822 | 


train... loss:32.555352568626404:   0%|                                                                                                 | 5/5000 [01:11<14:46:46, 10.65s/it]

Val update: epoch: 4 |accuracy: 0.7933675050735474 | f1: 0.2504480183124542 | auc: 0.6517706513404846 | treshold: 0.12
Test: accuracy: 0.802328884601593 | f1: 0.24520158767700195 | auc: 0.6563292741775513 | 


train... loss:32.29015788435936:   0%|                                                                                                  | 6/5000 [01:19<13:26:58,  9.70s/it]

Val update: epoch: 5 |accuracy: 0.7607762813568115 | f1: 0.24923688173294067 | auc: 0.6570345163345337 | treshold: 0.12
Test: accuracy: 0.7719316482543945 | f1: 0.2461659461259842 | auc: 0.6631274819374084 | 


train... loss:31.412767827510834:   0%|▏                                                                                               | 12/5000 [02:03<10:46:23,  7.78s/it]

Val update: epoch: 11 |accuracy: 0.7911597490310669 | f1: 0.2581848204135895 | auc: 0.65810626745224 | treshold: 0.12
Test: accuracy: 0.7994586825370789 | f1: 0.26283279061317444 | auc: 0.668986439704895 | 


train... loss:31.35055923461914:   0%|▎                                                                                                | 13/5000 [02:11<10:49:41,  7.82s/it]

Val update: epoch: 12 |accuracy: 0.8794695138931274 | f1: 0.2327272742986679 | auc: 0.663524329662323 | treshold: 0.15000000000000002
Test: accuracy: 0.8852073550224304 | f1: 0.22848576307296753 | auc: 0.6673926711082458 | 


train... loss:31.134289383888245:   0%|▎                                                                                               | 14/5000 [02:19<10:50:15,  7.82s/it]

Val update: epoch: 13 |accuracy: 0.869974672794342 | f1: 0.24425700306892395 | auc: 0.6640497446060181 | treshold: 0.13
Test: accuracy: 0.8761358261108398 | f1: 0.24123166501522064 | auc: 0.6726142168045044 | 


train... loss:31.147396057844162:   0%|▎                                                                                               | 15/5000 [02:27<10:51:55,  7.85s/it]

Val update: epoch: 14 |accuracy: 0.8480361700057983 | f1: 0.2621992230415344 | auc: 0.6647053956985474 | treshold: 0.14
Test: accuracy: 0.8570111393928528 | f1: 0.2589595317840576 | auc: 0.6712898015975952 | 


train... loss:31.132243365049362:   0%|▎                                                                                               | 16/5000 [02:35<10:52:17,  7.85s/it]

Val update: epoch: 15 |accuracy: 0.7881800532341003 | f1: 0.2650524973869324 | auc: 0.666884183883667 | treshold: 0.13
Test: accuracy: 0.7963654398918152 | f1: 0.2663809359073639 | auc: 0.6734336614608765 | 


train... loss:30.956328779459:   0%|▎                                                                                                  | 18/5000 [02:50<10:45:13,  7.77s/it]

Val update: epoch: 17 |accuracy: 0.8699901103973389 | f1: 0.2648625075817108 | auc: 0.6687343120574951 | treshold: 0.14
Test: accuracy: 0.875734269618988 | f1: 0.2505829632282257 | auc: 0.6749803423881531 | 


train... loss:30.77937039732933:   0%|▎                                                                                                | 19/5000 [02:58<10:48:20,  7.81s/it]

Val update: epoch: 18 |accuracy: 0.88271164894104 | f1: 0.24430517852306366 | auc: 0.6701903939247131 | treshold: 0.16
Test: accuracy: 0.8883006572723389 | f1: 0.23567721247673035 | auc: 0.6745896339416504 | 


train... loss:30.456127643585205:   0%|▍                                                                                               | 23/5000 [03:28<10:34:06,  7.64s/it]

Val update: epoch: 22 |accuracy: 0.8794540762901306 | f1: 0.2549618184566498 | auc: 0.6710207462310791 | treshold: 0.14
Test: accuracy: 0.8843894600868225 | f1: 0.24801702797412872 | auc: 0.6725579500198364 | 


train... loss:30.083432763814926:   1%|▍                                                                                               | 26/5000 [03:51<10:32:18,  7.63s/it]

Val update: epoch: 25 |accuracy: 0.8552924394607544 | f1: 0.2777221202850342 | auc: 0.6739699840545654 | treshold: 0.14
Test: accuracy: 0.8624540567398071 | f1: 0.26914262771606445 | auc: 0.6788731217384338 | 


train... loss:29.995643123984337:   1%|▋                                                                                               | 33/5000 [04:43<10:29:39,  7.61s/it]

Val update: epoch: 32 |accuracy: 0.7833940386772156 | f1: 0.2799958884716034 | auc: 0.6755121946334839 | treshold: 0.13
Test: accuracy: 0.7856282591819763 | f1: 0.26927560567855835 | auc: 0.6755298972129822 | 


train... loss:29.87032163143158:   1%|▋                                                                                                | 35/5000 [04:59<10:37:21,  7.70s/it]

Val update: epoch: 34 |accuracy: 0.8797783255577087 | f1: 0.27190276980400085 | auc: 0.6760849952697754 | treshold: 0.14
Test: accuracy: 0.8857873678207397 | f1: 0.26153847575187683 | auc: 0.67588210105896 | 


train... loss:29.62252241373062:   1%|▋                                                                                                | 37/5000 [05:14<10:36:34,  7.70s/it]

Val update: epoch: 36 |accuracy: 0.8637065291404724 | f1: 0.27981725335121155 | auc: 0.6765272617340088 | treshold: 0.15000000000000002
Test: accuracy: 0.8710200190544128 | f1: 0.27646616101264954 | auc: 0.680415689945221 | 


train... loss:29.62823587656021:   1%|▋                                                                                                | 38/5000 [05:22<10:41:51,  7.76s/it]

Val update: epoch: 37 |accuracy: 0.8845025897026062 | f1: 0.26678428053855896 | auc: 0.6772441267967224 | treshold: 0.16
Test: accuracy: 0.8896390795707703 | f1: 0.25199073553085327 | auc: 0.6794477105140686 | 


train... loss:29.56452512741089:   1%|▊                                                                                                | 40/5000 [05:37<10:40:57,  7.75s/it]

Val update: epoch: 39 |accuracy: 0.8257426023483276 | f1: 0.2882905602455139 | auc: 0.6812461614608765 | treshold: 0.14
Test: accuracy: 0.8320122361183167 | f1: 0.2793160676956177 | auc: 0.6825573444366455 | 


train... loss:29.593356609344482:   1%|▉                                                                                               | 49/5000 [06:52<11:35:24,  8.43s/it]



Evaluating SlatewiseGRU with neural embeddings
Test before learning: {'f1': 0.1642305850982666, 'roc-auc': 0.525184154510498, 'accuracy': 0.6215962767601013}


train... loss:44.94625923037529:   0%|                                                                                                  | 1/5000 [00:10<14:02:29, 10.11s/it]

Val update: epoch: 0 |accuracy: 0.1853424310684204 | f1: 0.19218933582305908 | auc: 0.62271648645401 | treshold: 0.09999999999999999
Test: accuracy: 0.19773061573505402 | f1: 0.18699419498443604 | auc: 0.6265320777893066 | 


train... loss:33.257901936769485:   0%|                                                                                                 | 2/5000 [00:20<13:55:09, 10.03s/it]

Val update: epoch: 1 |accuracy: 0.8082504868507385 | f1: 0.2585960030555725 | auc: 0.6666598916053772 | treshold: 0.15000000000000002
Test: accuracy: 0.815623939037323 | f1: 0.2503325641155243 | auc: 0.6668229103088379 | 


train... loss:31.76764327287674:   0%|                                                                                                  | 3/5000 [00:30<13:52:31, 10.00s/it]

Val update: epoch: 2 |accuracy: 0.8633051514625549 | f1: 0.26044103503227234 | auc: 0.6961171627044678 | treshold: 0.14
Test: accuracy: 0.8675698637962341 | f1: 0.2466796338558197 | auc: 0.6978551149368286 | 


train... loss:30.985296338796616:   0%|                                                                                                 | 4/5000 [00:40<13:51:51,  9.99s/it]

Val update: epoch: 3 |accuracy: 0.8812912702560425 | f1: 0.23893892765045166 | auc: 0.7141842842102051 | treshold: 0.16
Test: accuracy: 0.8865755796432495 | f1: 0.22936242818832397 | auc: 0.713141918182373 | 


train... loss:30.275600731372833:   0%|                                                                                                 | 5/5000 [00:49<13:45:49,  9.92s/it]

Val update: epoch: 4 |accuracy: 0.8560952544212341 | f1: 0.29080119729042053 | auc: 0.7161807417869568 | treshold: 0.16
Test: accuracy: 0.8590931296348572 | f1: 0.2744467556476593 | auc: 0.7144386172294617 | 


train... loss:29.802046537399292:   0%|                                                                                                 | 6/5000 [00:59<13:47:41,  9.94s/it]

Val update: epoch: 5 |accuracy: 0.8415210247039795 | f1: 0.30963748693466187 | auc: 0.7247567176818848 | treshold: 0.16
Test: accuracy: 0.8452180624008179 | f1: 0.2974213659763336 | auc: 0.7236754298210144 | 


train... loss:28.193827345967293:   0%|▏                                                                                               | 10/5000 [01:38<13:25:55,  9.69s/it]

Val update: epoch: 9 |accuracy: 0.8868029117584229 | f1: 0.27117297053337097 | auc: 0.728812575340271 | treshold: 0.2
Test: accuracy: 0.8917359709739685 | f1: 0.2575973868370056 | auc: 0.7287331223487854 | 


train... loss:27.960309609770775:   0%|▏                                                                                               | 11/5000 [01:48<13:36:36,  9.82s/it]

Val update: epoch: 10 |accuracy: 0.8797783255577087 | f1: 0.29637661576271057 | auc: 0.7294415235519409 | treshold: 0.19
Test: accuracy: 0.8848801851272583 | f1: 0.2837049961090088 | auc: 0.7288740873336792 | 


train... loss:26.77064834535122:   0%|▎                                                                                                | 16/5000 [02:36<13:28:21,  9.73s/it]

Val update: epoch: 15 |accuracy: 0.8843481540679932 | f1: 0.2902889549732208 | auc: 0.730168342590332 | treshold: 0.22
Test: accuracy: 0.8881370425224304 | f1: 0.28005358576774597 | auc: 0.7283613681793213 | 


train... loss:25.189408838748932:   1%|▋                                                                                               | 37/5000 [06:05<13:37:07,  9.88s/it]


In [7]:
pd.DataFrame(content_wise_results).to_csv(f'results/cw_{experiment_name}.csv')
del dataset, train_loader, val_loader, test_loader, train_user_item_matrix, train_num_items

# RL4RS

In [4]:
rl4rs_results = []
dataset = RL4RS.load(os.path.join(pkl_path, 'rl4rs.pkl'))
(
    train_loader, 
    val_loader, 
    test_loader, 
    train_user_item_matrix, 
    train_num_items 
) = get_train_val_test_tmatrix_tnumitems(dataset, batch_size=350)

print(f"{len(dataset)} data points among {len(train_loader)} batches")

45942 data points among 106 batches


In [5]:
for embeddings in ['explicit', 'neural', 'svd',  ]:
    print(f"\nEvaluating {experiment_name} with {embeddings} embeddings")

    model = SlatewiseGRU(
        RecsysEmbedding(
            train_num_items, 
            train_user_item_matrix, 
            embeddings=embeddings,
            embedding_dim=40
        ),
    ).to(device)

    best_model, metrics = train(
        model, 
        train_loader, val_loader, test_loader, 
        device=device, lr=1e-3, num_epochs=5000, early_stopping=7,
        silent=True
    )
    
    metrics['embeddings'] = embeddings
    rl4rs_results.append(metrics)


Evaluating SlatewiseGRU with explicit embeddings
Test before learning: {'f1': 0.3915459215641022, 'roc-auc': 0.3203587234020233, 'accuracy': 0.35120299458503723}


train... loss:60.527464777231216:   0%|                                                                                                 | 1/5000 [01:05<91:05:04, 65.59s/it]

Val update: epoch: 0 |accuracy: 0.7665070295333862 | f1: 0.8138593435287476 | auc: 0.8345321416854858 | treshold: 0.41000000000000003
Test: accuracy: 0.7651553750038147 | f1: 0.814621090888977 | auc: 0.8300703763961792 | 


train... loss:46.512063801288605:   0%|                                                                                                 | 2/5000 [01:15<45:30:08, 32.77s/it]

Val update: epoch: 1 |accuracy: 0.7907415628433228 | f1: 0.8367424011230469 | auc: 0.8607780337333679 | treshold: 0.39
Test: accuracy: 0.78844153881073 | f1: 0.8363112211227417 | auc: 0.8572777509689331 | 


train... loss:41.60515797138214:   0%|                                                                                                  | 3/5000 [01:24<30:20:01, 21.85s/it]

Val update: epoch: 2 |accuracy: 0.8113964796066284 | f1: 0.8589465022087097 | auc: 0.8867722749710083 | treshold: 0.4
Test: accuracy: 0.8120179176330566 | f1: 0.8602452278137207 | auc: 0.8836080431938171 | 


train... loss:37.70105466246605:   0%|                                                                                                  | 4/5000 [01:33<23:33:56, 16.98s/it]

Val update: epoch: 3 |accuracy: 0.8216514587402344 | f1: 0.8690974712371826 | auc: 0.9013959169387817 | treshold: 0.41000000000000003
Test: accuracy: 0.8209648132324219 | f1: 0.869376540184021 | auc: 0.8980299234390259 | 


train... loss:34.483782172203064:   0%|                                                                                                 | 5/5000 [01:42<19:12:27, 13.84s/it]

Val update: epoch: 4 |accuracy: 0.83817058801651 | f1: 0.8772000670433044 | auc: 0.9081591963768005 | treshold: 0.39
Test: accuracy: 0.8364889621734619 | f1: 0.8767812252044678 | auc: 0.9044672846794128 | 


train... loss:31.399221807718277:   0%|▏                                                                                                | 8/5000 [02:08<14:23:12, 10.38s/it]

Val update: epoch: 7 |accuracy: 0.831325888633728 | f1: 0.8618189096450806 | auc: 0.9097534418106079 | treshold: 0.38
Test: accuracy: 0.8275903463363647 | f1: 0.8596677780151367 | auc: 0.9075059294700623 | 


train... loss:30.606727123260498:   0%|▏                                                                                                | 9/5000 [02:17<13:46:46,  9.94s/it]

Val update: epoch: 8 |accuracy: 0.8430078029632568 | f1: 0.880699872970581 | auc: 0.912627637386322 | treshold: 0.41000000000000003
Test: accuracy: 0.8407447934150696 | f1: 0.8798570036888123 | auc: 0.9089988470077515 | 


train... loss:29.977589040994644:   0%|▏                                                                                               | 11/5000 [02:34<12:51:14,  9.28s/it]

Val update: epoch: 10 |accuracy: 0.8496106266975403 | f1: 0.8831202983856201 | auc: 0.9170119762420654 | treshold: 0.34
Test: accuracy: 0.8459194898605347 | f1: 0.8811593055725098 | auc: 0.9144606590270996 | 


train... loss:28.636937499046326:   0%|▎                                                                                               | 15/5000 [03:08<12:08:11,  8.76s/it]

Val update: epoch: 14 |accuracy: 0.8492720127105713 | f1: 0.8819607496261597 | auc: 0.9173790812492371 | treshold: 0.34
Test: accuracy: 0.8459678292274475 | f1: 0.880290150642395 | auc: 0.9153916835784912 | 


train... loss:24.964551731944084:   1%|█                                                                                               | 56/5000 [09:03<13:20:06,  9.71s/it]



Evaluating SlatewiseGRU with neural embeddings
Test before learning: {'f1': 0.7518815398216248, 'roc-auc': 0.4940677881240845, 'accuracy': 0.6117761135101318}


train... loss:58.5702400803566:   0%|                                                                                                   | 1/5000 [00:09<13:31:15,  9.74s/it]

Val update: epoch: 0 |accuracy: 0.7861945629119873 | f1: 0.8378457427024841 | auc: 0.843624472618103 | treshold: 0.43
Test: accuracy: 0.7812356352806091 | f1: 0.8350743055343628 | auc: 0.8352789878845215 | 


train... loss:45.614613860845566:   0%|                                                                                                 | 2/5000 [00:19<13:53:13, 10.00s/it]

Val update: epoch: 1 |accuracy: 0.7939341068267822 | f1: 0.831099808216095 | auc: 0.8793274760246277 | treshold: 0.42000000000000004
Test: accuracy: 0.7883448004722595 | f1: 0.828443169593811 | auc: 0.8701988458633423 | 


train... loss:38.86043927073479:   0%|                                                                                                  | 3/5000 [00:29<13:48:49,  9.95s/it]

Val update: epoch: 2 |accuracy: 0.8297054171562195 | f1: 0.87139493227005 | auc: 0.9039137363433838 | treshold: 0.41000000000000003
Test: accuracy: 0.8251723051071167 | f1: 0.8689362406730652 | auc: 0.8971447944641113 | 


train... loss:35.21067264676094:   0%|                                                                                                  | 4/5000 [00:39<13:52:32, 10.00s/it]

Val update: epoch: 3 |accuracy: 0.8267546892166138 | f1: 0.871738851070404 | auc: 0.9041911363601685 | treshold: 0.36000000000000004
Test: accuracy: 0.8209648132324219 | f1: 0.8684902191162109 | auc: 0.8969899415969849 | 


train... loss:33.69998861849308:   0%|                                                                                                  | 5/5000 [00:49<13:34:11,  9.78s/it]

Val update: epoch: 4 |accuracy: 0.8359454274177551 | f1: 0.8692736029624939 | auc: 0.911042332649231 | treshold: 0.42000000000000004
Test: accuracy: 0.8294281363487244 | f1: 0.8649073243141174 | auc: 0.905407726764679 | 


train... loss:30.977994814515114:   0%|▏                                                                                                | 8/5000 [01:17<13:26:50,  9.70s/it]

Val update: epoch: 7 |accuracy: 0.8387752175331116 | f1: 0.8710837960243225 | auc: 0.9132770299911499 | treshold: 0.37
Test: accuracy: 0.8333938121795654 | f1: 0.8679267168045044 | auc: 0.9070544242858887 | 


train... loss:29.597405187785625:   0%|▏                                                                                               | 10/5000 [01:36<13:14:25,  9.55s/it]

Val update: epoch: 9 |accuracy: 0.844894289970398 | f1: 0.8825393319129944 | auc: 0.9166987538337708 | treshold: 0.37
Test: accuracy: 0.8395115733146667 | f1: 0.8793338537216187 | auc: 0.9110370874404907 | 


train... loss:29.24077820777893:   0%|▏                                                                                                | 11/5000 [01:46<13:20:44,  9.63s/it]

Val update: epoch: 10 |accuracy: 0.8512310981750488 | f1: 0.8849355578422546 | auc: 0.9214173555374146 | treshold: 0.38
Test: accuracy: 0.8478781580924988 | f1: 0.8832903504371643 | auc: 0.91692715883255 | 


train... loss:26.66114268451929:   0%|▎                                                                                                | 19/5000 [02:59<12:43:16,  9.19s/it]

Val update: epoch: 18 |accuracy: 0.84982830286026 | f1: 0.8802530169487 | auc: 0.9219175577163696 | treshold: 0.35000000000000003
Test: accuracy: 0.8472252488136292 | f1: 0.8792200088500977 | auc: 0.9193196296691895 | 


train... loss:21.120174750685692:   3%|██▍                                                                                            | 127/5000 [18:29<11:49:42,  8.74s/it]



Evaluating SlatewiseGRU with svd embeddings
Test before learning: {'f1': 0.03072095662355423, 'roc-auc': 0.4046263098716736, 'accuracy': 0.35761094093322754}


train... loss:68.07077246904373:   0%|                                                                                                  | 1/5000 [00:08<11:32:15,  8.31s/it]

Val update: epoch: 0 |accuracy: 0.7110240459442139 | f1: 0.7856245636940002 | auc: 0.7321064472198486 | treshold: 0.3
Test: accuracy: 0.7154152989387512 | f1: 0.7900454998016357 | auc: 0.7318026423454285 | 


train... loss:61.63403230905533:   0%|                                                                                                  | 2/5000 [00:16<11:28:00,  8.26s/it]

Val update: epoch: 1 |accuracy: 0.7318241000175476 | f1: 0.8027327060699463 | auc: 0.7654455304145813 | treshold: 0.33
Test: accuracy: 0.733430027961731 | f1: 0.8051608204841614 | auc: 0.7626577615737915 | 


train... loss:58.64460635185242:   0%|                                                                                                  | 3/5000 [00:24<11:32:52,  8.32s/it]

Val update: epoch: 2 |accuracy: 0.7450781464576721 | f1: 0.8130608797073364 | auc: 0.7833225131034851 | treshold: 0.35000000000000003
Test: accuracy: 0.7458590269088745 | f1: 0.8149778246879578 | auc: 0.7781178951263428 | 


train... loss:55.87036690115929:   0%|                                                                                                  | 4/5000 [00:33<11:40:50,  8.42s/it]

Val update: epoch: 3 |accuracy: 0.7340008616447449 | f1: 0.8221424221992493 | auc: 0.8018357753753662 | treshold: 0.37
Test: accuracy: 0.7372748255729675 | f1: 0.8253917098045349 | auc: 0.7926602959632874 | 


train... loss:53.820115238428116:   0%|                                                                                                 | 5/5000 [00:41<11:27:05,  8.25s/it]

Val update: epoch: 4 |accuracy: 0.7287766933441162 | f1: 0.8208282589912415 | auc: 0.805312991142273 | treshold: 0.41000000000000003
Test: accuracy: 0.7308427095413208 | f1: 0.8229662179946899 | auc: 0.7934978604316711 | 


train... loss:51.440177857875824:   0%|                                                                                                 | 6/5000 [00:49<11:31:50,  8.31s/it]

Val update: epoch: 5 |accuracy: 0.767909824848175 | f1: 0.829628586769104 | auc: 0.8192315101623535 | treshold: 0.42000000000000004
Test: accuracy: 0.7662918567657471 | f1: 0.8295324444770813 | auc: 0.8115700483322144 | 


train... loss:50.13752740621567:   0%|▏                                                                                                 | 7/5000 [00:58<11:36:32,  8.37s/it]

Val update: epoch: 6 |accuracy: 0.771174967288971 | f1: 0.8272878527641296 | auc: 0.8256872892379761 | treshold: 0.42000000000000004
Test: accuracy: 0.7682021260261536 | f1: 0.8264537453651428 | auc: 0.8183467388153076 | 


train... loss:47.27033060789108:   0%|▏                                                                                                 | 9/5000 [01:14<11:24:56,  8.23s/it]

Val update: epoch: 8 |accuracy: 0.7652735710144043 | f1: 0.834968626499176 | auc: 0.8393774032592773 | treshold: 0.41000000000000003
Test: accuracy: 0.7607060670852661 | f1: 0.832809567451477 | auc: 0.8286436796188354 | 


train... loss:46.43124669790268:   0%|▏                                                                                                | 10/5000 [01:22<11:25:50,  8.25s/it]

Val update: epoch: 9 |accuracy: 0.7813331484794617 | f1: 0.8314849734306335 | auc: 0.8455098867416382 | treshold: 0.43
Test: accuracy: 0.7783097624778748 | f1: 0.8308362364768982 | auc: 0.8380083441734314 | 


train... loss:45.237654119729996:   0%|▏                                                                                               | 11/5000 [01:31<11:30:42,  8.31s/it]

Val update: epoch: 10 |accuracy: 0.7848643064498901 | f1: 0.8314223289489746 | auc: 0.8498649001121521 | treshold: 0.42000000000000004
Test: accuracy: 0.7802925705909729 | f1: 0.8295885324478149 | auc: 0.8432841300964355 | 


train... loss:44.553138345479965:   0%|▏                                                                                               | 12/5000 [01:39<11:28:55,  8.29s/it]

Val update: epoch: 11 |accuracy: 0.778866171836853 | f1: 0.841248095035553 | auc: 0.8568811416625977 | treshold: 0.4
Test: accuracy: 0.775674045085907 | f1: 0.8400875926017761 | auc: 0.847791850566864 | 


train... loss:43.723387002944946:   0%|▏                                                                                               | 13/5000 [01:47<11:35:17,  8.37s/it]

Val update: epoch: 12 |accuracy: 0.7720698714256287 | f1: 0.8401194214820862 | auc: 0.8588995337486267 | treshold: 0.42000000000000004
Test: accuracy: 0.7703784108161926 | f1: 0.8399514555931091 | auc: 0.8494564890861511 | 


train... loss:42.88685595989227:   0%|▎                                                                                                | 14/5000 [01:56<11:32:16,  8.33s/it]

Val update: epoch: 13 |accuracy: 0.7922168970108032 | f1: 0.8478903770446777 | auc: 0.8658046722412109 | treshold: 0.42000000000000004
Test: accuracy: 0.7899165749549866 | f1: 0.847353994846344 | auc: 0.8577067852020264 | 


train... loss:41.97246293723583:   0%|▎                                                                                                | 15/5000 [02:04<11:37:47,  8.40s/it]

Val update: epoch: 14 |accuracy: 0.7967638969421387 | f1: 0.851171612739563 | auc: 0.8690428137779236 | treshold: 0.42000000000000004
Test: accuracy: 0.7926973700523376 | f1: 0.8495040535926819 | auc: 0.8610823154449463 | 


train... loss:41.3592549264431:   0%|▎                                                                                                 | 16/5000 [02:13<11:42:21,  8.46s/it]

Val update: epoch: 15 |accuracy: 0.8013350963592529 | f1: 0.8519466519355774 | auc: 0.8733751773834229 | treshold: 0.42000000000000004
Test: accuracy: 0.7994921803474426 | f1: 0.851743221282959 | auc: 0.8656696081161499 | 


train... loss:40.80587098002434:   0%|▎                                                                                                | 17/5000 [02:21<11:32:45,  8.34s/it]

Val update: epoch: 16 |accuracy: 0.8052532076835632 | f1: 0.8521266579627991 | auc: 0.8752619028091431 | treshold: 0.42000000000000004
Test: accuracy: 0.8016443252563477 | f1: 0.8506835103034973 | auc: 0.867877721786499 | 


train... loss:40.64431756734848:   0%|▎                                                                                                | 18/5000 [02:29<11:34:31,  8.36s/it]

Val update: epoch: 17 |accuracy: 0.8074783682823181 | f1: 0.8467167615890503 | auc: 0.8769110441207886 | treshold: 0.44
Test: accuracy: 0.8025147914886475 | f1: 0.8440191745758057 | auc: 0.8709082007408142 | 


train... loss:40.3174968957901:   0%|▍                                                                                                 | 20/5000 [02:45<11:25:30,  8.26s/it]

Val update: epoch: 19 |accuracy: 0.8180960416793823 | f1: 0.8619721531867981 | auc: 0.8834640979766846 | treshold: 0.42000000000000004
Test: accuracy: 0.8158868551254272 | f1: 0.8614073991775513 | auc: 0.8768662810325623 | 


train... loss:39.91748324036598:   0%|▍                                                                                                | 21/5000 [02:54<11:29:08,  8.30s/it]

Val update: epoch: 20 |accuracy: 0.7975378632545471 | f1: 0.8560372591018677 | auc: 0.8840525150299072 | treshold: 0.39
Test: accuracy: 0.7946076393127441 | f1: 0.855125367641449 | auc: 0.8769891262054443 | 


train... loss:39.2757493853569:   0%|▍                                                                                                 | 22/5000 [03:02<11:28:58,  8.30s/it]

Val update: epoch: 21 |accuracy: 0.8198858499526978 | f1: 0.8610919117927551 | auc: 0.8853291273117065 | treshold: 0.39
Test: accuracy: 0.816733181476593 | f1: 0.8597390651702881 | auc: 0.8792827725410461 | 


train... loss:38.59160029888153:   0%|▍                                                                                                | 24/5000 [03:18<11:13:45,  8.12s/it]

Val update: epoch: 23 |accuracy: 0.8187490701675415 | f1: 0.8572136163711548 | auc: 0.8864840269088745 | treshold: 0.42000000000000004
Test: accuracy: 0.814339280128479 | f1: 0.8548691868782043 | auc: 0.8800528049468994 | 


train... loss:38.29121348261833:   0%|▍                                                                                                | 25/5000 [03:27<11:24:29,  8.26s/it]

Val update: epoch: 24 |accuracy: 0.8202728033065796 | f1: 0.8673580288887024 | auc: 0.8897553086280823 | treshold: 0.41000000000000003
Test: accuracy: 0.8191512227058411 | f1: 0.8674194812774658 | auc: 0.883796215057373 | 


train... loss:37.54507315158844:   1%|▌                                                                                                | 26/5000 [03:35<11:31:55,  8.35s/it]

Val update: epoch: 25 |accuracy: 0.8189667463302612 | f1: 0.866126537322998 | auc: 0.8907374143600464 | treshold: 0.38
Test: accuracy: 0.8166364431381226 | f1: 0.865537703037262 | auc: 0.8846683502197266 | 


train... loss:37.78067281842232:   1%|▌                                                                                                | 27/5000 [03:44<11:32:27,  8.35s/it]

Val update: epoch: 26 |accuracy: 0.8254244923591614 | f1: 0.8695699572563171 | auc: 0.892900824546814 | treshold: 0.38
Test: accuracy: 0.8257526159286499 | f1: 0.8706840872764587 | auc: 0.8871996998786926 | 


train... loss:37.69675213098526:   1%|▌                                                                                                | 29/5000 [04:00<11:20:29,  8.21s/it]

Val update: epoch: 28 |accuracy: 0.8223044276237488 | f1: 0.8692402243614197 | auc: 0.8936282396316528 | treshold: 0.41000000000000003
Test: accuracy: 0.8218836784362793 | f1: 0.869904637336731 | auc: 0.8882394433021545 | 


train... loss:36.70290610194206:   1%|▌                                                                                                | 31/5000 [04:16<11:16:33,  8.17s/it]

Val update: epoch: 30 |accuracy: 0.8255937695503235 | f1: 0.8705223202705383 | auc: 0.8941760659217834 | treshold: 0.42000000000000004
Test: accuracy: 0.8232136368751526 | f1: 0.8698392510414124 | auc: 0.8885153532028198 | 


train... loss:36.58035519719124:   1%|▋                                                                                                | 33/5000 [04:32<11:11:19,  8.11s/it]

Val update: epoch: 32 |accuracy: 0.8328738212585449 | f1: 0.8706380128860474 | auc: 0.8967159986495972 | treshold: 0.41000000000000003
Test: accuracy: 0.83070969581604 | f1: 0.8700365424156189 | auc: 0.8914424777030945 | 


train... loss:35.29826559126377:   1%|▋                                                                                                | 38/5000 [05:11<10:58:56,  7.97s/it]

Val update: epoch: 37 |accuracy: 0.8306970596313477 | f1: 0.8731516599655151 | auc: 0.8977552652359009 | treshold: 0.37
Test: accuracy: 0.8298150300979614 | f1: 0.8733762502670288 | auc: 0.8915663957595825 | 


train... loss:35.84325301647186:   1%|▊                                                                                                | 40/5000 [05:27<11:11:04,  8.12s/it]

Val update: epoch: 39 |accuracy: 0.8362598419189453 | f1: 0.8725143074989319 | auc: 0.9002146124839783 | treshold: 0.4
Test: accuracy: 0.8345544934272766 | f1: 0.8721551895141602 | auc: 0.8955903649330139 | 


train... loss:34.24882562458515:   1%|▊                                                                                                | 45/5000 [06:06<10:53:44,  7.92s/it]

Val update: epoch: 44 |accuracy: 0.8375658988952637 | f1: 0.874640703201294 | auc: 0.9029099941253662 | treshold: 0.4
Test: accuracy: 0.8367791175842285 | f1: 0.874823808670044 | auc: 0.897625207901001 | 


train... loss:32.39382588863373:   1%|█▏                                                                                               | 58/5000 [07:46<10:48:57,  7.88s/it]

Val update: epoch: 57 |accuracy: 0.8399361371994019 | f1: 0.8764883875846863 | auc: 0.9069984555244446 | treshold: 0.38
Test: accuracy: 0.8389312028884888 | f1: 0.8762884736061096 | auc: 0.9029574990272522 | 


train... loss:31.786647975444794:   1%|█▏                                                                                              | 60/5000 [08:02<10:57:24,  7.98s/it]

Val update: epoch: 59 |accuracy: 0.8376868367195129 | f1: 0.871019184589386 | auc: 0.907134473323822 | treshold: 0.39
Test: accuracy: 0.8346753716468811 | f1: 0.8692759275436401 | auc: 0.9034159183502197 | 


train... loss:31.38290873169899:   1%|█▏                                                                                               | 63/5000 [08:25<10:48:33,  7.88s/it]

Val update: epoch: 62 |accuracy: 0.8381221890449524 | f1: 0.8759475946426392 | auc: 0.9076036214828491 | treshold: 0.4
Test: accuracy: 0.8371660113334656 | f1: 0.8756003975868225 | auc: 0.9030414819717407 | 


train... loss:29.134227842092514:   2%|█▌                                                                                              | 83/5000 [10:57<10:38:16,  7.79s/it]

Val update: epoch: 82 |accuracy: 0.833720326423645 | f1: 0.8649339079856873 | auc: 0.9092946648597717 | treshold: 0.4
Test: accuracy: 0.8290170431137085 | f1: 0.8619834780693054 | auc: 0.9053507447242737 | 


train... loss:28.859275236725807:   2%|█▋                                                                                              | 86/5000 [11:21<10:46:52,  7.90s/it]

Val update: epoch: 85 |accuracy: 0.8377110362052917 | f1: 0.8749394416809082 | auc: 0.9094183444976807 | treshold: 0.4
Test: accuracy: 0.8349413871765137 | f1: 0.8734801411628723 | auc: 0.9042787551879883 | 


train... loss:28.880144611001015:   2%|█▋                                                                                              | 87/5000 [11:29<10:55:37,  8.01s/it]

Val update: epoch: 86 |accuracy: 0.8389203548431396 | f1: 0.8755419254302979 | auc: 0.9098940491676331 | treshold: 0.4
Test: accuracy: 0.8376254439353943 | f1: 0.8752484917640686 | auc: 0.9051905274391174 | 


train... loss:28.665145352482796:   2%|█▋                                                                                              | 88/5000 [11:38<11:08:23,  8.16s/it]

Val update: epoch: 87 |accuracy: 0.841121256351471 | f1: 0.8735928535461426 | auc: 0.9105575680732727 | treshold: 0.4
Test: accuracy: 0.8358602523803711 | f1: 0.8702995777130127 | auc: 0.9065022468566895 | 


train... loss:29.051254779100418:   2%|█▊                                                                                              | 94/5000 [12:24<10:42:05,  7.85s/it]

Val update: epoch: 93 |accuracy: 0.8388235569000244 | f1: 0.8699098229408264 | auc: 0.9115453958511353 | treshold: 0.39
Test: accuracy: 0.8340224623680115 | f1: 0.867172360420227 | auc: 0.9078233242034912 | 


train... loss:26.533825382590294:   3%|███                                                                                            | 164/5000 [21:21<10:29:54,  7.82s/it]


In [6]:
pd.DataFrame(rl4rs_results).to_csv(f'results/rl4rs_{experiment_name}.csv')
del dataset, train_loader, val_loader, test_loader, train_user_item_matrix, train_num_items

In [7]:
rl4rs_results

[{'f1': 0.880290150642395,
  'roc-auc': 0.9153916835784912,
  'accuracy': 0.8459678292274475,
  'embeddings': 'explicit'},
 {'f1': 0.8792200088500977,
  'roc-auc': 0.9193196296691895,
  'accuracy': 0.8472252488136292,
  'embeddings': 'neural'},
 {'f1': 0.867172360420227,
  'roc-auc': 0.9078233242034912,
  'accuracy': 0.8340224623680115,
  'embeddings': 'svd'}]