In [1]:
import os
import torch
import random
import datetime
import pandas as pd
import numpy as np

from torch.utils.data import Dataset
from src.datasets import RL4RS, ContentWise, DummyData
from src.utils import train, get_dummy_data, get_train_val_test_tmatrix_tnumitems
from src.embeddings import RecsysEmbedding

experiment_name = 'FlatAttentionClickedOnly2LinearLnorm'
device = 'cuda:2'
seed = 123
pkl_path = '../data/'

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7f80d8cb4190>

In [2]:
torch.__version__

'1.9.0'

In [3]:
future_mask = torch.ones((3, 5))
future_mask = torch.triu(future_mask, diagonal=1).bool()
future_mask


tensor([[False,  True,  True,  True,  True],
        [False, False,  True,  True,  True],
        [False, False, False,  True,  True]])

# Модель

In [4]:
import torch.nn.functional as F
torch.autograd.set_detect_anomaly(True)


class AttentionSequencewiseResponseModel(torch.nn.Module):
    """
    No recurrent dependency, just slate-wise attention.
    """
    def __init__(self, embedding, nheads=2, output_dim=1, debug=False):
        super().__init__()
        self.embedding_dim = embedding.embedding_dim
        self.embedding = embedding
        self.nheads = nheads
        self.debug = debug
        self.attention= torch.nn.MultiheadAttention(
            self.embedding_dim,
            num_heads=nheads,
            batch_first=True
        )
        
        self.linear = torch.nn.Linear(2* embedding.embedding_dim, embedding.embedding_dim)
        
        
        
        self.out_layer = torch.nn.Sequential(
                torch.nn.LayerNorm(2* embedding.embedding_dim),
                torch.nn.Linear(embedding.embedding_dim * 2, embedding.embedding_dim * 2),
                torch.nn.GELU(),
                torch.nn.Linear(embedding.embedding_dim * 2, output_dim)
        )

    def forward(self, batch):
        item_embs, user_embs = self.embedding(batch)
        shp = item_embs.shape
        device = item_embs.device
        
        if self.debug: print('responses', batch['responses'])
        
        # flattening slates into long sequences
        item_embs = item_embs.flatten(1, 2)
        
        # getting user embedding (mean consumed items)
        if self.training:
            indices = (batch['length'] - 1)
        else:
            indices = (batch['in_length'] - 1)
        indices[indices<0] = 0
        indices = indices[:, None, None].repeat(1, 1, user_embs.size(-1))
        user_embs = user_embs.gather(1, indices).squeeze(-2)
        
        # adding a user embedding as a 'zero item' to 
        # make predictions if nothing is observed
        keys = torch.cat([
                user_embs[:, None, :],
                item_embs
            ],
            dim = 1
        )
        
        clicked_mask = batch['responses'].flatten(1, 2) > 0
        if self.debug: print(clicked_mask, clicked_mask.shape)
        clicked_mask = torch.cat([
                torch.ones_like(clicked_mask[:, 0, None]).bool(),
                clicked_mask
            ],
            dim=-1
        )
        if self.debug: print('clicked_mask:', clicked_mask, clicked_mask.shape)
        
        clicked_items = [
            keys[i][clicked_mask[i], :]
            for i in range(shp[0])
        ]
        keys = torch.nn.utils.rnn.pad_sequence(
            clicked_items,
            batch_first=True,
            padding_value=float('nan')
        )
        key_padding_mask = keys.isnan().all(-1)
        keys = keys.nan_to_num(0)
        
        if self.debug: print('key', keys.shape, keys, key_padding_mask)
        
        # forbid model looking into future (and into current iteraction)
        # at the end, mask will be (num_heads * bsize, slate_size * sequence_size, max_len_clicked_items)
        future_mask = torch.ones((item_embs.size(-2) + 1, item_embs.size(-2) + 1 ), device=device)
        future_mask = torch.triu(future_mask, diagonal=0)[1:, :].bool()
        if self.debug: print('future_mask', future_mask.shape, future_mask)
        ####### TODOTODOTODO ########
        # change future masl to be slatewise, not sequencewise
        # build it without large matrix generaton if possible
        # chunk previous iteracrtion
        ############################
        
        if self.debug: print('click_mask_repeated', clicked_mask[None, 0, ...].repeat(item_embs.size(-2), 1))
        
        if self.debug: print(
            'first user click & past', 
            future_mask[clicked_mask[None, 0, ...].repeat(item_embs.size(-2), 1)].reshape(item_embs.size(-2), keys.size(-2) )
        )
        
        attn_mask = [
            future_mask[
                clicked_mask[None, i, ...].repeat(item_embs.size(-2), 1)
            ].reshape(item_embs.size(-2), clicked_mask[None, i, ...].sum() ).T
            for i in range(shp[0])
        ]
        if self.debug: print('attn_mask', attn_mask[0].shape, attn_mask[1].shape )
                
        
        attn_mask = torch.nn.utils.rnn.pad_sequence(
            attn_mask,
            batch_first=True,
            padding_value=True
        ).permute(0, 2, 1)
        if self.debug: print('attn_mask_stacked', attn_mask)
        
        if self.debug: print(item_embs, keys, key_padding_mask, attn_mask.repeat_interleave(self.nheads, 0))
        features, attn_map = self.attention(
            item_embs, keys, keys,
            key_padding_mask=key_padding_mask,
            attn_mask = attn_mask.repeat_interleave(self.nheads, 0)
        )
        
        if self.debug: print(shp, item_embs.shape)
        features = torch.cat(
            [
                features.reshape(shp),
                item_embs.reshape(shp)
            ],
            dim = -1
        )
        
        return self.out_layer(features).squeeze(-1)
        if self.debug: print('features', features.shape, features)
        
        return self.out_layer(features).squeeze(-1)

    
    
