In [None]:
import math
import mindspore
import numpy as np
import mindspore.nn as nn
import mindspore.ops.composite as C
import mindspore.ops.functional as F
import mindspore.ops.operations as P
from mindspore import Tensor, Parameter
from rnn import RNN, WithLossCell

In [None]:
def make_batch(seq_data, num_dic, n_step):
    input_batch, output_batch, target_batch = [], [], []

    for seq in seq_data:
        for i in range(2):
            seq[i] = seq[i] + 'P' * (n_step - len(seq[i]))

        input = [num_dic[n] for n in seq[0]]
        output = [num_dic[n] for n in ('S' + seq[1])]
        target = [num_dic[n] for n in (seq[1] + 'E')]

        input_batch.append(np.eye(n_class)[input])
        output_batch.append(np.eye(n_class)[output])
        target_batch.append(target) # not one-hot

    # make tensor
    return Tensor(input_batch, mindspore.float32), Tensor(output_batch, mindspore.float32), Tensor(target_batch, mindspore.int32)

In [None]:
# Model
class Seq2Seq(nn.Cell):
    def __init__(self, n_class, n_hidden, dropout):
        super(Seq2Seq, self).__init__()

        self.enc_cell = RNN(input_size=n_class, hidden_size=n_hidden, dropout=dropout)
        self.dec_cell = RNN(input_size=n_class, hidden_size=n_hidden, dropout=dropout)
        self.fc = nn.Dense(n_hidden, n_class)
        
        self.transpose = P.Transpose()
        
    def construct(self, enc_input, enc_hidden, dec_input):
        enc_input = self.transpose(enc_input, (1, 0, 2)) # enc_input: [max_len(=n_step, time step), batch_size, n_class]
        dec_input = self.transpose(dec_input, (1, 0, 2)) # dec_input: [max_len(=n_step, time step), batch_size, n_class]

        # enc_states : [num_layers(=1) * num_directions(=1), batch_size, n_hidden]
        _, enc_states = self.enc_cell(enc_input, enc_hidden)
        # outputs : [max_len+1(=6), batch_size, num_directions(=1) * n_hidden(=128)]
        outputs, _ = self.dec_cell(dec_input, enc_states)

        model = self.fc(outputs) # model : [max_len+1(=6), batch_size, n_class]
        return model

In [None]:
n_step = 5
n_hidden = 128
dropout = 0.5
char_arr = [c for c in 'SEPabcdefghijklmnopqrstuvwxyz']
num_dic = {n: i for i, n in enumerate(char_arr)}
seq_data = [['man', 'women'], ['black', 'white'], ['king', 'queen'], ['girl', 'boy'], ['up', 'down'], ['high', 'low']]

n_class = len(num_dic)
batch_size = len(seq_data)

In [None]:
model = Seq2Seq(n_class, n_hidden, dropout)

In [None]:
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
optimizer = nn.Adam(model.get_parameters(), learning_rate=0.001)

In [None]:
input_batch, output_batch, target_batch = make_batch(seq_data, num_dic, n_step)

In [None]:
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")

In [None]:
net_with_criterion = WithLossCell(model, criterion)
train_network = nn.TrainOneStepCell(net_with_criterion, optimizer)

for epoch in range(5000):
    # make hidden shape [num_layers * num_directions, batch_size, n_hidden]
    hidden = Tensor(np.zeros((batch_size, n_hidden)), mindspore.float32)
    # input_batch : [batch_size, max_len(=n_step, time step), n_class]
    # output_batch : [batch_size, max_len+1(=n_step, time step) (becase of 'S' or 'E'), n_class]
    # target_batch : [batch_size, max_len+1(=n_step, time step)], not one-hot
    loss = train_network(input_batch, hidden, output_batch, target_batch)
    if (epoch + 1) % 1000 == 0:
        print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss.asnumpy()))

In [None]:
# Test
def translate(word):
    input_batch, output_batch, _ = make_batch([[word, 'P' * len(word)]],  num_dic, n_step)
    # make hidden shape [num_layers * num_directions, batch_size, n_hidden]
    hidden = Tensor(np.zeros((1, n_hidden)), mindspore.float32)
    output = model(input_batch, hidden, output_batch)
    # output : [max_len+1(=6), batch_size(=1), n_class]

    predict = output.asnumpy().argmax(2) # select n_class dimension
    decoded = [char_arr[i[0]] for i in predict]
    end = decoded.index('E')
    translated = ''.join(decoded[:end])

    return translated.replace('P', '')

print('test')
print('man ->', translate('man'))
print('mans ->', translate('mans'))
print('king ->', translate('king'))
print('black ->', translate('black'))
print('upp ->', translate('upp'))