## Reduce the size of the dataset and preprocess the data


Import

In [1]:
# Resizing img
import os
import argparse
from PIL import Image
from sklearn.model_selection import train_test_split

# Make vocab
import json
import re
from collections import defaultdict

# Preprocess data
import glob
import numpy as np

# Build dataset
from torch.utils.data import Dataset
from torchvision import transforms
from torch.utils.data import DataLoader

# Models
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F

# Training
import time
from torch import optim



Define paths

In [2]:
img_pth = 'dataset\img'
ann_pth = 'dataset\\ann'
qst_pth = 'dataset\qst'
out_img_pth = 'preprocessed\img'
out_data_pth = 'preprocessed\data'
out_vocab_pth = 'preprocessed\\vocab'
out_ann_pth = 'preprocessed\\ann'
out_qst_pth = 'preprocessed\qst'
ckpt_pth = 'late_fusion\ckpt'
log_pth = 'late_fusion\log'

Preprocess images

In [3]:
def resize_image(image, size):
    """Resize an image to the given size."""
    return image.resize(size, Image.Resampling.LANCZOS)

def resize_images(input_dir, output_dir, size, split_ratio):
    """Resize the images in 'input_dir' and save into 'output_dir'."""
    for idir in os.scandir(input_dir):
        if not idir.is_dir():
            print('No valid directory')
            continue
        if not os.path.exists(output_dir+'\\'+idir.name):
            os.makedirs(output_dir+'\\'+idir.name)
        else:
            for file in os.listdir(output_dir+'\\'+idir.name):
                if os.path.isfile(os.path.join(output_dir+'\\'+idir.name, file)):
                    os.remove(os.path.join(output_dir+'\\'+idir.name, file))
                      
        images = os.listdir(idir.path)
        images, _ = train_test_split(images,
                                     test_size=split_ratio,
                                     shuffle=False)
        n_images = len(images)
        for id, image in enumerate(images):
            try:
                with open(os.path.join(idir.path, image), 'r+b') as f:
                    with Image.open(f) as img:
                        img = resize_image(img, size)
                        img.save(os.path.join(output_dir+'\\'+idir.name, image), img.format)
            except(IOError, SyntaxError) as e:
                pass
            if (id+1) % 500 == 0:
                print("[{}/{}] resized images and saved into '{}'."
                      .format(id+1, n_images, output_dir+'\\'+idir.name))
                    
def main():

    input_dir = img_pth
    output_dir = out_img_pth
    image_size = [224, 224]
    split_ratio = 0.95
    resize_images(input_dir, output_dir, image_size, split_ratio)
    
main()

[500/1000] resized images and saved into 'preprocessed\img\img_test'.
[1000/1000] resized images and saved into 'preprocessed\img\img_test'.
[500/1000] resized images and saved into 'preprocessed\img\img_train'.
[1000/1000] resized images and saved into 'preprocessed\img\img_train'.
[500/500] resized images and saved into 'preprocessed\img\img_val'.


Removing questions and annotations that doesnt belong to the images

In [4]:
def removing(out_img_pth, ann_pth, qst_pth):

    test_ids = []
    train_ids = []
    val_ids = []
    for idir in os.scandir(out_img_pth):
        for file in os.listdir(idir):
            components = file.split('_')
            image_id = components[-1].split('.')[0]
            numeric_part = int(image_id)
            if 'test' in idir.name:
                test_ids.append(numeric_part)
            elif 'train' in idir.name:
                train_ids.append(numeric_part)
            elif 'val' in idir.name:
                val_ids.append(numeric_part)
    
    for idir in os.scandir(ann_pth):
        if 'test' in idir.name:
            ids = test_ids
        elif 'train' in idir.name:
            ids = train_ids
        elif 'val' in idir.name:
            ids = val_ids
        
        for file in os.listdir(idir):
            path = os.path.join(idir, file)
            with open(path, 'r') as f:
                data = json.load(f)
            annotations = data['annotations']
            prelen = len(annotations)
            labels = dict()
            number = 0
            for label in annotations:
                    if int(label['image_id']) in ids:
                        labels.update({number: label})   
                        number += 1
            print(f'{len(labels)} remaining annotations of {prelen} in {file}')
            data['annotations'] = labels
            with open(out_ann_pth + '\\' + file, 'w') as f:
                json.dump(data, f)
    
    for idir in os.scandir(qst_pth):
        if 'test' in idir.name:
            ids = test_ids
        elif 'train' in idir.name:
            ids = train_ids
        elif 'val' in idir.name:
            ids = val_ids
            
        for file in os.listdir(idir):
            path = os.path.join(idir, file)
            with open(path, 'r') as f:
                data = json.load(f)
            questions = data['questions']
            prelen = len(questions)
            labels = dict()
            number = 0
            for label in questions:
                    if int(label['image_id']) in ids:
                        labels.update({number: label}) 
                        number += 1
            print(f'{len(labels)} remaining questions of {prelen} in {file}')
            data['questions'] = labels
            with open(out_qst_pth + '\\' + file, 'w') as f:
                json.dump(data, f)

