In [1]:
import os
import torch
import random
import datetime
import pandas as pd
import numpy as np

from torch.utils.data import Dataset
from src.datasets import RL4RS, ContentWise, DummyData
from src.utils import train, get_dummy_data, get_train_val_test_tmatrix_tnumitems, get_svd_encoder
from src.embeddings import RecsysEmbedding

experiment_name = 'FlattenedGRU'
device = 'cuda:0'
seed = 123
pkl_path = '../pkl/'

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7f5550456bd0>

In [2]:
torch.__version__

'1.12.1'

# Модель

In [3]:
import torch.nn.functional as F
torch.autograd.set_detect_anomaly(True)

class NeuralClickModel(torch.nn.Module):
    def __init__(self, embedding, output_dim=1, dropout = 0.1):
        super().__init__()
        self.embedding_dim = embedding.embedding_dim
        self.embedding = embedding
        self.rnn_layer = torch.nn.GRU(
            input_size = embedding.embedding_dim, 
            hidden_size = embedding.embedding_dim, 
            batch_first = True,
            dropout=dropout
        )
        self.out_layer = torch.nn.Linear(embedding.embedding_dim, output_dim)


    def forward(self, batch):
        shp = batch['slates_item_indexes'].shape
        item_embs, user_embs = self.embedding(batch)
        item_embs = item_embs.flatten(-3, -2)
        
        # while training, let out model see the future 
        # while testing, it can see only the 
        if self.training:
            indices = (batch['length'] - 1)
        else:
            indices = (batch['in_length'] - 1)
        
        indices[indices<0] = 0
        indices = indices[:, None, None].repeat(1, 1, user_embs.size(-1))
        user_embs = user_embs.gather(1, indices).squeeze(-2).unsqueeze(0)

#         print(indices.shape, user_embs.shape, item_embs.shape, )
        rnn_out, _ = self.rnn_layer(
            item_embs,
            user_embs,
        )
        return self.out_layer(rnn_out).reshape(shp)

# Игрушечный датасет: проверим, что сходится к идеальным метрикам

In [4]:
d = DummyData()
dummy_loader, dummy_matrix = get_dummy_data(d)

model = NeuralClickModel(
    RecsysEmbedding(d.n_items, dummy_matrix, embeddings='neural').to('cpu'),
    output_dim=1
).to('cpu')

train(
    model, 
    dummy_loader, dummy_loader, dummy_loader,
    device=device, lr=1e-3, num_epochs=5000, dummy=True,
    silent=True,
)


biulding affinity matrix...


3it [00:00, 2683.50it/s]


Test before learning: {'f1': 0.4000000059604645, 'roc-auc': 0.3333333134651184, 'accuracy': 0.25}


train... loss:0.7189432978630066:   0%|                                                                                                    | 1/5000 [00:00<29:11,  2.85it/s]

Val update: epoch: 0 |accuracy: 0.5 | f1: 0.5 | auc: 0.3333333134651184 | treshold: 0.51
Test: accuracy: 0.5 | f1: 0.5 | auc: 0.3333333134651184 | 


train... loss:0.6976741552352905:   0%|                                                                                                    | 3/5000 [00:01<30:55,  2.69it/s]

Val update: epoch: 3 |accuracy: 1.0 | f1: 1.0 | auc: 1.0 | treshold: 0.54
Test: accuracy: 1.0 | f1: 1.0 | auc: 1.0 | 





(NeuralClickModel(
   (embedding): RecsysEmbedding(
     (item_embeddings): Embedding(5, 32)
   )
   (rnn_layer): GRU(32, 32, batch_first=True, dropout=0.1)
   (out_layer): Linear(in_features=32, out_features=1, bias=True)
 ),
 {'f1': 1.0, 'roc-auc': 1.0, 'accuracy': 1.0})

# ContentWise

In [5]:
content_wise_results = []
dataset = ContentWise.load(os.path.join(pkl_path, 'cw.pkl'))
(
    train_loader, 
    val_loader, 
    test_loader, 
    train_user_item_matrix, 
    train_num_items 
) = get_train_val_test_tmatrix_tnumitems(dataset, batch_size=150)

print(f"{len(dataset)} data points among {len(train_loader)} batches")

20216 data points among 108 batches


In [6]:
for embeddings in ['svd', 'neural']:
    print(f"\nEvaluating {experiment_name} with {embeddings} embeddings")
    
    model = NeuralClickModel(
        RecsysEmbedding(train_num_items, train_user_item_matrix, embeddings=embeddings),
        output_dim=1
    ).to(device)

    _, metrics = train(
        model, 
        train_loader, val_loader, test_loader, 
        device=device, lr=1e-3, num_epochs=5000, early_stopping=7,
       silent=True, 
    )
    
    metrics['embeddings'] = embeddings
    content_wise_results.append(metrics)


Evaluating FlattenedGRU with svd embeddings




Test before learning: {'f1': 0.18066997826099396, 'roc-auc': 0.5031406283378601, 'accuracy': 0.10415486246347427}


train... loss:47.45200476050377:   0%|                                                                                                  | 1/5000 [00:50<69:43:38, 50.21s/it]

Val update: epoch: 0 |accuracy: 0.10144636034965515 | f1: 0.18420572578907013 | auc: 0.5617542266845703 | treshold: 0.01
Test: accuracy: 0.09924836456775665 | f1: 0.18057496845722198 | auc: 0.5637004375457764 | 


train... loss:33.96879383921623:   0%|                                                                                                  | 2/5000 [01:08<43:40:45, 31.46s/it]

