In [1]:
import torch
import random
import datetime
import pandas as pd
import numpy as np
import os

from torch.utils.data import Dataset
from src.datasets import RL4RS, ContentWise, DummyData
from src.utils import train, get_dummy_data, get_train_val_test_tmatrix_tnumitems, fit_treshold
from src.embeddings import RecsysEmbedding
from sklearn.linear_model import LogisticRegression

experiment_name = 'NeuralClickModel'
device = 'cuda:2'
seed = 7331
pkl_path = '../pkl/'


random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7fc5cdc3eb10>

In [2]:
def flatten(true, pred, mask, to_cpu=True):
    mask = mask.flatten()
    nnz_idx = mask.nonzero()[:, 0]
    true, pred = [x.flatten()[nnz_idx] for x in [true, pred]]
    if to_cpu:
        true, pred = [x.cpu().numpy() for x in [true, pred]]
    return true, pred

# Модель

In [3]:
import torch.nn as nn
import torch.nn.functional as F

class NeuralClickModel(nn.Module):
    def __init__(self, embedding, readout=False):
        super().__init__()
        
        self.embedding = embedding
        
        self.emb_dim = embedding.embedding_dim
        self.rnn_layer = nn.GRU(
            input_size=self.emb_dim * 2, 
            hidden_size=self.emb_dim, 
            batch_first=True
        )
        self.out_layer = nn.Linear(self.emb_dim, 1)
        
        self.thr = -1.5
        self.readout = readout
        self.readout_mode = 'threshold' # ['soft' ,'threshold', 'sample', 'diff_sample']
        
        self.calibration = False
        self.w = 1
        self.b = 0

    def forward(self, batch):
        item_embs, user_embs = self.embedding(batch)
        shp = item_embs.shape
        max_sequence = item_embs.size(1)
        # ilya format:
        # 'items': (batch, slate, 2*embedding_dim ) 2, нужно для readout, по умолчанию ноли на половине эмбеддинга
        # 'clicks': (batch, slate)
        # 'users': (1, batch, embedding_dim), 
        # 'mask': (batch, slate)
        
        x = {}
        x['items'] = torch.cat(
            [
                item_embs.flatten(0,1),
                torch.zeros_like(item_embs.flatten(0,1)),
            ],
            dim = -1
        )
                
        if self.training:
            indices = (batch['length'] - 1)
        else:
            indices = (batch['in_length'] - 1)
        indices[indices<0] = 0
        indices = indices[:, None, None].repeat(1, 1, user_embs.size(-1))
        user_embs = user_embs.gather(1, indices).squeeze(-2).unsqueeze(0)
        x['users'] = user_embs.repeat_interleave(max_sequence, 1)
        x['clicks'] = (batch['responses'].flatten(0,1) > 0 ).int().clone()
        x['mask'] = batch['slates_mask'].flatten(0,1).clone()
        
        items = x['items']
        h = x['users']
        
        if self.readout:
            res = []
            seq_len = items.shape[1]
            for i in range(seq_len):
#                 print(items[:,[i],:])
                output, h = self.rnn_layer(items[:,[i],:], h)
                y = self.out_layer(output)[:, :, 0]
                
                # readout
                if i + 1 < seq_len:
                    if self.readout_mode == 'threshold':
                        items[:, [i+1], self.emb_dim:] *= (y.detach()[:, :, None] > self.thr).to(torch.float32)
                    elif self.readout_mode == 'soft':
                        items[:, [i+1], self.emb_dim:] *= torch.sigmoid(y)[:, :, None]
                    elif self.readout_mode == 'diff_sample' or self.readout_mode == 'sample':
                        eps = 1e-8
                        gumbel_sample = -( (torch.rand_like(y) + eps).log() / (torch.rand_like(y) + eps).log() + eps).log()
                        T = 0.3
                        bernoulli_sample = torch.sigmoid( (nn.LogSigmoid()(self.w * y + self.b) + gumbel_sample) / T )
                        if self.readout_mode == 'sample':
                            bernoulli_sample = bernoulli_sample.detach()
                        items[:, [i+1], self.emb_dim:] *= bernoulli_sample[:, :, None]
                    else:
                        raise
                    
                res.append(y)
        
            y = torch.cat(res, axis=1)
            
        else:
            items[:, 1:, self.emb_dim:] *= x['clicks'][:, :-1, None]
            rnn_out, _ = self.rnn_layer(items, h)
            y = self.out_layer(rnn_out)[:, :, 0]
        
        
        if self.calibration and self.training:
            clicks_flat, logits_flat = flatten(x['clicks'], y.detach(), x['mask'])
            logreg = LogisticRegression()
            logreg.fit(logits_flat[:, None], clicks_flat)
            γ = 0.3
            self.w = (1 - γ) * self.w + γ * logreg.coef_[0, 0]
            self.b = (1 - γ) * self.b + γ * logreg.intercept_[0]
            y = self.w * y + self.b
        else:
            y = self.w * y + self.b
            
        return y.reshape(shp[:-1])

