# Libraries

In [1]:
import warnings

warnings.filterwarnings("ignore")

In [2]:
import gc
import os

from datetime import datetime
from functools import partial
from tqdm.auto import tqdm

import polars as pl

import torch
import pytorch_lightning
from torch import nn
from torchinfo import summary
from pytorch_lightning.loggers import WandbLogger

from ptls.nn import TrxEncoder, PBL2Norm, PBLayerNorm, RnnSeqEncoder
from ptls.frames.supervised import SeqToTargetDataset, SequenceToTarget
from ptls.frames import PtlsDataModule

from src.config import MONTHS, N_CHUNKS
from src.preprocessing import PolarsDataPreprocessor
from src.utils import prepare_data, to_records, init_weights
from src.metric import AUROC
from build_chunks import train_clients

In [3]:
pl.set_random_seed(56)
pytorch_lightning.seed_everything(56)

Global seed set to 56


56

# Data

целевой месяц

In [4]:
MON = 10

объединим одинаковые подряд идущие действия в одно

In [5]:
def unite(data):
    cols = [
        "day_ind", "event_type", "event_subtype", "currency", "src_type11", "src_type12", "dst_type11", "dst_type12"
    ]
    data = data.sort("event_time")
    data = data.with_columns(
        [pl.col(col).fill_null(-1) for col in cols]
    )
    data = data.with_columns(
        [(pl.col(col) != pl.col(col).shift(1).over("client_id")).fill_null(False).alias(f"flag_{col}") for col in cols]
    )
    flag = False
    for col in cols:
        flag |= data[f"flag_{col}"]
    data = data.with_columns(flag=flag)
    data = data.with_columns(group_id=pl.col("flag").cumsum().over("client_id"))
    return (
        data
        .group_by(("client_id", "group_id"), maintain_order=True)
        .agg(
            pl.col("event_time").mean(),  # положим средний момент времени, когда были выполнены транзакции
            pl.col("amount").sum(),  # положим общую сумму транзакций
            pl.col("event_type").first(),
            pl.col("event_subtype").first(),
            pl.col("currency").first(),
            pl.col("src_type11").first(),
            pl.col("src_type12").first(),
            pl.col("dst_type11").first(),
            pl.col("dst_type12").first(),
            pl.col("src_type21").first(),
            pl.col("src_type22").first(),
            pl.col("src_type31").first(),
            pl.col("src_type32").first(),
            pl.col("mon_ind").first(),
            pl.col("day_ind").first(),
        )
        .drop("group_id")
    )

## Prepare data preprocessor

`src_type21`, `src_type22`, `src_type31`, `src_type32` — одинаковые для одного пользователя, при этом `src_type22` и `src_type32` значительно улучшают качество, а `src_type21` и `src_type31` — ухудшают

In [6]:
trx_preprocessor = PolarsDataPreprocessor(
    col_id="client_id",
    col_event_time="event_time",
    cols_category=[
        "mon_ind", "day_ind",
        "event_type", "event_subtype",
        "currency",
        "src_type11", "src_type12", "dst_type11", "dst_type12",
        "src_type22", "src_type32",
    ],
)

In [7]:
train_target = pl.read_parquet("./data/train_target.parquet/*")
train_target = train_target.filter(pl.col("mon") == MONTHS[MON])
train_target = train_target.with_columns(target=pl.concat_list([f"target_{i}" for i in range(1, 5)]))
train_target = train_target.select(("client_id", "target"))
train_target.head()

