# Libraries

In [1]:
import warnings

warnings.filterwarnings("ignore")

In [2]:
import gc
import os

from datetime import datetime
from functools import partial
from tqdm.auto import tqdm

import polars as pl

import torch
import pytorch_lightning
from torch import nn
from torchinfo import summary
from pytorch_lightning.loggers import WandbLogger

from ptls.nn import TrxEncoder, PBL2Norm, PBLayerNorm, RnnSeqEncoder
from ptls.nn.trx_encoder.encoders import IdentityEncoder
from ptls.frames.supervised import SeqToTargetDataset, SequenceToTarget
from ptls.frames import PtlsDataModule

from src.config import MONTHS, N_CHUNKS
from src.preprocessing import PolarsDataPreprocessor
from src.utils import prepare_data, to_records, init_weights
from src.metric import AUROC
from build_chunks import train_clients

In [3]:
pl.set_random_seed(56)
pytorch_lightning.seed_everything(56)

Global seed set to 56


56

# Data

целевой месяц

In [4]:
MON = 10

## Prepare data preprocessor

In [5]:
dial_preprocessor = PolarsDataPreprocessor(
    col_id="client_id",
    col_event_time="event_time",
    cols_category=["mon_ind", "day_ind"],
    cols_numerical=["embedding"],
)

In [6]:
train_target = pl.read_parquet("./data/train_target.parquet/*")
train_target = train_target.filter(pl.col("mon") == MONTHS[MON])
train_target = train_target.with_columns(target=pl.concat_list([f"target_{i}" for i in range(1, 5)]))
train_target = train_target.select(("client_id", "target"))
train_target.head()

client_id,target
str,list[i32]
"""1d4ebf30ab5b98…","[0, 0, … 0]"
"""1d55174bce3ef4…","[0, 0, … 0]"
"""1d5d052f87d6bd…","[0, 0, … 0]"
"""1d68b588164639…","[0, 0, … 0]"
"""1d817e82e1cc59…","[0, 0, … 0]"


считаем счетчики для категориальных признаков

In [7]:
%%time

source = "F:/chunks/dial_train.parquet/"
for file in tqdm(os.listdir(source)):
    train_dial = pl.read_parquet(os.path.join(source, file))
    train_dial = train_dial.filter(pl.col("event_time") < datetime.strptime(MONTHS[MON], "%Y-%m-%d").replace(day=1))
    train_dial = train_dial.filter(pl.col("client_id").is_in(train_target["client_id"]))
    train_dial = prepare_data(train_dial)
    dial_preprocessor.fit(train_dial)

  0%|          | 0/8 [00:00<?, ?it/s]

создаем словари

In [8]:
dial_preprocessor.freeze()

<src.preprocessing.PolarsDataPreprocessor at 0x20e8afc4190>

## Prepare val data

In [9]:
val_target = pl.read_parquet("./data/test_target_b.parquet/*").unique()
val_target = val_target.filter(pl.col("mon") == MONTHS[MON])
val_target = val_target.with_columns(target=pl.concat_list([f"target_{i}" for i in range(1, 5)]))
val_target = val_target.select(("client_id", "target"))
val_target.head()

client_id,target
str,list[i32]
"""fa3cb84514a7a5…","[0, 0, … 0]"
"""f2d13658b42afc…","[0, 0, … 0]"
"""23edaf4a96e350…","[0, 0, … 0]"
"""3ec96190b94b6e…","[0, 0, … 0]"
"""72794a9d34ce6c…","[0, 0, … 0]"


In [10]:
source = "F:/dial_test.parquet/"
val_dial = []
for file in tqdm(os.listdir(source)):
    cur = pl.read_parquet(os.path.join(source, file))
    cur = cur.filter(pl.col("event_time") < datetime.strptime(MONTHS[MON], "%Y-%m-%d").replace(day=1))
    cur = cur.filter(pl.col("client_id").is_in(val_target["client_id"]))
    val_dial.append(cur)
val_dial = pl.concat(val_dial)
val_dial = prepare_data(val_dial)
print(val_dial.shape)
val_dial.head()

  0%|          | 0/3 [00:00<?, ?it/s]

