In [None]:
!pip install lightning
!pip install pyspark
!pip install pytorch-lifestream

In [None]:
import wandb

wandb.login(key="79f2120f8d4212aceb2c60b3c89a1b6727c19cff")

In [None]:
from huggingface_hub import hf_hub_download 

hf_hub_download(repo_id="ai-lab/MBD-mini", filename="ptls.tar.gz", repo_type="dataset", local_dir="/kaggle/working/")
hf_hub_download(repo_id="ai-lab/MBD-mini", filename="targets.tar.gz", repo_type="dataset", local_dir="/kaggle/working/")

In [None]:
!tar -xf ptls.tar.gz
!tar -xf targets.tar.gz

In [None]:
import pandas as pd
import numpy as np
import os

import pyspark
from pyspark.sql import SparkSession
# import pyspark.sql.functions as F
from pyspark.sql import types as T
import time
import datetime
from ptls.data_load.datasets import ParquetDataset, ParquetFiles
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, FloatType, ArrayType
from tqdm.notebook import tqdm
from ptls.preprocessing import PysparkDataPreprocessor
import pytorch_lightning as pl
from ptls.data_load.datasets import MemoryMapDataset
from ptls.data_load.iterable_processing import SeqLenFilter, FeatureFilter
from ptls.data_load.iterable_processing.iterable_seq_len_limit import ISeqLenLimit
from ptls.data_load.iterable_processing.to_torch_tensor import ToTorch
from ptls.frames.coles import CoLESModule
from ptls.frames import PtlsDataModule
from ptls.frames.coles import ColesDataset
from ptls.frames.coles.split_strategy import SampleSlices
import torch
import numpy as np
import pandas as pd
import calendar
from glob import glob
from ptls.data_load.utils import collate_feature_dict

from ptls.data_load.iterable_processing_dataset import IterableProcessingDataset
from datetime import datetime
from ptls.data_load.padded_batch import PaddedBatch

In [None]:
SEED = 0
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# Prepare data

In [None]:
spark_conf = pyspark.SparkConf()
spark_conf.setMaster("local[*]").setAppName("JoinModality")
spark_conf.set("spark.driver.maxResultSize", "16g")
spark_conf.set("spark.executor.memory", "32g")
spark_conf.set("spark.executor.memoryOverhead", "16g")
spark_conf.set("spark.driver.memory", "32g")
spark_conf.set("spark.driver.memoryOverhead", "16g")
spark_conf.set("spark.cores.max", "8")
spark_conf.set("spark.sql.shuffle.partitions", "200")
spark_conf.set("spark.local.dir", "../../spark_local_dir")


spark = SparkSession.builder.config(conf=spark_conf).getOrCreate()
spark.sparkContext.getConf().getAll()

In [None]:
!mkdir /kaggle/working/ptls/trx_supervised

In [None]:
TARGETS_DATA_PATH = '/kaggle/working/targets/'
TRX_DATA_PATH = '/kaggle/working/ptls/trx/'
TRX_SUPERVISED_PATH = '/kaggle/working/ptls/trx_supervised/'

preprocessor_target = PysparkDataPreprocessor(
    col_id="client_id",
    col_event_time="mon",
    event_time_transformation="dt_to_timestamp",
    cols_identity=["target_1", "target_2", "target_3", "target_4"],
)

In [None]:
for fold in range(5):
    targets = spark.read.parquet(os.path.join(TARGETS_DATA_PATH , f'fold={fold}'))
    trx = spark.read.parquet(os.path.join(TRX_DATA_PATH , f'fold={fold}'))
    
    targets = preprocessor_target.fit_transform(targets).drop(*['event_time' ,'trans_count', 'diff_trans_date'])
    trx = trx.join(targets, on='client_id', how='left')
    trx.write.mode('overwrite').parquet(os.path.join(TRX_SUPERVISED_PATH, f'fold={fold}'))

In [None]:
spark.stop()

In [None]:
import pandas as pd
import numpy as np
from ptls.data_load.iterable_processing_dataset import IterableProcessingDataset
from ptls.data_load import IterableChain
from datetime import datetime
from ptls.data_load.datasets.parquet_dataset import ParquetDataset, ParquetFiles
from ptls.data_load.iterable_processing.feature_filter import FeatureFilter
from ptls.data_load.iterable_processing.to_torch_tensor import ToTorch
import torch
from functools import partial
from torch.utils.data import DataLoader
from ptls.data_load.padded_batch import PaddedBatch
from ptls.data_load.utils import collate_feature_dict
from tqdm import tqdm
import ptls

class TargetToTorch(IterableProcessingDataset):
    def __init__(self, col_target):
        super().__init__()
        self.col_target = col_target

    def __iter__(self):
        for rec in self._src:
            features = rec[0] if type(rec) is tuple else rec
            features[self.col_target] = np.stack(np.array(features[self.col_target]))
            features[self.col_target] = torch.tensor(features[self.col_target])
            yield features

class DeleteNan(IterableProcessingDataset):
    def __init__(self, col_name):
        super().__init__()
        self.col_name = col_name
    
    def __iter__(self):
        for rec in self._src:
            features = rec[0] if type(rec) is tuple else rec
            if features[self.col_name] is not None:
                yield features


class DialToTorch(IterableProcessingDataset):
    def __init__(self, col_time, col_embeds):
        super().__init__()
        self._year=2022
        self.col_embeds = col_embeds
        self.col_time = col_time
    def __iter__(self):
        for rec in self._src:
            features = rec[0] if type(rec) is tuple else rec
            features = features.copy()
            if features[self.col_time] is None:
                features[self.col_time] = torch.tensor([0])
            if features[self.col_embeds] is None:
                features[self.col_embeds] = torch.zeros(768)
            
            for key, tens in features.items():
                if key == self.col_embeds:
                    features[key] = torch.tensor(tens.tolist())

            yield features

class GetSplit(IterableProcessingDataset):
    def __init__(
        self,
        start_month,
        end_month,
        year=2022,
        col_id='client_id',
        col_time='event_time'
    ):
        super().__init__()
        self.start_month = start_month
        self.end_month = end_month
        self._year = year
        self._col_id = col_id
        self._col_time = col_time
        
    def __iter__(self):
        for rec in self._src:
            for month in range(self.start_month, self.end_month+1):
                features = rec[0] if type(rec) is tuple else rec
                features = features.copy()
                
                if month == 12:
                    month_event_time = datetime(self._year + 1, 1, 1).timestamp()
                else:
                    month_event_time = datetime(self._year, month + 1, 1).timestamp()
                    
                year_event_time = datetime(self._year, 1, 1).timestamp()
                
                mask = features[self._col_time] < month_event_time
                for key, tensor in features.items():
                    if key.startswith('target'):
                        features[key] = tensor[month - 1].tolist()    
                    elif key != self._col_id:
                        features[key] = tensor[mask] 
                            
                features[self._col_id] += '_month=' + str(month)

                yield features

# Baseline

**load data**

In [None]:
train = ptls.data_load.datasets.ParquetDataset(
        data_files=[
            os.path.join(TRX_SUPERVISED_PATH, 'fold=0'),
            os.path.join(TRX_SUPERVISED_PATH, 'fold=1'),
            os.path.join(TRX_SUPERVISED_PATH, 'fold=2'),
        ],
    i_filters=[
        ptls.data_load.iterable_processing.SeqLenFilter(min_seq_len=16),
        ptls.data_load.iterable_processing.SeqLenFilter(max_seq_len=2048),
        ptls.data_load.iterable_processing.feature_filter.FeatureFilter(
            drop_feature_names=[
                'client_id',
                'target_1',
                'target_2',
                'target_3',
                'target_4'
            ]
        ),
        ptls.data_load.iterable_processing.to_torch_tensor.ToTorch()
    ],
    shuffle_files=True
)

