# nlp-project-vqa


In [1]:
import mindspore
import numpy as np
from easydict import EasyDict
from preprocess.preprocess import *

In [None]:
import moxing as mox
# 请替换成自己的obs路径
mox.file.copy_parallel(src_url="s3://nlp-haofeng/nlp_project_vqa", dst_url='./') 

In [2]:
from mindspore import context
context.set_context(mode=context.GRAPH_MODE)

## 1 预处理

### 1.1 预处理配置

In [3]:
padding = '<pad>'
config = EasyDict({
	'train_ans_path': './data/annotations/train.json',
	'train_que_path': './data/questions/train.json',
	'valid_ans_path': './data/annotations/val.json',
	'valid_que_path': './data/questions/val.json',
	'test_ans_path':  './data/annotations/test.json',
	'test_que_path':  './data/questions/test.json',
	'train_img_path': './data/images/train/COCO_train2014_',
	'test_img_path': './data/images/test/COCO_val2014_',
	'val_img_path': './data/images/val/COCO_val2014_',
	'max_length': 25,
	'dict_path': './mindrecord/dict.npy',
	'idx_word_dict_path': './mindrecord/idx_word_dict.npy',
	'num_splits': 1,
	'train_mindrecord_path': './mindrecord/train.mindrecord',
	'valid_mindrecord_path': './mindrecord/valid.mindrecord',
	'test_mindrecord_path':  './mindrecord/test.mindrecord',
})

### 1.2 读取数据

注: 只取那些答案长度为1的vqa组合

In [4]:
# get 3 types of input data
train_images, train_questions, train_answers = get_list(config.train_que_path, config.train_ans_path)
valid_images, valid_questions, valid_answers = get_list(config.valid_que_path, config.valid_ans_path)
test_images,  test_questions,  test_answers  = get_list(config.test_que_path,  config.test_ans_path)

In [5]:
total_questions = train_questions + valid_questions + test_questions
total_answers = train_answers + valid_answers + test_answers

### 1.3 构建词典

In [6]:
# build word vocab
word_dict = dict({'<pad>': 0})
word_dict = add_word_into_dict(total_questions, word_dict)
word_dict = add_word_into_dict(total_answers, word_dict)

In [7]:
# build revert dict
idx_word_dict = dict()
for item in word_dict.items():
	idx_word_dict[item[1]] = item[0]

In [8]:
# save dict
np.save(config.dict_path, word_dict)
np.save(config.idx_word_dict_path, idx_word_dict)

### 1.4 向量化 & 补齐长度

In [9]:
# word -> vector & padding
train_questions_vec = get_vec_and_pad(train_questions, word_dict, config.max_length)
valid_questions_vec = get_vec_and_pad(valid_questions, word_dict, config.max_length)
test_questions_vec = get_vec_and_pad(test_questions, word_dict, config.max_length)

train_answers_vec = get_vec_and_pad(train_answers, word_dict, 1)
valid_answers_vec = get_vec_and_pad(valid_answers, word_dict, 1)
test_answers_vec = get_vec_and_pad(test_answers, word_dict, 1)


# train_images_list = read_image(train_images, config.train_img_path)
# np.save('./mindrecord/train_images_list', train_images_list)

# valid_images_list = read_image(valid_images, config.val_img_path)
# np.save('./mindrecord/valid_images_list', valid_images_list)

# test_images_list = read_image(test_images, config.test_img_path)
# np.save('./mindrecord/test_images_list', test_images_list)
train_images_list = train_images
valid_images_list = valid_images
test_images_list  = test_images

### 1.5 生成MindRecord

In [10]:
generate_mindrecord(config.train_mindrecord_path, config.num_splits, train_images_list, train_questions_vec, train_answers_vec)
generate_mindrecord(config.valid_mindrecord_path, config.num_splits, valid_images_list, valid_questions_vec, valid_answers_vec)
generate_mindrecord(config.test_mindrecord_path,  config.num_splits, test_images_list,  test_questions_vec, test_answers_vec)

train
valid
test


## 2 加载数据

### 2.1 加载词典

In [11]:
# load dict
word_dict = np.load(config.dict_path, allow_pickle=True).item()
idx_word_dict = np.load(config.idx_word_dict_path, allow_pickle=True).item()

### 2.2 训练配置

In [12]:
train_config = EasyDict({
	'model': 'baseline',
	'vocab_size': 10233,
	'batch_size': 128,
	'epoch_size': 10,
	'max_length': 25,
	'hidden_size': 1024,
	'lr': 1e-3,
	'momentum': 0.9,
	'epoch_size': 20,
	'early_stop': 100,
	# 'save_checkpoint_steps': 1279,
	'ckpt_save_path': './ckpt',
	'checkpoint_path': './ckpt/baseline',
	# 'keep_checkpoint_max': 2,
})

