In [1]:
!pip install pyspark
!pip install pytorch-lifestream

Collecting pytorch-lifestream
  Downloading pytorch-lifestream-0.6.0.tar.gz (163 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m163.4/163.4 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting hydra-core>=1.1.2 (from pytorch-lifestream)
  Downloading hydra_core-1.3.2-py3-none-any.whl.metadata (5.5 kB)
Downloading hydra_core-1.3.2-py3-none-any.whl (154 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.5/154.5 kB[0m [31m10.6 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: pytorch-lifestream
  Building wheel for pytorch-lifestream (pyproject.toml) ... [?25l[?25hdone
  Created wheel for pytorch-lifestream: filename=pytorch_lifestream-0.6.0-py3-none-any.whl size=274639 sha256=b1b388f9531cb2195da8ca4272039753a1b3079361354a960c06e17d41ebc0a8
 

In [2]:
new_config = """#!/usr/bin/env bash

mkdir -p data
cd data/
mkdir -p raw_data
cd raw_data/

curl -OL https://storage.yandexcloud.net/datasouls-ods/materials/0433a4ca/transactions.zip
curl -OL https://storage.yandexcloud.net/datasouls-ods/materials/0554f0cf/clickstream.zip
curl -OL https://storage.yandexcloud.net/datasouls-ods/materials/acfacf11/train_matching.csv

curl -OL https://storage.yandexcloud.net/datasouls-ods/materials/b949c04c/mcc_codes.csv
curl -OL https://storage.yandexcloud.net/datasouls-ods/materials/705abbab/click_categories.csv
curl -OL https://storage.yandexcloud.net/datasouls-ods/materials/e33f2201/currency_rk.csv

curl -OL https://storage.yandexcloud.net/datasouls-ods/materials/a3643657/sample_submission.zip
curl -OL https://storage.yandexcloud.net/datasouls-ods/materials/24687252/baseline_catboost.zip

curl -OL https://storage.yandexcloud.net/datasouls-ods/materials/b99fed70/puzzle.csv

cd ../../
"""

with open('/kaggle/working/get_data.sh', mode='w') as file:
    file.write(new_config)

In [3]:
!source /kaggle/working/get_data.sh

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  250M  100  250M    0     0  15.5M      0  0:00:16  0:00:15  0:00:01 17.8M0  15.7M      0  0:00:15  0:00:15 --:--:-- 17.9M
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  469M  100  469M    0     0  12.7M      0  0:00:36  0:00:36 --:--:-- 13.8M
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 1045k  100 1045k    0     0   604k      0  0:00:01  0:00:01 --:--:--  604k
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  195k  100  195k    0     0   156k      0  0:00:01  0:00:01 --:--:--  156k
  % 

# Prepare Data

In [4]:
!mkdir data/train_matching_folds
!mkdir data/transactions_folds
!mkdir data/clickstream_folds

# Get data splits

In [5]:
import logging

import numpy as np
import pandas as pd
from zipfile import ZipFile

from sklearn.model_selection import StratifiedKFold

N_SPLITS = 6
CNT_BIN_COUNT = 4
SHUFFLE_RANDOM_STATE = 42

logger = logging.getLogger(__name__)


def get_split_data():
    print('Data loading...')
    df_train_matching = pd.read_csv('data/raw_data/train_matching.csv')

    with ZipFile('data/raw_data/transactions.zip') as z:
        df_transactions = pd.read_csv(z.open('transactions.csv'))
    with ZipFile('data/raw_data/clickstream.zip') as z:
        df_clickstream = pd.read_csv(z.open('clickstream.csv'))

    print(f'Loaded {len(df_train_matching)} pairs, '
                f'{len(df_transactions)} transactions, {len(df_clickstream)} clicks')

    vc = df_transactions['user_id'].value_counts()
    s_trx_cnt_bins = pd.cut(
        vc,
        vc.quantile(np.linspace(0, 1, CNT_BIN_COUNT + 1)),
        labels=np.arange(CNT_BIN_COUNT),
    ).fillna(0).astype(str).rename('trx_cnt_bins')
    vc = df_clickstream['user_id'].value_counts()
    s_click_cnt_bins = pd.cut(
        vc,
        vc.quantile(np.linspace(0, 1, CNT_BIN_COUNT + 1)),
        labels=np.arange(CNT_BIN_COUNT),
    ).fillna(0).astype(str).rename('click_cnt_bins')
    print(f'Prepared {CNT_BIN_COUNT} bins for trx and clicks')

    df_train_matching = pd.merge(df_train_matching, s_trx_cnt_bins, left_on='bank', right_index=True, how='left')
    df_train_matching = pd.merge(df_train_matching, s_click_cnt_bins, left_on='rtk', right_index=True, how='left')
    df_train_matching['trx_cnt_bins'] = df_train_matching['trx_cnt_bins'].fillna(str(CNT_BIN_COUNT))
    df_train_matching['click_cnt_bins'] = df_train_matching['click_cnt_bins'].fillna(str(CNT_BIN_COUNT))
    df_train_matching['cnt_bins'] = df_train_matching['trx_cnt_bins'] + df_train_matching['click_cnt_bins']
    df_train_matching = df_train_matching[['bank', 'rtk', 'cnt_bins']]

    skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=SHUFFLE_RANDOM_STATE)
    for fold_id, (_, ix_folds) in enumerate(skf.split(df_train_matching, df_train_matching['cnt_bins'])):
        df_match = df_train_matching[['bank', 'rtk']].iloc[ix_folds]
        df_match.to_csv(f'data/train_matching_folds/train_matching_{fold_id}.csv', index=False)

        df_trx = df_transactions[lambda x: x['user_id'].isin(df_match['bank'].values)]
        df_trx.to_csv(f'data/transactions_folds/transactions_{fold_id}.csv', index=False)

        df_click = df_clickstream[lambda x: x['user_id'].isin(df_match['rtk'].values)]
        df_click.to_csv(f'data/clickstream_folds/clickstream_{fold_id}.csv', index=False)

        print(f'Saved data for fold {fold_id}: {len(df_match)} pairs, '
                    f'{len(df_trx)} transactions ({len(df_trx["user_id"].unique())} unique users), '
                    f'{len(df_click)} clicks ({len(df_click["user_id"].unique())} unique users)')

    df_trx = df_transactions[lambda x: ~x['user_id'].isin(df_train_matching['bank'].values)]
    df_trx.to_csv(f'data/transactions_unmatched.csv', index=False)

    df_click = df_clickstream[lambda x: ~x['user_id'].isin(df_train_matching['rtk'].values)]
    df_click.to_csv(f'data/clickstream_unmatched.csv', index=False)
    print(f'Saved unmatched data: '
                f'{len(df_trx)} transactions ({len(df_trx["user_id"].unique())} unique users), '
                f'{len(df_click)} clicks ({len(df_click["user_id"].unique())} unique users)')

    print(f'All splits saved')

get_split_data()

Data loading...
Loaded 17581 pairs, 19821910 transactions, 126752515 clicks
Prepared 4 bins for trx and clicks
Saved data for fold 0: 2931 pairs, 2572365 transactions (2931 unique users), 15124893 clicks (2446 unique users)
Saved data for fold 1: 2930 pairs, 2583449 transactions (2930 unique users), 16898083 clicks (2445 unique users)
Saved data for fold 2: 2930 pairs, 2579526 transactions (2930 unique users), 16838662 clicks (2445 unique users)
Saved data for fold 3: 2930 pairs, 2568464 transactions (2930 unique users), 15739400 clicks (2445 unique users)
Saved data for fold 4: 2930 pairs, 2576772 transactions (2930 unique users), 15029891 clicks (2445 unique users)
Saved data for fold 5: 2930 pairs, 2559992 transactions (2930 unique users), 15165618 clicks (2445 unique users)
Saved unmatched data: 4381342 transactions (4952 unique users), 31955968 clicks (4952 unique users)
All splits saved


In [6]:
from glob import glob
data_path = "data/train_matching_folds/train_matching_*.csv"

valid_fold_id = 4
folds_count = len(glob(f'{data_path}'))
train_folds = [i for i in range(folds_count) if valid_fold_id is not None and i != valid_fold_id]

In [7]:
train_folds = [0]

In [8]:
import pandas as pd
import torch


def trx_types(df):
    df['mcc_code'] = df['mcc_code'].astype(str)
    df['currency_rk'] = df['currency_rk'].astype(str)
    df['event_time'] = pd.to_datetime(df['transaction_dttm']).astype(int) / 1e9
    return df[['user_id', 'event_time', 'mcc_code', 'currency_rk', 'transaction_amt']]


def click_types(df):
    df['event_time'] = pd.to_datetime(df['timestamp']).astype(int) / 1e9
    df = pd.merge(df, pd.read_csv('./data/raw_data/click_categories.csv'), on='cat_id')
    df['cat_id'] = df['cat_id'].astype(str)
    return df[['user_id', 'event_time', 'cat_id', 'level_0', 'level_1', 'level_2', 'new_uid']]


def trx_to_torch(seq):
    for x in seq:
        yield x['user_id'], {
            'event_time': x['event_time'],
            'mcc_code': x['mcc_code'],
            'currency_rk': x['currency_rk'],
            'transaction_amt': x['transaction_amt'],
        }


def click_to_torch(seq):
    for x in seq:
        yield x['user_id'], {
            'event_time': torch.from_numpy(x['event_time']).float(),
            'cat_id': torch.from_numpy(x['cat_id']).int(),
            'level_0': torch.from_numpy(x['level_0']).int(),
            'level_1': torch.from_numpy(x['level_1']).int(),
            'level_2': torch.from_numpy(x['level_2']).int(),
            'new_uid': torch.from_numpy(x['new_uid']).int(),
        }


In [9]:
import pandas as pd

data_path = "./data/"
df_matching_train = pd.concat([pd.read_csv(f'{data_path}/train_matching_folds/train_matching_{i}.csv') for i in train_folds])
df_trx_train = pd.concat([trx_types(pd.read_csv(f'{data_path}/transactions_folds/transactions_{i}.csv')) for i in train_folds])
df_click_train = pd.concat([click_types(pd.read_csv(f'{data_path}/clickstream_folds/clickstream_{i}.csv')) for i in train_folds])

In [10]:
df_click_train['user_id'].unique()

array(['0016b2dad12c450b8308e5c3ec2548fe',
       '001cad12665b4483b54b314346a44c69',
       '003c5614416a4c81ab4ee74b72035842', ...,
       'ff8f13a5976147d9b9cf640daa36417a',
       'ffc5887cddb44824bb1e9cb2c59de4e0',
       'ffd1cfa7a0e64b848d439f9040b37f92'], dtype=object)

In [11]:
from ptls.preprocessing.pandas_preprocessor import PandasDataPreprocessor

preprocessor_trx = PandasDataPreprocessor(
    col_id='user_id',
    col_event_time='event_time',
    event_time_transformation='none',
    cols_category=["mcc_code", "currency_rk"],
    cols_identity=[],
)
preprocessor_click = PandasDataPreprocessor(
    col_id='user_id',
    col_event_time='event_time',
    event_time_transformation='none',
    cols_category=['cat_id', 'level_0', 'level_1', 'level_2'],
    cols_identity=['new_uid'],
)

In [12]:
import torch

def trx_to_torch(seq):
    for x in seq:
        yield x['user_id'], {
            'event_time': x['event_time'],
            'mcc_code': x['mcc_code'],
            'currency_rk': x['currency_rk'],
            'transaction_amt': x['transaction_amt'],
        }


def click_to_torch(seq):
    for x in seq:
        yield x['user_id'], {
            'event_time': x['event_time'],
            'cat_id': x['cat_id'],
            'level_0': x['level_0'],
            'level_1': x['level_1'],
            'level_2': x['level_2'],
            'new_uid': x['new_uid'],
        }


In [13]:
features_trx_train = dict(trx_to_torch(preprocessor_trx.fit_transform(df_trx_train)))
features_click_train = dict(click_to_torch(preprocessor_click.fit_transform(df_click_train)))

In [14]:
import gc

del df_trx_train
del df_click_train
gc.collect()

0

In [15]:
import warnings
from enum import Enum
import torch


from ptls.data_load.padded_batch import PaddedBatch
from ptls.nn.trx_encoder.batch_norm import RBatchNorm, RBatchNormWithLens
from ptls.nn.trx_encoder.noisy_embedding import NoisyEmbedding
from ptls.nn.trx_encoder.trx_encoder_base import TrxEncoderBase

class TorchDataTypeConstant(Enum):
    TORCH_FLOAT16 = torch.float16
    TORCH_BFLOAT16 = torch.bfloat16
    TORCH_FLOAT32 = torch.float32
    TORCH_FLOAT64 = torch.float64
    TORCH_INT8 = torch.int8
    TORCH_INT16 = torch.int16
    TORCH_INT32 = torch.int32
    TORCH_INT64 = torch.int64
    TORCH_BOOL = torch.bool
TORCH_FLOAT32 = TorchDataTypeConstant.TORCH_FLOAT32.value
TORCH_EMB_DTYPE = TORCH_FLOAT32

class TrxEncoder(TrxEncoderBase):
    """Network layer which makes representation for single transactions

     Input is `PaddedBatch` with ptls-format dictionary, with feature arrays of shape (B, T)
     Output is `PaddedBatch` with transaction embeddings of shape (B, T, H)
     where:
        B - batch size, sequence count in batch
        T - sequence length
        H - hidden size, representation dimension

    `ptls.nn.trx_encoder.noisy_embedding.NoisyEmbedding` implementation are used for categorical features.

    Parameters
        embeddings:
            dict with categorical feature names.
            Values must be like this `{'in': dictionary_size, 'out': embedding_size}`
            These features will be encoded with lookup embedding table of shape (dictionary_size, embedding_size)
            Values can be a `torch.nn.Embedding` implementation
        numeric_values:
            dict with numerical feature names.
            Values must be a string with scaler_name.
            Possible values are: 'identity', 'sigmoid', 'log', 'year'.
            These features will be scaled with selected scaler.
            Values can be `ptls.nn.trx_encoder.scalers.BaseScaler` implementatoin

            One field can have many scalers. In this case key become alias and col name should be in scaler.
            Check `TrxEncoderBase.numeric_values` for more details

        embeddings_noise (float):
            Noise level for embedding. `0` meens without noise
        emb_dropout (float):
            Probability of an element of embedding to be zeroed
        spatial_dropout (bool):
            Whether to dropout full dimension of embedding in the whole sequence

        use_batch_norm:
            True - All numerical values will be normalized after scaling
            False - No normalizing for numerical values
        use_batch_norm_with_lens:
            True - Respect seq_lens during batch_norm. Padding zeroes will be ignored
            False - Batch norm ower all time axis. Padding zeroes will included.

        orthogonal_init:
            if True then `torch.nn.init.orthogonal` applied
        linear_projection_size:
            Linear layer at the end will be added for non-zero value

        out_of_index:
            How to process a categorical indexes which are greater than dictionary size.
            'clip' - values will be collapsed to maximum index. This works well for frequency encoded categories.
                We join infrequent categories to one.
            'assert' - raise an error of invalid index appear.

        norm_embeddings: keep default value for this parameter
        clip_replace_value: Not useed. keep default value for this parameter
        positions: Not used. Keep default value for this parameter

    Examples:
        >>> B, T = 5, 20
        >>> trx_encoder = TrxEncoder(
        >>>     embeddings={'mcc_code': {'in': 100, 'out': 5}},
        >>>     numeric_values={'amount': 'log'},
        >>> )
        >>> x = PaddedBatch(
        >>>     payload={
        >>>         'mcc_code': torch.randint(0, 99, (B, T)),
        >>>         'amount': torch.randn(B, T),
        >>>     },
        >>>     length=torch.randint(10, 20, (B,)),
        >>> )
        >>> z = trx_encoder(x)
        >>> assert z.payload.shape == (5, 20, 6)  # B, T, H
    """
    def __init__(self,
                 embeddings=None,
                 numeric_values=None,
                 custom_embeddings=None,
                 embeddings_noise: float = 0,
                 norm_embeddings=None,
                 use_batch_norm=False,
                 use_batch_norm_with_lens=False,
                 clip_replace_value=None,
                 positions=None,
                 emb_dropout=0,
                 spatial_dropout=False,
                 orthogonal_init=False,
                 linear_projection_size=0,
                 out_of_index: str = 'clip',
                 ):
        if clip_replace_value is not None:
            warnings.warn('`clip_replace_value` attribute is deprecated. Always "clip to max" used. '
                          'Use `out_of_index="assert"` to avoid categorical values clip', DeprecationWarning)

        if positions is not None:
            warnings.warn('`positions` is deprecated. positions is not used', UserWarning)

        if embeddings is None:
            embeddings = {}
        if custom_embeddings is None:
            custom_embeddings = {}

        noisy_embeddings = {}
        for emb_name, emb_props in embeddings.items():
            if emb_props.get('disabled', False):
                continue
            if emb_props['in'] == 0 or emb_props['out'] == 0:
                continue
            noisy_embeddings[emb_name] = NoisyEmbedding(
                num_embeddings=emb_props['in'],
                embedding_dim=emb_props['out'],
                padding_idx=0,
                max_norm=1 if norm_embeddings else None,
                noise_scale=embeddings_noise,
                dropout=emb_dropout,
                spatial_dropout=spatial_dropout,
            )

        super().__init__(
            embeddings=noisy_embeddings,
            numeric_values=numeric_values,
            custom_embeddings=custom_embeddings,
            out_of_index=out_of_index,
        )

        custom_embedding_size = self.custom_embedding_size
        if use_batch_norm and custom_embedding_size > 0:
            # :TODO: Should we use Batch norm with not-numerical custom embeddings?
            if use_batch_norm_with_lens:
                self.custom_embedding_batch_norm = RBatchNormWithLens(custom_embedding_size)
            else:
                self.custom_embedding_batch_norm = RBatchNorm(custom_embedding_size)
        else:
            self.custom_embedding_batch_norm = None

        if linear_projection_size > 0:
            self.linear_projection_head = torch.nn.Linear(super().output_size, linear_projection_size)
        else:
            self.linear_projection_head = None

        if orthogonal_init:
            for n, p in self.named_parameters():
                if n.startswith('embeddings.') and n.endswith('.weight'):
                    torch.nn.init.orthogonal_(p.data[1:])
                if n == 'linear_projection_head.weight':
                    torch.nn.init.orthogonal_(p.data)

    def forward(self, x: PaddedBatch, names=None, seq_len=None):
        if isinstance(x, PaddedBatch) is False:
            pre_x = dict()
            for i, field_name in enumerate(names):
                pre_x[field_name] = x[i]
            x = PaddedBatch(pre_x, seq_len)

        processed_embeddings = [self.get_category_embeddings(x, field_name)
                                for field_name in self.embeddings.keys()]
        processed_custom_embeddings = [self.get_custom_embeddings(x, field_name)
                                       for field_name in self.custom_embeddings.keys()]
        if len(processed_custom_embeddings):
            processed_custom_embeddings = torch.cat(processed_custom_embeddings, dim=2)
            if self.custom_embedding_batch_norm is not None:
                processed_custom_embeddings = PaddedBatch(processed_custom_embeddings, x.seq_lens)
                processed_custom_embeddings = self.custom_embedding_batch_norm(processed_custom_embeddings)
                processed_custom_embeddings = processed_custom_embeddings.payload
            processed_embeddings.append(processed_custom_embeddings)

        out = torch.cat(processed_embeddings, dim=2)
        out = out.type(TORCH_EMB_DTYPE)
        out = self.linear_projection_head(out) if self.linear_projection_head is not None else out
        return PaddedBatch(out, x.seq_lens)

    @property
    def output_size(self):
        """Returns hidden size of output representation
        """
        if self.linear_projection_head is not None:
            return self.linear_projection_head.out_features
        return super().output_size

# Utilities

In [None]:
def add_nulls(df_dict):
    new_df_dict = {}
    for user_id, user_dict in df_dict.items():
        filled_event_time = []
        prev_time = -1
        max_date = 1720828799
        event_time_int = user_dict['event_time'].int()
        for time in event_time_int:
            if time / 86400 > prev_time / 86400 + 1:
                # Если между предыдущим временем и текущим есть пропуск, заполняем пропуски
                for t in range(prev_time + 1, time):
                    filled_event_time.append(t)
            filled_event_time.append(time)
            prev_time = time
        for time in range(event_time_int[-1], max_date):
            filled_event_time.append(time)
            
        # Преобразуем в тензор
        filled_event_time = torch.tensor(filled_event_time)

        # Обработка остальных признаков
        filled_features = {key: [] for key in user_dict.keys()}

        event_time_index = 0  # Индекс для event_time

        # Пробегаем по всем значениям в filled_event_time
        for i in range(len(filled_event_time)):
            # Если текущий момент времени совпадает с событием в исходном event_time
            if event_time_index < len(event_time_int) and filled_event_time[i] == event_time_int[event_time_index]:
                # Для каждого признака сохраняем значение, если оно есть в event_time
                for key in user_dict.keys():
                    filled_features[key].append(user_dict[key][event_time_index])
                event_time_index += 1
            else:
                # Если нет события, заполняем значением из dict_sizes для каждого признака
                for key in user_dict.keys():
                    filled_features[key].append(user_dict[key])

        # Преобразуем все заполненные признаки в тензоры
        filled_feature_tensors = {key: torch.tensor(filled_features[key]) for key in filled_features}
        filled_feature_tensors['event_time'] = filled_event_time
        new_df_dict[user_id] = filled_feature_tensors
    return new_df_dict
    
features_trx_train = add_nulls(features_trx_train)
features_click_train = add_nulls(features_click_train)

**При использовании add_nulls mrr~0.025 на одной модели. на ансамбле из 3-х моделей ~0.074. На 2-х фолдах**

# MLM TRX Module

In [16]:
import numpy as np
import pytorch_lightning as pl
import torch
from ptls.nn.seq_encoder.rnn_encoder import RnnEncoder
from ptls.nn.seq_encoder.utils import LastStepEncoder
# from ptls.nn.trx_encoder import TrxEncoder
from ptls.data_load.padded_batch import PaddedBatch
import torchmetrics
from torchmetrics import Metric

class MeanLoss(torchmetrics.Metric):
    def __init__(self, **params):
        super().__init__(**params)

        self.add_state('_sum', torch.tensor([0.0]))
        self.add_state('_cnt', torch.tensor([0]))

    def update(self, x):
        self._sum += x.sum()
        self._cnt += x.numel()

    def compute(self):
        return self._sum / self._cnt.float()

class MLMPretrainModule(pl.LightningModule):
    def __init__(self, data_type, params,
                 lr, weight_decay,
                 max_lr, pct_start, total_steps,
                 ):
        super().__init__()
        self.save_hyperparameters()

        common_trx_size = params.common_trx_size
        self.seq_encoder = None

        self.token_mask = torch.nn.Parameter(torch.randn(1, 1, common_trx_size), requires_grad=True)
        self.transf = torch.nn.TransformerEncoder(
            encoder_layer=torch.nn.TransformerEncoderLayer(
                d_model=common_trx_size,
                nhead=params.transf.nhead,
                dim_feedforward=params.transf.dim_feedforward,
                dropout=params.transf.dropout,
                batch_first=True,
            ),
            num_layers=params.transf.num_layers,
            norm=torch.nn.LayerNorm(common_trx_size) if params.transf.norm else None,
        )

        if params.transf.use_pe:
            self.pe = torch.nn.Parameter(self.get_pe(), requires_grad=False)
        else:
            self.pe = None
        self.padding_mask = torch.nn.Parameter(torch.tensor([True, False]).bool(), requires_grad=False)

        self.train_mlm_loss_all = MeanLoss()
        self.valid_mlm_loss_all = MeanLoss()
        self.train_mlm_loss_self = MeanLoss()
        self.valid_mlm_loss_self = MeanLoss()

    def get_pe(self):
        max_len = self.hparams.params.transf.max_len
        H = self.hparams.params.common_trx_size
        f = 2 * np.pi * torch.arange(max_len).view(1, -1, 1) / \
            torch.exp(torch.linspace(*np.log([4, max_len]), H // 2)).view(1, 1, -1)
        return torch.cat([torch.sin(f), torch.cos(f)], dim=2)

    def configure_optimizers(self):
        optim = torch.optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer=optim,
            max_lr=self.hparams.max_lr,
            total_steps=self.hparams.total_steps,
            pct_start=self.hparams.pct_start,
            anneal_strategy='cos',
            cycle_momentum=False,
            div_factor=25.0,
            final_div_factor=10000.0,
            three_phase=True,
        )
        scheduler = {'scheduler': scheduler, 'interval': 'step'}
        return [optim], [scheduler]

    def get_mask(self, x: PaddedBatch):
        return torch.bernoulli(x.seq_len_mask.float() * self.hparams.params.mlm.replace_proba).bool()

    def mask_x(self, x: PaddedBatch, mask):
        return torch.where(mask.unsqueeze(2).expand_as(x.payload),
                           self.token_mask.expand_as(x.payload), x.payload)

    def get_neg_ix(self, mask, neg_type):
        """Sample from predicts, where `mask == True`, without self element.
        For `neg_type='all'` - sample from predicted tokens from batch
        For `neg_type='self'` - sample from predicted tokens from row
        """
        if neg_type == 'all':
            mn = mask.float().view(1, -1) - \
                 torch.eye(mask.numel(), device=mask.device)[mask.flatten()]
            neg_ix = torch.multinomial(mn, self.hparams.params.mlm.neg_count_all)
            b_ix = neg_ix.div(mask.size(1), rounding_mode='trunc')
            neg_ix = neg_ix % mask.size(1)
            return b_ix, neg_ix
        if neg_type == 'self':
            mask_ix = mask.nonzero(as_tuple=False)
            one_pos = torch.eye(mask.size(1), device=mask.device)[mask_ix[:, 1]]
            mn = mask[mask_ix[:, 0]].float() - one_pos
            mn = mn + 1e-9 * (1 - one_pos)
            neg_ix = torch.multinomial(mn, self.hparams.params.mlm.neg_count_self, replacement=True)
            b_ix = mask_ix[:, 0].view(-1, 1).expand_as(neg_ix)
            return b_ix, neg_ix
        raise AttributeError(f'Unknown neg_type: {neg_type}')

    def sentence_encoding(self, x: PaddedBatch):
        return None

    def mlm_loss(self, x: PaddedBatch, neg_type, x_orig: PaddedBatch):
        mask = self.get_mask(x)
        masked_x = self.mask_x(x, mask)
        B, T, H = masked_x.size()

        if self.pe is not None:
            if self.training:
                start_pos = np.random.randint(0, self.hparams.params.transf.max_len - T, 1)[0]
            else:
                start_pos = 0
            pe = self.pe[:, start_pos:start_pos + T]
            masked_x = masked_x + pe

        se = self.sentence_encoding(x_orig)
        if se is not None:
            masked_x = masked_x + se

        out = self.transf(masked_x, src_key_padding_mask=self.padding_mask[x.seq_len_mask])

        if self.pe is not None:
            out = out - pe
        if se is not None:
            out = out - se

        target = x.payload[mask].unsqueeze(1)  # N, 1, H
        predict = out[mask].unsqueeze(1)  # N, 1, H
        neg_ix = self.get_neg_ix(mask, neg_type)
        negative = out[neg_ix[0], neg_ix[1]]  # N, nneg, H
        out_samples = torch.cat([predict, negative], dim=1)
        probas = torch.softmax((target * out_samples).sum(dim=2), dim=1)
        loss = -torch.log(probas[:, 0])
        return loss

    def training_step(self, batch, batch_idx):
        (x_trx, _), = batch

        z_trx = self.seq_encoder(x_trx)  # PB: B, T, H

        loss_mlm = self.mlm_loss(z_trx, neg_type='all', x_orig=x_trx)
        self.train_mlm_loss_all(loss_mlm)
        loss_mlm_all = loss_mlm.mean()
        self.log(f'loss/mlm_{self.hparams.data_type}', loss_mlm_all)

        loss_mlm = self.mlm_loss(z_trx, neg_type='self', x_orig=x_trx)
        self.train_mlm_loss_self(loss_mlm)
        loss_mlm_self = loss_mlm.mean()
        self.log(f'loss/mlm_{self.hparams.data_type}_self', loss_mlm_self)

        return loss_mlm_all + loss_mlm_self

    def validation_step(self, batch, batch_idx):
        (x_trx, _), = batch
        z_trx = self.seq_encoder(x_trx)  # PB: B, T, H

        loss_mlm = self.mlm_loss(z_trx, neg_type='all', x_orig=x_trx)
        self.valid_mlm_loss_all(loss_mlm)

        loss_mlm = self.mlm_loss(z_trx, neg_type='self', x_orig=x_trx)
        self.valid_mlm_loss_self(loss_mlm)

    def on_training_epoch_end(self, _):
        self.log(f'metrics/train_{self.hparams.data_type}_mlm', self.train_mlm_loss_all, prog_bar=True)
        self.log(f'metrics/train_{self.hparams.data_type}_mlm_self', self.train_mlm_loss_self, prog_bar=True)

    def on_validation_epoch_end(self, _):
        self.log(f'metrics/valid_{self.hparams.data_type}_mlm', self.valid_mlm_loss_all, prog_bar=True)

class CustomTrxTransform(torch.nn.Module):
    def __init__(self, trx_amnt_quantiles):
        super().__init__()
        self.trx_amnt_quantiles = torch.nn.Parameter(trx_amnt_quantiles, requires_grad=False)

    def forward(self, x):
        x.payload['transaction_amt_q'] = torch.bucketize(x.payload['transaction_amt'], self.trx_amnt_quantiles) + 1
        return x

class DateFeaturesTransform(torch.nn.Module):
    def forward(self, x):
        et = x.payload['event_time'].int()
        et_day = et.div(24 * 60 * 60, rounding_mode='floor').int()
        x.payload['hour'] = et.div(60 * 60, rounding_mode='floor') % 24 + 1
        x.payload['weekday'] = et.div(60 * 60 * 24, rounding_mode='floor') % 7 + 1
        x.payload['day_diff'] = torch.clamp(torch.diff(et_day, prepend=et_day[:, :1], dim=1), 0, 14)
        return x

class PBLinear(torch.nn.Linear):
    def forward(self, x: PaddedBatch):
        return PaddedBatch(super().forward(x.payload), x.seq_lens)

class PBL2Norm(torch.nn.Module):
    def __init__(self, beta):
        super().__init__()
        self.beta = beta

    def forward(self, x):
        return PaddedBatch(self.beta * x.payload / (x.payload.pow(2).sum(dim=-1, keepdim=True) + 1e-9).pow(0.5),
                           x.seq_lens)

class TransactionEncoder(torch.nn.Module):
    def __init__(self, params, trx_amnt_quantiles):
        super().__init__()

        self.trx_amnt_quantiles = trx_amnt_quantiles
        print(params.trx_seq.trx_encoder)
        t = TrxEncoder(
            norm_embeddings=params.trx_seq.trx_encoder.norm_embeddings,
            embeddings_noise=params.trx_seq.trx_encoder.embeddings_noise,
            embeddings=params.trx_seq.trx_encoder.embeddings
        )
        self.seq_encoder = torch.nn.Sequential(
            CustomTrxTransform(trx_amnt_quantiles=trx_amnt_quantiles),
            DateFeaturesTransform(),
            t, PBLinear(t.output_size, params.common_trx_size),
            PBL2Norm(params.mlm.beta),
        )

    def forward(self, x):
        return self.seq_encoder(x)

class MLMPretrainModuleTrx(MLMPretrainModule):
    def __init__(self,
                 trx_amnt_quantiles,
                 params,
                 lr, weight_decay,
                 max_lr, pct_start, total_steps,
                 ):
        super().__init__(data_type='trx',
                         params=params,
                         lr=lr, weight_decay=weight_decay,
                         max_lr=max_lr, pct_start=pct_start, total_steps=total_steps,
                         )
        self.save_hyperparameters()
        self.seq_encoder = TransactionEncoder(params, trx_amnt_quantiles)

In [17]:
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
from ptls.data_load import augmentation_chain
from ptls.data_load.augmentations.random_slice import RandomSlice
from ptls.preprocessing.pandas_preprocessor import PandasDataPreprocessor
import random
from ptls.data_load.data_module.coles_data_module import coles_collate_fn


class PairedDataset(torch.utils.data.Dataset):
    def __init__(self, pairs, data, augmentations, n_sample):
        super().__init__()

        self.pairs = pairs
        self.data = data
        self.augmentations = augmentations
        self.n_sample = n_sample

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, item):
        ids = self.pairs[item]
        return tuple([[a(d[i]) for _ in range(self.n_sample)]
                      for i, d, a in zip(ids, self.data, self.augmentations)])

    @staticmethod
    def collate_fn(batch):
        return [coles_collate_fn(c) for c in zip(*batch)]

class DropDuplicate:
    def __init__(self, col_check, col_new_cnt=None, keep='first'):
        super().__init__()

        self.col_check = col_check
        self.col_new_cnt = col_new_cnt
        if keep != 'first':
            raise NotImplementedError()

    def __call__(self, x):
        idx, new_cnt = self.get_idx(x[self.col_check])
        new_x = {k: v[idx] for k, v in x.items()}
        if self.col_new_cnt is not None:
            new_x[self.col_new_cnt] = torch.from_numpy(new_cnt)
        return new_x

    def get_idx(self, x):
        diff = np.diff(x, prepend=x[0] - 1)
        new_ix = np.where(diff != 0)[0]
        new_cnt = np.diff(new_ix, append=len(x))
        return new_ix, new_cnt

def pretrain_mlm_trx(features_trx_train, cfg):
    train_dl_mlm_trx = torch.utils.data.DataLoader(
        PairedDataset(
            np.sort(np.array(list(features_trx_train.keys()))).reshape(-1, 1),
            data=[features_trx_train],
            augmentations=[augmentation_chain(
                DropDuplicate('mcc_code', col_new_cnt='c_cnt'),
                RandomSlice(32, 128)
            )],
            n_sample=1,
        ),
        collate_fn=PairedDataset.collate_fn,
        shuffle=True,
        num_workers=12,
        batch_size=128,
        persistent_workers=True,
    )

    # calculate trx_amnt_quantiles
    v = []
    for batch in train_dl_mlm_trx:
        v.append(batch[0][0].payload['transaction_amt'][batch[0][0].seq_len_mask.bool()])
    v = torch.cat(v)
    trx_amnt_quantiles = torch.quantile(torch.unique(v), torch.linspace(0, 1, 100, dtype=v.dtype))

    mlm_model_trx = MLMPretrainModuleTrx(
        params=cfg.model_config,
        lr=0.001, weight_decay=0,
        max_lr=0.001, pct_start=9000 / 2 / 10000, total_steps=10000,
        trx_amnt_quantiles=trx_amnt_quantiles,
    )

    trainer = pl.Trainer(
        accelerator="cuda" if torch.cuda.is_available() else "cpu",
        max_steps=4000,
        enable_progress_bar=True,
        callbacks=[
            pl.callbacks.LearningRateMonitor(),
            pl.callbacks.ModelCheckpoint(
                every_n_train_steps=50, save_top_k=-1,
            ),
        ]
    )
    model_version_trx = trainer.logger.version
    print('Trx pretrain start')
    print('baseline loss all + self:  {:.3f} + {:.3f}'.format(
        np.log(mlm_model_trx.hparams.params.mlm.neg_count_all + 1),
        np.log(mlm_model_trx.hparams.params.mlm.neg_count_self + 1)
    ))
    print(f'version = {model_version_trx}')
    trainer.fit(mlm_model_trx, train_dl_mlm_trx)
    trainer.save_checkpoint(f'{cfg.objects_path}/pretrain_trx.cpt', weights_only=True)
    print('Trx pretrain done')


In [18]:
!mkdir conf
!mkdir objects

In [19]:
# %load /kaggle/working/MBD/scenario_mbd/conf/trx_coles.yaml
new_config = """defaults:
  - override hydra/job_logging: disabled

valid_fold_id: 4

ensemble_size: 5
data_path: ./data
objects_path: "./objects_${valid_fold_id}"

model_config:
    common_trx_size: 256
    transf:
        nhead: 4
        dim_feedforward: 1024
        dropout: 0.1
        num_layers: 3
        norm: false
        max_len: 6000
        use_pe: true
    mlm:
        replace_proba: 0.1
        neg_count_all: 64
        neg_count_self: 8
        beta: 10

    trx_seq:
        trx_encoder:
          norm_embeddings: false
          embeddings_noise: 0.003
          embeddings:
            mcc_code:
              in: 350
              out: 64
            currency_rk:
              in: 10
              out: 4
            transaction_amt_q:
              in: 110
              out: 8
            hour:
              in: 30
              out: 16
            weekday:
              in: 10
              out: 4
            day_diff:
              in: 15
              out: 8
          numeric_values:
            transaction_amt: identity
            c_cnt: log

    click_seq:
        trx_encoder:
          use_batch_norm_with_lens: false
          norm_embeddings: false
          embeddings_noise: 0.003
          embeddings:
            cat_id:
              in: 400
              out: 64
            level_0:
              in: 400
              out: 16
            level_1:
              in: 400
              out: 8
            level_2:
              in: 400
              out: 4
            hour:
              in: 30
              out: 16
            weekday:
              in: 10
              out: 4
            day_diff:
              in: 15
              out: 8
          numeric_values:
            c_cnt: log

    rnn:
      type: gru
      hidden_size: 256
      bidir: false
      trainable_starter: static

"""

with open('/kaggle/working/conf/config.yaml', mode='w') as file:
    file.write(new_config)

In [20]:
from hydra import initialize, compose

initialize(version_base=None, config_path="conf")
cfg = compose(config_name="config.yaml", return_hydra_config=True)

In [21]:
# pretrain_mlm_trx(features_trx_train, cfg)

# Click encoder

In [22]:
class CustomClickTransform(torch.nn.Module):
    def forward(self, x):
        #         x.payload['cat_id'] = torch.clamp(x.payload['cat_id'], 0, 300)
        #         x.payload['level_0'] = torch.clamp(x.payload['level_0'], 0, 200)
        #         x.payload['level_1'] = torch.clamp(x.payload['level_1'], 0, 200)
        #         x.payload['level_2'] = torch.clamp(x.payload['level_2'], 0, 200)
        #         x.payload['c_cnt_clamp'] = torch.clamp(x.payload['c_cnt'], 0, 20).int()
        return x

class ClickEncoder(torch.nn.Module):
    def __init__(self, params):
        super().__init__()

        t = TrxEncoder(
            use_batch_norm_with_lens=params.click_seq.trx_encoder.use_batch_norm_with_lens,
            norm_embeddings=params.click_seq.trx_encoder.norm_embeddings,
            embeddings_noise=params.click_seq.trx_encoder.embeddings_noise,
            embeddings=params.click_seq.trx_encoder.embeddings
        )
        self.seq_encoder = torch.nn.Sequential(
            CustomClickTransform(),
            DateFeaturesTransform(),
            t, PBLinear(t.output_size, params.common_trx_size),
            PBL2Norm(params.mlm.beta),
        )

    def forward(self, x):
        return self.seq_encoder(x)

class MLMPretrainModuleClick(MLMPretrainModule):
    def __init__(self, params,
                 lr, weight_decay,
                 max_lr, pct_start, total_steps,
                 ):
        super().__init__(data_type='click',
                         params=params,
                         lr=lr, weight_decay=weight_decay,
                         max_lr=max_lr, pct_start=pct_start, total_steps=total_steps,
                         )
        self.save_hyperparameters()
        self.seq_encoder = ClickEncoder(params)

def pretrain_mlm_click(features_click_train, cfg):
    train_dl_mlm_click = torch.utils.data.DataLoader(
        PairedDataset(
            np.sort(np.array(list(features_click_train.keys()))).reshape(-1, 1),
            data=[features_click_train],
            augmentations=[augmentation_chain(
                DropDuplicate('cat_id', col_new_cnt='c_cnt'),
                RandomSlice(32, 128)
            )],
            n_sample=1,
        ),
        collate_fn=PairedDataset.collate_fn,
        shuffle=True,
        num_workers=12,
        batch_size=64,
        persistent_workers=True,
    )

    mlm_model_click = MLMPretrainModuleClick(
        params=cfg.model_config,
        lr=0.001, weight_decay=0,
        max_lr=0.001, pct_start=9000 / 2 / 10000, total_steps=10000,
    )

    trainer = pl.Trainer(
        accelerator="cuda" if torch.cuda.is_available() else "cpu",
        max_steps=3000,
        enable_progress_bar=True,
        callbacks=[
            pl.callbacks.LearningRateMonitor(),
            pl.callbacks.ModelCheckpoint(
                every_n_train_steps=2000, save_top_k=-1,
            ),
        ]
    )
    model_version_click = trainer.logger.version
    print('Click pretrain start')
    print('baseline loss all + self:  {:.3f} + {:.3f}'.format(
        np.log(mlm_model_click.hparams.params.mlm.neg_count_all + 1),
        np.log(mlm_model_click.hparams.params.mlm.neg_count_self + 1)
    ))
    trainer.fit(mlm_model_click, train_dl_mlm_click)
    trainer.save_checkpoint(f'{cfg.objects_path}/pretrain_click.cpt', weights_only=True)
    print('Click pretrain done')

In [23]:
# pretrain_mlm_click(features_click_train, cfg)

In [24]:
class PairedFullDataset(torch.utils.data.Dataset):
    def __init__(self, pairs, data, augmentations, n_sample):
        super().__init__()

        self.pairs = pairs
        self.data = data
        self.augmentations = augmentations
        self.n_sample = n_sample
        self.full_pairs = self.get_full_pairs()

    @staticmethod
    def collate_fn(batch):
        """
        In:
        [1, 2 ,3], [4, 5]
        [6, 7], [None, None]
        [None, None], [8, 9]

        Out:
        x, labels, out_of_match

        PaddedBatch([1, 2, 3, 6, 7]), [0, 0, 0, 1, 1], [0, 0, 0, 1, 1]
        PaddedBatch([4, 5, 8, 9]), [0, 0, 2, 2], [0, 0, 1, 1]
        """
        data_1 = [t for p, _ in batch for t in p if t is not None]
        data_2 = [t for _, p in batch for t in p if t is not None]

        labels = torch.arange(len(batch), dtype=torch.int32)
        labels_1 = torch.repeat_interleave(labels,
                                           torch.tensor([len([t for t in p if t is not None]) for p, _ in batch]))
        labels_2 = torch.repeat_interleave(labels,
                                           torch.tensor([len([t for t in p if t is not None]) for _, p in batch]))
        out_of_match_1 = torch.tensor([1 if p2[0] is None else 0 for p1, p2 in batch for t in p1 if t is not None])
        out_of_match_2 = torch.tensor([1 if p1[0] is None else 0 for p1, p2 in batch for t in p2 if t is not None])
        return (
            padded_collate_wo_target(data_1), labels_1, out_of_match_1.int(),
            padded_collate_wo_target(data_2), labels_2, out_of_match_2.int(),
        )

    @staticmethod
    def not_in(v, items_to_exclude):
        a = np.sort(items_to_exclude)
        return v[np.pad(a, pad_width=(0, 1), constant_values='')[np.searchsorted(a, v)] != v]

    def get_full_pairs(self):
        free_trx = self.not_in(np.array(list(self.data[0].keys())), self.pairs[:, 0]).reshape(-1, 1)
        free_clicks = self.not_in(np.array(list(self.data[1].keys())), self.pairs[:, 1]).reshape(-1, 1)

        return np.concatenate([
            self.pairs,
            np.concatenate([free_trx, np.full((len(free_trx), 1), '0')], axis=1),
            np.concatenate([np.full((len(free_clicks), 1), '0'), free_clicks], axis=1),
        ], axis=0)

    def __len__(self):
        return len(self.full_pairs)

    def __getitem__(self, item):
        ids = self.full_pairs[item]
        return tuple([[a(d[i]) if i != '0' else None for _ in range(self.n_sample)]
                      for i, d, a in zip(ids, self.data, self.augmentations)])


In [25]:
class PBLayerNorm(torch.nn.LayerNorm):
    def forward(self, x: PaddedBatch):
        return PaddedBatch(super().forward(x.payload), x.seq_lens)

class L2Scorer(torch.nn.Module):
    def forward(self, x):
        B, H = x.size()
        a, b =x[:, :H // 2], x[:, H // 2:]
        return -(a - b).pow(2).sum(dim=1)

class PairedModule(pl.LightningModule):
    def __init__(self, params, trx_amnt_quantiles, k,
                 lr, weight_decay,
                 max_lr, pct_start, total_steps,
                 beta, neg_count,
                 ):
        super().__init__()
        self.save_hyperparameters(ignore=['mlm_model_trx', 'mlm_model_click'])

        common_trx_size = params.common_trx_size
        self.rnn_enc = torch.nn.Sequential(
            RnnEncoder(
                common_trx_size,
                type=params.rnn.type,
                hidden_size=params.rnn.hidden_size,
                bidir=params.rnn.bidir,
                trainable_starter=params.rnn.trainable_starter
            ),
            LastStepEncoder(),
            #             NormEncoder(),
        )
        self._seq_encoder_trx = torch.nn.Sequential(
            TransactionEncoder(params, trx_amnt_quantiles),
            PBLayerNorm(common_trx_size),
        )
        self._seq_encoder_click = torch.nn.Sequential(
            ClickEncoder(params),
            PBLayerNorm(common_trx_size),
        )

        self.cls = torch.nn.Sequential(
            L2Scorer(),
        )

        self.train_precision = PrecisionK(k=k)
        self.train_mrr = MeanReciprocalRankK(k=k)
        self.valid_precision = PrecisionK(k=k)
        self.valid_mrr = MeanReciprocalRankK(k=k)

    def load_pretrained(self, trx, click):
        self._seq_encoder_trx[0].load_state_dict(trx.state_dict())
        self._seq_encoder_click[0].load_state_dict(click.state_dict())

    def seq_encoder_trx(self, x):
        x = self._seq_encoder_trx(x)
        return self.rnn_enc(x)

    def seq_encoder_click(self, x_orig):
        x = self._seq_encoder_click(x_orig)
        #         x = PaddedBatch(
        #             x.payload + self.mlm_model_click.sentence_encoding(x_orig),
        #             x.seq_lens,
        #         )
        return self.rnn_enc(x)

    def configure_optimizers(self):
        optim = torch.optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer=optim,
            max_lr=self.hparams.max_lr,
            total_steps=self.hparams.total_steps,
            pct_start=self.hparams.pct_start,
            anneal_strategy='cos',
            cycle_momentum=False,
            div_factor=25.0,
            final_div_factor=10000.0,
            three_phase=True,
        )
        scheduler = {'scheduler': scheduler, 'interval': 'step'}
        return [optim], [scheduler]

    def loss_fn_p(self, embeddings, labels, ref_emb, ref_labels):
        beta = self.hparams.beta
        neg_count = self.hparams.neg_count

        pos_ix = (labels.view(-1, 1) == ref_labels.view(1, -1)).nonzero(as_tuple=False)
        pos_labels = labels[pos_ix[:, 0]]
        neg_w = ((pos_labels.view(-1, 1) != ref_labels.view(1, -1))).float()
        neg_ix = torch.multinomial(neg_w, neg_count - 1)
        all_ix = torch.cat([pos_ix[:, [1]], neg_ix], dim=1)
        logits = -(embeddings[pos_ix[:, [0]]] - ref_emb[all_ix]).pow(2).sum(dim=2)
        logits = logits * beta
        logs = -torch.log(torch.softmax(logits, dim=1))[:, 0]
        #         logs = torch.relu(logs + np.log(0.1))
        return logs.mean()

    def training_step(self, batch, batch_idx):
        # pairs
        x_trx, l_trx, m_trx, x_click, l_click, m_click = batch
        z_trx = self.seq_encoder_trx(x_trx)  # B, H
        z_click = self.seq_encoder_click(x_click)  # B, H
        loss_pt = self.loss_fn_p(embeddings=z_trx, labels=l_trx, ref_emb=z_click, ref_labels=l_click)
        self.log('loss/loss_pt', loss_pt)

        loss_pc = self.loss_fn_p(embeddings=z_click, labels=l_click, ref_emb=z_trx, ref_labels=l_trx)
        self.log('loss/loss_pc', loss_pc)

        with torch.no_grad():
            out = -(z_trx.unsqueeze(1) - z_click.unsqueeze(0)).pow(2).sum(dim=2)
            out = out[m_trx == 0][:, m_click == 0]
            T, C = out.size()
            assert T == C
            n_samples = z_trx.size(0) // (l_trx.max().item() + 1)
            for i in range(n_samples):
                l2 = out[i::n_samples, i::n_samples]
                self.train_precision(l2)
                self.train_mrr(l2)

        return loss_pt + 0.1 * loss_pc  # loss_pc

    def on_train_epoch_end(self):
        self.log('train_metrics/precision', self.train_precision, prog_bar=True)
        self.log('train_metrics/mrr', self.train_mrr, prog_bar=True)

# Metrics

In [26]:
import numpy as np
import torch
import torchmetrics
from torchmetrics import Metric
import pytorch_lightning as pl


class PrecisionK(Metric):
    def __init__(self, k, **params):
        super().__init__(**params)

        self.add_state('_sum', torch.tensor(0))
        self.add_state('_cnt', torch.tensor(0))
        self.k = k

    def update(self, preds, target=None):
        B, _ = preds.size()
        ix_sort = torch.argsort(preds, dim=1, descending=True)
        ix_sort = ix_sort == torch.arange(B, device=preds.device, dtype=torch.long).view(-1, 1)
        k = min(self.k, B)
        ix_sort = (ix_sort[:, :k].int().sum(dim=1) > 0).int().sum()
        self._sum = self._sum + ix_sort
        self._cnt = self._cnt + B

    def compute(self):
        return self._sum.float() / self._cnt.float()


class MeanReciprocalRankK(Metric):
    def __init__(self, k, max_k=100, **params):
        super().__init__(**params)

        self.add_state('_sum', torch.tensor(0))
        self.add_state('_cnt', torch.tensor(0))
        self.k = k
        self.max_k = max_k

    def update(self, preds, target=None):
        B, _ = preds.size()
        ix_sort = torch.argsort(preds, dim=1, descending=True)
        ix_sort = ix_sort == torch.arange(B, device=preds.device, dtype=torch.long).view(-1, 1)
        k = min(self.k, B)
        ix_sort = ix_sort[:, :k]
        ranks = self.k / self.max_k / (1 + torch.arange(k, device=preds.device).view(1, -1).expand(B, k))
        ranks = ranks[ix_sort]

        self._sum = self._sum + ranks.sum()
        self._cnt = self._cnt + B

    def compute(self):
        return self._sum.float() / self._cnt.float()


class ValidationCallback(pl.Callback):
    def __init__(self, v_trx, v_click, target, device, device_main, k=100, batch_size=1024):
        self.v_trx = v_trx
        self.v_click = v_click
        self.target = target
        self.device = device
        self.device_main = device_main
        self.k = k
        self.batch_size = batch_size

    def on_train_epoch_end(self, trainer, pl_module):
        was_traning = False
        if pl_module.training:
            pl_module.eval()
            was_traning = True

        pl_module.to(self.device)
        with torch.no_grad():
            z_trx = []
            for ((x_trx, _),) in self.v_trx:
                z_trx.append(pl_module.seq_encoder_trx(x_trx.to(self.device)))
            z_trx = torch.cat(z_trx, dim=0)
            z_click = []
            for ((x_click, _),) in self.v_click:
                z_click.append(pl_module.seq_encoder_click(x_click.to(self.device)))
            z_click = torch.cat(z_click, dim=0)

            T = z_trx.size(0)
            C = z_click.size(0)
            device = z_trx.device
            ix_t = torch.arange(T, device=device).view(-1, 1).expand(T, C).flatten()
            ix_c = torch.arange(C, device=device).view(1, -1).expand(T, C).flatten()

            z_out = []
            for i in range(0, len(ix_t), self.batch_size):
                z_pairs = torch.cat([
                    z_trx[ix_t[i:i + self.batch_size]],
                    z_click[ix_c[i:i + self.batch_size]],
                ], dim=1)
                z_out.append(pl_module.cls(z_pairs).unsqueeze(1))
            z_out = torch.cat(z_out, dim=0).view(T, C)

            precision, mrr, r1 = self.logits_to_metrics(z_out)

            pl_module.log('valid_full_metrics/precision', precision, prog_bar=True)
            pl_module.log('valid_full_metrics/mrr', mrr, prog_bar=False)
            pl_module.log('valid_full_metrics/r1', r1, prog_bar=False)

        pl_module.to(self.device_main)
        if was_traning:
            pl_module.train()

    def logits_to_metrics(self, z_out):
        T, C = z_out.size()
        z_ranks = torch.zeros_like(z_out)
        z_ranks[
            torch.arange(T, device=self.device).view(-1, 1).expand(T, C),
            torch.argsort(z_out, dim=1, descending=True),
        ] = torch.arange(C, device=self.device).float().view(1, -1).expand(T, C) + 1
        true_ranks = z_ranks[
            np.arange(T),
            np.searchsorted(self.v_click.dataset.pairs[:, 0],
                            self.target.set_index('bank')['rtk'].loc[self.v_trx.dataset.pairs[:, 0]].values)
        ]
        precision = torch.where(true_ranks <= self.k,
                                torch.ones(1, device=self.device), torch.zeros(1, device=self.device)).mean()
        mrr = torch.where(true_ranks <= self.k, 1 / true_ranks, torch.zeros(1, device=self.device)).mean()
        r1 = 2 * mrr * precision / (mrr + precision)
        return precision, mrr, r1


class ValidationSplittingCallback(pl.Callback):
    def __init__(self, v_trx, v_click, target, device, device_main, agg_method, k=100, batch_size=1024):
        self.v_trx = v_trx
        self.v_click = v_click
        self.target = target
        self.device = device
        self.device_main = device_main
        self.agg_method = agg_method
        self.k = k
        self.batch_size = batch_size

    def on_train_epoch_end(self, trainer, pl_module):
        was_traning = False
        if pl_module.training:
            pl_module.eval()
            was_traning = True

        pl_module.to(self.device)
        with torch.no_grad():
            z_trx = []
            for ((x_trx, _),) in self.v_trx:
                z_trx.append(pl_module.seq_encoder_trx(x_trx.to(self.device)))
            z_trx = torch.cat(z_trx, dim=0)
            z_click = []
            for ((x_click, _),) in self.v_click:
                z_click.append(pl_module.seq_encoder_click(x_click.to(self.device)))
            z_click = torch.cat(z_click, dim=0)

            T = z_trx.size(0)
            C = z_click.size(0)
            device = z_trx.device
            ix_t = torch.arange(T, device=device).view(-1, 1).expand(T, C).flatten()
            ix_c = torch.arange(C, device=device).view(1, -1).expand(T, C).flatten()

            z_out = []
            for i in range(0, len(ix_t), self.batch_size):
                z_pairs = torch.cat([
                    z_trx[ix_t[i:i + self.batch_size]],
                    z_click[ix_c[i:i + self.batch_size]],
                ], dim=1)
                z_out.append(pl_module.cls(z_pairs).unsqueeze(1))
            z_out = torch.cat(z_out, dim=0).view(T, C)

            T, Nt = len(self.v_trx.dataset), z_out.size(0) // len(self.v_trx.dataset)
            C, Nc = len(self.v_click.dataset), z_out.size(1) // len(self.v_click.dataset)

            if self.agg_method == 'max':
                z_out = z_out.view(T, Nt, C, Nc).max(dim=3).values.max(dim=1).values
            elif self.agg_method == 'mean':
                z_out = z_out.view(T, Nt, C, Nc).mean(dim=[1, 3])
            else:
                raise AttributeError(f'agg_method: {self.agg_method}')

            precision, mrr, r1 = self.logits_to_metrics(z_out)

            pl_module.log('valid_full_metrics/precision', precision, prog_bar=True)
            pl_module.log('valid_full_metrics/mrr', mrr, prog_bar=False)
            pl_module.log('valid_full_metrics/r1', r1, prog_bar=False)

        pl_module.to(self.device_main)
        if was_traning:
            pl_module.train()

    def logits_to_metrics(self, z_out):
        T, C = z_out.size()
        z_ranks = torch.zeros_like(z_out)
        z_ranks[
            torch.arange(T, device=self.device).view(-1, 1).expand(T, C),
            torch.argsort(z_out, dim=1, descending=True),
        ] = torch.arange(C, device=self.device).float().view(1, -1).expand(T, C) + 1
        true_ranks = z_ranks[
            np.arange(T),
            np.searchsorted(self.v_click.dataset.ids,
                            self.target.set_index('bank')['rtk'].loc[self.v_trx.dataset.ids].values)
        ]
        precision = torch.where(true_ranks <= self.k,
                                torch.ones(1, device=self.device), torch.zeros(1, device=self.device)).mean()
        mrr = torch.where(true_ranks <= self.k, 1 / true_ranks, torch.zeros(1, device=self.device)).mean()
        r1 = 2 * mrr * precision / (mrr + precision)
        return precision, mrr, r1


class MeanLoss(torchmetrics.Metric):
    def __init__(self, **params):
        super().__init__(**params)

        self.add_state('_sum', torch.tensor([0.0]))
        self.add_state('_cnt', torch.tensor([0]))

    def update(self, x):
        self._sum += x.sum()
        self._cnt += x.numel()

    def compute(self):
        return self._sum / self._cnt.float()

In [27]:
from ptls.data_load import padded_collate_wo_target

def train_qsm(df_matching_train, features_trx_train, features_click_train, model_n, cfg):
    batch_size = 128
    train_dl = torch.utils.data.DataLoader(
        PairedFullDataset(
            df_matching_train[lambda x: x['rtk'].ne('0')].values,
            data=[
                features_trx_train,
                features_click_train,
            ],
            augmentations=[
                augmentation_chain(DropDuplicate('mcc_code', col_new_cnt='c_cnt'), RandomSlice(32, 1024)),  # 1024
                augmentation_chain(DropDuplicate('cat_id', col_new_cnt='c_cnt'), RandomSlice(64, 2048)),  # 2048
            ],
            n_sample=2,
        ),
        collate_fn=PairedFullDataset.collate_fn,
        drop_last=True,
        shuffle=True,
        num_workers=12,
        batch_size=batch_size,
        persistent_workers=True,
    )

    mlm_model_trx = MLMPretrainModuleTrx.load_from_checkpoint(f'{cfg.objects_path}/pretrain_trx.cpt')
    mlm_model_click = MLMPretrainModuleClick.load_from_checkpoint(f'{cfg.objects_path}/pretrain_click.cpt')
    pl.seed_everything(random.randint(1, 2**16 - 1))
    sup_model = PairedModule(
        cfg.model_config, trx_amnt_quantiles=mlm_model_trx.seq_encoder.trx_amnt_quantiles,
        k=100 * batch_size // 3000,
        lr=0.0022, weight_decay=0,
        max_lr=0.0018, pct_start=1100 / 6000, total_steps=6000,
        beta=0.2 / 1.4, neg_count=120,
    )
    sup_model.load_pretrained(mlm_model_trx.seq_encoder, mlm_model_click.seq_encoder)

    trainer = pl.Trainer(
        accelerator="cuda" if torch.cuda.is_available() else "cpu",
        max_steps=3300,
        callbacks=[
            pl.callbacks.LearningRateMonitor(),
            pl.callbacks.ModelCheckpoint(
                every_n_train_steps=1000, save_top_k=-1,
            ),
        ]
    )
    print('Train qsm start')
    trainer.fit(sup_model, train_dl)
    trainer.save_checkpoint(f'{cfg.objects_path}/nn_distance_coles_model_{model_n}.cpt', weights_only=True)
    print(f'Train qsm [{model_n}] done')

**C предобучением энкодеров**

In [28]:
# ensemble_size = 1

# for i in range(ensemble_size):
#     train_qsm(df_matching_train, features_trx_train, features_click_train, i, cfg)

**mrr~0.029 на одной модели. на ансамбле из 3-х моделей ~0.08. На 2-х фолдах**

**В оригинальном репозитории, ансамбль из 5 моделей, mrr=0.2**

**Без предобучения енкодеров**

In [34]:
from ptls.data_load import padded_collate_wo_target

def train_qsm(df_matching_train, features_trx_train, features_click_train, model_n, cfg):
    batch_size = 128
    train_dl = torch.utils.data.DataLoader(
        PairedFullDataset(
            df_matching_train[lambda x: x['rtk'].ne('0')].values,
            data=[
                features_trx_train,
                features_click_train,
            ],
            augmentations=[
                augmentation_chain(DropDuplicate('mcc_code', col_new_cnt='c_cnt'), RandomSlice(32, 1024)),  # 1024
                augmentation_chain(DropDuplicate('cat_id', col_new_cnt='c_cnt'), RandomSlice(64, 2048)),  # 2048
            ],
            n_sample=2,
        ),
        collate_fn=PairedFullDataset.collate_fn,
        drop_last=True,
        shuffle=True,
        num_workers=12,
        batch_size=batch_size,
        persistent_workers=True,
    )

    train_dl_mlm_trx = torch.utils.data.DataLoader(
        PairedDataset(
            np.sort(np.array(list(features_trx_train.keys()))).reshape(-1, 1),
            data=[features_trx_train],
            augmentations=[augmentation_chain(
                DropDuplicate('mcc_code', col_new_cnt='c_cnt'),
                RandomSlice(32, 128)
            )],
            n_sample=1,
        ),
        collate_fn=PairedDataset.collate_fn,
        shuffle=True,
        num_workers=12,
        batch_size=128,
        persistent_workers=True,
    )
    
    v = []
    for batch in train_dl_mlm_trx:
        v.append(batch[0][0].payload['transaction_amt'][batch[0][0].seq_len_mask.bool()])
    v = torch.cat(v)
    trx_amnt_quantiles = torch.quantile(torch.unique(v), torch.linspace(0, 1, 100, dtype=v.dtype))

    mlm_model_trx = MLMPretrainModuleTrx(
        params=cfg.model_config,
        lr=0.001, weight_decay=0,
        max_lr=0.001, pct_start=9000 / 2 / 10000, total_steps=10000,
        trx_amnt_quantiles=trx_amnt_quantiles,
    )
    mlm_model_click = MLMPretrainModuleClick(
        params=cfg.model_config,
        lr=0.001, weight_decay=0,
        max_lr=0.001, pct_start=9000 / 2 / 10000, total_steps=10000,
    )
    pl.seed_everything(random.randint(1, 2**16 - 1))
    sup_model = PairedModule(
        cfg.model_config, trx_amnt_quantiles=mlm_model_trx.seq_encoder.trx_amnt_quantiles,
        k=100 * batch_size // 3000,
        lr=0.0022, weight_decay=0,
        max_lr=0.0018, pct_start=1100 / 6000, total_steps=6000,
        beta=0.2 / 1.4, neg_count=120,
    )
    sup_model.load_pretrained(mlm_model_trx.seq_encoder, mlm_model_click.seq_encoder)

    trainer = pl.Trainer(
        accelerator="cuda" if torch.cuda.is_available() else "cpu",
        max_steps=3300,
        callbacks=[
            pl.callbacks.LearningRateMonitor(),
            pl.callbacks.ModelCheckpoint(
                every_n_train_steps=1000, save_top_k=-1,
            ),
        ]
    )
    print('Train qsm start')
    trainer.fit(sup_model, train_dl)
    trainer.save_checkpoint(f'{cfg.objects_path}/nn_distance_coles_model_{model_n}.cpt', weights_only=True)
    print(f'Train qsm [{model_n}] done')

In [None]:
ensemble_size = 1

for i in range(ensemble_size):
    train_qsm(df_matching_train, features_trx_train, features_click_train, i, cfg)

{'norm_embeddings': False, 'embeddings_noise': 0.003, 'embeddings': {'mcc_code': {'in': 350, 'out': 64}, 'currency_rk': {'in': 10, 'out': 4}, 'transaction_amt_q': {'in': 110, 'out': 8}, 'hour': {'in': 30, 'out': 16}, 'weekday': {'in': 10, 'out': 4}, 'day_diff': {'in': 15, 'out': 8}}, 'numeric_values': {'transaction_amt': 'identity', 'c_cnt': 'log'}}
{'norm_embeddings': False, 'embeddings_noise': 0.003, 'embeddings': {'mcc_code': {'in': 350, 'out': 64}, 'currency_rk': {'in': 10, 'out': 4}, 'transaction_amt_q': {'in': 110, 'out': 8}, 'hour': {'in': 30, 'out': 16}, 'weekday': {'in': 10, 'out': 4}, 'day_diff': {'in': 15, 'out': 8}}, 'numeric_values': {'transaction_amt': 'identity', 'c_cnt': 'log'}}
Train qsm start


/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/fit_loop.py:310: The number of training batches (22) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Энкодеры учатся и внутри, но крайне медленно

При обучении энкодеров внутри, они обучаются сильно медленнее, чем отдельно

Ожидалось, что обучение будет медленнее, но сравнимо с оригинальным подходом

# Module with cross-attention

**cross-attention without mean**

In [46]:
from torch import nn

class PBLayerNorm(torch.nn.LayerNorm):
    def forward(self, x: PaddedBatch):
        return PaddedBatch(super().forward(x.payload), x.seq_lens)

class L2Scorer(torch.nn.Module):
    def forward(self, x):
        B, H = x.size()
        a, b =x[:, :H // 2], x[:, H // 2:]
        return -(a - b).pow(2).sum(dim=1)

class PairedModule(pl.LightningModule):
    def __init__(self, params, trx_amnt_quantiles, k,
                 lr, weight_decay,
                 max_lr, pct_start, total_steps,
                 beta, neg_count,
                 ):
        super().__init__()
        self.save_hyperparameters(ignore=['mlm_model_trx', 'mlm_model_click'])

        common_trx_size = params.common_trx_size
        self.rnn_enc = torch.nn.Sequential(
            RnnEncoder(
                common_trx_size,
                type=params.rnn.type,
                hidden_size=params.rnn.hidden_size,
                bidir=params.rnn.bidir,
                trainable_starter=params.rnn.trainable_starter
            ),
            LastStepEncoder(),
            #             NormEncoder(),
        )
        self.mha_trx_click = nn.MultiheadAttention(
            embed_dim=256,##?,
            num_heads=4,
            dropout=0.3,
            batch_first=True
        )
        self.mha_click_trx = nn.MultiheadAttention(
            embed_dim=256,##?,
            num_heads=4,
            dropout=0.3,
            batch_first=True
        )
        self._seq_encoder_trx = torch.nn.Sequential(
            TransactionEncoder(params, trx_amnt_quantiles),
            PBLayerNorm(common_trx_size),
        )
        self._seq_encoder_click = torch.nn.Sequential(
            ClickEncoder(params),
            PBLayerNorm(common_trx_size),
        )

        self.cls = torch.nn.Sequential(
            L2Scorer(),
        )

        self.train_precision = PrecisionK(k=k)
        self.train_mrr = MeanReciprocalRankK(k=k)
        self.valid_precision = PrecisionK(k=k)
        self.valid_mrr = MeanReciprocalRankK(k=k)

    def load_pretrained(self, trx, click):
        self._seq_encoder_trx[0].load_state_dict(trx.state_dict())
        self._seq_encoder_click[0].load_state_dict(click.state_dict())

    def seq_encoder_trx(self, x):
        x = self._seq_encoder_trx(x)
        return self.rnn_enc(x)

    def seq_encoder_click(self, x_orig):
        x = self._seq_encoder_click(x_orig)
        #         x = PaddedBatch(
        #             x.payload + self.mlm_model_click.sentence_encoding(x_orig),
        #             x.seq_lens,
        #         )
        return self.rnn_enc(x)

    def configure_optimizers(self):
        optim = torch.optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer=optim,
            max_lr=self.hparams.max_lr,
            total_steps=self.hparams.total_steps,
            pct_start=self.hparams.pct_start,
            anneal_strategy='cos',
            cycle_momentum=False,
            div_factor=25.0,
            final_div_factor=10000.0,
            three_phase=True,
        )
        scheduler = {'scheduler': scheduler, 'interval': 'step'}
        return [optim], [scheduler]

    def loss_fn_p(self, embeddings, labels, ref_emb, ref_labels):
        beta = self.hparams.beta
        neg_count = self.hparams.neg_count

        pos_ix = (labels.view(-1, 1) == ref_labels.view(1, -1)).nonzero(as_tuple=False)
        pos_labels = labels[pos_ix[:, 0]]
        neg_w = ((pos_labels.view(-1, 1) != ref_labels.view(1, -1))).float()
        neg_ix = torch.multinomial(neg_w, neg_count - 1)
        all_ix = torch.cat([pos_ix[:, [1]], neg_ix], dim=1)
        logits = -(embeddings[pos_ix[:, [0]]] - ref_emb[all_ix]).pow(2).sum(dim=2)
        logits = logits * beta
        logs = -torch.log(torch.softmax(logits, dim=1))[:, 0]
        #         logs = torch.relu(logs + np.log(0.1))
        return logs.mean()

    def training_step(self, batch, batch_idx):
        # pairs
        x_trx, l_trx, m_trx, x_click, l_click, m_click = batch
        z_trx = self.seq_encoder_trx(x_trx)  # B, H
        z_click = self.seq_encoder_click(x_click)  # B, H
        z_trx, _ = self.mha_trx_click(z_trx, z_click, z_click)
        z_click, _ = self.mha_click_trx(z_click, z_trx, z_trx)
        loss_pt = self.loss_fn_p(embeddings=z_trx, labels=l_trx, ref_emb=z_click, ref_labels=l_click)
        self.log('loss/loss_pt', loss_pt)

        loss_pc = self.loss_fn_p(embeddings=z_click, labels=l_click, ref_emb=z_trx, ref_labels=l_trx)
        self.log('loss/loss_pc', loss_pc)

        with torch.no_grad():
            out = -(z_trx.unsqueeze(1) - z_click.unsqueeze(0)).pow(2).sum(dim=2)
            out = out[m_trx == 0][:, m_click == 0]
            T, C = out.size()
            assert T == C
            n_samples = z_trx.size(0) // (l_trx.max().item() + 1)
            for i in range(n_samples):
                l2 = out[i::n_samples, i::n_samples]
                self.train_precision(l2)
                self.train_mrr(l2)

        return loss_pt + 0.1 * loss_pc  # loss_pc

    def on_train_epoch_end(self):
        self.log('train_metrics/precision', self.train_precision, prog_bar=True)
        self.log('train_metrics/mrr', self.train_mrr, prog_bar=True)

In [47]:
from ptls.data_load import padded_collate_wo_target

def train_qsm(df_matching_train, features_trx_train, features_click_train, model_n, cfg):
    batch_size = 64
    train_dl = torch.utils.data.DataLoader(
        PairedFullDataset(
            df_matching_train[lambda x: x['rtk'].ne('0')].values,
            data=[
                features_trx_train,
                features_click_train,
            ],
            augmentations=[
                augmentation_chain(DropDuplicate('mcc_code', col_new_cnt='c_cnt'), RandomSlice(32, 1024)),  # 1024
                augmentation_chain(DropDuplicate('cat_id', col_new_cnt='c_cnt'), RandomSlice(64, 2048)),  # 2048
            ],
            n_sample=2,
        ),
        collate_fn=PairedFullDataset.collate_fn,
        drop_last=True,
        shuffle=True,
        num_workers=12,
        batch_size=batch_size,
        persistent_workers=True,
    )

    mlm_model_trx = MLMPretrainModuleTrx.load_from_checkpoint(f'{cfg.objects_path}/pretrain_trx.cpt')
    mlm_model_click = MLMPretrainModuleClick.load_from_checkpoint(f'{cfg.objects_path}/pretrain_click.cpt')
    pl.seed_everything(random.randint(1, 2**16 - 1))
    sup_model = PairedModule(
        cfg.model_config, trx_amnt_quantiles=mlm_model_trx.seq_encoder.trx_amnt_quantiles,
        k=100 * batch_size // 3000,
        lr=0.0022, weight_decay=0,
        max_lr=0.0018, pct_start=1100 / 6000, total_steps=6000,
        beta=0.2 / 1.4, neg_count=60,
    )
    sup_model.load_pretrained(mlm_model_trx.seq_encoder, mlm_model_click.seq_encoder)

    trainer = pl.Trainer(
        accelerator="cuda" if torch.cuda.is_available() else "cpu",
        max_steps=4000,
        callbacks=[
            pl.callbacks.LearningRateMonitor(),
            pl.callbacks.ModelCheckpoint(
                every_n_train_steps=1000, save_top_k=-1,
            ),
        ]
    )
    print('Train qsm start')
    trainer.fit(sup_model, train_dl)
    trainer.save_checkpoint(f'{cfg.objects_path}/nn_distance_coles_model_{model_n}.cpt', weights_only=True)
    print(f'Train qsm [{model_n}] done')

In [48]:
# import gc
# gc.collect()
# torch.cuda.empty_cache()

In [None]:
# ensemble_size = 1

# for i in range(ensemble_size):
#     train_qsm(df_matching_train, features_trx_train, features_click_train, i, cfg)

**mrr~0.019 на одной модели. На 1-м фолде**

**cross-attention with mean**

In [None]:
from torch import nn

class PBLayerNorm(torch.nn.LayerNorm):
    def forward(self, x: PaddedBatch):
        return PaddedBatch(super().forward(x.payload), x.seq_lens)

class L2Scorer(torch.nn.Module):
    def forward(self, x):
        B, H = x.size()
        a, b =x[:, :H // 2], x[:, H // 2:]
        return -(a - b).pow(2).sum(dim=1)

class PairedModule(pl.LightningModule):
    def __init__(self, params, trx_amnt_quantiles, k,
                 lr, weight_decay,
                 max_lr, pct_start, total_steps,
                 beta, neg_count,
                 ):
        super().__init__()
        self.save_hyperparameters(ignore=['mlm_model_trx', 'mlm_model_click'])

        common_trx_size = params.common_trx_size
        self.rnn_enc = torch.nn.Sequential(
            RnnEncoder(
                common_trx_size,
                type=params.rnn.type,
                hidden_size=params.rnn.hidden_size,
                bidir=params.rnn.bidir,
                trainable_starter=params.rnn.trainable_starter
            ),
            LastStepEncoder(),
            #             NormEncoder(),
        )
        self.mha_trx_click = nn.MultiheadAttention(
            embed_dim=256,##?,
            num_heads=4,
            dropout=0.3,
            batch_first=True
        )
        self.mha_click_trx = nn.MultiheadAttention(
            embed_dim=256,##?,
            num_heads=4,
            dropout=0.3,
            batch_first=True
        )
        self._seq_encoder_trx = torch.nn.Sequential(
            TransactionEncoder(params, trx_amnt_quantiles),
            PBLayerNorm(common_trx_size),
        )
        self._seq_encoder_click = torch.nn.Sequential(
            ClickEncoder(params),
            PBLayerNorm(common_trx_size),
        )

        self.cls = torch.nn.Sequential(
            L2Scorer(),
        )

        self.train_precision = PrecisionK(k=k)
        self.train_mrr = MeanReciprocalRankK(k=k)
        self.valid_precision = PrecisionK(k=k)
        self.valid_mrr = MeanReciprocalRankK(k=k)

    def load_pretrained(self, trx, click):
        self._seq_encoder_trx[0].load_state_dict(trx.state_dict())
        self._seq_encoder_click[0].load_state_dict(click.state_dict())

    def seq_encoder_trx(self, x):
        x = self._seq_encoder_trx(x)
        return self.rnn_enc(x)

    def seq_encoder_click(self, x_orig):
        x = self._seq_encoder_click(x_orig)
        #         x = PaddedBatch(
        #             x.payload + self.mlm_model_click.sentence_encoding(x_orig),
        #             x.seq_lens,
        #         )
        return self.rnn_enc(x)

    def configure_optimizers(self):
        optim = torch.optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer=optim,
            max_lr=self.hparams.max_lr,
            total_steps=self.hparams.total_steps,
            pct_start=self.hparams.pct_start,
            anneal_strategy='cos',
            cycle_momentum=False,
            div_factor=25.0,
            final_div_factor=10000.0,
            three_phase=True,
        )
        scheduler = {'scheduler': scheduler, 'interval': 'step'}
        return [optim], [scheduler]

    def loss_fn_p(self, embeddings, labels, ref_emb, ref_labels):
        beta = self.hparams.beta
        neg_count = self.hparams.neg_count

        pos_ix = (labels.view(-1, 1) == ref_labels.view(1, -1)).nonzero(as_tuple=False)
        pos_labels = labels[pos_ix[:, 0]]
        neg_w = ((pos_labels.view(-1, 1) != ref_labels.view(1, -1))).float()
        neg_ix = torch.multinomial(neg_w, neg_count - 1)
        all_ix = torch.cat([pos_ix[:, [1]], neg_ix], dim=1)
        logits = -(embeddings[pos_ix[:, [0]]] - ref_emb[all_ix]).pow(2).sum(dim=2)
        logits = logits * beta
        logs = -torch.log(torch.softmax(logits, dim=1))[:, 0]
        #         logs = torch.relu(logs + np.log(0.1))
        return logs.mean()

    def training_step(self, batch, batch_idx):
        # pairs
        x_trx, l_trx, m_trx, x_click, l_click, m_click = batch
        z_trx = self.seq_encoder_trx(x_trx)  # B, H
        z_click = self.seq_encoder_click(x_click)  # B, H
        z_trx, _ = self.mha_trx_click(z_trx, z_click, z_click)
        z_click, _ = self.mha_click_trx(z_click, z_trx, z_trx)
        
        loss_pt = self.loss_fn_p(embeddings=z_trx, labels=l_trx, ref_emb=z_click, ref_labels=l_click)
        self.log('loss/loss_pt', loss_pt)

        loss_pc = self.loss_fn_p(embeddings=z_click, labels=l_click, ref_emb=z_trx, ref_labels=l_trx)
        self.log('loss/loss_pc', loss_pc)

        with torch.no_grad():
            out = -(z_trx.unsqueeze(1) - z_click.unsqueeze(0)).pow(2).sum(dim=2)
            out = out[m_trx == 0][:, m_click == 0]
            T, C = out.size()
            assert T == C
            n_samples = z_trx.size(0) // (l_trx.max().item() + 1)
            for i in range(n_samples):
                l2 = out[i::n_samples, i::n_samples]
                self.train_precision(l2)
                self.train_mrr(l2)

        return loss_pt + 0.1 * loss_pc  # loss_pc

    def on_train_epoch_end(self):
        self.log('train_metrics/precision', self.train_precision, prog_bar=True)
        self.log('train_metrics/mrr', self.train_mrr, prog_bar=True)

In [None]:
from ptls.data_load import padded_collate_wo_target

def train_qsm(df_matching_train, features_trx_train, features_click_train, model_n, cfg):
    batch_size = 64
    train_dl = torch.utils.data.DataLoader(
        PairedFullDataset(
            df_matching_train[lambda x: x['rtk'].ne('0')].values,
            data=[
                features_trx_train,
                features_click_train,
            ],
            augmentations=[
                augmentation_chain(DropDuplicate('mcc_code', col_new_cnt='c_cnt'), RandomSlice(32, 1024)),  # 1024
                augmentation_chain(DropDuplicate('cat_id', col_new_cnt='c_cnt'), RandomSlice(64, 2048)),  # 2048
            ],
            n_sample=2,
        ),
        collate_fn=PairedFullDataset.collate_fn,
        drop_last=True,
        shuffle=True,
        num_workers=12,
        batch_size=batch_size,
        persistent_workers=True,
    )

    mlm_model_trx = MLMPretrainModuleTrx.load_from_checkpoint(f'{cfg.objects_path}/pretrain_trx.cpt')
    mlm_model_click = MLMPretrainModuleClick.load_from_checkpoint(f'{cfg.objects_path}/pretrain_click.cpt')
    pl.seed_everything(random.randint(1, 2**16 - 1))
    sup_model = PairedModule(
        cfg.model_config, trx_amnt_quantiles=mlm_model_trx.seq_encoder.trx_amnt_quantiles,
        k=100 * batch_size // 3000,
        lr=0.0022, weight_decay=0,
        max_lr=0.0018, pct_start=1100 / 6000, total_steps=6000,
        beta=0.2 / 1.4, neg_count=60,
    )
    sup_model.load_pretrained(mlm_model_trx.seq_encoder, mlm_model_click.seq_encoder)

    trainer = pl.Trainer(
        accelerator="cuda" if torch.cuda.is_available() else "cpu",
        max_steps=46,
        callbacks=[
            pl.callbacks.LearningRateMonitor(),
            pl.callbacks.ModelCheckpoint(
                every_n_train_steps=1000, save_top_k=-1,
            ),
        ]
    )
    print('Train qsm start')
    trainer.fit(sup_model, train_dl)
    trainer.save_checkpoint(f'{cfg.objects_path}/nn_distance_coles_model_{model_n}.cpt', weights_only=True)
    print(f'Train qsm [{model_n}] done')

In [None]:
ensemble_size = 1

for i in range(ensemble_size):
    train_qsm(df_matching_train, features_trx_train, features_click_train, i, cfg)

# GPT, CoLES

**Get pyspark data**

In [None]:
!git clone https://github.com/Dzhambo/MBD.git

In [4]:
import os

import pyspark
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql import types as T
import ptls
from ptls.preprocessing import PysparkDataPreprocessor

# os.environ['JAVA_HOME']= ''

spark_conf = pyspark.SparkConf()
spark_conf.setMaster("local[*]").setAppName("JoinModality")
spark_conf.set("spark.driver.maxResultSize", "8g")
spark_conf.set("spark.executor.memory", "16g")
spark_conf.set("spark.executor.memoryOverhead", "8g")
spark_conf.set("spark.driver.memory", "16g")
spark_conf.set("spark.driver.memoryOverhead", "8g")
spark_conf.set("spark.cores.max", "4")
spark_conf.set("spark.sql.shuffle.partitions", "200")
spark_conf.set("spark.local.dir", "../../spark_local_dir")


spark = SparkSession.builder.config(conf=spark_conf).getOrCreate()
spark.sparkContext.getConf().getAll()

[('spark.driver.extraJavaOptions',
  '-Djava.net.preferIPv6Addresses=false -XX:+IgnoreUnrecognizedVMOptions --add-opens=java.base/java.lang=ALL-UNNAMED --add-opens=java.base/java.lang.invoke=ALL-UNNAMED --add-opens=java.base/java.lang.reflect=ALL-UNNAMED --add-opens=java.base/java.io=ALL-UNNAMED --add-opens=java.base/java.net=ALL-UNNAMED --add-opens=java.base/java.nio=ALL-UNNAMED --add-opens=java.base/java.util=ALL-UNNAMED --add-opens=java.base/java.util.concurrent=ALL-UNNAMED --add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED --add-opens=java.base/jdk.internal.ref=ALL-UNNAMED --add-opens=java.base/sun.nio.ch=ALL-UNNAMED --add-opens=java.base/sun.nio.cs=ALL-UNNAMED --add-opens=java.base/sun.security.action=ALL-UNNAMED --add-opens=java.base/sun.util.calendar=ALL-UNNAMED --add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED -Djdk.reflect.useDirectMethodHandle=false'),
 ('spark.app.name', 'JoinModality'),
 ('spark.local.dir', '../../spark_local_dir'),
 ('spark.app.submit

In [5]:
!unzip ./data/raw_data/transactions.zip -d ./data/raw_data/
!unzip ./data/raw_data/clickstream.zip -d ./data/raw_data/

Archive:  ./data/raw_data/transactions.zip
  inflating: ./data/raw_data/transactions.csv  
  inflating: ./data/raw_data/__MACOSX/._transactions.csv  
Archive:  ./data/raw_data/clickstream.zip
  inflating: ./data/raw_data/clickstream.csv  


In [6]:
transactions = spark.read.csv('./data/raw_data/transactions.csv', header=True)
clickstream = spark.read.csv('./data/raw_data/clickstream.csv', header=True)
train_matching = spark.read.csv('./data/raw_data/train_matching.csv', header=True)
# train_edu = spark.read.csv('../data/raw_data/train_edu.csv', header=True) #Этого файла нет

click_categories = spark.read.csv('./data/raw_data/click_categories.csv', header=True)
clickstream = clickstream.join(click_categories, on='cat_id')

In [7]:
preprocessor_trx = PysparkDataPreprocessor(
        col_id='user_id',
        col_event_time='transaction_dttm',
        event_time_transformation='dt_to_timestamp',
        cols_category=["mcc_code", "currency_rk"],
    )


preprocessor_click = PysparkDataPreprocessor(
    col_id='user_id',
    col_event_time='timestamp',
    event_time_transformation='dt_to_timestamp',
    cols_category=['cat_id', 'level_0', 'level_1', 'level_2'],
)

In [8]:
transactions_prepared = preprocessor_trx.fit_transform(transactions)
clickstream_prepared = preprocessor_click.fit_transform(clickstream)

In [9]:
train_matching = train_matching.withColumnRenamed('rtk', 'user_id')
clickstream_prepared = clickstream_prepared.join(train_matching, on='user_id', how='outer').drop('user_id')
clickstream_prepared  = clickstream_prepared.withColumnRenamed('bank', 'user_id')
clickstream_prepared.show(2)

+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|          event_time|              cat_id|             new_uid|             level_0|             level_1|             level_2|             user_id|
+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|[1616467560, 1616...|[1, 1, 1, 1, 1, 1...|[411399, 411399, ...|[1, 1, 1, 1, 1, 1...|[1, 1, 1, 1, 1, 1...|[1, 1, 1, 1, 1, 1...|95f2446d41fc4536b...|
|[1612159140, 1612...|[3, 12, 5, 12, 40...|[1840824, 1840824...|[3, 5, 5, 5, 38, ...|[1, 2, 1, 2, 1, 1...|[1, 1, 1, 1, 1, 1...|89d5b991d5dc4c5d8...|
+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
only showing top 2 rows



In [10]:
clickstream_prepared = clickstream_prepared.withColumnRenamed('event_time', 'click_event_time')
transactions_prepared = transactions_prepared.withColumnRenamed('event_time', 'trx_event_time')

In [11]:
mm_dataset = transactions_prepared.join(clickstream_prepared, on='user_id', how='outer')
mm_dataset.show(5)

+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|             user_id|      trx_event_time|            mcc_code|         currency_rk|     transaction_amt|    click_event_time|              cat_id|             new_uid|             level_0|             level_1|             level_2|
+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|0012e60b16f14da4b...|[1596253849, 1596...|[1, 23, 7, 1, 1, ...|[1, 1, 1, 1, 1, 1...|[-398.97632, -195...|[1611572447, 1611...|[29, 1, 27, 9, 9,...|[1439071, 1079747...|[26, 1, 25, 9, 9,...|[1, 1, 1, 1, 1, 1...|[1, 1, 1, 1, 1, 1...|
|003d93fb918846ada...|[1596258479, 1596...|[5, 5, 8, 2, 29, ...|[1, 

In [12]:
mm_dataset.write.mode('overwrite').parquet('./spark_data/mm_dataset.parquet')

In [13]:
!mkdir spark_data

mkdir: cannot create directory ‘spark_data’: File exists


In [14]:
mm_dataset_fold0, mm_dataset_fold1, mm_dataset_fold2,  mm_dataset_fold3, mm_dataset_fold4 = mm_dataset.randomSplit([0.2, 0.2, 0.2, 0.2, 0.2], seed=42)
mm_dataset_fold0.write.mode('overwrite').parquet('./spark_data/mm_dataset_fold/fold=0')
mm_dataset_fold1.write.mode('overwrite').parquet('./spark_data/mm_dataset_fold/fold=1')
# mm_dataset_fold2.write.mode('overwrite').parquet('./spark_data/mm_dataset_fold/fold=2')
# mm_dataset_fold3.write.mode('overwrite').parquet('./spark_data/mm_dataset_fold/fold=3')
mm_dataset_fold4.write.mode('overwrite').parquet('./spark_data/mm_dataset_fold/fold=4')

In [15]:
spark.stop()

# Utilities

In [28]:
from ptls.data_load.iterable_processing_dataset import IterableProcessingDataset
from datetime import datetime
from ptls.data_load.padded_batch import PaddedBatch

class DeleteNan(IterableProcessingDataset):
    def __init__(self, col_name):
        super().__init__()
        self.col_name = col_name
    
    def __iter__(self):
        for rec in self._src:
            features = rec[0] if type(rec) is tuple else rec
            if features[self.col_name] is not None:
                yield features

class TypeProc(IterableProcessingDataset):
    def __init__(self, col_name, tp='float'):
        super().__init__()
        self.col_name = col_name
        self.tp = tp

    def __iter__(self):
        for rec in self._src:
            features = rec[0] if type(rec) is tuple else rec
            if type(features[self.col_name]) is not str:
                features[self.col_name] = np.array([float(val) if self.tp=='float' else int(float(val)) for val in features[self.col_name]])
            else:
                features[self.col_name] = float(features[self.col_name]) if self.tp=='float' else int(float(features[self.col_name]))
            yield features

**Create dataset**

In [29]:
from ptls.data_load.datasets import ParquetDataset
from ptls.data_load.iterable_processing.add_modal_name import AddModalName
from ptls.data_load.iterable_processing import FeatureFilter, SeqLenFilter, ISeqLenLimit
from ptls.data_load.iterable_processing.time_proc import TimeProcMultimodal
from ptls.data_load.iterable_processing import ToTorch
train = ParquetDataset(
    shuffle_files=True,
    data_files=[
        './spark_data/mm_dataset_fold/fold=0',
        './spark_data/mm_dataset_fold/fold=1',
        # './spark_data/mm_dataset_fold/fold=2',
        # './spark_data/mm_dataset_fold/fold=3'
    ],
    i_filters=[
        DeleteNan(col_name='mcc_code'),
        DeleteNan(col_name='cat_id'),
        DeleteNan(col_name='user_id'),
        TypeProc(col_name='transaction_amt'),
        AddModalName(
            source='trx',
            cols=[
                'mcc_code',
                'currency_rk',
                'transaction_amt'
            ]
        ),
        AddModalName(
            source='click',
            cols=[
                'cat_id',
                'level_0',
                'level_1',
                'level_2'
            ]
        ),
        FeatureFilter(drop_feature_names=[
            'user_id',
            'higher_education',
            'new_uid'
            ]
        ),
        SeqLenFilter(min_seq_len=32),
        ISeqLenLimit(max_seq_len=4096),
        TimeProcMultimodal(
            source='trx',
            time_col='trx_event_time'
        ),
        TimeProcMultimodal(
            source='click',
            time_col='click_event_time'
        ),
        ToTorch()
    ]
)

valid = ParquetDataset(
    shuffle_files=True,
    data_files=[
        './spark_data/mm_dataset_fold/fold=4',
    ],
    i_filters=[
        DeleteNan(col_name='mcc_code'),
        DeleteNan(col_name='cat_id'),
        DeleteNan(col_name='user_id'),
        TypeProc(col_name='transaction_amt'),
        AddModalName(
            source='trx',
            cols=[
                'mcc_code',
                'currency_rk',
                'transaction_amt'
            ]
        ),
        AddModalName(
            source='click',
            cols=[
                'cat_id',
                'level_0',
                'level_1',
                'level_2'
            ]
        ),
        FeatureFilter(drop_feature_names=[
            'user_id',
            'higher_education',
            'new_uid'
            ]
        ),
        SeqLenFilter(min_seq_len=32),
        ISeqLenLimit(max_seq_len=4096),
        TimeProcMultimodal(
            source='trx',
            time_col='trx_event_time'
        ),
        TimeProcMultimodal(
            source='click',
            time_col='click_event_time'
        ),
        ToTorch()
    ]
)

# Create Dataset PTLS

In [30]:
from ptls.data_load.feature_dict import FeatureDict
from ptls.frames.coles.split_strategy import AbsSplit
from ptls.frames.coles.metric import metric_recall_top_K, outer_cosine_similarity, outer_pairwise_distance

from ptls.frames.abs_module import ABSModule
from ptls.frames.coles.losses import ContrastiveLoss
from ptls.frames.coles.metric import BatchRecallTopK
from ptls.frames.coles.sampling_strategies import HardNegativePairSelector
from ptls.nn.head import Head
from ptls.nn.seq_encoder.containers import SeqEncoderContainer
from ptls.data_load.utils import collate_feature_dict
from ptls.data_load.padded_batch import PaddedBatch
import torch

def collate_feature_dict(batch):
    new_x_ = defaultdict(list)
    for i, x in enumerate(batch):
        for k, v in x.items():
            new_x_[k].append(v)
    
    seq_col = next(k for k, v in batch[0].items() if FeatureDict.is_seq_feature(k, v))
    lengths = torch.LongTensor([len(rec[seq_col]) for rec in batch])
    new_x = {}
    for k, v in new_x_.items():
        if type(v[0]) is torch.Tensor:
            if k.startswith('target'):
                new_x[k] = torch.stack(v, dim=0)
            else:
                new_x[k] = torch.nn.utils.rnn.pad_sequence(v, batch_first=True)
        elif type(v[0]) is np.ndarray:
            new_x[k] = v  # list of arrays[object]
        else:
            v = np.array(v)
            if v.dtype.kind == 'i':
                new_x[k] = torch.from_numpy(v).long()
            elif v.dtype.kind == 'f':
                new_x[k] = torch.from_numpy(v).float()
            elif v.dtype.kind == 'b':
                new_x[k] = torch.from_numpy(v).bool()
            else:
                new_x[k] = v
    return PaddedBatch(new_x, lengths)

def collate_multimodal_feature_dict(batch):
    res = {}
    for source, source_batch in batch.items():
        res[source] = collate_feature_dict(source_batch)
    return res
    
def get_dict_class_labels(batch):
    res = defaultdict(list)
    for i, samples in enumerate(batch):
        for source, values in samples.items():
            for _ in values:
                res[source].append(i)
    for source in res:
        res[source] = torch.LongTensor(res[source])
    return dict(res)
            

class MultiModalDiffSplitDataset(FeatureDict, torch.utils.data.Dataset):
    def __init__(
        self,
        data,
        splitters,
        source_features,
        col_id,
        source_names,
        col_time='event_time',
        *args, **kwargs
    ):
        """
        Dataset for multimodal learning.
        Parameters:
        -----------
        data:
            concatinated data with feature dicts.
        splitter:
            object from from `ptls.frames.coles.split_strategy`.
            Used to split original sequence into subsequences which are samples from one client.
        source_features:
            list of column names 
        col_id:
            column name with user_id
        source_names:
            column name with name sources
        col_time:
            column name with event_time
        """
        super().__init__(*args, **kwargs)
        
        self.data = data
        self.splitters = splitters
        self.col_time = col_time
        self.col_id = col_id
        self.source_names = source_names
        self.source_features = source_features
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        feature_arrays = self.data[idx]
        split_data = self.split_source(feature_arrays)
        return self.get_splits(split_data)
    
    def __iter__(self):
        for feature_arrays in self.data:
            split_data = self.split_source(feature_arrays)
            yield self.get_splits(split_data)
            
    def split_source(self, feature_arrays):
        res = defaultdict(dict)
        for feature_name, feature_array in feature_arrays.items():
            if feature_name == self.col_id:
                res[self.col_id] = feature_array
            else:
                source_name, feature_name_transform = self.get_names(feature_name)
                res[source_name][feature_name_transform] = feature_array
        for source in self.source_names:
            if source not in res:
                res[source] = {source_feature: torch.tensor([]) for source_feature in self.source_features[source]}
        return res
    
    def get_names(self, feature_name):
        idx_del = feature_name.find('_')
        return feature_name[:idx_del], feature_name[idx_del + 1:]
    
    def get_splits(self, feature_arrays):
        res = {}
        for source_name, feature_array in feature_arrays.items():
            if source_name != self.col_id:
                local_date = feature_array[self.col_time]
                if source_name not in self.splitters:
                    continue
                indexes = self.splitters[source_name].split(local_date)
                res[source_name] = [{k: v[ix] for k, v in feature_array.items() if self.is_seq_feature(k, v)} for ix in indexes]
        return res
        
    def collate_fn(self, batch, return_dct_labels=False):
        dict_class_labels = get_dict_class_labels(batch)
        batch = reduce(lambda x, y: {k: x[k] + y[k] for k in x if k in y}, batch)
        padded_batch = collate_multimodal_feature_dict(batch)
        if return_dct_labels:
            return padded_batch, dict_class_labels
        return padded_batch, dict_class_labels[list(dict_class_labels.keys())[0]]

    
class MultiModalDiffSplitIterableDataset(MultiModalDiffSplitDataset, torch.utils.data.IterableDataset):
    pass

In [31]:
from ptls.frames import PtlsDataModule
from ptls.frames.coles.split_strategy import SampleSlices, NoSplit

train_data = MultiModalDiffSplitIterableDataset(
    splitters={
        'trx': SampleSlices(
            split_count=2,
            cnt_min=32,
            cnt_max=180
            ),
        'click': SampleSlices(
            split_count=2,
            cnt_min=32,
            cnt_max=180
        )
    },
    data=train,
    source_features={
        'trx': [
            'mcc_code',
            'currency_rk',
            'transaction_amt',
            'event_time',
            'hour',
            'weekday'
        ],
        'click': [
            'cat_id',
            'level_0',
            'level_1',
            'level_2',
            'event_time',
            'hour',
            'weekday'
        ]
    },
    col_id='user_id',
    col_time='event_time',
    source_names=[
        'trx',
        'click'
    ]
)

valid_data = MultiModalDiffSplitIterableDataset(
    splitters={
        'trx': NoSplit(),
        'click': NoSplit()
    },
    data=valid,
    source_features={
        'trx': [
            'mcc_code',
            'currency_rk',
            'transaction_amt',
            'event_time',
            'hour',
            'weekday'
        ],
        'click': [
            'cat_id',
            'level_0',
            'level_1',
            'level_2',
            'event_time',
            'hour',
            'weekday'
        ]
    },
    col_id='user_id',
    col_time='event_time',
    source_names=[
        'trx',
        'click'
    ]
)

data_module = PtlsDataModule(
    train_data=train_data,
    valid_data=valid_data,
    train_batch_size=64,
    train_num_workers=8,
    valid_batch_size=256,
    valid_num_workers=8
)

In [52]:
def metric_real_recall_top_K(X, y, K, num_pos=1, metric='cosine'):
    """
        calculate metric R@K
        X - tensor with size n x d, where n - number of examples, d - size of embedding vectors
        y - true labels
        N - count of closest examples, which we consider for recall calcualtion
        metric: 'cosine' / 'euclidean'.
            !!! 'euclidean' - to slow for datasets bigger than 100K rows
    """
    # TODO: take K from `y`
    K_adjusted = min(X.size(0) - 1, K)
    
    res = []

    n = X.size(0)
    d = X.size(1)
    max_size = 2 ** 32
    batch_size = max(1, max_size // (n * d))

    with torch.no_grad():

        for i in range(1 + (len(X) - 1) // batch_size):

            id_left = i * batch_size
            id_right = min((i + 1) * batch_size, len(y))
            y_batch = y[id_left:id_right]

            if metric == 'cosine':
                pdist = -1 * outer_cosine_similarity(X, X[id_left:id_right])
            elif metric == 'euclidean':
                pdist = outer_pairwise_distance(X, X[id_left:id_right])
            else:
                raise AttributeError(f'wrong metric "{metric}"')

            values, indices = pdist.topk(K_adjusted + 1, 0, largest=False)

            y_rep = y_batch.repeat(K_adjusted, 1)
            res.append((y[indices[1:]] == y_rep).sum().item())

    return np.sum(res) / len(y) / num_pos

def cosine_similarity_matrix(x1, x2):
    x1_norm = x1 / x1.norm(dim=1)[:, None]
    x2_norm = x2 / x2.norm(dim=1)[:, None]
    return torch.mm(x1_norm, x2_norm.transpose(0, 1))

def metric_recall_top_K_for_embs(embs_1, embs_2, true_matches, K=100):
    similarity_matrix = cosine_similarity_matrix(embs_1, embs_2)
    K_adjusted = min(len(embs_1), K)
    top_k = similarity_matrix.topk(k=K_adjusted, dim=1).indices
    correct_matches = 0
    for i, indices in enumerate(top_k):
        if true_matches[i] in indices:
            correct_matches += 1
    recall_at_k = correct_matches / len(similarity_matrix)
    return recall_at_k

class M3CoLESModule(ABSModule):
    """
    Multi-Modal Matching
    Contrastive Learning for Event Sequences ([CoLES](https://arxiv.org/abs/2002.08232))

    Subsequences are sampled from original sequence.
    Samples from the same sequence are `positive` examples
    Samples from the different sequences are `negative` examples
    Embeddings for all samples are calculated.
    Paired distances between all embeddings are calculated.
    The loss function tends to make positive distances smaller and negative ones larger.

    Parameters
        seq_encoder:
            Model which calculate embeddings for original raw transaction sequences
            `seq_encoder` is trained by `CoLESModule` to get better representations of input sequences
        head:
            Model which helps to train. Not used during inference
            Can be normalisation layer which make embedding l2 length equals 1
            Can be MLP as `projection head` like in SymCLR framework.
        loss:
            loss object from `ptls.frames.coles.losses`.
            There are paired and triplet loss. They are required sampling strategy
            from `ptls.frames.coles.sampling_strategies`. Sampling strategy takes a relevant pairs or triplets from
            pairwise distance matrix.
        validation_metric:
            Keep None. `ptls.frames.coles.metric.BatchRecallTopK` used by default.
        optimizer_partial:
            optimizer init partial. Network parameters are missed.
        lr_scheduler_partial:
            scheduler init partial. Optimizer are missed.

    """
    def __init__(self,
                 seq_encoders=None,
                 mod_names=None,
                 head=None,
                 loss=None,
                 validation_metric=None,
                 optimizer_partial=None,
                 lr_scheduler_partial=None):
        torch.set_float32_matmul_precision('high')
        if head is None:
            head = Head(use_norm_encoder=True)

        if loss is None:
            loss = ContrastiveLoss(margin=0.5,
                                   sampling_strategy=HardNegativePairSelector(neg_count=5))

        if validation_metric is None:
            validation_metric = BatchRecallTopK(K=4, metric='cosine')
        
        for k in seq_encoders.keys():
            if type(seq_encoders[k]) is str:
                seq_encoders[k] = seq_encoders[seq_encoders[k]]
                
        super().__init__(validation_metric,
                         first(seq_encoders.values()),
                         loss,
                         optimizer_partial,
                         lr_scheduler_partial)
        
        self.seq_encoders = torch.nn.ModuleDict(seq_encoders)
        self._head = head   
        self.y_h_cache = {'train':[], 'valid': []}
        
    @property
    def metric_name(self):
        return 'recall_top_k'

    @property
    def is_requires_reduced_sequence(self):
        return True
    
    def forward(self, x):
        res = {}
        
        for mod_name in x.keys():
            res[mod_name] = self.seq_encoders[mod_name](x[mod_name])
            
        return res

    def shared_step(self, x, y):
        y_h = self(x)
        
        if self._head is not None:
            y_h_head = {k: self._head(y_h_k) for k, y_h_k in y_h.items()}
            y_h = y_h_head
            
        return y_h, y
    
    def _one_step(self, batch, _, stage):
        y_h, y = self.shared_step(*batch)
        y_h_list = list(y_h.values())
        loss = self._loss(torch.cat(y_h_list), torch.cat([y, y]))
        self.log(f'loss/{stage}', loss.detach())
        
        x, y = batch
        for mod_name, mod_x in x.items():
            self.log(f'seq_len/{stage}/{mod_name}', x[mod_name].seq_lens.float().mean().detach(), prog_bar=True)
        
        if stage == "valid":
            n, d = y_h_list[0].shape
            y_h_concat = torch.zeros((2*n, d), device = y_h_list[0].device)
            
            for i in range(2):
                y_h_concat[range(i,2*n,2)] = y_h_list[i] 
            
            if len(self.y_h_cache[stage]) <= 380:
                self.y_h_cache[stage].append((y_h_concat.cpu(), {k: y_h_k.cpu() for k, y_h_k in y_h.items()} , 
                                             {k:x_k.seq_lens.cpu() for k, x_k in x.items()})) 
    
        return loss
    
    def training_step(self, batch, _):
        return self._one_step(batch, _, "train")
    
    def validation_step(self, batch, _):
        return self._one_step(batch, _, "valid")
    
    def on_validation_epoch_end(self):        
        #len_intervals = [(0, 10), (10, 20), (20, 30), (30, 40), (40, 60), (60, 80), (80, 120), (120, 160), (160, 240)]
        self.log_recall_top_K(self.y_h_cache['valid'], len_intervals=None, stage="valid", K=100)
        self.log_recall_top_K(self.y_h_cache['valid'], len_intervals=None, stage="valid", K=50)
        self.log_recall_top_K(self.y_h_cache['valid'], len_intervals=None, stage="valid", K=1)
        
        
        del self.y_h_cache["valid"]
        self.y_h_cache["valid"] = []
        
    def log_recall_top_K(self, y_h_cache, len_intervals=None, stage="valid", K=100):
        y_h = torch.cat([item[0] for item in y_h_cache], dim = 0)
        y_h_mods = defaultdict(list)
        seq_lens_dict = defaultdict(list)
        
        for item in y_h_cache:
            for k, emb in item[1].items():
                y_h_mods[k].append(emb)
                
            for k, l in item[2].items():
                seq_lens_dict[k].append(l)
        
        y_h_mods = {k: torch.cat(el, dim=0) for k ,el in y_h_mods.items()}
        seq_lens_dict = {k: torch.cat(el) for k ,el in seq_lens_dict.items()}

        #n, _ = y_h.shape
        #y = torch.zeros((n,)).cpu().long()
        #y[range(0,n,2)] = torch.arange(0, n//2)
        #y[range(1,n,2)] = torch.arange(0, n//2)
        #computed_metric = metric_real_recall_top_K(y_h, y, K=100)
        y_h_bank, y_h_rmb = list(y_h_mods.values())
        computed_metric_b2r = metric_recall_top_K_for_embs(y_h_bank, y_h_rmb, torch.arange(y_h_rmb.shape[0]), K=K)
        computed_metric_r2b = metric_recall_top_K_for_embs(y_h_rmb, y_h_bank, torch.arange(y_h_rmb.shape[0]), K=K)
        
        if len_intervals != None:
            for mod, seq_lens in seq_lens_dict.items():
                for start, end in len_intervals:
                    mask = ((seq_lens > start) & (seq_lens <= end))

                    if torch.any(mask):
                        #y_h_filtered = y_h[mask.repeat_interleave(2)]
                        y_h_bank_filtered = y_h_bank[mask]
                        y_h_rmb_filtered = y_h_rmb[mask]

                        #y = torch.div(torch.arange(len(y_h_filtered)), 2, rounding_mode='floor')
                        #recall = metric_real_recall_top_K(y_h_filtered, y, K=100)
                        recall_r2b = metric_recall_top_K_for_embs(y_h_rmb_filtered, y_h_bank_filtered, torch.arange(y_h_rmb_filtered.shape[0]), K=100)
                        recall_b2r = metric_recall_top_K_for_embs(y_h_bank_filtered, y_h_rmb_filtered, torch.arange(y_h_rmb_filtered.shape[0]), K=100)

                        #self.log(f"{mode}/R@100_len_from_{start}_to_{end}", recall, prog_bar=True)
                        self.log(f"{stage}/{mod}/r2b_R@100_len_from_{start}_to_{end}", recall_r2b, prog_bar=True)
                        self.log(f"{stage}/{mod}/b2r_R@100_len_from_{start}_to_{end}", recall_b2r, prog_bar=True)
        
        #self.log(f"{mode}/R@100", computed_metric, prog_bar=True)
        self.log(f"{stage}/click2trx_R@{K}", computed_metric_r2b, prog_bar=True)
        self.log(f"{stage}/trx2click_R@{K}", computed_metric_b2r, prog_bar=True)

In [53]:
from ptls.frames.coles.losses import SoftmaxLoss
from ptls.frames.coles.metric import BatchRecallTopK
from functools import partial

def first(iterable, default=None):
    iterator = iter(iterable)
    return next(iterator, default)

head = ptls.nn.Head(
        input_size=128,
        use_norm_encoder=True,
        hidden_layers_sizes=[128, 128],
        objective='regression',
        num_classes=128
    )

seq_encoders = {
    'trx': ptls.nn.RnnSeqEncoder(
        trx_encoder=ptls.nn.TrxEncoder(
            norm_embeddings=False,
            embeddings_noise=0.003,
            embeddings={
                'mcc_code': {
                    'in': 350,
                    'out': 64
                },
                'currency_rk': {
                    'in': 10,
                    'out': 4,
                },
                'hour': {
                    'in': 25,
                    'out': 16
                },
                'weekday': {
                    'in': 8,
                    'out': 4
                }
            },
            numeric_values={'transaction_amt': 'log'}
        ),
        type='gru',
        hidden_size=128
    ),
    'click': ptls.nn.RnnSeqEncoder(
        trx_encoder=ptls.nn.TrxEncoder(
            embeddings_noise=0.003,
            embeddings={
                'cat_id': {
                    'in': 400,
                    'out':64
                },
                'level_0': {
                    'in': 400,
                    'out': 16
                },
                'level_1': {
                    'in': 400,
                    'out': 8
                },
                'level_2': {
                    'in': 400,
                    'out': 4
                },
                'hour': {
                    'in': 25,
                    'out': 16
                },
                'weekday': {
                    'in': 8,
                    'out': 4
                }
            }
        ),
        type='gru',
        hidden_size=128
    )
}

pl_module = M3CoLESModule(
    validation_metric=BatchRecallTopK(
        K=1,
        metric='cosine'
    ),
    head=head,
    seq_encoders=seq_encoders,
    loss=SoftmaxLoss(),
    optimizer_partial=partial(torch.optim.AdamW, lr=0.001, weight_decay=1e-4),
    lr_scheduler_partial=partial(torch.optim.lr_scheduler.StepLR, step_size=1, gamma=0.9)
)

In [54]:
import pytorch_lightning as pl

trainer = pl.Trainer(
    max_epochs=50,
    accelerator="cuda" if torch.cuda.is_available() else "cpu",
    enable_progress_bar=True,
    gradient_clip_val=0.5,
    log_every_n_steps=50,
    limit_val_batches=36
)

**На COLES mrr ~0.011, на 2-х фолдах**

In [None]:
import numpy as np
from collections import defaultdict
from functools import reduce

trainer.fit(pl_module, data_module)

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7d9a57ada200>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7d9a57ada200>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/

Validation: |          | 0/? [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7d9a57ada200>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7d9a57ada200>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/

Validation: |          | 0/? [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7d9a57ada200>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7d9a57ada200>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/

Validation: |          | 0/? [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7d9a57ada200>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7d9a57ada200>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/

Validation: |          | 0/? [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7d9a57ada200>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7d9a57ada200>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/