In [None]:
#Commands for downgrading Python version in Kaggle notebook

#adds external APT repository
!add-apt-repository -y 'ppa:deadsnakes/ppa'

#installs specific python version
!apt install -y python3.10

#installs distutils package for the specific python version
!apt-get install -y python3.10-distutils

#Install Python modules using the newly installed Python version
#-I, Ignores and overwrites the installed packages.
!sudo /usr/bin/python3.10 -m pip install -Iv <package-name>

#Run Python scripts using the newly installed Python version
!/usr/bin/python3.10 <python_script.py>

In [None]:
!pip install lightning
!pip install pyspark
!pip install pytorch-lifestream

In [None]:
import wandb

wandb.login(key="79f2120f8d4212aceb2c60b3c89a1b6727c19cff")

In [None]:
from huggingface_hub import hf_hub_download 

hf_hub_download(repo_id="ai-lab/MBD-mini", filename="ptls.tar.gz", repo_type="dataset", local_dir="/kaggle/working/")
hf_hub_download(repo_id="ai-lab/MBD-mini", filename="targets.tar.gz", repo_type="dataset", local_dir="/kaggle/working/")

In [None]:
!tar -xf ptls.tar.gz
!tar -xf targets.tar.gz

In [None]:
import pandas as pd
import numpy as np
import os

import pyspark
from pyspark.sql import SparkSession
# import pyspark.sql.functions as F
from pyspark.sql import types as T
import time
import datetime
from ptls.data_load.datasets import ParquetDataset, ParquetFiles
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, FloatType, ArrayType
from tqdm.notebook import tqdm
from ptls.preprocessing import PysparkDataPreprocessor
import pytorch_lightning as pl
from ptls.data_load.datasets import MemoryMapDataset
from ptls.data_load.iterable_processing import SeqLenFilter, FeatureFilter
from ptls.data_load.iterable_processing.iterable_seq_len_limit import ISeqLenLimit
from ptls.data_load.iterable_processing.to_torch_tensor import ToTorch
from ptls.frames.coles import CoLESModule
from ptls.frames import PtlsDataModule
from ptls.frames.coles import ColesDataset
from ptls.frames.coles.split_strategy import SampleSlices
import torch
import numpy as np
import pandas as pd
import calendar
from glob import glob
from ptls.data_load.utils import collate_feature_dict

from ptls.data_load.iterable_processing_dataset import IterableProcessingDataset
from datetime import datetime
from ptls.data_load.padded_batch import PaddedBatch

In [None]:
SEED = 0
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

In [None]:
spark_conf = pyspark.SparkConf()
spark_conf.setMaster("local[*]").setAppName("JoinModality")
spark_conf.set("spark.driver.maxResultSize", "16g")
spark_conf.set("spark.executor.memory", "24g")
spark_conf.set("spark.executor.memoryOverhead", "16g")
spark_conf.set("spark.driver.memory", "16g")
spark_conf.set("spark.driver.memoryOverhead", "16g")
spark_conf.set("spark.cores.max", "4")
spark_conf.set("spark.sql.shuffle.partitions", "200")
spark_conf.set("spark.local.dir", "../../spark_local_dir")


spark = SparkSession.builder.config(conf=spark_conf).getOrCreate()
spark.sparkContext.getConf().getAll()

In [None]:
!mkdir /kaggle/working/ptls/trx_supervised

In [None]:
TARGETS_DATA_PATH = '/kaggle/working/targets/'
TRX_DATA_PATH = '/kaggle/working/ptls/trx/'
TRX_SUPERVISED_PATH = '/kaggle/working/ptls/trx_supervised/'

preprocessor_target = PysparkDataPreprocessor(
    col_id="client_id",
    col_event_time="mon",
    event_time_transformation="dt_to_timestamp",
    cols_identity=["target_1", "target_2", "target_3", "target_4"],
)

In [None]:
for fold in range(5):
    targets = spark.read.parquet(os.path.join(TARGETS_DATA_PATH , f'fold={fold}'))
    trx = spark.read.parquet(os.path.join(TRX_DATA_PATH , f'fold={fold}'))
    
    targets = preprocessor_target.fit_transform(targets).drop(*['event_time' ,'trans_count', 'diff_trans_date'])
    trx = trx.join(targets, on='client_id', how='left')
    trx.write.mode('overwrite').parquet(os.path.join(TRX_SUPERVISED_PATH, f'fold={fold}'))

In [None]:
spark.stop()

In [None]:
import pandas as pd
import numpy as np
from ptls.data_load.iterable_processing_dataset import IterableProcessingDataset
from ptls.data_load import IterableChain
from datetime import datetime
from ptls.data_load.datasets.parquet_dataset import ParquetDataset, ParquetFiles
from ptls.data_load.iterable_processing.feature_filter import FeatureFilter
from ptls.data_load.iterable_processing.to_torch_tensor import ToTorch
import torch
from functools import partial
from torch.utils.data import DataLoader
from ptls.data_load.padded_batch import PaddedBatch
from ptls.data_load.utils import collate_feature_dict
from tqdm import tqdm
import ptls
import torch.nn.functional as F

class TargetToTorch(IterableProcessingDataset):
    def __init__(self, col_target):
        super().__init__()
        self.col_target = col_target

    def __iter__(self):
        for rec in self._src:
            features = rec[0] if type(rec) is tuple else rec
            features[self.col_target] = np.stack(np.array(features[self.col_target]))
            features[self.col_target] = torch.tensor(features[self.col_target])
            yield features

class DeleteNan(IterableProcessingDataset):
    def __init__(self, col_name):
        super().__init__()
        self.col_name = col_name
    
    def __iter__(self):
        for rec in self._src:
            features = rec[0] if type(rec) is tuple else rec
            if features[self.col_name] is not None:
                yield features


class DialToTorch(IterableProcessingDataset):
    def __init__(self, col_time, col_embeds):
        super().__init__()
        self._year=2022
        self.col_embeds = col_embeds
        self.col_time = col_time
    def __iter__(self):
        for rec in self._src:
            features = rec[0] if type(rec) is tuple else rec
            features = features.copy()
            if features[self.col_time] is None:
                features[self.col_time] = torch.tensor([0])
            if features[self.col_embeds] is None:
                features[self.col_embeds] = torch.zeros(768)
            
            for key, tens in features.items():
                if key == self.col_embeds:
                    features[key] = torch.tensor(tens.tolist())

            yield features

class GetSplit(IterableProcessingDataset):
    def __init__(
        self,
        start_month,
        end_month,
        year=2022,
        col_id='client_id',
        col_time='event_time'
    ):
        super().__init__()
        self.start_month = start_month
        self.end_month = end_month
        self._year = year
        self._col_id = col_id
        self._col_time = col_time
        
    def __iter__(self):
        for rec in self._src:
            for month in range(self.start_month, self.end_month+1):
                features = rec[0] if type(rec) is tuple else rec
                features = features.copy()
                
                if month == 12:
                    month_event_time = datetime(self._year + 1, 1, 1).timestamp()
                else:
                    month_event_time = datetime(self._year, month + 1, 1).timestamp()
                    
                year_event_time = datetime(self._year, 1, 1).timestamp()
                
                mask = features[self._col_time] < month_event_time
                for key, tensor in features.items():
                    if key.startswith('target'):
                        features[key] = tensor[month - 1].tolist()    
                    elif key != self._col_id:
                        features[key] = tensor[mask] 
                            
                features[self._col_id] += '_month=' + str(month)

                yield features

class GetPatches(IterableProcessingDataset):
    def __init__(
        self,
        col_id='client_id',
        patches_size=12
    ):
        super().__init__()
        self._col_id = col_id
        self._patches_size = patches_size

    def __iter__(self):
        for rec in self._src:
            features = rec[0] if type(rec) is tuple else rec
            features = features.copy()

            patches_dict = {}
            for key, tensor in features.items():
                if key == self._col_id:
                    patches_dict[key] = tensor
                else:
                    patches = list(torch.split(tensor, self._patches_size, dim=0))
                    if patches[-1].size()[0] != self._patches_size:
                        pad_size = self._patches_size - patches[-1].size()[0]
                        padding_tensor = torch.Tensor([0] * pad_size).type(tensor.dtype)
                        patches[-1] = torch.cat((patches[-1], padding_tensor))
        
                    patches = torch.stack(patches)
                    patches_dict[key] = patches

            yield features, patches_dict
                
                


# class Get2Patches(IterableProcessingDataset):
#     def __init__(
#         self,
#         col_id='client_id',
#         small_patches_size=3,
#         large_patches_size=12
#     ):
#         super().__init__()
#         self._col_id = col_id
#         self._small_patches_size = small_patches_size
#         self._large_patches_size = large_patches_size

#     def __iter__(self):
#         for rec in self._src:
#             features = rec[0] if type(rec) is tuple else rec
#             features = features.copy()

#             for key, tensor in features.items():
#                 small_patches = list(torch.split(tensor, self._small_patches_size, dim=1))
#                 large_patches = list(torch.split(tensor, self._small_patches_size, dim=1))
#                 if small_patches[-1].size()[1] != self._small_patches_size:
#                     pad_size = self._small_patches_size - small_patches[-1].size()[1]
#                     small_patches[-1] = F.pad(small_patches[-1], (0, 0, pad_size, 0), "replicate")
    
                
#                 if large_patches[-1].size()[1] != self._large_patches_size:
#                     pad_size = self._large_patches_size - large_patches[-1].size()[1]
#                     large_patches[-1] = F.pad(large_patches[-1], (0, 0, pad_size, 0), "replicate")
    
#                 small_patches = torch._stack(small_patches)
#                 large_patches = torch._stack(large_patches)
            
            
    
#             large_comp_embed = torch.zeros(x_new.size()[0], 128).to(x.device)

class QuantilfyAmount(IterableProcessingDataset):
    def __init__(self, col_amt='amount', quantilies=None):
        super().__init__()
        self.col_amt = col_amt
        if quantilies is None:
            self.quantilies = [0., 267.6, 1198.65, 3667.2, 8639.8, 18325.7, 36713.2, 68950.3, 143969.1, 421719.1]
        else: 
            self.quantilies = quantilies
    
    def __iter__(self):
        for rec in self._src:
            features = rec[0] if type(rec) is tuple else rec
            amount = features[self.col_amt]
            am_quant = torch.zeros(len(amount), dtype=torch.int)
            for i, q in enumerate(self.quantilies):
                am_quant = torch.where(amount>q, i, am_quant)
            features[self.col_amt] = am_quant
            yield features
            

In [None]:
train = ptls.data_load.datasets.ParquetDataset(
        data_files=[
            os.path.join(TRX_SUPERVISED_PATH, 'fold=0'),
            os.path.join(TRX_SUPERVISED_PATH, 'fold=1'),
            os.path.join(TRX_SUPERVISED_PATH, 'fold=2'),
        ],
    i_filters=[
        ptls.data_load.iterable_processing.SeqLenFilter(min_seq_len=16),
        ptls.data_load.iterable_processing.SeqLenFilter(max_seq_len=2048),
        ptls.data_load.iterable_processing.feature_filter.FeatureFilter(
            drop_feature_names=[
                'client_id',
                'target_1',
                'target_2',
                'target_3',
                'target_4'
            ]
        ),
        # QuantilfyAmount(),
        ptls.data_load.iterable_processing.CategorySizeClip(
            category_max_size={
                'event_type': 58,
                'event_subtype' :59,
                'src_type11': 85,
                'src_type12': 349,
                'dst_type11': 84,
                'dst_type12': 417,
                'src_type22': 90,
                'src_type32': 91,
                'amount': 10
            }
        ),
        ptls.data_load.iterable_processing.to_torch_tensor.ToTorch(),
        
        
        # GetPatches()
    ],
    shuffle_files=True
)

valid = ptls.data_load.datasets.ParquetDataset(
        data_files=[
            os.path.join(TRX_DATA_PATH, 'fold=3')
        ],
    i_filters=[
        ptls.data_load.iterable_processing.SeqLenFilter(min_seq_len=16),
        ptls.data_load.iterable_processing.SeqLenFilter(max_seq_len=2048),
        ptls.data_load.iterable_processing.feature_filter.FeatureFilter(
            drop_feature_names=[
                'client_id',
                'target_1',
                'target_2',
                'target_3',
                'target_4'
            ]
        ),
        # QuantilfyAmount(),
        ptls.data_load.iterable_processing.CategorySizeClip(
            category_max_size={
                'event_type': 58,
                'event_subtype' :59,
                'src_type11': 85,
                'src_type12': 349,
                'dst_type11': 84,
                'dst_type12': 417,
                'src_type22': 90,
                'src_type32': 91,
                'amount': 10
            }
        ),
        ptls.data_load.iterable_processing.to_torch_tensor.ToTorch(),
        
        # GetPatches()
    ]
)

# GPT Baseline

In [None]:
from ptls.frames.gpt.gpt_dataset import GptIterableDataset