### 2.3 生成数据集

In [13]:
# create dataset
train_dataset = generate_dataset(config.train_mindrecord_path, train_config.batch_size, train_config.epoch_size)
valid_dataset = generate_dataset(config.valid_mindrecord_path, train_config.batch_size, train_config.epoch_size)
test_dataset  = generate_dataset(config.test_mindrecord_path, train_config.batch_size, 1)

## 3 训练模型

### 3.1 创建模型

In [14]:
import mindspore.nn as nn
import mindspore.ops.operations as P
import mindspore.ops.functional as F
from mindspore import dtype as mstype
from mindspore import Tensor

In [15]:
class Network(nn.Cell):
	def __init__(self):
		super(Network, self).__init__()
		self.reshape = P.Reshape()
		self.embedding = nn.Embedding(train_config.vocab_size, train_config.hidden_size)
		self.out = nn.Dense(train_config.hidden_size*train_config.max_length, train_config.vocab_size)
	def construct(self, x):
		x = self.embedding(x)
		x = x.reshape(x.shape[0], -1)
		x = self.out(x)
		return x

In [16]:
class NetworkWithLoss(nn.Cell):
	def __init__(self, network):
		super(NetworkWithLoss, self).__init__()
		self.network = network
		self.loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
		self.cast = P.Cast()
		self.reshape = P.Reshape()
	def construct(self, images, questions, answers):
		# print(images.shape)
		# print(questions.shape)
		# print(answers.shape)
		#questions = Tensor(questions)
		#answers = self.reshape(answers, (-1,))
		out = self.network(questions)
		#print(out.shape)
		loss = self.loss(out, answers)
		#print(loss.shape)
		return self.cast(loss, mstype.float32)


### 3.2 回调显示

In [17]:
from mindspore.train.callback import TimeMonitor
time_callback = TimeMonitor(data_size=train_dataset.get_dataset_size())
callbacks = [time_callback]

In [18]:
from mindspore.train.serialization import save_checkpoint

### 3.3 开始训练

In [19]:
def train():
	# 创建文件夹
	if not os.path.exists(train_config.ckpt_save_path):
		os.mkdir(train_config.ckpt_save_path)
	
	# TODO:创建网络
	network = Network()
	network = NetworkWithLoss(network)
	optimizer = nn.Adam(network.trainable_params(), learning_rate=train_config.lr, beta1=0.9, beta2=0.98)
	model = mindspore.Model(network, optimizer=optimizer, eval_network=network, metrics={"acc": nn.Accuracy(), "loss": nn.Loss()})
	
	# 训练，保留最好模型
	valid_acc_max = 0.0
	valid_loss_min = np.inf
	for epoch in range(train_config.epoch_size):
		print("train: ", end='')
		train_res = model.train(1, train_dataset, callbacks, dataset_sink_mode=True)
		train_loss = train_res['loss'].asnumpy()
		train_acc = train_res['acc'].asnumpy()
		print("valid: ", end='')
		valid_res = model.eval(valid_dataset, callbacks, dataset_sink_mode=True)
		valid_loss = valid_res['loss'].asnumpy()
		valid_acc = valid_res['acc'].asnumpy()
		print("epoch:{}, train loss={:.5f}, acc={:.5f} | valid loss={:.5f}, acc={:.5f}".format(
               epoch, train_loss, train_acc, valid_loss, valid_acc))
		
		if valid_acc >= valid_acc_max or valid_loss < valid_loss_min:
			if valid_acc >= valid_acc_max and valid_loss < valid_loss_min:
				valid_acc_model = valid_acc
				valid_loss_model = valid_loss_model
				save_checkpoint(model.network, train_config.checkpoint_path)
			valid_acc_max = np.max((valid_acc_max, valid_acc))
			valid_loss_min = np.min((valid_loss_min, valid_loss))
			current_step = 0
		else:
			current_step += 1
			if current_step == train_config.early_stop:
				print("early stop... min loss: {}, max acc: {}".format(valid_loss_min, valid_acc_max), end='')
				print("; validation model loss: {}, acc: {}".format(valid_loss_model, valid_acc_model))
				break


In [20]:
train()



train: 

## 4 测试模型

### 4.1 创建测试模型

In [None]:
from mindspore.train.serialization import load_checkpoint

In [None]:
def test():
	# TODO: 创建测试模型
	network = None
	load_checkpoint(train_config.checkpoint_path, net=network)
	network = NetworkWithLoss(network)
	model = mindspore.Model(network, metrics={'acc': nn.Accuracy(), 'loss': nn.Loss()})
	# 生成结果
	test_res = model.eval(test_dataset)
    
	print("test loss={}, acc={}".format(test_res['loss'].asnumpy(), test_res['acc'].asnumpy()))

### 4.2 开始测试

In [None]:
test()