In [None]:
import json
import random
import os

# Set a random seed for reproducibility
random_seed = 42
random.seed(random_seed)

# Load the original annotation file
input_json_path = '../../Annotations/main_annotation.json'  # Replace with your input JSON file path
output_dir = '../../Annotations/splits/trials'  # Output directory for split files

# Create the output directory if it doesn't exist
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# Load the dataset
with open(input_json_path, 'r') as f:
    data = json.load(f)

# Create empty lists for train, val, and test
train_set = {"subjects": []}
val_set = {"subjects": []}
test_set = {"subjects": []}

# Split ratio for train, validation, and test
train_ratio = 0.6
val_ratio = 0.2
test_ratio = 0.2

# Shuffle the data before splitting
for subject in data['subjects']:
    trials = subject['trials']
    random.shuffle(trials)
    
    # Determine split sizes
    num_trials = len(trials)
    num_train = int(train_ratio * num_trials)
    num_val = int(val_ratio * num_trials)
    num_test = num_trials - num_train - num_val

    # Create subsets for this subject
    train_trials = trials[:num_train]
    val_trials = trials[num_train:num_train + num_val]
    test_trials = trials[num_train + num_val:]

    # Add trials to respective sets
    if train_trials:
        train_set['subjects'].append({
            "subject_id": subject["subject_id"],
            "expertise_level": subject["expertise_level"],
            "trials": train_trials
        })

    if val_trials:
        val_set['subjects'].append({
            "subject_id": subject["subject_id"],
            "expertise_level": subject["expertise_level"],
            "trials": val_trials
        })

    if test_trials:
        test_set['subjects'].append({
            "subject_id": subject["subject_id"],
            "expertise_level": subject["expertise_level"],
            "trials": test_trials
        })

# Save split JSON files
def save_split_file(split_data, split_name):
    with open(os.path.join(output_dir, f'{split_name}_split.json'), 'w') as f:
        json.dump(split_data, f, indent=4)

# Save the train, validation, and test splits
save_split_file(train_set, 'train')
save_split_file(val_set, 'val')
save_split_file(test_set, 'test')

print(f"Splits saved in directory: {output_dir}")


Splits saved in directory: ../../Annotations/splits/trials
