In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Скрипт по обучению coles модели

## Импорт модулей и подгрузка конфига

In [None]:
# !pip install pytorch-lifestream==0.5.2
# !pip install pytorch-lightning==1.6.*

In [None]:
from functools import partial
import os
import yaml
import joblib
import gc

import torchmetrics
from sklearn.model_selection import train_test_split

import numpy as np
import pandas as pd

import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from ptls.frames.supervised import SeqToTargetDataset, SequenceToTarget
from ptls.nn import TrxEncoder, RnnSeqEncoder,TransformerSeqEncoder,TransformerEncoder,Head
from ptls.frames.coles import CoLESModule
from ptls.data_load.datasets import MemoryMapDataset
from ptls.data_load.iterable_processing import ISeqLenLimit,FeatureFilter,SeqLenFilter
from ptls.data_load.utils import collate_feature_dict
from ptls.frames.coles import ColesDataset
from ptls.frames.coles.split_strategy import SampleSlices,SampleUniform
from ptls.frames import PtlsDataModule
from ptls.frames.inference_module import InferenceModule
from ptls.data_load.utils import collate_feature_dict
from ptls.data_load.datasets import inference_data_loader
import ptls

import logging

In [None]:
config_name = 'sberhack_coles_model_1.yaml'
path_to_working_directory = 'drive/My Drive/курсовая/pytorch-lifestream'

with open(os.path.join(path_to_working_directory,'configs',config_name),'r') as f:
  model_config = yaml.safe_load(f)

with open(os.path.join(path_to_working_directory,'configs',model_config['data_config']),'r') as f:
  data_config = yaml.safe_load(f)

In [None]:
# опредедяем head,metric,loss самостоятельно
head = Head(input_size=model_config['rnn_config']['hidden_state'],use_batch_norm=True,num_classes=1,objective='classification')
loss=torch.nn.BCELoss()
metric=torchmetrics.AUROC(task='binary')
batch_size = 10
num_workers = 4

##Загрузка предобработанных данных

In [None]:
df_data_train = joblib.load(os.path.join(data_config['path_folder'],'train_'+model_config['data_config'].replace('yaml','pickle')))
df_data_valid = joblib.load(os.path.join(data_config['path_folder'],'valid_'+model_config['data_config'].replace('yaml','pickle')))
df_data_test = joblib.load(os.path.join(data_config['path_folder'],'test_'+model_config['data_config'].replace('yaml','pickle')))
preprocessor = joblib.load(os.path.join(data_config['path_folder'],'preprocessor_'+model_config['data_config'].replace('yaml','pickle')))

In [None]:
dataset_train = MemoryMapDataset(df_data_train,)
dataset_valid = MemoryMapDataset(df_data_valid)
dataset_test = MemoryMapDataset(df_data_test)

## Загрузка обученного seq_encoder

In [None]:
# только с таким костылем смог запустить
logger = logging.getLogger(__name__)
class SequenceToTarget2(SequenceToTarget):
    def configure_optimizers(self):
        if self.hparams.pretrained_lr is not None:
            if self.hparams.pretrained_lr == 'freeze':
                for p in self.seq_encoder.parameters():
                    p.requires_grad = False
                logger.info('Created optimizer with frozen encoder')
                parameters = self.parameters()
            else:
                parameters = [
                    {'params': self.seq_encoder.parameters(), 'lr': self.hparams.pretrained_lr},
                    {'params': self.head.parameters()},  # use predefined lr from `self.optimizer_partial`
                ]
                logger.info('Created optimizer with two lr groups')
        else:
            parameters = self.parameters()

        optimizer = self.optimizer_partial(parameters)
        scheduler = self.lr_scheduler_partial(optimizer)
        return {"optimizer": optimizer,
                "lr_scheduler": scheduler,
                "monitor": [f"val_{metric}" for metric in self.valid_metrics][0]}

