# Deep Learning

This notebook serves as the supporting material for the chapter **Deep Learning**. In this notebook, we'll learn different activation funtions. Then we'll create a deep neural network using Deeplearning4j and train a model capable of classifying random handwriting digits. 

>_"While handwriting recognition has been attempted by different machine learning algorithms over the years, deep learning performs remarkably well and achieves an accuracy of over 99.7% on the MNIST dataset."_ 

So, let's begin...

In [1]:
%%classpath add mvn
org.nd4j nd4j-native-platform 0.9.1
org.deeplearning4j deeplearning4j-core 0.9.1
org.datavec datavec-api 0.9.1
org.datavec datavec-local 0.9.1
org.datavec datavec-dataframe 0.9.1
org.bytedeco javacpp 1.5
org.apache.httpcomponents httpclient 4.3.5
org.deeplearning4j deeplearning4j-ui_2.11 0.9.1

## Activation Functions

### 1.) Saturating activation funtion

In [3]:
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.api.iter.NdIndexIterator;

import java.util.ArrayList;
import java.util.List;

INDArray array = Nd4j.linspace(-5,5,200);
INDArray sigmoid = Transforms.sigmoid(array);

def ch = new Crosshair(color: Color.gray, width: 2, style: StrokeType.DOT);
p1 = new Plot(title: "Sigmoid activation function", crosshair: ch);
p1 << new ConstantLine(x: 0, y: 0, color: Color.black);
p1 << new ConstantLine(y: 1, color: Color.black, style: StrokeType.DOT);
p1 << new Line(x: [-5, 5], y: [-3/4, 7/4], style: StrokeType.DASH, color: Color.green);
p1 << new Line(x: toDoubleArrayList(array), y: toDoubleArrayList(sigmoid), displayName: "Sigmoid", color: Color.blue, width: 3);
p1 << new Text(x: 0, y: 0.5, text: "Linear", pointerAngle: 3.505);
p1 << new Text(x: -5, y: 0, text: "Saturating", pointerAngle: 1.57);
p1 << new Text(x: 5, y: 1, text: "Saturating", pointerAngle: 4.71);

public List<Double> toDoubleArrayList(INDArray array){
    NdIndexIterator iter = new NdIndexIterator(200);
    List<Double> list = new ArrayList<Double>();
    while (iter.hasNext()) {
        int[] nextIndex = iter.next();
        double nextVal = array.getDouble(nextIndex);
        list.add(nextVal);
    }
    return list;
}

### 2.) Nonsaturating Activation Functions

In [4]:
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.api.iter.NdIndexIterator;

import java.util.ArrayList;
import java.util.List;

INDArray array = Nd4j.linspace(-5,5,200);
INDArray relu = Transforms.relu(array);
INDArray leakyRelu = Transforms.leakyRelu(array);
INDArray elu = Transforms.elu(array);

def ch = new Crosshair(color: Color.gray, width: 2, style: StrokeType.DOT);
p1 = new Plot(title: "Non saturating activation function", crosshair: ch);
p1 << new ConstantLine(x: 0, y: 0, color: Color.black);
p1 << new ConstantLine(y: -1, color: Color.black, style: StrokeType.DOT);
p1.getYAxes()[0].setBound(-1.5,5);
p1 << new Line(x: toDoubleArrayList(array), y: toDoubleArrayList(elu), displayName: "ELU (α=1)", color: Color.red)
p1 << new Line(x: toDoubleArrayList(array), y: toDoubleArrayList(relu), displayName: "ReLU", color: Color.orange)
p1 << new Line(x: toDoubleArrayList(array), y: toDoubleArrayList(leakyRelu), displayName: "Leaky ReLU", color: Color.blue);
p1 << new Text(x: -5, y: 0, text: "Leak", pointerAngle: 1.57);


public List<Double> toDoubleArrayList(INDArray array){
    NdIndexIterator iter = new NdIndexIterator(200);
    List<Double> list = new ArrayList<Double>();
    while (iter.hasNext()) {
        int[] nextIndex = iter.next();
        double nextVal = array.getDouble(nextIndex);
        list.add(nextVal);
    }
    return list;
}

Let's train a neural network on MNIST using the Leaky ReLU. 

We've to create a DataUtils class first, containing methods required for downloading, extracting and deleting the dataset files. 

In [6]:
package aima.notebooks.deeplearning;