d = DummyData()
dummy_loader, dummy_matrix = get_dummy_data(d)

for batch in dummy_loader:
    break
    
model = AttentionSequencewiseResponseModel(
    RecsysEmbedding(d.n_items, dummy_matrix, embeddings='explicit', embedding_dim=2),
    output_dim=1
).to('cpu')

model(batch)


3it [00:00, 2321.57it/s]

biulding affinity matrix...





tensor([[[ 0.2185, -0.0839,  0.2184],
         [ 0.2181, -0.0837,  0.2182]],

        [[ 0.0054,  0.0054,  0.0054],
         [ 0.0054,  0.0054,  0.0054]]], grad_fn=<SqueezeBackward1>)

# Игрушечный датасет: проверим, что сходится к идеальным метрикам

In [5]:
d = DummyData()
dummy_loader, dummy_matrix = get_dummy_data(d)

for embeddings in ['explicit', 'svd', 'neural']:
    print(f"\nEvaluating {experiment_name} with {embeddings} embeddings")
    model = AttentionSequencewiseResponseModel(
        RecsysEmbedding(d.n_items, dummy_matrix, embeddings='neural').to('cpu'),
        output_dim=1
    ).to('cpu')
    _, metrics = train(model, 
       dummy_loader, dummy_loader, dummy_loader, device='cpu', lr=1e-3, num_epochs=5000, dummy=True,
       silent=True,
#        debug=True,
    
    )


3it [00:00, 2709.50it/s]

biulding affinity matrix...

Evaluating FlatAttentionClickedOnly2LinearLnorm with explicit embeddings



train... loss:0.667537271976471:   0%|                                                                                       | 0/5000 [00:00<?, ?it/s]

Test before learning: {'f1': 1.0, 'roc-auc': 1.0, 'accuracy': 1.0}


train... loss:0.667537271976471:   0%|                                                                                       | 0/5000 [00:00<?, ?it/s]
train:   0%|                                                                                                                 | 0/5000 [00:00<?, ?it/s]


Evaluating FlatAttentionClickedOnly2LinearLnorm with svd embeddings
Test before learning: {'f1': 0.0, 'roc-auc': 0.0, 'accuracy': 0.75}


train... loss:0.6866967082023621:   0%|                                                                              | 1/5000 [00:00<22:06,  3.77it/s]

Val update: epoch: 0 |accuracy: 0.25 | f1: 0.4000000059604645 | auc: 0.0 | treshold: 0.01
Test: accuracy: 0.25 | f1: 0.4000000059604645 | auc: 0.0 | 


