In [1]:
import os
os.environ['COCO_DIR'] = '/usr1/data/mingqia2/datasets/coco/'
os.environ['AOKVQA_DIR'] = '/usr1/data/mingqia2/aokvqa/'
os.environ['HF_HOME'] = '/usr1/data/models_cache'

import json 
from collections import Counter
from tqdm import tqdm

import torch
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from datasets import load_dataset
from PIL import Image

from transformers import Blip2Processor, Blip2ForConditionalGeneration
from transformers import AutoProcessor, AutoModelForVisualQuestionAnswering

from load_aokvqa import load_aokvqa, get_coco_path
import random 

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class VQADataset(Dataset):
    def __init__(self, dataset, processor, coco_dir, max_length=256):
        """
        Args:
            dataset: List of samples with original question, answer, and visual clues.
            processor: BLIP processor for text and image preprocessing.
            image_dir: Path to the directory containing images.
        """
        self.dataset = dataset
        self.processor = processor
        self.coco_dir = coco_dir
        self.max_length = max_length
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        sample = self.dataset[idx]
        image_path = get_coco_path(sample['split'], sample['image_id'], self.coco_dir)
        image = Image.open(image_path).convert("RGB")
        
        question = sample['question']
        visual_clues = sample.get('viper_gpt', {}).get('viper_question', '') + ' ' + sample.get('viper_gpt', {}).get('viper_response', '')
        rationales = sample.get('rationales', [])
        
        # Format choices into A:xx, B:xx, ... 
        formatted_choices = ', '.join([f"{chr(65 + i)}: {choice}" for i, choice in enumerate(sample['choices'])])
        correct_choice = chr(65 + sample["correct_choice_idx"]) # Convert index to letter
        direct_answers = sample['direct_answers']
        most_frequent_answer = Counter(direct_answers).most_common()[0][0]

        # Prompts for MC and DA tasks
        mc_question = (
            f"Question: {question} \n Visual Clues: {visual_clues} \n Choices: {formatted_choices}. "
            "Please provide a rationale and then return the letter of the correct answer in the format: "
            "'Rationale: [your explanation] \\n Answer: [your answer]'."
        )
        da_question = (
            f"Question: {question} \n Visual Clues: {visual_clues}. "
            "Please provide a rationale and then return the direct answer in the format: "
            "'Rationale: [your explanation] \\n Answer: [your answer]'."
        )
        
        rationale_text = " ".join(rationales)
        mc_output = f"Rationale: {rationale_text} \n Answer: {correct_choice}"
        da_output = f"Rationale: {rationale_text} \n Answer: {most_frequent_answer}"
        # print("MC Output:", mc_output)
        # print("DA Output:", da_output)
        
        # Process image and augmented text
        mc_encoding = self.processor(image, mc_question, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt")
        da_encoding = self.processor(image, da_question, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt")  
        
        mc_labels = self.processor.tokenizer(
        mc_output, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt"
    ).input_ids.squeeze(0)
        # print(f"MC Labels Shape: {mc_labels.shape}")
        
        da_labels = self.processor.tokenizer(
        da_output, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt"
    ).input_ids.squeeze(0)
        # print(f"DA Labels Shape: {da_labels.shape}")
    
        # Truncate DA Input IDs to match max_length
        # da_input_ids = da_encoding["input_ids"].squeeze()
        # if da_input_ids.size(0) > self.max_length:
        #     da_input_ids = da_input_ids[:self.max_length]
        # # Truncate DA Attention Mask
        # da_attention_mask = da_encoding["attention_mask"].squeeze()
        # if da_attention_mask.size(0) > self.max_length:
        #     da_attention_mask = da_attention_mask[:self.max_length]
        
        # Ensure no unexpected dimensions in labels
        assert mc_labels.dim() == 1, f"Unexpected MC Labels Shape: {mc_labels.shape}"
        assert da_labels.dim() == 1, f"Unexpected DA Labels Shape: {da_labels.shape}"
           
        return {
            "pixel_values": mc_encoding["pixel_values"].squeeze(0),
            "mc_input_ids": mc_encoding["input_ids"].squeeze(0),
            "mc_attention_mask": mc_encoding["attention_mask"].squeeze(0),
            "mc_labels": mc_labels,
            # "da_input_ids": da_input_ids,
            # "da_attention_mask": da_attention_mask,
            "da_input_ids": da_encoding["input_ids"].squeeze(0),
            "da_attention_mask": da_encoding["attention_mask"].squeeze(0),
            "da_labels": da_labels,
        }
            

In [3]:
def vqa_collate_fn(batch):
    pixel_values = torch.stack([item["pixel_values"] for item in batch])
    mc_input_ids = torch.stack([item["mc_input_ids"] for item in batch])
    mc_attention_mask = torch.stack([item["mc_attention_mask"] for item in batch])
    mc_labels = torch.stack([item["mc_labels"] for item in batch])
    da_input_ids = torch.stack([item["da_input_ids"] for item in batch])
    da_attention_mask = torch.stack([item["da_attention_mask"] for item in batch])
    da_labels = torch.stack([item["da_labels"] for item in batch])
    
    # print("Shapes after collation: ")
    # print(f"Pixel Values Shape: {pixel_values.shape}")
    # print(f"MC Input IDs Shape: {mc_input_ids.shape}")
    # print(f"MC Attention Mask Shape: {mc_attention_mask.shape}")
    # print(f"MC Labels Shape: {mc_labels.shape}")
    # print(f"DA Input IDs Shape: {da_input_ids.shape}")
    # print(f"DA Attention Mask Shape: {da_attention_mask.shape}")
    # print(f"DA Labels Shape: {da_labels.shape}")
    
    return {
        "pixel_values": pixel_values,
        "mc_input_ids": mc_input_ids,
        "mc_attention_mask": mc_attention_mask,
        "mc_labels": mc_labels,
        "da_input_ids": da_input_ids,
        "da_attention_mask": da_attention_mask,
        "da_labels": da_labels,
    }

In [None]:
class CustomLoss(nn.Module):
    def __init__(self, da_weight=1.0, mc_weight=1.0):
        super(CustomLoss, self).__init__()
        self.da_weight = da_weight
        self.mc_weight = mc_weight
        self.loss_fn = nn.CrossEntropyLoss() 
        # self.mc_loss = nn.CrossEntropyLoss()  # Loss for MC task
        # self.da_loss = nn.CrossEntropyLoss()  # Loss for DA task
    
    def forward(self, mc_logits, da_logits, mc_labels, da_labels):

        # print(f"Original MC Logits Shape: {mc_logits.shape}")
        # print(f"Original MC Labels Shape: {mc_labels.shape}")
        # print(f"Original DA Logits Shape: {da_logits.shape}")
        # print(f"Original DA Labels Shape: {da_labels.shape}")
    
        mc_logits = mc_logits.contiguous().view(-1, mc_logits.size(-1))  # Shape: [batch_size * seq_len, vocab_size]
        mc_labels = mc_labels.contiguous().view(-1)                   # Shape: [batch_size * seq_len]

        da_logits = da_logits.contiguous().view(-1, da_logits.size(-1))  # Shape: [batch_size * seq_len, vocab_size]
        da_labels = da_labels.contiguous().view(-1)                    # Shape: [batch_size * seq_len]
        
        # print(f"Flattened MC Logits Shape: {mc_logits.shape}")
        # print(f"Flattened MC Labels Shape: {mc_labels.shape}")
        # print(f"Flattened DA Logits Shape: {da_logits.shape}")
        # print(f"Flattened DA Labels Shape: {da_labels.shape}")

        mc_loss = self.loss_fn(mc_logits, mc_labels)
        da_loss = self.loss_fn(da_logits, da_labels)
        # total_loss = mc_loss + da_loss
        total_loss = self.mc_weight * mc_loss + self.da_weight * da_loss
        return total_loss, mc_loss, da_loss 

In [5]:
def train_model(model, train_dataloader, processor, valid_dataloader, epochs, learning_rate, device, da_weight=1.0, mc_weight=1.0):
        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
        loss_fn = CustomLoss(da_weight=da_weight, mc_weight=mc_weight)
        
        for epoch in range(epochs):
            model.train()
            total_loss, total_mc_loss, total_da_loss = 0, 0, 0
            
            for batch in tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{epochs}"):
                pixel_values = batch["pixel_values"].to(device)
                print(f"Pixel Values Shape: {pixel_values.shape}")
                
                # MC task
                mc_input_ids = batch["mc_input_ids"].to(device)
                mc_attention_mask = batch["mc_attention_mask"].to(device)
                mc_labels = batch["mc_labels"].to(device)

                mc_outputs = model(
                    pixel_values=pixel_values, 
                    input_ids=mc_input_ids, 
                    attention_mask=mc_attention_mask, 
                    # labels=mc_labels
                )

                # DA task
                da_input_ids = batch["da_input_ids"].to(device)
                da_attention_mask = batch["da_attention_mask"].to(device)
                da_labels = batch["da_labels"].to(device)

                # print(f"DA Input IDs Shape: {da_input_ids.shape}")
                # print(f"DA Attention Mask Shape: {da_attention_mask.shape}")
                # print(f"DA Labels Shape: {da_labels.shape}")

                da_outputs = model(
                    pixel_values=pixel_values, 
                    input_ids=da_input_ids, 
                    attention_mask=da_attention_mask, 
                    # labels=da_labels
                )
                # print(f"DA Logits Shape: {da_outputs.logits.shape}")
                
                mc_logits = mc_outputs.logits[:, :mc_labels.size(1), :]  # Shape: [batch_size, label_seq_len, vocab_size]
                da_logits = da_outputs.logits[:, :da_labels.size(1), :]  # Shape: [batch_size, label_seq_len, vocab_size]
                # Compute weighted losses
                loss, mc_loss, da_loss = loss_fn(mc_logits, da_logits, mc_labels, da_labels)

                # Backpropagation
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                total_mc_loss += mc_loss.item()
                total_da_loss += da_loss.item()
            
            print(f"Epoch {epoch + 1}/{epochs}: Total Loss = {total_loss / len(train_dataloader):.4f}, "
                f"MC Loss = {total_mc_loss / len(train_dataloader):.4f}, "
                f"DA Loss = {total_da_loss / len(train_dataloader):.4f}")
            
            # Evaluate after each epoch
            # evaluate_model(model, processor, valid_dataloader, device, loss_fn)

In [6]:
NUM_EPOCHS = 3
LR = 5e-4
DA_WEIGHT = 1.0
MC_WEIGHT = 1.0

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load BLIP-2 model and processor
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", device_map="auto")

coco_dir = os.getenv('COCO_DIR')
aokvqa_dir = os.getenv('AOKVQA_DIR')

train_path = "../results/viper_augmentations/aokvqa_plus_viper_train.json"
val_path = "../results/viper_augmentations/aokvqa_plus_viper_val.json"

training_dataset = load_dataset("json", data_files={"train": train_path}, split="train")
valid_dataset = load_dataset("json", data_files={"val": val_path}, split="val")

print(f"Training set: {len(training_dataset)} samples, Validation set: {len(valid_dataset)} samples")

train_dataset = VQADataset(dataset=training_dataset, processor=processor, coco_dir=coco_dir)
valid_dataset = VQADataset(dataset=valid_dataset, processor=processor, coco_dir=coco_dir)

BATCH_SIZE = 8
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=vqa_collate_fn)
valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=vqa_collate_fn)


