In [1]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import pyarrow.parquet as pq
import sys
import random

sys.path.append('../../../')
from configs.data_configs.rosbank import data_configs

import split_strategy
from splitting_dataset import SplittingDataset, TargetEnumeratorDataset, ConvertingTrxDataset, DropoutTrxDataset


In [2]:
conf = data_configs()

In [3]:
df = pd.read_parquet(conf.train_path)
df.head()

Unnamed: 0,cl_id,amount,event_time,mcc,channel_type,currency,trx_category,trx_count,target_target_flag,target_target_sum
0,10018,"[10.609081944147828, 10.596659732783579, 10.81...","[17120.38773148148, 17133.667800925927, 17134....","[13, 2, 13, 2, 1, 18, 13, 2, 13, 2, 5, 13, 9, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]","[5, 3, 5, 3, 1, 1, 5, 3, 5, 3, 1, 5, 5, 5, 5]",15,0,0.0
1,10030,"[4.61512051684126, 6.90875477931522, 10.598857...","[17141.0, 17141.0, 17145.0, 17147.0, 17147.0, ...","[9, 9, 21, 1, 25, 6, 14, 14, 3, 3, 3, 13, 1, 3...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[5, 5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 7, 1, 1, 3, ...",42,1,59.51
2,10038,"[7.4127640174265625, 7.370230641807081, 7.8180...","[17301.0, 17301.0, 17301.0, 17301.774780092594...","[1, 1, 1, 2, 2, 4, 2, 8, 1, 22, 8, 1, 8, 4, 2,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[1, 1, 1, 2, 2, 1, 3, 1, 1, 1, 1, 1, 1, 1, 2, ...",111,0,0.0
3,10057,"[7.494708263135679, 7.736394428979239, 10.7789...","[17151.0, 17151.0, 17153.0, 17154.0, 17155.0, ...","[6, 21, 2, 6, 2, 4, 2, 22, 15, 2, 1, 35, 4, 2,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[1, 1, 4, 1, 4, 1, 3, 1, 1, 3, 1, 1, 1, 4, 1, ...",61,1,62961.31
4,10062,"[8.31898612539206, 8.824824939175638, 6.509067...","[17143.0, 17143.0, 17143.0, 17144.0, 17144.0, ...","[80, 15, 37, 38, 11, 11, 2, 24, 7, 5, 5, 11, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, ...",82,1,107126.35


In [4]:
def read_pyarrow_file(path, use_threads=True):
    p_table = pq.read_table(
        source=path,
        use_threads=use_threads,
    )

    col_indexes = [n for n in p_table.column_names]

    def get_records():
        for rb in p_table.to_batches():
            col_arrays = [rb.column(i) for i, _ in enumerate(col_indexes)]
            col_arrays = [a.to_numpy(zero_copy_only=False) for a in col_arrays]
            for row in zip(*col_arrays):
                # np.array(a) makes `a` writable for future usage
                rec = {n: np.array(a) if isinstance(a, np.ndarray) else a for n, a in zip(col_indexes, row)}
                yield rec

    return get_records()


def prepare_embeddings(seq, conf, is_train):
    min_seq_len = 1
    embeddings = list(conf.features.embeddings.keys())

    feature_keys = embeddings + list(conf.features.numeric_values.keys())

    for rec in seq:
        seq_len = len(rec['event_time'])
        if is_train and seq_len < min_seq_len:
            continue

        if 'feature_arrays' in rec:
            feature_arrays = rec['feature_arrays']
            feature_arrays = {k: v for k, v in feature_arrays.items() if k in feature_keys}
        else:
            feature_arrays = {k: v for k, v in rec.items() if k in feature_keys}

        # TODO: datetime processing. Take date-time features

        # shift embeddings to 1, 0 is padding value
        feature_arrays = {k: v + (1 if k in embeddings else 0) for k, v in feature_arrays.items()}

        # clip embeddings dictionary by max value
        for e_name, e_params in conf.features.embeddings.items():
            feature_arrays[e_name] = feature_arrays[e_name].clip(0, e_params['in'] - 1)

        feature_arrays['event_time'] = rec['event_time']

        rec['feature_arrays'] = feature_arrays
        yield rec

def shuffle_client_list_reproducible(conf, data):
    if conf.client_list_shuffle_seed != 0:
        dataset_col_id = conf.get('col_id', 'client_id')
        data = sorted(data, key=lambda x: x.get(dataset_col_id)) #changed from COLES a bit
        random.Random(conf.client_list_shuffle_seed).shuffle(data)
    return data

In [5]:
def prepare_data(conf):
    data = read_pyarrow_file(conf.train_path)
    data = tqdm(data)
    
    data = prepare_embeddings(data, conf, is_train=True)
    data = shuffle_client_list_reproducible(conf, data)
    data = list(data)


    valid_ix = np.arange(len(data))
    valid_ix = np.random.choice(valid_ix, size=int(len(data) * conf.valid_size), replace=False)
    valid_ix = set(valid_ix.tolist())

   # logger.info(f'Loaded {len(data)} rows. Split in progress...')
    train_data = [rec for i, rec in enumerate(data) if i not in valid_ix]
    valid_data = [rec for i, rec in enumerate(data) if i in valid_ix]

    #logger.info(f'Train data len: {len(train_data)}, Valid data len: {len(valid_data)}')

    return train_data, valid_data

In [6]:
train, valid = prepare_data(conf)

4501it [00:00, 22805.94it/s]

9717it [00:00, 18333.81it/s]


In [10]:
train_dataset = SplittingDataset(
        train,
        split_strategy.create(**conf.train.split_strategy)
    )

train_dataset = TargetEnumeratorDataset(train_dataset)
train_dataset = ConvertingTrxDataset(train_dataset)