Val update: epoch: 1 |accuracy: 0.10144636034965515 | f1: 0.18420572578907013 | auc: 0.5710059404373169 | treshold: 0.060000000000000005
Test: accuracy: 0.09924836456775665 | f1: 0.18057496845722198 | auc: 0.5705729126930237 | 


train... loss:33.519731253385544:   0%|                                                                                                 | 3/5000 [01:26<35:05:23, 25.28s/it]

Val update: epoch: 2 |accuracy: 0.8381648659706116 | f1: 0.22174333035945892 | auc: 0.5896240472793579 | treshold: 0.06999999999999999
Test: accuracy: 0.8454231023788452 | f1: 0.22342099249362946 | auc: 0.5852904319763184 | 


train... loss:33.00463917851448:   0%|                                                                                                  | 4/5000 [01:44<31:12:57, 22.49s/it]

Val update: epoch: 3 |accuracy: 0.8444540500640869 | f1: 0.23405873775482178 | auc: 0.6137709617614746 | treshold: 0.13
Test: accuracy: 0.8503146767616272 | f1: 0.2286943793296814 | auc: 0.6124370098114014 | 


train... loss:32.77262130379677:   0%|                                                                                                  | 5/5000 [02:02<29:04:24, 20.95s/it]

Val update: epoch: 4 |accuracy: 0.7602682709693909 | f1: 0.24351473152637482 | auc: 0.626449465751648 | treshold: 0.12
Test: accuracy: 0.7713484764099121 | f1: 0.23660625517368317 | auc: 0.6238862872123718 | 


train... loss:32.62635996937752:   0%|                                                                                                  | 6/5000 [02:21<27:52:04, 20.09s/it]

Val update: epoch: 5 |accuracy: 0.8053898811340332 | f1: 0.24141669273376465 | auc: 0.6366562843322754 | treshold: 0.12
Test: accuracy: 0.814984917640686 | f1: 0.23917576670646667 | auc: 0.6320788264274597 | 


train... loss:32.4938767850399:   0%|▏                                                                                                  | 7/5000 [02:39<26:58:28, 19.45s/it]

Val update: epoch: 6 |accuracy: 0.8384429812431335 | f1: 0.2384733110666275 | auc: 0.6414768695831299 | treshold: 0.12
Test: accuracy: 0.8452292084693909 | f1: 0.23578792810440063 | auc: 0.6352653503417969 | 


train... loss:32.41746339201927:   0%|▏                                                                                                 | 8/5000 [02:57<26:27:09, 19.08s/it]

Val update: epoch: 7 |accuracy: 0.8548691272735596 | f1: 0.22851979732513428 | auc: 0.6483805775642395 | treshold: 0.12
Test: accuracy: 0.8616637587547302 | f1: 0.23110079765319824 | auc: 0.6453441381454468 | 


train... loss:32.3449187874794:   0%|▏                                                                                                 | 10/5000 [03:32<25:27:46, 18.37s/it]

Val update: epoch: 9 |accuracy: 0.8683128952980042 | f1: 0.22414420545101166 | auc: 0.6489580869674683 | treshold: 0.12
Test: accuracy: 0.871924102306366 | f1: 0.2154211550951004 | auc: 0.6482113599777222 | 


train... loss:32.17203137278557:   0%|▎                                                                                                | 13/5000 [04:25<24:48:05, 17.90s/it]

Val update: epoch: 12 |accuracy: 0.838612973690033 | f1: 0.24175983667373657 | auc: 0.6524484753608704 | treshold: 0.12
Test: accuracy: 0.8462880849838257 | f1: 0.23883022367954254 | auc: 0.6459234952926636 | 


train... loss:32.140176087617874:   0%|▎                                                                                               | 15/5000 [05:00<24:35:24, 17.76s/it]

Val update: epoch: 14 |accuracy: 0.791729748249054 | f1: 0.24391338229179382 | auc: 0.6532397270202637 | treshold: 0.12
Test: accuracy: 0.8001312613487244 | f1: 0.2398184984922409 | auc: 0.6491990089416504 | 


train... loss:32.167258113622665:   0%|▎                                                                                               | 16/5000 [05:19<24:49:56, 17.94s/it]

Val update: epoch: 15 |accuracy: 0.7101399898529053 | f1: 0.2557530403137207 | auc: 0.6587523221969604 | treshold: 0.11
Test: accuracy: 0.7177051305770874 | f1: 0.24798378348350525 | auc: 0.6532393097877502 | 


train... loss:32.00305077433586:   0%|▍                                                                                                | 21/5000 [06:44<23:54:31, 17.29s/it]

Val update: epoch: 20 |accuracy: 0.7737274765968323 | f1: 0.25103574991226196 | auc: 0.6606159210205078 | treshold: 0.12
Test: accuracy: 0.7855310440063477 | f1: 0.24655525386333466 | auc: 0.6542736291885376 | 


train... loss:31.969020813703537:   0%|▍                                                                                               | 22/5000 [07:02<24:09:01, 17.47s/it]

Val update: epoch: 21 |accuracy: 0.7739283442497253 | f1: 0.24974359571933746 | auc: 0.6607446074485779 | treshold: 0.12
Test: accuracy: 0.7838757038116455 | f1: 0.2433166205883026 | auc: 0.6523022651672363 | 


train... loss:31.969043254852295:   0%|▍                                                                                               | 23/5000 [07:19<24:12:21, 17.51s/it]

