In [None]:
#Commands for downgrading Python version in Kaggle notebook

#adds external APT repository
!add-apt-repository -y 'ppa:deadsnakes/ppa'

#installs specific python version
!apt install -y python3.10

#installs distutils package for the specific python version
!apt-get install -y python3.10-distutils

#Install Python modules using the newly installed Python version
#-I, Ignores and overwrites the installed packages.
!sudo /usr/bin/python3.10 -m pip install -Iv <package-name>

#Run Python scripts using the newly installed Python version
!/usr/bin/python3.10 <python_script.py>

In [None]:
!mkdir -p data
!mkdir -p predictions
!mkdir -p embeddings
!mkdir -p models
!cd data/

!curl -OL https://storage.yandexcloud.net/ds-ods/files/data/docs/competitions/DataFusion2024/Data/clients.csv
!curl -OL https://storage.yandexcloud.net/ds-ods/files/data/docs/competitions/DataFusion2024/Data/train.csv
!curl -OL https://storage.yandexcloud.net/ds-ods/files/data/docs/competitions/DataFusion2024/Data/report_dates.csv
!curl -OL https://storage.yandexcloud.net/ds-ods/files/data/docs/competitions/DataFusion2024/Data/transactions.csv.zip
!curl -OL https://storage.yandexcloud.net/ds-ods/files/data/docs/competitions/DataFusion2024/Data/sample_submit_naive.csv

!unzip transactions.csv.zip

!rm transactions.csv.zip

!cd ../

# python -m get_test_ids

In [None]:
!pip install pyspark
!pip install lightning
!pip install pytorch-lifestream

In [None]:
# !python --version

In [None]:
import pandas as pd
import numpy as np
import os

In [None]:
import wandb

wandb.login(key="79f2120f8d4212aceb2c60b3c89a1b6727c19cff")

## Data preprocessing

In [None]:
from sklearn.model_selection import train_test_split

main = pd.read_csv('/kaggle/input/datafusion2024-contest/train.csv')
train = main[main.target != -1]
train_, test_ = train_test_split(train, random_state=42, test_size=0.2)
test_[['user_id']].to_csv('data/test_ids.csv', index=False)

In [None]:
train_

In [None]:
transactions_df = pd.read_csv('/kaggle/input/datafusion2024-contest/transactions.csv')
clients_df = pd.read_csv('/kaggle/input/datafusion2024-contest/clients.csv')
reports_df = pd.read_csv('/kaggle/input/datafusion2024-contest/report_dates.csv')

In [None]:
# train_df

In [None]:
# transactions_df

In [None]:
transactions_df['transaction_dttm'] = pd.to_datetime(transactions_df.transaction_dttm)
report_dates = pd.read_csv('/kaggle/input/datafusion2024-contest/report_dates.csv', parse_dates=['report_dt'])
df_ = transactions_df.merge(clients_df[['user_id', 'report']], how='left', on='user_id')
df_ = df_.merge(report_dates, how='left', on='report')
transactions_df['days_to_report'] = (df_['report_dt'] - df_['transaction_dttm']).dt.days

In [None]:
# Добавляем количество дней, часов с момента первой и предыдущей транзакций
first_trx = transactions_df.groupby('user_id')['transaction_dttm'].min().reset_index()
first_trx.rename(columns={'transaction_dttm': 'first_tr'}, inplace=True)
transactions_df = transactions_df.merge(first_trx, on='user_id', how='left')

transactions_df['days_from_first_tr'] = (transactions_df['transaction_dttm']-transactions_df['first_tr'])/ np.timedelta64(1, 'D')
transactions_df['days_from_first_tr'] = (transactions_df['days_from_first_tr']).astype('int')
transactions_df['days_from_prev_tr'] = transactions_df['transaction_dttm'].diff()/ np.timedelta64(1, 'D')
transactions_df['days_from_prev_tr'] = transactions_df['days_from_prev_tr'].fillna(0)

transactions_df['days_from_prev_tr'] = (transactions_df['days_from_prev_tr']).astype('int')

transactions_df['hours_from_first_tr'] = (transactions_df['transaction_dttm']-transactions_df['first_tr'])/ np.timedelta64(1, 'h')
transactions_df['hours_from_prev_tr'] = transactions_df['transaction_dttm'].diff()/ np.timedelta64(1, 'h')
transactions_df['hours_from_prev_tr'] = transactions_df['hours_from_prev_tr'].fillna(0)

transactions_df = transactions_df.drop(columns=['first_tr'])

In [None]:
# Кодируем день недели, добавляем флаг выходного дня
days_of_week = {'Monday': 1,
                'Tuesday': 2,
                'Wednesday': 3,
                'Thursday': 4,
                'Friday': 5,
                'Saturday': 6,
                'Sunday': 7
               }

transactions_df['day_of_week'] = transactions_df['transaction_dttm'].dt.day_name()
for k, v in days_of_week.items():
    transactions_df['day_of_week'].replace(k,v,inplace= True)
    
transactions_df["is_day_off"] = transactions_df['day_of_week'].map(lambda x: 1 if x in (6,7) else 0)

In [None]:
transactions_df

In [None]:
cat_cols_ = ['mcc_code',
             'currency_rk',
             'day_of_week',
             'is_day_off',]
num_cols_ = ['transaction_amt',
              'days_from_first_tr',
              'days_from_prev_tr',
              'hours_from_first_tr',
              'hours_from_prev_tr',
            ]                              

In [None]:
from ptls.preprocessing import PandasDataPreprocessor

trx_preprocessor = PandasDataPreprocessor(
    col_id='user_id',
    col_event_time='transaction_dttm',
    event_time_transformation='dt_to_timestamp',
    cols_category=cat_cols_,
    cols_numerical=num_cols_,
    return_records=True,
)

In [None]:
transactions_df_train = transactions_df[transactions_df['user_id'].isin(train_['user_id'])]
transactions_df_test = transactions_df[transactions_df['user_id'].isin(test_['user_id'])]

In [None]:
%%time

dataset_train = trx_preprocessor.fit_transform(transactions_df_train)
dataset_test = trx_preprocessor.fit_transform(transactions_df_test)

# Baseline

In [None]:
%load_ext autoreload
%autoreload 2
import torch
import pytorch_lightning as pl
from functools import partial
from ptls.nn import TrxEncoder, RnnSeqEncoder
from ptls.frames.coles import CoLESModule
# import lion_pytorch
from ptls.frames.coles.losses import SoftmaxLoss
from ptls.data_load.datasets import MemoryMapDataset
from ptls.data_load.iterable_processing import SeqLenFilter
from ptls.frames.coles import ColesIterableDataset
from ptls.frames.coles.split_strategy import SampleSlices
from ptls.frames import PtlsDataModule
from ptls.frames.inference_module import InferenceModule
from ptls.data_load.utils import collate_feature_dict
from lightning.pytorch.loggers import WandbLogger

trx_encoder_params = dict(
    embeddings_noise=0.003,
    numeric_values={'transaction_amt': 'identity',
                    'days_from_prev_tr': 'identity',
                   },
    embeddings={
        'currency_rk': {'in': 5, 'out': 8},
        'day_of_week': {'in': 8, 'out': 8},
        'mcc_code': {'in': 333, 'out': 16},
        },
    )

train_dl = PtlsDataModule(
    train_data=ColesIterableDataset(
        MemoryMapDataset(
            data=dataset_train,
            i_filters=[
                SeqLenFilter(min_seq_len=20)
            ],
        ),
        splitter=SampleSlices(
            split_count=5,
            cnt_min=20,
            cnt_max=150,
        ),
    ),
    train_num_workers=2,
    train_batch_size=128,
)

inference_dataset = MemoryMapDataset(
    data=dataset_test,
)

inference_dl = torch.utils.data.DataLoader(
    dataset=inference_dataset,
    collate_fn=collate_feature_dict,
    shuffle=False,
    batch_size=128,
    num_workers=8,
)

def get_coles_embeddings(random_seed):

    seq_encoder = RnnSeqEncoder(
        trx_encoder=TrxEncoder(**trx_encoder_params),
        hidden_size=512,
        type='gru',
    )

    model = CoLESModule(
        seq_encoder=seq_encoder,
        optimizer_partial=partial(torch.optim.AdamW, lr=0.001, weight_decay=1e-4),
        lr_scheduler_partial=partial(torch.optim.lr_scheduler.StepLR, step_size=5, gamma=0.9),
        loss = SoftmaxLoss()
    )
    wandb_logger = WandbLogger(project="MBD_My_Code", log_model="all", name="df2024_coles_base")

    trainer = pl.Trainer(
        logger=wandb_logger,
        max_epochs=20,
        accelerator="cuda" if torch.cuda.is_available() else "cpu",
        enable_progress_bar=True,
    )
    print(f'logger.version = {trainer.logger.version}')
    trainer.fit(model, train_dl)
    torch.save(model.seq_encoder.state_dict(), f"./models/coles-model{random_seed}.pt")
    inference_module = InferenceModule(
        model=seq_encoder,
        pandas_output=True,
        drop_seq_features=True,
        model_out_name=f'emb_{random_seed}')
    
    predict = pl.Trainer(accelerator="cuda" if torch.cuda.is_available() else "cpu").predict(inference_module, inference_dl)
    full_predict = pd.concat(predict, axis=0)
    full_predict.to_csv(f'./embeddings/coles_{random_seed}.csv', index=False) 

In [None]:
pl.seed_everything(0)
get_coles_embeddings(0)

In [None]:
inference_embeddings = pd.read_csv('./embeddings/coles_0.csv')

In [None]:
inference_embeddings

In [None]:
inference_embeddings_t = inference_embeddings.merge(test_[['user_id', 'target']], how='left', on='user_id')

In [None]:
inference_embeddings_t

In [None]:
inf_emb_train, inf_emb_test = train_test_split(inference_embeddings_t, random_state=42, test_size=0.2)

