In [2]:
import torch.utils.data as data
from PIL import Image
import torch
import torch.utils.data
import torch.nn as nn
from torch.nn import Module
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
import numpy as np
import h5py
import json
import time
import pdb
import random

  from ._conv import register_converters as _register_converters


In [3]:
class train(data.Dataset): # torch wrapper
    def __init__(self, input_img_h5, input_ques_h5, input_json, negative_sample, num_val, data_split):

        print('DataLoader loading: %s' %data_split)
        print('Loading image feature from %s' %input_img_h5)

        if data_split == 'test':
            split = 'val'
        else:
            split = 'train' # train and val split both corresponding to 'train'

        f = json.load(open(input_json, 'r'))
        self.itow = f['itow']
        self.img_info = f['img_'+split]

        # get the data split.
        total_num = len(self.img_info)
        if data_split == 'train':
            s = 0
            e = total_num - num_val
        elif data_split == 'val':
            s = total_num - num_val
            e = total_num
        else:
            s = 0
            e = total_num
            
        self.img_info = self.img_info[s:e]

        print('%s number of data: %d' %(data_split, e-s))
        # load the data.
        f = h5py.File(input_img_h5, 'r')
        self.imgs = f['images_'+split][s:e]
        f.close()

        print('Loading txt from %s' %input_ques_h5)
        f = h5py.File(input_ques_h5, 'r')
        self.ques = f['ques_'+split][s:e]
        self.ans = f['ans_'+split][s:e]
        self.cap = f['cap_'+split][s:e]

        self.ques_len = f['ques_len_'+split][s:e]
        self.ans_len = f['ans_len_'+split][s:e]
        self.cap_len = f['cap_len_'+split][s:e]

        self.ans_ids = f['ans_index_'+split][s:e]
        self.opt_ids = f['opt_'+split][s:e]
        self.opt_list = f['opt_list_'+split][:]
        self.opt_len = f['opt_len_'+split][:]
        f.close()

        self.ques_length = self.ques.shape[2]
        self.ans_length = self.ans.shape[2]
        self.his_length = self.ques_length + self.ans_length
        self.vocab_size = len(self.itow)+1

        print('Vocab Size: %d' % self.vocab_size)
        self.split = split
        self.rnd = 10
        self.negative_sample = negative_sample

        
    def __getitem__(self, index):
        # get the image
        img = torch.from_numpy(self.imgs[index])

        # get the history
        his = np.zeros((self.rnd, self.his_length))
        his[0,self.his_length-self.cap_len[index]:] = self.cap[index,:self.cap_len[index]]

        ques = np.zeros((self.rnd, self.ques_length))
        ans = np.zeros((self.rnd, self.ans_length+1))
        ans_target = np.zeros((self.rnd, self.ans_length+1))
        ques_ori = np.zeros((self.rnd, self.ques_length))

        opt_ans = np.zeros((self.rnd, self.negative_sample, self.ans_length+1))
        ans_len = np.zeros((self.rnd))
        opt_ans_len = np.zeros((self.rnd, self.negative_sample))

        ans_idx = np.zeros((self.rnd))
        opt_ans_idx = np.zeros((self.rnd, self.negative_sample))

        for i in range(self.rnd):
            # get the index
            q_len = self.ques_len[index, i]
            a_len = self.ans_len[index, i]
            qa_len = q_len + a_len

            if i+1 < self.rnd:
                his[i+1, self.his_length-qa_len:self.his_length-a_len] = self.ques[index, i, :q_len]
                his[i+1, self.his_length-a_len:] = self.ans[index, i, :a_len]

            ques[i, self.ques_length-q_len:] = self.ques[index, i, :q_len]

            ques_ori[i, :q_len] = self.ques[index, i, :q_len]
            ans[i, 1:a_len+1] = self.ans[index, i, :a_len]
            ans[i, 0] = self.vocab_size

            ans_target[i, :a_len] = self.ans[index, i, :a_len]
            ans_target[i, a_len] = self.vocab_size
            ans_len[i] = self.ans_len[index, i]

            opt_ids = self.opt_ids[index, i] # since python start from 0
            # random select the negative samples.
            ans_idx[i] = opt_ids[self.ans_ids[index, i]]
            # exclude the gt index.
            opt_ids = np.delete(opt_ids, ans_idx[i], 0)
            random.shuffle(opt_ids)
            for j in range(self.negative_sample):
                ids = opt_ids[j]
                opt_ans_idx[i,j] = ids

                opt_len = self.opt_len[ids]

                opt_ans_len[i, j] = opt_len
                opt_ans[i, j, :opt_len] = self.opt_list[ids,:opt_len]
                opt_ans[i, j, opt_len] = self.vocab_size

        his = torch.from_numpy(his)
        ques = torch.from_numpy(ques)
        ans = torch.from_numpy(ans)
        ans_target = torch.from_numpy(ans_target)
        ques_ori = torch.from_numpy(ques_ori)
        ans_len = torch.from_numpy(ans_len)
        opt_ans_len = torch.from_numpy(opt_ans_len)
        opt_ans = torch.from_numpy(opt_ans)
        ans_idx = torch.from_numpy(ans_idx)
        opt_ans_idx = torch.from_numpy(opt_ans_idx)
        return img, his, ques, ans, ans_target, ans_len, ans_idx, ques_ori, \
                opt_ans, opt_ans_len, opt_ans_idx

    def __len__(self):
        return self.ques.shape[0]

