## Description

This notebook uses BLIP-2, a state-of-the-art vision-language model by Salesforce, to generate descriptive text annotations (captions) of real images locally. These annotations can be used as downstream prompts for input into text-to-image diffusion models, to generate new synthetic image datasets. An intermediary large language model (LLM) may be used to format and clean/improve the quality of annotations created before they are used for diffusion input.

HuggingFace docs: https://huggingface.co/docs/transformers/main/en/model_doc/blip_2.

## Set-up environment

First, follow the set up instructions in (make sure you have finished running 'python download_data.py')

Compute advisory: Recommended to run in a GPU environment with high RAM.

In [1]:
%%time

# Standard libraries
import random
import time
import logging
from typing import List, Dict, Tuple, Any

# Third-party libraries
import numpy as np
import PIL
import torch
from transformers import AutoProcessor, Blip2ForConditionalGeneration, BlipProcessor, logging as transformers_logging

# Bitmind-specific libraries
from bitmind.image_dataset import ImageDataset
from bitmind.constants import DATASET_META

  from .autonotebook import tqdm as notebook_tqdm
2024-07-03 10:51:38.216308: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:479] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-03 10:51:38.430967: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:10575] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-03 10:51:38.432390: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1442] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-07-03 10:51:38.789632: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


CPU times: user 6.75 s, sys: 3.4 s, total: 10.2 s
Wall time: 12.3 s


### Load finetuned BLIP-2 and processor

We can instantiate the model and its corresponding processor from the [hub](https://huggingface.co/models?other=blip-2). Here we load a BLIP-2 checkpoint that leverages the pre-trained OPT model by Meta AI, which as 2.7 billion parameters.

In [2]:
%%time
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

device = "cuda" if torch.cuda.is_available() else "cpu"

#Default log settings
transformers_level = logging.getLogger("transformers").getEffectiveLevel()
huggingface_hub_level = logging.getLogger("huggingface_hub").getEffectiveLevel()

#Suppress logs
logging.getLogger("transformers").setLevel(logging.ERROR)
logging.getLogger("huggingface_hub").setLevel(logging.ERROR)

processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b-coco")
# by default `from_pretrained` loads the weights in float32
# we load in float16 instead to save memory
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b-coco", torch_dtype=torch.float16) 
model.to(device)

#Restore log settings
logging.getLogger("transformers").setLevel(transformers_level)
logging.getLogger("huggingface_hub").setLevel(huggingface_hub_level)

Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:23<00:00, 11.84s/it]


CPU times: user 24.8 s, sys: 21.2 s, total: 46 s
Wall time: 50.6 s


### Load Real Image Datasets

In [3]:
%%time
print("Loading real datasets")
real_image_datasets = [
    ImageDataset(ds['path'], 'test', ds.get('name', None), ds['create_splits'])
    for ds in DATASET_META['real']
]
real_image_datasets

Loading real datasets
CPU times: user 6.32 s, sys: 7.5 s, total: 13.8 s
Wall time: 44.4 s


[<bitmind.image_dataset.ImageDataset at 0x7f7867ad4a30>,
 <bitmind.image_dataset.ImageDataset at 0x7f78647de770>,
 <bitmind.image_dataset.ImageDataset at 0x7f7867ad6260>]

In [4]:
for dataset in real_image_datasets:
    print(dataset.huggingface_dataset_path, dataset.dataset)

dalle-mini/open-images Dataset({
    features: ['url', 'key', 'shard_id', 'status', 'error_message', 'width', 'height', 'exif', 'original_width', 'original_height'],
    num_rows: 125436
})
merkol/ffhq-256 Dataset({
    features: ['image'],
    num_rows: 7000
})
saitsharipov/CelebA-HQ Dataset({
    features: ['image'],
    num_rows: 20260
})


## Prompt-Based Text Annotation

Pipeline for generating text annotations describing each image in real image datasets. 

In [5]:
prompts = [
    "A picture of",
    "The setting is",
    "The background is",
    "The image type/style is"
    # "the background is",
    # "The color(s) are",
    # "The texture(s) are",
    # "The emotion/mood is",
    # "The image medium is",
    # "The image style is"
]

