<a href="https://www.kaggle.com/code/aisuko/visual-question-answering?scriptVersionId=164950683" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

# Overview

Visual Question Answering(VQA) is the task of answering open-ended questions based on an image. The input to model supporting this task is typically a combination of an image and a question, and the output is an answer expressed in natural language. Some noteworthy use case examples for VQA include:

- Accessibility applications for visually impaired individuals
- `Education:` posing questions about visual materials presented in lectures or textbooks. VQA can also be utilized in interactive museum exhibits or historical sites.
- `Customer service and e-commerce:` VQA can enhance user experience by letting users ask questions about products.
- `Image retrieval:` VQA models can be used to retrieve images with specific characteristics. For examples, the user can ask "Is there a dog?" to find all images with dogs from a set of images.

Let's fine-tune **a classification** VQA model, specifically ViLT(Vision and Language Transformer Without Convolution or Region Supervision), on the VQA dataset.

# Fine-tuning ViLT

ViLT model incorporates text embeddings into a Vision Transformer(ViT), allowing it to have a minial design for Vision-and-Language Pre-training(VLP). This model can be used for severl downstream tasks. For the VQA task, a classifier head is placed on top(a linear layer on top of the final hidden state of the CLS token) and randomly initialized.

Visual Question Answering is thus treated as a **classification problem**. However, more recent models, such as BLIP, BLIP-2, and InstructBLIP, treat VQA as a generative task.

In [None]:
%%capture
!pip install transformers==4.35.2
!pip install datasets==2.15.0

In [None]:
import os
from huggingface_hub import login
from kaggle_secrets import UserSecretsClient

user_secrets = UserSecretsClient()

login(token=user_secrets.get_secret("HUGGINGFACE_TOKEN"))

os.environ["WANDB_API_KEY"]=user_secrets.get_secret("WANDB_API_KEY")
os.environ["WANDB_PROJECT"] = "Fine-tuning vilt-b32-mlm"
os.environ["WANDB_NOTES"] = "Fine tune model distilbert base uncased"
os.environ["WANDB_NAME"] = "ft-vilt-b32-mlm"
os.environ["MODEL_NAME"] = "dandelin/vilt-b32-mlm"

# Loading the data

Feel free to choose different size datasets by using `split` parameter.

In [None]:
from datasets import load_dataset

dataset=load_dataset("aisuko/vqa", split="validation[:500]")
dataset

In [None]:
dataset[0]

These features belows relevant to the task, so we can remove other features.
- `question`: the question to be answerted from the image
- `image_id`: the path to the image the question refers to
- `label`: the annotations

In [None]:
# dataset=dataset.remove_columns(['question_type', 'question_id','answer_type'])

The label feature contains several answers to the same question(called `ids` here) collected by different human annotators. This is because the answer to a question can be subjective. In this case, the question is "where is he looking?".

Some poeple annotated this with "down", others with "at table", another one with "skateboard", etc.

In [None]:
from PIL import Image

image=Image.open(dataset[0]['image_id'])
image

Due to the questions' and answer's ambiguity, datasets like this are treated as a multi-label classification problem( as multiple answers are possibly valid). Moreover, rather than just creating a one-hot encoded vector, one creates a soft encoding, based on the number of times a certain answer appeared in the annotations. For instance, in the example above, because the answer "down" is selected way more often than other answers, it has a score (called `weight` in the dataset) of 1.0, and the rest of the answers have scores<1.0. To later instantiate the model with an appropriate classification head, let's create two dictionaries: one that maps the label name to an integer and vice versa:

In [None]:
import itertools

labels=[item['ids'] for item in dataset['label']]
flattened_labels=list(itertools.chain(*labels))
unique_labels=list(set(flattened_labels))

label2id={label:idx for idx,label in enumerate(unique_labels)}
id2label={idx: label for label, idx in label2id.items()}

def replace_ids(inputs):
    inputs["label"]["ids"]=[label2id[x] for x in inputs["label"]["ids"]]
    return inputs


dataset=dataset.map(replace_ids)
flat_dataset=dataset.flatten()
flat_dataset.features

# Preprocessing data

