In [1]:
import os
import torch
import random
import datetime
import pandas as pd
import numpy as np

from torch.utils.data import Dataset
from src.datasets import RL4RS, ContentWise, DummyData
from src.utils import train, get_dummy_data, get_train_val_test_tmatrix_tnumitems, get_svd_encoder
from src.embeddings import RecsysEmbedding

experiment_name = 'InSlateAttentionSequencewiseGRU'
device = 'cuda:2'
seed = 123
pkl_path = '../data/'

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7f2bb8023170>

In [2]:
torch.__version__

'1.9.0'

# Модель

In [3]:
import torch.nn.functional as F
torch.autograd.set_detect_anomaly(True)

class AttentionResponseModel(torch.nn.Module):
    """
    No recurrent dependency, just slate-wise attention.
    """
    def __init__(self, embedding, nheads=2, output_dim=1):
        super().__init__()
        self.embedding_dim = embedding.embedding_dim
        self.embedding = embedding
        self.attention = torch.nn.MultiheadAttention(
            2 * embedding.embedding_dim,
            num_heads=nheads,
            batch_first=True
        )
        
        self.rnn_layer = torch.nn.GRU(
            input_size = embedding.embedding_dim, 
            hidden_size = embedding.embedding_dim, 
            batch_first=True
        )
        
        self.out_layer = torch.nn.Linear(3 * embedding.embedding_dim, output_dim)
    
    
    def get_attention_embeddings(self, item_embs, user_embs, slate_mask):
        shp = item_embs.shape      
        key_padding_mask = slate_mask
        key_padding_mask[:,:, 0] = True # let model attent to first padd token if slate is empty 
        features = torch.cat(
            [
                item_embs,
                user_embs[:, :, None, :].repeat(1, 1, item_embs.size(-2), 1).reshape(shp)
            ],
            dim = -1
        ).flatten(0,1)
        
        features, attn_map = self.attention(
            features, features, features,
            key_padding_mask=~key_padding_mask.flatten(0, 1)
        )
        shp = list(shp)
        shp[-1] *= 2
        features = features.reshape(shp)
        return features
    
    def forward(self, batch):
        # consider sequential clicks, hence need to flatten slates
        item_embs, user_embs = self.embedding(batch)
        slate_mask = batch['slates_mask'].clone()
        
        att_features = self.get_attention_embeddings(item_embs, user_embs, slate_mask)
        
        gru_features, _ = self.rnn_layer(item_embs.flatten(-3, -2))
        gru_features = gru_features.reshape(item_embs.shape)
        
        features = torch.cat(
            [att_features, gru_features],
            dim=-1
        )
        
        return self.out_layer(features).squeeze(-1)


    
d = DummyData()
dummy_loader, dummy_matrix = get_dummy_data(d)
emb = RecsysEmbedding(d.n_items, dummy_matrix, embeddings='explicit')

for batch in dummy_loader:
    break

model = AttentionResponseModel(
    RecsysEmbedding(d.n_items, dummy_matrix, embeddings='neural'),
    output_dim=1
).to('cpu')

model(batch)


3it [00:00, 3044.50it/s]

biulding affinity matrix...





tensor([[[ 0.1783,  0.1444,  0.1578],
         [ 0.1012,  0.0908,  0.1225]],

        [[-0.2847, -0.2641, -0.2120],
         [-0.0057,  0.0075,  0.0132]]], grad_fn=<SqueezeBackward1>)

# Игрушечный датасет: проверим, что сходится к идеальным метрикам

In [4]:
d = DummyData()
dummy_loader, dummy_matrix = get_dummy_data(d)

for embeddings in ['explicit', 'svd', 'neural']:
    print(f"\nEvaluating {experiment_name} with {embeddings} embeddings")
    model = AttentionResponseModel(
        RecsysEmbedding(d.n_items, dummy_matrix, embeddings='neural').to('cpu'),
        output_dim=1
    ).to('cpu')
    _, metrics = train(model, 
       dummy_loader, dummy_loader, dummy_loader, device=device, lr=1e-3, num_epochs=5000, dummy=True,
       silent=True,
#        debug=True,
    
    )


3it [00:00, 3583.85it/s]

biulding affinity matrix...

Evaluating InSlateAttentionSequencewiseGRU with explicit embeddings



train... loss:0.7021339535713196:   0%|                                                                                      | 0/5000 [00:00<?, ?it/s]

Test before learning: {'f1': 0.4000000059604645, 'roc-auc': 0.6666666269302368, 'accuracy': 0.25}


train... loss:0.670462965965271:   0%|                                                                               | 1/5000 [00:00<31:28,  2.65it/s]

Test update: epoch: 0 |accuracy: 0.75 | f1: 0.6666666865348816 | auc: 0.6666666269302368 | treshold: 0.52


train... loss:0.6418819427490234:   0%|                                                                              | 2/5000 [00:00<31:28,  2.65it/s]

Test update: epoch: 1 |accuracy: 0.75 | f1: 0.6666666865348816 | auc: 1.0 | treshold: 0.48000000000000004


train... loss:0.5694424510002136:   0%|                                                                              | 5/5000 [00:01<32:22,  2.57it/s]

Test update: epoch: 5 |accuracy: 1.0 | f1: 1.0 | auc: 1.0 | treshold: 0.5

Evaluating InSlateAttentionSequencewiseGRU with svd embeddings



train... loss:0.7021787762641907:   0%|                                                                                      | 0/5000 [00:00<?, ?it/s]

Test before learning: {'f1': 0.0, 'roc-auc': 0.6666666269302368, 'accuracy': 0.75}


train... loss:0.6731473803520203:   0%|                                                                              | 1/5000 [00:00<32:16,  2.58it/s]

