In [None]:
import os
import time
import json
import pickle

import requests
import numpy as np
import pandas as pd

from tqdm import tqdm
from tools import compute_metrics

import warnings
warnings.filterwarnings(action='ignore', category=UserWarning)

# Подготовка данных

## Загрузка данных

In [None]:
url = "https://storage.yandexcloud.net/itmo-recsys-public-data/kion_train.zip"

req = requests.get(url, stream=True)

with open('kion_train.zip', "wb") as fd:
    total_size_in_bytes = int(req.headers.get('Content-Length', 0))
    progress_bar = tqdm(desc='kion dataset download', total=total_size_in_bytes, unit='iB', unit_scale=True)
    for chunk in req.iter_content(chunk_size=2 ** 20):
        progress_bar.update(len(chunk))
        fd.write(chunk)

kion dataset download:  98%|█████████▊| 77.6M/78.8M [00:04<00:00, 20.8MiB/s]

In [None]:
!unzip kion_train.zip

Archive:  kion_train.zip
   creating: kion_train/
  inflating: kion_train/interactions.csv  
  inflating: __MACOSX/kion_train/._interactions.csv  
  inflating: kion_train/users.csv    
  inflating: __MACOSX/kion_train/._users.csv  
  inflating: kion_train/items.csv    
  inflating: __MACOSX/kion_train/._items.csv  


In [None]:
interactions = pd.read_csv('kion_train/interactions.csv')
interactions.head()

Unnamed: 0,user_id,item_id,last_watch_dt,total_dur,watched_pct
0,176549,9506,2021-05-11,4250,72.0
1,699317,1659,2021-05-29,8317,100.0
2,656683,7107,2021-05-09,10,0.0
3,864613,7638,2021-07-05,14483,100.0
4,964868,9506,2021-04-30,6725,100.0


In [None]:
interactions.drop(interactions[interactions['last_watch_dt'].str.len() != 10].index, inplace=True)
interactions['last_watch_dt'] = pd.to_datetime(interactions['last_watch_dt'], format='%Y-%m-%d')
max_date = interactions['last_watch_dt'].max()
interactions.drop(interactions.query("total_dur < 300").index, inplace=True)
interactions['watched_pct'] = np.where(interactions['watched_pct'] > 10, 3, 1)

In [None]:
interactions.user_id.nunique(), interactions.item_id.nunique()

(809577, 14163)

## Преобразование формата данных под BERT4Rec и SasRec

In [None]:
train = interactions[interactions['last_watch_dt'] < max_date - pd.Timedelta(days=7)].copy()
test = interactions[interactions['last_watch_dt'] >= max_date - pd.Timedelta(days=7)].copy()

In [None]:
# Избавимся от пользователей с историей просмотров < 10
train_count = train.groupby('user_id')['item_id'].count()
valid_users = train_count[train_count > 20].index

train_count = train.groupby('item_id')['user_id'].count()
valid_items = train_count[train_count > 10].index

train = train[train['user_id'].isin(valid_users)] 
train = train[train['item_id'].isin(valid_items)] 

In [None]:
cold_users = set(test['user_id']) - set(train['user_id'])

# Отбрасываем холодных пользователей
test.drop(test[test['user_id'].isin(cold_users)].index, inplace=True)

In [None]:
user_idx_map_inv = {idx+1:user_id for idx, user_id in enumerate(np.sort(train.user_id.unique()))}
item_idx_map = {item_id:idx+1 for idx, item_id in enumerate(np.sort(train.item_id.unique()))}

train.sort_values(by='last_watch_dt', inplace=True)
train['item_id'] = train['item_id'].map(lambda x: item_idx_map[x])
train_users_items = train.groupby('user_id')['item_id'].apply(list)

users_history = pd.DataFrame({
    'user_id': np.arange(1, train['user_id'].nunique())
})

users_history['item_id'] = users_history['user_id'].map(lambda x: train_users_items[user_idx_map_inv[x]])
users_history = users_history.explode('item_id')

users_history.to_csv('/content/train.txt', sep=' ', index=False, header=False)

# BERT4Rec

In [None]:
!git clone https://github.com/Tagirov0/BERT4rec_py3_tf2.git
%cd /content/BERT4rec_py3_tf2/BERT4rec

