In [7]:
import argparse
import os
import numpy as np
import torch
from torch.utils.data import DataLoader, Subset
from torchvision.transforms import ToTensor, Lambda
from typing import Literal
from data import CamVidDataset
from stratifiers.kfold import KFoldWrapper
from stratifiers.wdes import WDESKFold
from stratifiers.ips import IPSKFold
import random
from tqdm import tqdm

import csv
import os


In [4]:
def get_dataset(name: str, path: str):
    annotation_transform = Lambda(lambda x: torch.as_tensor(np.expand_dims(np.array(x), 0), dtype=torch.int64))
    common_args = {'split': 'train', 'image_transform': ToTensor(), 'annotation_transform': annotation_transform}
    if name == 'camvid':
        return CamVidDataset(path, **common_args)
    raise ValueError('Unsupported dataset {}'.format(name))

def get_stratifier(method: Literal['random', 'ips', 'wdes'], n_splits):
    if method == 'random':
        return KFoldWrapper(n_splits=n_splits)
    elif method == 'ips':
        return IPSKFold(n_splits=n_splits)
    elif method == 'wdes':
        return WDESKFold(n_splits=n_splits)

In [5]:
dataset = get_dataset('camvid', './dataset')

In [38]:
stratifier = get_stratifier('wdes', n_splits=10)

# Define where to save your split log
split_log_path = "split_log.csv"

# Create CSV header if not exists
if not os.path.exists(split_log_path):
    with open(split_log_path, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["fold", "split_type", "index"])

# Iterate over folds
for i, (train_idx, test_idx) in enumerate(stratifier.split(dataset)):
    # if i != 0:  # Only process the first fold for demonstration
    #     continue

    # Append to CSV for this fold
    with open(split_log_path, "a", newline="") as f:
        writer = csv.writer(f)
        # Log train indices
        for idx in train_idx:
            writer.writerow([i, "train", idx])
        # Log test indices
        for idx in test_idx:
            writer.writerow([i, "test", idx])

print(f"✅ Split log saved to {split_log_path}")

Reading dataset information for stratifier


100%|██████████| 700/700 [01:05<00:00, 10.64it/s]


Starting WDES


100%|██████████| 50/50 [00:39<00:00,  1.28it/s]

✅ Split log saved to split_log.csv





In [22]:


# Assign folds: 0–6 for train, 7 for val, 8 for test (example)
train_idx = np.concatenate([splits[i][1] for i in range(8)])
val_idx   = splits[8][1]
test_idx  = splits[9][1]

# Log to CSV
with open(split_log_path, "a", newline="") as f:
    writer = csv.writer(f)
    for idx in train_idx:
        writer.writerow(["train", idx])
    for idx in val_idx:
        writer.writerow(["val", idx])
    for idx in test_idx:
        writer.writerow(["test", idx])

print(f"✅ Train/Val/Test split log saved to {split_log_path}")

✅ Train/Val/Test split log saved to split_log.csv


In [45]:
import numpy as np
from scipy.stats import wasserstein_distance
from collections import defaultdict
import csv

# -------------------------------------------------------------
# 1️⃣ Load folds from split_log.csv
# -------------------------------------------------------------
def load_folds_from_csv(csv_path):
    folds = defaultdict(lambda: {"train": [], "test": []})
    with open(csv_path, newline="") as f:
        reader = csv.DictReader(f)
        for row in reader:
            folds[int(row["fold"])][row["split_type"]].append(int(row["index"]))
    for k in folds:
        for s in folds[k]:
            folds[k][s] = np.array(folds[k][s])

    return folds


# -------------------------------------------------------------
# 2️⃣ Sample Distribution (same as before)
# -------------------------------------------------------------
def sample_distribution(folds, proportions):
    K = len(folds)
    N = sum(len(f) for f in folds)
    expected = [r * N for r in proportions]
    sd = np.mean([abs(len(folds[k]) - expected[k]) for k in range(K)])
    return sd


