In [1]:
# !export 'PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512'

import os
import torch
import random
import datetime
import pandas as pd
import numpy as np

from torch.utils.data import Dataset
from src.datasets import RL4RS, ContentWise, DummyData
from src.utils import train, get_dummy_data, get_train_val_test_tmatrix_tnumitems
from src.embeddings import RecsysEmbedding

experiment_name = 'SessionwiseAttention'
device = 'cuda:0'
seed = 123
pkl_path = '../pkl/'

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7f068e0cab50>

In [2]:
torch.__version__

'1.12.1'

# Модель

In [3]:
class SessionwiseAttention(torch.nn.Module):
    """
    No recurrent dependency, just slate-wise attention.
    """
    def __init__(self, embedding, nheads=2, output_dim=1):
        super().__init__()
        self.embedding_dim = embedding.embedding_dim
        self.embedding = embedding
        self.attention= torch.nn.MultiheadAttention(
            self.embedding_dim,
            num_heads=nheads,
            batch_first=True
        )
        
        self.out_layer = torch.nn.Linear(2 * embedding.embedding_dim, output_dim)

    def forward(self, batch):
        item_embs, user_embs = self.embedding(batch)
        shp = item_embs.shape
        device = item_embs.device
        
        # flattening slates in to long sequences
        item_embs = item_embs.flatten(1, 2)
        
        # let model attent to first padd token if slate is empty to avoid NaN gradients
        # (anyway they does not contrinute into metrics computation)
        key_padding_mask = batch['slates_mask'].clone()
        key_padding_mask[:,:, 0] = True # let model attent to first padd token if slate is empty 
        key_padding_mask = ~key_padding_mask.flatten(1,2)
        
        # forbid model looking into future (and into current iteraction)
        future_mask = torch.ones((item_embs.size(-2), item_embs.size(-2))).to(device)
        future_mask = torch.triu(future_mask, diagonal=1).bool()
        
        features, attn_map = self.attention(
            item_embs, item_embs, item_embs,
            key_padding_mask=key_padding_mask,
            attn_mask = future_mask
        )
#         print(features.shape, user_embs.shape, shp)
        features = torch.cat(
            [
                features.reshape(shp),
                user_embs[:, :, None, :].repeat(1, 1, shp[-2], 1)
            ],
            dim = -1
        )
        
        return self.out_layer(features).squeeze(-1)

# Игрушечный датасет: проверим, что сходится к идеальным метрикам

In [4]:
d = DummyData()
dummy_loader, dummy_matrix = get_dummy_data(d)

model = SessionwiseAttention(
    RecsysEmbedding(d.n_items, dummy_matrix, embeddings='neural').to('cpu'),
    output_dim=1
).to('cpu')

train(
    model, 
    dummy_loader, dummy_loader, dummy_loader,
    device='cpu', lr=1e-3, num_epochs=5000, dummy=True,
    silent=True,
)


biulding affinity matrix...


3it [00:00, 3956.89it/s]


Test before learning: {'f1': 0.4000000059604645, 'roc-auc': 0.6666666269302368, 'accuracy': 0.25}


train... loss:0.6950739622116089:   0%|                                                                                                    | 1/5000 [00:00<26:50,  3.10it/s]

Val update: epoch: 0 |accuracy: 0.75 | f1: 0.6666666865348816 | auc: 1.0 | treshold: 0.52
Test: accuracy: 0.75 | f1: 0.6666666865348816 | auc: 1.0 | 


train... loss:0.6950739622116089:   0%|                                                                                                    | 1/5000 [00:00<51:44,  1.61it/s]

Val update: epoch: 1 |accuracy: 1.0 | f1: 1.0 | auc: 1.0 | treshold: 0.5800000000000001
Test: accuracy: 1.0 | f1: 1.0 | auc: 1.0 | 





(SessionwiseAttention(
   (embedding): RecsysEmbedding(
     (item_embeddings): Embedding(5, 32)
   )
   (attention): MultiheadAttention(
     (out_proj): NonDynamicallyQuantizableLinear(in_features=32, out_features=32, bias=True)
   )
   (out_layer): Linear(in_features=64, out_features=1, bias=True)
 ),
 {'f1': 1.0, 'roc-auc': 1.0, 'accuracy': 1.0})

# ContentWise

In [5]:
content_wise_results = []
dataset = ContentWise.load(os.path.join(pkl_path, 'cw.pkl'))

# another batch size due to large memory usage
(
    train_loader, 
    val_loader, 
    test_loader, 
    train_user_item_matrix, 
    train_num_items 
) = get_train_val_test_tmatrix_tnumitems(dataset, batch_size=10)

print(f"{len(dataset)} data points among {len(train_loader)} batches")

20216 data points among 1618 batches


In [6]:
for embeddings in ['svd', 'neural']:
    print(f"\nEvaluating {experiment_name} with {embeddings} embeddings")
    
    model = SessionwiseAttention(
        RecsysEmbedding(train_num_items, train_user_item_matrix, embeddings=embeddings),
        output_dim=1
    ).to(device)

    _, metrics = train(
        model, 
        train_loader, val_loader, test_loader, 
        device=device, lr=1e-3, num_epochs=5000, early_stopping=7,
       silent=True, 
    )
    
    metrics['embeddings'] = embeddings
    content_wise_results.append(metrics)


Evaluating SessionwiseAttention with svd embeddings
Test before learning: {'f1': 0.06840086728334427, 'roc-auc': 0.5205211639404297, 'accuracy': 0.8633807301521301}


train... loss:522.6785174459219:   0%|                                                                                                | 1/5000 [02:23<199:10:45, 143.44s/it]

Val update: epoch: 0 |accuracy: 0.8171793818473816 | f1: 0.23787805438041687 | auc: 0.6180979013442993 | treshold: 0.09
Test: accuracy: 0.8175572752952576 | f1: 0.24274539947509766 | auc: 0.6207095384597778 | 


train... loss:497.3637044876814:   0%|                                                                                                | 2/5000 [04:14<172:28:49, 124.24s/it]

Val update: epoch: 1 |accuracy: 0.8029280304908752 | f1: 0.23885244131088257 | auc: 0.6211483478546143 | treshold: 0.13
Test: accuracy: 0.7982856035232544 | f1: 0.24761904776096344 | auc: 0.625458300113678 | 


