# Libraries

In [1]:
import warnings

warnings.filterwarnings("ignore")

In [2]:
import gc
import os

from datetime import datetime
from functools import partial
from tqdm.auto import tqdm

import polars as pl

import torch
import pytorch_lightning
from torch import nn
from torchinfo import summary
from pytorch_lightning.loggers import WandbLogger

from ptls.nn import TrxEncoder, PBL2Norm, PBLayerNorm
from ptls.nn.trx_encoder.encoders import IdentityEncoder
from ptls.nn.seq_encoder.rnn_encoder import RnnEncoder
from ptls.frames.coles.multimodal_supervised_dataset import MultiModalSupervisedIterableDataset
# from ptls.frames.coles.multimodal_dataset import MultiModalSortTimeSeqEncoderContainer
from ptls.frames.supervised import SequenceToTarget
from ptls.frames import PtlsDataModule

from src.config import MONTHS, N_CHUNKS
from src.preprocessing import PolarsDataPreprocessor
from src.utils import prepare_data, to_records, init_weights
from src.multimodal_dataset import MultiModalSortTimeSeqEncoderContainer
from src.metric import AUROC
from build_chunks import train_clients

In [3]:
pl.set_random_seed(56)
pytorch_lightning.seed_everything(56)

56

In [4]:
# torch.set_float32_matmul_precision('high')

# Supervised-multimodal

In [5]:
MON = 10  # целевой месяц
HIDDEN_SIZE = 128
NUM_CLASSES = 4

In [6]:
train_target = pl.read_parquet("./data/train_target.parquet/*")
train_target = train_target.filter(pl.col("mon") == MONTHS[MON])
train_target = train_target.with_columns(target=pl.concat_list([f"target_{i}" for i in range(1, 5)]))
train_target = train_target.select(("client_id", "target"))
train_target.head()

client_id,target
str,list[i32]
"""1d4ebf30ab5b98…","[0, 0, … 0]"
"""1d55174bce3ef4…","[0, 0, … 0]"
"""1d5d052f87d6bd…","[0, 0, … 0]"
"""1d68b588164639…","[0, 0, … 0]"
"""1d817e82e1cc59…","[0, 0, … 0]"


In [7]:
val_target = pl.read_parquet("./data/test_target_b.parquet/*").unique()
val_target = val_target.filter(pl.col("mon") == MONTHS[MON])
val_target = val_target.with_columns(target=pl.concat_list([f"target_{i}" for i in range(1, 5)]))
val_target = val_target.select(("client_id", "target"))
val_target.head()

client_id,target
str,list[i32]
"""f2d13658b42afc…","[0, 0, … 0]"
"""f3975d57411ebc…","[0, 0, … 0]"
"""5c446c72e11deb…","[0, 0, … 0]"
"""fabe5ac957bf10…","[0, 0, … 0]"
"""12d1d490206e23…","[0, 0, … 0]"


мультимодальная модель состоит из четырех частей:

    1) три кодировщика отдельного действия для каждой из модальностей
    2) их выходы объединяются в одну последовательность и сортируются по времени события
    3) кодировцик всей последовательности
    4) классификационная голова

сначала подготовим препроцессоры и кодировщики

# Trx

## Data

объединим одинаковые подряд идущие действия в одно

In [8]:
def trx_unite(data):
    cols = [
        "day_ind", "event_type", "event_subtype", "currency", "src_type11", "src_type12", "dst_type11", "dst_type12"
    ]
    data = data.sort("event_time")
    data = data.with_columns(
        [pl.col(col).fill_null(-1) for col in cols]
    )
    data = data.with_columns(
        [(pl.col(col) != pl.col(col).shift(1).over("client_id")).fill_null(False).alias(f"flag_{col}") for col in cols]
    )
    flag = False
    for col in cols:
        flag |= data[f"flag_{col}"]
    data = data.with_columns(flag=flag)
    data = data.with_columns(group_id=pl.col("flag").cumsum().over("client_id"))
    return (
        data
        .group_by(("client_id", "group_id"), maintain_order=True)
        .agg(
            pl.col("event_time").mean(),  # положим средний момент времени, когда были выполнены транзакции
            pl.col("amount").sum(),  # положим общую сумму транзакций
            pl.col("event_type").first(),
            pl.col("event_subtype").first(),
            pl.col("currency").first(),
            pl.col("src_type11").first(),
            pl.col("src_type12").first(),
            pl.col("dst_type11").first(),
            pl.col("dst_type12").first(),
            pl.col("src_type21").first(),
            pl.col("src_type22").first(),
            pl.col("src_type31").first(),
            pl.col("src_type32").first(),
            pl.col("mon_ind").first(),
            pl.col("day_ind").first(),
        )
        .drop("group_id")
    )

