In [1]:
import torch
import random
import datetime
import pandas as pd
import numpy as np
import os

from torch.utils.data import Dataset
from src.datasets import RL4RS, ContentWise, DummyData
from src.utils import train, get_dummy_data, get_train_val_test_tmatrix_tnumitems, fit_treshold
from src.embeddings import RecsysEmbedding
from sklearn.linear_model import LogisticRegression

experiment_name = 'SlatewiseGRU'
device = 'cuda:2'
seed = 7331
pkl_path = '../pkl/'


random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7f2ddc42ab30>

In [2]:
def flatten(true, pred, mask, to_cpu=True):
    mask = mask.flatten()
    nnz_idx = mask.nonzero()[:, 0]
    true, pred = [x.flatten()[nnz_idx] for x in [true, pred]]
    if to_cpu:
        true, pred = [x.cpu().numpy() for x in [true, pred]]
    return true, pred

# Модель

In [3]:
import torch.nn as nn
import torch.nn.functional as F

class SlatewiseGRU(nn.Module):
    def __init__(self, embedding, readout=False):
        super().__init__()
        
        self.embedding = embedding
        
        self.emb_dim = embedding.embedding_dim
        self.rnn_layer = nn.GRU(
            input_size=self.emb_dim * 2, 
            hidden_size=self.emb_dim, 
            batch_first=True
        )
        self.out_layer = nn.Linear(self.emb_dim, 1)
        
        self.thr = -1.5
        self.readout = readout
        self.readout_mode = 'threshold' # ['soft' ,'threshold', 'sample', 'diff_sample']
        
        self.calibration = False
        self.w = 1
        self.b = 0

    def forward(self, batch):
        item_embs, user_embs = self.embedding(batch)
        shp = item_embs.shape
        max_sequence = item_embs.size(1)
        # ilya format:
        # 'items': (batch, slate, 2*embedding_dim ) 2, нужно для readout, по умолчанию ноли на половине эмбеддинга
        # 'clicks': (batch, slate)
        # 'users': (1, batch, embedding_dim), 
        # 'mask': (batch, slate)
        
        x = {}
        x['items'] = torch.cat(
            [
                item_embs.flatten(0,1),
                torch.zeros_like(item_embs.flatten(0,1)),
            ],
            dim = -1
        )
                
        if self.training:
            indices = (batch['length'] - 1)
        else:
            indices = (batch['in_length'] - 1)
        indices[indices<0] = 0
        indices = indices[:, None, None].repeat(1, 1, user_embs.size(-1))
        user_embs = user_embs.gather(1, indices).squeeze(-2).unsqueeze(0)
        x['users'] = user_embs.repeat_interleave(max_sequence, 1)
        x['clicks'] = (batch['responses'].flatten(0,1) > 0 ).int().clone()
        x['mask'] = batch['slates_mask'].flatten(0,1).clone()
        
        items = x['items']
        h = x['users']
        
        if self.readout:
            res = []
            seq_len = items.shape[1]
            for i in range(seq_len):
