In [None]:
from huggingface_hub import hf_hub_download 

hf_hub_download(repo_id="ai-lab/MBD-mini", filename="ptls.tar.gz", repo_type="dataset", local_dir="/kaggle/working/")
hf_hub_download(repo_id="ai-lab/MBD-mini", filename="targets.tar.gz", repo_type="dataset", local_dir="/kaggle/working/")

In [None]:
!pip install lightning

In [None]:
import wandb

wandb.login(key="79f2120f8d4212aceb2c60b3c89a1b6727c19cff")

In [None]:
!pip install pyspark
!pip install pytorch-lifestream

In [None]:
!tar -xf ptls.tar.gz
!tar -xf targets.tar.gz

In [None]:
import pandas as pd
import numpy as np
import os

import pyspark
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql import types as T
import time
import datetime
from ptls.data_load.datasets import ParquetDataset, ParquetFiles
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, FloatType, ArrayType
from tqdm.notebook import tqdm
from ptls.preprocessing import PysparkDataPreprocessor
import pytorch_lightning as pl
from ptls.data_load.datasets import MemoryMapDataset
from ptls.data_load.iterable_processing import SeqLenFilter, FeatureFilter
from ptls.data_load.iterable_processing.iterable_seq_len_limit import ISeqLenLimit
from ptls.data_load.iterable_processing.to_torch_tensor import ToTorch
from ptls.frames.coles import CoLESModule
from ptls.frames import PtlsDataModule
from ptls.frames.coles import ColesDataset
from ptls.frames.coles.split_strategy import SampleSlices
import torch
import numpy as np
import pandas as pd
import calendar
from glob import glob
from ptls.data_load.utils import collate_feature_dict

from ptls.data_load.iterable_processing_dataset import IterableProcessingDataset
from datetime import datetime
from ptls.data_load.padded_batch import PaddedBatch

In [None]:
spark_conf = pyspark.SparkConf()
spark_conf.setMaster("local[*]").setAppName("JoinModality")
spark_conf.set("spark.driver.maxResultSize", "16g")
spark_conf.set("spark.executor.memory", "32g")
spark_conf.set("spark.executor.memoryOverhead", "16g")
spark_conf.set("spark.driver.memory", "32g")
spark_conf.set("spark.driver.memoryOverhead", "16g")
spark_conf.set("spark.cores.max", "8")
spark_conf.set("spark.sql.shuffle.partitions", "200")
spark_conf.set("spark.local.dir", "../../spark_local_dir")


spark = SparkSession.builder.config(conf=spark_conf).getOrCreate()
spark.sparkContext.getConf().getAll()

In [None]:
!mkdir /kaggle/working/mm_dataset

In [None]:
TRX_DATA_PATH = '/kaggle/working/ptls/trx/'
GEO_DATA_PATH = '/kaggle/working/ptls/geo/'
DIAL_DATA_PATH = '/kaggle/working/ptls/dialog/'

MM_DATA_PATH = '/kaggle/working/mm_dataset'
MMT_DATA_PATH = '/kaggle/working/mm_dataset_supervised'

TARGETS_DATA_PATH = '/kaggle/working/targets/'

In [None]:
def rename_col(df, prefix, col_id='client_id'):
    new_column_names = [f"{prefix}_{col}" for col in df.columns if col != col_id]
    old_column_names = [col for col in df.columns if col != col_id]
    for old_col, new_col in zip(old_column_names, new_column_names):
        df = df.withColumnRenamed(old_col, new_col)
    return df

In [None]:
from ptls.preprocessing import PysparkDataPreprocessor
from pyspark.sql.functions import explode, col


for fold in tqdm(range(0, 5)):
    trx = spark.read.parquet(os.path.join(TRX_DATA_PATH, f'fold={fold}'))
    dial = spark.read.parquet(os.path.join(DIAL_DATA_PATH, f'fold={fold}'))
    
    trx = rename_col(trx, 'trx')
    dial = rename_col(dial, 'dial')
    
    mm_dataset = trx.join(dial, on='client_id', how='outer').drop(*['trx_src_type21', 'trx_src_type31'])

    mm_dataset.write.mode('overwrite').parquet(os.path.join(MM_DATA_PATH, f'fold={fold}'))
    
    del trx
    del dial
    del mm_dataset

In [None]:
# spark.stop()

In [None]:
import pandas as pd
import numpy as np
from ptls.data_load.iterable_processing_dataset import IterableProcessingDataset
from ptls.data_load import IterableChain
from datetime import datetime
from ptls.data_load.datasets.parquet_dataset import ParquetDataset, ParquetFiles
from ptls.data_load.iterable_processing.feature_filter import FeatureFilter
from ptls.data_load.iterable_processing.to_torch_tensor import ToTorch
import torch
from functools import partial
from torch.utils.data import DataLoader
from ptls.data_load.padded_batch import PaddedBatch
from ptls.data_load.utils import collate_feature_dict
from tqdm import tqdm


class TargetToTorch(IterableProcessingDataset):
    def __init__(self, col_target):
        super().__init__()
        self.col_target = col_target

    def __iter__(self):
        for rec in self._src:
            features = rec[0] if type(rec) is tuple else rec
            features[self.col_target] = np.stack(np.array(features[self.col_target]))
            features[self.col_target] = torch.tensor(features[self.col_target])
            yield features

class DeleteNan(IterableProcessingDataset):
    def __init__(self, col_name):
        super().__init__()
        self.col_name = col_name
    
    def __iter__(self):
        for rec in self._src:
            features = rec[0] if type(rec) is tuple else rec
            if features[self.col_name] is not None:
                yield features


class DialToTorch(IterableProcessingDataset):
    def __init__(self, col_time, col_embeds):
        super().__init__()
        self._year=2022
        self.col_embeds = col_embeds
        self.col_time = col_time
    def __iter__(self):
        for rec in self._src:
            features = rec[0] if type(rec) is tuple else rec
            features = features.copy()
            if features[self.col_time] is None:
                features[self.col_time] = torch.tensor([0])
            if features[self.col_embeds] is None:
                features[self.col_embeds] = torch.zeros(768)
            
            for key, tens in features.items():
                if key == self.col_embeds:
                    features[key] = torch.tensor(tens.tolist())

            yield features

class GetSplit(IterableProcessingDataset):
    def __init__(
        self,
        start_month,
        end_month,
        year=2022,
        col_id='client_id',
        col_time='event_time'
    ):
        super().__init__()
        self.start_month = start_month
        self.end_month = end_month
        self._year = year
        self._col_id = col_id
        self._col_time = col_time
        
    def __iter__(self):
        for rec in self._src:
            for month in range(self.start_month, self.end_month+1):
                features = rec[0] if type(rec) is tuple else rec
                features = features.copy()
                
                if month == 12:
                    month_event_time = datetime(self._year + 1, 1, 1).timestamp()
                else:
                    month_event_time = datetime(self._year, month + 1, 1).timestamp()
                    
                year_event_time = datetime(self._year, 1, 1).timestamp()
                
                mask = features[self._col_time] < month_event_time
                
                for key, tensor in features.items():
                    if key.startswith('target'):
                        features[key] = tensor[month - 1].tolist()    
                    elif key != self._col_id:
                        features[key] = tensor[mask] 
                            
                features[self._col_id] += '_month=' + str(month)

                yield features

