In [1]:
from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
import torch
import os
import json
from PIL import Image

device = "cuda"

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# https://github.com/salesforce/LAVIS/tree/main/projects/instructblip#prepare-vicuna-weights
#!pip3 install "fschat[model_worker,webui]"
#!python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5

In [2]:
from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
import torch

processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b", load_in_4bit=True, torch_dtype=torch.float16)

Loading checkpoint shards: 100%|██████████| 4/4 [01:08<00:00, 17.21s/it]


In [19]:
class AnswerGenerator:
    def __init__(self, model, preprocessor, device='cuda'):
        self.model = model
        self.preprocessor = preprocessor
        self.device = device

    @torch.no_grad()
    def _load(self, image, prompt):
        if type(image) == str:
            raw_image = Image.open(image).convert("RGB")
        else:
            raw_image = image
        # image_emb = self.preprocessor["eval"](raw_image).unsqueeze(0).to(device)
        inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(device="cuda", dtype=torch.float32)
        return inputs

    def inference(self, image, prompt):
        inputs = self._load(image, prompt)
        # qaSet = {"image": img_emb, "prompt": prompt}
        outputs = self.model.generate(**inputs)
        generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
        return generated_text


In [20]:
llm = AnswerGenerator(model, processor, device)
base_dir = "."
# key: imageFile_name, value: [{key(Q_i):value(question description)}]
with open(f'{base_dir}/qSet.json', 'r') as f:
    question_dict = json.load(f)

answer_dict = {}
for fileName in question_dict.keys():
    answer_dict[fileName] = {}
    for q in question_dict[fileName]:
        img_dir = f'{base_dir}/images/{fileName}.png'
        ans = llm.inference(img_dir, question_dict[fileName][q])
        ansKey = "A" + q[1:]
        answer_dict[fileName][ansKey] = ans

with open(f'{base_dir}/answer_set.json', 'w') as outfile:
    json.dump(answer_dict, outfile)