def main():
    img_pth = out_img_pth
    annot_pth = ann_pth
    quest_pth = qst_pth
    removing(img_pth, annot_pth, quest_pth)
    
main()

3000 remaining annotations of 60000 in ann_train.json
1500 remaining annotations of 30000 in ann_val.json
3000 remaining questions of 60000 in multi_qst_test.json
3000 remaining questions of 60000 in open_qst_test.json
3000 remaining questions of 60000 in multi_qst_train.json
3000 remaining questions of 60000 in open_qst_train.json
1500 remaining questions of 30000 in multi_qst_val.json
1500 remaining questions of 30000 in open_qst_val.json


Make vocab

In [36]:
def make_q_vocab(input_dir, output_dir):

    for file in os.scandir(input_dir):
        if "test" in file.name:
            continue
        if not os.path.exists(output_dir):
            os.makedirs(output_dir) 
            
        regex = re.compile(r'(\W+)')
        q_vocab = []
        path = os.path.join(input_dir, file.name)
        with open(path, 'r') as f:
            q_data = json.load(f)
        question = q_data['questions'].values()
        for quest in question:
            split = regex.split(quest['question'].lower())
            tmp = [w.strip() for w in split if len(w.strip()) > 0]
            q_vocab.extend(tmp)
    
        q_vocab = list(set(q_vocab))
        q_vocab.sort()
        q_vocab.insert(0, '<pad>')
        q_vocab.insert(1, '<unk>')
    
        with open(output_dir + '\\' + file.name.split(".")[0] + '_vocabs.txt', 'w') as f:
            f.writelines([v+'\n' for v in q_vocab])

        print(f'The number of total words of questions in {file.name}: {len(q_vocab)}')

def make_a_vocab(input_dir, output_dir):

    for file in os.scandir(input_dir):
        if not os.path.exists(output_dir):
            os.makedirs(output_dir) 
            
        answers = defaultdict(lambda :0)
        path = os.path.join(input_dir, file.name)
        with open(path, 'r') as f:
            data = json.load(f)
        annotations = data['annotations'].values()
        for label in annotations:
            for ans in label['answers']:
                vocab = ans['answer']
                if re.search(r'[^\w\s]', vocab):
                    continue
                answers[vocab] += 1
    
        answers = sorted(answers, key=answers.get, reverse= True) 
        with open(output_dir + '\\' + file.name.split(".")[0] + '_vocabs.txt', 'w') as f :
            f.writelines([ans+'\n' for ans in answers])
            
        print(f'The number of total words of answers in {file.name}: {len(answers)}')

def make_vocab(output_dir):
    ann_vocab = set()
    qst_vocab = set()
    for file in os.scandir(output_dir):
        if 'ann' in file.name:
            with open(file.path, 'r') as f:
                for line in f: 
                    print(line)
                    ann_vocab.add(line.split('\n')[0])
        elif 'qst' in file.name:
            with open(file.path, 'r') as f:
                for line in f: 
                    qst_vocab.add(line.split('\n')[0])

    print(ann_vocab)
    print(f'The number of total words of answers in ann_vocabs will be: {len(ann_vocab)}')
    print(f'The number of total words of answers in qst_vocabs will be: {len(qst_vocab)}')
    
    with open(output_dir + '\\ann_vocabs.txt', 'w') as f:
        f.writelines([ans+'\n' for ans in ann_vocab])
    with open(output_dir + '\\qst_vocabs.txt', 'w') as f:
        f.writelines([ans+'\n' for ans in qst_vocab])