Here we are going to load a ViLT processor to prepare the image and text data for the model. **ViltProcessor** wraps a BERT tokenizer and ViLT image processor into a convenient single processor:

In [None]:
from transformers import ViltProcessor

processor=ViltProcessor.from_pretrained(os.getenv('MODEL_NAME'))

To preprocess the data we need to encode the images and questions using the **ViltProcessor**. The processor will use the **BertTokenizeFast** to tokenize the text and create `input_ids`, `attention_mask` and `token_type_ids` for the text data. As for images, the processor will leverage **ViltImageProcessor** to resize and normalize the image, and create `pixel_values` and `pixel_mask`.

All these preprocessing steps are done under the hood, we only need to call the processor. However, we still need to prepare the target labels. In this representation, each element corresponds to a possible answer(label). For correct answers, the element holds their respective score (weight), while the remaining elements are set to zero. The following function applies the processor to the images and questions and formats the labels as described above:

We will apply the preprocessing function over the entire dataset. Setting `batched=True` to speed up `map` to process multiple elements of the dataset at once. And we can also remove the columns we don't need:

In [None]:
import torch

def preprocess_data(examples):
    image_paths=examples['image_id']
    images=[Image.open(image_path) for image_path in image_paths]
    texts=examples['question']
    
    encoding=processor(images, texts, padding="max_length", truncation=True, return_tensors="pt")
    
    for k,v in encoding.items():
        encoding[k]=v.squeeze()
    
    targets=[]
    
    for labels, scores in zip(examples['label.ids'], examples['label.weights']):
        target=torch.zeros(len(id2label))
        
        for label, score in zip(labels, scores):
            target[label]=score
            
        targets.append(target)
    encoding["labels"]=targets
    
    return encoding

processed_dataset=flat_dataset.map(preprocess_data, batched=True, remove_columns=['question','question_type','question_id', 'image_id', 'answer_type', 'label.ids', 'label.weights'])
processed_dataset

# Creating Batch of Data

Creating a batch of examples using DefaultDataCollator

In [None]:
from transformers import DefaultDataCollator

data_collator=DefaultDataCollator()

# Training

In [None]:
from transformers import ViltForQuestionAnswering

model=ViltForQuestionAnswering.from_pretrained(os.getenv('MODEL_NAME'), num_labels=len(id2label), id2label=id2label, label2id=label2id, device_map='auto')
print(model.config)

In [None]:
from transformers import TrainingArguments, Trainer

training_args=TrainingArguments(
    output_dir=os.getenv("WANDB_NAME"),
    per_device_train_batch_size=4,
    num_train_epochs=5,
    save_steps=200,
    logging_steps=50,
    learning_rate=5e-5,
    save_total_limit=2,
    remove_unused_columns=False,
    push_to_hub=False,
    fp16=True,
    report_to="wandb", # or report_to="tensorboard"
    run_name=os.getenv("WANDB_NAME"),
)

trainer=Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=processed_dataset,
    tokenizer=processor,
)

trainer.train()

In [None]:
processor.push_to_hub(os.getenv("WANDB_NAME"))
trainer.push_to_hub(os.getenv("WANDB_NAME"))

# Inference

In [None]:
from transformers import pipeline

pipe=pipeline("visual-question-answering", model=os.getenv("WANDB_NAME"), device_map='auto')

In [None]:
example=dataset[0]
image=Image.open(example['image_id'])
question=example['question']

print(question)

In [None]:
pipe(image, question, top_k=1)

Or manually replicate the results of the pipeline:

1. Take an image and a question, prepare them for the model using the processor from the model
2. Forward the result or preprocessing through the model
3. From the logits, get the most likely answer's id, and find the actual answer in the `id2label`

In [None]:
processor=ViltProcessor.from_pretrained(os.getenv("WANDB_NAME"))

In [None]:
image=Image.open(example['image_id'])
question=example['question']

inputs=processor(image, question, return_tensors="pt")

model=ViltForQuestionAnswering.from_pretrained(os.getenv("WANDB_NAME"))


# forward pass
with torch.no_grad():
    outputs=model(**inputs)

logits=outputs.logits
idx=logits.argmax(-1).item()
print("Predicted answer", model.config.id2label[idx])