In [1]:
from datetime import datetime

import polars as pl

import torch
import pytorch_lightning
from torch import nn

from src.preprocessing import PolarsDataPreprocessor
from src.utils import prepare_data

In [2]:
pl.set_random_seed(56)

In [3]:
pytorch_lightning.seed_everything(56, True)

Global seed set to 56


56

In [4]:
months = [
    '2022-02-28',
    '2022-03-31',
    '2022-04-30',
    '2022-05-31',
    '2022-06-30',
    '2022-07-31',
    '2022-08-31',
    '2022-09-30',
    '2022-10-31',
    '2022-11-30',
    '2022-12-31',
    '2023-01-31',
]
month2id = dict((month, idx) for idx, month in enumerate(months))
id2month = dict((idx, month) for idx, month in enumerate(months))

In [5]:
mon = 10

In [6]:
train_geo = pl.read_parquet("F:/chunks/geo_train.pq")
train_geo = prepare_data(train_geo)
train_geo = train_geo.filter(pl.col("event_time") < datetime.strptime(months[mon], "%Y-%m-%d").replace(day=1))
print(train_geo.shape)
train_geo.head()

(53101214, 7)


client_id,event_time,geohash_4,geohash_5,geohash_6,mon_ind,day_ind
str,i64,i32,i32,i32,i64,i64
"""009c52bd099cbb…",57046408,32892,35465,1028609,21,660
"""009c52bd099cbb…",56294845,32892,35465,461846,21,651
"""009c52bd099cbb…",56912877,32892,35465,1028609,21,658
"""009c52bd099cbb…",55770461,32892,35465,461846,21,645
"""009c52bd099cbb…",54736642,32892,35465,1028609,20,633


In [7]:
train_target = pl.read_parquet("F:/chunks/train_target.pq")
train_target = train_target.with_columns(
    target=pl.concat_list([f"target_{i}" for i in range(1, 5)])
)
train_target = train_target.select(("client_id", "target"))
train_target.head()

client_id,target
str,list[i32]
"""06a2ce26f19242…","[0, 0, … 0]"
"""06d250bda1fe78…","[0, 0, … 0]"
"""1d0fd54040602e…","[0, 0, … 0]"
"""85c233cac30252…","[0, 0, … 0]"
"""16d2d3fbdef66b…","[0, 0, … 0]"


In [8]:
val_geo = pl.read_parquet("F:/chunks/geo_val.pq")
val_geo = prepare_data(val_geo)
val_geo = val_geo.filter(pl.col("event_time") < datetime.strptime(months[mon], "%Y-%m-%d").replace(day=1))
print(val_geo.shape)
val_geo.head()

(23574218, 7)


client_id,event_time,geohash_4,geohash_5,geohash_6,mon_ind,day_ind
str,i64,i32,i32,i32,i64,i64
"""4e16dc21c7a960…",38311933,36432,320079,2073605,14,443
"""4e16dc21c7a960…",36842499,36432,182144,427368,14,426
"""4e16dc21c7a960…",44297900,36432,320079,2073605,16,512
"""4e16dc21c7a960…",36393558,36432,182144,427368,13,421
"""4e16dc21c7a960…",39058004,36432,320079,2073605,14,452


In [9]:
val_target = pl.read_parquet("F:/chunks/val_target.pq")
val_target = val_target.with_columns(
    target=pl.concat_list([f"target_{i}" for i in range(1, 5)])
)
val_target = val_target.select(("client_id", "target"))
val_target.head()

client_id,target
str,list[i32]
"""c67a6454099213…","[0, 0, … 0]"
"""e225b104fe6428…","[0, 0, … 0]"
"""eacde9cdeaf7ec…","[0, 0, … 0]"
"""e962187af0d4fa…","[0, 0, … 0]"
"""12502ddd41e10e…","[0, 0, … 0]"


In [10]:
geo_preprocessor = PolarsDataPreprocessor(
    col_id="client_id",
    col_event_time="event_time",
    cols_category=[
        "mon_ind", "day_ind",
        "geohash_4",
    ],
)

In [11]:
%%time

train_geo = geo_preprocessor.fit_transform(train_geo)