In [6]:
def generate_description(image: PIL.Image.Image, verbose: bool = False) -> str:
    """
    Generates a description for a given image using a sequence of prompts.

    Parameters:
    image (PIL.Image.Image): The image for which to generate a text description (string).
    verbose (bool): If True, prints the prompts and answers during processing. Defaults to False.

    Returns:
    str: The generated description for the image.
    """
    if not verbose: transformers_logging.set_verbosity_error() # Only display error messages (no warnings)
    description = ""
    for i, prompt in enumerate(prompts):
        # Append prompt to description to build context history
        description += prompt + ' '
        inputs = processor(image, text=description, return_tensors="pt").to(device, torch.float16)
        generated_ids = model.generate(**inputs, max_new_tokens=20)
        answer = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
        if verbose:
            print(f"{i}. Prompt: {prompt}")
            print(f"{i}. Answer: {answer}")
        if answer:
            # Append answer to description to build context history
            description += answer + '.'
        else:
            description = description[:-len(prompt) - 1]  # Remove the last prompt if no answer is generated\
    if not verbose: transformers_logging.set_verbosity_info() # Restore transformer warnings
    return description


#### To-do

-Add index for datasets other than open-images for storage and diff purposes

-Improve latency by changing annotation method, implementing multiprocessing, or ensure efficient gpu usage for prompting BLIP-2 at scale

-Annotation storage to local file

-Diff checking between local file and annotations in notebook instance memory

##### Populate a nested dictionary with image annotations for all datasets

Highest dictionary dataset level: {Keys = HuggingFace dataset path : Values = Sub dictionary}

Sub dictionary image level: {Keys = image id : Values = string text annotations}

In [7]:
def generate_annotation_dataset(real_image_datasets: ImageDataset, verbose: bool = False, max_images = None) -> Tuple[Dict[str, Dict[str, Any]], float]:
    """
    Generates text annotations for each image in real_image_datasets, and calculate|s the average per image latency.

    Parameters:
    real_image_datasets (List[Any]): A list of datasets containing real images.
    verbose (bool): If True, prints the prompts and answers during processing. Defaults to False.
    max_images (int): Maximum number of images to annotate. If None, annotations will be generated for all images.
    Returns:
    Tuple[Dict[str, Dict[str, str]], float]: A tuple containing the annotations dictionary and the average latency.
    """
    annotations_dict = {}
    total_time = 0
    num_images = 0

    for i, dataset in enumerate(real_image_datasets):
        dataset_path = dataset.huggingface_dataset_path
        annotations_dict[dataset_path] = {}
        if verbose:
            print(f'{dataset_path} Dataset ({i + 1} of {len(real_image_datasets)}):')

        for j, image_dict in enumerate(dataset):
            if image_dict['image'] is not None:
                image = image_dict['image']

                start_time = time.time()  # Start the timer
                annotations = generate_description(image, verbose)
                end_time = time.time()  # End the timer

                # Latency calculation
                time_elapsed = end_time - start_time
                total_time += time_elapsed
                num_images += 1
                annotations_dict[dataset_path][image_dict['id']] = annotations
                if verbose:
                    print(f'Image: {j}')
                    print(annotations)
                    print(f'Time elapsed for this image: {time_elapsed:.4f} seconds')
                if max_images is not None and max_images < num_images: break
            if max_images is not None and max_images < num_images: break

    average_latency = total_time / num_images if num_images > 0 else 0.0

    print(f'Average annotation latency: {average_latency:.4f} seconds')
    return annotations_dict, average_latency

##### dalle-mini/open-images Dataset is the only real image dataset loaded with image ID's

In [8]:
image_datasets_to_annotate = [real_image_datasets[0]]
image_datasets_to_annotate, image_datasets_to_annotate[0].huggingface_dataset_path

([<bitmind.image_dataset.ImageDataset at 0x7f7867ad4a30>],
 'dalle-mini/open-images')

In [9]:
annotations_dict, average_latency = generate_annotation_dataset(image_datasets_to_annotate, verbose=False, max_images=20)

Average annotation latency: 17.2714 seconds


In [10]:
annotations_dict

{'dalle-mini/open-images': {'https://farm8.staticflickr.com/7214/7337479734_53f1048393_o.jpg': 'A picture of a man dressed as a witch standing in a field of flowers.The setting is a garden.The background is a field of flowers.The flowers are yellow and purple.The flowers are in a garden.The.',
  'https://farm2.staticflickr.com/3299/3509371657_096da0a7e6_o.jpg': 'A picture of a pile of junk outside of a building.The setting is a dilapidated building.The background is a green plant..',
  'https://c4.staticflickr.com/1/55/155642663_7e28e8b2ff_o.jpg': 'A picture of a flower pot with pink flowers on a balcony.The setting is a city apartment building.The background is a white building.The building is in the background..',
  'https://c4.staticflickr.com/9/8555/15625756039_a60b0bd0a5_o.jpg': 'A picture of a group of people standing in a room.The setting is a hospital.The background is a white wall.The people are smiling.The doctor is standing in front of the group.The.',
  'https://c4.staticfl

In [11]:
average_latency

17.2714068208422