In [None]:
import os
import numpy as np
import re
import sys
import random
import unicodedata
import math

from mindspore import Tensor, nn, Model, context
from mindspore.train.serialization import load_param_into_net, load_checkpoint
from mindspore.train.callback import LossMonitor, CheckpointConfig, ModelCheckpoint, TimeMonitor
from mindspore import dataset as ds
from mindspore.mindrecord import FileWriter
from mindspore import Parameter
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype

In [None]:
from easydict import EasyDict as edict

# CONFIG
cfg = edict({
    'en_vocab_size': 1154,
    'ch_vocab_size': 1116,
    'max_seq_length': 10,
    'hidden_size': 1024,
    'batch_size': 16,
    'eval_batch_size': 1,
    'learning_rate': 0.001,
    'momentum': 0.9,
    'num_epochs': 15,
    'save_checkpoint_steps': 125,
    'keep_checkpoint_max': 10,
    'dataset_path':'./preprocess',
    'ckpt_save_path':'./ckpt',
    'checkpoint_path':'./ckpt/gru-15_125.ckpt'
})

In [None]:
class Attention(nn.cell):
    def __init__(self, config, is_training=True):
        super(Attention, self).__init__()
        self.hidden_size = config.hidden_size
        self.attnq = nn.Dense(self.hidden_size, self.hidden_size)
        self.attni = nn.Dense(self.hidden_size, self.hidden_size)
        self.attnp = nn.Dense(self.hidden_size, 1, activation = "softmax")
        self.add = P.Add()
        self.mul = P.Mul()

    def sum(x):
        x = np.array(x)
        y = np.zeros((x.shape[0], x.shape[2])).astype(x.dtype)
        for i in range(0,x.shape[0]):
            for j in range(0,x.shape[1]):
                for k in range(0,x.shape[2]):
                    y[i][k] += x[i][j][k]
        return Tensor(y)

    def construct(self, question, img):
        i_attn = self.attni(img)
        q_attn = self.attnq(question)
        ha = nn.tanh(self.add(i_attn, q_attn))
        p = self.attnp(ha)
        u = self.add(self.sum(self.mul(p,img)), question)
        return u

class SAN(nn.cell):
    def __init__(self, config, is_training=True):
        super(SAN, self).__init__()
        self.hidden_size = config.hidden_size
        self.attn_1 = Attention(config = config, is_training = is_training)
        self.attn_2 = Attention(config = config, is_training = is_training)
        self.attn_3 = Attention(config = config, is_training = is_training)

    def construct(self, question, img):

        u_1 = self.attn_1(question,img)
        u_2 = self.attn_2(u_1,img)
        u_3 = self.attn_3(u_2,img)

        return u_3
        
class Encoder(nn.Cell):
    def __init__(self, config, is_training=True):
        super(Encoder, self).__init__()
        self.vocab_size = config.en_vocab_size
        self.hidden_size = config.hidden_size
        if is_training:
            self.batch_size = config.batch_size
        else:
            self.batch_size = config.eval_batch_size

        self.trans = P.Transpose()
        self.perm = (1, 0, 2)
        self.embedding = nn.Embedding(self.vocab_size, self.hidden_size)
        self.gru = GRU(config, is_training=is_training).to_float(mstype.float16)
        self.cnn = CNN(config, is_training=is_training)
        self.h = Tensor(np.zeros((self.batch_size, self.hidden_size)).astype(np.float16))

    def construct(self, encoder_input, img_input):
        embeddings = self.embedding(encoder_input)
        embeddings = self.trans(embeddings, self.perm)
        output, hidden = self.gru(embeddings, self.h)
        img_output = self.cnn(img_input)
        return output, hidden, img_output

class Decoder(nn.Cell):
    def __init__(self, config, is_training=True, dropout=0.1):
        super(Decoder, self).__init__()

        self.vocab_size = config.ch_vocab_size
        self.hidden_size = config.hidden_size
        self.max_len = config.max_seq_length

        self.trans = P.Transpose()
        self.perm = (1, 0, 2)
        self.embedding = nn.Embedding(self.vocab_size, self.hidden_size)
        self.dropout = nn.Dropout(1-dropout)
        self.attn = nn.Dense(self.hidden_size, self.max_len)
        self.softmax = nn.Softmax(axis=2)
        self.bmm = P.BatchMatMul()
        self.concat = P.Concat(axis=2)
        self.attn_combine = nn.Dense(self.hidden_size * 2, self.hidden_size)

        self.gru = GRU(config, is_training=is_training).to_float(mstype.float16)
        self.out = nn.Dense(self.hidden_size, self.vocab_size)
        self.logsoftmax = nn.LogSoftmax(axis=2)
        self.cast = P.Cast()

    def construct(self, decoder_input, hidden, encoder_output):
        embeddings = self.embedding(decoder_input)
        embeddings = self.dropout(embeddings)
        # calculate attn
        attn_weights = self.softmax(self.attn(embeddings))
        encoder_output = self.trans(encoder_output, self.perm)
        attn_applied = self.bmm(attn_weights, self.cast(encoder_output,mstype.float32))
        output = self.concat((embeddings, attn_applied))
        output = self.attn_combine(output)


        embeddings = self.trans(embeddings, self.perm)
        output, hidden = self.gru(embeddings, hidden)
        output = self.cast(output, mstype.float32)
        output = self.out(output)
        output = self.logsoftmax(output)

        return output, hidden, attn_weights