valid = ptls.data_load.datasets.ParquetDataset(
        data_files=[
            os.path.join(TRX_DATA_PATH, 'fold=3')
        ],
    i_filters=[
        ptls.data_load.iterable_processing.SeqLenFilter(min_seq_len=16),
        ptls.data_load.iterable_processing.SeqLenFilter(max_seq_len=2048),
        ptls.data_load.iterable_processing.feature_filter.FeatureFilter(
            drop_feature_names=[
                'client_id',
                'target_1',
                'target_2',
                'target_3',
                'target_4'
            ]
        ),
        ptls.data_load.iterable_processing.to_torch_tensor.ToTorch()
    ]
)

In [None]:
# next(iter(train))

**Create data module**

In [None]:
data_module = ptls.frames.PtlsDataModule(
    train_data=ptls.frames.coles.ColesIterableDataset(
        splitter=ptls.frames.coles.split_strategy.SampleSlices(
            split_count=3,
            cnt_min=16,
            cnt_max=180
        ),
        data=train
    ),
    valid_data=ptls.frames.coles.ColesIterableDataset(
        splitter=ptls.frames.coles.split_strategy.SampleSlices(
            split_count=3,
            cnt_min=16,
            cnt_max=180
        ),
        data=valid
    ),
    train_batch_size=64,
    train_num_workers=0,
    valid_batch_size=32,
    valid_num_workers=0
)

**Baseline CoLES module**

In [None]:
optimizer_partial = partial(
    torch.optim.AdamW,
    lr=0.001,
    weight_decay=1e-4
)

lr_scheduler_partial = partial(
    torch.optim.lr_scheduler.StepLR,
    step_size=2,
    gamma=0.9025
)

seq_encoder = ptls.nn.RnnSeqEncoder(
        trx_encoder=ptls.nn.TrxEncoder(
            norm_embeddings=False,
            embeddings_noise=0.003,
            embeddings={
                'event_type': {"in": 58, "out": 24},
                'event_subtype': {"in": 59, "out": 24},
                'src_type11': {"in": 85, "out": 24},
                'src_type12': {"in": 349, "out": 24},
                'dst_type11': {"in": 84, "out": 24},
                'dst_type12': {"in": 417, "out": 24},
                'src_type22': {"in": 90, "out": 24},
                'src_type32': {"in": 91, "out": 24}
            },
            numeric_values={
                'amount': 'log'
            }
        ),
        type='gru',
        hidden_size=256
    )

pl_module = ptls.frames.coles.CoLESModule(
    validation_metric=ptls.frames.coles.metric.BatchRecallTopK(
        K=4,
        metric='cosine'
    ),
    seq_encoder=seq_encoder,
    optimizer_partial=optimizer_partial,
    lr_scheduler_partial=lr_scheduler_partial
)

In [None]:
from lightning.pytorch.loggers import WandbLogger

wandb_logger = WandbLogger(project="MBD_My_Code", log_model="all", name="mbd_uni_baseline")

trainer = pl.Trainer(
    logger=wandb_logger,
    max_epochs=20,
    accelerator="cuda" if torch.cuda.is_available() else "cpu",
    enable_progress_bar=True,
    gradient_clip_val=0.5,
    log_every_n_steps=50,
    limit_val_batches=32
)

In [None]:
trainer.fit(pl_module, data_module)

**Inference task**

In [None]:
from ptls.data_load.iterable_processing.iterable_seq_len_limit import ISeqLenLimit
from ptls.data_load.iterable_processing.feature_filter import FeatureFilter
from ptls.data_load.iterable_processing.to_torch_tensor import ToTorch

inference_dataset = ptls.data_load.datasets.ParquetDataset(
    data_files=[os.path.join(TRX_SUPERVISED_PATH, 'fold=4')],
    i_filters=[
        ISeqLenLimit(max_seq_len=4096),
        ToTorch(),
        FeatureFilter(
            keep_feature_names=[
                'client_id',
                'target_1',
                'target_2',
                'target_3',
                'target_4'
            ]
        ),
        GetSplit(
            start_month=1,
            end_month=12,
            col_id='client_id'
        )
    ]
)

In [None]:
inference_dl = DataLoader(
    dataset=inference_dataset,
    shuffle=False,
    num_workers=0,
    batch_size=32
)

In [None]:
import pandas as pd
import pytorch_lightning as pl
import torch
import numpy as np
from itertools import chain
from ptls.data_load.padded_batch import PaddedBatch
from datetime import datetime
from ptls.custom_layers import StatPooling
from ptls.nn.seq_step import LastStepEncoder


class InferenceModuleMultimodal(pl.LightningModule):
    def __init__(
        self,
        model,
        pandas_output=True,
        col_id='client_id',
        target_col_names=None,
        model_out_name='emb',
        model_type='notab'
    ):
        super().__init__()

        self.model = model
        self.pandas_output = pandas_output
        self.target_col_names = target_col_names
        self.col_id = col_id
        self.model_out_name = model_out_name
        self.model_type = model_type

        self.stat_pooler = StatPooling()
        self.last_step = LastStepEncoder()

    def forward(self, x):
        x_len = len(x)
        if x_len == 3:
            x, batch_ids, target_cols = x
        else: 
            x, batch_ids = x
        if 'seq_encoder' in dir(self.model):
            out = self.model.seq_encoder(x)
        else:
            out = self.model(x)
            
        if x_len == 3:
            target_cols = torch.tensor(target_cols)
            x_out = {
                self.col_id: batch_ids,
                self.model_out_name: out
            }
            if len(target_cols.size()) > 1:
                for idx, target_col in enumerate(self.target_col_names):
                    x_out[target_col] = target_cols[:, idx]
            else: 
                x_out[self.target_col_names[0]] = target_cols[::4]
        else:
            x_out = {
                self.col_id: batch_ids,
                self.model_out_name: out
            }

        if self.pandas_output:
            return self.to_pandas(x_out)
        return x_out

    @staticmethod
    def to_pandas(x):
        expand_cols = []
        scalar_features = {}

        for k, v in x.items():
            if type(v) is torch.Tensor:
                v = v.cpu().numpy()

            if type(v) is list or len(v.shape) == 1:
                scalar_features[k] = v
            elif len(v.shape) == 2:
                expand_cols.append(k)
            else:
                scalar_features[k] = None

        dataframes = [pd.DataFrame(scalar_features)]
        for col in expand_cols:
            v = x[col].cpu().numpy()
            dataframes.append(pd.DataFrame(v, columns=[f'{col}_{i:04d}' for i in range(v.shape[1])]))

        return pd.concat(dataframes, axis=1)

def collate_feature_dict_with_target(batch, col_id='client_id', target_col_names=None):
    batch_ids = []
    target_cols = []
    for sample in batch:
        batch_ids.append(sample[col_id])
        del sample[col_id]
        
        if target_col_names is not None:
            sample_targets = []
            for target_col in target_col_names:
                sample_targets.append(sample[target_col])
                del sample[target_col]
            target_cols.append(sample_targets)
                
            
    padded_batch = collate_feature_dict(batch)
    if target_col_names is not None:
        return padded_batch, batch_ids, target_cols
    return padded_batch, batch_ids

In [None]:
target_col_names = [
    'target_1',
    'target_2',
    'target_3',
    'target_4'
]

collate_fn = partial(
    collate_feature_dict_with_target,
    target_col_names=target_col_names
)

inference_dl = DataLoader(
    dataset=inference_dataset,
    collate_fn=collate_fn,
    shuffle=False,
    num_workers=0,
    batch_size=32
)

In [None]:
inf_module = InferenceModuleMultimodal(
    model=pl_module,
    pandas_output=True,
    col_id='client_id',
    target_col_names=target_col_names
)

