In [1]:
import os
import torch
import random
import datetime
import pandas as pd
import numpy as np

from torch.utils.data import Dataset
from src.datasets import RL4RS, ContentWise, DummyData
from src.utils import train, get_dummy_data, get_train_val_test_tmatrix_tnumitems
from src.embeddings import RecsysEmbedding

experiment_name = 'SCOT'
device = 'cuda:0'
seed = 123
pkl_path = '../pkl/'

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7f7174462bd0>

In [2]:
torch.__version__

'1.12.1'

# Модель

In [3]:
import torch.nn.functional as F
torch.autograd.set_detect_anomaly(True)


class SCOT(torch.nn.Module):
    """
    No recurrent dependency, just slate-wise attention.
    """
    def __init__(self, embedding, nheads=2, output_dim=1, debug=False):
        super().__init__()
        self.embedding_dim = embedding.embedding_dim
        self.embedding = embedding
        self.nheads = nheads
        self.debug = debug
        self.attention= torch.nn.MultiheadAttention(
            self.embedding_dim,
            num_heads=nheads,
            batch_first=True
        )
        
        self.out_layer = torch.nn.Sequential(
                # torch.nn.LayerNorm(2* embedding.embedding_dim),
                # torch.nn.Linear(embedding.embedding_dim * 2, embedding.embedding_dim * 2),
                # torch.nn.GELU(),
                torch.nn.Linear(embedding.embedding_dim * 2, output_dim)
        )

    def forward(self, batch):
        item_embs, user_embs = self.embedding(batch)
        shp = item_embs.shape
        device = item_embs.device
        
        if self.debug: print('responses', batch['responses'])
        
        # flattening slates into long sequences
        item_embs = item_embs.flatten(1, 2)
        
        # getting user embedding (mean consumed items)
        if self.training:
            indices = (batch['length'] - 1)
        else:
            indices = (batch['in_length'] - 1)
        indices[indices<0] = 0
        indices = indices[:, None, None].repeat(1, 1, user_embs.size(-1))
        user_embs = user_embs.gather(1, indices).squeeze(-2)
        
        # adding a user embedding as a 'zero item' to 
        # make predictions if nothing is observed
        keys = torch.cat([
                user_embs[:, None, :],
                item_embs
            ],
            dim = 1
        )
        
        clicked_mask = batch['responses'].flatten(1, 2) > 0
        if self.debug: print(clicked_mask, clicked_mask.shape)
        clicked_mask = torch.cat([
                torch.ones_like(clicked_mask[:, 0, None]).bool(),
                clicked_mask
            ],
            dim=-1
        )
        if self.debug: print('clicked_mask:', clicked_mask, clicked_mask.shape)
        
        clicked_items = [
            keys[i][clicked_mask[i], :]
            for i in range(shp[0])
        ]
        keys = torch.nn.utils.rnn.pad_sequence(
            clicked_items,
            batch_first=True,
            padding_value=float('nan')
        )
        key_padding_mask = keys.isnan().all(-1)
        keys = keys.nan_to_num(0)
        
        if self.debug: print('key', keys.shape, keys, key_padding_mask)
        
        # forbid model looking into future (and into current iteraction)
        # at the end, mask will be (num_heads * bsize, slate_size * sequence_size, max_len_clicked_items)
        future_mask = torch.ones((item_embs.size(-2) + 1, item_embs.size(-2) + 1 ), device=device)
        future_mask = torch.triu(future_mask, diagonal=0)[1:, :].bool()
        if self.debug: print('future_mask', future_mask.shape, future_mask)
        ####### TODOTODOTODO ########
        # change future masl to be slatewise, not sequencewise
        # build it without large matrix generaton if possible
        # chunk previous iteracrtion
        ############################
        
        if self.debug: print('click_mask_repeated', clicked_mask[None, 0, ...].repeat(item_embs.size(-2), 1))
        
        if self.debug: print(
            'first user click & past', 
            future_mask[clicked_mask[None, 0, ...].repeat(item_embs.size(-2), 1)].reshape(item_embs.size(-2), keys.size(-2) )
        )
        
        attn_mask = [
            future_mask[
                clicked_mask[None, i, ...].repeat(item_embs.size(-2), 1)
            ].reshape(item_embs.size(-2), clicked_mask[None, i, ...].sum() ).T
            for i in range(shp[0])
        ]
        if self.debug: print('attn_mask', attn_mask[0].shape, attn_mask[1].shape )
                
        
        attn_mask = torch.nn.utils.rnn.pad_sequence(
            attn_mask,
            batch_first=True,
            padding_value=True
        ).permute(0, 2, 1)
        if self.debug: print('attn_mask_stacked', attn_mask)
        
        if self.debug: print(item_embs, keys, key_padding_mask, attn_mask.repeat_interleave(self.nheads, 0))
        features, attn_map = self.attention(
            item_embs, keys, keys,
            key_padding_mask=key_padding_mask,
            attn_mask = attn_mask.repeat_interleave(self.nheads, 0)
        )
        
        if self.debug: print(shp, item_embs.shape)
        features = torch.cat(
            [
                features.reshape(shp),
                item_embs.reshape(shp)
            ],
            dim = -1
        )
        
        return self.out_layer(features).squeeze(-1)