#                 print(items[:,[i],:])
                output, h = self.rnn_layer(items[:,[i],:], h)
                y = self.out_layer(output)[:, :, 0]
                
                # readout
                if i + 1 < seq_len:
                    if self.readout_mode == 'threshold':
                        items[:, [i+1], self.emb_dim:] *= (y.detach()[:, :, None] > self.thr).to(torch.float32)
                    elif self.readout_mode == 'soft':
                        items[:, [i+1], self.emb_dim:] *= torch.sigmoid(y)[:, :, None]
                    elif self.readout_mode == 'diff_sample' or self.readout_mode == 'sample':
                        eps = 1e-8
                        gumbel_sample = -( (torch.rand_like(y) + eps).log() / (torch.rand_like(y) + eps).log() + eps).log()
                        T = 0.3
                        bernoulli_sample = torch.sigmoid( (nn.LogSigmoid()(self.w * y + self.b) + gumbel_sample) / T )
                        if self.readout_mode == 'sample':
                            bernoulli_sample = bernoulli_sample.detach()
                        items[:, [i+1], self.emb_dim:] *= bernoulli_sample[:, :, None]
                    else:
                        raise
                    
                res.append(y)
        
            y = torch.cat(res, axis=1)
            
        else:
            items[:, 1:, self.emb_dim:] *= x['clicks'][:, :-1, None]
            rnn_out, _ = self.rnn_layer(items, h)
            y = self.out_layer(rnn_out)[:, :, 0]
        
        
        if self.calibration and self.training:
            clicks_flat, logits_flat = flatten(x['clicks'], y.detach(), x['mask'])
            logreg = LogisticRegression()
            logreg.fit(logits_flat[:, None], clicks_flat)
            γ = 0.3
            self.w = (1 - γ) * self.w + γ * logreg.coef_[0, 0]
            self.b = (1 - γ) * self.b + γ * logreg.intercept_[0]
            y = self.w * y + self.b
        else:
            y = self.w * y + self.b
            
        return y.reshape(shp[:-1])

# Игрушечный датасет: проверим, что сходится к идеальным метрикам

In [4]:
d = DummyData()
dummy_loader, dummy_matrix = get_dummy_data(d)

model = SlatewiseGRU(
    RecsysEmbedding(d.n_items, dummy_matrix, embeddings='svd', embedding_dim = 2),
    readout=False
).to('cpu')

train(
    model, 
    dummy_loader, dummy_loader, dummy_loader,
    device=device, lr=1e-3, num_epochs=5000, dummy=True,
    silent=True,
)


biulding affinity matrix...


3it [00:00, 3458.74it/s]


Test before learning: {'f1': 0.0, 'roc-auc': 0.3333333134651184, 'accuracy': 0.75}


train... loss:0.7181857228279114:   0%|                                                                                                    | 1/5000 [00:00<26:43,  3.12it/s]

Val update: epoch: 0 |accuracy: 0.25 | f1: 0.4000000059604645 | auc: 0.3333333134651184 | treshold: 0.01
Test: accuracy: 0.25 | f1: 0.4000000059604645 | auc: 0.3333333134651184 | 


train... loss:0.7011198401451111:   1%|▊                                                                                                  | 42/5000 [00:09<21:07,  3.91it/s]

Val update: epoch: 41 |accuracy: 0.25 | f1: 0.4000000059604645 | auc: 0.6666666269302368 | treshold: 0.01
Test: accuracy: 0.25 | f1: 0.4000000059604645 | auc: 0.6666666269302368 | 


train... loss:0.6998206377029419:   1%|▉                                                                                                  | 45/5000 [00:10<21:54,  3.77it/s]

Val update: epoch: 44 |accuracy: 0.25 | f1: 0.4000000059604645 | auc: 1.0 | treshold: 0.01
Test: accuracy: 0.25 | f1: 0.4000000059604645 | auc: 1.0 | 


train... loss:0.6820520758628845:   2%|█▋                                                                                                 | 85/5000 [00:19<20:37,  3.97it/s]

Val update: epoch: 84 |accuracy: 0.75 | f1: 0.6666666865348816 | auc: 1.0 | treshold: 0.36000000000000004
Test: accuracy: 0.75 | f1: 0.6666666865348816 | auc: 1.0 | 


train... loss:0.6675306558609009:   2%|██▎                                                                                               | 117/5000 [00:26<18:25,  4.42it/s]

Val update: epoch: 117 |accuracy: 1.0 | f1: 1.0 | auc: 1.0 | treshold: 0.38
Test: accuracy: 1.0 | f1: 1.0 | auc: 1.0 | 





(SlatewiseGRU(
   (embedding): RecsysEmbedding()
   (rnn_layer): GRU(4, 2, batch_first=True)
   (out_layer): Linear(in_features=2, out_features=1, bias=True)
 ),
 {'f1': 1.0, 'roc-auc': 1.0, 'accuracy': 1.0})

