In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys
import torch 
import numpy as np
import pickle
from torch.utils.data import Dataset, DataLoader

sys.path.append('../')

from src.transactions_qa.tqa_model import TransactionQAModel
from src.models.components.models import TransactionsModel
from src.utils.tools import (make_time_batch, 
                   calculate_embedding_size)

from src.data.alfa.components import ( 
                             cat_features_names, 
                             num_features_names, 
                             meta_features_names)

from src.data import AlfaDataModule, AlfaPretrainingDataModule
from src.transactions_qa.tqa_model import TransactionQAModel
from src.transactions_qa.utils import get_projections_maps
from src.tasks import AbstractTask, AutoTask

from src.models.components.embedding import EmbeddingLayer

import pytorch_lightning as pl

from ptls.frames import PtlsDataModule
from ptls.frames.bert import MLMPretrainModule
from ptls.frames.coles import CoLESModule
from ptls.frames.cpc import CpcModule

from ptls.nn import TransformerEncoder
from ptls.nn.seq_encoder.containers import SeqEncoderContainer, RnnSeqEncoder
from functools import partial
from collections import namedtuple
from ptls.frames.coles.split_strategy import SampleSlices


In [2]:
def load_transaction_model(encoder_type='whisper/tiny', head_type='next'):
    projections_maps = get_projections_maps(relative_folder="..")
    # Loading Transactions model & weights
    print(f"Loading Transactions model...")

    transactions_model_encoder_type = encoder_type
    transactions_model_head_type = head_type


    transactions_model_config = {
        "cat_features": cat_features_names,
        "cat_embedding_projections": projections_maps.get('cat_embedding_projections'),
        "num_features": num_features_names,
        "num_embedding_projections": projections_maps.get('num_embedding_projections'),
        "meta_features": meta_features_names,
        "meta_embedding_projections": projections_maps.get('meta_embedding_projections'),
        "encoder_type": transactions_model_encoder_type,
        "head_type": transactions_model_head_type,
        "embedding_dropout": 0.1
    }
    transactions_model = TransactionsModel(**transactions_model_config)

    return transactions_model, projections_maps

In [3]:
transaction_model, projection_maps = load_transaction_model(head_type='next')

Loading Transactions model...
USING whisper


In [4]:
class PtlsPaddedBatch:
    def __init__(self, data, mask):
        self.payload = data
        self.seq_lens = torch.LongTensor([data.shape[1]] * data.shape[0])
        self.seq_len_mask = mask