Val update: epoch: 22 |accuracy: 0.7815001606941223 | f1: 0.25500527024269104 | auc: 0.662692666053772 | treshold: 0.12
Test: accuracy: 0.7912130355834961 | f1: 0.2469879537820816 | auc: 0.655503511428833 | 


train... loss:31.924747109413147:   0%|▍                                                                                               | 25/5000 [07:53<23:52:40, 17.28s/it]

Val update: epoch: 24 |accuracy: 0.7860741019248962 | f1: 0.257136732339859 | auc: 0.6627814769744873 | treshold: 0.12
Test: accuracy: 0.795045793056488 | f1: 0.24840033054351807 | auc: 0.6563730239868164 | 


train... loss:31.862241834402084:   1%|▌                                                                                               | 30/5000 [09:19<23:47:46, 17.24s/it]

Val update: epoch: 29 |accuracy: 0.793939471244812 | f1: 0.2500421702861786 | auc: 0.6643167734146118 | treshold: 0.12
Test: accuracy: 0.8038297295570374 | f1: 0.24799908697605133 | auc: 0.6584506630897522 | 


train... loss:31.782037377357483:   1%|▋                                                                                               | 37/5000 [11:16<23:20:51, 16.94s/it]

Val update: epoch: 36 |accuracy: 0.8240875005722046 | f1: 0.25055956840515137 | auc: 0.6656495332717896 | treshold: 0.12
Test: accuracy: 0.8304500579833984 | f1: 0.24613752961158752 | auc: 0.6599881649017334 | 


train... loss:31.747726321220398:   1%|▋                                                                                               | 39/5000 [11:50<23:30:36, 17.06s/it]

Val update: epoch: 38 |accuracy: 0.8457520604133606 | f1: 0.24969933927059174 | auc: 0.6660052537918091 | treshold: 0.12
Test: accuracy: 0.8506278395652771 | f1: 0.23959915339946747 | auc: 0.6579734683036804 | 


train... loss:31.71501988172531:   1%|▊                                                                                                | 40/5000 [12:08<23:48:20, 17.28s/it]

Val update: epoch: 39 |accuracy: 0.852566659450531 | f1: 0.2477331906557083 | auc: 0.6669719219207764 | treshold: 0.12
Test: accuracy: 0.8585617542266846 | f1: 0.24200767278671265 | auc: 0.6566804647445679 | 


train... loss:31.652738571166992:   1%|▊                                                                                               | 43/5000 [12:59<23:38:49, 17.17s/it]

Val update: epoch: 42 |accuracy: 0.8555335998535156 | f1: 0.24525712430477142 | auc: 0.6681803464889526 | treshold: 0.12
Test: accuracy: 0.8596802353858948 | f1: 0.2372111827135086 | auc: 0.6595054864883423 | 


train... loss:31.53829249739647:   1%|█                                                                                                | 54/5000 [16:06<23:34:34, 17.16s/it]

Val update: epoch: 53 |accuracy: 0.7371202707290649 | f1: 0.25711789727211 | auc: 0.6687147617340088 | treshold: 0.12
Test: accuracy: 0.7478300929069519 | f1: 0.2491229623556137 | auc: 0.6627209186553955 | 


train... loss:31.52682074904442:   1%|█                                                                                                | 57/5000 [16:57<23:33:04, 17.15s/it]

Val update: epoch: 56 |accuracy: 0.8393083214759827 | f1: 0.25117015838623047 | auc: 0.6689668893814087 | treshold: 0.12
Test: accuracy: 0.846123993396759 | f1: 0.25047218799591064 | auc: 0.6609026193618774 | 


train... loss:31.540336340665817:   1%|█                                                                                               | 58/5000 [17:14<23:36:25, 17.20s/it]

Val update: epoch: 57 |accuracy: 0.8440059423446655 | f1: 0.2490515559911728 | auc: 0.6704317927360535 | treshold: 0.12
Test: accuracy: 0.8510156273841858 | f1: 0.2505626380443573 | auc: 0.6647539138793945 | 


train... loss:30.975221782922745:   2%|█▋                                                                                              | 91/5000 [26:23<23:01:41, 16.89s/it]

Val update: epoch: 90 |accuracy: 0.8646815419197083 | f1: 0.2476157695055008 | auc: 0.670540452003479 | treshold: 0.13
Test: accuracy: 0.8697169423103333 | f1: 0.25 | auc: 0.6640154719352722 | 


train... loss:30.461345821619034:   2%|██▏                                                                                            | 114/5000 [32:45<22:55:31, 16.89s/it]

Val update: epoch: 113 |accuracy: 0.7616744637489319 | f1: 0.260890394449234 | auc: 0.6707226634025574 | treshold: 0.13
Test: accuracy: 0.7733170390129089 | f1: 0.2618492543697357 | auc: 0.6691310405731201 | 


train... loss:31.149161636829376:   3%|██▍                                                                                            | 129/5000 [37:14<23:26:10, 17.32s/it]



Evaluating FlattenedGRU with neural embeddings
Test before learning: {'f1': 0.17595963180065155, 'roc-auc': 0.5231964588165283, 'accuracy': 0.19640886783599854}


train... loss:46.0183342397213:   0%|                                                                                                   | 1/5000 [00:18<25:59:09, 18.71s/it]

Val update: epoch: 0 |accuracy: 0.10144636034965515 | f1: 0.18420572578907013 | auc: 0.6128265857696533 | treshold: 0.01
Test: accuracy: 0.09924836456775665 | f1: 0.18057496845722198 | auc: 0.6116218566894531 | 


