In [1]:
import os
os.environ['COCO_DIR'] = '/usr1/data/mingqia2/datasets/coco/'
os.environ['AOKVQA_DIR'] = '/usr1/data/mingqia2/aokvqa/'
os.environ['HF_HOME'] = '/usr1/data/models_cache'

import json 
from tqdm import tqdm

import torch
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from datasets import load_dataset
from PIL import Image

from transformers import Blip2Processor, Blip2ForConditionalGeneration
from transformers import AutoProcessor, AutoModelForVisualQuestionAnswering

from load_aokvqa import load_aokvqa, get_coco_path

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load BLIP-2 model and processor
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", device_map="auto")


Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.03it/s]


In [3]:
class VQADataset(Dataset):
    def __init__(self, dataset, processor, coco_dir, max_length=128):
        """
        Args:
            dataset: List of samples with original question, answer, and visual clues.
            processor: BLIP processor for text and image preprocessing.
            image_dir: Path to the directory containing images.
        """
        self.dataset = dataset
        self.processor = processor
        self.coco_dir = coco_dir
        self.max_length = max_length
        # self.split = split # train or val 
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        sample = self.dataset[idx]
         # Get image path and open image
        image_path = get_coco_path(sample['split'], sample['image_id'], self.coco_dir)
        image = Image.open(image_path).convert("RGB")
        
        # Combine question with MC options
        question = sample['question']
        choices = sample['choices']
        alphabet = ['A', 'B', 'C', 'D']
        # Format the choices
        formatted_choices = ', '.join([f"{letter}: {choice}" for letter, choice in zip(alphabet, choices)])
        
        
        # Get visual clues 
        visual_clues = sample.get('viper_gpt', {}).get('viper_question', '') + ' ' + sample.get('viper_gpt', {}).get('viper_response', '')

        mc_question = f"Question: {question} Visual Clues: {visual_clues} Choices: {formatted_choices}. Please only return the letter of the correct answer."
        da_question = f"Question: {question} Visual Clues: {visual_clues}"
        
        # Process image and augmented text
        mc_encoding = self.processor(image, mc_question, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt")
        da_encoding = self.processor(image, da_question, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt")
        
        # Prepare answers 
        direct_answers = sample['direct_answers']
        correct_choice_idx = sample['correct_choice_idx'] # index of the correct MC choice
        
        da_text = " | ".join(direct_answers)
        da_labels = self.processor.tokenizer(
            da_text, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt"
        ).input_ids.squeeze()
        mc_label = torch.tensor(correct_choice_idx, dtype=torch.long)
        
        encoding = {
            "pixel_values": mc_encoding["pixel_values"].squeeze(0),  
            "mc_input_ids": mc_encoding["input_ids"].squeeze(0),
            "mc_attention_mask": mc_encoding["attention_mask"].squeeze(0),
            "da_input_ids": da_encoding["input_ids"].squeeze(0),
            "da_attention_mask": da_encoding["attention_mask"].squeeze(0),
            "direct_answer_labels": da_labels,
            "multiple_choice_label": mc_label
        }
        
        return encoding

In [None]:
class CustomLoss(nn.Module):
    def __init__(self):
        super(CustomLoss, self).__init__()
        self.mc_loss = nn.CrossEntropyLoss()  # Loss for MC task
        self.da_loss = nn.CrossEntropyLoss()  # Loss for DA task
        # self.ce_loss = nn.CrossEntropyLoss() # overall loss  
        # self.aux_loss = nn.MSELoss()  # Auxiliary loss for alignment
    
    def forward(self, mc_logits, da_logits, mc_labels, da_labels):
        mc_logits = mc_logits[:, 0, :]

        min_seq_length = min(da_logits.size(1), da_labels.size(1))
        da_logits = da_logits[:, :min_seq_length, :].contiguous()
        da_labels = da_labels[:, :min_seq_length].contiguous()
        
        # Flatten da_logits and da_labels for loss computation
        da_logits = da_logits.view(-1, da_logits.size(-1)) # [batch_size * seq_length, vocab_size]
        da_labels = da_labels.view(-1) # [batch_size * seq_length]                    
        
        mc_loss = self.mc_loss(mc_logits, mc_labels)
        da_loss = self.da_loss(da_logits, da_labels)
        total_loss = mc_loss + da_loss
        return total_loss 
    
        # Auxiliary loss to align metadata and question embeddings
        # aux = self.aux_loss(metadata_embeddings, question_embeddings)
        # return ce + 0.1 * aux  # Combine with weighting

In [5]:
coco_dir = os.getenv('COCO_DIR')
aokvqa_dir = os.getenv('AOKVQA_DIR')

train_path = "../results/viper_augmentations/aokvqa_plus_viper_train.json"
val_path = "../results/viper_augmentations/aokvqa_plus_viper_val.json"

training_dataset = load_dataset("json", data_files={"train": train_path}, split="train")
valid_dataset = load_dataset("json", data_files={"val": val_path}, split="val")

print(f"Training set: {len(training_dataset)} samples, Validation set: {len(valid_dataset)} samples")

train_dataset = VQADataset(dataset=training_dataset, processor=processor, coco_dir=coco_dir)
valid_dataset = VQADataset(dataset=valid_dataset, processor=processor, coco_dir=coco_dir)

BATCH_SIZE = 8
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False)



Training set: 11271 samples, Validation set: 774 samples


In [6]:
train_dataset[0]

{'pixel_values': tensor([[[ 1.5070,  1.5216,  1.5800,  ..., -1.4711, -1.3835,  0.2807],
          [ 1.5654,  1.5800,  1.5800,  ..., -1.6463, -0.1280,  0.6165],
          [ 1.6530,  1.6676,  1.6238,  ..., -1.4565, -0.5368,  0.3975],
          ...,
          [ 0.2515,  0.3245, -0.6536,  ...,  0.4559,  0.3245,  0.4267],
          [ 0.0909,  0.2515, -0.5222,  ...,  0.5289,  0.3099,  0.3245],
          [ 0.1347,  0.3245, -0.4930,  ...,  0.3829,  0.3829,  0.2953]],
 
         [[ 1.8047,  1.8047,  1.8498,  ..., -1.3169, -1.2118,  0.9493],
          [ 1.8198,  1.8498,  1.8648,  ..., -1.5420,  0.3490,  1.3995],
          [ 1.9098,  1.8948,  1.8798,  ..., -1.3469, -0.0712,  1.1444],
          ...,
          [ 0.2589,  0.4090, -0.5815,  ...,  0.4691,  0.3790,  0.4240],
          [ 0.1089,  0.3190, -0.4464,  ...,  0.5591,  0.3490,  0.3340],
          [ 0.1989,  0.3940, -0.3864,  ...,  0.4090,  0.4240,  0.3190]],
 
         [[ 2.0890,  2.1032,  2.1032,  ..., -1.0678, -0.9399,  1.5202],
          [ 

In [7]:
NUM_EPOCHS = 3
LR = 5e-4

loss_fn = CustomLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

for epoch in range(NUM_EPOCHS):
    model.train()
    total_loss = 0
    
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{NUM_EPOCHS}"):
        # Move data to device
        pixel_values = batch["pixel_values"].to(device)
        mc_input_ids = batch["mc_input_ids"].to(device)
        mc_attention_mask = batch["mc_attention_mask"].to(device)
        da_input_ids = batch["da_input_ids"].to(device)
        da_attention_mask = batch["da_attention_mask"].to(device)
        mc_labels = batch["multiple_choice_label"].to(device)
        da_labels = batch["direct_answer_labels"].to(device)

        # Forward pass for MC
        mc_outputs = model(pixel_values=pixel_values, input_ids=mc_input_ids, attention_mask=mc_attention_mask)
        mc_logits = mc_outputs.logits

        # Forward pass for DA
        da_outputs = model(pixel_values=pixel_values, input_ids=da_input_ids, attention_mask=da_attention_mask)
        da_logits = da_outputs.logits
        
        # print(f"mc_logits shape: {mc_logits.shape}, mc_labels shape: {mc_labels.shape}")
        # print(f"da_logits shape: {da_logits.shape}") 
        # print(f"da_labels shape: {da_labels.shape}")
        
        loss = loss_fn(mc_logits=mc_logits, da_logits=da_logits, mc_labels=mc_labels, da_labels=da_labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    print(f"Epoch {epoch}: Loss = {total_loss / len(train_dataloader)}")

Epoch 1/3:   0%|          | 3/1409 [00:08<1:10:16,  3.00s/it]


KeyboardInterrupt: 

In [None]:
model.eval()
val_loss = 0

with torch.no_grad():
    for batch in tqdm(valid_dataloader, desc="Validation"):
        # Move data to device
        pixel_values = batch["pixel_values"].to(device)
        mc_input_ids = batch["mc_input_ids"].to(device)
        mc_attention_mask = batch["mc_attention_mask"].to(device)
        da_input_ids = batch["da_input_ids"].to(device)
        da_attention_mask = batch["da_attention_mask"].to(device)
        mc_labels = batch["multiple_choice_label"].to(device)
        da_labels = batch["direct_answer_labels"].to(device)

        # Forward pass for MC
        mc_outputs = model(pixel_values=pixel_values, input_ids=mc_input_ids, attention_mask=mc_attention_mask)
        mc_logits = mc_outputs.logits

        # Forward pass for DA
        da_outputs = model(pixel_values=pixel_values, input_ids=da_input_ids, attention_mask=da_attention_mask)
        da_logits = da_outputs.logits

        # Compute loss
        loss = loss_fn(mc_logits=mc_logits, da_logits=da_logits, mc_labels=mc_labels, da_labels=da_labels)
        val_loss += loss.item()

avg_val_loss = val_loss / len(valid_dataloader)
print(f"Validation Loss = {avg_val_loss:.4f}")