data_module = ptls.frames.PtlsDataModule(
    train_data=GptIterableDataset(
        data=train,
        min_len=80,
        max_len=300
    ),
    valid_data=GptIterableDataset(
        data=valid,
        min_len=80,
        max_len=300
    ),
    train_batch_size=128,
    train_num_workers=0,
    valid_batch_size=64,
    valid_num_workers=0
)

**Gpt Module**

In [None]:
import pytorch_lightning as pl
import torch
from torch import nn
import warnings
from torchmetrics import MeanMetric
from typing import Tuple, Dict, List, Union

from ptls.nn.seq_encoder.abs_seq_encoder import AbsSeqEncoder
from ptls.nn import PBL2Norm
from ptls.data_load.padded_batch import PaddedBatch
from ptls.custom_layers import StatPooling, GEGLU
from ptls.nn.seq_step import LastStepEncoder

class Head(nn.Module):   
    def __init__(self, input_size, n_classes, hidden_size=64, drop_p=0.1):
        super().__init__()
        self.head = nn.Sequential(
            nn.Linear(input_size, hidden_size, bias=True),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(hidden_size, n_classes)
        )
    def forward(self, x):
        x = self.head(x)
        return x

class GptPretrainModule(pl.LightningModule):
    """GPT2 Language model

    Original sequence are encoded by `TrxEncoder`.
    Model `seq_encoder` predicts embedding of next transaction.
    Heads are used to predict each feature class of future transaction.

    Parameters
    ----------
    trx_encoder:
        Module for transform dict with feature sequences to sequence of transaction representations
    seq_encoder:
        Module for sequence processing. Generally this is transformer based encoder. Rnn is also possible
        Should works without sequence reduce
    head_hidden_size:
        Hidden size of heads for feature prediction
    seed_seq_len:
         Size of starting sequence without loss 
    total_steps:
        total_steps expected in OneCycle lr scheduler
    max_lr:
        max_lr of OneCycle lr scheduler
    weight_decay:
        weight_decay of Adam optimizer
    pct_start:
        % of total_steps when lr increase
    norm_predict:
        use l2 norm for transformer output or not
    inference_pooling_strategy:
        'out' - `seq_encoder` forward (`is_reduce_requence=True`) (B, H)
        'out_stat' - min, max, mean, std statistics pooled from `seq_encoder` layer (B, H) -> (B, 4H)
        'trx_stat' - min, max, mean, std statistics pooled from `trx_encoder` layer (B, H) -> (B, 4H)
        'trx_stat_out' - min, max, mean, std statistics pooled from `trx_encoder` layer + 'out' from `seq_encoder` (B, H) -> (B, 5H)
    """

    def __init__(self,
                 trx_encoder: torch.nn.Module,
                 seq_encoder: AbsSeqEncoder,
                 head_hidden_size: int = 64,
                 total_steps: int = 64000,
                 seed_seq_len: int = 16,
                 max_lr: float = 0.00005,
                 weight_decay: float = 0.0,
                 pct_start: float = 0.1,
                 norm_predict: bool = False,
                 inference_pooling_strategy: str = 'out_stat'
                 ):

        super().__init__()
        self.save_hyperparameters(ignore=['trx_encoder', 'seq_encoder'])

        self.trx_encoder = trx_encoder
        # assert self.trx_encoder.embeddings, '`embeddings` parameter for `trx_encoder` should contain at least 1 feature!'

        self._seq_encoder = seq_encoder
        self._seq_encoder.is_reduce_sequence = False

        self.head = nn.ModuleDict()
        for col_name, noisy_emb in self.trx_encoder.embeddings.items():
            self.head[col_name] = Head(input_size=self._seq_encoder.embedding_size, hidden_size=head_hidden_size, n_classes=noisy_emb.num_embeddings)

        if self.hparams.norm_predict:
            self.fn_norm_predict = PBL2Norm()

        self.loss = nn.CrossEntropyLoss(ignore_index=0)

        self.train_gpt_loss = MeanMetric()
        self.valid_gpt_loss = MeanMetric()

    def forward(self, batch: PaddedBatch):
        # print(f"{batch.payload['amount']}")
        z_trx = self.trx_encoder(batch) 
        out = self._seq_encoder(z_trx)
        if self.hparams.norm_predict:
            out = self.fn_norm_predict(out)
        return out
    
    def loss_gpt(self, predictions, labels, is_train_step):
        loss = 0
        for col_name, head in self.head.items():
            y_pred = head(predictions[:, self.hparams.seed_seq_len:-1, :])
            y_pred = y_pred.view(-1, y_pred.size(-1))

            y_true = labels[col_name][:, self.hparams.seed_seq_len+1:]
            y_true = torch.flatten(y_true.long())

            loss += self.loss(y_pred, y_true)
        return loss

    def training_step(self, batch, batch_idx):
        out = self.forward(batch)  # PB: B, T, H
        out = out.payload if isinstance(out, PaddedBatch) else out
        labels = batch.payload

        loss_gpt = self.loss_gpt(out, labels, is_train_step=True)
        self.train_gpt_loss(loss_gpt)
        self.log(f'gpt/loss', loss_gpt, sync_dist=True)
        return loss_gpt

    def validation_step(self, batch, batch_idx):
        out = self.forward(batch)  # PB: B, T, H
        out = out.payload if isinstance(out, PaddedBatch) else out
        labels = batch.payload

        loss_gpt = self.loss_gpt(out, labels, is_train_step=False)
        self.valid_gpt_loss(loss_gpt)

    def on_training_epoch_end(self):
        self.log(f'gpt/train_gpt_loss', self.train_gpt_loss, prog_bar=False, sync_dist=True, rank_zero_only=True)
        # self.train_gpt_loss reset not required here

    def on_validation_epoch_end(self):
        self.log(f'gpt/valid_gpt_loss', self.valid_gpt_loss, prog_bar=True, sync_dist=True, rank_zero_only=True)
        # self.valid_gpt_loss reset not required here

    def configure_optimizers(self):
        optim = torch.optim.NAdam(self.parameters(),
                                 lr=self.hparams.max_lr,
                                 weight_decay=self.hparams.weight_decay,
                                 )
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer=optim,
            max_lr=self.hparams.max_lr,
            total_steps=self.hparams.total_steps,
            pct_start=self.hparams.pct_start,
            anneal_strategy='cos',
            cycle_momentum=False,
            div_factor=25.0,
            final_div_factor=10000.0,
            three_phase=False,
        )
        scheduler = {'scheduler': scheduler, 'interval': 'step'}
        return [optim], [scheduler]
    
    @property
    def seq_encoder(self):
        return GPTInferenceModule(pretrained_model=self)

class GPTInferenceModule(torch.nn.Module):
    def __init__(self, pretrained_model):
        super().__init__()
        self.model = pretrained_model
        self.model.is_reduce_sequence = False

        self.stat_pooler = StatPooling()
        self.last_step = LastStepEncoder()

    def forward(self, batch):
        z_trx = self.model.trx_encoder(batch)
        out = self.model._seq_encoder(z_trx)
        out = out if isinstance(out, PaddedBatch) else PaddedBatch(out, batch.seq_lens)
        if self.model.hparams.inference_pooling_strategy=='trx_stat_out':
            stats = self.stat_pooler(z_trx)
            out = self.last_step(out)
            out = torch.cat([stats, out], dim=1)
        elif self.model.hparams.inference_pooling_strategy=='trx_stat':
            out = self.stat_pooler(z_trx)
        elif self.model.hparams.inference_pooling_strategy=='out_stat':
            out = self.stat_pooler(out)
        elif self.model.hparams.inference_pooling_strategy=='out':
            out = self.last_step(out)
        else:
            raise
        if self.model.hparams.norm_predict:
            out = out / (out.pow(2).sum(dim=-1, keepdim=True) + 1e-9).pow(0.5)
        return out

class CorrGptPretrainModule(GptPretrainModule):
    def __init__(self, feature_encoder, seq_encoder, trx_encoder, *args, **kwargs):
        trx_encoder.numeric_values = {}
        super().__init__(trx_encoder=trx_encoder, seq_encoder=seq_encoder, *args, **kwargs)
        self.save_hyperparameters(ignore=['trx_encoder', 'seq_encoder', 'feature_encoder'])
        self.feature_encoder = feature_encoder

    def forward(self, batch):
        # print(f"{batch.payload['event_time'].size()=}")
        z_trx = self.trx_encoder(batch)
        # print(z_trx.payload.size())
        payload = z_trx.payload.view(z_trx.payload.shape[:-1] + (-1, 24))
        payload = self.feature_encoder(payload)
        encoded_trx = PaddedBatch(payload=payload, length=z_trx.seq_lens)
        out = self._seq_encoder(encoded_trx)
        # print(out.payload.size())
        if self.hparams.norm_predict:
            out = self.fn_norm_predict(out)
        return out

In [None]:
from ptls.nn import TabFormerFeatureEncoder
from ptls.nn import TransformerEncoder
from ptls.nn import TrxEncoder

feature_encoder = TabFormerFeatureEncoder(
    n_cols=9,
    emb_dim=24
)

seq_encoder = TransformerEncoder(
    n_heads=2,
    n_layers=2,
    input_size=216,
    use_positional_encoding=True
)

trx_encoder = TrxEncoder(
            norm_embeddings=False,
            embeddings_noise=0.0,
            embeddings={
                'event_type': {"in": 58, "out": 24},
                'event_subtype': {"in": 59, "out": 24},
                'src_type11': {"in": 85, "out": 24},
                'src_type12': {"in": 349, "out": 24},
                'dst_type11': {"in": 84, "out": 24},
                'dst_type12': {"in": 417, "out": 24},
                'src_type22': {"in": 90, "out": 24},
                'src_type32': {"in": 91, "out": 24},
                'amount': {"in": 11, "out": 24}
            },
            # numeric_values={'amount': 'log'}
        )

pl_module = CorrGptPretrainModule(
    # total_steps=20000,
    feature_encoder=feature_encoder,
    trx_encoder=trx_encoder,
    seq_encoder=seq_encoder
)
# pl_module = GptPretrainModule(
#     total_steps=20000,
#     trx_encoder=trx_encoder,
#     seq_encoder=seq_encoder
# )

In [None]:
from lightning.pytorch.loggers import WandbLogger

wandb_logger = WandbLogger(project="MBD_My_Code", log_model="all")

trainer = pl.Trainer(
    logger=wandb_logger,
    max_epochs=20,
    accelerator="cuda" if torch.cuda.is_available() else "cpu",
    enable_progress_bar=True,
    gradient_clip_val=0.5,
    log_every_n_steps=50,
    limit_val_batches=32
)

In [None]:
trainer.fit(pl_module, data_module)

In [None]:
inference_dataset = ptls.data_load.datasets.ParquetDataset(
        data_files=[
            os.path.join(TRX_SUPERVISED_PATH, 'fold=4')
        ],
    i_filters=[
        ptls.data_load.iterable_processing.SeqLenFilter(min_seq_len=16),
        ptls.data_load.iterable_processing.SeqLenFilter(max_seq_len=2048),
        ptls.data_load.iterable_processing.feature_filter.FeatureFilter(
            keep_feature_names=[
                'client_id',
                'target_1',
                'target_2',
                'target_3',
                'target_4'
            ]
        ),
        QuantilfyAmount(),
        ptls.data_load.iterable_processing.CategorySizeClip(
            category_max_size={
                'event_type': 58,
                'event_subtype' :59,
                'src_type11': 85,
                'src_type12': 349,
                'dst_type11': 84,
                'dst_type12': 417,
                'src_type22': 90,
                'src_type32': 91,
                'amount': 10
            }
        ),
        GetSplit(
            start_month=1,
            end_month=12,
            col_id='client_id'
        ),
        ptls.data_load.iterable_processing.to_torch_tensor.ToTorch(),
    ],
    shuffle_files=True
)

In [None]:
inference_dl = DataLoader(
    dataset=inference_dataset,
    shuffle=False,
    num_workers=0,
    batch_size=32
)

In [None]:
import pandas as pd
import pytorch_lightning as pl
import torch
import numpy as np
from itertools import chain
from ptls.data_load.padded_batch import PaddedBatch
from datetime import datetime
from ptls.custom_layers import StatPooling
from ptls.nn.seq_step import LastStepEncoder