In [None]:
inf_embeddings = pd.concat(trainer.predict(inf_module, inference_dl))

In [None]:
inf_embeddings

**downstream task**

In [None]:
from sklearn.model_selection import train_test_split

dwns_train, dwns_test = train_test_split(inf_embeddings, test_size=0.2)

In [None]:
targets_train = np.array(
    [
        dwns_train['target_1'].to_numpy(),
        dwns_train['target_2'].to_numpy(),
        dwns_train['target_3'].to_numpy(),
        dwns_train['target_4'].to_numpy()
    ]
).T
targets_test = np.array(
    [
        dwns_test['target_1'].to_numpy(),
        dwns_test['target_2'].to_numpy(),
        dwns_test['target_3'].to_numpy(),
        dwns_test['target_4'].to_numpy()
    ]
).T

dwns_train = dwns_train.drop(columns=[
    'client_id',
    'target_1',
    'target_2',
    'target_3',
    'target_4'
]).to_numpy()

dwns_test = dwns_test.drop(columns=[
    'client_id',
    'target_1',
    'target_2',
    'target_3',
    'target_4'
]).to_numpy()

In [None]:
from lightgbm import LGBMClassifier

models = [LGBMClassifier(
    n_estimators=500,
    boosting_type='gbdt',
    subsample=0.5,
    subsample_freq=1,
    learning_rate=0.02,
    feature_fraction=0.75,
    max_depth=6,
    lambda_l1=1,
    lambda_l2=1,
    min_data_in_leaf=50,
    random_state=42,
    n_jobs=8,
    verbose=-1
) for _ in range(4)]

In [None]:
for target_id in range(4):
    models[target_id].fit(dwns_train, targets_train[:, target_id])

In [None]:
from sklearn.metrics import classification_report
from sklearn.metrics import roc_auc_score

for i in range(len(models)):
    preds = models[i].predict_proba(dwns_test)
    print(f"ROC-AUC target_{i} = {roc_auc_score(targets_test[:, i], preds[:, 1])}")

# Regional Attention

In [None]:
import torch
from torch import nn

from ptls.data_load import PaddedBatch
from ptls.nn.seq_encoder.rnn_encoder import RnnEncoder
from ptls.nn.seq_encoder.transformer_encoder import TransformerEncoder
from ptls.nn.seq_encoder.longformer_encoder import LongformerEncoder
from ptls.nn.seq_encoder.custom_encoder import Encoder
from ptls.nn.trx_encoder import TrxEncoder
from ptls.nn.seq_encoder.containers import SeqEncoderContainer
import torch.nn.functional as F

class RnnSeqEncoderRegAttn(SeqEncoderContainer):
    def __init__(self,
                 trx_encoder=None,
                 input_size=None,
                 is_reduce_sequence=True,
                 **seq_encoder_params,
                 ):
        super().__init__(
            trx_encoder=trx_encoder,
            seq_encoder_cls=RnnEncoder,
            input_size=input_size,
            seq_encoder_params=seq_encoder_params,
            is_reduce_sequence=is_reduce_sequence,
        )
        
        self.reg_seq_encoder = RnnEncoder(
            input_size=input_size if input_size is not None else trx_encoder.output_size,
            is_reduce_sequence=is_reduce_sequence,
            **seq_encoder_params,
        )

        self.emb_dim = 256
        self.regional_attention = nn.MultiheadAttention(
            embed_dim=self.emb_dim,
            num_heads=8,
            dropout=0.3,
            batch_first=True
        )

    def forward(self, x, names=None, seq_len=None, h_0=None):
        # print(f"x_in = {x.payload['amount'].size()}")
        x = self.trx_encoder(x)

        x_new = x.payload
        
        segment_length = 10
        pad_length = (segment_length - (x_new.size()[1] % segment_length)) % segment_length
        padded_x_new = F.pad(x_new, ((0, 0, 0, pad_length, 0, 0)), 'constant', 0)
        segmented_tensors = torch.stack(torch.split(padded_x_new, segment_length, dim=1)).to(x_new.device)
        
        regional_embeddings = torch.Tensor().to(x.device)
        for tensor in segmented_tensors:
            tensor = PaddedBatch(tensor.permute(1, 0, 2), [tensor.size()[0]] * tensor.size()[1])
            regional_embed = self.reg_seq_encoder(tensor)
            regional_embeddings = torch.cat((regional_embeddings, regional_embed), 0)
        
        regional_embeddings = regional_embeddings[:len(regional_embeddings) - pad_length, :]
        # layer_norm = nn.LayerNorm([regional_embeddings.size()[0], regional_embeddings.size()[1]])
        # layer_norm.to(x.device)
        # regional_embeddings = layer_norm(regional_embeddings)
        # print(regional_embeddings.size())
        if regional_embeddings.size()[1] != self.emb_dim:
            regional_embeddings = F.pad(regional_embeddings, ((0, abs(regional_embeddings.size()[1] - self.emb_dim), 0, 0)), 'constant', 0)
        x_reg_embed, _ = self.regional_attention(regional_embeddings, regional_embeddings, regional_embeddings)
        if regional_embeddings.size()[1] != x_new.size()[0]:
            x_reg_embed = x_reg_embed[:, :-abs(regional_embeddings.size()[1] - x_new.size()[0])]
        x_reg_embed = x_reg_embed.permute(1, 0)
        x_reg_embed = x_reg_embed[:, :, None]
        x_new = x_new + x_reg_embed
        x_new.to(x.device)
        x_new = PaddedBatch(x_new, x.seq_lens)
        # x_new.to(x.device)
        x = self.seq_encoder(x_new, h_0)
        # print(f"rnn_x_size = {x.size()}")
        return x
    
    

In [None]:
optimizer_partial = partial(
    torch.optim.AdamW,
    lr=0.001,
    weight_decay=1e-4
)

lr_scheduler_partial = partial(
    torch.optim.lr_scheduler.StepLR,
    step_size=2,
    gamma=0.9025
)

seq_encoder = RnnSeqEncoderRegAttn(
        trx_encoder=ptls.nn.TrxEncoder(
            norm_embeddings=False,
            embeddings_noise=0.003,
            embeddings={
                'event_type': {"in": 58, "out": 24},
                'event_subtype': {"in": 59, "out": 24},
                'src_type11': {"in": 85, "out": 24},
                'src_type12': {"in": 349, "out": 24},
                'dst_type11': {"in": 84, "out": 24},
                'dst_type12': {"in": 417, "out": 24},
                'src_type22': {"in": 90, "out": 24},
                'src_type32': {"in": 91, "out": 24}
            },
            numeric_values={
                'amount': 'log'
            }
        ),
        type='gru',
        hidden_size=256
    )

pl_module = ptls.frames.coles.CoLESModule(
    validation_metric=ptls.frames.coles.metric.BatchRecallTopK(
        K=4,
        metric='cosine'
    ),
    seq_encoder=seq_encoder,
    optimizer_partial=optimizer_partial,
    lr_scheduler_partial=lr_scheduler_partial
)

In [None]:
from lightning.pytorch.loggers import WandbLogger

wandb_logger = WandbLogger(project="MBD_My_Code", log_model="all", name="mdb_uni_regattn")

trainer = pl.Trainer(
    logger=wandb_logger,
    max_epochs=15,
    accelerator="cuda" if torch.cuda.is_available() else "cpu",
    enable_progress_bar=True,
    gradient_clip_val=0.5,
    log_every_n_steps=50,
    limit_val_batches=32
)

In [None]:
trainer.fit(pl_module, data_module)

**inference**