train... loss:492.9937225282192:   0%|                                                                                                | 3/5000 [06:05<164:08:37, 118.25s/it]

Val update: epoch: 2 |accuracy: 0.8266803026199341 | f1: 0.24523571133613586 | auc: 0.6258740425109863 | treshold: 0.13
Test: accuracy: 0.8259924650192261 | f1: 0.2569926679134369 | auc: 0.6314875483512878 | 


train... loss:491.71891306340694:   0%|                                                                                               | 4/5000 [07:56<160:09:36, 115.41s/it]

Val update: epoch: 3 |accuracy: 0.8382829427719116 | f1: 0.2451283484697342 | auc: 0.6311534643173218 | treshold: 0.12
Test: accuracy: 0.8391847610473633 | f1: 0.25386080145835876 | auc: 0.6328122615814209 | 


train... loss:489.82569690048695:   0%|▏                                                                                             | 12/5000 [21:53<147:19:39, 106.33s/it]

Val update: epoch: 11 |accuracy: 0.8707444071769714 | f1: 0.22427645325660706 | auc: 0.6325771808624268 | treshold: 0.12
Test: accuracy: 0.8737916946411133 | f1: 0.24219748377799988 | auc: 0.634259045124054 | 


train... loss:489.29529897868633:   0%|▍                                                                                             | 21/5000 [37:34<146:16:56, 105.77s/it]

Val update: epoch: 20 |accuracy: 0.8273424506187439 | f1: 0.25037500262260437 | auc: 0.6346225738525391 | treshold: 0.12
Test: accuracy: 0.8278619050979614 | f1: 0.2608014643192291 | auc: 0.637192964553833 | 


train... loss:488.81389336287975:   0%|▍                                                                                             | 23/5000 [41:09<147:58:02, 107.03s/it]

Val update: epoch: 22 |accuracy: 0.8246505260467529 | f1: 0.2517353594303131 | auc: 0.6401534080505371 | treshold: 0.12
Test: accuracy: 0.8242142200469971 | f1: 0.2632182538509369 | auc: 0.6376187801361084 | 


train... loss:487.89444556832314:   1%|▌                                                                                             | 27/5000 [48:10<146:57:53, 106.39s/it]

Val update: epoch: 26 |accuracy: 0.8172081708908081 | f1: 0.24979321658611298 | auc: 0.6409962773323059 | treshold: 0.12
Test: accuracy: 0.8170861601829529 | f1: 0.26106712222099304 | auc: 0.6371282339096069 | 


train... loss:487.84131325781345:   1%|▊                                                                                           | 41/5000 [1:13:54<148:59:32, 108.16s/it]



Evaluating SessionwiseAttention with neural embeddings
Test before learning: {'f1': 0.16817502677440643, 'roc-auc': 0.4886724650859833, 'accuracy': 0.3141072392463684}


train... loss:514.380825906992:   0%|                                                                                                 | 1/5000 [01:53<158:10:13, 113.91s/it]

Val update: epoch: 0 |accuracy: 0.8819871544837952 | f1: 0.1782277524471283 | auc: 0.6709238290786743 | treshold: 0.12
Test: accuracy: 0.8799775242805481 | f1: 0.1849520057439804 | auc: 0.6738258600234985 | 


train... loss:480.38318118453026:   0%|                                                                                               | 2/5000 [03:46<157:12:13, 113.23s/it]

Val update: epoch: 1 |accuracy: 0.8199145197868347 | f1: 0.2667057514190674 | auc: 0.6933383345603943 | treshold: 0.13
Test: accuracy: 0.8221928477287292 | f1: 0.27672332525253296 | auc: 0.6985437273979187 | 


train... loss:467.45993869006634:   0%|                                                                                               | 6/5000 [11:01<152:50:35, 110.18s/it]

Val update: epoch: 5 |accuracy: 0.7718197107315063 | f1: 0.27874594926834106 | auc: 0.6984010934829712 | treshold: 0.14
Test: accuracy: 0.7737552523612976 | f1: 0.2841204106807709 | auc: 0.7038165926933289 | 


train... loss:462.80683448910713:   0%|▍                                                                                             | 21/5000 [39:24<155:43:53, 112.60s/it]


In [14]:
pd.DataFrame(content_wise_results).to_csv(f'results/cw_{experiment_name}.csv')
del dataset, train_loader, val_loader, test_loader, train_user_item_matrix, train_num_items

NameError: name 'dataset' is not defined

# RL4RS

In [8]:
rl4rs_results = []
dataset = RL4RS.load(os.path.join(pkl_path, 'rl4rs.pkl'))
(
    train_loader, 
    val_loader, 
    test_loader, 
    train_user_item_matrix, 
    train_num_items 
) = get_train_val_test_tmatrix_tnumitems(dataset, batch_size=350)

print(f"{len(dataset)} data points among {len(train_loader)} batches")

45942 data points among 106 batches


In [9]:
for embeddings in ['explicit', 'neural', 'svd']:
    print(f"\nEvaluating {experiment_name} with {embeddings} embeddings")

    model = SessionwiseAttention(
        RecsysEmbedding(
            train_num_items,
            train_user_item_matrix,
            embeddings=embeddings,
            embedding_dim=40
        ),
        output_dim=1
    ).to(device)


    best_model, metrics = train(
        model, 
        train_loader, val_loader, test_loader, 
        device=device, lr=1e-3, num_epochs=5000, early_stopping=7,
        silent=True
    )
    
    metrics['embeddings'] = embeddings
    rl4rs_results.append(metrics)


Evaluating SessionwiseAttention with explicit embeddings
Test before learning: {'f1': 0.5757297277450562, 'roc-auc': 0.5344966650009155, 'accuracy': 0.5195502638816833}


train... loss:60.10406452417374:   0%|                                                                                                  | 1/5000 [00:59<82:48:07, 59.63s/it]

Val update: epoch: 0 |accuracy: 0.7888550162315369 | f1: 0.8451688289642334 | auc: 0.8453947901725769 | treshold: 0.46
Test: accuracy: 0.7957683205604553 | f1: 0.8509669899940491 | auc: 0.8532283306121826 | 