class InferenceModuleMultimodal(pl.LightningModule):
    def __init__(
        self,
        model,
        pandas_output=True,
        col_id='client_id',
        target_col_names=None,
        model_out_name='emb',
        model_type='notab'
    ):
        super().__init__()

        self.model = model
        self.pandas_output = pandas_output
        self.target_col_names = target_col_names
        self.col_id = col_id
        self.model_out_name = model_out_name
        self.model_type = model_type

        self.stat_pooler = StatPooling()
        self.last_step = LastStepEncoder()

    def forward(self, x):
        x_len = len(x)
        if x_len == 3:
            x, batch_ids, target_cols = x
        else: 
            x, batch_ids = x
        if 'seq_encoder' in dir(self.model):
            out = self.model.seq_encoder(x)
        else:
            out = self.model(x)
            
        if x_len == 3:
            target_cols = torch.Tensor(target_cols)
            
            x_out = {
                self.col_id: batch_ids,
                self.model_out_name: out
            }
            if len(target_cols.size()) > 1:
                for idx, target_col in enumerate(self.target_col_names):
                    x_out[target_col] = target_cols[:, idx]
            else: 
                x_out[self.target_col_names[0]] = target_cols[::4]
        else:
            x_out = {
                self.col_id: batch_ids,
                self.model_out_name: out
            }

        if self.pandas_output:
            return self.to_pandas(x_out)
        return x_out

    @staticmethod
    def to_pandas(x):
        expand_cols = []
        scalar_features = {}

        for k, v in x.items():
            if type(v) is torch.Tensor:
                v = v.cpu().numpy()

            if type(v) is list or len(v.shape) == 1:
                scalar_features[k] = v
            elif len(v.shape) == 2:
                expand_cols.append(k)
            else:
                scalar_features[k] = None
        # print(scalar_features)
        dataframes = [pd.DataFrame(scalar_features)]
        for col in expand_cols:
            v = x[col].cpu().numpy()
            dataframes.append(pd.DataFrame(v, columns=[f'{col}_{i:04d}' for i in range(v.shape[1])]))

        return pd.concat(dataframes, axis=1)

def collate_feature_dict_with_target(batch, col_id='client_id', target_col_names=None):
    batch_ids = []
    target_cols = []
    for sample in batch:
        batch_ids.append(sample[col_id])
        del sample[col_id]
        
        if target_col_names is not None:
            sample_targets = []
            for target_col in target_col_names:
                sample_targets.append(sample[target_col])
                del sample[target_col]
            target_cols.append(sample_targets)
                
            
    padded_batch = collate_feature_dict(batch)
    if target_col_names is not None:
        return padded_batch, batch_ids, target_cols
    return padded_batch, batch_ids

In [None]:
target_col_names = [
    'target_1',
    'target_2',
    'target_3',
    'target_4'
]

collate_fn = partial(
    collate_feature_dict_with_target,
    target_col_names=target_col_names
)

inference_dl = DataLoader(
    dataset=inference_dataset,
    collate_fn=collate_fn,
    shuffle=False,
    num_workers=0,
    batch_size=32
)

In [None]:
inf_module = InferenceModuleMultimodal(
    model=pl_module,
    pandas_output=True,
    col_id='client_id',
    target_col_names=target_col_names
)

In [None]:
inf_embeddings = pd.concat(trainer.predict(inf_module, inference_dl))

In [None]:
inf_embeddings

In [None]:
from sklearn.model_selection import train_test_split

dwns_train, dwns_test = train_test_split(inf_embeddings, test_size=0.2)

In [None]:
targets_train = np.array(
    [
        dwns_train['target_1'].to_numpy(),
        dwns_train['target_2'].to_numpy(),
        dwns_train['target_3'].to_numpy(),
        dwns_train['target_4'].to_numpy()
    ]
).T
targets_test = np.array(
    [
        dwns_test['target_1'].to_numpy(),
        dwns_test['target_2'].to_numpy(),
        dwns_test['target_3'].to_numpy(),
        dwns_test['target_4'].to_numpy()
    ]
).T

dwns_train = dwns_train.drop(columns=[
    'client_id',
    'target_1',
    'target_2',
    'target_3',
    'target_4'
]).to_numpy()

dwns_test = dwns_test.drop(columns=[
    'client_id',
    'target_1',
    'target_2',
    'target_3',
    'target_4'
]).to_numpy()

In [None]:
from lightgbm import LGBMClassifier

models = [LGBMClassifier(
    n_estimators=500,
    boosting_type='gbdt',
    subsample=0.5,
    subsample_freq=1,
    learning_rate=0.02,
    feature_fraction=0.75,
    max_depth=6,
    lambda_l1=1,
    lambda_l2=1,
    min_data_in_leaf=50,
    random_state=42,
    n_jobs=8,
    verbose=-1
) for _ in range(4)]

In [None]:
for target_id in range(4):
    models[target_id].fit(dwns_train, targets_train[:, target_id])

In [None]:
from sklearn.metrics import classification_report
from sklearn.metrics import roc_auc_score

for i in range(len(models)):
    preds = models[i].predict_proba(dwns_test)
    print(f"ROC-AUC target_{i} = {roc_auc_score(targets_test[:, i], preds[:, 1])}")

# GPT Module with SWIN enc

In [None]:
from ptls.frames.gpt.gpt_dataset import GptIterableDataset

data_module = ptls.frames.PtlsDataModule(
    train_data=GptIterableDataset(
        data=train,
        min_len=80,
        max_len=300
    ),
    valid_data=GptIterableDataset(
        data=valid,
        min_len=80,
        max_len=300
    ),
    train_batch_size=128,
    train_num_workers=0,
    valid_batch_size=64,
    valid_num_workers=0
)

In [None]:
import torch

a = torch.rand(64, 853, 216)
B, L, C = a.shape
torch.arange(L)[None, :] + torch.ones((B, L))

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import numpy as np


#------------------------------------------------------------------------------------------------------------
# Based on https://github.com/yukara-ikemiya/Swin-Transformer-1d/tree/main and adapted to pytorch-lifestream
#------------------------------------------------------------------------------------------------------------

import torch
import torch.nn as nn
from ptls.data_load.padded_batch import PaddedBatch


def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
    # copied from timm/models/layers/drop.py
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.

    """
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
    if keep_prob > 0.0 and scale_by_keep:
        random_tensor.div_(keep_prob)
    return x * random_tensor


class DropPath(nn.Module):
    # copied from timm/models/layers/drop.py
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """

    def __init__(self, drop_prob=None, scale_by_keep=True):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)


class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


def window_partition(x, window_size):
    """
    Args:
        x: (B, L, C)
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, C)
    """
    B, L, C = x.shape
    x = x.view(B, L // window_size, window_size, C)
    windows = x.contiguous().view(-1, window_size, C)
    return windows


def window_reverse(windows, window_size, L):
    """
    Args:
        windows: (num_windows*B, window_size, C)
        window_size (int): Window size
        L (int): Length of data

    Returns:
        x: (B, L, C)
    """
    B = int(windows.shape[0] / (L / window_size))
    x = windows.view(B, L // window_size, window_size, -1)
    x = x.contiguous().view(B, L, -1)
    return x


class WindowAttention(nn.Module):
    """ Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (int): The width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """
    def __init__(self, dim: int, window_size: int, num_heads: int,
                 qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros(2 * window_size - 1, num_heads))  # 2*window_size - 1, nH

        # get pair-wise relative position index for each token inside the window
        coords_w = torch.arange(self.window_size)
        relative_coords = coords_w[:, None] - coords_w[None, :]  # W, W
        relative_coords[:, :] += self.window_size - 1  # shift to start from 0

        # relative_position_index | example
        # [2, 1, 0]
        # [3, 2, 1]
        # [4, 3, 2]
        self.register_buffer("relative_position_index", relative_coords)  # (W, W): range of 0 -- 2*(W-1)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        torch.nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask_add, mask_mult):
        """
        Args:
            x: input features with shape of (num_windows*B, W, C)
            mask: (0/-inf) mask with shape of (num_windows, W, W) or None
        """
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size, self.window_size, -1)  # W, W, nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, W, W
        attn = attn + relative_position_bias.unsqueeze(0)

        nW = mask_add.shape[1]
        attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask_add
        attn = attn.view(-1, self.num_heads, N, N)
        attn = self.softmax(attn)
        attn = attn * mask_mult
        
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

    def extra_repr(self) -> str:
        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'


class SwinTransformerBlock(nn.Module):
    r""" Swin Transformer Block.

    Args:
        dim (int): Number of input channels.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """
    def __init__(self, dim, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=self.window_size, num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        attn_mask = None
        self.register_buffer("attn_mask", attn_mask)

    def forward(self, x):
        seq_lens = x.seq_lens
        x = x.payload
        
        B, L, C = x.shape

        # define seq_len_mask
        mask = torch.arange(L, device=x.device)[None, :] + torch.ones((B, L), device=x.device)
        mask[mask > seq_lens[:, None]] = 0.
        mask[mask > 0.] = 1.
        mask = mask[:, :, None]

        # make new max seq_len `L` divisible by `self.window_size` by adding 'zero' samples
        num_samples_to_add = self.window_size - (L % self.window_size)
        
        if num_samples_to_add < self.window_size:
            additional_samples = torch.zeros((B, num_samples_to_add, C), device=x.device)
            x = torch.cat((x, additional_samples), dim=1)
            mask_additional_samples = torch.zeros((B, num_samples_to_add, mask.shape[2]), device=mask.device)
            mask = torch.cat((mask, mask_additional_samples), dim=1)
            L += num_samples_to_add

        # mask out padding transactions
        x = x * mask
        
        assert L >= self.window_size, f'input length ({L}) must be >= window size ({self.window_size})'
        assert L % self.window_size == 0, f'input length ({L}) must be divisible by window size ({self.window_size})'

        shortcut = x
        x = self.norm1(x)

        # shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=-self.shift_size, dims=1) # cyclic shift as in orig. 2D SWIN-transformer
            mask = torch.roll(mask, shifts=-self.shift_size, dims=1) # cyclic shift of the mask
        else:
            shifted_x = x
        
        # partition
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, C
        mask = window_partition(mask, self.window_size) # nW*B, window_size, 1
        
        # calculate attn_mask
        attn_mask = (mask @ mask.transpose(-2, -1)) # nW*B, window_size, window_size
        
        attn_mask_real = attn_mask.clone().detach()
        attn_mask_real = attn_mask_real.view(attn_mask_real.shape[0], self.window_size, self.window_size).unsqueeze(1).expand(-1, self.num_heads, -1, -1) # B*nW, nH, window_size, window_size
        
        attn_mask[attn_mask == 0.] = -torch.inf
        attn_mask[attn_mask == 1.] = 0.
        attn_mask[:, torch.arange(attn_mask.shape[-1]), torch.arange(attn_mask.shape[-1])] = 0.
        attn_mask = attn_mask.view(B, attn_mask.shape[0] // B, self.window_size, self.window_size).unsqueeze(2).expand(-1, -1, self.num_heads, -1, -1) # B, nW, nH, window_size, window_size
        
        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask_add=attn_mask, mask_mult=attn_mask_real)  # nW*B, window_size, C
        
        # merge windows
        shifted_x = window_reverse(attn_windows, self.window_size, L)  # (B, L, C)

        # reverse zero-padding shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=self.shift_size, dims=1) # cyclic shift as in orig. 2D SWIN-transformer
        else:
            x = shifted_x

        x = shortcut + self.drop_path(x)

        # FFN
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        
        return PaddedBatch(x, seq_lens)

    def extra_repr(self) -> str:
        return f"dim={self.dim}, num_heads={self.num_heads}, " \
               f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"

class SwinTransformerV2Layer(nn.Module):
    """ A basic Swin Transformer layer for one stage.

    Args:
        dim (int): Number of input channels.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        window_size (int): Local window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
        pretrained_window_size (int): Local window size in pre-training.
    """

    def __init__(
        self,
        dim: int,
        depth: int,
        num_heads: int,
        window_size: int,
        mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
        drop_path=0., norm_layer=nn.LayerNorm, use_checkpoint=False,
        pretrained_window_size=0
    ):

        super().__init__()
        self.dim = dim
        self.depth = depth
        self.num_heads = num_heads
        self.window_size = window_size
        self.use_checkpoint = use_checkpoint

        # build blocks
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(dim=dim,
                                 num_heads=num_heads, window_size=window_size,
                                 shift_size=0 if (i % 2 == 0) else window_size // 2,
                                 mlp_ratio=mlp_ratio,
                                 qkv_bias=qkv_bias,
                                 drop=drop, attn_drop=attn_drop,
                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                                 norm_layer=norm_layer)
            for i in range(depth)])

    def forward(self, x):
        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)
        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, depth={self.depth}, num_heads={self.num_heads}, window_size={self.window_size}"

    @property
    def embedding_size(self):
        return self.dim
    
    def flops(self):
        flops = 0
        for blk in self.blocks:
            flops += blk.flops()
        return flops

    def _init_respostnorm(self):
        for blk in self.blocks:
            nn.init.constant_(blk.norm1.bias, 0)
            nn.init.constant_(blk.norm1.weight, 0)
            nn.init.constant_(blk.norm2.bias, 0)
            nn.init.constant_(blk.norm2.weight, 0)

In [None]:
import math

