In [1]:
import json
import os
import random
import time
from typing import Any, Dict, List, Optional

import numpy as np
from pycocotools.coco import COCO
from tqdm import tqdm

In [2]:
root_path = os.path.join('..', '..')
full_annotation_file = os.path.join(root_path, 'data', 'full_2017_bboxes.json')
novel_classes_file = os.path.join(root_path, 'data', 'novel_class_ids.json')

coco_dset = COCO(full_annotation_file)
with open(novel_classes_file, 'r') as f:
    novel_classes = json.load(f)['novel_cat_ids']

loading annotations into memory...
Done (t=1.91s)
creating index...
index created!


Check that the function "getImgIds" with catIds argument returns all images containing annotation of a class.

In [3]:
for novel_class in tqdm(novel_classes):
    len_1 = len(coco_dset.getImgIds(catIds=[novel_class]))
    len_2 = len(set([x['image_id'] for x in coco_dset.anns.values() if x['category_id'] == novel_class]))
    if len_1 != len_2:
        print(len_1, len_2)

100%|██████████| 346/346 [00:09<00:00, 36.07it/s]


Check if for some of the classes there are images containing boxes of other novel classes.

In [4]:
for novel_class in tqdm(novel_classes):
    # Get all images with at least an annotation for this class
    class_imgs = coco_dset.getImgIds(catIds=[novel_class])
    # For each of the images...
    for img in class_imgs:
        # Load all of its annotation
        anns_for_img = coco_dset.getAnnIds(imgIds=[img])
        # For each annotation found in the image...
        for ann in coco_dset.loadAnns(anns_for_img):
            # Check if the category ID of that annotation is different from that of the considered novel class
            if ann['category_id'] != novel_class:
                print(f"Image {img} does not exclusively contain class {class_imgs}, but also class {ann['category_id']}")

100%|██████████| 346/346 [00:00<00:00, 5218.77it/s]


All novel classes apparently have well separated images. Now we proceed with the algorithm.