# ContentWise

In [5]:
content_wise_results = []
dataset = ContentWise.load(os.path.join(pkl_path, 'cw.pkl'))
(
    train_loader, 
    val_loader, 
    test_loader, 
    train_user_item_matrix, 
    train_num_items 
) = get_train_val_test_tmatrix_tnumitems(dataset, batch_size=150)

print(f"{len(dataset)} data points among {len(train_loader)} batches")

20216 data points among 108 batches


In [6]:
for embeddings in ['svd', 'neural']:
    print(f"\nEvaluating {experiment_name} with {embeddings} embeddings")
    model = SlatewiseGRU(
        RecsysEmbedding(train_num_items, train_user_item_matrix, embeddings=embeddings).to('cpu'),
    ).to(device)

    _, metrics = train(
        model, 
        train_loader, val_loader, test_loader, 
        device=device, lr=1e-3, num_epochs=5000, early_stopping=7,
       silent=True, 
    )
    
    metrics['embeddings'] = embeddings
    content_wise_results.append(metrics)


Evaluating SlatewiseGRU with svd embeddings
Test before learning: {'f1': 0.17239901423454285, 'roc-auc': 0.4641438126564026, 'accuracy': 0.11004862934350967}


train... loss:49.18179851770401:   0%|                                                                                                  | 1/5000 [00:39<55:26:15, 39.92s/it]

Val update: epoch: 0 |accuracy: 0.1531217247247696 | f1: 0.18792562186717987 | auc: 0.6076698899269104 | treshold: 0.09
Test: accuracy: 0.17369838058948517 | f1: 0.1841088980436325 | auc: 0.6171619892120361 | 


train... loss:34.75427806377411:   0%|                                                                                                  | 2/5000 [00:47<29:18:44, 21.11s/it]

Val update: epoch: 1 |accuracy: 0.7042703628540039 | f1: 0.23013544082641602 | auc: 0.608708381652832 | treshold: 0.13
Test: accuracy: 0.7140371203422546 | f1: 0.2277820110321045 | auc: 0.6147950887680054 | 


train... loss:33.50057601928711:   0%|                                                                                                  | 3/5000 [00:55<20:54:11, 15.06s/it]

Val update: epoch: 2 |accuracy: 0.8128049373626709 | f1: 0.23322582244873047 | auc: 0.6253132820129395 | treshold: 0.13
Test: accuracy: 0.8206504583358765 | f1: 0.23027826845645905 | auc: 0.6287914514541626 | 


train... loss:32.839709997177124:   0%|                                                                                                 | 4/5000 [01:03<16:57:21, 12.22s/it]

Val update: epoch: 3 |accuracy: 0.8089298009872437 | f1: 0.241201713681221 | auc: 0.6412267088890076 | treshold: 0.13
Test: accuracy: 0.8171557188034058 | f1: 0.23751938343048096 | auc: 0.6442759037017822 | 


train... loss:32.555352568626404:   0%|                                                                                                 | 5/5000 [01:11<14:46:46, 10.65s/it]

Val update: epoch: 4 |accuracy: 0.7933675050735474 | f1: 0.2504480183124542 | auc: 0.6517706513404846 | treshold: 0.12
Test: accuracy: 0.802328884601593 | f1: 0.24520158767700195 | auc: 0.6563292741775513 | 


train... loss:32.29015788435936:   0%|                                                                                                  | 6/5000 [01:19<13:26:58,  9.70s/it]

Val update: epoch: 5 |accuracy: 0.7607762813568115 | f1: 0.24923688173294067 | auc: 0.6570345163345337 | treshold: 0.12
Test: accuracy: 0.7719316482543945 | f1: 0.2461659461259842 | auc: 0.6631274819374084 | 


train... loss:31.412767827510834:   0%|▏                                                                                               | 12/5000 [02:03<10:46:23,  7.78s/it]