class GPTInferenceModule(torch.nn.Module):
    def __init__(self, pretrained_model):
        super().__init__()
        self.model = pretrained_model
        self.model.is_reduce_sequence = False

        self.stat_pooler = StatPooling()
        self.last_step = LastStepEncoder()

    def forward(self, batch):
        z_trx = self.model.trx_encoder(batch)
        out = self.model._seq_encoder(z_trx)
        out = out if isinstance(out, PaddedBatch) else PaddedBatch(out, batch.seq_lens)
        if self.model.hparams.inference_pooling_strategy=='trx_stat_out':
            stats = self.stat_pooler(z_trx)
            out = self.last_step(out)
            out = torch.cat([stats, out], dim=1)
        elif self.model.hparams.inference_pooling_strategy=='trx_stat':
            out = self.stat_pooler(z_trx)
        elif self.model.hparams.inference_pooling_strategy=='out_stat':
            out = self.stat_pooler(out)
        elif self.model.hparams.inference_pooling_strategy=='out':
            out = self.last_step(out)
        else:
            raise
        if self.model.hparams.norm_predict:
            out = out / (out.pow(2).sum(dim=-1, keepdim=True) + 1e-9).pow(0.5)
        return out

class SWINPretrainModule(GptPretrainModule):
    def __init__(self, feature_encoder, seq_encoder, trx_encoder, *args, **kwargs):
        trx_encoder.numeric_values = {}
        super().__init__(trx_encoder=trx_encoder, seq_encoder=seq_encoder, *args, **kwargs)
        self.save_hyperparameters(ignore=['trx_encoder', 'seq_encoder', 'feature_encoder'])
        self.feature_encoder = feature_encoder
        self.window_size = self._seq_encoder.window_size

    def forward(self, batch):
        # print(f"{batch.payload['event_time'].size()=}")
        
        
        z_trx = self.trx_encoder(batch)
        # print(z_trx.seq_lens)
        # print(f'{z_trx.payload.size()=}')
        # print(z_trx.payload.size())
        payload = z_trx.payload.view(z_trx.payload.shape[:-1] + (-1, 24))
        payload = self.feature_encoder(payload)
        # print(f"{payload.size()=}")
        feature_embed = PaddedBatch(payload, z_trx.seq_lens)
        # encoded_trx = PaddedBatch(payload=payload, length=z_trx.seq_lens)

        # pad_size = math.ceil(feature_embed.size()[1] / self.window_size) * self.window_size - feature_embed.size()[1]
        # feature_embed = F.pad(feature_embed, (0, 0, 0, pad_size, 0, 0), 'constant', 0)
        
        out = self._seq_encoder(feature_embed)
        out = PaddedBatch(out.payload[:, :payload.shape[1], :], z_trx.seq_lens)
        # print(f'{out.size()=}')
        # print(f'{payload.size()=}')
        # if pad_size > 0:
        #     out = out[:, :-pad_size, :]
        # out = PaddedBatch(out, z_trx.seq_lens)
        # print(f"{out.seq_lens=}")
        # print(f"{out.payload.size()}")
        if self.hparams.norm_predict:
            out = self.fn_norm_predict(out)
        return out

In [None]:
from ptls.nn import TabFormerFeatureEncoder
from ptls.nn import TransformerEncoder
from ptls.nn import TrxEncoder

feature_encoder = TabFormerFeatureEncoder(
    n_cols=9,
    emb_dim=24
)

seq_encoder = SwinTransformerV2Layer(
    num_heads=4,
    depth=4,
    dim=216,
    window_size=12
)

trx_encoder = TrxEncoder(
            norm_embeddings=False,
            embeddings_noise=0.0,
            embeddings={
                'event_type': {"in": 58, "out": 24},
                'event_subtype': {"in": 59, "out": 24},
                'src_type11': {"in": 85, "out": 24},
                'src_type12': {"in": 349, "out": 24},
                'dst_type11': {"in": 84, "out": 24},
                'dst_type12': {"in": 417, "out": 24},
                'src_type22': {"in": 90, "out": 24},
                'src_type32': {"in": 91, "out": 24},
                'amount': {"in": 11, "out": 24}
            },
            # numeric_values={'amount': 'log'}
        )

pl_module = SWINPretrainModule(
    # total_steps=20000,
    feature_encoder=feature_encoder,
    trx_encoder=trx_encoder,
    seq_encoder=seq_encoder
)
# pl_module = GptPretrainModule(
#     total_steps=20000,
#     trx_encoder=trx_encoder,
#     seq_encoder=seq_encoder
# )

In [None]:
from lightning.pytorch.loggers import WandbLogger

wandb_logger = WandbLogger(project="MBD_My_Code", log_model="all")

trainer = pl.Trainer(
    logger=wandb_logger,
    max_epochs=20,
    accelerator="cuda" if torch.cuda.is_available() else "cpu",
    enable_progress_bar=True,
    gradient_clip_val=0.5,
    log_every_n_steps=50,
    limit_val_batches=32
)

In [None]:
trainer.fit(pl_module, data_module)

In [None]:
inference_dataset = ptls.data_load.datasets.ParquetDataset(
        data_files=[
            os.path.join(TRX_SUPERVISED_PATH, 'fold=4')
        ],
    i_filters=[
        ptls.data_load.iterable_processing.SeqLenFilter(min_seq_len=16),
        ptls.data_load.iterable_processing.SeqLenFilter(max_seq_len=2048),
        ptls.data_load.iterable_processing.feature_filter.FeatureFilter(
            keep_feature_names=[
                'client_id',
                'target_1',
                'target_2',
                'target_3',
                'target_4'
            ]
        ),
        QuantilfyAmount(),
        ptls.data_load.iterable_processing.CategorySizeClip(
            category_max_size={
                'event_type': 58,
                'event_subtype' :59,
                'src_type11': 85,
                'src_type12': 349,
                'dst_type11': 84,
                'dst_type12': 417,
                'src_type22': 90,
                'src_type32': 91,
                'amount': 10
            }
        ),
        GetSplit(
            start_month=1,
            end_month=12,
            col_id='client_id'
        ),
        ptls.data_load.iterable_processing.to_torch_tensor.ToTorch(),
    ],
    shuffle_files=True
)

In [None]:
inference_dl = DataLoader(
    dataset=inference_dataset,
    shuffle=False,
    num_workers=0,
    batch_size=32
)

In [None]:
target_col_names = [
    'target_1',
    'target_2',
    'target_3',
    'target_4'
]

collate_fn = partial(
    collate_feature_dict_with_target,
    target_col_names=target_col_names
)

inference_dl = DataLoader(
    dataset=inference_dataset,
    collate_fn=collate_fn,
    shuffle=False,
    num_workers=0,
    batch_size=32
)

In [None]:
import pandas as pd
import pytorch_lightning as pl
import torch
import numpy as np
from itertools import chain
from ptls.data_load.padded_batch import PaddedBatch
from datetime import datetime
from ptls.custom_layers import StatPooling
from ptls.nn.seq_step import LastStepEncoder


class InferenceModuleMultimodal(pl.LightningModule):
    def __init__(
        self,
        model,
        pandas_output=False,
        col_id='client_id',
        target_col_names=None,
        model_out_name='emb',
        model_type='notab'
    ):
        super().__init__()

        self.model = model
        self.pandas_output = pandas_output
        self.target_col_names = target_col_names
        self.col_id = col_id
        self.model_out_name = model_out_name
        self.model_type = model_type

        self.stat_pooler = StatPooling()
        self.last_step = LastStepEncoder()

    def forward(self, x):
        x_len = len(x)
        if x_len == 3:
            x, batch_ids, target_cols = x
        else: 
            x, batch_ids = x
        # if 'seq_encoder' in dir(self.model):
        #     out = self.model.seq_encoder(x)
        # else:
        out = self.model(x).payload[:, 0, :]
            
        if x_len == 3:
            target_cols = torch.tensor(target_cols)
            x_out = {
                self.col_id: batch_ids,
                self.model_out_name: out
            }
            if len(target_cols.size()) > 1:
                for idx, target_col in enumerate(self.target_col_names):
                    x_out[target_col] = target_cols[:, idx]
            else: 
                x_out[self.target_col_names[0]] = target_cols[::4]
        else:
            x_out = {
                self.col_id: batch_ids,
                self.model_out_name: out
            }

        if self.pandas_output:
            return self.to_pandas(x_out)
        return x_out

    @staticmethod
    def to_pandas(x):
        expand_cols = []
        scalar_features = {}

        for k, v in x.items():
            if type(v) is torch.Tensor:
                v = v.cpu().numpy()

            if type(v) is list or len(v.shape) == 1:
                scalar_features[k] = v
            elif len(v.shape) == 2:
                expand_cols.append(k)
            else:
                scalar_features[k] = None

        dataframes = [pd.DataFrame(scalar_features)]
        for col in expand_cols:
            v = x[col].cpu().numpy()
            dataframes.append(pd.DataFrame(v, columns=[f'{col}_{i:04d}' for i in range(v.shape[1])]))

        return pd.concat(dataframes, axis=1)

def collate_feature_dict_with_target(batch, col_id='client_id', target_col_names=None):
    batch_ids = []
    target_cols = []
    for sample in batch:
        batch_ids.append(sample[col_id])
        del sample[col_id]
        
        if target_col_names is not None:
            sample_targets = []
            for target_col in target_col_names:
                sample_targets.append(sample[target_col])
                del sample[target_col]
            target_cols.append(sample_targets)
                
            
    padded_batch = collate_feature_dict(batch)
    if target_col_names is not None:
        return padded_batch, batch_ids, target_cols
    return padded_batch, batch_ids

In [None]:
inf_module = InferenceModuleMultimodal(
    model=pl_module,
    pandas_output=True,
    col_id='client_id',
    target_col_names=target_col_names
)

In [None]:
inf_embeddings = pd.concat(trainer.predict(inf_module, inference_dl))

In [None]:
inf_embeddings[:3]

In [None]:
from sklearn.model_selection import train_test_split

dwns_train, dwns_test = train_test_split(inf_embeddings, test_size=0.2)

In [None]:
targets_train = np.array(
    [
        dwns_train['target_1'].to_numpy(),
        dwns_train['target_2'].to_numpy(),
        dwns_train['target_3'].to_numpy(),
        dwns_train['target_4'].to_numpy()
    ]
).T
targets_test = np.array(
    [
        dwns_test['target_1'].to_numpy(),
        dwns_test['target_2'].to_numpy(),
        dwns_test['target_3'].to_numpy(),
        dwns_test['target_4'].to_numpy()
    ]
).T

dwns_train = dwns_train.drop(columns=[
    'client_id',
    'target_1',
    'target_2',
    'target_3',
    'target_4'
]).to_numpy()

dwns_test = dwns_test.drop(columns=[
    'client_id',
    'target_1',
    'target_2',
    'target_3',
    'target_4'
]).to_numpy()

In [None]:
from lightgbm import LGBMClassifier

models = [LGBMClassifier(
    n_estimators=500,
    boosting_type='gbdt',
    subsample=0.5,
    subsample_freq=1,
    learning_rate=0.02,
    feature_fraction=0.75,
    max_depth=6,
    lambda_l1=1,
    lambda_l2=1,
    min_data_in_leaf=50,
    random_state=42,
    n_jobs=8,
    verbose=-1
) for _ in range(4)]

In [None]:
for target_id in range(4):
    models[target_id].fit(dwns_train, targets_train[:, target_id])

In [None]:
from sklearn.metrics import classification_report
from sklearn.metrics import roc_auc_score

for i in range(len(models)):
    preds = models[i].predict_proba(dwns_test)
    print(f"ROC-AUC target_{i} = {roc_auc_score(targets_test[:, i], preds[:, 1])}")

# GPT with unitregular

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import numpy as np


#------------------------------------------------------------------------------------------------------------
# Based on https://github.com/yukara-ikemiya/Swin-Transformer-1d/tree/main and adapted to pytorch-lifestream
#------------------------------------------------------------------------------------------------------------

import torch
import torch.nn as nn
from ptls.data_load.padded_batch import PaddedBatch


def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
    # copied from timm/models/layers/drop.py
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.

    """
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
    if keep_prob > 0.0 and scale_by_keep:
        random_tensor.div_(keep_prob)
    return x * random_tensor


class DropPath(nn.Module):
    # copied from timm/models/layers/drop.py
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """

    def __init__(self, drop_prob=None, scale_by_keep=True):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)


class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


