In [1]:
from datetime import datetime

import polars as pl

import torch
import pytorch_lightning
from torch import nn

from src.preprocessing import PolarsDataPreprocessor
from src.utils import prepare_data

In [2]:
pl.set_random_seed(56)

In [3]:
pytorch_lightning.seed_everything(56, True)

56

In [4]:
months = [
    '2022-02-28',
    '2022-03-31',
    '2022-04-30',
    '2022-05-31',
    '2022-06-30',
    '2022-07-31',
    '2022-08-31',
    '2022-09-30',
    '2022-10-31',
    '2022-11-30',
    '2022-12-31',
    '2023-01-31',
]
month2id = dict((month, idx) for idx, month in enumerate(months))
id2month = dict((idx, month) for idx, month in enumerate(months))

In [5]:
mon = 10

In [6]:
train_trx = pl.read_parquet("F:/chunks/trx_train.pq")
train_trx = prepare_data(train_trx)
train_trx = train_trx.filter(pl.col("event_time") < datetime.strptime(months[mon], "%Y-%m-%d").replace(day=1))
print(train_trx.shape)
train_trx.head()

(14603731, 16)


event_time,amount,client_id,event_type,event_subtype,currency,src_type11,src_type12,dst_type11,dst_type12,src_type21,src_type22,src_type31,src_type32,mon_ind,day_ind
i64,f32,str,i32,i32,f64,f64,f64,f64,f64,f64,f64,f64,f64,i64,i64
35499473,39204.261719,"""52f416800d4d8c…",54,55,11.0,19.0,344.0,364.0,22652.0,22823.0,48.0,942.0,4.0,13,410
38062618,77238.382812,"""52f416800d4d8c…",54,55,11.0,19.0,344.0,869.0,31488.0,22823.0,48.0,942.0,4.0,14,440
35018297,14293.958008,"""52f416800d4d8c…",54,55,11.0,19.0,344.0,364.0,22652.0,22823.0,48.0,942.0,4.0,13,405
32446147,2569.062988,"""52f416800d4d8c…",54,55,11.0,19.0,344.0,364.0,22652.0,22823.0,48.0,942.0,4.0,12,375
34413708,62966.214844,"""52f416800d4d8c…",54,55,11.0,19.0,344.0,364.0,22652.0,22823.0,48.0,942.0,4.0,13,398


In [7]:
train_target = pl.read_parquet("F:/chunks/train_target.pq")
train_target = train_target.with_columns(
    target=pl.concat_list([f"target_{i}" for i in range(1, 5)])
)
train_target = train_target.select(("client_id", "target"))
train_target.head()

client_id,target
str,list[i32]
"""06a2ce26f19242…","[0, 0, … 0]"
"""06d250bda1fe78…","[0, 0, … 0]"
"""1d0fd54040602e…","[0, 0, … 0]"
"""85c233cac30252…","[0, 0, … 0]"
"""16d2d3fbdef66b…","[0, 0, … 0]"


In [8]:
val_trx = pl.read_parquet("F:/chunks/trx_val.pq")
val_trx = prepare_data(val_trx)
val_trx = val_trx.filter(pl.col("event_time") < datetime.strptime(months[mon], "%Y-%m-%d").replace(day=1))
print(val_trx.shape)
val_trx.head()

(8361985, 16)


event_time,amount,client_id,event_type,event_subtype,currency,src_type11,src_type12,dst_type11,dst_type12,src_type21,src_type22,src_type31,src_type32,mon_ind,day_ind
i64,f32,str,i32,i32,f64,f64,f64,f64,f64,f64,f64,f64,f64,i64,i64
40850365,183.638657,"""1b290c64a11658…",25,47,11.0,19.0,344.0,1166.0,30836.0,44170.0,85.0,724.0,70.0,15,472
33121383,418.706451,"""1b290c64a11658…",40,51,11.0,72.0,189.0,938.0,18481.0,44170.0,85.0,724.0,70.0,12,383
47772211,2831.543945,"""1b290c64a11658…",52,12,11.0,128.0,456.0,1302.0,8693.0,44170.0,85.0,724.0,70.0,18,552
43369115,5114.887695,"""1b290c64a11658…",40,51,11.0,72.0,682.0,433.0,10049.0,44170.0,85.0,724.0,70.0,16,501
45257498,202.14978,"""1b290c64a11658…",52,12,11.0,128.0,456.0,1302.0,8693.0,44170.0,85.0,724.0,70.0,17,523


In [9]:
val_target = pl.read_parquet("F:/chunks/val_target.pq")
val_target = val_target.with_columns(
    target=pl.concat_list([f"target_{i}" for i in range(1, 5)])
)
val_target = val_target.select(("client_id", "target"))
val_target.head()

client_id,target
str,list[i32]
"""c67a6454099213…","[0, 0, … 0]"
"""e225b104fe6428…","[0, 0, … 0]"
"""eacde9cdeaf7ec…","[0, 0, … 0]"
"""e962187af0d4fa…","[0, 0, … 0]"
"""12502ddd41e10e…","[0, 0, … 0]"


