In [1]:
import os
os.environ['COCO_DIR'] = '/usr1/data/mingqia2/datasets/coco/'
os.environ['AOKVQA_DIR'] = '/usr1/data/mingqia2/aokvqa/'
os.environ['HF_HOME'] = '/usr1/data/models_cache'

import json 
from collections import Counter
from tqdm import tqdm

import torch
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from datasets import load_dataset
from PIL import Image

from transformers import Blip2Processor, Blip2ForConditionalGeneration

from load_aokvqa import load_aokvqa, get_coco_path
from model import VQADataset
from accelerate import infer_auto_device_map

  from .autonotebook import tqdm as notebook_tqdm


In [19]:
saved_model_path = "./trained_model"
device = "cuda" if torch.cuda.is_available() else "cpu"

processor = Blip2Processor.from_pretrained(saved_model_path)
model = Blip2ForConditionalGeneration.from_pretrained(saved_model_path).to(device)

Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.06it/s]


In [20]:
coco_dir = os.getenv('COCO_DIR')
aokvqa_dir = os.getenv('AOKVQA_DIR')
train_path = "../results/viper_augmentations/aokvqa_plus_viper_train.json"
val_path = "../results/viper_augmentations/aokvqa_plus_viper_val.json"
val_dataset = load_dataset("json", data_files={"val": val_path}, split="val")

In [18]:
val_dataset[0]

{'split': 'val',
 'image_id': 377368,
 'question_id': '2Aq5RiEn7eyfWjEbpuYT2o',
 'question': 'Which number birthday is probably being celebrated?',
 'choices': ['one', 'ten', 'nine', 'thirty'],
 'correct_choice_idx': 3,
 'direct_answers': ['thirty',
  '30th',
  'thirty',
  'thirty',
  'thirty',
  '30th',
  'thirty',
  'thirty',
  'thirty',
  'thirty'],
 'difficult_direct_answer': False,
 'rationales': ['There is a birthday cake on the table with the number 30 written in icing.',
  'The cake says 30.',
  'The numerals three and zero are written on the cake, which indicates the person is 30 years of age as of the birthdate.']}

In [4]:
def save_results_to_file(results, output_file):
    with open(output_file, 'w') as f:
        json.dump(results, f)

In [16]:
def inference_pipeline(processor, model, val_dataset, coco_dir, output_file, batch_size=100, split='val', mc=False):
    results = []
    total_items = len(val_dataset)

    for idx, item in tqdm(enumerate(val_dataset), total=total_items):
        image_id = item['image_id']
        question_id = item['question_id']
        question = item['question']
        visual_clues = item.get('visual_clues', "")  # Ensure compatibility if 'visual_clues' is missing
        
        # Handle Multiple-Choice (MC) and Direct Answer (DA)
        if mc:
            choices = item['choices']
            formatted_choices = " ".join([f"({chr(65+i)}) {choice}" for i, choice in enumerate(choices)])
            prompt = (
                f"Question: {question} \n Visual Clues: {visual_clues} \n Choices: {formatted_choices}. "
                "Please provide a rationale and then return the letter of the correct answer in the format: "
                "'Rationale: [your explanation] \\n Answer: [your answer]'."
            )
        else:
            prompt = (
                f"Question: {question} \n Visual Clues: {visual_clues}. "
                "Please provide a rationale and then return the direct answer in the format: "
                "'Rationale: [your explanation] \\n Answer: [your answer]'."
            )
        print(prompt)
        
        # Get image path and preprocess the image
        image_path = get_coco_path(split, image_id, coco_dir)
        raw_image = Image.open(image_path)
        inputs = processor(images=raw_image, text=prompt, truncation=True, return_tensors='pt').to(device)
        
        # Generate output
        output = model.generate(**inputs, max_new_tokens=50, do_sample=False)  # Adjust max_new_tokens as needed
        decoded_output = processor.decode(output[0], skip_special_tokens=True)
        
        # Extract rationale and answer
        if "\\n Answer:" in decoded_output:
            rationale, answer = decoded_output.split("\\n Answer:", 1)
        else:
            rationale, answer = decoded_output, ""  # Fallback if format is incorrect

        # Store the results
        result = {
            "image_id": image_id,
            "question_id": question_id,
            "rationale": rationale.strip(),
            "answer": answer.strip(),
        }
        results.append(result)

        # Save results periodically
        if (idx + 1) % batch_size == 0 or (idx + 1) == total_items:
            save_results_to_file(results, output_file)

    return results

In [17]:
output_file = 'ft_blip2-opt-2.7b_val_results_da.json'
inference_pipeline(processor, model, val_dataset, coco_dir, output_file, batch_size=100, split='val')

  0%|          | 0/774 [00:00<?, ?it/s]

Question: Which number birthday is probably being celebrated? 
 Visual Clues: . Please provide a rationale and then return the direct answer in the format: 'Rationale: [your explanation] \n Answer: [your answer]'.


  0%|          | 1/774 [00:01<14:21,  1.12s/it]

Question: What best describes the pool of water? 
 Visual Clues: . Please provide a rationale and then return the direct answer in the format: 'Rationale: [your explanation] \n Answer: [your answer]'.


  0%|          | 2/774 [00:02<14:16,  1.11s/it]

Question: What is the white substance on top of the cupcakes? 
 Visual Clues: . Please provide a rationale and then return the direct answer in the format: 'Rationale: [your explanation] \n Answer: [your answer]'.


  0%|          | 2/774 [00:02<16:47,  1.30s/it]


KeyboardInterrupt: 