
# Multi-label Stratified Split (80/10/10)

**What this notebook does (simple):**
- Reads `labels/dataset.csv` (one row per image, 0/1 columns per label).
- Uses **iterative stratification** to preserve label balance across Train/Val/Test.
- Saves `labels/splits/train.csv`, `val.csv`, `test.csv`.

> Run this from your project root where `labels/dataset.csv` exists.


In [None]:

# 1) Install dependencies (run once)
# If you already have these, you can skip this cell.
%pip install iterative-stratification pandas


In [1]:

# 2) Do the stratified split
from pathlib import Path
import pandas as pd
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit

CSV = Path("labels/dataset.csv")
OUT = Path("labels/splits")
OUT.mkdir(parents=True, exist_ok=True)

# Adjust ratios if you want
SPLITS = {"train": 0.80, "val": 0.10, "test": 0.10}

df = pd.read_csv(CSV)
label_cols = [c for c in df.columns if c != "path"]
X = df[["path"]].values
Y = df[label_cols].values

# 1) Train vs (Val+Test)
m1 = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=1 - SPLITS["train"], random_state=42)
tr_idx, tmp_idx = next(m1.split(X, Y))
df_tr, df_tmp = df.iloc[tr_idx], df.iloc[tmp_idx]

# 2) Val vs Test from temp
val_ratio = SPLITS["val"] / (SPLITS["val"] + SPLITS["test"])
m2 = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=1 - val_ratio, random_state=43)
v_idx, te_idx = next(m2.split(df_tmp[["path"]].values, df_tmp[label_cols].values))
df_val, df_test = df_tmp.iloc[v_idx], df_tmp.iloc[te_idx]

df_tr.to_csv(OUT / "train.csv", index=False)
df_val.to_csv(OUT / "val.csv", index=False)
df_test.to_csv(OUT / "test.csv", index=False)

len(df_tr), len(df_val), len(df_test)


(120, 15, 16)

In [2]:

# 3) Quick sanity check: label frequencies per split
import pandas as pd
from pathlib import Path

OUT = Path("labels/splits")
for name in ["train","val","test"]:
    df = pd.read_csv(OUT / f"{name}.csv")
    label_cols = [c for c in df.columns if c != "path"]
    counts = df[label_cols].sum().sort_values(ascending=False)
    print(f"\n=== {name.upper()} ({len(df)}) ===")
    print(counts)



=== TRAIN (120) ===
Littering / Garbage         68
Damaged Road Surface        29
Illegal Parking             22
Pothole                     13
None of the above            1
Broken/Damaged Road Sign     0
Vandalism / Graffiti         0
dtype: int64

=== VAL (15) ===
Littering / Garbage         9
Damaged Road Surface        3
Illegal Parking             3
Pothole                     1
Broken/Damaged Road Sign    0
Vandalism / Graffiti        0
None of the above           0
dtype: int64

=== TEST (16) ===
Littering / Garbage         9
Damaged Road Surface        4
Illegal Parking             3
Pothole                     2
Broken/Damaged Road Sign    0
Vandalism / Graffiti        0
None of the above           0
dtype: int64
