In [1]:
from functools import partial

import pandas as pd
import numpy as np

from sklearn.model_selection import train_test_split

import torch

from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.loggers import CometLogger

from ptls.data_load.iterable_processing import SeqLenFilter
from ptls.data_load.datasets import MemoryMapDataset
from ptls.preprocessing import PandasDataPreprocessor
from ptls.frames import PtlsDataModule

from nn.trx_encoder import TimeTrxEncoder
from nn.seq_encoder import ContConvSeqEncoder

from datasets import TS2VecDataset

from utils.encode import encode_data
from utils.evaluation import bootstrap_eval
from utils.preprocessing import CustomDatetimeNormalization

comet_ml is installed but `COMET_API_KEY` is not set.


In [2]:
from ptls.frames.abs_module import ABSModule
from ptls.data_load.padded_batch import PaddedBatch
from ptls.nn.head import Head

from torchmetrics import MeanMetric

from losses.hierarchical_contrastive_loss import HierarchicalContrastiveLoss
from modules import take_per_row, mask_input

# fixed TS2Vec module for TimeTrxEncoder
# TODO: join 2 versions

class TS2VecTime(ABSModule):
    '''The TS2Vec model'''
    def __init__(
        self,
        seq_encoder,
        mask_mode="binomial",
        head=None,
        loss=None,
        validation_metric=None,
        optimizer_partial=None,
        lr_scheduler_partial=None
    ):
        ''' Initialize a TS2Vec model.
        
        Args:
        '''
        if head is None:
            head = Head(use_norm_encoder=True)
        
        if loss is None:
            loss = HierarchicalContrastiveLoss(alpha=0.5, temporal_unit=0)

        self.temporal_unit = loss.temporal_unit
        self.mask_mode = mask_mode
        
        super().__init__(validation_metric,
                         seq_encoder,
                         loss,
                         optimizer_partial,
                         lr_scheduler_partial)

        self._head = head
        self.valid_loss = MeanMetric()

    def shared_step(self, x, y):
        trx_encoder = self._seq_encoder.trx_encoder
        seq_encoder = self._seq_encoder.seq_encoder 

        seq_lens = x.seq_lens
        encoder_out = trx_encoder(x).payload

        x = encoder_out["embeddings"]
        t = encoder_out["event_time"]

        ts_l = x.size(1)
        crop_l = np.random.randint(low=2 ** (self.temporal_unit + 1), high=ts_l+1)
        crop_left = np.random.randint(ts_l - crop_l + 1)
        crop_right = crop_left + crop_l
        crop_eleft = np.random.randint(crop_left + 1)
        crop_eright = np.random.randint(low=crop_right, high=ts_l + 1)
        crop_offset = np.random.randint(low=-crop_eleft, high=ts_l - crop_eright + 1, size=x.size(0))

        input1 = take_per_row(x, crop_offset + crop_eleft, crop_right - crop_eleft)
        input2 = take_per_row(x, crop_offset + crop_left, crop_eright - crop_left)
        
        t1 = take_per_row(t, crop_offset + crop_eleft, crop_right - crop_eleft)
        t2 = take_per_row(t, crop_offset + crop_left, crop_eright - crop_left)
        
        input1_masked = mask_input(input1, self.mask_mode)
        input2_masked = mask_input(input2, self.mask_mode)
        
        out1 = seq_encoder(PaddedBatch({"embeddings": input1_masked, "event_time": t1}, seq_lens)).payload
        out1 = out1[:, -crop_l:]

        out2 = seq_encoder(PaddedBatch({"embeddings": input2_masked, "event_time": t2}, seq_lens)).payload
        out2 = out2[:, :crop_l]
        
        if self._head is not None:
            out1 = self._head(out1)
            out2 = self._head(out2)

        return (out1, out2), y

    def validation_step(self, batch, _):
        y_h, y = self.shared_step(*batch)
        loss = self._loss(y_h, y)
        self.valid_loss(loss)

    def validation_epoch_end(self, outputs):
        self.log(f'valid_loss', self.valid_loss, prog_bar=True)

    @property
    def is_requires_reduced_sequence(self):
        return False
    
    @property
    def metric_name(self):
        return "valid_loss"

# Read and preprocess data

In [3]:
df = pd.read_parquet("data/preprocessed_new/default.parquet")
df.head()

Unnamed: 0,user_id,mcc_code,amount,timestamp,holiday_target,weekend_target,global_target
0,69,119,-342.89792,2021-03-05 02:52:36,0,0,0
1,69,118,-1251.8812,2021-03-05 09:43:28,0,0,0
2,69,106,-87.30924,2021-03-05 11:17:23,0,0,0
3,69,156,-1822.177,2021-03-05 13:41:03,0,0,0
4,69,105,-427.12363,2021-03-05 19:14:23,0,0,0


In [4]:
df["mcc_code"].min(), df["mcc_code"].max()

(0, 308)