# -------------------------------------------------------------
# 3️⃣ Pixel Label Distribution (adapted for image-level)
# -------------------------------------------------------------
def pixel_label_distribution_from_pixel_counts(folds, pixel_counts):
    """
    pixel_counts: np.array [N_images, num_classes]
    folds: list of arrays (each with image indices)
    """
    K = len(folds)
    C = pixel_counts.shape[1]

    total_pixels = pixel_counts.sum()
    total_class_counts = pixel_counts.sum(axis=0)
    total_ratios = total_class_counts / total_pixels

    pl_diffs = []
    for c in range(C):
        fold_diffs = []
        for fold in folds:
            fold_class_count = pixel_counts[fold, c].sum()
            fold_pixels = pixel_counts[fold, :].sum()
            fold_ratio = fold_class_count / fold_pixels
            fold_diffs.append(abs(fold_ratio - total_ratios[c]))
        pl_diffs.append(np.mean(fold_diffs))
    return np.mean(pl_diffs)


# -------------------------------------------------------------
# 4️⃣ Label Wasserstein Distance (adapted for image-level)
# -------------------------------------------------------------
def label_wasserstein_distance_from_pixel_counts(folds, pixel_counts):
    K = len(folds)
    C = pixel_counts.shape[1]

    # Normalize full dataset distribution
    total_class_dist = pixel_counts.sum(axis=0)
    total_class_dist /= total_class_dist.sum()
    total_cum = np.cumsum(total_class_dist)

    lwd_values = []
    for fold in folds:
        fold_class_dist = pixel_counts[fold].sum(axis=0)
        fold_class_dist /= fold_class_dist.sum()
        fold_cum = np.cumsum(fold_class_dist)
        lwd_values.append(np.sum(np.abs(total_cum - fold_cum)))
    return np.mean(lwd_values)


In [None]:
def evaluate_stratification_from_csv(csv_path, dataset, split_type="test"):
    """
    dataset: your image dataset, where dataset[i] -> (image, mask)
    """
    folds_dict = load_folds_from_csv(csv_path)
    folds = [folds_dict[k][split_type] for k in sorted(folds_dict.keys())]
    proportions = [1 / len(folds)] * len(folds)

    # Compute pixel counts (if not already available)
    num_classes = dataset.num_classes
    pixel_counts = np.zeros((len(dataset), num_classes))
    print("Computing pixel class counts...")
    for i in range(len(dataset)):
        _, mask = dataset[i]
        pixel_counts[i] = np.bincount(mask.flatten(), minlength=num_classes)[:num_classes]
    sd = sample_distribution(folds, proportions)
    pld = pixel_label_distribution_from_pixel_counts(folds, pixel_counts)
    lwd = label_wasserstein_distance_from_pixel_counts(folds, pixel_counts)

    return {"SD": sd, "PLD": pld, "LWD": lwd}


In [47]:
metrics = evaluate_stratification_from_csv("split_log.csv", dataset, split_type="test")

print(f"📊 Sample Distribution (SD): {metrics['SD']:.6f}")
print(f"🎨 Pixel Label Distribution (PLD): {metrics['PLD']:.6f}")
print(f"🌊 Label Wasserstein Distance (LWD): {metrics['LWD']:.6f}")


Computing pixel class counts...
[[  3283. 120707. 196459. ...      0.      0.      0.]
 [  9314. 360995. 230940. ...      0.      0.      0.]
 [     0. 361004. 469380. ...      0.      0.      0.]
 ...
 [  1818. 328768. 409934. ...      0.      0.      0.]
 [     0. 169414. 467479. ...      0.      0.      0.]
 [     0. 131953. 430543. ...      0.      0.      0.]]
📊 Sample Distribution (SD): 0.000000
🎨 Pixel Label Distribution (PLD): 0.000329
🌊 Label Wasserstein Distance (LWD): 0.008376
