In [1]:
from datetime import datetime

import polars as pl

import torch
import pytorch_lightning
from torch import nn

from src.preprocessing import PolarsDataPreprocessor
from src.utils import prepare_data

In [2]:
pl.set_random_seed(56)

In [3]:
pytorch_lightning.seed_everything(56, True)

56

In [4]:
months = [
    '2022-02-28',
    '2022-03-31',
    '2022-04-30',
    '2022-05-31',
    '2022-06-30',
    '2022-07-31',
    '2022-08-31',
    '2022-09-30',
    '2022-10-31',
    '2022-11-30',
    '2022-12-31',
    '2023-01-31',
]
month2id = dict((month, idx) for idx, month in enumerate(months))
id2month = dict((idx, month) for idx, month in enumerate(months))

In [5]:
mon = 10

In [6]:
train_dial = pl.read_parquet("F:/dial_train.parquet/*")
train_dial = prepare_data(train_dial)
train_dial = train_dial.filter(pl.col("event_time") < datetime.strptime(months[mon], "%Y-%m-%d").replace(day=1))
print(train_dial.shape)
train_dial.head()

(1215209, 5)


client_id,event_time,embedding,mon_ind,day_ind
str,i64,list[f32],i64,i64
"""a039ad3b595d4f…",42035070,"[-0.003248, 0.140231, … -0.010614]",16,486
"""a039ad3b595d4f…",57488718,"[0.058927, -0.007723, … 0.086532]",21,665
"""a039ad3b595d4f…",61634186,"[0.079807, -0.003912, … 0.095575]",23,713
"""a060e69e9e049a…",53617023,"[0.523752, -0.30542, … 0.568542]",20,620
"""a08c690dd972d2…",42023542,"[-0.009235, -0.069714, … -0.062696]",16,486


In [7]:
train_target = pl.read_parquet("./data/train_target.parquet/*")
train_target = train_target.filter(pl.col("mon") == months[mon])
train_target = train_target.with_columns(
    target=pl.concat_list([f"target_{i}" for i in range(1, 5)])
)
train_target = train_target.select(("client_id", "target"))
train_target.head()

client_id,target
str,list[i32]
"""1d4ebf30ab5b98…","[0, 0, … 0]"
"""1d55174bce3ef4…","[0, 0, … 0]"
"""1d5d052f87d6bd…","[0, 0, … 0]"
"""1d68b588164639…","[0, 0, … 0]"
"""1d817e82e1cc59…","[0, 0, … 0]"


In [8]:
val_dial = pl.read_parquet("F:/dial_test.parquet/*")
val_dial = prepare_data(val_dial)
val_dial = val_dial.filter(pl.col("event_time") < datetime.strptime(months[mon], "%Y-%m-%d").replace(day=1))
print(val_dial.shape)
val_dial.head()

(286526, 5)


client_id,event_time,embedding,mon_ind,day_ind
str,i64,list[f32],i64,i64
"""08b3569cdfd015…",36413591,"[0.110589, -0.000545, … 0.134537]",13,421
"""08b3569cdfd015…",41938633,"[0.00209, 0.072185, … 0.015157]",16,485
"""08b3569cdfd015…",42626904,"[0.194512, -0.032053, … 0.160809]",16,493
"""08b3569cdfd015…",38386943,"[0.045035, 0.042004, … -0.112547]",14,444
"""08b3569cdfd015…",38220918,"[0.330715, -0.023786, … 0.342075]",14,442


In [9]:
val_target = pl.read_parquet("./data/test_target_b.parquet/*").unique()
val_target = val_target.filter(pl.col("mon") == months[mon])
val_target = val_target.with_columns(
    target=pl.concat_list([f"target_{i}" for i in range(1, 5)])
)
val_target = val_target.select(("client_id", "target"))
val_target.head()

client_id,target
str,list[i32]
"""eacde9cdeaf7ec…","[0, 0, … 0]"
"""581cdd678867bb…","[0, 0, … 0]"
"""be99a105e19a53…","[0, 0, … 0]"
"""9f6d6e6dea2b12…","[0, 0, … 0]"
"""12d1d490206e23…","[0, 0, … 0]"


In [10]:
dial_preprocessor = PolarsDataPreprocessor(
    col_id="client_id",
    col_event_time="event_time",
    cols_category=[
        "mon_ind", "day_ind", 
    ],
    cols_numerical=["embedding"],
)

