In [1]:
# !pip install salesforce-lavis

In [2]:
import pandas as pd
import numpy as np

import torch
from PIL import Image
from lavis.models import load_model_and_preprocess

In [3]:
root_path = '/home/smanduru/CS682Project/data/CUB_200_2011'

images_path = root_path + '/images'
attributes_path = root_path + '/attributes/'

In [4]:
images = []

with open(root_path + "/images.txt", "r") as file:
    
    for line in file:
        folder, image = line.strip().split('/')
        index, folder = folder.split(" ")
        images.append(images_path + '/' + folder + '/' + image)

In [5]:
len(images)

11788

In [6]:
data = pd.read_csv(attributes_path + "image_attribute_labels.txt", 
                   names=['img_index', 'attribute_index', 'attribute_value', 'certainity', 'unknown'],
                   delimiter=' ', on_bad_lines='warn')

In [7]:
data

Unnamed: 0,img_index,attribute_index,attribute_value,certainity,unknown
0,1,1,0,3,27.708
1,1,2,0,3,27.708
2,1,3,0,3,27.708
3,1,4,0,3,27.708
4,1,5,1,3,27.708
...,...,...,...,...,...
3677851,11788,308,1,4,4.989
3677852,11788,309,0,4,8.309
3677853,11788,310,0,4,8.309
3677854,11788,311,0,4,8.309


In [8]:
ranges = [(153, 167), (10, 24), (198, 212)] #, (264, 278)] # Forehead, Wing, Belly, Leg

df = data[data['attribute_value'] == 1]
df = df[df['attribute_index'].isin(sum([list(range(start, end + 1)) for start, end in ranges], []))]

In [9]:
df = df[['img_index', 'attribute_index', 'attribute_value']].reset_index(drop=True)

In [10]:
df

Unnamed: 0,img_index,attribute_index,attribute_value
0,1,165,1
1,2,15,1
2,2,158,1
3,2,166,1
4,2,210,1
...,...,...,...
49570,11788,164,1
49571,11788,165,1
49572,11788,203,1
49573,11788,204,1


In [11]:
attributes = {}
with open('/home/smanduru/CS682Project/data/attributes.txt' , "r") as file:
    
    for line in file:
        
        attribute_index, attribute = line.split()
        attribute_qsn, attribute_value = attribute.split('::')
        attribute_qsn = attribute_qsn.split('_', 1)[-1]
        attribute_qsn = attribute_qsn.replace('_', ' ')
        
        # print(attribute_index, attribute_qsn, attribute_value)
        attributes[attribute_index] = [int(attribute_index), attribute_qsn, attribute_value] 

In [12]:
att_df = pd.DataFrame.from_dict(attributes, orient='index', 
                                columns=['attribute_index', 'attribute_qsn', 'attribute_answer'])

In [13]:
att_df

Unnamed: 0,attribute_index,attribute_qsn,attribute_answer
1,1,bill shape,curved_(up_or_down)
2,2,bill shape,dagger
3,3,bill shape,hooked
4,4,bill shape,needle
5,5,bill shape,hooked_seabird
...,...,...,...
308,308,crown color,buff
309,309,wing pattern,solid
310,310,wing pattern,spotted
311,311,wing pattern,striped


In [14]:
merged_df = df.merge(att_df, how='inner', on='attribute_index')
merged_df = merged_df.sort_values(by=['img_index', 'attribute_index']).reset_index(drop = True)
merged_df

Unnamed: 0,img_index,attribute_index,attribute_value,attribute_qsn,attribute_answer
0,1,165,1,forehead color,white
1,2,15,1,wing color,grey
2,2,158,1,forehead color,grey
3,2,166,1,forehead color,red
4,2,210,1,belly color,white
...,...,...,...,...,...
49570,11788,164,1,forehead color,black
49571,11788,165,1,forehead color,white
49572,11788,203,1,belly color,grey
49573,11788,204,1,belly color,yellow


In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

answer_candidates = ["blue", "brown", "iridescent", "purple",
                     "rufous", "grey", "yellow", "olive",
                     "green", "pink", "orange", "black",
                     "white", "red", "buff"]

In [16]:
model, vis_processors, txt_processors = load_model_and_preprocess(name="albef_vqa",
                                                                  model_type="vqav2",
                                                                  is_eval=True,
                                                                  device=device)

In [17]:
# for row in merged_df.iterrows():
    
#     img_path = images[int(row[1]['img_index']) - 1]
#     raw_image = Image.open(img_path).convert("RGB")
    
#     # display(raw_image.resize((596, 437)))
#     question = f"What is the {row[1]['attribute_qsn']} of the bird?"
    
#     image = vis_processors["eval"](raw_image).unsqueeze(0).to(device)
#     question = txt_processors["eval"](question)

#     samples = {"image": image, "text_input": question}
    
#     print(model.predict_answers(samples, answer_list=answer_candidates, inference_method="rank"))
#     print(row[1]['attribute_answer'])
#     break

In [18]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms.functional import to_pil_image

In [19]:
class MyDataset(Dataset):
    def __init__(self, images, image_indexes, questions, answers, transform=None):
        self.images = images
        self.questions = questions
        self.transform = transform
        self.image_indexes = image_indexes
        self.answers = answers

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        
        img_index = self.image_indexes[idx]
        img_path = images[int(img_index) - 1]
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        
        question_part = self.questions[idx]
        question = f"What is the {question_part} of the bird?"
        
        answer = self.answers[idx]
        
        return {'image': img, 'question': question, 'answer': answer}

In [20]:
image_indexes = merged_df['img_index'].to_list()  # List of image paths
questions = merged_df['attribute_qsn'].to_list() # List of questions
answers = merged_df['attribute_answer'].to_list()

# Define image transformations (replace with your actual transformations)
transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])

# Create an instance of the dataset
dataset = MyDataset(images, image_indexes, questions, answers, transform)

In [21]:
batch_size = 32
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [22]:
len(dataloader)

369

In [23]:
total_correct = total_samples = 0

for batch in dataloader:
    images_batch = batch['image']
    questions_batch = batch['question']
    
    # pil_images_batch = [to_pil_image(image) for image in images_batch]
    
    # Process each image and question in the batch
    processed_images_batch = [vis_processors["eval"](to_pil_image(image)).unsqueeze(0).to(device) for image in images_batch]
    processed_questions_batch = [txt_processors["eval"](question) for question in questions_batch]
    
    # print(processed_images_batch, processed_questions_batch)
    concatenated_images = torch.stack(processed_images_batch)
    reshaped_images_batch = torch.reshape(concatenated_images, (len(processed_images_batch), 3, 384, 384))

    
    # Use the model to predict answers
    predictions = model.predict_answers(samples={"image": reshaped_images_batch,
                                                 "text_input": processed_questions_batch},
                                        answer_list=answer_candidates,
                                        inference_method="rank")
        
    # Extract ground truth answers from the batch
    ground_truth_answers = batch['answer']
    
    # Calculate accuracy for this batch
    correct_predictions = sum([1 for pred, truth in zip(predictions, ground_truth_answers) if pred == truth])
    total_correct += correct_predictions
    total_samples += len(processed_images_batch)
    
    
# Calculate overall accuracy
accuracy = total_correct / total_samples

In [24]:
accuracy

0.4165252799457075