import org.apache.commons.compress.archivers.tar.TarArchiveEntry;
import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream;
import org.apache.http.HttpEntity;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClientBuilder;

import java.io.*;
import java.nio.file.*;
import java.nio.file.attribute.BasicFileAttributes;

public class DataUtils{
    
    public DataUtils(){}
    
    public boolean downloadFile(String remoteUrl, String localPath) throws IOException {
        boolean downloaded = false;
        if (remoteUrl == null || localPath == null)
            return downloaded;
        File file = new File(localPath);
        if (!file.exists()) {
            file.getParentFile().mkdirs();
            HttpClientBuilder builder = HttpClientBuilder.create();
            CloseableHttpClient client = builder.build();
            try {
                CloseableHttpResponse response = client.execute(new HttpGet(remoteUrl))
                HttpEntity entity = response.getEntity();
                if (entity != null) {
                    try {
                        FileOutputStream outstream = new FileOutputStream(file)
                        entity.writeTo(outstream);
                        outstream.flush();
                        outstream.close();
                    } catch(IOException e){
                        System.out.println(e);
                    }
                }
            } catch(IOException e){
                System.out.println(e);
            }
            downloaded = true;
        }
        if (!file.exists())
            throw new IOException("File doesn't exist: " + localPath);
        return downloaded;
    }
    public void extractTarGz(String inputPath, String outputPath) throws IOException {
        if (inputPath == null || outputPath == null)
            return;
        final int bufferSize = 4096;
        if (!outputPath.endsWith("" + File.separatorChar))
            outputPath = outputPath + File.separatorChar;
        try {
            TarArchiveInputStream tais = new TarArchiveInputStream(new GzipCompressorInputStream(new BufferedInputStream(new FileInputStream(inputPath))))
            TarArchiveEntry entry;
            while ((entry = (TarArchiveEntry) tais.getNextEntry()) != null) {
                if (entry.isDirectory()) {
                    new File(outputPath + entry.getName()).mkdirs();
                } else {
                    int count;
                    byte[] data = new byte[bufferSize];
                    FileOutputStream fos = new FileOutputStream(outputPath + entry.getName());
                    BufferedOutputStream dest = new BufferedOutputStream(fos, bufferSize);
                    while ((count = tais.read(data, 0, bufferSize)) != -1) {
                        dest.write(data, 0, count);
                    }
                    dest.close();
                }
            }
        } catch(IOException e){
            System.out.println(e);
        }
    }
    public void deleteDir(String path) throws IOException{
        Path directory = Paths.get(path);
        Files.walkFileTree(directory, new SimpleFileVisitor<Path>() {
            @Override
            public FileVisitResult visitFile(Path file, BasicFileAttributes attributes) throws IOException {
                Files.delete(file); // this will work because it's always a File
                return FileVisitResult.CONTINUE;
            }

            @Override
            public FileVisitResult postVisitDirectory(Path dir, IOException exc) throws IOException {
                Files.delete(dir); //this will work because Files in the directory are already deleted
                return FileVisitResult.CONTINUE;
            }
        });
    }
}

null

Now let's download the MNIST dataset.

In [11]:
import aima.notebooks.deeplearning.DataUtils;
import java.io.File;

String DATA_URL = "http://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz";
String BASE_PATH = "./assets";
String localFilePath = BASE_PATH + "/mnist_png.tar.gz";
DataUtils dataUtils = new DataUtils();
if (!new File(localFilePath).exists()) {
    if (dataUtils.downloadFile(DATA_URL, localFilePath)) {
        dataUtils.extractTarGz(localFilePath, BASE_PATH);
    }
}

null

In [6]:
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.optimize.listeners.EvaluativeListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.optimize.api.InvocationType;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.stats.StatsListener;
import org.deeplearning4j.ui.storage.FileStatsStorage;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.lossfunctions.LossFunctions;

import java.io.File;
import java.util.Random;

int seed = 123;
double learningRate = 0.01;
int batchSize = 100;
int numEpochs = 1;

int height = 28;
int width = 28;
int channels = 1;
int numInput = height * width;
int numHidden = 1000;
int numOutput = 10;

//Prepare data for loading
File trainData = new File("./assets/mnist_png/training");
FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, new Random(seed));
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); // use parent directory name as the image label
ImageRecordReader trainRR = new ImageRecordReader(height, width, 1, labelMaker);
trainRR.initialize(trainSplit);
DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRR, batchSize, 1, numOutput);
DataNormalization imageScaler = new ImagePreProcessingScaler();
imageScaler.fit(trainIter);
trainIter.setPreProcessor(imageScaler);