# Игрушечный датасет: проверим, что сходится к идеальным метрикам

In [4]:
d = DummyData()
dummy_loader, dummy_matrix = get_dummy_data(d)

model = SCOT(
    RecsysEmbedding(d.n_items, dummy_matrix, embeddings='neural').to('cpu'),
    output_dim=1
).to('cpu')

train(
    model, 
    dummy_loader, dummy_loader, dummy_loader,
    device=device, lr=1e-3, num_epochs=5000, dummy=True,
    silent=True,
)


biulding affinity matrix...


3it [00:00, 3142.59it/s]


Test before learning: {'f1': 0.4000000059604645, 'roc-auc': 0.6666666269302368, 'accuracy': 0.25}


train... loss:0.7044374346733093:   0%|                                                                                                    | 1/5000 [00:00<28:57,  2.88it/s]

Val update: epoch: 0 |accuracy: 0.75 | f1: 0.6666666865348816 | auc: 0.6666666269302368 | treshold: 0.55
Test: accuracy: 0.75 | f1: 0.6666666865348816 | auc: 0.6666666269302368 | 


train... loss:0.6942538022994995:   0%|                                                                                                    | 2/5000 [00:00<28:51,  2.89it/s]

Val update: epoch: 1 |accuracy: 0.75 | f1: 0.6666666865348816 | auc: 1.0 | treshold: 0.55
Test: accuracy: 0.75 | f1: 0.6666666865348816 | auc: 1.0 | 


train... loss:0.6847817301750183:   0%|                                                                                                    | 3/5000 [00:01<35:38,  2.34it/s]

Val update: epoch: 3 |accuracy: 1.0 | f1: 1.0 | auc: 1.0 | treshold: 0.5800000000000001
Test: accuracy: 1.0 | f1: 1.0 | auc: 1.0 | 





(SCOT(
   (embedding): RecsysEmbedding(
     (item_embeddings): Embedding(5, 32)
   )
   (attention): MultiheadAttention(
     (out_proj): NonDynamicallyQuantizableLinear(in_features=32, out_features=32, bias=True)
   )
   (out_layer): Sequential(
     (0): Linear(in_features=64, out_features=1, bias=True)
   )
 ),
 {'f1': 1.0, 'roc-auc': 1.0, 'accuracy': 1.0})

# ContentWise

In [5]:
content_wise_results = []
dataset = ContentWise.load(os.path.join(pkl_path, 'cw.pkl'))
(
    train_loader, 
    val_loader, 
    test_loader, 
    train_user_item_matrix, 
    train_num_items 
) = get_train_val_test_tmatrix_tnumitems(dataset, batch_size=150)

print(f"{len(dataset)} data points among {len(train_loader)} batches")

20216 data points among 108 batches


In [6]:
for embeddings in ['svd', 'neural']:
    print(f"\nEvaluating {experiment_name} with {embeddings} embeddings")
    
    model = SCOT(
        RecsysEmbedding(train_num_items, train_user_item_matrix, embeddings=embeddings),
        output_dim=1
    ).to(device)

    _, metrics = train(
        model, 
        train_loader, val_loader, test_loader, 
        device=device, lr=1e-3, num_epochs=5000, early_stopping=7,
       silent=True, 
    )
    
    metrics['embeddings'] = embeddings
    content_wise_results.append(metrics)


Evaluating SCOT with svd embeddings
Test before learning: {'f1': 0.13685372471809387, 'roc-auc': 0.5017930269241333, 'accuracy': 0.8512182831764221}