train... loss:53.102506309747696:   0%|                                                                                                 | 2/5000 [01:05<39:13:49, 28.26s/it]

Val update: epoch: 1 |accuracy: 0.7928215265274048 | f1: 0.8477967381477356 | auc: 0.8560933470726013 | treshold: 0.43
Test: accuracy: 0.8009672164916992 | f1: 0.8546350598335266 | auc: 0.8639253973960876 | 


train... loss:50.80802398920059:   0%|                                                                                                  | 3/5000 [01:12<25:17:16, 18.22s/it]

Val update: epoch: 2 |accuracy: 0.7965946197509766 | f1: 0.8537797927856445 | auc: 0.8696720004081726 | treshold: 0.44
Test: accuracy: 0.8028050065040588 | f1: 0.8589271306991577 | auc: 0.8775451183319092 | 


train... loss:49.55788230895996:   0%|                                                                                                  | 4/5000 [01:18<18:42:16, 13.48s/it]

Val update: epoch: 3 |accuracy: 0.7915155291557312 | f1: 0.8523972630500793 | auc: 0.874175488948822 | treshold: 0.44
Test: accuracy: 0.7964695692062378 | f1: 0.8565047383308411 | auc: 0.8809843063354492 | 


train... loss:48.750785529613495:   0%|                                                                                                 | 5/5000 [01:24<15:10:00, 10.93s/it]

Val update: epoch: 4 |accuracy: 0.8046485781669617 | f1: 0.856950581073761 | auc: 0.882781982421875 | treshold: 0.44
Test: accuracy: 0.814556896686554 | f1: 0.8649942874908447 | auc: 0.8895819187164307 | 


train... loss:47.98665335774422:   0%|                                                                                                  | 6/5000 [01:31<12:59:36,  9.37s/it]

Val update: epoch: 5 |accuracy: 0.808010458946228 | f1: 0.8562997579574585 | auc: 0.8849816918373108 | treshold: 0.45
Test: accuracy: 0.817942202091217 | f1: 0.8645693063735962 | auc: 0.8918578624725342 | 


train... loss:47.79461517930031:   0%|▏                                                                                                 | 7/5000 [01:37<11:37:03,  8.38s/it]

Val update: epoch: 6 |accuracy: 0.8126541972160339 | f1: 0.8568577170372009 | auc: 0.8882078528404236 | treshold: 0.48000000000000004
Test: accuracy: 0.8207229971885681 | f1: 0.8640131950378418 | auc: 0.8946926593780518 | 


train... loss:47.12804752588272:   0%|▏                                                                                                 | 11/5000 [02:01<9:03:38,  6.54s/it]

Val update: epoch: 10 |accuracy: 0.8081797361373901 | f1: 0.860803484916687 | auc: 0.8903756141662598 | treshold: 0.48000000000000004
Test: accuracy: 0.811969518661499 | f1: 0.8645296096801758 | auc: 0.8970915079116821 | 


train... loss:46.8231797516346:   0%|▏                                                                                                  | 12/5000 [02:07<9:00:06,  6.50s/it]

Val update: epoch: 11 |accuracy: 0.8170077204704285 | f1: 0.8617880344390869 | auc: 0.893810510635376 | treshold: 0.47000000000000003
Test: accuracy: 0.8263329863548279 | f1: 0.8699243068695068 | auc: 0.90016770362854 | 


train... loss:46.65276247262955:   0%|▎                                                                                                 | 13/5000 [02:14<8:58:16,  6.48s/it]

Val update: epoch: 12 |accuracy: 0.8176365494728088 | f1: 0.8579823970794678 | auc: 0.8943523168563843 | treshold: 0.49
Test: accuracy: 0.8258976936340332 | f1: 0.8654608130455017 | auc: 0.9003539085388184 | 


train... loss:46.46898740530014:   0%|▎                                                                                                 | 15/5000 [02:26<8:39:33,  6.25s/it]

Val update: epoch: 14 |accuracy: 0.8188942074775696 | f1: 0.8646518588066101 | auc: 0.8955890536308289 | treshold: 0.49
Test: accuracy: 0.8258010149002075 | f1: 0.8707617521286011 | auc: 0.9014418125152588 | 


train... loss:46.126227498054504:   0%|▎                                                                                                | 17/5000 [02:38<8:36:16,  6.22s/it]

Val update: epoch: 16 |accuracy: 0.8162820935249329 | f1: 0.8658329844474792 | auc: 0.896823525428772 | treshold: 0.48000000000000004
Test: accuracy: 0.8191754221916199 | f1: 0.8689495921134949 | auc: 0.9028255939483643 | 


train... loss:45.97109845280647:   0%|▎                                                                                                 | 19/5000 [02:50<8:32:04,  6.17s/it]

Val update: epoch: 18 |accuracy: 0.8228123784065247 | f1: 0.8672007918357849 | auc: 0.8977030515670776 | treshold: 0.5
Test: accuracy: 0.8282915949821472 | f1: 0.8724103569984436 | auc: 0.9033479690551758 | 


train... loss:46.05689254403114:   0%|▍                                                                                                 | 20/5000 [02:56<8:38:44,  6.25s/it]

Val update: epoch: 19 |accuracy: 0.822425365447998 | f1: 0.8637847900390625 | auc: 0.8977936506271362 | treshold: 0.5
Test: accuracy: 0.8302260637283325 | f1: 0.8708164095878601 | auc: 0.903335690498352 | 


train... loss:45.79892662167549:   0%|▍                                                                                                 | 21/5000 [03:03<8:45:10,  6.33s/it]

Val update: epoch: 20 |accuracy: 0.822062611579895 | f1: 0.8611912727355957 | auc: 0.8992562890052795 | treshold: 0.51
Test: accuracy: 0.8310482501983643 | f1: 0.8690813183784485 | auc: 0.9045037031173706 | 


train... loss:45.73856168985367:   0%|▍                                                                                                 | 23/5000 [03:15<8:36:01,  6.22s/it]

Val update: epoch: 22 |accuracy: 0.8215305209159851 | f1: 0.8684505820274353 | auc: 0.9007251262664795 | treshold: 0.49
Test: accuracy: 0.8265506029129028 | f1: 0.8729565739631653 | auc: 0.9064463376998901 | 


