# Attention model for Visual Question Answering in MXNET Gluon

In this notebook, we introduce attention from <a href="https://arxiv.org/pdf/1606.01847.pdf">"Multimodal Compact Bilinear Pooling for Visual Question Answering and Visual Grounding"</a>, Akira Fukui, Dong Huk Park, Daylen Yang, Anna Rohrbach, Trevor Darrell and Marcus Rohrbach, EMNLP 2016.

## Models

For each spatial grid location in the visual representation, we use count sketch to merge the slice of the visual feature with the language representation. After the pooling we use two convolutional layers to predict the attention weight for each grid location. We apply softmax to produce a normalized soft attention map. We then take a weighted sum of the spatial vectors using the attention map to create the attended visual representation. 

![](img/VQA-attention.png )

In [1]:
import mxnet as mx
from __future__ import print_function
import numpy as np
import mxnet as mx
import mxnet.ndarray as F
import mxnet.contrib.ndarray as C
import mxnet.gluon as gluon
from mxnet.gluon import nn
from mxnet import autograd
import bisect
from IPython.core.display import display, HTML
import logging
logging.basicConfig(level=logging.INFO)
import os
from mxnet.test_utils import download
import json
from IPython.display import HTML, display

  import OpenSSL.SSL


In [2]:
batch_size = 20
ctx = mx.gpu(2)
#ctx = mx.cpu()
compute_size  = batch_size
out_dim = 10000
gpus = 1
class Net(gluon.Block):
    def __init__(self, **kwargs):
        super(Net, self).__init__(**kwargs)
        with self.name_scope():
            # layers created in name_scope will inherit name space
            # from parent layer.
            self.bn = nn.BatchNorm()
            self.dropout = nn.Dropout(0.3)
            self.fc1 = nn.Dense(8192,activation="relu")
            self.fc2 = nn.Dense(1000)
            self.conv1 = nn.Conv2D(channels=512, strides=(1,1), padding=(1, 1), kernel_size=3, activation='relu',layout = 'NCHW',use_bias=False)
            self.conv2 = nn.Conv2D(channels=1, strides=(1,1), padding=(1, 1), kernel_size=3, activation='relu',layout = 'NCHW',use_bias=False)
 
    def forward(self, x):
        x1 = F.L2Normalization(x[0])
        x2 = F.L2Normalization(x[1])
        
        # Reshape the inputs
        text_data_copy = F.Reshape(x1, shape=(0,0,1,1))
        text_data_copy = F.broadcast_to(text_data_copy,shape=(0,0,14,14))
        text_swapaxis = F.transpose(data = text_data_copy, axes=(0,2,3,1)) #from N 1024 14(H) 14  to N 14(H) 14 1024
        img_swapaxis = F.transpose(data = x2, axes=(0,2,3,1)) #from N 1024 14(H) 14  to N 14(H) 14 1024
        # Count sketch
        S1 = F.array(np.random.randint(0, 2, (1,1024))*2-1,ctx = ctx)
        H1 = F.array(np.random.randint(0, out_dim,(1,1024)),ctx = ctx)
        S2 = F.array(np.random.randint(0, 2, (1,2048))*2-1,ctx = ctx)
        H2 = F.array(np.random.randint(0, out_dim,(1,2048)),ctx = ctx)
        cs1 = C.count_sketch( data = text_swapaxis,s = S1, h= H1,name= 'cs1',out_dim = out_dim,processing_batch_size = compute_size) 
        cs2 = C.count_sketch( data = img_swapaxis,s = S2, h = H2,name='cs2',out_dim = out_dim,processing_batch_size = compute_size)
        fft1 = C.fft(data = cs1, name='fft1', compute_size =compute_size) 
        fft2 = C.fft(data = cs1, name='fft2', compute_size =compute_size) 
        c = fft2*fft1
        ifft = C.ifft(data = c, name='ifft1', compute_size = compute_size) 
        # Attention
        mcb0_swapaxis = F.transpose(data = ifft, axes=(0,3,1,2)) #from N 14(H) 14 5000 to N 5000 14(H) 14
        cv0 = self.conv1(mcb0_swapaxis)
        cv1 = self.conv2(cv0)
        body = F.SoftmaxActivation(data=cv1, mode = "channel", name='softmax0')  #N 1 14(H) 14 
        image_attention = F.broadcast_mul(x2,body)
        image_attention = F.sum(image_attention,axis = (2,3))   #N 2048
        # CSecond ount sketch
        S3 = F.array(np.random.randint(0, 2, (1,1024))*2-1,ctx = ctx)
        H3 = F.array(np.random.randint(0, out_dim,(1,1024)),ctx = ctx)
        S4 = F.array(np.random.randint(0, 2, (1,2048))*2-1,ctx = ctx)
        H4 = F.array(np.random.randint(0, out_dim,(1,2048)),ctx = ctx)
        cs3 = C.count_sketch( data = x1,s = S3, h= H3,name= 'cs3',out_dim = out_dim,processing_batch_size = compute_size) 
        cs4 = C.count_sketch( data = image_attention,s = S4, h = H4,name='cs4',out_dim = out_dim,processing_batch_size = compute_size)
        fft3 = C.fft(data = cs3, name='fft3', compute_size =compute_size) 
        fft4 = C.fft(data = cs4, name='fft4', compute_size = compute_size) 
        c2 = fft3*fft4
        z = C.ifft(data = c2, name='ifft2', compute_size = compute_size) 
        
        #z = F.concat(x1,image_attention,dim=1)
        z = self.fc1(z)
        z = self.bn(z)
        z = self.dropout(z)
        z = self.fc2(z)
        return z