train... loss:49.74499770998955:   0%|                                                                                                  | 1/5000 [01:01<85:40:05, 61.69s/it]

Val update: epoch: 0 |accuracy: 0.09982575476169586 | f1: 0.18153013288974762 | auc: 0.5062177181243896 | treshold: 0.01
Test: accuracy: 0.09984634816646576 | f1: 0.18156416714191437 | auc: 0.5073295831680298 | 


train... loss:34.518369287252426:   0%|                                                                                                 | 2/5000 [01:30<59:06:30, 42.58s/it]

Val update: epoch: 1 |accuracy: 0.12159887701272964 | f1: 0.18361753225326538 | auc: 0.5580906867980957 | treshold: 0.04
Test: accuracy: 0.12342137843370438 | f1: 0.18401013314723969 | auc: 0.5660172700881958 | 


train... loss:34.06975802779198:   0%|                                                                                                  | 3/5000 [02:00<50:51:28, 36.64s/it]

Val update: epoch: 2 |accuracy: 0.13854676485061646 | f1: 0.18538756668567657 | auc: 0.5635009407997131 | treshold: 0.060000000000000005
Test: accuracy: 0.14033803343772888 | f1: 0.18594016134738922 | auc: 0.5686299204826355 | 


train... loss:33.899354457855225:   0%|                                                                                                 | 4/5000 [02:29<46:49:49, 33.74s/it]

Val update: epoch: 3 |accuracy: 0.1623453050851822 | f1: 0.187079057097435 | auc: 0.575305700302124 | treshold: 0.06999999999999999
Test: accuracy: 0.16443988680839539 | f1: 0.18724022805690765 | auc: 0.5816481113433838 | 


train... loss:33.685514718294144:   0%|                                                                                                 | 5/5000 [02:59<44:44:16, 32.24s/it]

Val update: epoch: 4 |accuracy: 0.20350871980190277 | f1: 0.18769745528697968 | auc: 0.5903298854827881 | treshold: 0.06999999999999999
Test: accuracy: 0.21306797862052917 | f1: 0.18951302766799927 | auc: 0.5929321050643921 | 


train... loss:33.4452702999115:   0%|                                                                                                   | 6/5000 [03:28<43:22:13, 31.26s/it]

Val update: epoch: 5 |accuracy: 0.8530537486076355 | f1: 0.24673639237880707 | auc: 0.6116593480110168 | treshold: 0.15000000000000002
Test: accuracy: 0.8565595746040344 | f1: 0.25266849994659424 | auc: 0.6111446619033813 | 


train... loss:33.06515821814537:   0%|▏                                                                                                 | 7/5000 [03:58<42:38:14, 30.74s/it]

Val update: epoch: 6 |accuracy: 0.8904939889907837 | f1: 0.16962167620658875 | auc: 0.627440333366394 | treshold: 0.14
Test: accuracy: 0.8935245275497437 | f1: 0.18302267789840698 | auc: 0.629646897315979 | 


train... loss:32.747941970825195:   0%|▏                                                                                                | 8/5000 [04:27<41:58:48, 30.27s/it]

Val update: epoch: 7 |accuracy: 0.7757755517959595 | f1: 0.24318890273571014 | auc: 0.6329869031906128 | treshold: 0.15000000000000002
Test: accuracy: 0.7779322266578674 | f1: 0.24052850902080536 | auc: 0.6328929662704468 | 


train... loss:32.65600147843361:   0%|▏                                                                                                 | 9/5000 [04:57<41:38:47, 30.04s/it]

Val update: epoch: 8 |accuracy: 0.8873069286346436 | f1: 0.2204594612121582 | auc: 0.6370352506637573 | treshold: 0.15000000000000002
Test: accuracy: 0.8896758556365967 | f1: 0.23562809824943542 | auc: 0.6367872357368469 | 


train... loss:32.56835728883743:   0%|▏                                                                                                | 10/5000 [05:26<41:28:11, 29.92s/it]

Val update: epoch: 9 |accuracy: 0.8722802400588989 | f1: 0.25915688276290894 | auc: 0.6393176317214966 | treshold: 0.14
Test: accuracy: 0.8766664266586304 | f1: 0.2689104676246643 | auc: 0.6392389535903931 | 


train... loss:32.46553173661232:   0%|▏                                                                                                | 11/5000 [05:56<41:23:38, 29.87s/it]