In [4]:
input_img_h5 = 'vdl_img_vgg_demo.h5'
input_ques_h5 = 'visdial_data_demo.h5'
input_json = 'visdial_params.json'
negative_sample = 20
num_val = 1000
dataset = train(input_img_h5=input_img_h5, input_ques_h5=input_ques_h5,
                input_json=input_json, negative_sample = negative_sample,
                num_val = num_val, data_split = 'train')

DataLoader loading: train
Loading image feature from vdl_img_vgg_demo.h5
train number of data: 81783
Loading txt from visdial_data_demo.h5
Vocab Size: 8964


In [5]:
batchSize = 100
num_workers = 0
dloader = torch.utils.data.DataLoader(dataset, batch_size=batchSize,
                                         shuffle=True, num_workers=int(num_workers))

In [6]:
data_iter1 = iter(dloader)
data = data_iter.next()

NameError: name 'data_iter' is not defined

In [None]:
image, history, question, answer, answerT, answerLen, answerIdx, questionL, \
        opt_answerT, opt_answerLen, opt_answerIdx = data

In [7]:
class _netE(nn.Module):
    """Container module with an encoder, a recurrent module, and a decoder."""

    def __init__(self, rnn_type, ninp, nhid, nlayers, dropout, img_feat_size):
        super(_netE, self).__init__()

        self.d = dropout
        self.rnn_type = rnn_type
        self.nhid = nhid
        self.nlayers = nlayers
        self.nhid = nhid
        self.ninp = ninp
        self.img_embed = nn.Linear(img_feat_size, nhid)

        self.ques_rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout)
        self.his_rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout)

        self.Wq_1 = nn.Linear(self.nhid, self.nhid)
        self.Wh_1 = nn.Linear(self.nhid, self.nhid)
        self.Wa_1 = nn.Linear(self.nhid, 1)

        self.Wq_2 = nn.Linear(self.nhid, self.nhid)
        self.Wh_2 = nn.Linear(self.nhid, self.nhid)
        self.Wi_2 = nn.Linear(self.nhid, self.nhid)
        self.Wa_2 = nn.Linear(self.nhid, 1)

        self.fc1 = nn.Linear(self.nhid*3, self.ninp)

    def forward(self, ques_emb, his_emb, img_raw, ques_hidden, his_hidden, rnd):

        img_emb = F.tanh(self.img_embed(img_raw))

        ques_feat, ques_hidden = self.ques_rnn(ques_emb, ques_hidden)
        ques_feat = ques_feat[-1]

        his_feat, his_hidden = self.his_rnn(his_emb, his_hidden)
        his_feat = his_feat[-1]

        ques_emb_1 = self.Wq_1(ques_feat).view(-1, 1, self.nhid)
        his_emb_1 = self.Wh_1(his_feat).view(-1, rnd, self.nhid)

        atten_emb_1 = F.tanh(his_emb_1 + ques_emb_1.expand_as(his_emb_1))
        his_atten_weight = F.softmax(self.Wa_1(F.dropout(atten_emb_1, self.d, training=self.training
                                                ).view(-1, self.nhid)).view(-1, rnd))

        his_attn_feat = torch.bmm(his_atten_weight.view(-1, 1, rnd),
                                        his_feat.view(-1, rnd, self.nhid))

        his_attn_feat = his_attn_feat.view(-1, self.nhid)
        ques_emb_2 = self.Wq_2(ques_feat).view(-1, 1, self.nhid)
        his_emb_2 = self.Wh_2(his_attn_feat).view(-1, 1, self.nhid)
        img_emb_2 = self.Wi_2(img_emb).view(-1, 49, self.nhid)

        atten_emb_2 = F.tanh(img_emb_2 + ques_emb_2.expand_as(img_emb_2) + \
                                    his_emb_2.expand_as(img_emb_2))

        img_atten_weight = F.softmax(self.Wa_2(F.dropout(atten_emb_2, self.d, training=self.training
                                                ).view(-1, self.nhid)).view(-1, 49))

        img_attn_feat = torch.bmm(img_atten_weight.view(-1, 1, 49),
                                        img_emb.view(-1, 49, self.nhid))

        concat_feat = torch.cat((ques_feat, his_attn_feat.view(-1, self.nhid), \
                                 img_attn_feat.view(-1, self.nhid)),1)

        encoder_feat = F.tanh(self.fc1(F.dropout(concat_feat, self.d, training=self.training)))

        return encoder_feat, ques_hidden

    def init_hidden(self, bsz):
        weight = next(self.parameters()).data
        if self.rnn_type == 'LSTM':
            return (Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()),
                    Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()))
        else:
            return Variable(weight.new(self.nlayers, bsz, self.nhid).zero_())


