In [2]:
import os
import torch
import random
import datetime
import pandas as pd
import numpy as np
import torch.nn.functional as F

from src.datasets import RL4RS, ContentWise, DummyData
from src.utils import train, get_dummy_data, get_train_val_test_tmatrix_tnumitems, get_svd_encoder
from src.embeddings import RecsysEmbedding
from torch.utils.data import Dataset

experiment_name = 'AttentionGRU2'
device = 'cuda:2'
seed = 123
pkl_path = '../pkl/'

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7f5fcdfb93b0>

In [3]:
torch.__version__

'1.12.1'

# Модель

In [None]:
class SlatewiseAttentionLayer(nn.Module):
    """
    Attention-based enoder that operates independently on slates
    """
    def __init__(self, in_dim, nheads=2, out_dim=None):
        super().__init__()
        if out_dim is None:
            out_dim = in_dim
        self.out_dim = out_dim

        self.attention = torch.nn.MultiheadAttention(
            in_dim,
            num_heads=nheads,
            batch_first=True
        ) 
        
        self.lnorm1 = nn.LayerNorm(2 * in_dim)
        self.lnorm2 = nn.LayerNorm(out_dim)

        self.ff = torch.nn.Sequential(
            torch.nn.Linear(2 * in_dim, in_dim),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(in_dim, out_dim),
        )
        
    def forward(self, input, mask):
        # batch, session, slate, embedding
        shp = list(input.shape)
        shp[-1] = self.out_dim

        
        key_padding_mask = mask.clone().flatten(0, 1)
        # let model attent to first padd token if slate is empty 
        key_padding_mask[:, 0] = True
        features = input.flatten(0,1)
        features, attn_map = self.attention(
            features, features, features,
            key_padding_mask=~key_padding_mask
        )
               
        features = torch.cat(
            [features, input.flatten(0,1)],
            dim=-1
        )

        features = self.lnorm1(features)
        features = self.lnorm2(self.ff(features))
        
        return features.reshape(shp)

In [17]:
# torch.autograd.set_detect_anomaly(True)

