In [76]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
from skimage import measure
from skimage.color import label2rgb
from pycocotools.coco import COCO

import json
import os

In [77]:
import random
from collections import defaultdict

In [78]:
def get_anns_by_image_id(image_id, original_anns):
    result = []
    for ann in original_anns:
        if ann['image_id'] == image_id:
            result.append(ann)
    return result
    

In [79]:
def copy_dict_to_set(dict_set, key, original_data):
    dict_set[key] = original_data[key]

In [80]:
def save_dict_to_json(name, dict_set):
    with open(f"../{name}.json", "w") as f:
        json.dump(dict_set, f, indent=4)

In [81]:
train_set = defaultdict(list) #60%
val_set = defaultdict(list) #20%
test_set = defaultdict(list) #20%

In [82]:
with open(f'../combined_dict_unique.json') as file_object:
        data = json.load(file_object)
        
        print(f"total images {len(data['images'])}")
        train_size = int(len(data['images'])*0.6)
        val_size = int(len(data['images'])*0.2)
        test_size = int(len(data['images'])*0.2)
        
        print(f"training set size = {train_size}") #10788
        print(f"val set size = {val_size}") #3596
        print(f"test set size = {test_size}") #3596
        
        # print(data['images'][:10])
        random.shuffle(data['images'])
        # # verify shuffled
        # print(data['images'][:10])
        
        
        train_set['images'] = data['images'][:train_size]
        # print(len(train_set['images']))
        val_set['images'] = data['images'][train_size:train_size+val_size]
        # print(len(val_set['images']))
        test_set['images'] = data['images'][train_size+val_size:train_size+val_size+test_size]
        # print(len(test_set['images']))

        # # verify no overlap in first and last items
        # print(train_set['images'][-1])
        # print(val_set['images'][0])
        # print(val_set['images'][-1])
        # print(test_set['images'][0])

        for img in train_set['images']:
            train_set['annotations'].extend(get_anns_by_image_id(img['id'], data['annotations']))
            
        for img in val_set['images']:
            val_set['annotations'].extend(get_anns_by_image_id(img['id'], data['annotations']))
            
        for img in test_set['images']:
            test_set['annotations'].extend(get_anns_by_image_id(img['id'], data['annotations']))
            
        # print(len(train_set['annotations']))
        # print(len(val_set['annotations']))
        # print(len(test_set['annotations']))
        
        for key in data.keys():
            if key != 'images' and key != 'annotations':
                copy_dict_to_set(train_set, key, data)
                copy_dict_to_set(val_set, key, data)
                copy_dict_to_set(test_set, key, data)
        
        save_dict_to_json('train_set', train_set)
        save_dict_to_json('val_set', val_set)
        save_dict_to_json('test_set', test_set)

total images 17980
training set size = 10788
val set size = 3596
test set size = 3596