train... loss:45.55339950323105:   1%|▌                                                                                                 | 26/5000 [03:33<8:24:58,  6.09s/it]

Val update: epoch: 25 |accuracy: 0.824819803237915 | f1: 0.8627933859825134 | auc: 0.9018179178237915 | treshold: 0.53
Test: accuracy: 0.8343368172645569 | f1: 0.8711709380149841 | auc: 0.9073630571365356 | 


train... loss:45.58780479431152:   1%|▌                                                                                                 | 27/5000 [03:39<8:31:17,  6.17s/it]

Val update: epoch: 26 |accuracy: 0.8190634846687317 | f1: 0.8685352802276611 | auc: 0.9030065536499023 | treshold: 0.5
Test: accuracy: 0.8219320774078369 | f1: 0.8713891506195068 | auc: 0.9080197811126709 | 


train... loss:45.32948875427246:   1%|▌                                                                                                 | 30/5000 [03:57<8:23:36,  6.08s/it]

Val update: epoch: 29 |accuracy: 0.8260775208473206 | f1: 0.8656365275382996 | auc: 0.9031783938407898 | treshold: 0.5
Test: accuracy: 0.834288477897644 | f1: 0.8727792501449585 | auc: 0.9082990288734436 | 


train... loss:45.19906321167946:   1%|▌                                                                                                 | 31/5000 [04:03<8:30:40,  6.17s/it]

Val update: epoch: 30 |accuracy: 0.828786313533783 | f1: 0.8675411343574524 | auc: 0.9032131433486938 | treshold: 0.51
Test: accuracy: 0.8359085917472839 | f1: 0.8741048574447632 | auc: 0.9077103137969971 | 


train... loss:45.36033275723457:   1%|▋                                                                                                 | 32/5000 [04:10<8:36:18,  6.24s/it]

Val update: epoch: 31 |accuracy: 0.812823474407196 | f1: 0.8655419945716858 | auc: 0.9045448899269104 | treshold: 0.49
Test: accuracy: 0.8166847825050354 | f1: 0.869191586971283 | auc: 0.9094489812850952 | 


train... loss:45.120078563690186:   1%|▋                                                                                                | 34/5000 [04:22<8:32:38,  6.19s/it]

Val update: epoch: 33 |accuracy: 0.8258598446846008 | f1: 0.871055543422699 | auc: 0.9061870574951172 | treshold: 0.54
Test: accuracy: 0.8315560221672058 | f1: 0.8760762810707092 | auc: 0.9104084968566895 | 


train... loss:45.15168014168739:   1%|▋                                                                                                 | 36/5000 [04:34<8:31:58,  6.19s/it]

Val update: epoch: 35 |accuracy: 0.8319305181503296 | f1: 0.8705742359161377 | auc: 0.9067187309265137 | treshold: 0.51
Test: accuracy: 0.8395599126815796 | f1: 0.8773590326309204 | auc: 0.9106618165969849 | 


train... loss:44.55873431265354:   1%|▋                                                                                                 | 38/5000 [04:46<8:29:14,  6.16s/it]

Val update: epoch: 37 |accuracy: 0.8231267929077148 | f1: 0.8699703216552734 | auc: 0.9067786931991577 | treshold: 0.5
Test: accuracy: 0.8270342350006104 | f1: 0.8738247752189636 | auc: 0.9110760688781738 | 


train... loss:44.88643115758896:   1%|▊                                                                                                 | 42/5000 [05:10<8:16:53,  6.01s/it]

Val update: epoch: 41 |accuracy: 0.8308905363082886 | f1: 0.869751513004303 | auc: 0.9070331454277039 | treshold: 0.5
Test: accuracy: 0.8389554023742676 | f1: 0.8766346573829651 | auc: 0.9115485548973083 | 


train... loss:45.023915231227875:   1%|▊                                                                                                | 43/5000 [05:16<8:23:39,  6.10s/it]

Val update: epoch: 42 |accuracy: 0.8319789171218872 | f1: 0.8698942065238953 | auc: 0.9078110456466675 | treshold: 0.52
Test: accuracy: 0.8384475708007812 | f1: 0.8758386373519897 | auc: 0.9114151000976562 | 


train... loss:44.9871002137661:   1%|▊                                                                                                  | 44/5000 [05:23<8:28:54,  6.16s/it]

Val update: epoch: 43 |accuracy: 0.8157742023468018 | f1: 0.8678269386291504 | auc: 0.9084307551383972 | treshold: 0.49
Test: accuracy: 0.8190303444862366 | f1: 0.8709477186203003 | auc: 0.9124715328216553 | 


train... loss:44.681902438402176:   1%|▉                                                                                                | 47/5000 [05:40<8:17:59,  6.03s/it]

Val update: epoch: 46 |accuracy: 0.831664502620697 | f1: 0.8736153841018677 | auc: 0.9086848497390747 | treshold: 0.5
Test: accuracy: 0.8362229466438293 | f1: 0.878010094165802 | auc: 0.9120076894760132 | 


train... loss:44.44553977251053:   1%|▉                                                                                                 | 48/5000 [05:47<8:23:43,  6.10s/it]

Val update: epoch: 47 |accuracy: 0.827214241027832 | f1: 0.8725696802139282 | auc: 0.9099631309509277 | treshold: 0.51
Test: accuracy: 0.8314109444618225 | f1: 0.8764793276786804 | auc: 0.913644015789032 | 


train... loss:44.994081288576126:   1%|█                                                                                                | 55/5000 [06:33<9:49:12,  7.15s/it]



Evaluating SessionwiseAttention with neural embeddings
Test before learning: {'f1': 0.5771521329879761, 'roc-auc': 0.4987527132034302, 'accuracy': 0.5089831948280334}


train... loss:63.576477348804474:   0%|                                                                                                  | 1/5000 [00:06<9:15:20,  6.67s/it]

Val update: epoch: 0 |accuracy: 0.7440623044967651 | f1: 0.8250099420547485 | auc: 0.8208204507827759 | treshold: 0.44
Test: accuracy: 0.7448434233665466 | f1: 0.8264302611351013 | auc: 0.8281413316726685 | 


train... loss:53.82590615749359:   0%|                                                                                                   | 2/5000 [00:13<9:21:51,  6.75s/it]

