In [1]:
from datetime import datetime

import polars as pl

In [2]:
months = [
    '2022-02-28',
    '2022-03-31',
    '2022-04-30',
    '2022-05-31',
    '2022-06-30',
    '2022-07-31',
    '2022-08-31',
    '2022-09-30',
    '2022-10-31',
    '2022-11-30',
    '2022-12-31',
    '2023-01-31',
]
month2id = dict((month, idx) for idx, month in enumerate(months))
id2month = dict((idx, month) for idx, month in enumerate(months))

In [3]:
def make_features(data, mon):
    min_dt = datetime(2021, 1, 1)
    features = (
        data
        .filter(pl.col("event_time") < datetime.strptime(months[mon], "%Y-%m-%d").replace(day=1))
        .with_columns(event_day=(pl.col("event_time") - min_dt).dt.days())
        .group_by("client_id")
        .agg(
            trx_cnt=pl.count(),
            trx_uniq_days=pl.col("event_day").n_unique(),
            trx_first_day=pl.col("event_day").min(),
            trx_last_day=pl.col("event_day").max(),
            
            event_type_uniq=pl.col("event_type").n_unique(),
            event_subtype_uniq=pl.col("event_subtype").n_unique(),
            currency_uniq=pl.col("currency").n_unique(),
            
            dst_type11_uniq=pl.col("dst_type11").n_unique(),
            dst_type12_uniq=pl.col("dst_type12").n_unique(),
            src_type11_uniq=pl.col("src_type11").n_unique(),
            src_type12_uniq=pl.col("src_type12").n_unique(),
            
            src_type21=pl.col("src_type21").fill_null(-2).first().cast(pl.Int32),
            src_type22=pl.col("src_type22").fill_null(-2).first().cast(pl.Int32),
            src_type31=pl.col("src_type31").fill_null(-2).first().cast(pl.Int32),
            src_type32=pl.col("src_type32").fill_null(-2).first().cast(pl.Int32),
        )
        .with_columns(
            trx_len_period=pl.col("trx_last_day") - pl.col("trx_first_day") + 1,
        )
        .with_columns(
            trx_cnt_per_day=pl.col("trx_cnt") / pl.col("trx_uniq_days"),
            trx_density=pl.col("trx_cnt") / pl.col("trx_len_period"),
        )
    )
    
    rub_features = (
        data
        .filter(pl.col("currency") == 11.0)
        .filter(pl.col("event_time") < datetime.strptime(months[mon], "%Y-%m-%d").replace(day=1))
        .group_by("client_id")
        .agg(
            amt_cnt=pl.count(),
            amt_sum=pl.col("amount").sum(),
            amt_min=pl.col("amount").min(),
            amt_max=pl.col("amount").max(),
            amt_mean=pl.col("amount").mean(),
            amt_median=pl.col("amount").median(),
        )
        .with_columns(
            amt_range=pl.col("amt_max") - pl.col("amt_min"),
        )
    )
    features = features.join(rub_features, on="client_id", how="left").fill_null(0)
    
    return features

In [4]:
mon = 9

In [5]:
import gc
import os
from collections import Counter
from tqdm.auto import tqdm

In [6]:
counter = Counter()
for file in tqdm(os.listdir("./data/trx_train.parquet/")):
    train_trx = pl.read_parquet("./data/trx_train.parquet/" + file)
    counter.update(train_trx["client_id"].unique())

  0%|          | 0/14 [00:00<?, ?it/s]

In [7]:
one = []
two = []
for client_id, count in counter.most_common():
    if count == 1:
        one.append(client_id)
    else:
        two.append(client_id)

In [8]:
len(one), len(two)

(855601, 26)

In [9]:
train_features = []
buf = []
for file in tqdm(os.listdir("./data/trx_train.parquet/")):
    train_trx = pl.read_parquet("./data/trx_train.parquet/" + file)
    buf.append(train_trx.filter(pl.col("client_id").is_in(two)))
    train_trx = train_trx.filter(pl.col("client_id").is_in(one))
    train_features.append(make_features(train_trx, mon))
    gc.collect()

  0%|          | 0/14 [00:00<?, ?it/s]

In [10]:
buf = pl.concat(buf)
train_features.append(make_features(buf, mon))