In [8]:
class _netW(nn.Module):
    def __init__(self, ntoken, ninp, dropout):
        super(_netW, self).__init__()
#         self.word_embed = nn.Embedding(ntoken+1, ninp).cuda()
#         self.Linear = share_Linear(self.word_embed.weight).cuda()
        self.word_embed = nn.Embedding(ntoken+1, ninp)
        self.Linear = share_Linear(self.word_embed.weight)
        self.init_weights()
        self.d = dropout

    def init_weights(self):
        initrange = 0.1
        self.word_embed.weight.data.uniform_(-initrange, initrange)

    def forward(self, input, format ='index'):
        if format == 'onehot':
            out = F.dropout(self.Linear(input), self.d, training=self.training)
        elif format == 'index':
            out = F.dropout(self.word_embed(input), self.d, training=self.training)

        return out

    def init_hidden(self, bsz):
        weight = next(self.parameters()).data
        if self.rnn_type == 'LSTM':
            return (Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()),
                    Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()))
        else:
            return Variable(weight.new(self.nlayers, bsz, self.nhid).zero_())


In [9]:
class _netD(nn.Module):
    """
    Given the real/wrong/fake answer, use a RNN (LSTM) to embed the answer.
    """
    def __init__(self, rnn_type, ninp, nhid, nlayers, ntoken, dropout):
        super(_netD, self).__init__()

        self.rnn_type = rnn_type
        self.nhid = nhid
        self.nlayers = nlayers
        self.ntoken = ntoken
        self.ninp = ninp
        self.d = dropout

        self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers)
        self.W1 = nn.Linear(self.nhid, self.nhid)
        self.W2 = nn.Linear(self.nhid, 1)
        self.fc = nn.Linear(nhid, ninp)

    def forward(self, input_feat, idx, hidden, vocab_size):

        output, _ = self.rnn(input_feat, hidden)
        mask = idx.data.eq(0)  # generate the mask
        mask[idx.data == vocab_size] = 1 # also set the last token to be 1
        if isinstance(input_feat, Variable):
            mask = Variable(mask, volatile=input_feat.volatile)

        # Doing self attention here.
        atten = self.W2(F.dropout(F.tanh(self.W1(output.view(-1, self.nhid))), self.d, training=self.training)).view(idx.size())
        atten.masked_fill_(mask, -99999)
        weight = F.softmax(atten.t()).view(-1,1,idx.size(0))
        feat = torch.bmm(weight, output.transpose(0,1)).view(-1,self.nhid)
        feat = F.dropout(feat, self.d, training=self.training)
        transform_output = F.tanh(self.fc(feat))

        return transform_output

    def init_hidden(self, bsz):
        weight = next(self.parameters()).data
        if self.rnn_type == 'LSTM':
            return (Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()),
                    Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()))
        else:
            return Variable(weight.new(self.nlayers, bsz, self.nhid).zero_())