In [None]:
from ptls.data_load.datasets import ParquetDataset
from ptls.data_load.iterable_processing import SeqLenFilter
from ptls.data_load.iterable_processing.iterable_seq_len_limit import ISeqLenLimit
from ptls.data_load.iterable_processing.to_torch_tensor import ToTorch
train = ParquetDataset(
    data_files=[
        os.path.join(MM_DATA_PATH, f'fold={0}'),
        os.path.join(MM_DATA_PATH, f'fold={1}'),
        os.path.join(MM_DATA_PATH, f'fold={2}')
    ],
    i_filters=[
        DeleteNan('trx_event_time'),
        DeleteNan('dial_event_time'),
        SeqLenFilter(min_seq_len=8),
        ISeqLenLimit(max_seq_len=128),
        ToTorch(),
        DialToTorch(col_time='dial_event_time', col_embeds='dial_embedding'),
        # GetSplit(
        #     start_month=1,
        #     end_month=12,
        #     col_id='client_id'
        # )
    ],
    shuffle_files=True
)
valid = ParquetDataset(
    data_files=[
        os.path.join(MM_DATA_PATH, f'fold={3}')
    ],
    i_filters=[
        DeleteNan('trx_event_time'),
        DeleteNan('dial_event_time'),
        SeqLenFilter(min_seq_len=8),
        ISeqLenLimit(max_seq_len=128),
        ToTorch(),
        DialToTorch(col_time='dial_event_time', col_embeds='dial_embedding'),
        # GetSplit(
        #     start_month=1,
        #     end_month=12,
        #     col_id='client_id'
        # )
    ],
    shuffle_files=False
)

In [None]:
from ptls.data_load.feature_dict import FeatureDict
from collections import defaultdict
from functools import reduce
import torch.nn.functional as F


# def split_and_pad(tensor: torch.Tensor, segment_length):
#     segments = [tensor[i:i + segment_length] for i in range(0, len(tensor), segment_length)]
#     padded_segments = [F.pad(segment, (0, segment_length - len(segment)), mode='constant') for segment in segments]
#     return torch.vstack(padded_segments)

# def get_regional_splits(batch: list[dict], segment_length):
#     regional_tokens = []
#     for item in batch:
#         segmented_data = {key: split_and_pad(tensor, segment_length) for key, tensor in item.items()}
#         regional_tokens.append(segmented_data)
#     return regional_tokens

class MultiModalDiffSplitDataset(FeatureDict, torch.utils.data.Dataset):
    def __init__(
        self,
        data,
        splitters,
        source_features,
        col_id,
        source_names,
        col_time='event_time',
        *args, **kwargs
    ):
        """
        Dataset for multimodal learning.
        Parameters:
        -----------
        data:
            concatinated data with feature dicts.
        splitter:
            object from from `ptls.frames.coles.split_strategy`.
            Used to split original sequence into subsequences which are samples from one client.
        source_features:
            list of column names 
        col_id:
            column name with user_id
        source_names:
            column name with name sources
        col_time:
            column name with event_time
        """
        super().__init__(*args, **kwargs)
        
        self.data = data
        self.splitters = splitters
        self.col_time = col_time
        self.col_id = col_id
        self.source_names = source_names
        self.source_features = source_features
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        feature_arrays = self.data[idx]
        split_data = self.split_source(feature_arrays)
        # print(self.get_split(split_data))
        return self.get_splits(split_data)
    
    def __iter__(self):
        for feature_arrays in self.data:
            split_data = self.split_source(feature_arrays)
            yield self.get_splits(split_data)
            
    def split_source(self, feature_arrays):
        res = defaultdict(dict)
        for feature_name, feature_array in feature_arrays.items():
            if feature_name == self.col_id:
                res[self.col_id] = feature_array
            else:
                source_name, feature_name_transform = self.get_names(feature_name)
                res[source_name][feature_name_transform] = feature_array
        for source in self.source_names:
            if source not in res:
                res[source] = {source_feature: torch.tensor([]) for source_feature in self.source_features[source]}
        # print(f'res = {res}')
        return res
    
    def get_names(self, feature_name):
        idx_del = feature_name.find('_')
        return feature_name[:idx_del], feature_name[idx_del + 1:]
    
    def get_splits(self, feature_arrays):
        res = {}
        for source_name, feature_array in feature_arrays.items():
            if source_name != self.col_id:
                local_date = feature_array[self.col_time]
                if source_name not in self.splitters:
                    continue
                indexes = self.splitters[source_name].split(local_date)
                res[source_name] = [{k: v[ix] for k, v in feature_array.items() if self.is_seq_feature(k, v)} for ix in indexes]
        return res
    
    #Вернуть диалоги с транзакциями (без таргетов)
    def collate_fn(self, batch, return_dct_labels=False):
        dict_class_labels = get_dict_class_labels(batch)
        batch = reduce(lambda x, y: {k: x[k] + y[k] for k in x if k in y}, batch)
        # Get regional split only for transactions from MBD
        # batch['reg_trx_tokens'] = get_regional_splits(batch['trx'], segment_length=3)
        # print(f"{batch=}")
        padded_batch = collate_multimodal_feature_dict(batch)
        if return_dct_labels:
            return padded_batch, dict_class_labels
        return padded_batch, dict_class_labels[list(dict_class_labels.keys())[0]]

def collate_multimodal_feature_dict(batch):
    res = {}
    for source, source_batch in batch.items():
        res[source] = collate_feature_dict(source_batch)
    # print(f"multimodal_feature_dict = {res['trx'].payload['event_time'].size()}")
    # print()
    return res

def collate_feature_dict(batch):
    new_x_ = defaultdict(list)
    for i, x in enumerate(batch):
        for k, v in x.items():
            new_x_[k].append(v)
    
    seq_col = next(k for k, v in batch[0].items() if FeatureDict.is_seq_feature(k, v))
    lengths = torch.LongTensor([len(rec[seq_col]) for rec in batch])
    new_x = {}
    for k, v in new_x_.items():
        if type(v[0]) is torch.Tensor:
            if k.startswith('target'):
                new_x[k] = torch.stack(v, dim=0)
            else:
                new_x[k] = torch.nn.utils.rnn.pad_sequence(v, batch_first=True)
        elif type(v[0]) is np.ndarray:
            new_x[k] = v  # list of arrays[object]
        else:
            v = np.array(v)
            if v.dtype.kind == 'i':
                new_x[k] = torch.from_numpy(v).long()
            elif v.dtype.kind == 'f':
                new_x[k] = torch.from_numpy(v).float()
            elif v.dtype.kind == 'b':
                new_x[k] = torch.from_numpy(v).bool()
            else:
                new_x[k] = v
    return PaddedBatch(new_x, lengths)
    
