In [1]:
import os 

os.chdir("app/")

In [15]:
from functools import partial

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

from sklearn.model_selection import train_test_split

import torch.nn as nn
import torch

from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.trainer import Trainer

from ptls.data_load.iterable_processing import SeqLenFilter
from ptls.data_load.datasets import MemoryMapDataset
from ptls.preprocessing import PandasDataPreprocessor

from ptls.nn import TrxEncoder
from ptls.nn.seq_encoder import RnnSeqEncoder

from ptls.frames import PtlsDataModule
from ptls.frames.coles.coles_module import CoLESModule
from ptls.frames.coles import ColesDataset
from ptls.frames.coles.split_strategy import SampleSlices, NoSplit

from utils.encode import encode_data
from utils.evaluation import bootstrap_eval

## Load data

In [3]:
df = pd.read_parquet("data/preprocessed_new/churn.parquet")
df.head()

Unnamed: 0,user_id,mcc_code,timestamp,amount,global_target,holiday_target,weekend_target,churn_target,minute,hour,day,month,day_of_week,time_delta
0,0,2,2017-10-12 12:24:07,20000.0,0,0,0,0,24,12,12,10,3,0
1,0,19,2017-10-21 00:00:00,5023.0,0,0,1,0,0,0,21,10,5,732953
2,0,1,2017-10-21 00:00:00,2031.0,0,0,1,0,0,0,21,10,5,0
3,0,9,2017-10-24 13:14:24,36562.0,0,0,0,0,14,13,24,10,1,306864
4,0,10,2017-12-05 00:00:00,767.0,0,0,0,0,0,0,5,12,1,3581136


## Prepare dataset and dataloaders

In [4]:
preprocessor = PandasDataPreprocessor(
    col_id="user_id",
    col_event_time="timestamp",
    event_time_transformation="dt_to_timestamp",
    cols_category=["mcc_code"],
    cols_first_item=["global_target"]
)

In [5]:
data = preprocessor.fit_transform(df)

In [6]:
val_size = 0.1
test_size = 0.1

train, val_test = train_test_split(data, test_size=test_size+val_size, random_state=42)
val, test = train_test_split(val_test, test_size=test_size/(test_size+val_size), random_state=42)

In [7]:
train_ds = ColesDataset(
    data=MemoryMapDataset(train, [SeqLenFilter(min_seq_len=15)]),
    splitter=SampleSlices(5, 15, 150),
)
val_ds = ColesDataset(
    data=MemoryMapDataset(val, [SeqLenFilter(min_seq_len=15)]),
    splitter=SampleSlices(5, 15, 150),
)
test_ds = MemoryMapDataset(test, [SeqLenFilter(min_seq_len=15)])

In [8]:
datamodule = PtlsDataModule(
    train_data=train_ds,
    valid_data=val_ds,
    train_batch_size=128,
    valid_batch_size=128,
    train_num_workers=8,
    valid_num_workers=8
)

# CoLES

## Model training

In [31]:
trx_encoder = TrxEncoder(
    embeddings={
        "mcc_code": {"in": 345, "out": 24}
    },
    numeric_values={
        "amount": "identity",
    },
    use_batch_norm_with_lens=True,
    norm_embeddings=False,
    embeddings_noise=0.0003
)

seq_encoder = RnnSeqEncoder(
    trx_encoder,
    hidden_size=1024,
    type="lstm",
    bidir=False,
    trainable_starter="static",
)

In [32]:
lr_scheduler_partial = partial(torch.optim.lr_scheduler.ReduceLROnPlateau, factor=.9025, patience=5, mode="max")
optimizer_partial = partial(torch.optim.Adam, lr=4e-3)

model = CoLESModule(
    seq_encoder,
    optimizer_partial=optimizer_partial,
    lr_scheduler_partial=lr_scheduler_partial
)

In [None]:
checkpoint = ModelCheckpoint(
    monitor="recall_top_k", 
    mode="max"
)

