## Test Dataset splitting tools

### Load metadata & labels (images are not necessary)


In [7]:
import pandas as pd
import numpy as np
import split_dataset as split

df_file = "/mnt/nas_device_0/fmow/cleaned_trainval_labels.pkl"
df = pd.read_pickle(df_file)
metadata = df.to_dict(orient="list")
labels = metadata.pop("category")
num_folds = 5

For prototype, we want to check that three conditions are satisfied after splitting the dataset:

- No leakage between train and validation splits
- Stratified splits maintain the same (normalized) class frequency as the unsplit dataset
- No "group leakage" between train and validation if groups are specified
  - i.e. If "Country Code" is specified, images from the same country do not leak between train and validation sets


In [8]:
def test_sample_leakage(splits):
    test_inds = set(splits.get("test", []))
    print("\nChecking for Sample Leakage")
    for foldname, folddict in splits.items():
        if foldname == "test":
            continue
        train_inds = set(folddict["train"])
        val_inds = set(folddict["val"])
        assert test_inds.isdisjoint(train_inds), "common elements between train and test"
        assert test_inds.isdisjoint(val_inds), "common elements between val and test"
        assert val_inds.isdisjoint(train_inds), "common elements between train and val"
        print(f"{foldname} ok. No sample leakage detected")


def test_stratification(splits, all_labels):
    unique_labels, label_counts = np.unique(all_labels, return_counts=True)
    label_freq = label_counts / label_counts.sum()
    test_inds = splits.get("test")
    print("\nChecking label stratification")
    if test_inds is not None:
        test_labels = all_labels[test_inds]
        unique_test, test_counts = np.unique(test_labels, return_counts=True)
        test_freq = test_counts / test_counts.sum()
        assert len(unique_test) == len(unique_labels), "Test set does not contain all labels"
        assert (unique_test == unique_labels).all(), "Mismatch between test labels and all labels"
        assert np.allclose(test_freq, label_freq, rtol=0.05, atol=1/len(label_freq)), "Test set difference greater than tolerance"
        print("Test split ok.")
    for foldname, folddict in splits.items():
        if foldname == "test":
            continue
        train_labels = all_labels[folddict["train"]]
        unique_train, train_counts = np.unique(train_labels, return_counts=True)
        train_freq = train_counts / train_counts.sum()
        assert len(unique_train) == len(unique_labels), "Test set does not contain all labels"
        assert (unique_train == unique_labels).all(), "Mismatch between test labels and all labels"
        assert np.allclose(train_freq, label_freq, rtol=0.05), "Test set difference greater than 5%"
        print(f"{foldname} ok. Class frequencies match")



def test_group_leakage(splits, groups):
    test_inds = splits.get("test", [])
    test_groups = set(groups[test_inds])
    print("\nChecking Group Leakage")
    for foldname, folddict in splits.items():
        if foldname == "test": 
            continue
        train_groups = set(groups[folddict["train"]])
        val_groups = set(groups[folddict["val"]])
        assert test_groups.isdisjoint(train_groups), "common groups between train and test"
        assert test_groups.isdisjoint(val_groups),   "common groups between val and test"
        assert val_groups.isdisjoint(train_groups),  "common groups between train and val"
        print(f"{foldname} ok. No group leakage detected")


### Test unstratified, ungrouped labels (Regular K-Fold)


In [9]:
vanilla_folds = split.split_dataset(labels=labels, num_folds=5, test_frac=0.15)
test_sample_leakage(vanilla_folds)


Checking for Sample Leakage
fold_0 ok. No sample leakage detected
fold_1 ok. No sample leakage detected
fold_2 ok. No sample leakage detected
fold_3 ok. No sample leakage detected
fold_4 ok. No sample leakage detected


### Test unstratified, grouped labels (Group K-Fold)


In [10]:
grouped_folds = split.split_dataset(
    labels=labels, num_folds=5, test_frac=0.15, split_on=["country_code"], metadata=metadata)
test_sample_leakage(grouped_folds)
test_group_leakage(grouped_folds, np.array(metadata["country_code"]))


Checking for Sample Leakage
fold_0 ok. No sample leakage detected
fold_1 ok. No sample leakage detected
fold_2 ok. No sample leakage detected
fold_3 ok. No sample leakage detected
fold_4 ok. No sample leakage detected

Checking Group Leakage
fold_0 ok. No group leakage detected
fold_1 ok. No group leakage detected
fold_2 ok. No group leakage detected
fold_3 ok. No group leakage detected
fold_4 ok. No group leakage detected


### Test stratified, ungrouped labels (Stratified K-Fold)


In [11]:
strat_folds = split.split_dataset(
    labels=labels, num_folds=5, test_frac=0.15, stratified=True)
test_sample_leakage(strat_folds)
test_stratification(strat_folds, np.array(labels))


Checking for Sample Leakage
fold_0 ok. No sample leakage detected
fold_1 ok. No sample leakage detected
fold_2 ok. No sample leakage detected
fold_3 ok. No sample leakage detected
fold_4 ok. No sample leakage detected

Checking label stratification
Test split ok.
fold_0 ok. Class frequencies match
fold_1 ok. Class frequencies match
fold_2 ok. Class frequencies match
fold_3 ok. Class frequencies match
fold_4 ok. Class frequencies match


### Test Stratified, Grouped Labels (Stratified Group K-Fold)

We don't test stratified because we cannot guarantee that class labels will be identically
stratified among groups. The results from stratified grouped k-fold are the best the splitting
function can achieve.


In [12]:
strat_grouped = split.split_dataset(
    labels=labels, 
    num_folds=5, 
    test_frac=0.15, 
    split_on=["country_code"], 
    metadata=metadata, 
    stratified=True
)
test_sample_leakage(strat_grouped)
test_group_leakage(strat_grouped, np.array(metadata["country_code"]))


Checking for Sample Leakage
fold_0 ok. No sample leakage detected
fold_1 ok. No sample leakage detected
fold_2 ok. No sample leakage detected
fold_3 ok. No sample leakage detected
fold_4 ok. No sample leakage detected

Checking Group Leakage
fold_0 ok. No group leakage detected
fold_1 ok. No group leakage detected
fold_2 ok. No group leakage detected
fold_3 ok. No group leakage detected
fold_4 ok. No group leakage detected