Val update: epoch: 1 |accuracy: 0.7758912444114685 | f1: 0.8454893827438354 | auc: 0.8683303594589233 | treshold: 0.45
Test: accuracy: 0.7790593504905701 | f1: 0.8481645584106445 | auc: 0.8766768574714661 | 


train... loss:48.8667978644371:   0%|                                                                                                    | 3/5000 [00:20<9:25:59,  6.80s/it]

Val update: epoch: 2 |accuracy: 0.8108402490615845 | f1: 0.8622117042541504 | auc: 0.8912395238876343 | treshold: 0.5
Test: accuracy: 0.8169265985488892 | f1: 0.867047131061554 | auc: 0.8980998992919922 | 


train... loss:47.095514327287674:   0%|                                                                                                  | 4/5000 [00:27<9:28:29,  6.83s/it]

Val update: epoch: 3 |accuracy: 0.8250616788864136 | f1: 0.862863302230835 | auc: 0.899319589138031 | treshold: 0.51
Test: accuracy: 0.831314206123352 | f1: 0.86815345287323 | auc: 0.9055581092834473 | 


train... loss:45.54659339785576:   0%|                                                                                                   | 5/5000 [00:34<9:29:46,  6.84s/it]

Val update: epoch: 4 |accuracy: 0.8359212279319763 | f1: 0.8750184178352356 | auc: 0.9100804328918457 | treshold: 0.54
Test: accuracy: 0.8398500680923462 | f1: 0.8784570097923279 | auc: 0.9137812256813049 | 


train... loss:45.192744106054306:   0%|                                                                                                  | 6/5000 [00:40<9:29:51,  6.85s/it]

Val update: epoch: 5 |accuracy: 0.8329705595970154 | f1: 0.8761744499206543 | auc: 0.9138736724853516 | treshold: 0.54
Test: accuracy: 0.8341192007064819 | f1: 0.8775655627250671 | auc: 0.9167599081993103 | 


train... loss:44.47795623540878:   0%|▏                                                                                                  | 8/5000 [00:53<9:16:47,  6.69s/it]

Val update: epoch: 7 |accuracy: 0.84283846616745 | f1: 0.8784056901931763 | auc: 0.9161219596862793 | treshold: 0.52
Test: accuracy: 0.8472977876663208 | f1: 0.8824569582939148 | auc: 0.9199730157852173 | 


train... loss:44.0023595392704:   0%|▎                                                                                                  | 14/5000 [01:32<9:00:05,  6.50s/it]

Val update: epoch: 13 |accuracy: 0.825497031211853 | f1: 0.8736405372619629 | auc: 0.9169049859046936 | treshold: 0.49
Test: accuracy: 0.8272760510444641 | f1: 0.8753381371498108 | auc: 0.9191257953643799 | 


train... loss:43.94691649079323:   0%|▎                                                                                                 | 15/5000 [01:39<9:08:07,  6.60s/it]

Val update: epoch: 14 |accuracy: 0.8432738184928894 | f1: 0.8754899501800537 | auc: 0.9177823066711426 | treshold: 0.54
Test: accuracy: 0.8470801711082458 | f1: 0.8792576789855957 | auc: 0.9211745858192444 | 


train... loss:43.460166573524475:   0%|▎                                                                                                | 17/5000 [01:52<9:09:20,  6.61s/it]

Val update: epoch: 16 |accuracy: 0.8281574845314026 | f1: 0.8753793239593506 | auc: 0.9189143180847168 | treshold: 0.51
Test: accuracy: 0.830903172492981 | f1: 0.8778920769691467 | auc: 0.9212234616279602 | 


train... loss:43.53853513300419:   0%|▎                                                                                                 | 18/5000 [01:59<9:15:09,  6.69s/it]

Val update: epoch: 17 |accuracy: 0.8480626940727234 | f1: 0.8838773965835571 | auc: 0.9208476543426514 | treshold: 0.51
Test: accuracy: 0.8489662408828735 | f1: 0.8852513432502747 | auc: 0.9233006238937378 | 


train... loss:43.350308775901794:   0%|▍                                                                                                | 22/5000 [02:24<9:02:52,  6.54s/it]

Val update: epoch: 21 |accuracy: 0.8480626940727234 | f1: 0.8847339153289795 | auc: 0.9217748045921326 | treshold: 0.5
Test: accuracy: 0.8489421010017395 | f1: 0.8859807848930359 | auc: 0.9236876368522644 | 


train... loss:42.84186094999313:   1%|▌                                                                                                 | 26/5000 [02:50<8:57:10,  6.48s/it]

Val update: epoch: 25 |accuracy: 0.8509650230407715 | f1: 0.8851358890533447 | auc: 0.9219724535942078 | treshold: 0.51
Test: accuracy: 0.8531253933906555 | f1: 0.8875393271446228 | auc: 0.9240998029708862 | 


train... loss:42.939302176237106:   1%|▋                                                                                                | 33/5000 [03:34<8:50:04,  6.40s/it]

Val update: epoch: 32 |accuracy: 0.8320514559745789 | f1: 0.8784568905830383 | auc: 0.9227665662765503 | treshold: 0.49
Test: accuracy: 0.8340466618537903 | f1: 0.8804126381874084 | auc: 0.9244817495346069 | 


train... loss:42.81679108738899:   1%|▋                                                                                                 | 36/5000 [03:53<8:54:59,  6.47s/it]

Val update: epoch: 35 |accuracy: 0.8395975232124329 | f1: 0.882152259349823 | auc: 0.9235205054283142 | treshold: 0.52
Test: accuracy: 0.8427759408950806 | f1: 0.885013997554779 | auc: 0.9251147508621216 | 


train... loss:42.92095822095871:   1%|▋                                                                                                 | 36/5000 [03:59<9:11:26,  6.67s/it]



Evaluating SessionwiseAttention with svd embeddings
Test before learning: {'f1': 0.47716495394706726, 'roc-auc': 0.46824291348457336, 'accuracy': 0.44275179505348206}


train... loss:69.12789046764374:   0%|                                                                                                   | 1/5000 [00:06<8:33:13,  6.16s/it]