Val update: epoch: 10 |accuracy: 0.8077799677848816 | f1: 0.2561235725879669 | auc: 0.6421184539794922 | treshold: 0.15000000000000002
Test: accuracy: 0.8105802536010742 | f1: 0.25771304965019226 | auc: 0.6417320966720581 | 


train... loss:32.618060022592545:   0%|▏                                                                                               | 13/5000 [06:53<40:32:05, 29.26s/it]

Val update: epoch: 12 |accuracy: 0.8852368593215942 | f1: 0.24108725786209106 | auc: 0.6449704766273499 | treshold: 0.14
Test: accuracy: 0.8872759342193604 | f1: 0.2589706480503082 | auc: 0.6470450162887573 | 


train... loss:32.12242129445076:   0%|▎                                                                                                | 14/5000 [07:23<40:47:11, 29.45s/it]

Val update: epoch: 13 |accuracy: 0.7985613942146301 | f1: 0.25917404890060425 | auc: 0.6469675302505493 | treshold: 0.14
Test: accuracy: 0.8020926117897034 | f1: 0.25806450843811035 | auc: 0.6473703384399414 | 


train... loss:32.17801457643509:   0%|▎                                                                                                | 15/5000 [07:52<40:48:06, 29.47s/it]

Val update: epoch: 14 |accuracy: 0.8523687124252319 | f1: 0.27573609352111816 | auc: 0.6506842374801636 | treshold: 0.16
Test: accuracy: 0.8560473918914795 | f1: 0.280237078666687 | auc: 0.6497970223426819 | 


train... loss:32.10946875810623:   0%|▎                                                                                                | 17/5000 [08:50<40:22:19, 29.17s/it]

Val update: epoch: 16 |accuracy: 0.8886175155639648 | f1: 0.23924320936203003 | auc: 0.654229462146759 | treshold: 0.14
Test: accuracy: 0.8914904594421387 | f1: 0.2647496163845062 | auc: 0.6553739905357361 | 


train... loss:31.971059888601303:   0%|▎                                                                                               | 18/5000 [09:20<40:44:29, 29.44s/it]

Val update: epoch: 17 |accuracy: 0.8778947591781616 | f1: 0.26552003622055054 | auc: 0.6552233695983887 | treshold: 0.14
Test: accuracy: 0.882373571395874 | f1: 0.28576505184173584 | auc: 0.6563026309013367 | 


train... loss:32.14436003565788:   0%|▍                                                                                                | 20/5000 [10:17<40:09:32, 29.03s/it]

Val update: epoch: 19 |accuracy: 0.8344081044197083 | f1: 0.27906373143196106 | auc: 0.6561322808265686 | treshold: 0.14
Test: accuracy: 0.8355015516281128 | f1: 0.2791920602321625 | auc: 0.6559827327728271 | 


train... loss:31.759607285261154:   0%|▍                                                                                               | 23/5000 [11:40<39:17:39, 28.42s/it]

Val update: epoch: 22 |accuracy: 0.8705080151557922 | f1: 0.27764391899108887 | auc: 0.6576802730560303 | treshold: 0.13
Test: accuracy: 0.8750566840171814 | f1: 0.29109930992126465 | auc: 0.6613082885742188 | 


train... loss:31.60437723994255:   0%|▍                                                                                                | 24/5000 [12:10<39:38:37, 28.68s/it]

Val update: epoch: 23 |accuracy: 0.8552429676055908 | f1: 0.28487345576286316 | auc: 0.6586943864822388 | treshold: 0.14
Test: accuracy: 0.8585497736930847 | f1: 0.28769344091415405 | auc: 0.6610957384109497 | 


train... loss:31.62145361304283:   0%|▍                                                                                                | 25/5000 [12:38<39:44:32, 28.76s/it]

Val update: epoch: 24 |accuracy: 0.829746663570404 | f1: 0.282360315322876 | auc: 0.6598076224327087 | treshold: 0.13
Test: accuracy: 0.831023633480072 | f1: 0.2811430096626282 | auc: 0.6622660160064697 | 


train... loss:31.637040376663208:   1%|▍                                                                                               | 26/5000 [13:08<40:01:37, 28.97s/it]

Val update: epoch: 25 |accuracy: 0.8284807801246643 | f1: 0.2818482220172882 | auc: 0.6601201891899109 | treshold: 0.14
Test: accuracy: 0.8324723839759827 | f1: 0.28324568271636963 | auc: 0.6624599695205688 | 