Test update: epoch: 0 |accuracy: 0.75 | f1: 0.6666666865348816 | auc: 0.6666666269302368 | treshold: 0.47000000000000003


train... loss:0.6474711894989014:   0%|                                                                              | 2/5000 [00:00<32:17,  2.58it/s]

Test update: epoch: 1 |accuracy: 0.75 | f1: 0.6666666865348816 | auc: 1.0 | treshold: 0.43


train... loss:0.6232665777206421:   0%|                                                                              | 3/5000 [00:01<39:48,  2.09it/s]

Test update: epoch: 3 |accuracy: 1.0 | f1: 1.0 | auc: 1.0 | treshold: 0.53

Evaluating InSlateAttentionSequencewiseGRU with neural embeddings



train... loss:0.6780272126197815:   0%|                                                                                      | 0/5000 [00:00<?, ?it/s]

Test before learning: {'f1': 0.0, 'roc-auc': 1.0, 'accuracy': 0.75}


train... loss:0.647767186164856:   0%|                                                                               | 1/5000 [00:00<30:33,  2.73it/s]

Test update: epoch: 0 |accuracy: 0.75 | f1: 0.6666666865348816 | auc: 1.0 | treshold: 0.46


train... loss:0.5346876382827759:   0%|                                                                              | 6/5000 [00:02<29:21,  2.84it/s]

Test update: epoch: 6 |accuracy: 1.0 | f1: 1.0 | auc: 1.0 | treshold: 0.56





# ContentWise

In [5]:
# c = ContentWise('../data/CW/data/ContentWiseImpressions/CW10M-CSV/')

In [6]:
# c.dump(os.path.join(pkl_path, 'cw.pkl'))

In [7]:
content_wise_results = []
c = ContentWise.load(os.path.join(pkl_path, 'cw.pkl'))
c_train_loader, c_val_loader, c_test_loader, c_train_user_item_matrix, train_num_items = get_train_val_test_tmatrix_tnumitems(c, batch_size=50)
len(c_train_loader), len(c)

(324, 20216)

In [8]:
for embeddings in ['neural', 'svd']:
    print(f"\nEvaluating {experiment_name} with {embeddings} embeddings")
    model = AttentionResponseModel(
        RecsysEmbedding(train_num_items, c_train_user_item_matrix, embeddings='neural'),
        output_dim=1
    ).to(device)

    _, metrics = train(model, 
       c_train_loader, c_val_loader, c_test_loader, device=device, lr=1e-3, num_epochs=5000, early_stopping=7,
       silent=True, 
    )
    
    metrics['embeddings'] = embeddings
    content_wise_results.append(metrics)
    
pd.DataFrame(content_wise_results).to_csv(f'results/cw_{experiment_name}.csv')


Evaluating InSlateAttentionSequencewiseGRU with neural embeddings


train:   0%|                                                                                                                 | 0/5000 [00:00<?, ?it/s]

Test before learning: {'f1': 0.16915073990821838, 'roc-auc': 0.5131357908248901, 'accuracy': 0.5203292369842529}


train... loss:105.64544066786766:   0%|                                                                         | 1/5000 [01:43<144:07:24, 103.79s/it]

Test update: epoch: 0 |accuracy: 0.7411112785339355 | f1: 0.2638704478740692 | auc: 0.6752206087112427 | treshold: 0.12


train... loss:95.67465278506279:   0%|                                                                           | 2/5000 [03:07<135:35:26, 97.66s/it]

Test update: epoch: 1 |accuracy: 0.7134921550750732 | f1: 0.2818922698497772 | auc: 0.7126098871231079 | treshold: 0.12


train... loss:92.90769630670547:   0%|                                                                           | 3/5000 [04:33<130:52:49, 94.29s/it]

Test update: epoch: 2 |accuracy: 0.8896945714950562 | f1: 0.2246527373790741 | auc: 0.7288975715637207 | treshold: 0.14


train... loss:90.726393699646:   0%|                                                                             | 4/5000 [05:59<127:20:24, 91.76s/it]

Test update: epoch: 3 |accuracy: 0.8102136254310608 | f1: 0.31133192777633667 | auc: 0.731738805770874 | treshold: 0.16


train... loss:90.36211867630482:   0%|                                                                           | 5/5000 [07:25<125:04:54, 90.15s/it]

Test update: epoch: 4 |accuracy: 0.8015155792236328 | f1: 0.31500375270843506 | auc: 0.7366188764572144 | treshold: 0.15000000000000002


train... loss:88.36872591078281:   0%|                                                                           | 7/5000 [10:18<122:28:11, 88.30s/it]

Test update: epoch: 6 |accuracy: 0.7844305634498596 | f1: 0.3195481300354004 | auc: 0.7458364963531494 | treshold: 0.17


train... loss:85.55649302899837:   0%|▏                                                                         | 11/5000 [15:53<118:26:43, 85.47s/it]

Test update: epoch: 10 |accuracy: 0.8403379917144775 | f1: 0.34169501066207886 | auc: 0.7534766793251038 | treshold: 0.18000000000000002


train... loss:82.93851664662361:   0%|▎                                                                         | 19/5000 [27:00<116:31:56, 84.22s/it]

Test update: epoch: 18 |accuracy: 0.8756749033927917 | f1: 0.3189567029476166 | auc: 0.7501835823059082 | treshold: 0.2


train... loss:78.93031565845013:   1%|▉                                                                       | 63/5000 [1:31:12<119:07:18, 86.86s/it]



Evaluating InSlateAttentionSequencewiseGRU with svd embeddings


train:   0%|                                                                                                                 | 0/5000 [00:00<?, ?it/s]

Test before learning: {'f1': 0.16925175487995148, 'roc-auc': 0.514764666557312, 'accuracy': 0.5357959866523743}