train... loss:0.6653292775154114:   0%|                                                                              | 2/5000 [00:00<22:02,  3.78it/s]

Val update: epoch: 1 |accuracy: 0.25 | f1: 0.4000000059604645 | auc: 1.0 | treshold: 0.01
Test: accuracy: 0.25 | f1: 0.4000000059604645 | auc: 1.0 | 


train... loss:0.6653292775154114:   0%|                                                                              | 2/5000 [00:00<33:01,  2.52it/s]
train:   0%|                                                                                                                 | 0/5000 [00:00<?, ?it/s]

Val update: epoch: 2 |accuracy: 1.0 | f1: 1.0 | auc: 1.0 | treshold: 0.48000000000000004
Test: accuracy: 1.0 | f1: 1.0 | auc: 1.0 | 

Evaluating FlatAttentionClickedOnly2LinearLnorm with neural embeddings
Test before learning: {'f1': 0.5, 'roc-auc': 0.3333333134651184, 'accuracy': 0.5}


train... loss:0.6762871146202087:   0%|                                                                              | 2/5000 [00:00<17:47,  4.68it/s]

Val update: epoch: 1 |accuracy: 0.5 | f1: 0.5 | auc: 1.0 | treshold: 0.51
Test: accuracy: 0.5 | f1: 0.5 | auc: 1.0 | 


train... loss:0.6762871146202087:   0%|                                                                              | 2/5000 [00:00<29:58,  2.78it/s]

Val update: epoch: 2 |accuracy: 1.0 | f1: 1.0 | auc: 1.0 | treshold: 0.56
Test: accuracy: 1.0 | f1: 1.0 | auc: 1.0 | 





# ContentWise

In [6]:
content_wise_results = []
c = ContentWise.load(os.path.join(pkl_path, 'cw.pkl'))
c_train_loader, c_val_loader, c_test_loader, c_train_user_item_matrix, train_num_items = get_train_val_test_tmatrix_tnumitems(c, batch_size=64)
len(c_train_loader), len(c)

(253, 20216)

In [7]:
for batch in c:
    print(batch['slates_item_ids'].shape)
    break

(28, 50)


In [8]:
for embeddings in ['neural', 'svd']:
    print(f"\nEvaluating {experiment_name} with {embeddings} embeddings")
    model = AttentionSequencewiseResponseModel(
        RecsysEmbedding(train_num_items, c_train_user_item_matrix, embeddings='neural'),
        output_dim=1
    ).to(device)

    _, metrics = train(model, 
       c_train_loader, c_val_loader, c_test_loader, device=device, lr=1e-3, num_epochs=5000, early_stopping=7,
       silent=True, 
    )
    
    metrics['embeddings'] = embeddings
    content_wise_results.append(metrics)
    
pd.DataFrame(content_wise_results).to_csv(f'results/cw_{experiment_name}.csv')


Evaluating FlatAttentionClickedOnly2LinearLnorm with neural embeddings


train:   0%|                                                                                                                 | 0/5000 [00:00<?, ?it/s]

Test before learning: {'f1': 0.17295430600643158, 'roc-auc': 0.5320330858230591, 'accuracy': 0.63466477394104}


train... loss:84.3181885778904:   0%|                                                                            | 1/5000 [01:24<117:56:30, 84.94s/it]

Val update: epoch: 0 |accuracy: 0.7507981657981873 | f1: 0.27686962485313416 | auc: 0.6850188970565796 | treshold: 0.13
Test: accuracy: 0.7520033717155457 | f1: 0.2763349115848541 | auc: 0.6870205402374268 | 


train... loss:72.8741234689951:   0%|                                                                            | 2/5000 [02:38<113:16:17, 81.59s/it]

Val update: epoch: 1 |accuracy: 0.8103855848312378 | f1: 0.3014194369316101 | auc: 0.7131384611129761 | treshold: 0.16
Test: accuracy: 0.8132011890411377 | f1: 0.3019131124019623 | auc: 0.7117938995361328 | 


train... loss:70.37394453585148:   0%|                                                                           | 3/5000 [03:53<110:35:26, 79.67s/it]