In [10]:
trx_preprocessor = PolarsDataPreprocessor(
    col_id="client_id",
    col_event_time="event_time",
    cols_category=[
        "mon_ind", "day_ind", 
        "event_type", "event_subtype", 
        "currency",
        "src_type11", "src_type12", "dst_type11", "dst_type12",
        "src_type21", "src_type22", "src_type31", "src_type32",
    ],
    # cols_numerical=["amount"],
)

In [11]:
%%time

train_trx = trx_preprocessor.fit_transform(train_trx)

CPU times: total: 2min 56s
Wall time: 40.9 s


In [12]:
%%time

val_trx = trx_preprocessor.transform(val_trx)

CPU times: total: 1min 21s
Wall time: 6.53 s


In [13]:
train_trx = train_trx.join(train_target, on="client_id", how="outer")
val_trx = val_trx.join(val_target, on="client_id", how="outer")

In [14]:
def to_records(data):
    res = [{} for _ in range(len(data))]
    for i, value in enumerate(data["target"]):
        res[i]["target"] = torch.tensor(value.to_numpy())
    for col, dtype in zip(data.columns, data.dtypes):
        if col in ("client_id", "target"):
            continue
        assert dtype == pl.List
        for i, value in enumerate(data[col].fill_null([])):
            res[i][col] = torch.tensor(value.to_numpy())
    return res

In [15]:
train_dict = to_records(train_trx)
val_dict = to_records(val_trx)

In [16]:
from ptls.frames.supervised import SeqToTargetDataset

In [17]:
train_data = SeqToTargetDataset(train_dict, target_col_name="target", target_dtype=torch.float32)
val_data = SeqToTargetDataset(val_dict, target_col_name="target", target_dtype=torch.float32)

In [18]:
from ptls.frames import PtlsDataModule

In [19]:
sup_data = PtlsDataModule(
    train_data=train_data,
    train_num_workers=0,
    train_batch_size=64,
    valid_data=val_data,
    valid_num_workers=0,
    valid_batch_size=64,
)

In [20]:
_ = next(iter(sup_data.train_dataloader()))

In [105]:
trx_preprocessor.get_category_dictionary_sizes()

{'mon_ind': 14,
 'day_ind': 356,
 'event_type': 55,
 'event_subtype': 57,
 'currency': 13,
 'src_type11': 46,
 'src_type12': 182,
 'dst_type11': 57,
 'dst_type12': 248,
 'src_type21': 8220,
 'src_type22': 87,
 'src_type31': 1617,
 'src_type32': 88}

In [106]:
trx_encoder_params = dict(
    embeddings_noise=0.003,
    linear_projection_size=128,
    embeddings = {
        "mon_ind": {"in": 14, "out": 7},
        "day_ind": {"in": 356, "out": 64},
        
        "event_type": {"in": 55, "out": 16},
        "event_subtype": {"in": 57, "out": 16},
        
        "currency": {"in": 13, "out": 6},
        
        "src_type11": {"in": 46, "out": 16},
        "src_type12": {"in": 182, "out": 32},
        "dst_type11": {"in": 57, "out": 16},
        "dst_type12": {"in": 248, "out": 32},
        
        # "src_type21": {"in": 8220, "out": 64},
        "src_type22": {"in": 87, "out": 16},
        # "src_type31": {"in": 1617, "out": 64},
        "src_type32": {"in": 88, "out": 16},
    },
    # numeric_values = {
    #     "amount": "log",
    # },
    # use_batch_norm_with_lens=True
)

In [107]:
from ptls.nn import TrxEncoder

In [108]:
trx_encoder = TrxEncoder(**trx_encoder_params)

In [109]:
from torchinfo import summary

In [110]:
summary(trx_encoder)