Val update: epoch: 0 |accuracy: 0.6485754251480103 | f1: 0.7831407785415649 | auc: 0.7152985334396362 | treshold: 0.49
Test: accuracy: 0.6492806077003479 | f1: 0.7837225198745728 | auc: 0.7169841527938843 | 


train... loss:65.54003313183784:   0%|                                                                                                   | 2/5000 [00:12<8:39:14,  6.23s/it]

Val update: epoch: 1 |accuracy: 0.6434237957000732 | f1: 0.7828814387321472 | auc: 0.7273731231689453 | treshold: 0.42000000000000004
Test: accuracy: 0.6451940536499023 | f1: 0.7842428088188171 | auc: 0.7321620583534241 | 


train... loss:63.882972329854965:   0%|                                                                                                  | 3/5000 [00:18<8:37:06,  6.21s/it]

Val update: epoch: 2 |accuracy: 0.6505829095840454 | f1: 0.7848387956619263 | auc: 0.7386698722839355 | treshold: 0.43
Test: accuracy: 0.6517712473869324 | f1: 0.7862623929977417 | auc: 0.7443430423736572 | 


train... loss:62.740438133478165:   0%|                                                                                                  | 4/5000 [00:24<8:36:27,  6.20s/it]

Val update: epoch: 3 |accuracy: 0.7027040123939514 | f1: 0.7870408892631531 | auc: 0.7603610157966614 | treshold: 0.43
Test: accuracy: 0.7078709006309509 | f1: 0.7909536361694336 | auc: 0.7657673358917236 | 


train... loss:61.0042679309845:   0%|                                                                                                    | 5/5000 [00:31<8:38:15,  6.23s/it]

Val update: epoch: 4 |accuracy: 0.7234073281288147 | f1: 0.7885433435440063 | auc: 0.776167631149292 | treshold: 0.45
Test: accuracy: 0.7309877872467041 | f1: 0.7950630784034729 | auc: 0.7822588086128235 | 


train... loss:59.85639697313309:   0%|                                                                                                   | 6/5000 [00:37<8:39:12,  6.24s/it]

Val update: epoch: 5 |accuracy: 0.7232380509376526 | f1: 0.7996708750724792 | auc: 0.7881367206573486 | treshold: 0.42000000000000004
Test: accuracy: 0.732994794845581 | f1: 0.8073151111602783 | auc: 0.7946302890777588 | 


train... loss:58.60243409872055:   0%|▏                                                                                                  | 7/5000 [00:43<8:41:29,  6.27s/it]

Val update: epoch: 6 |accuracy: 0.6806946396827698 | f1: 0.7963722944259644 | auc: 0.7975301742553711 | treshold: 0.43
Test: accuracy: 0.6870753169059753 | f1: 0.8007636070251465 | auc: 0.805088996887207 | 


train... loss:57.76501268148422:   0%|▏                                                                                                  | 8/5000 [00:49<8:41:19,  6.27s/it]

Val update: epoch: 7 |accuracy: 0.7446911334991455 | f1: 0.8065178394317627 | auc: 0.8056989908218384 | treshold: 0.42000000000000004
Test: accuracy: 0.7528956532478333 | f1: 0.8139666318893433 | auc: 0.8129507899284363 | 


train... loss:57.2687765955925:   0%|▏                                                                                                   | 9/5000 [00:56<8:41:07,  6.26s/it]

Val update: epoch: 8 |accuracy: 0.7323078513145447 | f1: 0.8146994709968567 | auc: 0.8124620914459229 | treshold: 0.4
Test: accuracy: 0.7393301725387573 | f1: 0.8202674388885498 | auc: 0.8195371627807617 | 


train... loss:56.55599129199982:   0%|▏                                                                                                 | 10/5000 [01:02<8:42:05,  6.28s/it]

Val update: epoch: 9 |accuracy: 0.7535674571990967 | f1: 0.8195135593414307 | auc: 0.8187358379364014 | treshold: 0.43
Test: accuracy: 0.7594486474990845 | f1: 0.8248221278190613 | auc: 0.8255927562713623 | 


train... loss:56.151644706726074:   0%|▏                                                                                                | 11/5000 [01:08<8:41:49,  6.28s/it]

Val update: epoch: 10 |accuracy: 0.7565665245056152 | f1: 0.8243486285209656 | auc: 0.8235154151916504 | treshold: 0.4
Test: accuracy: 0.7616007924079895 | f1: 0.8286910653114319 | auc: 0.8303882479667664 | 


train... loss:55.491580098867416:   0%|▏                                                                                                | 12/5000 [01:15<8:40:39,  6.26s/it]

Val update: epoch: 11 |accuracy: 0.7186668515205383 | f1: 0.814924418926239 | auc: 0.8252679705619812 | treshold: 0.41000000000000003
Test: accuracy: 0.7248942255973816 | f1: 0.8191054463386536 | auc: 0.8323583602905273 | 


train... loss:55.320485442876816:   0%|▎                                                                                                | 13/5000 [01:21<8:40:26,  6.26s/it]

Val update: epoch: 12 |accuracy: 0.7538093328475952 | f1: 0.8273018598556519 | auc: 0.8277868032455444 | treshold: 0.41000000000000003
Test: accuracy: 0.7610688209533691 | f1: 0.8330094218254089 | auc: 0.8342591524124146 | 


train... loss:55.35376560688019:   0%|▎                                                                                                 | 14/5000 [01:27<8:42:22,  6.29s/it]

Val update: epoch: 13 |accuracy: 0.7670875191688538 | f1: 0.8205166459083557 | auc: 0.8315439820289612 | treshold: 0.42000000000000004
Test: accuracy: 0.7731108665466309 | f1: 0.8263854384422302 | auc: 0.8376115560531616 | 


train... loss:54.69984620809555:   0%|▎                                                                                                 | 16/5000 [01:39<8:35:38,  6.21s/it]

Val update: epoch: 15 |accuracy: 0.7369999289512634 | f1: 0.823919951915741 | auc: 0.8356353044509888 | treshold: 0.41000000000000003
Test: accuracy: 0.741796612739563 | f1: 0.8274903893470764 | auc: 0.8417340517044067 | 


train... loss:54.07863247394562:   0%|▎                                                                                                 | 17/5000 [01:46<8:41:05,  6.27s/it]