Val update: epoch: 2 |accuracy: 0.87528395652771 | f1: 0.3301179111003876 | auc: 0.7454283237457275 | treshold: 0.19
Test: accuracy: 0.8805140852928162 | f1: 0.33313068747520447 | auc: 0.7445813417434692 | 


train... loss:67.70227608084679:   0%|▏                                                                         | 16/5000 [20:54<108:33:51, 78.42s/it]



Evaluating FlatAttentionClickedOnly2LinearLnorm with svd embeddings


train:   0%|                                                                                                                 | 0/5000 [00:00<?, ?it/s]

Test before learning: {'f1': 0.1648014932870865, 'roc-auc': 0.495025098323822, 'accuracy': 0.3763984739780426}


train... loss:85.59551414847374:   0%|                                                                           | 1/5000 [01:14<102:58:42, 74.16s/it]

Val update: epoch: 0 |accuracy: 0.7406674027442932 | f1: 0.2762284278869629 | auc: 0.6805226802825928 | treshold: 0.13
Test: accuracy: 0.7419204115867615 | f1: 0.2748338580131531 | auc: 0.6816522479057312 | 


train... loss:72.39747193455696:   0%|                                                                           | 2/5000 [02:27<102:49:10, 74.06s/it]

Val update: epoch: 1 |accuracy: 0.7881899476051331 | f1: 0.32327988743782043 | auc: 0.7359376549720764 | treshold: 0.18000000000000002
Test: accuracy: 0.7915570735931396 | f1: 0.32013803720474243 | auc: 0.7311761379241943 | 


train... loss:69.15933473408222:   0%|                                                                           | 4/5000 [04:55<102:40:39, 73.99s/it]

Val update: epoch: 3 |accuracy: 0.8788450956344604 | f1: 0.3252115845680237 | auc: 0.7410580515861511 | treshold: 0.19
Test: accuracy: 0.8800628781318665 | f1: 0.31823810935020447 | auc: 0.7383405566215515 | 


train... loss:68.88804356753826:   0%|                                                                           | 6/5000 [07:23<102:32:38, 73.92s/it]

Val update: epoch: 5 |accuracy: 0.870126485824585 | f1: 0.3395519554615021 | auc: 0.7468965649604797 | treshold: 0.19
Test: accuracy: 0.8723294734954834 | f1: 0.3379327058792114 | auc: 0.7466886043548584 | 


train... loss:67.77542945742607:   0%|                                                                           | 7/5000 [08:37<102:55:13, 74.21s/it]

Val update: epoch: 6 |accuracy: 0.8879013657569885 | f1: 0.319448322057724 | auc: 0.7505460977554321 | treshold: 0.21000000000000002
Test: accuracy: 0.8896634578704834 | f1: 0.3216301500797272 | auc: 0.7505181431770325 | 


train... loss:64.8458554148674:   1%|▌                                                                          | 36/5000 [44:55<103:15:07, 74.88s/it]


# RL4RS

In [9]:
!ls ../data/rl4rs-dataset/

item_info.csv		rl4rs_dataset_a_sl.csv	rl4rs_dataset_b_sl.csv
rl4rs_dataset_a_rl.csv	rl4rs_dataset_b_rl.csv


In [10]:
rl4rs_results = []
r = RL4RS.load(os.path.join(pkl_path, 'rl4rs.pkl'))
r_train_loader, r_val_loader, r_test_loader, r_train_user_item_matrix, train_num_items = get_train_val_test_tmatrix_tnumitems(r, batch_size=10000)
len(r_train_loader), len(r)

(4, 45942)

In [11]:
for embeddings in ['svd',  ]:
    print(f"\nEvaluating {experiment_name} with {embeddings} embeddings")

    model = AttentionSequencewiseResponseModel(
        RecsysEmbedding(train_num_items, r_train_user_item_matrix, embeddings='neural'),
        output_dim=1
    ).to(device)

    _, metrics = train(model, 
       r_train_loader, r_val_loader, r_test_loader, device=device, lr=1e-3, num_epochs=5000, early_stopping=7,
       silent=True, 
    )
    
    metrics['embeddings'] = embeddings
    rl4rs_results.append(metrics)
    
pd.DataFrame(rl4rs_results).to_csv(f'results/rl4rs_{experiment_name}.csv')


