### Install Required Packages

In [1]:
!pip install daam==0.0.11
!pip install accelerate  # this is to reduce CPU model load overhead
# !pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/$CUDA_VERSION/torch$TORCH_VERSION/index.html

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting daam==0.0.11
  Downloading daam-0.0.11.tar.gz (21 kB)
Collecting diffusers==0.9.0
  Downloading diffusers-0.9.0-py3-none-any.whl (453 kB)
[K     |████████████████████████████████| 453 kB 52.8 MB/s 
Collecting gradio
  Downloading gradio-3.15.0-py3-none-any.whl (13.8 MB)
[K     |████████████████████████████████| 13.8 MB 71.5 MB/s 
[?25hCollecting ftfy
  Downloading ftfy-6.1.1-py3-none-any.whl (53 kB)
[K     |████████████████████████████████| 53 kB 2.0 MB/s 
[?25hCollecting transformers==4.24.0
  Downloading transformers-4.24.0-py3-none-any.whl (5.5 MB)
[K     |████████████████████████████████| 5.5 MB 63.2 MB/s 
Collecting huggingface-hub>=0.10.0
  Downloading huggingface_hub-0.11.1-py3-none-any.whl (182 kB)
[K     |████████████████████████████████| 182 kB 73.4 MB/s 
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp38-cp38-manylinux_2_17_x86_6

We would be running `Stable Diffusion 2` so enable `GPU` under `View Resources > Change runtime type`

In [2]:
!nvidia-smi -L

GPU 0: Tesla T4 (UUID: GPU-5f351a36-f8cb-b6c6-d6ab-b13cf3bcb61a)


### Load Necessary Libraries

We will load the necessary libraries required for generating DAAM outputs for input prompts.

In [3]:
import os
import json
import datetime
from tqdm import tqdm
import base64
import IPython
import random
import requests
from io import BytesIO
from math import trunc

from matplotlib import pyplot as plt
import numpy as np

from PIL import Image
from PIL import ImageDraw as PILImageDraw
import cv2

from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from nltk import pos_tag 

from diffusers import StableDiffusionPipeline
import daam
import torch

Download the stopwords for removing stopwords

In [4]:
import nltk
nltk.download('stopwords')
nltk.download('wordnet')
nltk.download('punkt')
nltk.download('omw-1.4')
nltk.download('averaged_perceptron_tagger')

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Unzipping taggers/averaged_perceptron_tagger.zip.


True

### Load Data

The below list is a placeholder for any list of prompts, we will be replacing it with a list of prompts from `MS-COCO` text annotations later.

In [5]:
prompts = [
  "A group of people stand in the back of a truck filled with cotton.",
  "A mother and three children collecting garbage from a blue and white garbage can on the street.",
  "A woman is sitting in a chair reading a book with her head resting on her free hand.",
  "A brown and white dog exiting a yellow and blue ramp in a grassy area.",
  "A boy stands on a rocky mountain."
  ]

# prompts = [
#   "A group of people stand in the back of a truck filled with cotton."
# ]

Cleaning the prompts. I adopt few ways to clean the prompt:
- Tokenization
- Lemmatization
- Remove stop words
- Remove non-alphabets
- Keep only nouns

In [6]:
# Stopwords
stpwords = set(stopwords.words('english'))

# Lemmatizer
lemmatizer = WordNetLemmatizer() 

def clean_prompt(prompt):
  # tokenize each prompt after they have been converted into lower case
  prompt_tokenized = nltk.word_tokenize(prompt.lower())
  # lemmatize above prompt
  lemmatized_prompt = [lemmatizer.lemmatize(word) for word in prompt_tokenized]
  # pos tag lemmatized prompt above
  pos_tagged_prompt = pos_tag(lemmatized_prompt)
  # keep only the noun words
  fin_prompt = [word for word, pos in pos_tagged_prompt if ('N' in pos and word not in stpwords and word.isalpha())]
  return fin_prompt

### Setting up the Pipeline

#### Brief Overview of the Storage Scheme of the Generated Data

I will set up the pipeline for generation of the heatmaps and bounding boxes for each generated image for each prompt in the `prompts` list. We will be generating `NUM_IMAGES_PER_PROMPT` images per prompt.

Every output by our code will be stored in `Data-Generated` folder.

I use `COCO` Dataset format to store the bboxes and segmentation for each image.

#### Set some parameters

In [7]:
NUM_IMAGES_PER_PROMPT = 2 # Number of images to be generated per prompt
NUM_INFERENCE_STEPS = 50 # Number of inference steps to the Diffusion Model
NAME_OF_DATASET = 'COCO Stable Diffusion 2 Dataset' # Name of the generated dataset
SAVE_AFTER_NUM_IMAGES = 5 # Number of images after which the annotation and caption files will be saved

Now, let's load the `stabilityai/stable-diffusion-2-base` diffusion model.

In [8]:
model = StableDiffusionPipeline.from_pretrained('stabilityai/stable-diffusion-2-base')
model = model.to('cuda')

Downloading:   0%|          | 0.00/539 [00:00<?, ?B/s]

Fetching 16 files:   0%|          | 0/16 [00:00<?, ?it/s]

Downloading:   0%|          | 0.00/342 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/308 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/738 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.36G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.36G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/525k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/460 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/929 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.06M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.01k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/3.46G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/3.46G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/716 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/335M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/335M [00:00<?, ?B/s]

Setting up the folder structure to generate the data and the dataset information for storing annotations and captions in `COCO` format.

In [9]:
# The folder that will contain the generated data
os.mkdir('Data-Generated') # Stores everything that is generated
os.mkdir('Data-Generated/images') # Stores generated images
os.mkdir('Data-Generated/annotations') # Stores Annotations
os.mkdir('Data-Generated/captions') # Stores Captions

info = { # Info about the dataset
    "description": NAME_OF_DATASET,
    "url": "https://github.com/RishiDarkDevil/Text-Based-Object-Discovery",
    "version": "1.0",
    "year": 2022,
    "contributor": "Rishi Dey Chowdhury (RishiDarkDevil)",
    "date_created": "2022"
}

licenses = [{ # Licenses associated with the dataset
    'url': 'https://huggingface.co/stabilityai/stable-diffusion-2/blob/main/LICENSE-MODEL',
    'id': 1,
    'name': 'CreativeML Open RAIL++-M License'
}]

images = list() # Stores the generated image info
annotations = list() # Stores the annotation info
categories = list() # Stores the category info
captions = list() # Stores the captions info
cat2id = dict() # Stores the category to id mapping
cat_id = 1 # Assigns id to categories as we go on adding categories which we discover
image_id = 1 # Assigns generated image ids
annotation_id = 1 # Assigns annotations annotation ids
caption_id = 1 # Assigns captions caption ids
save_idx = 1 # The index which stores how many times we saved the json file before

Let's generate Global Word Attribution HeatMaps.

In [10]:
class NpEncoder(json.JSONEncoder): # To help encode the unsupported datatypes to json serializable format
  def default(self, obj):
    if isinstance(obj, np.integer):
        return int(obj)
    if isinstance(obj, np.floating):
        return float(obj)
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    return json.JSONEncoder.default(self, obj)

def save(): # Saving annotations and captions when called -- mainly to avoid code repetition
  # Serializing json
  json_obj_det = json.dumps({
      'info': info,
      'licenses': licenses,
      'images': images,
      'annotations': annotations,
      'categories': categories
  }, indent=4, cls=NpEncoder)

  # Writing json
  with open(f"Data-Generated/annotations/object_detect-{save_idx}.json", "w") as outfile:
    outfile.write(json_obj_det)

  print('Saved Annotations')

  # Delete json from python env
  del json_obj_det

  # Serializing json
  json_obj_cap = json.dumps({
      'info': info,
      'licenses': licenses,
      'images': images,
      'annotations': captions,
  }, indent=4, cls=NpEncoder)

  # Writing json
  with open(f"Data-Generated/captions/object_caption-{save_idx}.json", "w") as outfile:
    outfile.write(json_obj_cap)

  # Delete json from python env
  del json_obj_cap

  # Clearing out all the lists except cat2id to maintaining the unique category ids assigned to each new object
  images.clear()
  annotations.clear()
  categories.clear()
  captions.clear()

In [11]:
try:
  # Iterating overs the prompts
  for i, prompt in enumerate(prompts):

    # Flashing some details
    print(f'Prompt No. {i+1}/{len(prompts)}')
    print('Prompt:', prompt)
    cleaned_prompt = clean_prompt(prompt)
    print('Cleaned Prompt:', ' '.join(cleaned_prompt))
    print('Generating Image...')

    # Updating Categories using cleaned prompt if required and assigning index
    for word in cleaned_prompt:
      if word not in cat2id:
        cat2id[word] = cat_id
        categories.append({"supercategory": '', "id": cat_id, "name": word}) ### FIX SUPERCATEGORY
        cat_id += 1

    for j in range(NUM_IMAGES_PER_PROMPT):

      # Generating images and storing their trace for daam output
      with daam.trace(model) as trc:
        output_image = model(prompt, num_inference_steps=NUM_INFERENCE_STEPS).images[0]
        global_heat_map = trc.compute_global_heat_map()

      # Saving generated Image
      output_image.save(f'Data-Generated/images/{i}_{j}.png')
      print(f'Saving generated Image as.. {i}_{j}.png')

      width, height = output_image.size
      # Image details
      image_det = {
          'license': 1,
          'file_name': f'{i}_{j}.png',
          'height': height,
          'width': width,
          'date_captured': datetime.datetime.now().strftime("%m/%d/%Y, %H:%M:%S"),
          'id': image_id
      }
      images.append(image_det)

      # Captions details
      cap_det = {
          'id': caption_id,
          'image_id': image_id,
          'caption': prompt
      }
      captions.append(cap_det)

      print(f'Generating Annotations for {i}_{j}.png')
      # Generate Global Word Attribution HeatMap
      for word in tqdm(cleaned_prompt):

        # word category id
        word_cat_id = cat2id[word]
        
        # Compute heatmap for a non-stopword
        word_heatmap = global_heat_map.compute_word_heat_map(word).expand_as(output_image).numpy()

        # Casting heatmap from 0-1 floating range to 0-255 unsigned 8 bit integer
        heatmap = np.array(word_heatmap * 255, dtype = np.uint8)

        # Binary threshold of the above heatmap - serves as sort of semantic segmentation for the word
        thresh = cv2.threshold(heatmap, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]

        # Find contours from the binary threshold
        cnts = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        cnts = cnts[0] if len(cnts) == 2 else cnts[1]

        # Annotating the segmentation and the bounding boxes
        for idx, cnt in enumerate(cnts):
          x,y,w,h = cv2.boundingRect(cnt)
          area = cv2.contourArea(cnt)

          ann_det = { # Annotation details
              'segmentation': [list(cnt.squeeze().reshape(1, -1).squeeze())],
              'area': area,
              'iscrowd': 0,
              'image_id': image_id,
              'bbox': [x, y, w, h],
              'category_id': word_cat_id,
              'id': annotation_id,
          }
          annotation_id += 1
          annotations.append(ann_det)
      
      print()
      print('Generated Annotations...')
      
      # Saving Annotations and Captions
      if image_id % SAVE_AFTER_NUM_IMAGES == 0:
        save()
        print(f'Annotations and Captions saved... object_detect-{save_idx}.json and object_caption-{save_idx}.json')

        save_idx += 1

      image_id += 1

      caption_id += 1

  if ((image_id-1) % SAVE_AFTER_NUM_IMAGES != 0 and image_id > 1) or image_id == 1:
    save()

except KeyboardInterrupt: # In case of KeyboardInterrupt save the annotations and captions
  save()

Prompt No. 1/5
Prompt: A group of people stand in the back of a truck filled with cotton.
Cleaned Prompt: group people back truck filled cotton
Generating Image...


  0%|          | 0/51 [00:00<?, ?it/s]

Saving generated Image as.. 0_0.png
Generating Annotations for 0_0.png


100%|██████████| 6/6 [00:00<00:00, 52.47it/s]


Generated Annotations...





  0%|          | 0/51 [00:00<?, ?it/s]

Saving generated Image as.. 0_1.png
Generating Annotations for 0_1.png


100%|██████████| 6/6 [00:00<00:00, 260.89it/s]


Generated Annotations...
Prompt No. 2/5
Prompt: A mother and three children collecting garbage from a blue and white garbage can on the street.
Cleaned Prompt: mother child garbage blue garbage street
Generating Image...





  0%|          | 0/51 [00:00<?, ?it/s]

Saving generated Image as.. 1_0.png
Generating Annotations for 1_0.png


100%|██████████| 6/6 [00:00<00:00, 259.10it/s]


Generated Annotations...





  0%|          | 0/51 [00:00<?, ?it/s]

Saving generated Image as.. 1_1.png
Generating Annotations for 1_1.png


100%|██████████| 6/6 [00:00<00:00, 234.42it/s]


Generated Annotations...
Prompt No. 3/5
Prompt: A woman is sitting in a chair reading a book with her head resting on her free hand.
Cleaned Prompt: woman chair book head hand
Generating Image...





  0%|          | 0/51 [00:00<?, ?it/s]

Saving generated Image as.. 2_0.png
Generating Annotations for 2_0.png


100%|██████████| 5/5 [00:00<00:00, 231.95it/s]


Generated Annotations...
Saved Annotations
Annotations and Captions saved... object_detect-1.json and object_caption-1.json





  0%|          | 0/51 [00:00<?, ?it/s]

Saving generated Image as.. 2_1.png
Generating Annotations for 2_1.png


100%|██████████| 5/5 [00:00<00:00, 230.34it/s]


Generated Annotations...
Prompt No. 4/5
Prompt: A brown and white dog exiting a yellow and blue ramp in a grassy area.
Cleaned Prompt: brown dog yellow ramp area
Generating Image...





  0%|          | 0/51 [00:00<?, ?it/s]

Saving generated Image as.. 3_0.png
Generating Annotations for 3_0.png


100%|██████████| 5/5 [00:00<00:00, 233.90it/s]


Generated Annotations...





  0%|          | 0/51 [00:00<?, ?it/s]

Saving generated Image as.. 3_1.png
Generating Annotations for 3_1.png


100%|██████████| 5/5 [00:00<00:00, 115.16it/s]


Generated Annotations...
Prompt No. 5/5
Prompt: A boy stands on a rocky mountain.
Cleaned Prompt: stand mountain
Generating Image...





  0%|          | 0/51 [00:00<?, ?it/s]

Saving generated Image as.. 4_0.png
Generating Annotations for 4_0.png


100%|██████████| 2/2 [00:00<00:00, 263.82it/s]


Generated Annotations...





  0%|          | 0/51 [00:00<?, ?it/s]

Saving generated Image as.. 4_1.png
Generating Annotations for 4_1.png


100%|██████████| 2/2 [00:00<00:00, 237.14it/s]


Generated Annotations...
Saved Annotations
Annotations and Captions saved... object_detect-2.json and object_caption-2.json





In case we generate multiple annotation files we will have to merge them together into a single json file. The following cell does just that and results into a single `annotations.json`.

In [33]:
print('Starting Annotation Files Merge...')
# Annotation File Names present in the annotations directory
ann_file_names = os.listdir('Data-Generated/annotations')
print('Number of Annotation Files found:', len(ann_file_names))
print('Annotation Files found:', ' '.join(ann_file_names))
ann_files = list() # Contains the list of loaded annotation json files
for ann_file_name in tqdm(ann_file_names): # Loads the annotation json files and appens to ann_files
  with open(os.path.join('Data-Generated/annotations', ann_file_name)) as json_file:
    ann_file = json.load(json_file)
    ann_files.append(ann_file)
# Creating the single annotation file
annotation_file = {
    'info': ann_files[0]['info'],
    'licenses': ann_files[0]['licenses'],
    'images': [image for image in ann_file['images'] for ann_file in ann_files],
    'annotations': [ann for ann in ann_file['annotations'] for ann_file in ann_files],
    'categories': [cat for cat in ann_file['categories'] for ann_file in ann_files]
}
# Serializing json
ann_json_file = json.dumps(annotation_file, indent=4)
# Writing json
with open(f"Data-Generated/annotations/annotations.json", "w") as outfile:
  outfile.write(ann_json_file)
print()
print('Saved Annotation file... annotations.json')
print('Removing the annotation files other than annotations.json')
for ann_file_name in ann_file_names:
  os.remove(os.path.join('Data-Generated/annotations', ann_file_name))
print('A successful merge!')

Starting Annotation File Merge...
Number of Annotation Files found: 2
Annotation Files found: object_detect-1.json object_detect-2.json


100%|██████████| 2/2 [00:00<00:00, 78.09it/s]

Saved Annotation file... annotations.json





In case we generate multiple caption files we will have to merge them together into a single json file. The following cell does just that and results into a single `captions.json`.

In [None]:
print('Starting Caption Files Merge...')
# Caption File Names present in the captions directory
cap_file_names = os.listdir('Data-Generated/captions')
print('Number of Caption Files found:', len(cap_file_names))
print('Caption Files found:', ' '.join(cap_file_names))
cap_files = list() # Contains the list of loaded caption json files
for cap_file_name in tqdm(cap_file_names): # Loads the caption json files and appens to cap_files
  with open(os.path.join('Data-Generated/captions', cap_file_name)) as json_file:
    cap_file = json.load(json_file)
    cap_files.append(cap_file)
# Creating the single caption file
caption_file = {
    'info': cap_files[0]['info'],
    'licenses': cap_files[0]['licenses'],
    'images': [image for image in cap_file['images'] for cap_file in cap_files],
    'annotations': [ann for ann in cap_file['annotations'] for cap_file in cap_files],
}
# Serializing json
cap_json_file = json.dumps(caption_file, indent=4)
# Writing json
with open(f"Data-Generated/captions/captions.json", "w") as outfile:
  outfile.write(cap_json_file)
print()
print('Saved Caption file... captions.json')
print('Removing the caption files other than captions.json')
for cap_file_name in cap_file_names:
  os.remove(os.path.join('Data-Generated/captions', cap_file_name))
print('A successful merge!')

To download the folder containing all the generated data.

In [16]:
!zip -r /content/file.zip /content/Data-Generated
from google.colab import files
files.download("/content/file.zip")

  adding: content/Data-Generated/ (stored 0%)
  adding: content/Data-Generated/annotations/ (stored 0%)
  adding: content/Data-Generated/annotations/object_detect-1.json (deflated 93%)
  adding: content/Data-Generated/annotations/object_detect-2.json (deflated 93%)
  adding: content/Data-Generated/captions/ (stored 0%)
  adding: content/Data-Generated/captions/object_caption-1.json (deflated 75%)
  adding: content/Data-Generated/captions/object_caption-2.json (deflated 74%)
  adding: content/Data-Generated/images/ (stored 0%)
  adding: content/Data-Generated/images/1_0.png (deflated 0%)
  adding: content/Data-Generated/images/3_0.png (deflated 0%)
  adding: content/Data-Generated/images/4_0.png (deflated 0%)
  adding: content/Data-Generated/images/3_1.png (deflated 0%)
  adding: content/Data-Generated/images/2_1.png (deflated 0%)
  adding: content/Data-Generated/images/2_0.png (deflated 1%)
  adding: content/Data-Generated/images/4_1.png (deflated 0%)
  adding: content/Data-Generated/i

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

If in case you need to delete the `Data-Generated` folder with all its contents.

In [40]:
# # UNCOMMENT IF NEEDED
# import shutil
# shutil.rmtree('Data-Generated')

### Visualization of DAAM related outputs

Now that I have generated the images for the prompts and annotated each image with the segmentations and bounding boxes. Let's visualize stuffs.

In [46]:
# from directory_tree import display_tree
# display_tree('Data-Generated')


$ Operating System : Linux
$ Path : Data-Generated

*************** Directory Tree ***************

Data-Generated/
├── annotations/
│   └── object_detect-1.json
├── captions/
│   └── object_caption-1.json
└── images/
    ├── 0_0.png
    ├── 0_1.png
    ├── 1_0.png
    ├── 1_1.png
    ├── 2_0.png
    ├── 2_1.png
    ├── 3_0.png
    ├── 3_1.png
    ├── 4_0.png
    └── 4_1.png


In [3]:
# Load the dataset json
class CocoDataset():
    def __init__(self, annotation_path, image_dir):
        self.annotation_path = annotation_path
        self.image_dir = image_dir
        self.colors = ['blue', 'purple', 'red', 'green', 'orange', 'salmon', 'pink', 'gold',
                        'orchid', 'slateblue', 'limegreen', 'seagreen', 'darkgreen', 'olive',
                        'teal', 'aquamarine', 'steelblue', 'powderblue', 'dodgerblue', 'navy',
                        'magenta', 'sienna', 'maroon']
        
        json_file = open(self.annotation_path)
        self.coco = json.load(json_file)
        json_file.close()
        
        #self.process_info()
        #self.process_licenses()
        self.process_categories()
        self.process_images()
        self.process_segmentations()

    def display_info(self):
        print('Dataset Info:')
        print('=============')
        for key, item in self.info.items():
            print('  {}: {}'.format(key, item))
        
        requirements = [['description', str],
                        ['url', str],
                        ['version', str],
                        ['year', int],
                        ['contributor', str],
                        ['date_created', str]]
        for req, req_type in requirements:
            if req not in self.info:
                print('ERROR: {} is missing'.format(req))
            elif type(self.info[req]) != req_type:
                print('ERROR: {} should be type {}'.format(req, str(req_type)))
        print('')
        
    def display_licenses(self):
        print('Licenses:')
        print('=========')
        
        requirements = [['id', int],
                        ['url', str],
                        ['name', str]]
        for license in self.licenses:
            for key, item in license.items():
                print('  {}: {}'.format(key, item))
            for req, req_type in requirements:
                if req not in license:
                    print('ERROR: {} is missing'.format(req))
                elif type(license[req]) != req_type:
                    print('ERROR: {} should be type {}'.format(req, str(req_type)))
            print('')
        print('')
        
    def display_categories(self):
        print('Categories:')
        print('=========')
        for sc_key, sc_val in self.super_categories.items():
            print('  super_category: {}'.format(sc_key))
            for cat_id in sc_val:
                print('    id {}: {}'.format(cat_id, self.categories[cat_id]['name']))
            print('')
    
    def display_image(self, image_id, show_polys=True, show_bbox=True, show_labels=True, show_crowds=True, use_url=False):
        print('Image:')
        print('======')
        if image_id == 'random':
            image_id = random.choice(list(self.images.keys()))
        
        # Print the image info
        image = self.images[image_id]
        for key, val in image.items():
            print('  {}: {}'.format(key, val))
            
        # Open the image
        if use_url:
            image_path = image['coco_url']
            response = requests.get(image_path)
            image = Image.open(BytesIO(response.content))
            
        else:
            image_path = os.path.join(self.image_dir, image['file_name'])
            image = Image.open(image_path)
            
        buffered = BytesIO()
        image.save(buffered, format="PNG")
        img_str = "data:image/png;base64, " + base64.b64encode(buffered.getvalue()).decode()
        
        # Calculate the size and adjusted display size
        max_width = 900
        image_width, image_height = image.size
        adjusted_width = min(image_width, max_width)
        adjusted_ratio = adjusted_width / image_width
        adjusted_height = adjusted_ratio * image_height
        
        # Create list of polygons to be drawn
        polygons = {}
        bbox_polygons = {}
        rle_regions = {}
        poly_colors = {}
        labels = {}
        print('  segmentations ({}):'.format(len(self.segmentations[image_id])))
        for i, segm in enumerate(self.segmentations[image_id]):
            polygons_list = []
            if segm['iscrowd'] != 0:
                # Gotta decode the RLE
                px = 0
                x, y = 0, 0
                rle_list = []
                for j, counts in enumerate(segm['segmentation']['counts']):
                    if j % 2 == 0:
                        # Empty pixels
                        px += counts
                    else:
                        # Need to draw on these pixels, since we are drawing in vector form,
                        # we need to draw horizontal lines on the image
                        x_start = trunc(trunc(px / image_height) * adjusted_ratio)
                        y_start = trunc(px % image_height * adjusted_ratio)
                        px += counts
                        x_end = trunc(trunc(px / image_height) * adjusted_ratio)
                        y_end = trunc(px % image_height * adjusted_ratio)
                        if x_end == x_start:
                            # This is only on one line
                            rle_list.append({'x': x_start, 'y': y_start, 'width': 1 , 'height': (y_end - y_start)})
                        if x_end > x_start:
                            # This spans more than one line
                            # Insert top line first
                            rle_list.append({'x': x_start, 'y': y_start, 'width': 1, 'height': (image_height - y_start)})
                            
                            # Insert middle lines if needed
                            lines_spanned = x_end - x_start + 1 # total number of lines spanned
                            full_lines_to_insert = lines_spanned - 2
                            if full_lines_to_insert > 0:
                                full_lines_to_insert = trunc(full_lines_to_insert * adjusted_ratio)
                                rle_list.append({'x': (x_start + 1), 'y': 0, 'width': full_lines_to_insert, 'height': image_height})
                                
                            # Insert bottom line
                            rle_list.append({'x': x_end, 'y': 0, 'width': 1, 'height': y_end})
                if len(rle_list) > 0:
                    rle_regions[segm['id']] = rle_list  
            else:
                # Add the polygon segmentation
                for segmentation_points in segm['segmentation']:
                    segmentation_points = np.multiply(segmentation_points, adjusted_ratio).astype(int)
                    polygons_list.append(str(segmentation_points).lstrip('[').rstrip(']'))

            polygons[segm['id']] = polygons_list

            if i < len(self.colors):
                poly_colors[segm['id']] = self.colors[i]
            else:
                poly_colors[segm['id']] = 'white'
            
            bbox = segm['bbox']
            bbox_points = [bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1],
                           bbox[0] + bbox[2], bbox[1] + bbox[3], bbox[0], bbox[1] + bbox[3],
                           bbox[0], bbox[1]]
            bbox_points = np.multiply(bbox_points, adjusted_ratio).astype(int)
            bbox_polygons[segm['id']] = str(bbox_points).lstrip('[').rstrip(']')
            
            labels[segm['id']] = (self.categories[segm['category_id']]['name'], (bbox_points[0], bbox_points[1] - 4))
            
            # Print details
            print('    {}:{}:{}'.format(segm['id'], poly_colors[segm['id']], self.categories[segm['category_id']]))

        # Draw segmentation polygons on image
        html = '<div class="container" style="position:relative;">'
        html += '<img src="{}" style="position:relative;top:0px;left:0px;width:{}px;">'.format(img_str, adjusted_width)
        html += '<div class="svgclass"><svg width="{}" height="{}">'.format(adjusted_width, adjusted_height)
        
        if show_polys:
            for seg_id, points_list in polygons.items():
                fill_color = poly_colors[seg_id]
                stroke_color = poly_colors[seg_id]
                for points in points_list:
                    html += '<polygon points="{}" style="fill:{}; stroke:{}; stroke-width:1; fill-opacity:0.5" />'.format(points, fill_color, stroke_color)
        
        if show_crowds:
            for seg_id, rect_list in rle_regions.items():
                fill_color = poly_colors[seg_id]
                stroke_color = poly_colors[seg_id]
                for rect_def in rect_list:
                    x, y = rect_def['x'], rect_def['y']
                    w, h = rect_def['width'], rect_def['height']
                    html += '<rect x="{}" y="{}" width="{}" height="{}" style="fill:{}; stroke:{}; stroke-width:1; fill-opacity:0.5; stroke-opacity:0.5" />'.format(x, y, w, h, fill_color, stroke_color)
            
        if show_bbox:
            for seg_id, points in bbox_polygons.items():
                fill_color = poly_colors[seg_id]
                stroke_color = poly_colors[seg_id]
                html += '<polygon points="{}" style="fill:{}; stroke:{}; stroke-width:1; fill-opacity:0" />'.format(points, fill_color, stroke_color)
                
        if show_labels:
            for seg_id, label in labels.items():
                color = poly_colors[seg_id]
                html += '<text x="{}" y="{}" style="fill:{}; font-size: 12pt;">{}</text>'.format(label[1][0], label[1][1], color, label[0])
                
        html += '</svg></div>'
        html += '</div>'
        html += '<style>'
        html += '.svgclass { position:absolute; top:0px; left:0px;}'
        html += '</style>'
        return html
       
    def process_info(self):
        self.info = self.coco['info']
    
    def process_licenses(self):
        self.licenses = self.coco['licenses']
    
    def process_categories(self):
        self.categories = {}
        self.super_categories = {}
        for category in self.coco['categories']:
            cat_id = category['id']
            super_category = category['supercategory']
            
            # Add category to the categories dict
            if cat_id not in self.categories:
                self.categories[cat_id] = category
            else:
                print("ERROR: Skipping duplicate category id: {}".format(category))

            # Add category to super_categories dict
            if super_category not in self.super_categories:
                self.super_categories[super_category] = {cat_id} # Create a new set with the category id
            else:
                self.super_categories[super_category] |= {cat_id} # Add category id to the set
                
    def process_images(self):
        self.images = {}
        for image in self.coco['images']:
            image_id = image['id']
            if image_id in self.images:
                print("ERROR: Skipping duplicate image id: {}".format(image))
            else:
                self.images[image_id] = image
                
    def process_segmentations(self):
        self.segmentations = {}
        for segmentation in self.coco['annotations']:
            image_id = segmentation['image_id']
            if image_id not in self.segmentations:
                self.segmentations[image_id] = []
            self.segmentations[image_id].append(segmentation)

In [2]:
import os
import json
import datetime
from tqdm import tqdm
import base64
import IPython
import random
import requests
from io import BytesIO
from math import trunc

from matplotlib import pyplot as plt
import numpy as np

from PIL import Image
from PIL import ImageDraw as PILImageDraw
import cv2

In [8]:
annotation_path = r'content/Data-Generated/annotations/object_detect-1.json'
image_dir = r'content/Data-Generated/images'

coco_dataset = CocoDataset(annotation_path, image_dir)

In [9]:
coco_dataset.display_categories()

Categories:
  super_category: 
    id 1: group
    id 2: people
    id 3: back
    id 4: truck
    id 5: filled
    id 6: cotton
    id 7: mother
    id 8: child
    id 9: garbage
    id 10: blue
    id 11: street
    id 12: woman
    id 13: chair
    id 14: book
    id 15: head
    id 16: hand



In [14]:
html = coco_dataset.display_image(5, use_url=False)
IPython.display.HTML(html)

Image:
  license: 1
  file_name: 2_0.png
  height: 512
  width: 512
  date_captured: 12/23/2022, 14:19:58
  id: 5
  segmentations (80):
    361:blue:{'supercategory': '', 'id': 12, 'name': 'woman'}
    362:purple:{'supercategory': '', 'id': 12, 'name': 'woman'}
    363:red:{'supercategory': '', 'id': 12, 'name': 'woman'}
    364:green:{'supercategory': '', 'id': 12, 'name': 'woman'}
    365:orange:{'supercategory': '', 'id': 12, 'name': 'woman'}
    366:salmon:{'supercategory': '', 'id': 12, 'name': 'woman'}
    367:pink:{'supercategory': '', 'id': 12, 'name': 'woman'}
    368:gold:{'supercategory': '', 'id': 12, 'name': 'woman'}
    369:orchid:{'supercategory': '', 'id': 12, 'name': 'woman'}
    370:slateblue:{'supercategory': '', 'id': 12, 'name': 'woman'}
    371:limegreen:{'supercategory': '', 'id': 12, 'name': 'woman'}
    372:seagreen:{'supercategory': '', 'id': 12, 'name': 'woman'}
    373:darkgreen:{'supercategory': '', 'id': 12, 'name': 'woman'}
    374:olive:{'supercategory':

In [None]:
json_file = open(annotation_path)
coco = json.load(json_file)
json_file.close()

In [1]:
!unzip file.zip

Archive:  file.zip
   creating: content/Data-Generated/
   creating: content/Data-Generated/annotations/
  inflating: content/Data-Generated/annotations/object_detect-1.json  
  inflating: content/Data-Generated/annotations/object_detect-2.json  
   creating: content/Data-Generated/captions/
  inflating: content/Data-Generated/captions/object_caption-1.json  
  inflating: content/Data-Generated/captions/object_caption-2.json  
   creating: content/Data-Generated/images/
  inflating: content/Data-Generated/images/1_0.png  
  inflating: content/Data-Generated/images/3_0.png  
  inflating: content/Data-Generated/images/4_0.png  
  inflating: content/Data-Generated/images/3_1.png  
  inflating: content/Data-Generated/images/2_1.png  
  inflating: content/Data-Generated/images/2_0.png  
  inflating: content/Data-Generated/images/4_1.png  
  inflating: content/Data-Generated/images/0_1.png  
  inflating: content/Data-Generated/images/1_1.png  
  inflating: content/Data-Generated/images/0_0.p