In [10]:
import os

import requests
from transformers import Blip2Processor, BlipForQuestionAnswering, Blip2Model
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoProcessor
import torch
from PIL import Image
from torch.utils.data import DataLoader
from tqdm import tqdm
import pickle

In [2]:
training_dataset = load_dataset("json", data_files=#path to train annotations, split="train[:70%]")
valid_dataset = load_dataset("json", data_files=#train file path, split="train[30%:]")

In [3]:
import torch 
from torch.utils.data import Dataset 

class VQADataset(Dataset): 

    def __init__(self, dataset, processor):
        self.dataset = dataset
        self.processor = processor
        
    def __len__(self): 
        return len(self.dataset)
        
    def __getitem__(self, idx):
        question = self.dataset[idx]['question']
        answer = self.dataset[idx]['answer']
        image_id = self.dataset[idx]['image_id']
        image_path = #image path
        image = Image.open(image_path).convert("RGB")
        return question, answer, image

In [None]:
from transformers import AutoModelForCausalLM, AutoProcessor
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

model = AutoModelForCausalLM.from_pretrained(
    "microsoft/Florence-2-base-ft",
    trust_remote_code=True,
    revision='refs/pr/6'
).to(device) 
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base-ft", 
    trust_remote_code=True, revision='refs/pr/6')

for param in model.vision_tower.parameters():
  param.is_trainable = False


In [None]:
print("Training sets: {} - Validating set: {}".format(len(training_dataset), len(valid_dataset)))

In [6]:
import os 
from torch.utils.data import DataLoader
from tqdm import tqdm 
from transformers import AdamW, get_scheduler

def collate_fn(batch): 
    questions, answers, images = zip(*batch)
    inputs = processor(text=list(questions), images=list(images), return_tensors="pt", padding=True).to(device)
    return inputs, answers 
train_dataset = VQADataset(dataset=training_dataset,
                          processor=processor)
valid_dataset = VQADataset(dataset=valid_dataset,
                          processor=processor)

batch_size = 1
num_workers = 0

train_loader = DataLoader(train_dataset, batch_size=batch_size, 
                          collate_fn=collate_fn, num_workers=num_workers, shuffle=True)
val_loader = DataLoader(valid_dataset, batch_size=batch_size, 
                          collate_fn=collate_fn, num_workers=num_workers)


In [None]:
epochs = 10
optimizer = AdamW(model.parameters(), lr=1e-6)
num_training_steps = epochs * len(train_loader)

lr_scheduler = get_scheduler(name="linear", optimizer=optimizer, 
                              num_warmup_steps=0, num_training_steps=num_training_steps,)

for epoch in range(epochs): 
    model.train() 
    train_loss = 0
    i = -1
    idx = 0
    for inputs, answers in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{epochs}"):
        i += 1
        input_ids = inputs["input_ids"]
        pixel_values = inputs["pixel_values"] 
        labels = processor.tokenizer(text=answers, return_tensors="pt", padding=True, return_token_type_ids=False).input_ids.to(device)
        outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        train_loss += loss.item()
        idx+=1
    avg_train_loss = train_loss / len(train_loader)
    print(f"Average Training Loss: {avg_train_loss}")

    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Validation Epoch {epoch + 1}/{epochs}"):
            inputs, answers = batch
            input_ids = inputs["input_ids"]
            pixel_values = inputs["pixel_values"]
            labels = processor.tokenizer(text=answers, return_tensors="pt", padding=True, return_token_type_ids=False).input_ids.to(device)
            outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=labels)
            loss = outputs.loss
            val_loss += loss.item()

        print(val_loss / len(val_loader))


In [3]:
import json
question_file = #test questions 

with open(question_file, 'r') as input_file:
    questions = json.load(input_file)

temp = questions['questions']
with open("file.json", "w") as file:
    # Dump the data into the file as JSON
    json.dump(temp, file, indent=4)

In [4]:
qs = []
for question in questions["questions"]:
    qs.append(question['question'])
imgs = []
for question in questions["questions"]:
    imgs.append(question["image_id"])

In [None]:
print(len(qs))
print(len(imgs))

In [6]:
import json
annotation_file = #test annotations

with open(annotation_file, 'r') as input_file:
    annotations = json.load(input_file)

temp = annotations['annotations']
with open("file.json", "w") as file:
    # Dump the data into the file as JSON
    json.dump(temp, file, indent=4)

In [7]:
answers = []
for annotation in annotations['annotations']:
    answers.append(annotation['multiple_choice_answer'])

In [None]:
print(len(answers))

In [None]:
from PIL import Image
import requests
from transformers import AutoProcessor, BlipForQuestionAnswering
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

model = AutoModelForCausalLM.from_pretrained(
    "microsoft/Florence-2-base-ft",
    trust_remote_code=True,
    revision='refs/pr/6'
).to(device) 
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base-ft", 
    trust_remote_code=True, revision='refs/pr/6')
model.load_state_dict(torch.load("blip/florence_weights.pth"))
res = []
correct = 0
idx = 0
for i in range(len(qs)):
    image_path = #image path 
    image = Image.open(image_path).convert("RGB")
    text = qs[idx]
    inputs = processor(images=image, text=text, return_tensors="pt").to(device)
    input_ids = inputs["input_ids"]
    pixel_values = inputs["pixel_values"]
    preds = model.generate(input_ids=input_ids, pixel_values=pixel_values,max_new_tokens = 1024, num_beams = 3)
    generated_text = processor.batch_decode(preds, skip_special_tokens=True)[0]
    res.append(generated_text)
    print(res[idx], answers[idx])
    if(res[idx]==answers[idx]):
        correct=correct+1
    idx+=1

In [None]:
print(correct/len(res))