train... loss:33.16570806503296:   0%|                                                                                                  | 2/5000 [00:37<26:05:10, 18.79s/it]

Val update: epoch: 1 |accuracy: 0.8780634999275208 | f1: 0.1843927651643753 | auc: 0.6462735533714294 | treshold: 0.11
Test: accuracy: 0.8818116784095764 | f1: 0.18121707439422607 | auc: 0.647661566734314 | 


train... loss:32.25834980607033:   0%|                                                                                                  | 3/5000 [00:56<26:11:18, 18.87s/it]

Val update: epoch: 2 |accuracy: 0.595589816570282 | f1: 0.24759221076965332 | auc: 0.6744126081466675 | treshold: 0.13
Test: accuracy: 0.6014704704284668 | f1: 0.24278144538402557 | auc: 0.6768244504928589 | 


train... loss:31.71683043241501:   0%|                                                                                                  | 4/5000 [01:15<26:15:54, 18.93s/it]

Val update: epoch: 3 |accuracy: 0.8153722286224365 | f1: 0.27137455344200134 | auc: 0.6926244497299194 | treshold: 0.15000000000000002
Test: accuracy: 0.8211143016815186 | f1: 0.27483224868774414 | auc: 0.6950615644454956 | 


train... loss:31.29829803109169:   0%|                                                                                                  | 5/5000 [01:34<26:23:52, 19.03s/it]

Val update: epoch: 4 |accuracy: 0.8557190299034119 | f1: 0.2567858099937439 | auc: 0.7044557332992554 | treshold: 0.13
Test: accuracy: 0.8609926104545593 | f1: 0.26600518822669983 | auc: 0.7042655944824219 | 


train... loss:30.943330705165863:   0%|                                                                                                 | 6/5000 [01:54<26:33:54, 19.15s/it]

Val update: epoch: 5 |accuracy: 0.6921531558036804 | f1: 0.2822452783584595 | auc: 0.7094480991363525 | treshold: 0.13
Test: accuracy: 0.6974527835845947 | f1: 0.27899208664894104 | auc: 0.7082265615463257 | 


train... loss:30.867054343223572:   0%|▏                                                                                                | 7/5000 [02:13<26:42:32, 19.26s/it]

Val update: epoch: 6 |accuracy: 0.7409370541572571 | f1: 0.28455597162246704 | auc: 0.71052086353302 | treshold: 0.16
Test: accuracy: 0.7461150884628296 | f1: 0.28524646162986755 | auc: 0.709149956703186 | 


train... loss:30.575381606817245:   0%|▏                                                                                                | 8/5000 [02:32<26:41:11, 19.25s/it]

Val update: epoch: 7 |accuracy: 0.8456748127937317 | f1: 0.2722436785697937 | auc: 0.7144083976745605 | treshold: 0.15000000000000002
Test: accuracy: 0.8508217334747314 | f1: 0.28340139985084534 | auc: 0.7122294306755066 | 


train... loss:30.492777109146118:   0%|▏                                                                                                | 9/5000 [02:51<26:36:46, 19.20s/it]

Val update: epoch: 8 |accuracy: 0.8608801960945129 | f1: 0.2711891829967499 | auc: 0.7202092409133911 | treshold: 0.14
Test: accuracy: 0.8642139434814453 | f1: 0.275137335062027 | auc: 0.7161108255386353 | 


train... loss:30.189866691827774:   0%|▏                                                                                               | 11/5000 [03:29<26:17:01, 18.97s/it]

Val update: epoch: 10 |accuracy: 0.7430385947227478 | f1: 0.29600778222084045 | auc: 0.7223330140113831 | treshold: 0.16
Test: accuracy: 0.7480985522270203 | f1: 0.29287898540496826 | auc: 0.7188670635223389 | 


train... loss:29.894390612840652:   0%|▎                                                                                               | 14/5000 [04:24<26:02:43, 18.81s/it]

Val update: epoch: 13 |accuracy: 0.8483635783195496 | f1: 0.2895822823047638 | auc: 0.7230022549629211 | treshold: 0.14
Test: accuracy: 0.8511199951171875 | f1: 0.29493609070777893 | auc: 0.7204040884971619 | 


train... loss:29.867526412010193:   0%|▎                                                                                               | 15/5000 [04:44<26:16:40, 18.98s/it]

Val update: epoch: 14 |accuracy: 0.8557499051094055 | f1: 0.29146111011505127 | auc: 0.7287898063659668 | treshold: 0.14
Test: accuracy: 0.8579950332641602 | f1: 0.29320070147514343 | auc: 0.722341775894165 | 


train... loss:29.75363126397133:   0%|▎                                                                                                | 16/5000 [05:04<26:33:20, 19.18s/it]

Val update: epoch: 15 |accuracy: 0.8459529876708984 | f1: 0.2978093922138214 | auc: 0.7311632633209229 | treshold: 0.15000000000000002
Test: accuracy: 0.8474363684654236 | f1: 0.2973901033401489 | auc: 0.7257110476493835 | 


train... loss:29.371671050786972:   0%|▍                                                                                               | 21/5000 [06:36<25:49:53, 18.68s/it]

Val update: epoch: 20 |accuracy: 0.7465617656707764 | f1: 0.3047183156013489 | auc: 0.7317176461219788 | treshold: 0.16
Test: accuracy: 0.7506338357925415 | f1: 0.2986451983451843 | auc: 0.7267630696296692 | 


train... loss:29.293291866779327:   0%|▍                                                                                               | 22/5000 [06:55<26:12:03, 18.95s/it]