In [5]:
class PtlsEmbeddingLayer(EmbeddingLayer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.output_size = super().get_embedding_size()

    def forward(self, x):
        mask=x['mask']
        x = super().forward(x)
        return x

In [6]:
trx_encoder = PtlsEmbeddingLayer(projection_maps['cat_embedding_projections'],
                                    cat_features_names,
                                    projection_maps['num_embedding_projections'],
                                    num_features_names)

In [7]:
dm = AlfaDataModule(data_dir='/home/jovyan/romashka/data/')

### Resaving weights

In [10]:
# ckpt = torch.load('/home/jovyan/checkpoints/transactions_model/final_model.ckpt')
# new_ckpt = {}
# for elem in ckpt:
#     if 'encoder' in elem:
#         new_ckpt['encoder_model.' + elem] = ckpt[elem]
#     elif 'head' in elem:
#         new_ckpt['head.'+elem] = ckpt[elem]
#     elif 'mapping_embedding' in elem:
#         new_ckpt['connector.connector'+elem[len('mapping_embedding'):]] = ckpt[elem]
#     else:
#         new_ckpt[elem] = ckpt[elem]

# torch.save(new_ckpt, '/home/jovyan/checkpoints/transactions_model/final_model_v2.ckpt')

### Splitter

In [8]:
splitter = SampleSlices(split_count=7, cnt_min=10, cnt_max=30, is_sorted=True)

In [9]:
splitter.split_count

7

In [12]:
new_dm = AlfaPretrainingDataModule(data_dir='/home/jovyan/romashka/data', rep=7, mode='coles', splitter=splitter)

In [13]:
batch = next(iter(new_dm.val_dataloader()))

In [12]:
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    output = {}

    # cat_features shape 1 x cat_features x seq_len
    # num_features shape 1 x num_features x seq_len
    # meta_feature shape 1 x meta_features
    # mask shape 1 x seq_len
    # label shape 1

    # checking batch_size correctness
    assert batch[0]['num_features'].shape[1] == 1, "Incorrect output of dataloader"

    output['num_features'] = pad_sequence([d['num_features'].transpose(0, -1) for d in batch], # num_features x batch_size x seq_len
                                            batch_first=True).squeeze(2).permute(-1, 0, 1)
    output['cat_features'] = pad_sequence([d['cat_features'].transpose(0, -1) for d in batch], # cat_features x batch_size x seq_len
                                            batch_first=True).squeeze(2).permute(-1, 0, 1)
    output['meta_features'] = torch.cat([d['meta_features'] for d in batch], dim=1) # meta_features x batch_size

    output['mask'] = pad_sequence([d['mask'].transpose(0, -1) for d in batch], batch_first=True).squeeze(2)
    output['app_id'] = torch.cat([d['app_id'] for d in batch])

    if 'label' in batch[0]:
        output['label'] = torch.cat([d['label'] for d in batch])

    return output

In [13]:
def split_process(batch, splitter):
    res = {}


    seq_len = batch['mask'].shape[1]
    local_date = torch.arange(seq_len)
    if splitter is not None:
        indexes = splitter.split(local_date)
        pad_size = max([len(ixs) for ixs in indexes])
    
    for k, v in batch.items():
        if k in ['num_features', 'cat_features'] and splitter is not None:
            new_v = []
            for elem in v:
                tmp = []
                for i, ixs in enumerate(indexes):
                    to_tmp = elem[:, ixs]
                    if to_tmp.shape[1] < pad_size:
                        to_tmp = torch.cat([
                            to_tmp, torch.zeros(to_tmp.shape[0], pad_size - to_tmp.shape[1], dtype=torch.int)
                        ], axis=1)
                    tmp.append(to_tmp)
                new_v.append(torch.cat(tmp, dim=0))
            new_v = torch.stack(new_v, dim=0)
        elif k == 'meta_features' and splitter is not None:
            new_v = v.repeat(1, len(indexes))
        else:
            new_v = v 
        res[k] = new_v
    res['mask'] = res['cat_features'][0] != 0
    return res

In [14]:
def my_collate_fn(batch, splitter, rep=7, mode='coles'):
    batch = collate_fn(batch)
    len_batch = batch['num_features'][0].shape[0]
    labels = torch.arange(len_batch, device=batch['mask'].device).repeat(rep)
    batch = split_process(batch, splitter)
    
    if mode == 'cpc':
        return batch, None
    elif mode == 'coles':
        return batch, labels

In [15]:
train_dataloader = DataLoader(
    dm.val_ds,
    batch_size=32,
    collate_fn=partial(my_collate_fn, splitter=splitter,
                       rep=7 if splitter is not None else 1,
                       mode='coles')

)

In [16]:
b = next(iter(train_dataloader))

In [17]:
batch = next(iter(dm.val_dataloader()))

### Training

In [5]:
import torch.nn as nn

class EmbeddingPlusConnector(nn.Module):
    def __init__(self, embedding_layer, connector):
        super().__init__()
        self.embedding_layer = embedding_layer
        self.connector = connector
        self.output_size = connector.output_size

    def forward(self, batch):
        mask = batch['mask']
        
        embedding = self.embedding_layer(batch)
        embedding = self.connector(embedding, attention_mask=mask)
        return embedding

In [11]:
# TODO split trx_encoder and seq_encoder

class MySeqEncoder(SeqEncoderContainer):
    def __init__(self,):

        params =  {
            'cat_embedding_projections': projection_maps['cat_embedding_projections'],
            'cat_features': cat_features_names,
            'num_embedding_projections': projection_maps['num_embedding_projections'],
            'num_features': num_features_names,
            'head_type': 'pretraining_last_output',
            'encoder_type': 'whisper/tiny'
        }
        super().__init__(
            trx_encoder=None,
            seq_encoder_cls=TransactionsModel,
            input_size=False,
            seq_encoder_params=params,
            is_reduce_sequence=False,
        )
        self.trx_encoder = EmbeddingPlusConnector(self.seq_encoder.embedding, self.seq_encoder.connector)
                
        self.full_model = self.seq_encoder
        self.seq_encoder = self.seq_encoder.encoder_model
        self.seq_encoder.embedding_size = self.seq_encoder.output_size
    
    def forward(self, x, h_0=None):
        x = self.full_model(x)
        return x

In [12]:
seq_encoder = MySeqEncoder()

USING whisper


In [29]:
model = CoLESModule(
    seq_encoder=seq_encoder,
    optimizer_partial=partial(torch.optim.Adam, lr=1e-3),
    lr_scheduler_partial=partial(torch.optim.lr_scheduler.StepLR, step_size=20, gamma=0.9)
)

In [75]:
trainer = pl.Trainer()

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [51]:
tmp = model.shared_step(*b)

In [76]:
trainer.fit(model, train_dataloader)

Missing logger folder: /home/jovyan/romashka/notebooks/lightning_logs

  | Name               | Type            | Params
-------------------------------------------------------
0 | _loss              | ContrastiveLoss | 0     
1 | _seq_encoder       | MySeqEncoder    | 29.5 M
2 | _validation_metric | BatchRecallTopK | 0     
3 | _head              | Head            | 0     
-------------------------------------------------------
29.5 M    Trainable params
0         Non-trainable params
29.5 M    Total params
117.860   Total estimated model params size (MB)
  tensorboard.__version__
  from urllib3.contrib.pyopenssl import orig_util_SSLContext as SSLContext
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  (np.object, string),
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  (np.bool, bool),
Deprecated in NumPy 1.20; for more details and gui

Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