In [None]:
from ptls.data_load.iterable_processing.iterable_seq_len_limit import ISeqLenLimit
from ptls.data_load.iterable_processing.feature_filter import FeatureFilter
from ptls.data_load.iterable_processing.to_torch_tensor import ToTorch

inference_dataset = ptls.data_load.datasets.ParquetDataset(
    data_files=[os.path.join(TRX_SUPERVISED_PATH, 'fold=4')],
    i_filters=[
        ISeqLenLimit(max_seq_len=4096),
        ToTorch(),
        FeatureFilter(
            keep_feature_names=[
                'client_id',
                'target_1',
                'target_2',
                'target_3',
                'target_4'
            ]
        ),
        GetSplit(
            start_month=1,
            end_month=12,
            col_id='client_id'
        )
    ]
)

In [None]:
inference_dl = DataLoader(
    dataset=inference_dataset,
    shuffle=False,
    num_workers=0,
    batch_size=32
)

In [None]:
import pandas as pd
import pytorch_lightning as pl
import torch
import numpy as np
from itertools import chain
from ptls.data_load.padded_batch import PaddedBatch
from datetime import datetime
from ptls.custom_layers import StatPooling
from ptls.nn.seq_step import LastStepEncoder


class InferenceModuleMultimodal(pl.LightningModule):
    def __init__(
        self,
        model,
        pandas_output=True,
        col_id='client_id',
        target_col_names=None,
        model_out_name='emb',
        model_type='notab'
    ):
        super().__init__()

        self.model = model
        self.pandas_output = pandas_output
        self.target_col_names = target_col_names
        self.col_id = col_id
        self.model_out_name = model_out_name
        self.model_type = model_type

        self.stat_pooler = StatPooling()
        self.last_step = LastStepEncoder()

    def forward(self, x):
        x_len = len(x)
        if x_len == 3:
            x, batch_ids, target_cols = x
        else: 
            x, batch_ids = x
        if 'seq_encoder' in dir(self.model):
            out = self.model.seq_encoder(x)
        else:
            out = self.model(x)
            
        if x_len == 3:
            target_cols = torch.tensor(target_cols)
            x_out = {
                self.col_id: batch_ids,
                self.model_out_name: out
            }
            if len(target_cols.size()) > 1:
                for idx, target_col in enumerate(self.target_col_names):
                    x_out[target_col] = target_cols[:, idx]
            else: 
                x_out[self.target_col_names[0]] = target_cols[::4]
        else:
            x_out = {
                self.col_id: batch_ids,
                self.model_out_name: out
            }

        if self.pandas_output:
            return self.to_pandas(x_out)
        return x_out

    @staticmethod
    def to_pandas(x):
        expand_cols = []
        scalar_features = {}

        for k, v in x.items():
            if type(v) is torch.Tensor:
                v = v.cpu().numpy()

            if type(v) is list or len(v.shape) == 1:
                scalar_features[k] = v
            elif len(v.shape) == 2:
                expand_cols.append(k)
            else:
                scalar_features[k] = None

        dataframes = [pd.DataFrame(scalar_features)]
        for col in expand_cols:
            v = x[col].cpu().numpy()
            dataframes.append(pd.DataFrame(v, columns=[f'{col}_{i:04d}' for i in range(v.shape[1])]))

        return pd.concat(dataframes, axis=1)

def collate_feature_dict_with_target(batch, col_id='client_id', target_col_names=None):
    batch_ids = []
    target_cols = []
    for sample in batch:
        batch_ids.append(sample[col_id])
        del sample[col_id]
        
        if target_col_names is not None:
            sample_targets = []
            for target_col in target_col_names:
                sample_targets.append(sample[target_col])
                del sample[target_col]
            target_cols.append(sample_targets)
                
            
    padded_batch = collate_feature_dict(batch)
    if target_col_names is not None:
        return padded_batch, batch_ids, target_cols
    return padded_batch, batch_ids

In [None]:
target_col_names = [
    'target_1',
    'target_2',
    'target_3',
    'target_4'
]

collate_fn = partial(
    collate_feature_dict_with_target,
    target_col_names=target_col_names
)

inference_dl = DataLoader(
    dataset=inference_dataset,
    collate_fn=collate_fn,
    shuffle=False,
    num_workers=0,
    batch_size=32
)

In [None]:
inf_module = InferenceModuleMultimodal(
    model=pl_module,
    pandas_output=True,
    col_id='client_id',
    target_col_names=target_col_names
)

In [None]:
inf_embeddings = pd.concat(trainer.predict(inf_module, inference_dl))

**downstream**

In [None]:
from sklearn.model_selection import train_test_split

dwns_train, dwns_test = train_test_split(inf_embeddings, test_size=0.2)

In [None]:
targets_train = np.array(
    [
        dwns_train['target_1'].to_numpy(),
        dwns_train['target_2'].to_numpy(),
        dwns_train['target_3'].to_numpy(),
        dwns_train['target_4'].to_numpy()
    ]
).T
targets_test = np.array(
    [
        dwns_test['target_1'].to_numpy(),
        dwns_test['target_2'].to_numpy(),
        dwns_test['target_3'].to_numpy(),
        dwns_test['target_4'].to_numpy()
    ]
).T

dwns_train = dwns_train.drop(columns=[
    'client_id',
    'target_1',
    'target_2',
    'target_3',
    'target_4'
]).to_numpy()

dwns_test = dwns_test.drop(columns=[
    'client_id',
    'target_1',
    'target_2',
    'target_3',
    'target_4'
]).to_numpy()

In [None]:
from lightgbm import LGBMClassifier

models = [LGBMClassifier(
    n_estimators=500,
    boosting_type='gbdt',
    subsample=0.5,
    subsample_freq=1,
    learning_rate=0.02,
    feature_fraction=0.75,
    max_depth=6,
    lambda_l1=1,
    lambda_l2=1,
    min_data_in_leaf=50,
    random_state=42,
    n_jobs=8,
    verbose=-1
) for _ in range(4)]

In [None]:
for target_id in range(4):
    models[target_id].fit(dwns_train, targets_train[:, target_id])

In [None]:
from sklearn.metrics import classification_report
from sklearn.metrics import roc_auc_score

for i in range(len(models)):
    preds = models[i].predict_proba(dwns_test)
    print(f"ROC-AUC target_{i} = {roc_auc_score(targets_test[:, i], preds[:, 1])}")

# Cross-Attention

In [None]:
import torch.nn.functional as F

class SplitToPatches(IterableProcessingDataset):
    """
    Create small and large patches for NN with cross-attention
    """
    def __init__(
        self,
        small_patches_size=3,
        large_patches_size=12,
        col_id='client_id',
        col_time='event_time'
    ):
        super().__init__()
        self.small_patches_size = small_patches_size
        self.large_patches_size = large_patches_size
        self._col_id = col_id
        self._col_time = col_time
        
    def __iter__(self):
        for rec in self._src:
            features = rec[0] if type(rec) is tuple else rec
            features = features.copy()

            patched_features = {}
    
            for key, tensor in features.items():
                if key.startswith('target'):
                    patched_features[key] = features[key]
                elif key != self._col_id:
                    small_patches = list(torch.split(features[key], self.small_patches_size))
                    if small_patches[-1].size() != self.small_patches_size:
                        small_patches[-1] = F.pad(small_patches[-1], (0, self.small_patches_size - len(small_patches[-1])), "constant", small_patches[-1][-1])
                    small_patches = torch.stack(small_patches)
                    
                    large_patches = list(torch.split(features[key], self.large_patches_size))
                    if large_patches[-1].size() != self.large_patches_size:
                        large_patches[-1] = F.pad(large_patches[-1], (0, self.large_patches_size - len(large_patches[-1])), "constant", large_patches[-1][-1])
                    large_patches = torch.stack(large_patches)

                    patched_features[key] = features[key]
                    patched_features['small_' + key] = small_patches
                    patched_features['large_' + key] = large_patches
                    # del features[key]
    
            yield patched_features