# Игрушечный датасет: проверим, что сходится к идеальным метрикам

In [4]:
d = DummyData()
dummy_loader, dummy_matrix = get_dummy_data(d)

model = NeuralClickModel(
    RecsysEmbedding(d.n_items, dummy_matrix, embeddings='svd', embedding_dim = 2),
    readout=False
).to('cpu')

train(
    model, 
    dummy_loader, dummy_loader, dummy_loader,
    device=device, lr=1e-3, num_epochs=5000, dummy=True,
    silent=True,
)


biulding affinity matrix...


3it [00:00, 2560.10it/s]


Test before learning: {'f1': 0.0, 'roc-auc': 0.3333333134651184, 'accuracy': 0.75}


train... loss:0.7181857228279114:   0%|                                                                                                    | 1/5000 [00:00<25:59,  3.21it/s]

Val update: epoch: 0 |accuracy: 0.25 | f1: 0.4000000059604645 | auc: 0.3333333134651184 | treshold: 0.01
Test: accuracy: 0.25 | f1: 0.4000000059604645 | auc: 0.3333333134651184 | 


train... loss:0.7011198401451111:   1%|▊                                                                                                  | 42/5000 [00:09<20:03,  4.12it/s]

Val update: epoch: 41 |accuracy: 0.25 | f1: 0.4000000059604645 | auc: 0.6666666269302368 | treshold: 0.01
Test: accuracy: 0.25 | f1: 0.4000000059604645 | auc: 0.6666666269302368 | 


train... loss:0.6998206377029419:   1%|▉                                                                                                  | 45/5000 [00:10<20:53,  3.95it/s]

Val update: epoch: 44 |accuracy: 0.25 | f1: 0.4000000059604645 | auc: 1.0 | treshold: 0.01
Test: accuracy: 0.25 | f1: 0.4000000059604645 | auc: 1.0 | 


train... loss:0.6820520758628845:   2%|█▋                                                                                                 | 85/5000 [00:18<20:26,  4.01it/s]

Val update: epoch: 84 |accuracy: 0.75 | f1: 0.6666666865348816 | auc: 1.0 | treshold: 0.36000000000000004
Test: accuracy: 0.75 | f1: 0.6666666865348816 | auc: 1.0 | 


train... loss:0.6675306558609009:   2%|██▎                                                                                               | 117/5000 [00:25<18:03,  4.51it/s]

Val update: epoch: 117 |accuracy: 1.0 | f1: 1.0 | auc: 1.0 | treshold: 0.38
Test: accuracy: 1.0 | f1: 1.0 | auc: 1.0 | 





(NeuralClickModel(
   (embedding): RecsysEmbedding()
   (rnn_layer): GRU(4, 2, batch_first=True)
   (out_layer): Linear(in_features=2, out_features=1, bias=True)
 ),
 {'f1': 1.0, 'roc-auc': 1.0, 'accuracy': 1.0})

# ContentWise

In [5]:
content_wise_results = []
dataset = ContentWise.load(os.path.join(pkl_path, 'cw.pkl'))
(
    train_loader, 
    val_loader, 
    test_loader, 
    train_user_item_matrix, 
    train_num_items 
) = get_train_val_test_tmatrix_tnumitems(dataset, batch_size=150)

print(f"{len(dataset)} data points among {len(train_loader)} batches")

20216 data points among 108 batches


In [None]:
for embeddings in ['svd', 'neural']:
    print(f"\nEvaluating {experiment_name} with {embeddings} embeddings")
    model = NeuralClickModel(
        RecsysEmbedding(train_num_items, train_user_item_matrix, embeddings=embeddings).to('cpu'),
    ).to(device)

    _, metrics = train(
        model, 
        train_loader, val_loader, test_loader, 
        device=device, lr=1e-3, num_epochs=5000, early_stopping=7,
       silent=True, 
    )
    
    metrics['embeddings'] = embeddings
    content_wise_results.append(metrics)


