In [None]:
import numpy as np
import chainer
from chainer import cuda, Chain, Variable, optimizers, datasets, iterators, training, serializers, report
import chainer.functions as F
import chainer.links as L
import sys


In [None]:
input_alphabet = "0123456789+= "
output_alphabet = "0123456789 "


In [None]:
def generate(nb_data, k):
    a = np.exp(np.random.uniform(np.log(1), np.log(10 ** k), nb_data)).astype("i")
    b = np.exp(np.random.uniform(np.log(1), np.log(10 ** k), nb_data)).astype("i")
    c = a + b
    return a, b, c

def encode_in(a, b, k):
    alphabet = np.array(list(input_alphabet))
    texts = np.array(["{}+{}=".format(a_, b_).rjust(k, " ") for a_, b_ in zip(a, b)])
    return np.array([[alphabet == c for c in s] for s in texts]).astype("f")

def encode_out(c, k):
    texts = np.array(["{}".format(c_).ljust(k, " ") for c_ in c])
    return np.array([[output_alphabet.index(c) for c in s] for s in texts]).astype("i")

def generate_dataset(nb_data, k):
    out_k = k + 1    # +1 for carry digit
    in_k = 2 * k + 2    # +1 for operator "+", +1 for "="
    a, b, c = generate(nb_data, k)
    return datasets.TupleDataset(encode_in(a, b, in_k), encode_out(c, out_k))


In [None]:
def decode_in(x):
    return "".join([input_alphabet[_] for _ in x])

def decode_out(x):
    return "".join([output_alphabet[_] for _ in x])


In [None]:
class Model(Chain):
    def __init__(self, unit, out_size):
        super(Model, self).__init__()
        with self.init_scope():
            self.l1 = L.Linear(len(input_alphabet), unit)
            self.l2 = L.LSTM(unit, unit)
            self.l3 = L.Linear(unit, len(output_alphabet))
            self.add_persistent("out_size", out_size)

    def __call__(self, x):
        self.l2.reset_state()
        for i in range(x.shape[1]):
            h = F.relu(self.l1(Variable(x[:, i, :])))
            h = self.l2(h)
        result = []
        for i in range(self.out_size):
            h = F.relu(h)
            h = self.l2(h)
            result.append(self.l3(h))
        result = F.concat([F.reshape(_, (-1, 1, 11)) for _ in result], axis=1)
        return result


In [None]:
class Calculator(Chain):
    def __init__(self, predictor, use_gpu=None):
        super(Calculator, self).__init__()
        with self.init_scope():
            self.predictor = predictor

        global xp
        if use_gpu is not None:
            cuda.get_device(use_gpu).use()
            self.predictor.to_gpu()
            xp = cuda.cupy
        else:
            xp = np

    def __call__(self, x, t):
        y = self.predictor(x)
        t = Variable(t.flatten())
        y = F.reshape(y, (-1, len(output_alphabet)))
        loss = F.softmax_cross_entropy(y, t)
        accu = F.accuracy(y, t)
        report({"loss": loss, "accuracy": accu, }, self)
        return loss

    def predict(self, x):
        y = self.predictor(x)
        return y.data


In [None]:
class CustomSerialIterator(iterators.SerialIterator):
    def __init__(self, generator, batch_size, data_update_epoch, repeat=True, shuffle=True):
        self.generator = generator
        self.data_update_epoch = data_update_epoch
        self.last_epoch = 0
        super().__init__(self.generator(), batch_size, repeat, shuffle)

    def next(self):
        if self.last_epoch != self.epoch:
            self.last_epoch = self.epoch
            if self.epoch % self.data_update_epoch == 0:
                self.dataset = self.generator()
        minibatch = super().next()
        return minibatch


In [None]:
def parse_args(cmdline=None):
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--gpu", type=int)
    parser.add_argument("--resume", type=str)
    parser.add_argument("--outdir", type=str, default="result")
    parser.add_argument("--unit", type=int, default=64)
    parser.add_argument("--data_samples", type=int, default=5000)
    parser.add_argument("--batch_size", type=int, default=100)
    parser.add_argument("--epochs", type=int, default=10000)
    parser.add_argument("--data_update_epoch", type=int, default=10)
    parser.add_argument("--number_width", type=int, default=5)
    args = parser.parse_args(cmdline)
    return args


In [None]:
def init(args):
    out_k = args.number_width + 1    # +1 for carry digit
    model = Calculator(Model(args.unit, out_k), args.gpu)
    optimizer = optimizers.Adam()
    optimizer.setup(model)
    return model, optimizer


In [None]:
def eval_hook(evaluator):
    model = evaluator.get_target("main")
    iterator = evaluator.get_iterator("main")
    iterator.reset()
    minibatch = next(iterator)
    x, t = evaluator.converter(minibatch)
    y = model.predict(x)
    print(["{} {} -> {}".format(\
        decode_in(x[i].argmax(axis=1)), decode_out(t[i]), decode_out(y[i].argmax(axis=1))\
    ) for i in range(len(x))])


In [None]:
def train(model, optimizer, args):
#     train_set = generate_dataset(args.data_samples, args.number_width)
#     train_iter = iterators.SerialIterator(train_set, batch_size=args.batch_size, repeat=True, shuffle=True)
    train_iter = CustomSerialIterator(lambda: generate_dataset(args.data_samples, args.number_width), data_update_epoch=args.data_update_epoch, batch_size=args.batch_size, repeat=True, shuffle=True)
    updater = training.StandardUpdater(train_iter, optimizer)
    trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir)

    test_set = generate_dataset(33, args.number_width)
    test_iter = iterators.SerialIterator(test_set, batch_size=3, repeat=False, shuffle=False)
    trainer.extend(training.extensions.Evaluator(test_iter, model, eval_hook=eval_hook))

    trainer.extend(training.extensions.LogReport(), trigger=(100, "epoch"))
    trainer.extend(training.extensions.PrintReport(["epoch", "main/loss", "validation/main/loss", "main/accuracy", "validation/main/accuracy", ]), trigger=(1, "epoch"))
    trainer.extend(training.extensions.snapshot(), trigger=(50, "epoch"))
    if training.extensions.PlotReport.available():
        trainer.extend(training.extensions.PlotReport(["main/loss", "validation/main/loss", ], "epoch", file_name="loss.png"))
        trainer.extend(training.extensions.PlotReport(["main/accuracy", "validation/main/accuracy", ], "epoch", file_name="accuracy.png"))
    trainer.extend(training.extensions.ProgressBar(), trigger=(10, "epoch"))
    if args.resume is not None:
        serializers.load_npz(args.resume, trainer)

    trainer.run()


In [None]:
def test(model, args):
    while True:
        a = int(input("a="))
        b = int(input("b="))
        c = model.predict(encode_in([a], [b], args.number_width * 2 + 2))
        print("predicted c={}".format(decode_out(c[0].argmax(axis=1))))
        print(" computed c={}".format(a + b))


In [None]:
if __name__ == "__main__":
    args = parse_args([])
#     args = parse_args(["--resume", "result/snapshot_iter_500000", "--epochs", "10000", ])
    model, optimizer = init(args)
    train(model, optimizer, args)
    test(model, args)