def get_dict_class_labels(batch):
    res = defaultdict(list)
    for i, samples in enumerate(batch):
        for source, values in samples.items():
            for _ in values:
                res[source].append(i)
    for source in res:
        res[source] = torch.LongTensor(res[source])
    return dict(res)

class MultiModalDiffSplitIterableDataset(MultiModalDiffSplitDataset, torch.utils.data.IterableDataset):
    pass

In [None]:
from ptls.frames.coles import MultiModalIterableDataset
from lightning.pytorch.loggers import WandbLogger
import ptls

data_module = PtlsDataModule(
    train_data=MultiModalDiffSplitIterableDataset(
        data=train,
        splitters= {
            'trx': SampleSlices(
                split_count=3,
                cnt_min=16,
                cnt_max=90
            ),
            'dial': SampleSlices(
                split_count=3,
                cnt_min=2,
                cnt_max=10
            ),
        },
        source_features={
            "trx": [
                "event_type",
                "event_subtype",
                "src_type11",
                "src_type12",
                "dst_type11",
                "dst_type12",
                "src_type22",
                "src_type32"
            ],
            "dial": [
                "embedding"
            ],
        },
        col_id='client_id',
        col_time='event_time',
        source_names=['trx', 'dial'],
    ),
    valid_data=MultiModalDiffSplitIterableDataset(
        data=valid,
        splitters= {
            'trx': SampleSlices(
                split_count=2,
                cnt_min=5,
                cnt_max=64
            ),
            'dial': SampleSlices(
                split_count=2,
                cnt_min=2,
                cnt_max=10
            ),
            },
        source_features={
            "trx": [
                "event_type",
                "event_subtype",
                "src_type11",
                "src_type12",
                "dst_type11",
                "dst_type12",
                "src_type22",
                "src_type32"
            ],
            "dial": [
                "embedding"
            ],
        },
        col_id='client_id',
        col_time='event_time',
        source_names=['trx', 'dial'],
    ),
    train_batch_size=64,
    train_num_workers=0,
    valid_batch_size=64,
    valid_num_workers=0
)

In [None]:
# train_data_test=MultiModalDiffSplitIterableDataset(
#         data=train,
#         splitters= {
#             'trx': SampleSlices(
#                 split_count=3,
#                 cnt_min=16,
#                 cnt_max=90
#             ),
#             'dial': SampleSlices(
#                 split_count=3,
#                 cnt_min=2,
#                 cnt_max=10
#             ),
#         },
#         source_features={
#             "trx": [
#                 "event_type",
#                 "event_subtype",
#                 "src_type11",
#                 "src_type12",
#                 "dst_type11",
#                 "dst_type12",
#                 "src_type22",
#                 "src_type32"
#             ],
#             "dial": [
#                 "embedding"
#             ],
#         },
#         col_id='client_id',
#         col_time='event_time',
#         source_names=['trx', 'dial'],
#     )

# next(iter(data_module.train_dl(train_data_test)))

# Add region attention in trx encoder

In [None]:
# x_new = torch.rand((192, 87, 32))
# segment_length = 10
# pad_length = (segment_length - (x_new.size()[1] % segment_length)) % segment_length
# padded_x_new = F.pad(x_new, ((0, 0, 0, pad_length, 0, 0)), 'constant', 0)
# segmented_tensors = torch.stack(torch.split(padded_x_new, segment_length, dim=1), dim=0)
# segmented_tensors = segmented_tensors.permute(0, 2, 1, 3)
# segmented_tensors.size()

# a = torch.Tensor([[1, 2, 3], [4, 5, 6]])
# b = torch.Tensor([[4, 5, 6]])
# a = torch.cat((a, b), dim=0)
# a
# torch.Tensor([])

a = torch.rand((128, 64, 32))
b = torch.rand((128, 64))

print(b[:, :, None].size())

c = a + b[:, :, None]
print(a)
print(c)

In [None]:
import torch

# from ptls.constant_repository import TORCH_EMB_DTYPE
from ptls.data_load import PaddedBatch
from ptls.nn.seq_encoder.rnn_encoder import RnnEncoder
from ptls.nn.seq_encoder.transformer_encoder import TransformerEncoder
from ptls.nn.seq_encoder.longformer_encoder import LongformerEncoder
from ptls.nn.seq_encoder.custom_encoder import Encoder
from ptls.nn.trx_encoder import TrxEncoder
from ptls.nn.seq_encoder.containers import SeqEncoderContainer

class RnnSeqEncoderRegAttn(SeqEncoderContainer):
    def __init__(self,
                 trx_encoder=None,
                 input_size=None,
                 is_reduce_sequence=True,
                 **seq_encoder_params,
                 ):
        super().__init__(
            trx_encoder=trx_encoder,
            seq_encoder_cls=RnnEncoder,
            input_size=input_size,
            seq_encoder_params=seq_encoder_params,
            is_reduce_sequence=is_reduce_sequence,
        )
        
        self.reg_seq_encoder = RnnEncoder(
            input_size=input_size if input_size is not None else trx_encoder.output_size,
            is_reduce_sequence=is_reduce_sequence,
            **seq_encoder_params,
        )

        self.emb_dim = 192
        self.regional_attention = nn.MultiheadAttention(
            embed_dim=self.emb_dim,
            num_heads=8,
            dropout=0.3,
            batch_first=True
        )
    
    
    def forward(self, x, names=None, seq_len=None, h_0=None):
        # print(f"x_in = {x.payload['amount'].size()}")
        x = self.trx_encoder(x)

        x_new = x.payload
        
        segment_length = 5
        pad_length = (segment_length - (x_new.size()[1] % segment_length)) % segment_length
        padded_x_new = F.pad(x_new, ((0, 0, 0, pad_length, 0, 0)), 'constant', 0)
        segmented_tensors = torch.stack(torch.split(padded_x_new, segment_length, dim=1)).to(x_new.device)
        
        regional_embeddings = torch.Tensor().to(x.device)
        for tensor in segmented_tensors:
            tensor = PaddedBatch(tensor.permute(1, 0, 2), [tensor.size()[0]] * tensor.size()[1])
            regional_embed = self.reg_seq_encoder(tensor)
            regional_embeddings = torch.cat((regional_embeddings, regional_embed), 0)
        
        regional_embeddings = regional_embeddings[:len(regional_embeddings) - pad_length, :]
        # layer_norm = nn.LayerNorm([regional_embeddings.size()[0], regional_embeddings.size()[1]])
        layer_norm.to(x.device)
        regional_embeddings = layer_norm(regional_embeddings)
        
        if regional_embeddings.size()[1] != self.emb_dim:
            regional_embeddings = F.pad(regional_embeddings, ((0, abs(regional_embeddings.size()[1] - self.emb_dim), 0, 0)), 'constant', 0)
        x_reg_embed, _ = self.regional_attention(regional_embeddings, regional_embeddings, regional_embeddings)
        if regional_embeddings.size()[1] != x_new.size()[0]:
            x_reg_embed = x_reg_embed[:, :-abs(regional_embeddings.size()[1] - x_new.size()[0])]
        x_reg_embed = x_reg_embed.permute(1, 0)
        x_reg_embed = x_reg_embed[:, :, None]
        x_new = x_new + x_reg_embed
        x_new.to(x.device)
        x_new = PaddedBatch(x_new, x.seq_lens)
        # x_new.to(x.device)
        x = self.seq_encoder(x_new, h_0)
        # print(f"rnn_x_size = {x.size()}")
        return x