# Train the model
train_model(
    model=model,
    train_dataloader=train_dataloader,
    valid_dataloader=valid_dataloader,
    processor=processor,
    epochs=NUM_EPOCHS,
    learning_rate=LR,
    device=device,
    da_weight=DA_WEIGHT,
    mc_weight=MC_WEIGHT,
    )

# Save the model
model.save_pretrained("./trained_model")
processor.save_pretrained("./trained_model")
print("Model and processor saved to './trained_model'")




Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.01it/s]


Training set: 11271 samples, Validation set: 774 samples


Epoch 1/3:   0%|          | 0/1409 [00:00<?, ?it/s]

Pixel Values Shape: torch.Size([8, 3, 224, 224])
Original MC Logits Shape: torch.Size([8, 256, 50304])
Original MC Labels Shape: torch.Size([8, 256])
Original DA Logits Shape: torch.Size([8, 256, 50304])
Original DA Labels Shape: torch.Size([8, 256])
Flattened MC Logits Shape: torch.Size([2048, 50304])
Flattened MC Labels Shape: torch.Size([2048])
Flattened DA Logits Shape: torch.Size([2048, 50304])
Flattened DA Labels Shape: torch.Size([2048])


Epoch 1/3:   0%|          | 1/1409 [00:04<1:38:07,  4.18s/it]

