In [9]:
%matplotlib inline

import numpy as np
import chainer
from chainer import cuda, Function, gradient_check, report, training, utils, Variable
from chainer import datasets, iterators, optimizers, serializers
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L
from chainer.training import extensions
from chainer.dataset import concat_examples
from chainer.datasets import TupleDataset
from chainer.functions.evaluation import accuracy
from chainer.functions.loss import softmax_cross_entropy
from chainer import link
from chainer import reporter

import matplotlib.pyplot as plt

global HIDDEN_UNITS
HIDDEN_UNITS = 3

In [10]:
# create toy data - compute sum of the previous and current input
def create_data(n=3000):

    X = np.random.rand(n,1).astype('float32')
    T = np.sum(np.hstack((X[0:-1],X[1:])),axis=1)
    T = np.hstack([0, T[0:]]).astype('float32')
    T = T.reshape([n,1])

    return TupleDataset(X, T)

In [11]:
class RNN(Chain):
    def __init__(self):
        super(RNN, self).__init__()
        with self.init_scope():
            self.l1 = L.LSTM(None, HIDDEN_UNITS)
            self.out = L.Linear(HIDDEN_UNITS, 1)
            
    def __call__(self, x):
        h1 = self.l1(x)
        y = self.out(h1)
        return y

In [57]:
class Regressor(link.Chain):

    compute_accuracy = False

    def __init__(self, predictor,
                 lossfun=softmax_cross_entropy.softmax_cross_entropy,
                 accfun=accuracy.accuracy,
                 label_key=-1):
        if not (isinstance(label_key, (int, str))):
            raise TypeError('label_key must be int or str, but is %s' %
                            type(label_key))

        super(Regressor, self).__init__()
        self.lossfun = lossfun
        self.accfun = accfun
        self.y = None
        self.loss = None
        self.accuracy = None
        self.label_key = label_key

        with self.init_scope():
            self.predictor = predictor

    def __call__(self, *args, **kwargs):

        if isinstance(self.label_key, int):
            if not (-len(args) <= self.label_key < len(args)):
                msg = 'Label key %d is out of bounds' % self.label_key
                raise ValueError(msg)
            t = args[self.label_key]
            if self.label_key == -1:
                args = args[:-1]
            else:
                args = args[:self.label_key] + args[self.label_key + 1:]
        elif isinstance(self.label_key, str):
            if self.label_key not in kwargs:
                msg = 'Label key "%s" is not found' % self.label_key
                raise ValueError(msg)
            t = kwargs[self.label_key]
            del kwargs[self.label_key]

        self.y = None
        self.loss = None
        self.accuracy = None
        self.y = self.predictor(*args, **kwargs)
        self.loss = self.lossfun(self.y, t)
        reporter.report({'loss': self.loss}, self)
        return self.loss

In [None]:
class Iterator(data):

    def __init__(self):
        self.data = data
        self.done = False
        
    def next(self):
        # go on

In [62]:
def run(train_iter, val_iter, test_data, network, regressor, optimizer, max_epoch):
    training_losses = []
    validation_losses = []
    
    mini_batch_losses = []

    while train_iter.epoch < max_epoch:
        # Get next mini-batch
        batch = train_iter.next()
        image_train, target_train = concat_examples(batch)

        # Prediction
        prediction_train = network(image_train)

        # Compute loss
        loss = regressor(prediction_train, target_train)
        mini_batch_losses.append(loss.data)

        # Compute gradients
        network.cleargrads()
        loss.backward()

        # Update variables
        optimizer.update()

        # Check the validation accuracy of prediction after every epoch
        if train_iter.is_new_epoch:  # If this iteration is the final iteration of the current epoch

            # Save the training loss
            training_losses.append(np.mean(mini_batch_losses))
            mini_batch_losses = []

            val_losses = []
            val_accuracies = []
            while True:
                val_batch = val_iter.next()
                image_val, target_val = concat_examples(val_batch)

                # Forward the validation data
                prediction_val = network(image_val)

                # Calculate the loss
                loss_val = regressor(prediction_val, target_val)
                val_losses.append(loss_val.data)

                if val_iter.is_new_epoch:
                    val_iter.epoch = 0
                    val_iter.current_position = 0
                    val_iter.is_new_epoch = False
                    val_iter._pushed_position = None

                    validation_losses.append(np.mean(val_losses))
                    break

    # Predict full test set
    image_test, target_test = concat_examples(test_data)
    # Forward test data
    prediction_test = network(image_test)
    # Calculate loss and accuracy
    loss_test = regressor(prediction_test, target_test)

    print('test_loss: ' + str(loss_test.data))
    return training_losses, validation_losses

In [66]:
def main():
    # Create data
    train = create_data()
    test = create_data()
    
    # TODO define an own iterator which loops over the whole training set during the epoch
    
    # Initialize iterators
    train_iter = iterators.SerialIterator(train, batch_size=1, shuffle=False, repeat=False)
    val_iter = iterators.SerialIterator(test, batch_size=1, repeat=False, shuffle=False)

    # Define model
    network = RNN()
    regressor = Regressor(network, F.mean_squared_error)
    optimizer = optimizers.SGD()
    optimizer.setup(network)
    
    training_losses, validation_losses = run(train_iter, val_iter, test, network, regressor, optimizer, 1)
    
    plt.plot(training_losses, label='Training loss')
    plt.plot(validation_losses, label='Validation loss')
    plt.legend()

if __name__ == "__main__":
    main()

TypeError: The batch size of x must be equal to or less thanthe size of the previous state h.