In [10]:
class nPairLoss(nn.Module):
    """
    Given the right, fake, wrong, wrong_sampled embedding, use the N Pair Loss
    objective (which is an extension to the triplet loss)

    Loss = log(1+exp(feat*wrong - feat*right + feat*fake - feat*right)) + L2 norm.

    Improved Deep Metric Learning with Multi-class N-pair Loss Objective (NIPS)
    """
    def __init__(self, ninp, margin):
        super(nPairLoss, self).__init__()
        self.ninp = ninp
        self.margin = np.log(margin)

    def forward(self, feat, right, wrong, batch_wrong, fake=None, fake_diff_mask=None):

        num_wrong = wrong.size(1)
        batch_size = feat.size(0)

        feat = feat.view(-1, self.ninp, 1)
        right_dis = torch.bmm(right.view(-1, 1, self.ninp), feat)
        wrong_dis = torch.bmm(wrong, feat)
        batch_wrong_dis = torch.bmm(batch_wrong, feat)

        wrong_score = torch.sum(torch.exp(wrong_dis - right_dis.expand_as(wrong_dis)),1) \
                + torch.sum(torch.exp(batch_wrong_dis - right_dis.expand_as(batch_wrong_dis)),1)

        loss_dis = torch.sum(torch.log(wrong_score + 1))
        loss_norm = right.norm() + feat.norm() + wrong.norm() + batch_wrong.norm()

        if fake:
            fake_dis = torch.bmm(fake.view(-1, 1, self.ninp), feat)
            fake_score = torch.masked_select(torch.exp(fake_dis - right_dis), fake_diff_mask)

            margin_score = F.relu(torch.log(fake_score + 1) - self.margin)
            loss_fake = torch.sum(margin_score)
            loss_dis += loss_fake
            loss_norm += fake.norm()

        loss = (loss_dis + 0.1 * loss_norm) / batch_size
        if fake:
            return loss, loss_fake.data[0] / batch_size
        else:
            return loss

In [11]:
class share_Linear(Module):
    r"""Applies a linear transformation to the incoming data: :math:`y = Ax + b`
    Args:
        in_features: size of each input sample
        out_features: size of each output sample
        bias: If set to False, the layer will not learn an additive bias. Default: True
    Shape:
        - Input: :math:`(N, in\_features)`
        - Output: :math:`(N, out\_features)`
    Attributes:
        weight: the learnable weights of the module of shape (out_features x in_features)
        bias:   the learnable bias of the module of shape (out_features)
    Examples::
        >>> m = nn.Linear(20, 30)
        >>> input = autograd.Variable(torch.randn(128, 20))
        >>> output = m(input)
        >>> print(output.size())
    """

    def __init__(self, weight):
        super(share_Linear, self).__init__()
        self.in_features = weight.size(0)
        self.out_features = weight.size(1)
        self.weight = weight.t()
        self.register_parameter('bias', None)

    def forward(self, input):
        return F.linear(input, self.weight, self.bias)

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
            + str(self.in_features) + ' -> ' \
            + str(self.out_features) + ')'


In [12]:
n_neg = negative_sample
vocab_size = dataset.vocab_size
ques_length = dataset.ques_length
ans_length = dataset.ans_length + 1
his_length = dataset.ans_length + dataset.ques_length
itow = dataset.itow
img_feat_size = 512

In [13]:
model = 'LSTM'
ninp = 300
nhid = 512
nlayers = 1 
dropout = 0.5
margin = 2 

