In [1]:
import os
import torch
import random
import datetime
import pandas as pd
import numpy as np

from torch.utils.data import Dataset
from src.datasets import RL4RS, ContentWise, DummyData
from src.utils import train, get_dummy_data, get_train_val_test_tmatrix_tnumitems, get_svd_encoder
from src.embeddings import RecsysEmbedding

experiment_name = 'SessionwiseGRU'
device = 'cuda:2'
seed = 123
pkl_path = '../pkl/'

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7f305fac2bb0>

In [2]:
torch.__version__

'1.12.1'

# Модель

In [3]:
import torch.nn.functional as F
torch.autograd.set_detect_anomaly(True)

class SessionwiseGRU(torch.nn.Module):
    def __init__(self, embedding, output_dim=1, dropout = 0.1):
        super().__init__()
        self.embedding_dim = embedding.embedding_dim
        self.embedding = embedding
        self.rnn_layer = torch.nn.GRU(
            input_size = embedding.embedding_dim, 
            hidden_size = embedding.embedding_dim, 
            batch_first = True,
            dropout=dropout
        )
        self.out_layer = torch.nn.Linear(embedding.embedding_dim, output_dim)


    def forward(self, batch):
        shp = batch['slates_item_indexes'].shape
        item_embs, user_embs = self.embedding(batch)
        item_embs = item_embs.flatten(-3, -2)
        
        # while training, let out model see the future 
        # while testing, it can see only the 
        if self.training:
            indices = (batch['length'] - 1)
        else:
            indices = (batch['in_length'] - 1)
        
        indices[indices<0] = 0
        indices = indices[:, None, None].repeat(1, 1, user_embs.size(-1))
        user_embs = user_embs.gather(1, indices).squeeze(-2).unsqueeze(0)

#         print(indices.shape, user_embs.shape, item_embs.shape, )
        rnn_out, _ = self.rnn_layer(
            item_embs,
            user_embs,
        )
        return self.out_layer(rnn_out).reshape(shp)

# Игрушечный датасет: проверим, что сходится к идеальным метрикам

In [4]:
d = DummyData()
dummy_loader, dummy_matrix = get_dummy_data(d)

model = SessionwiseGRU(
    RecsysEmbedding(d.n_items, dummy_matrix, embeddings='neural').to('cpu'),
    output_dim=1
).to('cpu')

train(
    model, 
    dummy_loader, dummy_loader, dummy_loader,
    device=device, lr=1e-3, num_epochs=5000, dummy=True,
    silent=True,
)


biulding affinity matrix...


3it [00:00, 3669.56it/s]


Test before learning: {'f1': 0.4000000059604645, 'roc-auc': 0.3333333134651184, 'accuracy': 0.25}


train... loss:0.7189432978630066:   0%|                                                                                                    | 1/5000 [00:00<27:56,  2.98it/s]

Val update: epoch: 0 |accuracy: 0.5 | f1: 0.5 | auc: 0.3333333134651184 | treshold: 0.51
Test: accuracy: 0.5 | f1: 0.5 | auc: 0.3333333134651184 | 


train... loss:0.6976741552352905:   0%|                                                                                                    | 3/5000 [00:01<30:40,  2.72it/s]

Val update: epoch: 3 |accuracy: 1.0 | f1: 1.0 | auc: 1.0 | treshold: 0.54
Test: accuracy: 1.0 | f1: 1.0 | auc: 1.0 | 





(SessionwiseGRU(
   (embedding): RecsysEmbedding(
     (item_embeddings): Embedding(5, 32)
   )
   (rnn_layer): GRU(32, 32, batch_first=True, dropout=0.1)
   (out_layer): Linear(in_features=32, out_features=1, bias=True)
 ),
 {'f1': 1.0, 'roc-auc': 1.0, 'accuracy': 1.0})

# ContentWise

In [5]:
content_wise_results = []
dataset = ContentWise.load(os.path.join(pkl_path, 'cw.pkl'))
(
    train_loader, 
    val_loader, 
    test_loader, 
    train_user_item_matrix, 
    train_num_items 
) = get_train_val_test_tmatrix_tnumitems(dataset, batch_size=150)

print(f"{len(dataset)} data points among {len(train_loader)} batches")

20216 data points among 108 batches


In [6]:
for embeddings in ['svd', 'neural']:
    print(f"\nEvaluating {experiment_name} with {embeddings} embeddings")
    
    model = SessionwiseGRU(
        RecsysEmbedding(
            train_num_items, 
            train_user_item_matrix, 
            embeddings=embeddings
        ),
        output_dim=1
    ).to(device)

    _, metrics = train(
        model, 
        train_loader, val_loader, test_loader, 
        device=device, lr=1e-3, num_epochs=5000, early_stopping=7,
       silent=True, 
    )
    
    metrics['embeddings'] = embeddings
    content_wise_results.append(metrics)


Evaluating SessionwiseGRU with svd embeddings




Test before learning: {'f1': 0.17965421080589294, 'roc-auc': 0.5195829272270203, 'accuracy': 0.10167654603719711}


train... loss:47.48355868458748:   0%|                                                                                                  | 1/5000 [00:50<69:31:04, 50.06s/it]

Val update: epoch: 0 |accuracy: 0.09938888251781464 | f1: 0.18080750107765198 | auc: 0.5529186725616455 | treshold: 0.04
Test: accuracy: 0.09880203753709793 | f1: 0.17983591556549072 | auc: 0.5536670088768005 | 


train... loss:34.07816553115845:   0%|                                                                                                  | 2/5000 [01:08<43:16:37, 31.17s/it]

Val update: epoch: 1 |accuracy: 0.09938888251781464 | f1: 0.18080750107765198 | auc: 0.5737990140914917 | treshold: 0.060000000000000005
Test: accuracy: 0.09880203753709793 | f1: 0.17983591556549072 | auc: 0.5750941038131714 | 


train... loss:33.60349127650261:   0%|                                                                                                  | 3/5000 [01:25<34:52:36, 25.13s/it]

Val update: epoch: 2 |accuracy: 0.8482187986373901 | f1: 0.229086235165596 | auc: 0.5773776769638062 | treshold: 0.06999999999999999
Test: accuracy: 0.8440237045288086 | f1: 0.23024360835552216 | auc: 0.5801332592964172 | 


train... loss:33.10733485221863:   0%|                                                                                                  | 4/5000 [01:43<30:46:44, 22.18s/it]

Val update: epoch: 3 |accuracy: 0.20561932027339935 | f1: 0.18460549414157867 | auc: 0.6166515946388245 | treshold: 0.09
Test: accuracy: 0.20728109776973724 | f1: 0.18304769694805145 | auc: 0.6163737773895264 | 


train... loss:32.825560837984085:   0%|                                                                                                 | 5/5000 [02:01<28:33:16, 20.58s/it]

Val update: epoch: 4 |accuracy: 0.7610523104667664 | f1: 0.2386854737997055 | auc: 0.6251019239425659 | treshold: 0.12
Test: accuracy: 0.7525509595870972 | f1: 0.23773759603500366 | auc: 0.6277182102203369 | 