class AttentionGRU2(torch.nn.Module):
    def __init__(self, embedding, nheads=2, output_dim=1):
        super().__init__()
        self.embedding_dim = embedding.embedding_dim
        self.embedding = embedding
        self.attention = torch.nn.MultiheadAttention(
            2 * embedding.embedding_dim,
            num_heads=nheads,
            batch_first=True
        )
        
        self.rnn_layer = torch.nn.GRUCell(
            input_size = embedding.embedding_dim, 
            hidden_size = embedding.embedding_dim, 
        )

        self.hidden_encode_layer = torch.nn.Linear(2 * embedding.embedding_dim, embedding.embedding_dim)
        
        self.out_layer = torch.nn.Sequential(
            torch.nn.Linear(3 * embedding.embedding_dim, embedding.embedding_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(embedding.embedding_dim, output_dim),
        )
    
    
    def get_attention_embeddings(self, item_embs, user_embs, slate_mask):
        shp = item_embs.shape      
        key_padding_mask = slate_mask
        key_padding_mask[:,:, 0] = True # let model attent to first padd token if slate is empty 
        features = torch.cat(
            [
                item_embs,
                user_embs[:, :, None, :].repeat(1, 1, item_embs.size(-2), 1).reshape(shp)
            ],
            dim = -1
        ).flatten(0,1)
        
        features, attn_map = self.attention(
            features, features, features,
            key_padding_mask=~key_padding_mask.flatten(0, 1)
        )
        shp = list(shp)
        shp[-1] *= 2
        features = features.reshape(shp)
        return features
    
    def forward(self, batch):
        # item embs dimensions: batch, sequence, slate, embedding
        # user embs dimensions: batch, sequence, embedding
        item_embs, user_embs = self.embedding(batch)

        # True for real items, False for placeholder
        slate_mask = batch['slates_mask'].clone()

        # item embs dimensions: batch, sequence, slate, embedding
        att_features = self.get_attention_embeddings(item_embs, user_embs, slate_mask)

        # clicked items goes at first 32 coordinates, not clicked at last 32 coordinates
        gru_features, _ = self.rnn_layer(item_embs.flatten(-3, -2))
        gru_features = gru_features.reshape(item_embs.shape)
        
        features = torch.cat(
            [att_features, gru_features],
            dim=-1
        )
        
        return self.out_layer(features).squeeze(-1)


d = DummyData()
dummy_loader, dummy_matrix = get_dummy_data(d)

model = AttentionResponseModel(
    RecsysEmbedding(d.n_items, dummy_matrix, embeddings='neural').to('cpu'),
    output_dim=1
).to('cpu')

for batch in dummy_loader:
    print(model(batch))
    break

biulding affinity matrix...


3it [00:00, 5286.94it/s]

tensor([[[-0.0397,  0.0249,  0.0446],
         [ 0.0463,  0.1012,  0.0272]],

        [[-0.1109, -0.1047, -0.1453],
         [-0.0338, -0.0430, -0.0457]]], grad_fn=<SqueezeBackward1>)





# Игрушечный датасет: проверим, что сходится к идеальным метрикам

In [21]:
d = DummyData()
dummy_loader, dummy_matrix = get_dummy_data(d)

model = AttentionGRU2(
    RecsysEmbedding(d.n_items, dummy_matrix, embeddings='neural').to('cpu'),
    output_dim=1
).to('cpu')

train(
    model, 
    dummy_loader, dummy_loader, dummy_loader,
    device=device, lr=1e-3, num_epochs=5000, dummy=True,
    silent=True,
)


biulding affinity matrix...


3it [00:00, 3181.52it/s]


Test before learning: {'f1': 0.4000000059604645, 'roc-auc': 0.3333333134651184, 'accuracy': 0.25}


train... loss:0.6785470247268677:   0%|                                                                                                    | 1/5000 [00:00<39:38,  2.10it/s]

Val update: epoch: 0 |accuracy: 0.25 | f1: 0.4000000059604645 | auc: 1.0 | treshold: 0.01
Test: accuracy: 0.25 | f1: 0.4000000059604645 | auc: 1.0 | 


train... loss:0.6630663275718689:   0%|                                                                                                    | 2/5000 [00:01<39:54,  2.09it/s]

Val update: epoch: 1 |accuracy: 0.75 | f1: 0.6666666865348816 | auc: 1.0 | treshold: 0.52
Test: accuracy: 0.75 | f1: 0.6666666865348816 | auc: 1.0 | 


train... loss:0.611418604850769:   0%|                                                                                                     | 6/5000 [00:02<38:05,  2.18it/s]

Val update: epoch: 6 |accuracy: 1.0 | f1: 1.0 | auc: 1.0 | treshold: 0.5800000000000001
Test: accuracy: 1.0 | f1: 1.0 | auc: 1.0 | 





(AttentionResponseModel(
   (embedding): RecsysEmbedding(
     (item_embeddings): Embedding(5, 32)
   )
   (attention): MultiheadAttention(
     (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
   )
   (rnn_layer): GRU(32, 32, batch_first=True)
   (hidden_encode_layer): Linear(in_features=64, out_features=32, bias=True)
   (out_layer): Sequential(
     (0): Linear(in_features=96, out_features=32, bias=True)
     (1): ReLU()
     (2): Linear(in_features=32, out_features=1, bias=True)
   )
 ),
 {'f1': 1.0, 'roc-auc': 1.0, 'accuracy': 1.0})

# ContentWise

In [22]:
content_wise_results = []
dataset = ContentWise.load(os.path.join(pkl_path, 'cw.pkl'))
(
    train_loader, 
    val_loader, 
    test_loader, 
    train_user_item_matrix, 
    train_num_items 
) = get_train_val_test_tmatrix_tnumitems(dataset, batch_size=150)

print(f"{len(dataset)} data points among {len(train_loader)} batches")

20216 data points among 108 batches


In [23]:
for embeddings in ['svd', 'neural']:
    print(f"\nEvaluating {experiment_name} with {embeddings} embeddings")
    
    model = AttentionGRU2(
        RecsysEmbedding(train_num_items, train_user_item_matrix, embeddings=embeddings),
        output_dim=1
    ).to(device)

    _, metrics = train(
        model, 
        train_loader, val_loader, test_loader, 
        device=device, lr=1e-3, num_epochs=5000, early_stopping=7,
       silent=True, 
    )
    
    metrics['embeddings'] = embeddings
    content_wise_results.append(metrics)


Evaluating SlatewiseAttentionSeqwiseGRU with svd embeddings
Test before learning: {'f1': 0.1819256842136383, 'roc-auc': 0.49667614698410034, 'accuracy': 0.10006504505872726}


train:   0%|                                                                                                                                       | 0/5000 [00:04<?, ?it/s]


KeyboardInterrupt: 

In [7]:
pd.DataFrame(content_wise_results).to_csv(f'results/cw_{experiment_name}.csv')
del dataset, train_loader, val_loader, test_loader, train_user_item_matrix, train_num_items

# RL4RS

In [8]:
rl4rs_results = []
dataset = RL4RS.load(os.path.join(pkl_path, 'rl4rs.pkl'))
(
    train_loader, 
    val_loader, 
    test_loader, 
    train_user_item_matrix, 
    train_num_items 
) = get_train_val_test_tmatrix_tnumitems(dataset, batch_size=350)

print(f"{len(dataset)} data points among {len(train_loader)} batches")

45942 data points among 106 batches


In [9]:
for embeddings in ['neural','explicit', 'svd',  ]:
    print(f"\nEvaluating {experiment_name} with {embeddings} embeddings")

    model = AttentionGRU2(
        RecsysEmbedding(
            train_num_items, 
            train_user_item_matrix, 
            embeddings=embeddings,
            embedding_dim=40
        ),
        output_dim=1
    ).to(device)

    best_model, metrics = train(
        model, 
        train_loader, val_loader, test_loader, 
        device=device, lr=1e-3, num_epochs=5000, early_stopping=7,
        silent=True
    )
    
    metrics['embeddings'] = embeddings
    rl4rs_results.append(metrics)


Evaluating InSlateAttentionSequencewiseGRU with neural embeddings
Test before learning: {'f1': 0.7070940136909485, 'roc-auc': 0.5168448090553284, 'accuracy': 0.5949462056159973}


train... loss:58.82067450881004:   0%|                                                                                                  | 1/5000 [01:03<88:29:15, 63.72s/it]

Val update: epoch: 0 |accuracy: 0.7771247625350952 | f1: 0.840562641620636 | auc: 0.861542820930481 | treshold: 0.4
Test: accuracy: 0.7778019309043884 | f1: 0.8418716192245483 | auc: 0.8600804805755615 | 


train... loss:49.33029007911682:   0%|                                                                                                  | 2/5000 [01:11<42:29:25, 30.61s/it]

Val update: epoch: 1 |accuracy: 0.8133314251899719 | f1: 0.8602874875068665 | auc: 0.8926573395729065 | treshold: 0.45
Test: accuracy: 0.8155966401100159 | f1: 0.8628565073013306 | auc: 0.891123354434967 | 


train... loss:46.4336553812027:   0%|                                                                                                   | 3/5000 [01:18<27:48:05, 20.03s/it]

Val update: epoch: 2 |accuracy: 0.8277463316917419 | f1: 0.8691049218177795 | auc: 0.9064161777496338 | treshold: 0.49
Test: accuracy: 0.8263813257217407 | f1: 0.8688344955444336 | auc: 0.903417706489563 | 


train... loss:44.79352426528931:   0%|                                                                                                  | 4/5000 [01:25<20:52:14, 15.04s/it]

Val update: epoch: 3 |accuracy: 0.833720326423645 | f1: 0.8711074590682983 | auc: 0.9127175807952881 | treshold: 0.49
Test: accuracy: 0.8321363925933838 | f1: 0.8707407116889954 | auc: 0.9096537828445435 | 


train... loss:43.563194423913956:   0%|                                                                                                 | 5/5000 [01:33<17:05:49, 12.32s/it]

Val update: epoch: 4 |accuracy: 0.8403231501579285 | f1: 0.8776682615280151 | auc: 0.9174004793167114 | treshold: 0.49
Test: accuracy: 0.8357877135276794 | f1: 0.8750160932540894 | auc: 0.9133666753768921 | 


train... loss:42.870913445949554:   0%|                                                                                                 | 6/5000 [01:40<14:45:51, 10.64s/it]

Val update: epoch: 5 |accuracy: 0.8425482511520386 | f1: 0.879769504070282 | auc: 0.9216941595077515 | treshold: 0.48000000000000004
Test: accuracy: 0.8387619256973267 | f1: 0.8773859143257141 | auc: 0.917554497718811 | 


train... loss:41.868689984083176:   0%|▏                                                                                                | 8/5000 [01:55<12:12:13,  8.80s/it]

Val update: epoch: 7 |accuracy: 0.8484013080596924 | f1: 0.8835420608520508 | auc: 0.9251205921173096 | treshold: 0.5
Test: accuracy: 0.8419296145439148 | f1: 0.8792507648468018 | auc: 0.9210832118988037 | 


train... loss:41.45595145225525:   0%|▏                                                                                                | 10/5000 [02:09<11:00:44,  7.94s/it]

Val update: epoch: 9 |accuracy: 0.8500459790229797 | f1: 0.8834805488586426 | auc: 0.9270545244216919 | treshold: 0.48000000000000004
Test: accuracy: 0.8455325961112976 | f1: 0.8805579543113708 | auc: 0.9228711724281311 | 


train... loss:40.43740478157997:   0%|▏                                                                                                | 11/5000 [02:16<10:48:17,  7.80s/it]

Val update: epoch: 10 |accuracy: 0.8448459506034851 | f1: 0.8851037621498108 | auc: 0.9310590028762817 | treshold: 0.5
Test: accuracy: 0.8389312028884888 | f1: 0.8812168836593628 | auc: 0.9259691834449768 | 


train... loss:39.917251855134964:   0%|▏                                                                                               | 12/5000 [02:24<10:40:22,  7.70s/it]

Val update: epoch: 11 |accuracy: 0.8492478132247925 | f1: 0.8880105018615723 | auc: 0.933652400970459 | treshold: 0.51
Test: accuracy: 0.8440333604812622 | f1: 0.8844624161720276 | auc: 0.9282408952713013 | 


train... loss:39.37145084142685:   0%|▎                                                                                                | 14/5000 [02:38<10:17:43,  7.43s/it]

Val update: epoch: 13 |accuracy: 0.8526096940040588 | f1: 0.8900595307350159 | auc: 0.9336705207824707 | treshold: 0.48000000000000004
Test: accuracy: 0.8444444537162781 | f1: 0.8846036791801453 | auc: 0.9280190467834473 | 


train... loss:39.505471497774124:   0%|▎                                                                                               | 15/5000 [02:46<10:17:57,  7.44s/it]

Val update: epoch: 14 |accuracy: 0.8607845902442932 | f1: 0.8921087384223938 | auc: 0.9348341822624207 | treshold: 0.49
Test: accuracy: 0.8562930822372437 | f1: 0.8890423774719238 | auc: 0.9308009147644043 | 


train... loss:39.19043782353401:   0%|▎                                                                                                | 16/5000 [02:53<10:16:36,  7.42s/it]

Val update: epoch: 15 |accuracy: 0.8583660125732422 | f1: 0.8854549527168274 | auc: 0.9350820779800415 | treshold: 0.47000000000000003
Test: accuracy: 0.8515778183937073 | f1: 0.8805603981018066 | auc: 0.9309139251708984 | 


train... loss:38.162901699543:   0%|▍                                                                                                   | 19/5000 [03:14<9:59:46,  7.22s/it]

Val update: epoch: 18 |accuracy: 0.8513761758804321 | f1: 0.889440655708313 | auc: 0.9358381032943726 | treshold: 0.49
Test: accuracy: 0.8464031219482422 | f1: 0.8861975073814392 | auc: 0.9316917061805725 | 


train... loss:37.74512952566147:   0%|▍                                                                                                | 20/5000 [03:22<10:05:29,  7.30s/it]

Val update: epoch: 19 |accuracy: 0.8519324660301208 | f1: 0.8909745216369629 | auc: 0.9380056858062744 | treshold: 0.48000000000000004
Test: accuracy: 0.8429936170578003 | f1: 0.8850653767585754 | auc: 0.9330083727836609 | 


train... loss:37.396418780088425:   0%|▍                                                                                                | 24/5000 [03:50<9:50:31,  7.12s/it]

Val update: epoch: 23 |accuracy: 0.8626952767372131 | f1: 0.8955877423286438 | auc: 0.9387286901473999 | treshold: 0.44
Test: accuracy: 0.8555918335914612 | f1: 0.8907907009124756 | auc: 0.9337456822395325 | 


train... loss:36.69810311496258:   1%|▌                                                                                                 | 30/5000 [04:31<9:42:28,  7.03s/it]

Val update: epoch: 29 |accuracy: 0.8520050048828125 | f1: 0.8908549547195435 | auc: 0.9391928315162659 | treshold: 0.46
Test: accuracy: 0.8454358577728271 | f1: 0.8865661025047302 | auc: 0.9342323541641235 | 


train... loss:36.550643265247345:   1%|▋                                                                                                | 34/5000 [04:59<9:44:00,  7.06s/it]

Val update: epoch: 33 |accuracy: 0.8663957715034485 | f1: 0.8970555067062378 | auc: 0.9398387670516968 | treshold: 0.46
Test: accuracy: 0.859799325466156 | f1: 0.8927726149559021 | auc: 0.9337024688720703 | 


train... loss:34.9722064435482:   1%|█▎                                                                                                | 64/5000 [08:31<10:57:21,  7.99s/it]



Evaluating InSlateAttentionSequencewiseGRU with explicit embeddings
Test before learning: {'f1': 0.5234761238098145, 'roc-auc': 0.5357589721679688, 'accuracy': 0.496409147977829}


train... loss:63.96340095996857:   0%|                                                                                                   | 1/5000 [00:07<9:49:16,  7.07s/it]

Val update: epoch: 0 |accuracy: 0.6612973213195801 | f1: 0.7894704937934875 | auc: 0.7909042835235596 | treshold: 0.39
Test: accuracy: 0.6642727851867676 | f1: 0.791374921798706 | auc: 0.7856383323669434 | 


train... loss:60.31742602586746:   0%|                                                                                                   | 2/5000 [00:14<9:48:45,  7.07s/it]

Val update: epoch: 1 |accuracy: 0.7500120997428894 | f1: 0.8222588896751404 | auc: 0.8101397752761841 | treshold: 0.41000000000000003
Test: accuracy: 0.7485672831535339 | f1: 0.8212664723396301 | auc: 0.8050750494003296 | 


train... loss:59.327128529548645:   0%|                                                                                                  | 3/5000 [00:21<9:44:51,  7.02s/it]

Val update: epoch: 2 |accuracy: 0.7531805038452148 | f1: 0.8259453177452087 | auc: 0.8169788122177124 | treshold: 0.41000000000000003
Test: accuracy: 0.7507193684577942 | f1: 0.8240935206413269 | auc: 0.8120150566101074 | 


train... loss:58.5383180975914:   0%|                                                                                                    | 4/5000 [00:28<9:44:41,  7.02s/it]

Val update: epoch: 3 |accuracy: 0.7173365950584412 | f1: 0.8145421147346497 | auc: 0.8192075490951538 | treshold: 0.41000000000000003
Test: accuracy: 0.7177124619483948 | f1: 0.8149862289428711 | auc: 0.8123385906219482 | 


train... loss:57.92720726132393:   0%|                                                                                                   | 5/5000 [00:35<9:44:43,  7.02s/it]

Val update: epoch: 4 |accuracy: 0.7289217710494995 | f1: 0.8202405571937561 | auc: 0.8336963057518005 | treshold: 0.42000000000000004
Test: accuracy: 0.7280619144439697 | f1: 0.8197987675666809 | auc: 0.8270795345306396 | 


train... loss:57.576224118471146:   0%|                                                                                                  | 6/5000 [00:42<9:47:45,  7.06s/it]

Val update: epoch: 5 |accuracy: 0.7464809417724609 | f1: 0.8277233839035034 | auc: 0.8378870487213135 | treshold: 0.44
Test: accuracy: 0.7446016073226929 | f1: 0.8264769911766052 | auc: 0.8310642242431641 | 


train... loss:56.74546191096306:   0%|▏                                                                                                  | 7/5000 [00:49<9:47:01,  7.05s/it]

Val update: epoch: 6 |accuracy: 0.7584288716316223 | f1: 0.8315228223800659 | auc: 0.8395500183105469 | treshold: 0.41000000000000003
Test: accuracy: 0.7549993991851807 | f1: 0.8294677734375 | auc: 0.8345728516578674 | 


train... loss:56.47116303443909:   0%|▏                                                                                                  | 8/5000 [00:56<9:45:55,  7.04s/it]

Val update: epoch: 7 |accuracy: 0.7123059034347534 | f1: 0.8140156865119934 | auc: 0.848324716091156 | treshold: 0.42000000000000004
Test: accuracy: 0.7113287448883057 | f1: 0.8138120174407959 | auc: 0.8387751579284668 | 


train... loss:56.28845000267029:   0%|▏                                                                                                  | 9/5000 [01:03<9:45:36,  7.04s/it]

Val update: epoch: 8 |accuracy: 0.7315096855163574 | f1: 0.8231790065765381 | auc: 0.8624886274337769 | treshold: 0.4
Test: accuracy: 0.7310845255851746 | f1: 0.8233219385147095 | auc: 0.8548674583435059 | 


train... loss:54.5392541885376:   0%|▏                                                                                                  | 11/5000 [01:16<9:37:30,  6.95s/it]

Val update: epoch: 10 |accuracy: 0.7870410680770874 | f1: 0.8468669056892395 | auc: 0.8712275624275208 | treshold: 0.43
Test: accuracy: 0.7831217646598816 | f1: 0.8444637060165405 | auc: 0.8614722490310669 | 


train... loss:55.20169559121132:   0%|▏                                                                                                 | 12/5000 [01:24<9:41:09,  6.99s/it]

Val update: epoch: 11 |accuracy: 0.7529144287109375 | f1: 0.8336210250854492 | auc: 0.8764777779579163 | treshold: 0.44
Test: accuracy: 0.7505984902381897 | f1: 0.8324289321899414 | auc: 0.865923285484314 | 


train... loss:53.37023088335991:   0%|▎                                                                                                 | 15/5000 [01:43<9:25:38,  6.81s/it]

Val update: epoch: 14 |accuracy: 0.797392725944519 | f1: 0.8402342200279236 | auc: 0.8784708380699158 | treshold: 0.44
Test: accuracy: 0.7941482067108154 | f1: 0.8385089635848999 | auc: 0.8700953722000122 | 


train... loss:53.183702021837234:   0%|▎                                                                                                | 16/5000 [01:51<9:31:17,  6.88s/it]

Val update: epoch: 15 |accuracy: 0.8047211170196533 | f1: 0.8478555679321289 | auc: 0.8825274705886841 | treshold: 0.45
Test: accuracy: 0.7983798980712891 | f1: 0.8439803719520569 | auc: 0.8720700144767761 | 


train... loss:53.078484654426575:   0%|▎                                                                                                | 17/5000 [01:58<9:36:12,  6.94s/it]

Val update: epoch: 16 |accuracy: 0.8077202439308167 | f1: 0.850670576095581 | auc: 0.8834667205810547 | treshold: 0.43
Test: accuracy: 0.801813542842865 | f1: 0.8473003506660461 | auc: 0.8737298250198364 | 


train... loss:52.66868579387665:   0%|▎                                                                                                 | 18/5000 [02:05<9:40:40,  6.99s/it]

Val update: epoch: 17 |accuracy: 0.8053258061408997 | f1: 0.8569702506065369 | auc: 0.8856683969497681 | treshold: 0.43
Test: accuracy: 0.7991536855697632 | f1: 0.8529911637306213 | auc: 0.8754594922065735 | 


train... loss:52.29195275902748:   0%|▍                                                                                                 | 20/5000 [02:18<9:34:03,  6.92s/it]

Val update: epoch: 19 |accuracy: 0.7713684439659119 | f1: 0.844089686870575 | auc: 0.8874250650405884 | treshold: 0.47000000000000003
Test: accuracy: 0.7657598853111267 | f1: 0.8409594893455505 | auc: 0.876424252986908 | 


train... loss:51.71615615487099:   0%|▍                                                                                                 | 23/5000 [02:38<9:21:07,  6.76s/it]

Val update: epoch: 22 |accuracy: 0.7872829437255859 | f1: 0.8518686890602112 | auc: 0.8882545828819275 | treshold: 0.44
Test: accuracy: 0.7808729410171509 | f1: 0.8483753204345703 | auc: 0.8775864243507385 | 


train... loss:51.17925238609314:   1%|▌                                                                                                 | 29/5000 [03:17<9:11:09,  6.65s/it]

Val update: epoch: 28 |accuracy: 0.8109369874000549 | f1: 0.8615455627441406 | auc: 0.8890084624290466 | treshold: 0.43
Test: accuracy: 0.8020795583724976 | f1: 0.8560777902603149 | auc: 0.8789182901382446 | 


train... loss:49.78277778625488:   1%|█▏                                                                                                | 62/5000 [06:56<9:13:17,  6.72s/it]



Evaluating InSlateAttentionSequencewiseGRU with svd embeddings
Test before learning: {'f1': 0.7786475419998169, 'roc-auc': 0.5453901886940002, 'accuracy': 0.6403337121009827}


train... loss:64.67854821681976:   0%|                                                                                                   | 1/5000 [00:07<9:48:59,  7.07s/it]

Val update: epoch: 0 |accuracy: 0.700696587562561 | f1: 0.7974134683609009 | auc: 0.7855069637298584 | treshold: 0.38
Test: accuracy: 0.7062991261482239 | f1: 0.8014451265335083 | auc: 0.789409339427948 | 


train... loss:56.99614438414574:   0%|                                                                                                   | 2/5000 [00:14<9:49:25,  7.08s/it]

Val update: epoch: 1 |accuracy: 0.7269385457038879 | f1: 0.8043734431266785 | auc: 0.8021373748779297 | treshold: 0.39
Test: accuracy: 0.7371538877487183 | f1: 0.8120352625846863 | auc: 0.8070188760757446 | 


train... loss:53.831738501787186:   0%|                                                                                                  | 3/5000 [00:21<9:47:24,  7.05s/it]

Val update: epoch: 2 |accuracy: 0.7412083148956299 | f1: 0.814545214176178 | auc: 0.8246825933456421 | treshold: 0.44
Test: accuracy: 0.7476242184638977 | f1: 0.8194509148597717 | auc: 0.8282930254936218 | 


train... loss:53.09312617778778:   0%|                                                                                                   | 4/5000 [00:28<9:50:39,  7.09s/it]

Val update: epoch: 3 |accuracy: 0.7833163738250732 | f1: 0.8316894769668579 | auc: 0.8614903688430786 | treshold: 0.45
Test: accuracy: 0.7862410545349121 | f1: 0.8347848653793335 | auc: 0.8629193305969238 | 


train... loss:50.13837909698486:   0%|                                                                                                   | 6/5000 [00:41<9:35:02,  6.91s/it]

Val update: epoch: 5 |accuracy: 0.7811396718025208 | f1: 0.8364096283912659 | auc: 0.8628407716751099 | treshold: 0.44
Test: accuracy: 0.7867972254753113 | f1: 0.8411208391189575 | auc: 0.8645704984664917 | 


train... loss:49.313341706991196:   0%|▏                                                                                                 | 7/5000 [00:48<9:40:09,  6.97s/it]

Val update: epoch: 6 |accuracy: 0.7715135812759399 | f1: 0.839606761932373 | auc: 0.8714889883995056 | treshold: 0.44
Test: accuracy: 0.7721920013427734 | f1: 0.8403788208961487 | auc: 0.8720890283584595 | 


train... loss:48.673452377319336:   0%|▏                                                                                                | 10/5000 [01:08<9:28:42,  6.84s/it]

Val update: epoch: 9 |accuracy: 0.7516567707061768 | f1: 0.8322934508323669 | auc: 0.8777687549591064 | treshold: 0.42000000000000004
Test: accuracy: 0.7531132698059082 | f1: 0.8334692716598511 | auc: 0.8776654601097107 | 


train... loss:47.786136761307716:   0%|▏                                                                                                | 11/5000 [01:16<9:35:50,  6.93s/it]

Val update: epoch: 10 |accuracy: 0.7738112807273865 | f1: 0.8434549570083618 | auc: 0.8859124183654785 | treshold: 0.44
Test: accuracy: 0.7755289673805237 | f1: 0.8451052308082581 | auc: 0.886222243309021 | 


train... loss:47.12998677790165:   0%|▏                                                                                                 | 12/5000 [01:23<9:40:15,  6.98s/it]

Val update: epoch: 11 |accuracy: 0.7822280526161194 | f1: 0.8477356433868408 | auc: 0.8895320296287537 | treshold: 0.44
Test: accuracy: 0.7827832102775574 | f1: 0.8485389947891235 | auc: 0.8892430663108826 | 


train... loss:45.45469534397125:   0%|▎                                                                                                 | 17/5000 [01:56<9:21:29,  6.76s/it]

Val update: epoch: 16 |accuracy: 0.8043583631515503 | f1: 0.8586753606796265 | auc: 0.896442174911499 | treshold: 0.44
Test: accuracy: 0.8060935735702515 | f1: 0.8604347705841064 | auc: 0.8955862522125244 | 


train... loss:44.64242035150528:   0%|▎                                                                                                 | 19/5000 [02:09<9:24:58,  6.81s/it]

Val update: epoch: 18 |accuracy: 0.796715497970581 | f1: 0.8559283018112183 | auc: 0.897294819355011 | treshold: 0.42000000000000004
Test: accuracy: 0.799419641494751 | f1: 0.8579744696617126 | auc: 0.8974224925041199 | 


train... loss:44.82900556921959:   0%|▍                                                                                                 | 21/5000 [02:23<9:25:36,  6.82s/it]

Val update: epoch: 20 |accuracy: 0.7931843400001526 | f1: 0.8544634580612183 | auc: 0.8999576568603516 | treshold: 0.43
Test: accuracy: 0.7925764918327332 | f1: 0.8545855283737183 | auc: 0.8994263410568237 | 


train... loss:44.3938904106617:   0%|▍                                                                                                  | 22/5000 [02:30<9:32:21,  6.90s/it]

Val update: epoch: 21 |accuracy: 0.8019881248474121 | f1: 0.8589397072792053 | auc: 0.9020766019821167 | treshold: 0.46
Test: accuracy: 0.8057550191879272 | f1: 0.8617122173309326 | auc: 0.9013591408729553 | 


train... loss:44.505712270736694:   0%|▍                                                                                                | 23/5000 [02:37<9:35:12,  6.93s/it]

Val update: epoch: 22 |accuracy: 0.7949257493019104 | f1: 0.8559487462043762 | auc: 0.9035002589225769 | treshold: 0.46
Test: accuracy: 0.795599102973938 | f1: 0.8566535115242004 | auc: 0.9021508097648621 | 


train... loss:44.37244117259979:   0%|▍                                                                                                 | 24/5000 [02:44<9:37:32,  6.96s/it]

Val update: epoch: 23 |accuracy: 0.803415060043335 | f1: 0.8601321578025818 | auc: 0.9039633274078369 | treshold: 0.44
Test: accuracy: 0.8030226230621338 | f1: 0.8602408766746521 | auc: 0.903124988079071 | 


train... loss:43.90994915366173:   0%|▍                                                                                                 | 25/5000 [02:51<9:38:49,  6.98s/it]

Val update: epoch: 24 |accuracy: 0.8076960444450378 | f1: 0.8623036742210388 | auc: 0.9058196544647217 | treshold: 0.44
Test: accuracy: 0.807762086391449 | f1: 0.8626610636711121 | auc: 0.904667317867279 | 


train... loss:43.79506251215935:   1%|▌                                                                                                 | 28/5000 [03:11<9:27:55,  6.85s/it]

Val update: epoch: 27 |accuracy: 0.8106709122657776 | f1: 0.8639460802078247 | auc: 0.9070138931274414 | treshold: 0.48000000000000004
Test: accuracy: 0.810155987739563 | f1: 0.8638893365859985 | auc: 0.9054217338562012 | 


train... loss:43.27022033929825:   1%|▌                                                                                                 | 29/5000 [03:18<9:32:36,  6.91s/it]

Val update: epoch: 28 |accuracy: 0.8131137490272522 | f1: 0.8650093674659729 | auc: 0.9078608751296997 | treshold: 0.44
Test: accuracy: 0.8135170936584473 | f1: 0.8655930757522583 | auc: 0.9067535400390625 | 


train... loss:43.216243386268616:   1%|▋                                                                                                | 33/5000 [03:45<9:19:40,  6.76s/it]

Val update: epoch: 32 |accuracy: 0.8277705311775208 | f1: 0.87054842710495 | auc: 0.9078973531723022 | treshold: 0.45
Test: accuracy: 0.8305646181106567 | f1: 0.8729119300842285 | auc: 0.9074380993843079 | 


train... loss:42.829331547021866:   1%|▋                                                                                                | 36/5000 [04:05<9:19:30,  6.76s/it]

Val update: epoch: 35 |accuracy: 0.8356794118881226 | f1: 0.8707775473594666 | auc: 0.9107421040534973 | treshold: 0.46
Test: accuracy: 0.8369725346565247 | f1: 0.8724362254142761 | auc: 0.9102761745452881 | 


train... loss:42.45061457157135:   1%|▋                                                                                                 | 37/5000 [04:12<9:28:32,  6.87s/it]

Val update: epoch: 36 |accuracy: 0.834058940410614 | f1: 0.8659411072731018 | auc: 0.9130579233169556 | treshold: 0.48000000000000004
Test: accuracy: 0.8330069184303284 | f1: 0.8656525015830994 | auc: 0.9123003482818604 | 


train... loss:42.440340012311935:   1%|▋                                                                                                | 38/5000 [04:19<9:33:30,  6.93s/it]

Val update: epoch: 37 |accuracy: 0.8359212279319763 | f1: 0.8698787689208984 | auc: 0.9130589962005615 | treshold: 0.48000000000000004
Test: accuracy: 0.8352557420730591 | f1: 0.8700032234191895 | auc: 0.9118404984474182 | 


train... loss:42.34934765100479:   1%|▊                                                                                                 | 40/5000 [04:32<9:28:26,  6.88s/it]

Val update: epoch: 39 |accuracy: 0.8376626372337341 | f1: 0.8717370629310608 | auc: 0.9140777587890625 | treshold: 0.47000000000000003
Test: accuracy: 0.8376012444496155 | f1: 0.8720713257789612 | auc: 0.913022518157959 | 


train... loss:42.02761101722717:   1%|▊                                                                                                 | 43/5000 [04:53<9:24:04,  6.83s/it]

Val update: epoch: 42 |accuracy: 0.8343249559402466 | f1: 0.8749040961265564 | auc: 0.9143041372299194 | treshold: 0.47000000000000003
Test: accuracy: 0.8335388898849487 | f1: 0.8748181462287903 | auc: 0.9125316739082336 | 


train... loss:41.36728295683861:   1%|▉                                                                                                 | 50/5000 [05:39<9:10:44,  6.68s/it]

Val update: epoch: 49 |accuracy: 0.8250858783721924 | f1: 0.8723479509353638 | auc: 0.9162136912345886 | treshold: 0.48000000000000004
Test: accuracy: 0.8246886730194092 | f1: 0.8723636269569397 | auc: 0.915189802646637 | 


train... loss:41.24025613069534:   1%|█                                                                                                 | 54/5000 [06:05<9:13:16,  6.71s/it]

Val update: epoch: 53 |accuracy: 0.841266393661499 | f1: 0.8737034797668457 | auc: 0.916419506072998 | treshold: 0.47000000000000003
Test: accuracy: 0.8391246795654297 | f1: 0.8726430535316467 | auc: 0.9159804582595825 | 


train... loss:41.020749509334564:   1%|█                                                                                                | 55/5000 [06:12<9:22:05,  6.82s/it]

Val update: epoch: 54 |accuracy: 0.8394040465354919 | f1: 0.8711230158805847 | auc: 0.9173226952552795 | treshold: 0.47000000000000003
Test: accuracy: 0.8387861251831055 | f1: 0.8711815476417542 | auc: 0.9170501828193665 | 


train... loss:41.69330483675003:   1%|█▏                                                                                                | 62/5000 [07:04<9:23:18,  6.84s/it]


In [10]:
pd.DataFrame(rl4rs_results).to_csv(f'results/rl4rs_{experiment_name}.csv')
del dataset, train_loader, val_loader, test_loader, train_user_item_matrix, train_num_items