In [None]:
from ptls.frames.abs_module import ABSModule
from ptls.frames.coles.metric import metric_recall_top_K, outer_cosine_similarity, outer_pairwise_distance
from ptls.frames.coles.losses import ContrastiveLoss
from torch import nn
from ptls.nn.seq_encoder.custom_encoder import MLP
import math


# class CrossAttention(nn.Module):
#     def __init__(self, d_in, d_out_kq, d_out_v):
#         super().__init__()
#         self.d_out_kq=d_out_kq
#         self.W_query=nn.Parameter(torch.rand(d_in, d_out_kq))
#         self.W_key  = nn.Parameter(torch.rand(d_in, d_out_kq))
#         self.W_value=nn.Parameter(torch.rand(d_in, d_out_v))
    
#     def forward(self, x_1, x_2):
#         queries_1=x_1.matmul(self.W_query)
#         keys_2=x_2.matmul(self.W_key)
#         values_2=x_2.matmul(self.W_value)
        
#         attn_scores=queries_1.matmul(keys_2.T)
#         attn_weights=torch.softmax(
#             attn_scores/self.d_out_kq**0.5, dim=-1
#         )
        
#         context_vec=attn_weights.matmul(values_2)
#         return context_vec

def first(iterable, default=None):
    iterator = iter(iterable)
    return next(iterator, default)


# class PositionalEncoding(nn.Module):
#     def __init__(self,
#                  d_model,
#                  use_start_random_shift=True,
#                  max_len=5000,
#                  ):
#         super().__init__()
#         self.use_start_random_shift = use_start_random_shift
#         self.max_len = max_len

#         pe = torch.zeros(max_len, d_model)
#         position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
#         div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
#         pe[:, 0::2] = torch.sin(position * div_term)
#         pe[:, 1::2] = torch.cos(position * div_term)
#         pe = pe.unsqueeze(0)
#         self.register_buffer('pe', pe)

#     def forward(self, x):
#         T = x.size(0)
#         if self.training and self.use_start_random_shift:
#             start_pos = random.randint(0, self.max_len - T)
#         else:
#             start_pos = 0
#         # print(f'{x.size()=}')
#         # print(f'{self.pe.size()}')
#         # print(f'{self.pe[:, start_pos:start_pos + T].size()=}')
#         x = x + self.pe[:, start_pos:start_pos + T]
#         return x

# Проверить, что на каждой эпохе лежит разный юзер
# Сделать выбор positive + negative случайно
# Сделать self-attention на унимодальности

