In [None]:
!pip install transformers

In [None]:
from transformers import ViltProcessor, ViltForQuestionAnswering
from PIL import Image

In [None]:
# imports
import os
import re
import time
import json
import math
import shutil
import random
import pandas as pd
import numpy as np
from PIL import Image
from collections import Counter, defaultdict
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet152, ResNet152_Weights
import torch.optim as optim
from IPython.display import clear_output
import warnings
warnings.filterwarnings("ignore")

In [None]:
!wget http://images.cocodataset.org/zips/val2014.zip
!unzip /content/val2014.zip
!rm /content/val2014.zip

!wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Val_mscoco.zip
!unzip /content/v2_Questions_Val_mscoco.zip
!rm /content/v2_Questions_Val_mscoco.zip
!mv /content/v2_OpenEnded_mscoco_val2014_questions.json /content/val2014questions.json

!wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Val_mscoco.zip
!unzip /content/v2_Annotations_Val_mscoco.zip
!rm /content/v2_Annotations_Val_mscoco.zip
!mv /content/v2_mscoco_val2014_annotations.json /content/val2014answers.json

!mkdir /content/questions
!mkdir /content/answers

!mv /content/val2014questions.json /content/questions/val.json
!mv /content/val2014answers.json /content/answers/val.json

In [None]:
class VQADataset(Dataset):
    def __init__(self, phase, questions_dir, answers_dir, images_dir):
        self.phase = phase
        self.questions_json = questions_dir + "/" + self.phase + ".json"
        self.answers_json = answers_dir + "/" + self.phase + ".json"
        self.images_dir = images_dir

        self.dataset = self.create_dataset()


    def create_dataset(self):
        with open(self.questions_json) as f:
            questions = json.load(f)["questions"]
        with open(self.answers_json) as f:
            answers = json.load(f)["annotations"]

        dataset = []
        file_loop = tqdm(enumerate(zip(questions, answers)), total=len(questions), colour="green")
        for idx, (q, a) in file_loop:
            if(q["image_id"]!=a["image_id"]):
                continue
            image_id = str(q["image_id"])
            image_path = self.images_dir + "/" + self.phase + "/" + image_id + ".jpg"

            ans = a["answers"]
            answers = []

            for answer in ans:
                if((answer["answer_confidence"]=="yes") and (answer["answer"] not in answers)):
                    answers.append(answer["answer"].lower())

            sample = {}
            sample["image_path"] = image_path
            sample["question"] = q["question"]
            sample["answers"] = answers
            dataset.append(sample)

            file_loop.set_description(f"Generating {self.phase} data")

        random.shuffle(dataset)
        return dataset


    def __len__(self):
        return len(self.dataset)


    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = index.tolist()

        sample = self.dataset[index]
        image_path =  sample["image_path"]
        image = Image.open(image_path).convert("RGB")
        question = sample["question"]
        answers = sample["answers"]

        return image, question, answers

In [None]:
def rename_image_dataset(phase, input_dir, num_samples=None):
    images = os.listdir(input_dir)
    if(len(images)==0):
        print("Input directory {} is empty".format(input_dir))
    else:
        if num_samples is not None:
            random.shuffle(images)
            images = images[:num_samples]
        image_count = len(images)
        file_loop = tqdm(enumerate(images), total=len(images), colour="green")
        for n_image, image_name in file_loop:
            try:
                input_image_path = os.path.join(input_dir + '/', image_name)
                with open(input_image_path, 'r+b') as f:
                    with Image.open(f) as img:
                        image_name = image_name.split("_")[-1].lstrip("0")
                        output_image_path = os.path.join(input_dir + '/', image_name)
                        img.save(output_image_path, img.format)
                        os.remove(input_image_path)
            except (IOError, SyntaxError) as e:
                print("Error while resizing {}".format(image_name))
                pass
            file_loop.set_description(f"Resizing {phase} images...")

In [None]:
rename_image_dataset(phase="val", input_dir="/content/val2014")

In [None]:
val_dataset = VQADataset(phase="val",
                         questions_dir="/content/questions",
                         answers_dir="/content/answers",
                         images_dir="/content/val2014")

val_loader = DataLoader(val_dataset,
                        batch_size=1,
                        shuffle=False)

In [None]:
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")

In [None]:
# device = "cuda" if torch.cuda.is_available() else 'cpu'
device = 'cpu'
model.to(device)

In [None]:
model.device

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Saving model checkpoint in google drive.

import torch

# torch.save(model.state_dict(), '/content/drive/MyDrive/model.pth')

In [None]:
# Loading model from google drive.

model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
model.load_state_dict(torch.load('/content/drive/MyDrive/model.pth'))

In [None]:
import pickle
pickle.dump(model, open('model3.pkl', 'wb'))

In [None]:
print(model)

In [None]:
questions_json = "/content/questions/val.json"
answers_json = "/content/answers/val.json"
with open(questions_json) as f:
    questions = json.load(f)["questions"]
with open(answers_json) as f:
    answers = json.load(f)["annotations"]

In [None]:
questions[1]

In [None]:
answers[1]

# Inference on custom images and questions

In [None]:
def predict(im_path, ques):
  image = Image.open(im_path).convert("RGB")
  question = ques

  encodings = processor(image, question, return_tensors="pt").to(device)
  outputs = model(**encodings)
  logits = outputs.logits
  _, answer_index_top5 = torch.topk(logits, 5)

  predicted_answer = []
  for pred_answer_index in answer_index_top5[0, :]:
      predicted_answer.append(model.config.id2label[pred_answer_index.item()])

  return predicted_answer[0]

In [None]:
image = Image.open("/content/testimage.jpg").convert("RGB")
question = "What object is in the image?"
all_answers = []

encodings = processor(image, question, return_tensors="pt").to(device)
outputs = model(**encodings)
logits = outputs.logits
_, answer_index_top5 = torch.topk(logits, 5)

predicted_answer = []
for pred_answer_index in answer_index_top5[0, :]:
    predicted_answer.append(model.config.id2label[pred_answer_index.item()])

In [None]:
print(predicted_answer)

# Calculating validation accuracy.

In [None]:
images_dir="/content/val2014"

questions = questions[2:]
answers = answers[2:]
correct = 0
total = 0

correct_q = 0
total_q = 0

file_loop = tqdm(enumerate(zip(questions, answers)), total=len(questions), colour="green")
for idx, (q, a) in file_loop:
    if(q["image_id"]!=a["image_id"]):
        continue
    image_id = str(q["image_id"])
    image_path = images_dir + "/" + image_id + ".jpg"

    ans = a["answers"]
    all_answers = []

    for answer in ans:
        if((answer["answer_confidence"]=="yes") and (answer["answer"] not in all_answers)):
            all_answers.append(answer["answer"].lower())

    image = Image.open(image_path).convert("RGB")
    question = q["question"]
    all_answers = all_answers

    encodings = processor(image, question, return_tensors="pt").to(device)
    outputs = model(**encodings)
    logits = outputs.logits
    _, answer_index_top5 = torch.topk(logits, 5)
    predicted_answer = []
    for pred_answer_index in answer_index_top5[0, :]:
        predicted_answer.append(model.config.id2label[pred_answer_index.item()])

    if predicted_answer[0] in all_answers:
        correct_q += 1
        total_q += 1
    else:
        total_q += 1

    for ans in all_answers:
        total += 1
        if ans in predicted_answer:
            correct+=1

    file_loop.set_description(f"Testing on validation data")

In [None]:
print("Accuracy :", correct_q/total_q *100)

In [None]:
print("Accuracy :", correct/total *100)