Layer (type:depth-idx)                   Param #
TrxEncoder                               --
├─ModuleDict: 1-1                        --
│    └─NoisyEmbedding: 2-1               98
│    │    └─Dropout: 3-1                 --
│    └─NoisyEmbedding: 2-2               22,784
│    │    └─Dropout: 3-2                 --
│    └─NoisyEmbedding: 2-3               880
│    │    └─Dropout: 3-3                 --
│    └─NoisyEmbedding: 2-4               912
│    │    └─Dropout: 3-4                 --
│    └─NoisyEmbedding: 2-5               78
│    │    └─Dropout: 3-5                 --
│    └─NoisyEmbedding: 2-6               736
│    │    └─Dropout: 3-6                 --
│    └─NoisyEmbedding: 2-7               5,824
│    │    └─Dropout: 3-7                 --
│    └─NoisyEmbedding: 2-8               912
│    │    └─Dropout: 3-8                 --
│    └─NoisyEmbedding: 2-9               7,936
│    │    └─Dropout: 3-9                 --
│    └─NoisyEmbedding: 2-10              1,392
│    │    

In [111]:
from ptls.nn import RnnSeqEncoder

In [112]:
from ptls.nn import PBLayerNorm, PBL2Norm

In [113]:
seq_encoder = RnnSeqEncoder(
    trx_encoder=nn.Sequential(trx_encoder, PBL2Norm(), PBLayerNorm(trx_encoder.output_size)),
    input_size=trx_encoder.output_size,
    hidden_size=128,
    # bidir=True,
    # num_layers=2,
)

In [114]:
classifier = nn.Sequential(
    nn.Linear(128, 64),
    nn.ReLU(),
    nn.Linear(64, 4),
)

In [115]:
from functools import partial
from ptls.frames.supervised import SeqToTargetDataset, SequenceToTarget
from ptls.frames import PtlsDataModule
import torch.nn as nn
import torchmetrics

In [116]:
class AUROC(nn.Module):
    def __init__(self):
        super().__init__()
        self.metric = torchmetrics.AUROC(task='multilabel', num_labels=4)
    def forward(self, preds, target):
        return self.metric(preds, target.int())
    def compute(self):
        return self.metric.compute()
    def reset(self):
        return self.metric.reset()

In [117]:
sup_module = SequenceToTarget(
    seq_encoder=seq_encoder,
    head=classifier,
    loss=nn.BCEWithLogitsLoss(),
    metric_list=AUROC(),
    optimizer_partial=partial(torch.optim.AdamW, lr=1e-4),
    lr_scheduler_partial=partial(torch.optim.lr_scheduler.ConstantLR, factor=1.0),
)

In [118]:
sup_module

SequenceToTarget(
  (seq_encoder): RnnSeqEncoder(
    (trx_encoder): Sequential(
      (0): TrxEncoder(
        (embeddings): ModuleDict(
          (mon_ind): NoisyEmbedding(
            14, 7, padding_idx=0
            (dropout): Dropout(p=0, inplace=False)
          )
          (day_ind): NoisyEmbedding(
            356, 64, padding_idx=0
            (dropout): Dropout(p=0, inplace=False)
          )
          (event_type): NoisyEmbedding(
            55, 16, padding_idx=0
            (dropout): Dropout(p=0, inplace=False)
          )
          (event_subtype): NoisyEmbedding(
            57, 16, padding_idx=0
            (dropout): Dropout(p=0, inplace=False)
          )
          (currency): NoisyEmbedding(
            13, 6, padding_idx=0
            (dropout): Dropout(p=0, inplace=False)
          )
          (src_type11): NoisyEmbedding(
            46, 16, padding_idx=0
            (dropout): Dropout(p=0, inplace=False)
          )
          (src_type12): NoisyEmbedding(
      

In [119]:
from pytorch_lightning.loggers import WandbLogger

In [120]:
from pytorch_lightning.callbacks import LearningRateMonitor

In [121]:
lr_monitor = LearningRateMonitor(logging_interval="step")

In [None]:
pl_trainer = pytorch_lightning.Trainer(
    logger = WandbLogger(),
    max_epochs = 15,
    accelerator = "gpu",
    devices = 1,
    enable_progress_bar = True,
    callbacks = [lr_monitor]
)

In [None]:
pl_trainer.fit(sup_module, sup_data)

In [124]:
prediction = pl_trainer.predict(sup_module, sup_data.val_dataloader())

  rank_zero_warn(



Predicting: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")



In [125]:
import pandas as pd

In [None]:
prediction = pd.concat(prediction)
prediction

In [None]:
from sklearn.metrics import roc_auc_score

In [None]:
roc_auc_score(val_trx["target"].to_list(), prediction[["out_0000", "out_0001", "out_0002", "out_0003"]])