train... loss:32.72361943125725:   0%|                                                                                                  | 6/5000 [02:19<27:15:15, 19.65s/it]

Val update: epoch: 5 |accuracy: 0.7985392808914185 | f1: 0.2456747442483902 | auc: 0.6360332369804382 | treshold: 0.12
Test: accuracy: 0.7901303172111511 | f1: 0.24763959646224976 | auc: 0.6363770365715027 | 


train... loss:32.56756290793419:   0%|▏                                                                                                 | 7/5000 [02:36<26:22:03, 19.01s/it]

Val update: epoch: 6 |accuracy: 0.8396482467651367 | f1: 0.2465331256389618 | auc: 0.6431294083595276 | treshold: 0.12
Test: accuracy: 0.8343467116355896 | f1: 0.2469727098941803 | auc: 0.6429708003997803 | 


train... loss:32.468492209911346:   0%|▏                                                                                                | 8/5000 [02:55<25:58:28, 18.73s/it]

Val update: epoch: 7 |accuracy: 0.8522283434867859 | f1: 0.23550277948379517 | auc: 0.6435239315032959 | treshold: 0.12
Test: accuracy: 0.8487493395805359 | f1: 0.2379435896873474 | auc: 0.6440587043762207 | 


train... loss:32.41186237335205:   0%|▏                                                                                                 | 9/5000 [03:12<25:34:41, 18.45s/it]

Val update: epoch: 8 |accuracy: 0.8736622333526611 | f1: 0.22066935896873474 | auc: 0.6469062566757202 | treshold: 0.12
Test: accuracy: 0.8727537989616394 | f1: 0.22267168760299683 | auc: 0.6493200063705444 | 


train... loss:32.400557577610016:   0%|▏                                                                                               | 10/5000 [03:30<25:13:02, 18.19s/it]

Val update: epoch: 9 |accuracy: 0.8719183206558228 | f1: 0.22564657032489777 | auc: 0.6493664383888245 | treshold: 0.12
Test: accuracy: 0.8696686029434204 | f1: 0.22733761370182037 | auc: 0.6516949534416199 | 


train... loss:32.2983860373497:   0%|▏                                                                                                 | 11/5000 [03:48<25:01:10, 18.05s/it]

Val update: epoch: 10 |accuracy: 0.8640781044960022 | f1: 0.23091845214366913 | auc: 0.6524428129196167 | treshold: 0.12
Test: accuracy: 0.8618577718734741 | f1: 0.2318185567855835 | auc: 0.652309000492096 | 


train... loss:32.25128963589668:   0%|▎                                                                                                | 13/5000 [04:22<24:23:07, 17.60s/it]

Val update: epoch: 12 |accuracy: 0.8153823018074036 | f1: 0.2501513361930847 | auc: 0.652967095375061 | treshold: 0.11
Test: accuracy: 0.8106883764266968 | f1: 0.2514727711677551 | auc: 0.6553683280944824 | 


train... loss:32.16134759783745:   0%|▎                                                                                                | 14/5000 [04:40<24:25:06, 17.63s/it]

Val update: epoch: 13 |accuracy: 0.8052914142608643 | f1: 0.24860511720180511 | auc: 0.655693531036377 | treshold: 0.12
Test: accuracy: 0.7980014085769653 | f1: 0.24705485999584198 | auc: 0.6563364267349243 | 


train... loss:32.223952025175095:   0%|▎                                                                                               | 15/5000 [04:57<24:20:31, 17.58s/it]

Val update: epoch: 14 |accuracy: 0.7623043656349182 | f1: 0.2503643035888672 | auc: 0.6567521691322327 | treshold: 0.11
Test: accuracy: 0.7563886642456055 | f1: 0.2515374422073364 | auc: 0.658530056476593 | 


train... loss:32.157773315906525:   0%|▎                                                                                               | 16/5000 [05:15<24:27:25, 17.67s/it]

Val update: epoch: 15 |accuracy: 0.7837531566619873 | f1: 0.2533964514732361 | auc: 0.6578312516212463 | treshold: 0.12
Test: accuracy: 0.7766908407211304 | f1: 0.2527948319911957 | auc: 0.6594458222389221 | 


train... loss:32.09799614548683:   0%|▎                                                                                                | 18/5000 [05:49<24:14:14, 17.51s/it]

Val update: epoch: 17 |accuracy: 0.7781487703323364 | f1: 0.2499496042728424 | auc: 0.6597491502761841 | treshold: 0.12
Test: accuracy: 0.7701441645622253 | f1: 0.24789480865001678 | auc: 0.6572144627571106 | 


train... loss:32.0808362364769:   0%|▍                                                                                                 | 21/5000 [06:41<23:57:35, 17.32s/it]

Val update: epoch: 20 |accuracy: 0.772246241569519 | f1: 0.25251930952072144 | auc: 0.6598657369613647 | treshold: 0.12
Test: accuracy: 0.7638834714889526 | f1: 0.2529403269290924 | auc: 0.6599140167236328 | 


train... loss:32.050273180007935:   0%|▍                                                                                               | 22/5000 [06:58<24:05:41, 17.43s/it]

Val update: epoch: 21 |accuracy: 0.7709792852401733 | f1: 0.2519838511943817 | auc: 0.659955620765686 | treshold: 0.12
Test: accuracy: 0.7626493573188782 | f1: 0.25124627351760864 | auc: 0.6606194972991943 | 


train... loss:32.022758930921555:   0%|▍                                                                                               | 23/5000 [07:17<24:21:46, 17.62s/it]

Val update: epoch: 22 |accuracy: 0.7683559656143188 | f1: 0.2509278357028961 | auc: 0.659984827041626 | treshold: 0.12
Test: accuracy: 0.7606627941131592 | f1: 0.24975232779979706 | auc: 0.6588616371154785 | 


train... loss:31.965380758047104:   1%|▍                                                                                               | 26/5000 [08:09<24:13:55, 17.54s/it]

Val update: epoch: 25 |accuracy: 0.7622894644737244 | f1: 0.24794869124889374 | auc: 0.6604293584823608 | treshold: 0.12
Test: accuracy: 0.7528519630432129 | f1: 0.24718071520328522 | auc: 0.6627093553543091 | 


train... loss:31.94447273015976:   1%|▌                                                                                                | 27/5000 [08:27<24:25:21, 17.68s/it]

Val update: epoch: 26 |accuracy: 0.7659412622451782 | f1: 0.24891184270381927 | auc: 0.6611189842224121 | treshold: 0.12
Test: accuracy: 0.757773220539093 | f1: 0.2502445578575134 | auc: 0.6589192152023315 | 


train... loss:31.919546008110046:   1%|▌                                                                                               | 29/5000 [09:03<24:49:50, 17.98s/it]

Val update: epoch: 28 |accuracy: 0.7301236987113953 | f1: 0.2503933012485504 | auc: 0.6616170406341553 | treshold: 0.12
Test: accuracy: 0.7249345183372498 | f1: 0.25085052847862244 | auc: 0.6601583957672119 | 


train... loss:31.856757700443268:   1%|▌                                                                                               | 32/5000 [09:56<24:43:52, 17.92s/it]

