In [1]:
# 0) Setup
# ─────────────────────────────────────────────────────────────────────────────
%run setup.py  # adds project root and src/ to sys.path, exposes log_event, LogKind

import time
import pandas as pd
from sklearn.model_selection import GroupShuffleSplit
from pathlib import Path
from src.logs import log_event, LogKind  # noqa: F401

# 1) Timing & load
# ─────────────────────────────────────────────────────────────────────────────
start_time = time.time()
df = pd.read_csv("../data/quora.csv")  # original QQP data

# 2) Group‐aware splitting key
# ─────────────────────────────────────────────────────────────────────────────
# Ensure no question ID leaks across splits
df["group"] = df[["qid1", "qid2"]].min(axis=1)

# 3) 80 / 10 / 10 train/valid/test split
# ─────────────────────────────────────────────────────────────────────────────
gss = GroupShuffleSplit(n_splits=1, test_size=0.20, random_state=13)
train_idx, temp_idx = next(
    gss.split(df, df["is_duplicate"], df["group"])
)
df_train = df.iloc[train_idx].reset_index(drop=True)
df_temp  = df.iloc[temp_idx].reset_index(drop=True)

gss2 = GroupShuffleSplit(n_splits=1, test_size=0.50, random_state=13)
valid_idx, test_idx = next(
    gss2.split(df_temp, df_temp["is_duplicate"], df_temp["group"])
)
df_valid = df_temp.iloc[valid_idx].reset_index(drop=True)
df_test  = df_temp.iloc[test_idx].reset_index(drop=True)

# 4) Save splits
# ─────────────────────────────────────────────────────────────────────────────
out_dir = Path("../data/splits")
out_dir.mkdir(parents=True, exist_ok=True)
df_train.to_csv(out_dir / "train.csv", index=False)
df_valid.to_csv(out_dir / "valid.csv", index=False)
df_test.to_csv(out_dir / "test.csv", index=False)

# 5) Log split information
# ─────────────────────────────────────────────────────────────────────────────
split_time = time.time() - start_time
log_event(
    LogKind.SPLIT,
    train_count=len(df_train),
    valid_count=len(df_valid),
    test_count=len(df_test),
    duration_s=f"{split_time:.2f}"
)

print("Saved splits to:", list(out_dir.iterdir()))
print(f"Split durations logged to metric_logs/splits.txt (took {split_time:.2f}s).")

Saved splits to: [PosixPath('../data/splits/train.csv'), PosixPath('../data/splits/valid.csv'), PosixPath('../data/splits/test.csv')]
Split durations logged to metric_logs/splits.txt (took 1.05s).