Val update: epoch: 11 |accuracy: 0.7911597490310669 | f1: 0.2581848204135895 | auc: 0.65810626745224 | treshold: 0.12
Test: accuracy: 0.7994586825370789 | f1: 0.26283279061317444 | auc: 0.668986439704895 | 


train... loss:31.35055923461914:   0%|▎                                                                                                | 13/5000 [02:11<10:49:41,  7.82s/it]

Val update: epoch: 12 |accuracy: 0.8794695138931274 | f1: 0.2327272742986679 | auc: 0.663524329662323 | treshold: 0.15000000000000002
Test: accuracy: 0.8852073550224304 | f1: 0.22848576307296753 | auc: 0.6673926711082458 | 


train... loss:31.134289383888245:   0%|▎                                                                                               | 14/5000 [02:19<10:50:15,  7.82s/it]

Val update: epoch: 13 |accuracy: 0.869974672794342 | f1: 0.24425700306892395 | auc: 0.6640497446060181 | treshold: 0.13
Test: accuracy: 0.8761358261108398 | f1: 0.24123166501522064 | auc: 0.6726142168045044 | 


train... loss:31.147396057844162:   0%|▎                                                                                               | 15/5000 [02:27<10:51:55,  7.85s/it]

Val update: epoch: 14 |accuracy: 0.8480361700057983 | f1: 0.2621992230415344 | auc: 0.6647053956985474 | treshold: 0.14
Test: accuracy: 0.8570111393928528 | f1: 0.2589595317840576 | auc: 0.6712898015975952 | 


train... loss:31.132243365049362:   0%|▎                                                                                               | 16/5000 [02:35<10:52:17,  7.85s/it]

Val update: epoch: 15 |accuracy: 0.7881800532341003 | f1: 0.2650524973869324 | auc: 0.666884183883667 | treshold: 0.13
Test: accuracy: 0.7963654398918152 | f1: 0.2663809359073639 | auc: 0.6734336614608765 | 


train... loss:30.956328779459:   0%|▎                                                                                                  | 18/5000 [02:50<10:45:13,  7.77s/it]

Val update: epoch: 17 |accuracy: 0.8699901103973389 | f1: 0.2648625075817108 | auc: 0.6687343120574951 | treshold: 0.14
Test: accuracy: 0.875734269618988 | f1: 0.2505829632282257 | auc: 0.6749803423881531 | 


train... loss:30.77937039732933:   0%|▎                                                                                                | 19/5000 [02:58<10:48:20,  7.81s/it]

Val update: epoch: 18 |accuracy: 0.88271164894104 | f1: 0.24430517852306366 | auc: 0.6701903939247131 | treshold: 0.16
Test: accuracy: 0.8883006572723389 | f1: 0.23567721247673035 | auc: 0.6745896339416504 | 


train... loss:30.456127643585205:   0%|▍                                                                                               | 23/5000 [03:28<10:34:06,  7.64s/it]

Val update: epoch: 22 |accuracy: 0.8794540762901306 | f1: 0.2549618184566498 | auc: 0.6710207462310791 | treshold: 0.14
Test: accuracy: 0.8843894600868225 | f1: 0.24801702797412872 | auc: 0.6725579500198364 | 


train... loss:30.083432763814926:   1%|▍                                                                                               | 26/5000 [03:51<10:32:18,  7.63s/it]

Val update: epoch: 25 |accuracy: 0.8552924394607544 | f1: 0.2777221202850342 | auc: 0.6739699840545654 | treshold: 0.14
Test: accuracy: 0.8624540567398071 | f1: 0.26914262771606445 | auc: 0.6788731217384338 | 


train... loss:29.995643123984337:   1%|▋                                                                                               | 33/5000 [04:43<10:29:39,  7.61s/it]

Val update: epoch: 32 |accuracy: 0.7833940386772156 | f1: 0.2799958884716034 | auc: 0.6755121946334839 | treshold: 0.13
Test: accuracy: 0.7856282591819763 | f1: 0.26927560567855835 | auc: 0.6755298972129822 | 