Val update: epoch: 16 |accuracy: 0.7693851590156555 | f1: 0.8336676955223083 | auc: 0.8381637930870056 | treshold: 0.43
Test: accuracy: 0.7759883999824524 | f1: 0.8390548825263977 | auc: 0.8447255492210388 | 


train... loss:54.17195773124695:   0%|▎                                                                                                 | 18/5000 [01:52<8:42:47,  6.30s/it]

Val update: epoch: 17 |accuracy: 0.7774391770362854 | f1: 0.8289654850959778 | auc: 0.8416974544525146 | treshold: 0.41000000000000003
Test: accuracy: 0.7829524874687195 | f1: 0.8339745402336121 | auc: 0.847923994064331 | 


train... loss:53.903960675001144:   0%|▎                                                                                                | 19/5000 [01:58<8:44:43,  6.32s/it]

Val update: epoch: 18 |accuracy: 0.7774391770362854 | f1: 0.8289591073989868 | auc: 0.8421980142593384 | treshold: 0.41000000000000003
Test: accuracy: 0.7835811972618103 | f1: 0.8344922065734863 | auc: 0.8484675288200378 | 


train... loss:53.52357864379883:   0%|▍                                                                                                 | 20/5000 [02:05<8:47:31,  6.36s/it]

Val update: epoch: 19 |accuracy: 0.7784791588783264 | f1: 0.8361274600028992 | auc: 0.8436062932014465 | treshold: 0.4
Test: accuracy: 0.7850803732872009 | f1: 0.8416138291358948 | auc: 0.8501715660095215 | 


train... loss:53.25983113050461:   0%|▍                                                                                                 | 21/5000 [02:11<8:48:57,  6.37s/it]

Val update: epoch: 20 |accuracy: 0.7689740061759949 | f1: 0.8368518352508545 | auc: 0.8444843888282776 | treshold: 0.41000000000000003
Test: accuracy: 0.7753596901893616 | f1: 0.8420068025588989 | auc: 0.8511549234390259 | 


train... loss:53.244015991687775:   0%|▍                                                                                                | 22/5000 [02:18<8:51:12,  6.40s/it]

Val update: epoch: 21 |accuracy: 0.7668940424919128 | f1: 0.8366385102272034 | auc: 0.8450093269348145 | treshold: 0.42000000000000004
Test: accuracy: 0.7733284831047058 | f1: 0.8416393399238586 | auc: 0.8519690036773682 | 


train... loss:53.09846764802933:   0%|▍                                                                                                 | 23/5000 [02:24<8:51:28,  6.41s/it]

Val update: epoch: 22 |accuracy: 0.7637739777565002 | f1: 0.8358845710754395 | auc: 0.8460797071456909 | treshold: 0.42000000000000004
Test: accuracy: 0.7702091932296753 | f1: 0.8408554196357727 | auc: 0.8528828620910645 | 


train... loss:52.7377915084362:   0%|▍                                                                                                  | 24/5000 [02:31<8:51:28,  6.41s/it]

Val update: epoch: 23 |accuracy: 0.7630484104156494 | f1: 0.8361678123474121 | auc: 0.8474217653274536 | treshold: 0.41000000000000003
Test: accuracy: 0.7680329084396362 | f1: 0.8401779532432556 | auc: 0.854264497756958 | 


train... loss:52.78546315431595:   0%|▍                                                                                                 | 25/5000 [02:37<8:53:35,  6.44s/it]

Val update: epoch: 24 |accuracy: 0.7692158818244934 | f1: 0.8384737968444824 | auc: 0.849460780620575 | treshold: 0.42000000000000004
Test: accuracy: 0.7741022706031799 | f1: 0.8424673676490784 | auc: 0.8560202121734619 | 


train... loss:52.53187960386276:   1%|▌                                                                                                 | 26/5000 [02:44<8:54:27,  6.45s/it]

Val update: epoch: 25 |accuracy: 0.7793740630149841 | f1: 0.8413895964622498 | auc: 0.8495111465454102 | treshold: 0.42000000000000004
Test: accuracy: 0.7855882048606873 | f1: 0.8465943336486816 | auc: 0.8568253517150879 | 


train... loss:52.257383584976196:   1%|▌                                                                                                | 27/5000 [02:50<8:51:42,  6.42s/it]

Val update: epoch: 26 |accuracy: 0.7889276146888733 | f1: 0.840815007686615 | auc: 0.8514135479927063 | treshold: 0.42000000000000004
Test: accuracy: 0.7962519526481628 | f1: 0.8472886681556702 | auc: 0.8580882549285889 | 


train... loss:52.092192590236664:   1%|▌                                                                                                | 28/5000 [02:56<8:47:09,  6.36s/it]

Val update: epoch: 27 |accuracy: 0.7901127338409424 | f1: 0.838066816329956 | auc: 0.85334312915802 | treshold: 0.43
Test: accuracy: 0.7971224784851074 | f1: 0.8444510102272034 | auc: 0.8601480722427368 | 


train... loss:52.17539927363396:   1%|▌                                                                                                 | 29/5000 [03:02<8:45:11,  6.34s/it]

Val update: epoch: 28 |accuracy: 0.7910075783729553 | f1: 0.8384406566619873 | auc: 0.853823184967041 | treshold: 0.43
Test: accuracy: 0.7981380820274353 | f1: 0.8448442220687866 | auc: 0.861052930355072 | 


train... loss:51.97612974047661:   1%|▌                                                                                                 | 30/5000 [03:09<8:45:40,  6.35s/it]

Val update: epoch: 29 |accuracy: 0.7926280498504639 | f1: 0.8407799601554871 | auc: 0.8558451533317566 | treshold: 0.43
Test: accuracy: 0.800386905670166 | f1: 0.8475896716117859 | auc: 0.8628639578819275 | 


train... loss:51.27812948822975:   1%|▋                                                                                                 | 33/5000 [03:27<8:28:30,  6.14s/it]

Val update: epoch: 32 |accuracy: 0.7878150343894958 | f1: 0.8456789255142212 | auc: 0.8562108278274536 | treshold: 0.41000000000000003
Test: accuracy: 0.7934711575508118 | f1: 0.8505538105964661 | auc: 0.863640546798706 | 