In [None]:
train = ptls.data_load.datasets.ParquetDataset(
        data_files=[
            os.path.join(TRX_SUPERVISED_PATH, 'fold=0'),
            os.path.join(TRX_SUPERVISED_PATH, 'fold=1'),
            os.path.join(TRX_SUPERVISED_PATH, 'fold=2'),
        ],
    i_filters=[
        ptls.data_load.iterable_processing.SeqLenFilter(min_seq_len=16),
        ptls.data_load.iterable_processing.SeqLenFilter(max_seq_len=2048),
        ptls.data_load.iterable_processing.feature_filter.FeatureFilter(
            drop_feature_names=[
                'client_id',
                'target_1',
                'target_2',
                'target_3',
                'target_4'
            ]
        ),
        ptls.data_load.iterable_processing.to_torch_tensor.ToTorch(),
        # SplitToPatches()
    ],
    shuffle_files=True
)

valid = ptls.data_load.datasets.ParquetDataset(
        data_files=[
            os.path.join(TRX_DATA_PATH, 'fold=3')
        ],
    i_filters=[
        ptls.data_load.iterable_processing.SeqLenFilter(min_seq_len=16),
        ptls.data_load.iterable_processing.SeqLenFilter(max_seq_len=2048),
        ptls.data_load.iterable_processing.feature_filter.FeatureFilter(
            drop_feature_names=[
                'client_id',
                'target_1',
                'target_2',
                'target_3',
                'target_4'
            ]
        ),
        ptls.data_load.iterable_processing.to_torch_tensor.ToTorch(),
        # SplitToPatches()
    ]
)

In [None]:
data_module = ptls.frames.PtlsDataModule(
    train_data=ptls.frames.coles.ColesIterableDataset(
        splitter=ptls.frames.coles.split_strategy.SampleSlices(
            split_count=3,
            cnt_min=16,
            cnt_max=180
        ),
        data=train
    ),
    valid_data=ptls.frames.coles.ColesIterableDataset(
        splitter=ptls.frames.coles.split_strategy.SampleSlices(
            split_count=3,
            cnt_min=16,
            cnt_max=180
        ),
        data=valid
    ),
    train_batch_size=64,
    train_num_workers=0,
    valid_batch_size=32,
    valid_num_workers=0
)

In [None]:
# next(iter(train))

In [None]:
a = torch.tensor([[1, 2], [1, 3]])
a = F.pad(a, (0, 1), 'replicate')
a

In [None]:
import torch
from torch import nn

from ptls.data_load import PaddedBatch
from ptls.nn.seq_encoder.rnn_encoder import RnnEncoder
from ptls.nn.seq_encoder.transformer_encoder import TransformerEncoder
from ptls.nn.seq_encoder.longformer_encoder import LongformerEncoder
from ptls.nn.seq_encoder.custom_encoder import Encoder
from ptls.nn.trx_encoder import TrxEncoder
from ptls.nn.seq_encoder.containers import SeqEncoderContainer
import torch.nn.functional as F
from ptls.nn.seq_encoder.custom_encoder import MLP

class RnnSeqEncoderCrossAttn(SeqEncoderContainer):
    def __init__(self,
                 trx_encoder=None,
                 input_size=None,
                 small_patches_size=3,
                 large_patches_size=12,
                 is_reduce_sequence=True,
                 **seq_encoder_params,
                 ):
        super().__init__(
            trx_encoder=trx_encoder,
            seq_encoder_cls=RnnEncoder,
            input_size=input_size,
            seq_encoder_params=seq_encoder_params,
            is_reduce_sequence=is_reduce_sequence,
        )
        self.small_patches_size = small_patches_size
        self.large_patches_size = large_patches_size
        self.emb_dim = 128
        self.large_attn = nn.MultiheadAttention(
            embed_dim=self.emb_dim,
            num_heads=8,
            dropout=0.3,
            batch_first=True
        )

        # self.small_attn = nn.MultiheadAttention(
        #     embed_dim=self.emb_dim,
        #     num_heads=8,
        #     dropout=0.3,
        #     batch_first=True
        # )

        self.small_seq_encoder = RnnEncoder(
            input_size=input_size if input_size is not None else trx_encoder.output_size,
            is_reduce_sequence=is_reduce_sequence,
            type='gru',
            hidden_size=256
        )

        # MLP variation
        self.head_small = MLP(
                n_in=256,
                n_hidden=256,
                n_out=128
            )
        self.head_large = MLP(
                n_in=256,
                n_hidden=256,
                n_out=128
            ) 

    def forward(self, x, names=None, seq_len=None, h_0=None):
        """
        Возможно стоит использовать разные енкодеры для маленьких и для больших патчей
        """
        # for key, value in x.payload.items():
        #     small_x[key] = list(torch.split(value, self.small_patches_size, dim=1))
        #     if small_x[key][-1].size()[-1] != self.small_patches_size:
        #         pad_size = self.small_patches_size - small_x[key][-1].size()[-1]
        #         small_x[key][-1] = F.pad(small_x[key][-1], (0, pad_size), "replicate")
        #     small_x[key] = torch.stack(small_x[key])

        #     large_x[key] = list(torch.split(value, self.large_patches_size, dim=1))
        #     if large_x[key][-1].size()[-1] != self.large_patches_size:
        #         pad_size = self.large_patches_size - large_x[key][-1].size()[-1]
        #         large_x[key][-1] = F.pad(large_x[key][-1], (0, pad_size), "replicate")
        #     large_x[key] = torch.stack(large_x[key])
        x = self.trx_encoder(x)
        x_new = x.payload
        # print(f"{x_new.size()=}")
        # x_embs = self.seq_encoder(x)
        # print(f"{x_embs.size()=}")

        real_size = x_new.size()[1]
        small_patches = list(torch.split(x_new, self.small_patches_size, dim=1))
        large_patches = list(torch.split(x_new, self.large_patches_size, dim=1))

        if small_patches[-1].size()[1] != self.large_patches_size:
            pad_size = self.small_patches_size - small_patches[-1].size()[1]
            # print(pad_size)
            # print(small_patches[-1].size())
            small_patches[-1] = F.pad(small_patches[-1], (0, 0, pad_size, 0), "replicate")

        if large_patches[-1].size()[1] != self.large_patches_size:
            pad_size = self.large_patches_size - large_patches[-1].size()[1]
            large_patches[-1] = F.pad(large_patches[-1], (0, 0, pad_size, 0), "replicate")

        small_patches = torch.stack(small_patches)
        large_patches = torch.stack(large_patches)
        
        

        large_comp_embed = torch.zeros(x_new.size()[0], 128).to(x.device)
        for large_patch in large_patches:
            large_patch = PaddedBatch(large_patch, [large_patch.size()[1]] * large_patch.size()[0])
            large_emb = self.seq_encoder(large_patch)
            large_emb = self.head_large(large_emb)
            large_comp_embed = (large_comp_embed + large_emb) / 2

        small_comp_embed = torch.zeros(x_new.size()[0], 128).to(x.device)
        for small_patch in small_patches:
            small_patch = PaddedBatch(small_patch, [small_patch.size()[1]] * small_patch.size()[0])
            small_emb = self.small_seq_encoder(small_patch)
            small_emb = self.head_small(small_emb)
            small_comp_embed = (small_comp_embed + small_emb) / 2

        # small_attn_emb, _ = self.small_attn(small_comp_embed, large_comp_embed, large_comp_embed)
        large_attn_emb, _ = self.large_attn(large_comp_embed, small_comp_embed, small_comp_embed)

        out = torch.cat((small_comp_embed, large_attn_emb), dim=1)

        return out