Pixel Values Shape: torch.Size([8, 3, 224, 224])
Original MC Logits Shape: torch.Size([8, 256, 50304])
Original MC Labels Shape: torch.Size([8, 256])
Original DA Logits Shape: torch.Size([8, 256, 50304])
Original DA Labels Shape: torch.Size([8, 256])
Flattened MC Logits Shape: torch.Size([2048, 50304])
Flattened MC Labels Shape: torch.Size([2048])
Flattened DA Logits Shape: torch.Size([2048, 50304])
Flattened DA Labels Shape: torch.Size([2048])


Epoch 1/3:   0%|          | 2/1409 [00:06<1:16:41,  3.27s/it]

Pixel Values Shape: torch.Size([8, 3, 224, 224])
Original MC Logits Shape: torch.Size([8, 256, 50304])
Original MC Labels Shape: torch.Size([8, 256])
Original DA Logits Shape: torch.Size([8, 256, 50304])
Original DA Labels Shape: torch.Size([8, 256])
Flattened MC Logits Shape: torch.Size([2048, 50304])
Flattened MC Labels Shape: torch.Size([2048])
Flattened DA Logits Shape: torch.Size([2048, 50304])
Flattened DA Labels Shape: torch.Size([2048])


Epoch 1/3:   0%|          | 2/1409 [00:09<1:46:58,  4.56s/it]


KeyboardInterrupt: 

In [None]:

