## Importing Necessary Libraries

Importing all the necessary libraries required.

In [4]:
import os
import math
import time
import copy
import random
import numpy as np
from tqdm import tqdm
from datetime import datetime
# from torchviz import make_dot
import matplotlib.pyplot as plt

import cv2
from torchvision import transforms
from PIL import Image

from pycocotools.coco import COCO
import torch.utils.data as data
from torch.utils.data import Dataset

from huggingface_hub import notebook_login

import torch
from diffusers import StableDiffusionPipeline

from daam import trace, set_seed, plot_overlay_heat_map, expand_image

In [6]:
notebook_login()

Token is valid.
Your token has been saved in your configured git credential helpers (osxkeychain).
Your token has been saved to /Users/rishideychowdhury/.huggingface/token
Login successful


In [7]:
device = torch.device('mps') # Set it to 'cuda' for gpu or 'cpu' for cpu or 'mps' for M1

## Data Preparation

I am using the **Flickr30k Dataset**.

I define the class that will handle the custom dataset. The format in which the dataset of images and the annotation files are present is according to the COCO guidelines.

Make sure to have extracted the image and annotation features for the following class to function properly. The data directory tree is being provided below for reference. Suppose the folder containing all the data is named `data`.
```
data ───────────────────────────> Contains the Dataset and the relevant pre-computed features
├── flickr30k_annotations ──────> Contains the Annotation Files for the Flickr30k Dataset
│   ├── flickr30k_all.json ─────> Annotation File for the entire Flickr30k Images
│   ├── train.json ─────────────> Annotation File for the train Flickr30k Images 
│   └── val.json ───────────────> Annotation File for the validation Flickr30k Images 
└── flickr30k_images ───────────> Contains the Image Files for the Flickr30k Dataset
    └── flickr30k_images ───────> Ignore this directory
        ├── 1000092795.jpg
        ├── 10002456.jpg ───────> Images from the Flickr30k Dataset
        ├── ...
        └── 998845445.jpg
```

### Custom Dataset Classes

Following are the custom Dataset classes to handle our data and will serve has an input to our DataLoader for training the model.

#### Annotation COCO Dataset

The following is the class that handles the dataset in COCO format. It deals with the raw captions only. The `AnnCocoDataset` class can be used to fetch the data in the form of:
- `ann_ids`: IDs for the annotations in the .json annotation files
- `tgt_texts_raw`: Raw caption texts for each image (Since, there are multiple captions for each image, we select a random caption out of the many for each image)
- `texts`: Should contain the transformed preprocessed caption text, but as of now it returns `tgt_texts_raw`. **TODO!!!**

In [116]:
class AnnCocoDataset(Dataset):
  
  def __init__(self, ann_path):

    self.ann_path = ann_path
    
    self.coco = COCO(ann_path)
    self.image_ids = self.coco.getImgIds()

  def __len__(self):
    return len(self.image_ids)

  def load_annotations(self, image_index, return_all=False):
    ann_id = self.coco.getAnnIds(imgIds=self.image_ids[image_index])

    if not return_all:
      ann_id = ann_id[0] # A random annotation out of the many annotations is returned
      anns = self.coco.loadAnns(ann_id)[0]['caption']
    else:
      anns = self.coco.loadAnns(ann_id)
      anns = [i['caption'] for i in anns]
    return anns, ann_id

  def __getitem__(self, index):
    text, ann_id = self.load_annotations(index)
    return {
      'ann_id': ann_id,
      'text': text,
  }

  def collate_fn(self, batch):
    ann_ids = [s['ann_id'] for s in batch]
    texts = [s['text'] for s in batch]

    return {
      'ann_ids': ann_ids,
      'tgt_texts_raw': texts,
      'texts': texts,
    }

  def __str__(self): 
    s1 = "Number of images: " + str(len(self.image_ids)) + '\n'
    s2 = "Number of texts: " + str(len(self.coco.getAnnIds())) + '\n'
    return s1 + s2

In [117]:
dataset = AnnCocoDataset(
    ann_path="/Users/rishideychowdhury/Desktop/Joint-Embedding/Data/flickr30k_annotations/val.json"
)

loading annotations into memory...
Done (t=0.01s)
creating index...
index created!


## Setting the Model Pipeline

I will be using `CompVis/stable-diffusion-v1-4` model(=`model_id`) to generate images corresponding to the captions provided in the `dataset` above.

In [118]:
model_id = "CompVis/stable-diffusion-v1-4" # model id
dev = 'mps' # set to 'cpu', 'cuda' or 'mps'

In [119]:
pipe = StableDiffusionPipeline.from_pretrained(
  model_id, 
  use_auth_token=True
)
pipe = pipe.to(dev)

Fetching 19 files:   0%|          | 0/19 [00:00<?, ?it/s]

Downloading:   0%|          | 0.00/1.34k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/17.2k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/71.2k [00:00<?, ?B/s]

In [123]:
prompt = 'A dog runs across the field'
# gen = set_seed(0)

with torch.no_grad():
    with trace(pipe) as tc:
        out = pipe(prompt, num_inference_steps=30)
        heat_map = tc.compute_global_heat_map(prompt)
        heat_map = expand_image(heat_map.compute_word_heat_map('dog'))
        plot_overlay_heat_map(out.images[0], heat_map)
        plt.show()

  0%|          | 0/32 [00:00<?, ?it/s]

  return torch._C._nn.upsample_bicubic2d(input, output_size, align_corners, scale_factors)


AssertionError: Torch not compiled with CUDA enabled