# Imports

In [None]:
from typing import List, Tuple
import glob
import torch
import wandb
from PIL import Image
from transformers import AutoProcessor, Blip2ForConditionalGeneration

# some helper constants and functions
MODEL_ID = "Salesforce/blip2-opt-2.7b"
device = "cuda" if torch.cuda.is_available() else "cpu"


def load_image(path: str):
    img = Image.open(path).convert("RGB")
    return img

# Load model

In [None]:
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = Blip2ForConditionalGeneration.from_pretrained(
    MODEL_ID, torch_dtype=torch.float16
)

In [None]:
def image_captioning(image: Image.Image):
    inputs = processor(image, return_tensors="pt").to(device, torch.float16)
    generated_ids = model.generate(**inputs, max_new_tokens=20)
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[
        0
    ].strip()
    return generated_text

In [None]:
def prompted_image_captioning(image: Image.Image, prompt: str):
    inputs = processor(image, text=prompt, return_tensors="pt").to(
        device, torch.float16
    )
    generated_ids = model.generate(**inputs, max_new_tokens=20)
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[
        0
    ].strip()
    return generated_text

In [None]:
def vqa(image: Image.Image, question: str):
    inputs = processor(image, text=question, return_tensors="pt").to(
        device, torch.float16
    )
    generated_ids = model.generate(**inputs, max_new_tokens=10)
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[
        0
    ].strip()
    return generated_text

In [None]:
def chat_based_prompting(
    image: Image.Image, context: List[Tuple[str, str]], question: str
):
    template = "Question: {} Answer: {}."
    prompt = (
        " ".join(
            [template.format(context[i][0], context[i][1]) for i in range(len(context))]
        )
        + " Question: "
        + question
        + " Answer:"
    )
    inputs = processor(image, text=prompt, return_tensors="pt").to(
        device, torch.float16
    )
    generated_ids = model.generate(**inputs, max_new_tokens=10)
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[
        0
    ].strip()
    return generated_text

# Image captioning table

In [None]:
# get all images in the images folder
image_paths = glob.glob("images/*.png")
wandb.init(project="BLIP-2", name="image_captioning")
table = wandb.Table(columns=["Image","Generated caption"])

for img in image_paths:
    image = load_image(img)
    caption = image_captioning(image)
    table.add_data(wandb.Image(image), caption)

wandb.log({"img_captioning": table})
wandb.finish()

# Prompted Image captioning

In [None]:
# get all images in the images folder
image_paths = glob.glob("images/*.png")
wandb.init(project="BLIP-2", name="image_captioning")
table = wandb.Table(columns=["Prompt", "Image","Generated caption"])

for img in image_paths:
    image = load_image(img)
    caption = image_captioning(image)
    table.add_data(wandb.Image(image), caption)

wandb.log({"img_captioning": table})
wandb.finish()