X_train, y_train = inf_emb_train.drop(columns=['user_id', 'target']), inf_emb_train['target']
X_test, y_test = inf_emb_test.drop(columns=['user_id', 'target']), inf_emb_test['target']

In [None]:
from lightgbm import LGBMClassifier

down_model = LGBMClassifier(
    n_estimators=600,
    boosting_type='gbdt',
    subsample=0.5,
    subsample_freq=1,
    learning_rate=0.02,
    feature_fraction=0.75,
    max_depth=6,
    lambda_l1=1,
    lambda_l2=1,
    min_data_in_leaf=50,
    random_state=42,
    n_jobs=8,
    verbose=-1
)

In [None]:
%%time

down_model.fit(X_train, y_train)

In [None]:
from sklearn.metrics import classification_report
from sklearn.metrics import roc_auc_score

predict = down_model.predict_proba(X_test)
print(f"ROC-AUC target = {roc_auc_score(y_test, predict[:, 1])}")

In [None]:
predict = down_model.predict(X_test)
print(classification_report(y_test, predict))

# Regional Attention

In [None]:
import torch
import pytorch_lightning as pl
from functools import partial
from ptls.nn import TrxEncoder, RnnSeqEncoder
from ptls.frames.coles import CoLESModule
# import lion_pytorch
from ptls.frames.coles.losses import SoftmaxLoss
from ptls.data_load.datasets import MemoryMapDataset
from ptls.data_load.iterable_processing import SeqLenFilter
from ptls.frames.coles import ColesIterableDataset
from ptls.frames.coles.split_strategy import SampleSlices
from ptls.frames import PtlsDataModule
from ptls.frames.inference_module import InferenceModule
from ptls.data_load.utils import collate_feature_dict

train_dl = PtlsDataModule(
    train_data=ColesIterableDataset(
        MemoryMapDataset(
            data=dataset_train,
            i_filters=[
                SeqLenFilter(min_seq_len=20)
            ],
        ),
        splitter=SampleSlices(
            split_count=5,
            cnt_min=20,
            cnt_max=150,
        ),
    ),
    train_num_workers=2,
    train_batch_size=128,
)

inference_dataset = MemoryMapDataset(
    data=dataset_test,
)

inference_dl = torch.utils.data.DataLoader(
    dataset=inference_dataset,
    collate_fn=collate_feature_dict,
    shuffle=False,
    batch_size=128,
    num_workers=8,
)

In [None]:
import torch
from torch import nn

from ptls.data_load import PaddedBatch
from ptls.nn.seq_encoder.rnn_encoder import RnnEncoder
from ptls.nn.seq_encoder.transformer_encoder import TransformerEncoder
from ptls.nn.seq_encoder.longformer_encoder import LongformerEncoder
from ptls.nn.seq_encoder.custom_encoder import Encoder
from ptls.nn.trx_encoder import TrxEncoder
from ptls.nn.seq_encoder.containers import SeqEncoderContainer
import torch.nn.functional as F

# class RnnSeqEncoderRegAttn(SeqEncoderContainer):
#     def __init__(self,
#                  trx_encoder=None,
#                  input_size=None,
#                  is_reduce_sequence=True,
#                  **seq_encoder_params,
#                  ):
#         super().__init__(
#             trx_encoder=trx_encoder,
#             seq_encoder_cls=RnnEncoder,
#             input_size=input_size,
#             seq_encoder_params=seq_encoder_params,
#             is_reduce_sequence=is_reduce_sequence,
#         )
        
#         self.reg_seq_encoder = RnnEncoder(
#             input_size=input_size if input_size is not None else trx_encoder.output_size,
#             is_reduce_sequence=is_reduce_sequence,
#             type='gru',
#             hidden_size=34,
#         )

#         self.emb_dim = 34
#         self.regional_attention = nn.MultiheadAttention(
#             embed_dim=self.emb_dim,
#             num_heads=2,
#             dropout=0.3,
#             batch_first=True
#         )

class RnnSeqEncoderRegAttn(SeqEncoderContainer):
    def __init__(self,
                 trx_encoder=None,
                 input_size=None,
                 is_reduce_sequence=True,
                 **seq_encoder_params,
                 ):
        super().__init__(
            trx_encoder=trx_encoder,
            seq_encoder_cls=RnnEncoder,
            input_size=input_size,
            seq_encoder_params=seq_encoder_params,
            is_reduce_sequence=is_reduce_sequence,
        )
        
        self.reg_seq_encoder = RnnEncoder(
            input_size=input_size if input_size is not None else trx_encoder.output_size,
            is_reduce_sequence=is_reduce_sequence,
            **seq_encoder_params,
        )

        self.emb_dim = 34
        self.regional_attention = nn.MultiheadAttention(
            embed_dim=self.emb_dim,
            num_heads=2,
            dropout=0.2,
            batch_first=True
        )

    def forward(self, x, names=None, seq_len=None, h_0=None):
        # print(f"x_in = {x.payload['amount'].size()}")
        x = self.trx_encoder(x)

        x_new = x.payload
        
        segment_length = 10
        pad_length = (segment_length - (x_new.size()[1] % segment_length)) % segment_length
        padded_x_new = F.pad(x_new, ((0, 0, 0, pad_length, 0, 0)), 'constant', 0)
        segmented_tensors = torch.stack(torch.split(padded_x_new, segment_length, dim=1)).to(x_new.device)
        
        regional_embeddings = torch.Tensor().to(x.device)
        for tensor in segmented_tensors:
            tensor = PaddedBatch(tensor.permute(1, 0, 2), [tensor.size()[0]] * tensor.size()[1])
            regional_embed = self.reg_seq_encoder(tensor)
            regional_embeddings = torch.cat((regional_embeddings, regional_embed), 0)
        
        regional_embeddings = regional_embeddings[:len(regional_embeddings) - pad_length, :]
        # layer_norm = nn.LayerNorm([regional_embeddings.size()[0], regional_embeddings.size()[1]])
        # layer_norm.to(x.device)
        # regional_embeddings = layer_norm(regional_embeddings)
        # print(regional_embeddings.size())
        if regional_embeddings.size()[1] != self.emb_dim:
            regional_embeddings = F.pad(regional_embeddings, ((0, abs(regional_embeddings.size()[1] - self.emb_dim), 0, 0)), 'constant', 0)
        x_reg_embed, _ = self.regional_attention(regional_embeddings, regional_embeddings, regional_embeddings)
        # print(x_reg_embed.size())
        # print(x_new.size())
        if x_reg_embed.size()[0] != x_new.size()[1]:
            x_reg_embed = x_reg_embed[:, :-abs(x_reg_embed.size()[1] - x_new.size()[0])]
        # x_reg_embed = x_reg_embed.permute(1, 0)
        x_reg_embed = x_reg_embed[None, :, :]
        # print(f"{x_new.size()=}")
        # print(f"{x_reg_embed.size()=}")
        x_new = x_new + x_reg_embed
        x_new.to(x.device)
        x_new = PaddedBatch(x_new, x.seq_lens)
        # x_new.to(x.device)
        x = self.seq_encoder(x_new, h_0)
        # print(f"rnn_x_size = {x.size()}")
        return x
    
    

In [None]:
import ptls
from functools import partial

optimizer_partial = partial(
    torch.optim.AdamW,
    lr=0.001,
    weight_decay=1e-4
)

lr_scheduler_partial = partial(
    torch.optim.lr_scheduler.StepLR,
    step_size=2,
    gamma=0.9025
)

seq_encoder = RnnSeqEncoderRegAttn(
        trx_encoder=ptls.nn.TrxEncoder(
            embeddings_noise=0.003,
            numeric_values={'transaction_amt': 'identity',
                    'days_from_prev_tr': 'identity',
                   },
            embeddings={
            'currency_rk': {'in': 5, 'out': 8},
            'day_of_week': {'in': 8, 'out': 8},
            'mcc_code': {'in': 333, 'out': 16},
            },
        ),
        type='gru',
        hidden_size=34
    )

pl_module = ptls.frames.coles.CoLESModule(
    validation_metric=ptls.frames.coles.metric.BatchRecallTopK(
        K=4,
        metric='cosine'
    ),
    seq_encoder=seq_encoder,
    optimizer_partial=optimizer_partial,
    lr_scheduler_partial=lr_scheduler_partial
)

In [None]:
from lightning.pytorch.loggers import WandbLogger

wandb_logger = WandbLogger(project="MBD_My_Code", log_model="all", name='df2024_reg_attn')

trainer = pl.Trainer(
    logger=wandb_logger,
    max_epochs=20,
    accelerator="cuda" if torch.cuda.is_available() else "cpu",
    enable_progress_bar=True,
    gradient_clip_val=0.5,
    log_every_n_steps=50,
    limit_val_batches=32
)

In [None]:
trainer.fit(pl_module, train_dl)

In [None]:
import pandas as pd
import pytorch_lightning as pl
import torch
import numpy as np
from itertools import chain
from ptls.data_load.padded_batch import PaddedBatch
from datetime import datetime
from ptls.custom_layers import StatPooling
from ptls.nn.seq_step import LastStepEncoder