Cloning into 'BERT4rec_py3_tf2'...
remote: Enumerating objects: 184, done.[K
remote: Counting objects: 100% (56/56), done.[K
remote: Compressing objects: 100% (56/56), done.[K
remote: Total 184 (delta 18), reused 0 (delta 0), pack-reused 128[K
Receiving objects: 100% (184/184), 71.58 MiB | 2.06 MiB/s, done.
Resolving deltas: 100% (87/87), done.
/content/BERT4rec_py3_tf2/BERT4rec


In [None]:
!./run_ml-1m.sh

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
I0509 15:58:02.642991 140437868455744 basic_session_run_hooks.py:263] loss = 7.177764, step = 295300 (0.958 sec)
INFO:tensorflow:global_step/sec: 105.279
I0509 15:58:03.592465 140437868455744 basic_session_run_hooks.py:716] global_step/sec: 105.279
INFO:tensorflow:loss = 5.222242, step = 295400 (0.950 sec)
I0509 15:58:03.592834 140437868455744 basic_session_run_hooks.py:263] loss = 5.222242, step = 295400 (0.950 sec)
INFO:tensorflow:global_step/sec: 103.078
I0509 15:58:04.562634 140437868455744 basic_session_run_hooks.py:716] global_step/sec: 103.078
INFO:tensorflow:loss = 5.0558815, step = 295500 (0.970 sec)
I0509 15:58:04.563043 140437868455744 basic_session_run_hooks.py:263] loss = 5.0558815, step = 295500 (0.970 sec)
INFO:tensorflow:global_step/sec: 104.637
I0509 15:58:05.518297 140437868455744 basic_session_run_hooks.py:716] global_step/sec: 104.637
INFO:tensorflow:loss = 8.00245, step = 295600 (0.956 sec)
I0509 15:5

### Проверка точности

In [None]:
item_idx_map_inv = {idx:item_id for item_id, idx in item_idx_map.items()}

In [None]:
def get_preds_df(path, user_idx_map_inv, item_idx_map_inv):
    df = pd.read_csv(path, sep=' ', header=None, names=['user_id', 'item_id'])

    df['item_id'] = df['item_id'].str[5:].astype(int)
    df['rank'] = df.groupby('user_id').cumcount() + 1 

    df['item_id'] = df['item_id'].map(lambda x: item_idx_map_inv[x])
    df['user_id'] = df['user_id'].map(lambda x: user_idx_map_inv[x])
    return df

In [None]:
preds_path = 'bert4rec_kion_preds_256.txt'
preds_bert4rec_256 = get_preds_df(preds_path, user_idx_map_inv, item_idx_map_inv)
preds_bert4rec_256.head()

Unnamed: 0,user_id,item_id,rank
20,226847,512,1
21,226847,7793,2
22,226847,3784,3
23,226847,9817,4
24,226847,10878,5


In [None]:
# hidden_size = 256
metrics['BERT4Rec_256'] = compute_metrics(test, preds_bert4rec, 10)

In [None]:
!./run_ml-1m.sh

In [None]:
preds_path = 'bert4rec_kion_preds.txt'
preds_bert4rec_128 = get_preds_df(preds_path, user_idx_map_inv, item_idx_map_inv)
preds_bert4rec_128.head()

Unnamed: 0,user_id,item_id,rank
0,1047828,4495,1
1,1047828,12192,2
2,1047828,15297,3
3,1047828,7829,4
4,1047828,3784,5


In [None]:
# hidden_size = 128
metrics['BERT4Rec_128'] = compute_metrics(test, preds_bert4rec_128, 10)

In [None]:
df_metrics = pd.concat([
    pd.DataFrame(metrics['BERT4Rec_128']).transpose(),
    pd.DataFrame(metrics['BERT4Rec_256']).transpose()
])

df_metrics.index = ['BERT4Rec_128', 'BERT4Rec_256']

df_metrics

Unnamed: 0,Precision@1,Recall@1,Precision@2,Recall@2,Precision@3,Recall@3,Precision@4,Recall@4,Precision@5,Recall@5,...,Precision@7,Recall@7,Precision@8,Recall@8,Precision@9,Recall@9,Precision@10,Recall@10,MAP@10,MRR
BERT4Rec_128,0.040185,0.014526,0.037563,0.0264,0.035327,0.036685,0.033562,0.04514,0.031944,0.053465,...,0.029813,0.06884,0.028904,0.074823,0.028304,0.081933,0.027578,0.087972,0.033895,0.086145
BERT4Rec_256,0.042297,0.014966,0.037972,0.02563,0.035894,0.036196,0.034345,0.04623,0.032734,0.053887,...,0.029988,0.067829,0.028734,0.073015,0.027933,0.079889,0.027278,0.086122,0.033808,0.087404