Val update: epoch: 31 |accuracy: 0.7502012252807617 | f1: 0.24850903451442719 | auc: 0.6633386611938477 | treshold: 0.12
Test: accuracy: 0.7458236813545227 | f1: 0.25147366523742676 | auc: 0.6625568866729736 | 


train... loss:31.82919842004776:   1%|▋                                                                                                | 35/5000 [10:50<24:50:05, 18.01s/it]

Val update: epoch: 34 |accuracy: 0.7543001770973206 | f1: 0.2520192265510559 | auc: 0.6639922857284546 | treshold: 0.12
Test: accuracy: 0.7491797804832458 | f1: 0.25271275639533997 | auc: 0.6649953126907349 | 


train... loss:31.819978803396225:   1%|▋                                                                                               | 37/5000 [11:26<24:49:06, 18.00s/it]

Val update: epoch: 36 |accuracy: 0.7552988529205322 | f1: 0.2514249384403229 | auc: 0.6645603775978088 | treshold: 0.12
Test: accuracy: 0.7499774098396301 | f1: 0.2534489631652832 | auc: 0.6647671461105347 | 


train... loss:31.82412001490593:   1%|▋                                                                                                | 38/5000 [11:44<24:57:54, 18.11s/it]

Val update: epoch: 37 |accuracy: 0.7386048436164856 | f1: 0.25224918127059937 | auc: 0.6651040315628052 | treshold: 0.12
Test: accuracy: 0.7350178956985474 | f1: 0.25441455841064453 | auc: 0.6661825180053711 | 


train... loss:31.747672587633133:   1%|▊                                                                                               | 40/5000 [12:20<24:57:01, 18.11s/it]

Val update: epoch: 39 |accuracy: 0.7188999652862549 | f1: 0.25183480978012085 | auc: 0.6655407547950745 | treshold: 0.12
Test: accuracy: 0.7173042893409729 | f1: 0.25329941511154175 | auc: 0.6674298048019409 | 


train... loss:31.739202857017517:   1%|▊                                                                                               | 41/5000 [12:39<25:10:01, 18.27s/it]

Val update: epoch: 40 |accuracy: 0.7263824939727783 | f1: 0.24874155223369598 | auc: 0.6655572652816772 | treshold: 0.12
Test: accuracy: 0.7240164875984192 | f1: 0.25065380334854126 | auc: 0.6660141944885254 | 


train... loss:31.31447806954384:   1%|█▏                                                                                               | 64/5000 [19:19<24:08:21, 17.61s/it]

Val update: epoch: 63 |accuracy: 0.8633924722671509 | f1: 0.25140896439552307 | auc: 0.6667841076850891 | treshold: 0.13
Test: accuracy: 0.8607741594314575 | f1: 0.25135549902915955 | auc: 0.6693324446678162 | 


train... loss:31.150575578212738:   1%|█▎                                                                                              | 68/5000 [20:30<24:15:18, 17.70s/it]

Val update: epoch: 67 |accuracy: 0.8011178970336914 | f1: 0.25950387120246887 | auc: 0.666931688785553 | treshold: 0.13
Test: accuracy: 0.7985883355140686 | f1: 0.26007628440856934 | auc: 0.6619095802307129 | 


train... loss:30.96721801161766:   1%|█▎                                                                                               | 69/5000 [20:48<24:31:22, 17.90s/it]

Val update: epoch: 68 |accuracy: 0.7112088203430176 | f1: 0.2559238076210022 | auc: 0.6688295006752014 | treshold: 0.12
Test: accuracy: 0.714790940284729 | f1: 0.25521713495254517 | auc: 0.6664844751358032 | 


train... loss:30.89816462993622:   2%|█▍                                                                                               | 75/5000 [22:34<24:26:23, 17.86s/it]

Val update: epoch: 74 |accuracy: 0.8642271757125854 | f1: 0.25145861506462097 | auc: 0.669046938419342 | treshold: 0.13
Test: accuracy: 0.8594799041748047 | f1: 0.25464996695518494 | auc: 0.6636302471160889 | 


train... loss:30.893246829509735:   2%|█▍                                                                                              | 78/5000 [23:27<24:31:05, 17.93s/it]

Val update: epoch: 77 |accuracy: 0.8422268629074097 | f1: 0.2674233615398407 | auc: 0.6692649126052856 | treshold: 0.13
Test: accuracy: 0.8371309041976929 | f1: 0.2640098035335541 | auc: 0.664955735206604 | 


train... loss:30.820012241601944:   2%|█▌                                                                                              | 79/5000 [23:46<24:43:09, 18.08s/it]

Val update: epoch: 78 |accuracy: 0.818974494934082 | f1: 0.26682764291763306 | auc: 0.6701809167861938 | treshold: 0.13
Test: accuracy: 0.817505955696106 | f1: 0.2684604227542877 | auc: 0.6664113402366638 | 


train... loss:30.667582601308823:   2%|█▋                                                                                              | 85/5000 [25:31<24:20:09, 17.82s/it]

Val update: epoch: 84 |accuracy: 0.8683112263679504 | f1: 0.25411567091941833 | auc: 0.6725221872329712 | treshold: 0.13
Test: accuracy: 0.8647623658180237 | f1: 0.2525370121002197 | auc: 0.6689162850379944 | 


train... loss:30.851949512958527:   2%|█▋                                                                                              | 91/5000 [27:33<24:46:46, 18.17s/it]



Evaluating SessionwiseGRU with neural embeddings
Test before learning: {'f1': 0.18101291358470917, 'roc-auc': 0.5300668478012085, 'accuracy': 0.462887167930603}


train... loss:42.822390377521515:   0%|                                                                                                 | 1/5000 [00:20<28:29:16, 20.52s/it]

Val update: epoch: 0 |accuracy: 0.09938888251781464 | f1: 0.18080750107765198 | auc: 0.6257848739624023 | treshold: 0.060000000000000005
Test: accuracy: 0.09880203753709793 | f1: 0.17983591556549072 | auc: 0.6229747533798218 | 


train... loss:32.94593012332916:   0%|                                                                                                  | 2/5000 [00:41<28:42:14, 20.68s/it]

Val update: epoch: 1 |accuracy: 0.8451036214828491 | f1: 0.23464427888393402 | auc: 0.6560053825378418 | treshold: 0.09999999999999999
Test: accuracy: 0.8449116349220276 | f1: 0.2408839762210846 | auc: 0.6561639308929443 | 


train... loss:32.16994056105614:   0%|                                                                                                  | 3/5000 [01:01<28:34:29, 20.59s/it]

Val update: epoch: 2 |accuracy: 0.5961991548538208 | f1: 0.2441338151693344 | auc: 0.6793311238288879 | treshold: 0.13
Test: accuracy: 0.5988321304321289 | f1: 0.24208131432533264 | auc: 0.6787958741188049 | 


train... loss:31.655621886253357:   0%|                                                                                                 | 4/5000 [01:22<28:32:34, 20.57s/it]