CPU times: total: 4min 47s
Wall time: 59.9 s


In [12]:
%%time

val_geo = geo_preprocessor.transform(val_geo)

CPU times: total: 1min 37s
Wall time: 9.57 s


In [14]:
train_geo = train_geo.join(train_target, on="client_id", how="outer")
val_geo = val_geo.join(val_target, on="client_id", how="outer")

In [15]:
def to_records(data):
    res = [{} for _ in range(len(data))]
    for i, value in enumerate(data["target"]):
        res[i]["target"] = torch.tensor(value.to_numpy())
    for col, dtype in zip(data.columns, data.dtypes):
        if col in ("client_id", "target"):
            continue
        assert dtype == pl.List
        for i, value in enumerate(data[col].fill_null([])):
            res[i][col] = torch.tensor(value.to_numpy())
    return res

In [16]:
train_dict = to_records(train_geo)
val_dict = to_records(val_geo)

In [17]:
from ptls.frames.supervised import SeqToTargetDataset

In [18]:
train_data = SeqToTargetDataset(train_dict, target_col_name="target", target_dtype=torch.float32)
val_data = SeqToTargetDataset(val_dict, target_col_name="target", target_dtype=torch.float32)

In [19]:
from ptls.frames import PtlsDataModule

In [20]:
sup_data = PtlsDataModule(
    train_data=train_data,
    train_num_workers=0,
    train_batch_size=64,
    valid_data=val_data,
    valid_num_workers=0,
    valid_batch_size=64,
)

In [21]:
_ = next(iter(sup_data.train_dataloader()))

In [46]:
geo_preprocessor.get_category_dictionary_sizes()

{'mon_ind': 14, 'day_ind': 356, 'geohash_4': 17927}

In [47]:
geo_encoder_params = dict(
    embeddings_noise=0.003,
    linear_projection_size=128,
    embeddings = {
        "mon_ind": {"in": 14, "out": 7},
        "day_ind": {"in": 356, "out": 64},
        
        "geohash_4": {"in": 17927, "out": 256},
    },
)

In [48]:
from ptls.nn import TrxEncoder

In [49]:
geo_encoder = TrxEncoder(**geo_encoder_params)

In [50]:
from torchinfo import summary

In [51]:
summary(geo_encoder)

Layer (type:depth-idx)                   Param #
TrxEncoder                               --
├─ModuleDict: 1-1                        --
│    └─NoisyEmbedding: 2-1               98
│    │    └─Dropout: 3-1                 --
│    └─NoisyEmbedding: 2-2               22,784
│    │    └─Dropout: 3-2                 --
│    └─NoisyEmbedding: 2-3               4,589,312
│    │    └─Dropout: 3-3                 --
├─ModuleDict: 1-2                        --
├─Linear: 1-3                            41,984
Total params: 4,654,178
Trainable params: 4,654,178
Non-trainable params: 0

In [52]:
from ptls.nn import RnnSeqEncoder

In [53]:
from ptls.nn import PBLayerNorm, PBL2Norm

In [54]:
seq_encoder = RnnSeqEncoder(
    trx_encoder=nn.Sequential(geo_encoder, PBL2Norm(), PBLayerNorm(geo_encoder.output_size)),
    input_size=geo_encoder.output_size,
    hidden_size=128,
    # bidir=True,
    # num_layers=2,
)

In [55]:
classifier = nn.Sequential(
    nn.Linear(128, 64),
    nn.ReLU(),
    nn.Linear(64, 4),
)

In [56]:
from functools import partial
from ptls.frames.supervised import SeqToTargetDataset, SequenceToTarget
from ptls.frames import PtlsDataModule
import torch.nn as nn
import torchmetrics

In [57]:
class AUROC(nn.Module):
    def __init__(self):
        super().__init__()
        self.metric = torchmetrics.AUROC(task='multilabel', num_labels=4)
    def forward(self, preds, target):
        return self.metric(preds, target.int())
    def compute(self):
        return self.metric.compute()
    def reset(self):
        return self.metric.reset()