# SasRec

In [None]:
!git clone https://github.com/pmixer/SASRec.pytorch.git
%cd /content/SASRec.pytorch

Cloning into 'SASRec.pytorch'...
remote: Enumerating objects: 80, done.[K
remote: Counting objects: 100% (30/30), done.[K
remote: Compressing objects: 100% (9/9), done.[K
remote: Total 80 (delta 25), reused 21 (delta 21), pack-reused 50[K
Unpacking objects: 100% (80/80), 17.96 MiB | 4.40 MiB/s, done.
/content/SASRec.pytorch


Попробовал 2 версии SasRec на tf и pytorch, обе выдают нулевую точность, также пробовал получать рекомендации последовательно, но не помогло

In [None]:
!python main.py --device=cuda --dataset=train --train_dir=default --maxlen=50 --num_epochs=100

[1;30;43mВыходные данные были обрезаны до нескольких последних строк (5000).[0m
loss in epoch 80 iteration 106: 0.5992926359176636
loss in epoch 80 iteration 107: 0.5735217332839966
loss in epoch 80 iteration 108: 0.5583643913269043
loss in epoch 80 iteration 109: 0.5602438449859619
loss in epoch 80 iteration 110: 0.5684735774993896
loss in epoch 80 iteration 111: 0.5660626888275146
loss in epoch 80 iteration 112: 0.5571335554122925
loss in epoch 80 iteration 113: 0.549934983253479
loss in epoch 80 iteration 114: 0.5453009009361267
loss in epoch 80 iteration 115: 0.5630354881286621
loss in epoch 80 iteration 116: 0.553153395652771
loss in epoch 80 iteration 117: 0.5710519552230835
loss in epoch 80 iteration 118: 0.5440434813499451
loss in epoch 80 iteration 119: 0.550947904586792
loss in epoch 80 iteration 120: 0.5570998787879944
loss in epoch 80 iteration 121: 0.5350707769393921
loss in epoch 80 iteration 122: 0.5647276639938354
loss in epoch 80 iteration 123: 0.5434165000915527
los

In [None]:
!python predict.py --device=cuda --dataset=train --train_dir=default --inference_only=true --maxlen=50 --state_dict_path='/content/SASRec.pytorch/train_default/SASRec.epoch=100.lr=0.001.layer=2.head=1.hidden=50.maxlen=50.pth' 

average sequence length: 34.90
100% 31137/31137 [01:41<00:00, 305.78it/s]
Done


In [None]:
item_idx_map_inv = {item: idx for idx, item in item_idx_map.items()}

sasrec_preds = json.load(open('sasrec_preds.json'))
sasrec_preds = {user_idx_map_inv[int(user)]: [item_idx_map_inv[int(id)] for id in item] for user, item in sasrec_preds.items()}

In [None]:
sasrec = pd.DataFrame({
    'user_id': test['user_id'].unique()
})

sasrec['item_id'] = sasrec['user_id'].map(lambda x: sasrec_preds[x])
sasrec = sasrec.explode('item_id')
sasrec['item_id'] = sasrec['item_id'].astype(int)
sasrec['rank'] = sasrec.groupby('user_id').cumcount() + 1 

In [None]:
sasrec_metrics = compute_metrics(test, sasrec, 10)
sasrec_metrics

Precision@1     0.000545
Recall@1        0.000139
Precision@2     0.000443
Recall@2        0.000203
Precision@3     0.000500
Recall@3        0.000383
Precision@4     0.000528
Recall@4        0.000469
Precision@5     0.000545
Recall@5        0.000678
Precision@6     0.000568
Recall@6        0.000999
Precision@7     0.000574
Recall@7        0.001219
Precision@8     0.000622
Recall@8        0.001505
Precision@9     0.000658
Recall@9        0.001748
Precision@10    0.000654
Recall@10       0.001864
MAP@10          0.000454
MRR             0.001683
dtype: float64