client_id,target
str,list[i32]
"""1d4ebf30ab5b98…","[0, 0, … 0]"
"""1d55174bce3ef4…","[0, 0, … 0]"
"""1d5d052f87d6bd…","[0, 0, … 0]"
"""1d68b588164639…","[0, 0, … 0]"
"""1d817e82e1cc59…","[0, 0, … 0]"


считаем счетчики для категориальных признаков

In [8]:
%%time

source = "./data/chunks/trx_train.parquet/"
for file in tqdm(os.listdir(source)):
    train_trx = pl.read_parquet(os.path.join(source, file))
    train_trx = train_trx.filter(pl.col("event_time") < datetime.strptime(MONTHS[MON], "%Y-%m-%d").replace(day=1))
    train_trx = train_trx.filter(pl.col("client_id").is_in(train_target["client_id"]))
    train_trx = prepare_data(train_trx)
    train_trx = unite(train_trx)
    trx_preprocessor.fit(train_trx)

  0%|          | 0/8 [00:00<?, ?it/s]

создаем словари

In [9]:
trx_preprocessor.freeze()

<src.preprocessing.PolarsDataPreprocessor at 0x2283dcbdca0>

## Prepare val data

In [10]:
val_target = pl.read_parquet("./data/test_target_b.parquet/*").unique()
val_target = val_target.filter(pl.col("mon") == MONTHS[MON])
val_target = val_target.with_columns(target=pl.concat_list([f"target_{i}" for i in range(1, 5)]))
val_target = val_target.select(("client_id", "target"))
val_target.head()

client_id,target
str,list[i32]
"""581cdd678867bb…","[0, 0, … 0]"
"""b32e9c14b572f2…","[0, 0, … 0]"
"""12d1d490206e23…","[0, 0, … 0]"
"""75b2572752044c…","[0, 0, … 0]"
"""1ae27bb78f3530…","[0, 0, … 0]"


In [11]:
source = "./data/trx_test.parquet/"
val_trx = []
for file in tqdm(os.listdir(source)):
    cur = pl.read_parquet(os.path.join(source, file))
    cur = cur.filter(pl.col("event_time") < datetime.strptime(MONTHS[MON], "%Y-%m-%d").replace(day=1))
    cur = cur.filter(pl.col("client_id").is_in(val_target["client_id"]))
    val_trx.append(cur)
val_trx = pl.concat(val_trx)
val_trx = prepare_data(val_trx)
val_trx = unite(val_trx)
print(val_trx.shape)
val_trx.head()

  0%|          | 0/3 [00:00<?, ?it/s]