class InferenceModuleMultimodal(pl.LightningModule):
    def __init__(
        self,
        model,
        pandas_output=True,
        col_id='client_id',
        target_col_names=None,
        model_out_name='emb',
        model_type='notab'
    ):
        super().__init__()

        self.model = model
        self.pandas_output = pandas_output
        self.target_col_names = target_col_names
        self.col_id = col_id
        self.model_out_name = model_out_name
        self.model_type = model_type

        self.stat_pooler = StatPooling()
        self.last_step = LastStepEncoder()

    def forward(self, x):
        x_len = len(x)
        if x_len == 3:
            x, batch_ids, target_cols = x
        else: 
            x, batch_ids = x
        if 'seq_encoder' in dir(self.model):
            out = self.model.seq_encoder(x)
        else:
            out = self.model(x)
            
        if x_len == 3:
            target_cols = torch.tensor(target_cols)
            x_out = {
                self.col_id: batch_ids,
                self.model_out_name: out
            }
            if len(target_cols.size()) > 1:
                for idx, target_col in enumerate(self.target_col_names):
                    x_out[target_col] = target_cols[:, idx]
            else: 
                x_out[self.target_col_names[0]] = target_cols[::4]
        else:
            x_out = {
                self.col_id: batch_ids,
                self.model_out_name: out
            }

        if self.pandas_output:
            return self.to_pandas(x_out)
        return x_out

    @staticmethod
    def to_pandas(x):
        expand_cols = []
        scalar_features = {}

        for k, v in x.items():
            if type(v) is torch.Tensor:
                v = v.cpu().numpy()

            if type(v) is list or len(v.shape) == 1:
                scalar_features[k] = v
            elif len(v.shape) == 2:
                expand_cols.append(k)
            else:
                scalar_features[k] = None

        dataframes = [pd.DataFrame(scalar_features)]
        for col in expand_cols:
            v = x[col].cpu().numpy()
            dataframes.append(pd.DataFrame(v, columns=[f'{col}_{i:04d}' for i in range(v.shape[1])]))

        return pd.concat(dataframes, axis=1)

def collate_feature_dict_with_target(batch, col_id='client_id', target_col_names=None):
    batch_ids = []
    target_cols = []
    for sample in batch:
        batch_ids.append(sample[col_id])
        del sample[col_id]
        
        if target_col_names is not None:
            sample_targets = []
            for target_col in target_col_names:
                sample_targets.append(sample[target_col])
                del sample[target_col]
            target_cols.append(sample_targets)
                
            
    padded_batch = collate_feature_dict(batch)
    if target_col_names is not None:
        return padded_batch, batch_ids, target_cols
    return padded_batch, batch_ids

In [None]:
inference_module = InferenceModule(
        model=seq_encoder,
        pandas_output=True,
        drop_seq_features=True)
        # model_out_name=f'emb_{random_seed}')

predict = pl.Trainer(accelerator="cuda" if torch.cuda.is_available() else "cpu").predict(inference_module, inference_dl)

In [None]:
inference_embeddings = pd.concat(predict, axis=0)

In [None]:
inference_embeddings_t = inference_embeddings.merge(test_[['user_id', 'target']], how='left', on='user_id')

inf_emb_train, inf_emb_test = train_test_split(inference_embeddings_t, random_state=42, test_size=0.2)

X_train, y_train = inf_emb_train.drop(columns=['user_id', 'target']), inf_emb_train['target']
X_test, y_test = inf_emb_test.drop(columns=['user_id', 'target']), inf_emb_test['target']

In [None]:
from lightgbm import LGBMClassifier

down_model = LGBMClassifier(
    n_estimators=600,
    boosting_type='gbdt',
    subsample=0.5,
    subsample_freq=1,
    learning_rate=0.02,
    feature_fraction=0.75,
    max_depth=6,
    lambda_l1=1,
    lambda_l2=1,
    min_data_in_leaf=50,
    random_state=42,
    n_jobs=8,
    verbose=-1
)

In [None]:
%%time

down_model.fit(X_train, y_train)

In [None]:
from sklearn.metrics import classification_report
from sklearn.metrics import roc_auc_score

predict = down_model.predict_proba(X_test)
print(f"ROC-AUC target = {roc_auc_score(y_test, predict[:, 1])}")

In [None]:
predict = down_model.predict(X_test)
print(classification_report(y_test, predict))

# Cross-attention

In [None]:
import torch
from torch import nn

from ptls.data_load import PaddedBatch
from ptls.nn.seq_encoder.rnn_encoder import RnnEncoder
from ptls.nn.seq_encoder.transformer_encoder import TransformerEncoder
from ptls.nn.seq_encoder.longformer_encoder import LongformerEncoder
from ptls.nn.seq_encoder.custom_encoder import Encoder
from ptls.nn.trx_encoder import TrxEncoder
from ptls.nn.seq_encoder.containers import SeqEncoderContainer
import torch.nn.functional as F
from ptls.nn.seq_encoder.custom_encoder import MLP

class RnnSeqEncoderCrossAttn(SeqEncoderContainer):
    def __init__(self,
                 trx_encoder=None,
                 input_size=None,
                 small_patches_size=3,
                 large_patches_size=12,
                 is_reduce_sequence=True,
                 **seq_encoder_params,
                 ):
        super().__init__(
            trx_encoder=trx_encoder,
            seq_encoder_cls=RnnEncoder,
            input_size=input_size,
            seq_encoder_params=seq_encoder_params,
            is_reduce_sequence=is_reduce_sequence,
        )
        self.small_patches_size = small_patches_size
        self.large_patches_size = large_patches_size
        self.emb_dim = 128

        self.small_attn = nn.MultiheadAttention(
            embed_dim=self.emb_dim,
            num_heads=2,
            dropout=0.2,
            batch_first=True
        )
        
        self.large_attn = nn.MultiheadAttention(
            embed_dim=self.emb_dim,
            num_heads=2,
            dropout=0.2,
            batch_first=True
        )

        self.small_seq_encoder = RnnEncoder(
            input_size=input_size if input_size is not None else trx_encoder.output_size,
            is_reduce_sequence=is_reduce_sequence,
            type='gru',
            hidden_size=256
        )

        # MLP variation
        self.head_small = MLP(
                n_in=256,
                n_hidden=256,
                n_out=128
            )
        self.head_large = MLP(
                n_in=256,
                n_hidden=256,
                n_out=128
            ) 

    def forward(self, x, names=None, seq_len=None, h_0=None):
        """
        Возможно стоит использовать разные енкодеры для маленьких и для больших патчей
        """
        # for key, value in x.payload.items():
        #     small_x[key] = list(torch.split(value, self.small_patches_size, dim=1))
        #     if small_x[key][-1].size()[-1] != self.small_patches_size:
        #         pad_size = self.small_patches_size - small_x[key][-1].size()[-1]
        #         small_x[key][-1] = F.pad(small_x[key][-1], (0, pad_size), "replicate")
        #     small_x[key] = torch.stack(small_x[key])

        #     large_x[key] = list(torch.split(value, self.large_patches_size, dim=1))
        #     if large_x[key][-1].size()[-1] != self.large_patches_size:
        #         pad_size = self.large_patches_size - large_x[key][-1].size()[-1]
        #         large_x[key][-1] = F.pad(large_x[key][-1], (0, pad_size), "replicate")
        #     large_x[key] = torch.stack(large_x[key])
        x = self.trx_encoder(x)
        x_new = x.payload
        # print(f"{x_new.size()=}")
        # x_embs = self.seq_encoder(x)
        # print(f"{x_embs.size()=}")

        real_size = x_new.size()[1]
        small_patches = list(torch.split(x_new, self.small_patches_size, dim=1))
        large_patches = list(torch.split(x_new, self.large_patches_size, dim=1))

        if small_patches[-1].size()[1] != self.large_patches_size:
            pad_size = self.small_patches_size - small_patches[-1].size()[1]
            # print(pad_size)
            # print(small_patches[-1].size())
            small_patches[-1] = F.pad(small_patches[-1], (0, 0, pad_size, 0), "replicate")

        if large_patches[-1].size()[1] != self.large_patches_size:
            pad_size = self.large_patches_size - large_patches[-1].size()[1]
            large_patches[-1] = F.pad(large_patches[-1], (0, 0, pad_size, 0), "replicate")

        small_patches = torch.stack(small_patches)
        large_patches = torch.stack(large_patches)
        
        

        large_comp_embed = torch.zeros(x_new.size()[0], 128).to(x.device)
        for large_patch in large_patches:
            large_patch = PaddedBatch(large_patch, [large_patch.size()[1]] * large_patch.size()[0])
            large_emb = self.seq_encoder(large_patch)
            # large_emb = self.head_large(large_emb)
            large_comp_embed = (large_comp_embed + large_emb) / 2

        small_comp_embed = torch.zeros(x_new.size()[0], 128).to(x.device)
        for small_patch in small_patches:
            small_patch = PaddedBatch(small_patch, [small_patch.size()[1]] * small_patch.size()[0])
            small_emb = self.small_seq_encoder(small_patch)
            # small_emb = self.head_small(small_emb)
            small_comp_embed = (small_comp_embed + small_emb) / 2

        small_attn_emb, _ = self.small_attn(small_comp_embed, large_comp_embed, large_comp_embed)
        # large_attn_emb, _ = self.large_attn(large_comp_embed, small_comp_embed, small_comp_embed)

        out = torch.cat((small_attn_emb, large_comp_embed), dim=1)

        return out


In [None]:
import ptls

optimizer_partial = partial(
    torch.optim.AdamW,
    lr=0.001,
    weight_decay=1e-4
)

lr_scheduler_partial = partial(
    torch.optim.lr_scheduler.StepLR,
    step_size=2,
    gamma=0.9025
)

seq_encoder = RnnSeqEncoderCrossAttn(
        trx_encoder=ptls.nn.TrxEncoder(
            embeddings_noise=0.003,
            numeric_values={'transaction_amt': 'identity',
                    'days_from_prev_tr': 'identity',
                   },
            embeddings={
            'currency_rk': {'in': 5, 'out': 8},
            'day_of_week': {'in': 8, 'out': 8},
            'mcc_code': {'in': 333, 'out': 16},
            },
        ),
        type='gru',
        hidden_size=256,
        small_patches_size=3,
        large_patches_size=12
    )

pl_module = ptls.frames.coles.CoLESModule(
    validation_metric=ptls.frames.coles.metric.BatchRecallTopK(
        K=4,
        metric='cosine'
    ),
    seq_encoder=seq_encoder,
    optimizer_partial=optimizer_partial,
    lr_scheduler_partial=lr_scheduler_partial
)

In [None]:
from lightning.pytorch.loggers import WandbLogger

wandb_logger = WandbLogger(project="MBD_My_Code", log_model="all", name='df2024_cross_attn')