Evaluating NeuralClickModel with svd embeddings
Test before learning: {'f1': 0.17521433532238007, 'roc-auc': 0.46481871604919434, 'accuracy': 0.11281668394804001}


train... loss:49.150610506534576:   0%|                                                                                                 | 1/5000 [00:39<55:06:04, 39.68s/it]

Val update: epoch: 0 |accuracy: 0.33749404549598694 | f1: 0.20018358528614044 | auc: 0.6131184101104736 | treshold: 0.09
Test: accuracy: 0.3378555476665497 | f1: 0.20030193030834198 | auc: 0.609430193901062 | 


train... loss:34.73808878660202:   0%|                                                                                                  | 2/5000 [00:47<29:04:21, 20.94s/it]

Val update: epoch: 1 |accuracy: 0.7018216252326965 | f1: 0.23198693990707397 | auc: 0.617000937461853 | treshold: 0.13
Test: accuracy: 0.7026462554931641 | f1: 0.2237255722284317 | auc: 0.6088778376579285 | 


train... loss:33.44633388519287:   0%|                                                                                                  | 3/5000 [00:55<20:44:57, 14.95s/it]

Val update: epoch: 2 |accuracy: 0.7811975479125977 | f1: 0.23427018523216248 | auc: 0.6312115788459778 | treshold: 0.12
Test: accuracy: 0.7810584902763367 | f1: 0.22492040693759918 | auc: 0.6196919679641724 | 


train... loss:32.81031849980354:   0%|                                                                                                  | 4/5000 [01:03<16:47:24, 12.10s/it]

Val update: epoch: 3 |accuracy: 0.7995089292526245 | f1: 0.24701052904129028 | auc: 0.6464071273803711 | treshold: 0.12
Test: accuracy: 0.8014999628067017 | f1: 0.23390987515449524 | auc: 0.6387901902198792 | 


train... loss:32.49579495191574:   0%|                                                                                                  | 5/5000 [01:10<14:39:21, 10.56s/it]

Val update: epoch: 4 |accuracy: 0.7946301102638245 | f1: 0.24871066212654114 | auc: 0.6535422801971436 | treshold: 0.13
Test: accuracy: 0.796042799949646 | f1: 0.2365628182888031 | auc: 0.6498790383338928 | 


train... loss:32.295405477285385:   0%|                                                                                                 | 6/5000 [01:18<13:23:20,  9.65s/it]

Val update: epoch: 5 |accuracy: 0.7845398187637329 | f1: 0.2538672387599945 | auc: 0.65910804271698 | treshold: 0.13
Test: accuracy: 0.7863022685050964 | f1: 0.24689766764640808 | auc: 0.6564496159553528 | 


train... loss:32.10356479883194:   0%|▏                                                                                                 | 7/5000 [01:26<12:35:14,  9.08s/it]

Val update: epoch: 6 |accuracy: 0.7581181526184082 | f1: 0.2522769570350647 | auc: 0.6607139706611633 | treshold: 0.13
Test: accuracy: 0.7580104470252991 | f1: 0.24580739438533783 | auc: 0.6579915881156921 | 


train... loss:31.933301150798798:   0%|▏                                                                                                | 8/5000 [01:34<12:00:37,  8.66s/it]

Val update: epoch: 7 |accuracy: 0.6910026669502258 | f1: 0.24703748524188995 | auc: 0.6638944745063782 | treshold: 0.13
Test: accuracy: 0.6934544444084167 | f1: 0.24199019372463226 | auc: 0.66016685962677 | 


train... loss:31.229754716157913:   0%|▎                                                                                               | 14/5000 [02:18<10:32:54,  7.62s/it]

Val update: epoch: 13 |accuracy: 0.8507524132728577 | f1: 0.26379120349884033 | auc: 0.6643019914627075 | treshold: 0.14
Test: accuracy: 0.8519862294197083 | f1: 0.2516954243183136 | auc: 0.6623901128768921 | 


train... loss:31.074706226587296:   0%|▎                                                                                               | 15/5000 [02:26<10:37:36,  7.67s/it]

