# Импортируем необходимые библиотеки

In [1]:
!pip install pytorch-lifestream
!pip install comet_ml

Collecting pytorch-lifestream
  Downloading pytorch-lifestream-0.6.0.tar.gz (163 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m163.4/163.4 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting hydra-core>=1.1.2 (from pytorch-lifestream)
  Downloading hydra_core-1.3.2-py3-none-any.whl.metadata (5.5 kB)
Downloading hydra_core-1.3.2-py3-none-any.whl (154 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.5/154.5 kB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: pytorch-lifestream
  Building wheel for pytorch-lifestream (pyproject.toml) ... [?25l[?25hdone
  Created wheel for pytorch-lifestream: filename=pytorch_lifestream-0.6.0-py3-none-any.whl size=274640 sha256=774b121b85750ea9c7c63e05046647bf9e631e70cf28d67f31a13601b2a1c87e
  

In [2]:
# data preprocessing
import os
import numpy as np
import pandas as pd
import pickle

# misc
from tqdm import tqdm
from functools import partial

# logging
import comet_ml

# classical ML
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score
from catboost import CatBoostClassifier

# basic deep learning libs
import torch
import pytorch_lightning as pl
import torchmetrics

# ptls
from ptls.nn import TrxEncoder, RnnSeqEncoder, TransformerEncoder, GptEncoder, Head
from ptls.frames import PtlsDataModule
from ptls.frames.coles import CoLESModule
from ptls.frames.coles import ColesDataset
from ptls.frames.coles.split_strategy import SampleSlices
from ptls.frames.cpc import CpcModule
from ptls.frames.cpc import CpcDataset
from ptls.frames.gpt import GptDataset
from ptls.frames.supervised import SeqToTargetDataset, SequenceToTarget
from ptls.data_load.datasets import MemoryMapDataset
from ptls.data_load.datasets import inference_data_loader
from ptls.frames.inference_module import InferenceModule
from ptls.data_load.iterable_processing import SeqLenFilter
from ptls.preprocessing import PandasDataPreprocessor
from ptls.data_load.utils import collate_feature_dict
from ptls.frames.inference_module import InferenceModule

In [3]:
def seed_everything(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [5]:
comet_ml.login()

In [6]:
from pytorch_lightning.loggers import CometLogger

# Эксперименты.

**Данные:**

In [7]:
path_data = "https://huggingface.co/datasets/dllllb/age-group-prediction/resolve/main/transactions_train.csv.gz?download=true"
data = pd.read_csv(path_data, compression="gzip")
data

Unnamed: 0,client_id,trans_date,small_group,amount_rur
0,33172,6,4,71.463
1,33172,6,35,45.017
2,33172,8,11,13.887
3,33172,9,11,15.983
4,33172,10,11,21.341
...,...,...,...,...
26450572,43300,727,25,7.602
26450573,43300,727,15,3.709
26450574,43300,727,1,6.448
26450575,43300,727,11,24.669


In [8]:
path_target = "https://huggingface.co/datasets/dllllb/age-group-prediction/resolve/main/train_target.csv?download=true"
target = pd.read_csv(path_target)
target

Unnamed: 0,client_id,bins
0,24662,2
1,1046,0
2,34089,2
3,34848,1
4,47076,3
...,...,...
29995,14303,1
29996,22301,2
29997,25731,0
29998,16820,3


In [9]:
target_train, target_test = train_test_split(target, test_size=0.1, stratify=target["bins"], random_state=42)

In [10]:
trx_data_train = pd.merge(data, target_train["client_id"], on="client_id", how="inner")
trx_data_test = pd.merge(data, target_test["client_id"], on="client_id", how="inner")

---

**Квантизация непрерывных признаков (опциональный шаг, нужен только для GPT):**

In [11]:
def digitize(input_array: np.array, q_count: int = 1, bins: np.array = None):
    """Quantile-based discretization function.

    Parameters:
    -------
    input_array (np.array): Input array.
    q_count (int): Amount of quantiles. Used only if input parameter `bins` is None.
    bins (np.array):
        If None, then calculate bins as quantiles of input array,
        otherwise only apply bins to input_array. Default: None

    Returns
    -------
    out_array (np.array of ints): discretized input_array
    bins (np.array of floats):
        Returned only if input parameter `bins` is None.
    """

    if bins is None:
        return_bins = True
        bins = np.quantile(input_array, q=[i / q_count for i in range(1, q_count)], axis=0)
    else:
        return_bins = False

    out_array = np.digitize(input_array, bins)

    if return_bins:
        return out_array, bins
    else:
        return out_array

In [12]:
BINS_NUM = 128

In [13]:
numeric_features = ["amount_rur"]

for feat in numeric_features:
    trx_data_train[feat], bins = digitize(trx_data_train[feat], q_count=BINS_NUM)
    trx_data_test[feat] = digitize(trx_data_test[feat], bins=bins)

In [14]:
import gc

gc.collect()

234

---

In [15]:
preprocessor = PandasDataPreprocessor(
    col_id="client_id",
    col_event_time="trans_date",
    event_time_transformation="none",
    cols_category=["small_group"],
    cols_numerical=["amount_rur"],
    return_records=False,
)

In [16]:
data_train = preprocessor.fit_transform(trx_data_train)
data_test = preprocessor.transform(trx_data_test)

In [17]:
target_train.rename(columns={"bins": "target"}, inplace=True)
target_test.rename(columns={"bins": "target"}, inplace=True)
target_train.sort_values(by="client_id", inplace=True)
target_test.sort_values(by="client_id", inplace=True)
target_train = target_train["target"]
target_test = target_test["target"]
target_train.reset_index(drop=True, inplace=True)
target_test.reset_index(drop=True, inplace=True)

In [18]:
data_train = data_train.to_dict(orient="records")
data_test = data_test.to_dict(orient="records")

---

**Window Aggregator Class:**

In [19]:
from ptls.data_load.padded_batch import PaddedBatch


class WinAggregator(TrxEncoder):
    """The NN layer, a combination of TrxEncoder and Mean Aggregation within a window of #`agg_samples` transactions 
       (works like nn.Sequential([TrxEncoder, Mean Window Aggregation])).
       It is supposed that any two different windows do not overlap here.
       
       The types of the input and output are `PaddedBatch` of shapes (B, L, T) and (B, L', T) respectively, where 
       B means batch_size,
       L/L' means the max length of a sequence of transactions in a batch (the length is the same as #trx)
       T means the dimension of a single transaction.

       Parameters
        agg_samples (int):
            The number of transactions in a sliding aggregation window.

        use_pre_agg_attention (bool):
            If True, the attention layer will be used between trx encoding step and aggregation step.

        use_window_attention (bool):
            If True, the attention layer will be applied to transactions in a sliding window before pooling.
            
        embeddings:
            You can find info about this param in TrxEncoder desc.
        
        numeric_values:
            You can find info about this param in TrxEncoder desc.

        embeddings_noise:
            You can find info about this param in TrxEncoder desc.
            
        emb_dropout:
            You can find info about this param in TrxEncoder desc.
            
        spatial_dropout:
            You can find info about this param in TrxEncoder desc.

        use_batch_norm:
            You can find info about this param in TrxEncoder desc.

        orthogonal_init:
            You can find info about this param in TrxEncoder desc.
            
        linear_projection_size:
            You can find info about this param in TrxEncoder desc.

        out_of_index:
            You can find info about this param in TrxEncoder desc.

        norm_embeddings:
            Keep default value for this parameter
        
        clip_replace_value:
            Not used. Keep default value for this parameter
        
        positions: 
            Not used. Keep default value for this parameter
       """

    def __init__(self,
                 agg_samples=3,
                 use_pre_agg_attention=False,
                 use_window_attention=False,
                 embeddings=None,
                 numeric_values=None,
                 custom_embeddings=None,
                 embeddings_noise: float = 0,
                 norm_embeddings=None,
                 use_batch_norm=False,
                 use_batch_norm_with_lens=False,
                 clip_replace_value=None,
                 positions=None,
                 emb_dropout=0,
                 spatial_dropout=False,
                 orthogonal_init=False,
                 linear_projection_size=0,
                 out_of_index: str = 'clip',
                 ):
        super().__init__(
            embeddings=embeddings,
            numeric_values=numeric_values,
            custom_embeddings=custom_embeddings,
            embeddings_noise=embeddings_noise,
            norm_embeddings=norm_embeddings,
            use_batch_norm=use_batch_norm,
            use_batch_norm_with_lens=use_batch_norm_with_lens,
            clip_replace_value=clip_replace_value,
            positions=positions,
            emb_dropout=emb_dropout,
            spatial_dropout=spatial_dropout,
            orthogonal_init=orthogonal_init,
            linear_projection_size=linear_projection_size,
            out_of_index=out_of_index
        )

        self.agg_samples = agg_samples

        self.use_pre_agg_attention = use_pre_agg_attention        
        if self.use_pre_agg_attention:
            pass # Not Implemented

        self.use_window_attention = use_window_attention
        if self.use_window_attention:
            pass # Not Implemented

    def forward(self, pb: PaddedBatch):
        embeds = super().forward(pb)

        if self.use_pre_agg_attention:
            pass # Not Implemented

        mask = torch.arange(embeds.payload.shape[1], device=embeds.device)[None, :] + torch.ones((embeds.seq_lens.shape[0], embeds.payload.shape[1]), device=embeds.device)
        mask[mask > embeds.seq_lens[:, None]] = 0.
        mask[mask > 0.] = 1.
        mask = mask[:, :, None]
    
        masked_embeds = embeds.payload * mask
    
        num_samples_to_add = self.agg_samples - (masked_embeds.shape[1] % self.agg_samples)  
        if num_samples_to_add > 0:
            additional_samples = torch.zeros((masked_embeds.shape[0], num_samples_to_add, masked_embeds.shape[2]), device=masked_embeds.device)
            masked_embeds = torch.cat((masked_embeds, additional_samples), dim=1)
            mask_additional_samples = torch.zeros((mask.shape[0], num_samples_to_add, mask.shape[2]), device=mask.device)
            mask = torch.cat((mask, mask_additional_samples), dim=1)
    
        masked_embeds = torch.reshape(masked_embeds, (masked_embeds.shape[0], masked_embeds.shape[1] // self.agg_samples, self.agg_samples, masked_embeds.shape[2]))
        mask = torch.reshape(mask, (mask.shape[0], mask.shape[1] // self.agg_samples, self.agg_samples, mask.shape[2]))

        if self.use_window_attention:
            pass # Not Implemented
        
        mask = torch.sum(mask, dim=2)
        mask[mask == 0.] = 1.
    
        mean_embeds = torch.sum(masked_embeds, dim=2) / mask

        new_seq_lens = embeds.seq_lens // self.agg_samples
        div_mod_seq_lens = ((embeds.seq_lens % self.agg_samples) > 0).int()
        new_seq_lens += div_mod_seq_lens

        return PaddedBatch(mean_embeds, new_seq_lens)

In [45]:
# seed_everything(0)

In [46]:
# device = "cuda:0"

In [47]:
# trx_encoder_params = dict(
#     embeddings_noise=0.003,
#     numeric_values={"amount_rur": "log"},
#     embeddings={
#         "trans_date": {"in": 800, "out": 16},
#         "small_group": {"in": 250, "out": 16},
#     },
# )

# trx_encoder = TrxEncoder(**trx_encoder_params).to(device)

In [None]:
# from ptls.data_load.padded_batch import PaddedBatch


# train_loader = inference_data_loader(data_train, num_workers=0, batch_size=128)
# agg_samples = 5
# use_attention = False
# trx_encoder.eval()

# for i, batch in tqdm(enumerate(train_loader)):
#     batch = batch.to(device)
        
#     embeds = trx_encoder(batch)
    
#     mask = torch.arange(embeds.payload.shape[1], device=embeds.device)[None, :] + torch.ones((embeds.seq_lens.shape[0], embeds.payload.shape[1]), device=embeds.device)
#     mask[mask > embeds.seq_lens[:, None]] = 0.
#     mask[mask > 0.] = 1.
#     mask = mask[:, :, None]
    
#     masked_embeds = embeds.payload * mask
    
#     num_samples_to_add = agg_samples - (masked_embeds.shape[1] % agg_samples)  
#     if num_samples_to_add > 0:
#         additional_samples = torch.zeros((masked_embeds.shape[0], num_samples_to_add, masked_embeds.shape[2]), device=masked_embeds.device)
#         masked_embeds = torch.cat((masked_embeds, additional_samples), dim=1)

#         mask_additional_samples = torch.zeros((mask.shape[0], num_samples_to_add, mask.shape[2]), device=mask.device)
#         mask = torch.cat((mask, mask_additional_samples), dim=1)
    
#     masked_embeds = torch.reshape(masked_embeds, (masked_embeds.shape[0], masked_embeds.shape[1] // agg_samples, agg_samples, masked_embeds.shape[2]))
#     mask = torch.reshape(mask, (mask.shape[0], mask.shape[1] // agg_samples, agg_samples, mask.shape[2]))

#     mask = torch.sum(mask, dim=2)
#     mask[mask == 0.] = 1.
    
#     mean_embeds = torch.sum(masked_embeds, dim=2) / mask

#     new_seq_lens = embeds.seq_lens // agg_samples
#     div_mod_seq_lens = ((embeds.seq_lens % agg_samples) > 0).int()
#     new_seq_lens += div_mod_seq_lens

#     out = PaddedBatch(mean_embeds, new_seq_lens)

#     if i == 0:
#         print(out.payload)

In [49]:
# seed_everything(0)

In [50]:
# agg_encoder_params = dict(
#     embeddings_noise=0.003,
#     numeric_values={"amount_rur": "log"},
#     embeddings={
#         "trans_date": {"in": 800, "out": 16},
#         "small_group": {"in": 250, "out": 16},
#     },
#     agg_samples=5,
#     use_pre_agg_attention=False,
#     use_window_attention=False
# )

# trx_encoder = WinAggregator(**agg_encoder_params)

In [51]:
# trx_encoder.to(device)

WinAggregator(
  (embeddings): ModuleDict(
    (trans_date): NoisyEmbedding(
      800, 16, padding_idx=0
      (dropout): Dropout(p=0, inplace=False)
    )
    (small_group): NoisyEmbedding(
      250, 16, padding_idx=0
      (dropout): Dropout(p=0, inplace=False)
    )
  )
  (custom_embeddings): ModuleDict(
    (amount_rur): LogScaler()
  )
)

In [None]:
# train_loader = inference_data_loader(data_train, num_workers=0, batch_size=128)
# trx_encoder.eval()

# for i, batch in tqdm(enumerate(train_loader)):
#     batch = batch.to(device)    
#     embeds = trx_encoder(batch)

#     if i == 0:
#         print(embeds.payload)

# Sliding Window Aggregation (Mean Pooling) 

- **COLES:**

In [16]:
seed_everything(0)

**DataLoaders:**

In [17]:
data = PtlsDataModule(
    train_data=ColesDataset(
        MemoryMapDataset(
            data=data_train,
            i_filters=[SeqLenFilter(min_seq_len=25)],
        ),
        splitter=SampleSlices(
            split_count=5,
            cnt_min=30,
            cnt_max=190,
        ),
    ),
    train_num_workers=4,
    train_batch_size=128,
    valid_data=ColesDataset(
        MemoryMapDataset(
            data=data_test,
            i_filters=[SeqLenFilter(min_seq_len=25)],
        ),
        splitter=SampleSlices(
            split_count=5,
            cnt_min=30,
            cnt_max=190,
        ),
    ),
    valid_num_workers=4,
    valid_batch_size=128
)

**Модель:**

In [18]:
N_EPOCHS = 20

In [19]:
agg_encoder_params = dict(
    embeddings_noise=0.003,
    numeric_values={"amount_rur": "log"},
    embeddings={
        "trans_date": {"in": 800, "out": 16},
        "small_group": {"in": 250, "out": 16},
    },
    agg_samples=20,
    use_pre_agg_attention=False,
    use_window_attention=False
)

trx_encoder = WinAggregator(**agg_encoder_params)

seq_encoder = RnnSeqEncoder(
    trx_encoder=trx_encoder,
    hidden_size=512,
    type="gru",
)

coles = CoLESModule(
    seq_encoder=seq_encoder,
    optimizer_partial=partial(torch.optim.Adam, lr=3e-3, weight_decay=5e-4),
    lr_scheduler_partial=partial(torch.optim.lr_scheduler.CosineAnnealingLR, T_max=N_EPOCHS, eta_min=5e-5)
)

**Обучение:**

In [20]:
logger = CometLogger(project_name="EvS_SSL", experiment_name="CoLES_WinAgg")

trainer = pl.Trainer(
    logger=logger,
    max_epochs=N_EPOCHS,
    accelerator="gpu",
    devices=1,
    enable_progress_bar=True
)

In [21]:
trainer.fit(coles, data)

[1;38;5;39mCOMET INFO:[0m Experiment is live on comet.com https://www.comet.com/askoro/evs-ssl/d944aa2ac9164fd18f6495d36e136ff1

[1;38;5;39mCOMET INFO:[0m Couldn't find a Git repository in '/kaggle/working' nor in any parent directory. Set `COMET_GIT_DIRECTORY` if your Git Repository is elsewhere.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m Comet.ml Experiment Summary
[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m   Data:
[1;38;5;39mCOMET INFO:[0m     display_summary_level : 1
[1;38;5;39mCOMET INFO:[0m     name                  : CoLES_WinAgg
[1;38;5;39mCOMET INFO:[0m     url                   : https://www.comet.com/askoro/evs-ssl/d944aa2ac9164fd18f6495d36e136ff1
[1;38;5;39mCOMET INFO:[0m   Metrics [count] (min, max):
[1;38;5;39mCOMET INFO:[0m     loss [506]              : (181.281005859375, 683.1492919921875)
[1;38;5;39mCOMET INFO:[0m     seq_len [84]            : (105.6421890258789, 115.37187957763672)
[1;38;5;39mCOMET INFO:[0m     valid/recall_top_k [20] : (0.4913947284221649, 0.6664039492607117)
[1;38;5;39mCOMET INFO:[0m   Others:
[1;38;5;39mCOMET INFO:[0m     Na

In [55]:
trainer.logged_metrics

{'loss': tensor(174.7807),
 'seq_len': tensor(112.2350),
 'valid/recall_top_k': tensor(0.6652)}

In [56]:
torch.save(seq_encoder.state_dict(), "coles_enc_win_agg.pt")

**Измерим качество на тесте (catboost поверх эмбеддингов):**

In [57]:
encoder = coles.seq_encoder

device = "cuda:0"

encoder.to(device)

RnnSeqEncoder(
  (trx_encoder): WinAggregator(
    (embeddings): ModuleDict(
      (trans_date): NoisyEmbedding(
        800, 16, padding_idx=0
        (dropout): Dropout(p=0, inplace=False)
      )
      (small_group): NoisyEmbedding(
        250, 16, padding_idx=0
        (dropout): Dropout(p=0, inplace=False)
      )
    )
    (custom_embeddings): ModuleDict(
      (amount_rur): LogScaler()
    )
  )
  (seq_encoder): RnnEncoder(
    (rnn): GRU(33, 512, batch_first=True)
    (reducer): LastStepEncoder()
  )
)

In [58]:
from tqdm import tqdm

seed_everything(0)

In [59]:
train_loader = inference_data_loader(data_train, num_workers=0, batch_size=128)
encoder.eval()
train_embeds = None

for i, batch in tqdm(enumerate(train_loader)):
    train_embeds_batch = encoder(batch.to(device))
    if i == 0:
        train_embeds = train_embeds_batch.detach().cpu().numpy()
    else:
        train_embeds = np.concatenate([train_embeds, train_embeds_batch.detach().cpu().numpy()], axis=0)
    
train_embeds

211it [00:05, 35.41it/s]


array([[ 0.00452117, -0.02827387,  0.9337678 , ...,  0.00495588,
         0.0193441 ,  0.09391169],
       [ 0.04560404, -0.03634291,  0.8987108 , ..., -0.04580244,
        -0.00496499,  0.15117761],
       [ 0.03072902, -0.13078159,  0.99990016, ..., -0.09737144,
        -0.0161845 ,  0.968005  ],
       ...,
       [ 0.00750701, -0.00894194,  0.9967334 , ...,  0.01608533,
         0.00373627,  0.6375581 ],
       [ 0.0071538 ,  0.0496349 ,  0.99902135, ..., -0.03007819,
        -0.02328188,  0.16292547],
       [-0.02030775, -0.06471202,  0.9840942 , ..., -0.03953908,
         0.00458581,  0.41498113]], dtype=float32)

In [60]:
test_loader = inference_data_loader(data_test, num_workers=0, batch_size=128)
encoder.eval()
test_embeds = None

for i, batch in tqdm(enumerate(test_loader)):
    test_embeds_batch = encoder(batch.to(device))
    if i == 0:
        test_embeds = test_embeds_batch.detach().cpu().numpy()
    else:
        test_embeds = np.concatenate([test_embeds, test_embeds_batch.detach().cpu().numpy()], axis=0)
    
test_embeds

24it [00:00, 52.54it/s]


array([[-0.00913762, -0.0365316 ,  0.9914088 , ...,  0.06454185,
        -0.01362646,  0.36703894],
       [-0.01873677, -0.1018367 ,  0.9499835 , ..., -0.01761043,
         0.06090444,  0.62053996],
       [ 0.01613991, -0.04726062,  0.58812404, ..., -0.04554856,
         0.12404823, -0.04990862],
       ...,
       [ 0.06967799, -0.10422407,  0.9955157 , ...,  0.16319045,
        -0.20528884, -0.62589127],
       [ 0.01730782, -0.03003943,  0.99920636, ...,  0.02860174,
        -0.01038891,  0.9474452 ],
       [-0.01291906, -0.03034842,  0.88051796, ...,  0.00866131,
         0.09558208,  0.8824644 ]], dtype=float32)

In [61]:
clf = CatBoostClassifier(loss_function='MultiClass', task_type="GPU", devices='0')

clf.fit(train_embeds, target_train, plot_file="catboost_log.html")

Learning rate set to 0.12714
0:	learn: 1.3141301	total: 17.6ms	remaining: 17.5s
1:	learn: 1.2588322	total: 28.3ms	remaining: 14.1s
2:	learn: 1.2151895	total: 39.4ms	remaining: 13.1s
3:	learn: 1.1806934	total: 51.2ms	remaining: 12.8s
4:	learn: 1.1523811	total: 63ms	remaining: 12.5s
5:	learn: 1.1282296	total: 74.5ms	remaining: 12.3s
6:	learn: 1.1078615	total: 85.9ms	remaining: 12.2s
7:	learn: 1.0897215	total: 97.2ms	remaining: 12.1s
8:	learn: 1.0743169	total: 108ms	remaining: 11.9s
9:	learn: 1.0604403	total: 119ms	remaining: 11.8s
10:	learn: 1.0485289	total: 131ms	remaining: 11.8s
11:	learn: 1.0382023	total: 143ms	remaining: 11.8s
12:	learn: 1.0289413	total: 155ms	remaining: 11.8s
13:	learn: 1.0206798	total: 167ms	remaining: 11.8s
14:	learn: 1.0139099	total: 178ms	remaining: 11.7s
15:	learn: 1.0076668	total: 191ms	remaining: 11.7s
16:	learn: 1.0017263	total: 203ms	remaining: 11.7s
17:	learn: 0.9956908	total: 214ms	remaining: 11.7s
18:	learn: 0.9909684	total: 225ms	remaining: 11.6s
19:	le

<catboost.core.CatBoostClassifier at 0x7ff1abec1fc0>

In [62]:
test_pred = clf.predict(test_embeds)
test_proba = clf.predict_proba(test_embeds)

In [63]:
print("Accuracy:", accuracy_score(target_test, test_pred))
print("ROC-AUC:", roc_auc_score(target_test, test_proba, average="weighted", multi_class="ovr"))

Accuracy: 0.5623333333333334
ROC-AUC: 0.8212206273323065


- COLES embeds + Catboost:
  - `Accuracy: 0.6063333333333333`
  -  `ROC-AUC: 0.8485032660206542`

\

- COLES embeds (w/ Window Aggregation, 3 trx window) + Catboost:
  - `Accuracy: 0.5793333333333334`
  - `ROC-AUC: 0.8322195150062717`

\

- COLES embeds (w/ Window Aggregation, 5 trx window) + Catboost:
  - `Accuracy: 0.565`
  - `ROC-AUC: 0.8228954034229273`

\

- COLES embeds (w/ Window Aggregation, 20 trx window) + Catboost:
  - `Accuracy: 0.5623333333333334`
  - `ROC-AUC: 0.8212206273323065`

---

- **CPC modeling:**

In [16]:
seed_everything(0)

**DataLoaders:**

In [17]:
data = PtlsDataModule(
    train_data=CpcDataset(
        MemoryMapDataset(data=data_train),
        min_len=1000,
        max_len=1200
    ),
    train_num_workers=4,
    train_batch_size=64,
    valid_data=CpcDataset(
        MemoryMapDataset(data=data_test),
        min_len=1000,
        max_len=1200
    ),
    valid_num_workers=4,
    valid_batch_size=64
)

**Модель:**

In [18]:
N_EPOCHS = 15

In [19]:
agg_encoder_params = dict(
    embeddings_noise=0.003,
    numeric_values={"amount_rur": "log"},
    embeddings={
        "trans_date": {"in": 800, "out": 128},
        "small_group": {"in": 250, "out": 128},
    },
    agg_samples=20,
    use_pre_agg_attention=False,
    use_window_attention=False
)

trx_encoder = WinAggregator(**agg_encoder_params)

seq_encoder = RnnSeqEncoder(
    trx_encoder=trx_encoder,
    hidden_size=512,
    type="gru"
)

cpc = CpcModule(
    seq_encoder=seq_encoder,
    n_forward_steps=6,
    n_negatives=40,
    optimizer_partial=partial(torch.optim.Adam, lr=2e-3),
    lr_scheduler_partial=partial(torch.optim.lr_scheduler.StepLR, step_size=5, gamma=0.5)
)

**Обучение:**

In [20]:
logger = CometLogger(project_name="EvS_SSL", experiment_name="CPC_modeling_WinAgg")

trainer = pl.Trainer(
    logger=logger,
    max_epochs=N_EPOCHS,
    accelerator="gpu",
    devices=1,
    enable_progress_bar=True
)

In [21]:
trainer.fit(cpc, data)

[1;38;5;39mCOMET INFO:[0m Experiment is live on comet.com https://www.comet.com/askoro/evs-ssl/e6d095de1b0c43eaaa2b3afb5d762a43

[1;38;5;39mCOMET INFO:[0m Couldn't find a Git repository in '/kaggle/working' nor in any parent directory. Set `COMET_GIT_DIRECTORY` if your Git Repository is elsewhere.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m Comet.ml Experiment Summary
[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m   Data:
[1;38;5;39mCOMET INFO:[0m     display_summary_level : 1
[1;38;5;39mCOMET INFO:[0m     name                  : CPC_modeling_WinAgg
[1;38;5;39mCOMET INFO:[0m     url                   : https://www.comet.com/askoro/evs-ssl/e6d095de1b0c43eaaa2b3afb5d762a43
[1;38;5;39mCOMET INFO:[0m   Metrics [count] (min, max):
[1;38;5;39mCOMET INFO:[0m     loss [759]              : (0.6028669476509094, 3.8364548683166504)
[1;38;5;39mCOMET INFO:[0m     seq_len [126]           : (834.5, 905.59375)
[1;38;5;39mCOMET INFO:[0m     valid/cpc_accuracy [15] : (0.8311789631843567, 0.8848307132720947)
[1;38;5;39mCOMET INFO:[0m   Others:
[1;38;5;39mCOMET INFO:[0m     Name : CPC_mo

In [22]:
trainer.logged_metrics

{'loss': tensor(0.8090),
 'seq_len': tensor(865.0715),
 'valid/cpc_accuracy': tensor(0.8848)}

In [23]:
torch.save(seq_encoder.state_dict(), "cpc_enc_win_agg_trx20.pt")

**Измерим качество на тесте (catboost поверх эмбеддингов):**

In [24]:
encoder = cpc.seq_encoder

device = "cuda:0"

encoder.to(device)

RnnSeqEncoder(
  (trx_encoder): WinAggregator(
    (embeddings): ModuleDict(
      (trans_date): NoisyEmbedding(
        800, 128, padding_idx=0
        (dropout): Dropout(p=0, inplace=False)
      )
      (small_group): NoisyEmbedding(
        250, 128, padding_idx=0
        (dropout): Dropout(p=0, inplace=False)
      )
    )
    (custom_embeddings): ModuleDict(
      (amount_rur): LogScaler()
    )
  )
  (seq_encoder): RnnEncoder(
    (rnn): GRU(257, 512, batch_first=True)
    (reducer): LastStepEncoder()
  )
)

In [25]:
encoder.seq_encoder.is_reduce_sequence = True

In [26]:
from tqdm import tqdm

seed_everything(0)

In [27]:
train_loader = inference_data_loader(data_train, num_workers=0, batch_size=128)
encoder.eval()
train_embeds = None

for i, batch in tqdm(enumerate(train_loader)):
    train_embeds_batch = encoder(batch.to(device))
    if i == 0:
        train_embeds = train_embeds_batch.detach().cpu().numpy()
    else:
        train_embeds = np.concatenate([train_embeds, train_embeds_batch.detach().cpu().numpy()], axis=0)
    
train_embeds

211it [00:06, 32.54it/s]


array([[ 0.1269838 ,  0.06069179, -0.8039373 , ...,  0.18366602,
        -0.08507717, -0.09044018],
       [ 0.24877743,  0.05744626, -0.9161373 , ..., -0.05227854,
        -0.2074807 ,  0.03470141],
       [ 0.54053456,  0.24899046, -0.7695131 , ..., -0.01845624,
        -0.0511656 , -0.17779744],
       ...,
       [ 0.25679582,  0.1806997 , -0.7733664 , ..., -0.5018151 ,
        -0.34194958, -0.2643399 ],
       [ 0.07263993,  0.20356496, -0.56089795, ...,  0.04252471,
        -0.16948244, -0.04190145],
       [ 0.1883829 ,  0.09145279, -0.8739188 , ...,  0.17472775,
        -0.37370485, -0.05681634]], dtype=float32)

In [28]:
test_loader = inference_data_loader(data_test, num_workers=0, batch_size=128)
encoder.eval()
test_embeds = None

for i, batch in tqdm(enumerate(test_loader)):
    test_embeds_batch = encoder(batch.to(device))
    if i == 0:
        test_embeds = test_embeds_batch.detach().cpu().numpy()
    else:
        test_embeds = np.concatenate([test_embeds, test_embeds_batch.detach().cpu().numpy()], axis=0)
    
test_embeds

24it [00:00, 47.60it/s]


array([[ 0.20589167, -0.02757158, -0.71205354, ..., -0.01409675,
        -0.23025711,  0.03673205],
       [ 0.10020506,  0.0552436 ,  0.44483566, ...,  0.09490182,
        -0.47456124, -0.05336427],
       [ 0.10850395,  0.04108413, -0.9087362 , ...,  0.7003215 ,
        -0.27659315, -0.1418961 ],
       ...,
       [ 0.2944248 , -0.09982838, -0.9190626 , ...,  0.18788023,
        -0.58525664, -0.09479044],
       [-0.00449984,  0.06880501, -0.5330772 , ..., -0.03484499,
        -0.05156473, -0.04935532],
       [ 0.0023348 , -0.09331447, -0.19638704, ...,  0.50846523,
        -0.20731081, -0.04913478]], dtype=float32)

In [29]:
clf = CatBoostClassifier(loss_function='MultiClass', task_type="GPU", devices='0')

clf.fit(train_embeds, target_train, plot_file="catboost_log.html")

Learning rate set to 0.12714
0:	learn: 1.3179146	total: 15.5s	remaining: 4h 18m 40s
1:	learn: 1.2674537	total: 15.5s	remaining: 2h 9m 18s
2:	learn: 1.2271986	total: 15.6s	remaining: 1h 26m 10s
3:	learn: 1.1943827	total: 15.6s	remaining: 1h 4m 37s
4:	learn: 1.1672764	total: 15.6s	remaining: 51m 40s
5:	learn: 1.1441641	total: 15.6s	remaining: 43m 2s
6:	learn: 1.1244565	total: 15.6s	remaining: 36m 53s
7:	learn: 1.1073006	total: 15.6s	remaining: 32m 15s
8:	learn: 1.0917326	total: 15.6s	remaining: 28m 40s
9:	learn: 1.0788413	total: 15.6s	remaining: 25m 47s
10:	learn: 1.0664829	total: 15.6s	remaining: 23m 26s
11:	learn: 1.0560190	total: 15.7s	remaining: 21m 28s
12:	learn: 1.0465190	total: 15.7s	remaining: 19m 49s
13:	learn: 1.0367745	total: 15.7s	remaining: 18m 24s
14:	learn: 1.0287337	total: 15.7s	remaining: 17m 10s
15:	learn: 1.0214178	total: 15.7s	remaining: 16m 5s
16:	learn: 1.0141051	total: 15.7s	remaining: 15m 8s
17:	learn: 1.0081807	total: 15.7s	remaining: 14m 17s
18:	learn: 1.0023099

<catboost.core.CatBoostClassifier at 0x7f9ee5b1f700>

In [30]:
test_pred = clf.predict(test_embeds)
test_proba = clf.predict_proba(test_embeds)

In [31]:
print("Accuracy:", accuracy_score(target_test, test_pred))
print("ROC-AUC:", roc_auc_score(target_test, test_proba, average="weighted", multi_class="ovr"))

Accuracy: 0.585
ROC-AUC: 0.8302730752867178


- CPC context embeds + Catboost:
  - `Accuracy: 0.5763333333333334`
  - ` ROC-AUC: 0.8252403572367512`

\

- CPC context embeds (w/ Window Aggregation, 3 trx window) + Catboost:
  - `Accuracy: 0.5756666666666667`
  - ` ROC-AUC: 0.8313551234048895`

\

- CPC context embeds (w/ Window Aggregation, 5 trx window) + Catboost:
  - `Accuracy: 0.595`
  - ` ROC-AUC: 0.8372577763355443`

\

- CPC context embeds (w/ Window Aggregation, 10 trx window) + Catboost:
  - `Accuracy: 0.5896666666666667`
  - ` ROC-AUC: 0.8336245102246188`

\

- CPC context embeds (w/ Window Aggregation, 20 trx window) + Catboost:
  - `Accuracy: 0.585`
  - ` ROC-AUC: 0.8302730752867178`

**Интересно - как в случае с агрегацией (любой такого типа) вырос ROC-AUC по сравнению с бейзлайном (без агрегации). Зависимость от размера окна следующая: сначала с увеличением размера окна качество растёт, затем достигается оптимальный размер окна - при нём качество максимально, после дальнейшего увеличения окна качество падает.**

**Здесь оптимальный размер окна равен 5 - будем далее использовать именно такое значение при агрегации с помощью `WinAggregator` для CPC.**

---

- **GPT:**

In [20]:
seed_everything(0)

**DataLoaders:**

In [21]:
data = PtlsDataModule(
    train_data=GptDataset(
        MemoryMapDataset(data=data_train),
        min_len=1000,
        max_len=1200
    ),
    train_num_workers=4,
    train_batch_size=16,
    valid_data=GptDataset(
        MemoryMapDataset(data=data_test),
        min_len=1000,
        max_len=1200
    ),
    valid_num_workers=4,
    valid_batch_size=16
)

**Модель:**

In [22]:
from torchmetrics import MeanMetric
from typing import Tuple, Dict, List, Union
from torch import nn
import torch.nn.functional as F 
from ptls.nn.seq_encoder.abs_seq_encoder import AbsSeqEncoder
from ptls.nn import PBL2Norm
from ptls.data_load.padded_batch import PaddedBatch


class MeanPooling(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, pb: PaddedBatch):
        payload = pb.payload # (B, T, H)
        mask = pb.seq_len_mask.bool()
        pb_mean = payload.sum(dim=1) / mask.float().sum(dim=1, keepdim=True)
        return pb_mean


class StatPooling(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, pb: PaddedBatch):
        payload = pb.payload # (B, T, H)
        mask = pb.seq_len_mask.bool()
        inf_mask = torch.zeros_like(mask, device=mask.device).float()
        inf_mask[~mask] = -torch.inf
        
        pb_mean = payload.sum(dim=1) / mask.float().sum(dim=1, keepdim=True)
        pb_max = torch.max(payload + inf_mask.unsqueeze(-1), dim=1)[0]
        pb_stat = torch.cat((pb_mean, pb_max), dim=1)
        return pb_stat


class GPTHead(torch.nn.Module):   
    def __init__(self, input_size, n_classes, hidden_size=64, drop_p=0.1):
        super().__init__()
        self.head = nn.Sequential(
            nn.Linear(input_size, hidden_size, bias=True),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(hidden_size, n_classes)
        )
    def forward(self, x):
        x = self.head(x)
        return x


class GptPretrainModule(pl.LightningModule):
    """GPT2 Language model

    Sequence transactions are encoded by `trx_encoder`.
    Then `seq_encoder` encodes the given sequence 
    (we actually use NN to modify sequence transactions representations,
    then (during inference) we calculate the mean of these encoded transactions to get the representation of the whole sequence).
    After this we use heads to predict the classes of features of the future transaction.

    Parameters
    ----------
    trx_encoder:
        Module for transform dict with feature sequences to sequence of transaction representations
    seq_encoder:
        Module for sequence processing. Generally this is transformer based encoder. Rnn is also possible
        Should work without sequence reduction
    head_hidden_size:
        Hidden size of heads for feature prediction
    seed_seq_len:
         Size of starting sequence without loss 
    total_steps:
        total_steps expected in OneCycle lr scheduler
    max_lr:
        max_lr of OneCycle lr scheduler
    weight_decay:
        weight_decay of Adam optimizer
    pct_start:
        % of total_steps when lr increase
    norm_predict:
        use l2 norm for transformer output or not
    """

    def __init__(self,
                 trx_encoder: torch.nn.Module,
                 seq_encoder: AbsSeqEncoder,
                 head_hidden_size: int = 64,
                 total_steps: int = 64000,
                 seed_seq_len: int = 16,
                 max_lr: float = 0.00005,
                 weight_decay: float = 0.0,
                 pct_start: float = 0.1,
                 norm_predict: bool = False
                 ):

        super().__init__()
        self.save_hyperparameters(ignore=['trx_encoder', 'seq_encoder'])

        self.trx_encoder = trx_encoder
        self._seq_encoder = seq_encoder
        self._seq_encoder.is_reduce_sequence = False

        self.head = nn.ModuleDict()
        for col_name, noisy_emb in self.trx_encoder.embeddings.items():
            self.head[col_name] = GPTHead(input_size=self._seq_encoder.embedding_size, hidden_size=head_hidden_size, n_classes=noisy_emb.num_embeddings)

        if self.hparams.norm_predict:
            self.fn_norm_predict = PBL2Norm()

        self.loss = nn.CrossEntropyLoss(ignore_index=0)

        self.train_gpt_loss = MeanMetric()
        self.valid_gpt_loss = MeanMetric()

    def forward(self, batch: PaddedBatch):
        z_trx = self.trx_encoder(batch) 
        out = self._seq_encoder(z_trx)
        if self.hparams.norm_predict:
            out = self.fn_norm_predict(out)
        return out

    def loss_gpt(self, logits, labels):
        loss = 0
        for col_name, head in self.head.items():
            y_pred = head(logits[:, self.hparams.seed_seq_len:-1, :])
            y_pred = y_pred.view(-1, y_pred.size(-1))

            y_true = labels[col_name][:, self.hparams.seed_seq_len+1:]
            y_true = torch.flatten(y_true.long())

            loss += self.loss(y_pred, y_true)
            
        return loss

    def training_step(self, batch, batch_idx):
        out = self.forward(batch)  # PB: B, T, H
        out = out.payload if isinstance(out, PaddedBatch) else out
        labels = batch.payload

        loss_gpt = self.loss_gpt(out, labels)
        self.train_gpt_loss(loss_gpt)
        self.log('loss', loss_gpt, sync_dist=True)
        return loss_gpt

    def validation_step(self, batch, batch_idx):
        out = self.forward(batch)  # PB: B, T, H
        out = out.payload if isinstance(out, PaddedBatch) else out
        labels = batch.payload

        loss_gpt = self.loss_gpt(out, labels)
        self.valid_gpt_loss(loss_gpt)

    def on_training_epoch_end(self):
        self.log('train loss (by epochs)', self.train_gpt_loss, prog_bar=True, logger=True, sync_dist=True, rank_zero_only=True)

    def on_validation_epoch_end(self):
        self.log('val loss (by epochs)', self.valid_gpt_loss, prog_bar=True, logger=True, sync_dist=True, rank_zero_only=True)

    def configure_optimizers(self):
        optim = torch.optim.NAdam(self.parameters(),
                                  lr=self.hparams.max_lr,
                                  weight_decay=self.hparams.weight_decay
                                 )
        
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer=optim,
            max_lr=self.hparams.max_lr,
            total_steps=self.hparams.total_steps,
            pct_start=self.hparams.pct_start,
            anneal_strategy='cos',
            cycle_momentum=False,
            div_factor=25.0,
            final_div_factor=10000.0,
            three_phase=False
        )
        
        scheduler = {'scheduler': scheduler, 'interval': 'step'}
        return [optim], [scheduler]
    
    @property
    def seq_encoder(self):
        return GPTInferenceModule(pretrained_model=self)


class GPTInferenceModule(torch.nn.Module):
    def __init__(self, pretrained_model):
        super().__init__()
        self.model = pretrained_model
        self.model.is_reduce_sequence = False
        self.mean_pooling = MeanPooling()
        self.stat_pooling = StatPooling()

    def forward(self, batch, eval_strategy="mean"):
        z_trx = self.model.trx_encoder(batch)
        out = self.model._seq_encoder(z_trx)
        out = out if isinstance(out, PaddedBatch) else PaddedBatch(out, batch.seq_lens)

        if eval_strategy == "mean":
            out = self.mean_pooling(out)
        elif eval_strategy == "stat":
            out = self.stat_pooling(out)

        if self.model.hparams.norm_predict:
            out = out / (out.pow(2).sum(dim=-1, keepdim=True) + 1e-9).pow(0.5)
        return out

In [23]:
class WinAggGPTPretrainModule_MultiPred(GptPretrainModule):
    def __init__(self,
                 trx_encoder: torch.nn.Module,
                 seq_encoder: AbsSeqEncoder,
                 head_hidden_size: int = 64,
                 total_steps: int = 64000,
                 seed_seq_len: int = 16,
                 max_lr: float = 0.00005,
                 weight_decay: float = 0.0,
                 pct_start: float = 0.1,
                 norm_predict: bool = False
                 ):
        super().__init__(
            trx_encoder=trx_encoder,
            seq_encoder=seq_encoder,
            head_hidden_size=head_hidden_size,
            total_steps=total_steps,
            seed_seq_len=seed_seq_len,
            max_lr=max_lr,
            weight_decay=weight_decay,
            pct_start=pct_start,
            norm_predict=norm_predict
        )
        self.agg_samples = trx_encoder.agg_samples

    def loss_gpt(self, logits, labels):
        loss = 0
        
        for col_name, head in self.head.items():
            n_obj = 0
            out = head(logits[:, self.hparams.seed_seq_len:-1, :])
            
            for shift in range(self.agg_samples):
                y_true = labels[col_name][:, ((self.hparams.seed_seq_len + 1) * self.agg_samples + shift)::self.agg_samples]
                y_true = torch.flatten(y_true.long())
                
                if y_true.shape[0] < out.shape[0] * out.shape[1]:
                    pred = out[:, :-1, :]
                    pred = pred.reshape(-1, pred.size(-1))
                else:
                    pred = out.reshape(-1, out.size(-1))
                n_obj += pred.shape[0] 
                    
                loss += self.loss(pred, y_true) * pred.shape[0]
                
        return loss / n_obj


class WinAggGPTPretrainModule(GptPretrainModule):
    def __init__(self,
                 trx_encoder: torch.nn.Module,
                 seq_encoder: AbsSeqEncoder,
                 head_hidden_size: int = 64,
                 total_steps: int = 64000,
                 seed_seq_len: int = 16,
                 max_lr: float = 0.00005,
                 weight_decay: float = 0.0,
                 pct_start: float = 0.1,
                 norm_predict: bool = False
                 ):
        super().__init__(
            trx_encoder=trx_encoder,
            seq_encoder=seq_encoder,
            head_hidden_size=head_hidden_size,
            total_steps=total_steps,
            seed_seq_len=seed_seq_len,
            max_lr=max_lr,
            weight_decay=weight_decay,
            pct_start=pct_start,
            norm_predict=norm_predict
        )
        self.agg_samples = trx_encoder.agg_samples

    def loss_gpt(self, logits, labels):
        loss = 0
        
        for col_name, head in self.head.items():
            out = head(logits[:, self.hparams.seed_seq_len:-1, :])
            
            y_true = labels[col_name][:, ((self.hparams.seed_seq_len + 1) * self.agg_samples)::self.agg_samples]
            y_true = torch.flatten(y_true.long())
            
            if y_true.shape[0] < out.shape[0] * out.shape[1]:
                pred = out[:, :-1, :]
                pred = pred.reshape(-1, pred.size(-1))
            else:
                pred = out.reshape(-1, out.size(-1))
                    
            loss += self.loss(pred, y_true)
                
        return loss


class WinAggGPTPretrainModule_MultiLabel(GptPretrainModule):
    def __init__(self,
                 trx_encoder: torch.nn.Module,
                 seq_encoder: AbsSeqEncoder,
                 head_hidden_size: int = 64,
                 total_steps: int = 64000,
                 seed_seq_len: int = 16,
                 max_lr: float = 0.00005,
                 weight_decay: float = 0.0,
                 pct_start: float = 0.1,
                 norm_predict: bool = False
                 ):
        super().__init__(
            trx_encoder=trx_encoder,
            seq_encoder=seq_encoder,
            head_hidden_size=head_hidden_size,
            total_steps=total_steps,
            seed_seq_len=seed_seq_len,
            max_lr=max_lr,
            weight_decay=weight_decay,
            pct_start=pct_start,
            norm_predict=norm_predict
        )
        self.agg_samples = trx_encoder.agg_samples
        self.loss = nn.MultiLabelSoftMarginLoss()

    def loss_gpt(self, logits, labels):
        loss = 0
        
        for col_name, head in self.head.items():
            pred = head(logits[:, self.hparams.seed_seq_len:-1, :])

            ohe_labels = torch.zeros((pred.shape[0] * pred.shape[1], pred.shape[2]), device=pred.device)
            
            for shift in range(self.agg_samples):
                y_true = labels[col_name][:, ((self.hparams.seed_seq_len + 1) * self.agg_samples + shift)::self.agg_samples]
                y_true = torch.flatten(y_true.long())
                ohe_labels_part = F.one_hot(y_true, num_classes=pred.shape[2])
                
                if ohe_labels_part.shape[0] < pred.shape[0] * pred.shape[1]:
                    padding = torch.zeros((pred.shape[0], 1, pred.shape[2]), device=ohe_labels_part.device)
                    ohe_labels_part = torch.cat((ohe_labels_part.reshape(pred.shape[0], pred.shape[1] - 1, pred.shape[2]), padding), dim=1).reshape(pred.shape[0] * pred.shape[1], pred.shape[2])
                
                ohe_labels += ohe_labels_part

            ohe_labels[ohe_labels > 1] = 1
            
            pred = pred.reshape(-1, pred.size(-1))

            loss += self.loss(pred, ohe_labels)
                
        return loss

In [24]:
N_EPOCHS = 15

In [25]:
agg_encoder_params = dict(
    embeddings_noise=0.003,
    embeddings={
        "trans_date": {"in": 730, "out": 64},
        "small_group": {"in": 204, "out": 64},
        "amount_rur": {"in": BINS_NUM, "out": 64}
    },
    agg_samples=10,
    use_pre_agg_attention=False,
    use_window_attention=False
)

trx_encoder = WinAggregator(**agg_encoder_params)

seq_encoder = GptEncoder(
    n_embd=trx_encoder.output_size,
    n_layer=6,
    n_head=6,
    n_inner=256,
    activation_function="gelu_new",
    resid_pdrop=0.1,
    embd_pdrop=0.1,
    attn_pdrop=0.1,
    n_positions=2048,
    use_positional_encoding=True,
    use_start_random_shift=True,
    is_reduce_sequence=False
)

gpt = WinAggGPTPretrainModule_MultiLabel(
    trx_encoder=trx_encoder,
    seq_encoder=seq_encoder,
    head_hidden_size=256,
    total_steps=(N_EPOCHS * 1688), # num_epochs * num_steps_per_epoch
    seed_seq_len=16,
    max_lr=1e-3,
    weight_decay=0.,
    pct_start=0.1,
    norm_predict=False
)

**Обучение:**

In [26]:
logger = CometLogger(project_name="EvS_SSL", experiment_name="GPT_modeling_WinAgg")

trainer = pl.Trainer(
    logger=logger,
    max_epochs=N_EPOCHS,
    accelerator="gpu",
    devices=1,
    enable_progress_bar=True
)

In [27]:
trainer.fit(gpt, data)

[1;38;5;39mCOMET INFO:[0m Experiment is live on comet.com https://www.comet.com/askoro/evs-ssl/0bdd3613cfc44f44b41f124fa7499d10

[1;38;5;39mCOMET INFO:[0m Couldn't find a Git repository in '/kaggle/working' nor in any parent directory. Set `COMET_GIT_DIRECTORY` if your Git Repository is elsewhere.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m Comet.ml Experiment Summary
[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m   Data:
[1;38;5;39mCOMET INFO:[0m     display_summary_level : 1
[1;38;5;39mCOMET INFO:[0m     name                  : GPT_modeling_WinAgg
[1;38;5;39mCOMET INFO:[0m     url                   : https://www.comet.com/askoro/evs-ssl/0bdd3613cfc44f44b41f124fa7499d10
[1;38;5;39mCOMET INFO:[0m   Metrics [count] (min, max):
[1;38;5;39mCOMET INFO:[0m     loss [3038]               : (0.20921623706817627, 2.0972025394439697)
[1;38;5;39mCOMET INFO:[0m     val loss (by epochs) [15] : (0.24627543985843658, 0.2891102731227875)
[1;38;5;39mCOMET INFO:[0m   Others:
[1;38;5;39mCOMET INFO:[0m     Name : GPT_modeling_WinAgg
[1;38;5;39mCOMET INFO:[0m   Parameters:
[1;38;5;39mCOM

In [28]:
trainer.logged_metrics

{'loss': tensor(0.2285), 'val loss (by epochs)': tensor(0.2464)}

In [29]:
encoder = gpt.seq_encoder

In [30]:
torch.save(encoder.state_dict(), "gpt_WinAgg_trx10_multilabel.pt")

**Измерим качество на тесте (catboost поверх эмбеддингов):**

In [29]:
# import gdown

# gdown.download("https://drive.google.com/uc?export=download&id=1YBstN7hpEIREo7zORmPoEZ_0NyBgfjm6", "gpt_baseline_NAdam.pt")

Downloading...
From (original): https://drive.google.com/uc?export=download&id=1YBstN7hpEIREo7zORmPoEZ_0NyBgfjm6
From (redirected): https://drive.google.com/uc?export=download&id=1YBstN7hpEIREo7zORmPoEZ_0NyBgfjm6&confirm=t&uuid=be4debcd-1d0b-4619-a15e-892431777c63
To: /kaggle/working/gpt_baseline_NAdam.pt
100%|██████████| 34.7M/34.7M [00:00<00:00, 156MB/s] 


'gpt_baseline_NAdam.pt'

In [31]:
# state_dict = torch.load("./gpt_baseline_NAdam.pt")
# encoder.load_state_dict(state_dict)

device = "cuda:0"

encoder.to(device)

GPTInferenceModule(
  (model): WinAggGPTPretrainModule_MultiLabel(
    (trx_encoder): WinAggregator(
      (embeddings): ModuleDict(
        (trans_date): NoisyEmbedding(
          730, 64, padding_idx=0
          (dropout): Dropout(p=0, inplace=False)
        )
        (small_group): NoisyEmbedding(
          204, 64, padding_idx=0
          (dropout): Dropout(p=0, inplace=False)
        )
        (amount_rur): NoisyEmbedding(
          128, 64, padding_idx=0
          (dropout): Dropout(p=0, inplace=False)
        )
      )
      (custom_embeddings): ModuleDict()
    )
    (_seq_encoder): GptEncoder(
      (transf): GPT2Model(
        (wte): Embedding(4, 192)
        (wpe): Embedding(2048, 192)
        (drop): Dropout(p=0.1, inplace=False)
        (h): ModuleList(
          (0-5): 6 x GPT2Block(
            (ln_1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
            (attn): GPT2SdpaAttention(
              (c_attn): Conv1D(nf=576, nx=192)
              (c_proj): Conv1D(

In [32]:
from tqdm import tqdm

seed_everything(0)

In [33]:
train_loader = inference_data_loader(data_train, num_workers=0, batch_size=8)
encoder.eval()
train_embeds = None

for i, batch in tqdm(enumerate(train_loader)):
    train_embeds_batch = encoder(batch.to(device), eval_strategy="mean")
    if i == 0:
        train_embeds = train_embeds_batch.detach().cpu().numpy()
    else:
        train_embeds = np.concatenate([train_embeds, train_embeds_batch.detach().cpu().numpy()], axis=0)
    
train_embeds

3375it [00:31, 108.76it/s]


array([[ 0.31536353, -0.0842438 ,  0.3832604 , ..., -0.10778586,
        -0.6980415 , -0.36430973],
       [ 0.0241911 ,  0.08230145,  0.21791126, ...,  0.8697108 ,
        -0.30115065,  0.13838275],
       [ 0.49131852,  0.17025314,  0.153876  , ..., -0.07249554,
        -0.30611566,  0.12305231],
       ...,
       [ 1.0160196 ,  0.0417424 ,  0.00713044, ...,  0.20092922,
         0.03901017, -0.22009598],
       [ 0.21870992,  0.1300195 , -0.01389283, ..., -0.15358263,
         0.0228075 , -0.03736628],
       [ 1.5081874 ,  0.13934559,  0.21306148, ...,  0.1152389 ,
        -0.09603833,  0.45319265]], dtype=float32)

In [34]:
test_loader = inference_data_loader(data_test, num_workers=0, batch_size=1)
encoder.eval()
test_embeds = None

for i, batch in tqdm(enumerate(test_loader)):
    test_embeds_batch = encoder(batch.to(device), eval_strategy="mean")
    if i == 0:
        test_embeds = test_embeds_batch.detach().cpu().numpy()
    else:
        test_embeds = np.concatenate([test_embeds, test_embeds_batch.detach().cpu().numpy()], axis=0)
    
test_embeds

3000it [00:18, 158.40it/s]


array([[ 0.51009405,  0.28176787,  0.05083818, ..., -0.12658988,
         0.52998537, -0.48353752],
       [-0.8570451 , -0.42433134,  0.19792302, ..., -0.12885305,
        -0.3920125 , -0.25294182],
       [-1.064824  , -0.11614103,  0.40162733, ..., -0.25125697,
        -0.7998918 ,  0.03878516],
       ...,
       [-0.4554164 ,  0.0975307 ,  0.13601476, ..., -0.15060812,
        -0.44560304, -0.14436284],
       [-0.237132  ,  0.34084025,  0.02564172, ..., -0.00975494,
         0.02818822, -0.20714563],
       [-0.50885195,  0.16395755,  0.56415004, ..., -0.20615551,
         0.3055933 ,  0.30488333]], dtype=float32)

In [35]:
clf = CatBoostClassifier(loss_function='MultiClass', task_type="GPU", devices='0')

clf.fit(train_embeds, target_train, plot_file="catboost_log.html")

Learning rate set to 0.12714
0:	learn: 1.3201852	total: 11.8s	remaining: 3h 17m 13s
1:	learn: 1.2680158	total: 11.9s	remaining: 1h 38m 34s
2:	learn: 1.2271836	total: 11.9s	remaining: 1h 5m 41s
3:	learn: 1.1930850	total: 11.9s	remaining: 49m 14s
4:	learn: 1.1649018	total: 11.9s	remaining: 39m 22s
5:	learn: 1.1410818	total: 11.9s	remaining: 32m 48s
6:	learn: 1.1202687	total: 11.9s	remaining: 28m 6s
7:	learn: 1.1012972	total: 11.9s	remaining: 24m 34s
8:	learn: 1.0851079	total: 11.9s	remaining: 21m 50s
9:	learn: 1.0703359	total: 11.9s	remaining: 19m 38s
10:	learn: 1.0576010	total: 11.9s	remaining: 17m 51s
11:	learn: 1.0457335	total: 11.9s	remaining: 16m 21s
12:	learn: 1.0355069	total: 11.9s	remaining: 15m 5s
13:	learn: 1.0254308	total: 11.9s	remaining: 14m
14:	learn: 1.0165747	total: 11.9s	remaining: 13m 4s
15:	learn: 1.0078480	total: 11.9s	remaining: 12m 14s
16:	learn: 1.0004140	total: 12s	remaining: 11m 31s
17:	learn: 0.9932564	total: 12s	remaining: 10m 52s
18:	learn: 0.9866714	total: 12

<catboost.core.CatBoostClassifier at 0x799ed0ae0220>

In [36]:
test_pred = clf.predict(test_embeds)
test_proba = clf.predict_proba(test_embeds)

In [37]:
print("Accuracy:", accuracy_score(target_test, test_pred))
print("ROC-AUC:", roc_auc_score(target_test, test_proba, average="weighted", multi_class="ovr"))

Accuracy: 0.5813333333333334
ROC-AUC: 0.8354173752048473


- GPT embeds + Catboost:
  - `Accuracy: 0.6246666666666667`
  - `ROC-AUC: 0.8523603311141598`

\

- GPT embeds (w/ Window Aggregation (3 trx) & next several (3) trx prediction) + Catboost:
  - `Accuracy: 0.5906666666666667`
  - `ROC-AUC: 0.8418245528735205`

\

- GPT embeds (w/ Window Aggregation (5 trx) & next several (5) trx prediction) + Catboost:
  - `Accuracy: 0.5943333333333334`
  - `ROC-AUC: 0.8412700721706741`

\

- GPT embeds (w/ Window Aggregation (3 trx) & next (1) trx prediction) + Catboost:
  - `Accuracy: 0.5816666666666667`
  - `ROC-AUC: 0.8384765947236289`

\

- GPT embeds (w/ Window Aggregation (5 trx) & next (1) trx prediction) + Catboost:
  - `Accuracy: 0.5816666666666667`
  - `ROC-AUC: 0.8342331978726336`

\

- GPT embeds (w/ Window Aggregation (3 trx), multilabel training) + Catboost:
  - `Accuracy: 0.58`
  - `ROC-AUC: 0.8376578210301627`

\

- GPT embeds (w/ Window Aggregation (5 trx), multilabel training) + Catboost:
  - `Accuracy: 0.592`
  - `ROC-AUC: 0.8409446751939194`

# Итоги.

| Method                  |    Accuracy   | ROC-AUC       |
|-------------------------|---------------|---------------|
| **Flattened Sequences** | 0.508         | 0.7635        |
| **GRU (+ MLP)**         | 0.605         | 0.8459        |
| **GRU (+ Catboost)**    | 0.586         | 0.8331        |
| **CoLES**               | 0.606         | 0.8485        |
| **CPC Modeling**        | 0.576         | 0.8252        |
| **GPT2**                | 0.625         | 0.8524        |