`src_type21`, `src_type22`, `src_type31`, `src_type32` — одинаковые для одного пользователя, при этом `src_type22` и `src_type32` значительно улучшают качество, а `src_type21` и `src_type31` — ухудшают

In [9]:
trx_preprocessor = PolarsDataPreprocessor(
    col_id="client_id",
    col_event_time="event_time",
    cols_category=[
        "mon_ind", "day_ind",
        "event_type", "event_subtype",
        "currency",
        "src_type11", "src_type12", "dst_type11", "dst_type12",
        "src_type22", "src_type32",
    ],
    prefix="trx",
)

считаем счетчики для категориальных признаков

In [10]:
%%time

source = "./data/chunks/trx_train.parquet/"
for file in tqdm(os.listdir(source)):
    train_trx = pl.read_parquet(os.path.join(source, file))
    train_trx = train_trx.filter(pl.col("event_time") < datetime.strptime(MONTHS[MON], "%Y-%m-%d").replace(day=1))
    train_trx = train_trx.filter(pl.col("client_id").is_in(train_target["client_id"]))
    train_trx = prepare_data(train_trx)
    train_trx = trx_unite(train_trx)
    trx_preprocessor.fit(train_trx)

  0%|          | 0/8 [00:00<?, ?it/s]

создаем словари

In [11]:
trx_preprocessor.freeze()

<src.preprocessing.PolarsDataPreprocessor at 0x1fa3634df70>

In [12]:
del train_trx

## TrxEncoder

In [13]:
trx_preprocessor.get_category_dictionary_sizes()

{'mon_ind': 13,
 'day_ind': 336,
 'event_type': 56,
 'event_subtype': 58,
 'currency': 17,
 'src_type11': 76,
 'src_type12': 326,
 'dst_type11': 77,
 'dst_type12': 387,
 'src_type22': 89,
 'src_type32': 90}

In [14]:
trx_base_encoder = TrxEncoder(
    embeddings={
        "mon_ind": {"in": 13, "out": 6},
        "day_ind": {"in": 336, "out": 32},

        "event_type": {"in": 56, "out": 16},
        "event_subtype": {"in": 58, "out": 16},

        "currency": {"in": 17, "out": 8},

        "src_type11": {"in": 76, "out": 16},
        "src_type12": {"in": 326, "out": 32},
        "dst_type11": {"in": 77, "out": 16},
        "dst_type12": {"in": 387, "out": 32},

        "src_type22": {"in": 89, "out": 16},
        "src_type32": {"in": 90, "out": 16},
    },
    linear_projection_size=128,
)
# важно добавить нормировку для эмбеддингов
trx_encoder = nn.Sequential(trx_base_encoder, PBL2Norm(), PBLayerNorm(HIDDEN_SIZE))

# Geo

## Data

объединим одинаковые подряд идущие действия в одно

In [15]:
def geo_unite(data):
    cols = ["geohash_6", "day_ind"]
    data = data.sort("event_time")
    data = data.with_columns(
        [pl.col(col).fill_null(-1) for col in cols]
    )
    data = data.with_columns(
        [(pl.col(col) != pl.col(col).shift(1).over("client_id")).fill_null(False).alias(f"flag_{col}") for col in cols]
    )
    flag = False
    for col in cols:
        flag |= data[f"flag_{col}"]
    data = data.with_columns(flag=flag)
    data = data.with_columns(group_id=pl.col("flag").cumsum().over("client_id"))
    return (
        data
        .group_by(("client_id", "group_id"), maintain_order=True)
        .agg(
            pl.col("event_time").first(),  # положим первый момент времени, когда пользователь оказался в локации
            pl.col("geohash_4").first(),
            pl.col("geohash_5").first(),
            pl.col("geohash_6").first(),
            pl.col("mon_ind").first(),
            pl.col("day_ind").first(),
            pl.count().cast(pl.Int32),
        )
        .drop("group_id")
    )