In [11]:
%%time

train_dial = dial_preprocessor.fit_transform(train_dial)

CPU times: total: 27.4 s
Wall time: 22.6 s


In [12]:
%%time

val_dial = dial_preprocessor.transform(val_dial)

CPU times: total: 7.77 s
Wall time: 6.15 s


In [13]:
train_dial = train_dial.join(train_target, on="client_id", how="outer").filter(~pl.col("target").is_null())
val_dial = val_dial.join(val_target, on="client_id", how="outer").filter(~pl.col("target").is_null())

In [14]:
val_dial = val_dial.sample(fraction=1.0, shuffle=True)
val_dial

client_id,event_time,mon_ind,day_ind,embedding,target
str,list[i64],list[i32],list[i32],list[list[f32]],list[i32]
"""d3998cf889633d…",,,,,"[0, 0, … 0]"
"""e5412b4fa21f6f…",[38040903],[6],[156],"[[0.078863, 0.067005, … -0.09817]]","[0, 0, … 0]"
"""eb1edd32c3dd18…",,,,,"[0, 0, … 0]"
"""c8d16becb83770…",,,,,"[0, 0, … 0]"
"""fbf6cf67602181…",,,,,"[0, 0, … 0]"
"""7b90d4c493382d…",,,,,"[0, 0, … 0]"
"""12babeb2f62fbd…",,,,,"[0, 0, … 0]"
"""d5740438226829…",,,,,"[0, 0, … 0]"
"""6691c06d165d81…",,,,,"[0, 0, … 0]"
"""bb53dbbe09739e…",[49871005],[12],[144],"[[0.293067, -0.025177, … 0.275898]]","[0, 0, … 0]"


In [15]:
import numpy as np

In [16]:
def to_records(data):
    res = [{} for _ in range(len(data))]
    for i, value in enumerate(data["target"]):
        res[i]["target"] = torch.tensor(value.to_numpy())
    for col, dtype in zip(data.columns, data.dtypes):
        if col in ("client_id", "target"):
            continue
        assert dtype == pl.List
        if col != "embedding":
            for i, value in enumerate(data[col].fill_null([])):
                res[i][col] = torch.tensor(value.to_numpy())
        else:
            # for i, value in enumerate(data[col].fill_null([[0] * 768]).cast(pl.List(pl.List(pl.Float32)))):
            for i, value in enumerate(data[col]):
                if value is None:
                    res[i][col] = torch.zeros((1, 768), dtype=torch.float32)
                else:
                    res[i][col] = torch.tensor(np.vstack(value.to_numpy()))
    return res

In [17]:
train_dict = to_records(train_dial)
val_dict = to_records(val_dial)

In [18]:
from ptls.frames.supervised import SeqToTargetDataset

In [19]:
train_data = SeqToTargetDataset(train_dict, target_col_name="target", target_dtype=torch.float32)
val_data = SeqToTargetDataset(val_dict, target_col_name="target", target_dtype=torch.float32)

In [20]:
from ptls.frames import PtlsDataModule

In [48]:
sup_data = PtlsDataModule(
    train_data=train_data,
    train_num_workers=0,
    train_batch_size=128,
    train_drop_last=True,
    valid_data=val_data,
    valid_num_workers=0,
    valid_batch_size=128,
)

In [49]:
_ = next(iter(sup_data.train_dataloader()))

In [79]:
dial_preprocessor.get_category_dictionary_sizes()

{'mon_ind': 14, 'day_ind': 366}

In [80]:
from ptls.nn.trx_encoder.encoders import IdentityEncoder

In [82]:
dial_encoder_params = dict(
    embeddings_noise=0.003,
    linear_projection_size=128,
    embeddings = {
        "mon_ind": {"in": 14, "out": 7},
        "day_ind": {"in": 366, "out": 64},
    },
    custom_embeddings = {
        "embedding": IdentityEncoder(768),
    },
    # use_batch_norm=True,
    # use_batch_norm_with_lens=True,
)

In [83]:
from ptls.nn import TrxEncoder

In [84]:
dial_encoder = TrxEncoder(**dial_encoder_params)

In [85]:
from torchinfo import summary

In [86]:
summary(dial_encoder)