train... loss:106.7204330265522:   0%|                                                                           | 1/5000 [01:28<122:25:13, 88.16s/it]

Test update: epoch: 0 |accuracy: 0.6936063766479492 | f1: 0.25708356499671936 | auc: 0.6774395704269409 | treshold: 0.11


train... loss:95.10113735496998:   0%|                                                                           | 2/5000 [02:58<123:10:29, 88.72s/it]

Test update: epoch: 1 |accuracy: 0.8265828490257263 | f1: 0.2555607557296753 | auc: 0.6798020601272583 | treshold: 0.14


train... loss:92.13206408917904:   0%|                                                                           | 3/5000 [04:28<123:55:09, 89.28s/it]

Test update: epoch: 2 |accuracy: 0.8102758526802063 | f1: 0.3104676902294159 | auc: 0.7328310608863831 | treshold: 0.16


train... loss:89.36901257932186:   0%|                                                                           | 5/5000 [07:24<123:00:22, 88.65s/it]

Test update: epoch: 4 |accuracy: 0.8605971932411194 | f1: 0.31089916825294495 | auc: 0.7409325838088989 | treshold: 0.16


train... loss:87.66208964586258:   0%|                                                                           | 7/5000 [10:15<121:12:51, 87.40s/it]

Test update: epoch: 6 |accuracy: 0.7545396685600281 | f1: 0.31445831060409546 | auc: 0.7453716993331909 | treshold: 0.17


train... loss:85.89792414009571:   0%|                                                                           | 8/5000 [11:44<121:47:58, 87.84s/it]

Test update: epoch: 7 |accuracy: 0.8774176239967346 | f1: 0.3114840090274811 | auc: 0.7526431679725647 | treshold: 0.18000000000000002


train... loss:85.33268457651138:   0%|▏                                                                         | 11/5000 [16:08<122:06:58, 88.12s/it]

Test update: epoch: 10 |accuracy: 0.8279832601547241 | f1: 0.33439701795578003 | auc: 0.7530032396316528 | treshold: 0.19


train... loss:79.79979886114597:   1%|▌                                                                         | 36/5000 [54:02<124:11:55, 90.07s/it]


# RL4RS

In [9]:
# r = RL4RS('../data/rl4rs-dataset/', which='rl4rs_dataset_b_rl.csv', min_session_len=3)
# r.dump(os.path.join(pkl_path, 'rl4rs.pkl'))

In [10]:
rl4rs_results = []
r = RL4RS.load(os.path.join(pkl_path, 'rl4rs.pkl'))
r_train_loader, r_val_loader, r_test_loader, r_train_user_item_matrix, train_num_items = get_train_val_test_tmatrix_tnumitems(r, batch_size=20000)
len(r_train_loader), len(r)

(2, 45942)

In [11]:
for embeddings in ['neural','explicit', 'svd',  ]:
    print(f"\nEvaluating {experiment_name} with {embeddings} embeddings")

    model = AttentionResponseModel(
        RecsysEmbedding(train_num_items, r_train_user_item_matrix, embeddings='neural'),
        output_dim=1
    ).to(device)

    _, metrics = train(model, 
       r_train_loader, r_val_loader, r_test_loader, device=device, lr=1e-3, num_epochs=5000, early_stopping=7,
       silent=True, 
    )
    
    metrics['embeddings'] = embeddings
    rl4rs_results.append(metrics)
    
pd.DataFrame(rl4rs_results).to_csv(f'results/rl4rs_{experiment_name}.csv')


Evaluating InSlateAttentionSequencewiseGRU with neural embeddings


train:   0%|                                                                                                                 | 0/5000 [00:00<?, ?it/s]

Test before learning: {'f1': 0.4037405550479889, 'roc-auc': 0.509438693523407, 'accuracy': 0.4434046745300293}


train... loss:1.4017160534858704:   0%|                                                                           | 1/5000 [00:56<77:57:09, 56.14s/it]

Test update: epoch: 0 |accuracy: 0.6443960666656494 | f1: 0.7837225794792175 | auc: 0.5602662563323975 | treshold: 0.38


train... loss:1.3727481961250305:   0%|                                                                           | 2/5000 [01:12<61:26:56, 44.26s/it]

Test update: epoch: 1 |accuracy: 0.6442751884460449 | f1: 0.7836076617240906 | auc: 0.6006720066070557 | treshold: 0.4


train... loss:1.346182942390442:   0%|                                                                            | 3/5000 [01:29<50:11:14, 36.16s/it]

Test update: epoch: 2 |accuracy: 0.6447829604148865 | f1: 0.7836588025093079 | auc: 0.6308712363243103 | treshold: 0.42000000000000004


train... loss:1.3211457133293152:   0%|                                                                           | 4/5000 [01:44<41:19:37, 29.78s/it]

Test update: epoch: 3 |accuracy: 0.6459195017814636 | f1: 0.7840297222137451 | auc: 0.6525312662124634 | treshold: 0.43


train... loss:1.296710193157196:   0%|                                                                            | 5/5000 [01:58<34:29:51, 24.86s/it]

Test update: epoch: 4 |accuracy: 0.6486760973930359 | f1: 0.784528911113739 | auc: 0.6678309440612793 | treshold: 0.45


train... loss:1.273995578289032:   0%|                                                                            | 6/5000 [02:15<31:14:35, 22.52s/it]

Test update: epoch: 5 |accuracy: 0.6507314443588257 | f1: 0.7850531339645386 | auc: 0.6788144111633301 | treshold: 0.46


train... loss:1.2537260055541992:   0%|                                                                           | 7/5000 [02:35<30:09:01, 21.74s/it]

