# Chat with BLIP-2

In this notebook, we'll illustrate usage of BLIP-2, a state-of-the-art vision-language model by Salesforce.

HuggingFace docs: https://huggingface.co/docs/transformers/main/en/model_doc/blip_2.

## Set-up environment

We'll start by installing 🤗 Transformers. As the model is brand new at the time of writing this notebook, we install it from source.

Note that it's advised to run this notebook on a GPU environment, high RAM.

In [None]:
!pip install -q git+git@github.com:huggingface/transformers.git

## Load image

In [None]:
!pip install --upgrade transformers
import os
from PIL import Image
from transformers import AutoProcessor, Blip2ForConditionalGeneration
import torch

# 1. 載入 BLIP-2 模型與處理器
processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained(
    "Salesforce/blip2-opt-2.7b", 
    torch_dtype=torch.float16, 
    device_map="auto"
)

In [None]:
# 2. 設定圖片資料夾
img_dir = '/home/DL_MILS/MILS_Final/AVA_Dataset/road/train/road_0003'
img_files = [f for f in os.listdir(img_dir) if f.endswith('.jpg')]

# 3. 批次讀取圖片
images = [Image.open(os.path.join(img_dir, f)).convert('RGB') for f in img_files]

# 4. 設定問題（可根據需求修改）
# question = "Is there any risk in this image?"
question = "Question: Is this road risky?(answer 1:risky, answer 0:not risky) Answer:"

# 5. 批次處理與推論
inputs = processor(images=images, text=[question]*len(images), return_tensors="pt", padding=True).to(model.device)
with torch.no_grad():
    generated_ids = model.generate(**inputs, max_new_tokens=1)
    generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)

# 6. 只要有一張是 1 就輸出 1
def extract_label(text):
    text = text.strip()
    if text == '1' or text.lower().startswith('1') or 'risky' in text.lower():
        return 1
    return 0

labels = [extract_label(t) for t in generated_texts]
video_label = 1 if any(labels) else 0
print(f"Folder {os.path.basename(img_dir)} prediction(1:risk, 0:safe): {video_label} ")

In [None]:
# import os
# from PIL import Image
# from IPython.display import display

# # Directory containing the images
# img_dir = '/home/DL_MILS/MILS_Final/AVA_Dataset/road/train/road_0000'

# # List all jpg files in the directory
# img_files = [f for f in os.listdir(img_dir) if f.endswith('.jpg')]

# # Iterate and display each image
# for img_file in img_files:
#     img_path = os.path.join(img_dir, img_file)
#     image = Image.open(img_path).convert('RGB')
#     display(image.resize((596, 437)))

## Load model and processor

We can instantiate the model and its corresponding processor from the [hub](https://huggingface.co/models?other=blip-2). Here we load a BLIP-2 checkpoint that leverages the pre-trained OPT model by Meta AI, which as 2.7 billion parameters.

In [None]:
from transformers import AutoProcessor, Blip2ForConditionalGeneration
import torch

processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
# by default `from_pretrained` loads the weights in float32
# we load in float16 instead to save memory
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)

Let's use the GPU, as it will make generation a lot faster.

In [None]:
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

## Image captioning

If you don't provide any text prompt, then the model will by default start generating text from the BOS (beginning-of-sequence) token. So it will generate a caption for the image.

In [None]:
# inputs = processor(image, return_tensors="pt").to(device, torch.float16)

# generated_ids = model.generate(**inputs, max_new_tokens=20)
# generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
# print(generated_text)

## Prompted image captioning

You can provide a text prompt, which the model will continue given the image.

In [None]:
# prompt = "Is this road risky?(answer 1:risky, answer 0:not risky)"

# inputs = processor(image, text=prompt, return_tensors="pt").to(device, torch.float16)

# generated_ids = model.generate(**inputs, max_new_tokens=20)
# generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
# print(generated_text)

In [17]:
# prompt = "the weather looks"

# inputs = processor(image, text=prompt, return_tensors="pt").to(device, torch.float16)

# generated_ids = model.generate(**inputs, max_new_tokens=20)
# generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
# print(generated_text)

## Visual question answering (VQA)

In [None]:
# prompt = "Question: which city is this? Answer:"
prompt = "Question: Is this road risky?(answer 1:risky, answer 0:not risky) Answer:"

inputs = processor(image, text=prompt, return_tensors="pt").to(device, torch.float16)

generated_ids = model.generate(**inputs, max_new_tokens=10)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
print(generated_text)

## Chat-based prompting

We can create a ChatGPT-like interface by simply concatenating each generated response to the conversation. We prompt the model with some text (like "which city is this?"), the model generates an answer for it "Singapore"), which we just concatenate to the conversation. Then we ask a follow-up question ("why?") which we also just concatenate and feed to the model.

This means that the context can't be too long - models like OPT and T5 (the language models being used in BLIP-2) have a context length of 512 tokens.

In [None]:
# context = [
#     ("which city is this?", "singapore"),
#     ("why?", "it has a statue of a merlion"),
# ]
# question = "where is the name merlion coming from?"
# template = "Question: {} Answer: {}."

# prompt = " ".join([template.format(context[i][0], context[i][1]) for i in range(len(context))]) + " Question: " + question + " Answer:"

# print(prompt)

In [None]:
# inputs = processor(image, text=prompt, return_tensors="pt").to(device, torch.float16)

# generated_ids = model.generate(**inputs, max_new_tokens=10)
# generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
# print(generated_text)