## Description

This notebook uses BLIP-2, a state-of-the-art vision-language model by Salesforce, to generate descriptive text annotations (captions) of real images locally. These annotations can be used as downstream prompts for input into text-to-image diffusion models, to generate new synthetic image datasets. An intermediary large language model (LLM) may be used to format and clean/improve the quality of annotations created before they are used for diffusion input.

HuggingFace docs: https://huggingface.co/docs/transformers/main/en/model_doc/blip_2.

## Set-up environment

First, follow the set up instructions in (make sure you have finished running 'python download_data.py')

Compute advisory: Recommended to run in a GPU environment with high RAM.

In [2]:
%%time

# Standard libraries
import os
import random
import time
import logging
from typing import List, Dict, Tuple, Any
import json

# Third-party libraries
import numpy as np
import PIL
import torch
from transformers import AutoProcessor, Blip2ForConditionalGeneration, BlipProcessor, logging as transformers_logging

# Bitmind-specific libraries
from bitmind.image_dataset import ImageDataset
from bitmind.constants import DATASET_META
from bitmind.constants import IMAGE_ANNOTATION_MODEL

CPU times: user 68 μs, sys: 28 μs, total: 96 μs
Wall time: 97.8 μs


### Load finetuned BLIP-2 and processor

We can instantiate the model and its corresponding processor from the [hub](https://huggingface.co/models?other=blip-2). Here we load a BLIP-2 checkpoint that leverages the pre-trained OPT model by Meta AI, which as 2.7 billion parameters.

In [3]:
%%time
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

device = "cuda" if torch.cuda.is_available() else "cpu"

#Default log settings
transformers_level = logging.getLogger("transformers").getEffectiveLevel()
huggingface_hub_level = logging.getLogger("huggingface_hub").getEffectiveLevel()

#Suppress logs
logging.getLogger("transformers").setLevel(logging.ERROR)
logging.getLogger("huggingface_hub").setLevel(logging.ERROR)

processor = AutoProcessor.from_pretrained(IMAGE_ANNOTATION_MODEL)
# by default `from_pretrained` loads the weights in float32
# we load in float16 instead to save memory
model = Blip2ForConditionalGeneration.from_pretrained(IMAGE_ANNOTATION_MODEL, torch_dtype=torch.float16) 
model.to(device)

#Restore log settings
logging.getLogger("transformers").setLevel(transformers_level)
logging.getLogger("huggingface_hub").setLevel(huggingface_hub_level)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

CPU times: user 11.7 s, sys: 11.4 s, total: 23.1 s
Wall time: 10.9 s


### Load Real Image Datasets

In [4]:
%%time
print("Loading real datasets")
real_image_datasets = [
    ImageDataset(ds['path'], 'test', ds.get('name', None), ds['create_splits'])
    for ds in DATASET_META['real']
]
real_image_datasets

Loading real datasets


Resolving data files:   0%|          | 0/3606 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/18 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/52 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/27 [00:00<?, ?it/s]

CPU times: user 1.46 s, sys: 1.54 s, total: 3 s
Wall time: 9.7 s


[<bitmind.image_dataset.ImageDataset at 0x7f5eda2bf8b0>,
 <bitmind.image_dataset.ImageDataset at 0x7f5edaa64a00>,
 <bitmind.image_dataset.ImageDataset at 0x7f5eda3e79d0>]

In [5]:
for dataset in real_image_datasets:
    print(dataset.huggingface_dataset_path, dataset.dataset)

dalle-mini/open-images Dataset({
    features: ['url', 'key', 'shard_id', 'status', 'error_message', 'width', 'height', 'exif', 'original_width', 'original_height'],
    num_rows: 125436
})
merkol/ffhq-256 Dataset({
    features: ['image'],
    num_rows: 7000
})
saitsharipov/CelebA-HQ Dataset({
    features: ['image'],
    num_rows: 20260
})


## Prompt-Based Text Annotation

Pipeline for generating text annotations describing each image in real image datasets. 

In [6]:
prompts = [
    "A picture of",
    "The setting is",
    "The background is",
    "The image type/style is"
    # "the background is",
    # "The color(s) are",
    # "The texture(s) are",
    # "The emotion/mood is",
    # "The image medium is",
    # "The image style is"
]

In [7]:
def generate_description(image: PIL.Image.Image, use_prompts: bool = True, verbose: bool = False) -> str:
    """
    Generates a description for a given image using a sequence of prompts.

    Parameters:
    image (PIL.Image.Image): The image for which to generate a text description (string).
    verbose (bool): If True, prints the prompts and answers during processing. Defaults to False.

    Returns:
    str: The generated description for the image.
    """
    if not verbose: transformers_logging.set_verbosity_error() # Only display error messages (no warnings)
    description = ""
    if not use_prompts:
        inputs = processor(image, return_tensors="pt").to(device, torch.float16)
        generated_ids = model.generate(**inputs, max_new_tokens=20)
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
        description += generated_text
    else:
        for i, prompt in enumerate(prompts):
            # Append prompt to description to build context history
            description += prompt + ' '
            inputs = processor(image, text=description, return_tensors="pt").to(device, torch.float16)
            generated_ids = model.generate(**inputs, max_new_tokens=20)
            answer = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
            if answer:
                # Append answer to description to build context history
                description += answer
            else:
                description = description[:-len(prompt) - 1]  # Remove the last prompt if no answer is generated
            if verbose:
                print(f"{i}. Prompt: {prompt}")
                print(f"{i}. Answer: {answer}")
    if not verbose: transformers_logging.set_verbosity_info() # Restore transformer warnings
    return description


In [8]:
# Helper functions

def set_logging_level(verbose: int):
    level = logging.WARNING if verbose == 0 else logging.INFO if verbose < 3 else logging.DEBUG
    logging.getLogger().setLevel(level)

def ensure_save_path(path: str) -> str:
    if not os.path.exists(path):
        os.makedirs(path)
    return path

def create_annotation_dataset_directory(base_path: str, dataset_name: str) -> str:
    safe_name = dataset_name.replace("/", "_")
    full_path = os.path.join(base_path, safe_name)
    if not os.path.exists(full_path):
        os.makedirs(full_path)
    return full_path

def resize_image(image: PIL.Image.Image, max_width: int, max_height: int) -> PIL.Image.Image:
    """
    Resize the image if it is above the specified dimensions while maintaining the aspect ratio.

    Parameters:
    image (PIL.Image.Image): PIL.Image.Image object to be resized
    max_width (int): Maximum allowed width
    max_height (int): Maximum allowed height

    Returns:
    PIL.Image.Image: Resized PIL.Image.Image object
    """
    original_width, original_height = image.size

    # Check if the image is already within the allowed dimensions
    if original_width <= max_width and original_height <= max_height:
        return image

    # Calculate the aspect ratio
    aspect_ratio = original_width / original_height

    # Determine the new dimensions based on the aspect ratio
    if original_width > max_width:
        new_width = max_width
        new_height = int(new_width / aspect_ratio)
    else:
        new_height = max_height
        new_width = int(new_height * aspect_ratio)
    
    # Adjust if new dimensions exceed the maximum allowed dimensions
    if new_height > max_height:
        new_height = max_height
        new_width = int(new_height * aspect_ratio)
    if new_width > max_width:
        new_width = max_width
        new_height = int(new_width / aspect_ratio)

    # Resize the image using the high-quality ANTIALIAS filter
    resized_image = image.resize((new_width, new_height), PIL.Image.LANCZOS)
    
    return resized_image

def generate_annotation(image_id,
                        dataset_name: str,
                        image: PIL.Image.Image,
                        original_dimensions: tuple,
                        resize: bool,
                        resize_dim: int,
                        use_prompts: bool,
                        verbose: int):
    """
    Generate a text annotation for a given image.

    Parameters:
    image_id (int or str): The identifier for the image within the dataset.
    dataset_name (str): The name of the dataset the image belongs to.
    image (PIL.Image.Image): The image object that requires annotation.
    original_dimensions (tuple): Original dimensions of the image as (width, height).
    resize (bool): Allow image downsizing to maximum dimensions of (1280, 1280).
    verbose (int): Verbosity level.

    Returns:
    dict: Dictionary containing the annotation data.
    """
    image_to_process = image.copy()
    if resize:
        image_to_process = resize_image(image_to_process, resize_dim, resize_dim)
        if verbose > 1 and image_to_process.size != image.size:
            print(f"Resized {image_id}: {image.size} to {image_to_process.size}")

    description = generate_description(image_to_process, use_prompts, verbose > 2)
    annotation = {
        'description': description,
        'original_dataset': dataset_name,
        'original_dimensions': f"{original_dimensions[0]}x{original_dimensions[1]}",
        'index': image_id
    }
    return annotation

def save_annotation(dataset_dir: str, image_id, annotation: dict, verbose: int):
    """
    Save a text annotation to a JSON file if it doesn't already exist.

    Parameters:
    dataset_dir (str): The directory where the annotation file will be saved.
    image_id (int or str): The identifier for the image within the dataset.
    annotation (dict): Annotation data to be saved.
    verbose (int): Verbosity level. If greater than 0, it prints messages during processing.

    Returns:
    int: Returns 0 if the annotation is successfully saved, -1 if the annotation file already exists.
    """
    file_path = os.path.join(dataset_dir, f"{image_id}.json")
    if os.path.exists(file_path):
        if verbose > 0: print(f"Annotation for {image_id} already exists - Skipping")
        return -1  # Skip this image as it already has an annotation
    
    with open(file_path, 'w') as f:
        json.dump(annotation, f, indent=4)
        if verbose > 0: print(f"Created {file_path}")

    return 0

def process_image(dataset_dir: str,
                  image_info: dict,
                  dataset_name: str,
                  image_index: int,
                  resize: bool,
                  resize_dim : int,
                  use_prompts: bool,
                  verbose: int) -> tuple:
    if image_info['image'] is None:
        if verbose > 1:
            logging.debug(f"Skipping image {image_index} in dataset {dataset_name} due to missing image data.")
        return None, 0

    original_dimensions = image_info['image'].size
    start_time = time.time()
    annotation = generate_annotation(image_index,
                                     dataset_name,
                                     image_info['image'],
                                     original_dimensions,
                                     resize,
                                     resize_dim,
                                     use_prompts,
                                     verbose)
    save_annotation(dataset_dir, image_index, annotation, verbose)
    time_elapsed = time.time() - start_time

    if annotation == -1:
        if verbose > 1:
            logging.debug(f"Failed to generate annotation for image {image_index} in dataset {dataset_name}")
        return None, time_elapsed
    
    return annotation, time_elapsed

def compute_annotation_latency(processed_images: int, dataset_time: float, dataset_name: str) -> float:
    if processed_images > 0:
        average_latency = dataset_time / processed_images
        logging.info(f'Average annotation latency for {dataset_name}: {average_latency:.4f} seconds')
        return average_latency
    return 0.0

In [9]:
def generate_annotation_dataset(real_image_datasets: List[Any],
                                save_path: str = 'annotations/',
                                verbose: int = 0,
                                max_images: int | None = None,
                                resize_images = False,
                                resize_dim = 512,
                                use_prompts : bool = True) -> Tuple[Dict[str, Dict[str, Any]], float]:
    """
    Generates text annotations for images in the given datasets, saves them in a specified directory, 
    and computes the average per image latency. Returns a dictionary of new annotations and the average latency.

    Parameters:
        real_image_datasets (List[Any]): Datasets containing images.
        save_path (str): Directory path for saving annotation files.
        verbose (int): Verbosity level for process messages (Most verbose = 3).
        max_images (int): Maximum number of images to annotate.
        resize_images (bool) : Allow image downsizing before captioning.
                               Sets max dimensions to (resize_dim, resize_dim), maintaining aspect ratio.
3
    Returns:
        Tuple[Dict[str, Dict[str, Any]], float]: A tuple containing the annotations dictionary and average latency.
    """
    set_logging_level(verbose)
    annotations_dir = ensure_save_path(save_path)
    annotations = {}
    total_time = 0
    total_processed_images = 0

    for i, dataset in enumerate(real_image_datasets):
        dataset_name = dataset.huggingface_dataset_path
        dataset_dir = create_annotation_dataset_directory(annotations_dir, dataset_name)
        processed_images = 0
        dataset_time = 0

        for j, image_info in enumerate(dataset):
            annotation, time_elapsed = \
            process_image(dataset_dir, image_info, dataset_name, j, resize_images, resize_dim, use_prompts, verbose)
            if annotation is not None:
                annotations.setdefault(dataset_name, {})[image_info['id']] = annotation
                total_time += time_elapsed
                dataset_time += time_elapsed
                processed_images += 1
                if max_images is not None and processed_images >= max_images:
                    break

        average_latency = compute_annotation_latency(processed_images, dataset_time, dataset_name)
        total_processed_images += processed_images

    overall_average_latency = total_time / total_processed_images if total_processed_images else 0
    return annotations, overall_average_latency

Prompt v.s. Promptless Annotation

In [22]:
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
annotations_dict, average_latency = generate_annotation_dataset(real_image_datasets,
                                                                save_path='test_data/no_prompts_test/',
                                                                verbose=2,
                                                                max_images=10,
                                                               use_prompts=False)

Created no_prompts_test/dalle-mini_open-images/1.json
Created no_prompts_test/dalle-mini_open-images/2.json
Created no_prompts_test/dalle-mini_open-images/3.json
Created no_prompts_test/dalle-mini_open-images/4.json
Created no_prompts_test/dalle-mini_open-images/5.json
Created no_prompts_test/dalle-mini_open-images/7.json
Created no_prompts_test/dalle-mini_open-images/8.json
Created no_prompts_test/dalle-mini_open-images/9.json
Created no_prompts_test/dalle-mini_open-images/10.json


2024-07-09 10:49:12,553 - INFO - Average annotation latency for dalle-mini/open-images: 5.2395 seconds


Created no_prompts_test/dalle-mini_open-images/12.json
Created no_prompts_test/merkol_ffhq-256/0.json
Created no_prompts_test/merkol_ffhq-256/1.json
Created no_prompts_test/merkol_ffhq-256/2.json
Created no_prompts_test/merkol_ffhq-256/3.json
Created no_prompts_test/merkol_ffhq-256/4.json
Created no_prompts_test/merkol_ffhq-256/5.json
Created no_prompts_test/merkol_ffhq-256/6.json
Created no_prompts_test/merkol_ffhq-256/7.json
Created no_prompts_test/merkol_ffhq-256/8.json


2024-07-09 10:50:00,472 - INFO - Average annotation latency for merkol/ffhq-256: 4.7887 seconds


Created no_prompts_test/merkol_ffhq-256/9.json
Created no_prompts_test/saitsharipov_CelebA-HQ/0.json
Created no_prompts_test/saitsharipov_CelebA-HQ/1.json
Created no_prompts_test/saitsharipov_CelebA-HQ/2.json
Created no_prompts_test/saitsharipov_CelebA-HQ/3.json
Created no_prompts_test/saitsharipov_CelebA-HQ/4.json
Created no_prompts_test/saitsharipov_CelebA-HQ/5.json
Created no_prompts_test/saitsharipov_CelebA-HQ/6.json
Created no_prompts_test/saitsharipov_CelebA-HQ/7.json
Created no_prompts_test/saitsharipov_CelebA-HQ/8.json


2024-07-09 10:50:54,181 - INFO - Average annotation latency for saitsharipov/CelebA-HQ: 5.3601 seconds


Created no_prompts_test/saitsharipov_CelebA-HQ/9.json


Testing latency of image resizing vs no resizing in image preprocessing before annotation:

In [18]:
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
annotations_dict, average_latency = generate_annotation_dataset(real_image_datasets,
                                                                save_path='test_data/resize_test_1/',
                                                                verbose=2,
                                                                max_images=10,
                                                                resize_images=True,
                                                                resize_dim = 256)

Resized 1: (612, 612) to (256, 256)
Annotation for 1 already exists - Skipping
Resized 2: (800, 600) to (256, 192)
Annotation for 2 already exists - Skipping
Resized 3: (960, 1280) to (192, 256)
Annotation for 3 already exists - Skipping
Resized 4: (2592, 1944) to (256, 192)
Annotation for 4 already exists - Skipping
Resized 5: (1224, 814) to (256, 170)
Annotation for 5 already exists - Skipping
Resized 7: (827, 549) to (256, 169)
Annotation for 7 already exists - Skipping
Resized 8: (3456, 2304) to (256, 170)
Annotation for 8 already exists - Skipping
Resized 9: (3264, 2448) to (256, 192)
Annotation for 9 already exists - Skipping
Resized 10: (4000, 3000) to (256, 192)
Annotation for 10 already exists - Skipping
Resized 12: (1280, 960) to (256, 192)


2024-07-08 18:16:47,377 - INFO - Average annotation latency for dalle-mini/open-images: 14.7958 seconds


Annotation for 12 already exists - Skipping
Annotation for 0 already exists - Skipping
Annotation for 1 already exists - Skipping
Annotation for 2 already exists - Skipping
Annotation for 3 already exists - Skipping
Annotation for 4 already exists - Skipping
Annotation for 5 already exists - Skipping
Annotation for 6 already exists - Skipping
Annotation for 7 already exists - Skipping
Annotation for 8 already exists - Skipping


2024-07-08 18:19:02,736 - INFO - Average annotation latency for merkol/ffhq-256: 13.5327 seconds


Annotation for 9 already exists - Skipping
Annotation for 0 already exists - Skipping
Annotation for 1 already exists - Skipping
Annotation for 2 already exists - Skipping
Annotation for 3 already exists - Skipping
Annotation for 4 already exists - Skipping
Annotation for 5 already exists - Skipping
Annotation for 6 already exists - Skipping
Annotation for 7 already exists - Skipping
Annotation for 8 already exists - Skipping


2024-07-08 18:21:18,283 - INFO - Average annotation latency for saitsharipov/CelebA-HQ: 13.5533 seconds


Annotation for 9 already exists - Skipping


In [19]:
annotations_dict

{'dalle-mini/open-images': {'https://farm8.staticflickr.com/7214/7337479734_53f1048393_o.jpg': {'description': 'A picture of a man dressed as a witch standing in a field of flowersThe setting is a gardenThe background is a field of flowers',
   'original_dataset': 'dalle-mini/open-images',
   'original_dimensions': '612x612',
   'index': 1},
  'https://farm2.staticflickr.com/3299/3509371657_096da0a7e6_o.jpg': {'description': 'A picture of a pile of junk outside of a buildingThe setting is a cityThe background is a brick building',
   'original_dataset': 'dalle-mini/open-images',
   'original_dimensions': '800x600',
   'index': 2},
  'https://c4.staticflickr.com/1/55/155642663_7e28e8b2ff_o.jpg': {'description': 'A picture of a balcony with a flower pot on itThe setting is a cityThe background is a white building',
   'original_dataset': 'dalle-mini/open-images',
   'original_dimensions': '960x1280',
   'index': 3},
  'https://c4.staticflickr.com/9/8555/15625756039_a60b0bd0a5_o.jpg': {'d

In [20]:
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
annotations_dict, average_latency = generate_annotation_dataset(real_image_datasets,
                                                                save_path='test_data/no_resize_test_1/',
                                                                verbose=2,
                                                                max_images=10,
                                                                resize_images=False)

Annotation for 1 already exists - Skipping
Annotation for 2 already exists - Skipping
Annotation for 3 already exists - Skipping
Annotation for 4 already exists - Skipping
Annotation for 5 already exists - Skipping
Annotation for 7 already exists - Skipping
Annotation for 8 already exists - Skipping
Annotation for 9 already exists - Skipping
Annotation for 10 already exists - Skipping


2024-07-08 18:23:50,191 - INFO - Average annotation latency for dalle-mini/open-images: 14.9404 seconds


Annotation for 12 already exists - Skipping
Annotation for 0 already exists - Skipping
Annotation for 1 already exists - Skipping
Annotation for 2 already exists - Skipping
Annotation for 3 already exists - Skipping
Annotation for 4 already exists - Skipping
Annotation for 5 already exists - Skipping
Created no_resize_test_1/merkol_ffhq-256/6.json
Created no_resize_test_1/merkol_ffhq-256/7.json
Created no_resize_test_1/merkol_ffhq-256/8.json


2024-07-08 18:26:08,775 - INFO - Average annotation latency for merkol/ffhq-256: 13.8557 seconds


Created no_resize_test_1/merkol_ffhq-256/9.json
Created no_resize_test_1/saitsharipov_CelebA-HQ/0.json
Created no_resize_test_1/saitsharipov_CelebA-HQ/1.json
Created no_resize_test_1/saitsharipov_CelebA-HQ/2.json
Created no_resize_test_1/saitsharipov_CelebA-HQ/3.json
Created no_resize_test_1/saitsharipov_CelebA-HQ/4.json
Created no_resize_test_1/saitsharipov_CelebA-HQ/5.json
Created no_resize_test_1/saitsharipov_CelebA-HQ/6.json
Created no_resize_test_1/saitsharipov_CelebA-HQ/7.json
Created no_resize_test_1/saitsharipov_CelebA-HQ/8.json


2024-07-08 18:28:27,138 - INFO - Average annotation latency for saitsharipov/CelebA-HQ: 13.8353 seconds


Created no_resize_test_1/saitsharipov_CelebA-HQ/9.json


In [21]:
annotations_dict

{'dalle-mini/open-images': {'https://farm8.staticflickr.com/7214/7337479734_53f1048393_o.jpg': {'description': 'A picture of a man dressed as a witch standing in a field of flowersThe setting is a gardenThe background is a field of flowers',
   'original_dataset': 'dalle-mini/open-images',
   'original_dimensions': '612x612',
   'index': 1},
  'https://farm2.staticflickr.com/3299/3509371657_096da0a7e6_o.jpg': {'description': 'A picture of a pile of junk outside of a buildingThe setting is a dilapidated buildingThe background is a green plant',
   'original_dataset': 'dalle-mini/open-images',
   'original_dimensions': '800x600',
   'index': 2},
  'https://c4.staticflickr.com/1/55/155642663_7e28e8b2ff_o.jpg': {'description': 'A picture of a flower pot with pink flowers on a balconyThe setting is a city apartment buildingThe background is a white building',
   'original_dataset': 'dalle-mini/open-images',
   'original_dimensions': '960x1280',
   'index': 3},
  'https://c4.staticflickr.com