File testData = new File("./assets/mnist_png/testing");
FileSplit testSplit = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, new Random(seed));
ImageRecordReader testRR = new ImageRecordReader(height, width, 1, labelMaker);
testRR.initialize(testSplit);
DataSetIterator testIter = new RecordReaderDataSetIterator(testRR, batchSize, 1, numOutput);
testIter.setPreProcessor(imageScaler);

//Build the neural network
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
        .seed(seed)
        .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
        .updater(Updater.ADAM)
        .list()
        .layer(0, new DenseLayer.Builder()
                .nIn(numInput)
                .nOut(numHidden)
                .activation(Activation.RELU)
                .weightInit(WeightInit.XAVIER)
                .build())
        .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                .nIn(numHidden)
                .nOut(numOutput)
                .activation(Activation.SOFTMAX)
                .weightInit(WeightInit.XAVIER)
                .build())
        .setInputType(InputType.convolutional(height, width, channels))
        .build();

UIServer uiServer = UIServer.getInstance();
StatsStorage statsStorage = new FileStatsStorage(new File(System.getProperty("java.io.tmpdir"), "ui-dnn-stats.dl4j"));
uiServer.attach(statsStorage);

MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(10), new StatsListener(statsStorage), new EvaluativeListener(testIter, 1, InvocationType.EPOCH_END));


//Train the model and evaluate
for (int i = 0; i < numEpochs; i++) {
    model.fit(trainIter);
    System.out.println("********Evaluation Stats*********");
    Evaluation eval = model.evaluate(testIter);
    System.out.println(eval.stats());

    trainIter.reset();
    testIter.reset();
}

System.out.println("********Example finished*********");

********Evaluation Stats*********

Examples labeled as 0 classified by model as 0: 960 times
Examples labeled as 0 classified by model as 2: 3 times
Examples labeled as 0 classified by model as 3: 2 times
Examples labeled as 0 classified by model as 4: 2 times
Examples labeled as 0 classified by model as 5: 1 times
Examples labeled as 0 classified by model as 6: 2 times
Examples labeled as 0 classified by model as 7: 5 times
Examples labeled as 0 classified by model as 8: 2 times
Examples labeled as 0 classified by model as 9: 3 times
Examples labeled as 1 classified by model as 1: 1124 times
Examples labeled as 1 classified by model as 2: 2 times
Examples labeled as 1 classified by model as 3: 1 times
Examples labeled as 1 classified by model as 4: 1 times
Examples labeled as 1 classified by model as 5: 1 times
Examples labeled as 1 classified by model as 6: 3 times
Examples labeled as 1 classified by model as 8: 3 times
Examples labeled as 2 classified by model as 0: 3 times
Examples

null

## Convolutional networks

Convolutional neural networks are the specialized models that are highly efficient for processing information that can be represented in terms of measurements on a grid. This includes images, which are measurements of brightness on a two-dimensional grid, audio waveforms, which can be regarded as a one-dimensional grid across time, and three-dimensional grid data such as 3-D scans used in medical imaging. For a convolutional network, we use 4-dimensional arrays (known as **feature map**) to keep track of the shape of the image. A feature map is split into several **channels**. Each channel describes how a single type of feature appears across the entire image. The feature map is of shape $m*h*w*c$ where:
* $m$ is the number of examples to process together in the same batch,
* $h$ is the height of the image,
* $w$ is the width of the image, and
* $c$ is the number of channels.

Now let's create a convolutional neural network and train it on the MNIST dataset.

In [2]:
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.InvocationType;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.lossfunctions.LossFunctions;

import java.io.File;
import java.util.Random;

int seed = 123;
double learningRate = 0.01;
int batchSize = 100;
int numEpochs = 1;

int height = 28;
int width = 28;
int channels = 1;
int numInput = height * width;
int numHidden = 1000;
int numOutput = 10;

//Prepare data for loading
File trainData = new File("./assets/mnist_png/training");
FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, new Random(seed));
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); // use parent directory name as the image label
ImageRecordReader trainRR = new ImageRecordReader(height, width, 1, labelMaker);
trainRR.initialize(trainSplit);
DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRR, batchSize, 1, numOutput);
DataNormalization imageScaler = new ImagePreProcessingScaler();
imageScaler.fit(trainIter);
trainIter.setPreProcessor(imageScaler);