def main():
    input_qst_dir = out_qst_pth
    input_ann_dir = out_ann_pth
    output_vocab_dir = out_vocab_pth
    make_q_vocab(input_qst_dir, output_vocab_dir)
    make_a_vocab(input_ann_dir, output_vocab_dir)
    make_vocab(out_vocab_pth)
main()
    

The number of total words of questions in multi_qst_train.json: 1283
The number of total words of questions in multi_qst_val.json: 937
The number of total words of questions in open_qst_train.json: 1283
The number of total words of questions in open_qst_val.json: 937
The number of total words of answers in ann_train.json: 2131
The number of total words of answers in ann_val.json: 1288
yes

no

2

1

3

red

4

dog

white

brown

blue

yellow

cat

0

green

soccer

sitting

5

girl

right

table

gray

floor

log

bench

tan

wine

black

skateboard

playing

football

sunny

man

orange

duck

woman

sleeping

monkey bars

no one

rug

bone

baseball

couch

yarn

blanket

bird

mouse

chair

plant

book

grass

tree

boy

beehive

fish

stool

nothing

jumping

standing

pie

sun

left

food

pond

bush

sandbox

picnic

baby

cloud

toys

sand

bike

sidewalk

maybe

jump rope

on table

slide

sunset

beige

owl

dollhouse

on blanket

tv

deer

soccer ball

pizza

happy

picture



Make Vocab

In [39]:
def preprocessing(image_dir, annotation_dir, question_dir, output_dir, vocab_dir):
    
    dataset = dict()
    for file in os.scandir(question_dir):
        info = dict()
        
        if 'test' in file.name:
            continue
        elif 'train' in file.name:
            datatype = 'train'
        elif 'val' in file.name:
            datatype = 'val'
        
        with open(file.path, 'r') as f:
            data = json.load(f)
            questions = data['questions'].values()
    
        for ann in os.scandir(annotation_dir):
            if datatype == 'train' and 'train' in ann.name:
                with open(ann.path) as f:
                    annotations = json.load(f)['annotations'].values()
            elif datatype == 'val' and 'val' in ann.name:
                with open(ann.path) as f:
                    annotations = json.load(f)['annotations'].values()
        question_dict = {ans['question_id']: ans for ans in annotations}
        
        match_top_ans.unk_ans = 0
        num = 0
        for idx, qu in enumerate(questions):
            if (idx+1) % 1500 == 0:
                print(f'Processing {datatype} data: {idx+1}/{len(questions)}')
            qu_id = qu['question_id']
            qu_sentence = qu['question']
            qu_tokens = tokenizer(qu_sentence)
            img_id = qu['image_id']
            for dir in os.scandir(image_dir):
                if 'train' == datatype and 'train' in dir.name:
                    dir_path = dir.path
                elif 'val' == datatype and 'val' in dir.name:
                    dir_path = dir.path
                else:
                    continue
                for img in os.scandir(dir_path):
                    components = img.name.split('_')
                    image_id = components[-1].split('.')[0]
                    numeric_part = int(image_id)
                    if img_id == numeric_part:
                        img_path = img.path
            annotation_ans = question_dict[qu_id]['answers']
            
            qu_info = dict()
            qu_info.update({'img_id': img_id,
                            'img_path': img_path,
                            'qu_id': qu_id,
                            'qu_sentence': qu_sentence,
                            'qu_tokens': qu_tokens})
            
            for voc in os.scandir(vocab_dir):
                if 'ann' in voc.name:
                    if datatype == 'train' and 'train' in voc.name:
                        voc_path = voc.path
                    elif datatype == 'val' and 'val' in voc.name:
                        voc_path = voc.path
            
            all_ans, valid_ans = match_top_ans(annotation_ans, voc_path)
            qu_info.update({'all_ans': list(all_ans),
                            'valid_ans': list(valid_ans)})   

            info.update({idx: qu_info})
            
        dataset.update({datatype: info})
        print(f'Total {match_top_ans.unk_ans} out of {len(questions)} answers are <unk>')

    np.save(output_dir + '\\train.npy', np.array(dataset['train']))
    np.save(output_dir + '\\val.npy', np.array(dataset['val']))
    with open(output_dir + '\\train.json', 'w') as f:
        json.dump(dataset['train'], f)
    with open(output_dir + '\\val.json', 'w') as f:
        json.dump(dataset['val'], f)

