## Captioning Image Datasets with BLIP-2

### Set all the necessary parameters

In [None]:
# Set the seed for reproducibility
seed = 420

# Set the dataset name
dataset_name = "imagewoof"

# Set if the captions should be grouped by class or not (in the case of finetuning prompts, the captions must respect the order of the HF dataset)
group_by_class = False
if group_by_class:
    output_captions_file = f"../storage/prompts/{dataset_name}_blip2_pipeline.json"
else:
    output_captions_file = f"../storage/prompts/{dataset_name}_blip2_finetuning.txt"

# Set the dataset path on Hugging Face Datasets library
dataset_hf_path = "frgfm/imagewoof"

# Set the name of the image and label columns of the dataset on Hugging Face
image_column = "image"
label_column = "label"

# Set the BLIP-2 checkpoint on Hugging Face
blip2_checkpoint = "Salesforce/blip2-opt-2.7b"

### Run to create the captions file!

In [None]:
import json
import random
import numpy as np
from datasets import load_dataset, ClassLabel
from transformers import AutoProcessor, Blip2ForConditionalGeneration
import torch
from tqdm import tqdm
from PIL import Image
import torchvision.transforms as transforms
import IPython

# Set random seed for reproducibility
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

# Load dataset
if dataset_hf_path == "frgfm/imagewoof":
    dataset = load_dataset(dataset_hf_path, "full_size", split="train")
else:
    dataset = load_dataset(dataset_hf_path, split="train")

# Load BLIP-2 captioner and image_processor
image_processor = AutoProcessor.from_pretrained(blip2_checkpoint)
captioner = Blip2ForConditionalGeneration.from_pretrained(blip2_checkpoint, torch_dtype=torch.float16)

# Move captioner to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
captioner.to(device)

# Function to generate caption for an image
def generate_caption(image):
    inputs = image_processor(image, return_tensors="pt").to(device, torch.float16) # the image_processor will take care of resizing and normalizing the image
    generated_ids = captioner.generate(**inputs, max_new_tokens=20)
    caption = image_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
    return caption


# Generate captions for the train split of the dataset
captions = {} if group_by_class else []
for item in tqdm(dataset):
    image = item[image_column]
    label = item[label_column]

    # Process the label to get a string: if the label is of type 'ClassLabel', convert it to string
    if isinstance(dataset.features[label_column], ClassLabel):
        label = dataset.features[label_column].int2str(label)

    # Fix the "cra" typo in CIFAR-100 dataset
    if dataset_hf_path == "uoft-cs/cifar100" and label == "cra":
        label = "crab"

    caption = generate_caption(image)

    if group_by_class:
        if label not in captions:
            captions[label] = []
        captions[label].append(caption)
    else:
        captions.append(f"{label}: {caption}")


# Save pipeline captions to a JSON file (with indentation for better readability)
if group_by_class:
    with open(output_captions_file, 'w') as f:
        json.dump(captions, f, indent=4)
# Save finetuning captions to a TXT file
else:
    with open(output_captions_file, 'w') as f:
        for caption in captions:
            f.write(caption + '\n')


print(f"Captions saved to {output_captions_file}")