### Install Required Packages

In [1]:
!pip install daam==0.0.11
!pip install accelerate  # this is to reduce CPU model load overhead
# !pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/$CUDA_VERSION/torch$TORCH_VERSION/index.html

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting daam==0.0.11
  Downloading daam-0.0.11.tar.gz (21 kB)
Collecting diffusers==0.9.0
  Downloading diffusers-0.9.0-py3-none-any.whl (453 kB)
[K     |████████████████████████████████| 453 kB 24.4 MB/s 
Collecting gradio
  Downloading gradio-3.15.0-py3-none-any.whl (13.8 MB)
[K     |████████████████████████████████| 13.8 MB 67.9 MB/s 
[?25hCollecting ftfy
  Downloading ftfy-6.1.1-py3-none-any.whl (53 kB)
[K     |████████████████████████████████| 53 kB 2.4 MB/s 
[?25hCollecting transformers==4.24.0
  Downloading transformers-4.24.0-py3-none-any.whl (5.5 MB)
[K     |████████████████████████████████| 5.5 MB 58.2 MB/s 
Collecting huggingface-hub>=0.10.0
  Downloading huggingface_hub-0.11.1-py3-none-any.whl (182 kB)
[K     |████████████████████████████████| 182 kB 79.5 MB/s 
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp38-cp38-manylinux_2_17_x86_6

We would be running `Stable Diffusion 2` so enable `GPU` under `View Resources > Change runtime type`

In [2]:
!nvidia-smi -L

GPU 0: Tesla T4 (UUID: GPU-bf149c61-03c1-6abd-21e3-46f6efb68836)


### Load Necessary Libraries

We will load the necessary libraries required for generating DAAM outputs for input prompts.

In [25]:
import os
import json
import datetime
from tqdm import tqdm

from matplotlib import pyplot as plt
import numpy as np

from PIL import Image
import cv2

from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from nltk import pos_tag 

from diffusers import StableDiffusionPipeline
import daam
import torch

Download the stopwords for removing stopwords

In [19]:
import nltk
nltk.download('stopwords')
nltk.download('wordnet')
nltk.download('punkt')
nltk.download('omw-1.4')
nltk.download('averaged_perceptron_tagger')

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!


True

### Load Data

The below list is a placeholder for any list of prompts, we will be replacing it with a list of prompts from `MS-COCO` text annotations later.

In [20]:
prompts = [
  "A group of people stand in the back of a truck filled with cotton.",
  "A mother and three children collecting garbage from a blue and white garbage can on the street.",
  "A woman is sitting in a chair reading a book with her head resting on her free hand.",
  "A brown and white dog exiting a yellow and blue ramp in a grassy area.",
  "A boy stands on a rocky mountain."
  ]

# prompts = [
#   "A group of people stand in the back of a truck filled with cotton."
# ]

Cleaning the prompts. I adopt few ways to clean the prompt:
- Tokenization
- Lemmatization
- Remove stop words
- Remove non-alphabets
- Keep only nouns

In [17]:
# Stopwords
stpwords = set(stopwords.words('english'))

# Lemmatizer
lemmatizer = WordNetLemmatizer() 

def clean_prompt(prompt):
  # tokenize each prompt after they have been converted into lower case
  prompt_tokenized = nltk.word_tokenize(prompt.lower())
  # lemmatize above prompt
  lemmatized_prompt = [lemmatizer.lemmatize(word) for word in prompt_tokenized]
  # pos tag lemmatized prompt above
  pos_tagged_prompt = pos_tag(lemmatized_prompt)
  # keep only the noun words
  fin_prompt = [word for word, pos in pos_tagged_prompt if ('N' in pos and word not in stpwords and word.isalpha())]
  return fin_prompt

### Setting up the Pipeline

#### Brief Overview of the Storage Scheme of the Generated Data

I will set up the pipeline for generation of the heatmaps and bounding boxes for each generated image for each prompt in the `prompts` list. We will be generating `NUM_IMAGES_PER_PROMPT` images per prompt.

Every output by our code will be stored in `Data-Generated` folder.

I use `COCO` Dataset format to store the bboxes and segmentation for each image.

#### Set some parameters

In [21]:
NUM_IMAGES_PER_PROMPT = 2 # Number of images to be generated per prompt
NUM_INFERENCE_STEPS = 50 # Number of inference steps to the Diffusion Model
NAME_OF_DATASET = 'COCO Stable Diffusion 2 Dataset' # Name of the generated dataset
SAVE_AFTER_NUM_IMAGES = 100 # Number of images after which the annotation and caption files will be saved

Now, let's load the `stabilityai/stable-diffusion-2-base` diffusion model.

In [22]:
model = StableDiffusionPipeline.from_pretrained('stabilityai/stable-diffusion-2-base')
model = model.to('cuda')

Downloading:   0%|          | 0.00/539 [00:00<?, ?B/s]

Fetching 16 files:   0%|          | 0/16 [00:00<?, ?it/s]

Downloading:   0%|          | 0.00/342 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/308 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/738 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.36G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.36G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/525k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/460 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/929 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.06M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.01k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/3.46G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/3.46G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/716 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/335M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/335M [00:00<?, ?B/s]

Setting up the folder structure to generate the data and the dataset information for storing annotations and captions in `COCO` format.

In [41]:
# The folder that will contain the generated data
os.mkdir('Data-Generated') # Stores everything that is generated
os.mkdir('Data-Generated/images') # Stores generated images
os.mkdir('Data-Generated/annotations') # Stores Annotations
os.mkdir('Data-Generated/captions') # Stores Captions

info = { # Info about the dataset
    "description": NAME_OF_DATASET,
    "url": "https://github.com/RishiDarkDevil/Text-Based-Object-Discovery",
    "version": "1.0",
    "year": 2022,
    "contributor": "Rishi Dey Chowdhury (RishiDarkDevil)",
    "date_created": "2022"
}

licenses = [{ # Licenses associated with the dataset
    'url': 'https://huggingface.co/stabilityai/stable-diffusion-2/blob/main/LICENSE-MODEL',
    'id': 1,
    'name': 'CreativeML Open RAIL++-M License'
}]

images = list() # Stores the generated image info
annotations = list() # Stores the annotation info
categories = list() # Stores the category info
captions = list() # Stores the captions info
cat2id = dict() # Stores the category to id mapping
cat_id = 1 # Assigns id to categories as we go on adding categories which we discover
image_id = 1 # Assigns generated image ids
annotation_id = 1 # Assigns annotations annotation ids
caption_id = 1 # Assigns captions caption ids
save_idx = 1 # The index which stores how many times we saved the json file before

Let's generate Global Word Attribution HeatMaps.

In [44]:
class NpEncoder(json.JSONEncoder): # To help encode the unsupported datatypes to json serializable format
  def default(self, obj):
    if isinstance(obj, np.integer):
        return int(obj)
    if isinstance(obj, np.floating):
        return float(obj)
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    return json.JSONEncoder.default(self, obj)

def save(): # Saving annotations and captions when called -- mainly to avoid code repetition
  # Serializing json
  json_obj_det = json.dumps({
      'info': info,
      'licenses': licenses,
      'images': images,
      'annotations': annotations,
      'categories': categories
  }, indent=4, cls=NpEncoder)

  # Writing json
  with open(f"Data-Generated/annotations/object_detect-{save_idx}.json", "w") as outfile:
    outfile.write(json_obj_det)

  print('Saved Annotations')

  # Delete json from python env
  del json_obj_det

  # Serializing json
  json_obj_cap = json.dumps({
      'info': info,
      'licenses': licenses,
      'images': images,
      'annotations': captions,
  }, indent=4, cls=NpEncoder)

  # Writing json
  with open(f"Data-Generated/captions/object_caption-{save_idx}.json", "w") as outfile:
    outfile.write(json_obj_cap)

  # Delete json from python env
  del json_obj_cap

  # Clearing out all the lists except cat2id to maintaining the unique category ids assigned to each new object
  images.clear()
  annotations.clear()
  categories.clear()
  captions.clear()

In [None]:
try:
  # Iterating overs the prompts
  for i, prompt in enumerate(prompts):

    # Flashing some details
    print(f'Prompt No. {i+1}/{len(prompts)}')
    print('Prompt:', prompt)
    cleaned_prompt = clean_prompt(prompt)
    print('Cleaned Prompt:', ' '.join(cleaned_prompt))
    print('Generating Image...')

    # Updating Categories using cleaned prompt if required and assigning index
    for word in cleaned_prompt:
      if word not in cat2id:
        cat2id[word] = cat_id
        categories.append({"supercategory": word,"id": cat_id,"name": word}) ### FIX SUPERCATEGORY
        cat_id += 1

    for j in range(NUM_IMAGES_PER_PROMPT):

      # Generating images and storing their trace for daam output
      with daam.trace(model) as trc:
        output_image = model(prompt, num_inference_steps=NUM_INFERENCE_STEPS).images[0]
        global_heat_map = trc.compute_global_heat_map()

      # Saving generated Image
      output_image.save(f'Data-Generated/images/{i}_{j}.png')
      print(f'Saving generated Image as.. {i}_{j}.png')

      width, height = output_image.size
      # Image details
      image_det = {
          'license': 1,
          'file_name': f'{i}_{j}.png',
          'height': height,
          'width': width,
          'date_captured': datetime.datetime.now().strftime("%m/%d/%Y, %H:%M:%S"),
          'id': image_id
      }
      images.append(image_det)

      # Captions details
      cap_det = {
          'id': caption_id,
          'image_id': image_id,
          'caption': prompt
      }
      captions.append(cap_det)

      print(f'Generating Annotations for {i}_{j}.png')
      # Generate Global Word Attribution HeatMap
      for word in tqdm(cleaned_prompt):

        # word category id
        word_cat_id = cat2id[word]
        
        # Compute heatmap for a non-stopword
        word_heatmap = global_heat_map.compute_word_heat_map(word).expand_as(output_image).numpy()

        # Casting heatmap from 0-1 floating range to 0-255 unsigned 8 bit integer
        heatmap = np.array(word_heatmap * 255, dtype = np.uint8)

        # Binary threshold of the above heatmap - serves as sort of semantic segmentation for the word
        thresh = cv2.threshold(heatmap, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]

        # Find contours from the binary threshold
        cnts = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        cnts = cnts[0] if len(cnts) == 2 else cnts[1]

        # Annotating the segmentation and the bounding boxes
        for idx, cnt in enumerate(cnts):
          x,y,w,h = cv2.boundingRect(cnt)
          area = cv2.contourArea(cnt)

          ann_det = { # Annotation details
              'segmentation': [list(cnts[0].squeeze().reshape(1, -1).squeeze())],
              'area': area,
              'iscrowd': 0,
              'image_id': image_id,
              'bbox': [x, y, w, h],
              'category_id': word_cat_id,
              'id': annotation_id,
          }
          annotation_id += 1
          annotations.append(ann_det)
      
      print()
      print('Generated Annotations...')
      
      # Saving Annotations and Captions
      if image_id % SAVE_AFTER_NUM_IMAGES == 0:
        save()
        print(f'Annotations and Captions saved... object_detect-{save_idx}.json and object_caption-{save_idx}.json')

        save_idx += 1

      image_id += 1

      caption_id += 1

  if ((image_id-1) % SAVE_AFTER_NUM_IMAGES != 0 and image_id > 1) or image_id == 1:
    save()

except KeyboardInterrupt: # In case of KeyboardInterrupt save the annotations and captions
  save()

Let's take a look quickly at the images, annotations, categories and captions json files.

In [38]:
images

[{'license': 1,
  'file_name': '0_0.png',
  'height': 512,
  'width': 512,
  'date_captured': datetime.datetime(2022, 12, 23, 12, 13, 23, 605830),
  'id': 1},
 {'license': 1,
  'file_name': '0_1.png',
  'height': 512,
  'width': 512,
  'date_captured': datetime.datetime(2022, 12, 23, 12, 13, 47, 348571),
  'id': 2},
 {'license': 1,
  'file_name': '1_0.png',
  'height': 512,
  'width': 512,
  'date_captured': datetime.datetime(2022, 12, 23, 12, 14, 11, 420311),
  'id': 3},
 {'license': 1,
  'file_name': '1_1.png',
  'height': 512,
  'width': 512,
  'date_captured': datetime.datetime(2022, 12, 23, 12, 14, 35, 787937),
  'id': 4},
 {'license': 1,
  'file_name': '2_0.png',
  'height': 512,
  'width': 512,
  'date_captured': datetime.datetime(2022, 12, 23, 12, 15, 0, 294341),
  'id': 5},
 {'license': 1,
  'file_name': '2_1.png',
  'height': 512,
  'width': 512,
  'date_captured': datetime.datetime(2022, 12, 23, 12, 15, 24, 805035),
  'id': 6},
 {'license': 1,
  'file_name': '3_0.png',
  'he

In [37]:
annotations

[{'segmentation': [[404, 510, 404, 511]],
  'area': 0.0,
  'iscrowd': 0,
  'image_id': 1,
  'bbox': [404, 510, 1, 2],
  'category_id': 1,
  'id': 1},
 {'segmentation': [[404, 510, 404, 511]],
  'area': 5.0,
  'iscrowd': 0,
  'image_id': 1,
  'bbox': [202, 509, 4, 3],
  'category_id': 1,
  'id': 2},
 {'segmentation': [[404, 510, 404, 511]],
  'area': 17.5,
  'iscrowd': 0,
  'image_id': 1,
  'bbox': [144, 507, 6, 5],
  'category_id': 1,
  'id': 3},
 {'segmentation': [[404, 510, 404, 511]],
  'area': 37.5,
  'iscrowd': 0,
  'image_id': 1,
  'bbox': [505, 504, 7, 8],
  'category_id': 1,
  'id': 4},
 {'segmentation': [[404, 510, 404, 511]],
  'area': 117.0,
  'iscrowd': 0,
  'image_id': 1,
  'bbox': [319, 504, 20, 8],
  'category_id': 1,
  'id': 5},
 {'segmentation': [[404, 510, 404, 511]],
  'area': 55.5,
  'iscrowd': 0,
  'image_id': 1,
  'bbox': [164, 503, 10, 9],
  'category_id': 1,
  'id': 6},
 {'segmentation': [[404, 510, 404, 511]],
  'area': 76.0,
  'iscrowd': 0,
  'image_id': 1,
  

In [32]:
categories

[{'supercategory': 'group', 'id': 1, 'name': 'group'},
 {'supercategory': 'people', 'id': 2, 'name': 'people'},
 {'supercategory': 'back', 'id': 3, 'name': 'back'},
 {'supercategory': 'truck', 'id': 4, 'name': 'truck'},
 {'supercategory': 'filled', 'id': 5, 'name': 'filled'},
 {'supercategory': 'cotton', 'id': 6, 'name': 'cotton'},
 {'supercategory': 'mother', 'id': 7, 'name': 'mother'},
 {'supercategory': 'child', 'id': 8, 'name': 'child'},
 {'supercategory': 'garbage', 'id': 9, 'name': 'garbage'},
 {'supercategory': 'blue', 'id': 10, 'name': 'blue'},
 {'supercategory': 'street', 'id': 11, 'name': 'street'},
 {'supercategory': 'woman', 'id': 12, 'name': 'woman'},
 {'supercategory': 'chair', 'id': 13, 'name': 'chair'},
 {'supercategory': 'book', 'id': 14, 'name': 'book'},
 {'supercategory': 'head', 'id': 15, 'name': 'head'},
 {'supercategory': 'hand', 'id': 16, 'name': 'hand'},
 {'supercategory': 'brown', 'id': 17, 'name': 'brown'},
 {'supercategory': 'dog', 'id': 18, 'name': 'dog'},
 

In [39]:
captions

[{'id': 1,
  'image_id': 1,
  'caption': 'A group of people stand in the back of a truck filled with cotton.'},
 {'id': 1,
  'image_id': 2,
  'caption': 'A group of people stand in the back of a truck filled with cotton.'},
 {'id': 2,
  'image_id': 3,
  'caption': 'A mother and three children collecting garbage from a blue and white garbage can on the street.'},
 {'id': 2,
  'image_id': 4,
  'caption': 'A mother and three children collecting garbage from a blue and white garbage can on the street.'},
 {'id': 3,
  'image_id': 5,
  'caption': 'A woman is sitting in a chair reading a book with her head resting on her free hand.'},
 {'id': 3,
  'image_id': 6,
  'caption': 'A woman is sitting in a chair reading a book with her head resting on her free hand.'},
 {'id': 4,
  'image_id': 7,
  'caption': 'A brown and white dog exiting a yellow and blue ramp in a grassy area.'},
 {'id': 4,
  'image_id': 8,
  'caption': 'A brown and white dog exiting a yellow and blue ramp in a grassy area.'},
 {

To download the folder containing all the generated data.

In [47]:
!zip -r /content/file.zip /content/Data-Generated
from google.colab import files
files.download("/content/file.zip")

  adding: content/Data-Generated/ (stored 0%)
  adding: content/Data-Generated/annotations/ (stored 0%)
  adding: content/Data-Generated/annotations/object_detect-1.json (deflated 96%)
  adding: content/Data-Generated/captions/ (stored 0%)
  adding: content/Data-Generated/captions/object_caption-1.json (deflated 82%)
  adding: content/Data-Generated/images/ (stored 0%)
  adding: content/Data-Generated/images/1_0.png (deflated 0%)
  adding: content/Data-Generated/images/3_0.png (deflated 0%)
  adding: content/Data-Generated/images/4_0.png (deflated 0%)
  adding: content/Data-Generated/images/3_1.png (deflated 1%)
  adding: content/Data-Generated/images/2_1.png (deflated 0%)
  adding: content/Data-Generated/images/2_0.png (deflated 0%)
  adding: content/Data-Generated/images/4_1.png (deflated 0%)
  adding: content/Data-Generated/images/0_1.png (deflated 0%)
  adding: content/Data-Generated/images/1_1.png (deflated 0%)
  adding: content/Data-Generated/images/0_0.png (deflated 0%)


### Visualization of DAAM related outputs

Now that I have generated an arsenal of outputs for each generated image (where even the generated images are multiple for each prompt i.e. `NUM_IMAGES_PER_PROMPT`). Let's walk through these folders or files.

In [46]:
# from directory_tree import display_tree
# display_tree('Data-Generated')


$ Operating System : Linux
$ Path : Data-Generated

*************** Directory Tree ***************

Data-Generated/
├── annotations/
│   └── object_detect-1.json
├── captions/
│   └── object_caption-1.json
└── images/
    ├── 0_0.png
    ├── 0_1.png
    ├── 1_0.png
    ├── 1_1.png
    ├── 2_0.png
    ├── 2_1.png
    ├── 3_0.png
    ├── 3_1.png
    ├── 4_0.png
    └── 4_1.png
