In [1]:
import os
import warnings

import numpy as np
import pandas as pd
import torch
from implicit.als import AlternatingLeastSquares
from lightfm import LightFM
from lightning_fabric import seed_everything
from rectools import Columns
from rectools.dataset import Dataset
from rectools.metrics import (
    MAP,
    AvgRecPopularity,
    CoveredUsers,
    HitRate,
    Intersection,
    Recall,
    Serendipity,
)
from rectools.model_selection import TimeRangeSplitter, cross_validate
from rectools.models import (
    BERT4RecModel,
    EASEModel,
    ImplicitALSWrapperModel,
    LightFMWrapperModel,
    PopularModel,
    SASRecModel,
)

os.environ['OPENBLAS_NUM_THREADS'] = '1'
# Enable deterministic behaviour with CUDA >= 10.2
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

warnings.filterwarnings(action='ignore', category=UserWarning)


RANDOM_STATE = 42
torch.use_deterministic_algorithms(True)
seed_everything(RANDOM_STATE, workers=True)

Seed set to 42


42

In [2]:
torch.cuda.is_available()

True

In [3]:
from enum import Enum


class ItemsFeatureTopKConfig(int, Enum):
    """Конфигурация для ограничения количества топовых значений фич."""

    DIRECTORS_TOP_K = 30
    STUDIOS_TOP_K = 15


In [4]:
interactions = pd.read_csv(r'../datasets/interactions_processed.csv')
users = pd.read_csv(r'../datasets/users_processed.csv')
items = pd.read_csv(r'../datasets/items_processed.csv')

# Обработка данных

In [5]:
Columns.Datetime = 'last_watch_dt'

In [6]:
interactions.drop(
    interactions[interactions[Columns.Datetime].str.len() != 10].index,
    inplace=True,
)
interactions[Columns.Datetime] = pd.to_datetime(
    interactions[Columns.Datetime], format='%Y-%m-%d'
)
max_date = interactions[Columns.Datetime].max()
interactions[Columns.Weight] = np.where(interactions['watched_pct'] > 20, 3, 1)

In [7]:
# Разделяем на train и test
train = interactions[
    interactions[Columns.Datetime] < max_date - pd.Timedelta(days=7)
].copy()
test = interactions[
    interactions[Columns.Datetime] >= max_date - pd.Timedelta(days=7)
].copy()

In [8]:
train.drop(train.query('total_dur < 300').index, inplace=True)
cold_users = set(test[Columns.User]) - set(train[Columns.User])
len(cold_users)

72930

In [9]:
# Отбрасываем холодных пользователей
test.drop(test[test[Columns.User].isin(cold_users)].index, inplace=True)

# Подготовка фич

## User features

In [10]:
users = users.loc[users[Columns.User].isin(train[Columns.User])].copy()

In [11]:
user_features_frames = []
for feature in ['sex', 'age', 'income']:
    feature_frame = users.reindex(columns=[Columns.User, feature])
    feature_frame.columns = ['id', 'value']
    feature_frame['feature'] = feature
    user_features_frames.append(feature_frame)
user_features = pd.concat(user_features_frames)
user_features.head()

Unnamed: 0,id,value,feature
0,973171,М,sex
1,962099,М,sex
3,721985,Ж,sex
4,704055,Ж,sex
5,1037719,М,sex


## Item features

In [12]:
items = items.loc[items[Columns.Item].isin(train[Columns.Item])].copy()

In [13]:
items['genre'] = (
    items['genres'].str.replace(', ', ',', regex=False).str.split(',')
)
genre_feature = items[[Columns.Item, 'genre']].explode('genre')
genre_feature.columns = ['id', 'value']
genre_feature['feature'] = 'genre'
genre_feature.head()

Unnamed: 0,id,value,feature
0,10711,драмы,genre
0,10711,зарубежные,genre
0,10711,детективы,genre
0,10711,мелодрамы,genre
1,2508,зарубежные,genre


In [14]:
content_feature = items.reindex(columns=[Columns.Item, 'content_type'])
content_feature.columns = ['id', 'value']
content_feature['feature'] = 'content_type'

In [15]:
countries_feature = items.reindex(columns=[Columns.Item, 'countries'])
countries_feature.columns = ['id', 'value']
countries_feature['feature'] = 'countries'

In [16]:
release_decade_feature = items.reindex(columns=[Columns.Item, 'release_decade'])
release_decade_feature.columns = ['id', 'value']
release_decade_feature['feature'] = 'release_decade'

In [17]:
release_decade_feature['value'].value_counts()

value
2010.0s                 8091
2000.0s                 1955
2020.0s                 1682
1980.0s                  613
1990.0s                  572
1970.0s                  467
1960.0s                  270
1950.0s                  143
1940.0s                   91
1930.0s                   80
release_year_unknown      31
1920.0s                   17
1910.0s                    6
Name: count, dtype: int64

In [18]:
age_rating_feature = items.reindex(columns=[Columns.Item, 'age_rating'])
age_rating_feature.columns = ['id', 'value']
age_rating_feature['feature'] = 'age_rating'

Берем только ТОП-K студий, а остальные заменяем на 'other_studio'. 'other_studio' и 'unknown_studio' - **разные** вещи!

In [19]:
def replace_rare_studios(studio_list):
    return [
        studio if studio in top_studios else 'other_studio'
        for studio in studio_list
    ]


items['studio'] = items['studios'].str.split(r',\s*')
top_studios = (
    items['studio']
    .explode()
    .value_counts()
    .head(ItemsFeatureTopKConfig.STUDIOS_TOP_K)
    .index
)
items['studio'] = items['studio'].apply(replace_rare_studios)

