## Description

This notebook uses BLIP-2, a state-of-the-art vision-language model by Salesforce, to generate descriptive text annotations (captions) of real images locally. These annotations can be used as downstream prompts for input into text-to-image diffusion models, to generate new synthetic image datasets. An intermediary large language model (LLM) may be used to format and clean/improve the quality of annotations created before they are used for diffusion input.

HuggingFace docs: https://huggingface.co/docs/transformers/main/en/model_doc/blip_2.

## Set-up environment

First, follow the set up instructions in (make sure you have finished running 'python download_data.py')

Compute advisory: Recommended to run in a GPU environment with high RAM.

In [64]:
%%time
from bitmind.image_dataset import ImageDataset
from bitmind.constants import DATASET_META
import numpy as np
import random
import time
import PIL

import torch
from transformers import AutoProcessor, Blip2ForConditionalGeneration, BlipProcessor, logging as transformers_logging
from typing import List, Dict, Tuple, Any
# from transformers import pipeline, BitsAndBytesConfig
# from transformers import AutoModelForCausalLM, AutoTokenizer

CPU times: user 55 μs, sys: 7 μs, total: 62 μs
Wall time: 66 μs


### Load finetuned BLIP-2 and processor

We can instantiate the model and its corresponding processor from the [hub](https://huggingface.co/models?other=blip-2). Here we load a BLIP-2 checkpoint that leverages the pre-trained OPT model by Meta AI, which as 2.7 billion parameters.

In [2]:
%%time
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

device = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b-coco")
# by default `from_pretrained` loads the weights in float32
# we load in float16 instead to save memory
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b-coco", torch_dtype=torch.float16) 
model.to(device)

Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:35<00:00, 17.83s/it]


CPU times: user 1min 14s, sys: 1min 2s, total: 2min 16s
Wall time: 1min 7s