In [None]:
optimizer_partial = partial(
    torch.optim.AdamW,
    lr=0.001,
    weight_decay=1e-4
)

lr_scheduler_partial = partial(
    torch.optim.lr_scheduler.StepLR,
    step_size=2,
    gamma=0.9025
)

seq_encoder = RnnSeqEncoderCrossAttn(
        trx_encoder=ptls.nn.TrxEncoder(
            norm_embeddings=False,
            embeddings_noise=0.003,
            embeddings={
                'event_type': {"in": 58, "out": 24},
                'event_subtype': {"in": 59, "out": 24},
                'src_type11': {"in": 85, "out": 24},
                'src_type12': {"in": 349, "out": 24},
                'dst_type11': {"in": 84, "out": 24},
                'dst_type12': {"in": 417, "out": 24},
                'src_type22': {"in": 90, "out": 24},
                'src_type32': {"in": 91, "out": 24}
            },
            numeric_values={
                'amount': 'log'
            }
        ),
        small_patches_size=3,
        large_patches_size=12,
        type='gru',
        hidden_size=256
    )

pl_module = ptls.frames.coles.CoLESModule(
    validation_metric=ptls.frames.coles.metric.BatchRecallTopK(
        K=10,
        metric='cosine'
    ),
    seq_encoder=seq_encoder,
    optimizer_partial=optimizer_partial,
    lr_scheduler_partial=lr_scheduler_partial
)

In [None]:
from lightning.pytorch.loggers import WandbLogger

wandb_logger = WandbLogger(project="MBD_My_Code", log_model="all", name='mbd_uni_crossattn')

trainer = pl.Trainer(
    logger=wandb_logger,
    max_epochs=15,
    accelerator="cuda" if torch.cuda.is_available() else "cpu",
    enable_progress_bar=True,
    gradient_clip_val=0.5,
    log_every_n_steps=50,
    limit_val_batches=32
)

In [None]:
trainer.fit(pl_module, data_module)

# Aggregate over time

In [None]:
from datetime import timedelta

import torch.nn.functional as F

class AggregateOverTime(IterableProcessingDataset):
    """
    Create small and large patches for NN with cross-attention
    """
    def __init__(
        self,
        col_id='client_id',
        col_time='event_time'
    ):
        super().__init__()
        self._col_id = col_id
        self._col_time = col_time
        
    def __iter__(self):
        for rec in self._src:
            features = rec[0] if type(rec) is tuple else rec
            features = features.copy()

            start_date = datetime.fromtimestamp(features[self._col_time][0])
            start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
            
            end_date = datetime.fromtimestamp(features[self._col_time][-1])
            end_date = end_date.replace(hour=0, minute=0, second=0, microsecond=0)

            num_days = (end_date - start_date).days + 2

            # print(start_date.timestamp())
            
            days_list = [(start_date + timedelta(days=i)).timestamp() for i in range(num_days)]
            # print(f"{features['event_time']=}")

            masks = {
                days_list[i]: ((features['event_time'] < days_list[i + 1]) & (features['event_time'] > days_list[i])) for i in range(len(days_list) - 1)
            }

            truth_masks = {}

            for key, value in masks.items():
                if value.sum() > 0:
                    truth_masks[key] = value

            masks = truth_masks
            del truth_masks

            # print(f"{masks=}")

            new_features = {}
            for feat_key, feat_tensor in features.items():
                new_features[feat_key] = []
                for mask_key, mask_value in masks.items():
                    if feat_key.startswith('target'):
                        new_features[feat_key] = features[feat_key]
                    elif feat_key == self._col_time:
                        new_features[feat_key].append(int(mask_key))
                    elif feat_key != self._col_id:
                        if len(features[feat_key][mask_value]) > 1:
                            if features[feat_key][mask_value].dtype == torch.float32:
                                new_features[feat_key].append(features[feat_key][mask_value].mean().item())
                            else:
                                values, counts = torch.unique(features[feat_key][mask_value], return_counts=True)
                                most_common_value = values[torch.argmax(counts)]
                                new_features[feat_key].append(most_common_value.item())
                        else:
                            new_features[feat_key].append(features[feat_key][mask_value])

            for key, tensor in features.items():
                if key.startswith('target'):
                    new_features[key] = new_features[key]
                elif key == self._col_time:
                    new_features[key] = torch.Tensor(new_features[key]).int()
                elif key != self._col_id:
                    new_features[key] = torch.tensor([t for t in new_features[key]])
            
            yield new_features

In [None]:
train = ptls.data_load.datasets.ParquetDataset(
        data_files=[
            os.path.join(TRX_SUPERVISED_PATH, 'fold=0'),
            os.path.join(TRX_SUPERVISED_PATH, 'fold=1'),
            os.path.join(TRX_SUPERVISED_PATH, 'fold=2'),
        ],
    i_filters=[
        ptls.data_load.iterable_processing.SeqLenFilter(min_seq_len=16),
        ptls.data_load.iterable_processing.SeqLenFilter(max_seq_len=2048),
        ptls.data_load.iterable_processing.feature_filter.FeatureFilter(
            drop_feature_names=[
                'client_id',
                'target_1',
                'target_2',
                'target_3',
                'target_4'
            ]
        ),
        AggregateOverTime(),
        ptls.data_load.iterable_processing.to_torch_tensor.ToTorch()
    ],
    shuffle_files=True
)

valid = ptls.data_load.datasets.ParquetDataset(
        data_files=[
            os.path.join(TRX_DATA_PATH, 'fold=3')
        ],
    i_filters=[
        ptls.data_load.iterable_processing.SeqLenFilter(min_seq_len=16),
        ptls.data_load.iterable_processing.SeqLenFilter(max_seq_len=2048),
        ptls.data_load.iterable_processing.feature_filter.FeatureFilter(
            drop_feature_names=[
                'client_id',
                'target_1',
                'target_2',
                'target_3',
                'target_4'
            ]
        ),
        AggregateOverTime(),
        ptls.data_load.iterable_processing.to_torch_tensor.ToTorch()
    ]
)

In [None]:
next(iter(train))

In [None]:
data_module = ptls.frames.PtlsDataModule(
    train_data=ptls.frames.coles.ColesIterableDataset(
        splitter=ptls.frames.coles.split_strategy.SampleSlices(
            split_count=3,
            cnt_min=16,
            cnt_max=180
        ),
        data=train
    ),
    valid_data=ptls.frames.coles.ColesIterableDataset(
        splitter=ptls.frames.coles.split_strategy.SampleSlices(
            split_count=3,
            cnt_min=16,
            cnt_max=180
        ),
        data=valid
    ),
    train_batch_size=64,
    train_num_workers=0,
    valid_batch_size=32,
    valid_num_workers=0
)

In [None]:
optimizer_partial = partial(
    torch.optim.AdamW,
    lr=0.001,
    weight_decay=1e-4
)

lr_scheduler_partial = partial(
    torch.optim.lr_scheduler.StepLR,
    step_size=2,
    gamma=0.9025
)

seq_encoder = ptls.nn.RnnSeqEncoder(
        trx_encoder=ptls.nn.TrxEncoder(
            norm_embeddings=False,
            embeddings_noise=0.003,
            embeddings={
                'event_type': {"in": 58, "out": 24},
                'event_subtype': {"in": 59, "out": 24},
                'src_type11': {"in": 85, "out": 24},
                'src_type12': {"in": 349, "out": 24},
                'dst_type11': {"in": 84, "out": 24},
                'dst_type12': {"in": 417, "out": 24},
                'src_type22': {"in": 90, "out": 24},
                'src_type32': {"in": 91, "out": 24}
            },
            numeric_values={
                'amount': 'log'
            }
        ),
        type='gru',
        hidden_size=256
    )