netE = _netE(model, ninp, nhid, nlayers, dropout, img_feat_size)
netW = _netW(vocab_size, ninp, dropout)
netD = _netD(model, ninp, nhid, nlayers, vocab_size, dropout)
critD =nPairLoss(ninp, margin)

In [14]:
def sample_batch_neg(answerIdx, negAnswerIdx, sample_idx, num_sample):
    """
    input:
    answerIdx: batch_size
    negAnswerIdx: batch_size x opt.negative_sample

    output:
    sample_idx = batch_size x num_sample
    """

    batch_size = answerIdx.size(0)
    num_neg = negAnswerIdx.size(0) * negAnswerIdx.size(1)
    negAnswerIdx = negAnswerIdx.clone().view(-1)
    for b in range(batch_size):
        gt_idx = answerIdx[b]
        for n in range(num_sample):
            while True:
                rand = int(random.random() * num_neg)
                neg_idx = negAnswerIdx[rand]
                if gt_idx != neg_idx:
                    sample_idx.data[b, n] = rand
                    break


In [15]:
def repackage_hidden(h, batch_size):
    """Wraps hidden states in new Variables, to detach them from their history."""
    if type(h) == Variable:
        return Variable(h.data.resize_(h.size(0), batch_size, h.size(2)).zero_())
    else:
        return tuple(repackage_hidden(v, batch_size) for v in h)


In [16]:
def train(epoch):
    netW.train()
    netE.train()
    netD.train()

#     lr = adjust_learning_rate(optimizer, epoch, opt.lr)

    ques_hidden = netE.init_hidden(batchSize)
    hist_hidden = netE.init_hidden(batchSize)

    real_hidden = netD.init_hidden(batchSize)
    wrong_hidden = netD.init_hidden(batchSize)

    data_iter = iter(dloader)

    average_loss = 0
    count = 0
    i = 0

    while i < len(dloader):

        t1 = time.time()
        data = data_iter.next()
        image, history, question, answer, answerT, answerLen, answerIdx, questionL, \
                                    opt_answerT, opt_answerLen, opt_answerIdx = data

        batch_size = question.size(0)
        image = image.view(-1, img_feat_size)
        img_input.data.resize_(image.size()).copy_(image)

        for rnd in range(10):
            netW.zero_grad()
            netE.zero_grad()
            netD.zero_grad()
            # get the corresponding round QA and history.
            ques = question[:,rnd,:].t()
            his = history[:,:rnd+1,:].clone().view(-1, his_length).t()

            ans = answer[:,rnd,:].t()
            tans = answerT[:,rnd,:].t()
            wrong_ans = opt_answerT[:,rnd,:].clone().view(-1, ans_length).t()

            real_len = answerLen[:,rnd]
            wrong_len = opt_answerLen[:,rnd,:].clone().view(-1)

            ques_input.data.resize_(ques.size()).copy_(ques)
            his_input.data.resize_(his.size()).copy_(his)

            ans_input.data.resize_(ans.size()).copy_(ans)
            ans_target.data.resize_(tans.size()).copy_(tans)
            wrong_ans_input.data.resize_(wrong_ans.size()).copy_(wrong_ans)

            # sample in-batch negative index
            batch_sample_idx.data.resize_(batch_size, neg_batch_sample).zero_()
            sample_batch_neg(answerIdx[:,rnd], opt_answerIdx[:,rnd,:], batch_sample_idx, neg_batch_sample)

            ques_emb = netW(ques_input, format = 'index')
            his_emb = netW(his_input, format = 'index')

            ques_hidden = repackage_hidden(ques_hidden, batch_size)
            hist_hidden = repackage_hidden(hist_hidden, his_input.size(1))

            featD, ques_hidden = netE(ques_emb, his_emb, img_input, \
                                                ques_hidden, hist_hidden, rnd+1)

            ans_real_emb = netW(ans_target, format='index')
            ans_wrong_emb = netW(wrong_ans_input, format='index')

            real_hidden = repackage_hidden(real_hidden, batch_size)
            wrong_hidden = repackage_hidden(wrong_hidden, ans_wrong_emb.size(1))

            real_feat = netD(ans_real_emb, ans_target, real_hidden, vocab_size)
            wrong_feat = netD(ans_wrong_emb, wrong_ans_input, wrong_hidden, vocab_size)

            batch_wrong_feat = wrong_feat.index_select(0, batch_sample_idx.view(-1))
            wrong_feat = wrong_feat.view(batch_size, -1, ninp)
            batch_wrong_feat = batch_wrong_feat.view(batch_size, -1, ninp)

            nPairLoss = critD(featD, real_feat, wrong_feat, batch_wrong_feat)

            average_loss += nPairLoss.data[0]
            nPairLoss.backward()
            optimizer.step()
            count += 1

        i += 1
        if i % log_interval == 0:
            average_loss /= count
            print("step {} / {} (epoch {}), g_loss {:.3f}, lr = {:.6f}"\
                .format(i, len(dataloader), epoch, average_loss, lr))
            average_loss = 0
            count = 0

    return average_loss