Val update: epoch: 21 |accuracy: 0.7955465316772461 | f1: 0.31406500935554504 | auc: 0.732096254825592 | treshold: 0.17
Test: accuracy: 0.7989978194236755 | f1: 0.30461251735687256 | auc: 0.7257505059242249 | 


train... loss:29.243009239435196:   0%|▍                                                                                               | 23/5000 [07:15<26:27:20, 19.14s/it]

Val update: epoch: 22 |accuracy: 0.8278579711914062 | f1: 0.3067828118801117 | auc: 0.7333426475524902 | treshold: 0.17
Test: accuracy: 0.8300474286079407 | f1: 0.30520668625831604 | auc: 0.7261786460876465 | 


train... loss:29.048662051558495:   1%|▌                                                                                               | 28/5000 [08:47<25:33:06, 18.50s/it]

Val update: epoch: 27 |accuracy: 0.8572333455085754 | f1: 0.2861778438091278 | auc: 0.7346652746200562 | treshold: 0.15000000000000002
Test: accuracy: 0.8617532253265381 | f1: 0.28954628109931946 | auc: 0.7275290489196777 | 


train... loss:29.155868589878082:   1%|█▎                                                                                              | 68/5000 [21:10<25:35:39, 18.68s/it]


In [7]:
pd.DataFrame(content_wise_results).to_csv(f'results/cw_{experiment_name}.csv')
del dataset, train_loader, val_loader, test_loader, train_user_item_matrix, train_num_items

# RL4RS

In [None]:
rl4rs_results = []
dataset = RL4RS.load(os.path.join(pkl_path, 'rl4rs.pkl'))
(
    train_loader, 
    val_loader, 
    test_loader, 
    train_user_item_matrix, 
    train_num_items 
) = get_train_val_test_tmatrix_tnumitems(dataset, batch_size=350)

print(f"{len(dataset)} data points among {len(train_loader)} batches")

45942 data points among 106 batches


In [10]:
for embeddings in ['neural', 'svd',  ]:
    print(f"\nEvaluating {experiment_name} with {embeddings} embeddings")

    model = NeuralClickModel(
        RecsysEmbedding(train_num_items, train_user_item_matrix, embeddings=embeddings),
        output_dim=1
    ).to(device)


    best_model, metrics = train(
        model, 
        train_loader, val_loader, test_loader, 
        device=device, lr=1e-3, num_epochs=5000, early_stopping=7,
        silent=True
    )
    
    metrics['embeddings'] = embeddings
    rl4rs_results.append(metrics)


Evaluating FlattenedGRU with neural embeddings
Test before learning: {'f1': 0.15188458561897278, 'roc-auc': 0.5037975907325745, 'accuracy': 0.37591585516929626}


train... loss:64.7874805033207:   0%|                                                                                                   | 1/5000 [00:07<10:38:09,  7.66s/it]

Val update: epoch: 0 |accuracy: 0.7316790223121643 | f1: 0.8233214616775513 | auc: 0.7749164700508118 | treshold: 0.4
Test: accuracy: 0.7273122668266296 | f1: 0.8199799060821533 | auc: 0.7782286405563354 | 


train... loss:54.15752696990967:   0%|                                                                                                  | 2/5000 [00:15<10:35:35,  7.63s/it]

Val update: epoch: 1 |accuracy: 0.7767619490623474 | f1: 0.8362255692481995 | auc: 0.8285561800003052 | treshold: 0.39
Test: accuracy: 0.7797122597694397 | f1: 0.8374955654144287 | auc: 0.8331220149993896 | 


train... loss:48.73104628920555:   0%|                                                                                                  | 3/5000 [00:22<10:35:14,  7.63s/it]

Val update: epoch: 2 |accuracy: 0.7658782005310059 | f1: 0.8443930149078369 | auc: 0.8522939682006836 | treshold: 0.3
Test: accuracy: 0.7618426084518433 | f1: 0.8411015868186951 | auc: 0.8547146320343018 | 


train... loss:44.52147850394249:   0%|                                                                                                  | 4/5000 [00:30<10:33:39,  7.61s/it]

Val update: epoch: 3 |accuracy: 0.8015769124031067 | f1: 0.8580917716026306 | auc: 0.8656052350997925 | treshold: 0.34
Test: accuracy: 0.8025389909744263 | f1: 0.8578194379806519 | auc: 0.8713679909706116 | 


train... loss:41.00608304142952:   0%|                                                                                                  | 5/5000 [00:38<10:35:10,  7.63s/it]

Val update: epoch: 4 |accuracy: 0.78592848777771 | f1: 0.8273279070854187 | auc: 0.8773184418678284 | treshold: 0.36000000000000004
Test: accuracy: 0.789408802986145 | f1: 0.8291046023368835 | auc: 0.879605770111084 | 


train... loss:37.33762949705124:   0%|                                                                                                  | 6/5000 [00:45<10:34:00,  7.62s/it]

Val update: epoch: 5 |accuracy: 0.8048904538154602 | f1: 0.8612892627716064 | auc: 0.8965339660644531 | treshold: 0.35000000000000003
Test: accuracy: 0.8079797029495239 | f1: 0.8628379106521606 | auc: 0.8982766270637512 | 


train... loss:35.2119197845459:   0%|▏                                                                                                  | 7/5000 [00:53<10:35:35,  7.64s/it]

Val update: epoch: 6 |accuracy: 0.8312774896621704 | f1: 0.8711298108100891 | auc: 0.9074010848999023 | treshold: 0.36000000000000004
Test: accuracy: 0.8335388898849487 | f1: 0.872239351272583 | auc: 0.9090372323989868 | 