trainer = pl.Trainer(
    logger=wandb_logger,
    max_epochs=15,
    accelerator="cuda" if torch.cuda.is_available() else "cpu",
    enable_progress_bar=True,
    gradient_clip_val=0.5,
    log_every_n_steps=50,
    limit_val_batches=32
)

In [None]:
trainer.fit(pl_module, train_dl)

In [None]:
inference_module = InferenceModule(
        model=seq_encoder,
        pandas_output=True,
        drop_seq_features=True)
        # model_out_name=f'emb_{random_seed}')

predict = pl.Trainer(accelerator="cuda" if torch.cuda.is_available() else "cpu").predict(inference_module, inference_dl)

In [None]:
inference_embeddings = pd.concat(predict, axis=0)

In [None]:
inference_embeddings_t = inference_embeddings.merge(test_[['user_id', 'target']], how='left', on='user_id')

inf_emb_train, inf_emb_test = train_test_split(inference_embeddings_t, random_state=42, test_size=0.2)

X_train, y_train = inf_emb_train.drop(columns=['user_id', 'target']), inf_emb_train['target']
X_test, y_test = inf_emb_test.drop(columns=['user_id', 'target']), inf_emb_test['target']

In [None]:
from lightgbm import LGBMClassifier

down_model = LGBMClassifier(
    n_estimators=500,
    boosting_type='gbdt',
    subsample=0.5,
    subsample_freq=1,
    learning_rate=0.02,
    feature_fraction=0.75,
    max_depth=6,
    lambda_l1=1,
    lambda_l2=1,
    min_data_in_leaf=50,
    random_state=42,
    n_jobs=8,
    verbose=-1
)

In [None]:
%%time

down_model.fit(X_train, y_train)

In [None]:
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score

predict = down_model.predict_proba(X_test)
print(f"ROC-AUC target = {roc_auc_score(y_test, predict[:, 1])}")

In [None]:
predict = down_model.predict(X_test)
print(f"accuracy = {accuracy_score(y_test, predict)}")

# GPT Baseline

In [None]:
dataset_train[0]

In [None]:
import pandas as pd
import numpy as np
from ptls.data_load.iterable_processing_dataset import IterableProcessingDataset
from ptls.data_load import IterableChain
from datetime import datetime
from ptls.data_load.datasets.parquet_dataset import ParquetDataset, ParquetFiles
from ptls.data_load.iterable_processing.feature_filter import FeatureFilter
from ptls.data_load.iterable_processing.to_torch_tensor import ToTorch
import torch
from functools import partial
from torch.utils.data import DataLoader
from ptls.data_load.padded_batch import PaddedBatch
from ptls.data_load.utils import collate_feature_dict
from tqdm import tqdm
import ptls
import torch.nn.functional as F

class QuantilfyAmount(IterableProcessingDataset):
    def __init__(self, col_amt='amount', quantilies=None):
        super().__init__()
        self.col_amt = col_amt
        if quantilies is None:
            self.quantilies = [0., 267.6, 1198.65, 3667.2, 8639.8, 18325.7, 36713.2, 68950.3, 143969.1, 421719.1]
        else: 
            self.quantilies = quantilies
    
    def __iter__(self):
        for rec in self._src:
            features = rec[0] if type(rec) is tuple else rec
            amount = features[self.col_amt]
            am_quant = torch.zeros(len(amount), dtype=torch.int)
            for i, q in enumerate(self.quantilies):
                am_quant = torch.where(amount>q, i, am_quant)
            features[self.col_amt] = am_quant
            yield features

In [None]:
from ptls.frames.gpt.gpt_dataset import GptIterableDataset

dates_quantiles = [0, 5, 14, 36.7, 132.4, 256, 366]

train_dl = PtlsDataModule(
    train_data=GptIterableDataset(
        MemoryMapDataset(
            data=dataset_train,
            i_filters=[
        ptls.data_load.iterable_processing.SeqLenFilter(min_seq_len=16),
        ptls.data_load.iterable_processing.SeqLenFilter(max_seq_len=2048),
        QuantilfyAmount(col_amt='transaction_amt'),
        QuantilfyAmount(col_amt='days_from_prev_tr', quantilies=dates_quantiles),
        ptls.data_load.iterable_processing.CategorySizeClip(
            category_max_size={
                'transaction_amt': 10,
                'days_from_prev_tr': 365,
                'currency_rk': 80,
                'day_of_week': 7,
                'mcc_code': 90
              }
        ),
        ptls.data_load.iterable_processing.to_torch_tensor.ToTorch(),
        
        
        # GetPatches()
    ],
        ),
        min_len=40,
        max_len=500
    ),
    train_num_workers=0,
    train_batch_size=128,
)

inference_dataset = MemoryMapDataset(
    data=dataset_test,
    i_filters=[
        ptls.data_load.iterable_processing.SeqLenFilter(min_seq_len=16),
        ptls.data_load.iterable_processing.SeqLenFilter(max_seq_len=2048),
        QuantilfyAmount(col_amt='transaction_amt'),
        QuantilfyAmount(col_amt='days_from_prev_tr', quantilies=dates_quantiles),
        ptls.data_load.iterable_processing.CategorySizeClip(
            category_max_size={
                'transaction_amt': 10,
                'days_from_prev_tr': 365,
                'currency_rk': 80,
                'day_of_week': 7,
                'mcc_code': 90
              }
        ),
        ptls.data_load.iterable_processing.to_torch_tensor.ToTorch(),
        
        
        # GetPatches()
    ]
)

inference_dl = torch.utils.data.DataLoader(
    dataset=inference_dataset,
    collate_fn=collate_feature_dict,
    shuffle=False,
    batch_size=128,
    num_workers=2,
)

In [None]:
import pytorch_lightning as pl
import torch
from torch import nn
import warnings
from torchmetrics import MeanMetric
from typing import Tuple, Dict, List, Union

from ptls.nn.seq_encoder.abs_seq_encoder import AbsSeqEncoder
from ptls.nn import PBL2Norm
from ptls.data_load.padded_batch import PaddedBatch
from ptls.custom_layers import StatPooling, GEGLU
from ptls.nn.seq_step import LastStepEncoder

class Head(nn.Module):   
    def __init__(self, input_size, n_classes, hidden_size=64, drop_p=0.1):
        super().__init__()
        self.head = nn.Sequential(
            nn.Linear(input_size, hidden_size, bias=True),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(hidden_size, n_classes)
        )
    def forward(self, x):
        x = self.head(x)
        return x

class GptPretrainModule(pl.LightningModule):
    """GPT2 Language model

    Original sequence are encoded by `TrxEncoder`.
    Model `seq_encoder` predicts embedding of next transaction.
    Heads are used to predict each feature class of future transaction.

    Parameters
    ----------
    trx_encoder:
        Module for transform dict with feature sequences to sequence of transaction representations
    seq_encoder:
        Module for sequence processing. Generally this is transformer based encoder. Rnn is also possible
        Should works without sequence reduce
    head_hidden_size:
        Hidden size of heads for feature prediction
    seed_seq_len:
         Size of starting sequence without loss 
    total_steps:
        total_steps expected in OneCycle lr scheduler
    max_lr:
        max_lr of OneCycle lr scheduler
    weight_decay:
        weight_decay of Adam optimizer
    pct_start:
        % of total_steps when lr increase
    norm_predict:
        use l2 norm for transformer output or not
    inference_pooling_strategy:
        'out' - `seq_encoder` forward (`is_reduce_requence=True`) (B, H)
        'out_stat' - min, max, mean, std statistics pooled from `seq_encoder` layer (B, H) -> (B, 4H)
        'trx_stat' - min, max, mean, std statistics pooled from `trx_encoder` layer (B, H) -> (B, 4H)
        'trx_stat_out' - min, max, mean, std statistics pooled from `trx_encoder` layer + 'out' from `seq_encoder` (B, H) -> (B, 5H)
    """

    def __init__(self,
                 trx_encoder: torch.nn.Module,
                 seq_encoder: AbsSeqEncoder,
                 head_hidden_size: int = 64,
                 total_steps: int = 64000,
                 seed_seq_len: int = 16,
                 max_lr: float = 0.00005,
                 weight_decay: float = 0.0,
                 pct_start: float = 0.1,
                 norm_predict: bool = False,
                 inference_pooling_strategy: str = 'out_stat'
                 ):

        super().__init__()
        self.save_hyperparameters(ignore=['trx_encoder', 'seq_encoder'])

        self.trx_encoder = trx_encoder
        # assert self.trx_encoder.embeddings, '`embeddings` parameter for `trx_encoder` should contain at least 1 feature!'

        self._seq_encoder = seq_encoder
        self._seq_encoder.is_reduce_sequence = False

        self.head = nn.ModuleDict()
        for col_name, noisy_emb in self.trx_encoder.embeddings.items():
            self.head[col_name] = Head(input_size=self._seq_encoder.embedding_size, hidden_size=head_hidden_size, n_classes=noisy_emb.num_embeddings)

        if self.hparams.norm_predict:
            self.fn_norm_predict = PBL2Norm()

        self.loss = nn.CrossEntropyLoss(ignore_index=0)

        self.train_gpt_loss = MeanMetric()
        self.valid_gpt_loss = MeanMetric()

    def forward(self, batch: PaddedBatch):
        # print(f"{batch.payload['amount']}")
        z_trx = self.trx_encoder(batch) 
        out = self._seq_encoder(z_trx)
        if self.hparams.norm_predict:
            out = self.fn_norm_predict(out)
        return out
    
    def loss_gpt(self, predictions, labels, is_train_step):
        loss = 0
        for col_name, head in self.head.items():
            y_pred = head(predictions[:, self.hparams.seed_seq_len:-1, :])
            y_pred = y_pred.view(-1, y_pred.size(-1))

            y_true = labels[col_name][:, self.hparams.seed_seq_len+1:]
            y_true = torch.flatten(y_true.long())

            # print(f"{y_pred.size()=}")
            # print(f"{y_true=}")
            # print(y_true[10952])
            
            loss += self.loss(y_pred, y_true)
        return loss

    def training_step(self, batch, batch_idx):
        out = self.forward(batch)  # PB: B, T, H
        out = out.payload if isinstance(out, PaddedBatch) else out
        labels = batch.payload
        
        loss_gpt = self.loss_gpt(out, labels, is_train_step=True)
        self.train_gpt_loss(loss_gpt)
        self.log(f'gpt/loss', loss_gpt, sync_dist=True)
        return loss_gpt

    def validation_step(self, batch, batch_idx):
        out = self.forward(batch)  # PB: B, T, H
        out = out.payload if isinstance(out, PaddedBatch) else out
        labels = batch.payload

        loss_gpt = self.loss_gpt(out, labels, is_train_step=False)
        self.valid_gpt_loss(loss_gpt)

    def on_training_epoch_end(self):
        self.log(f'gpt/train_gpt_loss', self.train_gpt_loss, prog_bar=False, sync_dist=True, rank_zero_only=True)
        # self.train_gpt_loss reset not required here

    def on_validation_epoch_end(self):
        self.log(f'gpt/valid_gpt_loss', self.valid_gpt_loss, prog_bar=True, sync_dist=True, rank_zero_only=True)
        # self.valid_gpt_loss reset not required here

    def configure_optimizers(self):
        optim = torch.optim.NAdam(self.parameters(),
                                 lr=self.hparams.max_lr,
                                 weight_decay=self.hparams.weight_decay,
                                 )
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer=optim,
            max_lr=self.hparams.max_lr,
            total_steps=self.hparams.total_steps,
            pct_start=self.hparams.pct_start,
            anneal_strategy='cos',
            cycle_momentum=False,
            div_factor=25.0,
            final_div_factor=10000.0,
            three_phase=False,
        )
        scheduler = {'scheduler': scheduler, 'interval': 'step'}
        return [optim], [scheduler]
    
    @property
    def seq_encoder(self):
        return GPTInferenceModule(pretrained_model=self)