Evaluating FlatAttentionClickedOnly2LinearLnorm with svd embeddings


train:   0%|                                                                                                                 | 0/5000 [00:00<?, ?it/s]

Test before learning: {'f1': 0.6679242849349976, 'roc-auc': 0.4804536700248718, 'accuracy': 0.5402732491493225}


train... loss:2.7005019187927246:   0%|                                                                         | 1/5000 [02:35<215:18:55, 155.06s/it]

Val update: epoch: 0 |accuracy: 0.6478014588356018 | f1: 0.7859033346176147 | auc: 0.5354908108711243 | treshold: 0.42000000000000004
Test: accuracy: 0.6456776857376099 | f1: 0.7843275666236877 | auc: 0.5423147678375244 | 


train... loss:2.627191424369812:   0%|                                                                          | 2/5000 [04:42<203:48:38, 146.80s/it]

Val update: epoch: 1 |accuracy: 0.6517196297645569 | f1: 0.7875165939331055 | auc: 0.5957281589508057 | treshold: 0.46
Test: accuracy: 0.6501511335372925 | f1: 0.7861661314964294 | auc: 0.6054037809371948 | 


train... loss:2.558985114097595:   0%|                                                                          | 3/5000 [06:49<195:33:16, 140.88s/it]

Val update: epoch: 2 |accuracy: 0.655976414680481 | f1: 0.7889551520347595 | auc: 0.6448004245758057 | treshold: 0.48000000000000004
Test: accuracy: 0.6565107107162476 | f1: 0.7886066436767578 | auc: 0.6549935340881348 | 


train... loss:2.500037968158722:   0%|                                                                          | 4/5000 [08:58<190:20:21, 137.15s/it]

Val update: epoch: 3 |accuracy: 0.6715039014816284 | f1: 0.7935490012168884 | auc: 0.6800462007522583 | treshold: 0.52
Test: accuracy: 0.676169753074646 | f1: 0.7956854701042175 | auc: 0.6900849938392639 | 


train... loss:2.4456520080566406:   0%|                                                                         | 5/5000 [11:05<186:02:02, 134.08s/it]

Val update: epoch: 4 |accuracy: 0.6882407069206238 | f1: 0.8007296919822693 | auc: 0.706630289554596 | treshold: 0.54
Test: accuracy: 0.6925885677337646 | f1: 0.8026452660560608 | auc: 0.7161498069763184 | 


train... loss:2.4046945571899414:   0%|                                                                         | 6/5000 [13:12<183:02:43, 131.95s/it]

Val update: epoch: 5 |accuracy: 0.7024863362312317 | f1: 0.8059564828872681 | auc: 0.7287002205848694 | treshold: 0.56
Test: accuracy: 0.7093942761421204 | f1: 0.8096911907196045 | auc: 0.7375344038009644 | 


train... loss:2.3653018474578857:   0%|                                                                         | 7/5000 [15:19<181:00:31, 130.51s/it]

Val update: epoch: 6 |accuracy: 0.7200454473495483 | f1: 0.8105037212371826 | auc: 0.7469390034675598 | treshold: 0.59
Test: accuracy: 0.7281586527824402 | f1: 0.8151351809501648 | auc: 0.755002498626709 | 


train... loss:2.323249578475952:   0%|                                                                          | 8/5000 [17:26<179:33:37, 129.49s/it]

Val update: epoch: 7 |accuracy: 0.7277366518974304 | f1: 0.814403235912323 | auc: 0.7609150409698486 | treshold: 0.5700000000000001
Test: accuracy: 0.7344456315040588 | f1: 0.8181607723236084 | auc: 0.7679631114006042 | 


train... loss:2.264640688896179:   0%|▏                                                                         | 9/5000 [19:34<178:55:07, 129.05s/it]

Val update: epoch: 8 |accuracy: 0.7339766621589661 | f1: 0.8160364031791687 | auc: 0.7710447907447815 | treshold: 0.54
Test: accuracy: 0.7389432787895203 | f1: 0.8188712000846863 | auc: 0.7770974636077881 | 


train... loss:2.212135970592499:   0%|▏                                                                        | 10/5000 [21:41<178:13:15, 128.58s/it]