Test update: epoch: 6 |accuracy: 0.6555918455123901 | f1: 0.7860415577888489 | auc: 0.6872224807739258 | treshold: 0.48000000000000004


train... loss:1.2366291880607605:   0%|                                                                           | 8/5000 [02:53<28:46:42, 20.75s/it]

Test update: epoch: 7 |accuracy: 0.6567525267601013 | f1: 0.7867369651794434 | auc: 0.6943309903144836 | treshold: 0.48000000000000004


train... loss:1.2252127528190613:   0%|▏                                                                          | 9/5000 [03:09<26:37:08, 19.20s/it]

Test update: epoch: 8 |accuracy: 0.6603071093559265 | f1: 0.7877496480941772 | auc: 0.7010986804962158 | treshold: 0.49


train... loss:1.2204594016075134:   0%|▏                                                                         | 10/5000 [03:24<24:47:55, 17.89s/it]

Test update: epoch: 9 |accuracy: 0.6646354794502258 | f1: 0.7891127467155457 | auc: 0.7081382870674133 | treshold: 0.5


train... loss:1.2211987972259521:   0%|▏                                                                         | 11/5000 [03:43<25:21:16, 18.30s/it]

Test update: epoch: 10 |accuracy: 0.6742110848426819 | f1: 0.7911907434463501 | auc: 0.7157353758811951 | treshold: 0.53


train... loss:1.2257484197616577:   0%|▏                                                                         | 12/5000 [04:02<25:55:26, 18.71s/it]

Test update: epoch: 11 |accuracy: 0.6774755120277405 | f1: 0.7928109765052795 | auc: 0.7238761782646179 | treshold: 0.52


train... loss:1.2286036610603333:   0%|▏                                                                         | 13/5000 [04:23<26:44:27, 19.30s/it]

Test update: epoch: 12 |accuracy: 0.687728226184845 | f1: 0.7957550287246704 | auc: 0.7322457432746887 | treshold: 0.53


train... loss:1.222460150718689:   0%|▏                                                                          | 14/5000 [04:44<27:16:09, 19.69s/it]

Test update: epoch: 13 |accuracy: 0.6991657614707947 | f1: 0.7984446883201599 | auc: 0.7402814626693726 | treshold: 0.54


train... loss:1.2069849967956543:   0%|▏                                                                         | 15/5000 [05:04<27:34:25, 19.91s/it]

Test update: epoch: 14 |accuracy: 0.705501139163971 | f1: 0.8006612658500671 | auc: 0.7473258376121521 | treshold: 0.52


train... loss:1.1864922046661377:   0%|▏                                                                         | 16/5000 [05:26<28:09:29, 20.34s/it]

Test update: epoch: 15 |accuracy: 0.7070487141609192 | f1: 0.8014552593231201 | auc: 0.7530964612960815 | treshold: 0.48000000000000004


train... loss:1.165749967098236:   0%|▎                                                                          | 17/5000 [05:47<28:32:01, 20.61s/it]

Test update: epoch: 16 |accuracy: 0.7079917788505554 | f1: 0.8017793297767639 | auc: 0.7575833797454834 | treshold: 0.44


train... loss:1.1493234038352966:   0%|▎                                                                         | 18/5000 [06:07<28:17:43, 20.45s/it]

Test update: epoch: 17 |accuracy: 0.7091766595840454 | f1: 0.8026062846183777 | auc: 0.7608678340911865 | treshold: 0.4


train... loss:1.1415225863456726:   0%|▎                                                                         | 19/5000 [06:25<27:22:53, 19.79s/it]

Test update: epoch: 18 |accuracy: 0.7103614807128906 | f1: 0.8036457896232605 | auc: 0.7631769180297852 | treshold: 0.36000000000000004


train... loss:1.1439074873924255:   0%|▎                                                                         | 20/5000 [06:42<26:09:23, 18.91s/it]

Test update: epoch: 19 |accuracy: 0.7118849158287048 | f1: 0.8048161268234253 | auc: 0.7648658752441406 | treshold: 0.33


train... loss:1.154037892818451:   0%|▎                                                                          | 21/5000 [06:57<24:37:26, 17.80s/it]

Test update: epoch: 20 |accuracy: 0.7152944207191467 | f1: 0.8060551881790161 | auc: 0.7664119005203247 | treshold: 0.32


train... loss:1.168085515499115:   0%|▎                                                                          | 22/5000 [07:14<24:03:08, 17.39s/it]

Test update: epoch: 21 |accuracy: 0.7153669595718384 | f1: 0.807647705078125 | auc: 0.7685801982879639 | treshold: 0.3


train... loss:1.1787084341049194:   0%|▎                                                                         | 23/5000 [07:32<24:23:01, 17.64s/it]

Test update: epoch: 22 |accuracy: 0.7139644622802734 | f1: 0.8084651827812195 | auc: 0.7721337080001831 | treshold: 0.29000000000000004


train... loss:1.180545449256897:   0%|▎                                                                          | 24/5000 [07:51<24:50:17, 17.97s/it]

Test update: epoch: 23 |accuracy: 0.7170596122741699 | f1: 0.8096315264701843 | auc: 0.7772375345230103 | treshold: 0.31


train... loss:1.173385202884674:   0%|▍                                                                          | 25/5000 [08:10<25:18:50, 18.32s/it]

Test update: epoch: 24 |accuracy: 0.7191874980926514 | f1: 0.8109893798828125 | auc: 0.7834662795066833 | treshold: 0.33


train... loss:1.1592680215835571:   0%|▎                                                                         | 25/5000 [08:31<28:16:36, 20.46s/it]

Test update: epoch: 25 |accuracy: 0.7202998399734497 | f1: 0.8121477961540222 | auc: 0.790030837059021 | treshold: 0.35000000000000003