In [11]:
train_features = pl.concat(train_features)
train_features

client_id,trx_cnt,trx_uniq_days,trx_first_day,trx_last_day,event_type_uniq,event_subtype_uniq,currency_uniq,dst_type11_uniq,dst_type12_uniq,src_type11_uniq,src_type12_uniq,src_type21,src_type22,src_type31,src_type32,trx_len_period,trx_cnt_per_day,trx_density,amt_cnt,amt_sum,amt_min,amt_max,amt_mean,amt_median,amt_range
str,u32,u32,i64,i64,u32,u32,u32,u32,u32,u32,u32,i32,i32,i32,i32,i64,f64,f64,u32,f32,f32,f32,f32,f32,f32
"""bccefa7c1a05f1…",4,4,459,655,1,1,1,1,1,1,1,32822,64,1574,55,197,1.0,0.020305,4,1.6050e6,126738.046875,722986.1875,401250.8125,377639.46875,596248.125
"""eab52e5e022623…",461,258,365,715,12,14,1,7,11,4,5,34170,38,250,4,351,1.786822,1.31339,461,3.7728256e7,0.01626,1.7513e6,81840.03125,25633.560547,1.7513e6
"""c9aedf5f595ee8…",7,7,606,643,4,4,1,3,3,3,3,25714,70,724,70,38,1.0,0.184211,7,48567.507812,72.615654,16671.464844,6938.215332,797.364014,16598.849609
"""005521bad2ae25…",614,292,365,715,12,12,1,11,15,4,6,8173,42,2202,76,351,2.10274,1.749288,614,1.0708401e7,1.107937,373351.6875,17440.392578,2464.007324,373350.59375
"""46a90aee7e5041…",44,41,367,682,2,2,1,3,3,1,1,35161,70,2388,74,316,1.073171,0.139241,44,942180.1875,55.219833,87457.789062,21413.185547,10029.341797,87402.570312
"""7739f41585bb10…",12,12,386,647,1,1,1,1,1,1,1,13013,41,2084,4,262,1.0,0.045802,12,238303.25,3877.471191,31418.40625,19858.603516,22510.658203,27540.935547
"""d03c5d92f39957…",19,17,381,697,2,2,1,2,3,1,1,41888,21,1109,6,317,1.117647,0.059937,19,7608341.5,1601.574463,1.4527e6,400439.03125,184743.15625,1.4511e6
"""63d78c012efbe2…",718,299,365,712,19,19,1,10,15,5,7,34390,37,1058,4,348,2.401338,2.063218,718,4.93413184e8,0.236961,4.0834684e7,687205.0,20687.628906,4.0834684e7
"""488aac61a66409…",157,111,365,704,11,12,1,9,11,4,6,32722,69,844,4,340,1.414414,0.461765,157,1.8721e6,1.277519,305337.875,11924.316406,2941.044189,305336.59375
"""49c3a17a991d26…",28,27,405,667,2,2,1,2,3,1,1,29010,70,724,70,263,1.037037,0.106464,28,4413521.5,1341.094116,1701070.5,157625.765625,39228.8125,1.6997e6


In [12]:
train_features["client_id"].value_counts(sort=True)

client_id,counts
str,u32
"""bccefa7c1a05f1…",1
"""eab52e5e022623…",1
"""c9aedf5f595ee8…",1
"""005521bad2ae25…",1
"""46a90aee7e5041…",1
"""7739f41585bb10…",1
"""d03c5d92f39957…",1
"""63d78c012efbe2…",1
"""488aac61a66409…",1
"""49c3a17a991d26…",1


In [13]:
counter = Counter()
for file in tqdm(os.listdir("./data/trx_test.parquet/")):
    val_trx = pl.read_parquet("./data/trx_test.parquet/" + file)
    counter.update(val_trx["client_id"].unique())

  0%|          | 0/3 [00:00<?, ?it/s]

In [14]:
one = []
two = []
for client_id, count in counter.most_common():
    if count == 1:
        one.append(client_id)
    else:
        two.append(client_id)

In [15]:
len(one), len(two)

(223319, 2425)