Val update: epoch: 3 |accuracy: 0.8607989549636841 | f1: 0.24752235412597656 | auc: 0.6940723657608032 | treshold: 0.14
Test: accuracy: 0.859856128692627 | f1: 0.2558734118938446 | auc: 0.6911475658416748 | 


train... loss:31.4057614505291:   0%|                                                                                                   | 5/5000 [01:43<28:37:53, 20.64s/it]

Val update: epoch: 4 |accuracy: 0.7942167520523071 | f1: 0.2810872793197632 | auc: 0.7036116123199463 | treshold: 0.12
Test: accuracy: 0.793953001499176 | f1: 0.2778627574443817 | auc: 0.699545681476593 | 


train... loss:31.13055670261383:   0%|                                                                                                  | 6/5000 [02:03<28:40:35, 20.67s/it]

Val update: epoch: 5 |accuracy: 0.6803398132324219 | f1: 0.27444347739219666 | auc: 0.7069129943847656 | treshold: 0.15000000000000002
Test: accuracy: 0.6775727868080139 | f1: 0.2694039046764374 | auc: 0.7031509876251221 | 


train... loss:30.85667198896408:   0%|▏                                                                                                 | 8/5000 [02:43<28:12:52, 20.35s/it]

Val update: epoch: 7 |accuracy: 0.8145326972007751 | f1: 0.28393852710723877 | auc: 0.7141842246055603 | treshold: 0.12
Test: accuracy: 0.8141197562217712 | f1: 0.27953100204467773 | auc: 0.7090149521827698 | 


train... loss:30.600494623184204:   0%|▏                                                                                                | 9/5000 [03:04<28:15:35, 20.38s/it]

Val update: epoch: 8 |accuracy: 0.6796839833259583 | f1: 0.27885904908180237 | auc: 0.7166664600372314 | treshold: 0.15000000000000002
Test: accuracy: 0.6777834892272949 | f1: 0.27171915769577026 | auc: 0.7120574116706848 | 


train... loss:30.296895682811737:   0%|▏                                                                                               | 11/5000 [03:44<28:08:38, 20.31s/it]

Val update: epoch: 10 |accuracy: 0.8438962697982788 | f1: 0.2822287678718567 | auc: 0.7180449366569519 | treshold: 0.15000000000000002
Test: accuracy: 0.846190869808197 | f1: 0.2804843783378601 | auc: 0.713841438293457 | 


train... loss:30.304493069648743:   0%|▏                                                                                               | 12/5000 [04:05<28:20:35, 20.46s/it]

Val update: epoch: 11 |accuracy: 0.8629602193832397 | f1: 0.26869234442710876 | auc: 0.7181637287139893 | treshold: 0.14
Test: accuracy: 0.8651686906814575 | f1: 0.27357494831085205 | auc: 0.7159227132797241 | 


train... loss:29.98767602443695:   0%|▎                                                                                                | 13/5000 [04:26<28:30:34, 20.58s/it]

Val update: epoch: 12 |accuracy: 0.8201818466186523 | f1: 0.29160305857658386 | auc: 0.7207666039466858 | treshold: 0.15000000000000002
Test: accuracy: 0.8233001232147217 | f1: 0.288208544254303 | auc: 0.7173226475715637 | 


train... loss:29.789395704865456:   0%|▎                                                                                               | 14/5000 [04:46<28:30:48, 20.59s/it]

Val update: epoch: 13 |accuracy: 0.7365479469299316 | f1: 0.2887610197067261 | auc: 0.7218604683876038 | treshold: 0.17
Test: accuracy: 0.7371549606323242 | f1: 0.2867644131183624 | auc: 0.7185161113739014 | 


train... loss:29.5811298340559:   0%|▎                                                                                                 | 15/5000 [05:07<28:28:22, 20.56s/it]

Val update: epoch: 14 |accuracy: 0.8000149130821228 | f1: 0.29328417778015137 | auc: 0.7234261631965637 | treshold: 0.17
Test: accuracy: 0.8037804961204529 | f1: 0.29531943798065186 | auc: 0.7211157083511353 | 


train... loss:29.277683407068253:   0%|▎                                                                                               | 16/5000 [05:28<28:31:18, 20.60s/it]

Val update: epoch: 15 |accuracy: 0.8631986975669861 | f1: 0.2849797308444977 | auc: 0.7274666428565979 | treshold: 0.16
Test: accuracy: 0.86708003282547 | f1: 0.2884305417537689 | auc: 0.7238579988479614 | 


train... loss:27.86328910291195:   0%|▍                                                                                                | 24/5000 [08:05<27:27:01, 19.86s/it]

Val update: epoch: 23 |accuracy: 0.8442241549491882 | f1: 0.3070088326931 | auc: 0.7279775142669678 | treshold: 0.18000000000000002
Test: accuracy: 0.8462962508201599 | f1: 0.3075462877750397 | auc: 0.7238301038742065 | 


train... loss:27.806676656007767:   0%|▍                                                                                               | 25/5000 [08:25<27:51:04, 20.15s/it]

Val update: epoch: 24 |accuracy: 0.8038306832313538 | f1: 0.3053048253059387 | auc: 0.7285906672477722 | treshold: 0.19
Test: accuracy: 0.8057520389556885 | f1: 0.30296483635902405 | auc: 0.7223894000053406 | 


train... loss:26.249242529273033:   2%|██▏                                                                                            | 113/5000 [37:00<26:40:32, 19.65s/it]


In [12]:
pd.DataFrame(content_wise_results).to_csv(f'results/cw_{experiment_name}.csv')
del dataset, train_loader, val_loader, test_loader, train_user_item_matrix, train_num_items

# RL4RS

In [8]:
rl4rs_results = []
dataset = RL4RS.load(os.path.join(pkl_path, 'rl4rs.pkl'))
(
    train_loader, 
    val_loader, 
    test_loader, 
    train_user_item_matrix, 
    train_num_items 
) = get_train_val_test_tmatrix_tnumitems(dataset, batch_size=350)

print(f"{len(dataset)} data points among {len(train_loader)} batches")

45942 data points among 106 batches


In [9]:
for embeddings in ['explicit', 'neural', 'svd']:
    print(f"\nEvaluating {experiment_name} with {embeddings} embeddings")

    model = SessionwiseGRU(
        RecsysEmbedding(
            train_num_items, 
            train_user_item_matrix, 
            embeddings=embeddings,
            embedding_dim=40
        ),
        output_dim=1
    ).to(device)


    best_model, metrics = train(
        model, 
        train_loader, val_loader, test_loader, 
        device=device, lr=1e-3, num_epochs=5000, early_stopping=7,
        silent=True
    )
    
    metrics['embeddings'] = embeddings
    rl4rs_results.append(metrics)


Evaluating SessionwiseGRU with explicit embeddings




Test before learning: {'f1': 0.4881993234157562, 'roc-auc': 0.42448124289512634, 'accuracy': 0.4352557063102722}


train... loss:60.50820517539978:   0%|                                                                                                  | 1/5000 [01:02<87:27:11, 62.98s/it]