class GPTInferenceModule(torch.nn.Module):
    def __init__(self, pretrained_model):
        super().__init__()
        self.model = pretrained_model
        self.model.is_reduce_sequence = False

        self.stat_pooler = StatPooling()
        self.last_step = LastStepEncoder()

    def forward(self, batch):
        z_trx = self.model.trx_encoder(batch)
        out = self.model._seq_encoder(z_trx)
        out = out if isinstance(out, PaddedBatch) else PaddedBatch(out, batch.seq_lens)
        if self.model.hparams.inference_pooling_strategy=='trx_stat_out':
            stats = self.stat_pooler(z_trx)
            out = self.last_step(out)
            out = torch.cat([stats, out], dim=1)
        elif self.model.hparams.inference_pooling_strategy=='trx_stat':
            out = self.stat_pooler(z_trx)
        elif self.model.hparams.inference_pooling_strategy=='out_stat':
            out = self.stat_pooler(out)
        elif self.model.hparams.inference_pooling_strategy=='out':
            out = self.last_step(out)
        else:
            raise
        if self.model.hparams.norm_predict:
            out = out / (out.pow(2).sum(dim=-1, keepdim=True) + 1e-9).pow(0.5)
        return out

class CorrGptPretrainModule(GptPretrainModule):
    def __init__(self, feature_encoder, seq_encoder, trx_encoder, *args, **kwargs):
        trx_encoder.numeric_values = {}
        super().__init__(trx_encoder=trx_encoder, seq_encoder=seq_encoder, *args, **kwargs)
        self.save_hyperparameters(ignore=['trx_encoder', 'seq_encoder', 'feature_encoder'])
        self.feature_encoder = feature_encoder

    def forward(self, batch):
        # print(f"{batch.payload['event_time'].size()=}")
        z_trx = self.trx_encoder(batch)
        # print(z_trx.payload.size())
        payload = z_trx.payload.view(z_trx.payload.shape[:-1] + (-1, 16))
        print(payload.size())
        payload = self.feature_encoder(payload)
        encoded_trx = PaddedBatch(payload=payload, length=z_trx.seq_lens)
        out = self._seq_encoder(encoded_trx)
        # print(out.payload.size())
        if self.hparams.norm_predict:
            out = self.fn_norm_predict(out)
        return out

In [None]:
from ptls.nn import TabFormerFeatureEncoder
from ptls.nn import TransformerEncoder
from ptls.nn import TrxEncoder
import ptls

feature_encoder = TabFormerFeatureEncoder(
    n_cols=5,
    emb_dim=16
)

seq_encoder = TransformerEncoder(
    n_heads=2,
    n_layers=2,
    input_size=80,
    use_positional_encoding=True
)

trx_encoder=ptls.nn.TrxEncoder(
            embeddings_noise=0.003,
            # numeric_values={'transaction_amt': 'identity',
            #         'days_from_prev_tr': 'identity',
            #        },
            embeddings={
            'transaction_amt': {'in': 11 , 'out': 16},
            'days_from_prev_tr': {'in':10 , 'out': 16},
            'currency_rk': {'in': 5, 'out': 16},
            'day_of_week': {'in': 8, 'out': 16},
            'mcc_code': {'in': 333, 'out': 16},
            },
        )

pl_module = GptPretrainModule(
    # total_steps=20000,
    # feature_encoder=feature_encoder,
    trx_encoder=trx_encoder,
    seq_encoder=seq_encoder
)

In [None]:
from lightning.pytorch.loggers import WandbLogger

wandb_logger = WandbLogger(project="MBD_My_Code", log_model="all")

trainer = pl.Trainer(
    logger=wandb_logger,
    max_epochs=20,
    accelerator="cuda" if torch.cuda.is_available() else "cpu",
    enable_progress_bar=True,
    gradient_clip_val=0.5,
    log_every_n_steps=50,
    limit_val_batches=32
)

In [None]:
trainer.fit(pl_module, train_dl)

In [None]:
# Удалить!

import os
import pandas as pd
import pytorch_lightning as pl
import torch
import numpy as np

from itertools import chain
from ptls.data_load.padded_batch import PaddedBatch


class InferenceModule(pl.LightningModule):
    def __init__(self, model, pandas_output=True, drop_seq_features=True, model_out_name='out'):
        super().__init__()

        self.model = model
        self.pandas_output = pandas_output
        self.drop_seq_features = drop_seq_features
        self.model_out_name = model_out_name

    def forward(self, x: PaddedBatch):
        out = self.model(x)
        print(out)
        print(out.payload.size())
        if self.drop_seq_features:
            x = x.drop_seq_features()
            x[self.model_out_name] = out
        else:
            x.payload[self.model_out_name] = out
        if self.pandas_output:
            return self.to_pandas(x)
        return x

    def to_pandas(self, x):
        is_reduced = None
        scalar_features, seq_features, expand_features = {}, {}, {}
        df_scalar, df_seq, df_expand, out_df = None, None, None, None
        len_mask = None

        x_ = x
        print(x_)
        if type(x_) is PaddedBatch:
            len_mask = x_.seq_len_mask.bool().cpu().numpy()
            x_ = x_.payload
        is_reduced = (type(x_[self.model_out_name]) is not PaddedBatch)
        for k, v in x_.items():
            if type(v) is PaddedBatch:
                len_mask = v.seq_len_mask.bool().cpu().numpy()
                v = v.payload
            if type(v) is torch.Tensor:
                v = v.detach().cpu().numpy()
            if type(v) is list or len(v.shape) == 1:
                scalar_features[k] = v
            elif k.startswith('target'):
                scalar_features[k] = v
            elif len(v.shape) == 3:
                expand_features[k] = v
            elif k == self.model_out_name and len(v.shape) == 2:
                expand_features[k] = v
            elif len(v.shape) == 2:
                seq_features[k] = v

        if is_reduced:
            df_scalar, df_seq, df_expand = self.to_pandas_record(x, expand_features, scalar_features, seq_features, len_mask)
        else:
            df_scalar, df_seq, df_expand = self.to_pandas_sequence(x, expand_features, scalar_features, seq_features, len_mask)

        out_df = df_scalar
        if df_seq:
            df_seq = pd.concat(df_seq, axis = 1)
            out_df = pd.concat([df_scalar, df_seq], axis = 1)
        if df_expand:
            df_expand = pd.concat(df_expand, axis = 0).reset_index(drop=True)
            out_df = pd.concat([out_df.reset_index(drop=True), df_expand], axis = 1)

        return out_df
    
    def to_numpy(self, tensor):
        return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

    @staticmethod
    def to_pandas_record(x, expand_features, scalar_features, seq_features, len_mask):
        dataframes_scalar = []
        for k, v in scalar_features.items():
            dataframes_scalar.append(pd.DataFrame(v, columns=[k]))
        dataframes_scalar = pd.concat(dataframes_scalar, axis = 1)

        dataframes_seq = []
        for k, v in seq_features.items():
            data_lst = [usr[len_mask[i]] for i, usr in enumerate(v)]
            dataframes_seq.append(pd.DataFrame(zip(data_lst), columns=[k]))


        dataframes_expand = []
        for k, v in expand_features.items():
            for i, usr in enumerate(v):
                exp_num = usr.shape[1] if len(usr.shape) == 2 else usr.shape[0]
                df_trx = pd.DataFrame([usr], columns=[f'{k}_{j:04d}' for j in range(exp_num)])
                dataframes_expand.append(df_trx)

        return dataframes_scalar, dataframes_seq, dataframes_expand

    @staticmethod
    def to_pandas_sequence(x, expand_features, scalar_features, seq_features, len_mask):
        dataframes_scalar = []
        for k, v in scalar_features.items():
            data_lst = [[data]*np.sum(len_mask[i]) for i, data in enumerate(v)]
            data_lst = list(chain(*data_lst))
            dataframes_scalar.append(pd.DataFrame(data_lst, columns=[k]))
        dataframes_scalar = pd.concat(dataframes_scalar, axis = 1)

        dataframes_seq = []
        for k, v in seq_features.items():
            data_lst = [data[len_mask[i]] for i, data in enumerate(v)]
            data_lst = list(chain(*data_lst))
            dataframes_seq.append(pd.DataFrame(data_lst, columns=[k]))

        dataframes_expand = []
        for k, v in expand_features.items():
            for i, usr in enumerate(v):
                exp_num = usr.shape[1] if len(usr.shape) == 2 else usr.shape[0]
                df_trx = pd.DataFrame(usr[len_mask[i]], columns=[f'{k}_{j:04d}' for j in range(exp_num)])
                dataframes_expand.append(df_trx)

        return dataframes_scalar, dataframes_seq, dataframes_expand

