In [62]:
%matplotlib inline

import numpy as np
import chainer
from chainer import cuda, Function, gradient_check, report, training, utils, Variable
from chainer import datasets, iterators, optimizers, serializers
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L
from chainer.training import extensions
from chainer.dataset import concat_examples
from chainer.datasets import TupleDataset

import matplotlib.pyplot as plt

global HIDDEN_UNITS
HIDDEN_UNITS = 3

In [67]:
# create toy data - compute sum of the previous and current input
def create_data(n=3000):

    X = np.random.rand(n,1).astype('float32')
    T = np.sum(np.hstack((X[0:-1],X[1:])),axis=1)
    T = np.hstack([0, T[0:]]).astype('float32')
    T = T.reshape([n,1])

    return TupleDataset(X, T)

In [68]:
class RNN(Chain):
    def __init__(self):
        super(RNN, self).__init__()
        with self.init_scope():
            self.l1 = L.LSTM(None, HIDDEN_UNITS)
            self.out = L.Linear(HIDDEN_UNITS, 1)
            
    def __call__(self, x):
        h1 = self.l1(x)
        y = self.out(h1)
        return y

In [69]:
def main():
    # Create data
    train = create_data()
    test = create_data()
    
    # Initialize iterators
    train_iter = iterators.SerialIterator(train, batch_size=16, shuffle=False, repeat=False)
    val_iter = iterators.SerialIterator(test, batch_size=16, repeat=False, shuffle=False)

    # Define model
    model = RNN()
    optimizer = optimizers.SGD()
    optimizer.setup(model)
    
#     Define regressor and use absulute mean square error as loss function

    training_losses, validation_losses = run(train_iter, val_iter, test, model, optimizer, 20)
    
    plt.plot(training_losses, label='Training loss')
    plt.plot(validation_losses, label='Validation loss')
    plt.legend()

if __name__ == "__main__":
    main()

[[ 0.87997663]
 [ 0.11360273]
 [ 0.98734009]
 ..., 
 [ 0.63440347]
 [ 0.04962728]
 [ 0.33222422]]
[[ 0.        ]
 [ 0.99357939]
 [ 1.10094285]
 ..., 
 [ 0.94505721]
 [ 0.68403077]
 [ 0.38185149]]
3000
3000
[[ 0.65239382]
 [ 0.54007316]
 [ 0.25452092]
 ..., 
 [ 0.37505156]
 [ 0.70770723]
 [ 0.76447833]]
[[ 0.        ]
 [ 1.19246697]
 [ 0.79459405]
 ..., 
 [ 0.77999294]
 [ 1.08275878]
 [ 1.47218561]]
3000
3000