In [16]:
geo_preprocessor = PolarsDataPreprocessor(
    col_id="client_id",
    col_event_time="event_time",
    cols_category=["mon_ind", "day_ind", "geohash_4"],
    cols_numerical=["count"],
    prefix="geo",
)

считаем счетчики для категориальных признаков

In [17]:
%%time

source = "./data/chunks/geo_train.parquet/"
for file in tqdm(os.listdir(source)):
    train_geo = pl.read_parquet(os.path.join(source, file))
    train_geo = train_geo.filter(pl.col("event_time") < datetime.strptime(MONTHS[MON], "%Y-%m-%d").replace(day=1))
    train_geo = train_geo.filter(pl.col("client_id").is_in(train_target["client_id"]))
    train_geo = prepare_data(train_geo)
    train_geo = geo_unite(train_geo)
    geo_preprocessor.fit(train_geo)

  0%|          | 0/8 [00:00<?, ?it/s]

создаем словари

In [18]:
geo_preprocessor.freeze()

<src.preprocessing.PolarsDataPreprocessor at 0x1fa363a8730>

In [19]:
del train_geo

## TrxEncoder

In [20]:
geo_preprocessor.get_category_dictionary_sizes()

{'mon_ind': 13, 'day_ind': 336, 'geohash_4': 34036}

In [21]:
geo_base_encoder = TrxEncoder(
    embeddings={
        "mon_ind": {"in": 13, "out": 6},
        "day_ind": {"in": 336, "out": 32},
        "geohash_4": {"in": 4096, "out": 128},  # оставим только наиболее часто встречающиеся
    },
    numeric_values={
        "count": "identity"
    },
    linear_projection_size=128,
)
# важно добавить нормировку для эмбеддингов
geo_encoder = nn.Sequential(geo_base_encoder, PBL2Norm(), PBLayerNorm(HIDDEN_SIZE))

# Dial

## Data

In [22]:
dial_preprocessor = PolarsDataPreprocessor(
    col_id="client_id",
    col_event_time="event_time",
    cols_category=["mon_ind", "day_ind"],
    cols_numerical=["embedding"],
    prefix="dial",
)

считаем счетчики для категориальных признаков

In [23]:
%%time

source = "./data/chunks/dial_train.parquet/"
for file in tqdm(os.listdir(source)):
    train_dial = pl.read_parquet(os.path.join(source, file))
    train_dial = train_dial.filter(pl.col("event_time") < datetime.strptime(MONTHS[MON], "%Y-%m-%d").replace(day=1))
    train_dial = train_dial.filter(pl.col("client_id").is_in(train_target["client_id"]))
    train_dial = prepare_data(train_dial)
    dial_preprocessor.fit(train_dial)

  0%|          | 0/8 [00:00<?, ?it/s]

создаем словари

In [24]:
dial_preprocessor.freeze()

<src.preprocessing.PolarsDataPreprocessor at 0x1fa367fa6a0>

In [25]:
del train_dial

## TrxEncoder

In [26]:
dial_preprocessor.get_category_dictionary_sizes()

{'mon_ind': 13, 'day_ind': 336}

In [27]:
dial_base_encoder = TrxEncoder(
    embeddings={
        "mon_ind": {"in": 13, "out": 6},
        "day_ind": {"in": 336, "out": 32},
    },
    custom_embeddings={
        "embedding": IdentityEncoder(768),
    },
    use_batch_norm=False,
    linear_projection_size=HIDDEN_SIZE,
)
# важно добавить нормировку для эмбеддингов
dial_encoder = nn.Sequential(dial_base_encoder, PBL2Norm(), PBLayerNorm(HIDDEN_SIZE))