In [None]:
# если надо поменять, то сами ручками в файле все меняем
cat_feature_params = {k: {'in' : v, 'out' : v // model_config['rnn_config']['category_emb_dim_reduction']}for k,v in preprocessor.get_category_dictionary_sizes().items()}
num_feature_params = {f:'identity' for f in data_config['numeric_cols']}

trx_encoder_params = dict(
    embeddings_noise=0.001,
    numeric_values=num_feature_params,
    embeddings=cat_feature_params)
# здесь при своих кастомых rnn или transformer надо будет переопределять
seq_encoder = RnnSeqEncoder(
    trx_encoder=TrxEncoder(**trx_encoder_params),
    hidden_size=model_config['rnn_config']['hidden_state'],
    type=model_config['rnn_config']['rnn_type'])

coles_model = CoLESModule(
    seq_encoder=seq_encoder,
    optimizer_partial=partial(torch.optim.NAdam),
    lr_scheduler_partial=partial(torch.optim.lr_scheduler.ReduceLROnPlateau, mode='min', factor=0.2, patience=2)
)

# подгрузка весов
coles_model.load_state_dict(torch.load(os.path.join(path_to_working_directory,'models',config_name.replace('yaml','pt'))))


model = SequenceToTarget2(
    seq_encoder=coles_model.seq_encoder,
    head=head,
    loss=loss,
    metric_list=metric,
    optimizer_partial=partial(torch.optim.NAdam),
    lr_scheduler_partial=partial(torch.optim.lr_scheduler.ReduceLROnPlateau, mode='min', factor=0.2, patience=2)
)
# partial(torch.optim.lr_scheduler.StepLR, step_size=20, gamma=0.9) -
# ошибка MisconfigurationException: The provided lr scheduler `StepLR` doesn't follow PyTorch's LRScheduler API. You should override the `LightningModule.lr_scheduler_step` hook with your own logic if you are using a custom LR scheduler.

sup_data = PtlsDataModule(
    train_data=SeqToTargetDataset(dataset_train, target_col_name=data_config['target_col'], target_dtype=torch.float),
    valid_data=SeqToTargetDataset(dataset_valid, target_col_name=data_config['target_col'], target_dtype=torch.float),
    train_batch_size=batch_size,
    valid_batch_size=batch_size,
    train_num_workers=num_workers,
)

## Finetune модели

In [None]:
early_stop_callback = EarlyStopping(monitor=f"val_{metric._get_name()}", min_delta=0.01, patience=3, verbose=False, mode='max')
trainer = pl.Trainer(
    max_epochs=20,
    gpus=1 if torch.cuda.is_available() else 0,
    enable_progress_bar=True,
    callbacks = [early_stop_callback]
)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True, used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [None]:
%%time
gc.collect()
torch.cuda.empty_cache()

trainer.fit(model, sup_data)

INFO:pytorch_lightning.accelerators.gpu:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name          | Type          | Params
------------------------------------------------
0 | seq_encoder   | RnnSeqEncoder | 80.4 K
1 | head          | Head          | 385   
2 | loss          | BCELoss       | 0     
3 | train_metrics | ModuleDict    | 0     
4 | valid_metrics | ModuleDict    | 0     
5 | test_metrics  | ModuleDict    | 0     
------------------------------------------------
80.8 K    Trainable params
0         Non-trainable params
80.8 K    Total params
0.323     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

CPU times: user 1min 26s, sys: 9.61 s, total: 1min 36s
Wall time: 2min 2s


In [None]:
gc.collect()
torch.cuda.empty_cache()
trainer.test(ckpt_path='best', dataloaders=sup_data.val_dataloader())

INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at /content/lightning_logs/version_3/checkpoints/epoch=4-step=3215.ckpt
INFO:pytorch_lightning.accelerators.gpu:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.utilities.rank_zero:Loaded model weights from checkpoint at /content/lightning_logs/version_3/checkpoints/epoch=4-step=3215.ckpt


Testing: 0it [00:00, ?it/s]

[{'test_BinaryAUROC': 0.8600592017173767}]

In [None]:
coles_model.seq_encoder = model.seq_encoder
torch.save(coles_model.state_dict(), os.path.join(path_to_working_directory,'models','finetuned_'+config_name.replace('yaml','pt')))

## Сохранение эмбеддингов

In [None]:
def embedding_dataset(coles_model,prep_records_data,trainer,id_col,target_col=None):

  dl = inference_data_loader(prep_records_data, num_workers=1, batch_size=16)
  embeds = torch.vstack(trainer.predict(coles_model, dl))

  df = pd.DataFrame(data=embeds, columns=[f'embed_{i}' for i in range(embeds.shape[1])])
  df[id_col] = [x[id_col] for x in prep_records_data]
  if target_col:
    df[target_col] = [x[target_col] for x in prep_records_data]

  return df.drop_duplicates(subset=id_col)

In [None]:
gc.collect()
torch.cuda.empty_cache()
train_embded_df = embedding_dataset(coles_model,df_data_train,trainer,data_config['id_col'],data_config['target_col'])
valid_embded_df = embedding_dataset(coles_model,df_data_valid,trainer,data_config['id_col'],data_config['target_col'])
test_embded_df = embedding_dataset(coles_model,df_data_test,trainer,data_config['id_col'])

INFO:pytorch_lightning.accelerators.gpu:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 643it [00:00, ?it/s]

INFO:pytorch_lightning.accelerators.gpu:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 643it [00:00, ?it/s]

INFO:pytorch_lightning.accelerators.gpu:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 643it [00:00, ?it/s]

In [None]:
train_embded_df.to_csv(os.path.join(data_config['path_folder'],'train_finetuned_emb_'+config_name.replace('yaml','csv')),encoding='utf-8',index=False)
valid_embded_df.to_csv(os.path.join(data_config['path_folder'],'valid_finetuned__emb_'+config_name.replace('yaml','csv')),encoding='utf-8',index=False)
test_embded_df.to_csv(os.path.join(data_config['path_folder'],'test_finetuned__emb_'+config_name.replace('yaml','csv')),encoding='utf-8',index=False)