In [83]:
import mindspore
from mindspore import Tensor
import mindspore.ops as ops
import mindspore.nn as nn
from mindspore.train.model import Model
from mindspore.nn.metrics import Accuracy
from mindspore.train.callback import LossMonitor, TimeMonitor
import mindspore.dataset as ds
from mindspore import context

import numpy as np
import h5py
import json
import re
import moxing as mox
import pickle

In [84]:
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', device_id=0)

In [85]:
mox.file.copy_parallel(src_url="s3://nlp-xyz/word-vec/proj/nlp2021/annotations", dst_url='./annotations')
mox.file.copy_parallel(src_url="s3://nlp-xyz/word-vec/proj/nlp2021/questions", dst_url='./questions')
mox.file.copy_parallel(src_url="s3://nlp-xyz/word-vec/proj/nlp2021/vocab", dst_url='./vocab')
mox.file.copy_parallel(src_url="s3://nlp-xyz/word-vec/proj/image_features", dst_url='./image_features')
mox.file.copy_parallel(src_url="s3://nlp-xyz/word-vec/proj/dataset.pkl", dst_url='./dataset.pkl')

In [86]:
class text_encoder(nn.Cell):
    def __init__(self, vocab_size, embedding_size=2048, max_len=100):
        super(text_encoder, self).__init__()

        self.vocab_size = vocab_size
        self.embedding_size = embedding_size
        self.max_len = max_len

        self.embed = nn.Embedding(vocab_size, embedding_size)
        self.WQ = nn.Dense(embedding_size, embedding_size)
        self.WK = nn.Dense(embedding_size, embedding_size)
        self.WV = nn.Dense(embedding_size, embedding_size)

        self.pe = self.pe_gen()
        
        self.matmul = ops.BatchMatMul()
        self.transpose = ops.Transpose()
        self.softmax = nn.Softmax()
        self.reducemax = ops.ReduceMax(keep_dims=True)
        self.print = ops.Print()

    def pe_gen(self):
        pe = np.empty((self.max_len, self.embedding_size))
        even = 10000**(np.arange(self.embedding_size, step=2)/self.embedding_size)
        odd = 10000**(np.arange(self.embedding_size-1, step=2)/self.embedding_size)
        for pos in range(self.max_len):
            pe[pos, ::2] = np.sin(pos/even)
            pe[pos, 1::2] = np.cos(pos/odd)
        return Tensor(pe.astype('float32'))

    def construct(self, x, squeeze=False): # len
        x = self.embed(x) # batch * len * emb
        x = x + self.pe[:x.shape[1]] # batch * len * emb
        
        Q = self.WQ(x) # batch * len * emb
        K = self.WK(x) # batch * len * emb
        V = self.WV(x) # batch * len * emb
        
        QK = self.matmul(Q, self.transpose(K, (0,2,1))) # batch * len * len
        Z = self.matmul(self.softmax(QK), V) # batch * len * emb
        
        if squeeze:
            Z = self.reducemax(Z, 1) # batch * 1 * emb
        return Z # batch * 1 * emb

In [87]:
class block(nn.Cell):
    def __init__(self, hidden_size=2048):
        super(block, self).__init__()
        self.Wq = nn.Dense(hidden_size, hidden_size)
        self.Wi = nn.Dense(hidden_size, hidden_size, has_bias=False)
        self.Wp = nn.Dense(hidden_size, 1)
        self.transpose = ops.Transpose()
        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax(0)
        self.reduce_sum = ops.ReduceSum()
        self.matmul = ops.BatchMatMul()
    
    def construct(self, v_i, v_q):
        encoded_q = self.Wq(v_q) # batch * 1 * hid
        encoded_i = self.Wi(self.transpose(v_i, (0,2,1))) # batch * 36 * hid
        hA = self.tanh(encoded_q + encoded_i) # batch * 36 * hid
        pI = self.softmax(self.Wp(hA)) # batch * 36 * 1
        vI = self.matmul(self.transpose(pI, (0,2,1)), hA) # batch * 1 * hid
        u = vI + v_q
        return u

