In [4]:
import json
import random
import os
from math import ceil

# Set a random seed for reproducibility
random_seed = 42
random.seed(random_seed)

task = "cpr_quality" # segmentation or classification

# Load the original annotation file
input_json_path = f'../../Annotations/main_annotation_{task}.json'  # Replace with your input JSON file path
output_dir = f'../../Annotations/splits/{task}'  # Output directory for split files


# Create the output directory if it doesn't exist
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# Load the dataset
with open(input_json_path, 'r') as f:
    data = json.load(f)



# Define train-test ratio
train_ratio = 0.7
test_ratio = 1 - train_ratio  # 30%

# Shuffle the subjects to randomize order
subjects = data['subjects']
num_subjects = len(subjects)
num_test_subjects = int(test_ratio * num_subjects)

# Divide subjects into non-overlapping test sets for each split
def create_non_overlapping_splits(subjects, num_test_subjects):
    random.shuffle(subjects)
    splits = []
    
    # Calculate number of splits needed to cover all subjects in the test set at least once
    num_splits = ceil(num_subjects / num_test_subjects)
    
    # Generate non-overlapping splits
    for i in range(num_splits):
        # Select a unique subset of subjects for the validation set
        start_index = i * num_test_subjects
        end_index = start_index + num_test_subjects
        test_subjects = subjects[start_index:end_index]
        
        # Remaining subjects for training
        train_subjects = [subj for subj in subjects if subj not in test_subjects]
        
        # Prepare split data in the required format
        train_split = {"train": ",".join([subj['subject_id'] for subj in train_subjects])}
        test_split = {"validation": ",".join([subj['subject_id'] for subj in test_subjects])}
        
        split_data = {**train_split, **test_split}
        splits.append(split_data)
    
    return splits

# Generate non-overlapping splits
splits = create_non_overlapping_splits(subjects, num_test_subjects)

# Save each split as a JSON file
for i, split_data in enumerate(splits, start=1):
    split_file_path = os.path.join(output_dir, f'split_{i}.json')
    with open(split_file_path, 'w') as f:
        json.dump(split_data, f, indent=4)
    print(f"Saved {split_file_path}")


# Verification step to ensure all subjects are covered in validation sets across splits
all_validation_subjects = set()
for split in splits:
    validation_subjects = set(split["validation"].split(","))
    all_validation_subjects.update(validation_subjects)

# Collect all unique subject IDs from the dataset
all_subject_ids = {subj['subject_id'] for subj in subjects}

# Check if all subjects are covered
if all_subject_ids == all_validation_subjects:
    print("Success: All subjects are covered in the validation sets across all splits.")
else:
    missing_subjects = all_subject_ids - all_validation_subjects
    print(f"Error: The following subjects are missing in the validation sets: {missing_subjects}")

Saved ../../Annotations/splits/cpr_quality/split_1.json
Saved ../../Annotations/splits/cpr_quality/split_2.json
Saved ../../Annotations/splits/cpr_quality/split_3.json
Saved ../../Annotations/splits/cpr_quality/split_4.json
Success: All subjects are covered in the validation sets across all splits.