File testData = new File("./assets/mnist_png/testing");
FileSplit testSplit = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, new Random(seed));
ImageRecordReader testRR = new ImageRecordReader(height, width, 1, labelMaker);
testRR.initialize(testSplit);
DataSetIterator testIter = new RecordReaderDataSetIterator(testRR, batchSize, 1, numOutput);
testIter.setPreProcessor(imageScaler);

//Build the neural network
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(seed)
            .updater(Updater.ADAM)
            .weightInit(WeightInit.XAVIER)
            .list()
            .layer(0, new ConvolutionLayer.Builder(5, 5)
                .nIn(channels)
                .stride(1, 1)
                .nOut(20)
                .activation(Activation.IDENTITY)
                .build())
            .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                .kernelSize(2, 2)
                .stride(2, 2)
                .build())
            .layer(2, new ConvolutionLayer.Builder(5, 5)
                .stride(1, 1) // nIn need not specified in later layers
                .nOut(50)
                .activation(Activation.IDENTITY)
                .build())
            .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                .kernelSize(2, 2)
                .stride(2, 2)
                .build())
            .layer(4, new DenseLayer.Builder().activation(Activation.RELU)
                .nOut(500)
                .build())
            .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                .nOut(numOutput)
                .activation(Activation.SOFTMAX)
                .build())
            .setInputType(InputType.convolutionalFlat(height, width, channels)) 
            .build();


MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();

//Train the model and evaluate
for (int i = 0; i < numEpochs; i++) {
    model.fit(trainIter);
    System.out.println("********Evaluation Stats*********");
    Evaluation eval = model.evaluate(testIter);
    System.out.println(eval.stats());

    trainIter.reset();
    testIter.reset();
}

System.out.println("********Example finished*********");

********Evaluation Stats*********

Examples labeled as 0 classified by model as 0: 967 times
Examples labeled as 0 classified by model as 2: 1 times
Examples labeled as 0 classified by model as 5: 2 times
Examples labeled as 0 classified by model as 7: 2 times
Examples labeled as 0 classified by model as 8: 2 times
Examples labeled as 0 classified by model as 9: 6 times
Examples labeled as 1 classified by model as 1: 1132 times
Examples labeled as 1 classified by model as 3: 1 times
Examples labeled as 1 classified by model as 5: 1 times
Examples labeled as 1 classified by model as 6: 1 times
Examples labeled as 2 classified by model as 0: 1 times
Examples labeled as 2 classified by model as 1: 2 times
Examples labeled as 2 classified by model as 2: 1015 times
Examples labeled as 2 classified by model as 3: 3 times
Examples labeled as 2 classified by model as 4: 1 times
Examples labeled as 2 classified by model as 5: 1 times
Examples labeled as 2 classified by model as 7: 7 times
Examp

null

We can now delete the MNIST dataset files as they are no longer required.

In [7]:
import aima.notebooks.deeplearning.DataUtils;
import java.io.File;

String DATA_URL = "http://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz";
String BASE_PATH = "./assets";
String localFilePath = BASE_PATH + "/mnist_png.tar.gz";

File file = new File(localFilePath);
file.delete();
DataUtils dataUtils = new DataUtils();
dataUtils.deleteDir(BASE_PATH + "/mnist_png");

null

## Recurrent Neural Networks

Recurrent neural networks are the networks that introduce the concept of time i.e they allow us to define the value of some variable $v$ at time step $t$ in terms of the values of this variable at previous time steps. For example, we can define an update rule: $v_{(t)} = f(v_{(t-1)})$ using some function $f$ of our choice. These networks are particularly well suited for sequence processing tasks as they allow us to operate over the **sequences of vectors**: sequences in the input, the output, or in the most general case both. In the last few years, there has been incredible success applying RNNs to a variety of problems such as speech recognition, language modeling, translation, image captioning and the list goes on. 

Here, we'll apply RNN to a simple problem of generating text character by character. So, let's start...

In [31]:
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.indexaccum.IMax;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.lossfunctions.LossFunctions;

import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.List;

int seed = 123;
int nHidden = 50;
int epochs = 15;

//Define a sentence to learn
//Add a dummy character in beginning so that the RNN learns the complete sentence.
char[] LEARNSTRING = "*The quick brown fox jumps over a lazy dog.".toCharArray();