train... loss:29.87032163143158:   1%|▋                                                                                                | 35/5000 [04:59<10:37:21,  7.70s/it]

Val update: epoch: 34 |accuracy: 0.8797783255577087 | f1: 0.27190276980400085 | auc: 0.6760849952697754 | treshold: 0.14
Test: accuracy: 0.8857873678207397 | f1: 0.26153847575187683 | auc: 0.67588210105896 | 


train... loss:29.62252241373062:   1%|▋                                                                                                | 37/5000 [05:14<10:36:34,  7.70s/it]

Val update: epoch: 36 |accuracy: 0.8637065291404724 | f1: 0.27981725335121155 | auc: 0.6765272617340088 | treshold: 0.15000000000000002
Test: accuracy: 0.8710200190544128 | f1: 0.27646616101264954 | auc: 0.680415689945221 | 


train... loss:29.62823587656021:   1%|▋                                                                                                | 38/5000 [05:22<10:41:51,  7.76s/it]

Val update: epoch: 37 |accuracy: 0.8845025897026062 | f1: 0.26678428053855896 | auc: 0.6772441267967224 | treshold: 0.16
Test: accuracy: 0.8896390795707703 | f1: 0.25199073553085327 | auc: 0.6794477105140686 | 


train... loss:29.56452512741089:   1%|▊                                                                                                | 40/5000 [05:37<10:40:57,  7.75s/it]

Val update: epoch: 39 |accuracy: 0.8257426023483276 | f1: 0.2882905602455139 | auc: 0.6812461614608765 | treshold: 0.14
Test: accuracy: 0.8320122361183167 | f1: 0.2793160676956177 | auc: 0.6825573444366455 | 


train... loss:29.593356609344482:   1%|▉                                                                                               | 49/5000 [06:52<11:35:24,  8.43s/it]



Evaluating SlatewiseGRU with neural embeddings
Test before learning: {'f1': 0.1642305850982666, 'roc-auc': 0.525184154510498, 'accuracy': 0.6215962767601013}


train... loss:44.94625923037529:   0%|                                                                                                  | 1/5000 [00:10<14:02:29, 10.11s/it]

Val update: epoch: 0 |accuracy: 0.1853424310684204 | f1: 0.19218933582305908 | auc: 0.62271648645401 | treshold: 0.09999999999999999
Test: accuracy: 0.19773061573505402 | f1: 0.18699419498443604 | auc: 0.6265320777893066 | 


train... loss:33.257901936769485:   0%|                                                                                                 | 2/5000 [00:20<13:55:09, 10.03s/it]

Val update: epoch: 1 |accuracy: 0.8082504868507385 | f1: 0.2585960030555725 | auc: 0.6666598916053772 | treshold: 0.15000000000000002
Test: accuracy: 0.815623939037323 | f1: 0.2503325641155243 | auc: 0.6668229103088379 | 


train... loss:31.76764327287674:   0%|                                                                                                  | 3/5000 [00:30<13:52:31, 10.00s/it]

Val update: epoch: 2 |accuracy: 0.8633051514625549 | f1: 0.26044103503227234 | auc: 0.6961171627044678 | treshold: 0.14
Test: accuracy: 0.8675698637962341 | f1: 0.2466796338558197 | auc: 0.6978551149368286 | 


train... loss:30.985296338796616:   0%|                                                                                                 | 4/5000 [00:40<13:51:51,  9.99s/it]

Val update: epoch: 3 |accuracy: 0.8812912702560425 | f1: 0.23893892765045166 | auc: 0.7141842842102051 | treshold: 0.16
Test: accuracy: 0.8865755796432495 | f1: 0.22936242818832397 | auc: 0.713141918182373 | 


train... loss:30.275600731372833:   0%|                                                                                                 | 5/5000 [00:49<13:45:49,  9.92s/it]