def window_partition(x, window_size):
    """
    Args:
        x: (B, L, C)
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, C)
    """
    B, L, C = x.shape
    x = x.view(B, L // window_size, window_size, C)
    windows = x.contiguous().view(-1, window_size, C)
    return windows


def window_reverse(windows, window_size, L):
    """
    Args:
        windows: (num_windows*B, window_size, C)
        window_size (int): Window size
        L (int): Length of data

    Returns:
        x: (B, L, C)
    """
    B = int(windows.shape[0] / (L / window_size))
    x = windows.view(B, L // window_size, window_size, -1)
    x = x.contiguous().view(B, L, -1)
    return x


class WindowAttention(nn.Module):
    """ Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (int): The width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """
    def __init__(self, dim: int, window_size: int, num_heads: int,
                 qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros(2 * window_size - 1, num_heads))  # 2*window_size - 1, nH

        # get pair-wise relative position index for each token inside the window
        coords_w = torch.arange(self.window_size)
        relative_coords = coords_w[:, None] - coords_w[None, :]  # W, W
        relative_coords[:, :] += self.window_size - 1  # shift to start from 0

        # relative_position_index | example
        # [2, 1, 0]
        # [3, 2, 1]
        # [4, 3, 2]
        self.register_buffer("relative_position_index", relative_coords)  # (W, W): range of 0 -- 2*(W-1)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        torch.nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask_add, mask_mult):
        """
        Args:
            x: input features with shape of (num_windows*B, W, C)
            mask: (0/-inf) mask with shape of (num_windows, W, W) or None
        """
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size, self.window_size, -1)  # W, W, nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, W, W
        attn = attn + relative_position_bias.unsqueeze(0)

        nW = mask_add.shape[1]
        attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask_add
        attn = attn.view(-1, self.num_heads, N, N)
        attn = self.softmax(attn)
        attn = attn * mask_mult
        
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

    def extra_repr(self) -> str:
        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'


class SwinTransformerBlock(nn.Module):
    r""" Swin Transformer Block.

    Args:
        dim (int): Number of input channels.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """
    def __init__(self, dim, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=self.window_size, num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        attn_mask = None
        self.register_buffer("attn_mask", attn_mask)

    def forward(self, x):
        seq_lens = x.seq_lens
        x = x.payload
        
        B, L, C = x.shape

        # define seq_len_mask
        mask = torch.arange(L, device=x.device)[None, :] + torch.ones((B, L), device=x.device)
        mask[mask > seq_lens[:, None]] = 0.
        mask[mask > 0.] = 1.
        mask = mask[:, :, None]

        # make new max seq_len `L` divisible by `self.window_size` by adding 'zero' samples
        num_samples_to_add = self.window_size - (L % self.window_size)
        
        if num_samples_to_add < self.window_size:
            additional_samples = torch.zeros((B, num_samples_to_add, C), device=x.device)
            x = torch.cat((x, additional_samples), dim=1)
            mask_additional_samples = torch.zeros((B, num_samples_to_add, mask.shape[2]), device=mask.device)
            mask = torch.cat((mask, mask_additional_samples), dim=1)
            L += num_samples_to_add

        # mask out padding transactions
        x = x * mask
        
        assert L >= self.window_size, f'input length ({L}) must be >= window size ({self.window_size})'
        assert L % self.window_size == 0, f'input length ({L}) must be divisible by window size ({self.window_size})'

        shortcut = x
        x = self.norm1(x)

        # shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=-self.shift_size, dims=1) # cyclic shift as in orig. 2D SWIN-transformer
            mask = torch.roll(mask, shifts=-self.shift_size, dims=1) # cyclic shift of the mask
        else:
            shifted_x = x
        
        # partition
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, C
        mask = window_partition(mask, self.window_size) # nW*B, window_size, 1
        
        # calculate attn_mask
        attn_mask = (mask @ mask.transpose(-2, -1)) # nW*B, window_size, window_size
        
        attn_mask_real = attn_mask.clone().detach()
        attn_mask_real = attn_mask_real.view(attn_mask_real.shape[0], self.window_size, self.window_size).unsqueeze(1).expand(-1, self.num_heads, -1, -1) # B*nW, nH, window_size, window_size
        
        attn_mask[attn_mask == 0.] = -torch.inf
        attn_mask[attn_mask == 1.] = 0.
        attn_mask[:, torch.arange(attn_mask.shape[-1]), torch.arange(attn_mask.shape[-1])] = 0.
        attn_mask = attn_mask.view(B, attn_mask.shape[0] // B, self.window_size, self.window_size).unsqueeze(2).expand(-1, -1, self.num_heads, -1, -1) # B, nW, nH, window_size, window_size
        
        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask_add=attn_mask, mask_mult=attn_mask_real)  # nW*B, window_size, C
        
        # merge windows
        shifted_x = window_reverse(attn_windows, self.window_size, L)  # (B, L, C)

        # reverse zero-padding shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=self.shift_size, dims=1) # cyclic shift as in orig. 2D SWIN-transformer
        else:
            x = shifted_x

        x = shortcut + self.drop_path(x)

        # FFN
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        
        return PaddedBatch(x, seq_lens)

    def extra_repr(self) -> str:
        return f"dim={self.dim}, num_heads={self.num_heads}, " \
               f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"

class SwinTransformerV2Layer(nn.Module):
    """ A basic Swin Transformer layer for one stage.

    Args:
        dim (int): Number of input channels.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        window_size (int): Local window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
        pretrained_window_size (int): Local window size in pre-training.
    """

    def __init__(
        self,
        dim: int,
        depth: int,
        num_heads: int,
        window_size: int,
        mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
        drop_path=0., norm_layer=nn.LayerNorm, use_checkpoint=False,
        pretrained_window_size=0
    ):

        super().__init__()
        self.dim = dim
        self.depth = depth
        self.num_heads = num_heads
        self.window_size = window_size
        self.use_checkpoint = use_checkpoint

        # build blocks
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(dim=dim,
                                 num_heads=num_heads, window_size=window_size,
                                 shift_size=0 if (i % 2 == 0) else window_size // 2,
                                 mlp_ratio=mlp_ratio,
                                 qkv_bias=qkv_bias,
                                 drop=drop, attn_drop=attn_drop,
                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                                 norm_layer=norm_layer)
            for i in range(depth)])

    def forward(self, x):
        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)
        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, depth={self.depth}, num_heads={self.num_heads}, window_size={self.window_size}"

    @property
    def embedding_size(self):
        return self.dim
    
    def flops(self):
        flops = 0
        for blk in self.blocks:
            flops += blk.flops()
        return flops

    def _init_respostnorm(self):
        for blk in self.blocks:
            nn.init.constant_(blk.norm1.bias, 0)
            nn.init.constant_(blk.norm1.weight, 0)
            nn.init.constant_(blk.norm2.bias, 0)
            nn.init.constant_(blk.norm2.weight, 0)

In [None]:
import math

class GPTInferenceModule(torch.nn.Module):
    def __init__(self, pretrained_model):
        super().__init__()
        self.model = pretrained_model
        self.model.is_reduce_sequence = False

        self.stat_pooler = StatPooling()
        self.last_step = LastStepEncoder()

    def forward(self, batch):
        z_trx = self.model.trx_encoder(batch)
        out = self.model._seq_encoder(z_trx)
        out = out if isinstance(out, PaddedBatch) else PaddedBatch(out, batch.seq_lens)
        if self.model.hparams.inference_pooling_strategy=='trx_stat_out':
            stats = self.stat_pooler(z_trx)
            out = self.last_step(out)
            out = torch.cat([stats, out], dim=1)
        elif self.model.hparams.inference_pooling_strategy=='trx_stat':
            out = self.stat_pooler(z_trx)
        elif self.model.hparams.inference_pooling_strategy=='out_stat':
            out = self.stat_pooler(out)
        elif self.model.hparams.inference_pooling_strategy=='out':
            out = self.last_step(out)
        else:
            raise
        if self.model.hparams.norm_predict:
            out = out / (out.pow(2).sum(dim=-1, keepdim=True) + 1e-9).pow(0.5)
        return out

class SWINPretrainModule(GptPretrainModule):
    def __init__(self, feature_encoder, seq_encoder, trx_encoder, *args, **kwargs):
        trx_encoder.numeric_values = {}
        super().__init__(trx_encoder=trx_encoder, seq_encoder=seq_encoder, *args, **kwargs)
        self.save_hyperparameters(ignore=['trx_encoder', 'seq_encoder', 'feature_encoder'])
        self.feature_encoder = feature_encoder
        self.window_size = self._seq_encoder.window_size

    def forward(self, batch):
        # print(f"{batch.payload['event_time'].size()=}")
        
        
        z_trx = self.trx_encoder(batch)
        # print(z_trx.seq_lens)
        # print(f'{z_trx.payload.size()=}')
        # print(z_trx.payload.size())
        payload = z_trx.payload.view(z_trx.payload.shape[:-1] + (-1, 24))
        payload = self.feature_encoder(payload)
        # print(f"{payload.size()=}")
        feature_embed = PaddedBatch(payload, z_trx.seq_lens)
        # encoded_trx = PaddedBatch(payload=payload, length=z_trx.seq_lens)

        # pad_size = math.ceil(feature_embed.size()[1] / self.window_size) * self.window_size - feature_embed.size()[1]
        # feature_embed = F.pad(feature_embed, (0, 0, 0, pad_size, 0, 0), 'constant', 0)
        
        out = self._seq_encoder(feature_embed)
        out = PaddedBatch(out.payload[:, :payload.shape[1], :], z_trx.seq_lens)
        # print(f'{out.size()=}')
        # print(f'{payload.size()=}')
        # if pad_size > 0:
        #     out = out[:, :-pad_size, :]
        # out = PaddedBatch(out, z_trx.seq_lens)
        # print(f"{out.seq_lens=}")
        # print(f"{out.payload.size()}")
        if self.hparams.norm_predict:
            out = self.fn_norm_predict(out)
        return out

# CoLES with SWIN enc

In [None]:
data_module = ptls.frames.PtlsDataModule(
    train_data=ptls.frames.coles.ColesIterableDataset(
        splitter=ptls.frames.coles.split_strategy.SampleSlices(
            split_count=5,
            cnt_min=16,
            cnt_max=400
        ),
        data=train
    ),
    valid_data=ptls.frames.coles.ColesIterableDataset(
        splitter=ptls.frames.coles.split_strategy.SampleSlices(
            split_count=5,
            cnt_min=16,
            cnt_max=400
        ),
        data=valid
    ),
    train_batch_size=64,
    train_num_workers=0,
    valid_batch_size=32,
    valid_num_workers=0
)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import numpy as np


#------------------------------------------------------------------------------------------------------------
# Based on https://github.com/yukara-ikemiya/Swin-Transformer-1d/tree/main and adapted to pytorch-lifestream
#------------------------------------------------------------------------------------------------------------

import torch
import torch.nn as nn
from ptls.data_load.padded_batch import PaddedBatch


def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
    # copied from timm/models/layers/drop.py
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.

    """
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
    if keep_prob > 0.0 and scale_by_keep:
        random_tensor.div_(keep_prob)
    return x * random_tensor


class DropPath(nn.Module):
    # copied from timm/models/layers/drop.py
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """

    def __init__(self, drop_prob=None, scale_by_keep=True):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)


class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