Evaluating InSlateAttentionSequencewiseGRU with explicit embeddings



train:   0%|                                                                                                                 | 0/5000 [00:00<?, ?it/s]

Test before learning: {'f1': 0.6852515339851379, 'roc-auc': 0.5197587013244629, 'accuracy': 0.5771006941795349}


train... loss:1.3693444728851318:   0%|                                                                           | 1/5000 [00:14<19:55:32, 14.35s/it]

Test update: epoch: 0 |accuracy: 0.6476362943649292 | f1: 0.785023033618927 | auc: 0.5871069431304932 | treshold: 0.45


train... loss:1.3472683429718018:   0%|                                                                           | 2/5000 [00:29<20:23:43, 14.69s/it]

Test update: epoch: 1 |accuracy: 0.6604521870613098 | f1: 0.7891718149185181 | auc: 0.6301552057266235 | treshold: 0.48000000000000004


train... loss:1.3272178769111633:   0%|                                                                           | 3/5000 [00:47<21:33:22, 15.53s/it]

Test update: epoch: 2 |accuracy: 0.6691814661026001 | f1: 0.7923881411552429 | auc: 0.6569207310676575 | treshold: 0.49


train... loss:1.3080652952194214:   0%|                                                                           | 4/5000 [01:05<22:43:07, 16.37s/it]

Test update: epoch: 3 |accuracy: 0.6758311986923218 | f1: 0.794619619846344 | auc: 0.6739804744720459 | treshold: 0.5


train... loss:1.289518117904663:   0%|                                                                            | 5/5000 [01:24<23:52:32, 17.21s/it]

Test update: epoch: 4 |accuracy: 0.6805223226547241 | f1: 0.7959284782409668 | auc: 0.68519526720047 | treshold: 0.51


train... loss:1.2720388770103455:   0%|                                                                           | 6/5000 [01:43<24:39:42, 17.78s/it]

Test update: epoch: 5 |accuracy: 0.684753954410553 | f1: 0.7970231771469116 | auc: 0.6933290958404541 | treshold: 0.52


train... loss:1.2561444640159607:   0%|                                                                           | 7/5000 [02:03<25:32:48, 18.42s/it]

Test update: epoch: 6 |accuracy: 0.6845121383666992 | f1: 0.7972399592399597 | auc: 0.7001067399978638 | treshold: 0.52


train... loss:1.240820288658142:   0%|                                                                            | 8/5000 [02:23<26:00:25, 18.76s/it]

Test update: epoch: 7 |accuracy: 0.6887438297271729 | f1: 0.7983898520469666 | auc: 0.7064001560211182 | treshold: 0.53


train... loss:1.2289145588874817:   0%|▏                                                                          | 9/5000 [02:43<26:31:46, 19.14s/it]

Test update: epoch: 8 |accuracy: 0.6899528503417969 | f1: 0.7991163730621338 | auc: 0.7126234769821167 | treshold: 0.53


train... loss:1.220075249671936:   0%|▏                                                                          | 10/5000 [03:05<27:45:14, 20.02s/it]

Test update: epoch: 9 |accuracy: 0.6913553476333618 | f1: 0.7996169328689575 | auc: 0.7190766334533691 | treshold: 0.53


train... loss:1.2146414518356323:   0%|▏                                                                         | 11/5000 [03:26<28:20:44, 20.45s/it]

Test update: epoch: 10 |accuracy: 0.6926853060722351 | f1: 0.7997920513153076 | auc: 0.7260132431983948 | treshold: 0.53


train... loss:1.2113587856292725:   0%|▏                                                                         | 12/5000 [03:48<28:42:37, 20.72s/it]

Test update: epoch: 11 |accuracy: 0.7009551525115967 | f1: 0.8019854426383972 | auc: 0.7335154414176941 | treshold: 0.55


train... loss:1.2098809480667114:   0%|▏                                                                         | 13/5000 [04:09<28:54:44, 20.87s/it]

Test update: epoch: 12 |accuracy: 0.7032523155212402 | f1: 0.8030049800872803 | auc: 0.7412892580032349 | treshold: 0.54


train... loss:1.205241084098816:   0%|▏                                                                          | 14/5000 [04:30<29:03:29, 20.98s/it]

Test update: epoch: 13 |accuracy: 0.7077257633209229 | f1: 0.8046924471855164 | auc: 0.7489036917686462 | treshold: 0.53


train... loss:1.1967811584472656:   0%|▏                                                                         | 15/5000 [04:50<28:24:43, 20.52s/it]

Test update: epoch: 14 |accuracy: 0.7137951850891113 | f1: 0.8045768141746521 | auc: 0.7557705640792847 | treshold: 0.54


train... loss:1.1828409433364868:   0%|▏                                                                         | 16/5000 [05:09<27:43:36, 20.03s/it]

Test update: epoch: 15 |accuracy: 0.7174707055091858 | f1: 0.8066331148147583 | auc: 0.7615669965744019 | treshold: 0.51


train... loss:1.1675438284873962:   0%|▎                                                                         | 17/5000 [05:29<27:40:54, 20.00s/it]

Test update: epoch: 16 |accuracy: 0.7198162078857422 | f1: 0.8082543015480042 | auc: 0.7663269639015198 | treshold: 0.47000000000000003


train... loss:1.1516030430793762:   0%|▎                                                                         | 18/5000 [05:51<28:44:00, 20.76s/it]

Test update: epoch: 17 |accuracy: 0.7223552465438843 | f1: 0.810177206993103 | auc: 0.7702431082725525 | treshold: 0.43


train... loss:1.1390975713729858:   0%|▎                                                                         | 19/5000 [06:06<26:14:32, 18.97s/it]