class SAN(nn.Cell):
    def __init__(self, hidden_size=2048):
        super(SAN, self).__init__()
        
        self.block0 = block(hidden_size)
        self.block1 = block(hidden_size)
        self.block2 = block(hidden_size)
        self.block3 = block(hidden_size)

    def construct(self, v_i, v_q): # v_i: batch * 36 * hid, v_q: batch * 1 * hid
        u0 = self.block0(v_i, v_q)
        u1 = self.block1(v_i, u0)
        u2 = self.block2(v_i, u1)
        u3 = self.block3(v_i, u2)

        return u3 # 1 * hid

In [88]:
class VQA(nn.Cell):
    def __init__(self, question_vocab_size, answer_vocab_size, hidden_size=2048):
        super(VQA, self).__init__()
        self.question_encoder = text_encoder(question_vocab_size, hidden_size)
        self.answer_encoder = text_encoder(answer_vocab_size, hidden_size)
        self.SAN = SAN(hidden_size)
        self.matmul = ops.BatchMatMul()
        self.transpose = ops.Transpose()
        self.argmax = ops.Argmax()
        self.softmax = nn.Softmax()
        self.cast = ops.Cast()
        self.print = ops.Print()

    def construct(self, x):
        v_i = x[:, :-1]
        q = self.cast(x[:, -2:-1, :20], mindspore.int32).view(x.shape[0],-1) # 
        ans = self.cast(x[:, -2:-1, 20:], mindspore.int32).view(x.shape[0],-1)
        v_q = self.question_encoder(q, True) # batch * 1 * hid
        v_ans = self.answer_encoder(ans) # batch * n * hid
        uk = self.SAN(v_i, v_q) # batch * 1 * hid
        score = self.matmul(uk, self.transpose(v_ans, (0,2,1)))
        prob = self.softmax(score) # batch * 1 * n
        chosen = self.argmax(score)
        
        # return score.view(x.shape[0], -1)
        return score
        # return prob
        # return chosen

In [89]:
# vqa = VQA(20000, 20000)

In [90]:
# im = np.random.randn(64, 2048, 36)
# text = np.concatenate([np.arange(36)]*64).reshape(64,1,36)
# syn = Tensor(np.concatenate((im, text), 1).astype('float32'))

In [91]:
# vqa(syn)

In [92]:
categories = ['train', 'val', 'test']

In [93]:
questions = {}
annotations = {}
features = {}

for c in categories:
    quest_json = open('./questions/{}.json'.format(c))
    quest = json.load(quest_json)
    quest_json.close()
    questions[c] = quest['questions']

    anno_json = open('./annotations/{}.json'.format(c))
    anno = json.load(anno_json)
    anno_json.close()
    annotations[c] = anno['annotations']

    try:
        features[c] = h5py.File('./image_features/image-{}.h5'.format(c), 'r')
    except:
        pass

vocab_file = open('./vocab/test.json')
vocab = json.load(vocab_file)
vocab_file.close()

In [94]:
_special_chars = re.compile('[^a-z0-9 ]*')
_period_strip = re.compile(r'(?!<=\d)(\.)(?!\d)')
_comma_strip = re.compile(r'(\d)(,)(\d)')
_punctuation_chars = re.escape(r';/[]"{}()=+\_-><@`,?!')
_punctuation = re.compile(r'([{}])'.format(re.escape(_punctuation_chars)))
_punctuation_with_a_space = re.compile(r'(?<= )([{0}])|([{0}])(?= )'.format(_punctuation_chars))

def question_tokenize(question):
    question = question.lower()[:-1]
    question = _special_chars.sub('', question)
    return question.split(' ')