def window_partition(x, window_size):
    """
    Args:
        x: (B, L, C)
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, C)
    """
    B, L, C = x.shape
    x = x.view(B, L // window_size, window_size, C)
    windows = x.contiguous().view(-1, window_size, C)
    return windows


def window_reverse(windows, window_size, L):
    """
    Args:
        windows: (num_windows*B, window_size, C)
        window_size (int): Window size
        L (int): Length of data

    Returns:
        x: (B, L, C)
    """
    B = int(windows.shape[0] / (L / window_size))
    x = windows.view(B, L // window_size, window_size, -1)
    x = x.contiguous().view(B, L, -1)
    return x


class WindowAttention(nn.Module):
    """ Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (int): The width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """
    def __init__(self, dim: int, window_size: int, num_heads: int,
                 qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros(2 * window_size - 1, num_heads))  # 2*window_size - 1, nH

        # get pair-wise relative position index for each token inside the window
        coords_w = torch.arange(self.window_size)
        relative_coords = coords_w[:, None] - coords_w[None, :]  # W, W
        relative_coords[:, :] += self.window_size - 1  # shift to start from 0

        # relative_position_index | example
        # [2, 1, 0]
        # [3, 2, 1]
        # [4, 3, 2]
        self.register_buffer("relative_position_index", relative_coords)  # (W, W): range of 0 -- 2*(W-1)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        torch.nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask_add, mask_mult):
        """
        Args:
            x: input features with shape of (num_windows*B, W, C)
            mask: (0/-inf) mask with shape of (num_windows, W, W) or None
        """
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size, self.window_size, -1)  # W, W, nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, W, W
        attn = attn + relative_position_bias.unsqueeze(0)

        nW = mask_add.shape[1]
        attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask_add
        attn = attn.view(-1, self.num_heads, N, N)
        attn = self.softmax(attn)
        attn = attn * mask_mult
        
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

    def extra_repr(self) -> str:
        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'


class SwinTransformerBlock(nn.Module):
    r""" Swin Transformer Block.

    Args:
        dim (int): Number of input channels.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """
    def __init__(self, dim, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=self.window_size, num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        attn_mask = None
        self.register_buffer("attn_mask", attn_mask)

    def forward(self, x):
        seq_lens = x.seq_lens
        x = x.payload
        
        B, L, C = x.shape

        # define seq_len_mask
        mask = torch.arange(L, device=x.device)[None, :] + torch.ones((B, L), device=x.device)
        mask[mask > seq_lens[:, None]] = 0.
        mask[mask > 0.] = 1.
        mask = mask[:, :, None]

        # make new max seq_len `L` divisible by `self.window_size` by adding 'zero' samples
        num_samples_to_add = self.window_size - (L % self.window_size)
        
        if num_samples_to_add < self.window_size:
            additional_samples = torch.zeros((B, num_samples_to_add, C), device=x.device)
            x = torch.cat((x, additional_samples), dim=1)
            mask_additional_samples = torch.zeros((B, num_samples_to_add, mask.shape[2]), device=mask.device)
            mask = torch.cat((mask, mask_additional_samples), dim=1)
            L += num_samples_to_add

        # mask out padding transactions
        x = x * mask
        
        assert L >= self.window_size, f'input length ({L}) must be >= window size ({self.window_size})'
        assert L % self.window_size == 0, f'input length ({L}) must be divisible by window size ({self.window_size})'

        shortcut = x
        x = self.norm1(x)

        # shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=-self.shift_size, dims=1) # cyclic shift as in orig. 2D SWIN-transformer
            mask = torch.roll(mask, shifts=-self.shift_size, dims=1) # cyclic shift of the mask
        else:
            shifted_x = x
        
        # partition
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, C
        mask = window_partition(mask, self.window_size) # nW*B, window_size, 1
        
        # calculate attn_mask
        attn_mask = (mask @ mask.transpose(-2, -1)) # nW*B, window_size, window_size
        
        attn_mask_real = attn_mask.clone().detach()
        attn_mask_real = attn_mask_real.view(attn_mask_real.shape[0], self.window_size, self.window_size).unsqueeze(1).expand(-1, self.num_heads, -1, -1) # B*nW, nH, window_size, window_size
        
        attn_mask[attn_mask == 0.] = -torch.inf
        attn_mask[attn_mask == 1.] = 0.
        attn_mask[:, torch.arange(attn_mask.shape[-1]), torch.arange(attn_mask.shape[-1])] = 0.
        attn_mask = attn_mask.view(B, attn_mask.shape[0] // B, self.window_size, self.window_size).unsqueeze(2).expand(-1, -1, self.num_heads, -1, -1) # B, nW, nH, window_size, window_size
        
        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask_add=attn_mask, mask_mult=attn_mask_real)  # nW*B, window_size, C
        
        # merge windows
        shifted_x = window_reverse(attn_windows, self.window_size, L)  # (B, L, C)

        # reverse zero-padding shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=self.shift_size, dims=1) # cyclic shift as in orig. 2D SWIN-transformer
        else:
            x = shifted_x

        x = shortcut + self.drop_path(x)

        # FFN
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        
        return PaddedBatch(x, seq_lens)

    def extra_repr(self) -> str:
        return f"dim={self.dim}, num_heads={self.num_heads}, " \
               f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"

class SwinTransformerV2Layer(nn.Module):
    """ A basic Swin Transformer layer for one stage.

    Args:
        dim (int): Number of input channels.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        window_size (int): Local window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
        pretrained_window_size (int): Local window size in pre-training.
    """

    def __init__(
        self,
        dim: int,
        depth: int,
        num_heads: int,
        window_size: int,
        mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
        drop_path=0., norm_layer=nn.LayerNorm, use_checkpoint=False,
        pretrained_window_size=0
    ):

        super().__init__()
        self.dim = dim
        self.depth = depth
        self.num_heads = num_heads
        self.window_size = window_size
        self.use_checkpoint = use_checkpoint

        # build blocks
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(dim=dim,
                                 num_heads=num_heads, window_size=window_size,
                                 shift_size=0 if (i % 2 == 0) else window_size // 2,
                                 mlp_ratio=mlp_ratio,
                                 qkv_bias=qkv_bias,
                                 drop=drop, attn_drop=attn_drop,
                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                                 norm_layer=norm_layer)
            for i in range(depth)])

    def forward(self, x):
        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)
        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, depth={self.depth}, num_heads={self.num_heads}, window_size={self.window_size}"

    @property
    def embedding_size(self):
        return self.dim
    
    def flops(self):
        flops = 0
        for blk in self.blocks:
            flops += blk.flops()
        return flops

    def _init_respostnorm(self):
        for blk in self.blocks:
            nn.init.constant_(blk.norm1.bias, 0)
            nn.init.constant_(blk.norm1.weight, 0)
            nn.init.constant_(blk.norm2.bias, 0)
            nn.init.constant_(blk.norm2.weight, 0)

In [None]:
import warnings

import torch

# from ptls.constant_repository import TORCH_EMB_DTYPE
from ptls.data_load.padded_batch import PaddedBatch
from ptls.nn.trx_encoder.batch_norm import RBatchNorm, RBatchNormWithLens
from ptls.nn.trx_encoder.noisy_embedding import NoisyEmbedding
from ptls.nn.trx_encoder.trx_encoder_base import TrxEncoderBase

TORCH_EMB_DTYPE = torch.float32

class TrxEncoderSWIN(TrxEncoderBase):
    """Network layer which makes representation for single transactions

     Input is `PaddedBatch` with ptls-format dictionary, with feature arrays of shape (B, T)
     Output is `PaddedBatch` with transaction embeddings of shape (B, T, H)
     where:
        B - batch size, sequence count in batch
        T - sequence length
        H - hidden size, representation dimension

    `ptls.nn.trx_encoder.noisy_embedding.NoisyEmbedding` implementation are used for categorical features.

    Parameters
        embeddings:
            dict with categorical feature names.
            Values must be like this `{'in': dictionary_size, 'out': embedding_size}`
            These features will be encoded with lookup embedding table of shape (dictionary_size, embedding_size)
            Values can be a `torch.nn.Embedding` implementation
        numeric_values:
            dict with numerical feature names.
            Values must be a string with scaler_name.
            Possible values are: 'identity', 'sigmoid', 'log', 'year'.
            These features will be scaled with selected scaler.
            Values can be `ptls.nn.trx_encoder.scalers.BaseScaler` implementatoin

            One field can have many scalers. In this case key become alias and col name should be in scaler.
            Check `TrxEncoderBase.numeric_values` for more details

        embeddings_noise (float):
            Noise level for embedding. `0` meens without noise
        emb_dropout (float):
            Probability of an element of embedding to be zeroed
        spatial_dropout (bool):
            Whether to dropout full dimension of embedding in the whole sequence

        use_batch_norm:
            True - All numerical values will be normalized after scaling
            False - No normalizing for numerical values
        use_batch_norm_with_lens:
            True - Respect seq_lens during batch_norm. Padding zeroes will be ignored
            False - Batch norm ower all time axis. Padding zeroes will included.

        orthogonal_init:
            if True then `torch.nn.init.orthogonal` applied
        linear_projection_size:
            Linear layer at the end will be added for non-zero value

        out_of_index:
            How to process a categorical indexes which are greater than dictionary size.
            'clip' - values will be collapsed to maximum index. This works well for frequency encoded categories.
                We join infrequent categories to one.
            'assert' - raise an error of invalid index appear.

        norm_embeddings: keep default value for this parameter
        clip_replace_value: Not useed. keep default value for this parameter
        positions: Not used. Keep default value for this parameter

    Examples:
        >>> B, T = 5, 20
        >>> trx_encoder = TrxEncoder(
        >>>     embeddings={'mcc_code': {'in': 100, 'out': 5}},
        >>>     numeric_values={'amount': 'log'},
        >>> )
        >>> x = PaddedBatch(
        >>>     payload={
        >>>         'mcc_code': torch.randint(0, 99, (B, T)),
        >>>         'amount': torch.randn(B, T),
        >>>     },
        >>>     length=torch.randint(10, 20, (B,)),
        >>> )
        >>> z = trx_encoder(x)
        >>> assert z.payload.shape == (5, 20, 6)  # B, T, H
    """
    def __init__(self,
                 embeddings=None,
                 numeric_values=None,
                 custom_embeddings=None,
                 embeddings_noise: float = 0,
                 norm_embeddings=None,
                 use_batch_norm=False,
                 use_batch_norm_with_lens=False,
                 clip_replace_value=None,
                 positions=None,
                 emb_dropout=0,
                 spatial_dropout=False,
                 orthogonal_init=False,
                 linear_projection_size=0,
                 out_of_index: str = 'clip',
                 ):
        if clip_replace_value is not None:
            warnings.warn('`clip_replace_value` attribute is deprecated. Always "clip to max" used. '
                          'Use `out_of_index="assert"` to avoid categorical values clip', DeprecationWarning)

        if positions is not None:
            warnings.warn('`positions` is deprecated. positions is not used', UserWarning)

        if embeddings is None:
            embeddings = {}
        if custom_embeddings is None:
            custom_embeddings = {}

        noisy_embeddings = {}
        for emb_name, emb_props in embeddings.items():
            if emb_props.get('disabled', False):
                continue
            if emb_props['in'] == 0 or emb_props['out'] == 0:
                continue
            noisy_embeddings[emb_name] = NoisyEmbedding(
                num_embeddings=emb_props['in'],
                embedding_dim=emb_props['out'],
                padding_idx=0,
                max_norm=1 if norm_embeddings else None,
                noise_scale=embeddings_noise,
                dropout=emb_dropout,
                spatial_dropout=spatial_dropout,
            )

        super().__init__(
            embeddings=noisy_embeddings,
            numeric_values=numeric_values,
            custom_embeddings=custom_embeddings,
            out_of_index=out_of_index,
        )
        self.swin_encoder = SwinTransformerV2Layer(
            num_heads=5,
            depth=3,
            dim=225,
            window_size=15
        )
        custom_embedding_size = self.custom_embedding_size
        if use_batch_norm and custom_embedding_size > 0:
            # :TODO: Should we use Batch norm with not-numerical custom embeddings?
            if use_batch_norm_with_lens:
                self.custom_embedding_batch_norm = RBatchNormWithLens(custom_embedding_size)
            else:
                self.custom_embedding_batch_norm = RBatchNorm(custom_embedding_size)
        else:
            self.custom_embedding_batch_norm = None

        if linear_projection_size > 0:
            self.linear_projection_head = torch.nn.Linear(super().output_size, linear_projection_size)
        else:
            self.linear_projection_head = None

        if orthogonal_init:
            for n, p in self.named_parameters():
                if n.startswith('embeddings.') and n.endswith('.weight'):
                    torch.nn.init.orthogonal_(p.data[1:])
                if n == 'linear_projection_head.weight':
                    torch.nn.init.orthogonal_(p.data)

    def forward(self, x: PaddedBatch, names=None, seq_len=None):
        if isinstance(x, PaddedBatch) is False:
            pre_x = dict()
            for i, field_name in enumerate(names):
                pre_x[field_name] = x[i]
            x = PaddedBatch(pre_x, seq_len)

        processed_embeddings = [self.get_category_embeddings(x, field_name)
                                for field_name in self.embeddings.keys()]
        processed_custom_embeddings = [self.get_custom_embeddings(x, field_name)
                                       for field_name in self.custom_embeddings.keys()]
        if len(processed_custom_embeddings):
            processed_custom_embeddings = torch.cat(processed_custom_embeddings, dim=2)
            if self.custom_embedding_batch_norm is not None:
                processed_custom_embeddings = PaddedBatch(processed_custom_embeddings, x.seq_lens)
                processed_custom_embeddings = self.custom_embedding_batch_norm(processed_custom_embeddings)
                processed_custom_embeddings = processed_custom_embeddings.payload
            processed_embeddings.append(processed_custom_embeddings)

        out = torch.cat(processed_embeddings, dim=2)
        out = out.type(TORCH_EMB_DTYPE)
        out = self.linear_projection_head(out) if self.linear_projection_head is not None else out
        # print(out.size())
        out = PaddedBatch(out, x.seq_lens)
        out = self.swin_encoder(out)
        return out

    @property
    def output_size(self):
        """Returns hidden size of output representation
        """
        if self.linear_projection_head is not None:
            return self.linear_projection_head.out_features
        return super().output_size

from ptls.nn.seq_encoder.rnn_encoder import RnnEncoder
from ptls.nn.seq_encoder.containers import SeqEncoderContainer


