# Description

This notebook implements image-to-text captioning for real image datasets. It is modified from the "Chat with BLIP-2" notebook illustrating usage of BLIP-2, a state-of-the-art vision-language model by Salesforce.

Original Notebook: https://github.com/NielsRogge/Transformers-Tutorials/blob/master/BLIP-2/Chat_with_BLIP_2.ipynb

HuggingFace docs: https://huggingface.co/docs/transformers/main/en/model_doc/blip_2.

## Set-up environment

Follow the set up instructions in (make sure you have finished running 'python download_data.py')

Compute advisory: Recommended to run in a GPU environment with high RAM.

In [None]:
%%time
from bitmind.image_dataset import ImageDataset
from bitmind.constants import DATASET_META
import numpy as np
import random

import torch
from transformers import AutoProcessor, Blip2ForConditionalGeneration

## Load model and processor

We can instantiate the model and its corresponding processor from the [hub](https://huggingface.co/models?other=blip-2). Here we load a BLIP-2 checkpoint that leverages the pre-trained OPT model by Meta AI, which as 2.7 billion parameters.

In [None]:
%%time
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

device = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
# by default `from_pretrained` loads the weights in float32
# we load in float16 instead to save memory
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16) 
model.to(device)

## Load Real Image Datasets

In [None]:
%%time
print("Loading real datasets")
real_image_datasets = [
    ImageDataset(ds['path'], 'test', ds.get('name', None), ds['create_splits'])
    for ds in DATASET_META['real']
]
real_image_datasets

#### Display random sample image from real image datasets

In [None]:
# select a dataset at random
real_dataset = real_image_datasets[np.random.randint(0, len(real_image_datasets))]
source_name = real_dataset.huggingface_dataset_path
# select a dict containing a sample imagee at random
sample = real_dataset.sample(k=1)[0][0]
image = sample['image']
print(sample)
image

## Image captioning

If you don't provide any text prompt, then the model will by default start generating text from the BOS (beginning-of-sequence) token. So it will generate a caption for the image.

In [None]:
inputs = processor(image, return_tensors="pt").to(device, torch.float16)

generated_ids = model.generate(**inputs, max_new_tokens=20)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
print(generated_text)

## Prompted image captioning

You can provide a text prompt, which the model will continue given the image.

In [None]:
%%time
prompt = "this is a picture of"

inputs = processor(image, text=prompt, return_tensors="pt").to(device, torch.float16)

generated_ids = model.generate(**inputs, max_new_tokens=20)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
print(generated_text)

In [None]:
%%time
prompt = "this is a picture of (in detail)"

inputs = processor(image, text=prompt, return_tensors="pt").to(device, torch.float16)

generated_ids = model.generate(**inputs, max_new_tokens=20)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
print(generated_text)

## Visual question answering (VQA)

In [None]:
prompt = "Question: what is this a picture of? Answer:"

inputs = processor(image, text=prompt, return_tensors="pt").to(device, torch.float16)

generated_ids = model.generate(**inputs, max_new_tokens=50)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
print(generated_text)

In [None]:
prompt = "Question: Describe this picture in detail starting with 'this is a picture of...'. Answer:"

inputs = processor(image, text=prompt, return_tensors="pt").to(device, torch.float16)

generated_ids = model.generate(**inputs, max_new_tokens=10)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
print(generated_text)

## Chat-based prompting

We can create a ChatGPT-like interface by simply concatenating each generated response to the conversation. We prompt the model with some text (like "which city is this?"), the model generates an answer for it "Singapore"), which we just concatenate to the conversation. Then we ask a follow-up question ("why?") which we also just concatenate and feed to the model.

This means that the context can't be too long - models like OPT and T5 (the language models being used in BLIP-2) have a context length of 512 tokens.

In [None]:
questions = [
    "Describe this picture starting with 'this is a picture of...'",
    "what colors are in the picture",
    "how many people are in the picture",
]
template = "Question: {} Answer: {}"

In [None]:
conversation = ""
for i in range(len(questions)):
    conversation += template.format(questions[i], '')
    inputs = processor(image, text=conversation, return_tensors="pt").to(device, torch.float16)
    generated_ids = model.generate(**inputs, max_new_tokens=20)
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
#     print(str(i) + ':', template.format(questions[i], generated_text))
    conversation += generated_text + ' '
    print(i, conversation)
print('\nConversation History:\n' + conversation)