train... loss:31.68530911207199:   1%|▌                                                                                                | 28/5000 [14:04<39:27:51, 28.57s/it]

Val update: epoch: 27 |accuracy: 0.8314444422721863 | f1: 0.27864882349967957 | auc: 0.6604750156402588 | treshold: 0.13
Test: accuracy: 0.8299846053123474 | f1: 0.2743285298347473 | auc: 0.661458432674408 | 


train... loss:31.58834660053253:   1%|▋                                                                                                | 34/5000 [16:47<38:04:13, 27.60s/it]

Val update: epoch: 33 |accuracy: 0.850641131401062 | f1: 0.285734623670578 | auc: 0.6610395908355713 | treshold: 0.14
Test: accuracy: 0.8548620939254761 | f1: 0.29025334119796753 | auc: 0.6674587726593018 | 


train... loss:31.468235284090042:   1%|▋                                                                                               | 37/5000 [18:10<38:21:52, 27.83s/it]

Val update: epoch: 36 |accuracy: 0.8562705516815186 | f1: 0.2885366678237915 | auc: 0.6625367403030396 | treshold: 0.13
Test: accuracy: 0.8590326905250549 | f1: 0.2917432487010956 | auc: 0.6665442585945129 | 


train... loss:31.412882566452026:   1%|▊                                                                                               | 41/5000 [19:59<38:14:40, 27.76s/it]

Val update: epoch: 40 |accuracy: 0.8656827807426453 | f1: 0.28787997364997864 | auc: 0.6629058122634888 | treshold: 0.13
Test: accuracy: 0.8699787855148315 | f1: 0.2959029972553253 | auc: 0.6672337651252747 | 


train... loss:31.366200178861618:   1%|▉                                                                                               | 48/5000 [23:09<37:57:15, 27.59s/it]

Val update: epoch: 47 |accuracy: 0.8581768274307251 | f1: 0.28916922211647034 | auc: 0.6636558771133423 | treshold: 0.12
Test: accuracy: 0.863057017326355 | f1: 0.2934158742427826 | auc: 0.6681876182556152 | 


train... loss:31.338451832532883:   1%|▉                                                                                               | 49/5000 [23:38<38:29:02, 27.98s/it]

Val update: epoch: 48 |accuracy: 0.8581470251083374 | f1: 0.2914527952671051 | auc: 0.665249228477478 | treshold: 0.14
Test: accuracy: 0.8635692000389099 | f1: 0.29749077558517456 | auc: 0.6717571020126343 | 


train... loss:31.401248902082443:   1%|▉                                                                                               | 51/5000 [24:58<40:23:58, 29.39s/it]



Evaluating SCOT with neural embeddings
Test before learning: {'f1': 0.1609901487827301, 'roc-auc': 0.48149383068084717, 'accuracy': 0.4643155038356781}


train... loss:48.67877772450447:   0%|                                                                                                 | 1/5000 [01:18<108:44:06, 78.30s/it]

Val update: epoch: 0 |accuracy: 0.23847676813602448 | f1: 0.18770453333854675 | auc: 0.5698308944702148 | treshold: 0.03
Test: accuracy: 0.26153507828712463 | f1: 0.1897789090871811 | auc: 0.5711350440979004 | 


train... loss:34.91371810436249:   0%|                                                                                                 | 2/5000 [02:37<109:15:33, 78.70s/it]

Val update: epoch: 1 |accuracy: 0.8692272305488586 | f1: 0.14953994750976562 | auc: 0.613926351070404 | treshold: 0.11
Test: accuracy: 0.874090850353241 | f1: 0.16238318383693695 | auc: 0.6167601943016052 | 


train... loss:33.7190118432045:   0%|                                                                                                  | 3/5000 [03:56<109:19:35, 78.76s/it]

Val update: epoch: 2 |accuracy: 0.7686716914176941 | f1: 0.24114514887332916 | auc: 0.6499383449554443 | treshold: 0.13
Test: accuracy: 0.7750201225280762 | f1: 0.24733182787895203 | auc: 0.6545760631561279 | 


train... loss:32.621150225400925:   0%|                                                                                                | 4/5000 [05:15<109:46:42, 79.10s/it]

Val update: epoch: 3 |accuracy: 0.6269974708557129 | f1: 0.2492656260728836 | auc: 0.6793665885925293 | treshold: 0.12
Test: accuracy: 0.6274822354316711 | f1: 0.2478874921798706 | auc: 0.6794295907020569 | 