In [5]:
def create_annotation_sets_with_K_shots(coco_dset: COCO, novel_classes: List[int], 
                                        K: int, do_not_sample: Optional[Dict[int, List[int]]] = None,
                                        timeout: int = 10):
    '''
    Samples exactly `K` annotations for each of the `novel_classes` of the input `coco_dset`.
    If a sampled annotation is part of an image that also contains annotations of another novel class,
    the other novel class is populated with annotations from the same image, in order to fully capture
    all annotations in an image. Given the random nature of the algorithm, a timeout can be provided.

    Args:
    - `coco_dset`: The COCO object for the dataset managing annotations and image IDs
    - `noval_classes`: A list of novel class IDs to sample for
    - `K`: The number of annotations to sample for each class
    - `do_not_sample` (Optional, Default None): A dict containing previously-sampled annotations 
        for each class, to avoid sampling them a second time if the algorithm is used more than once 
        (e.g. for a validation or a test set)
    - `timeout` (Optional, Default 10): Number of seconds to let the algorithm run
    '''
    assert K > 0, "K must be positive"
    st_t = time.time()

    # Collect images that should be excluded from sampling
    if do_not_sample is not None:
        do_not_sample_images = {}
        for k in do_not_sample:
            do_not_sample_images[k] = [x['image_id'] for x in do_not_sample[k]]
    
    # Consistency loop
    stop_flag = False
    while not stop_flag:
        
        # Step 1): randomize priority of novel classes
        novel_classes = random.sample(novel_classes, len(novel_classes))
        novel_classes_set = set(novel_classes)
        
        # Create new set of annotations
        annots = {class_id: [] for class_id in novel_classes}
        
        # Step 2): Constructively add images to our annotation pool
        for class_id in tqdm(novel_classes):
            
            # Get the images containing annotations of that class and randomize them
            class_image_ids = coco_dset.getImgIds(catIds=[class_id])
            random.shuffle(class_image_ids)
            
            # Remove from the set of image ids those that we should not sample (if any)
            if do_not_sample is not None:
                class_image_ids = list(set(class_image_ids) - set(do_not_sample_images[class_id]))
            
            # Start choosing images to fill the class annotations
            for img_id in class_image_ids:
                
                # Get annotations for that image
                img_annots = coco_dset.getAnnIds(imgIds=[img_id])
                img_annots = coco_dset.loadAnns(img_annots)
                # Separate annotations into "annotations of that class" and 
                # "annotations of another class"
                class_annots = []; non_class_annots = []
                for ann in img_annots:
                    # Note: we only consider novel classes here. A base class in the same image
                    # is fine.
                    if    ann['category_id'] not in novel_classes_set: continue
                    elif  ann['category_id'] == class_id: class_annots.append(ann)
                    else: non_class_annots.append(ann)
                
                # Checks:
                
                # 1) The annots could be added to the class list without overflowing K annotations
                # (otherwise choose another image)
                if len(annots[class_id]) + len(class_annots) > K:
                    continue
                
                # 2) The annotations of other classes in the image don't overflow their respective classes
                # (otherwise choose another image)
                ncann_elems = {}
                for ncann in non_class_annots:
                    if ncann['category_id'] not in ncann_elems:
                        ncann_elems[ncann['category_id']] = 1
                    else:
                        ncann_elems[ncann['category_id']] += 1
                if len(ncann_elems) > 0 and any([len(annots[k]) + ncann_elems[k] > K for k in ncann_elems]):
                    continue
                
                # If all went well, simply add the annotations into their respective classes
                for ann in img_annots:
                    annots[ann['category_id']].append(ann)
                
                # Break the cycle early if we have reached the correct amount of annotations for this class
                if len(annots[class_id]) == K:
                    break
            
            # Check that the class actually has K annotations and no other class has more than K,
            # otherwise restart from scratch
            if len(annots[class_id]) != K or any([len(annots[x]) > K for x in annots]):
                break
            # If all classes have exactly K annotations, stop!
            if all([len(annots[x]) == K for x in annots]):
                stop_flag = True
        
        end_t = time.time()
        if end_t - st_t > timeout:
            raise TimeoutError("Annotation Sampling function was not able to complete sampling in time. Try lowering K.")
        if not stop_flag:
            print("[ANNOTATION SAMPLING] Restarting...")
   
    return annots


In [22]:
train_annots = create_annotation_sets_with_K_shots(coco_dset, novel_classes, K=10)

100%|██████████| 346/346 [00:00<00:00, 9840.51it/s]


In [23]:
val_annots = create_annotation_sets_with_K_shots(coco_dset, novel_classes, K=20, do_not_sample=train_annots)

100%|██████████| 346/346 [00:00<00:00, 8258.15it/s]


In [24]:
test_annots = create_annotation_sets_with_K_shots(coco_dset, novel_classes, K=20, do_not_sample={
    k: train_annots[k] + val_annots[k]
    for k in train_annots
})

100%|██████████| 346/346 [00:00<00:00, 7372.75it/s]


Let's check to make sure that there are no overlaps between the sets

In [25]:
def log_err(t1, t2):
    return f"A {t1} annotation was also found in the {t2} set"

for class_id in train_annots:
    for ann in train_annots[class_id]:
        if ann in val_annots[class_id]:
            print(log_err("train", "val"))
        if ann in test_annots[class_id]:
            print(log_err("train", "test"))

for class_id in val_annots:
    for ann in val_annots[class_id]:
        if ann in train_annots[class_id]:
            print(log_err("val", "train"))
        if ann in test_annots[class_id]:
            print(log_err("val", "test"))

for class_id in test_annots:
    for ann in test_annots[class_id]:
        if ann in val_annots[class_id]:
            print(log_err("test", "val"))
        if ann in train_annots[class_id]:
            print(log_err("test", "train"))