pl_module = ptls.frames.coles.CoLESModule(
    validation_metric=ptls.frames.coles.metric.BatchRecallTopK(
        K=4,
        metric='cosine'
    ),
    seq_encoder=seq_encoder,
    optimizer_partial=optimizer_partial,
    lr_scheduler_partial=lr_scheduler_partial
)

In [None]:
from lightning.pytorch.loggers import WandbLogger

wandb_logger = WandbLogger(project="MBD_My_Code", log_model="all", name='mbd_agg_overtime')

trainer = pl.Trainer(
    logger=wandb_logger,
    max_epochs=20,
    accelerator="cuda" if torch.cuda.is_available() else "cpu",
    enable_progress_bar=True,
    gradient_clip_val=0.5,
    log_every_n_steps=50,
    limit_val_batches=32
)

In [None]:
trainer.fit(pl_module, data_module)

# Custom transformer block

In [None]:
import torch.nn.functional as F

class SplitToPatches(IterableProcessingDataset):
    """
    Create small and large patches for NN with cross-attention
    """
    def __init__(
        self,
        small_patches_size=3,
        large_patches_size=12,
        col_id='client_id',
        col_time='event_time'
    ):
        super().__init__()
        self.small_patches_size = small_patches_size
        self.large_patches_size = large_patches_size
        self._col_id = col_id
        self._col_time = col_time
        
    def __iter__(self):
        for rec in self._src:
            features = rec[0] if type(rec) is tuple else rec
            features = features.copy()

            patched_features = {}
            small_patch = {}
            large_patch = {}
    
            for key, tensor in features.items():
                if key.startswith('target'):
                    patched_features[key] = features[key]
                elif key != self._col_id:
                    small_patches = list(torch.split(features[key], self.small_patches_size))
                    if small_patches[-1].size() != self.small_patches_size:
                        small_patches[-1] = F.pad(small_patches[-1], (0, self.small_patches_size - len(small_patches[-1])), "constant", small_patches[-1][-1])
                    small_patches = torch.stack(small_patches)
                    
                    large_patches = list(torch.split(features[key], self.large_patches_size))
                    if large_patches[-1].size() != self.large_patches_size:
                        large_patches[-1] = F.pad(large_patches[-1], (0, self.large_patches_size - len(large_patches[-1])), "constant", large_patches[-1][-1])
                    large_patches = torch.stack(large_patches)

                    # patched_features[key] = features[key]
                    small_patch[key] = small_patches
                    large_patch[key] = large_patches
                    # del features[key]
    
            yield features, small_patch, large_patch

In [None]:
train = ptls.data_load.datasets.ParquetDataset(
        data_files=[
            os.path.join(TRX_SUPERVISED_PATH, 'fold=0'),
            os.path.join(TRX_SUPERVISED_PATH, 'fold=1'),
            os.path.join(TRX_SUPERVISED_PATH, 'fold=2'),
        ],
    i_filters=[
        ptls.data_load.iterable_processing.SeqLenFilter(min_seq_len=16),
        ptls.data_load.iterable_processing.SeqLenFilter(max_seq_len=2048),
        ptls.data_load.iterable_processing.feature_filter.FeatureFilter(
            drop_feature_names=[
                'client_id',
                'target_1',
                'target_2',
                'target_3',
                'target_4'
            ]
        ),
        ptls.data_load.iterable_processing.to_torch_tensor.ToTorch(),
        # SplitToPatches()
    ],
    shuffle_files=True
)

valid = ptls.data_load.datasets.ParquetDataset(
        data_files=[
            os.path.join(TRX_DATA_PATH, 'fold=3')
        ],
    i_filters=[
        ptls.data_load.iterable_processing.SeqLenFilter(min_seq_len=16),
        ptls.data_load.iterable_processing.SeqLenFilter(max_seq_len=2048),
        ptls.data_load.iterable_processing.feature_filter.FeatureFilter(
            drop_feature_names=[
                'client_id',
                'target_1',
                'target_2',
                'target_3',
                'target_4'
            ]
        ),
        ptls.data_load.iterable_processing.to_torch_tensor.ToTorch(),
        # SplitToPatches()
    ]
)

In [None]:
next(iter(train))

**Tabformer**

In [None]:
from ptls.frames.tabformer import TabformerIterableDataset

data_module = ptls.frames.PtlsDataModule(
    train_data=TabformerIterableDataset(
        data=train,
        max_len=300,
        min_len=80
    ),
    valid_data=TabformerIterableDataset(
        data=valid,
        max_len=300,
        min_len=80
    ),
    train_batch_size=64,
    train_num_workers=0,
    valid_batch_size=32,
    valid_num_workers=0,
)

In [None]:
optimizer_partial = partial(
    torch.optim.AdamW,
    lr=0.001,
    weight_decay=1e-4
)

lr_scheduler_partial = partial(
    torch.optim.lr_scheduler.StepLR,
    step_size=2,
    gamma=0.9025
)

trx_encoder = ptls.nn.TrxEncoder(
            norm_embeddings=False,
            embeddings_noise=0.003,
            embeddings={
                'event_type': {"in": 59, "out": 24},
                'event_subtype': {"in": 60, "out": 24},
                'src_type11': {"in": 86, "out": 24},
                'src_type12': {"in": 350, "out": 24},
                'dst_type11': {"in": 85, "out": 24},
                'dst_type12': {"in": 418, "out": 24},
                'src_type22': {"in": 91, "out": 24},
                'src_type32': {"in": 92, "out": 24},
                'amount': {"in": 11, "out": 24}
            },
        )

seq_encoder = ptls.nn.seq_encoder.CustomSeqEncoder(
    n_heads=2,
    n_layers=8,
    input_size=216,
    use_positional_encoding=True
)

feature_encoder = ptls.nn.TabFormerFeatureEncoder(
    n_cols=9,
    emb_dim=24
)

pl_module = TabformerPretrainModule(
    total_steps=20000,
    mask_prob=0.2,
    feature_encoder=feature_encoder,
    seq_encoder=seq_encoder,
    trx_encoder=trx_encoder
)

In [None]:
from lightning.pytorch.loggers import WandbLogger

wandb_logger = WandbLogger(project="MBD_My_Code", log_model="all")

trainer = pl.Trainer(
    logger=wandb_logger,
    max_epochs=15,
    accelerator="cuda" if torch.cuda.is_available() else "cpu",
    enable_progress_bar=True,
    gradient_clip_val=0.5,
    log_every_n_steps=50,
    limit_val_batches=32
)

In [None]:
trainer.fit(pl_module, data_module)

**----------------------------------------------------------------**

In [None]:
from ptls.nn.seq_encoder.abs_seq_encoder import AbsSeqEncoder

class CrossEncoder(AbsSeqEncoder):
    """
    Custom transformer encoder with cross-attention mechanism
    """
    def __init__(
         self,
         input_size: int,
         intermediate_size: int = 2048,
         num_hidden_layers: int = 8,
         num_attention_heads: int = 8,
         attn_block_mode: str = 'rezero',
         self_attn_mode: str = 'quadratic',
         aggregation_mode: str = 'mean',
         layer_norm=None,
         is_reduce_sequence=True
        ):
        super().__init__(is_reduce_sequence=is_reduce_sequence)

        self.transformer = torch.nn.Sequential(
            AttentionBlock(
                input_size, intermediate_size, attn_block_mode, self_attn_mode, layer_norm, num_attention_heads
            ) for _ in range(num_hidder_layer)
        )
        self.aggregation = Aggregation(reduction=aggregation_mode)
        self.is_reduce_sequence = is_reduce_sequence

    def forward(self, xs: PaddedBatch, xl: PaddedBatch):
        out = self.transformer(x.payload)
        out = self.aggregation(out)
        if self.is_reduce_sequence:
            return out
        return PaddedBatch(out, x.seq_lens)