train... loss:51.230511516332626:   1%|▋                                                                                                | 34/5000 [03:33<8:34:44,  6.22s/it]

Val update: epoch: 33 |accuracy: 0.784888505935669 | f1: 0.8453056216239929 | auc: 0.8598212003707886 | treshold: 0.42000000000000004
Test: accuracy: 0.7922379374504089 | f1: 0.8509833812713623 | auc: 0.8670480251312256 | 


train... loss:50.81160458922386:   1%|▋                                                                                                 | 38/5000 [03:57<8:16:56,  6.01s/it]

Val update: epoch: 37 |accuracy: 0.7778987288475037 | f1: 0.8436057567596436 | auc: 0.8598881959915161 | treshold: 0.43
Test: accuracy: 0.7824447154998779 | f1: 0.8474645018577576 | auc: 0.8666446805000305 | 


train... loss:50.692202150821686:   1%|▊                                                                                                | 39/5000 [04:03<8:22:49,  6.08s/it]

Val update: epoch: 38 |accuracy: 0.7767136096954346 | f1: 0.8437055349349976 | auc: 0.8619158267974854 | treshold: 0.41000000000000003
Test: accuracy: 0.7819610834121704 | f1: 0.8477295637130737 | auc: 0.8690256476402283 | 


train... loss:50.34883201122284:   1%|▊                                                                                                 | 42/5000 [04:21<8:18:53,  6.04s/it]

Val update: epoch: 41 |accuracy: 0.7859768867492676 | f1: 0.8467891216278076 | auc: 0.8652701377868652 | treshold: 0.41000000000000003
Test: accuracy: 0.7933260798454285 | f1: 0.8524318337440491 | auc: 0.8725892305374146 | 


train... loss:50.54136282205582:   1%|▊                                                                                                 | 43/5000 [04:27<8:31:27,  6.19s/it]

Val update: epoch: 42 |accuracy: 0.7987229824066162 | f1: 0.8499783873558044 | auc: 0.8660659790039062 | treshold: 0.43
Test: accuracy: 0.8051021695137024 | f1: 0.8552754521369934 | auc: 0.873189389705658 | 


train... loss:50.39825049042702:   1%|▉                                                                                                 | 48/5000 [04:57<8:14:19,  5.99s/it]

Val update: epoch: 47 |accuracy: 0.7980215549468994 | f1: 0.843517541885376 | auc: 0.8677321672439575 | treshold: 0.43
Test: accuracy: 0.8071333765983582 | f1: 0.8513991832733154 | auc: 0.8742008209228516 | 


train... loss:50.0648598074913:   1%|▉                                                                                                  | 49/5000 [05:03<8:23:07,  6.10s/it]

Val update: epoch: 48 |accuracy: 0.7946596741676331 | f1: 0.8495907783508301 | auc: 0.8683708906173706 | treshold: 0.42000000000000004
Test: accuracy: 0.8019586801528931 | f1: 0.8553617000579834 | auc: 0.874860405921936 | 


train... loss:50.53167071938515:   1%|▉                                                                                                 | 50/5000 [05:09<8:28:22,  6.16s/it]

Val update: epoch: 49 |accuracy: 0.7848159670829773 | f1: 0.8477193117141724 | auc: 0.8701743483543396 | treshold: 0.43
Test: accuracy: 0.7902792692184448 | f1: 0.8518778085708618 | auc: 0.8765193223953247 | 


train... loss:49.85315063595772:   1%|█                                                                                                 | 55/5000 [05:39<8:16:31,  6.02s/it]

Val update: epoch: 54 |accuracy: 0.8012866973876953 | f1: 0.8510515093803406 | auc: 0.8705160021781921 | treshold: 0.44
Test: accuracy: 0.8074477314949036 | f1: 0.8560582995414734 | auc: 0.87652188539505 | 


train... loss:49.540421426296234:   1%|█                                                                                                | 57/5000 [05:51<8:19:25,  6.06s/it]

Val update: epoch: 56 |accuracy: 0.7844773530960083 | f1: 0.8481243252754211 | auc: 0.8725230693817139 | treshold: 0.43
Test: accuracy: 0.788223922252655 | f1: 0.8511151671409607 | auc: 0.8791541457176208 | 


train... loss:49.51350477337837:   1%|█▏                                                                                                | 58/5000 [05:57<8:27:06,  6.16s/it]

Val update: epoch: 57 |accuracy: 0.8027378916740417 | f1: 0.8514497876167297 | auc: 0.8726757764816284 | treshold: 0.45
Test: accuracy: 0.8088018298149109 | f1: 0.8566923141479492 | auc: 0.878899097442627 | 


train... loss:49.13167987763882:   1%|█▏                                                                                                | 60/5000 [06:09<8:26:31,  6.15s/it]

Val update: epoch: 59 |accuracy: 0.7909108400344849 | f1: 0.8502485752105713 | auc: 0.8734177947044373 | treshold: 0.44
Test: accuracy: 0.7955749034881592 | f1: 0.8539745211601257 | auc: 0.879436731338501 | 


train... loss:49.54580545425415:   1%|█▏                                                                                                | 63/5000 [06:27<8:18:50,  6.06s/it]

Val update: epoch: 62 |accuracy: 0.8031490445137024 | f1: 0.8478549122810364 | auc: 0.8747761249542236 | treshold: 0.45
Test: accuracy: 0.8118002414703369 | f1: 0.855267345905304 | auc: 0.8807227611541748 | 


train... loss:49.58063179254532:   1%|█▎                                                                                                | 65/5000 [06:39<8:22:56,  6.11s/it]

Val update: epoch: 64 |accuracy: 0.7774149775505066 | f1: 0.8452106714248657 | auc: 0.8749202489852905 | treshold: 0.42000000000000004
Test: accuracy: 0.7818643450737 | f1: 0.8486688733100891 | auc: 0.8809521794319153 | 


train... loss:49.46283343434334:   1%|█▎                                                                                                | 66/5000 [06:51<8:32:53,  6.24s/it]


In [12]:
pd.DataFrame(rl4rs_results).to_csv(f'results/rl4rs_{experiment_name}.csv')
del dataset, train_loader, val_loader, test_loader, train_user_item_matrix, train_num_items