In [3]:
#!/usr/bin/env python3
import json
import random
import os
from collections import defaultdict

# reproducibility
random_seed = 42
random.seed(random_seed)

# task and paths
task = "classification"  # "classification" or "segmentation"
task = "segmentation"  # "classification" or "segmentation"
task = "cpr_quality"  # "classification" or "segmentation"
input_json_path = f'../../Annotations/aaai26_main_annotation_{task}.json'
output_dir = '../../Annotations/splits/trials/'

# ensure output directory exists
os.makedirs(output_dir, exist_ok=True)

# load the full annotation
with open(input_json_path, 'r') as f:
    data = json.load(f)

# prepare split containers
train_set = {"subjects": []}
val_set   = {"subjects": []}
test_set  = {"subjects": []}

# split ratios
train_ratio = 0.6
val_ratio   = 0.2
test_ratio  = 0.2

def build_subject_split(subject, entries):
    """
    Given a list of {'scenario_id', 'trial'} dicts, rebuild the nested subject entry.
    """
    by_scenario = defaultdict(list)
    for e in entries:
        by_scenario[e['scenario_id']].append(e['trial'])
    split_subj = {
        "subject_id":      subject["subject_id"],
        "expertise_level": subject.get("expertise_level", ""),
        "scenarios":       []
    }
    for scen_id, trials in by_scenario.items():
        split_subj["scenarios"].append({
            "scenario_id": scen_id,
            "trials":      trials
        })
    return split_subj

# iterate each subject, flatten their trials across scenarios, shuffle & split
for subject in data.get('subjects', []):
    # flatten scenarioâ†’trial into entries
    entries = []
    for scenario in subject.get('scenarios', []):
        scen_id = scenario['scenario_id']
        for trial in scenario.get('trials', []):
            entries.append({
                "scenario_id": scen_id,
                "trial":       trial
            })

    random.shuffle(entries)
    n_total = len(entries)
    n_train = int(train_ratio * n_total)
    n_val   = int(val_ratio * n_total)
    # rest goes to test
    n_test  = n_total - n_train - n_val

    train_entries = entries[:n_train]
    val_entries   = entries[n_train:n_train + n_val]
    test_entries  = entries[n_train + n_val:]

    # build & append subject splits if nonempty
    if train_entries:
        train_set['subjects'].append(build_subject_split(subject, train_entries))
    if val_entries:
        val_set['subjects'].append(build_subject_split(subject, val_entries))
    if test_entries:
        test_set['subjects'].append(build_subject_split(subject, test_entries))

# helper to save each split
def save_split(split_data, name):
    path = os.path.join(output_dir, f'aaai26_{name}_split_{task}.json')
    with open(path, 'w') as f:
        json.dump(split_data, f, indent=4)
    print(f"Saved {name} split to {path}")

# write out all three
save_split(train_set, 'train')
save_split(val_set,   'val')
save_split(test_set,  'test')


Saved train split to ../../Annotations/splits/trials/aaai26_train_split_cpr_quality.json
Saved val split to ../../Annotations/splits/trials/aaai26_val_split_cpr_quality.json
Saved test split to ../../Annotations/splits/trials/aaai26_test_split_cpr_quality.json
