In [None]:
# imports
import os
import re
import time
import json
import math
import shutil
import random
import pandas as pd
import numpy as np
from PIL import Image
from collections import Counter, defaultdict
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet152, ResNet152_Weights
import torch.optim as optim
from IPython.display import clear_output
import warnings
warnings.filterwarnings("ignore")

In [None]:
#removing unnecessary files from colab
!rm -rf /content/sample_data

#for downloading training data
!wget http://images.cocodataset.org/zips/train2014.zip
!unzip /content/train2014.zip
!rm /content/train2014.zip

!wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Train_mscoco.zip
!unzip /content/v2_Questions_Train_mscoco.zip
!rm /content/v2_Questions_Train_mscoco.zip
!mv /content/v2_OpenEnded_mscoco_train2014_questions.json /content/train2014questions.json

!wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Train_mscoco.zip
!unzip /content/v2_Annotations_Train_mscoco.zip
!rm /content/v2_Annotations_Train_mscoco.zip
!mv /content/v2_mscoco_train2014_annotations.json /content/train2014answers.json


# for downloading validation data
!wget http://images.cocodataset.org/zips/val2014.zip
!unzip /content/val2014.zip
!rm /content/val2014.zip

!wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Val_mscoco.zip
!unzip /content/v2_Questions_Val_mscoco.zip
!rm /content/v2_Questions_Val_mscoco.zip
!mv /content/v2_OpenEnded_mscoco_val2014_questions.json /content/val2014questions.json

!wget https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Val_mscoco.zip
!unzip /content/v2_Annotations_Val_mscoco.zip
!rm /content/v2_Annotations_Val_mscoco.zip
!mv /content/v2_mscoco_val2014_annotations.json /content/val2014answers.json

clear_output()

!mkdir /content/questions
!mkdir /content/answers

!mv /content/train2014questions.json /content/questions/train.json
!mv /content/val2014questions.json /content/questions/val.json
!mv /content/train2014answers.json /content/answers/train.json
!mv /content/val2014answers.json /content/answers/val.json

In [None]:
def resize_image(image, image_size):
    return image.resize(image_size, Image.LANCZOS)

def resize_image_dataset(phase, input_dir, images_dir, image_size, num_samples=None):
    images = os.listdir(input_dir)
    if(len(images)==0):
        print("Input directory {} is empty".format(input_dir))
    else:
        output_dir = images_dir + "/" + phase
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        if num_samples is not None:
            random.shuffle(images)
            images = images[:num_samples]
        image_count = len(images)
        file_loop = tqdm(enumerate(images), total=len(images), colour="green")
        for n_image, image_name in file_loop:
            try:
                with open(os.path.join(input_dir + '/', image_name), 'r+b') as f:
                    with Image.open(f) as img:
                        img = resize_image(img, image_size)
                        image_name = image_name.split("_")[-1].lstrip("0")
                        output_image_path = os.path.join(output_dir + '/', image_name)
                        img.save(output_image_path, img.format)
            except (IOError, SyntaxError) as e:
                print("Error while resizing {}".format(image_name))
                pass
            file_loop.set_description(f"Resizing {phase} images...")
        shutil.rmtree(input_dir)

In [None]:
def load_text_file(file_path):
    with open(file_path) as f:
        text = f.read().splitlines()
    return text


SENTENCE_SPLIT_REGEX = re.compile(r'(\W+)')
def tokenise(sentence):
    tokens = SENTENCE_SPLIT_REGEX.split(sentence.lower())
    tokens = [t.strip() for t in tokens if len(t.strip()) > 0]
    return tokens


class Vocab:
    def __init__(self, vocab_file_path):
        self.vocab = load_text_file(vocab_file_path)
        self.word2index = {word:index for index, word in enumerate(self.vocab)}
        self.vocab_size = len(self.vocab)
        self.unk_index = self.word2index["<unk>"] if "<unk>" in self.word2index else None
    

    def vocabSize(self):
        return self.vocab_size
    

    def is_present(self, word):
        if word in self.vocab:
            return True
        return False


    def idx2word(self, idx):
        return self.vocab[idx]
    

    def word2idx(self, word):
        if word in self.word2index:
            return self.word2index[word]
        elif self.unk_index is not None:
            return self.unk_index
        else:
            raise ValueError("word {} is not in dictionary and <unk> does not exist".format(word))