In [11]:
train_dataset[5][0]

({'amount': tensor([ 6.9383,  6.1944,  6.6284,  9.2104, 10.3090,  7.9728, 10.8198, 10.5967,
           8.1306,  6.0753,  4.4543,  7.0598,  6.9856,  6.3168,  4.7005,  5.4647,
           6.0831,  7.4430,  6.0124,  7.2951,  7.9494,  5.2193,  6.4693,  6.3335,
           8.2496,  9.6097,  5.8230,  6.6431,  6.8967,  4.4188,  6.5761, 10.8198,
           5.9254,  5.9819,  6.5596,  7.4568,  4.1636,  8.5073,  8.8847,  7.3172,
           6.0986,  9.8828,  9.1009,  9.4351,  6.6265,  7.4989,  8.8538, 10.3090,
          10.3090,  7.8335,  6.1506,  7.9452,  6.2425,  6.5132, 10.3981,  6.6053,
           6.1417], dtype=torch.float64),
  'mcc': tensor([ 5,  2, 19,  3,  3, 59,  3,  3,  2, 73,  2,  2,  5,  4,  2,  4,  2,  2,
           2, 39,  7,  2,  2,  2, 99, 52,  2,  2,  9,  6,  2,  3,  2,  9,  8,  8,
           8, 26, 21,  8,  2, 58, 58, 69,  7,  7,  3,  3,  3,  7, 57,  7,  2,  4,
          75, 57,  2], dtype=torch.int32),
  'channel_type': tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2

In [12]:
import functools
import operator
from collections import defaultdict

import math
from collections import OrderedDict
from typing import Dict

import torch
import torch.nn as nn


batch = [train_dataset[0], train_dataset[1]]

batch1 = functools.reduce(operator.iadd, batch)

In [13]:
new_x_ = defaultdict(list)
for x, _ in batch1:
    for k, v in x.items():
        new_x_[k].append(v)

In [14]:
lengths = torch.IntTensor([len(e) for e in next(iter(new_x_.values()))])

In [15]:
lengths

tensor([ 49,  22, 117, 118,  96,  62,  40,  47,  55,  62], dtype=torch.int32)

In [16]:
new_x = {k: torch.nn.utils.rnn.pad_sequence(v, batch_first=True) for k, v in new_x_.items()}

In [17]:
new_y = torch.tensor([y for _, y in batch1])

In [18]:
class PaddedBatch:
    def __init__(self, payload: Dict[str, torch.Tensor], length: torch.LongTensor):
        self._payload = payload
        self._length = length

    @property
    def payload(self):
        return self._payload

    @property
    def seq_lens(self):
        return self._length

    def __len__(self):
        return len(self._length)

    def to(self, device, non_blocking=False):
        length = self._length.to(device=device, non_blocking=non_blocking)
        payload = {
            k: v.to(device=device, non_blocking=non_blocking) for k, v in self._payload.items()
        }
        return PaddedBatch(payload, length)

In [23]:
from torch.utils.data import DataLoader

def padded_collate(batch):
    new_x_ = defaultdict(list)
    for x, _ in batch:
        for k, v in x.items():
            new_x_[k].append(v)

    lengths = torch.IntTensor([len(e) for e in next(iter(new_x_.values()))])

    new_x = {k: torch.nn.utils.rnn.pad_sequence(v, batch_first=True) for k, v in new_x_.items()}
    new_y = torch.tensor([y for _, y in batch])

    return PaddedBatch(new_x, lengths), new_y

def collate_splitted_rows(batch):
    # flattens samples in list of lists to samples in list
    batch = functools.reduce(operator.iadd, batch)
    return padded_collate(batch)



def create_data_loaders(conf):
    train_data, valid_data = prepare_data(conf)

    train_dataset = SplittingDataset(
        train_data,
        split_strategy.create(**conf.train.split_strategy)
    )
    train_dataset = TargetEnumeratorDataset(train_dataset)
    train_dataset = ConvertingTrxDataset(train_dataset)
    # не уверен что нам нужна история с дропаутом точек.
    # Но это выглядит неплохой аугментацией в целом
    train_dataset = DropoutTrxDataset(train_dataset, trx_dropout=conf.train.dropout,
                                       seq_len=conf.train.max_seq_len)
    
    train_loader = DataLoader(
        dataset=train_dataset,
        shuffle=True,
        collate_fn=collate_splitted_rows,
        num_workers=1,
        batch_size=3,
    )

    # valid_dataset = SplittingDataset(
    #     valid_data,
    #     split_strategy.create(**conf['params.valid.split_strategy'])
    # )
    # valid_dataset = TargetEnumeratorDataset(valid_dataset)
    # valid_dataset = ConvertingTrxDataset(valid_dataset)
    # valid_dataset = DropoutTrxDataset(valid_dataset, trx_dropout=0.0,
    #                                   seq_len=conf['params.valid.max_seq_len'])
    # valid_loader = DataLoader(
    #     dataset=valid_dataset,
    #     shuffle=False,
    #     collate_fn=collate_splitted_rows,
    #     num_workers=conf['params.valid'].get('num_workers', 0),
    #     batch_size=conf['params.valid.batch_size'],
    # )

    return train_loader#, valid_loader

In [24]:
train_loader = create_data_loaders(conf)

0it [00:00, ?it/s]

9717it [00:00, 15863.74it/s]


In [28]:
for batch in train_loader:
    break

In [29]:
batch

(<__main__.PaddedBatch at 0x7f7741f4c100>,
 tensor([7824, 7824, 7824, 7824, 7824, 7662, 7662, 7662, 7662, 7662, 1959, 1959,
         1959, 1959, 1959]))