Val update: epoch: 4 |accuracy: 0.8560952544212341 | f1: 0.29080119729042053 | auc: 0.7161807417869568 | treshold: 0.16
Test: accuracy: 0.8590931296348572 | f1: 0.2744467556476593 | auc: 0.7144386172294617 | 


train... loss:29.802046537399292:   0%|                                                                                                 | 6/5000 [00:59<13:47:41,  9.94s/it]

Val update: epoch: 5 |accuracy: 0.8415210247039795 | f1: 0.30963748693466187 | auc: 0.7247567176818848 | treshold: 0.16
Test: accuracy: 0.8452180624008179 | f1: 0.2974213659763336 | auc: 0.7236754298210144 | 


train... loss:28.193827345967293:   0%|▏                                                                                               | 10/5000 [01:38<13:25:55,  9.69s/it]

Val update: epoch: 9 |accuracy: 0.8868029117584229 | f1: 0.27117297053337097 | auc: 0.728812575340271 | treshold: 0.2
Test: accuracy: 0.8917359709739685 | f1: 0.2575973868370056 | auc: 0.7287331223487854 | 


train... loss:27.960309609770775:   0%|▏                                                                                               | 11/5000 [01:48<13:36:36,  9.82s/it]

Val update: epoch: 10 |accuracy: 0.8797783255577087 | f1: 0.29637661576271057 | auc: 0.7294415235519409 | treshold: 0.19
Test: accuracy: 0.8848801851272583 | f1: 0.2837049961090088 | auc: 0.7288740873336792 | 


train... loss:26.77064834535122:   0%|▎                                                                                                | 16/5000 [02:36<13:28:21,  9.73s/it]

Val update: epoch: 15 |accuracy: 0.8843481540679932 | f1: 0.2902889549732208 | auc: 0.730168342590332 | treshold: 0.22
Test: accuracy: 0.8881370425224304 | f1: 0.28005358576774597 | auc: 0.7283613681793213 | 


train... loss:25.189408838748932:   1%|▋                                                                                               | 37/5000 [06:05<13:37:07,  9.88s/it]


In [7]:
pd.DataFrame(content_wise_results).to_csv(f'results/cw_{experiment_name}.csv')
del dataset, train_loader, val_loader, test_loader, train_user_item_matrix, train_num_items

# RL4RS

In [8]:
rl4rs_results = []
dataset = RL4RS.load(os.path.join(pkl_path, 'rl4rs.pkl'))
(
    train_loader, 
    val_loader, 
    test_loader, 
    train_user_item_matrix, 
    train_num_items 
) = get_train_val_test_tmatrix_tnumitems(dataset, batch_size=350)

print(f"{len(dataset)} data points among {len(train_loader)} batches")

45942 data points among 106 batches


In [9]:
for embeddings in ['neural','explicit', 'svd',  ]:
    print(f"\nEvaluating {experiment_name} with {embeddings} embeddings")

    model = SlatewiseGRU(
        RecsysEmbedding(train_num_items, train_user_item_matrix, embeddings=embeddings),
    ).to(device)

    best_model, metrics = train(
        model, 
        train_loader, val_loader, test_loader, 
        device=device, lr=1e-3, num_epochs=5000, early_stopping=7,
        silent=True
    )
    
    metrics['embeddings'] = embeddings
    rl4rs_results.append(metrics)


Evaluating SlatewiseGRU with neural embeddings
Test before learning: {'f1': 0.16593550145626068, 'roc-auc': 0.48787927627563477, 'accuracy': 0.3701607882976532}


train... loss:63.73415842652321:   0%|                                                                                                  | 1/5000 [01:03<87:35:14, 63.08s/it]

Val update: epoch: 0 |accuracy: 0.7681758999824524 | f1: 0.8368760347366333 | auc: 0.8075522184371948 | treshold: 0.39
Test: accuracy: 0.7684923410415649 | f1: 0.8374698758125305 | auc: 0.8102037906646729 | 