In [None]:
inference_module = InferenceModule(
        model=pl_module,
        pandas_output=True,
        drop_seq_features=True)
        # model_out_name=f'emb_{random_seed}')

predict = trainer.predict(inference_module, inference_dl)

In [None]:
inference_embeddings = pd.concat(predict, axis=0)

In [None]:
inference_embeddings_t = inference_embeddings.merge(test_[['user_id', 'target']], how='left', on='user_id')

inf_emb_train, inf_emb_test = train_test_split(inference_embeddings_t, random_state=42, test_size=0.2)

X_train, y_train = inf_emb_train.drop(columns=['user_id', 'target']), inf_emb_train['target']
X_test, y_test = inf_emb_test.drop(columns=['user_id', 'target']), inf_emb_test['target']

In [None]:
from lightgbm import LGBMClassifier

down_model = LGBMClassifier(
    n_estimators=500,
    boosting_type='gbdt',
    subsample=0.5,
    subsample_freq=1,
    learning_rate=0.02,
    feature_fraction=0.75,
    max_depth=6,
    lambda_l1=1,
    lambda_l2=1,
    min_data_in_leaf=50,
    random_state=42,
    n_jobs=8,
    verbose=-1
)

In [None]:
%%time

down_model.fit(X_train, y_train)

In [None]:
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score

predict = down_model.predict_proba(X_test)
print(f"ROC-AUC target = {roc_auc_score(y_test, predict[:, 1])}")

In [None]:
predict = down_model.predict(X_test)
print(f"accuracy = {accuracy_score(y_test, predict)}")

# CoLES with SWIN enc

In [None]:
trx_encoder_params = dict(
    embeddings_noise=0.003,
    numeric_values={'transaction_amt': 'identity',
                    'days_from_prev_tr': 'identity',
                   },
    embeddings={
        'currency_rk': {'in': 5, 'out': 8},
        'day_of_week': {'in': 8, 'out': 8},
        'mcc_code': {'in': 333, 'out': 16},
        },
    )

train_dl = PtlsDataModule(
    train_data=ColesIterableDataset(
        MemoryMapDataset(
            data=dataset_train,
            i_filters=[
                SeqLenFilter(min_seq_len=20)
            ],
        ),
        splitter=SampleSlices(
            split_count=5,
            cnt_min=20,
            cnt_max=150,
        ),
    ),
    train_num_workers=2,
    train_batch_size=128,
)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import numpy as np


#------------------------------------------------------------------------------------------------------------
# Based on https://github.com/yukara-ikemiya/Swin-Transformer-1d/tree/main and adapted to pytorch-lifestream
#------------------------------------------------------------------------------------------------------------

import torch
import torch.nn as nn
from ptls.data_load.padded_batch import PaddedBatch


def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
    # copied from timm/models/layers/drop.py
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.

    """
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
    if keep_prob > 0.0 and scale_by_keep:
        random_tensor.div_(keep_prob)
    return x * random_tensor


class DropPath(nn.Module):
    # copied from timm/models/layers/drop.py
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """

    def __init__(self, drop_prob=None, scale_by_keep=True):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)


class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


def window_partition(x, window_size):
    """
    Args:
        x: (B, L, C)
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, C)
    """
    B, L, C = x.shape
    x = x.view(B, L // window_size, window_size, C)
    windows = x.contiguous().view(-1, window_size, C)
    return windows


def window_reverse(windows, window_size, L):
    """
    Args:
        windows: (num_windows*B, window_size, C)
        window_size (int): Window size
        L (int): Length of data

    Returns:
        x: (B, L, C)
    """
    B = int(windows.shape[0] / (L / window_size))
    x = windows.view(B, L // window_size, window_size, -1)
    x = x.contiguous().view(B, L, -1)
    return x


class WindowAttention(nn.Module):
    """ Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (int): The width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """
    def __init__(self, dim: int, window_size: int, num_heads: int,
                 qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros(2 * window_size - 1, num_heads))  # 2*window_size - 1, nH

        # get pair-wise relative position index for each token inside the window
        coords_w = torch.arange(self.window_size)
        relative_coords = coords_w[:, None] - coords_w[None, :]  # W, W
        relative_coords[:, :] += self.window_size - 1  # shift to start from 0

        # relative_position_index | example
        # [2, 1, 0]
        # [3, 2, 1]
        # [4, 3, 2]
        self.register_buffer("relative_position_index", relative_coords)  # (W, W): range of 0 -- 2*(W-1)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        torch.nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask_add, mask_mult):
        """
        Args:
            x: input features with shape of (num_windows*B, W, C)
            mask: (0/-inf) mask with shape of (num_windows, W, W) or None
        """
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size, self.window_size, -1)  # W, W, nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, W, W
        attn = attn + relative_position_bias.unsqueeze(0)

        nW = mask_add.shape[1]
        attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask_add
        attn = attn.view(-1, self.num_heads, N, N)
        attn = self.softmax(attn)
        attn = attn * mask_mult
        
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

    def extra_repr(self) -> str:
        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'


class SwinTransformerBlock(nn.Module):
    r""" Swin Transformer Block.

    Args:
        dim (int): Number of input channels.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
        decoder (bool, optional): Flag that shows whether this block is decoder-like (hence, attn_mask should prevent from seeing future tokens). True => decoder-like; False => encoder-like. Default: False
        start_end_fusion (bool, optional): Flag that shows if the last and the first half-windows should merge (True) or not (False).
    """

    def __init__(self, dim, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm,
                 decoder=False, start_end_fusion=True):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=self.window_size, num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        attn_mask = None
        self.register_buffer("attn_mask", attn_mask)

        self.decoder = decoder
        self.start_end_fusion = start_end_fusion

    def forward(self, x):
        seq_lens = x.seq_lens
        x = x.payload
        
        B, L, C = x.shape

        # define seq_len_mask
        mask = torch.arange(L, device=x.device)[None, :] + torch.ones((B, L), device=x.device)
        mask[mask > seq_lens[:, None]] = 0.
        mask[mask > 0.] = 1.
        mask = mask[:, :, None]

        # make new max seq_len `L` divisible by `self.window_size` by adding 'zero' samples
        num_samples_to_add = self.window_size - (L % self.window_size)
        
        if num_samples_to_add < self.window_size:
            additional_samples = torch.zeros((B, num_samples_to_add, C), device=x.device)
            x = torch.cat((x, additional_samples), dim=1)
            mask_additional_samples = torch.zeros((B, num_samples_to_add, mask.shape[2]), device=mask.device)
            mask = torch.cat((mask, mask_additional_samples), dim=1)
            L += num_samples_to_add

        # zero out padding transactions
        x = x * mask
        
        assert L >= self.window_size, f'input length ({L}) must be >= window size ({self.window_size})'
        assert L % self.window_size == 0, f'input length ({L}) must be divisible by window size ({self.window_size})'

        shortcut = x
        x = self.norm1(x)

        # shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=-self.shift_size, dims=1) # cyclic shift 
            if not self.start_end_fusion:
                shifted_x[:, -self.shift_size:] = 0. # zero out invalid embs
            mask = torch.roll(mask, shifts=-self.shift_size, dims=1) # cyclic shift of the mask
            if not self.start_end_fusion:
                mask[:, -self.shift_size:] = 0.
        else:
            shifted_x = x
        
        # partition
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, C
        mask = window_partition(mask, self.window_size) # nW*B, window_size, 1
        
        # calculate attn_mask
        attn_mask = (mask @ mask.transpose(-2, -1)) # nW*B, window_size, window_size
        
        if self.decoder:
            no_look_ahead_attn_mask = 1. - torch.triu(torch.ones_like(attn_mask), diagonal=1)
            attn_mask *= no_look_ahead_attn_mask
        
        attn_mask_real = attn_mask.clone().detach()
        attn_mask_real = attn_mask_real.view(attn_mask_real.shape[0], self.window_size, self.window_size).unsqueeze(1).expand(-1, self.num_heads, -1, -1) # B*nW, nH, window_size, window_size
        
        attn_mask[attn_mask == 0.] = -torch.inf
        attn_mask[attn_mask == 1.] = 0.
        attn_mask[:, torch.arange(attn_mask.shape[-1]), torch.arange(attn_mask.shape[-1])] = 0.
        attn_mask = attn_mask.view(B, attn_mask.shape[0] // B, self.window_size, self.window_size).unsqueeze(2).expand(-1, -1, self.num_heads, -1, -1) # B, nW, nH, window_size, window_size
        
        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask_add=attn_mask, mask_mult=attn_mask_real)  # nW*B, window_size, C
        
        # merge windows
        shifted_x = window_reverse(attn_windows, self.window_size, L)  # (B, L, C)

        # reverse zero-padding shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=self.shift_size, dims=1) # cyclic shift
            if not self.start_end_fusion:
                x[:, :self.shift_size] = 0. # zero out invalid embs
        else:
            x = shifted_x

        x = shortcut + self.drop_path(x)

        # FFN
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        
        return PaddedBatch(x, seq_lens)

    def extra_repr(self) -> str:
        return f"dim={self.dim}, num_heads={self.num_heads}, " \
               f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"

class SwinTransformerLayer(nn.Module):
    """ A basic Swin Transformer layer for one stage.

    Args:
        dim (int): Number of input channels.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        window_size (int): Local window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        decoder (bool, optional): Flag that shows whether blocks in this layer are decoder-like. True => decoder-like; False => encoder-like. Default: False
        start_end_fusion (bool, optional): Flag that shows if the last and the first half-windows should merge (True) or not (False).
    """

    def __init__(
        self,
        dim: int,
        depth: int,
        num_heads: int,
        window_size: int,
        mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
        drop_path=0., norm_layer=nn.LayerNorm,
        decoder=False, start_end_fusion=True
    ):
        super().__init__()
        self.dim = dim
        self.depth = depth
        self.num_heads = num_heads
        self.window_size = window_size

        # build blocks
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(dim=dim,
                                 num_heads=num_heads, window_size=window_size,
                                 shift_size=0 if (i % 2 == 0) else window_size // 2,
                                 mlp_ratio=mlp_ratio,
                                 qkv_bias=qkv_bias, qk_scale=qk_scale,
                                 drop=drop, attn_drop=attn_drop,
                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                                 norm_layer=norm_layer,
                                 decoder=decoder,
                                 start_end_fusion=start_end_fusion)
            for i in range(depth)])

    def forward(self, x):
        for blk in self.blocks:
            x = blk(x)
        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, depth={self.depth}, num_heads={self.num_heads}, window_size={self.window_size}"


