In [11]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import pyarrow.parquet as pq
import sys
import random

sys.path.append('../../../')
from configs.data_configs.rosbank import data_configs

import split_strategy
from splitting_dataset import SplittingDataset

In [2]:
conf = data_configs()

In [3]:
df = pd.read_parquet(conf.train_path)
df.head()

Unnamed: 0,cl_id,amount,event_time,mcc,channel_type,currency,trx_category,trx_count,target_target_flag,target_target_sum
0,10018,"[10.609081944147828, 10.596659732783579, 10.81...","[17120.38773148148, 17133.667800925927, 17134....","[13, 2, 13, 2, 1, 18, 13, 2, 13, 2, 5, 13, 9, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]","[5, 3, 5, 3, 1, 1, 5, 3, 5, 3, 1, 5, 5, 5, 5]",15,0,0.0
1,10030,"[4.61512051684126, 6.90875477931522, 10.598857...","[17141.0, 17141.0, 17145.0, 17147.0, 17147.0, ...","[9, 9, 21, 1, 25, 6, 14, 14, 3, 3, 3, 13, 1, 3...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[5, 5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 7, 1, 1, 3, ...",42,1,59.51
2,10038,"[7.4127640174265625, 7.370230641807081, 7.8180...","[17301.0, 17301.0, 17301.0, 17301.774780092594...","[1, 1, 1, 2, 2, 4, 2, 8, 1, 22, 8, 1, 8, 4, 2,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[1, 1, 1, 2, 2, 1, 3, 1, 1, 1, 1, 1, 1, 1, 2, ...",111,0,0.0
3,10057,"[7.494708263135679, 7.736394428979239, 10.7789...","[17151.0, 17151.0, 17153.0, 17154.0, 17155.0, ...","[6, 21, 2, 6, 2, 4, 2, 22, 15, 2, 1, 35, 4, 2,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[1, 1, 4, 1, 4, 1, 3, 1, 1, 3, 1, 1, 1, 4, 1, ...",61,1,62961.31
4,10062,"[8.31898612539206, 8.824824939175638, 6.509067...","[17143.0, 17143.0, 17143.0, 17144.0, 17144.0, ...","[80, 15, 37, 38, 11, 11, 2, 24, 7, 5, 5, 11, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, ...",82,1,107126.35


In [4]:
def read_pyarrow_file(path, use_threads=True):
    p_table = pq.read_table(
        source=path,
        use_threads=use_threads,
    )

    col_indexes = [n for n in p_table.column_names]

    def get_records():
        for rb in p_table.to_batches():
            col_arrays = [rb.column(i) for i, _ in enumerate(col_indexes)]
            col_arrays = [a.to_numpy(zero_copy_only=False) for a in col_arrays]
            for row in zip(*col_arrays):
                # np.array(a) makes `a` writable for future usage
                rec = {n: np.array(a) if isinstance(a, np.ndarray) else a for n, a in zip(col_indexes, row)}
                yield rec

    return get_records()


def prepare_embeddings(seq, conf, is_train):
    min_seq_len = 1
    embeddings = list(conf.features.embeddings.keys())

    feature_keys = embeddings + list(conf.features.numeric_values.keys())

    for rec in seq:
        seq_len = len(rec['event_time'])
        if is_train and seq_len < min_seq_len:
            continue

        if 'feature_arrays' in rec:
            feature_arrays = rec['feature_arrays']
            feature_arrays = {k: v for k, v in feature_arrays.items() if k in feature_keys}
        else:
            feature_arrays = {k: v for k, v in rec.items() if k in feature_keys}

        # TODO: datetime processing. Take date-time features

        # shift embeddings to 1, 0 is padding value
        feature_arrays = {k: v + (1 if k in embeddings else 0) for k, v in feature_arrays.items()}

        # clip embeddings dictionary by max value
        for e_name, e_params in conf.features.embeddings.items():
            feature_arrays[e_name] = feature_arrays[e_name].clip(0, e_params['in'] - 1)

        feature_arrays['event_time'] = rec['event_time']

        rec['feature_arrays'] = feature_arrays
        yield rec

def shuffle_client_list_reproducible(conf, data):
    if conf.client_list_shuffle_seed != 0:
        dataset_col_id = conf.get('col_id', 'client_id')
        data = sorted(data, key=lambda x: x.get(dataset_col_id)) #changed from COLES a bit
        random.Random(conf.client_list_shuffle_seed).shuffle(data)
    return data

In [5]:
data = read_pyarrow_file(conf.train_path)
data = tqdm(data)

data = prepare_embeddings(data, conf, is_train=True)

0it [00:00, ?it/s]

In [6]:
a = next(data)

In [7]:
a['feature_arrays']