LinkedHashSet<Character> LEARNSTRING_CHARS = new LinkedHashSet<>();
for (char c : LEARNSTRING) LEARNSTRING_CHARS.add(c);
List<Character> LEARNSTRING_CHARS_LIST = new ArrayList<>();
LEARNSTRING_CHARS_LIST.addAll(LEARNSTRING_CHARS);


//Build the neural network
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
        .seed(seed)
        .updater(Updater.ADAM)
        .weightInit(WeightInit.XAVIER)
        .list()
        .layer(0, new LSTM.Builder()
                .nIn(LEARNSTRING_CHARS.size())
                .nOut(nHidden)
                .activation(Activation.TANH)
                .build())
        .layer(1, new RnnOutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                .activation(Activation.SOFTMAX)
                .nIn(nHidden)
                .nOut(LEARNSTRING_CHARS.size())
                .build())
        .build();

MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();

//Create our training data
int[] shape = [1, LEARNSTRING_CHARS_LIST.size(), LEARNSTRING.length]
INDArray input = Nd4j.zeros(shape);
INDArray labels = Nd4j.zeros(shape);

int pos = 0;
for (char currChar : LEARNSTRING) {
    char nextChar = LEARNSTRING[(pos + 1) % (LEARNSTRING.length)]; //When currChar is the last, take the first character as nextChar.
    // Input neuron for current character is 1 at "pos"
    int[] inputArr = [0, LEARNSTRING_CHARS_LIST.indexOf(currChar), pos];
    input.putScalar(inputArr, 1);

    // Output neuron for next character is 1 at "pos"
    int[] labelArr = [0, LEARNSTRING_CHARS_LIST.indexOf(nextChar), pos];
    labels.putScalar(labelArr, 1);
    pos++;
}

DataSet trainingData = new DataSet(input, labels);

//Train the model and evaluate
for (int i = 0; i < epochs; i++) {
    model.fit(trainingData);
    model.rnnClearPreviousState();

    System.out.print("Epoch " + i + " completed. Sample:\t");
    //Evaluate
    //Put the first character into RNN as an initialisation
    int[] testShape = [1, LEARNSTRING_CHARS_LIST.size(), 1]
    INDArray testInit = Nd4j.zeros(testShape);
    testInit.putScalar(LEARNSTRING_CHARS_LIST.indexOf(LEARNSTRING[0]), 1);

    INDArray output = model.rnnTimeStep(testInit);
    //output now contains the highest value neuron at such a position which is the index of a character, which the model thinks should come next
    //now the model should guess (LEARNSTRING.length - 1) more characters...

    for (int j = 0; j < LEARNSTRING.length - 1; j++) {

        //First let's process the last output of the model.
        int sampledCharacterIndex = Nd4j.getExecutioner().exec(new IMax(output, null, 1),1).getAt(0);
        System.out.print(LEARNSTRING_CHARS_LIST.get(sampledCharacterIndex));

        //Use the last output as next input
        int[] nextInputShape = [1, LEARNSTRING_CHARS_LIST.size(), 1];
        INDArray nextInput = Nd4j.zeros(nextInputShape);
        nextInput.putScalar(sampledCharacterIndex, 1);
        output = model.rnnTimeStep(nextInput);
    }
    System.out.println();
}

Epoch 0 completed. Sample:	                                          
Epoch 1 completed. Sample:	                                          
Epoch 2 completed. Sample:	e   aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
Epoch 3 completed. Sample:	eequuikk uukuukkuukuukuukuukkuukuukuukuukk
Epoch 4 completed. Sample:	The    orororororororororororororororororo
Epoch 5 completed. Sample:	The     o oooooooooooooooooooooooooooooooo
Epoch 6 completed. Sample:	Thhhq       o ogg.g.*jupssooo ogg.jmpssooo
Epoch 7 completed. Sample:	Theqqucccc    o o o o o o o o o o o o o o 
Epoch 8 completed. Sample:	Thequiccck  o o o o o o o o o o o o o o o 
Epoch 9 completed. Sample:	The quick brove o ox jumps ove o o ox jump
Epoch 10 completed. Sample:	The quick br o  ox jumps ove a ove a ove a
Epoch 11 completed. Sample:	The quick br a lazy dog.***Tee  uuick br a
Epoch 12 completed. Sample:	The quick brown fox jumps over a lazy dog.
Epoch 13 completed. Sample:	The quick brown fox jumps over a lazy dog.
Epoch 14 complet

null