def answer_tokenize(answer):
    if _punctuation.search(answer) is None:
        return answer
    answer = _punctuation_with_a_space.sub('', answer)
    if re.search(_comma_strip, answer) is not None:
        answer = answer.replace(',', '')
    answer = _punctuation.sub(' ', answer)
    answer = _period_strip.sub('', answer)
    return answer.strip()

def question2vec(question):
    return Tensor([vocab['question'][word] for word in question_tokenize(question)])

def answer2vec(answer):
    return Tensor([vocab['answer'][answer_tokenize(answer)]])

In [95]:
data_file = open('dataset.pkl', 'rb')
dataset = pickle.load(data_file)
data_file.close()

In [96]:
nb_answers = max([max([data['answer_label'] for data in dataset[c]]) for c in categories]) + 1

In [97]:
nb_choices = 16

In [98]:
vqa = VQA(len(vocab['question']), len(vocab['answer']))

In [99]:
# im = np.random.randn(64, 2048, 36)
# text = np.concatenate([np.arange(36)]*64).reshape(64,1,36)
# syn = Tensor(np.concatenate((im, text), 1).astype('float32'))

In [100]:
# a = vqa(syn)

In [101]:
id2idx = {'train' : {}, 'test' : {}, 'val' : {}}
for c in categories:
    id2idx[c] = {id : i for i, id in enumerate(features[c]['ids'])}

In [102]:
id2idx_file = open('id2idx.pkl', 'wb')
pickle.dump(id2idx, id2idx_file)
id2idx_file.close()
mox.file.copy_parallel(src_url="./id2idx.pkl", dst_url='s3://nlp-xyz/word-vec/proj/id2idx.pkl')

In [103]:
train_data_set = []
for i, data in enumerate(dataset['train']):
    if i >= 64:
        break
    answers = np.empty(nb_choices).astype('int32')
    answers[0] = data['answer_label']
    answers[1:] = np.random.choice(nb_answers, nb_choices - 1, replace=False)
    while answers[0] in answers[1:]:
        answers[1:] = np.random.choice(nb_answers, nb_choices - 1, replace=False)
    np.random.shuffle(answers)
    label = np.where(answers==data['answer_label'])[0].astype('int32')
    question_answer = np.concatenate((data['question'].astype('int32').reshape(1,-1), answers.reshape(1,-1)),1)
    image_text = np.concatenate((features['train']['features'][id2idx['train'][data['image_id']]].astype('float32'), 
                                question_answer.astype('float32')), 0)
    train_data_set.append(
        #(features['train']['features'][id2idx['train'][data['image_id']]].astype('float32'), 
         #data['question'].astype('int32'),
         #answers, 
         #label)
        
        (image_text, label)
    )

In [104]:
class Generator():
    def __init__(self, input_list):
        self.input_list=input_list
    def __getitem__(self,item):
        return (self.input_list[item][0], self.input_list[item][1])
        # return (self.input_list[item][0], self.input_list[item][1], self.input_list[item][2], self.input_list[item][3])
        # return (self.input_list[item]['data'], self.input_list[item]['label'])
    def __len__(self):
        return len(self.input_list)

In [105]:
trainset = ds.GeneratorDataset(
    source=Generator(input_list=train_data_set), 
    # column_names=["v_i", 'q', 'ans',"label"],
    # column_names=["v_i", "label"],
    column_names=["data", "label"],
    shuffle=False
)
trainset = trainset.batch(batch_size=64)
# trainset = trainset.repeat(2)

In [110]:
# next(trainset.create_dict_iterator())

In [107]:
opt = nn.Adam(vqa.get_parameters(), learning_rate=0.001)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
model = Model(vqa, loss_fn=loss, optimizer=opt, metrics={'acc': Accuracy()})

In [108]:
time_cb = TimeMonitor(data_size=64)
loss_cb = LossMonitor()

In [111]:
model.train(1, trainset, callbacks=[time_cb, loss_cb])
print('train_success')