In [1]:
from pathlib import Path
import duckdb
import pandas as pd

# Project root = folder that contains data/, src/, sql/, etc.
project_root = Path(".").resolve()
data_csv = project_root / "data" / "synthetic_sessions.csv"

data_csv


PosixPath('/Users/GitHub 2025/skip-song-prediction/data/synthetic_sessions.csv')

In [2]:
con = duckdb.connect(database=":memory:")

# 1) Load the CSV into a DuckDB table
con.execute(f"""
    CREATE TABLE events AS
    SELECT *
    FROM read_csv_auto('{data_csv.as_posix()}');
""")

# 2) Add a deterministic split column using a hash
#    ~70% train, 20% val, 10% test
con.execute("""
    CREATE OR REPLACE TABLE events_split AS
    SELECT
        *,
        CASE
            WHEN hash(user_id, session_id, position) % 10 < 7 THEN 'train'
            WHEN hash(user_id, session_id, position) % 10 < 9 THEN 'val'
            ELSE 'test'
        END AS split
    FROM events;
""")

# Quick sanity check: how many rows in each split?
con.execute("""
    SELECT split, COUNT(*) AS rows
    FROM events_split
    GROUP BY split
    ORDER BY split;
""").df()


Unnamed: 0,split,rows
0,test,43532
1,train,307918
2,val,87384


In [4]:
out_dir = project_root / "data"

train_df = con.execute("""
    SELECT * FROM events_split WHERE split = 'train';
""").df()

val_df = con.execute("""
    SELECT * FROM events_split WHERE split = 'val';
""").df()

test_df = con.execute("""
    SELECT * FROM events_split WHERE split = 'test';
""").df()

train_path = out_dir / "train.parquet"
val_path   = out_dir / "val.parquet"
test_path  = out_dir / "test.parquet"

train_df.to_parquet(train_path)
val_df.to_parquet(val_path)
test_df.to_parquet(test_path)

train_path, val_path, test_path


(PosixPath('/Users/GitHub 2025/skip-song-prediction/data/train.parquet'),
 PosixPath('/Users/GitHub 2025/skip-song-prediction/data/val.parquet'),
 PosixPath('/Users/GitHub 2025/skip-song-prediction/data/test.parquet'))

In [5]:
train = pd.read_parquet("data/train.parquet")
val   = pd.read_parquet("data/val.parquet")
test  = pd.read_parquet("data/test.parquet")

train.shape, val.shape, test.shape


((307918, 23), (87384, 23), (43532, 23))