Val update: epoch: 9 |accuracy: 0.7390074133872986 | f1: 0.8167942762374878 | auc: 0.7787522673606873 | treshold: 0.5
Test: accuracy: 0.7443597912788391 | f1: 0.8200755715370178 | auc: 0.7840741872787476 | 


train... loss:2.1942065954208374:   0%|▏                                                                       | 11/5000 [23:50<178:03:41, 128.49s/it]

Val update: epoch: 10 |accuracy: 0.7414018511772156 | f1: 0.8186075091362 | auc: 0.7864785194396973 | treshold: 0.44
Test: accuracy: 0.7470922470092773 | f1: 0.8223886489868164 | auc: 0.7915689945220947 | 


train... loss:2.1988566517829895:   0%|▏                                                                       | 12/5000 [25:58<177:49:22, 128.34s/it]

Val update: epoch: 11 |accuracy: 0.7464325428009033 | f1: 0.8226297497749329 | auc: 0.7981290817260742 | treshold: 0.4
Test: accuracy: 0.7514931559562683 | f1: 0.8259639739990234 | auc: 0.8032486438751221 | 


train... loss:2.1660225987434387:   0%|▏                                                                       | 13/5000 [28:05<177:26:09, 128.09s/it]

Val update: epoch: 12 |accuracy: 0.750616729259491 | f1: 0.8275262117385864 | auc: 0.8130170106887817 | treshold: 0.38
Test: accuracy: 0.7541530728340149 | f1: 0.8296785354614258 | auc: 0.8182663321495056 | 


train... loss:2.085037350654602:   0%|▏                                                                        | 14/5000 [30:13<177:32:04, 128.18s/it]

Val update: epoch: 13 |accuracy: 0.7558651566505432 | f1: 0.8316319584846497 | auc: 0.8254384398460388 | treshold: 0.4
Test: accuracy: 0.7583847045898438 | f1: 0.8329711556434631 | auc: 0.830744206905365 | 


train... loss:2.006169557571411:   0%|▏                                                                        | 15/5000 [32:21<177:17:18, 128.03s/it]

Val update: epoch: 14 |accuracy: 0.7609442472457886 | f1: 0.8338544368743896 | auc: 0.8333532810211182 | treshold: 0.46
Test: accuracy: 0.7654213309288025 | f1: 0.8366202116012573 | auc: 0.8386160731315613 | 


train... loss:1.9738087058067322:   0%|▏                                                                       | 16/5000 [34:29<177:18:13, 128.07s/it]

Val update: epoch: 15 |accuracy: 0.7637739777565002 | f1: 0.8359783291816711 | auc: 0.8382916450500488 | treshold: 0.51
Test: accuracy: 0.7670414447784424 | f1: 0.8380731344223022 | auc: 0.8434215784072876 | 


train... loss:1.998332440853119:   0%|▏                                                                        | 17/5000 [36:37<177:12:44, 128.03s/it]

Val update: epoch: 16 |accuracy: 0.7771005630493164 | f1: 0.8409992456436157 | auc: 0.8432740569114685 | treshold: 0.61
Test: accuracy: 0.7816709280014038 | f1: 0.8441044688224792 | auc: 0.8481858968734741 | 


train... loss:2.0188634991645813:   0%|▎                                                                       | 18/5000 [38:44<176:49:53, 127.78s/it]

Val update: epoch: 17 |accuracy: 0.7866541147232056 | f1: 0.8461873531341553 | auc: 0.850487232208252 | treshold: 0.64
Test: accuracy: 0.7903034687042236 | f1: 0.8486931920051575 | auc: 0.8549204468727112 | 


train... loss:1.9731696248054504:   0%|▎                                                                       | 19/5000 [40:52<176:31:05, 127.58s/it]

Val update: epoch: 18 |accuracy: 0.7954095005989075 | f1: 0.8496631979942322 | auc: 0.8580266237258911 | treshold: 0.65
Test: accuracy: 0.7972192168235779 | f1: 0.8509844541549683 | auc: 0.8618439435958862 | 


train... loss:1.8966480493545532:   0%|▎                                                                       | 20/5000 [42:59<176:16:08, 127.42s/it]