# Merge datasets

сделаем это для валидации

In [28]:
# список всех признаков
source_features = {
    "trx": [
        "event_time",
        "mon_ind", "day_ind",
        "event_type", "event_subtype",
        "currency",
        "src_type11", "src_type12", "dst_type11", "dst_type12",
        "src_type22", "src_type32",
    ],
    "geo": [
        "event_time",
        "mon_ind", "day_ind", "geohash_4",
        "count"
    ],
    "dial": [
        "event_time",
        "mon_ind", "day_ind",
        "embedding"
    ],
}

In [29]:
source = "./data/trx_test.parquet/"
val_trx = []
for file in tqdm(os.listdir(source)):
    cur = pl.read_parquet(os.path.join(source, file))
    cur = cur.filter(pl.col("event_time") < datetime.strptime(MONTHS[MON], "%Y-%m-%d").replace(day=1))
    cur = cur.filter(pl.col("client_id").is_in(val_target["client_id"]))
    val_trx.append(cur)
val_trx = pl.concat(val_trx)
val_trx = prepare_data(val_trx)
val_trx = trx_unite(val_trx)
print(val_trx.shape)
val_trx.head()

  0%|          | 0/3 [00:00<?, ?it/s]

client_id,event_time,amount,event_type,event_subtype,currency,src_type11,src_type12,dst_type11,dst_type12,src_type21,src_type22,src_type31,src_type32,mon_ind,day_ind
str,f64,f32,i32,i32,i32,i32,i32,i32,i32,i32,i32,i32,i32,i64,i64
"""72165b9123a68d…",31525200.0,1363.118042,54,8,11,70,201,1171,23776,46241,15,1323,25,11,364
"""d4cdf7c3849920…",31525206.0,9923.888672,37,18,11,19,902,852,14606,11174,39,1474,39,11,364
"""72165b9123a68d…",31525211.0,936.750793,54,55,11,70,201,1171,23776,46241,15,1323,25,11,364
"""59594460508b42…",31525218.0,50.728539,40,51,11,72,280,433,10049,46241,15,623,81,11,364
"""7b26eb6a9135e7…",31525225.0,1924.053467,16,31,11,19,344,813,18723,30859,87,536,81,11,364


In [30]:
val_trx = trx_preprocessor.transform(val_trx)

In [31]:
source = "./data/geo_test.parquet/"
val_geo = []
for file in tqdm(os.listdir(source)):
    cur = pl.read_parquet(os.path.join(source, file))
    cur = cur.filter(pl.col("event_time") < datetime.strptime(MONTHS[MON], "%Y-%m-%d").replace(day=1))
    cur = cur.filter(pl.col("client_id").is_in(val_target["client_id"]))
    val_geo.append(cur)
val_geo = pl.concat(val_geo)
val_geo = prepare_data(val_geo)
val_geo = geo_unite(val_geo)
print(val_geo.shape)
val_geo.head()

  0%|          | 0/6 [00:00<?, ?it/s]

client_id,event_time,geohash_4,geohash_5,geohash_6,mon_ind,day_ind,count
str,i64,i32,i32,i32,i64,i64,i32
"""229c772c6cacec…",31525205,7168,295849,1978488,11,364,1
"""2379cc7ae1bbd4…",31525205,17154,206157,295130,11,364,1
"""ff9d467372a9bd…",31525217,21062,349208,2751538,11,364,1
"""b13ab03a891a17…",31525228,36984,254857,1980854,11,364,1
"""bf80fd4757f415…",31525232,5169,110999,2320023,11,364,1


In [32]:
val_geo = geo_preprocessor.transform(val_geo)

In [33]:
source = "./data/dial_test.parquet/"
val_dial = []
for file in tqdm(os.listdir(source)):
    cur = pl.read_parquet(os.path.join(source, file))
    cur = cur.filter(pl.col("event_time") < datetime.strptime(MONTHS[MON], "%Y-%m-%d").replace(day=1))
    cur = cur.filter(pl.col("client_id").is_in(val_target["client_id"]))
    val_dial.append(cur)
