In [1]:
import os
import torch
import random
import datetime
import pandas as pd
import numpy as np
import torch.nn.functional as F

from src.datasets import RL4RS, ContentWise, DummyData
from src.utils import train, get_dummy_data, get_train_val_test_tmatrix_tnumitems, get_svd_encoder
from src.embeddings import RecsysEmbedding
from torch.utils.data import Dataset

experiment_name = 'AttentionGRU-FIXED'
device = 'cuda:2'
seed = 123
pkl_path = '../pkl/'

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7f04380aebd0>

In [2]:
torch.__version__

'1.12.1'

# Модель

In [3]:
# torch.autograd.set_detect_anomaly(True)

class AttentionGRU(torch.nn.Module):
    def __init__(self, embedding, nheads=2, output_dim=1):
        super().__init__()
        self.embedding_dim = embedding.embedding_dim
        self.embedding = embedding
        self.attention = torch.nn.MultiheadAttention(
            2 * embedding.embedding_dim,
            num_heads=nheads,
            batch_first=True
        )

        self.rnn_cell = torch.nn.GRUCell(
            input_size = 2 * embedding.embedding_dim, 
            hidden_size = embedding.embedding_dim, 
        )
        
        self.out_layer = torch.nn.Linear(3 * embedding.embedding_dim, output_dim)
    
    
    def get_attention_embeddings(self, item_embs, user_embs, slate_mask):
        shp = item_embs.shape      
        key_padding_mask = slate_mask
        key_padding_mask[:,:, 0] = True # let model attent to first padd token if slate is empty 
        features = torch.cat(
            [
                item_embs,
                user_embs[:, :, None, :].repeat(1, 1, item_embs.size(-2), 1).reshape(shp)
            ],
            dim = -1
        ).flatten(0,1)

        features, attn_map = self.attention(
            features, features, features,
            key_padding_mask=~key_padding_mask.flatten(0, 1)
        )
        shp = list(shp)
        shp[-1] *= 2
        features = features.reshape(shp)
        return features
    
    def forward(self, batch):
        # consider sequential clicks, hence need to flatten slates
        item_embs, user_embs = self.embedding(batch)
        slate_mask = batch['slates_mask'].clone()

        # item_embs dims: batch, session, slate, embedding
        # user_embs dims: batch, session, embedding
        session_length = item_embs.shape[-3]
        slate_size = item_embs.shape[-2]
        hidden = user_embs[..., 0, :]
        preds = []
        for rec in range(session_length):
            # att_features dims: batch, 1, slate, embedding
            att_features = self.get_attention_embeddings(
                item_embs[..., rec, :, :].unsqueeze(-3), 
                hidden.unsqueeze(-2), 
                slate_mask[..., rec, :].unsqueeze(-3)
            )
            # hidden dims: batch, embedding
            hidden = self.rnn_cell(
                att_features.squeeze(-3).mean(-2),
                hidden
            )
            features = torch.cat(
                [
                    att_features, 
                    hidden[..., None, None ,:].repeat(1, 1, slate_size, 1)
                ],
                dim=-1
            )
            preds.append(features)
        preds = torch.cat(preds, dim=-3)            
        return self.out_layer(preds).squeeze(-1)

# Игрушечный датасет: проверим, что сходится к идеальным метрикам

In [4]:
d = DummyData()
dummy_loader, dummy_matrix = get_dummy_data(d)

model = AttentionGRU(
    RecsysEmbedding(d.n_items, dummy_matrix, embeddings='neural').to('cpu'),
    output_dim=1
).to('cpu')

train(
    model, 
    dummy_loader, dummy_loader, dummy_loader,
    device=device, lr=1e-3, num_epochs=5000, dummy=True,
    silent=True,
)


biulding affinity matrix...


3it [00:00, 3498.17it/s]


Test before learning: {'f1': 0.4000000059604645, 'roc-auc': 0.0, 'accuracy': 0.25}


train:   0%|          | 0/5000 [00:00<?, ?it/s]

Val update: epoch: 0 |accuracy: 0.25 | f1: 0.4000000059604645 | auc: 0.3333333134651184 | treshold: 0.01
Test: accuracy: 0.25 | f1: 0.4000000059604645 | auc: 0.3333333134651184 | 
Val update: epoch: 1 |accuracy: 0.25 | f1: 0.4000000059604645 | auc: 1.0 | treshold: 0.01
Test: accuracy: 0.25 | f1: 0.4000000059604645 | auc: 1.0 | 
Val update: epoch: 2 |accuracy: 1.0 | f1: 1.0 | auc: 1.0 | treshold: 0.52
Test: accuracy: 1.0 | f1: 1.0 | auc: 1.0 | 