train... loss:31.720140397548676:   0%|                                                                                                | 6/5000 [07:50<108:31:42, 78.23s/it]

Val update: epoch: 5 |accuracy: 0.8003038167953491 | f1: 0.27530670166015625 | auc: 0.689855694770813 | treshold: 0.13
Test: accuracy: 0.8040096759796143 | f1: 0.2790547311306 | auc: 0.6936919093132019 | 


train... loss:31.503081381320953:   0%|▏                                                                                               | 7/5000 [09:09<108:58:12, 78.57s/it]

Val update: epoch: 6 |accuracy: 0.7482538223266602 | f1: 0.28012946248054504 | auc: 0.6982169151306152 | treshold: 0.17
Test: accuracy: 0.7477281093597412 | f1: 0.27945664525032043 | auc: 0.700057327747345 | 


train... loss:31.189642757177353:   0%|▏                                                                                               | 8/5000 [10:28<108:52:24, 78.51s/it]

Val update: epoch: 7 |accuracy: 0.8758693337440491 | f1: 0.24096165597438812 | auc: 0.6994113922119141 | treshold: 0.15000000000000002
Test: accuracy: 0.8793883323669434 | f1: 0.25961193442344666 | auc: 0.7036441564559937 | 


train... loss:31.11687135696411:   0%|▏                                                                                                | 9/5000 [11:47<109:02:18, 78.65s/it]

Val update: epoch: 8 |accuracy: 0.7227575182914734 | f1: 0.28084680438041687 | auc: 0.7027031779289246 | treshold: 0.15000000000000002
Test: accuracy: 0.7252652645111084 | f1: 0.2802484333515167 | auc: 0.7032586336135864 | 


train... loss:30.707155615091324:   0%|▏                                                                                              | 11/5000 [14:22<108:25:22, 78.24s/it]

Val update: epoch: 10 |accuracy: 0.8753183484077454 | f1: 0.2550275921821594 | auc: 0.7114785313606262 | treshold: 0.15000000000000002
Test: accuracy: 0.8766371607780457 | f1: 0.2708873748779297 | auc: 0.71410071849823 | 


train... loss:30.077650368213654:   0%|▎                                                                                              | 15/5000 [19:30<107:09:55, 77.39s/it]

Val update: epoch: 14 |accuracy: 0.8111903667449951 | f1: 0.3022564649581909 | auc: 0.7160630226135254 | treshold: 0.16
Test: accuracy: 0.8126143217086792 | f1: 0.3064507246017456 | auc: 0.7150951027870178 | 


train... loss:27.733672067523003:   1%|▊                                                                                              | 42/5000 [55:04<108:21:39, 78.68s/it]


In [7]:
pd.DataFrame(content_wise_results).to_csv(f'results/cw_{experiment_name}.csv')
del dataset, train_loader, val_loader, test_loader, train_user_item_matrix, train_num_items

# RL4RS

In [8]:
rl4rs_results = []
dataset = RL4RS.load(os.path.join(pkl_path, 'rl4rs.pkl'))
(
    train_loader, 
    val_loader, 
    test_loader, 
    train_user_item_matrix, 
    train_num_items 
) = get_train_val_test_tmatrix_tnumitems(dataset, batch_size=350)

print(f"{len(dataset)} data points among {len(train_loader)} batches")

45942 data points among 106 batches


In [None]:
for embeddings in ['neural','explicit', 'svd',  ]:
    print(f"\nEvaluating {experiment_name} with {embeddings} embeddings")

    model = SCOT(
        RecsysEmbedding(
            train_num_items, 
            train_user_item_matrix, 
            embeddings=embeddings,
            embedding_dim=40
        ),
        output_dim=1
    ).to(device)

    best_model, metrics = train(
        model, 
        train_loader, val_loader, test_loader, 
        device=device, lr=1e-3, num_epochs=5000, early_stopping=7,
        silent=True
    )
    
    metrics['embeddings'] = embeddings
    rl4rs_results.append(metrics)


Evaluating SCOT with neural embeddings
Test before learning: {'f1': 0.6153877973556519, 'roc-auc': 0.5364928245544434, 'accuracy': 0.5471889972686768}


train... loss:59.57800215482712:   0%|                                                                                                | 1/5000 [02:16<189:52:07, 136.73s/it]