In [None]:
def create_question_vocab(json_file, vocab_dir, min_word_count=2):
    with open(json_file) as f:
        data = json.load(f)["questions"]

    vocab = []
    data_loop = tqdm(enumerate(data), total=len(data), colour="green")
    for idx, data in data_loop:
        vocab += tokenise(data["question"])

        data_loop.set_description("Generating Question Vocabulary")

    word_count = Counter(vocab)
    vocab.clear()

    for word in word_count:
        if word_count[word]>min_word_count:
            vocab.append(word)

    vocab.sort()
    vocab.insert(0, "<pad>")
    vocab.insert(1, "<unk>")   
    
    if not os.path.exists(vocab_dir):
        os.makedirs(vocab_dir)
    
    question_vocab_file = vocab_dir + "/question_vocab.txt"
    with open(question_vocab_file, "w") as f:
        f.writelines([word+"\n" for word in vocab])
    
    return question_vocab_file, len(vocab)

In [None]:
def create_answer_vocab(json_file, vocab_dir, topN=None):
    with open(json_file) as f:
        data = json.load(f)["annotations"]

    vocab = []
    data_loop = tqdm(enumerate(data), total=len(data), colour="green")
    for idx, data in data_loop:
        answers = data["answers"]
        for answer in answers:
            if (answer["answer_confidence"]=="yes"):
                vocab += tokenise(answer["answer"])
        data_loop.set_description("Generating Answer Vocabulary")

    word_count = Counter(vocab)
    
    vocab.clear()
    for word in word_count:
        vocab.append(word)

    if (topN is not None) and (len(vocab)>topN):
        vocab = vocab[:topN]
    vocab.insert(0, "<unk>")

    if not os.path.exists(vocab_dir):
        os.makedirs(vocab_dir)
    
    answer_vocab_file = vocab_dir + "/answer_vocab.txt"
    with open(answer_vocab_file, "w") as f:
        f.writelines([word+"\n" for word in vocab])
    
    return answer_vocab_file, len(vocab)

In [None]:
class VQADataset(Dataset):
    def __init__(self, phase, questions_dir, answers_dir, question_vocab, answer_vocab, images_dir, max_question_length, transform=None):
        self.phase = phase
        self.questions_json = questions_dir + "/" + self.phase + ".json"
        self.answers_json = answers_dir + "/" + self.phase + ".json"
        self.question_vocab_path = question_vocab_path
        self.answer_vocab_path = answer_vocab_path
        self.images_dir = images_dir
        self.max_question_length = max_question_length
        self.transform = transform

        self.question_vocab = question_vocab
        self.answer_vocab = answer_vocab

        self.dataset = self.create_dataset()
    

    def create_dataset(self):
        with open(self.questions_json) as f:
            questions = json.load(f)["questions"]
        with open(self.answers_json) as f:
            answers = json.load(f)["annotations"]

        dataset = []
        file_loop = tqdm(enumerate(zip(questions, answers)), total=len(questions), colour="green")
        for idx, (q, a) in file_loop:
            if(q["image_id"]!=a["image_id"]):
                continue
            image_id = str(q["image_id"])
            image_path = self.images_dir + "/" + self.phase + "/" + image_id + ".jpg"

            ans = a["answers"]
            answers = []

            for answer in ans:
                if((answer["answer_confidence"]=="yes") and (answer["answer"] not in answers)):
                    answers.append(answer["answer"])
            
            sample = {}
            sample["image_path"] = image_path
            sample["question"] = q["question"]
            sample["answers"] = answers
            dataset.append(sample)

            file_loop.set_description(f"Generating {self.phase} data")
        
        random.shuffle(dataset)
        return dataset


    def __len__(self):
        return len(self.dataset)
    
    
    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = index.tolist()
        
        sample = self.dataset[index]
        image_path =  sample["image_path"]
        image = Image.open(image_path).convert("RGB")
        if self.transform is not None:
            image = self.transform(image)

        question = np.array([self.question_vocab.word2idx("<pad>")] * self.max_question_length)
        question_tokenised = [self.question_vocab.word2idx(token) for token in tokenise(sample["question"])]
        if len(question_tokenised)<=self.max_question_length:
            question[:len(question_tokenised)] = question_tokenised
        else:
            question = question_tokenised[:self.max_question_length]

        all_answers = [self.answer_vocab.word2idx(ans) for ans in sample["answers"]]
        if self.answer_vocab.unk_index in all_answers:
            all_answers.remove(self.answer_vocab.unk_index)
        if len(all_answers)==0:
            answer = self.answer_vocab.unk_index
        else:
            answer = random.choice(all_answers)
        
        return image, question, answer