Val update: epoch: 19 |accuracy: 0.8009964823722839 | f1: 0.8531186580657959 | auc: 0.863500714302063 | treshold: 0.61
Test: accuracy: 0.8013541102409363 | f1: 0.8532590270042419 | auc: 0.8667763471603394 | 


train... loss:1.8439721167087555:   0%|▎                                                                       | 21/5000 [45:06<176:24:07, 127.55s/it]

Val update: epoch: 20 |accuracy: 0.8042132258415222 | f1: 0.8538148999214172 | auc: 0.8665772080421448 | treshold: 0.5700000000000001
Test: accuracy: 0.8038205504417419 | f1: 0.853309690952301 | auc: 0.8693675994873047 | 


train... loss:1.8307538330554962:   0%|▎                                                                       | 22/5000 [47:13<176:09:17, 127.39s/it]

Val update: epoch: 21 |accuracy: 0.8055434823036194 | f1: 0.8550620079040527 | auc: 0.8703561425209045 | treshold: 0.51
Test: accuracy: 0.8053197860717773 | f1: 0.854609489440918 | auc: 0.8728649616241455 | 


train... loss:1.8264312446117401:   0%|▎                                                                       | 23/5000 [49:22<176:25:18, 127.61s/it]

Val update: epoch: 22 |accuracy: 0.811976969242096 | f1: 0.8602854013442993 | auc: 0.8762827515602112 | treshold: 0.46
Test: accuracy: 0.8110506534576416 | f1: 0.859263002872467 | auc: 0.8787684440612793 | 


train... loss:1.8051697313785553:   0%|▎                                                                       | 24/5000 [51:29<176:07:48, 127.43s/it]

Val update: epoch: 23 |accuracy: 0.8128960728645325 | f1: 0.86185222864151 | auc: 0.8795561790466309 | treshold: 0.42000000000000004
Test: accuracy: 0.8135896325111389 | f1: 0.8621546626091003 | auc: 0.8823593258857727 | 


train... loss:1.7487117946147919:   1%|▍                                                                     | 29/5000 [1:01:59<174:27:43, 126.35s/it]

Val update: epoch: 28 |accuracy: 0.8263919353485107 | f1: 0.8715737462043762 | auc: 0.8878380656242371 | treshold: 0.65
Test: accuracy: 0.8286059498786926 | f1: 0.8733426332473755 | auc: 0.8924986124038696 | 


train... loss:1.690667450428009:   1%|▍                                                                      | 30/5000 [1:04:06<174:48:35, 126.62s/it]

Val update: epoch: 29 |accuracy: 0.8333333134651184 | f1: 0.8748978972434998 | auc: 0.8943770527839661 | treshold: 0.66
Test: accuracy: 0.834506094455719 | f1: 0.8756540417671204 | auc: 0.8985152244567871 | 


train... loss:1.639001339673996:   1%|▍                                                                      | 31/5000 [1:06:13<175:06:10, 126.86s/it]

Val update: epoch: 30 |accuracy: 0.831325888633728 | f1: 0.8715085983276367 | auc: 0.895466685295105 | treshold: 0.63
Test: accuracy: 0.8334663510322571 | f1: 0.8730810880661011 | auc: 0.899145245552063 | 


train... loss:1.6841345131397247:   1%|▍                                                                     | 34/5000 [1:12:33<174:44:51, 126.68s/it]

Val update: epoch: 33 |accuracy: 0.8281332850456238 | f1: 0.8718854784965515 | auc: 0.8969786763191223 | treshold: 0.42000000000000004
Test: accuracy: 0.830298662185669 | f1: 0.8730830550193787 | auc: 0.900980532169342 | 


train... loss:1.656096875667572:   1%|▍                                                                      | 35/5000 [1:14:40<174:51:49, 126.79s/it]

Val update: epoch: 34 |accuracy: 0.8331640362739563 | f1: 0.8754019141197205 | auc: 0.8993916511535645 | treshold: 0.42000000000000004
Test: accuracy: 0.8358602523803711 | f1: 0.8769710659980774 | auc: 0.9036072492599487 | 


train... loss:1.6118916273117065:   1%|▌                                                                     | 41/5000 [1:27:22<175:03:44, 127.09s/it]

