In [None]:
## import necessary modules
import os
from pycocotools.coco import COCO
import matplotlib.pyplot as plt
import numpy as np
import skimage

In [None]:
## source file configuration
src_annotation_file_path = r"./data/annotations_trainval2014/annotations/instances_train2014.json"
local_image_dir = r".\data\annotations_trainval2014\annotations\images"

category_names = [r"cat"]

In [None]:
## create COCO object
coco_obj = COCO(src_annotation_file_path)

In [None]:
## check coco info
coco_obj.info()

In [None]:
## check coco dataset categories
print(coco_obj.dataset["categories"])

In [None]:
## get category IDs
cat_ids = coco_obj.getCatIds(catNms = category_names)
print(cat_ids)

In [None]:
## get all the image IDs of corresponding categories 
img_ids = coco_obj.getImgIds(catIds = cat_ids)
print(len(img_ids))

In [None]:
## select one of the image and download the image to local
selected_img_idx = 0

selected_img_id = img_ids[selected_img_idx]
print(selected_img_id)

# create local directory if not exsist
if not os.path.isdir(local_image_dir):
    os.makedirs(local_image_dir)

# call the download method to download the image
coco_obj.download(tarDir = local_image_dir, imgIds = [selected_img_id])

In [None]:
## load coco image downloaded image
# load coco image (dictionary)
selected_img = coco_obj.loadImgs(ids = [selected_img_id])
print(selected_img)

# create local image path
local_image_path = os.path.join(local_image_dir, selected_img[0]["file_name"])
print(local_image_path)

# load local image
local_image = skimage.io.imread(local_image_path)

# plot local image
plt.figure()
plt.imshow(local_image)
plt.xticks([])
plt.yticks([])
plt.title("Downloaded image")
plt.show()

In [None]:
## load the annotations corresponding to the image and category

# load annotation ids using the selected indexs
selected_ann_ids = coco_obj.getAnnIds(imgIds = [selected_img_id], catIds = cat_ids)
selected_anns = coco_obj.loadAnns(selected_ann_ids)
print(selected_anns)

# plot annotation ids
plot_image = local_image

plt.figure()

plt.subplot(1,2,1)
plt.imshow(plot_image)
plt.xticks([])
plt.yticks([])
plt.title("Image")

plt.subplot(1,2,2)
plt.imshow(plot_image)

for cur_ann in selected_anns:
    # convert annonation dictionary to mask 
    cur_mask = coco_obj.annToMask(cur_ann)
    
    print(cur_mask.shape)
    cur_map = np.full(cur_mask.shape, np.nan)
    cur_map[cur_mask > 0] = cur_ann["category_id"]
    
    plt.imshow(cur_map, cmap = "tab20c")

plt.xticks([])
plt.yticks([])
plt.title("Segmenetation")

plt.tight_layout()
plt.show()