def tokenizer(sentence):

    regex = re.compile(r'(\W+)')
    tokens = regex.split(sentence.lower())
    tokens = [w.strip() for w in tokens if len(w.strip()) > 0]
    return tokens

def match_top_ans(annotation_ans, vocab_path):
    
    if "top_ans" not in match_top_ans.__dict__:
        with open(vocab_path, 'r') as f:
            match_top_ans.top_ans = {line.strip() for line in f}
    annotation_ans = {ans['answer'] for ans in annotation_ans}
    valid_ans = match_top_ans.top_ans & annotation_ans

    if len(valid_ans) == 0:
        valid_ans = ['<unk>']
        match_top_ans.unk_ans += 1

    return annotation_ans, valid_ans

def main():

    image_dir = out_img_pth
    annotation_dir = out_ann_pth
    question_dir = out_qst_pth
    output_dir = out_data_pth
    vocab_dir = out_vocab_pth
    preprocessing(image_dir, annotation_dir, question_dir, output_dir, vocab_dir)

main()

Processing train data: 1500/3000
Processing train data: 3000/3000
Total 1 out of 3000 answers are <unk>
Processing val data: 1500/1500
Total 27 out of 1500 answers are <unk>
Processing train data: 1500/3000
Processing train data: 3000/3000
Total 1 out of 3000 answers are <unk>
Processing val data: 1500/1500
Total 27 out of 1500 answers are <unk>


Build dataset

In [None]:
class VQADataset(Dataset):

    def __init__(self, input_dir, input_file, max_qu_len = 30, transform = None):

        self.input_data = np.load(os.path.join(input_dir, input_file), allow_pickle=True)
        self.qu_vocab = Vocab(input_dir+'/qst_vocabs.txt')
        self.ans_vocab = Vocab(input_dir+'/ann_vocabs.txt')
        self.max_qu_len = max_qu_len
        self.transform = transform

    def __getitem__(self, idx):

        path = self.input_data[idx]['img_path']
        img = np.array(Image.open(path).convert('RGB'))
        qu_id = self.input_data[idx]['qu_id']
        qu_tokens = self.input_data[idx]['qu_tokens']
        qu2idx = np.array([self.qu_vocab.word2idx('<pad>')] * self.max_qu_len)
        qu2idx[:len(qu_tokens)] = [self.qu_vocab.word2idx(token) for token in qu_tokens]
        sample = {'image': img, 'question': qu2idx, 'question_id': qu_id}

        if self.labeled:
            ans2idx = [self.ans_vocab.word2idx(ans) for ans in self.input_data[idx]['valid_ans']]
            ans2idx = np.random.choice(ans2idx)
            sample['answer'] = ans2idx

        if self.transform:
            sample['image'] = self.transform(sample['image'])

        return sample

    def __len__(self):

        return len(self.input_data)

def data_loader(input_dir, batch_size, max_qu_len, num_worker):

    transform = transforms.Compose([
        transforms.ToTensor(),  # convert to (C,H,W) and [0,1]
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))  # mean=0; std=1
    ])

    vqa_dataset = {
        'train': VQADataset(
            input_dir=input_dir,
            input_file='train.npy',
            max_qu_len=max_qu_len,
            transform=transform),
        'val': VQADataset(
            input_dir=input_dir,
            input_file='val.npy',
            max_qu_len=max_qu_len,
            transform=transform)
    }

    dataloader = {
        key: DataLoader(
            dataset=vqa_dataset[key],
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_worker)
        for key in ['train', 'val']
    }

    return dataloader