def evaluate_model(model, dataloader, device, loss_fn):
    model.eval()
    total_loss = 0
    total_mc_loss = 0
    total_da_loss = 0

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            pixel_values = batch["pixel_values"].to(device)

            # MC task
            mc_input_ids = batch["mc_input_ids"].to(device)
            mc_attention_mask = batch["mc_attention_mask"].to(device)
            mc_labels = batch["mc_labels"].to(device)

            mc_outputs = model(
                pixel_values=pixel_values, 
                input_ids=mc_input_ids, 
                attention_mask=mc_attention_mask, 
                labels=mc_labels
            )

            # DA task
            da_input_ids = batch["da_input_ids"].to(device)
            da_attention_mask = batch["da_attention_mask"].to(device)
            da_labels = batch["da_labels"].to(device)

            da_outputs = model(
                pixel_values=pixel_values, 
                input_ids=da_input_ids, 
                attention_mask=da_attention_mask, 
                labels=da_labels
            )

            # Compute weighted losses
            loss, mc_loss, da_loss = loss_fn(mc_outputs.logits, mc_labels, da_outputs.logits, da_labels)

            total_loss += loss.item()
            total_mc_loss += mc_loss.item()
            total_da_loss += da_loss.item()

    avg_loss = total_loss / len(dataloader)
    avg_mc_loss = total_mc_loss / len(dataloader)
    avg_da_loss = total_da_loss / len(dataloader)

    print(f"Validation: Total Loss = {avg_loss:.4f}, MC Loss = {avg_mc_loss:.4f}, DA Loss = {avg_da_loss:.4f}")

In [None]:
# NUM_EPOCHS = 3
# LR = 5e-4

# loss_fn = CustomLoss()
# optimizer = torch.optim.Adam(model.parameters(), lr=LR)

# for epoch in range(NUM_EPOCHS):
#     model.train()
#     total_loss = 0
    
#     for batch in tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{NUM_EPOCHS}"):
#         # Move data to device
#         pixel_values = batch["pixel_values"].to(device)
#         mc_input_ids = batch["mc_input_ids"].to(device)
#         mc_attention_mask = batch["mc_attention_mask"].to(device)
#         da_input_ids = batch["da_input_ids"].to(device)
#         da_attention_mask = batch["da_attention_mask"].to(device)
#         mc_labels = batch["multiple_choice_label"].to(device)
#         da_labels = batch["direct_answer_labels"].to(device)

#         # Forward pass for MC
#         mc_outputs = model(pixel_values=pixel_values, input_ids=mc_input_ids, attention_mask=mc_attention_mask)
#         mc_logits = mc_outputs.logits

#         # Forward pass for DA
#         da_outputs = model(pixel_values=pixel_values, input_ids=da_input_ids, attention_mask=da_attention_mask)
#         da_logits = da_outputs.logits
        
#         # print(f"mc_logits shape: {mc_logits.shape}, mc_labels shape: {mc_labels.shape}")
#         # print(f"da_logits shape: {da_logits.shape}") 
#         # print(f"da_labels shape: {da_labels.shape}")
        
#         loss = loss_fn(mc_logits=mc_logits, da_logits=da_logits, mc_labels=mc_labels, da_labels=da_labels)

#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
        
#         total_loss += loss.item()
    
#     print(f"Epoch {epoch}: Loss = {total_loss / len(train_dataloader)}")

Epoch 1/3:   0%|          | 3/1409 [00:08<1:10:16,  3.00s/it]


KeyboardInterrupt: 

In [None]:
# model.eval()
# val_loss = 0

# with torch.no_grad():
#     for batch in tqdm(valid_dataloader, desc="Validation"):
#         # Move data to device
#         pixel_values = batch["pixel_values"].to(device)
#         mc_input_ids = batch["mc_input_ids"].to(device)
#         mc_attention_mask = batch["mc_attention_mask"].to(device)
#         da_input_ids = batch["da_input_ids"].to(device)
#         da_attention_mask = batch["da_attention_mask"].to(device)
#         mc_labels = batch["multiple_choice_label"].to(device)
#         da_labels = batch["direct_answer_labels"].to(device)

#         # Forward pass for MC
#         mc_outputs = model(pixel_values=pixel_values, input_ids=mc_input_ids, attention_mask=mc_attention_mask)
#         mc_logits = mc_outputs.logits

#         # Forward pass for DA
#         da_outputs = model(pixel_values=pixel_values, input_ids=da_input_ids, attention_mask=da_attention_mask)
#         da_logits = da_outputs.logits

#         # Compute loss
#         loss = loss_fn(mc_logits=mc_logits, da_logits=da_logits, mc_labels=mc_labels, da_labels=da_labels)
#         val_loss += loss.item()

# avg_val_loss = val_loss / len(valid_dataloader)
# print(f"Validation Loss = {avg_val_loss:.4f}")