trainer = Trainer(
    max_epochs=50,
    devices=[1],
    accelerator="gpu",
    callbacks=[checkpoint]
)

trainer.fit(model, datamodule)

In [34]:
model.load_state_dict(torch.load(checkpoint.best_model_path)["state_dict"])

<All keys matched successfully>

In [41]:
torch.save(model.seq_encoder.state_dict(), "coles_churn.pth")

## Model evaluation

In [36]:
train_val_ds = MemoryMapDataset(train + val, [SeqLenFilter(min_seq_len=15)])

In [37]:
X_train, y_train = encode_data(model.seq_encoder, train_val_ds)
X_test, y_test = encode_data(model.seq_encoder, test_ds)

print("Train size:", len(y_train))
print("Test size:", len(y_test))

Train size: 3961
Test size: 443


In [38]:
results = bootstrap_eval(X_train, X_test, y_train, y_test, n_runs=10)

In [39]:
results

Unnamed: 0,ROC-AUC,PR-AUC
0,0.735462,0.80103
1,0.742974,0.80585
2,0.751693,0.820683
3,0.735801,0.803453
4,0.732775,0.802779
5,0.754232,0.822166
6,0.734637,0.815245
7,0.733452,0.804415
8,0.730807,0.793126
9,0.728796,0.803373


In [40]:
results.agg(["mean", "std"])

Unnamed: 0,ROC-AUC,PR-AUC
mean,0.738063,0.807212
std,0.00871,0.009216


# CoLES with time features

## Model training

In [9]:
trx_encoder = TrxEncoder(
    embeddings={
        "mcc_code": {"in": 345, "out": 24}
    },
    numeric_values={
        "amount": "identity",
        "event_time": "identity",
        "time_delta": "identity",
    },
    use_batch_norm_with_lens=True,
    norm_embeddings=False,
    embeddings_noise=0.0003
)

seq_encoder = RnnSeqEncoder(
    trx_encoder,
    hidden_size=1024,
    type="lstm",
    bidir=False,
    trainable_starter="static",
)

In [10]:
lr_scheduler_partial = partial(torch.optim.lr_scheduler.ReduceLROnPlateau, factor=.9025, patience=5, mode="max")
optimizer_partial = partial(torch.optim.Adam, lr=4e-3)

model = CoLESModule(
    seq_encoder,
    optimizer_partial=optimizer_partial,
    lr_scheduler_partial=lr_scheduler_partial
)

In [None]:
checkpoint = ModelCheckpoint(
    monitor="recall_top_k", 
    mode="max"
)

trainer = Trainer(
    max_epochs=50,
    devices=[1],
    accelerator="gpu",
    callbacks=[checkpoint]
)

trainer.fit(model, datamodule)

In [21]:
model.load_state_dict(torch.load(checkpoint.best_model_path)["state_dict"])

<All keys matched successfully>

In [30]:
torch.save(model.seq_encoder.state_dict(), "coles_churn_date.pth")

## Model evaluation

In [24]:
train_val_ds = MemoryMapDataset(train + val, [SeqLenFilter(min_seq_len=15)])

In [25]:
X_train, y_train = encode_data(model.seq_encoder, train_val_ds)
X_test, y_test = encode_data(model.seq_encoder, test_ds)

print("Train size:", len(y_train))
print("Test size:", len(y_test))

Train size: 3961
Test size: 443


In [26]:
results = bootstrap_eval(X_train, X_test, y_train, y_test, n_runs=10)

In [27]:
results

Unnamed: 0,ROC-AUC,PR-AUC
0,0.797338,0.863472
1,0.805485,0.871617
2,0.812807,0.874014
3,0.803771,0.867662
4,0.80612,0.854886
5,0.800406,0.859777
6,0.798713,0.858944
7,0.812807,0.873045
8,0.815177,0.872571
9,0.806882,0.862757


In [29]:
results.agg(["mean", "std"])

Unnamed: 0,ROC-AUC,PR-AUC
mean,0.805951,0.865875
std,0.006167,0.006833