Blip2ForConditionalGeneration(
  (vision_model): Blip2VisionModel(
    (embeddings): Blip2VisionEmbeddings(
      (patch_embedding): Conv2d(3, 1408, kernel_size=(14, 14), stride=(14, 14))
    )
    (encoder): Blip2Encoder(
      (layers): ModuleList(
        (0-38): 39 x Blip2EncoderLayer(
          (self_attn): Blip2Attention(
            (dropout): Dropout(p=0.0, inplace=False)
            (qkv): Linear(in_features=1408, out_features=4224, bias=True)
            (projection): Linear(in_features=1408, out_features=1408, bias=True)
          )
          (layer_norm1): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
          (mlp): Blip2MLP(
            (activation_fn): GELUActivation()
            (fc1): Linear(in_features=1408, out_features=6144, bias=True)
            (fc2): Linear(in_features=6144, out_features=1408, bias=True)
          )
          (layer_norm2): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
        )
      )
    )
    (post_layernorm): LayerNorm((

### Load Real Image Datasets

In [3]:
%%time
print("Loading real datasets")
real_image_datasets = [
    ImageDataset(ds['path'], 'test', ds.get('name', None), ds['create_splits'])
    for ds in DATASET_META['real']
]
real_image_datasets

Loading real datasets
CPU times: user 3.75 s, sys: 6.31 s, total: 10.1 s
Wall time: 39.2 s


[<bitmind.image_dataset.ImageDataset at 0x7f6b6e01b100>,
 <bitmind.image_dataset.ImageDataset at 0x7f6c7855cca0>,
 <bitmind.image_dataset.ImageDataset at 0x7f6afe77fc40>]

In [63]:
for dataset in real_image_datasets:
    print(dataset.huggingface_dataset_path, dataset.dataset)

dalle-mini/open-images Dataset({
    features: ['url', 'key', 'shard_id', 'status', 'error_message', 'width', 'height', 'exif', 'original_width', 'original_height'],
    num_rows: 125436
})
merkol/ffhq-256 Dataset({
    features: ['image'],
    num_rows: 7000
})
saitsharipov/CelebA-HQ Dataset({
    features: ['image'],
    num_rows: 20260
})


## Prompt-Based Text Annotation

Pipeline for generating text annotations describing each image in real image datasets. 

In [66]:
prompts = [
    "A picture of",
    "The setting is",
    "The background is",
    "The image type/style is"
    # "the background is",
    # "The color(s) are",
    # "The texture(s) are",
    # "The emotion/mood is",
    # "The image medium is",
    # "The image style is"
]

In [68]:
def generate_description(image: PIL.Image.Image, verbose: bool = False) -> str:
    """
    Generates a description for a given image using a sequence of prompts.

    Parameters:
    image (PIL.Image.Image): The image for which to generate a text description (string).
    verbose (bool): If True, prints the prompts and answers during processing. Defaults to False.

    Returns:
    str: The generated description for the image.
    """
    if not verbose: transformers_logging.set_verbosity_error() # Only display error messages (no warnings)
    description = ""
    for i, prompt in enumerate(prompts):
        # Append prompt to description to build context history
        description += prompt + ' '
        inputs = processor(image, text=description, return_tensors="pt").to(device, torch.float16)
        generated_ids = model.generate(**inputs, max_new_tokens=20)
        answer = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
        if verbose:
            print(f"{i}. Prompt: {prompt}")
            print(f"{i}. Answer: {answer}")
        if answer:
            # Append answer to description to build context history
            description += answer + '.'
        else:
            description = description[:-len(prompt) - 1]  # Remove the last prompt if no answer is generated\
    if not verbose: transformers_logging.set_verbosity_info() # Restore transformer warnings
    return description


#### To-do

-Add index for datasets other than open-images for storage and diff purposes

-Improve latency by changing annotation method, implementing multiprocessing, or ensure efficient gpu usage for prompting BLIP-2 at scale

-Annotation storage to local file

-Diff checking between local file and annotations in notebook instance memory

##### Populate a nested dictionary with image annotations for all datasets

Highest dictionary dataset level: {Keys = HuggingFace dataset path : Values = Sub dictionary}

Sub dictionary image level: {Keys = image id : Values = string text annotations}

In [92]:
def generate_annotation_dataset(real_image_datasets: ImageDataset, verbose: bool = False, image_cap = None) -> Tuple[Dict[str, Dict[str, Any]], float]:
    """
    Generates text annotations for each image in real_image_datasets, and calculate|s the average per image latency.

    Parameters:
    real_image_datasets (List[Any]): A list of datasets containing real images.
    verbose (bool): If True, prints the prompts and answers during processing. Defaults to False.
    Returns:
    Tuple[Dict[str, Dict[str, Any]], float]: A tuple containing the annotations dictionary and the average latency.
    """
    annotations_dict = {}
    total_time = 0
    num_images = 0

    for i, dataset in enumerate(real_image_datasets):
        dataset_path = dataset.huggingface_dataset_path
        annotations_dict[dataset_path] = {}
        if verbose:
            print(f'{dataset_path} Dataset ({i + 1} of {len(real_image_datasets)}):')

        for j, image_dict in enumerate(dataset):
            if image_dict['image'] is not None:
                image = image_dict['image']

                start_time = time.time()  # Start the timer
                annotations = generate_description(image, verbose)
                end_time = time.time()  # End the timer

                # Latency calculation
                time_elapsed = end_time - start_time
                total_time += time_elapsed
                num_images += 1
                print("Image id:", type(image_dict['id']), image_dict['id'])
                annotations_dict[dataset_path][image_dict['id']] = annotations
                if verbose:
                    print(f'Image: {j}')
                    print(annotations)
                    print(f'Time elapsed for this image: {time_elapsed:.4f} seconds')
                if image_cap is not None and image_cap < num_images: break
            if image_cap is not None and image_cap < num_images: break

    average_latency = total_time / num_images if num_images > 0 else 0.0

    print(f'Average annotation latency: {average_latency:.4f} seconds')
    return annotations_dict, average_latency

##### dalle-mini/open-images Dataset is the only real image dataset loaded with image ID's

In [93]:
image_datasets_to_annotate = [real_image_datasets[0]]
image_datasets_to_annotate, image_datasets_to_annotate[0].huggingface_dataset_path

([<bitmind.image_dataset.ImageDataset at 0x7f6b6e01b100>],
 'dalle-mini/open-images')

In [94]:
annotations_dict, average_latency = generate_annotation_dataset(image_datasets_to_annotate, verbose=False, image_cap=5)

Image id: <class 'str'> https://farm8.staticflickr.com/7214/7337479734_53f1048393_o.jpg
Image id: <class 'str'> https://farm2.staticflickr.com/3299/3509371657_096da0a7e6_o.jpg
Image id: <class 'str'> https://c4.staticflickr.com/1/55/155642663_7e28e8b2ff_o.jpg
Image id: <class 'str'> https://c4.staticflickr.com/9/8555/15625756039_a60b0bd0a5_o.jpg
Image id: <class 'str'> https://c4.staticflickr.com/1/52/129983132_b668be4a47_o.jpg
Image id: <class 'str'> https://farm5.staticflickr.com/3710/10921149963_e04ba75721_o.jpg
Average annotation latency: 21.8835 seconds


In [98]:
annotations_dict

{'dalle-mini/open-images': {'https://farm8.staticflickr.com/7214/7337479734_53f1048393_o.jpg': 'A picture of a man dressed as a witch standing in a field of flowers.The setting is a garden.The background is a field of flowers.The flowers are yellow and purple.The flowers are in a garden.The.',
  'https://farm2.staticflickr.com/3299/3509371657_096da0a7e6_o.jpg': 'A picture of a pile of junk outside of a building.The setting is a dilapidated building.The background is a green plant..',
  'https://c4.staticflickr.com/1/55/155642663_7e28e8b2ff_o.jpg': 'A picture of a flower pot with pink flowers on a balcony.The setting is a city apartment building.The background is a white building.The building is in the background..',
  'https://c4.staticflickr.com/9/8555/15625756039_a60b0bd0a5_o.jpg': 'A picture of a group of people standing in a room.The setting is a hospital.The background is a white wall.The people are smiling.The doctor is standing in front of the group.The.',
  'https://c4.staticfl

In [99]:
average_latency

21.883494019508362