class M3CoLESModule(ABSModule):
    """
    Multi-Modal Matching
    Contrastive Learning for Event Sequences ([CoLES](https://arxiv.org/abs/2002.08232))

    Subsequences are sampled from original sequence.
    Samples from the same sequence are `positive` examples
    Samples from the different sequences are `negative` examples
    Embeddings for all samples are calculated.
    Paired distances between all embeddings are calculated.
    The loss function tends to make positive distances smaller and negative ones larger.

    Parameters
        seq_encoder:
            Model which calculate embeddings for original raw transaction sequences
            `seq_encoder` is trained by `CoLESModule` to get better representations of input sequences
        head:
            Model which helps to train. Not used during inference
            Can be normalisation layer which make embedding l2 length equals 1
            Can be MLP as `projection head` like in SymCLR framework.
        loss:
            loss object from `ptls.frames.coles.losses`.
            There are paired and triplet loss. They are required sampling strategy
            from `ptls.frames.coles.sampling_strategies`. Sampling strategy takes a relevant pairs or triplets from
            pairwise distance matrix.
        validation_metric:
            Keep None. `ptls.frames.coles.metric.BatchRecallTopK` used by default.
        optimizer_partial:
            optimizer init partial. Network parameters are missed.
        lr_scheduler_partial:
            scheduler init partial. Optimizer are missed.

    """
    def __init__(self,
                 seq_encoders=None,
                 mod_names=None,
                 head=None,
                 loss=None,
                 validation_metric=None,
                 optimizer_partial=None,
                 lr_scheduler_partial=None):
        torch.set_float32_matmul_precision('high')
        if head is None:
            head = ptls.nn.Head(use_norm_encoder=True)

        if loss is None:
            loss = ContrastiveLoss(margin=0.5,
                                   sampling_strategy=HardNegativePairSelector(neg_count=5))

        if validation_metric is None:
            validation_metric = BatchRecallTopK(K=4, metric='cosine')
        
        for k in seq_encoders.keys():
            if type(seq_encoders[k]) is str:
                seq_encoders[k] = seq_encoders[seq_encoders[k]]

        
        super().__init__(validation_metric,
                         first(seq_encoders.values()),
                         loss,
                         optimizer_partial,
                         lr_scheduler_partial)

        # cross_mha_MLP = ptls.nn.seq_encoder.
        self.mha_trx_dial = nn.MultiheadAttention(
            embed_dim=128,
            num_heads=8,
            dropout=0.3,
            batch_first=True
        )
        self.mha_dial_trx = nn.MultiheadAttention(
            embed_dim=128,
            num_heads=8,
            dropout=0.1,
            batch_first=True
        )

        ## MLP variation
        # self.head_trx = MLP(
        #         n_in=128,
        #         n_hidden=128,
        #         n_out=128
        #     )
        # self.head_dial = MLP(
        #         n_in=128,
        #         n_hidden=128,
        #         n_out=128
        #     ) 

        #FFN variant
        self.head_trx = MLP(
                n_in=128,
                n_hidden=256,
                n_out=128
            )
        self.head_dial = MLP(
                n_in=128,
                n_hidden=256,
                n_out=128
            ) 

        self.seq_encoders = torch.nn.ModuleDict(seq_encoders)
        self._head = head   
        self.y_h_cache = {'train':[], 'valid': []}
        
    @property
    def metric_name(self):
        return 'recall_top_k'

    @property
    def is_requires_reduced_sequence(self):
        return True
    
    def forward(self, x):
        res = {}
        for mod_name in x.keys():
            res[mod_name] = self.seq_encoders[mod_name](x[mod_name])
        return res

    def shared_step(self, x, y):
        y_h = self(x)
        
        if self._head is not None:
            y_h_head = {k: self._head(y_h_k) for k, y_h_k in y_h.items()}
            y_h = y_h_head
        return y_h, y
    
    def _one_step(self, batch, _, stage):
        y_h, y = self.shared_step(*batch)
        y_h_list = list(y_h.values())
        loss = self._loss(torch.cat(y_h_list), torch.cat([y, y]))
        self.log(f'loss/{stage}', loss.detach())
        
        x, y = batch
        for mod_name, mod_x in x.items():
            self.log(f'seq_len/{stage}/{mod_name}', x[mod_name].seq_lens.float().mean().detach(), prog_bar=True)
        
        if stage == "valid":
            n, d = y_h_list[0].shape
            y_h_concat = torch.zeros((2*n, d), device = y_h_list[0].device)
            
            for i in range(2):
                y_h_concat[range(i,2*n,2)] = y_h_list[i] 
            if len(self.y_h_cache[stage]) <= 380:
                self.y_h_cache[stage].append((y_h_concat.cpu(), {k: y_h_k.cpu() for k, y_h_k in y_h.items()} , 
                                             {k:x_k.seq_lens.cpu() for k, x_k in x.items()})) 
        return loss
    
    def training_step(self, batch, _):
        return self._one_step(batch, _, "train")
    
    def validation_step(self, batch, _):
        return self._one_step(batch, _, "valid")
    
    def on_validation_epoch_end(self):        
        #len_intervals = [(0, 10), (10, 20), (20, 30), (30, 40), (40, 60), (60, 80), (80, 120), (120, 160), (160, 240)]
        self.log_recall_top_K(self.y_h_cache['valid'], len_intervals=None, stage="valid", K=30)
        self.log_recall_top_K(self.y_h_cache['valid'], len_intervals=None, stage="valid", K=20)
        self.log_recall_top_K(self.y_h_cache['valid'], len_intervals=None, stage="valid", K=1)
        
        
        del self.y_h_cache["valid"]
        self.y_h_cache["valid"] = []
        
    def log_recall_top_K(self, y_h_cache, len_intervals=None, stage="valid", K=15):
        y_h = torch.cat([item[0] for item in y_h_cache], dim = 0)
        y_h_mods = defaultdict(list)
        seq_lens_dict = defaultdict(list)
        
        for item in y_h_cache:
            for k, emb in item[1].items():
                y_h_mods[k].append(emb)
                
            for k, l in item[2].items():
                seq_lens_dict[k].append(l)
        
        y_h_mods = {k: torch.cat(el, dim=0) for k ,el in y_h_mods.items()}
        seq_lens_dict = {k: torch.cat(el) for k ,el in seq_lens_dict.items()}

        #n, _ = y_h.shape
        #y = torch.zeros((n,)).cpu().long()
        #y[range(0,n,2)] = torch.arange(0, n//2)
        #y[range(1,n,2)] = torch.arange(0, n//2)
        #computed_metric = metric_real_recall_top_K(y_h, y, K=100)
        y_h_bank, y_h_rmb = list(y_h_mods.values())
        computed_metric_b2r = metric_recall_top_K_for_embs(y_h_bank, y_h_rmb, torch.arange(y_h_rmb.shape[0]), K=K)
        computed_metric_r2b = metric_recall_top_K_for_embs(y_h_rmb, y_h_bank, torch.arange(y_h_rmb.shape[0]), K=K)
        
        if len_intervals != None:
            for mod, seq_lens in seq_lens_dict.items():
                for start, end in len_intervals:
                    mask = ((seq_lens > start) & (seq_lens <= end))

                    if torch.any(mask):
                        #y_h_filtered = y_h[mask.repeat_interleave(2)]
                        y_h_bank_filtered = y_h_bank[mask]
                        y_h_rmb_filtered = y_h_rmb[mask]

                        #y = torch.div(torch.arange(len(y_h_filtered)), 2, rounding_mode='floor')
                        #recall = metric_real_recall_top_K(y_h_filtered, y, K=100)
                        recall_r2b = metric_recall_top_K_for_embs(y_h_rmb_filtered, y_h_bank_filtered, torch.arange(y_h_rmb_filtered.shape[0]), K=30)
                        recall_b2r = metric_recall_top_K_for_embs(y_h_bank_filtered, y_h_rmb_filtered, torch.arange(y_h_rmb_filtered.shape[0]), K=30)

                        #self.log(f"{mode}/R@100_len_from_{start}_to_{end}", recall, prog_bar=True)
                        print(f"{stage}/{mod}/r2b_R@100_len_from_{start}_to_{end}", recall_r2b, prog_bar=True)
                        self.log(f"{stage}/{mod}/b2r_R@100_len_from_{start}_to_{end}", recall_b2r, prog_bar=True)
        
        #self.log(f"{mode}/R@100", computed_metric, prog_bar=True)
        self.log(f"{stage}/click2trx_R@{K}", computed_metric_r2b, prog_bar=True)
        self.log(f"{stage}/trx2click_R@{K}", computed_metric_b2r, prog_bar=True)