Val update: epoch: 0 |accuracy: 0.7574372291564941 | f1: 0.8132645487785339 | auc: 0.8160781264305115 | treshold: 0.44
Test: accuracy: 0.7682505249977112 | f1: 0.8227155208587646 | auc: 0.8280416131019592 | 


train... loss:51.59247046709061:   0%|                                                                                                  | 2/5000 [01:09<41:40:07, 30.01s/it]

Val update: epoch: 1 |accuracy: 0.7606781721115112 | f1: 0.8016597032546997 | auc: 0.8547112941741943 | treshold: 0.38
Test: accuracy: 0.770088255405426 | f1: 0.8111618757247925 | auc: 0.8619877099990845 | 


train... loss:45.91336426138878:   0%|                                                                                                  | 3/5000 [01:16<27:05:05, 19.51s/it]

Val update: epoch: 2 |accuracy: 0.7967880964279175 | f1: 0.8521191477775574 | auc: 0.8629810810089111 | treshold: 0.38
Test: accuracy: 0.8000967502593994 | f1: 0.855099618434906 | auc: 0.8694379329681396 | 


train... loss:42.31544545292854:   0%|                                                                                                  | 4/5000 [01:23<20:13:50, 14.58s/it]

Val update: epoch: 3 |accuracy: 0.7913945913314819 | f1: 0.8569438457489014 | auc: 0.8778505325317383 | treshold: 0.36000000000000004
Test: accuracy: 0.793446958065033 | f1: 0.8586650490760803 | auc: 0.883622944355011 | 


train... loss:40.52724951505661:   0%|                                                                                                  | 5/5000 [01:30<16:25:37, 11.84s/it]

Val update: epoch: 4 |accuracy: 0.8128476738929749 | f1: 0.8575531244277954 | auc: 0.8849530220031738 | treshold: 0.34
Test: accuracy: 0.8165881037712097 | f1: 0.8608946204185486 | auc: 0.8884271383285522 | 


train... loss:35.50683119893074:   0%|▏                                                                                                 | 7/5000 [01:44<12:34:09,  9.06s/it]

Val update: epoch: 6 |accuracy: 0.8227881789207458 | f1: 0.8744451999664307 | auc: 0.9142627120018005 | treshold: 0.33
Test: accuracy: 0.8230927586555481 | f1: 0.8747988939285278 | auc: 0.9165470600128174 | 


train... loss:31.846922010183334:   0%|▏                                                                                                | 9/5000 [01:58<10:55:21,  7.88s/it]

Val update: epoch: 8 |accuracy: 0.853625476360321 | f1: 0.8877304792404175 | auc: 0.918938398361206 | treshold: 0.42000000000000004
Test: accuracy: 0.8531011939048767 | f1: 0.8879172205924988 | auc: 0.9205757975578308 | 


train... loss:29.80987687408924:   0%|▎                                                                                                 | 13/5000 [02:24<9:38:31,  6.96s/it]

Val update: epoch: 12 |accuracy: 0.8500943183898926 | f1: 0.8819878101348877 | auc: 0.9194398522377014 | treshold: 0.36000000000000004
Test: accuracy: 0.8499577045440674 | f1: 0.8824120163917542 | auc: 0.9203335642814636 | 


train... loss:28.52599011361599:   0%|▎                                                                                                 | 18/5000 [02:57<9:22:02,  6.77s/it]

Val update: epoch: 17 |accuracy: 0.8536980748176575 | f1: 0.8892307281494141 | auc: 0.922161877155304 | treshold: 0.44
Test: accuracy: 0.8534397482872009 | f1: 0.8895650506019592 | auc: 0.9236644506454468 | 


train... loss:26.27913323044777:   1%|█                                                                                                | 53/5000 [06:51<10:39:48,  7.76s/it]



Evaluating SessionwiseGRU with neural embeddings
Test before learning: {'f1': 0.3928951621055603, 'roc-auc': 0.49920767545700073, 'accuracy': 0.4346753656864166}


train... loss:62.855235666036606:   0%|                                                                                                 | 1/5000 [00:07<11:03:17,  7.96s/it]

Val update: epoch: 0 |accuracy: 0.742030680179596 | f1: 0.8285043835639954 | auc: 0.7891491651535034 | treshold: 0.39
Test: accuracy: 0.745979905128479 | f1: 0.8313127160072327 | auc: 0.8016515970230103 | 


train... loss:52.80232501029968:   0%|                                                                                                  | 2/5000 [00:15<11:00:36,  7.93s/it]

Val update: epoch: 1 |accuracy: 0.7679582238197327 | f1: 0.8440608382225037 | auc: 0.8366842269897461 | treshold: 0.36000000000000004
Test: accuracy: 0.7725063562393188 | f1: 0.8474906086921692 | auc: 0.8497934341430664 | 


train... loss:47.029819905757904:   0%|                                                                                                 | 3/5000 [00:23<11:01:43,  7.95s/it]

Val update: epoch: 2 |accuracy: 0.7943452596664429 | f1: 0.838453471660614 | auc: 0.8665439486503601 | treshold: 0.38
Test: accuracy: 0.8025873303413391 | f1: 0.8463276028633118 | auc: 0.8727356195449829 | 


train... loss:43.0148246884346:   0%|                                                                                                   | 4/5000 [00:31<11:02:45,  7.96s/it]

Val update: epoch: 3 |accuracy: 0.7895805835723877 | f1: 0.8556016683578491 | auc: 0.8668836355209351 | treshold: 0.35000000000000003
Test: accuracy: 0.7867488861083984 | f1: 0.8543686270713806 | auc: 0.8732645511627197 | 


train... loss:41.11425492167473:   0%|                                                                                                  | 5/5000 [00:39<11:01:44,  7.95s/it]

Val update: epoch: 4 |accuracy: 0.804938793182373 | f1: 0.8488879799842834 | auc: 0.8801166415214539 | treshold: 0.35000000000000003
Test: accuracy: 0.8129851222038269 | f1: 0.8562987446784973 | auc: 0.8852585554122925 | 


train... loss:37.94949695467949:   0%|                                                                                                  | 6/5000 [00:47<11:02:16,  7.96s/it]

Val update: epoch: 5 |accuracy: 0.8098486065864563 | f1: 0.8657674789428711 | auc: 0.8909774422645569 | treshold: 0.37
Test: accuracy: 0.8080522418022156 | f1: 0.8652657866477966 | auc: 0.8956835269927979 | 


train... loss:34.89358773827553:   0%|▏                                                                                                 | 7/5000 [00:55<10:57:35,  7.90s/it]

Val update: epoch: 6 |accuracy: 0.817467212677002 | f1: 0.8524910807609558 | auc: 0.899204671382904 | treshold: 0.37
Test: accuracy: 0.8204086422920227 | f1: 0.8557779788970947 | auc: 0.9010817408561707 | 


train... loss:33.34385597705841:   0%|▏                                                                                                 | 8/5000 [01:03<10:58:16,  7.91s/it]

Val update: epoch: 7 |accuracy: 0.8348570466041565 | f1: 0.8786003589630127 | auc: 0.9103359580039978 | treshold: 0.4
Test: accuracy: 0.8358843922615051 | f1: 0.8798484802246094 | auc: 0.9129641056060791 | 