Test update: epoch: 18 |accuracy: 0.7236368060112 | f1: 0.8115115165710449 | auc: 0.773574709892273 | treshold: 0.39


train... loss:1.1312572956085205:   0%|▎                                                                         | 20/5000 [06:22<25:03:11, 18.11s/it]

Test update: epoch: 19 |accuracy: 0.7262967228889465 | f1: 0.8132763504981995 | auc: 0.7766565680503845 | treshold: 0.36000000000000004


train... loss:1.1289432644844055:   0%|▎                                                                         | 21/5000 [06:41<25:19:37, 18.31s/it]

Test update: epoch: 20 |accuracy: 0.7289807796478271 | f1: 0.814927339553833 | auc: 0.7797917127609253 | treshold: 0.34


train... loss:1.1301833391189575:   0%|▎                                                                         | 22/5000 [07:00<25:46:20, 18.64s/it]

Test update: epoch: 21 |accuracy: 0.7295369505882263 | f1: 0.8163834810256958 | auc: 0.7831375598907471 | treshold: 0.32


train... loss:1.133038878440857:   0%|▎                                                                          | 23/5000 [07:20<26:19:42, 19.04s/it]

Test update: epoch: 22 |accuracy: 0.7316406965255737 | f1: 0.8168042302131653 | auc: 0.7866204977035522 | treshold: 0.33


train... loss:1.1346312761306763:   0%|▎                                                                         | 24/5000 [07:41<26:59:08, 19.52s/it]

Test update: epoch: 23 |accuracy: 0.733792781829834 | f1: 0.8170928359031677 | auc: 0.7899200320243835 | treshold: 0.35000000000000003


train... loss:1.1327191591262817:   0%|▎                                                                         | 25/5000 [08:01<27:20:49, 19.79s/it]

Test update: epoch: 24 |accuracy: 0.7321968078613281 | f1: 0.8170600533485413 | auc: 0.7928102016448975 | treshold: 0.36000000000000004


train... loss:1.1265175342559814:   1%|▍                                                                         | 26/5000 [08:21<27:21:29, 19.80s/it]

Test update: epoch: 25 |accuracy: 0.731785774230957 | f1: 0.8172411322593689 | auc: 0.7953964471817017 | treshold: 0.38


train... loss:1.1165523529052734:   1%|▍                                                                         | 27/5000 [08:44<28:45:25, 20.82s/it]

Test update: epoch: 26 |accuracy: 0.7319550514221191 | f1: 0.8177918195724487 | auc: 0.7981798648834229 | treshold: 0.4


train... loss:1.1042444109916687:   1%|▍                                                                         | 28/5000 [09:05<28:44:09, 20.81s/it]

Test update: epoch: 27 |accuracy: 0.7361141443252563 | f1: 0.8196406960487366 | auc: 0.8014849424362183 | treshold: 0.43


train... loss:1.0916649103164673:   1%|▍                                                                         | 29/5000 [09:26<28:51:40, 20.90s/it]

Test update: epoch: 28 |accuracy: 0.7369121313095093 | f1: 0.8205982446670532 | auc: 0.8053079843521118 | treshold: 0.45


train... loss:1.078747570514679:   1%|▍                                                                          | 30/5000 [09:47<28:41:48, 20.79s/it]

Test update: epoch: 29 |accuracy: 0.7412646412849426 | f1: 0.8222945332527161 | auc: 0.80943763256073 | treshold: 0.48000000000000004


train... loss:1.0687715411186218:   1%|▍                                                                         | 31/5000 [10:06<28:08:29, 20.39s/it]

Test update: epoch: 30 |accuracy: 0.7431991100311279 | f1: 0.8235294222831726 | auc: 0.8135445713996887 | treshold: 0.5


train... loss:1.061814308166504:   1%|▍                                                                          | 32/5000 [10:25<27:42:14, 20.08s/it]

Test update: epoch: 31 |accuracy: 0.747551679611206 | f1: 0.8250377178192139 | auc: 0.8173600435256958 | treshold: 0.53


train... loss:1.0591154098510742:   1%|▍                                                                         | 33/5000 [10:45<27:41:10, 20.07s/it]

Test update: epoch: 32 |accuracy: 0.7518317103385925 | f1: 0.8262951970100403 | auc: 0.820708155632019 | treshold: 0.56


train... loss:1.0607154965400696:   1%|▌                                                                         | 34/5000 [11:05<27:33:02, 19.97s/it]

Test update: epoch: 33 |accuracy: 0.7543706893920898 | f1: 0.827186107635498 | auc: 0.8235565423965454 | treshold: 0.5800000000000001


train... loss:1.0664087533950806:   1%|▌                                                                         | 35/5000 [11:25<27:27:19, 19.91s/it]

Test update: epoch: 34 |accuracy: 0.756716251373291 | f1: 0.8283779621124268 | auc: 0.8259400725364685 | treshold: 0.59


train... loss:1.07179456949234:   1%|▌                                                                           | 36/5000 [11:47<28:31:01, 20.68s/it]

Test update: epoch: 35 |accuracy: 0.759932279586792 | f1: 0.828561544418335 | auc: 0.8279561400413513 | treshold: 0.61


train... loss:1.0749653577804565:   1%|▌                                                                         | 37/5000 [12:09<28:52:48, 20.95s/it]

Test update: epoch: 36 |accuracy: 0.7611896991729736 | f1: 0.8295595645904541 | auc: 0.8296947479248047 | treshold: 0.6


train... loss:1.0734723210334778:   1%|▌                                                                         | 38/5000 [12:30<29:04:54, 21.10s/it]

Test update: epoch: 37 |accuracy: 0.762108564376831 | f1: 0.8294560313224792 | auc: 0.831250786781311 | treshold: 0.59