train... loss:32.93538871407509:   0%|▏                                                                                                 | 9/5000 [01:08<10:23:08,  7.49s/it]

Val update: epoch: 8 |accuracy: 0.8383156657218933 | f1: 0.8786376714706421 | auc: 0.9118818044662476 | treshold: 0.33
Test: accuracy: 0.8380365371704102 | f1: 0.87770676612854 | auc: 0.9120757579803467 | 


train... loss:30.683799102902412:   0%|▏                                                                                               | 12/5000 [01:29<10:11:20,  7.35s/it]

Val update: epoch: 11 |accuracy: 0.846756637096405 | f1: 0.8835208415985107 | auc: 0.9180481433868408 | treshold: 0.35000000000000003
Test: accuracy: 0.8450731635093689 | f1: 0.8815514445304871 | auc: 0.9165759086608887 | 


train... loss:28.520894944667816:   0%|▎                                                                                               | 17/5000 [02:05<10:12:52,  7.38s/it]

Val update: epoch: 16 |accuracy: 0.8472161889076233 | f1: 0.8849425315856934 | auc: 0.9183204770088196 | treshold: 0.44
Test: accuracy: 0.8462579846382141 | f1: 0.8833479881286621 | auc: 0.9177730083465576 | 


train... loss:28.242462500929832:   0%|▎                                                                                               | 18/5000 [02:13<10:22:33,  7.50s/it]

Val update: epoch: 17 |accuracy: 0.8451845645904541 | f1: 0.8781481385231018 | auc: 0.9187286496162415 | treshold: 0.41000000000000003
Test: accuracy: 0.8455809354782104 | f1: 0.877859354019165 | auc: 0.9183375239372253 | 


train... loss:26.950839176774025:   1%|▌                                                                                               | 27/5000 [03:18<10:08:06,  7.34s/it]

Val update: epoch: 26 |accuracy: 0.8479417562484741 | f1: 0.8859376907348633 | auc: 0.9203897714614868 | treshold: 0.41000000000000003
Test: accuracy: 0.84548419713974 | f1: 0.8833558559417725 | auc: 0.9181041717529297 | 


train... loss:22.361636720597744:   2%|██                                                                                              | 105/5000 [12:43<9:53:22,  7.27s/it]



Evaluating FlattenedGRU with svd embeddings
Test before learning: {'f1': 0.0006719426601193845, 'roc-auc': 0.5468539595603943, 'accuracy': 0.3526780307292938}


train... loss:70.09309810400009:   0%|                                                                                                   | 1/5000 [00:06<9:22:52,  6.76s/it]

Val update: epoch: 0 |accuracy: 0.6511633396148682 | f1: 0.7887328267097473 | auc: 0.6260797381401062 | treshold: 0.01
Test: accuracy: 0.6475154161453247 | f1: 0.7860507965087891 | auc: 0.6269787549972534 | 


train... loss:65.98856726288795:   0%|                                                                                                   | 2/5000 [00:13<9:22:08,  6.75s/it]

Val update: epoch: 1 |accuracy: 0.6511633396148682 | f1: 0.7887328267097473 | auc: 0.7030317187309265 | treshold: 0.23
Test: accuracy: 0.6475154161453247 | f1: 0.7860507965087891 | auc: 0.7036685943603516 | 


train... loss:62.28320199251175:   0%|                                                                                                   | 3/5000 [00:20<9:25:01,  6.78s/it]

Val update: epoch: 2 |accuracy: 0.7127170562744141 | f1: 0.7995679974555969 | auc: 0.757646381855011 | treshold: 0.3
Test: accuracy: 0.7097328305244446 | f1: 0.7966044545173645 | auc: 0.7618816494941711 | 


train... loss:59.59153187274933:   0%|                                                                                                   | 4/5000 [00:27<9:25:52,  6.80s/it]

Val update: epoch: 3 |accuracy: 0.7289701700210571 | f1: 0.7940073609352112 | auc: 0.7826544046401978 | treshold: 0.4
Test: accuracy: 0.7292225956916809 | f1: 0.7932575345039368 | auc: 0.7888569235801697 | 


train... loss:56.55192482471466:   0%|                                                                                                   | 6/5000 [00:40<9:16:30,  6.69s/it]

Val update: epoch: 5 |accuracy: 0.712547779083252 | f1: 0.8076269626617432 | auc: 0.783757209777832 | treshold: 0.34
Test: accuracy: 0.7134566307067871 | f1: 0.8074297904968262 | auc: 0.7886998653411865 | 


train... loss:55.714584827423096:   0%|▏                                                                                                 | 7/5000 [00:47<9:21:06,  6.74s/it]

Val update: epoch: 6 |accuracy: 0.7250761985778809 | f1: 0.8163561820983887 | auc: 0.7984519004821777 | treshold: 0.38
Test: accuracy: 0.7233466506004333 | f1: 0.8145073652267456 | auc: 0.801854133605957 | 


train... loss:53.969640642404556:   0%|▏                                                                                                 | 8/5000 [00:53<9:22:28,  6.76s/it]

Val update: epoch: 7 |accuracy: 0.7273254990577698 | f1: 0.821839451789856 | auc: 0.8129879832267761 | treshold: 0.35000000000000003
Test: accuracy: 0.7239511609077454 | f1: 0.8191266655921936 | auc: 0.8156179189682007 | 