client_id,event_time,embedding,mon_ind,day_ind
str,i64,list[f32],i64,i64
"""07bbb9625426c5…",52392760,"[0.397706, -0.029813, … 0.379887]",19,606
"""07bbb9625426c5…",39437321,"[0.578899, -0.342521, … 0.563267]",15,456
"""07bbb9625426c5…",38833031,"[0.432345, -0.347777, … 0.467044]",14,449
"""07bbb9625426c5…",49548823,"[-0.06783, 0.117739, … -0.17042]",18,573
"""07bbb9625426c5…",37116263,"[-0.067204, 0.386104, … -0.030214]",14,429


In [11]:
val_dial = dial_preprocessor.transform(val_dial)
val_dial = val_dial.join(val_target, on="client_id", how="outer").sample(fraction=1.0, shuffle=True, seed=56)
val_dict = to_records(val_dial, embedding_dim=768)
len(val_dict)

48877

In [12]:
val_data = SeqToTargetDataset(val_dict, target_col_name="target", target_dtype=torch.float32)

# Model

модель состоит из трех частей:

    1) кодировщик отдельной транзакции
    2) кодировцик всей последовательности
    3) классификационная голова

In [13]:
HIDDEN_SIZE = 128
NUM_CLASSES = 4

## TrxEncoder

In [14]:
dial_preprocessor.get_category_dictionary_sizes()

{'mon_ind': 13, 'day_ind': 336}

In [15]:
dial_base_encoder = TrxEncoder(
    embeddings={
        "mon_ind": {"in": 13, "out": 6},
        "day_ind": {"in": 336, "out": 32},
    },
    custom_embeddings={
        "embedding": IdentityEncoder(768),
    },
    use_batch_norm=False,
    linear_projection_size=HIDDEN_SIZE,
)
# важно добавить нормировку для эмбеддингов
dial_encoder = nn.Sequential(dial_base_encoder, PBL2Norm(), PBLayerNorm(HIDDEN_SIZE))

## SeqEncoder

In [16]:
seq_encoder = RnnSeqEncoder(
    trx_encoder=dial_encoder,
    input_size=HIDDEN_SIZE,
    hidden_size=HIDDEN_SIZE,
)

## Head

используем BCEWithLogitsLoss, поэтому в конце нет Sigmoid-ы

In [17]:
head = nn.Sequential(
    nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE // 2),
    nn.ReLU(),
    nn.Linear(HIDDEN_SIZE // 2, NUM_CLASSES),
)

## Framework

In [18]:
sup_module = SequenceToTarget(
    seq_encoder=seq_encoder,
    head=head,
    loss=nn.BCEWithLogitsLoss(),
    metric_list=AUROC(NUM_CLASSES),
    optimizer_partial=partial(torch.optim.AdamW, lr=1e-4),
    lr_scheduler_partial=partial(torch.optim.lr_scheduler.ConstantLR, factor=1.0), # заглушка, чтобы Lr не менялся
)

In [19]:
summary(sup_module)

Layer (type:depth-idx)                             Param #
SequenceToTarget                                   --
├─RnnSeqEncoder: 1-1                               --
│    └─Sequential: 2-1                             --
│    │    └─TrxEncoder: 3-1                        114,126
│    │    └─PBShell: 3-2                           --
│    │    └─PBShell: 3-3                           256
│    └─RnnEncoder: 2-2                             128
│    │    └─GRU: 3-4                               99,072
│    │    └─LastStepEncoder: 3-5                   --
├─Sequential: 1-2                                  --
│    └─Linear: 2-3                                 8,256
│    └─ReLU: 2-4                                   --
│    └─Linear: 2-5                                 260
├─BCEWithLogitsLoss: 1-3                           --
├─ModuleDict: 1-4                                  --
│    └─AUROC: 2-6                                  --
│    │    └─MultilabelAUROC: 3-6                   --
├─Module

## Initialize the weights

In [20]:
for module in sup_module.seq_encoder.trx_encoder.modules():
    init_weights(module, 1 / HIDDEN_SIZE ** 0.5)

for module in sup_module.head.modules():
    init_weights(module, 1 / HIDDEN_SIZE ** 0.5)

# Train model

In [21]:
logger = WandbLogger()

In [22]:
EPOCHS = 1

каждую эпоху обучаем на каждом из чанков последовательно

In [23]:
%%time

source = "F:/chunks/dial_train.parquet/"
chunk_size = (len(train_clients) + N_CHUNKS - 1) // N_CHUNKS
for epoch in tqdm(range(EPOCHS)):
    print(f"Epoch: {epoch}")
    trainer = None
    for i in tqdm(range(N_CHUNKS)):
        # клиенты, которые попали в чанк и есть в обучающей выборке
        cur_clients = set(train_clients[i * chunk_size:(i + 1) * chunk_size]) & set(train_target["client_id"])

        train_dial = pl.read_parquet(os.path.join(source, f"part-{i}.parquet"))
        train_dial = train_dial.filter(pl.col("event_time") < datetime.strptime(MONTHS[MON], "%Y-%m-%d").replace(day=1))
        train_dial = train_dial.filter(pl.col("client_id").is_in(cur_clients))
        train_dial = prepare_data(train_dial)

        train_dial = dial_preprocessor.transform(train_dial)
        train_dial = train_dial.join(
            train_target.filter(pl.col("client_id").is_in(cur_clients)), on="client_id", how="outer"
        )
        train_dict = to_records(train_dial, embedding_dim=768)

        train_data = SeqToTargetDataset(train_dict, target_col_name="target", target_dtype=torch.float32)

        sup_data = PtlsDataModule(
            train_data=train_data,
            train_num_workers=0,
            train_batch_size=256,
            valid_data=val_data,
            valid_num_workers=0,
            valid_batch_size=256,
        )

        trainer = pytorch_lightning.Trainer(
            logger=logger,
            max_epochs=2, # для ускорения сделаем несколько прогонов
            accelerator="gpu",
            devices=1,
            enable_progress_bar=True,
        )

        trainer.fit(sup_module, sup_data)

        del sup_data, train_data, train_dict, train_dial
        gc.collect()

    print(trainer.callback_metrics['valid/AUROC'].item())

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

0.6834