val_dial = pl.concat(val_dial)
val_dial = prepare_data(val_dial)
print(val_dial.shape)
val_dial.head()

  0%|          | 0/3 [00:00<?, ?it/s]

client_id,event_time,embedding,mon_ind,day_ind
str,i64,list[f32],i64,i64
"""07bbb9625426c5…",52392760,"[0.397706, -0.029813, … 0.379887]",19,606
"""07bbb9625426c5…",39437321,"[0.578899, -0.342521, … 0.563267]",15,456
"""07bbb9625426c5…",38833031,"[0.432345, -0.347777, … 0.467044]",14,449
"""07bbb9625426c5…",49548823,"[-0.06783, 0.117739, … -0.17042]",18,573
"""07bbb9625426c5…",37116263,"[-0.067204, 0.386104, … -0.030214]",14,429


In [34]:
val_dial = dial_preprocessor.transform(val_dial)

In [35]:
val_joined = val_target.join(val_trx, on="client_id", how="left")
val_joined = val_joined.join(val_geo, on="client_id", how="left")
val_joined = val_joined.join(val_dial, on="client_id", how="left")
val_joined = val_joined.sample(fraction=1.0, shuffle=True, seed=56)
print(val_joined.shape)
val_joined.head()

(48877, 23)


client_id,target,trx_event_time,trx_mon_ind,trx_day_ind,trx_event_type,trx_event_subtype,trx_currency,trx_src_type11,trx_src_type12,trx_dst_type11,trx_dst_type12,trx_src_type22,trx_src_type32,geo_event_time,geo_mon_ind,geo_day_ind,geo_geohash_4,geo_count,dial_event_time,dial_mon_ind,dial_day_ind,dial_embedding
str,list[i32],list[f64],list[i32],list[i32],list[i32],list[i32],list[i32],list[i32],list[i32],list[i32],list[i32],list[i32],list[i32],list[i64],list[i32],list[i32],list[i32],list[i32],list[i64],list[i32],list[i32],list[list[f32]]
"""3084feefff86b9…","[0, 0, … 0]","[4.8260296e7, 4.8789453e7, … 6.0286706e7]","[2, 2, … 9]","[86, 59, … 333]","[6, 10, … 3]","[6, 9, … 3]","[1, 1, … 1]","[3, 4, … 1]","[8, 5, … 1]","[3, 8, … 3]","[3, 9, … 3]","[1, 1, … 1]","[30, 30, … 30]","[40039161, 40044710, … 60338760]","[6, 6, … 10]","[166, 166, … 330]","[1, 1, … 1]","[1, 1, … 1]","[47838417, 48357850, … 59741680]","[1, 1, … 5]","[16, 17, … 135]","[[0.265258, -0.03815, … 0.287905], [0.265258, -0.23815, … 0.287905], … [0.257742, -0.367889, … 0.215164]]"
"""0157292b98fdfe…","[0, 0, … 0]","[3.1537507e7, 3.1629219e7, … 6.0385744e7]","[11, 11, … 9]","[268, 269, … 334]","[3, 1, … 4]","[3, 1, … 4]","[1, 1, … 1]","[1, 1, … 1]","[2, 2, … 2]","[3, 1, … 5]","[3, 1, … 6]","[27, 27, … 27]","[41, 41, … 41]","[32457306, 32883580, … 60364780]","[11, 11, … 10]","[317, 307, … 330]","[27, 27, … 27]","[1, 1, … 1]",,,,
"""826f779f771dbb…","[0, 0, … 0]",,,,,,,,,,,,,,,,,,,,,
"""6c3bccb98c96ba…","[0, 0, … 1]","[3.5086603e7, 3.6179829e7, … 6.0274971e7]","[10, 10, … 9]","[280, 199, … 333]","[1, 1, … 1]","[1, 1, … 1]","[1, 1, … 1]","[1, 7, … 1]","[1, 10, … 1]","[2, 1, … 11]","[5, 1, … 36]","[44, 44, … 44]","[7, 7, … 7]",,,,,,"[43064481, 44199936, … 58534778]","[7, 7, … 5]","[224, 188, … 168]","[[0.471715, -0.239396, … 0.568909], [0.374588, 0.130431, … 0.382765], … [0.551841, -0.229372, … 0.522842]]"
"""20abfa068f3466…","[0, 0, … 0]","[3.167186e7, 3.1836522e7, … 6.0350241e7]","[11, 11, … 9]","[269, 290, … 334]","[1, 1, … 25]","[1, 1, … 26]","[1, 1, … 1]","[1, 1, … 1]","[1, 1, … 1]","[1, 1, … 4]","[1, 1, … 4]","[10, 10, … 10]","[1, 1, … 1]","[50090250, 50100403, … 60358746]","[1, 1, … 10]","[16, 16, … 330]","[37, 20, … 37]","[1, 2, … 2]","[47824048, 50572383, … 57168259]","[1, 2, … 4]","[16, 59, … 146]","[[0.345506, -0.376908, … 0.310672], [0.065283, 0.142979, … -0.190615], … [0.231746, -0.141705, … 0.38085]]"