train... loss:52.99818441271782:   0%|▏                                                                                                  | 9/5000 [01:00<9:24:33,  6.79s/it]

Val update: epoch: 8 |accuracy: 0.7378222942352295 | f1: 0.8275368213653564 | auc: 0.831358015537262 | treshold: 0.35000000000000003
Test: accuracy: 0.7362350225448608 | f1: 0.8259842991828918 | auc: 0.8340294361114502 | 


train... loss:52.36100506782532:   0%|▏                                                                                                 | 10/5000 [01:07<9:24:04,  6.78s/it]

Val update: epoch: 9 |accuracy: 0.7633386254310608 | f1: 0.8361547589302063 | auc: 0.8402623534202576 | treshold: 0.36000000000000004
Test: accuracy: 0.7629790902137756 | f1: 0.835597574710846 | auc: 0.8428791165351868 | 


train... loss:50.20127519965172:   0%|▏                                                                                                 | 12/5000 [01:20<9:16:56,  6.70s/it]

Val update: epoch: 11 |accuracy: 0.7564214468002319 | f1: 0.8353846669197083 | auc: 0.8419333100318909 | treshold: 0.39
Test: accuracy: 0.7556523084640503 | f1: 0.8344447016716003 | auc: 0.8477630019187927 | 


train... loss:48.21806624531746:   0%|▎                                                                                                 | 13/5000 [01:27<9:20:39,  6.75s/it]

Val update: epoch: 12 |accuracy: 0.7878392338752747 | f1: 0.837597668170929 | auc: 0.8572051525115967 | treshold: 0.4
Test: accuracy: 0.7905210852622986 | f1: 0.8387349247932434 | auc: 0.8620035648345947 | 


train... loss:46.80847439169884:   0%|▎                                                                                                 | 14/5000 [01:34<9:24:47,  6.80s/it]

Val update: epoch: 13 |accuracy: 0.7824698686599731 | f1: 0.8470772504806519 | auc: 0.8572285175323486 | treshold: 0.4
Test: accuracy: 0.7845000624656677 | f1: 0.8477101922035217 | auc: 0.8641485571861267 | 


train... loss:44.69212186336517:   0%|▎                                                                                                 | 16/5000 [01:47<9:19:09,  6.73s/it]

Val update: epoch: 15 |accuracy: 0.7969815731048584 | f1: 0.84527188539505 | auc: 0.8668402433395386 | treshold: 0.38
Test: accuracy: 0.7988393306732178 | f1: 0.8455124497413635 | auc: 0.8708468675613403 | 


train... loss:43.47337245941162:   0%|▎                                                                                                 | 17/5000 [01:54<9:23:44,  6.79s/it]

Val update: epoch: 16 |accuracy: 0.7895805835723877 | f1: 0.8332886099815369 | auc: 0.8668846487998962 | treshold: 0.4
Test: accuracy: 0.7934228181838989 | f1: 0.8355628848075867 | auc: 0.8707714080810547 | 


train... loss:42.73639187216759:   0%|▎                                                                                                 | 18/5000 [02:01<9:25:50,  6.81s/it]

Val update: epoch: 17 |accuracy: 0.8049871921539307 | f1: 0.8555665016174316 | auc: 0.8759992122650146 | treshold: 0.39
Test: accuracy: 0.8057066798210144 | f1: 0.8548932075500488 | auc: 0.8794417977333069 | 


train... loss:41.530331671237946:   0%|▍                                                                                                | 20/5000 [02:14<9:18:19,  6.73s/it]

Val update: epoch: 19 |accuracy: 0.8067286014556885 | f1: 0.8517852425575256 | auc: 0.8792979717254639 | treshold: 0.39
Test: accuracy: 0.8069157004356384 | f1: 0.8507393002510071 | auc: 0.8806277513504028 | 


train... loss:40.23729434609413:   0%|▍                                                                                                 | 24/5000 [02:40<9:03:05,  6.55s/it]

Val update: epoch: 23 |accuracy: 0.8063174486160278 | f1: 0.859622061252594 | auc: 0.8824865818023682 | treshold: 0.39
Test: accuracy: 0.80396568775177 | f1: 0.8569108843803406 | auc: 0.884792685508728 | 


train... loss:39.285666793584824:   0%|▍                                                                                                | 25/5000 [02:47<9:11:18,  6.65s/it]

Val update: epoch: 24 |accuracy: 0.811106264591217 | f1: 0.8638613820075989 | auc: 0.88904869556427 | treshold: 0.38
Test: accuracy: 0.8081973195075989 | f1: 0.8610906600952148 | auc: 0.889386773109436 | 


train... loss:38.4742616713047:   1%|▌                                                                                                  | 28/5000 [03:06<9:07:19,  6.60s/it]

Val update: epoch: 27 |accuracy: 0.8198133111000061 | f1: 0.8663676977157593 | auc: 0.8933144211769104 | treshold: 0.4
Test: accuracy: 0.8178938627243042 | f1: 0.8640687465667725 | auc: 0.8938559889793396 | 


train... loss:37.75227200984955:   1%|▋                                                                                                 | 34/5000 [03:45<8:58:47,  6.51s/it]

Val update: epoch: 33 |accuracy: 0.8269965648651123 | f1: 0.8709054589271545 | auc: 0.8990672826766968 | treshold: 0.42000000000000004
Test: accuracy: 0.8245193958282471 | f1: 0.8679801225662231 | auc: 0.898916482925415 | 


train... loss:36.8710018992424:   1%|▋                                                                                                  | 37/5000 [04:04<8:58:04,  6.51s/it]