class SwinTransformerBackbone(nn.Module):
    """ Swin Transformer Backbone (4 stages as in orig. 2D impl.).

    Args:
        dim (int): Number of input channels.
        depths (list[int]): Numbers of blocks in stages.
        num_heads (int): Number of attention heads in W-MSA layers.
        start_window_size (int): Local window size of stage 1.
        window_size_mult (int): the number by which the `window_size` is being multiplied when moving to another stage
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
    """
    def __init__(
        self,
        dim: int,
        depths: list[int],
        num_heads: int,
        start_window_size: int,
        window_size_mult: int = 1,
        mlp_ratio=4.,
        qkv_bias=True,
        qk_scale=None,
        drop=0.,
        attn_drop=0.,
        drop_path=0.,
        norm_layer=nn.LayerNorm
    ):
        super().__init__()
        self.dim = dim
        self.depths = depths
        self.num_heads = num_heads
        self.window_sizes = [start_window_size]
        
        for i in range(len(self.depths) - 1):
            self.window_sizes += [self.window_sizes[-1] * window_size_mult]

        # build model
        self.backbone = nn.ModuleList([
            SwinTransformerLayer(dim=self.dim,
                                 depth=self.depths[i],
                                 num_heads=self.num_heads,
                                 window_size=self.window_sizes[i],
                                 mlp_ratio=mlp_ratio,
                                 qkv_bias=qkv_bias,
                                 qk_scale=qk_scale,
                                 drop=drop,
                                 attn_drop=attn_drop,
                                 drop_path=drop_path,
                                 norm_layer=norm_layer)
            for i in range(len(self.depths))])

    def forward(self, x):
        for layer in self.backbone:
            x = layer(x)
        return x

class SWIN_RNN_SeqEncoder(SeqEncoderContainer):
    """SeqEncoderContainer with SWIN transformer layer for features hierarchic fusion and RnnEncoder for feature aggregation.
    
    Parameters
        trx_encoder:
            TrxEncoder object
        input_size:
            input_size parameter for RnnEncoder
            If None: input_size = trx_encoder.output_size
            Set input_size explicitly or use None if your trx_encoder object has output_size attribute
        is_reduce_sequence:
            False - returns PaddedBatch with all transactions embeddings
            True - returns one embedding for sequence based on CLS token
        swin_depths: Numbers of blocks in stages (SWIN backbone).
        swin_num_heads: Number of attention heads in W-MSA layers (SWIN backbone).
        swin_start_window_size (int): Local window size of stage 1 (SWIN backbone).
        swin_window_size_mult (int): the number by which the `window_size` is being multiplied when moving to another stage (SWIN backbone).
        swin_drop: Dropout rate (SWIN backbone). Default: 0.0
        swin_attn_drop: Attention dropout rate (SWIN backbone). Default: 0.0
        swin_drop_path: Stochastic depth rate (SWIN backbone). Default: 0.0
        **rnn_seq_encoder_params:
            RnnEncoder params

    """
    def __init__(self,
                 trx_encoder=None,
                 input_size=None,
                 is_reduce_sequence=True,
                 swin_depths=[],
                 swin_num_heads=4,
                 swin_start_window_size=4,
                 swin_window_size_mult=1,
                 swin_drop=0.,
                 swin_attn_drop=0.,
                 swin_drop_path=0.,
                 **rnn_seq_encoder_params
                 ):
        super().__init__(
            trx_encoder=trx_encoder,
            seq_encoder_cls=RnnEncoder,
            input_size=input_size,
            seq_encoder_params=rnn_seq_encoder_params,
            is_reduce_sequence=is_reduce_sequence,
        )
        self.swin_fusion = SwinTransformerBackbone(
                               dim=trx_encoder.output_size,
                               depths=swin_depths,
                               num_heads=swin_num_heads,
                               start_window_size=swin_start_window_size,
                               window_size_mult=swin_window_size_mult,
                               drop=swin_drop,
                               attn_drop=swin_attn_drop,
                               drop_path=swin_drop_path
                              )

    def forward(self, x, names=None, seq_len=None, h_0=None):
        x = self.trx_encoder(x)
        x = self.swin_fusion(x)
        x = self.seq_encoder(x, h_0)
        return x

# class SWINTrxEncoder(nn.Module):
#     def __init__(self, trx_encoder):
#         super().__init__()
#         self.trx_encoder = trx_encoder
#         self.swin_encoder = nn.ModuleList([SwinTransformerV2Layer(
#             num_heads=4,
#             depth=4,
#             dim=216,
#             window_size=12 // i
#         ) for i in range (1, 4)])

#     def forward(self, x):
#         x = self.trx_encoder(x)
#         for layer in self.swin_encoder:
#             x = layer(x)
#         return x

#     @property
#     def output_size(self):
#         return self.trx_encoder.output_size

In [None]:
optimizer_partial = partial(
    torch.optim.AdamW,
    lr=1e-3,
    weight_decay=1e-4
)

lr_scheduler_partial = partial(
    torch.optim.lr_scheduler.StepLR,
    step_size=2,
    gamma=0.9025
)

trx_encoder = ptls.nn.TrxEncoder(
                norm_embeddings=False,
                embeddings_noise=0.003,
                embeddings={
                    'event_type': {"in": 58, "out": 28},
                    'event_subtype': {"in": 59, "out": 28},
                    'src_type11': {"in": 85, "out": 28},
                    'src_type12': {"in": 349, "out": 28},
                    'dst_type11': {"in": 84, "out": 28},
                    'dst_type12': {"in": 417, "out": 28},
                    'src_type22': {"in": 90, "out": 28},
                    'src_type32': {"in": 91, "out": 28},
                },
                numeric_values={
                    'amount': 'log'
                }
            )
seq_encoder = ptls.nn.RnnSeqEncoder(
    trx_encoder=trx_encoder,
    type='gru',
    hidden_size=256
)

pl_module = ptls.frames.coles.CoLESModule(
    # validation_metric=ptls.frames.coles.metric.BatchRecallTopK(
    #     K=4,
    #     metric='cosine'
    # ),
    seq_encoder=seq_encoder,
    optimizer_partial=optimizer_partial,
    lr_scheduler_partial=lr_scheduler_partial
)

In [None]:
from lightning.pytorch.loggers import WandbLogger

wandb_logger = WandbLogger(project="MBD_My_Code", log_model="all")

trainer = pl.Trainer(
    logger=wandb_logger,
    max_epochs=20,
    accelerator="cuda" if torch.cuda.is_available() else "cpu",
    enable_progress_bar=True,
    gradient_clip_val=0.3,
    log_every_n_steps=50,
    limit_val_batches=32
)

In [None]:
trainer.fit(pl_module, data_module)

In [None]:
from ptls.data_load.iterable_processing.iterable_seq_len_limit import ISeqLenLimit
from ptls.data_load.iterable_processing.feature_filter import FeatureFilter
from ptls.data_load.iterable_processing.to_torch_tensor import ToTorch

inference_dataset = ptls.data_load.datasets.ParquetDataset(
    data_files=[os.path.join(TRX_SUPERVISED_PATH, 'fold=4')],
    i_filters=[
        ISeqLenLimit(max_seq_len=4096),
        ToTorch(),
        FeatureFilter(
            keep_feature_names=[
                'client_id',
                'target_1',
                'target_2',
                'target_3',
                'target_4'
            ]
        ),
        GetSplit(
            start_month=1,
            end_month=12,
            col_id='client_id'
        )
    ]
)

In [None]:
inference_dl = DataLoader(
    dataset=inference_dataset,
    shuffle=False,
    num_workers=0,
    batch_size=32
)

In [None]:
import pandas as pd
import pytorch_lightning as pl
import torch
import numpy as np
from itertools import chain
from ptls.data_load.padded_batch import PaddedBatch
from datetime import datetime
from ptls.custom_layers import StatPooling
from ptls.nn.seq_step import LastStepEncoder


class InferenceModuleMultimodal(pl.LightningModule):
    def __init__(
        self,
        model,
        pandas_output=True,
        col_id='client_id',
        target_col_names=None,
        model_out_name='emb',
        model_type='notab'
    ):
        super().__init__()

        self.model = model
        self.pandas_output = pandas_output
        self.target_col_names = target_col_names
        self.col_id = col_id
        self.model_out_name = model_out_name
        self.model_type = model_type

        self.stat_pooler = StatPooling()
        self.last_step = LastStepEncoder()

        self.batch_cntr = 0

    def forward(self, x):
        x_len = len(x)
        if x_len == 3:
            x, batch_ids, target_cols = x
        else: 
            x, batch_ids = x
        if 'seq_encoder' in dir(self.model):
            out = self.model.seq_encoder(x)
        else:
            out = self.model(x)
        if x_len == 3:
            target_cols = torch.tensor(target_cols)
            x_out = {
                self.col_id: batch_ids,
                self.model_out_name: out
            }
            if len(target_cols.size()) > 1:
                for idx, target_col in enumerate(self.target_col_names):
                    x_out[target_col] = target_cols[:, idx]
            else: 
                x_out[self.target_col_names[0]] = target_cols[::4]
        else:
            x_out = {
                self.col_id: batch_ids,
                self.model_out_name: out
            }

        if self.pandas_output:
            return self.to_pandas(x_out)
        return x_out

    @staticmethod
    def to_pandas(x):
        expand_cols = []
        scalar_features = {}

        for k, v in x.items():
            if type(v) is torch.Tensor:
                v = v.cpu().numpy()

            if type(v) is list or len(v.shape) == 1:
                scalar_features[k] = v
            elif len(v.shape) == 2:
                expand_cols.append(k)
            else:
                scalar_features[k] = None

        dataframes = [pd.DataFrame(scalar_features)]
        for col in expand_cols:
            v = x[col].cpu().numpy()
            dataframes.append(pd.DataFrame(v, columns=[f'{col}_{i:04d}' for i in range(v.shape[1])]))

        return pd.concat(dataframes, axis=1)

def collate_feature_dict_with_target(batch, col_id='client_id', target_col_names=None):
    batch_ids = []
    target_cols = []
    for sample in batch:
        batch_ids.append(sample[col_id])
        del sample[col_id]
        
        if target_col_names is not None:
            sample_targets = []
            for target_col in target_col_names:
                sample_targets.append(sample[target_col])
                del sample[target_col]
            target_cols.append(sample_targets)
                
            
    padded_batch = collate_feature_dict(batch)
    if target_col_names is not None:
        return padded_batch, batch_ids, target_cols
    return padded_batch, batch_ids

In [None]:
from torch.utils.data import Subset

target_col_names = [
    'target_1',
    'target_2',
    'target_3',
    'target_4'
]

collate_fn = partial(
    collate_feature_dict_with_target,
    target_col_names=target_col_names
)

indices = range(3000)

inference_dl = DataLoader(
    # dataset=Subset(inference_dataset, indices),
    dataset=inference_dataset,
    collate_fn=collate_fn,
    shuffle=False,
    num_workers=0,
    batch_size=32
)

In [None]:
inf_module = InferenceModuleMultimodal(
    model=pl_module,
    pandas_output=True,
    col_id='client_id',
    target_col_names=target_col_names
)

In [None]:
inf_embeddings = pd.concat(trainer.predict(inf_module, inference_dl))

In [None]:
inf_embeddings

In [None]:
from sklearn.model_selection import train_test_split

dwns_train, dwns_test = train_test_split(inf_embeddings, test_size=0.2)

In [None]:
targets_train = np.array(
    [
        dwns_train['target_1'].to_numpy(),
        dwns_train['target_2'].to_numpy(),
        dwns_train['target_3'].to_numpy(),
        dwns_train['target_4'].to_numpy()
    ]
).T
targets_test = np.array(
    [
        dwns_test['target_1'].to_numpy(),
        dwns_test['target_2'].to_numpy(),
        dwns_test['target_3'].to_numpy(),
        dwns_test['target_4'].to_numpy()
    ]
).T

dwns_train = dwns_train.drop(columns=[
    'client_id',
    'target_1',
    'target_2',
    'target_3',
    'target_4'
]).to_numpy()

dwns_test = dwns_test.drop(columns=[
    'client_id',
    'target_1',
    'target_2',
    'target_3',
    'target_4'
]).to_numpy()

In [None]:
from lightgbm import LGBMClassifier

models = [LGBMClassifier(
    n_estimators=500,
    boosting_type='gbdt',
    subsample=0.5,
    subsample_freq=1,
    learning_rate=0.02,
    feature_fraction=0.75,
    max_depth=6,
    lambda_l1=1,
    lambda_l2=1,
    min_data_in_leaf=50,
    random_state=42,
    n_jobs=8,
    verbose=-1
) for _ in range(4)]

In [None]:
for target_id in range(4):
    models[target_id].fit(dwns_train, targets_train[:, target_id])

In [None]:
from sklearn.metrics import classification_report
from sklearn.metrics import roc_auc_score

for i in range(len(models)):
    preds = models[i].predict_proba(dwns_test)
    print(f"ROC-AUC target_{i} = {roc_auc_score(targets_test[:, i], preds[:, 1])}")

# Патчи(до лучгих времен)

**Check data distribution from week**

In [None]:
import matplotlib.pyplot as plt


