In [1]:
import json
from pathlib import Path
from pycocotools.coco import COCO

In [5]:
class CocoFilter():
    """ Filters the COCO dataset
    """
    def _process_info(self):
        self.info = self.coco['info']
        
    def _process_licenses(self):
        self.licenses = self.coco['licenses']
        
    def _process_categories(self):
        self.categories = dict()
        self.super_categories = dict()
        self.category_set = set()

        for category in self.coco['categories']:
            cat_id = category['id']
            super_category = category['supercategory']
            
            # Add category to categories dict
            if cat_id not in self.categories:
                self.categories[cat_id] = category
                self.category_set.add(category['name'])
            else:
                print(f'ERROR: Skipping duplicate category id: {category}')
            
            # Add category id to the super_categories dict
            if super_category not in self.super_categories:
                self.super_categories[super_category] = {cat_id}
            else:
                self.super_categories[super_category] |= {cat_id} # e.g. {1, 2, 3} |= {4} => {1, 2, 3, 4}

    def _process_images(self):
        self.images = dict()
        for image in self.coco['images']:
            image_id = image['id']
            if image_id not in self.images:
                self.images[image_id] = image
            else:
                print(f'ERROR: Skipping duplicate image id: {image}')
                
    def _process_segmentations(self):
        self.segmentations = dict()
        for segmentation in self.coco['annotations']:
            image_id = segmentation['image_id']
            if image_id not in self.segmentations:
                self.segmentations[image_id] = []
            self.segmentations[image_id].append(segmentation)

    def _filter_categories(self):
        """ Find category ids matching args
            Create mapping from original category id to new category id
            Create new collection of categories
        """
        missing_categories = set(self.filter_categories) - self.category_set
        if len(missing_categories) > 0:
            print(f'Did not find categories: {missing_categories}')
            should_continue = input('Continue? (y/n) ').lower()
            if should_continue != 'y' and should_continue != 'yes':
                print('Quitting early.')
                quit()

        self.new_category_map = dict()
        new_id = 1
        for key, item in self.categories.items():
            if item['name'] in self.filter_categories:
                self.new_category_map[key] = key
                new_id += 1

        self.new_categories = []
        for original_cat_id, new_id in self.new_category_map.items():
            new_category = dict(self.categories[original_cat_id])
            new_category['id'] = new_id
            self.new_categories.append(new_category)
        print(self.new_category_map)

    def _filter_annotations(self):
        """ Create new collection of annotations matching category ids
            Keep track of image ids matching annotations
        """
        self.new_segmentations = []
        self.new_image_ids = set()
        for image_id, segmentation_list in self.segmentations.items():
            for segmentation in segmentation_list:
                original_seg_cat = segmentation['category_id']
                if original_seg_cat in self.new_category_map.keys():
                    new_segmentation = dict(segmentation)
                    new_segmentation['category_id'] = self.new_category_map[original_seg_cat]
                    self.new_segmentations.append(new_segmentation)
                    self.new_image_ids.add(image_id)

    def _filter_images(self):
        """ Create new collection of images
        """
        self.new_images = []
        for image_id in self.new_image_ids:
            self.new_images.append(self.images[image_id])

    def main(self, input_json,output_json, categories):
        # Open json
        self.input_json_path = Path(input_json)
        self.output_json_path = Path(output_json)
        self.filter_categories = categories

        # Verify input path exists
        if not self.input_json_path.exists():
            print('Input json path not found.')
            print('Quitting early.')
            quit()

        # Verify output path does not already exist
        if self.output_json_path.exists():
            should_continue = input('Output path already exists. Overwrite? (y/n) ').lower()
            if should_continue != 'y' and should_continue != 'yes':
                print('Quitting early.')
                quit()
        
        # Load the json
        print('Loading json file...')
        with open(self.input_json_path) as json_file:
            self.coco = json.load(json_file)
        print(self.coco.keys())
        
        # Process the json
        print('Processing input json...')
        self._process_info()
        #self._process_licenses()
        self._process_categories()
        self._process_images()
        self._process_segmentations()
        
        

        # Filter to specific categories
        print('Filtering...')
        self._filter_categories()
        #print(self.new_categories)
        
        self._filter_annotations()
        self._filter_images()

        print(len(self.new_images))
        
        # Build new JSON
        new_master_json = {
            'info': self.info,
            'images': self.new_images,
            'annotations': self.new_segmentations,
            'categories': self.new_categories
        }

        # Write the JSON to a file
        print('Saving new json file...')
        with open(self.output_json_path, 'w+') as output_file:
            json.dump(new_master_json, output_file)

        print('Filtered json saved.')
        
        