Val update: epoch: 0 |accuracy: 0.790862500667572 | f1: 0.8524629473686218 | auc: 0.8703110814094543 | treshold: 0.4
Test: accuracy: 0.7901341915130615 | f1: 0.8508327007293701 | auc: 0.8719138503074646 | 


train... loss:43.483878791332245:   0%|                                                                                               | 2/5000 [03:39<145:50:39, 105.05s/it]

Val update: epoch: 1 |accuracy: 0.8391621708869934 | f1: 0.8847926259040833 | auc: 0.9085134267807007 | treshold: 0.55
Test: accuracy: 0.8343126773834229 | f1: 0.8802516460418701 | auc: 0.9070823192596436 | 


train... loss:39.53883966803551:   0%|                                                                                                 | 3/5000 [05:02<131:59:00, 95.09s/it]

Val update: epoch: 2 |accuracy: 0.8462729454040527 | f1: 0.8891910910606384 | auc: 0.9174481630325317 | treshold: 0.53
Test: accuracy: 0.8415669202804565 | f1: 0.8847939372062683 | auc: 0.9156161546707153 | 


train... loss:37.24711927771568:   0%|                                                                                                 | 4/5000 [06:26<125:27:20, 90.40s/it]

Val update: epoch: 3 |accuracy: 0.8554152846336365 | f1: 0.8948941826820374 | auc: 0.9224686622619629 | treshold: 0.51
Test: accuracy: 0.8480715751647949 | f1: 0.8886328339576721 | auc: 0.9210187792778015 | 


train... loss:36.19328734278679:   0%|                                                                                                 | 5/5000 [07:48<121:34:23, 87.62s/it]

Val update: epoch: 4 |accuracy: 0.859768807888031 | f1: 0.8973259925842285 | auc: 0.9264046549797058 | treshold: 0.49
Test: accuracy: 0.8543586134910583 | f1: 0.8923907279968262 | auc: 0.9249786734580994 | 


train... loss:34.7994818687439:   0%|                                                                                                  | 6/5000 [09:12<119:29:30, 86.14s/it]

Val update: epoch: 5 |accuracy: 0.8647269606590271 | f1: 0.8971061706542969 | auc: 0.9276810884475708 | treshold: 0.47000000000000003
Test: accuracy: 0.8596783876419067 | f1: 0.8922156691551208 | auc: 0.9245980381965637 | 


train... loss:35.1882840692997:   0%|▏                                                                                                 | 8/5000 [11:55<116:24:52, 83.95s/it]

Val update: epoch: 7 |accuracy: 0.853939950466156 | f1: 0.8955786228179932 | auc: 0.9290846586227417 | treshold: 0.47000000000000003
Test: accuracy: 0.8496191501617432 | f1: 0.8915208578109741 | auc: 0.9261579513549805 | 


train... loss:34.34776130318642:   0%|▏                                                                                                | 9/5000 [13:18<115:38:01, 83.41s/it]

Val update: epoch: 8 |accuracy: 0.8476031422615051 | f1: 0.8921043276786804 | auc: 0.930382251739502 | treshold: 0.47000000000000003
Test: accuracy: 0.8420746922492981 | f1: 0.8873091340065002 | auc: 0.9272301197052002 | 


train... loss:34.07372435927391:   0%|▏                                                                                               | 10/5000 [14:41<115:49:34, 83.56s/it]

Val update: epoch: 9 |accuracy: 0.8524887561798096 | f1: 0.8944133520126343 | auc: 0.9308153390884399 | treshold: 0.45
Test: accuracy: 0.848458468914032 | f1: 0.8904274702072144 | auc: 0.9278963804244995 | 


train... loss:33.853845089673996:   0%|▏                                                                                              | 11/5000 [16:04<115:33:15, 83.38s/it]

Val update: epoch: 10 |accuracy: 0.8641222715377808 | f1: 0.9003087878227234 | auc: 0.9330757856369019 | treshold: 0.46
Test: accuracy: 0.859412431716919 | f1: 0.8959183692932129 | auc: 0.9289408922195435 | 


train... loss:33.60004737973213:   0%|▏                                                                                               | 12/5000 [17:26<114:38:43, 82.74s/it]

In [None]:
pd.DataFrame(rl4rs_results).to_csv(f'results/rl4rs_{experiment_name}.csv')
del dataset, train_loader, val_loader, test_loader, train_user_item_matrix, train_num_items