train... loss:31.773744702339172:   0%|▏                                                                                                | 9/5000 [01:11<10:57:57,  7.91s/it]

Val update: epoch: 8 |accuracy: 0.8408552408218384 | f1: 0.8728060126304626 | auc: 0.9154351949691772 | treshold: 0.39
Test: accuracy: 0.8430661559104919 | f1: 0.8753265738487244 | auc: 0.916930079460144 | 


train... loss:30.71716558933258:   0%|▏                                                                                                | 10/5000 [01:19<10:58:10,  7.91s/it]

Val update: epoch: 9 |accuracy: 0.8483529090881348 | f1: 0.8858214616775513 | auc: 0.9167110323905945 | treshold: 0.41000000000000003
Test: accuracy: 0.8502720594406128 | f1: 0.8876633048057556 | auc: 0.9201153516769409 | 


train... loss:29.37326091527939:   0%|▏                                                                                                | 12/5000 [01:34<10:47:25,  7.79s/it]

Val update: epoch: 11 |accuracy: 0.853625476360321 | f1: 0.886979877948761 | auc: 0.9191623330116272 | treshold: 0.35000000000000003
Test: accuracy: 0.8539958596229553 | f1: 0.8878320455551147 | auc: 0.9217108488082886 | 


train... loss:28.5719406157732:   0%|▎                                                                                                 | 16/5000 [02:04<10:36:15,  7.66s/it]

Val update: epoch: 15 |accuracy: 0.8441687226295471 | f1: 0.8851699233055115 | auc: 0.9213276505470276 | treshold: 0.39
Test: accuracy: 0.8476846814155579 | f1: 0.8881390690803528 | auc: 0.9230121374130249 | 


train... loss:23.272811502218246:   2%|█▌                                                                                              | 83/5000 [10:26<10:18:39,  7.55s/it]



Evaluating SessionwiseGRU with svd embeddings
Test before learning: {'f1': 0.7895838618278503, 'roc-auc': 0.5017134547233582, 'accuracy': 0.6523999571800232}


train... loss:69.61009615659714:   0%|                                                                                                   | 1/5000 [00:06<9:27:41,  6.81s/it]

Val update: epoch: 0 |accuracy: 0.6505829095840454 | f1: 0.7883068323135376 | auc: 0.5961850881576538 | treshold: 0.01
Test: accuracy: 0.6529561281204224 | f1: 0.7900465130805969 | auc: 0.6077756881713867 | 


train... loss:66.76936364173889:   0%|                                                                                                   | 2/5000 [00:13<9:26:29,  6.80s/it]

Val update: epoch: 1 |accuracy: 0.6505829095840454 | f1: 0.7883068323135376 | auc: 0.7009993195533752 | treshold: 0.21000000000000002
Test: accuracy: 0.6529561281204224 | f1: 0.7900465130805969 | auc: 0.7189827561378479 | 


train... loss:61.91348770260811:   0%|                                                                                                   | 3/5000 [00:20<9:27:09,  6.81s/it]

Val update: epoch: 2 |accuracy: 0.6505829095840454 | f1: 0.7881143689155579 | auc: 0.7590543031692505 | treshold: 0.28
Test: accuracy: 0.6529561281204224 | f1: 0.7898682355880737 | auc: 0.7736423015594482 | 


train... loss:59.94834965467453:   0%|                                                                                                   | 4/5000 [00:27<9:31:37,  6.86s/it]

Val update: epoch: 3 |accuracy: 0.6954965591430664 | f1: 0.8028623461723328 | auc: 0.7783321142196655 | treshold: 0.35000000000000003
Test: accuracy: 0.6988272070884705 | f1: 0.8060603141784668 | auc: 0.7932279109954834 | 


train... loss:57.72984817624092:   0%|                                                                                                   | 5/5000 [00:34<9:36:39,  6.93s/it]

Val update: epoch: 4 |accuracy: 0.72572922706604 | f1: 0.811977744102478 | auc: 0.7807934284210205 | treshold: 0.42000000000000004
Test: accuracy: 0.7325353622436523 | f1: 0.8175866007804871 | auc: 0.7959981560707092 | 


train... loss:55.36587983369827:   0%|                                                                                                   | 6/5000 [00:41<9:37:11,  6.93s/it]

Val update: epoch: 5 |accuracy: 0.7425143718719482 | f1: 0.8198585510253906 | auc: 0.8043502569198608 | treshold: 0.37
Test: accuracy: 0.7511304616928101 | f1: 0.8267689943313599 | auc: 0.8187210559844971 | 


train... loss:53.2257177233696:   0%|▏                                                                                                   | 8/5000 [00:54<9:27:37,  6.82s/it]

Val update: epoch: 7 |accuracy: 0.7296473383903503 | f1: 0.8237131237983704 | auc: 0.8228716254234314 | treshold: 0.35000000000000003
Test: accuracy: 0.7320517301559448 | f1: 0.8259045481681824 | auc: 0.8356223106384277 | 


train... loss:52.249251902103424:   0%|▏                                                                                                 | 9/5000 [01:01<9:31:40,  6.87s/it]

Val update: epoch: 8 |accuracy: 0.7584530711174011 | f1: 0.834435760974884 | auc: 0.8321247100830078 | treshold: 0.37
Test: accuracy: 0.7620843648910522 | f1: 0.8375517725944519 | auc: 0.8465657234191895 | 


train... loss:50.608155846595764:   0%|▏                                                                                                | 10/5000 [01:08<9:37:46,  6.95s/it]

Val update: epoch: 9 |accuracy: 0.7808493971824646 | f1: 0.8398692011833191 | auc: 0.8404390811920166 | treshold: 0.41000000000000003
Test: accuracy: 0.7886108160018921 | f1: 0.8464592099189758 | auc: 0.8538966178894043 | 


train... loss:50.31140556931496:   0%|▏                                                                                                 | 11/5000 [01:15<9:41:49,  7.00s/it]

Val update: epoch: 10 |accuracy: 0.7792773246765137 | f1: 0.8353183269500732 | auc: 0.8412235975265503 | treshold: 0.39
Test: accuracy: 0.7881755828857422 | f1: 0.8428022861480713 | auc: 0.8531665205955505 | 


train... loss:49.29840889573097:   0%|▏                                                                                                 | 12/5000 [01:22<9:43:33,  7.02s/it]

Val update: epoch: 11 |accuracy: 0.7789145112037659 | f1: 0.8453953266143799 | auc: 0.8462210297584534 | treshold: 0.39
Test: accuracy: 0.7819126844406128 | f1: 0.8480754494667053 | auc: 0.8566611409187317 | 


train... loss:46.78345614671707:   0%|▎                                                                                                 | 15/5000 [01:43<9:27:54,  6.84s/it]

Val update: epoch: 14 |accuracy: 0.7926764488220215 | f1: 0.8450974225997925 | auc: 0.8535665273666382 | treshold: 0.39
Test: accuracy: 0.7955265641212463 | f1: 0.8481212854385376 | auc: 0.8607978820800781 | 