class Vocab:

    def __init__(self, vocab_file):

        self.vocab = self.load_vocab(vocab_file)
        self.vocab2idx = {vocab: idx for idx, vocab in enumerate(self.vocab)}
        self.vocab_size = len(self.vocab)

    def load_vocab(self, vocab_file):

        with open(vocab_file) as f:
            vocab = [v.strip() for v in f]

        return vocab

    def word2idx(self, vocab):

        if vocab in self.vocab2idx:
            return self.vocab2idx[vocab]
        else:
            return self.vocab2idx['<unk>']

    def idx2word(self, idx):

        return self.vocab[idx]

Late fusion models

In [None]:
class ImgEncoder(nn.Module):

    def __init__(self, embed_dim):

        super(ImgEncoder, self).__init__()
        self.model = models.vgg19(pretrained=True)
        in_features = self.model.classifier[-1].in_features
        self.model.classifier = nn.Sequential(*list(self.model.classifier.children())[:-1]) # remove vgg19 last layer
        self.fc = nn.Linear(in_features, embed_dim)

    def forward(self, image):

        with torch.no_grad():
            img_feature = self.model(image) # (batch, channel, height, width)
        img_feature = self.fc(img_feature)

        l2_norm = F.normalize(img_feature, p=2, dim=1).detach()
        return l2_norm

class QuEncoder(nn.Module):

    def __init__(self, qu_vocab_size, word_embed, hidden_size, num_hidden, qu_feature_size):

        super(QuEncoder, self).__init__()
        self.word_embedding = nn.Embedding(qu_vocab_size, word_embed)
        self.tanh = nn.Tanh()
        self.lstm = nn.LSTM(word_embed, hidden_size, num_hidden) # input_feature, hidden_feature, num_layer
        self.fc = nn.Linear(2*num_hidden*hidden_size, qu_feature_size)

    def forward(self, question):

        qu_embedding = self.word_embedding(question)                # (batchsize, qu_length=30, word_embed=300)
        qu_embedding = self.tanh(qu_embedding)
        qu_embedding = qu_embedding.transpose(0, 1)                 # (qu_length=30, batchsize, word_embed=300)
        _, (hidden, cell) = self.lstm(qu_embedding)                 # (num_layer=2, batchsize, hidden_size=1024)
        qu_feature = torch.cat((hidden, cell), dim=2)               # (num_layer=2, batchsize, 2*hidden_size=1024)
        qu_feature = qu_feature.transpose(0, 1)                     # (batchsize, num_layer=2, 2*hidden_size=1024)
        qu_feature = qu_feature.reshape(qu_feature.size()[0], -1)   # (batchsize, 2*num_layer*hidden_size=2048)
        qu_feature = self.tanh(qu_feature)
        qu_feature = self.fc(qu_feature)                            # (batchsize, qu_feature_size=1024)

        return qu_feature

class VQAModel(nn.Module):

    def __init__(self, feature_size, qu_vocab_size, ans_vocab_size, word_embed, hidden_size, num_hidden):

        super(VQAModel, self).__init__()
        self.img_encoder = ImgEncoder(feature_size)
        self.qu_encoder = QuEncoder(qu_vocab_size, word_embed, hidden_size, num_hidden, feature_size)
        self.dropout = nn.Dropout(0.5)
        self.tanh = nn.Tanh()
        self.fc1 = nn.Linear(feature_size, ans_vocab_size)
        self.fc2 = nn.Linear(ans_vocab_size, ans_vocab_size)

    def forward(self, image, question):

        img_feature = self.img_encoder(image)               # (batchsize, feature_size=1024)
        qst_feature = self.qu_encoder(question)
        combined_feature = img_feature * qst_feature
        combined_feature = self.dropout(combined_feature)
        combined_feature = self.tanh(combined_feature)
        combined_feature = self.fc1(combined_feature)       # (batchsize, ans_vocab_size=1000)
        combined_feature = self.dropout(combined_feature)
        combined_feature = self.tanh(combined_feature)
        logits = self.fc2(combined_feature)                 # (batchsize, ans_vocab_size=1000)

        return logits

