# VQA v2.0 Interactive Notebook

## Import Libraries

In [None]:
# Import libraries
import importlib
import os
import string
from collections import Counter
from typing import Iterable, List

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn_utils
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import transforms


In [None]:
# Check GPU
if not torch.cuda.is_available():
    raise RuntimeError("GPU with CUDA support is required")
device = torch.device("cuda")


In [None]:
# Import utils
if not os.path.exists("./utils"):
    !git clone https://github.com/ZhangShaozuo/Artificial_Intelligence_VQA.git
    !ln -s ./DL-BigProject-VQA/utils ./utils

if os.path.exists("./DL-BigProject-VQA/utils"):
    !cd ./DL-BigProject-VQA/utils && git pull

import utils.data as data_util
import utils.helper as helper
import utils.train as train_util
from utils.vocab import Vocab

importlib.reload(data_util)
importlib.reload(helper)
importlib.reload(train_util)
pass


## Load Dataset

In [None]:
# Load dataset (single word answer only)
image_transform = transforms.Compose(
    [
        transforms.Resize(int(224 / 0.875)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ]
)


def tokenizer(question: str) -> List[str]:
    # to lower case
    question = question.lower()
    # remove punctuation
    trans = str.maketrans("", "", string.punctuation)
    question = question.translate(trans)
    # split words
    return question.split()


question_vocab = Vocab({})


def question_transform(question: str):
    tokens = tokenizer(question)
    indices = [question_vocab[token] for token in tokens]
    return torch.tensor(indices, dtype=torch.long)


def question_transform_factory(corpus: Iterable[str]):
    global question_vocab
    counter = Counter()
    for text in corpus:
        counter.update(tokenizer(text))
    question_vocab = Vocab(counter, specials=["<pad>", "<unk>"])
    return question_transform


answer_vocab = Vocab({})


def answer_tansform(answer: str):
    return answer_vocab[answer]


def answer_tansform_factory(corpus: Iterable[str]):
    global answer_vocab
    answer_vocab = Vocab(Counter(corpus), specials=["<unk>"], min_freq=10)
    return answer_tansform


train_dataset = data_util.VQA2Dataset(
    "./VQA2/",
    group="train",
    image_transform=image_transform,
    question_transform_factory=question_transform_factory,
    answer_transform_factory=answer_tansform_factory,
    download=True,
)

valid_dataset = data_util.VQA2Dataset(
    "./VQA2/",
    group="val",
    image_transform=image_transform,
    question_transform=question_transform,
    answer_transform=answer_tansform,
)

test_dataset = data_util.VQA2Dataset(
    "./VQA2/",
    group="test",
    image_transform=image_transform,
    question_transform=question_transform,
    answer_transform=answer_tansform,
)

print("train_dataset:", len(train_dataset))
print("valid_dataset:", len(valid_dataset))
print("test_dataset:", len(test_dataset))
print()
print("quesiton_vocab size:", len(question_vocab))
print("answer_vocab size:  ", len(answer_vocab))


In [None]:
# Save vocab
question_vocab.save("question_vocab.json")
answer_vocab.save("answer_vocab.json")


In [None]:
# Load vocab
question_vocab = Vocab.load("question_vocab.json")
answer_vocab = Vocab.load("answer_vocab.json")


## Create Dataloader

In [None]:
# Create dataloader
batch_size = 64
PAD_IDX = question_vocab["<pad>"]


def generate_batch(data_batch):
    data_batch.sort(key=lambda x: -len(x[1]))  # for pack_padded_sequence
    images, questions, answers = zip(*data_batch)
    images = torch.stack(images, 0)
    q_lengths = [len(q) for q in questions]
    questions = rnn_utils.pad_sequence(questions, padding_value=PAD_IDX)
    answers = torch.tensor(answers, dtype=torch.long)
    return images, questions, q_lengths, answers


train_loader = DataLoader(
    train_dataset, batch_size=batch_size, collate_fn=generate_batch, shuffle=True
)

valid_loader = DataLoader(
    valid_dataset, batch_size=batch_size, collate_fn=generate_batch, shuffle=True
)

# use a subset of the validation dataset
mini_valid_loader = DataLoader(
    valid_dataset,
    batch_size=batch_size,
    collate_fn=generate_batch,
    sampler=SubsetRandomSampler(list(range(512))),
)


In [None]:
# Visualize some samples
def sequence_to_sentence(seq: List[int]) -> str:
    sentence = []
    for i in seq:
        if i == PAD_IDX:
            break
        sentence.append(question_vocab.itos[i])
    return " ".join(sentence)


def visualize_samples(images, questions, answers, max_num=-1):
    if max_num < 0:
        max_num = len(images)

    # PyTorch RNN is using (seq_len, batch, input_size)
    # make it (batch, seq_len, input_size)
    questions = questions.transpose(0, 1)

    for _, v, q, a in zip(range(max_num), images, questions, answers):
        print("Q:", sequence_to_sentence(q))
        print("A:", answer_vocab.itos[a])
        helper.imshow(v)
        plt.show()


images, questions, _, answers = next(iter(train_loader))
visualize_samples(images, questions, answers, max_num=4)


## Create Model

In [None]:
class VNet(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        self.backbone = torchvision.models.detection.fasterrcnn_resnet50_fpn(
            pretrained=pretrained
        ).backbone
        self.out_channels = self.backbone.out_channels

    def forward(self, images):
        return self.backbone(images)["3"]


In [None]:
class QNet(nn.Module):
    def __init__(self, vocab, embedding_dim, out_dim, weights_path=None):
        super().__init__()
        self.embedding = nn.Embedding(len(vocab), embedding_dim)
        self.tanh = nn.Tanh()
        self.rnn = nn.LSTM(embedding_dim, out_dim)

        if weights_path:
            counter = 0
            weights = self.embedding.weight.detach().numpy()
            with open(weights_path, encoding="utf-8") as f:
                for line in f:
                    elements = line.split(" ")

                    word = elements[0]
                    if word not in question_vocab.stoi:
                        continue

                    embed = np.asarray(elements[1:], dtype="float32")
                    weights[question_vocab.stoi[word]] = embed

                    counter += 1
                    if counter / len(question_vocab) > 0.9:
                        break
            self.embedding.weight.data.copy_(torch.from_numpy(weights))

    def forward(self, q, q_len):
        embedded = self.embedding(q)
        tanhed = self.tanh(embedded)
        packed = rnn_utils.pack_padded_sequence(tanhed, q_len)
        _, (_, features) = self.rnn(packed)
        features = features.squeeze(0)
        return features


In [None]:
class VQANet(nn.Module):
    def __init__(self, vocab, num_classes: int):
        super().__init__()
        self.v_net = VNet()
        self.q_net = QNet(vocab, 100, 256)

        self.v_query = nn.Sequential(nn.Conv2d(256, 128, 1), nn.Sigmoid())
        self.q_query = nn.Sequential(nn.Linear(256, 128), nn.Sigmoid())
        self.attention_softmax = nn.Softmax(dim=1)

        self.q_fc = nn.Sequential(nn.Linear(256, 256), nn.Sigmoid())

        self.classifier = nn.Sequential(
            nn.Linear(256, 512), nn.ReLU(True), nn.Linear(512, num_classes)
        )
        self.last_attention = None

    def forward(self, v, q, q_len):
        v_feat = self.v_net(v)
        q_feat = self.q_net(q, q_len)

        v_query = self.v_query(v_feat)
        q_query = self.q_query(q_feat)
        attention = (v_query * q_query.view((-1, 128, 1, 1))).sum(dim=1)
        attention = self.attention_softmax(attention.view(-1, 7 * 7)).view(-1, 1, 7, 7)
        self.last_attention = attention.detach()  # save for visualization

        v_final = (v_feat * attention).view((-1, 256, 7 * 7)).sum(dim=2)
        q_final = self.q_fc(q_feat)
        out = self.classifier(v_final * q_final)
        return out


In [None]:
model = VQANet(question_vocab, len(answer_vocab)).to(device)
for p in model.v_net.parameters():
    p.requires_grad = False


## Train Model

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

history = train_util.train_model(
    model,
    criterion,
    optimizer,
    train_loader,
    mini_valid_loader,
    epochs=10,
    valid_every=200,
)
train_util.plot_history(history)


In [None]:
train_util.plot_history(history)


## Test Model

In [None]:
# Accuracy on the validation set
valid_loss, valid_accu = train_util.validate_model(
    model, mini_valid_loader, nn.CrossEntropyLoss()
)
print("loss:", valid_loss)
print("accu:", valid_accu)


In [None]:
# Plot some samples
images, questions, question_lengths, answers = next(iter(train_loader))
model.eval()
with torch.no_grad():
    outputs = model(images.to(device), questions.to(device), question_lengths)
predictions = torch.argmax(outputs, dim=1).cpu()

questions = questions.transpose(0, 1)
for _, v, q, a, pred in zip(range(4), images, questions, answers, predictions):
    print("Q:", sequence_to_sentence(q))
    print("A:", answer_vocab.itos[a])
    print("Model:", answer_vocab.itos[pred])
    helper.imshow(v)
    plt.show()


## Attention Visualization

In [None]:
def visualize_attention(loader):
    images, questions, question_length, answers = next(iter(loader))
    predictions = (
        model(images.to(device), questions.to(device), question_length).detach().cpu()
    )
    attenions = model.last_attention.detach().cpu()
    questions = questions.transpose(0, 1)

    for _, v, q, a, pred, atten in zip(
        range(4), images, questions, answers, predictions, attenions
    ):
        print("Q:", sequence_to_sentence(q))
        print("A:", answer_vocab.itos[a])
        print("model:", answer_vocab.itos[pred.argmax(0)])
        fig = plt.figure()
        ax1 = fig.add_subplot(121)
        ax2 = fig.add_subplot(122)
        helper.imshow(v, ax1)
        ax2.imshow(atten.squeeze(0))
        plt.show()


visualize_attention(train_loader)
