In [4]:
import pandas as pd
import random
from pathlib import Path
from typing import List, Tuple, Optional


def generate_split_csv(
    all_patients: List[str],
    split_percentages: Tuple[int, int, int],
    output_csv: Path,
    skip_patients: Optional[List[str]] = None,
    seed: Optional[int] = None,
) -> None:
    """
    Generate or update a patient split CSV with train/val/test/skip assignments.

    Patients already in the CSV retain their split.
    New patients are randomized and assigned to splits using the input percentages.
    Skipped patients are included in the randomization, but reassigned to 'skip' after.

    Args:
        all_patients: List of all patient IDs to consider.
        split_percentages: Tuple of (train%, val%, test%) that sums to 100.
        output_csv: Where to save the final CSV.
        skip_patients: List of patient IDs to mark as 'skip'.
        seed: Optional random seed for reproducibility.
    """
    assert sum(split_percentages) == 100, "Split percentages must sum to 100"
    skip_patients = set(skip_patients or [])
    all_patients = sorted(set(all_patients))
    rng = random.Random(seed)

    # Load existing CSV if available
    if output_csv.exists():
        existing_df = pd.read_csv(output_csv)
        existing_assignments = dict(zip(existing_df.patient_id, existing_df.split))
        print(f"Loaded existing split with {len(existing_assignments)} patients.")
    else:
        existing_assignments = {}

    # Determine new patients
    fixed_patients = set(existing_assignments)
    new_patients = [p for p in all_patients if p not in fixed_patients]

    # Randomize all new patients (including skipped)
    rng.shuffle(new_patients)

    # Assign all new patients as if none are skipped
    n_total = len(new_patients)
    n_train = int(split_percentages[0] / 100 * n_total)
    n_val = int(split_percentages[1] / 100 * n_total)
    n_test = n_total - n_train - n_val

    split_assignments = (
        ["train"] * n_train +
        ["validation"] * n_val +
        ["test"] * n_test
    )
    new_assignments = dict(zip(new_patients, split_assignments))

    # After assigning, change skipped patients to "skip"
    for patient in skip_patients:
        if patient in new_assignments:
            new_assignments[patient] = "skip"

    # Merge with existing assignments
    combined = existing_assignments.copy()
    combined.update(new_assignments)

    # Output to DataFrame
    df_out = pd.DataFrame(sorted(combined.items()), columns=["patient_id", "split"])
    output_csv.parent.mkdir(parents=True, exist_ok=True)
    df_out.to_csv(output_csv, index=False)

    # Report
    split_counts = df_out["split"].value_counts().to_dict()
    print("Final split counts:", split_counts)
    print(f"CSV written to {output_csv}")


In [5]:
from pathlib import Path
import os

all_patients_path = Path('/home/ayeluru/vascular-superenhancement-4d-flow/working_dir/all_patients/patient_data')

all_patients = []
for patient in os.listdir(all_patients_path):
    all_patients.append(patient)

print(all_patients)



skip_patients = ['Amupam', 'Bephedou', 'Bigeral', 'Ceymuslek', 'Fisale', 'Ganage', 'Githucu', 'Gudabee', 'Gujarjoy', 'Hapase', 'Haseyad', 'Hidistoy', 'Ikatus', 'Johijap', 'Molotoze', 'Nocucuech', 'Nounuri', 'Otikek', 'Pegiqui', 'Pibenos', 'Qualinad', 'Quibanid', 'Quofoustush', 'Rekana', 'Riqueli', 'Swunufod', 'Teeroosay', 'Tehupoug', 'Ubazquis', 'Waruese', 'Wukazu', 'Yidepi']
split_percentages = (80, 10, 10)
output_csv = Path('/home/ayeluru/vascular-superenhancement-4d-flow/splits/splits_07-21-25.csv')


['Dinaspig', 'Rekana', 'Tehupoug', 'Tieriegi', 'Cornuefor', 'Cefaru', 'Tiepolem', 'Upbapit', 'Suedrurnep', 'Tupotu', 'Bibathot', 'Kusksusuth', 'Noyetut', 'Gudulung', 'Pemiggig', 'Ribieko', 'Kihaquo', 'Crutaswo', 'Slathalu', 'Sithiroo', 'Heynelap', 'Qualinad', 'Trecuquem', 'Badiswu', 'Osumud', 'Suquepog', 'Nocucuech', 'Omeflok', 'Hucisot', 'Mugukood', 'Bephedou', 'Frahidiel', 'Uzegbex', 'Jiesieja', 'Phidecro', 'Quogosi', 'Nosite', 'Kadiforn', 'Cadedag', 'Datokif', 'Phunitul', 'Hekita', 'Tatufok', 'Nounuri', 'Duquestank', 'Quehokee', 'Sahuhe', 'Rayoloth', 'Golotag', 'Quinudoul', 'Ostablip', 'Pibenos', 'Nufotoy', 'Otikek', 'Bovutou', 'Fisale', 'Gudabee', 'Getahig', 'Gemapoey', 'Jufamich', 'Hetoprist', 'Socrowa', 'Glomahe', 'Eporid', 'Dapafem', 'Ceymuslek', 'Diecudey', 'Johijap', 'Gujarjoy', 'Strousuhe', 'Preflefi', 'Urpofooy', 'Jenoomer', 'Githucu', 'Udmodas', 'Kopejek', 'Phahofi', 'Ceriba', 'Hicuhe', 'Fudoquo', 'Dalibul', 'Thameci', 'Pubayue', 'Waruese', 'Ekotey', 'Pegiqui', 'Kofudun', '

In [7]:
generate_split_csv(
    all_patients=all_patients,
    split_percentages=split_percentages,
    output_csv=output_csv,
    skip_patients=skip_patients,
    seed=42
)

Final split counts: {'train': 117, 'skip': 32, 'test': 17, 'validation': 15}
CSV written to /home/ayeluru/vascular-superenhancement-4d-flow/splits/splits_07-21-25.csv
