In [1]:
import json
import os
import sys
from pycocotools.coco import COCO
from tqdm import tnrange, tqdm_notebook
import copy

In [2]:
class DatasetSplitter():
    """
        Parameters:
            - annotations_file: 
    """
    def __init__(self, annotations_file, annotations_output_dir, categories_list, boundaries_list):
        
        # Preliminary error checking
        assert os.path.isfile(os.getcwd() + "/" + annotations_file), "Annotations file at path {} does not exist".format(annotations_file)
        assert os.path.isdir(os.getcwd() + "/" + annotations_output_dir), "Annotations output dir at path {} does not exist".format(annotations_output_dir)
        assert categories_list is not None, "Categories list must be populated"
        assert boundaries_list is not None, "Boundaries list must be populated"
        assert len(categories_list) == len(boundaries_list), "Boundary list must be the same size as the categories list"
        
        # Assigning variables
        self.coco_json = COCO(annotations_file)
        self.annotations_output_dir = annotations_output_dir
        self.categories_list = categories_list
        self.boundaries_list = boundaries_list
        
        # Get image IDs for all images in dataset
        self.imgIds = self.coco_json.getImgIds()
        self.images = self.coco_json.loadImgs(ids = self.imgIds)
    
    def splitDataset(self):
        
        # Make coco json master list
        split_coco_json = {
            "images": [], 
            "annotations": [], 
            "categories": self.categories_list
        }
        
        # Add images to image list
        for image in self.images:
            split_coco_json["images"].append(image)
        
        # Process each annotation in each image
        for x in tnrange(len(self.imgIds), desc = 'Processing image annotations...'):
            
            # Get annotation IDs pertaining to image
            annIds = self.coco_json.getAnnIds(imgIds = self.imgIds[x])

            # Get all annotations pertaining to image
            annotations = self.coco_json.loadAnns(ids = annIds)
            
            for annotation in annotations:
                
                # Make a copy of boundaries list
                temp_list = copy.deepcopy(self.boundaries_list)
                
                # Add area to temp list, then sort and find index of area
                temp_list.append(annotation["area"])
                sorted_list = sorted(temp_list)
                index = sorted_list.index(annotation["area"]) + 1
                
                # Change category id to index and add to split coco json master dictionary
                annotation["category_id"] = index
                split_coco_json["annotations"].append(annotation)
        
        with open(self.annotations_output_dir + "/annotations_split.json", "w") as outfile:
            json.dump(split_coco_json, outfile)
        
        print("Saved annotations file to {}".format(self.annotations_output_dir))

In [7]:

categories_list = [
    
    # Area 166 - 50,000
    {
        "id": 1, 
        "name": "Small Structure", 
        "supercategory": "Structure"
    }, 
    
    # Area 50,001 - 200,000
    {
        "id": 2, 
        "name": "Medium Structure", 
        "supercategory": "Structure"
    }, 
    
    # Area 200,001 and above
    {
        "id": 3,
        "name": "Other", 
        "supercategory": "Structure"
    }
]

dataset_splitter = DatasetSplitter(    annotations_file = "datasets/Downtown_Sliced/test/annotations.json", 
                                       annotations_output_dir = "datasets/Downtown_Sliced/test", 
                                       categories_list = categories_list,
                                       boundaries_list = [50000, 200000, 900000])

loading annotations into memory...
Done (t=1.28s)
creating index...
index created!


In [8]:
dataset_splitter.splitDataset()

HBox(children=(IntProgress(value=0, description='Processing image annotations...', max=284, style=ProgressStyl…


Saved annotations file to datasets/Downtown_Sliced/test