In [5]:
min_timestamp = int(df["timestamp"].min().timestamp())
min_timestamp

1514769288

In [6]:
# normilize times for convolutions
min_timestamp = int(df["timestamp"].min().timestamp())

time_transformer = CustomDatetimeNormalization(
    col_name_original="timestamp",
    min_timestamp=min_timestamp,
    col_name_target="event_time",
)

preprocessor = PandasDataPreprocessor(
    col_id="user_id",
    col_event_time=time_transformer,
    cols_category=["mcc_code"],
    cols_first_item=["global_target"]
)

data = preprocessor.fit_transform(df)

In [8]:
global_targets = [item["global_target"] for item in data]
len(global_targets)

7080

In [9]:
val_size = 0.1
test_size = 0.1

train, val_test, targets_train, targets_val_test = train_test_split(
    data, global_targets, test_size=test_size+val_size, random_state=42, stratify=global_targets
)

val, test = train_test_split(val_test, test_size=test_size/(test_size+val_size), random_state=42, stratify=targets_val_test)

train_ds = TS2VecDataset(train)
val_ds = TS2VecDataset(val)
test_ds = TS2VecDataset(test)

datamodule = PtlsDataModule(
    train_data=train_ds,
    valid_data=val_ds,
    train_batch_size=16,
    valid_batch_size=16,
    train_num_workers=8,
    valid_num_workers=8
)

In [10]:
sum(targets_train) / len(targets_train), sum(targets_val_test) / len(targets_val_test)

(0.037076271186440676, 0.03672316384180791)

In [12]:
trx_encoder = TimeTrxEncoder(
    use_batch_norm_with_lens=True,
    norm_embeddings=False,
    embeddings_noise=0.003,
    embeddings={
        "mcc_code": {"in": 309, "out": 24}
    },
    numeric_values={
        "amount": "identity"
    }
)

seq_encoder = ContConvSeqEncoder(
        trx_encoder,
        is_reduce_sequence=False,
        kernel_hiddens=[8, 16, 8],
        hidden_size=32,
        num_layers=10,
        kernel_size=5,
        dropout=0.1,
    )

num_params = sum([p.numel() for p in seq_encoder.parameters()])
print("Num parameters:", num_params)

Num parameters: 121610


In [13]:
lr_scheduler_partial = partial(torch.optim.lr_scheduler.ReduceLROnPlateau, factor=.9025, patience=5, mode="min")
optimizer_partial = partial(torch.optim.Adam, lr=1e-3)

model = TS2VecTime(
    seq_encoder,
    optimizer_partial=optimizer_partial,
    lr_scheduler_partial=lr_scheduler_partial
)

In [None]:
checkpoint = ModelCheckpoint(
    monitor="valid_loss", 
    mode="min"
)

comet_logger = CometLogger(
    api_key="agnHNC2vEt7tOxnnxT4LzYf7Y",
    project_name="ts2vec-irregular",
    workspace="stalex2902",
    experiment_name="CCNN_TS2Vec_default_check",
    display_summary_level=0,
)

trainer = Trainer(
    max_epochs=5,
    accelerator="gpu",
    devices=[1],
    callbacks=[checkpoint],
    logger=comet_logger,
    accumulate_grad_batches=4
)

trainer.fit(model, datamodule)

model.load_state_dict(torch.load(checkpoint.best_model_path)["state_dict"])
torch.save(model.seq_encoder.state_dict(), "ts2vec_ccnn_default.pth")

In [15]:
model.seq_encoder.load_state_dict(torch.load("checkpoints/default/ts2vec_ccnn/ts2vec_ccnn_default.pth"))

<All keys matched successfully>

# Evaluation

In [16]:
train_val_ds = MemoryMapDataset(train + val, [SeqLenFilter(min_seq_len=25)])

X_train, y_train = encode_data(model.seq_encoder, train_val_ds)
X_test, y_test = encode_data(model.seq_encoder, test_ds)

print("Train size:", len(y_train))
print("Test size:", len(y_test))

Train size: 6372
Test size: 708


In [17]:
results = bootstrap_eval(X_train, X_test, y_train, y_test, n_runs=10)
results

100%|██████████| 10/10 [00:05<00:00,  1.75it/s]


Unnamed: 0,ROC-AUC,PR-AUC,Accuracy
0,0.513366,0.043722,0.963277
1,0.575118,0.050933,0.963277
2,0.57134,0.053073,0.963277
3,0.557918,0.056636,0.963277
4,0.574554,0.051469,0.963277
5,0.55882,0.046525,0.963277
6,0.503158,0.058105,0.963277
7,0.625818,0.062856,0.963277
8,0.546413,0.059156,0.963277
9,0.555436,0.068872,0.963277


In [18]:
results.agg(["mean", "std"])

Unnamed: 0,ROC-AUC,PR-AUC,Accuracy
mean,0.558194,0.055135,0.963277
std,0.034133,0.007567,0.0