Val update: epoch: 36 |accuracy: 0.8239974975585938 | f1: 0.8713151216506958 | auc: 0.8991389274597168 | treshold: 0.4
Test: accuracy: 0.821520984172821 | f1: 0.8685414791107178 | auc: 0.8992078304290771 | 


train... loss:36.543319791555405:   1%|▋                                                                                                | 38/5000 [04:11<9:05:52,  6.60s/it]

Val update: epoch: 37 |accuracy: 0.8268030881881714 | f1: 0.8698638677597046 | auc: 0.900615394115448 | treshold: 0.42000000000000004
Test: accuracy: 0.826889157295227 | f1: 0.8688467741012573 | auc: 0.9007039666175842 | 


train... loss:35.83971834182739:   1%|▊                                                                                                 | 44/5000 [04:49<8:58:58,  6.53s/it]

Val update: epoch: 43 |accuracy: 0.8287379741668701 | f1: 0.8702852129936218 | auc: 0.903150200843811 | treshold: 0.42000000000000004
Test: accuracy: 0.8303470015525818 | f1: 0.8702566623687744 | auc: 0.9036237001419067 | 


train... loss:35.430919736623764:   1%|▊                                                                                                | 45/5000 [04:56<9:07:52,  6.63s/it]

Val update: epoch: 44 |accuracy: 0.827214241027832 | f1: 0.8727874755859375 | auc: 0.9035871624946594 | treshold: 0.4
Test: accuracy: 0.8260186314582825 | f1: 0.8707817792892456 | auc: 0.9040542840957642 | 


train... loss:34.715289533138275:   1%|▉                                                                                                | 51/5000 [05:35<8:59:49,  6.54s/it]

Val update: epoch: 50 |accuracy: 0.8322933316230774 | f1: 0.8753191828727722 | auc: 0.9061061143875122 | treshold: 0.4
Test: accuracy: 0.8324265480041504 | f1: 0.8739404082298279 | auc: 0.9065450429916382 | 


train... loss:32.26053726673126:   1%|█▎                                                                                                | 64/5000 [06:57<8:52:01,  6.47s/it]

Val update: epoch: 63 |accuracy: 0.8374449610710144 | f1: 0.8756176829338074 | auc: 0.9062703847885132 | treshold: 0.42000000000000004
Test: accuracy: 0.8377463221549988 | f1: 0.874677836894989 | auc: 0.9069621562957764 | 


train... loss:32.176953703165054:   1%|█▎                                                                                               | 67/5000 [07:17<8:55:57,  6.52s/it]

Val update: epoch: 66 |accuracy: 0.8355100750923157 | f1: 0.8698945641517639 | auc: 0.908982515335083 | treshold: 0.42000000000000004
Test: accuracy: 0.8346753716468811 | f1: 0.8683597445487976 | auc: 0.9089608788490295 | 


train... loss:30.775516390800476:   2%|█▍                                                                                               | 75/5000 [08:08<8:53:21,  6.50s/it]

Val update: epoch: 74 |accuracy: 0.838194727897644 | f1: 0.8786240220069885 | auc: 0.9096196889877319 | treshold: 0.41000000000000003
Test: accuracy: 0.8376254439353943 | f1: 0.8771833777427673 | auc: 0.909818708896637 | 


train... loss:30.55927586555481:   2%|█▌                                                                                                | 79/5000 [08:34<8:58:16,  6.56s/it]

Val update: epoch: 78 |accuracy: 0.8410003185272217 | f1: 0.8787711262702942 | auc: 0.9104902744293213 | treshold: 0.43
Test: accuracy: 0.8405755162239075 | f1: 0.8773600459098816 | auc: 0.9105210900306702 | 


train... loss:29.139897614717484:   2%|█▊                                                                                               | 94/5000 [10:10<8:52:20,  6.51s/it]

Val update: epoch: 93 |accuracy: 0.8422338366508484 | f1: 0.8797492980957031 | auc: 0.9120118618011475 | treshold: 0.44
Test: accuracy: 0.841059148311615 | f1: 0.8779364228248596 | auc: 0.911666750907898 | 


train... loss:28.63973069190979:   2%|██▏                                                                                              | 110/5000 [11:52<8:52:15,  6.53s/it]

Val update: epoch: 109 |accuracy: 0.8410245180130005 | f1: 0.8789569735527039 | auc: 0.9121874570846558 | treshold: 0.42000000000000004
Test: accuracy: 0.8403337001800537 | f1: 0.8776746392250061 | auc: 0.9116716384887695 | 


train... loss:28.234099194407463:   2%|██▏                                                                                             | 115/5000 [12:24<8:48:29,  6.49s/it]

Val update: epoch: 114 |accuracy: 0.8426207900047302 | f1: 0.8802694082260132 | auc: 0.9133642911911011 | treshold: 0.43
Test: accuracy: 0.841300904750824 | f1: 0.8784427046775818 | auc: 0.9130176305770874 | 


train... loss:27.872017681598663:   2%|██▍                                                                                             | 124/5000 [13:21<8:47:51,  6.50s/it]

Val update: epoch: 123 |accuracy: 0.841943621635437 | f1: 0.8800058960914612 | auc: 0.913579523563385 | treshold: 0.42000000000000004
Test: accuracy: 0.8402369618415833 | f1: 0.8780433535575867 | auc: 0.9129936695098877 | 


train... loss:28.233209624886513:   3%|██▍                                                                                             | 130/5000 [14:05<8:48:12,  6.51s/it]
