# Train Test Split

Will train test split any COCO dataset annotation file. In addition, it will perform a 5-fold cross-validation split.

In [1]:
import os
import json
import numpy as np
from tqdm import tqdm

coco_folder = '/home/jack/Mounts/DiskOne/kona_coffee/datasets/compiled_v23'
coco_file_name = 'coco.json'

train_split = 0.8
test_split = 0.1
val_split = 0.1

train_data = {
    'images': [],
    'annotations': [],
    'categories': []
}

test_data = {
    'images': [],
    'annotations': [],
    'categories': []
}

validation_data = {
    'images': [],
    'annotations': [],
    'categories': []
}

In [2]:
new_coco_data = None
with open(os.path.join(coco_folder, coco_file_name), 'r') as f:
    new_coco_data = json.load(f)

train_data['categories'] = new_coco_data['categories']
test_data['categories'] = new_coco_data['categories']
validation_data['categories'] = new_coco_data['categories']

In [3]:
grouped_default = []
grouped_augmented = []

mapped_annotations = {}
for annotation in new_coco_data['annotations']:
    if annotation['image_id'] not in mapped_annotations:
        mapped_annotations[annotation['image_id']] = []
    mapped_annotations[annotation['image_id']].append(annotation)

for image in tqdm(new_coco_data['images']):
    annotations = mapped_annotations.get(image['id'])
    if annotations is not None:
        for annotation in annotations:
            if annotation.get('extras', {}).get('augmented') is True:
                grouped_augmented.append((image, annotation))
            else:
                grouped_default.append((image, annotation))
            
np.random.shuffle(grouped_augmented)
np.random.shuffle(grouped_default)
            
print(f'Grouped default: {len(grouped_default)}')
print(f'Grouped augmented: {len(grouped_augmented)}')

100%|██████████| 1148/1148 [00:00<00:00, 9117.07it/s]

Grouped default: 6521
Grouped augmented: 31559





In [4]:
total_annotations = len(grouped_default) + len(grouped_augmented)

training_size = int(total_annotations * train_split)
test_size = int(total_annotations * test_split)
validation_size = int(total_annotations * val_split)

print(f'Training size: {training_size}')
print(f'Test size: {test_size}')
print(f'Validation size: {validation_size}')

Training size: 30464
Test size: 3808
Validation size: 3808


In [5]:
all_grouped = grouped_augmented + grouped_default

# select all augmented and some default for training
training_set = all_grouped[:training_size] 

# slice = max(0, training_size - len(grouped_augmented))
# training_set += grouped_default[:slice]

# select the rest for validation
slice = training_size
validation_set = all_grouped[slice:slice + validation_size]

# select some default for test
slice = slice + validation_size
test_set = all_grouped[slice:slice + test_size]



print('Training set size: {}'.format(len(training_set)))
print('Test set size: {}'.format(len(test_set)))
print('Validation set size: {}'.format(len(validation_set)))

for image, annotation in training_set:
    train_data['images'].append(image)
    train_data['annotations'].append(annotation)

for image, annotation in test_set:
    test_data['images'].append(image)
    test_data['annotations'].append(annotation)

for image, annotation in validation_set:
    validation_data['images'].append(image)
    validation_data['annotations'].append(annotation)

Training set size: 30464
Test set size: 3808
Validation set size: 3808


In [6]:
with open(os.path.join(coco_folder, 'train.json'), 'w') as f:
    json.dump(train_data, f)

with open(os.path.join(coco_folder, 'test.json'), 'w') as f:
    json.dump(test_data, f)
    
with open(os.path.join(coco_folder, 'validation.json'), 'w') as f:
    json.dump(validation_data, f)