In [20]:
studios_feature = items[[Columns.Item, 'studio']].explode('studio')
studios_feature.columns = ['id', 'value']
studios_feature['feature'] = 'studios'

In [21]:
# Для директоров оставляем топ-30, остальные заменяем на 'other'
items['directors'] = (
    items['directors'].str.replace(', ', ',', regex=False).str.split(',')
)
top_directors = (
    items['directors']
    .explode()
    .value_counts()
    .head(ItemsFeatureTopKConfig.DIRECTORS_TOP_K)
    .index
)

items['director'] = items['directors'].apply(
    lambda x: [d if d in top_directors else 'other_director' for d in x]
)

In [22]:
directors_feature = items[[Columns.Item, 'director']].explode('director')
directors_feature.columns = ['id', 'value']
directors_feature['feature'] = 'director'
directors_feature.head()

Unnamed: 0,id,value,feature
0,10711,other_director,director
1,2508,other_director,director
2,10716,other_director,director
3,7868,other_director,director
4,16268,other_director,director


In [23]:
item_features = pd.concat((
    genre_feature,
    content_feature,
    countries_feature,
    release_decade_feature,
    age_rating_feature,
    studios_feature,
    directors_feature,
))


In [24]:
item_features.info()

<class 'pandas.core.frame.DataFrame'>
Index: 123469 entries, 0 to 15961
Data columns (total 3 columns):
 #   Column   Non-Null Count   Dtype 
---  ------   --------------   ----- 
 0   id       123469 non-null  int64 
 1   value    123469 non-null  object
 2   feature  123469 non-null  object
dtypes: int64(1), object(2)
memory usage: 3.8+ MB


In [25]:
CAT_USER_FEATURES = list(user_features['feature'].unique())
CAT_ITEM_FEATURES = list(item_features['feature'].unique())

In [26]:
CAT_ITEM_FEATURES

['genre',
 'content_type',
 'countries',
 'release_decade',
 'age_rating',
 'studios',
 'director']

In [27]:
dataset = Dataset.construct(
    interactions_df=interactions,
    user_features_df=user_features,
    cat_user_features=CAT_USER_FEATURES,
    item_features_df=item_features,
    cat_item_features=CAT_ITEM_FEATURES,
)

TEST_USERS = test[Columns.User].unique()

In [28]:
splitter = TimeRangeSplitter(
    test_size='7D',
    n_splits=1,
    filter_already_seen=True,
)

In [29]:
GLOBAL_K = 10
metrics = {
    f'Recall@{GLOBAL_K}': Recall(GLOBAL_K),
    f'HitRate@{GLOBAL_K}': HitRate(GLOBAL_K),
    f'MAP@{GLOBAL_K}': MAP(GLOBAL_K),
    f'Serendipity@{GLOBAL_K}': Serendipity(GLOBAL_K),
    # how many test users received recommendations
    f'CoveredUsers@{GLOBAL_K}': CoveredUsers(GLOBAL_K),
    # average popularity of recommended items
    f'AvgRecPopularity@{GLOBAL_K}': AvgRecPopularity(GLOBAL_K),
    # intersection with recommendations from reference model
    f'Intersection@{GLOBAL_K}': Intersection(GLOBAL_K),
}

In [30]:
models = {
    'popular': PopularModel(),
    'ease': EASEModel(),
    'ials': ImplicitALSWrapperModel(
        AlternatingLeastSquares(
            factors=32,
            regularization=0.01883534498756549,
            iterations=5,
        )
    ),
    'lightfm': LightFMWrapperModel(
        LightFM(
            no_components=128,
            learning_rate=0.002680734151218913,
            rho=0.927338160882052,
            loss='warp',
            epsilon=3.2185481401279125e-06,
            user_alpha=0,
            item_alpha=0,
            random_state=RANDOM_STATE,
        ),
        epochs=1,
        num_threads=4,
        verbose=1,
    ),
    'bert4rec_softmax_ids_and_cat': BERT4RecModel(
        mask_prob=0.15,
        deterministic=True,
    ),
}

  check_blas_config()
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [31]:
%%time

# For each fold generate train and test part of dataset
# Then fit every model, generate recommendations and calculate metrics

cv_results = cross_validate(
    dataset=dataset,
    splitter=splitter,
    models=models,
    metrics=metrics,
    k=GLOBAL_K,
    filter_viewed=True,
    # pass reference model to calculate recommendations intersection
    ref_models=['popular'],
    validate_ref_models=True,
)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  unq_values = pd.unique(values)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
`Trainer.fit` stopped: `max_epochs=3` reached.


CPU times: total: 2h 41min 13s
Wall time: 1h 20min 20s


In [32]:
pivot_results = (
    pd.DataFrame(cv_results['metrics'])
    .drop(columns='i_split')
    .groupby(['model'], sort=False)
    .agg(['mean'])
)
pivot_results.columns = pivot_results.columns.droplevel(1)
pivot_results.to_csv('rectools_transformers_cv.csv', index=True)
pivot_results

Unnamed: 0_level_0,Recall@10,HitRate@10,MAP@10,AvgRecPopularity@10,Serendipity@10,Intersection@10_popular,CoveredUsers@10
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
popular,0.166089,0.274365,0.080114,82236.761783,2e-06,1.0,1.0
ease,0.081169,0.151532,0.027664,9327.264478,0.000268,0.076693,1.0
ials,0.13491,0.23859,0.060807,46040.167013,5e-05,0.354506,1.0
bert4rec_softmax_ids_and_cat,0.200544,0.334551,0.095186,46393.294402,0.000174,0.366816,1.0