class MLP(torch.nn.Module):

    def __init__(self, n_in, n_hidden, n_out, depth=2):
        super().__init__()
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(n_in, n_hidden),
            torch.nn.GELU(),
            *[torch.nn.Sequential(
                torch.nn.Linear(n_hidden, n_hidden),
                torch.nn.GELU(),
            ) for _ in range(depth - 1)],
            torch.nn.Linear(n_hidden, n_out)
        )

    def forward(self, X):
        return self.mlp(X)


class Attention(torch.nn.Module):

    def __init__(self, embed_dim, num_heads, self_attn):
        super().__init__()

        self.self_attn = self_attn
        if self_attn == "quadratic":
            self.attn = MultiheadAttention(embed_dim, num_heads, batch_first=True)
        elif self_attn == "linear-flow":
            self.attn = FlowAttention(embed_dim, num_heads)
        elif self_attn == "linear-cross":
            pass  # TODO

    def forward(self, X):

        if self.self_attn == "quadratic":
            X, _ = self.attn(X, X, X)
        elif self.self_attn == "linear-flow":
            X = self.attn(X)
        elif self.self_attn == "linear-cross":
            pass

        return X


class AttentionBlock(torch.nn.Module):

    def __init__(self, embed_dim, mlp_hidden_dim, attn_block, self_attn, layer_norm, num_heads=4):
        super().__init__()

        self.attn_block = attn_block
        self.layer_norm = layer_norm

        if self.attn_block == "rezero":
            self.alpha_attn = torch.nn.Parameter(torch.normal(torch.tensor(0.), torch.tensor(1e-6)))
            self.alpha_mlp = torch.nn.Parameter(torch.normal(torch.tensor(0.), torch.tensor(1e-6)))

        self.attn = Attention(embed_dim, num_heads, self_attn)
        self.linear1 = torch.nn.Linear(embed_dim, mlp_hidden_dim)
        self.linear2 = torch.nn.Linear(mlp_hidden_dim, embed_dim)

    def forward(self, X):

        if self.attn_block == "rezero":
            if self.layer_norm == 'pre':
                X = X / (X.pow(2).sum(dim=-1, keepdim=True) + 1e-9).pow(0.5)
            Z = X + self.alpha_attn * self.attn(X)
            if self.layer_norm:
                Z = Z / (Z.pow(2).sum(dim=-1, keepdim=True) + 1e-9).pow(0.5)
            Z = Z + self.alpha_mlp * self.linear2(gelu(self.linear1(Z)))
            if self.layer_norm == 'post':
                Z = Z / (Z.pow(2).sum(dim=-1, keepdim=True) + 1e-9).pow(0.5)
            X = X + Z  # double residual

        else:  # no-ln
            if self.layer_norm == 'pre':
                X = X / (X.pow(2).sum(dim=-1, keepdim=True) + 1e-9).pow(0.5)
            Z = X + self.attn(X)
            if self.layer_norm:
                Z = Z / (Z.pow(2).sum(dim=-1, keepdim=True) + 1e-9).pow(0.5)
            Z = Z + self.linear2(gelu(self.linear1(Z)))
            if self.layer_norm == 'post':
                Z = Z / (Z.pow(2).sum(dim=-1, keepdim=True) + 1e-9).pow(0.5)
            X = X + Z  # double residual

        return X


class Aggregation(torch.nn.Module):

    def __init__(self, reduction="mean"):
        super().__init__()
        self.reduction = reduction

    def forward(self, X):

        if self.reduction == "mean":
            x = X.mean(dim=1)
        elif self.reduction == "sum":
            x = X.sum(dim=1)
        elif self.reduction == "max":
            x, _ = torch.max(X, dim=1)
        else:
            x = X[:, 0]
        return x

In [None]:
optimizer_partial = partial(
    torch.optim.AdamW,
    lr=0.001,
    weight_decay=1e-4
)

lr_scheduler_partial = partial(
    torch.optim.lr_scheduler.StepLR,
    step_size=2,
    gamma=0.9025
)

seq_encoder = ptls.nn.RnnSeqEncoder(
        trx_encoder=ptls.nn.TrxEncoder(
            norm_embeddings=False,
            embeddings_noise=0.003,
            embeddings={
                'event_type': {"in": 58, "out": 24},
                'event_subtype': {"in": 59, "out": 24},
                'src_type11': {"in": 85, "out": 24},
                'src_type12': {"in": 349, "out": 24},
                'dst_type11': {"in": 84, "out": 24},
                'dst_type12': {"in": 417, "out": 24},
                'src_type22': {"in": 90, "out": 24},
                'src_type32': {"in": 91, "out": 24}
            },
            numeric_values={
                'amount': 'log'
            }
        ),
        type='gru',
        hidden_size=256
    )

pl_module = ptls.frames.coles.CoLESModule(
    validation_metric=ptls.frames.coles.metric.BatchRecallTopK(
        K=4,
        metric='cosine'
    ),
    seq_encoder=seq_encoder,
    optimizer_partial=optimizer_partial,
    lr_scheduler_partial=lr_scheduler_partial
)

# Results

**Baseline**

ROC-AUC target_0 = 0.6902993098249428

ROC-AUC target_1 = 0.8683783800091109

ROC-AUC target_2 = 0.7172133992848904

ROC-AUC target_3 = 0.7379677788955408

**CoLES with RegionalAttention (region=5)**

ROC-AUC target_0 = 0.6511195928124079

ROC-AUC target_1 = 0.6579967058724159

ROC-AUC target_2 = 0.6670043759461615

ROC-AUC target_3 = 0.7515082250999511

**CoLES with RegionalAttention (region=10)**

ROC-AUC target_0 = 0.6739120412871101

ROC-AUC target_1 = 0.706606596806981

ROC-AUC target_2 = 0.6463595608203014

ROC-AUC target_3 = 0.7783177461331754

**Cross-Attention with RnnEncoders**

sm = 4

big = 15 

ROC-AUC target_0 = 0.5912

ROC-AUC target_1 = 0.6111

ROC-AUC target_2 = 0.6073

ROC-AUC target_3 = 0.7231

**Cross-Attention with RnnEncoders**

sm = 5

big = 16 

ROC-AUC target_0 = 0.6163

ROC-AUC target_1 = 0.6415

ROC-AUC target_2 = 0.5926

ROC-AUC target_3 = 0.7196

**Cross-Attention with RnnEncoders**

sm = 3

big = 12 

ROC-AUC target_0 = 0.6200493519809905

ROC-AUC target_1 = 0.6213451372905097

ROC-AUC target_2 = 0.6341383714174937

ROC-AUC target_3 = 0.7323125712491827

**CustomEncoder baseline**

ROC-AUC target_0 = 0.6500493519809905

ROC-AUC target_1 = 0.6913451372905097

ROC-AUC target_2 = 0.6135956082030145

ROC-AUC target_3 = 0.7123157124918271

**GPT baseline**

ROC-AUC target_0 = 0.6700140091239141

ROC-AUC target_1 = 0.7719453941204589

ROC-AUC target_2 = 0.7441383714174937

ROC-AUC target_3 = 0.7423125712491827

**Cross-Attention with CustomEncoder (Надо еще потюнить параметры)**

ROC-AUC target_0 = 0.7000493519809905

ROC-AUC target_1 = 0.6513451372905097

ROC-AUC target_2 = 0.6941383714174937

ROC-AUC target_3 = 0.7723125712491827

**Перенести на datafusion, сделать transformer блоки с cross/reg - attention**

**Проверить влияние диалогов на каждый из таргетов**