In [36]:
val_dict = to_records(val_joined, embedding_col="dial_embedding", embedding_dim=768)

In [37]:
for i in range(len(val_dict)):
    val_dict[i]["target"] = val_dict[i]["target"].tolist()

In [38]:
val_multimodal_data = MultiModalSupervisedIterableDataset(
    data=val_dict,
    source_features=source_features,
    source_names=source_features.keys(),
    col_id="client_id",
    col_time="event_time",
    target_name="target",
)

In [39]:
del val_dict, val_joined, val_dial, val_geo, val_trx

# Model

## SeqEncoder, it will combine our three TrxEncoders

In [40]:
seq_encoder = MultiModalSortTimeSeqEncoderContainer(
    trx_encoders={
        "trx": trx_encoder,
        "geo": geo_encoder,
        "dial": dial_encoder,
    },
    input_size=HIDDEN_SIZE,
    hidden_size=HIDDEN_SIZE,
    seq_encoder_cls=RnnEncoder,
    type="gru"
)

## Head

используем BCEWithLogitsLoss, поэтому в конце нет Sigmoid-ы

In [41]:
head = nn.Sequential(
    nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE // 2),
    nn.ReLU(),
    nn.Linear(HIDDEN_SIZE // 2, NUM_CLASSES),
)

## Framework

In [42]:
sup_module = SequenceToTarget(
    seq_encoder=seq_encoder,
    head=head,
    loss=nn.BCEWithLogitsLoss(),
    metric_list=AUROC(NUM_CLASSES),
    optimizer_partial=partial(torch.optim.AdamW, lr=1e-4),
    lr_scheduler_partial=partial(torch.optim.lr_scheduler.ConstantLR, factor=1.0),  # заглушка, чтобы Lr не менялся
)

In [43]:
summary(sup_module)

Layer (type:depth-idx)                                  Param #
SequenceToTarget                                        --
├─MultiModalSortTimeSeqEncoderContainer: 1-1            --
│    └─ModuleDict: 2-1                                  --
│    │    └─Sequential: 3-1                             67,670
│    │    └─Sequential: 3-2                             556,880
│    │    └─Sequential: 3-3                             114,382
│    └─RnnEncoder: 2-2                                  128
│    │    └─GRU: 3-4                                    99,072
│    │    └─LastStepEncoder: 3-5                        --
├─Sequential: 1-2                                       --
│    └─Linear: 2-3                                      8,256
│    └─ReLU: 2-4                                        --
│    └─Linear: 2-5                                      260
├─BCEWithLogitsLoss: 1-3                                --
├─ModuleDict: 1-4                                       --
│    └─AUROC: 2-6           

## Initialize the weights

In [44]:
for module in sup_module.seq_encoder.trx_encoders.modules():
    init_weights(module, 1 / HIDDEN_SIZE ** 0.5)

for module in sup_module.head.modules():
    init_weights(module, 1 / HIDDEN_SIZE ** 0.5)

# Train model

In [45]:
logger = WandbLogger()

In [46]:
EPOCHS = 2

каждую эпоху обучаем на каждом из чанков последовательно

In [None]:
%%time

trx_source = "./data/chunks/trx_train.parquet/"
geo_source = "./data/chunks/geo_train.parquet/"
dial_source = "./data/chunks/dial_train.parquet/"
chunk_size = (len(train_clients) + N_CHUNKS - 1) // N_CHUNKS
for epoch in tqdm(range(EPOCHS)):
    print(f"Epoch: {epoch}")
    trainer = None
    for i in tqdm(range(N_CHUNKS)):
        # клиенты, которые попали в чанк и есть в обучающей выборке
        cur_clients = set(train_clients[i * chunk_size:(i + 1) * chunk_size]) & set(train_target["client_id"])

        train_trx = pl.read_parquet(os.path.join(trx_source, f"part-{i}.parquet"))
        train_trx = train_trx.filter(pl.col("event_time") < datetime.strptime(MONTHS[MON], "%Y-%m-%d").replace(day=1))
        train_trx = train_trx.filter(pl.col("client_id").is_in(cur_clients))
        train_trx = prepare_data(train_trx)
        train_trx = trx_unite(train_trx)
        train_trx = trx_preprocessor.transform(train_trx)

        train_geo = pl.read_parquet(os.path.join(geo_source, f"part-{i}.parquet"))
        train_geo = train_geo.filter(pl.col("event_time") < datetime.strptime(MONTHS[MON], "%Y-%m-%d").replace(day=1))
        train_geo = train_geo.filter(pl.col("client_id").is_in(cur_clients))
        train_geo = prepare_data(train_geo)
        train_geo = geo_unite(train_geo)
        train_geo = geo_preprocessor.transform(train_geo)

        train_dial = pl.read_parquet(os.path.join(dial_source, f"part-{i}.parquet"))
        train_dial = train_dial.filter(pl.col("event_time") < datetime.strptime(MONTHS[MON], "%Y-%m-%d").replace(day=1))
        train_dial = train_dial.filter(pl.col("client_id").is_in(cur_clients))
        train_dial = prepare_data(train_dial)
        train_dial = dial_preprocessor.transform(train_dial)

        cur_train_target = train_target.filter(pl.col("client_id").is_in(cur_clients))
        train_joined = cur_train_target.join(train_trx, on="client_id", how="left")
        train_joined = train_joined.join(train_geo, on="client_id", how="left")
        train_joined = train_joined.join(train_dial, on="client_id", how="left")

        train_dict = to_records(train_joined, embedding_col="dial_embedding", embedding_dim=768)
        for j in range(len(train_dict)):
            train_dict[j]["target"] = train_dict[j]["target"].tolist()

        train_multimodal_data = MultiModalSupervisedIterableDataset(
            data=train_dict,
            source_features=source_features,
            source_names=source_features.keys(),
            col_id="client_id",
            col_time="event_time",
            target_name="target",
        )

        sup_data = PtlsDataModule(
            train_data=train_multimodal_data,
            train_num_workers=0,
            train_batch_size=32,
            valid_data=val_multimodal_data,
            valid_num_workers=0,
            valid_batch_size=32,
        )

        trainer = pytorch_lightning.Trainer(
            logger=logger,
            max_epochs=2,  # для ускорения сделаем несколько прогонов
            accelerator="gpu",
            devices=1,
            enable_progress_bar=True,
        )

        trainer.fit(sup_module, sup_data)

        del sup_data, train_multimodal_data, train_dict, train_joined, train_dial, train_geo, train_trx
        gc.collect()

    print(trainer.callback_metrics['valid/AUROC'].item())

# Inference

In [None]:
sup_data = PtlsDataModule(
    valid_data=val_multimodal_data,
    valid_num_workers=0,
    valid_batch_size=64,
)
trainer = pytorch_lightning.Trainer(
    accelerator="gpu",
    devices=1,
    enable_progress_bar=True,
)
prediction = trainer.predict(sup_module, sup_data.val_dataloader())

In [None]:
import pandas as pd

prediction = pd.concat(prediction)
prediction