class Seq2Seq(nn.Cell):
    def __init__(self, config, is_train=True):
        super(Seq2Seq, self).__init__()
        self.max_len = config.max_seq_length
        self.is_train = is_train

        self.encoder = Encoder(config, is_train)
        self.decoder = Decoder(config, is_train)
        self.expanddims = P.ExpandDims()
        self.squeeze = P.Squeeze(axis=0)
        self.argmax = P.ArgMaxWithValue(axis=int(2), keep_dims=True)
        self.concat = P.Concat(axis=1)
        self.concat2 = P.Concat(axis=0)
        self.select = P.Select()
        self.san = SAN(config,is_train)

    def construct(self, src, dst, img):
        encoder_output, hidden = self.encoder(src)
        img_output = self.cnn(img)
        san_out = self.san(encoder_output,img_output)
        
        decoder_hidden = self.squeeze(encoder_output[self.max_len-2:self.max_len-1:1, ::, ::])
        if self.is_train:
            outputs, _ = self.decoder(dst, decoder_hidden, san_out)
        else:
            decoder_input = dst[::,0:1:1]
            decoder_outputs = ()
            for i in range(0, self.max_len):
                decoder_output, decoder_hidden, _ = self.decoder(decoder_input, 
                                                                 decoder_hidden, san_out)
                decoder_hidden = self.squeeze(decoder_hidden)
                decoder_output, _ = self.argmax(decoder_output)
                decoder_output = self.squeeze(decoder_output)
                decoder_outputs += (decoder_output,)
                decoder_input = decoder_output
            outputs = self.concat(decoder_outputs)
        # if self.is_train:
        #     outputs, _ = self.decoder(dst, decoder_hidden, encoder_output)
        # else:
        #     decoder_input = dst[::,0:1:1]
        #     decoder_outputs = ()
        #     for i in range(0, self.max_len):
        #         decoder_output, decoder_hidden, _ = self.decoder(decoder_input, 
        #                                                          decoder_hidden, encoder_output)
        #         decoder_hidden = self.squeeze(decoder_hidden)
        #         decoder_output, _ = self.argmax(decoder_output)
        #         decoder_output = self.squeeze(decoder_output)
        #         decoder_outputs += (decoder_output,)
        #         decoder_input = decoder_output
        #     outputs = self.concat(decoder_outputs)
        return outputs

class NLLLoss(_Loss):
    '''
       NLLLoss function
    '''
    def __init__(self, reduction='mean'):
        super(NLLLoss, self).__init__(reduction)
        self.one_hot = P.OneHot()
        self.reduce_sum = P.ReduceSum()

    def construct(self, logits, label):
        label_one_hot = self.one_hot(label, F.shape(logits)[-1], F.scalar_to_array(1.0), 
                                     F.scalar_to_array(0.0))
        #print('NLLLoss label_one_hot:',label_one_hot, label_one_hot.shape)
        #print('NLLLoss logits:',logits, logits.shape)
        #print('xxx:', logits * label_one_hot)
        loss = self.reduce_sum(-1.0 * logits * label_one_hot, (1,))
        return self.get_loss(loss)
    
class WithLossCell(nn.Cell):
    def __init__(self, backbone, config):
        super(WithLossCell, self).__init__(auto_prefix=False)
        self._backbone = backbone
        self.batch_size = config.batch_size
        self.onehot = nn.OneHot(depth=config.ch_vocab_size)
        self._loss_fn = NLLLoss()
        self.max_len = config.max_seq_length
        self.squeeze = P.Squeeze()
        self.cast = P.Cast()
        self.argmax = P.ArgMaxWithValue(axis=1, keep_dims=True)
        self.print = P.Print()

    def construct(self, src, dst, label):
        out = self._backbone(src, dst)
        loss_total = 0
        for i in range(self.batch_size):
            loss = self._loss_fn(self.squeeze(out[::,i:i+1:1,::]), 
                                 self.squeeze(label[i:i+1:1, ::]))
            loss_total += loss
        loss = loss_total / self.batch_size
        return loss

In [None]:
network = Seq2Seq(cfg)
network = WithLossCell(network, cfg)
optimizer = nn.Adam(network.trainable_params(), learning_rate=cfg.learning_rate, beta1=0.9, beta2=0.98)
model = Model(network, optimizer=optimizer)

In [None]:
loss_cb = LossMonitor()
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="gru", directory=cfg.ckpt_save_path, config=config_ck)
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
callbacks = [time_cb, ckpoint_cb, loss_cb]

model.train(cfg.num_epochs, ds_train, callbacks=callbacks, dataset_sink_mode=True)