train... loss:1.0639789700508118:   1%|▌                                                                         | 39/5000 [12:52<29:23:45, 21.33s/it]

Test update: epoch: 38 |accuracy: 0.7616974711418152 | f1: 0.8295425176620483 | auc: 0.8327465057373047 | treshold: 0.56


train... loss:1.050272822380066:   1%|▌                                                                          | 40/5000 [13:14<29:20:56, 21.30s/it]

Test update: epoch: 39 |accuracy: 0.7632450461387634 | f1: 0.83006751537323 | auc: 0.8342562913894653 | treshold: 0.54


train... loss:1.0337945818901062:   1%|▌                                                                         | 41/5000 [13:35<29:35:28, 21.48s/it]

Test update: epoch: 40 |accuracy: 0.7639704942703247 | f1: 0.8305293917655945 | auc: 0.8358326554298401 | treshold: 0.51


train... loss:1.01697838306427:   1%|▋                                                                           | 42/5000 [13:58<29:51:05, 21.68s/it]

Test update: epoch: 41 |accuracy: 0.7657840847969055 | f1: 0.8310777544975281 | auc: 0.8375252485275269 | treshold: 0.49


train... loss:1.0030112266540527:   1%|▋                                                                         | 43/5000 [14:21<30:25:41, 22.10s/it]

Test update: epoch: 42 |accuracy: 0.7675250768661499 | f1: 0.831622838973999 | auc: 0.8393309116363525 | treshold: 0.47000000000000003


train... loss:0.9930444359779358:   1%|▋                                                                         | 44/5000 [14:41<29:38:36, 21.53s/it]

Test update: epoch: 43 |accuracy: 0.767307460308075 | f1: 0.8328150510787964 | auc: 0.8412758111953735 | treshold: 0.43


train... loss:0.9883936047554016:   1%|▋                                                                         | 45/5000 [15:00<28:25:24, 20.65s/it]

Test update: epoch: 44 |accuracy: 0.7706202268600464 | f1: 0.8342767357826233 | auc: 0.8433526754379272 | treshold: 0.42000000000000004


train... loss:0.989894449710846:   1%|▋                                                                          | 46/5000 [15:16<26:43:16, 19.42s/it]

Test update: epoch: 45 |accuracy: 0.7727481722831726 | f1: 0.8362318277359009 | auc: 0.8455653190612793 | treshold: 0.4


train... loss:0.9962213039398193:   1%|▋                                                                         | 47/5000 [15:31<24:57:21, 18.14s/it]

Test update: epoch: 46 |accuracy: 0.7749002575874329 | f1: 0.8377572894096375 | auc: 0.8477848768234253 | treshold: 0.39


train... loss:1.0058960914611816:   1%|▋                                                                         | 48/5000 [15:50<25:03:37, 18.22s/it]

Test update: epoch: 47 |accuracy: 0.7787933945655823 | f1: 0.8398290872573853 | auc: 0.8499234914779663 | treshold: 0.39


train... loss:1.0154248476028442:   1%|▋                                                                         | 49/5000 [16:10<25:44:58, 18.72s/it]

Test update: epoch: 48 |accuracy: 0.7805343866348267 | f1: 0.8420740962028503 | auc: 0.8520216345787048 | treshold: 0.38


train... loss:1.020628273487091:   1%|▊                                                                          | 50/5000 [16:29<26:06:19, 18.99s/it]

Test update: epoch: 49 |accuracy: 0.7835327982902527 | f1: 0.8444536924362183 | auc: 0.8542534112930298 | treshold: 0.38


train... loss:1.019224226474762:   1%|▊                                                                          | 51/5000 [16:47<25:41:32, 18.69s/it]

Test update: epoch: 50 |accuracy: 0.786047637462616 | f1: 0.8457354307174683 | auc: 0.8565951585769653 | treshold: 0.39


train... loss:1.0112647414207458:   1%|▊                                                                         | 51/5000 [17:01<27:32:20, 20.03s/it]

Test update: epoch: 51 |accuracy: 0.7865070700645447 | f1: 0.8469127416610718 | auc: 0.8589910268783569 | treshold: 0.39

Evaluating InSlateAttentionSequencewiseGRU with svd embeddings



train:   0%|                                                                                                                 | 0/5000 [00:00<?, ?it/s]

Test before learning: {'f1': 0.6479829549789429, 'roc-auc': 0.5475693345069885, 'accuracy': 0.5611171722412109}


train... loss:1.3676549196243286:   0%|                                                                           | 1/5000 [00:18<25:20:03, 18.24s/it]

Test update: epoch: 0 |accuracy: 0.6454358696937561 | f1: 0.7840405106544495 | auc: 0.5811114311218262 | treshold: 0.44


train... loss:1.3445869088172913:   0%|                                                                           | 2/5000 [00:35<25:03:30, 18.05s/it]

Test update: epoch: 1 |accuracy: 0.646088719367981 | f1: 0.7842252850532532 | auc: 0.6024339199066162 | treshold: 0.45


train... loss:1.32430100440979:   0%|                                                                             | 3/5000 [00:53<24:56:07, 17.96s/it]

Test update: epoch: 2 |accuracy: 0.6478297710418701 | f1: 0.7844796776771545 | auc: 0.617066502571106 | treshold: 0.47000000000000003


train... loss:1.3060161471366882:   0%|                                                                           | 4/5000 [01:11<24:44:37, 17.83s/it]

Test update: epoch: 3 |accuracy: 0.6479748487472534 | f1: 0.7845621109008789 | auc: 0.628209114074707 | treshold: 0.47000000000000003


train... loss:1.2895804047584534:   0%|                                                                           | 5/5000 [01:28<24:34:48, 17.72s/it]