Val update: epoch: 14 |accuracy: 0.7622525095939636 | f1: 0.26350656151771545 | auc: 0.6666558980941772 | treshold: 0.13
Test: accuracy: 0.7624310255050659 | f1: 0.2583162784576416 | auc: 0.6674239039421082 | 


train... loss:30.93208223581314:   0%|▎                                                                                                | 17/5000 [02:41<10:32:37,  7.62s/it]

Val update: epoch: 16 |accuracy: 0.8742753267288208 | f1: 0.25340983271598816 | auc: 0.66755211353302 | treshold: 0.13
Test: accuracy: 0.8736318945884705 | f1: 0.2328336089849472 | auc: 0.6682276129722595 | 


train... loss:30.87713885307312:   0%|▎                                                                                                | 18/5000 [02:49<10:35:43,  7.66s/it]

Val update: epoch: 17 |accuracy: 0.867511510848999 | f1: 0.27142858505249023 | auc: 0.6701675057411194 | treshold: 0.13
Test: accuracy: 0.8671839237213135 | f1: 0.25332075357437134 | auc: 0.6709856986999512 | 


train... loss:31.115827798843384:   0%|▎                                                                                               | 19/5000 [02:57<10:41:03,  7.72s/it]

Val update: epoch: 18 |accuracy: 0.8423570394515991 | f1: 0.27674418687820435 | auc: 0.6743515133857727 | treshold: 0.14
Test: accuracy: 0.8423371315002441 | f1: 0.2633715569972992 | auc: 0.6702533960342407 | 


train... loss:30.251953333616257:   1%|▍                                                                                               | 26/5000 [03:48<10:18:00,  7.45s/it]

Val update: epoch: 25 |accuracy: 0.8460636734962463 | f1: 0.2802547812461853 | auc: 0.6761230230331421 | treshold: 0.14
Test: accuracy: 0.845477283000946 | f1: 0.26761072874069214 | auc: 0.6767306327819824 | 


train... loss:30.19716864824295:   1%|▌                                                                                                | 27/5000 [03:56<10:27:34,  7.57s/it]

Val update: epoch: 26 |accuracy: 0.8812925815582275 | f1: 0.2641398310661316 | auc: 0.676596999168396 | treshold: 0.14
Test: accuracy: 0.8813450932502747 | f1: 0.25866666436195374 | auc: 0.6780471801757812 | 


train... loss:30.09117679297924:   1%|▌                                                                                                | 30/5000 [04:18<10:24:12,  7.54s/it]

Val update: epoch: 29 |accuracy: 0.8707745671272278 | f1: 0.28287622332572937 | auc: 0.6773531436920166 | treshold: 0.14
Test: accuracy: 0.8712387084960938 | f1: 0.27275076508522034 | auc: 0.6777620911598206 | 


train... loss:30.341856330633163:   1%|▋                                                                                               | 35/5000 [05:02<11:54:59,  8.64s/it]



Evaluating NeuralClickModel with neural embeddings


In [None]:
pd.DataFrame(content_wise_results).to_csv(f'results/cw_{experiment_name}.csv')
del dataset, train_loader, val_loader, test_loader, train_user_item_matrix, train_num_items

# RL4RS

In [None]:
rl4rs_results = []
dataset = RL4RS.load(os.path.join(pkl_path, 'rl4rs.pkl'))
(
    train_loader, 
    val_loader, 
    test_loader, 
    train_user_item_matrix, 
    train_num_items 
) = get_train_val_test_tmatrix_tnumitems(dataset, batch_size=350)

print(f"{len(dataset)} data points among {len(train_loader)} batches")

In [None]:
for embeddings in ['neural','explicit', 'svd',  ]:
    print(f"\nEvaluating {experiment_name} with {embeddings} embeddings")

    model = NeuralClickModel(
        RecsysEmbedding(train_num_items, train_user_item_matrix, embeddings=embeddings),
    ).to(device)

    best_model, metrics = train(
        model, 
        train_loader, val_loader, test_loader, 
        device=device, lr=1e-3, num_epochs=5000, early_stopping=7,
        silent=True
    )
    
    metrics['embeddings'] = embeddings
    rl4rs_results.append(metrics)

In [None]:
pd.DataFrame(rl4rs_results).to_csv(f'results/rl4rs_{experiment_name}.csv')
del dataset, train_loader, val_loader, test_loader, train_user_item_matrix, train_num_items