Training loop

In [None]:


BATCH_SIZE = 150
MAX_QU_LEN = 30
NUM_WORKER = 8
FEATURE_SIZE, WORD_EMBED = 1024, 300
NUM_HIDDEN, HIDDEN_SIZE = 2, 512
LEARNING_RATE, STEP_SIZE, GAMMA = 0.001, 10, 0.1
EPOCH = 50

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

def train():

    dataloader = data_loader(input_dir=out_data_pth, batch_size=BATCH_SIZE, max_qu_len=MAX_QU_LEN, num_worker=NUM_WORKER)
    qu_vocab_size = dataloader['train'].dataset.qu_vocab.vocab_size
    ans_vocab_size = dataloader["train"].dataset.ans_vocab.vocab_size

    model = VQAModel(feature_size=FEATURE_SIZE, qu_vocab_size=qu_vocab_size, ans_vocab_size=ans_vocab_size,
                     word_embed=WORD_EMBED, hidden_size=HIDDEN_SIZE, num_hidden=NUM_HIDDEN).to(device)

    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA)
    criterion = nn.CrossEntropyLoss()

    print('>> start training')
    start_time = time.time()
    for epoch in range(EPOCH):
        epoch_loss = {key: 0 for key in ['train', 'val']}

        model.train()
        for idx, sample in enumerate(dataloader['train']):

            image = sample['image'].to(device=device)
            question = sample['question'].to(device=device)
            label = sample['answer'].to(device=device)
            # forward
            logits = model(image, question)
            loss = criterion(logits, label)
            epoch_loss['train'] += loss.item()
            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        model.eval()
        for idx, sample in enumerate(dataloader['val']):

            image = sample['image'].to(device=device)
            question = sample['question'].to(device=device)
            label = sample['answer'].to(device=device)
            with torch.no_grad():
                logits = model(image, question)
                loss = criterion(logits, label)
            epoch_loss['val'] += loss.item()

        # statistic
        for phase in ['train', 'val']:
            epoch_loss[phase] /= len(dataloader[phase])
            with open(os.path.join(LOG_DIR, f'{phase}_log.txt'), 'a') as f:
                f.write(str(epoch+1) + '\t' + str(epoch_loss[phase]) + '\n')
        print('Epoch:{}/{} | Training Loss: {train:6f} | Validation Loss: {val:6f}'.format(epoch+1, EPOCH, **epoch_loss))

        scheduler.step()
        early_stop = early_stopping(model, epoch_loss['val'])
        if (epoch+1) % 5 == 0:
            torch.save(model.state_dict(), os.path.join(ckpt_pth, f'model-epoch-{epoch+1}.pth'))
        if early_stop:
            print(f'>> Early stop at {epoch+1} epoch')
            break

    end_time = time.time()
    training_time = end_time - start_time
    print(f">> Finishing training | Training Time:{training_time//60:.0f}m:{training_time%60:.0f}s")

def early_stopping(model, epoch_loss, patience=7):

    early_stop = False
    if not bool(early_stopping.__dict__):
        early_stopping.best_loss = epoch_loss
        early_stopping.record_loss = epoch_loss
        early_stopping.counter = 0

    if epoch_loss < early_stopping.best_loss:
        early_stopping.best = epoch_loss
        torch.save(model.state_dict(), os.path.join(ckpt_pth, 'best_model.pth'))

    if epoch_loss > early_stopping.record_loss:
        early_stopping.counter += 1
        if early_stopping.counter >= patience:
            early_stop = True
    else:
        early_stopping.counter = 0
        early_stopping.record_loss = epoch_loss

    return early_stop

if __name__ == '__main__':

    if not os.path.exists(LOG_DIR):
        os.makedirs(LOG_DIR)
    if not os.path.exists(CKPT_DIR):
        os.makedirs(CKPT_DIR)
    train()