In [58]:
sup_module = SequenceToTarget(
    seq_encoder=seq_encoder,
    head=classifier,
    loss=nn.BCEWithLogitsLoss(),
    metric_list=AUROC(),
    optimizer_partial=partial(torch.optim.AdamW, lr=1e-4),
    lr_scheduler_partial=partial(torch.optim.lr_scheduler.ConstantLR, factor=1.0),
)

In [59]:
sup_module

SequenceToTarget(
  (seq_encoder): RnnSeqEncoder(
    (trx_encoder): Sequential(
      (0): TrxEncoder(
        (embeddings): ModuleDict(
          (mon_ind): NoisyEmbedding(
            14, 7, padding_idx=0
            (dropout): Dropout(p=0, inplace=False)
          )
          (day_ind): NoisyEmbedding(
            356, 64, padding_idx=0
            (dropout): Dropout(p=0, inplace=False)
          )
          (geohash_4): NoisyEmbedding(
            17927, 256, padding_idx=0
            (dropout): Dropout(p=0, inplace=False)
          )
        )
        (custom_embeddings): ModuleDict()
        (linear_projection_head): Linear(in_features=327, out_features=128, bias=True)
      )
      (1): PBShell()
      (2): PBShell((128,), eps=1e-05, elementwise_affine=True)
    )
    (seq_encoder): RnnEncoder(
      (rnn): GRU(128, 128, batch_first=True)
      (reducer): LastStepEncoder()
    )
  )
  (head): Sequential(
    (0): Linear(in_features=128, out_features=64, bias=True)
    (1): ReLU

In [60]:
from pytorch_lightning.loggers import WandbLogger

In [61]:
from pytorch_lightning.callbacks import LearningRateMonitor

In [62]:
lr_monitor = LearningRateMonitor(logging_interval="step")

In [63]:
pl_trainer = pytorch_lightning.Trainer(
    logger = WandbLogger(),
    max_epochs = 15,
    accelerator = "gpu",
    devices = 1,
    enable_progress_bar = True,
    callbacks = [lr_monitor]
)

  rank_zero_warn(

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [64]:
pl_trainer.fit(sup_module, sup_data)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type              | Params
----------------------------------------------------
0 | seq_encoder   | RnnSeqEncoder     | 4.8 M 
1 | head          | Sequential        | 8.5 K 
2 | loss          | BCEWithLogitsLoss | 0     
3 | train_metrics | ModuleDict        | 0     
4 | valid_metrics | ModuleDict        | 0     
5 | test_metrics  | ModuleDict        | 0     
----------------------------------------------------
4.8 M     Trainable params
0         Non-trainable params
4.8 M     Total params
19.049    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


  rank_zero_warn(



Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")



In [65]:
prediction = pl_trainer.predict(sup_module, sup_data.val_dataloader())

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(



Predicting: 0it [00:00, ?it/s]

In [66]:
import pandas as pd

In [67]:
prediction = pd.concat(prediction)
prediction

Unnamed: 0,seq_id_0000,seq_id_0001,seq_id_0002,seq_id_0003,out_0000,out_0001,out_0002,out_0003
0,0.0,0.0,0.0,0.0,-4.924044,-6.623470,-4.783706,-5.879585
1,0.0,0.0,0.0,0.0,-6.009947,-7.913704,-5.983152,-6.778406
2,0.0,0.0,0.0,0.0,-4.529310,-6.030220,-4.171947,-4.568515
3,0.0,0.0,0.0,0.0,-5.997715,-7.659665,-5.973559,-6.344449
4,0.0,0.0,0.0,0.0,-5.111794,-6.067881,-4.332725,-4.070150
...,...,...,...,...,...,...,...,...
40,0.0,0.0,0.0,0.0,-6.015372,-7.455849,-5.809294,-5.906121
41,0.0,0.0,0.0,1.0,-5.984677,-7.847799,-5.992179,-6.474560
42,0.0,0.0,0.0,0.0,-5.997715,-7.659665,-5.973559,-6.344449
43,0.0,0.0,0.0,0.0,-5.664832,-7.599912,-5.673007,-6.522179


In [68]:
from sklearn.metrics import roc_auc_score

In [69]:
roc_auc_score(val_geo["target"].to_list(), prediction[["out_0000", "out_0001", "out_0002", "out_0003"]])

0.6224184300707871