train... loss:48.87808521091938:   0%|                                                                                                  | 2/5000 [01:09<41:34:08, 29.94s/it]

Val update: epoch: 1 |accuracy: 0.8034634590148926 | f1: 0.8563498854637146 | auc: 0.8670310974121094 | treshold: 0.41000000000000003
Test: accuracy: 0.7996131181716919 | f1: 0.8537545204162598 | auc: 0.8656007051467896 | 


train... loss:41.56733526289463:   0%|                                                                                                  | 3/5000 [01:16<26:58:36, 19.43s/it]

Val update: epoch: 2 |accuracy: 0.8006578683853149 | f1: 0.8590484857559204 | auc: 0.887866735458374 | treshold: 0.42000000000000004
Test: accuracy: 0.7945109605789185 | f1: 0.8547151684761047 | auc: 0.8860260844230652 | 


train... loss:36.58666281402111:   0%|                                                                                                  | 4/5000 [01:23<20:03:45, 14.46s/it]

Val update: epoch: 3 |accuracy: 0.8312049508094788 | f1: 0.8686652183532715 | auc: 0.9035803079605103 | treshold: 0.38
Test: accuracy: 0.8259944319725037 | f1: 0.8644003868103027 | auc: 0.8993576169013977 | 


train... loss:34.54479628801346:   0%|                                                                                                  | 5/5000 [01:30<16:14:55, 11.71s/it]

Val update: epoch: 4 |accuracy: 0.8371063470840454 | f1: 0.8786508440971375 | auc: 0.9099557399749756 | treshold: 0.44
Test: accuracy: 0.8349897265434265 | f1: 0.8767563700675964 | auc: 0.9076763391494751 | 


train... loss:32.81495387852192:   0%|                                                                                                  | 6/5000 [01:37<13:56:27, 10.05s/it]

Val update: epoch: 5 |accuracy: 0.8348812460899353 | f1: 0.8683494925498962 | auc: 0.9105205535888672 | treshold: 0.4
Test: accuracy: 0.8310965895652771 | f1: 0.8650841116905212 | auc: 0.9077848792076111 | 


train... loss:30.442696809768677:   0%|▏                                                                                                | 9/5000 [01:56<10:33:48,  7.62s/it]

Val update: epoch: 8 |accuracy: 0.8324142694473267 | f1: 0.8638783693313599 | auc: 0.9106537103652954 | treshold: 0.4
Test: accuracy: 0.8292346596717834 | f1: 0.8608582615852356 | auc: 0.9097559452056885 | 


train... loss:30.013588160276413:   0%|▏                                                                                               | 10/5000 [02:03<10:11:59,  7.36s/it]

Val update: epoch: 9 |accuracy: 0.8489575982093811 | f1: 0.8829845190048218 | auc: 0.918175995349884 | treshold: 0.36000000000000004
Test: accuracy: 0.8487486243247986 | f1: 0.8826960325241089 | auc: 0.9174371957778931 | 


train... loss:27.834611371159554:   0%|▎                                                                                                | 17/5000 [02:47<9:05:09,  6.56s/it]

Val update: epoch: 16 |accuracy: 0.8505780696868896 | f1: 0.8832356929779053 | auc: 0.919102668762207 | treshold: 0.37
Test: accuracy: 0.8501511216163635 | f1: 0.8826969861984253 | auc: 0.9197525978088379 | 


train... loss:22.061962768435478:   2%|█▉                                                                                              | 101/5000 [11:40<9:26:08,  6.93s/it]



Evaluating SlatewiseGRU with explicit embeddings


RuntimeError: input.size(-1) must be equal to input_size. Expected 64, got 80

In [None]:
pd.DataFrame(rl4rs_results).to_csv(f'results/rl4rs_{experiment_name}.csv')
del dataset, train_loader, val_loader, test_loader, train_user_item_matrix, train_num_items