In [1]:
import os
from portiloopml.portiloop_python.ANN.data.mass_data_new import SubjectLoader
import random

In [2]:
dataset_path = '/project/MASS/mass_spindles_dataset/'

In [3]:
subject_loader = SubjectLoader(
    os.path.join(dataset_path, 'subject_info.csv'))

In [67]:
def get_subjects_folds(fold_num, test_subjects_per_fold=28, seed=42):
    '''
    Get the subjects for a specific fold

    Args:
    - fold_num: int
        The fold number in [0, 1, 2, 3, 4]
    '''
    # Get the number of subjects
    num_subjects_total = len(subject_loader.select_all_subjects()) - 1

    # Generate 5 different sets of 28 subjects each with balanced age distribution
    fold_test_subjects = []
    subjects_sampled_so_far = []
    for i in range(5):
        # On the last fold, we simply take whatever is left
        if i == 4:
            subjects_left = subject_loader.select_random_subjects(
                num_subjects=test_subjects_per_fold,
                exclude=subjects_sampled_so_far
            )
            fold_test_subjects.append(subjects_left)
            break

        young_test = subject_loader.select_subjects_age(
            min_age=0,
            max_age=30,
            num_subjects=test_subjects_per_fold // 2,
            seed=seed,
            exclude=subjects_sampled_so_far)
        
        old_test = subject_loader.select_subjects_age(
            min_age=40,
            max_age=100,
            num_subjects=test_subjects_per_fold // 2,
            seed=seed,
            exclude=subjects_sampled_so_far)
        
        subjects_sampled_so_far += young_test + old_test
        fold_test_subjects.append(young_test + old_test)

    # Get the subjects for the fold
    subjects_test = fold_test_subjects[fold_num] 
    
    # Get 6 young and 6 old subjects for the validation set 
    young_val_subjects = subject_loader.select_subjects_age(
        min_age=0,
        max_age=30,
        num_subjects=6,
        seed=seed,
        exclude=subjects_test)
    old_val_subjects = subject_loader.select_subjects_age(
        min_age=40,
        max_age=100,
        num_subjects=6,
        seed=seed,
        exclude=subjects_test + young_val_subjects)
    
    subjects_val = young_val_subjects + old_val_subjects
    
    # Get the subjects for the training set
    all_subjects_so_far = subjects_test + subjects_val
    subjects_train = subject_loader.select_random_subjects(
        num_subjects=num_subjects_total - len(all_subjects_so_far),
        exclude=all_subjects_so_far
    )

    return subjects_train, subjects_val, subjects_test

In [70]:
for i in range(5):
    train, val, test = get_subjects_folds(4)

    # Make sure there are no duplicates
    assert len(train) == len(set(train))
    assert len(val) == len(set(val))
    assert len(test) == len(set(test))

    # Assert there are no subjects in common in between the sets
    assert len(set(train).intersection(set(val))) == 0
    assert len(set(train).intersection(set(test))) == 0
    assert len(set(val).intersection(set(test))) == 0

# Make sure there are no subjects in common in the test set of the different folds
for i in range(5):
    for j in range(5):
        if i != j:
            assert len(set(get_subjects_folds(i)[2]).intersection(set(get_subjects_folds(j)[2]))) == 0