train... loss:45.53162109851837:   0%|▎                                                                                                 | 17/5000 [01:56<9:28:08,  6.84s/it]

Val update: epoch: 16 |accuracy: 0.7901610732078552 | f1: 0.8372904062271118 | auc: 0.857064962387085 | treshold: 0.41000000000000003
Test: accuracy: 0.7952605485916138 | f1: 0.8422601819038391 | auc: 0.8621870875358582 | 


train... loss:44.33606880903244:   0%|▎                                                                                                 | 19/5000 [02:10<9:27:58,  6.84s/it]

Val update: epoch: 18 |accuracy: 0.7964252829551697 | f1: 0.8437389731407166 | auc: 0.8638356924057007 | treshold: 0.4
Test: accuracy: 0.7990569472312927 | f1: 0.8466506600379944 | auc: 0.8675981163978577 | 


train... loss:43.350206553936005:   0%|▍                                                                                                | 20/5000 [02:17<9:37:16,  6.96s/it]

Val update: epoch: 19 |accuracy: 0.8036327362060547 | f1: 0.8540666699409485 | auc: 0.8695065975189209 | treshold: 0.42000000000000004
Test: accuracy: 0.8034095168113708 | f1: 0.8549612760543823 | auc: 0.8745737075805664 | 


train... loss:42.55870145559311:   0%|▍                                                                                                 | 21/5000 [02:24<9:40:58,  7.00s/it]

Val update: epoch: 20 |accuracy: 0.7905964255332947 | f1: 0.8537104725837708 | auc: 0.8705636262893677 | treshold: 0.38
Test: accuracy: 0.7889977097511292 | f1: 0.8533445596694946 | auc: 0.8735484480857849 | 


train... loss:41.12229512631893:   0%|▍                                                                                                 | 22/5000 [02:31<9:43:35,  7.03s/it]

Val update: epoch: 21 |accuracy: 0.7952159643173218 | f1: 0.8560254573822021 | auc: 0.8759599924087524 | treshold: 0.4
Test: accuracy: 0.7919961214065552 | f1: 0.8546075224876404 | auc: 0.8777013421058655 | 


train... loss:40.558552503585815:   0%|▍                                                                                                | 23/5000 [02:38<9:43:42,  7.04s/it]

Val update: epoch: 22 |accuracy: 0.7982150912284851 | f1: 0.8580181002616882 | auc: 0.8826581835746765 | treshold: 0.39
Test: accuracy: 0.7954056262969971 | f1: 0.8568770289421082 | auc: 0.8855622410774231 | 


train... loss:39.62602338194847:   0%|▍                                                                                                 | 24/5000 [02:45<9:44:37,  7.05s/it]

Val update: epoch: 23 |accuracy: 0.8099936842918396 | f1: 0.8634594082832336 | auc: 0.8888983726501465 | treshold: 0.41000000000000003
Test: accuracy: 0.807181715965271 | f1: 0.8624460697174072 | auc: 0.8913837671279907 | 


train... loss:39.210728257894516:   0%|▍                                                                                                | 25/5000 [02:52<9:45:44,  7.06s/it]

Val update: epoch: 24 |accuracy: 0.8162820935249329 | f1: 0.8668723106384277 | auc: 0.8915619850158691 | treshold: 0.42000000000000004
Test: accuracy: 0.8127917051315308 | f1: 0.8653846383094788 | auc: 0.8941280841827393 | 


train... loss:38.52468624711037:   1%|▌                                                                                                 | 26/5000 [02:59<9:46:41,  7.08s/it]

Val update: epoch: 25 |accuracy: 0.8217723369598389 | f1: 0.8689279556274414 | auc: 0.8927366733551025 | treshold: 0.43
Test: accuracy: 0.8184500336647034 | f1: 0.8676957488059998 | auc: 0.8945238590240479 | 


train... loss:38.17501047253609:   1%|▌                                                                                                 | 27/5000 [03:07<9:47:16,  7.09s/it]

Val update: epoch: 26 |accuracy: 0.8214337825775146 | f1: 0.8700473308563232 | auc: 0.8966867923736572 | treshold: 0.39
Test: accuracy: 0.8183774352073669 | f1: 0.8688149452209473 | auc: 0.8983397483825684 | 


train... loss:37.61427268385887:   1%|▌                                                                                                 | 28/5000 [03:14<9:49:14,  7.11s/it]

Val update: epoch: 27 |accuracy: 0.829124927520752 | f1: 0.8716831207275391 | auc: 0.8981202840805054 | treshold: 0.45
Test: accuracy: 0.8254382610321045 | f1: 0.8700379729270935 | auc: 0.8997318744659424 | 


train... loss:37.290294498205185:   1%|▌                                                                                                | 30/5000 [03:27<9:38:53,  6.99s/it]

Val update: epoch: 29 |accuracy: 0.830842137336731 | f1: 0.8752808570861816 | auc: 0.9021109342575073 | treshold: 0.43
Test: accuracy: 0.8264780640602112 | f1: 0.8730720281600952 | auc: 0.9033664464950562 | 


train... loss:36.23656949400902:   1%|▋                                                                                                 | 32/5000 [03:41<9:32:43,  6.92s/it]

Val update: epoch: 31 |accuracy: 0.8317854404449463 | f1: 0.8763753175735474 | auc: 0.9043072462081909 | treshold: 0.41000000000000003
Test: accuracy: 0.8266956806182861 | f1: 0.8736001253128052 | auc: 0.9048566818237305 | 


train... loss:35.53117747604847:   1%|▋                                                                                                 | 34/5000 [03:54<9:27:53,  6.86s/it]

Val update: epoch: 33 |accuracy: 0.8337928652763367 | f1: 0.877088189125061 | auc: 0.9056293964385986 | treshold: 0.42000000000000004
Test: accuracy: 0.8289687037467957 | f1: 0.874532163143158 | auc: 0.9057130217552185 | 


train... loss:35.71739247441292:   1%|▋                                                                                                 | 35/5000 [04:02<9:35:05,  6.95s/it]

Val update: epoch: 34 |accuracy: 0.8369370698928833 | f1: 0.8773735761642456 | auc: 0.9070500135421753 | treshold: 0.42000000000000004
Test: accuracy: 0.8338532447814941 | f1: 0.8758559823036194 | auc: 0.9076467752456665 | 


train... loss:34.6170437335968:   1%|▊                                                                                                  | 39/5000 [04:28<9:22:13,  6.80s/it]

Val update: epoch: 38 |accuracy: 0.8352198600769043 | f1: 0.8785150051116943 | auc: 0.9082323312759399 | treshold: 0.4
Test: accuracy: 0.8320638537406921 | f1: 0.876986026763916 | auc: 0.9088174700737 | 


train... loss:34.60519960522652:   1%|▊                                                                                                 | 40/5000 [04:35<9:30:29,  6.90s/it]

Val update: epoch: 39 |accuracy: 0.8403472900390625 | f1: 0.8794360160827637 | auc: 0.9090366363525391 | treshold: 0.41000000000000003
Test: accuracy: 0.8372143507003784 | f1: 0.8778354525566101 | auc: 0.9086928963661194 | 