client_id,event_time,amount,event_type,event_subtype,currency,src_type11,src_type12,dst_type11,dst_type12,src_type21,src_type22,src_type31,src_type32,mon_ind,day_ind
str,f64,f32,i32,i32,i32,i32,i32,i32,i32,i32,i32,i32,i32,i64,i64
"""72165b9123a68d…",31525200.0,1363.118042,54,8,11,70,201,1171,23776,46241,15,1323,25,11,364
"""d4cdf7c3849920…",31525206.0,9923.888672,37,18,11,19,902,852,14606,11174,39,1474,39,11,364
"""72165b9123a68d…",31525211.0,936.750793,54,55,11,70,201,1171,23776,46241,15,1323,25,11,364
"""59594460508b42…",31525218.0,50.728539,40,51,11,72,280,433,10049,46241,15,623,81,11,364
"""7b26eb6a9135e7…",31525225.0,1924.053467,16,31,11,19,344,813,18723,30859,87,536,81,11,364


In [12]:
val_trx = trx_preprocessor.transform(val_trx)
val_trx = val_trx.join(val_target, on="client_id", how="outer").sample(fraction=1.0, shuffle=True, seed=56)
val_dict = to_records(val_trx)
len(val_dict)

48877

In [13]:
val_data = SeqToTargetDataset(val_dict, target_col_name="target", target_dtype=torch.float32)

# Model

модель состоит из трех частей:

    1) кодировщик отдельной транзакции
    2) кодировцик всей последовательности
    3) классификационная голова

In [14]:
HIDDEN_SIZE = 128
NUM_CLASSES = 4

## TrxEncoder

In [15]:
trx_preprocessor.get_category_dictionary_sizes()

{'mon_ind': 13,
 'day_ind': 336,
 'event_type': 56,
 'event_subtype': 58,
 'currency': 17,
 'src_type11': 76,
 'src_type12': 326,
 'dst_type11': 77,
 'dst_type12': 387,
 'src_type22': 89,
 'src_type32': 90}

In [16]:
trx_base_encoder = TrxEncoder(
    embeddings={
        "mon_ind": {"in": 13, "out": 6},
        "day_ind": {"in": 336, "out": 32},

        "event_type": {"in": 56, "out": 16},
        "event_subtype": {"in": 58, "out": 16},

        "currency": {"in": 17, "out": 8},

        "src_type11": {"in": 76, "out": 16},
        "src_type12": {"in": 326, "out": 32},
        "dst_type11": {"in": 77, "out": 16},
        "dst_type12": {"in": 387, "out": 32},

        "src_type22": {"in": 89, "out": 16},
        "src_type32": {"in": 90, "out": 16},
    },
    linear_projection_size=128,
)
# важно добавить нормировку для эмбеддингов
trx_encoder = nn.Sequential(trx_base_encoder, PBL2Norm(), PBLayerNorm(HIDDEN_SIZE))

## SeqEncoder

In [17]:
seq_encoder = RnnSeqEncoder(
    trx_encoder=trx_encoder,
    input_size=HIDDEN_SIZE,
    hidden_size=HIDDEN_SIZE,
)

## Head

используем BCEWithLogitsLoss, поэтому в конце нет Sigmoid-ы

In [18]:
head = nn.Sequential(
    nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE // 2),
    nn.ReLU(),
    nn.Linear(HIDDEN_SIZE // 2, NUM_CLASSES),
)

## Framework

In [19]:
sup_module = SequenceToTarget(
    seq_encoder=seq_encoder,
    head=head,
    loss=nn.BCEWithLogitsLoss(),
    metric_list=AUROC(NUM_CLASSES),
    optimizer_partial=partial(torch.optim.AdamW, lr=1e-4),
    lr_scheduler_partial=partial(torch.optim.lr_scheduler.ConstantLR, factor=1.0),  # заглушка, чтобы Lr не менялся
)

In [20]:
summary(sup_module)

Layer (type:depth-idx)                             Param #
SequenceToTarget                                   --
├─RnnSeqEncoder: 1-1                               --
│    └─Sequential: 2-1                             --
│    │    └─TrxEncoder: 3-1                        67,414
│    │    └─PBShell: 3-2                           --
│    │    └─PBShell: 3-3                           256
│    └─RnnEncoder: 2-2                             128
│    │    └─GRU: 3-4                               99,072
│    │    └─LastStepEncoder: 3-5                   --
├─Sequential: 1-2                                  --
│    └─Linear: 2-3                                 8,256
│    └─ReLU: 2-4                                   --
│    └─Linear: 2-5                                 260
├─BCEWithLogitsLoss: 1-3                           --
├─ModuleDict: 1-4                                  --
│    └─AUROC: 2-6                                  --
│    │    └─MultilabelAUROC: 3-6                   --
├─ModuleD

## Initialize the weights

In [21]:
for module in sup_module.seq_encoder.trx_encoder.modules():
    init_weights(module, 1 / HIDDEN_SIZE ** 0.5)

for module in sup_module.head.modules():
    init_weights(module, 1 / HIDDEN_SIZE ** 0.5)

# Train model

In [22]:
logger = WandbLogger()

In [23]:
EPOCHS = 2

каждую эпоху обучаем на каждом из чанков последовательно

In [24]:
%%time

source = "./data/chunks/trx_train.parquet/"
chunk_size = (len(train_clients) + N_CHUNKS - 1) // N_CHUNKS
for epoch in tqdm(range(EPOCHS)):
    print(f"Epoch: {epoch}")
    trainer = None
    for i in tqdm(range(N_CHUNKS)):
        # клиенты, которые попали в чанк и есть в обучающей выборке
        cur_clients = set(train_clients[i * chunk_size:(i + 1) * chunk_size]) & set(train_target["client_id"])

        train_trx = pl.read_parquet(os.path.join(source, f"part-{i}.parquet"))
        train_trx = train_trx.filter(pl.col("event_time") < datetime.strptime(MONTHS[MON], "%Y-%m-%d").replace(day=1))
        train_trx = train_trx.filter(pl.col("client_id").is_in(cur_clients))
        train_trx = prepare_data(train_trx)
        train_trx = unite(train_trx)

        train_trx = trx_preprocessor.transform(train_trx)
        train_trx = train_trx.join(
            train_target.filter(pl.col("client_id").is_in(cur_clients)), on="client_id", how="outer"
        )
        train_dict = to_records(train_trx)

        train_data = SeqToTargetDataset(train_dict, target_col_name="target", target_dtype=torch.float32)

        sup_data = PtlsDataModule(
            train_data=train_data,
            train_num_workers=0,
            train_batch_size=256,
            valid_data=val_data,
            valid_num_workers=0,
            valid_batch_size=256,
        )

        trainer = pytorch_lightning.Trainer(
            logger=logger,
            max_epochs=2,  # для ускорения сделаем несколько прогонов
            accelerator="gpu",
            devices=1,
            enable_progress_bar=True,
        )

        trainer.fit(sup_module, sup_data)

        del sup_data, train_data, train_dict, train_trx
        gc.collect()

    print(trainer.callback_metrics['valid/AUROC'].item())

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

0.8249