In [None]:
def get_loader(question_dir, answer_dir, question_vocab_path, answer_vocab_path, images_dir, max_question_length=20, transform=None, shuffle=True, num_workers=2, batch_size=64):
    data_loader = {}
    question_vocab = Vocab(question_vocab_path)
    answer_vocab = Vocab(answer_vocab_path)
    for phase in ["train", "val"]:
        dataset = VQADataset(phase, question_dir, answer_dir, question_vocab, answer_vocab, images_dir, max_question_length, transform)
        phase_dataloader = DataLoader(dataset=dataset,
                                      batch_size=batch_size, 
                                      shuffle=shuffle,
                                      num_workers=num_workers)
        data_loader[phase] = phase_dataloader
    
    return data_loader

In [None]:
class ImgEncoder(nn.Module):

    def __init__(self, embed_size):

        """
            loads pre-trained ResNet model 
            generates the image features from the input image
        """

        super(ImgEncoder, self).__init__()
        self.cnn_network = resnet152(weights=ResNet152_Weights.DEFAULT)
        in_features = self.cnn_network.fc.in_features
        self.cnn_network.fc = nn.Identity()
        self.fc = nn.Linear(in_features, embed_size)
        # self.tanh = nn.Tanh()   
        
        for parameter in self.cnn_network.parameters():
            parameter.requires_grad = False 


    def forward(self, image):
        """
            Extract feature vector from image vector
        """
        with torch.no_grad():
            img_feature = self.cnn_network(image).flatten(start_dim=1)  
                
        img_feature = F.normalize(img_feature, p=2.0, dim=1) 
        img_feature = self.fc(img_feature)   
        # img_feature = self.tanh(img_feature)           

        return img_feature


class QstEncoder(nn.Module):

    def __init__(self, qst_vocab_size, word_embed_size, embed_size=1024, num_layers=2, hidden_size=512):

        super(QstEncoder, self).__init__()
        self.word2vec = nn.Embedding(qst_vocab_size, word_embed_size)
        self.lstm = nn.LSTM(word_embed_size, hidden_size, num_layers)
        self.fc = nn.Linear(2*num_layers*hidden_size, embed_size)
        # self.tanh = nn.Tanh()     


    def forward(self, question):

        qst_vec = self.word2vec(question)                                                   
        qst_vec = qst_vec.transpose(0, 1) 
        _, (hidden, cell) = self.lstm(qst_vec)                                          
        qst_feature = torch.cat((hidden, cell), 2)                  
        qst_feature = qst_feature.transpose(0, 1)                     
        qst_feature = qst_feature.reshape(qst_feature.size()[0], -1)  
        qst_feature = self.fc(qst_feature)  
        # qst_feature = self.tanh(qst_feature)                         
        return qst_feature


class VQAModel(nn.Module):

    def __init__(self, qst_vocab_size, ans_vocab_size, word_embed_size, embed_size=1024, num_layers=2, hidden_size=512):

        super(VQAModel, self).__init__()
        self.img_encoder = ImgEncoder(embed_size)
        self.qst_encoder = QstEncoder(qst_vocab_size, word_embed_size, embed_size, num_layers, hidden_size)
        self.tanh = nn.Tanh()
        self.dropout = nn.Dropout(0.1)
        self.fc1 = nn.Linear(embed_size, ans_vocab_size)
        self.fc2 = nn.Linear(ans_vocab_size, ans_vocab_size)


    def forward(self, img, qst):

        img_feature = self.img_encoder(img)                     
        qst_feature = self.qst_encoder(qst)                   
        combined_feature = torch.mul(img_feature, qst_feature)  
        combined_feature = self.tanh(combined_feature)
        combined_feature = self.dropout(combined_feature)
        combined_feature = self.fc1(combined_feature)           
        combined_feature = self.tanh(combined_feature)
        combined_feature = self.dropout(combined_feature)
        combined_feature = self.fc2(combined_feature) 
        # prediction = self.tanh(combined_feature)         
        
        return combined_feature

In [None]:
def save_checkpoint(epoch, model_dir, model, optimizer=None, scheduler=None):
    checkpoint = {}
    checkpoint["model_state"] = model.state_dict()

    checkpoint["optimizer_state"] = None
    if optimizer is not None:
        checkpoint["optimizer_state"] = optimizer.state_dict()

    checkpoint["scheduler_state"] = None
    if scheduler is not None:
        checkpoint["scheduler_state"] = scheduler.state_dict()

    epoch_name = str(epoch)
    file_path = model_dir + "/epoch-" + epoch_name + ".pth"
    torch.save(checkpoint, file_path)