In [None]:
import warnings

import torch

# from ptls.constant_repository import TORCH_EMB_DTYPE
from ptls.data_load.padded_batch import PaddedBatch
from ptls.nn.trx_encoder.batch_norm import RBatchNorm, RBatchNormWithLens
from ptls.nn.trx_encoder.noisy_embedding import NoisyEmbedding
from ptls.nn.trx_encoder.trx_encoder_base import TrxEncoderBase

TORCH_EMB_DTYPE = torch.float32

class TrxEncoderSWIN(TrxEncoderBase):
    """Network layer which makes representation for single transactions

     Input is `PaddedBatch` with ptls-format dictionary, with feature arrays of shape (B, T)
     Output is `PaddedBatch` with transaction embeddings of shape (B, T, H)
     where:
        B - batch size, sequence count in batch
        T - sequence length
        H - hidden size, representation dimension

    `ptls.nn.trx_encoder.noisy_embedding.NoisyEmbedding` implementation are used for categorical features.

    Parameters
        embeddings:
            dict with categorical feature names.
            Values must be like this `{'in': dictionary_size, 'out': embedding_size}`
            These features will be encoded with lookup embedding table of shape (dictionary_size, embedding_size)
            Values can be a `torch.nn.Embedding` implementation
        numeric_values:
            dict with numerical feature names.
            Values must be a string with scaler_name.
            Possible values are: 'identity', 'sigmoid', 'log', 'year'.
            These features will be scaled with selected scaler.
            Values can be `ptls.nn.trx_encoder.scalers.BaseScaler` implementatoin

            One field can have many scalers. In this case key become alias and col name should be in scaler.
            Check `TrxEncoderBase.numeric_values` for more details

        embeddings_noise (float):
            Noise level for embedding. `0` meens without noise
        emb_dropout (float):
            Probability of an element of embedding to be zeroed
        spatial_dropout (bool):
            Whether to dropout full dimension of embedding in the whole sequence

        use_batch_norm:
            True - All numerical values will be normalized after scaling
            False - No normalizing for numerical values
        use_batch_norm_with_lens:
            True - Respect seq_lens during batch_norm. Padding zeroes will be ignored
            False - Batch norm ower all time axis. Padding zeroes will included.

        orthogonal_init:
            if True then `torch.nn.init.orthogonal` applied
        linear_projection_size:
            Linear layer at the end will be added for non-zero value

        out_of_index:
            How to process a categorical indexes which are greater than dictionary size.
            'clip' - values will be collapsed to maximum index. This works well for frequency encoded categories.
                We join infrequent categories to one.
            'assert' - raise an error of invalid index appear.

        norm_embeddings: keep default value for this parameter
        clip_replace_value: Not useed. keep default value for this parameter
        positions: Not used. Keep default value for this parameter

    Examples:
        >>> B, T = 5, 20
        >>> trx_encoder = TrxEncoder(
        >>>     embeddings={'mcc_code': {'in': 100, 'out': 5}},
        >>>     numeric_values={'amount': 'log'},
        >>> )
        >>> x = PaddedBatch(
        >>>     payload={
        >>>         'mcc_code': torch.randint(0, 99, (B, T)),
        >>>         'amount': torch.randn(B, T),
        >>>     },
        >>>     length=torch.randint(10, 20, (B,)),
        >>> )
        >>> z = trx_encoder(x)
        >>> assert z.payload.shape == (5, 20, 6)  # B, T, H
    """
    def __init__(self,
                 embeddings=None,
                 numeric_values=None,
                 custom_embeddings=None,
                 embeddings_noise: float = 0,
                 norm_embeddings=None,
                 use_batch_norm=False,
                 use_batch_norm_with_lens=False,
                 clip_replace_value=None,
                 positions=None,
                 emb_dropout=0,
                 spatial_dropout=False,
                 orthogonal_init=False,
                 linear_projection_size=0,
                 out_of_index: str = 'clip',
                 ):
        if clip_replace_value is not None:
            warnings.warn('`clip_replace_value` attribute is deprecated. Always "clip to max" used. '
                          'Use `out_of_index="assert"` to avoid categorical values clip', DeprecationWarning)

        if positions is not None:
            warnings.warn('`positions` is deprecated. positions is not used', UserWarning)

        if embeddings is None:
            embeddings = {}
        if custom_embeddings is None:
            custom_embeddings = {}

        noisy_embeddings = {}
        for emb_name, emb_props in embeddings.items():
            if emb_props.get('disabled', False):
                continue
            if emb_props['in'] == 0 or emb_props['out'] == 0:
                continue
            noisy_embeddings[emb_name] = NoisyEmbedding(
                num_embeddings=emb_props['in'],
                embedding_dim=emb_props['out'],
                padding_idx=0,
                max_norm=1 if norm_embeddings else None,
                noise_scale=embeddings_noise,
                dropout=emb_dropout,
                spatial_dropout=spatial_dropout,
            )

        super().__init__(
            embeddings=noisy_embeddings,
            numeric_values=numeric_values,
            custom_embeddings=custom_embeddings,
            out_of_index=out_of_index,
        )
        self.swin_encoder = SwinTransformerV2Layer(
            num_heads=5,
            depth=3,
            dim=225,
            window_size=15
        )
        custom_embedding_size = self.custom_embedding_size
        if use_batch_norm and custom_embedding_size > 0:
            # :TODO: Should we use Batch norm with not-numerical custom embeddings?
            if use_batch_norm_with_lens:
                self.custom_embedding_batch_norm = RBatchNormWithLens(custom_embedding_size)
            else:
                self.custom_embedding_batch_norm = RBatchNorm(custom_embedding_size)
        else:
            self.custom_embedding_batch_norm = None

        if linear_projection_size > 0:
            self.linear_projection_head = torch.nn.Linear(super().output_size, linear_projection_size)
        else:
            self.linear_projection_head = None

        if orthogonal_init:
            for n, p in self.named_parameters():
                if n.startswith('embeddings.') and n.endswith('.weight'):
                    torch.nn.init.orthogonal_(p.data[1:])
                if n == 'linear_projection_head.weight':
                    torch.nn.init.orthogonal_(p.data)

    def forward(self, x: PaddedBatch, names=None, seq_len=None):
        if isinstance(x, PaddedBatch) is False:
            pre_x = dict()
            for i, field_name in enumerate(names):
                pre_x[field_name] = x[i]
            x = PaddedBatch(pre_x, seq_len)

        processed_embeddings = [self.get_category_embeddings(x, field_name)
                                for field_name in self.embeddings.keys()]
        processed_custom_embeddings = [self.get_custom_embeddings(x, field_name)
                                       for field_name in self.custom_embeddings.keys()]
        if len(processed_custom_embeddings):
            processed_custom_embeddings = torch.cat(processed_custom_embeddings, dim=2)
            if self.custom_embedding_batch_norm is not None:
                processed_custom_embeddings = PaddedBatch(processed_custom_embeddings, x.seq_lens)
                processed_custom_embeddings = self.custom_embedding_batch_norm(processed_custom_embeddings)
                processed_custom_embeddings = processed_custom_embeddings.payload
            processed_embeddings.append(processed_custom_embeddings)

        out = torch.cat(processed_embeddings, dim=2)
        out = out.type(TORCH_EMB_DTYPE)
        out = self.linear_projection_head(out) if self.linear_projection_head is not None else out
        # print(out.size())
        out = PaddedBatch(out, x.seq_lens)
        out = self.swin_encoder(out)
        return out

    @property
    def output_size(self):
        """Returns hidden size of output representation
        """
        if self.linear_projection_head is not None:
            return self.linear_projection_head.out_features
        return super().output_size

from ptls.nn.seq_encoder.rnn_encoder import RnnEncoder
from ptls.nn.seq_encoder.containers import SeqEncoderContainer