Test update: epoch: 4 |accuracy: 0.6499576568603516 | f1: 0.785075843334198 | auc: 0.6378164291381836 | treshold: 0.48000000000000004


train... loss:1.274899959564209:   0%|                                                                            | 6/5000 [01:45<24:25:35, 17.61s/it]

Test update: epoch: 5 |accuracy: 0.6520130634307861 | f1: 0.7854810953140259 | auc: 0.6474413871765137 | treshold: 0.49


train... loss:1.2616733312606812:   0%|                                                                           | 7/5000 [02:03<24:33:24, 17.71s/it]

Test update: epoch: 6 |accuracy: 0.6541893482208252 | f1: 0.7857753038406372 | auc: 0.6581965088844299 | treshold: 0.5


train... loss:1.250448226928711:   0%|                                                                            | 8/5000 [02:20<24:05:03, 17.37s/it]

Test update: epoch: 7 |accuracy: 0.6570910215377808 | f1: 0.7861182689666748 | auc: 0.670337975025177 | treshold: 0.51


train... loss:1.2410127520561218:   0%|▏                                                                          | 9/5000 [02:37<23:58:05, 17.29s/it]

Test update: epoch: 8 |accuracy: 0.6632813215255737 | f1: 0.7867501974105835 | auc: 0.6835139393806458 | treshold: 0.53


train... loss:1.233081042766571:   0%|▏                                                                          | 10/5000 [02:50<22:19:15, 16.10s/it]

Test update: epoch: 9 |accuracy: 0.6663281321525574 | f1: 0.7886797785758972 | auc: 0.6968975067138672 | treshold: 0.52


train... loss:1.2249184846878052:   0%|▏                                                                         | 11/5000 [02:57<18:17:48, 13.20s/it]

Test update: epoch: 10 |accuracy: 0.6743320226669312 | f1: 0.7903878688812256 | auc: 0.7096501588821411 | treshold: 0.53


train... loss:1.215776801109314:   0%|▏                                                                          | 12/5000 [03:03<15:27:45, 11.16s/it]

Test update: epoch: 11 |accuracy: 0.6807882785797119 | f1: 0.792694628238678 | auc: 0.7210733294487 | treshold: 0.52


train... loss:1.2042565941810608:   0%|▏                                                                         | 13/5000 [03:09<13:23:38,  9.67s/it]

Test update: epoch: 12 |accuracy: 0.6888889074325562 | f1: 0.7950098514556885 | auc: 0.7308028340339661 | treshold: 0.51


train... loss:1.1892454624176025:   0%|▏                                                                         | 14/5000 [03:16<11:57:27,  8.63s/it]

Test update: epoch: 13 |accuracy: 0.698125958442688 | f1: 0.7985021471977234 | auc: 0.7387441396713257 | treshold: 0.49


train... loss:1.1749948263168335:   0%|▏                                                                         | 15/5000 [03:22<10:55:11,  7.89s/it]

Test update: epoch: 14 |accuracy: 0.7014145851135254 | f1: 0.8011914491653442 | auc: 0.7449713945388794 | treshold: 0.44


train... loss:1.1621044874191284:   0%|▏                                                                         | 16/5000 [03:28<10:15:51,  7.41s/it]

Test update: epoch: 15 |accuracy: 0.7049933671951294 | f1: 0.802172839641571 | auc: 0.7496806979179382 | treshold: 0.4


train... loss:1.156903862953186:   0%|▎                                                                           | 17/5000 [03:34<9:46:55,  7.07s/it]

Test update: epoch: 16 |accuracy: 0.7085237503051758 | f1: 0.8029297590255737 | auc: 0.7531416416168213 | treshold: 0.36000000000000004


train... loss:1.1612293124198914:   0%|▎                                                                          | 18/5000 [03:40<9:22:51,  6.78s/it]

Test update: epoch: 17 |accuracy: 0.7112078070640564 | f1: 0.8037981986999512 | auc: 0.7557374238967896 | treshold: 0.32


train... loss:1.1746001243591309:   0%|▎                                                                          | 19/5000 [03:47<9:07:06,  6.59s/it]

Test update: epoch: 18 |accuracy: 0.7120541930198669 | f1: 0.8056661486625671 | auc: 0.7581427097320557 | treshold: 0.27


train... loss:1.1916725039482117:   0%|▎                                                                          | 20/5000 [03:53<8:53:49,  6.43s/it]

Test update: epoch: 19 |accuracy: 0.714472234249115 | f1: 0.806780993938446 | auc: 0.7611402273178101 | treshold: 0.25


train... loss:1.2027429938316345:   0%|▎                                                                          | 21/5000 [03:59<8:46:14,  6.34s/it]

Test update: epoch: 20 |accuracy: 0.7124894261360168 | f1: 0.8090941309928894 | auc: 0.7651875019073486 | treshold: 0.21000000000000002


train... loss:1.2009921073913574:   0%|▎                                                                          | 22/5000 [04:05<8:42:02,  6.29s/it]

Test update: epoch: 21 |accuracy: 0.715076744556427 | f1: 0.8103218078613281 | auc: 0.7702997326850891 | treshold: 0.22


train... loss:1.1847588419914246:   0%|▎                                                                          | 23/5000 [04:11<8:37:54,  6.24s/it]

Test update: epoch: 22 |accuracy: 0.7149316668510437 | f1: 0.8109737634658813 | auc: 0.7761297225952148 | treshold: 0.23


train... loss:1.159471035003662:   0%|▎                                                                          | 23/5000 [04:17<15:29:19, 11.20s/it]

Test update: epoch: 23 |accuracy: 0.7169628739356995 | f1: 0.8120675086975098 | auc: 0.7820822596549988 | treshold: 0.26