def load_checkpoint(model_file_path, model, optimizer=None, scheduler=None):
    if(os.path.exists(model_file_path)):
        checkpoint = torch.load(model_file_path)
        model.load_state_dict(checkpoint["model_state"])
        if checkpoint["optimizer_state"] is not None and optimizer is not None:
            optimizer.load_state_dict(checkpoint["optimizer_state"])
        if checkpoint["scheduler_state"] is not None and scheduler is not None:
            scheduler.load_state_dict(checkpoint["scheduler_state"])
        print("Checkpoint loaded")

In [None]:
def train_model(model, criterion, optimizer, scheduler, device, num_epochs, data_loader, model_file_path, model_dir):
    model = model.to(device)
    
    if model_file_path is not None:
        load_checkpoint(model_file_path, model, optimizer, scheduler)
    
    saved_state_loss:float = 1e9

    for epoch in range(num_epochs):
        for phase in ["train", "val"]:
            if phase=="train":
                model.train()
            else:
                model.eval()
                correct:int = 0
                incorrect:int = 0
                total_loss:float = 0.0

            enumerator = tqdm(enumerate(data_loader[phase]), total=len(data_loader[phase]), leave=True, colour="green")
            for batch_idx, (image, question, answer) in enumerator:
                image = image.to(device)
                question = question.to(device).long()
                output = model(image, question)      
                answer = answer.to(device)
                
                loss = criterion(output, answer)

                if phase=="train":
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    if scheduler is not None:
                        scheduler.step()
                    
                    enumerator.set_description(f"Epoch : [{epoch+1}/{num_epochs}]")
                    enumerator.set_postfix(loss=loss.item())
                else:
                    total_loss += loss.item()
                    _, prediction = torch.max(output, 1)
                    correct += (prediction==answer).sum().item()
                    incorrect += (prediction!=answer).sum().item()
                    
                    enumerator.set_description("Running validation set. Calculating accuracy...")

            if phase=="train":
                save_checkpoint(epoch+1, model_dir, model, optimizer, scheduler)
            else:
                val_accuracy = ((correct) / (correct + incorrect))
                average_loss = (total_loss / (correct + incorrect))
                print(f"\nAverage loss on validation set : {average_loss:.3f} | Accuracy on validation set : {val_accuracy:.3f}")


In [None]:
# resizing images
train_input_dir = "/content/train2014"
val_input_dir = "/content/val2014"
images_dir = "/content/images"
image_size = (224, 224)
resize_image_dataset("train", train_input_dir, images_dir, image_size)
resize_image_dataset("val", val_input_dir, images_dir, image_size)

In [None]:
# generating vocabulary
vocab_dir = "/content/vocab"
questions_json = "/content/questions/train.json"
min_word_count = 3
question_vocab_path, question_vocab_size = create_question_vocab(questions_json, vocab_dir, min_word_count)

topN = 1000
answers_json = "/content/answers/train.json"
answer_vocab_path, answer_vocab_size = create_answer_vocab(answers_json, vocab_dir, topN)

In [None]:
# generating datasets and dataloaders
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
max_question_length = 30
max_number_answers = 10
num_workers = 2
batch_size = 64
shuffle = True
question_dir = "/content/questions"
answer_dir = "/content/answers"

data_loader = get_loader(question_dir=question_dir,
                         answer_dir=answer_dir,
                         question_vocab_path=question_vocab_path, 
                         answer_vocab_path=answer_vocab_path, 
                         images_dir=images_dir, 
                         max_question_length=max_question_length,
                         transform=transform, 
                         shuffle=shuffle, 
                         num_workers=num_workers,
                         batch_size=batch_size)

In [None]:
# initialize the model
model_dir = "/content/model_parameters"
os.makedirs(model_dir, exist_ok=True)

embed_size = 1024
word_embed_size = 100
num_layers = 2
hidden_size = 512
num_epochs = 20
learning_rate = 3e-4
device = "cuda" if torch.cuda.is_available() else "cpu"

model = VQAModel(embed_size=embed_size,
                 qst_vocab_size=question_vocab_size,
                 ans_vocab_size=answer_vocab_size,
                 word_embed_size=word_embed_size,
                 num_layers=num_layers,
                 hidden_size=hidden_size)  

In [None]:
model_file_path = None # replace by path to existing model weights

load_checkpoint(model_file_path, model, optimizer, scheduler)

In [None]:
# train the model
parameters = []
for parameter in model.parameters():
    if parameter.requires_grad:
        parameters.append(parameter)  

T_max = num_epochs * batch_size
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(parameters, lr=learning_rate)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max)

model_file_path = "/content/model_parameters/epoch-09.pth"

train_model(model=model,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            device=device,
            num_epochs=num_epochs,
            data_loader=data_loader,
            model_file_path=model_file_path,
            model_dir=model_dir)