In [16]:
val_features = []
buf = []
for file in tqdm(os.listdir("./data/trx_test.parquet/")):
    val_trx = pl.read_parquet("./data/trx_test.parquet/" + file)
    buf.append(val_trx.filter(pl.col("client_id").is_in(two)))
    val_trx = val_trx.filter(pl.col("client_id").is_in(one))
    val_features.append(make_features(val_trx, mon))
    gc.collect()

  0%|          | 0/3 [00:00<?, ?it/s]

In [17]:
buf = pl.concat(buf)
val_features.append(make_features(buf, mon))

In [18]:
val_features = pl.concat(val_features)
val_features

client_id,trx_cnt,trx_uniq_days,trx_first_day,trx_last_day,event_type_uniq,event_subtype_uniq,currency_uniq,dst_type11_uniq,dst_type12_uniq,src_type11_uniq,src_type12_uniq,src_type21,src_type22,src_type31,src_type32,trx_len_period,trx_cnt_per_day,trx_density,amt_cnt,amt_sum,amt_min,amt_max,amt_mean,amt_median,amt_range
str,u32,u32,i64,i64,u32,u32,u32,u32,u32,u32,u32,i32,i32,i32,i32,i64,f64,f64,u32,f32,f32,f32,f32,f32,f32
"""5a36339de5f22c…",369,213,367,668,14,14,1,7,9,4,4,44014,85,1128,88,302,1.732394,1.221854,369,4.716638e7,2.163952,1.7415e6,127822.171875,16571.041016,1.7415e6
"""192b65b348353e…",450,146,410,697,17,17,1,10,15,8,15,40463,31,854,77,288,3.082192,1.5625,450,9.0279928e7,0.004329,3.8832e6,200622.0625,38882.023438,3.8832e6
"""2ca09e1786d76a…",3,3,400,606,1,1,1,1,1,1,1,24730,48,869,16,207,1.0,0.014493,3,149445.375,18267.900391,67362.492188,49815.125,63814.992188,49094.59375
"""eed59303cf657d…",6,6,373,648,1,1,1,1,1,1,1,19227,70,1574,55,276,1.0,0.021739,6,340687.15625,22317.25,110391.226562,56781.191406,60641.742188,88073.976562
"""cee09e6cbf852f…",23,22,371,689,2,2,1,2,2,1,1,7575,25,445,81,319,1.045455,0.0721,23,604614.625,86.921127,97494.984375,26287.591797,17929.710938,97408.0625
"""146d0244bb3533…",251,177,367,710,8,9,1,4,7,1,1,39186,52,1032,25,344,1.418079,0.729651,251,1.8137952e7,7.224364,875972.25,72262.757812,21764.730469,875965.0
"""fc0bcbb3b7429c…",15,15,391,684,1,1,1,1,1,1,1,13013,41,242,26,294,1.0,0.05102,15,241969.046875,832.436523,49384.816406,16131.269531,10348.915039,48552.378906
"""f63fa99c0e38b3…",7,7,388,680,3,3,1,1,1,1,1,20082,21,439,81,293,1.0,0.023891,7,1.0873e6,0.022489,742949.3125,155333.390625,45745.90625,742949.3125
"""c018a11fd3cd51…",19,17,371,522,3,3,1,4,4,3,3,22576,70,515,21,152,1.117647,0.125,19,16274.114258,49.619747,1621.671753,856.532349,917.66394,1572.052002
"""ee2991c4a8a81a…",857,282,365,668,8,7,1,5,8,1,2,33750,37,2455,61,304,3.039007,2.819079,857,9.14513024e8,5.419353,4.3650476e7,1.0671e6,76076.789062,4.3650472e7


In [19]:
val_features["client_id"].value_counts(sort=True)

client_id,counts
str,u32
"""5a36339de5f22c…",1
"""192b65b348353e…",1
"""2ca09e1786d76a…",1
"""eed59303cf657d…",1
"""cee09e6cbf852f…",1
"""146d0244bb3533…",1
"""fc0bcbb3b7429c…",1
"""f63fa99c0e38b3…",1
"""c018a11fd3cd51…",1
"""ee2991c4a8a81a…",1


In [20]:
train_features.write_parquet(f"./features/train_trx_features_{mon}_cat.pq")
val_features.write_parquet(f"./features/val_trx_features_{mon}_cat.pq")