(AttentionGRU(
   (embedding): RecsysEmbedding(
     (item_embeddings): Embedding(5, 32)
   )
   (attention): MultiheadAttention(
     (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
   )
   (rnn_cell): GRUCell(64, 32)
   (out_layer): Linear(in_features=96, out_features=1, bias=True)
 ),
 {'f1': 1.0, 'roc-auc': 1.0, 'accuracy': 1.0})

# ContentWise

In [5]:
content_wise_results = []
dataset = ContentWise.load(os.path.join(pkl_path, 'cw.pkl'))
(
    train_loader, 
    val_loader, 
    test_loader, 
    train_user_item_matrix, 
    train_num_items 
) = get_train_val_test_tmatrix_tnumitems(dataset, batch_size=150)

print(f"{len(dataset)} data points among {len(train_loader)} batches")

20216 data points among 108 batches


In [6]:
for embeddings in ['svd', 'neural']:
    print(f"\nEvaluating {experiment_name} with {embeddings} embeddings")
    
    model = AttentionGRU(
        RecsysEmbedding(train_num_items, train_user_item_matrix, embeddings=embeddings),
        output_dim=1
    ).to(device)

    _, metrics = train(
        model, 
        train_loader, val_loader, test_loader, 
        device=device, lr=1e-3, num_epochs=5000, early_stopping=7,
       silent=True, 
    )
    
    metrics['embeddings'] = embeddings
    content_wise_results.append(metrics)


Evaluating AttentionGRU_FIXED with svd embeddings
Test before learning: {'f1': 0.0, 'roc-auc': 0.4743947386741638, 'accuracy': 0.9019922018051147}


train:   0%|          | 0/5000 [00:00<?, ?it/s]

Val update: epoch: 0 |accuracy: 0.09937293827533722 | f1: 0.18078112602233887 | auc: 0.4981193542480469 | treshold: 0.01
Test: accuracy: 0.09800780564546585 | f1: 0.1785193234682083 | auc: 0.502781867980957 | 
Val update: epoch: 1 |accuracy: 0.09941849112510681 | f1: 0.18078862130641937 | auc: 0.5313175916671753 | treshold: 0.04
Test: accuracy: 0.09800780564546585 | f1: 0.1785193234682083 | auc: 0.5303691625595093 | 
Val update: epoch: 2 |accuracy: 0.12510818243026733 | f1: 0.18386799097061157 | auc: 0.5338164567947388 | treshold: 0.05
Test: accuracy: 0.12855087220668793 | f1: 0.18196755647659302 | auc: 0.536331295967102 | 
Val update: epoch: 4 |accuracy: 0.15058530867099762 | f1: 0.18588745594024658 | auc: 0.545150637626648 | treshold: 0.05
Test: accuracy: 0.15406523644924164 | f1: 0.1843147873878479 | auc: 0.5524570941925049 | 
Val update: epoch: 6 |accuracy: 0.18500523269176483 | f1: 0.18610504269599915 | auc: 0.5599356889724731 | treshold: 0.060000000000000005
Test: accuracy: 0.186

train:   0%|          | 0/5000 [00:00<?, ?it/s]

Val update: epoch: 1 |accuracy: 0.6821736097335815 | f1: 0.2442687451839447 | auc: 0.6551786661148071 | treshold: 0.060000000000000005
Test: accuracy: 0.6829839944839478 | f1: 0.24138298630714417 | auc: 0.653461217880249 | 
Val update: epoch: 2 |accuracy: 0.8839409351348877 | f1: 0.21551723778247833 | auc: 0.684718132019043 | treshold: 0.14
Test: accuracy: 0.885289192199707 | f1: 0.21067674458026886 | auc: 0.6849374771118164 | 
Val update: epoch: 3 |accuracy: 0.8748766183853149 | f1: 0.23360922932624817 | auc: 0.7008762359619141 | treshold: 0.14
Test: accuracy: 0.8788809180259705 | f1: 0.2332613319158554 | auc: 0.700993537902832 | 
Val update: epoch: 4 |accuracy: 0.5919864177703857 | f1: 0.25845083594322205 | auc: 0.7081881761550903 | treshold: 0.13
Test: accuracy: 0.5989497303962708 | f1: 0.2576606273651123 | auc: 0.7096996307373047 | 
Val update: epoch: 7 |accuracy: 0.7808632850646973 | f1: 0.29824477434158325 | auc: 0.7220767736434937 | treshold: 0.13
Test: accuracy: 0.7844777703285

In [7]:
pd.DataFrame(content_wise_results).to_csv(f'results/cw_{experiment_name}.csv')
del dataset, train_loader, val_loader, test_loader, train_user_item_matrix, train_num_items

# RL4RS

In [8]:
rl4rs_results = []
dataset = RL4RS.load(os.path.join(pkl_path, 'rl4rs.pkl'))
(
    train_loader, 
    val_loader, 
    test_loader, 
    train_user_item_matrix, 
    train_num_items 
) = get_train_val_test_tmatrix_tnumitems(dataset, batch_size=350)

print(f"{len(dataset)} data points among {len(train_loader)} batches")

45942 data points among 106 batches


In [9]:
for embeddings in ['neural','explicit', 'svd',  ]:
    print(f"\nEvaluating {experiment_name} with {embeddings} embeddings")

    model = AttentionGRU(
        RecsysEmbedding(
            train_num_items, 
            train_user_item_matrix, 
            embeddings=embeddings,
            embedding_dim=40
        ),
        output_dim=1
    ).to(device)

    best_model, metrics = train(
        model, 
        train_loader, val_loader, test_loader, 
        device=device, lr=1e-3, num_epochs=5000, early_stopping=7,
        silent=True
    )
    
    metrics['embeddings'] = embeddings
    rl4rs_results.append(metrics)


Evaluating AttentionGRU_FIXED with neural embeddings
Test before learning: {'f1': 0.7831947803497314, 'roc-auc': 0.4515891671180725, 'accuracy': 0.6437432169914246}


train:   0%|          | 0/5000 [00:00<?, ?it/s]

Val update: epoch: 0 |accuracy: 0.710491955280304 | f1: 0.8135165572166443 | auc: 0.7544424533843994 | treshold: 0.37
Test: accuracy: 0.7043404579162598 | f1: 0.8077062368392944 | auc: 0.7642634510993958 | 
Val update: epoch: 1 |accuracy: 0.7315096855163574 | f1: 0.8276215195655823 | auc: 0.828403651714325 | treshold: 0.4
Test: accuracy: 0.7246040105819702 | f1: 0.8216148614883423 | auc: 0.8362383842468262 | 
Val update: epoch: 2 |accuracy: 0.7852513194084167 | f1: 0.8383667469024658 | auc: 0.8443363308906555 | treshold: 0.43
Test: accuracy: 0.7903760075569153 | f1: 0.8402410745620728 | auc: 0.8540526032447815 | 
Val update: epoch: 3 |accuracy: 0.7896773815155029 | f1: 0.847202718257904 | auc: 0.8463099002838135 | treshold: 0.4
Test: accuracy: 0.7903760075569153 | f1: 0.845815896987915 | auc: 0.8544607758522034 | 
Val update: epoch: 5 |accuracy: 0.7947322726249695 | f1: 0.8523049354553223 | auc: 0.855345606803894 | treshold: 0.38
Test: accuracy: 0.7935194969177246 | f1: 0.8498901128768

train:   0%|          | 0/5000 [00:00<?, ?it/s]

Val update: epoch: 0 |accuracy: 0.7491897940635681 | f1: 0.8236814141273499 | auc: 0.8125873804092407 | treshold: 0.39
Test: accuracy: 0.7525087594985962 | f1: 0.8244696259498596 | auc: 0.8222543001174927 | 
Val update: epoch: 2 |accuracy: 0.7791321873664856 | f1: 0.8502165079116821 | auc: 0.842571496963501 | treshold: 0.38
Test: accuracy: 0.7740055322647095 | f1: 0.84546959400177 | auc: 0.8518551588058472 | 
Val update: epoch: 4 |accuracy: 0.7897499203681946 | f1: 0.8510324954986572 | auc: 0.8438721895217896 | treshold: 0.39
Test: accuracy: 0.7886350154876709 | f1: 0.8489571809768677 | auc: 0.8533065319061279 | 
Val update: epoch: 5 |accuracy: 0.7872103452682495 | f1: 0.8526002168655396 | auc: 0.8444597125053406 | treshold: 0.4
Test: accuracy: 0.7821303606033325 | f1: 0.8477731943130493 | auc: 0.85223388671875 | 
Val update: epoch: 7 |accuracy: 0.7911769151687622 | f1: 0.853248119354248 | auc: 0.8496759533882141 | treshold: 0.37
Test: accuracy: 0.7900616526603699 | f1: 0.8511929512023

train:   0%|          | 0/5000 [00:00<?, ?it/s]

Val update: epoch: 0 |accuracy: 0.655105710029602 | f1: 0.7916179299354553 | auc: 0.7020512223243713 | treshold: 0.4
Test: accuracy: 0.6448071599006653 | f1: 0.7840519547462463 | auc: 0.7108439207077026 | 
Val update: epoch: 1 |accuracy: 0.6610071063041687 | f1: 0.7918745279312134 | auc: 0.7179774641990662 | treshold: 0.33
Test: accuracy: 0.6519647240638733 | f1: 0.7848162055015564 | auc: 0.7310923337936401 | 
Val update: epoch: 3 |accuracy: 0.7061626315116882 | f1: 0.7789201736450195 | auc: 0.743841290473938 | treshold: 0.42000000000000004
Test: accuracy: 0.712658703327179 | f1: 0.7804688811302185 | auc: 0.7573460340499878 | 
Val update: epoch: 4 |accuracy: 0.7072268128395081 | f1: 0.7813600897789001 | auc: 0.7489835023880005 | treshold: 0.4
Test: accuracy: 0.7151009440422058 | f1: 0.784425675868988 | auc: 0.7622102499008179 | 
Val update: epoch: 5 |accuracy: 0.691650927066803 | f1: 0.7986735105514526 | auc: 0.7617202997207642 | treshold: 0.46
Test: accuracy: 0.6893724799156189 | f1: 

In [10]:
pd.DataFrame(rl4rs_results).to_csv(f'results/rl4rs_{experiment_name}.csv')
del dataset, train_loader, val_loader, test_loader, train_user_item_matrix, train_num_items