In [8]:
input_json = 'data/val/annotations.json'
output_json = 'data/val/new_top50cat_val_filter.json'
categories = ['water', 'salad-leaf-salad-green', 'bread-white', 'tomato-raw', 'butter', 'carrot-raw', 'bread-wholemeal', 'coffee-with-caffeine', 'rice', 'egg', 'mixed-vegetables', 'apple', 'jam', 'cucumber', 'wine-red', 'banana', 'cheese', 'potatoes-steamed', 'bell-pepper-red-raw', 'hard-cheese', 'espresso-with-caffeine', 'tea', 'bread-whole-wheat', 'mixed-salad-chopped-without-sauce', 'avocado', 'white-coffee-with-caffeine', 'tomato-sauce', 'wine-white', 'broccoli', 'strawberries', 'pasta-spaghetti', 'honey', 'zucchini', 'parmesan', 'chicken', 'chips-french-fries', 'braided-white-loaf', 'dark-chocolate', 'mayonnaise', 'pizza-margherita-baked', 'blueberries', 'onion', 'salami', 'leaf-spinach', 'soft-cheese', 'salmon', 'water-mineral', 'gruyere', 'glucose-drink-50g', 'yaourt-yahourt-yogourt-ou-yoghourt-natural']
cf = CocoFilter()
cf.main(input_json,output_json,categories)

Loading json file...
dict_keys(['categories', 'info', 'images', 'annotations'])
Processing input json...
Filtering...
{1565: 1565, 2099: 2099, 2578: 2578, 1154: 1154, 1352: 1352, 1310: 1310, 2512: 2512, 2498: 2498, 1056: 1056, 2022: 2022, 1013: 1013, 1788: 1788, 1069: 1069, 1085: 1085, 1078: 1078, 2738: 2738, 1311: 1311, 1022: 1022, 1151: 1151, 1169: 1169, 1061: 1061, 2053: 2053, 2750: 2750, 2618: 2618, 2620: 2620, 2939: 2939, 1879: 1879, 1468: 1468, 2521: 2521, 1068: 1068, 1070: 1070, 5641: 5641, 1967: 1967, 1505: 1505, 1323: 1323, 1040: 1040, 1010: 1010, 1566: 1566, 1032: 1032, 2131: 2131, 1554: 1554, 1116: 1116, 3080: 3080, 2504: 2504, 1607: 1607, 2580: 2580, 2103: 2103, 1026: 1026, 1307: 1307, 1163: 1163}
583
Saving new json file...
Filtered json saved.


In [None]:
input_json = 'data/train/annotations.json'
output_json = 'data/train/new_top50cat_val_filter.json'
categories = ['water', 'salad-leaf-salad-green', 'bread-white', 'tomato-raw', 'butter', 'carrot-raw', 'bread-wholemeal', 'coffee-with-caffeine', 'rice', 'egg', 'mixed-vegetables', 'apple', 'jam', 'cucumber', 'wine-red', 'banana', 'cheese', 'potatoes-steamed', 'bell-pepper-red-raw', 'hard-cheese', 'espresso-with-caffeine', 'tea', 'bread-whole-wheat', 'mixed-salad-chopped-without-sauce', 'avocado', 'white-coffee-with-caffeine', 'tomato-sauce', 'wine-white', 'broccoli', 'strawberries', 'pasta-spaghetti', 'honey', 'zucchini', 'parmesan', 'chicken', 'chips-french-fries', 'braided-white-loaf', 'dark-chocolate', 'mayonnaise', 'pizza-margherita-baked', 'blueberries', 'onion', 'salami', 'leaf-spinach', 'soft-cheese', 'salmon', 'water-mineral', 'gruyere', 'glucose-drink-50g', 'yaourt-yahourt-yogourt-ou-yoghourt-natural']
cf = CocoFilter()
cf.main(input_json,output_json,categories)

In [49]:
TRAIN_ANNOTATIONS_PATH = "data/train/new_top50cat_train_filter.json"
TRAIN_IMAGE_DIRECTIORY = "data/train/images/"
VAL_ANNOTATIONS_PATH = "data/val/new_top50cat_val_filter.json"
VAL_IMAGE_DIRECTIORY = "data/val/images/"
train_coco = COCO(TRAIN_ANNOTATIONS_PATH)
with open(TRAIN_ANNOTATIONS_PATH) as f:
    train_annotations_data = json.load(f)

loading annotations into memory...
Done (t=2.10s)
creating index...
index created!


In [50]:
category_ids = train_coco.loadCats(train_coco.getCatIds())
category_names_readable = [_["name_readable"] for _ in category_ids]
category_names = [_["name"] for _ in category_ids]
category_id = [_["id"] for _ in category_ids]

In [51]:
print(category_names)

['bread-wholemeal', 'jam', 'water', 'banana', 'soft-cheese', 'hard-cheese', 'coffee-with-caffeine', 'tea', 'avocado', 'egg', 'chips-french-fries', 'chicken', 'tomato-raw', 'broccoli', 'carrot-raw', 'tomato-sauce', 'cheese', 'mixed-vegetables', 'apple', 'blueberries', 'cucumber', 'butter', 'mayonnaise', 'wine-red', 'wine-white', 'pizza-margherita-baked', 'salami', 'rice', 'white-coffee-with-caffeine', 'bell-pepper-red-raw', 'zucchini', 'yaourt-yahourt-yogourt-ou-yoghourt-natural', 'salmon', 'pasta-spaghetti', 'parmesan', 'salad-leaf-salad-green', 'potatoes-steamed', 'bread-white', 'leaf-spinach', 'dark-chocolate', 'bread-whole-wheat', 'onion', 'glucose-drink-50g', 'espresso-with-caffeine', 'braided-white-loaf', 'water-mineral', 'honey', 'mixed-salad-chopped-without-sauce', 'gruyere', 'strawberries']