Val update: epoch: 40 |accuracy: 0.8420161604881287 | f1: 0.8815250396728516 | auc: 0.9021801352500916 | treshold: 0.68
Test: accuracy: 0.8445895314216614 | f1: 0.8833681344985962 | auc: 0.9051235318183899 | 


train... loss:1.5744479894638062:   1%|▌                                                                     | 42/5000 [1:29:30<175:31:46, 127.45s/it]

Val update: epoch: 41 |accuracy: 0.8419194221496582 | f1: 0.8803499937057495 | auc: 0.903520941734314 | treshold: 0.64
Test: accuracy: 0.8423649072647095 | f1: 0.8806284666061401 | auc: 0.9058558940887451 | 


train... loss:1.5456950068473816:   1%|▌                                                                     | 43/5000 [1:31:38<175:45:52, 127.65s/it]

Val update: epoch: 42 |accuracy: 0.8425240516662598 | f1: 0.8803542852401733 | auc: 0.9047814011573792 | treshold: 0.5700000000000001
Test: accuracy: 0.8421714305877686 | f1: 0.879963219165802 | auc: 0.906876802444458 | 


train... loss:1.5348676443099976:   1%|▌                                                                     | 44/5000 [1:33:45<175:23:42, 127.41s/it]

Val update: epoch: 43 |accuracy: 0.8421370983123779 | f1: 0.8803329467773438 | auc: 0.9073643684387207 | treshold: 0.5
Test: accuracy: 0.8435497283935547 | f1: 0.8812364339828491 | auc: 0.9098491072654724 | 


train... loss:1.5246820747852325:   1%|▋                                                                     | 45/5000 [1:35:52<175:04:13, 127.20s/it]

Val update: epoch: 44 |accuracy: 0.8444831371307373 | f1: 0.883830189704895 | auc: 0.9098190069198608 | treshold: 0.43
Test: accuracy: 0.8453391194343567 | f1: 0.8841766119003296 | auc: 0.9131060838699341 | 


train... loss:1.429726392030716:   1%|▊                                                                      | 54/5000 [1:54:44<173:16:30, 126.12s/it]

Val update: epoch: 53 |accuracy: 0.8493445515632629 | f1: 0.8872191309928894 | auc: 0.9098195433616638 | treshold: 0.61
Test: accuracy: 0.8539958596229553 | f1: 0.8903537392616272 | auc: 0.9135520458221436 | 


train... loss:1.390988051891327:   1%|▊                                                                      | 55/5000 [1:56:52<174:20:19, 126.92s/it]

Val update: epoch: 54 |accuracy: 0.8519324660301208 | f1: 0.8882071375846863 | auc: 0.9122520685195923 | treshold: 0.55
Test: accuracy: 0.8553500175476074 | f1: 0.8904234766960144 | auc: 0.9157196283340454 | 


train... loss:1.4140877723693848:   1%|▊                                                                     | 56/5000 [1:59:00<174:25:11, 127.00s/it]

Val update: epoch: 55 |accuracy: 0.8518357276916504 | f1: 0.8888848423957825 | auc: 0.9143202304840088 | treshold: 0.44
Test: accuracy: 0.8558819890022278 | f1: 0.8915890455245972 | auc: 0.9176409244537354 | 


train... loss:1.473132461309433:   1%|▊                                                                      | 57/5000 [2:01:06<174:21:20, 126.98s/it]

Val update: epoch: 56 |accuracy: 0.8528515696525574 | f1: 0.890461266040802 | auc: 0.9154578447341919 | treshold: 0.37
Test: accuracy: 0.856220543384552 | f1: 0.8925900459289551 | auc: 0.9186960458755493 | 


train... loss:1.3949508666992188:   1%|▊                                                                     | 61/5000 [2:11:34<177:33:32, 129.42s/it]


In [12]:
rl4rs_results

[{'f1': 0.8925900459289551,
  'roc-auc': 0.9186960458755493,
  'accuracy': 0.856220543384552,
  'embeddings': 'svd'}]

In [13]:
rl4rs_results.append(
    {
  'f1': 0.8974264860153198,
  'roc-auc': 0.9311494827270508,
  'accuracy': 0.8648772835731506,
  'embeddings': 'explicit'}
)