class SwinTransformerBackbone(nn.Module):
    """ Swin Transformer Backbone (4 stages as in orig. 2D impl.).

    Args:
        dim (int): Number of input channels.
        depths (list[int]): Numbers of blocks in stages.
        num_heads (int): Number of attention heads in W-MSA layers.
        start_window_size (int): Local window size of stage 1.
        window_size_mult (int): the number by which the `window_size` is being multiplied when moving to another stage
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        decoder (bool, optional): Flag that shows whether blocks in this backbone are decoder-like. True => decoder-like; False => encoder-like. Default: False
        start_end_fusion (bool, optional): Flag that shows if the last and the first half-windows should merge (True) or not (False).
    """
    def __init__(
        self,
        dim: int,
        depths: list[int],
        num_heads,
        start_window_size: int,
        window_size_mult: int = 1,
        mlp_ratio=4.,
        qkv_bias=True,
        qk_scale=None,
        drop=0.,
        attn_drop=0.,
        drop_path=0.,
        norm_layer=nn.LayerNorm,
        decoder=False,
        start_end_fusion=True
    ):
        super().__init__()
        self.dim = dim
        self.depths = depths
        
        if type(num_heads) == int:
            self.num_heads = [num_heads] * len(depths)
        else:
            self.num_heads = num_heads
        
        self.window_sizes = [start_window_size]
        
        for i in range(len(self.depths) - 1):
            self.window_sizes += [self.window_sizes[-1] * window_size_mult]

        # build model
        self.backbone = nn.ModuleList([
            SwinTransformerLayer(dim=self.dim,
                                 depth=self.depths[i],
                                 num_heads=self.num_heads[i],
                                 window_size=self.window_sizes[i],
                                 mlp_ratio=mlp_ratio,
                                 qkv_bias=qkv_bias,
                                 qk_scale=qk_scale,
                                 drop=drop,
                                 attn_drop=attn_drop,
                                 drop_path=drop_path,
                                 norm_layer=norm_layer,
                                 decoder=decoder,
                                 start_end_fusion=start_end_fusion)
            for i in range(len(self.depths))])

    def forward(self, x):
        for layer in self.backbone:
            x = layer(x)
        return x


class SWIN_RNN_SeqEncoder(SeqEncoderContainer):
    """SeqEncoderContainer with SWIN transformer backbone for features hierarchic fusion and RnnEncoder for feature aggregation.
    
    Parameters
        trx_encoder:
            TrxEncoder object
        input_size:
            input_size parameter for RnnEncoder
            If None: input_size = trx_encoder.output_size
            Set input_size explicitly or use None if your trx_encoder object has output_size attribute
        is_reduce_sequence:
            False - returns PaddedBatch with all transactions embeddings
            True - returns one embedding for sequence based on CLS token
        swin_depths: Numbers of blocks in stages (SWIN backbone).
        swin_num_heads: Number of attention heads in W-MSA layers (SWIN backbone).
        swin_start_window_size: Local window size of stage 1 (SWIN backbone).
        swin_window_size_mult (int): the number by which the `window_size` is being multiplied when moving to another stage (SWIN backbone).
        swin_drop: Dropout rate (SWIN backbone). Default: 0.0
        swin_attn_drop: Attention dropout rate (SWIN backbone). Default: 0.0
        swin_drop_path: Stochastic depth rate (SWIN backbone). Default: 0.0
        swin_decoder: Flag that shows whether blocks in SWIN backbone are decoder-like. True => decoder-like; False => encoder-like. Default: False
        swin_start_end_fusion: Flag that shows if the last and the first half-windows should merge (True) or not (False). Must be False for CPC and GPT.
        **rnn_seq_encoder_params:
            RnnEncoder params
    """
    def __init__(self,
                 trx_encoder=None,
                 input_size=None,
                 is_reduce_sequence=True,
                 swin_depths=[],
                 swin_num_heads=4,
                 swin_start_window_size=4,
                 swin_window_size_mult=1,
                 swin_drop=0.,
                 swin_attn_drop=0.,
                 swin_drop_path=0.,
                 swin_decoder=False,
                 swin_start_end_fusion=True,
                 **rnn_seq_encoder_params
                 ):
        super().__init__(
            trx_encoder=trx_encoder,
            seq_encoder_cls=RnnEncoder,
            input_size=input_size,
            seq_encoder_params=rnn_seq_encoder_params,
            is_reduce_sequence=is_reduce_sequence,
        )
        self.swin_fusion = SwinTransformerBackbone(
                               dim=trx_encoder.output_size,
                               depths=swin_depths,
                               num_heads=swin_num_heads,
                               start_window_size=swin_start_window_size,
                               window_size_mult=swin_window_size_mult,
                               drop=swin_drop,
                               attn_drop=swin_attn_drop,
                               drop_path=swin_drop_path,
                               decoder=swin_decoder,
                               start_end_fusion=swin_start_end_fusion 
                              )

    def forward(self, x, names=None, seq_len=None, h_0=None):
        x = self.trx_encoder(x)
        x = self.swin_fusion(x)
        x = self.seq_encoder(x, h_0)
        return x

In [None]:
from ptls.nn import TabFormerFeatureEncoder
from ptls.nn import TransformerEncoder
from ptls.nn import TrxEncoder
import ptls

# trx_encoder_params = dict(
#     embeddings_noise=0.003,
#     numeric_values={'transaction_amt': 'identity',
#                     'days_from_prev_tr': 'identity',
#                    },
#     embeddings={
#         'currency_rk': {'in': 5, 'out': 8},
#         'day_of_week': {'in': 8, 'out': 8},
#         'mcc_code': {'in': 333, 'out': 16},
#         },
#     )

trx_encoder = ptls.nn.TrxEncoder(
            embeddings_noise=0.003,
            # numeric_values={'transaction_amt': 'identity',
            #         'days_from_prev_tr': 'identity',
            #        },
            embeddings={
            'transaction_amt': {'in': 11 , 'out': 16},
            'days_from_prev_tr': {'in':10 , 'out': 16},
            'currency_rk': {'in': 5, 'out': 16},
            'day_of_week': {'in': 8, 'out': 16},
            'mcc_code': {'in': 333, 'out': 16},
            },
        )

# seq_encoder_params = dict(
#         trx_encoder=TrxEncoder(**trx_encoder_params),
#         hidden_size=512,
#         type='gru',
# )

optimizer_partial = partial(
    torch.optim.AdamW,
    lr=1e-3,
    weight_decay=1e-4
)

lr_scheduler_partial = partial(
    torch.optim.lr_scheduler.StepLR,
    step_size=2,
    gamma=0.9025
)


seq_encoder = SWIN_RNN_SeqEncoder(
    trx_encoder=trx_encoder,
    swin_depths=[2, 2, 6, 2],
    swin_num_heads=[2, 4, 8, 16],
    swin_start_window_size=4,
    swin_window_size_mult=2,
    swin_drop=0.1,
    swin_attn_drop=0.1,
    swin_drop_path=0.1,
    swin_decoder=True,
    swin_start_end_fusion=False,
    hidden_size=512,
    type="gru")

pl_module = ptls.frames.coles.CoLESModule(
    # validation_metric=ptls.frames.coles.metric.BatchRecallTopK(
    #     K=4,
    #     metric='cosine'
    # ),
    seq_encoder=seq_encoder,
    optimizer_partial=optimizer_partial,
    lr_scheduler_partial=lr_scheduler_partial
)

In [None]:
from lightning.pytorch.loggers import WandbLogger

wandb_logger = WandbLogger(project="MBD_My_Code", log_model="all", name="df24_coles_swin")

trainer = pl.Trainer(
    logger=wandb_logger,
    max_epochs=20,
    accelerator="cuda" if torch.cuda.is_available() else "cpu",
    enable_progress_bar=True,
    gradient_clip_val=0.3,
    log_every_n_steps=50,
    limit_val_batches=32
)

In [None]:
trainer.fit(pl_module, train_dl)

In [None]:
inference_module = InferenceModule(
        model=pl_module,
        pandas_output=True,
        drop_seq_features=True)
        # model_out_name=f'emb_{random_seed}')

predict = trainer.predict(inference_module, inference_dl)

inference_embeddings = pd.concat(predict, axis=0)

inference_embeddings_t = inference_embeddings.merge(test_[['user_id', 'target']], how='left', on='user_id')

inf_emb_train, inf_emb_test = train_test_split(inference_embeddings_t, random_state=42, test_size=0.2)

X_train, y_train = inf_emb_train.drop(columns=['user_id', 'target']), inf_emb_train['target']
X_test, y_test = inf_emb_test.drop(columns=['user_id', 'target']), inf_emb_test['target']

In [None]:
from lightgbm import LGBMClassifier

down_model = LGBMClassifier(
    n_estimators=700,
    boosting_type='gbdt',
    subsample=0.5,
    subsample_freq=1,
    learning_rate=0.02,
    feature_fraction=0.75,
    max_depth=6,
    lambda_l1=1,
    lambda_l2=1,
    min_data_in_leaf=50,
    random_state=42,
    n_jobs=8,
    verbose=-1
)

In [None]:
%%time

down_model.fit(X_train, y_train)

from sklearn.metrics import accuracy_score, roc_auc_score, f1_score

predict = down_model.predict_proba(X_test)
print(f"ROC-AUC target = {roc_auc_score(y_test, predict[:, 1])}")

predict = down_model.predict(X_test)
print(f"accuracy = {accuracy_score(y_test, predict)}")

# Results

**Baseline (CoLES)**

ROC-AUC target = 0.6885229034842104

accuracy = 0.91

              precision    recall  f1-score   support

           0       0.91      1.00      0.95      2326
           
           1       0.69      0.05      0.09       234

    accuracy                           0.91      2560
    macro avg       0.80      0.52      0.52      2560
    weighted avg       0.89      0.91      0.87      2560


**CoLES + regional-attention**

ROC-AUC target = 0.705245790800391

accuracy = 0.91

              precision    recall  f1-score   support

           0       0.91      1.00      0.95      2326
           1       0.58      0.05      0.09       234

    accuracy                           0.91      2560
    macro avg       0.75      0.52      0.52      2560
    weigh avg       0.88      0.91      0.87      2560

**CoLES + cross-attention**

small patches = 4

large patches = 15

ROC-AUC target = 0.6635

accuracy = 0.90

**CoLES + cross-attention**

small patches = 5

large patches = 16

ROC-AUC target = 0.6513

accuracy = 0.90

**CoLES + cross-attention**

small patches = 3

large patches = 12

ROC-AUC target = 0.6671

accuracy = 0.91

**GPT Baseline**

ROC-AUC target = 0.7306271092054157

accuracy = 0.9293908996750897

**GPT With SWIN enc**

ROC-AUC target = 0.7012497152222003

accuracy = 0.909765625