# Performing visual question answering (VQA) with ViLT

This is a demo notebook that we have modified from Source: https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/ViLT/Inference_with_ViLT_(visual_question_answering).ipynb#scrollTo=eTEyjLn2gdlH. All of the credit goes to the original author, Niels Rogge.

In this notebook, we are going to illustate visual question answering with the Vision-and-Language Transformer (ViLT). This model is very minimal: it only adds text embedding layers to an existing ViT model. It does not require any sophisticated CNN-based pipelines to feed the image to the model (unlike models like [PixelBERT](https://arxiv.org/abs/2004.00849) and [LXMERT](https://arxiv.org/abs/1908.07490)). This makes the model also much faster than previous works.

![ViLT architecture](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/vilt_architecture.jpg)

Figure taken from the original [paper](https://arxiv.org/abs/2102.03334).

HuggingFace docs: https://huggingface.co/docs/transformers/master/en/model_doc/vilt

## Set-up environment

First, we install HuggingFace Transformers and import the requred packages..

In [None]:
%pip install -q git+https://github.com/huggingface/transformers.git
import torch
import time
import requests
from PIL import image
from transformers import ViltProcessor

## Prepare image + question

Here we take our familiar cats image (of the COCO dataset) and create a corresponding question.

In [None]:
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw) # open image using Pillow
text = "How many cats are there?" # image prompt
image # display image

In [None]:
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa") # Initialized a pre-trained VILT for VQA

Let's prepare the image+text pair for the model. Here, we leverage `ViltProcessor`, which will use (behind the scenes):
* `BertTokenizerFast` to tokenize the text (and create input_ids, attention_mask, token_type_ids)
* `ViltFeatureExtractor` to resize + normalize the image (and create pixel_values and pixel_mask). 

Note that the `pixel_mask` is only relevant in case of batches, as it can be used to indicate which pixels are real/which are padding. Here we're only preparing a single example for the model, hence all values of pixel_mask will be 1.

In [None]:
# Create a VILT Encoding of the image and text

encoding = processor(image, text, return_tensors="pt")
for k,v in encoding.items():
  print(k, v.shape)

## Define model

Here we load the ViLT model, fine-tuned on VQAv2, from the [hub](https://huggingface.co/dandelin/vilt-b32-finetuned-vqa).

In [None]:
from transformers import ViltForQuestionAnswering 

model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")

## Forward pass

We can now forward both the `input_ids` and `pixel_values` through the model. The model outputs logits of shape (batch_size, num_labels), which in this case will be (1, 3129) - as the VQAv2 dataset has 3129 possible answers.

In [None]:
# forward pass
outputs = model(**encoding)
logits = outputs.logits
idx = torch.sigmoid(logits).argmax(-1).item() # apply activation
print("Predicted answer:", model.config.id2label[idx])

In [None]:
# Cell to compute average inference time

starting_time = time.time()
image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(image_url, stream=True).raw)
text = "Where is the car?"
encoding = processor(image, text, return_tensors="pt")
for k, v in encoding.items():
  print(k, v.shape)

# forward pass
outputs = model(**encoding)
logits = outputs.logits
idx = torch.sigmoid(logits).argmax(-1).item()
print("Time taken:", time.time() - starting_time)
print("Predicted answer:", model.config.id2label[idx])

In [None]:
# print out logits info
logits.shape
logits[0,:3]