In [None]:
def metric_real_recall_top_K(X, y, K, num_pos=1, metric='cosine'):
    """
        calculate metric R@K
        X - tensor with size n x d, where n - number of examples, d - size of embedding vectors
        y - true labels
        N - count of closest examples, which we consider for recall calcualtion
        metric: 'cosine' / 'euclidean'.
            !!! 'euclidean' - to slow for datasets bigger than 100K rows
    """
    K_adjusted = min(X.size(0) - 1, K)
    
    res = []

    n = X.size(0)
    d = X.size(1)
    max_size = 2 ** 32
    batch_size = max(1, max_size // (n * d))

    with torch.no_grad():

        for i in range(1 + (len(X) - 1) // batch_size):

            id_left = i * batch_size
            id_right = min((i + 1) * batch_size, len(y))
            y_batch = y[id_left:id_right]
            # print(f"X = {X}")
            # print(f"X[] = {X[id_left:id_right]}")
            if metric == 'cosine':
                pdist = -1 * outer_cosine_similarity(X, X[id_left:id_right])
            elif metric == 'euclidean':
                pdist = outer_pairwise_distance(X, X[id_left:id_right])
            else:
                raise AttributeError(f'wrong metric "{metric}"')

            values, indices = pdist.topk(K_adjusted + 1, 0, largest=False)

            y_rep = y_batch.repeat(K_adjusted, 1)
            res.append((y[indices[1:]] == y_rep).sum().item())

    return np.sum(res) / len(y) / num_pos

def cosine_similarity_matrix(x1, x2):
    x1_norm = x1 / x1.norm(dim=1)[:, None]
    x2_norm = x2 / x2.norm(dim=1)[:, None]
    return torch.mm(x1_norm, x2_norm.transpose(0, 1))

def metric_recall_top_K_for_embs(embs_1, embs_2, true_matches, K=30):
    similarity_matrix = cosine_similarity_matrix(embs_1, embs_2)
    K_adjusted = min(len(embs_1), K)
    top_k = similarity_matrix.topk(k=K_adjusted, dim=1).indices
    correct_matches = 0
    for i, indices in enumerate(top_k):
        if true_matches[i] in indices:
            correct_matches += 1
    recall_at_k = correct_matches / len(similarity_matrix)
    return recall_at_k


In [None]:
from ptls.nn import TrxEncoder, RnnSeqEncoder
from ptls.frames.coles import CoLESModule
from functools import partial
import torch
from ptls.frames.coles import MultiModalSortTimeSeqEncoderContainer
from ptls.nn.trx_encoder.encoders import IdentityEncoder
from ptls.nn.seq_encoder.rnn_encoder import RnnEncoder
from ptls.nn.seq_encoder.transformer_encoder import TransformerEncoder
from ptls.frames.coles.losses import ContrastiveLoss
from ptls.frames.coles.sampling_strategies import HardNegativePairSelector

head = ptls.nn.Head(
    input_size=128,
    use_norm_encoder=True,
    hidden_layers_sizes=[128, 128],
    objective="regression",
    num_classes=128
)

loss = ptls.frames.coles.losses.SoftmaxLoss()

# With RNN
seq_encoders = {
    'trx': RnnSeqEncoderRegAttn(
        trx_encoder=TrxEncoder(
            norm_embeddings=False,
            embeddings_noise=0.003,
            linear_projection_size=32,
            embeddings={
                'event_type': {"in": 58, "out": 24},
                'event_subtype': {"in": 59, "out": 24},
                'src_type11': {"in": 85, "out": 24},
                'src_type12': {"in": 349, "out": 24},
                'dst_type11': {"in": 84, "out": 24},
                'dst_type12': {"in": 417, "out": 24},
                'src_type22': {"in": 90, "out": 24},
                'src_type32': {"in": 91, "out": 24}
            },
            numeric_values={
                'amount': 'log'
            }
        ),
        type='gru',
        hidden_size=128
    ),
    'dial': RnnSeqEncoder(
        trx_encoder=TrxEncoder(
            embeddings_noise=0.003,
            linear_projection_size=32,
            custom_embeddings={
                'embedding': IdentityEncoder(768)
            }
        ),
        type='gru',
        hidden_size=128
    )
}

optimizer_partial = partial(
    torch.optim.AdamW,
    lr=0.001,
    weight_decay=1e-4
)

lr_scheduler_partial = partial(
    torch.optim.lr_scheduler.StepLR,
    step_size=1,
    gamma=0.9
)

# pl_module = M3CoLESModule(
#     validation_metric=ptls.frames.coles.metric.BatchRecallTopK(
#         K=1,
#         metric='cosine',
#     ),
#     head=head,
#     seq_encoders=seq_encoders,
#     loss=loss,
#     optimizer_partial=optimizer_partial,
#     lr_scheduler_partial=lr_scheduler_partial
# )

pl_module = M3CoLESModule(
    validation_metric=ptls.frames.coles.metric.BatchRecallTopK(
        K=10,
        metric='cosine',
    ),
    head=head,
    seq_encoders=seq_encoders,
    loss=loss,
    optimizer_partial=optimizer_partial,
    lr_scheduler_partial=lr_scheduler_partial
)

In [None]:
wandb_logger = WandbLogger(project="MBD_My_Code", log_model="all")

trainer = pl.Trainer(
    logger=wandb_logger,
    max_epochs=25,
    accelerator="cuda" if torch.cuda.is_available() else "cpu",
    enable_progress_bar=True,
    gradient_clip_val=0.5,
    log_every_n_steps=50,
    limit_val_batches=32
)

In [None]:
trainer.fit(pl_module, data_module)

In [None]:
TARGETS_DATA_PATH = '/kaggle/working/targets/'

preprocessor_target = PysparkDataPreprocessor(
    col_id="client_id",
    col_event_time="mon",
    event_time_transformation="dt_to_timestamp",
    cols_identity=["target_1", "target_2", "target_3", "target_4"],
)


In [None]:
targets = spark.read.parquet(os.path.join(TARGETS_DATA_PATH , f'fold={fold}'))
mmt_dataset = spark.read.parquet(os.path.join(MM_DATA_PATH , f'fold={fold}'))

targets = preprocessor_target.fit_transform(targets).drop(*['event_time' ,'trans_count', 'diff_trans_date'])
mmt_dataset = mmt_dataset.join(targets, on='client_id', how='left')
mmt_dataset.write.mode('overwrite').parquet(os.path.join(MMT_DATA_PATH, f'fold={fold}'))

del mmt_dataset

In [None]:
class MatchingModalities(IterableProcessingDataset):
    def __init__(
        self,
        col_id='client_id',
        mod1_col_time='trx_event_time',
        mod2_col_time='dial_event_time',
        mod1_name='trx',
        mod2_name='dial'
    ):
        super().__init__()
        self.col_id = col_id
        self.mod1_col_time = mod1_col_time
        self.mod2_col_time = mod2_col_time
        self.mod1_name=mod1_name
        self.mod2_name=mod2_name

    def __iter__(self):
        for rec in self._src:
            features = rec[0] if type(rec) is tuple else rec
            features = features.copy()
            all_event_time, _ = torch.sort(torch.cat((features[self.mod1_col_time], features[self.mod2_col_time])))
            mod1_mask = torch.isin(all_event_time, features[self.mod1_col_time])
            mod2_mask = torch.isin(all_event_time, features[self.mod2_col_time])
            for key, tens in features.items():
                if key.startswith(self.mod1_name) and key != self.mod1_col_time:
                    indices = torch.where(~mod1_mask)[0].tolist()
                    result = []
                    for i in range(len(mod1_mask)):
                        if i in indices:
                            result.append(0)
                        if i < len(tens):
                            result.append(tens[i])
                    features[key] = torch.tensor(result, dtype=features[key].dtype)
                    # print(f'mod1_size = {features[key].size()}')
                elif key.startswith(self.mod2_name) and key != self.mod2_col_time:
                    indices = torch.where(~mod2_mask)[0].tolist()
                    # print(indices)
                    result = []
                    for i in range(len(mod2_mask)):
                        if i in indices:
                            result.append(torch.Tensor([0 for i in range(768)]))
                        if i < len(tens):
                            result.append(tens[i])
                    # print(len(result))
                    # print(result)
                    features[key] = torch.stack(result, dim=0)
                    # print(f'mod2_size = {features[key].size()}')
            features[self.mod1_col_time] = all_event_time
            features[self.mod2_col_time] = all_event_time
            
            yield features   

class MMToTorch(IterableProcessingDataset):
    def __init__(
        self,
        col_id='client_id'
    ):
        super().__init__()
        self.col_id='client_id'

    def __iter__(self):
        for rec in self._src:
            features = rec[0] if type(rec) is tuple else rec
            features = features.copy()
            for key, value in features.items():
                if key != 'client_id':
                    features[key] = torch.Tensor(value)
            yield features

class DialToTorch(IterableProcessingDataset):
    def __init__(
        self,
        embedding_col='dial_embedding'
    ):
        super().__init__()
        self.embedding_col=embedding_col

    def __iter__(self):
        for rec in self._src:
            features = rec[0] if type(rec) is tuple else rec
            features = features.copy()
            features[self.embedding_col] = np.stack(features[self.embedding_col], axis=0).astype(np.int32)
            features[self.embedding_col] = torch.FloatTensor(features[self.embedding_col])
            yield features

In [None]:
import copy

class GetSplit(IterableProcessingDataset):
    def __init__(
        self,
        start_month,
        end_month,
        year=2022,
        col_id='client_id',
        trx_col_time='trx_event_time',
        dial_col_time='dial_event_time'
    ):
        super().__init__()
        self.start_month = start_month
        self.end_month = end_month
        self._year = year
        self._col_id = col_id
        self._trx_col_time = trx_col_time
        self._dial_col_time = dial_col_time
        
    def __iter__(self):
        for rec in self._src:
            for month in range(self.start_month, self.end_month+1):
                features = rec[0] if type(rec) is tuple else rec
                features = features.copy()
                # print(f'features event time size = {features["trx_event_time"].size()}')
                if month == 12:
                    month_event_time = datetime(self._year + 1, 1, 1).timestamp()
                else:
                    month_event_time = datetime(self._year, month + 1, 1).timestamp()
                    
                year_event_time = datetime(self._year, 1, 1).timestamp()

                # print(f"{rec}")
                # print(f"{month_event_time=}")
                trx_mask = features[self._trx_col_time] < month_event_time
                dial_mask = features[self._dial_col_time] < month_event_time
                
                for key, tensor in features.items():
                    if key.startswith('target'):
                        features[key] = tensor[month - 1].tolist()
                    if key.startswith('trx'):
                        features[key] = tensor[trx_mask]
                        if len(features[key]) == 0:
                            if key == 'trx_event_time':
                                features[key] = torch.Tensor([month_event_time]).to(torch.int32)
                            else:
                                features[key] = torch.Tensor([0])

                    if key.startswith('dial'):
                        features[key] = tensor[dial_mask]
                        if len(features[key]) == 0:
                            if key == 'dial_event_time':
                                features[key] = torch.Tensor([month_event_time]).to(torch.int32)

                            else:
                                features[key] = torch.zeros([1, 768])

                    # elif key != self._col_id:
                        # print(f'mask_size = {mask.size()}')
                        # print(f'tensor size = {tensor.size()}')
                        # print(f'key = {key}')
                        # if mask.size()[0] > tensor.size()[0]:
                        #     mask = mask[:tensor.size()[0]]
                        # if mask.size()[0] < tensor.size()[0]:
                        #     tensor = tensor[:mask.size()[0]]
                        # features[key] = tensor[mask] 
                        # print(f'features[key] size = {features[key].size()}')
                        # print('=====================')
                    
                            
                features[self._col_id] += '_month=' + str(month)

                yield features

In [None]:
from ptls.frames.coles import MultiModalInferenceIterableDataset

dataset_inf = ParquetDataset(
    data_files=[
        os.path.join(MMT_DATA_PATH, f'fold={4}')
    ],
    i_filters=[
        DeleteNan('trx_event_time'),
        DeleteNan('dial_event_time'),
        # SeqLenFilter(min_seq_len=1),
        # ISeqLenLimit(max_seq_len=128),
        ptls.data_load.iterable_processing.to_torch_tensor.ToTorch(),
        # MMToTorch('client_id'),
        DialToTorch(),
        # MatchingModalities(
        #     col_id='client_id',
        #     mod1_col_time='trx_event_time',
        #     mod2_col_time='dial_event_time'
        # ),
        GetSplit(
            start_month=1,
            end_month=12,
            col_id='client_id',
            trx_col_time='trx_event_time',
            dial_col_time='dial_event_time'
        )
    ],
    shuffle_files=False
)

# print(next(iter(dataset_inf)))


dataset_inf = MultiModalInferenceIterableDataset(
        data=dataset_inf,
        source_features={
            "trx": [
                "event_type",
                "event_subtype",
                "src_type11",
                "src_type12",
                "dst_type11",
                "dst_type12",
                "src_type22",
                "src_type32"
            ],
            "dial": [
                "embedding"
            ]
        },
        col_id='client_id',
        col_time='event_time',
        source_names=['trx', 'dial'],
    )

In [None]:
# from ptls.custom_layers import StatPooling
# from ptls.nn.seq_step import LastStepEncoder
from itertools import chain

class InferenceModuleMultimodal(pl.LightningModule):
    def __init__(self, model, pandas_output=True, drop_seq_features=True, model_out_name='out', col_id = 'epk_id'):
        super().__init__()

        self.model = model
        self.pandas_output = pandas_output
        self.drop_seq_features = drop_seq_features
        self.model_out_name = model_out_name
        self.col_id = col_id

    def forward(self, x: PaddedBatch):
        x, batch_ids, targets = x

        #TODO: Проверить batch_ids == col_id. Проверил, совпадают
        out = self.model(x)
        # print(f"{_=}")
        x_out = {self.col_id : batch_ids, self.model_out_name: out}
        if self.pandas_output:
            return self.to_pandas(x_out, targets)
        return x_out

    def to_pandas(self, x, targets):
        expand_cols = []
        scalar_features = {}
        if self.model_out_name in x:
            for k, v in x[self.model_out_name].items():
                x[k] = v
        del x[self.model_out_name]
        for k, v in x.items():
            if type(v) is torch.Tensor:
                v = v.cpu().numpy()
            if type(v) is list or len(v.shape) == 1:
                scalar_features[k] = v
            elif len(v.shape) == 2:
                expand_cols.append(k)
            else:
                scalar_features[k] = None

        dataframes = [pd.DataFrame(scalar_features)]
        targets_dataframe = pd.DataFrame([item[0] for item in targets])
        for col in expand_cols:
            v = x[col].cpu().numpy()
            dataframes.append(pd.DataFrame(v, columns=[f'{col}_{i:04d}' for i in range(v.shape[1])]))
        return pd.concat(dataframes, axis=1).join(targets_dataframe)

In [None]:
def collate_feature_dict(batch):
    new_x_ = defaultdict(list)
    for i, x in enumerate(batch):
        for k, v in x.items():
            new_x_[k].append(v)
    
    seq_col = next(k for k, v in batch[0].items() if FeatureDict.is_seq_feature(k, v))
    lengths = torch.LongTensor([len(rec[seq_col]) for rec in batch])
    new_x = {}
    for k, v in new_x_.items():
        # print(new_x)
        if type(v[0]) is torch.Tensor:
            if k.startswith('target'):
                new_x[k] = torch.stack(v, dim=0)
            else:
                new_x[k] = torch.nn.utils.rnn.pad_sequence(v, batch_first=True)
        elif type(v[0]) is np.ndarray:
            new_x[k] = v  # list of arrays[object]
        else:
            v = np.array(v)
            if v.dtype.kind == 'i':
                new_x[k] = torch.from_numpy(v).long()
            elif v.dtype.kind == 'f':
                new_x[k] = torch.from_numpy(v).float()
            elif v.dtype.kind == 'b':
                new_x[k] = torch.from_numpy(v).bool()
            else:
                new_x[k] = v
    return PaddedBatch(new_x, lengths)

def collate_multimodal_feature_dict(batch):
    res = {}
    for source, source_batch in batch.items():
        res[source] = collate_feature_dict(source_batch)
    return res

def collate_feature_dict_with_target(batch, col_id='client_id', target_col_names=None):
    batch_ids = []
    target_cols = []
    # print('batch_in')
    for sample in batch:
        batch_ids.append(sample[col_id])
        del sample[col_id]
        if target_col_names is not None:
            for target_col in target_col_names:
                target_cols.append(sample[target_col])
                del sample[target_col]
    # print(batch)
    batch = reduce(lambda x, y: {k: x[k] + y[k] for k in x if k in y}, batch)
    # print(batch)
    padded_batch = collate_multimodal_feature_dict(batch)
    # print(padded_batch['trx'].payload)
    # print(batch_ids)
    # print(padded_batch['trx'].payload['event_time'].size())
    # print(padded_batch['trx'].payload)
    if target_col_names is not None:
        return padded_batch, batch_ids, target_cols
    return padded_batch, batch_ids[0]

In [None]:
from torch.utils.data import DataLoader

collate_fn = partial(
    collate_feature_dict_with_target,
    target_col_names=['target']
)

inference_dl = DataLoader(
    dataset=dataset_inf,
    collate_fn=collate_fn,
    shuffle=False,
    num_workers=0,
    batch_size=32
)

In [None]:
inf_module = InferenceModuleMultimodal(
    model=pl_module,
    pandas_output=True,
    col_id='client_id',
)

In [None]:
inf_embeddings = pd.concat(trainer.predict(inf_module, inference_dl))

In [None]:
from catboost import Pool, CatBoostClassifier
from sklearn.model_selection import train_test_split

inf_train, inf_test = train_test_split(inf_embeddings, test_size=0.2)
inf_train, inf_val = train_test_split(inf_train, test_size=0.1)

In [None]:
y_1_inf_train, y_1_inf_test = inf_train['1'].to_numpy(), inf_test['1'].to_numpy()
y_2_inf_train, y_2_inf_test = inf_train['2'].to_numpy(), inf_test['2'].to_numpy()
y_3_inf_train, y_3_inf_test = inf_train['3'].to_numpy(), inf_test['3'].to_numpy()
y_4_inf_train, y_4_inf_test = inf_train['4'].to_numpy(), inf_test['4'].to_numpy()

y_inf_tests = [
    y_1_inf_test,
    y_2_inf_test,
    y_3_inf_test,
    y_4_inf_test
]

X_inf_train, X_inf_test = inf_train.drop(columns=['client_id', '1', '2', '3', '4']).to_numpy(), inf_test.drop(columns=['client_id', '1', '2', '3', '4']).to_numpy()
inf_val_pairs = {
    'X_inf': [
        inf_val.drop(columns=['client_id', '1', '2', '3', '4']).to_numpy(),
        inf_val.drop(columns=['client_id', '1', '2', '3', '4']).to_numpy(),
        inf_val.drop(columns=['client_id', '1', '2', '3', '4']).to_numpy(),
        inf_val.drop(columns=['client_id', '1', '2', '3', '4']).to_numpy()
    ],
    'y_inf': [
        inf_val['1'].to_numpy(),
        inf_val['2'].to_numpy(),
        inf_val['3'].to_numpy(),
        inf_val['4'].to_numpy()
    ]
}

In [None]:
from lightgbm import LGBMClassifier

models = [LGBMClassifier(
    n_estimators=500,
    boosting_type='gbdt',
    subsample=0.5,
    subsample_freq=1,
    learning_rate=0.02,
    feature_fraction=0.75,
    max_depth=6,
    lambda_l1=1,
    lambda_l2=1,
    min_data_in_leaf=50,
    random_state=42,
    n_jobs=8,
    verbose=-1
) for _ in range(4)]

In [None]:
train_datasets = [
    (X_inf_train, y_1_inf_train),
    (X_inf_train, y_2_inf_train),
    (X_inf_train, y_3_inf_train),
    (X_inf_train, y_4_inf_train)
]
val_datasets = [
    (inf_val_pairs['X_inf'][0], inf_val_pairs['y_inf'][0]),
    (inf_val_pairs['X_inf'][1], inf_val_pairs['y_inf'][1]),
    (inf_val_pairs['X_inf'][2], inf_val_pairs['y_inf'][2]),
    (inf_val_pairs['X_inf'][3], inf_val_pairs['y_inf'][3])
]

In [None]:
for i in range(len(models)):
    models[i].fit(train_datasets[i][0], train_datasets[i][1], eval_set=val_datasets[i])

In [None]:
from sklearn.metrics import classification_report
from sklearn.metrics import roc_auc_score

preds = []
for i in range(len(models)):
    preds.append(models[i].predict_proba(X_inf_test))
    # print(classification_report(y_inf_tests[i], preds[i]))
    print(f"ROC-AUC target_{i} = {roc_auc_score(y_inf_tests[i], preds[i][:, 1])}")

**CoLES + regional attention in trx (segment_length=10) + LightGBM (Without LN):**

ROC-AUC target_0 = 0.6674023787783245

ROC-AUC target_1 = 0.8062314626627639

ROC-AUC target_2 = 0.5838852723065378

ROC-AUC target_3 = 0.6902740886029226

**CoLES + regional attention in trx + LightGBM (With LN):**

ROC-AUC target_0 = 0.6835130424746901

ROC-AUC target_1 = 0.7069505322479649

ROC-AUC target_2 = 0.5838252484700381

ROC-AUC target_3 = 0.6956062235064757


**CoLES + regional attention in trx + LightGBM (Witр BN):**

ROC-AUC target_0 = 0.5941521400656775

ROC-AUC target_1 = 0.72975601217508

ROC-AUC target_2 = 0.5737223564128565

ROC-AUC target_3 = 0.618690245938203

**CoLES + regional attention in trx (segment_length=5) (With LN)**

ROC-AUC target_0 = 0.702929348375156

ROC-AUC target_1 = 0.5948162619622575

ROC-AUC target_2 = 0.576047203889645

ROC-AUC target_3 = 0.7450892857142857