## Data IO

The inputs of the data iterator are extracted image and question features. At each step, the data iterator will return a data batch list: question data batch and image data batch. 

In [3]:
from VQAtrainIter import VQAtrainIter

In this model, since we need to include the attention scheme, the image feature is a $2048 \times 14 \times 14$ tensor, extracted from Resnet-152 pool5. We have 1000 training samples and 100 validation samples for this model due to space limit in this tutorial.

We will discuss about how to extract the features <a href="extract-feature.ipynb">here</a> in details.

In [4]:
dataset_files = {'train_attention': ('train_question_attention.npz','train_img_attention.npz','train_ans_attention.npz'),
                 'validation_attention': ('val_question_attention.npz','val_img_attention.npz','val_ans_attention.npz')
                }
train_q, train_i, train_a = dataset_files['train_attention']
val_q, val_i, val_a = dataset_files['validation_attention']

url_format = 'https://apache-mxnet.s3-accelerate.amazonaws.com/gluon/dataset/VQA-notebook/{}'
if not os.path.exists(train_q):
    logging.info('Downloading training dataset.')
    download(url_format.format(train_q),overwrite=True)
    download(url_format.format(train_i),overwrite=True)
    download(url_format.format(train_a),overwrite=True)
if not os.path.exists(val_q):
    logging.info('Downloading validation dataset.')
    download(url_format.format(val_q),overwrite=True)
    download(url_format.format(val_i),overwrite=True)
    download(url_format.format(val_a),overwrite=True)

layout = 'NT'
bucket = [1024]

train_question = np.load(train_q)['x']
val_question = np.load(val_q)['x']
train_ans = np.load(train_a)['x']
val_ans = np.load(val_a)['x']
train_img = np.load(train_i)['x']
val_img = np.load(val_i)['x']

print("Total training sample:",train_ans.shape[0])   
print("Total validation sample:",val_ans.shape[0])   

data_train  = VQAtrainIter(train_img, train_question, train_ans, batch_size, buckets = bucket,layout=layout)
data_eva = VQAtrainIter(val_img, val_question, val_ans, batch_size, buckets = bucket,layout=layout) 


INFO:root:Downloading training dataset.
INFO:root:downloaded https://apache-mxnet.s3-accelerate.amazonaws.com/gluon/dataset/VQA-notebook/train_question_attention.npz into train_question_attention.npz successfully


KeyboardInterrupt: 

## Training

In [None]:
net = Net()
net.collect_params().initialize(mx.init.Xavier(), ctx=ctx)

In [None]:
loss = gluon.loss.SoftmaxCrossEntropyLoss()

metric = mx.metric.Accuracy()

def evaluate_accuracy(data_iterator, net):
    numerator = 0.
    denominator = 0.
    
    data_iterator.reset()
    for i, batch in enumerate(data_iterator):
        with autograd.record():
            data1 = batch.data[0].as_in_context(ctx)
            data2 = batch.data[1].as_in_context(ctx)
            data = [data1,data2]
            label = batch.label[0].as_in_context(ctx)
            #label_one_hot = nd.one_hot(label, 10)
            output = net(data)
        
        metric.update([label], [output])
    return metric.get()[1]


In [None]:
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.01})

In [None]:
epochs = 10
moving_loss = 0.
best_eva = 0
for e in range(epochs):
    data_train.reset()
    for i, batch in enumerate(data_train):
        data1 = batch.data[0].as_in_context(ctx)
        data2 = batch.data[1].as_in_context(ctx)
        data = [data1,data2]
        label = batch.label[0].as_in_context(ctx)
        with autograd.record():
            output = net(data)
            cross_entropy = loss(output, label)
            cross_entropy.backward()
        trainer.step(data[0].shape[0])
        
        ##########################
        #  Keep a moving average of the losses
        ##########################
        if i == 0:
            moving_loss = np.mean(cross_entropy.asnumpy()[0])
        else:
            moving_loss = .99 * moving_loss + .01 * np.mean(cross_entropy.asnumpy()[0])
        #if i % 200 == 0:
        #    print("Epoch %s, batch %s. Moving avg of loss: %s" % (e, i, moving_loss))   
    eva_accuracy = evaluate_accuracy(data_eva, net)
    train_accuracy = evaluate_accuracy(data_train, net)
    print("Epoch %s. Loss: %s, Train_acc %s, Eval_acc %s" % (e, moving_loss, train_accuracy, eva_accuracy))
    if eva_accuracy > best_eva:
            best_eva = eva_accuracy
            logging.info('Best validation acc found. Checkpointing...')
            net.save_params('vqa-mlp-%d.params'%(e))