for idx, item in enumerate(iter(train)):
    if idx == 30:
        break
    dates = pd.to_datetime(item['event_time'].numpy(), unit='s')
    
    days_of_week = dates.dayofweek.tolist()
    
    days_counts = pd.Series(days_of_week).value_counts().sort_index()
    
    days_counts = days_counts.reindex(range(7), fill_value=0)
    
    weekday_names = [
        'Понедельник', 'Вторник', 'Среда',
        'Четверг', 'Пятница', 'Суббота', 'Воскресенье'
    ]
    
    plt.figure(figsize=(10, 6))
    plt.bar(weekday_names, days_counts, color='skyblue')
    plt.title('Распределение меток по дням недели', fontsize=14)
    plt.xlabel('День недели', fontsize=12)
    plt.ylabel('Количество меток', fontsize=12)
    plt.xticks(rotation=45)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.show()
    

In [None]:
a = np.arange(5)
a = np.append(a, 7)
a

In [None]:
a = np.arange(15)
print(len(a))
split_ids = np.arange(0, 15, 5)
split_ids = np.append(split_ids, 16)
[a[split_ids[i]:split_ids[i + 1]] for i in range(len(split_ids) - 1)]

In [None]:
from functools import reduce
from operator import iadd
from ptls.data_load.utils import collate_feature_dict
import joblib
from joblib import Parallel, delayed
from ptls.data_load.feature_dict import FeatureDict


class PatchesSplitter:
    def __init__(self, patch_size):
        self.patch_size = patch_size

    def split(self, dates):
        date_len = dates.shape[0]
        date_range = np.arange(date_len)

        split_ids = np.arange(0, date_len, self.patch_size)
        if len(split_ids) == 0 or split_ids[-1] < date_len:
           split_ids = np.append(split_ids, date_len)
        return [date_range[split_ids[i]:split_ids[i + 1]] for i in range(len(split_ids) - 1)]

class PatchedDataset(FeatureDict, torch.utils.data.Dataset):
    def __init__(self,
                 data,
                 splitter,
                 col_time='event_time',
                 n_jobs=1):
        self.data = data
        self.col_time = col_time
        self.splitter = splitter
        self.n_jobs = n_jobs

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int):
        feature_arrays = self.data[idx]
        return self.get_splits(feature_arrays)

    def __iter__(self):
        for feature_arrays in self.data:
            yield feature_arrays, self.get_splits(feature_arrays)

    def _create_split_subset(self, idx, feature_arrays):
        return {k: v[idx] for k, v in feature_arrays.items() if self.is_seq_feature(k, v)}
    
    def get_splits(self, feature_arrays):
        local_date = feature_arrays[self.col_time]
        indexes = self.splitter.split(local_date)
        with joblib.parallel_backend(backend='threading', n_jobs=self.n_jobs):
            parallel = Parallel()
            result_dict = parallel(delayed(self._create_split_subset)(idx, feature_arrays) for idx in indexes)

        return result_dict

    @staticmethod
    def collate_fn(batch):
        full, patched = [item[0] for item in batch], [item[1] for item in batch]
        
        patched_class_labels = torch.LongTensor(reduce(iadd, list(map(lambda x: [x[0] for _ in x[1]], enumerate(patched)))))
        patched_padded_batch = collate_feature_dict(reduce(iadd, patched))

        for i in range(len(full)):
            full[i] = [full[i]]
        full_class_labels = torch.LongTensor(reduce(iadd, list(map(lambda x: [x[0] for _ in x[1]], enumerate(full)))))
        full_padded_batch = collate_feature_dict(reduce(iadd, full))
        return full_padded_batch, full_class_labels, patched_padded_batch, patched_class_labels

class PatchedIterableDataset(PatchedDataset, torch.utils.data.IterableDataset):
    pass

In [None]:
data_module = ptls.frames.PtlsDataModule(
    train_data=PatchedIterableDataset(
        splitter=PatchesSplitter(
            patch_size=80
        ),
        data=train
    ),
    valid_data=PatchedIterableDataset(
        splitter=PatchesSplitter(
            patch_size=80
        ),
        data=valid
    ),
    train_batch_size=3,
    train_num_workers=0,
    valid_batch_size=32,
    valid_num_workers=0
)

In [None]:
next(iter(data_module.train_dl(train_data=PatchedIterableDataset(
        splitter=PatchesSplitter(
            patch_size=12
        ),
        data=train
    ))))[3].size()

In [None]:
import torch.nn as nn
from ptls.frames.coles.losses import ContrastiveLoss
from ptls.frames.coles.metric import BatchRecallTopK
from ptls.frames.coles.sampling_strategies import HardNegativePairSelector
from ptls.nn.head import Head

class PBLinear(torch.nn.Linear):
    def forward(self, x: PaddedBatch):
        return PaddedBatch(super().forward(x.payload), x.seq_lens)


# class RSA(nn.Module):
#     def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

#         super().__init__()
#         self.dim = dim
#         self.num_heads = num_heads
#         head_dim = dim // num_heads
#         self.scale = qk_scale or head_dim ** -0.5

#         self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
#         self.attn_drop = nn.Dropout(attn_drop)
#         self.proj = nn.Linear(dim, dim)
#         self.proj_drop = nn.Dropout(proj_drop)

#         self.softmax = nn.Softmax(dim=-1)
        
#     def forward(self, x):
#         B_, N, C = x.shape
#         qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
#         q, k, v = qkv[0], qkv[1], qkv[2]

#         q = q * self.scale
#         attn = (q @ k.transpose(-2, -1))

#         attn = self.softmax(attn)

#         attn = self.attn_drop(attn)

#         x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
#         x = self.proj(x)
#         x = self.proj_drop(x)
#         return x

class PositionalEncoding(nn.Module):
    def __init__(self,
                 d_model,
                 use_start_random_shift=True,
                 max_len=5000,
                 ):
        super().__init__()
        self.use_start_random_shift = use_start_random_shift
        self.max_len = max_len

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        T = x.size(1)
        if self.training and self.use_start_random_shift:
            start_pos = random.randint(0, self.max_len - T)
        else:
            start_pos = 0
        x = x + self.pe[:, start_pos:start_pos + T]
        return x

class RegionalModule(pl.LightningModule):
    def __init__(self,
                num_layers,
                trx_encoder,
                head=None,
                loss=None,
                validation_metric=None,
                optimizer_partial=None,
                lr_scheduler_partial=None):
        super().__init__()
        self.save_hyperparameters()
        dim = 256
        self.trx_encoder = trx_encoder
        self.num_layers = num_layers

        self.optimizer_partial = optimizer_partial
        self.lr_scheduler_partial = lr_scheduler_partial
        
        self.pb_linear_full = PBLinear(trx_encoder.output_size, 256)
        self.pb_linear_patched = PBLinear(trx_encoder.output_size, 256)

        self.enc_layer_patched = torch.nn.TransformerEncoderLayer(
            d_model=256,
            nhead=4,
            dim_feedforward=256,
            dropout=0.2,
            batch_first=True
        )

        self.enc_layer_full = torch.nn.TransformerEncoderLayer(
            d_model=256,
            nhead=4,
            dim_feedforward=256,
            dropout=0.2,
            batch_first=True
        )
        
        # self.rsa = RSA(
            
        # )
        self.pe = PositionalEncoding(
                use_start_random_shift=True,
                max_len=5000,
                d_model=256,
            )

        self.cross_attn = nn.MultiheadAttention(
            embed_dim=256,
            num_heads=8,
            dropout=0.1,
            batch_first=True
        )
        
        self.enc_norm = nn.LayerNorm(256)
        self.loss = ContrastiveLoss(margin=0.5,
                                   sampling_strategy=HardNegativePairSelector(neg_count=3))
        self.validation_metric = BatchRecallTopK(K=4, metric='cosine')
        self.head = Head(True)

    def forward(self, batch):
        x_full, full_labels, x_patches, patches_labels = batch

        print(x_patches.payload['event_time'].size())
        
        patches_embeds = self.trx_encoder(x_patches)
        patches_embeds = self.pb_linear_patched(patches_embeds)
        patches_embeds = self.pe(patches_embeds.payload)

        

        full_embeds = self.trx_encoder(x_full)
        full_embeds = self.pb_linear_full(full_embeds)
        full_embeds = self.pe(full_embeds.payload)
        full_embeds = full_embeds[patches_labels]

        print(f"{full_embeds.size()=}")

        # print(full_embeds.size())
        # print(patches_embeds.size())
        # print(full_embeds[patches_labels].size())

        for _ in range(self.num_layers):
            patches_embeds = self.enc_layer_patched(patches_embeds)
            patches_embeds = self.enc_norm(patches_embeds)
            
            full_embeds = self.enc_layer_full(full_embeds)
            full_embeds = self.enc_norm(full_embeds)

        print(f"{patches_embeds.size()=}")
        print(f"{full_embeds.size()=}")

        out, _ = self.cross_attn(full_embeds, patches_embeds, patches_embeds)
        print(out.size())
        counts = torch.bincount(patches_labels)
        splits = torch.split(patches_embeds, counts.tolist(), dim=0)
        splits = [s.reshape(1, -1, 256) for s in splits]
        outs = []
        for i in range(len(splits)):
            out, _ = self.cross_attn(full_embeds[i].reshape(1, -1, 256), splits[i], splits[i])
            outs.append(out.mean(dim=1)[0])
        outs = torch.stack(outs)

        print(outs.size())

        return out

    def shared_step(self, batch):
        y_h = self.head(self(batch)) if self.head is not None else self(batch)
        y = batch[-1]
        print(f"{y.size()=}")
        return y_h, y
    
    def training_step(self, batch, _):
        _, _, y_h, y = self.shared_step(batch)
        loss = self.loss(y_h, y)
        self.log('loss', loss)
        return loss

    def validation_step(self, batch, _):
        y_h, y = self.shared_step(batch)
        print(y_h.size(), y.size())
        self.validation_metric(y_h, y)

    def on_validation_epoch_end(self):
        self.log(f'valid/{self.metric_name}', self.validation_metric.compute(), prog_bar=True)
        self.validation_metric.reset()

    def configure_optimizers(self):
        optimizer = self.optimizer_partial(self.parameters())
        scheduler = self.lr_scheduler_partial(optimizer)
        
        if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
            scheduler = {
                'scheduler': scheduler,
                'monitor': self.metric_name,
            }
        return [optimizer], [scheduler]
    

In [None]:
import math
import random

trx_encoder = ptls.nn.TrxEncoder(
            norm_embeddings=False,
            embeddings_noise=0.003,
            embeddings={
                'event_type': {"in": 58, "out": 24},
                'event_subtype': {"in": 59, "out": 24},
                'src_type11': {"in": 85, "out": 24},
                'src_type12': {"in": 349, "out": 24},
                'dst_type11': {"in": 84, "out": 24},
                'dst_type12': {"in": 417, "out": 24},
                'src_type22': {"in": 90, "out": 24},
                'src_type32': {"in": 91, "out": 24}
            },
            numeric_values={
                'amount': 'log'
            }
        )

optimizer_partial = partial(
    torch.optim.AdamW,
    lr=0.001,
    weight_decay=1e-4
)

lr_scheduler_partial = partial(
    torch.optim.lr_scheduler.StepLR,
    step_size=1,
    gamma=0.9
)

module = RegionalModule(
    3,
    trx_encoder=trx_encoder,
    optimizer_partial=optimizer_partial,
    lr_scheduler_partial=lr_scheduler_partial
)

In [None]:
# batch = next(iter(data_module.train_dl(train_data=PatchedIterableDataset(
#         splitter=PatchesSplitter(
#             patch_size=12
#         ),
#         data=train
#     ))))

# module.forward(batch)

In [None]:
from lightning.pytorch.loggers import WandbLogger

wandb_logger = WandbLogger(project="MBD_My_Code", log_model="all")

trainer = pl.Trainer(
    logger=wandb_logger,
    max_epochs=50,
    accelerator="cuda" if torch.cuda.is_available() else "cpu",
    enable_progress_bar=True,
    gradient_clip_val=0.5,
    log_every_n_steps=50,
    limit_val_batches=32
)

In [None]:
trainer.fit(module, data_module)

# Results

**GPT Baseline (RPE)**

ROC-AUC target_0 = 0.721350441667294

ROC-AUC target_1 = 0.8153566994158479

ROC-AUC target_2 = 0.7543318213719505

ROC-AUC target_3 = 0.7678734355960962

**GPT + SWIN encoder (mean agg)**

ROC-AUC target_0 = 0.6516282035967174

ROC-AUC target_1 = 0.6540494938132733

ROC-AUC target_2 = 0.5948775519632066

ROC-AUC target_3 = 0.745110173028166

**CoLES + SWIN + mean**

ROC-AUC target_0 = 0.7036866820699578

ROC-AUC target_1 = 0.6407223242752476

ROC-AUC target_2 = 0.6655641307690068

ROC-AUC target_3 = 0.7021158878362178


**CoLES + SWIN + RNN**


ROC-AUC target_0 = 0.6988762517846345

ROC-AUC target_1 = 0.8023436653648343

ROC-AUC target_2 = 0.6582610542989475

ROC-AUC target_3 = 0.7862126496756947