Layer (type:depth-idx)                   Param #
TrxEncoder                               --
├─ModuleDict: 1-1                        --
│    └─NoisyEmbedding: 2-1               98
│    │    └─Dropout: 3-1                 --
│    └─NoisyEmbedding: 2-2               23,424
│    │    └─Dropout: 3-2                 --
├─ModuleDict: 1-2                        --
│    └─IdentityEncoder: 2-3              --
├─RBatchNormWithLens: 1-3                --
│    └─BatchNorm1d: 2-4                  1,536
├─Linear: 1-4                            107,520
Total params: 132,578
Trainable params: 132,578
Non-trainable params: 0

In [87]:
from ptls.nn import RnnSeqEncoder

In [88]:
from ptls.nn import PBLayerNorm, PBL2Norm

In [89]:
seq_encoder = RnnSeqEncoder(
    trx_encoder=nn.Sequential(dial_encoder, PBL2Norm(), PBLayerNorm(dial_encoder.output_size)),
    input_size=dial_encoder.output_size,
    hidden_size=128,
    # bidir=True,
    # num_layers=2,
)

In [90]:
classifier = nn.Sequential(
    nn.Linear(128, 64),
    nn.ReLU(),
    nn.Linear(64, 4),
)

In [91]:
from functools import partial
from ptls.frames.supervised import SeqToTargetDataset, SequenceToTarget
from ptls.frames import PtlsDataModule
import torch.nn as nn
import torchmetrics

In [92]:
class AUROC(nn.Module):
    def __init__(self):
        super().__init__()
        self.metric = torchmetrics.AUROC(task='multilabel', num_labels=4)
    def forward(self, preds, target):
        return self.metric(preds, target.int())
    def compute(self):
        return self.metric.compute()
    def reset(self):
        return self.metric.reset()

In [93]:
sup_module = SequenceToTarget(
    seq_encoder=seq_encoder,
    head=classifier,
    loss=nn.BCEWithLogitsLoss(),
    metric_list=AUROC(),
    optimizer_partial=partial(torch.optim.AdamW, lr=1e-4),
    lr_scheduler_partial=partial(torch.optim.lr_scheduler.ConstantLR, factor=1.0),
)

In [94]:
sup_module

SequenceToTarget(
  (seq_encoder): RnnSeqEncoder(
    (trx_encoder): Sequential(
      (0): TrxEncoder(
        (embeddings): ModuleDict(
          (mon_ind): NoisyEmbedding(
            14, 7, padding_idx=0
            (dropout): Dropout(p=0, inplace=False)
          )
          (day_ind): NoisyEmbedding(
            366, 64, padding_idx=0
            (dropout): Dropout(p=0, inplace=False)
          )
        )
        (custom_embeddings): ModuleDict(
          (embedding): IdentityEncoder()
        )
        (custom_embedding_batch_norm): RBatchNormWithLens(
          (bn): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (linear_projection_head): Linear(in_features=839, out_features=128, bias=True)
      )
      (1): PBShell()
      (2): PBShell((128,), eps=1e-05, elementwise_affine=True)
    )
    (seq_encoder): RnnEncoder(
      (rnn): GRU(128, 128, batch_first=True)
      (reducer): LastStepEncoder()
    )
  )
  (head): Sequential

In [95]:
from pytorch_lightning.loggers import WandbLogger

In [96]:
from pytorch_lightning.callbacks import LearningRateMonitor

In [97]:
lr_monitor = LearningRateMonitor(logging_interval="step")

In [98]:
pl_trainer = pytorch_lightning.Trainer(
    logger = WandbLogger(),
    max_epochs = 15,
    accelerator = "gpu",
    devices = 1,
    enable_progress_bar = True,
    callbacks = [lr_monitor]
)

  rank_zero_warn(



In [99]:
pl_trainer.fit(sup_module, sup_data)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")



Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


  rank_zero_warn(



Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")



In [100]:
prediction = pl_trainer.predict(sup_module, sup_data.val_dataloader())

  rank_zero_warn(



Predicting: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")



In [101]:
import pandas as pd

In [102]:
prediction = pd.concat(prediction)
prediction


KeyboardInterrupt



In [None]:
from sklearn.metrics import roc_auc_score

In [None]:
roc_auc_score(val_dial["target"].to_list(), prediction[["out_0000", "out_0001", "out_0002", "out_0003"]])