In [17]:
img_input = torch.FloatTensor(batchSize)
ques_input = torch.LongTensor(ques_length, batchSize)
his_input = torch.LongTensor(his_length, batchSize)

# answer input
ans_input = torch.LongTensor(ans_length, batchSize)
ans_target = torch.LongTensor(ans_length, batchSize)
wrong_ans_input = torch.LongTensor(ans_length, batchSize)
sample_ans_input = torch.LongTensor(1, batchSize)
opt_ans_input = torch.LongTensor(ans_length, batchSize)

batch_sample_idx = torch.LongTensor(batchSize)
fake_diff_mask = torch.ByteTensor(batchSize)
fake_len = torch.LongTensor(batchSize)
noise_input = torch.FloatTensor(batchSize)
gt_index = torch.LongTensor(batchSize)

In [18]:
ques_input = Variable(ques_input)
img_input = Variable(img_input)
his_input = Variable(his_input)

ans_input = Variable(ans_input)
ans_target = Variable(ans_target)
wrong_ans_input = Variable(wrong_ans_input)
sample_ans_input = Variable(sample_ans_input)

noise_input = Variable(noise_input)
batch_sample_idx = Variable(batch_sample_idx)
fake_diff_mask = Variable(fake_diff_mask)
opt_ans_input = Variable(opt_ans_input)
gt_index = Variable(gt_index)

In [19]:
lr = 0.0004
beta1 = 0.8
niter = 50
neg_batch_sample = 30 
log_interval = 50
save_iter = 10000000

optimizer = optim.Adam([{'params': netW.parameters()},
                        {'params': netE.parameters()},
                        {'params': netD.parameters()}], lr=lr, betas=(beta1, 0.999))


history = []

for epoch in range(1, niter):

    t = time.time()
    train_loss, lr = train(epoch)
    print ('Epoch: %d learningRate %4f train loss %4f Time: %3f' % (epoch, lr, train_loss, time.time()-t))
    train_his = {'loss': train_loss}

    print('Evaluating ... ')
    rank_all = val()
    R1 = np.sum(np.array(rank_all)==1) / float(len(rank_all))
    R5 =  np.sum(np.array(rank_all)<=5) / float(len(rank_all))
    R10 = np.sum(np.array(rank_all)<=10) / float(len(rank_all))
    ave = np.sum(np.array(rank_all)) / float(len(rank_all))
    mrr = np.sum(1/(np.array(rank_all, dtype='float'))) / float(len(rank_all))
    print ('%d/%d: mrr: %f R1: %f R5 %f R10 %f Mean %f' %(epoch, len(dataloader_val), mrr, R1, R5, R10, ave))
    val_his = {'R1': R1, 'R5':R5, 'R10': R10, 'Mean':ave, 'mrr':mrr}
    history.append({'epoch':epoch, 'train': train_his, 'val': val_his})

    # saving the model.
    if epoch % save_iter == 0:
        torch.save({'epoch': epoch,
                    'opt': opt,
                    'netW': netW.state_dict(),
                    'netD': netD.state_dict(),
                    'netE': netE.state_dict()},
                    '%s/epoch_%d.pth' % (save_path, epoch))

        json.dump(history, open('%s/log.json' %(save_path), 'w'))




KeyboardInterrupt: 