# Импортируем необходимые библиотеки

In [1]:
!pip install pytorch-lifestream
!pip install comet_ml

Collecting pytorch-lifestream
  Downloading pytorch-lifestream-0.6.0.tar.gz (163 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m163.4/163.4 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting hydra-core>=1.1.2 (from pytorch-lifestream)
  Downloading hydra_core-1.3.2-py3-none-any.whl.metadata (5.5 kB)
Downloading hydra_core-1.3.2-py3-none-any.whl (154 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.5/154.5 kB[0m [31m10.5 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: pytorch-lifestream
  Building wheel for pytorch-lifestream (pyproject.toml) ... [?25l[?25hdone
  Created wheel for pytorch-lifestream: filename=pytorch_lifestream-0.6.0-py3-none-any.whl size=274670 sha256=502dd8cea76c88adeb52d64f4105f305f1e98923a4a7f452973fc132f793cefc
 

In [2]:
# data preprocessing
import os
import numpy as np
import pandas as pd
import pickle

# misc
from tqdm import tqdm
from functools import partial

# logging
import comet_ml

# classical ML
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score
from catboost import CatBoostClassifier

# basic deep learning libs
import torch
import pytorch_lightning as pl
import torchmetrics

# ptls
from ptls.nn import TrxEncoder, RnnSeqEncoder, TransformerEncoder, GptEncoder, Head
from ptls.frames import PtlsDataModule
from ptls.frames.coles import CoLESModule
from ptls.frames.coles import ColesDataset
from ptls.frames.coles.split_strategy import SampleSlices
from ptls.frames.cpc import CpcModule
from ptls.frames.cpc import CpcDataset
from ptls.frames.gpt import GptDataset
from ptls.frames.supervised import SeqToTargetDataset, SequenceToTarget
from ptls.data_load.datasets import MemoryMapDataset
from ptls.data_load.datasets import inference_data_loader
from ptls.frames.inference_module import InferenceModule
from ptls.data_load.iterable_processing import SeqLenFilter
from ptls.preprocessing import PandasDataPreprocessor
from ptls.data_load.utils import collate_feature_dict
from ptls.frames.inference_module import InferenceModule

In [3]:
def seed_everything(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [5]:
comet_ml.login()

In [6]:
from pytorch_lightning.loggers import CometLogger

---

**Time2Vec:**

In [7]:
import torch
from ptls.data_load.padded_batch import PaddedBatch
from ptls.nn.trx_encoder.batch_norm import RBatchNorm, RBatchNormWithLens
from ptls.nn.trx_encoder.noisy_embedding import NoisyEmbedding
from ptls.nn.trx_encoder.trx_encoder_base import TrxEncoderBase
import torch.nn as nn


class Time2Vec(nn.Module):
    def __init__(self, k, interval=86400):
        super(Time2Vec, self).__init__()
        self.k = k
        self.w = nn.Parameter(torch.randn(k))
        self.b = nn.Parameter(torch.randn(k))
        self.w0 = nn.Parameter(torch.randn(1))
        self.b0 = nn.Parameter(torch.randn(1))
        self.interval = interval
        
    def forward(self, event_time, t0):
        t0_ = torch.zeros_like(event_time)
        time_diff=None
        if type(t0)!=int:
            first_column = t0[:, 0].unsqueeze(1)
            t0_ = first_column.expand(-1, t0.size(1))
        time_diff = (event_time - t0_)/self.interval
        v1 = self.w0 * time_diff.unsqueeze(-1) + self.b0
        v2 = torch.cos(self.w * time_diff.unsqueeze(-1) + self.b)
        
        return torch.cat([v1, v2], -1)

        
class TrxEncoderT2V(TrxEncoderBase):
    def __init__(self,
                 embeddings=None,
                 numeric_values=None,
                 custom_embeddings=None,
                 time_values=None,
                 embeddings_noise: float = 0,
                 norm_embeddings=None,
                 use_batch_norm=True,
                 use_batch_norm_with_lens=False,
                 clip_replace_value=None,
                 positions=None,
                 emb_dropout=0,
                 spatial_dropout=False,
                 orthogonal_init=False,
                 linear_projection_size=0,
                 out_of_index: str = 'clip',
                 k=2,
                 time_col='event_time'
                 ):
        if clip_replace_value is not None:
            warnings.warn('`clip_replace_value` attribute is deprecated. Always "clip to max" used. '
                          'Use `out_of_index="assert"` to avoid categorical values clip', DeprecationWarning)

        if positions is not None:
            warnings.warn('`positions` is deprecated. positions is not used', UserWarning)

        if embeddings is None:
            embeddings = {}
        if custom_embeddings is None:
            custom_embeddings = {}
        if time_values is None:
            time_values = {}

        noisy_embeddings = {}
        for emb_name, emb_props in embeddings.items():
            if emb_props.get('disabled', False):
                continue
            if emb_props['in'] == 0 or emb_props['out'] == 0:
                continue
            noisy_embeddings[emb_name] = NoisyEmbedding(
                num_embeddings=emb_props['in'],
                embedding_dim=emb_props['out'],
                padding_idx=0,
                max_norm=1 if norm_embeddings else None,
                noise_scale=embeddings_noise,
                dropout=emb_dropout,
                spatial_dropout=spatial_dropout,
            )

        super().__init__(
            embeddings=noisy_embeddings,
            numeric_values=numeric_values,
            custom_embeddings=custom_embeddings,
            out_of_index=out_of_index,
        )

        custom_embedding_size = self.custom_embedding_size
        if use_batch_norm and custom_embedding_size > 0:
            # :TODO: Should we use Batch norm with not-numerical custom embeddings?
            if use_batch_norm_with_lens:
                self.custom_embedding_batch_norm = RBatchNormWithLens(custom_embedding_size)
            else:
                self.custom_embedding_batch_norm = RBatchNorm(custom_embedding_size)
        else:
            self.custom_embedding_batch_norm = None
        
        self.k = k
        self.time2vec_days = Time2Vec(k=self.k)
        self.time_col = time_col
        
        if linear_projection_size > 0:
            self.linear_projection_head = torch.nn.Linear(super().output_size+k+1, linear_projection_size)
        else:
            self.linear_projection_head = None
            

        if orthogonal_init:
            for n, p in self.named_parameters():
                if n.startswith('embeddings.') and n.endswith('.weight'):
                    torch.nn.init.orthogonal_(p.data[1:])
                if n == 'linear_projection_head.weight':
                    torch.nn.init.orthogonal_(p.data)

    def forward(self, x: PaddedBatch):
        processed_embeddings = []
        processed_custom_embeddings = []

        for field_name in self.embeddings.keys():
            processed_embeddings.append(self.get_category_embeddings(x, field_name))
        
        for field_name in self.custom_embeddings.keys():
            processed_custom_embeddings.append(self.get_custom_embeddings(x, field_name))

        if len(processed_custom_embeddings):
            processed_custom_embeddings = torch.cat(processed_custom_embeddings, dim=2)
            if self.custom_embedding_batch_norm is not None:
                processed_custom_embeddings = PaddedBatch(processed_custom_embeddings, x.seq_lens)
                processed_custom_embeddings = self.custom_embedding_batch_norm(processed_custom_embeddings)
                processed_custom_embeddings = processed_custom_embeddings.payload
            processed_embeddings.append(processed_custom_embeddings)

        out = torch.cat(processed_embeddings, dim=2)

        time_encoded_days = self.time2vec_days(x.payload[self.time_col], x.payload[self.time_col])
        out = torch.cat((out, time_encoded_days), dim=2)

        if self.linear_projection_head is not None:
            out = self.linear_projection_head(out)
        return PaddedBatch(out, x.seq_lens)

    @property
    def output_size(self):
        """Returns hidden size of output representation
        """
        if self.linear_projection_head is not None:
            return self.linear_projection_head.out_features
        return super().output_size + self.k + 1

# Эксперименты.

**Данные:**

In [8]:
path_data = "https://huggingface.co/datasets/dllllb/rosbank-churn/resolve/main/train.csv.gz?download=true"
data = pd.read_csv(path_data, compression="gzip")
data

Unnamed: 0,PERIOD,cl_id,MCC,channel_type,currency,TRDATETIME,amount,trx_category,target_flag,target_sum
0,01/10/2017,0,5200,,810,21OCT17:00:00:00,5023.00,POS,0,0.0
1,01/10/2017,0,6011,,810,12OCT17:12:24:07,20000.00,DEPOSIT,0,0.0
2,01/12/2017,0,5921,,810,05DEC17:00:00:00,767.00,POS,0,0.0
3,01/10/2017,0,5411,,810,21OCT17:00:00:00,2031.00,POS,0,0.0
4,01/10/2017,0,6012,,810,24OCT17:13:14:24,36562.00,C2C_OUT,0,0.0
...,...,...,...,...,...,...,...,...,...,...
490508,01/04/2017,10176,6011,type1,810,24APR17:14:05:26,600.00,WD_ATM_ROS,1,405.0
490509,01/06/2017,10171,5411,type1,810,06JUN17:00:00:00,132.00,POS,0,0.0
490510,01/02/2017,10167,5541,type1,810,03FEB17:00:00:00,1000.00,POS,1,280428.2
490511,01/06/2017,10163,5941,type1,810,08JUN17:00:00:00,100.00,POS,0,0.0


In [9]:
target = data.groupby(by="cl_id").first().reset_index()[["cl_id", "target_flag"]]
target

Unnamed: 0,cl_id,target_flag
0,0,0
1,1,0
2,5,1
3,9,0
4,10,0
...,...,...
4995,10210,1
4996,10212,0
4997,10213,0
4998,10214,0


In [10]:
data.drop(columns=["PERIOD", "target_flag", "target_sum"], inplace=True)

In [11]:
target_train, target_test = train_test_split(target, test_size=0.1, stratify=target["target_flag"], random_state=42)

In [12]:
trx_data_train = pd.merge(data, target_train["cl_id"], on="cl_id", how="inner")
trx_data_test = pd.merge(data, target_test["cl_id"], on="cl_id", how="inner")

In [13]:
trx_data_train["channel_type"] = trx_data_train["channel_type"].fillna("none")
trx_data_test["channel_type"] = trx_data_test["channel_type"].fillna("none")

In [14]:
month2num = {"JAN": "/01/", "FEB": "/02/", "MAR": "/03/", "APR": "/04/", "MAY": "/05/", "JUN": "/06/",
             "JUL": "/07/", "AUG": "/08/", "SEP": "/09/", "OCT": "/10/", "NOV": "/11/", "DEC": "/12/"}

trx_data_train["TRDATETIME"] = trx_data_train["TRDATETIME"].map(lambda x: x[0:2] + month2num[x[2:5]] + x[5:7] + " " + x[8:])
trx_data_test["TRDATETIME"] = trx_data_test["TRDATETIME"].map(lambda x: x[0:2] + month2num[x[2:5]] + x[5:7] + " " + x[8:])

trx_data_train["TRDATETIME"] = pd.to_datetime(trx_data_train["TRDATETIME"],format='%d/%m/%y %H:%M:%S')
trx_data_test["TRDATETIME"] = pd.to_datetime(trx_data_test["TRDATETIME"],format='%d/%m/%y %H:%M:%S')

In [15]:
chtype2num = {"none": 0, "type1": 1, "type2": 2, "type3": 3, "type4": 4, "type5": 5}

trx_data_train["channel_type"] = trx_data_train["channel_type"].map(lambda x: chtype2num[x])
trx_data_test["channel_type"] = trx_data_test["channel_type"].map(lambda x: chtype2num[x])

In [16]:
trxcat2num = {"POS": 0, "DEPOSIT": 1, "WD_ATM_ROS": 2, "WD_ATM_PARTNER": 3, 
              "C2C_IN": 4, "WD_ATM_OTHER": 5, "C2C_OUT": 6, "BACK_TRX": 7,
              "CAT": 8, "CASH_ADV": 9}

trx_data_train["trx_category"] = trx_data_train["trx_category"].map(lambda x: trxcat2num[x])
trx_data_test["trx_category"] = trx_data_test["trx_category"].map(lambda x: trxcat2num[x])

---

**Квантизация непрерывных признаков (опциональный шаг, нужен только для GPT):**

In [17]:
def digitize(input_array: np.array, q_count: int = 1, bins: np.array = None):
    """Quantile-based discretization function.

    Parameters:
    -------
    input_array (np.array): Input array.
    q_count (int): Amount of quantiles. Used only if input parameter `bins` is None.
    bins (np.array):
        If None, then calculate bins as quantiles of input array,
        otherwise only apply bins to input_array. Default: None

    Returns
    -------
    out_array (np.array of ints): discretized input_array
    bins (np.array of floats):
        Returned only if input parameter `bins` is None.
    """

    if bins is None:
        return_bins = True
        bins = np.quantile(input_array, q=[i / q_count for i in range(1, q_count)], axis=0)
    else:
        return_bins = False

    out_array = np.digitize(input_array, bins)

    if return_bins:
        return out_array, bins
    else:
        return out_array

In [18]:
BINS_NUM = 128

In [19]:
numeric_features = ["amount"]

for feat in numeric_features:
    trx_data_train[feat], bins = digitize(trx_data_train[feat], q_count=BINS_NUM)
    trx_data_test[feat] = digitize(trx_data_test[feat], bins=bins)

In [20]:
import gc

gc.collect()

60

---

In [21]:
preprocessor = PandasDataPreprocessor(
    col_id="cl_id",
    col_event_time="TRDATETIME",
    event_time_transformation="dt_to_timestamp",
    cols_category=["MCC", "channel_type", "currency", "trx_category"],
    cols_numerical=["amount"],
    return_records=False,
)

In [22]:
data_train = preprocessor.fit_transform(trx_data_train)
data_test = preprocessor.transform(trx_data_test)

In [23]:
target_train.rename(columns={"target_flag": "target"}, inplace=True)
target_test.rename(columns={"target_flag": "target"}, inplace=True)
target_train.sort_values(by="cl_id", inplace=True)
target_test.sort_values(by="cl_id", inplace=True)
target_train = target_train["target"]
target_test = target_test["target"]
target_train.reset_index(drop=True, inplace=True)
target_test.reset_index(drop=True, inplace=True)

In [24]:
data_train = data_train.to_dict(orient="records")
data_test = data_test.to_dict(orient="records")

---

**Определение бинов для time diff'ов (в часах) (опциональный шаг, нужен только для TD-GPT):**

In [25]:
train_loader = inference_data_loader(data_train, num_workers=0, batch_size=128)
SECONDS_IN_HOUR = 3600
TIME_DIFF_BINS = 256

time_diffs = []

for batch in tqdm(train_loader):
    timestamps = batch.payload['event_time']
    timestamps_prev = torch.cat([timestamps[:, 0].unsqueeze(1), timestamps[:, :-1]], dim=1)
    batch.payload['time_diff'] = (timestamps - timestamps_prev) // SECONDS_IN_HOUR
    batch.payload['time_diff'][:, 0] = -1

    mask = torch.arange(batch.payload['time_diff'].shape[1], device=batch.device)[None, :] + torch.ones((batch.seq_lens.shape[0], batch.payload['time_diff'].shape[1]), device=batch.device)
    mask[mask > batch.seq_lens[:, None]] = 0.
    mask[mask > 0.] = 1.
    mask = mask.bool()

    batch.payload['time_diff'][~mask] = -1
    
    time_diffs += [batch.payload['time_diff'][batch.payload['time_diff'] != -1].numpy()]
    
time_diffs = np.concatenate(time_diffs)

time_diff_bins = np.quantile(time_diffs, q=[(i / TIME_DIFF_BINS) for i in range(1, TIME_DIFF_BINS)], axis=0)

36it [00:00, 98.07it/s] 


In [26]:
time_diff_bins

array([  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   1.,   2.,   3.,
         4.,   5.,   6.,   7.,   7.,   8.,   9.,   

In [27]:
time_diff_bins = list(set(time_diff_bins.tolist()))
time_diff_bins.sort()
time_diff_bins = torch.tensor(time_diff_bins, dtype=torch.int)
time_diff_bins

tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  20,  22,  24,  26,  31,  35,  38,  44,  48,
         54,  62,  72,  82,  96, 114, 120, 144, 168, 216, 300, 458],
       dtype=torch.int32)

In [28]:
TIME_DIFF_BINS_NUM = len(time_diff_bins)

TIME_DIFF_BINS_NUM

40

**Тест:**

In [29]:
train_loader = inference_data_loader(data_train, num_workers=0, batch_size=128)
SECONDS_IN_HOUR = 3600

for batch in tqdm(train_loader):
    timestamps = batch.payload['event_time']
    timestamps_prev = torch.cat([timestamps[:, 0].unsqueeze(1), timestamps[:, :-1]], dim=1)
    batch.payload['time_diff'] = (timestamps - timestamps_prev) // SECONDS_IN_HOUR
    batch.payload['time_diff'][:, 0] = -1

    mask = torch.arange(batch.payload['time_diff'].shape[1], device=batch.device)[None, :] + torch.ones((batch.seq_lens.shape[0], batch.payload['time_diff'].shape[1]), device=batch.device)
    mask[mask > batch.seq_lens[:, None]] = 0.
    mask[mask > 0.] = 1.
    mask = mask.bool()

    batch.payload['time_diff'][~mask] = -1

    print(torch.bucketize(batch.payload['time_diff'], time_diff_bins, right=True))

22it [00:00, 104.74it/s]

tensor([[ 0, 37,  1,  ...,  0,  0,  0],
        [ 0,  1, 22,  ...,  0,  0,  0],
        [ 0, 12, 37,  ...,  0,  0,  0],
        ...,
        [ 0, 22,  1,  ...,  0,  0,  0],
        [ 0,  1,  1,  ...,  0,  0,  0],
        [ 0,  1, 22,  ...,  0,  0,  0]])
tensor([[ 0,  1,  1,  ...,  0,  0,  0],
        [ 0,  1, 22,  ...,  0,  0,  0],
        [ 0,  1, 22,  ...,  0,  0,  0],
        ...,
        [ 0,  2,  1,  ...,  0,  0,  0],
        [ 0,  1,  1,  ...,  0,  0,  0],
        [ 0,  1, 39,  ...,  0,  0,  0]])
tensor([[ 0, 22,  1,  ...,  0,  0,  0],
        [ 0,  1, 22,  ...,  0,  0,  0],
        [ 0, 22,  1,  ...,  0,  0,  0],
        ...,
        [ 0, 11, 21,  ...,  0,  0,  0],
        [ 0,  1, 22,  ...,  0,  0,  0],
        [ 0,  1, 32,  ...,  0,  0,  0]])
tensor([[ 0, 22, 17,  ...,  0,  0,  0],
        [ 0, 40, 39,  ...,  0,  0,  0],
        [ 0, 28,  1,  ...,  0,  0,  0],
        ...,
        [ 0,  1, 31,  ...,  0,  0,  0],
        [ 0, 40, 39,  ...,  0,  0,  0],
        [ 0, 19,  6,  ...

36it [00:00, 105.80it/s]

tensor([[ 0,  8,  1,  ...,  0,  0,  0],
        [ 0, 21, 39,  ...,  0,  0,  0],
        [ 0,  1,  1,  ...,  0,  0,  0],
        ...,
        [ 0,  1,  1,  ...,  0,  0,  0],
        [ 0,  1, 15,  ...,  0,  0,  0],
        [ 0,  1,  1,  ...,  0,  0,  0]])
tensor([[ 0, 17,  8,  ...,  0,  0,  0],
        [ 0, 22, 28,  ...,  0,  0,  0],
        [ 0,  1, 10,  ...,  0,  0,  0],
        ...,
        [ 0,  1,  1,  ...,  0,  0,  0],
        [ 0,  1, 18,  ...,  0,  0,  0],
        [ 0,  1,  1,  ...,  0,  0,  0]])
tensor([[ 0, 22,  1,  ...,  0,  0,  0],
        [ 0,  1,  1,  ...,  0,  0,  0],
        [ 0,  1,  1,  ...,  0,  0,  0],
        ...,
        [ 0,  1, 10,  ...,  0,  0,  0],
        [ 0,  1,  1,  ...,  0,  0,  0],
        [ 0,  1, 17,  ...,  0,  0,  0]])
tensor([[ 0,  1,  4,  ...,  0,  0,  0],
        [ 0,  1,  1,  ...,  0,  0,  0],
        [ 0, 12,  1,  ...,  0,  0,  0],
        ...,
        [ 0, 22, 22,  ...,  0,  0,  0],
        [ 0,  1, 16,  ...,  0,  0,  0],
        [ 0,  1,  1,  ...




---

**Window Aggregator Class:**

In [30]:
from ptls.data_load.padded_batch import PaddedBatch


class WinAggregator(TrxEncoderT2V):
    """The NN layer, a combination of TrxEncoder and Mean Aggregation within a window of #`agg_samples` transactions 
       (works like nn.Sequential([TrxEncoder, Mean Window Aggregation])).
       It is assumed that any two different windows do not overlap here.
       
       The types of the input and output are `PaddedBatch` of shapes (B, L, T) and (B, L', T) respectively, where 
       B means batch_size,
       L/L' means the max length of a sequence of transactions in a batch (the length is the same as #trx)
       T means the dimension of a single transaction.

       Parameters
        agg_samples (int):
            The number of transactions in a sliding aggregation window.

        use_window_attention (bool):
            If True, the attention layer will be applied to transactions in a sliding window before pooling.

        k (int):
            Number of periodic components in T2V time embeddings

        time_col (str):
            Name of the time column in data
            
        embeddings:
            You can find info about this param in TrxEncoder desc.
        
        numeric_values:
            You can find info about this param in TrxEncoder desc.

        embeddings_noise:
            You can find info about this param in TrxEncoder desc.
            
        emb_dropout:
            You can find info about this param in TrxEncoder desc.
            
        spatial_dropout:
            You can find info about this param in TrxEncoder desc.

        use_batch_norm:
            You can find info about this param in TrxEncoder desc.

        orthogonal_init:
            You can find info about this param in TrxEncoder desc.
            
        linear_projection_size:
            You can find info about this param in TrxEncoder desc.

        out_of_index:
            You can find info about this param in TrxEncoder desc.

        norm_embeddings:
            Keep default value for this parameter
        
        clip_replace_value:
            Not used. Keep default value for this parameter
        
        positions: 
            Not used. Keep default value for this parameter
       """

    def __init__(self,
                 agg_samples=3,
                 use_window_attention=False,
                 embeddings=None,
                 numeric_values=None,
                 custom_embeddings=None,
                 time_values=None,
                 embeddings_noise: float = 0,
                 norm_embeddings=None,
                 use_batch_norm=False,
                 use_batch_norm_with_lens=False,
                 clip_replace_value=None,
                 positions=None,
                 emb_dropout=0,
                 spatial_dropout=False,
                 orthogonal_init=False,
                 linear_projection_size=0,
                 out_of_index: str = 'clip',
                 k=2,
                 time_col='event_time'
                ):
        
        super().__init__(
            embeddings=embeddings,
            numeric_values=numeric_values,
            custom_embeddings=custom_embeddings,
            embeddings_noise=embeddings_noise,
            norm_embeddings=norm_embeddings,
            use_batch_norm=use_batch_norm,
            use_batch_norm_with_lens=use_batch_norm_with_lens,
            clip_replace_value=clip_replace_value,
            positions=positions,
            emb_dropout=emb_dropout,
            spatial_dropout=spatial_dropout,
            orthogonal_init=orthogonal_init,
            linear_projection_size=linear_projection_size,
            out_of_index=out_of_index,
            k=k,
            time_col=time_col
        )

        self.agg_samples = agg_samples

        self.use_window_attention = use_window_attention
        if self.use_window_attention:
            pass # Not Implemented

    def forward(self, pb: PaddedBatch):
        embeds = super().forward(pb)

        mask = torch.arange(embeds.payload.shape[1], device=embeds.device)[None, :] + torch.ones((embeds.seq_lens.shape[0], embeds.payload.shape[1]), device=embeds.device)
        mask[mask > embeds.seq_lens[:, None]] = 0.
        mask[mask > 0.] = 1.
        mask = mask[:, :, None]
    
        masked_embeds = embeds.payload * mask
    
        num_samples_to_add = self.agg_samples - (masked_embeds.shape[1] % self.agg_samples)  
        if num_samples_to_add < self.agg_samples:
            additional_samples = torch.zeros((masked_embeds.shape[0], num_samples_to_add, masked_embeds.shape[2]), device=masked_embeds.device)
            masked_embeds = torch.cat((masked_embeds, additional_samples), dim=1)
            mask_additional_samples = torch.zeros((mask.shape[0], num_samples_to_add, mask.shape[2]), device=mask.device)
            mask = torch.cat((mask, mask_additional_samples), dim=1)
    
        masked_embeds = torch.reshape(masked_embeds, (masked_embeds.shape[0], masked_embeds.shape[1] // self.agg_samples, self.agg_samples, masked_embeds.shape[2]))
        mask = torch.reshape(mask, (mask.shape[0], mask.shape[1] // self.agg_samples, self.agg_samples, mask.shape[2]))

        if self.use_window_attention:
            pass # Not Implemented
        
        mask = torch.sum(mask, dim=2)
        mask[mask == 0.] = 1.
    
        mean_embeds = torch.sum(masked_embeds, dim=2) / mask

        new_seq_lens = embeds.seq_lens // self.agg_samples
        div_mod_seq_lens = ((embeds.seq_lens % self.agg_samples) > 0).int()
        new_seq_lens += div_mod_seq_lens

        return PaddedBatch(mean_embeds, new_seq_lens)

---

**Train sequences lengths check:**

In [173]:
agg_encoder_params = dict(
    embeddings={
        "MCC": {"in": 342, "out": 8},
        "channel_type": {"in": 7, "out": 8},
        "currency": {"in": 60, "out": 8},
        "trx_category": {"in": 11, "out": 8}            
    },
    numeric_values={"amount": "log"},
    embeddings_noise=0.003,
    k=7,
    time_col="event_time",
    agg_samples=9, # 3, 5, 7, 9
    use_window_attention=False
)

trx_encoder = WinAggregator(**agg_encoder_params)
trx_encoder.to("cuda")

WinAggregator(
  (embeddings): ModuleDict(
    (MCC): NoisyEmbedding(
      342, 8, padding_idx=0
      (dropout): Dropout(p=0, inplace=False)
    )
    (channel_type): NoisyEmbedding(
      7, 8, padding_idx=0
      (dropout): Dropout(p=0, inplace=False)
    )
    (currency): NoisyEmbedding(
      60, 8, padding_idx=0
      (dropout): Dropout(p=0, inplace=False)
    )
    (trx_category): NoisyEmbedding(
      11, 8, padding_idx=0
      (dropout): Dropout(p=0, inplace=False)
    )
  )
  (custom_embeddings): ModuleDict(
    (amount): LogScaler()
  )
  (time2vec_days): Time2Vec()
)

In [27]:
# train_loader = inference_data_loader(data_train, num_workers=0, batch_size=128)

# trx_encoder.eval()

# min_len = np.inf
# max_len = 0

# for batch in tqdm(train_loader):
#     embeds_batch = trx_encoder(batch.to("cuda"))
#     seq_lens = embeds_batch.seq_lens
#     min_len = min(min_len, seq_lens.min())
#     max_len = max(max_len, seq_lens.max())

# print("Min Length:", min_len.item())
# print("Max Length:", max_len.item())

In [174]:
train_loader = inference_data_loader(data_train, num_workers=0, batch_size=128)

trx_encoder.eval()

seq_lens = []

for batch in tqdm(train_loader):
    embeds_batch = trx_encoder(batch.to("cuda"))
    seq_lens += [embeds_batch.seq_lens.detach().cpu().numpy()]

seq_lens = np.concatenate(seq_lens)

threshold = int(np.quantile(seq_lens, 0.75) * 0.7)

print("Max Length:", threshold)

36it [00:00, 89.45it/s]

Max Length: 11





---

# Sliding Window Aggregation (Mean Pooling) 

- **COLES:**

In [203]:
seed_everything(42)

**DataLoaders:**

In [204]:
# seq_lens:
#   if agg == 3: (1, 262) => (cnt_min=5, cnt_max=33)
#   if agg == 5: (1, 157) => (cnt_min=5, cnt_max=20)
#   if agg == 7: (1, 79) => (cnt_min=5, cnt_max=14)

data = PtlsDataModule(
    train_data=ColesDataset(
        MemoryMapDataset(
            data=data_train,
            i_filters=[SeqLenFilter(min_seq_len=10)],
        ),
        splitter=SampleSlices(
            split_count=5,
            cnt_min=5,
            cnt_max=14,
        ),
    ),
    train_num_workers=4,
    train_batch_size=128,
    valid_data=ColesDataset(
        MemoryMapDataset(
            data=data_test,
            i_filters=[SeqLenFilter(min_seq_len=10)],
        ),
        splitter=SampleSlices(
            split_count=5,
            cnt_min=5,
            cnt_max=14,
        ),
    ),
    valid_num_workers=4,
    valid_batch_size=128
)

**Модель:**

In [205]:
N_EPOCHS = 20

In [206]:
agg_encoder_params = dict(
    embeddings={
        "MCC": {"in": 342, "out": 8},
        "channel_type": {"in": 7, "out": 8},
        "currency": {"in": 60, "out": 8},
        "trx_category": {"in": 11, "out": 8}            
    },
    numeric_values={"amount": "log"},
    embeddings_noise=0.003,
    k=7,
    time_col="event_time",
    agg_samples=7, # 3, 5, 7, 9
    use_window_attention=False
)

trx_encoder = WinAggregator(**agg_encoder_params)

seq_encoder = RnnSeqEncoder(
    trx_encoder=trx_encoder,
    hidden_size=512,
    type="gru"
)

coles = CoLESModule(
    seq_encoder=seq_encoder,
    optimizer_partial=partial(torch.optim.Adam, lr=1e-3, weight_decay=0),
    lr_scheduler_partial=partial(torch.optim.lr_scheduler.CosineAnnealingLR, T_max=N_EPOCHS, eta_min=5e-6)
)

**Обучение:**

In [207]:
logger = CometLogger(project_name="evs-ssl-rb", experiment_name="CoLES_WinAgg (7 trx)")

trainer = pl.Trainer(
    logger=logger,
    max_epochs=N_EPOCHS,
    accelerator="gpu",
    devices=1,
    enable_progress_bar=True
)

In [208]:
trainer.fit(coles, data)

[1;38;5;39mCOMET INFO:[0m Experiment is live on comet.com https://www.comet.com/askoro/evs-ssl-rb/114f9bcd1b72488c9a241f41ff4ada23

[1;38;5;39mCOMET INFO:[0m Couldn't find a Git repository in '/kaggle/working' nor in any parent directory. Set `COMET_GIT_DIRECTORY` if your Git Repository is elsewhere.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]


The number of training batches (33) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.



Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m Comet.ml Experiment Summary
[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m   Data:
[1;38;5;39mCOMET INFO:[0m     display_summary_level : 1
[1;38;5;39mCOMET INFO:[0m     name                  : CoLES_WinAgg (9 trx)
[1;38;5;39mCOMET INFO:[0m     url                   : https://www.comet.com/askoro/evs-ssl-rb/114f9bcd1b72488c9a241f41ff4ada23
[1;38;5;39mCOMET INFO:[0m   Metrics [count] (min, max):
[1;38;5;39mCOMET INFO:[0m     loss [79]               : (134.19076538085938, 693.1944580078125)
[1;38;5;39mCOMET INFO:[0m     seq_len [13]            : (7.817187786102295, 8.112500190734863)
[1;38;5;39mCOMET INFO:[0m     valid/recall_top_k [20] : (0.2074527144432068, 0.25891241431236267)
[1;38;5;39mCOMET INFO:[0m   Others:
[1;38;5;39mCOMET INF

In [209]:
trainer.logged_metrics

{'loss': tensor(136.4910),
 'seq_len': tensor(7.9541),
 'valid/recall_top_k': tensor(0.2589)}

In [28]:
torch.save(seq_encoder.state_dict(), "coles_enc_baseline_rosbank.pt")

**Измерим качество на тесте (catboost поверх эмбеддингов):**

In [None]:
# !wget "https://drive.google.com/uc?export=download&id=1Mn8o9IPT4Zzg3946orbw1MVZwpkrBoNb" -O "coles_enc_baseline.pt"

In [210]:
encoder = coles.seq_encoder

# state_dict = torch.load("./coles_enc_baseline.pt")
# encoder.load_state_dict(state_dict)

device = "cuda:0"

encoder.to(device)

RnnSeqEncoder(
  (trx_encoder): WinAggregator(
    (embeddings): ModuleDict(
      (MCC): NoisyEmbedding(
        342, 8, padding_idx=0
        (dropout): Dropout(p=0, inplace=False)
      )
      (channel_type): NoisyEmbedding(
        7, 8, padding_idx=0
        (dropout): Dropout(p=0, inplace=False)
      )
      (currency): NoisyEmbedding(
        60, 8, padding_idx=0
        (dropout): Dropout(p=0, inplace=False)
      )
      (trx_category): NoisyEmbedding(
        11, 8, padding_idx=0
        (dropout): Dropout(p=0, inplace=False)
      )
    )
    (custom_embeddings): ModuleDict(
      (amount): LogScaler()
    )
    (time2vec_days): Time2Vec()
  )
  (seq_encoder): RnnEncoder(
    (rnn): GRU(41, 512, batch_first=True)
    (reducer): LastStepEncoder()
  )
)

In [211]:
from tqdm import tqdm

seed_everything(42)

In [212]:
train_loader = inference_data_loader(data_train, num_workers=0, batch_size=128)
encoder.eval()
train_embeds = None

for i, batch in tqdm(enumerate(train_loader)):
    train_embeds_batch = encoder(batch.to(device))
    if i == 0:
        train_embeds = train_embeds_batch.detach().cpu().numpy()
    else:
        train_embeds = np.concatenate([train_embeds, train_embeds_batch.detach().cpu().numpy()], axis=0)
    
train_embeds

36it [00:00, 51.36it/s]


array([[ 0.57936627, -0.18378884, -0.17574356, ..., -0.21710737,
        -0.06292131, -0.32704908],
       [ 0.606927  , -0.1554516 , -0.31397867, ..., -0.43650442,
        -0.00421504, -0.40625575],
       [ 0.7031255 , -0.16951774, -0.56179416, ..., -0.21568154,
         0.00166862, -0.2231641 ],
       ...,
       [ 0.3167344 , -0.34212625, -0.01795288, ..., -0.03573746,
        -0.19905591, -0.08238798],
       [ 0.41005734, -0.1501216 , -0.18155561, ..., -0.18371323,
        -0.2751087 , -0.19519602],
       [ 0.4931113 , -0.12182132, -0.16451961, ..., -0.10161787,
        -0.21656094, -0.02736257]], dtype=float32)

In [213]:
test_loader = inference_data_loader(data_test, num_workers=0, batch_size=128)
encoder.eval()
test_embeds = None

for i, batch in tqdm(enumerate(test_loader)):
    test_embeds_batch = encoder(batch.to(device))
    if i == 0:
        test_embeds = test_embeds_batch.detach().cpu().numpy()
    else:
        test_embeds = np.concatenate([test_embeds, test_embeds_batch.detach().cpu().numpy()], axis=0)
    
test_embeds

4it [00:00, 53.10it/s]


array([[ 0.60895497, -0.10272635, -0.16864115, ..., -0.29525587,
        -0.01621668, -0.3742522 ],
       [ 0.71055967, -0.0142568 , -0.37561595, ..., -0.39505777,
         0.04606154, -0.34022102],
       [ 0.7093172 , -0.22727157, -0.32106084, ..., -0.4053869 ,
         0.00886898, -0.36873755],
       ...,
       [ 0.5287702 , -0.16252929, -0.35051206, ..., -0.25566188,
        -0.272476  , -0.0036813 ],
       [ 0.47882453, -0.20817941, -0.16450028, ..., -0.18157995,
        -0.2395215 ,  0.00882795],
       [ 0.5537437 , -0.38179442, -0.5657034 , ..., -0.32667106,
        -0.42998973,  0.02866538]], dtype=float32)

In [214]:
clf = CatBoostClassifier(loss_function='MultiClass', task_type="GPU", devices='0', random_state=42)

clf.fit(train_embeds, target_train, plot_file="catboost_log.html")

Learning rate set to 0.088214
0:	learn: 0.6680129	total: 11.9ms	remaining: 11.9s
1:	learn: 0.6480058	total: 19ms	remaining: 9.49s
2:	learn: 0.6289447	total: 26.2ms	remaining: 8.7s
3:	learn: 0.6132147	total: 32.8ms	remaining: 8.17s
4:	learn: 0.5984806	total: 39.3ms	remaining: 7.82s
5:	learn: 0.5859154	total: 46.1ms	remaining: 7.64s
6:	learn: 0.5748862	total: 53.1ms	remaining: 7.54s
7:	learn: 0.5659761	total: 60.4ms	remaining: 7.49s
8:	learn: 0.5572394	total: 67.3ms	remaining: 7.41s
9:	learn: 0.5497540	total: 74.2ms	remaining: 7.34s
10:	learn: 0.5428043	total: 81ms	remaining: 7.29s
11:	learn: 0.5364427	total: 88ms	remaining: 7.24s
12:	learn: 0.5306488	total: 94.9ms	remaining: 7.21s
13:	learn: 0.5249762	total: 102ms	remaining: 7.17s
14:	learn: 0.5199758	total: 109ms	remaining: 7.15s
15:	learn: 0.5154832	total: 116ms	remaining: 7.11s
16:	learn: 0.5113913	total: 123ms	remaining: 7.08s
17:	learn: 0.5076013	total: 130ms	remaining: 7.07s
18:	learn: 0.5032816	total: 137ms	remaining: 7.07s
19:	l

<catboost.core.CatBoostClassifier at 0x78d86b987820>

In [215]:
test_pred = clf.predict(test_embeds)
test_proba = clf.predict_proba(test_embeds)[:, 1]

In [216]:
print("Accuracy:", accuracy_score(target_test, test_pred))
print("ROC-AUC:", roc_auc_score(target_test, test_proba))

Accuracy: 0.714
ROC-AUC: 0.7819526962490488


In [218]:
arr = np.array([0.7960207864531901, 0.7836525230286056, 0.7819526962490488])

arr.mean(), arr.std()

(0.7872086685769482, 0.006269631507976172)

- COLES embeds + Catboost:
  - `Accuracy: 0.736`, `0.72`, `0.722`, avg: `0.726 +- 0.0071` 
  -  `ROC-AUC: 0.8099107995661394`, `0.8041475773421184`, `0.8088423370189894`, avg: `0.8076 +- 0.0025`

---

- COLES embeds + WinAgg (3 trx) + Catboost:
  - Accuracy: `0.738`, `0.716`, `0.722`, avg: `0.7253 +- 0.0093`
  - ROC-AUC: `0.8092470576807888`, `0.8035000242832397`, `0.7907917955027439`, avg: `0.8012 +- 0.0077`

---

- COLES embeds + WinAgg (5 trx) + Catboost:
  - Accuracy: `0.736`, `0.73`, `0.732`, avg: `0.7327 +- 0.0025`
  - ROC-AUC: `0.8147998251606741`, `0.8037266678538473`, `0.809651778342588`, avg: `0.8094 +- 0.0045`

---

- COLES embeds + WinAgg (7 trx) + Catboost:
  - Accuracy: `0.736`, `0.748`, `0.734`, avg: `0.7393 +- 0.0062`
  - ROC-AUC: `0.8078386297777275`, `0.8121772352722151`, `0.802755338265529`, avg: `0.8076 +- 0.0039`

---

- COLES embeds + WinAgg (9 trx) + Catboost:
  - Accuracy: `0.722`, `0.722`, `0.714`, avg: `0.7193 +- 0.0038`
  - ROC-AUC: `0.7960207864531901`, `0.7836525230286056`, `0.7819526962490488`, avg: `0.7872 +- 0.0063`

**Вывод:** для CoLES тенденция такова, что поначалу с увеличением агрегирующего окна качество (как по accuracy, так и по ROC-AUC) либо лучше бейзлайна, либо сравнимо с ним (как по отдельным сидам, так и в среднем) - и так до размера окна, равного 7. Но с дальнейшим увеличением  агрегирующего окна, результаты резко ухудшаются, становятся хуже, чем бейзлайн. Дальше не используем окна размера 9+ - предполагаем, что при таком сильном сжатии информации подобная тенденция будет сохраняться и для других методов.

**Лучший по ROC-AUC результат:**  

- COLES embeds + WinAgg (5 trx) + Catboost:
  - Accuracy: `0.736`, `0.73`, `0.732`, avg: `0.7327 +- 0.0025`
  - ROC-AUC: `0.8147998251606741`, `0.8037266678538473`, `0.809651778342588`, avg: `0.8094 +- 0.0045`

**Лучший по accuracy результат:**  

- COLES embeds + WinAgg (7 trx) + Catboost:
  - Accuracy: `0.736`, `0.748`, `0.734`, avg: `0.7393 +- 0.0062`
  - ROC-AUC: `0.8078386297777275`, `0.8121772352722151`, `0.802755338265529`, avg: `0.8076 +- 0.0039`



---

**Train sequences lengths check:**

In [388]:
agg_encoder_params = dict(
    embeddings={
        "MCC": {"in": 342, "out": 8},
        "channel_type": {"in": 7, "out": 8},
        "currency": {"in": 60, "out": 8},
        "trx_category": {"in": 11, "out": 8}            
    },
    numeric_values={"amount": "log"},
    embeddings_noise=0.003,
    k=7,
    time_col="event_time",
    agg_samples=6, # 3, 6
    use_window_attention=False
)

trx_encoder = WinAggregator(**agg_encoder_params)
trx_encoder.to("cuda")

WinAggregator(
  (embeddings): ModuleDict(
    (MCC): NoisyEmbedding(
      342, 8, padding_idx=0
      (dropout): Dropout(p=0, inplace=False)
    )
    (channel_type): NoisyEmbedding(
      7, 8, padding_idx=0
      (dropout): Dropout(p=0, inplace=False)
    )
    (currency): NoisyEmbedding(
      60, 8, padding_idx=0
      (dropout): Dropout(p=0, inplace=False)
    )
    (trx_category): NoisyEmbedding(
      11, 8, padding_idx=0
      (dropout): Dropout(p=0, inplace=False)
    )
  )
  (custom_embeddings): ModuleDict(
    (amount): LogScaler()
  )
  (time2vec_days): Time2Vec()
)

In [391]:
train_loader = inference_data_loader(data_train, num_workers=0, batch_size=128)

trx_encoder.eval()

seq_lens = []

for batch in tqdm(train_loader):
    embeds_batch = trx_encoder(batch.to("cuda"))
    seq_lens += [embeds_batch.seq_lens.detach().cpu().numpy()]

seq_lens = np.concatenate(seq_lens)

threshold = int(np.quantile(seq_lens, 0.6))

print("Max Length:", threshold)

36it [00:00, 92.71it/s]

Max Length: 14





---

- **CPC modeling:**

In [422]:
seed_everything(42)

**DataLoaders:**

In [423]:
data = PtlsDataModule(
    train_data=CpcDataset(
        MemoryMapDataset(data=data_train),
        min_len=14,             
        max_len=18
    ),
    train_num_workers=4,
    train_batch_size=128,
    valid_data=CpcDataset(
        MemoryMapDataset(data=data_test),
        min_len=14,
        max_len=18
    ),
    valid_num_workers=4,
    valid_batch_size=128
)

**Модель:**

In [424]:
N_EPOCHS = 20

In [425]:
agg_encoder_params = dict(
    embeddings={
        "MCC": {"in": 342, "out": 32}, # 8 / 32
        "channel_type": {"in": 7, "out": 32},
        "currency": {"in": 60, "out": 32},
        "trx_category": {"in": 11, "out": 32}            
    },
    numeric_values={"amount": "log"},
    embeddings_noise=0.003,
    k=31,
    time_col="event_time",
    agg_samples=6, # 3, 6
    use_window_attention=False
)

trx_encoder = WinAggregator(**agg_encoder_params)

seq_encoder = RnnSeqEncoder(
    trx_encoder=trx_encoder,
    hidden_size=512,
    type="gru"
)

cpc = CpcModule(
    seq_encoder=seq_encoder,
    n_forward_steps=1, # 2, 1
    n_negatives=40,
    optimizer_partial=partial(torch.optim.Adam, lr=5e-4),
    lr_scheduler_partial=partial(torch.optim.lr_scheduler.StepLR, step_size=5, gamma=0.5)
)

**Обучение:**

In [426]:
logger = CometLogger(project_name="evs-ssl-rb", experiment_name="CPC_modeling_WinAgg (6 trx, emb_dim=32)")

trainer = pl.Trainer(
    logger=logger,
    max_epochs=N_EPOCHS,
    accelerator="gpu",
    devices=1,
    enable_progress_bar=True
)

In [427]:
trainer.fit(cpc, data)

[1;38;5;39mCOMET INFO:[0m Couldn't find a Git repository in '/kaggle/working' nor in any parent directory. Set `COMET_GIT_DIRECTORY` if your Git Repository is elsewhere.
[1;38;5;39mCOMET INFO:[0m Experiment is live on comet.com https://www.comet.com/askoro/evs-ssl-rb/369e6c019cbe451e80163c14ec8ceccb



Sanity Checking: |          | 0/? [00:00<?, ?it/s]


The number of training batches (36) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.



Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m Comet.ml Experiment Summary
[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m   Data:
[1;38;5;39mCOMET INFO:[0m     display_summary_level : 1
[1;38;5;39mCOMET INFO:[0m     name                  : CPC_modeling_WinAgg (6 trx, emb_dim=32)
[1;38;5;39mCOMET INFO:[0m     url                   : https://www.comet.com/askoro/evs-ssl-rb/369e6c019cbe451e80163c14ec8ceccb
[1;38;5;39mCOMET INFO:[0m   Metrics [count] (min, max):
[1;38;5;39mCOMET INFO:[0m     loss [86]               : (1.1486479043960571, 4.482651710510254)
[1;38;5;39mCOMET INFO:[0m     seq_len [14]            : (14.40625, 15.5625)
[1;38;5;39mCOMET INFO:[0m     valid/cpc_accuracy [20] : (0.2008453905582428, 0.46597546339035034)
[1;38;5;39mCOMET INFO:[0m   Others:
[1;38;5;39mCOMET INF

In [428]:
trainer.logged_metrics

{'loss': tensor(2.0498),
 'seq_len': tensor(14.1500),
 'valid/cpc_accuracy': tensor(0.4512)}

In [82]:
torch.save(seq_encoder.state_dict(), "cpc_enc_baseline_rosbank.pt")

**Измерим качество на тесте (catboost поверх эмбеддингов):**

In [None]:
# !wget "https://drive.google.com/uc?export=download&id=11j6QgNsdOSTK-GRaAJLKObDW7ehS_aqK" -O "cpc_enc_baseline_higher_trx_dim.pt"

In [429]:
encoder = cpc.seq_encoder

# state_dict = torch.load("./cpc_enc_baseline_higher_trx_dim.pt")
# encoder.load_state_dict(state_dict)

device = "cuda:0"

encoder.to(device)

RnnSeqEncoder(
  (trx_encoder): WinAggregator(
    (embeddings): ModuleDict(
      (MCC): NoisyEmbedding(
        342, 32, padding_idx=0
        (dropout): Dropout(p=0, inplace=False)
      )
      (channel_type): NoisyEmbedding(
        7, 32, padding_idx=0
        (dropout): Dropout(p=0, inplace=False)
      )
      (currency): NoisyEmbedding(
        60, 32, padding_idx=0
        (dropout): Dropout(p=0, inplace=False)
      )
      (trx_category): NoisyEmbedding(
        11, 32, padding_idx=0
        (dropout): Dropout(p=0, inplace=False)
      )
    )
    (custom_embeddings): ModuleDict(
      (amount): LogScaler()
    )
    (time2vec_days): Time2Vec()
  )
  (seq_encoder): RnnEncoder(
    (rnn): GRU(161, 512, batch_first=True)
    (reducer): LastStepEncoder()
  )
)

In [430]:
encoder.seq_encoder.is_reduce_sequence = True

In [431]:
from tqdm import tqdm

seed_everything(42)

In [432]:
train_loader = inference_data_loader(data_train, num_workers=0, batch_size=128)
encoder.eval()
train_embeds = None

for i, batch in tqdm(enumerate(train_loader)):
    train_embeds_batch = encoder(batch.to(device))
    if i == 0:
        train_embeds = train_embeds_batch.detach().cpu().numpy()
    else:
        train_embeds = np.concatenate([train_embeds, train_embeds_batch.detach().cpu().numpy()], axis=0)
    
train_embeds

36it [00:00, 46.24it/s]


array([[ 0.3993237 , -0.51909906,  0.04315634, ..., -0.1392285 ,
        -0.18663862,  0.34905386],
       [ 0.8189407 , -0.77492315,  0.2741767 , ...,  0.41815653,
         0.52734065, -0.35107693],
       [ 0.8040916 , -0.6616869 ,  0.3346707 , ...,  0.3550261 ,
         0.27005607, -0.40699565],
       ...,
       [ 0.7689617 , -0.7887827 , -0.29147953, ...,  0.6089328 ,
         0.2920785 , -0.49989182],
       [ 0.7704482 , -0.7663792 , -0.12659185, ...,  0.55371827,
         0.5413032 , -0.44167453],
       [ 0.34682375, -0.4594525 , -0.08122421, ...,  0.20601241,
        -0.03729162, -0.00693535]], dtype=float32)

In [433]:
test_loader = inference_data_loader(data_test, num_workers=0, batch_size=128)
encoder.eval()
test_embeds = None

for i, batch in tqdm(enumerate(test_loader)):
    test_embeds_batch = encoder(batch.to(device))
    if i == 0:
        test_embeds = test_embeds_batch.detach().cpu().numpy()
    else:
        test_embeds = np.concatenate([test_embeds, test_embeds_batch.detach().cpu().numpy()], axis=0)
    
test_embeds

4it [00:00, 47.71it/s]


array([[ 0.806029  , -0.794624  ,  0.19617717, ...,  0.6187804 ,
         0.4524188 , -0.6174263 ],
       [ 0.57099754, -0.5680157 , -0.06303531, ...,  0.157916  ,
         0.07843278,  0.11238448],
       [ 0.78264284, -0.78219396,  0.07278191, ...,  0.540685  ,
         0.4575965 , -0.36384177],
       ...,
       [ 0.5912965 , -0.64527106, -0.11542716, ...,  0.30719984,
         0.35629836, -0.28690526],
       [ 0.676535  , -0.78397256, -0.32664165, ...,  0.5495547 ,
         0.32551172, -0.6116832 ],
       [ 0.7717395 , -0.83962274,  0.02277143, ...,  0.65164524,
         0.6138645 , -0.5330398 ]], dtype=float32)

In [434]:
clf = CatBoostClassifier(loss_function='MultiClass', task_type="GPU", devices='0', random_state=42)

clf.fit(train_embeds, target_train, plot_file="catboost_log.html")

Learning rate set to 0.088214
0:	learn: 0.6658668	total: 9.86ms	remaining: 9.85s
1:	learn: 0.6422692	total: 16.7ms	remaining: 8.31s
2:	learn: 0.6222961	total: 23.2ms	remaining: 7.73s
3:	learn: 0.6045720	total: 29.6ms	remaining: 7.37s
4:	learn: 0.5895609	total: 35.9ms	remaining: 7.14s
5:	learn: 0.5765608	total: 42.1ms	remaining: 6.98s
6:	learn: 0.5647210	total: 48.8ms	remaining: 6.92s
7:	learn: 0.5543978	total: 55.4ms	remaining: 6.87s
8:	learn: 0.5454646	total: 62.1ms	remaining: 6.84s
9:	learn: 0.5373023	total: 68.6ms	remaining: 6.79s
10:	learn: 0.5302237	total: 75ms	remaining: 6.74s
11:	learn: 0.5227453	total: 81.4ms	remaining: 6.7s
12:	learn: 0.5164884	total: 87.7ms	remaining: 6.66s
13:	learn: 0.5108631	total: 94.2ms	remaining: 6.63s
14:	learn: 0.5057859	total: 101ms	remaining: 6.61s
15:	learn: 0.5012138	total: 107ms	remaining: 6.59s
16:	learn: 0.4971354	total: 114ms	remaining: 6.56s
17:	learn: 0.4937541	total: 120ms	remaining: 6.54s
18:	learn: 0.4901459	total: 126ms	remaining: 6.53s


<catboost.core.CatBoostClassifier at 0x78d838579f00>

In [435]:
test_pred = clf.predict(test_embeds)
test_proba = clf.predict_proba(test_embeds)[:, 1]

In [436]:
print("Accuracy:", accuracy_score(target_test, test_pred))
print("ROC-AUC:", roc_auc_score(target_test, test_proba))

Accuracy: 0.72
ROC-AUC: 0.8043580320862542


In [438]:
arr = np.array([0.795988408800246, 0.8004079584270936, 0.8043580320862542])

arr.mean(), arr.std()

(0.8002514664378646, 0.003418675746975949)

- CPC context embeds w/ Aug + Catboost (dim of trx embeds: 32):
  - `Accuracy: 0.752`, `0.748`, `0.742`, avg: `0.7473 +- 0.0041`
  - `ROC-AUC: 0.8051836622363244`, `0.8137313626135242`, `0.810639296757378`, avg: `0.8099 +- 0.0035`

---

- CPC context embeds (2 forward steps) + WinAgg (3 trx) + Catboost:
  - `Accuracy: 0.738`, `0.754`, `0.752`, avg: `0.748 +- 0.0071`
  - `ROC-AUC: 0.8033381360185199`, `0.817033883213806`, `0.8130352430752295`, avg: `0.8111 +- 0.0058`

---

- CPC context embeds (1 forward step) + WinAgg (6 trx) + Catboost:
  - `Accuracy: 0.726`, `0.736`, `0.72`, avg: `0.7273 +- 0.0066`
  - `ROC-AUC: 0.795988408800246`, `0.8004079584270936`, `0.8043580320862542`, avg: `0.8003 +- 0.0034`


**Вывод:** результаты при агрегации по 3 транзакциям выходят немного лучше, чем для бейзлайна (хотя с учётом std они в целом сравнимы).

При агрегации по 6 транзакциям результаты выходят значительно хуже, чем в остальных случаях.

---

**Результаты для CPC с меньшей размерностью embed_dim (8):**

- CPC context embeds w/ Aug + Catboost (dim of trx embeds: 8):
  - `Accuracy: 0.754`, `0.742`, `0.744`, avg: `0.7467 +- 0.0052`
  - `ROC-AUC: 0.8175195480079649`, `0.8197697948875686`, `0.8122096129251591`, avg: `0.8165 +- 0.0032`

---

- CPC context embeds (2 forward steps) + WinAgg (3 trx) + Catboost:
  - `Accuracy: 0.77`, `0.72`, `0.712`, avg: `0.734 +- 0.0257`
  - `ROC-AUC: 0.8152693011283614`, `0.8042770879538943`, `0.7939162390118341`, avg: `0.8045 +- 0.0087`

---

- CPC context embeds (1 forward step) + WinAgg (6 trx) + Catboost:
  - `Accuracy: 0.732`, `0.73`, `0.738`, avg: `0.7333 +- 0.0034`
  - `ROC-AUC: 0.7938352948794742`, `0.7973320813974194`, `0.7982062780269059`, avg: `0.7965 +- 0.0019`

**Вывод:** результаты при агрегации транзакций выходят хуже, чем без неё.

---

- **GPT:**

In [67]:
seed_everything(42)

**DataLoaders:**

In [68]:
data = PtlsDataModule(
    train_data=GptDataset(
        MemoryMapDataset(data=data_train),
        min_len=1000, # 85
        max_len=1200 # 105
    ),
    train_num_workers=4,
    train_batch_size=64,
    valid_data=GptDataset(
        MemoryMapDataset(data=data_test),
        min_len=1000,
        max_len=1200
    ),
    valid_num_workers=4,
    valid_batch_size=64
)

**Модель:**

In [33]:
from torchmetrics import MeanMetric
from typing import Tuple, Dict, List, Union
from torch import nn
import torch.nn.functional as F 
from ptls.nn.seq_encoder.abs_seq_encoder import AbsSeqEncoder
from ptls.nn import PBL2Norm
from ptls.data_load.padded_batch import PaddedBatch


class MeanPooling(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, pb: PaddedBatch):
        payload = pb.payload # (B, T, H)
        mask = pb.seq_len_mask.bool()
        pb_mean = payload.sum(dim=1) / mask.float().sum(dim=1, keepdim=True)
        return pb_mean


class StatPooling(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, pb: PaddedBatch):
        payload = pb.payload # (B, T, H)
        mask = pb.seq_len_mask.bool()
        inf_mask = torch.zeros_like(mask, device=mask.device).float()
        inf_mask[~mask] = -torch.inf
        
        pb_mean = payload.sum(dim=1) / mask.float().sum(dim=1, keepdim=True)
        pb_max = torch.max(payload + inf_mask.unsqueeze(-1), dim=1)[0]
        pb_stat = torch.cat((pb_mean, pb_max), dim=1)
        return pb_stat


class GPTHead(torch.nn.Module):   
    def __init__(self, input_size, n_classes, hidden_size=64, drop_p=0.1):
        super().__init__()
        self.head = nn.Sequential(
            nn.Linear(input_size, hidden_size, bias=True),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(hidden_size, n_classes)
        )
    def forward(self, x):
        x = self.head(x)
        return x


class GptPretrainWithTimeDiffsModule(pl.LightningModule):
    def __init__(self,
                 trx_encoder: torch.nn.Module,
                 seq_encoder: AbsSeqEncoder,
                 time_diffs_boundaries: torch.tensor,
                 time_diffs_bins_num: int,
                 head_hidden_size: int = 64,
                 total_steps: int = 64000,
                 seed_seq_len: int = 16,
                 max_lr: float = 0.00005,
                 weight_decay: float = 0.0,
                 pct_start: float = 0.1,
                 norm_predict: bool = False
                 ):

        super().__init__()
        self.save_hyperparameters(ignore=['trx_encoder', 'seq_encoder'])

        self.trx_encoder = trx_encoder
        self._seq_encoder = seq_encoder
        self._seq_encoder.is_reduce_sequence = False

        self.head = nn.ModuleDict()
        for col_name, noisy_emb in self.trx_encoder.embeddings.items():
            self.head[col_name] = GPTHead(input_size=self._seq_encoder.embedding_size, hidden_size=head_hidden_size, n_classes=noisy_emb.num_embeddings)

        self.head['time_diff'] = GPTHead(input_size=self._seq_encoder.embedding_size, hidden_size=head_hidden_size, n_classes=time_diffs_bins_num) 

        if self.hparams.norm_predict:
            self.fn_norm_predict = PBL2Norm()

        self.loss = nn.CrossEntropyLoss(ignore_index=0)

        self.train_gpt_loss = MeanMetric()
        self.valid_gpt_loss = MeanMetric()

        self.time_diffs_boundaries = time_diffs_boundaries
        self.SECONDS_IN_HOUR = 3600

    def forward(self, batch: PaddedBatch):
        z_trx = self.trx_encoder(batch) 
        out = self._seq_encoder(z_trx)
        if self.hparams.norm_predict:
            out = self.fn_norm_predict(out)
        return out

    def loss_gpt(self, logits, labels):
        loss = 0
        for col_name, head in self.head.items():
            y_pred = head(logits[:, self.hparams.seed_seq_len:-1, :])
            y_pred = y_pred.view(-1, y_pred.size(-1))

            y_true = labels[col_name][:, self.hparams.seed_seq_len+1:]
            y_true = torch.flatten(y_true.long())
            
            loss += self.loss(y_pred, y_true)
            
        return loss

    def training_step(self, batch, batch_idx):
        out = self.forward(batch)  # PB: B, T, H
        out = out.payload if isinstance(out, PaddedBatch) else out
        labels = batch.payload

        timestamps = labels['event_time']
        timestamps_prev = torch.cat([timestamps[:, 0].unsqueeze(1), timestamps[:, :-1]], dim=1)
        labels['time_diff'] = (timestamps - timestamps_prev) // self.SECONDS_IN_HOUR
        labels['time_diff'][:, 0] = -1

        mask = torch.arange(labels['time_diff'].shape[1], device=batch.device)[None, :] + torch.ones((batch.seq_lens.shape[0], labels['time_diff'].shape[1]), device=batch.device)
        mask[mask > batch.seq_lens[:, None]] = 0.
        mask[mask > 0.] = 1.
        mask = mask.bool()

        labels['time_diff'][~mask] = -1

        labels['time_diff'] = torch.bucketize(labels['time_diff'], self.time_diffs_boundaries.to(batch.device), right=True)
        
        loss_gpt = self.loss_gpt(out, labels)
        self.train_gpt_loss(loss_gpt)
        self.log(f'loss', loss_gpt, sync_dist=True)
        return loss_gpt

    def validation_step(self, batch, batch_idx):
        out = self.forward(batch)  # PB: B, T, H
        out = out.payload if isinstance(out, PaddedBatch) else out
        labels = batch.payload

        timestamps = labels['event_time']
        timestamps_prev = torch.cat([timestamps[:, 0].unsqueeze(1), timestamps[:, :-1]], dim=1)
        labels['time_diff'] = (timestamps - timestamps_prev) // self.SECONDS_IN_HOUR
        labels['time_diff'][:, 0] = -1

        mask = torch.arange(labels['time_diff'].shape[1], device=batch.device)[None, :] + torch.ones((batch.seq_lens.shape[0], labels['time_diff'].shape[1]), device=batch.device)
        mask[mask > batch.seq_lens[:, None]] = 0.
        mask[mask > 0.] = 1.
        mask = mask.bool()

        labels['time_diff'][~mask] = -1

        labels['time_diff'] = torch.bucketize(labels['time_diff'], self.time_diffs_boundaries.to(batch.device), right=True)
        
        loss_gpt = self.loss_gpt(out, labels)
        self.valid_gpt_loss(loss_gpt)

    def on_training_epoch_end(self):
        self.log('train loss (by epochs)', self.train_gpt_loss, prog_bar=True, logger=True, sync_dist=True, rank_zero_only=True)

    def on_validation_epoch_end(self):
        self.log('val loss (by epochs)', self.valid_gpt_loss, prog_bar=True, logger=True, sync_dist=True, rank_zero_only=True)

    def configure_optimizers(self):
        optim = torch.optim.NAdam(self.parameters(),
                                  lr=self.hparams.max_lr,
                                  weight_decay=self.hparams.weight_decay
                                 )
        
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer=optim,
            max_lr=self.hparams.max_lr,
            total_steps=self.hparams.total_steps,
            pct_start=self.hparams.pct_start,
            anneal_strategy='cos',
            cycle_momentum=False,
            div_factor=25.0,
            final_div_factor=10000.0,
            three_phase=False
        )
        
        scheduler = {'scheduler': scheduler, 'interval': 'step'}
        return [optim], [scheduler]
    
    @property
    def seq_encoder(self):
        return GPTInferenceModule(pretrained_model=self)


class GPTInferenceModule(torch.nn.Module):
    def __init__(self, pretrained_model):
        super().__init__()
        self.model = pretrained_model
        self.model.is_reduce_sequence = False
        self.mean_pooling = MeanPooling()
        self.stat_pooling = StatPooling()

    def forward(self, batch, eval_strategy="mean"):
        z_trx = self.model.trx_encoder(batch)
        out = self.model._seq_encoder(z_trx)
        out = out if isinstance(out, PaddedBatch) else PaddedBatch(out, batch.seq_lens)

        if eval_strategy == "mean":
            out = self.mean_pooling(out)
        elif eval_strategy == "stat":
            out = self.stat_pooling(out)

        if self.model.hparams.norm_predict:
            out = out / (out.pow(2).sum(dim=-1, keepdim=True) + 1e-9).pow(0.5)
        return out

In [34]:
class WinAgg_TD_GPT_PretrainModule(GptPretrainWithTimeDiffsModule):
    def __init__(self,
                 trx_encoder: torch.nn.Module,
                 seq_encoder: AbsSeqEncoder,
                 time_diffs_boundaries: torch.tensor,
                 time_diffs_bins_num: int,
                 head_hidden_size: int = 64,
                 total_steps: int = 64000,
                 seed_seq_len: int = 16,
                 max_lr: float = 0.00005,
                 weight_decay: float = 0.0,
                 pct_start: float = 0.1,
                 norm_predict: bool = False
                 ):
        super().__init__(
            trx_encoder=trx_encoder,
            seq_encoder=seq_encoder,
            time_diffs_boundaries=time_diffs_boundaries,
            time_diffs_bins_num=time_diffs_bins_num,
            head_hidden_size=head_hidden_size,
            total_steps=total_steps,
            seed_seq_len=seed_seq_len,
            max_lr=max_lr,
            weight_decay=weight_decay,
            pct_start=pct_start,
            norm_predict=norm_predict
        )
        self.agg_samples = trx_encoder.agg_samples

    def loss_gpt(self, logits, labels):
        loss = 0
        
        for col_name, head in self.head.items():
            out = head(logits[:, (self.hparams.seed_seq_len // self.agg_samples):-1, :])
            
            y_true = labels[col_name][:, (self.hparams.seed_seq_len + self.agg_samples)::self.agg_samples]
            y_true = torch.flatten(y_true.long())
            
            pred = out.reshape(-1, out.size(-1))
                    
            loss += self.loss(pred, y_true)
                
        return loss

In [35]:
class WinAgg_TD_GPT_MultiPredPretrainModule(pl.LightningModule):
    def __init__(self,
                 trx_encoder: torch.nn.Module,
                 seq_encoder: AbsSeqEncoder,
                 time_diffs_boundaries: torch.tensor,
                 time_diffs_bins_num: int,
                 head_hidden_size: int = 64,
                 total_steps: int = 64000,
                 seed_seq_len: int = 16,
                 max_lr: float = 0.00005,
                 weight_decay: float = 0.0,
                 pct_start: float = 0.1,
                 norm_predict: bool = False
                 ):
        super().__init__()
        self.save_hyperparameters(ignore=['trx_encoder', 'seq_encoder'])

        self.trx_encoder = trx_encoder
        self._seq_encoder = seq_encoder
        self._seq_encoder.is_reduce_sequence = False

        self.agg_samples = trx_encoder.agg_samples

        self.head = nn.ModuleDict()
        for col_name, noisy_emb in self.trx_encoder.embeddings.items():
            for shift in range(self.agg_samples):
                self.head[f"{col_name}, {shift}"] = GPTHead(input_size=self._seq_encoder.embedding_size, hidden_size=head_hidden_size, n_classes=noisy_emb.num_embeddings)

        for shift in range(self.agg_samples):
            self.head[f"time_diff, {shift}"] = GPTHead(input_size=self._seq_encoder.embedding_size, hidden_size=head_hidden_size, n_classes=time_diffs_bins_num) 

        if self.hparams.norm_predict:
            self.fn_norm_predict = PBL2Norm()

        self.loss = nn.CrossEntropyLoss(ignore_index=0)

        self.train_gpt_loss = MeanMetric()
        self.valid_gpt_loss = MeanMetric()

        self.time_diffs_boundaries = time_diffs_boundaries
        self.SECONDS_IN_HOUR = 3600

    def forward(self, batch: PaddedBatch):
        z_trx = self.trx_encoder(batch) 
        out = self._seq_encoder(z_trx)
        if self.hparams.norm_predict:
            out = self.fn_norm_predict(out)
        return out

    def loss_gpt(self, logits, labels):
        loss = 0
        n_obj = 0

        NUM_FEATURES = len(self.head.keys()) // self.agg_samples
        
        for key, head in self.head.items():
            col_name, shift = key.split(', ')
            shift = int(shift)
            
            out = head(logits[:, (self.hparams.seed_seq_len // self.agg_samples):-1, :])
            
            y_true = labels[col_name][:, (self.hparams.seed_seq_len + self.agg_samples + shift)::self.agg_samples]
            y_true = torch.flatten(y_true.long())

            # delete last state of pred for sequences with len not divisible by `self.agg_samples`
            if y_true.shape[0] < out.shape[0] * out.shape[1]:
                pred = out[:, :-1, :]
                pred = pred.reshape(-1, pred.size(-1))
            else:
                pred = out.reshape(-1, out.size(-1))
            
            n_obj += pred.shape[0] 
            loss += self.loss(pred, y_true) * pred.shape[0]

        n_obj //= NUM_FEATURES
        
        return loss / n_obj

    def training_step(self, batch, batch_idx):
        out = self.forward(batch)  # PB: B, T, H
        out = out.payload if isinstance(out, PaddedBatch) else out
        labels = batch.payload

        timestamps = labels['event_time']
        timestamps_prev = torch.cat([timestamps[:, 0].unsqueeze(1), timestamps[:, :-1]], dim=1)
        labels['time_diff'] = (timestamps - timestamps_prev) // self.SECONDS_IN_HOUR
        labels['time_diff'][:, 0] = -1

        mask = torch.arange(labels['time_diff'].shape[1], device=batch.device)[None, :] + torch.ones((batch.seq_lens.shape[0], labels['time_diff'].shape[1]), device=batch.device)
        mask[mask > batch.seq_lens[:, None]] = 0.
        mask[mask > 0.] = 1.
        mask = mask.bool()

        labels['time_diff'][~mask] = -1

        labels['time_diff'] = torch.bucketize(labels['time_diff'], self.time_diffs_boundaries.to(batch.device), right=True)
        
        loss_gpt = self.loss_gpt(out, labels)
        self.train_gpt_loss(loss_gpt)
        self.log(f'loss', loss_gpt, sync_dist=True)
        return loss_gpt

    def validation_step(self, batch, batch_idx):
        out = self.forward(batch)  # PB: B, T, H
        out = out.payload if isinstance(out, PaddedBatch) else out
        labels = batch.payload

        timestamps = labels['event_time']
        timestamps_prev = torch.cat([timestamps[:, 0].unsqueeze(1), timestamps[:, :-1]], dim=1)
        labels['time_diff'] = (timestamps - timestamps_prev) // self.SECONDS_IN_HOUR
        labels['time_diff'][:, 0] = -1

        mask = torch.arange(labels['time_diff'].shape[1], device=batch.device)[None, :] + torch.ones((batch.seq_lens.shape[0], labels['time_diff'].shape[1]), device=batch.device)
        mask[mask > batch.seq_lens[:, None]] = 0.
        mask[mask > 0.] = 1.
        mask = mask.bool()

        labels['time_diff'][~mask] = -1

        labels['time_diff'] = torch.bucketize(labels['time_diff'], self.time_diffs_boundaries.to(batch.device), right=True)
        
        loss_gpt = self.loss_gpt(out, labels)
        self.valid_gpt_loss(loss_gpt)

    def on_training_epoch_end(self):
        self.log('train loss (by epochs)', self.train_gpt_loss, prog_bar=True, logger=True, sync_dist=True, rank_zero_only=True)

    def on_validation_epoch_end(self):
        self.log('val loss (by epochs)', self.valid_gpt_loss, prog_bar=True, logger=True, sync_dist=True, rank_zero_only=True)

    def configure_optimizers(self):
        optim = torch.optim.NAdam(self.parameters(),
                                  lr=self.hparams.max_lr,
                                  weight_decay=self.hparams.weight_decay
                                 )
        
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer=optim,
            max_lr=self.hparams.max_lr,
            total_steps=self.hparams.total_steps,
            pct_start=self.hparams.pct_start,
            anneal_strategy='cos',
            cycle_momentum=False,
            div_factor=25.0,
            final_div_factor=10000.0,
            three_phase=False
        )
        
        scheduler = {'scheduler': scheduler, 'interval': 'step'}
        return [optim], [scheduler]
    
    @property
    def seq_encoder(self):
        return GPTInferenceModule(pretrained_model=self)

In [36]:
class WinAgg_TD_GPT_MultiLabelPretrainModule(GptPretrainWithTimeDiffsModule):
    def __init__(self,
                 trx_encoder: torch.nn.Module,
                 seq_encoder: AbsSeqEncoder,
                 time_diffs_boundaries: torch.tensor,
                 time_diffs_bins_num: int,
                 head_hidden_size: int = 64,
                 total_steps: int = 64000,
                 seed_seq_len: int = 16,
                 max_lr: float = 0.00005,
                 weight_decay: float = 0.0,
                 pct_start: float = 0.1,
                 norm_predict: bool = False
                 ):
        super().__init__(
            trx_encoder=trx_encoder,
            seq_encoder=seq_encoder,
            time_diffs_boundaries=time_diffs_boundaries,
            time_diffs_bins_num=time_diffs_bins_num,
            head_hidden_size=head_hidden_size,
            total_steps=total_steps,
            seed_seq_len=seed_seq_len,
            max_lr=max_lr,
            weight_decay=weight_decay,
            pct_start=pct_start,
            norm_predict=norm_predict
        )
        self.agg_samples = trx_encoder.agg_samples
        self.loss = nn.MultiLabelSoftMarginLoss()

    def loss_gpt(self, logits, labels):
        loss = 0
        
        for col_name, head in self.head.items():
            pred = head(logits[:, (self.hparams.seed_seq_len // self.agg_samples):-1, :])

            ohe_labels = torch.zeros((pred.shape[0] * pred.shape[1], pred.shape[2]), device=pred.device)
            
            for shift in range(self.agg_samples):
                y_true = labels[col_name][:, (self.hparams.seed_seq_len + self.agg_samples + shift)::self.agg_samples]
                y_true = torch.flatten(y_true.long())
                ohe_labels_part = F.one_hot(y_true, num_classes=pred.shape[2])
                
                if ohe_labels_part.shape[0] < pred.shape[0] * pred.shape[1]:
                    padding = torch.zeros((pred.shape[0], 1, pred.shape[2]), device=ohe_labels_part.device)
                    ohe_labels_part = torch.cat((ohe_labels_part.reshape(pred.shape[0], pred.shape[1] - 1, pred.shape[2]), padding), dim=1).reshape(pred.shape[0] * pred.shape[1], pred.shape[2])
                
                ohe_labels += ohe_labels_part

            ohe_labels[ohe_labels > 1] = 1
            
            pred = pred.reshape(-1, pred.size(-1))

            loss += self.loss(pred, ohe_labels)
                
        return loss

In [69]:
N_EPOCHS = 20

In [70]:
agg_encoder_params = dict(
    embeddings_noise=0.003,
    embeddings={
        "MCC": {"in": 342, "out": 16},
        "channel_type": {"in": 7, "out": 16},
        "currency": {"in": 60, "out": 16},
        "trx_category": {"in": 11, "out": 16},
        "amount": {"in": BINS_NUM, "out": 16}
    },
    k=15,
    time_col="event_time",
    agg_samples=8, # 4, 8
    use_window_attention=False
)

trx_encoder = WinAggregator(**agg_encoder_params)

seq_encoder = GptEncoder(
    n_embd=trx_encoder.output_size,
    n_layer=6,
    n_head=6,
    n_inner=256,
    activation_function="gelu_new",
    resid_pdrop=0.1,
    embd_pdrop=0.1,
    attn_pdrop=0.1,
    n_positions=2048,
    use_positional_encoding=True,
    use_start_random_shift=True,
    is_reduce_sequence=False
)

gpt = WinAgg_TD_GPT_MultiLabelPretrainModule(
    trx_encoder=trx_encoder,
    seq_encoder=seq_encoder,
    time_diffs_boundaries=time_diff_bins,
    time_diffs_bins_num=(TIME_DIFF_BINS_NUM + 1), # (boundaries num) + (1 before the first boundary (OOD)) 
    head_hidden_size=256,
    total_steps=(N_EPOCHS * 71), # num_epochs * num_steps_per_epoch
    seed_seq_len=16,
    max_lr=3e-3,
    weight_decay=3e-4,
    pct_start=0.1,
    norm_predict=False
)

**Обучение:**

In [71]:
logger = CometLogger(project_name="evs-ssl-rb", experiment_name="TD-GPT_modeling_WinAgg (multilabel, 8 trx)")

trainer = pl.Trainer(
    logger=logger,
    max_epochs=N_EPOCHS,
    accelerator="gpu",
    devices=1,
    enable_progress_bar=True
)

In [72]:
trainer.fit(gpt, data)

[1;38;5;39mCOMET INFO:[0m Experiment is live on comet.com https://www.comet.com/askoro/evs-ssl-rb/8d97ec10a8fe431ab65454f6bed05bf2

[1;38;5;39mCOMET INFO:[0m Couldn't find a Git repository in '/kaggle/working' nor in any parent directory. Set `COMET_GIT_DIRECTORY` if your Git Repository is elsewhere.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m Comet.ml Experiment Summary
[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m   Data:
[1;38;5;39mCOMET INFO:[0m     display_summary_level : 1
[1;38;5;39mCOMET INFO:[0m     name                  : TD-GPT_modeling_WinAgg (multilabel, 8 trx)
[1;38;5;39mCOMET INFO:[0m     url                   : https://www.comet.com/askoro/evs-ssl-rb/8d97ec10a8fe431ab65454f6bed05bf2
[1;38;5;39mCOMET INFO:[0m   Metrics [count] (min, max):
[1;38;5;39mCOMET INFO:[0m     loss [170]                : (0.08528366684913635, 4.183159351348877)
[1;38;5;39mCOMET INFO:[0m     val loss (by epochs) [20] : (0.18971876800060272, 0.31538882851600647)
[1;38;5;39mCOMET INFO:[0m   Others:
[1;38;5;39mCOMET INFO:[0m     Name : TD-GPT_modeling_WinAgg (multilabel, 8 trx)
[1;38;5

In [73]:
trainer.logged_metrics

{'loss': tensor(0.3211), 'val loss (by epochs)': tensor(0.1898)}

In [74]:
encoder = gpt.seq_encoder

In [135]:
torch.save(encoder.state_dict(), "gpt_baseline_rosbank.pt")

**Измерим качество на тесте (catboost поверх эмбеддингов):**

In [102]:
# import gdown

# gdown.download("https://drive.google.com/uc?export=download&id=1YBstN7hpEIREo7zORmPoEZ_0NyBgfjm6", "gpt_baseline_NAdam.pt")

Downloading...
From (original): https://drive.google.com/uc?export=download&id=1YBstN7hpEIREo7zORmPoEZ_0NyBgfjm6
From (redirected): https://drive.google.com/uc?export=download&id=1YBstN7hpEIREo7zORmPoEZ_0NyBgfjm6&confirm=t&uuid=b0f44bc3-b84b-425c-968f-016e419987af
To: /kaggle/working/gpt_baseline_NAdam.pt
100%|██████████| 34.7M/34.7M [00:00<00:00, 83.5MB/s]


'gpt_baseline_NAdam.pt'

In [75]:
# state_dict = torch.load("./gpt_baseline_NAdam.pt")
# encoder.load_state_dict(state_dict)

device = "cuda:0"

encoder.to(device)

GPTInferenceModule(
  (model): WinAgg_TD_GPT_MultiLabelPretrainModule(
    (trx_encoder): WinAggregator(
      (embeddings): ModuleDict(
        (MCC): NoisyEmbedding(
          342, 16, padding_idx=0
          (dropout): Dropout(p=0, inplace=False)
        )
        (channel_type): NoisyEmbedding(
          7, 16, padding_idx=0
          (dropout): Dropout(p=0, inplace=False)
        )
        (currency): NoisyEmbedding(
          60, 16, padding_idx=0
          (dropout): Dropout(p=0, inplace=False)
        )
        (trx_category): NoisyEmbedding(
          11, 16, padding_idx=0
          (dropout): Dropout(p=0, inplace=False)
        )
        (amount): NoisyEmbedding(
          128, 16, padding_idx=0
          (dropout): Dropout(p=0, inplace=False)
        )
      )
      (custom_embeddings): ModuleDict()
      (time2vec_days): Time2Vec()
    )
    (_seq_encoder): GptEncoder(
      (transf): GPT2Model(
        (wte): Embedding(4, 96)
        (wpe): Embedding(2048, 96)
        (dro

In [76]:
from tqdm import tqdm

seed_everything(42)

In [77]:
train_loader = inference_data_loader(data_train, num_workers=0, batch_size=8)
encoder.eval()
train_embeds = None

for i, batch in tqdm(enumerate(train_loader)):
    train_embeds_batch = encoder(batch.to(device), eval_strategy="stat")
    if i == 0:
        train_embeds = train_embeds_batch.detach().cpu().numpy()
    else:
        train_embeds = np.concatenate([train_embeds, train_embeds_batch.detach().cpu().numpy()], axis=0)
    
train_embeds

563it [00:04, 124.01it/s]


array([[-6.46412201e+01, -5.57692184e+01,  2.71090031e+01, ...,
         2.04409695e+00,  8.50525081e-01,  1.56406149e-01],
       [-3.53400016e+00, -2.40670896e+00,  9.26254690e-01, ...,
         3.42187762e+00,  1.97863567e+00,  7.34392181e-02],
       [-1.44118834e+01, -9.93260765e+00,  4.83481503e+00, ...,
         2.65834045e+00,  4.90296423e-01, -1.09801307e-01],
       ...,
       [-1.13449252e+00, -7.30633616e-01,  6.17217310e-02, ...,
         2.44484162e+00,  1.42213571e+00,  2.45016828e-01],
       [-5.83157659e-01, -4.29976195e-01, -1.27299637e-01, ...,
         3.03994346e+00,  1.74793279e+00,  1.17828354e-01],
       [-4.24323082e+00, -4.53340101e+00,  1.69838715e+00, ...,
         2.15181804e+00,  2.12893128e+00, -5.18252790e-01]], dtype=float32)

In [78]:
test_loader = inference_data_loader(data_test, num_workers=0, batch_size=8)
encoder.eval()
test_embeds = None

for i, batch in tqdm(enumerate(test_loader)):
    test_embeds_batch = encoder(batch.to(device), eval_strategy="stat")
    if i == 0:
        test_embeds = test_embeds_batch.detach().cpu().numpy()
    else:
        test_embeds = np.concatenate([test_embeds, test_embeds_batch.detach().cpu().numpy()], axis=0)
    
test_embeds

63it [00:00, 125.57it/s]


array([[ -1.3249881 ,  -0.67583215,   0.21625178, ...,   2.834036  ,
          1.524177  ,   0.52752304],
       [-11.209502  ,  -9.388976  ,   3.7508378 , ...,   2.7500365 ,
          0.41094542,  -0.17393509],
       [ -1.039877  ,  -0.52256167,   0.07566922, ...,   3.1003706 ,
          1.9132227 ,   0.20663895],
       ...,
       [-28.81272   , -22.847645  ,   8.823192  , ...,   2.924701  ,
          1.1828089 ,  -0.07550007],
       [ -7.576869  ,  -6.666329  ,   2.3489196 , ...,   3.3405352 ,
          0.8391173 ,  -0.29951826],
       [ -0.69672865,  -0.3774949 ,  -0.12771936, ...,   3.253129  ,
          2.791896  ,   0.48261037]], dtype=float32)

In [79]:
clf = CatBoostClassifier(loss_function='MultiClass', task_type="GPU", devices='0', random_state=42)

clf.fit(train_embeds, target_train, plot_file="catboost_log.html")

Learning rate set to 0.088214
0:	learn: 0.6674010	total: 6.64ms	remaining: 6.63s
1:	learn: 0.6452875	total: 11.4ms	remaining: 5.67s
2:	learn: 0.6261255	total: 15.9ms	remaining: 5.28s
3:	learn: 0.6096951	total: 20.5ms	remaining: 5.1s
4:	learn: 0.5967079	total: 25.2ms	remaining: 5s
5:	learn: 0.5844678	total: 29.7ms	remaining: 4.92s
6:	learn: 0.5731182	total: 34.4ms	remaining: 4.88s
7:	learn: 0.5631409	total: 39ms	remaining: 4.83s
8:	learn: 0.5549907	total: 43.5ms	remaining: 4.79s
9:	learn: 0.5474615	total: 47.8ms	remaining: 4.74s
10:	learn: 0.5404775	total: 52.3ms	remaining: 4.7s
11:	learn: 0.5344569	total: 56.7ms	remaining: 4.66s
12:	learn: 0.5292433	total: 61.2ms	remaining: 4.64s
13:	learn: 0.5238415	total: 65.6ms	remaining: 4.62s
14:	learn: 0.5190005	total: 69.9ms	remaining: 4.59s
15:	learn: 0.5146917	total: 74.2ms	remaining: 4.56s
16:	learn: 0.5110907	total: 78.6ms	remaining: 4.55s
17:	learn: 0.5075221	total: 83ms	remaining: 4.53s
18:	learn: 0.5045972	total: 87.5ms	remaining: 4.51s
1

<catboost.core.CatBoostClassifier at 0x7c5f4431e470>

In [80]:
test_pred = clf.predict(test_embeds)
test_proba = clf.predict_proba(test_embeds)[:, 1]

In [81]:
print("Accuracy:", accuracy_score(target_test, test_pred))
print("ROC-AUC:", roc_auc_score(target_test, test_proba))

Accuracy: 0.72
ROC-AUC: 0.781402276149002


In [83]:
arr = np.array([0.8011364556183321, 0.7957779540561104, 0.781402276149002])

arr.mean(), arr.std()

(0.7927722286078148, 0.00833207652928652)

- TD-GPT embeds + Catboost:
  - `Accuracy: 0.736`, `0.724`, `0.73`, avg: `0.73 +- 0.0049`
  - `ROC-AUC: 0.8034352689773517`, `0.787618785514238`, `0.7936410289618107`, avg: `0.7949 +- 0.0065`

---

- TD-GPT embeds + WinAgg (single trx pred, 4 trx window) + Catboost:
  - `Accuracy: 0.736`, `0.726`, `0.732`, avg: `0.7313 +- 0.0041`
  - `ROC-AUC: 0.7910993832057114`, `0.7937705395735862`, `0.7891243463761313`, avg: `0.7913 +- 0.0019`

---

- TD-GPT embeds + WinAgg (single trx pred, 8 trx window) + Catboost:
  - `Accuracy: 0.736`, `0.722`, `0.72`, avg: `0.726 +- 0.0071`
  - `ROC-AUC: 0.8009745673536126`, `0.7877644849524857`, `0.7801881141636043`, avg: `0.7896 +- 0.0086`

---

- TD-GPT embeds + WinAgg (multi trx pred, 4 trx window) + Catboost:
  - `Accuracy: 0.736`, `0.714`, `0.742`, avg: `0.731 +- 0.012`
  - `ROC-AUC: 0.7981739003739619`, `0.7757199980573408`, `0.8029334153567208`, avg: `0.7923 +- 0.0119`

---

- TD-GPT embeds + WinAgg (multi trx pred, 8 trx window) + Catboost:
  - `Accuracy: 0.72`, `0.672`, `0.724`, avg: `0.7053 +- 0.0236`
  - `ROC-AUC: 0.7933496300853151`, `0.7479885383108579`, `0.7824869275226238`, avg: `0.7746 +- 0.0193`

---

- TD-GPT embeds + WinAgg (multilabel, 4 trx window) + Catboost:
  - `Accuracy: 0.722`, `0.714`, `0.76`, avg: `0.732 +- 0.0201`
  - `ROC-AUC: 0.7971863819591718`, `0.7809166113548428`, `0.8089880364572374`, avg: `0.7957 +- 0.0115`

---

- TD-GPT embeds + WinAgg (multilabel, 8 trx window) + Catboost:
  - `Accuracy: 0.724`, `0.742`, `0.72`, avg: `0.7287 +- 0.0096`
  - `ROC-AUC: 0.8011364556183321`, `0.7957779540561104`, `0.781402276149002`, avg: `0.7928 +- 0.0083`

**Вывод:** Лучше всего по метрикам относительно других подходов - подход с multilabel-обучением (какие транзакции войдут в следующую сагрегированную транзакцию).

**Лучший результат - для агрегации по 4 транзакциям:**

- TD-GPT embeds + WinAgg (multilabel, 4 trx window) + Catboost:
  - `Accuracy: 0.722`, `0.714`, `0.76`, avg: `0.732 +- 0.0201`
  - `ROC-AUC: 0.7971863819591718`, `0.7809166113548428`, `0.8089880364572374`, avg: `0.7957 +- 0.0115`

# Итоги.

| Method                                |    Accuracy           | ROC-AUC         |
|---------------------------------------|-----------------------|-----------------|
| **Flattened Sequences**               | 0.67 ± 0.0046         | 0.7536 ± 0.003  |
| **GRU (+ MLP)**                       | 0.746 ± 0.0076        | 0.8148 ± 0.0037 |
| **CoLES**                             | 0.726 ± 0.0071        | 0.8076 ± 0.0025 |
| **COLES embeds + WinAgg (5 trx)**     | 0.7327 ± 0.0025       | 0.8094 ± 0.0045 |
| **COLES embeds + WinAgg (7 trx)**     | 0.7393 ± 0.0062       | 0.8076 ± 0.0039 |
| **CPC Modeling (emb_dim=32)**         | 0.747 ± 0.0041        | 0.8099 ± 0.0035 |
| **CPC Modeling (emb_dim=32) + WinAgg**| 0.748 ± 0.0071        | 0.8111 ± 0.0058 |
| **CPC Modeling (emb_dim=8)**          | 0.747 ± 0.0052        | 0.8165 ± 0.0032 |
| **CPC Modeling (emb_dim=8) + WinAgg** | 0.734 ± 0.0257        | 0.8045 ± 0.0087 |
| **TD-GPT**                            | 0.73 ± 0.0049         | 0.7949 ± 0.0065 |
| **TD-GPT + WinAgg (multilabel loss)** | 0.732 ± 0.0201        | 0.7957 ± 0.0115 |