train... loss:33.963875353336334:   1%|▊                                                                                                | 42/5000 [04:49<9:31:06,  6.91s/it]

Val update: epoch: 41 |accuracy: 0.835776150226593 | f1: 0.8690099716186523 | auc: 0.9092758893966675 | treshold: 0.41000000000000003
Test: accuracy: 0.8355700373649597 | f1: 0.8697667121887207 | auc: 0.9097118377685547 | 


train... loss:33.907284423708916:   1%|▊                                                                                                | 45/5000 [05:09<9:24:41,  6.84s/it]

Val update: epoch: 44 |accuracy: 0.8388961553573608 | f1: 0.8801612257957458 | auc: 0.9107710123062134 | treshold: 0.41000000000000003
Test: accuracy: 0.837504506111145 | f1: 0.8798326253890991 | auc: 0.9116508960723877 | 


train... loss:32.96543197333813:   1%|▉                                                                                                 | 47/5000 [05:23<9:25:27,  6.85s/it]

Val update: epoch: 46 |accuracy: 0.838726818561554 | f1: 0.8716260194778442 | auc: 0.9110019207000732 | treshold: 0.43
Test: accuracy: 0.836899995803833 | f1: 0.8708670735359192 | auc: 0.9113709926605225 | 


train... loss:32.199996300041676:   1%|▉                                                                                                | 49/5000 [05:37<9:25:14,  6.85s/it]

Val update: epoch: 48 |accuracy: 0.8444347977638245 | f1: 0.8818429112434387 | auc: 0.9123435020446777 | treshold: 0.42000000000000004
Test: accuracy: 0.8437674045562744 | f1: 0.8820016384124756 | auc: 0.9145193099975586 | 


train... loss:32.23187816143036:   1%|▉                                                                                                 | 50/5000 [05:44<9:29:01,  6.90s/it]

Val update: epoch: 49 |accuracy: 0.8407343029975891 | f1: 0.8735137581825256 | auc: 0.9126625061035156 | treshold: 0.43
Test: accuracy: 0.8395115733146667 | f1: 0.8733179569244385 | auc: 0.9120738506317139 | 


train... loss:32.7263308763504:   1%|█                                                                                                  | 53/5000 [06:04<9:26:05,  6.87s/it]

Val update: epoch: 52 |accuracy: 0.8428143262863159 | f1: 0.8818127512931824 | auc: 0.9129257798194885 | treshold: 0.43
Test: accuracy: 0.843477189540863 | f1: 0.8829708099365234 | auc: 0.9136471748352051 | 


train... loss:31.7363583445549:   1%|█                                                                                                  | 55/5000 [06:18<9:24:38,  6.85s/it]

Val update: epoch: 54 |accuracy: 0.8457891941070557 | f1: 0.8837726712226868 | auc: 0.9151889085769653 | treshold: 0.41000000000000003
Test: accuracy: 0.8459436297416687 | f1: 0.8845142722129822 | auc: 0.9176298379898071 | 


train... loss:30.12977209687233:   1%|█▎                                                                                                | 70/5000 [07:56<9:08:52,  6.68s/it]

Val update: epoch: 69 |accuracy: 0.8336477279663086 | f1: 0.8639689683914185 | auc: 0.9154866337776184 | treshold: 0.44
Test: accuracy: 0.8339983224868774 | f1: 0.865140974521637 | auc: 0.91620934009552 | 


train... loss:29.870126456022263:   1%|█▍                                                                                               | 71/5000 [08:03<9:18:38,  6.80s/it]

Val update: epoch: 70 |accuracy: 0.8499975800514221 | f1: 0.8825690150260925 | auc: 0.9161500930786133 | treshold: 0.4
Test: accuracy: 0.8475638031959534 | f1: 0.8813297748565674 | auc: 0.9182339906692505 | 


train... loss:29.87573143839836:   1%|█▍                                                                                                | 72/5000 [08:10<9:22:19,  6.85s/it]

Val update: epoch: 71 |accuracy: 0.8472645282745361 | f1: 0.883407473564148 | auc: 0.9163686037063599 | treshold: 0.43
Test: accuracy: 0.845460057258606 | f1: 0.8829208612442017 | auc: 0.9177389144897461 | 


train... loss:30.12754061818123:   2%|█▍                                                                                                | 76/5000 [08:36<9:13:44,  6.75s/it]

Val update: epoch: 75 |accuracy: 0.8427417278289795 | f1: 0.8756930232048035 | auc: 0.9164512157440186 | treshold: 0.41000000000000003
Test: accuracy: 0.8441784381866455 | f1: 0.8774158954620361 | auc: 0.9180556535720825 | 


train... loss:29.093308188021183:   2%|█▌                                                                                               | 78/5000 [08:50<9:16:37,  6.79s/it]

Val update: epoch: 77 |accuracy: 0.8485463857650757 | f1: 0.8831062316894531 | auc: 0.9168189764022827 | treshold: 0.45
Test: accuracy: 0.849280595779419 | f1: 0.8844563961029053 | auc: 0.9181777238845825 | 


train... loss:29.056792438030243:   2%|█▌                                                                                               | 80/5000 [09:03<9:16:11,  6.78s/it]

Val update: epoch: 79 |accuracy: 0.8454022407531738 | f1: 0.8769254684448242 | auc: 0.9176271557807922 | treshold: 0.42000000000000004
Test: accuracy: 0.8468866944313049 | f1: 0.8787716627120972 | auc: 0.9202585220336914 | 


train... loss:28.972287878394127:   2%|█▋                                                                                               | 86/5000 [09:43<9:08:18,  6.69s/it]

Val update: epoch: 85 |accuracy: 0.8456440567970276 | f1: 0.8772078156471252 | auc: 0.9183125495910645 | treshold: 0.42000000000000004
Test: accuracy: 0.8457260131835938 | f1: 0.8779087662696838 | auc: 0.9212091565132141 | 


train... loss:28.79174444079399:   2%|█▋                                                                                                | 87/5000 [09:50<9:17:36,  6.81s/it]

Val update: epoch: 86 |accuracy: 0.848135232925415 | f1: 0.8793451189994812 | auc: 0.9189528226852417 | treshold: 0.45
Test: accuracy: 0.8479264974594116 | f1: 0.8798685669898987 | auc: 0.9219239950180054 | 


train... loss:26.88394247740507:   3%|██▌                                                                                              | 129/5000 [14:24<9:01:24,  6.67s/it]

Val update: epoch: 128 |accuracy: 0.8463212847709656 | f1: 0.8835623860359192 | auc: 0.9192904233932495 | treshold: 0.44
Test: accuracy: 0.8479748368263245 | f1: 0.8852632641792297 | auc: 0.9212088584899902 | 


train... loss:25.71717904508114:   4%|███▍                                                                                             | 179/5000 [19:54<8:56:24,  6.68s/it]


In [13]:
pd.DataFrame(rl4rs_results).to_csv(f'results/rl4rs_{experiment_name}.csv')
del dataset, train_loader, val_loader, test_loader, train_user_item_matrix, train_num_items