In [4]:
#!/usr/bin/env python3
import json
import random
import os
from math import ceil

# reproducibility
random_seed = 42
random.seed(random_seed)

# task and paths
task = "cpr_quality"
input_json_path = f'../../Annotations/aaai26_main_annotation_{task}.json'
output_dir = f'../../Annotations/splits/{task}/'

# ensure output directory exists
os.makedirs(output_dir, exist_ok=True)

# load the full annotation
with open(input_json_path, 'r') as f:
    data = json.load(f)

# grab the list of subjects
subjects = data.get('subjects', [])
num_subjects = len(subjects)

# compute how many subjects should go into each validation fold
train_ratio = 0.7
num_test_subjects = int((1 - train_ratio) * num_subjects)

def create_non_overlapping_splits(subjects, n_test):
    """
    Shuffle the subjects and produce as many non-overlapping
    train/validation splits as needed to cover every subject
    at least once in validation.
    """
    pool = subjects[:]          # copy
    random.shuffle(pool)
    splits = []
    num_splits = ceil(len(pool) / n_test)

    for i in range(num_splits):
        start = i * n_test
        end   = start + n_test
        val_subjs = pool[start:end]
        train_subjs = [s for s in pool if s not in val_subjs]
        splits.append({
            "train":      ",".join(s['subject_id'] for s in train_subjs),
            "validation": ",".join(s['subject_id'] for s in val_subjs)
        })
    return splits

# build the splits
splits = create_non_overlapping_splits(subjects, num_test_subjects)

# write out each split
for idx, split in enumerate(splits, start=1):
    path = os.path.join(output_dir, f'aaai26_split_{idx}.json')
    with open(path, 'w') as f:
        json.dump(split, f, indent=4)
    print(f"Saved split #{idx} to {path}")

# verification: ensure every subject appears in at least one validation set
all_val = set()
for sp in splits:
    all_val |= set(sp['validation'].split(','))
all_ids = {s['subject_id'] for s in subjects}

if all_val == all_ids:
    print("Success: All subjects are covered in validation folds.")
else:
    missing = all_ids - all_val
    print(f"Warning: these subjects never appeared in validation: {missing}")


Saved split #1 to ../../Annotations/splits/cpr_quality/aaai26_split_1.json
Saved split #2 to ../../Annotations/splits/cpr_quality/aaai26_split_2.json
Saved split #3 to ../../Annotations/splits/cpr_quality/aaai26_split_3.json
Saved split #4 to ../../Annotations/splits/cpr_quality/aaai26_split_4.json
Success: All subjects are covered in validation folds.