{'amount': array([10.60908194, 10.59665973, 10.81979828, 10.81979828,  4.62291187,
         5.70378247,  9.90353755,  9.90353755, 10.30898599, 10.30898599,
         4.56434819,  9.90353755,  9.61587214, 10.88183247,  8.6996814 ]),
 'mcc': array([14,  3, 14,  3,  2, 19, 14,  3, 14,  3,  6, 14, 10, 14, 14],
       dtype=int32),
 'channel_type': array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=int32),
 'currency': array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=int32),
 'trx_category': array([6, 4, 6, 4, 2, 2, 6, 4, 6, 4, 2, 6, 6, 6, 6], dtype=int32),
 'event_time': array([17120.38773148, 17133.66780093, 17134.58347222, 17134.60094907,
        17138.        , 17138.        , 17142.73583333, 17148.50952546,
        17150.55356481, 17150.55957176, 17151.        , 17200.29989583,
        17217.62599537, 17224.56177083, 17227.12228009])}

In [8]:
def prepare_data(conf):
    data = read_pyarrow_file(conf.train_path)
    data = tqdm(data)
    
    data = prepare_embeddings(data, conf, is_train=True)
    data = shuffle_client_list_reproducible(conf, data)
    data = list(data)


    valid_ix = np.arange(len(data))
    valid_ix = np.random.choice(valid_ix, size=int(len(data) * conf.valid_size), replace=False)
    valid_ix = set(valid_ix.tolist())

   # logger.info(f'Loaded {len(data)} rows. Split in progress...')
    train_data = [rec for i, rec in enumerate(data) if i not in valid_ix]
    valid_data = [rec for i, rec in enumerate(data) if i in valid_ix]

    #logger.info(f'Train data len: {len(train_data)}, Valid data len: {len(valid_data)}')

    return train_data, valid_data

In [9]:
train, valid = prepare_data(conf)



9717it [00:00, 16477.76it/s]


In [14]:
train_dataset = SplittingDataset(
        train,
        split_strategy.create(**conf.split_strategy)
    )

In [22]:
len(train_dataset[0][0]['amount']), len(train_dataset[0][2]['amount'])

(34, 24)

In [9]:
def create_data_loaders(conf):
    train_data, valid_data = prepare_data(conf)

    train_dataset = SplittingDataset(
        train_data,
        split_strategy.create(**conf['params.train.split_strategy'])
    )
    train_dataset = TargetEnumeratorDataset(train_dataset)
    train_dataset = ConvertingTrxDataset(train_dataset)
    train_dataset = DropoutTrxDataset(train_dataset, trx_dropout=conf['params.train.trx_dropout'],
                                      seq_len=conf['params.train.max_seq_len'])

    if conf['params.train'].get('all_time_shuffle',False):
        train_dataset = AllTimeShuffleMLDataset(train_dataset)
        logger.info('AllTimeShuffle used')

    train_loader = DataLoader(
        dataset=train_dataset,
        shuffle=True,
        collate_fn=collate_splitted_rows,
        num_workers=conf['params.train'].get('num_workers', 0),
        batch_size=conf['params.train.batch_size'],
    )

    valid_dataset = SplittingDataset(
        valid_data,
        split_strategy.create(**conf['params.valid.split_strategy'])
    )
    valid_dataset = TargetEnumeratorDataset(valid_dataset)
    valid_dataset = ConvertingTrxDataset(valid_dataset)
    valid_dataset = DropoutTrxDataset(valid_dataset, trx_dropout=0.0,
                                      seq_len=conf['params.valid.max_seq_len'])
    valid_loader = DataLoader(
        dataset=valid_dataset,
        shuffle=False,
        collate_fn=collate_splitted_rows,
        num_workers=conf['params.valid'].get('num_workers', 0),
        batch_size=conf['params.valid.batch_size'],
    )

    return train_loader, valid_loader

{'cl_id': '1964',
 'amount': array([ 5.10594547,  6.28226675,  5.83188248,  6.09356977,  5.50125821,
         6.60123012,  6.00635316,  8.01697775,  6.47234629,  7.54380287,
         6.71538339,  5.14166356,  6.71417053,  7.43838353,  4.18965474,
         6.90875478,  5.64897424,  5.79909265,  4.61512052,  5.91620206,
         6.57646957,  5.83188248,  6.51471269,  7.37900813,  5.83188248,
         7.00488199,  6.38435087,  7.00397414,  8.13153071,  5.48844174,
         6.37927472,  7.31388683,  6.59850903,  6.20455776,  5.14166356,
         5.09986643,  5.13579844,  6.88448665,  5.64897424,  5.59842196,
         8.85380827,  7.20063383,  6.71538339, 10.1266711 ,  5.83188248,
         4.4543473 ,  6.09582456,  5.32787617,  5.44673737,  5.01727984,
         3.4339872 ,  5.01727984,  5.99146455,  6.2166061 ,  6.90875478,
        12.20607765,  7.60140233,  6.4707995 ,  5.75257264,  6.85646198,
         7.60090246,  6.07534603,  6.90875478,  7.54791299,  7.31388683,
         6.90875478,  6