In [52]:
print(category_id)

[1565, 2099, 2578, 1154, 1352, 1310, 2512, 2498, 1056, 2022, 1013, 1788, 1069, 1085, 1078, 2738, 1311, 1022, 1151, 1169, 1061, 2053, 2750, 2618, 2620, 2939, 1879, 1468, 2521, 1068, 1070, 5641, 1967, 1505, 1323, 1040, 1010, 1566, 1032, 2131, 1554, 1116, 3080, 2504, 1607, 2580, 2103, 1026, 1307, 1163]


In [53]:
val_coco = COCO(VAL_ANNOTATIONS_PATH)
with open(VAL_ANNOTATIONS_PATH) as f:
    val_annotations_data = json.load(f)


loading annotations into memory...
Done (t=0.02s)
creating index...
index created!


In [54]:
category_ids = train_coco.loadCats(train_coco.getCatIds())
category_names_readable = [_["name_readable"] for _ in category_ids]
category_names = [_["name"] for _ in category_ids]
category_id = [_["id"] for _ in category_ids]

In [55]:
print(category_names)

['bread-wholemeal', 'jam', 'water', 'banana', 'soft-cheese', 'hard-cheese', 'coffee-with-caffeine', 'tea', 'avocado', 'egg', 'chips-french-fries', 'chicken', 'tomato-raw', 'broccoli', 'carrot-raw', 'tomato-sauce', 'cheese', 'mixed-vegetables', 'apple', 'blueberries', 'cucumber', 'butter', 'mayonnaise', 'wine-red', 'wine-white', 'pizza-margherita-baked', 'salami', 'rice', 'white-coffee-with-caffeine', 'bell-pepper-red-raw', 'zucchini', 'yaourt-yahourt-yogourt-ou-yoghourt-natural', 'salmon', 'pasta-spaghetti', 'parmesan', 'salad-leaf-salad-green', 'potatoes-steamed', 'bread-white', 'leaf-spinach', 'dark-chocolate', 'bread-whole-wheat', 'onion', 'glucose-drink-50g', 'espresso-with-caffeine', 'braided-white-loaf', 'water-mineral', 'honey', 'mixed-salad-chopped-without-sauce', 'gruyere', 'strawberries']


In [56]:
print(category_id)

[1565, 2099, 2578, 1154, 1352, 1310, 2512, 2498, 1056, 2022, 1013, 1788, 1069, 1085, 1078, 2738, 1311, 1022, 1151, 1169, 1061, 2053, 2750, 2618, 2620, 2939, 1879, 1468, 2521, 1068, 1070, 5641, 1967, 1505, 1323, 1040, 1010, 1566, 1032, 2131, 1554, 1116, 3080, 2504, 1607, 2580, 2103, 1026, 1307, 1163]


In [2]:
TRAIN_ANNOTATIONS_PATH = "data/train/top50cat_train_filter.json"
TRAIN_IMAGE_DIRECTIORY = "data/train/images/"
train_coco = COCO(TRAIN_ANNOTATIONS_PATH)
with open(TRAIN_ANNOTATIONS_PATH) as f:
    train_annotations_data = json.load(f)

loading annotations into memory...
Done (t=2.60s)
creating index...
index created!


In [3]:
category_ids = train_coco.loadCats(train_coco.getCatIds())
category_names_readable = [_["name_readable"] for _ in category_ids]
category_names = [_["name"] for _ in category_ids]
category_id = [_["id"] for _ in category_ids]

In [4]:
print(category_names)

['bread-wholemeal', 'jam', 'water', 'banana', 'soft-cheese', 'hard-cheese', 'coffee-with-caffeine', 'tea', 'avocado', 'egg', 'chips-french-fries', 'chicken', 'tomato-raw', 'broccoli', 'carrot-raw', 'tomato-sauce', 'cheese', 'mixed-vegetables', 'apple', 'blueberries', 'cucumber', 'butter', 'mayonnaise', 'wine-red', 'wine-white', 'pizza-margherita-baked', 'salami', 'rice', 'white-coffee-with-caffeine', 'bell-pepper-red-raw', 'zucchini', 'yaourt-yahourt-yogourt-ou-yoghourt-natural', 'salmon', 'pasta-spaghetti', 'parmesan', 'salad-leaf-salad-green', 'potatoes-steamed', 'bread-white', 'leaf-spinach', 'dark-chocolate', 'bread-whole-wheat', 'onion', 'glucose-drink-50g', 'espresso-with-caffeine', 'braided-white-loaf', 'water-mineral', 'honey', 'mixed-salad-chopped-without-sauce', 'gruyere', 'strawberries']
