In [1]:
import random
import warnings

import pandas as pd
import numpy as np
import polars as pl

from typing import List, Union

random.seed(2023)
np.random.seed(2023)

warnings.filterwarnings('ignore')

In [2]:
train_df = pd.read_parquet('./data/train.parquet.gzip').rename_axis('timestamp').reset_index()
test_df = pd.read_parquet('./data/test.parquet.gzip')

### Pandas train-val split and chrono split

In [3]:
def chrono_split(
        df: pd.DataFrame, 
        split_by_column: str = 'user_id', 
        ratio: float = 0.7, 
        col_timestamp: str = 'timestamp') -> List[pd.DataFrame]:

    df = df.sort_values([split_by_column, col_timestamp])
    groups = df.groupby(split_by_column)

    df["count"] = groups[split_by_column].transform("count")
    df["rank_s"] = groups.cumcount() + 1

    ratio = [ratio, 1 - ratio]
    splits = []
    prev_threshold = None
    for threshold in np.cumsum(ratio):
        condition = df["rank_s"] <= round(threshold * df["count"])
        if prev_threshold is not None:
            condition &= df["rank_s"] > round(prev_threshold * df["count"])
        splits.append(df[condition].drop(["rank_s", "count"], axis=1))
        prev_threshold = threshold

    return splits

def train_val_split(
        train_df: pd.DataFrame, 
        val_users_n: int = 200_000) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    
    user_ids = train_df['user_id'].unique()
    user_ids_val = random.sample(list(user_ids), val_users_n)
    condition = train_df['user_id'].isin(user_ids_val)

    val = train_df[condition]
    val_no_targets, val_targets = chrono_split(val, ratio=0.7)

    train = pd.concat([train_df[~condition], val_no_targets]).sort_values('timestamp')
    return train, val_no_targets, val_targets

train_split, val_no_targets, val_targets = train_val_split(
    train_df, # [train_df['user_id'].isin(test_df['user_id'].values)],
    val_users_n = 200_000
    )

print(f'Train All: {train_split.shape[0]:_}')
print(f"Train NZ: {train_split[train_split['timespent'] != 0].shape[0]:_}")

print(f"Validation Targets NZ: {val_targets[val_targets['timespent'] != 0].shape[0]:_}")
print(f'Validation Targets All: {val_targets.shape[0]:_}')

print(f"Validation no Targets NZ: {val_no_targets[val_no_targets['timespent'] != 0].shape[0]:_}")
print(f'Validation no Targets All: {val_no_targets.shape[0]:_}')

Train All: 135_783_164
Train NZ: 22_149_314
Validation Targets NZ: 1_430_156
Validation Targets All: 8_656_851
Validation no Targets NZ: 3_281_276
Validation no Targets All: 20_188_972


### Polars train-val split and chrono split

In [4]:
train_df_pl = pl.from_pandas(train_df)
test_df_pl = pl.from_pandas(test_df)

In [5]:
def train_val_split(
        train_df: pl.DataFrame, 
        val_users_n: int = 200_000) -> List[pd.DataFrame]:
    
    user_ids = train_df['user_id'].unique()
    user_ids_val = random.sample(list(user_ids), val_users_n)
    condition = pl.col('user_id').is_in(user_ids_val)

    return train_df.filter(~condition), train_df.filter(condition)

train_pt, val = train_val_split(
    train_df_pl, # .filter(~pl.col('user_id').is_in(test_df_pl['user_id'])),
    val_users_n = 200_000
    )

print(f'Train All: {train_pt.shape[0]:_}')
print(f"Train NZ: {train_pt.filter(pl.col('timespent') != 0).shape[0]:_}")

print(f'Validation All: {val.shape[0]:_}')
print(f"Validation NZ: {val.filter(pl.col('timespent') != 0).shape[0]:_}")

Train All: 115_480_428
Train NZ: 18_859_896
Validation All: 28_959_587
Validation NZ: 4_719_574


In [6]:
from typing import List, Union

def chrono_split(
        df: Union[pl.DataFrame, pd.DataFrame], 
        split_by_column: str = 'user_id', 
        ratio: float = 0.7, 
        col_timestamp: str = 'timestamp') -> List[pd.DataFrame]:
    
    if isinstance(df, pd.DataFrame):
        df = pl.from_pandas(df)

    df = df.sort([split_by_column, col_timestamp])
    df = df.with_columns([
        pl.col('user_id').cumcount().over(['user_id']).alias('rank_s'),
        pl.col('user_id').count().over(['user_id']).alias('count')
        ])
    
    ratio = [ratio, 1 - ratio]
    splits = []
    prev_threshold = None
    for threshold in np.cumsum(ratio):
        condition = df["rank_s"] <= (threshold * df["count"]) #.round(0)
        if prev_threshold is not None:
            condition &= df["rank_s"] > (prev_threshold * df["count"]) #.round(0)

        splits.append(
            df.filter(condition)
              .drop(["rank_s", "count"])
              .to_pandas()
              .set_index('timestamp')
              )
        
        prev_threshold = threshold

    return splits    

val_no_targets, val_targets = chrono_split(val, ratio=0.7)

print(f"Validation Targets NZ: {val_targets[val_targets['timespent'] != 0].shape[0]:_}")
print(f'Validation Targets All: {val_targets.shape[0]:_}')

print(f"Validation no Targets NZ: {val_no_targets[val_no_targets['timespent'] != 0].shape[0]:_}")
print(f'Validation no Targets All: {val_no_targets.shape[0]:_}')

Validation Targets NZ: 1_410_115
Validation Targets All: 8_580_193
Validation no Targets NZ: 3_309_459
Validation no Targets All: 20_379_394


In [7]:
train_split = pl.concat([train_pt, pl.from_pandas(val_no_targets.reset_index())]).sort(['timestamp']).to_pandas()

print(f'Train All: {train_split.shape[0]:_}')
print(f"Train NZ: {train_split[train_split['timespent'] != 0].shape[0]:_}")

Train All: 135_859_822
Train NZ: 22_169_355


In [8]:
train_split.set_index('timestamp').to_parquet('./data/splits/train.parquet.gzip', compression='gzip')
val_no_targets.to_parquet('./data/splits/val_no_targets.parquet.gzip', compression='gzip')
val_targets.to_parquet('./data/splits/val_targets.parquet.gzip', compression='gzip')