## Char-RNN

In this tutorial, we will build a char-rnn model for natural language generation. The training text is tokenized as a sequence of characters. After training, the model is able to output the probability distribution over the alphabet, therefore "predicting" the next character. By iterating this process, one can generate text snippets.

Char-RNN processes text sequences of arbitrary length, and the loss function makes use of ordinary Scala control-flow features during the training phase. Therefore it is an instance of dynamic neural network.

This implementation of Char-RNN is inspired by Andrej Karpathy's execellent blog post [The Unreasonable Effectiveness of Recurrent Neural Networks](https://karpathy.github.io/2015/05/21/rnn-effectiveness/) and [Python/numpy implementation](https://gist.github.com/karpathy/d4dee566867f8291f086).

## Importing dependencies

In [1]:
import $ivy.`org.nd4j:nd4j-native-platform:0.8.0`
import $ivy.`com.thoughtworks.deeplearning::plugins-cumulativedoublelayers:2.0.0-RC5`
import $ivy.`com.thoughtworks.deeplearning::plugins-doubletraining:2.0.0-RC5`
import $ivy.`com.thoughtworks.deeplearning::plugins-cumulativeindarraylayers:2.0.0-RC5`
import $ivy.`com.thoughtworks.deeplearning::plugins-indarrayweights:2.0.0-RC5`
import $ivy.`com.thoughtworks.deeplearning::plugins-indarrayliterals:2.0.0-RC5`
import $ivy.`com.thoughtworks.deeplearning::plugins-logging:2.0.0-RC5`

import java.io.PrintWriter
import scala.math
import collection.immutable.IndexedSeq
import scala.io.Source
import scala.concurrent.ExecutionContext.Implicits.global
import scalaz.concurrent.Task
import scalaz.std.iterable._
import scalaz.syntax.all._
import org.nd4j.linalg.factory.Nd4j
import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.ops.transforms.Transforms
import org.nd4j.linalg.api.ops.impl.indexaccum.IMax
import com.thoughtworks.deeplearning.plugins.DoubleLiterals
import com.thoughtworks.deeplearning.plugins.INDArrayLiterals
import com.thoughtworks.deeplearning.plugins.CumulativeDoubleLayers
import com.thoughtworks.deeplearning.plugins.DoubleTraining
import com.thoughtworks.deeplearning.plugins.CumulativeINDArrayLayers
import com.thoughtworks.deeplearning.plugins.INDArrayWeights
import com.thoughtworks.deeplearning.plugins.Operators
import com.thoughtworks.deeplearning.plugins.Logging
import com.thoughtworks.feature.Factory

[32mimport [39m[36m$ivy.$                                    
[39m
[32mimport [39m[36m$ivy.$                                                                        
[39m
[32mimport [39m[36m$ivy.$                                                                
[39m
[32mimport [39m[36m$ivy.$                                                                          
[39m
[32mimport [39m[36m$ivy.$                                                                 
[39m
[32mimport [39m[36m$ivy.$                                                                  
[39m
[32mimport [39m[36m$ivy.$                                                         

[39m
[32mimport [39m[36mjava.io.PrintWriter
[39m
[32mimport [39m[36mscala.math
[39m
[32mimport [39m[36mcollection.immutable.IndexedSeq
[39m
[32mimport [39m[36mscala.io.Source
[39m
[32mimport [39m[36mscala.concurrent.ExecutionContext.Implicits.global
[39m
[32mimport [39m[36mscalaz.concurrent.Task
[39m

## Preparing the corpus, setting up plugins & parameters

In [2]:
val data = "DeepLearning.scala"
val dataSize = data.size

val ixToChar = data.toSet.toArray
val charToIx = (for (i <- ixToChar.indices) yield (ixToChar(i), i)).toMap
val vocabSize = ixToChar.size

def oneOfK(c: Char) = Nd4j.zeros(vocabSize, 1).putScalar(charToIx(c), 1)

[36mdata[39m: [32mString[39m = [32m"DeepLearning.scala"[39m
[36mdataSize[39m: [32mInt[39m = [32m18[39m
[36mixToChar[39m: [32mArray[39m[[32mChar[39m] = [33mArray[39m([32m'e'[39m, [32m's'[39m, [32m'n'[39m, [32m'.'[39m, [32m'a'[39m, [32m'i'[39m, [32m'L'[39m, [32m'g'[39m, [32m'l'[39m, [32m'p'[39m, [32m'c'[39m, [32m'r'[39m, [32m'D'[39m)
[36mcharToIx[39m: [32mMap[39m[[32mChar[39m, [32mInt[39m] = [33mMap[39m(
  [32m'e'[39m -> [32m0[39m,
  [32m's'[39m -> [32m1[39m,
  [32m'n'[39m -> [32m2[39m,
  [32m'.'[39m -> [32m3[39m,
  [32m'a'[39m -> [32m4[39m,
  [32m'i'[39m -> [32m5[39m,
  [32m'L'[39m -> [32m6[39m,
  [32m'g'[39m -> [32m7[39m,
  [32m'l'[39m -> [32m8[39m,
  [32m'p'[39m -> [32m9[39m,
  [32m'c'[39m -> [32m10[39m,
[33m...[39m
[36mvocabSize[39m: [32mInt[39m = [32m13[39m
defined [32mfunction[39m [36moneOfK[39m

In [3]:
trait LearningRate extends INDArrayWeights {
    val learningRate: Double
    
    trait INDArrayOptimizerApi extends super.INDArrayOptimizerApi { this: INDArrayOptimizer =>
      override def delta: INDArray = super.delta mul learningRate
    }
    override type INDArrayOptimizer <: INDArrayOptimizerApi with Optimizer
  }

trait Adagrad extends INDArrayWeights {
    val eps: Double
    
    trait INDArrayWeightApi extends super.INDArrayWeightApi { this: INDArrayWeight =>
      var cache: Option[INDArray] = None
    }

    override type INDArrayWeight <: INDArrayWeightApi with Weight

    trait INDArrayOptimizerApi extends super.INDArrayOptimizerApi { this: INDArrayOptimizer =>
      private lazy val deltaLazy: INDArray = {
        import org.nd4s.Implicits._
        import weight._
        val delta0 = super.delta
        cache = Some(cache.getOrElse(Nd4j.zeros(delta0.shape: _*)) + delta0 * delta0)
        delta0 / (Transforms.sqrt(cache.get) + eps)
      }
      override def delta = deltaLazy
    }
    override type INDArrayOptimizer <: INDArrayOptimizerApi with Optimizer
  }

defined [32mtrait[39m [36mLearningRate[39m
defined [32mtrait[39m [36mAdagrad[39m

In [5]:
interp.load("""
  val hyperparameters = Factory[Adagrad with LearningRate with DoubleTraining with CumulativeDoubleLayers with CumulativeINDArrayLayers with Operators with INDArrayLiterals with DoubleLiterals with Logging].newInstance(learningRate = 0.1, eps=1e-8)
""")

In [6]:
import hyperparameters.INDArrayWeight
import hyperparameters.DoubleLayer
import hyperparameters.INDArrayLayer
import hyperparameters.implicits._

[32mimport [39m[36mhyperparameters.INDArrayWeight
[39m
[32mimport [39m[36mhyperparameters.DoubleLayer
[39m
[32mimport [39m[36mhyperparameters.INDArrayLayer
[39m
[32mimport [39m[36mhyperparameters.implicits._[39m

In [7]:
val hiddenSize = 100 // 100
val seqLength = 25

val wxh = {
    import org.nd4s.Implicits._
    INDArrayWeight(Nd4j.randn(hiddenSize, vocabSize) * 0.01)
}

val whh = {
    import org.nd4s.Implicits._
    INDArrayWeight(Nd4j.randn(hiddenSize, hiddenSize) * 0.01)
}

val why = {
    import org.nd4s.Implicits._
    INDArrayWeight(Nd4j.randn(vocabSize, hiddenSize) * 0.01)
}

val bh = INDArrayWeight(Nd4j.zeros(hiddenSize, 1))
val by = INDArrayWeight(Nd4j.zeros(vocabSize, 1))

SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder".
SLF4J: Defaulting to no-operation (NOP) logger implementation
SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.


[36mhiddenSize[39m: [32mInt[39m = [32m100[39m
[36mseqLength[39m: [32mInt[39m = [32m25[39m
[36mwxh[39m: [32mObject[39m with [32mhyperparameters[39m.[32mINDArrayWeightApi[39m with [32mhyperparameters[39m.[32mWeightApi[39m with [32mhyperparameters[39m.[32mWeightApi[39m with [32mhyperparameters[39m.[32mINDArrayWeightApi[39m = $sess.cmd4Wrapper$Helper$Anonymous$macro$1$1$Anonymous$macro$34$1@45e200c1
[36mwhh[39m: [32mObject[39m with [32mhyperparameters[39m.[32mINDArrayWeightApi[39m with [32mhyperparameters[39m.[32mWeightApi[39m with [32mhyperparameters[39m.[32mWeightApi[39m with [32mhyperparameters[39m.[32mINDArrayWeightApi[39m = $sess.cmd4Wrapper$Helper$Anonymous$macro$1$1$Anonymous$macro$34$1@79848e87
[36mwhy[39m: [32mObject[39m with [32mhyperparameters[39m.[32mINDArrayWeightApi[39m with [32mhyperparameters[39m.[32mWeightApi[39m with [32mhyperparameters[39m.[32mWeightApi[39m with [32mhyperparameters[39m.[32mINDArrayWei

## Implementing the neural network

In [8]:
def tanh(x: INDArrayLayer): INDArrayLayer = {
  val exp_x = hyperparameters.exp(x)
  val exp_nx = hyperparameters.exp(-x)
  (exp_x - exp_nx) / (exp_x + exp_nx)
}

defined [32mfunction[39m [36mtanh[39m

In [9]:
def charRNN(x: INDArray, y: INDArray, hprev: INDArrayLayer): (DoubleLayer, INDArrayLayer, INDArrayLayer) = {
    val hnext = tanh(wxh.dot(x) + whh.dot(hprev) + bh)
    val yraw = why.dot(hnext) + by
    val yraw_exp = hyperparameters.exp(yraw)
    val prob = yraw_exp / yraw_exp.sum
    val loss = -hyperparameters.log((prob * y).sum)
    (loss, prob, hnext)
}

defined [32mfunction[39m [36mcharRNN[39m

In [10]:
val batches = data.zip(data.tail).grouped(seqLength).toVector

type WithHiddenLayer[A] = (A, INDArrayLayer)
type Batch = IndexedSeq[(Char, Char)]
type Losses = Vector[Double]

val h = new PrintWriter("/tmp/log.txt")

def singleBatch(batch: WithHiddenLayer[Batch]): WithHiddenLayer[DoubleLayer] = {
  batch match {
    case (batchseq, hprev) => batchseq.foldLeft((DoubleLayer(0.0.forward), hprev)) {
      (bstate: WithHiddenLayer[DoubleLayer], xy: (Char, Char)) =>
        (bstate, xy) match {
          case ((tot, localhprev), (x, y)) => {
            charRNN(oneOfK(x), oneOfK(y), localhprev) match {
              case (localloss, _, localhnext) => {
                (tot + localloss, localhnext)
              }
            }
          }
        }
    }
  }
}

def initH = INDArrayLayer(Nd4j.zeros(hiddenSize, 1).forward)

def singleRound(initprevloss: Losses): Task[Losses] =
  (batches.foldLeftM((initprevloss, initH)) {
    (bstate: WithHiddenLayer[Losses], batch: Batch) =>
      bstate match {
        case (prevloss, hprev) => singleBatch(batch, hprev) match {
          case (bloss, hnext) => bloss.train.map {
            (blossval: Double) => {
                val nloss = prevloss.last * 0.999 + blossval * 0.001
                h.println(nloss)
                h.flush()
                (prevloss :+ prevloss.last * 0.999 + blossval * 0.001, hnext)
            }
          }
        }
      }
  }).map {
    (fstate: WithHiddenLayer[Losses]) =>
      fstate match {
        case (floss, _) => floss
      }
  }

def allRounds: Task[Losses] = (0 until 1024).foldLeftM(Vector(-math.log(1.0 / vocabSize) * seqLength)) {
  (ploss: Losses, round: Int) => singleRound(ploss)
}

[36mbatches[39m: [32mVector[39m[[32mIndexedSeq[39m[([32mChar[39m, [32mChar[39m)]] = [33mVector[39m(
  [33mVector[39m(
    ([32m'D'[39m, [32m'e'[39m),
    ([32m'e'[39m, [32m'e'[39m),
    ([32m'e'[39m, [32m'p'[39m),
    ([32m'p'[39m, [32m'L'[39m),
    ([32m'L'[39m, [32m'e'[39m),
    ([32m'e'[39m, [32m'a'[39m),
    ([32m'a'[39m, [32m'r'[39m),
    ([32m'r'[39m, [32m'n'[39m),
    ([32m'n'[39m, [32m'i'[39m),
    ([32m'i'[39m, [32m'n'[39m),
[33m...[39m
defined [32mtype[39m [36mWithHiddenLayer[39m
defined [32mtype[39m [36mBatch[39m
defined [32mtype[39m [36mLosses[39m
[36mh[39m: [32mPrintWriter[39m = java.io.PrintWriter@72408227
defined [32mfunction[39m [36msingleBatch[39m
defined [32mfunction[39m [36minitH[39m
defined [32mfunction[39m [36msingleRound[39m
defined [32mfunction[39m [36mallRounds[39m

## Training the model and using it to generate text

In [11]:
println(allRounds.unsafePerformSync)

Jul 20, 2017 10:26:31 AM $sess.cmd8Wrapper$Helper hnext
SEVERE: An exception is thrown in layer $sess.cmd4Wrapper$Helper$Anonymous$macro$1$1$Anonymous$macro$20$1@9a7486
java.lang.ClassCastException: org.bytedeco.javacpp.indexer.IntRawIndexer cannot be cast to org.bytedeco.javacpp.indexer.UByteRawIndexer
	at org.nd4j.linalg.api.buffer.BaseDataBuffer.getInt(BaseDataBuffer.java:892)
	at org.nd4j.linalg.api.ndarray.BaseNDArray.isScalar(BaseNDArray.java:1720)
	at org.nd4j.linalg.api.ndarray.BaseNDArray.size(BaseNDArray.java:4215)
	at org.nd4j.linalg.api.shape.Shape.newShapeNoCopy(Shape.java:956)
	at org.nd4j.linalg.api.ndarray.BaseNDArray.reshape(BaseNDArray.java:3617)
	at org.nd4j.linalg.api.ndarray.BaseNDArray.reshape(BaseNDArray.java:3672)
	at com.thoughtworks.deeplearning.plugins.INDArrayLayers$Nd4jIssues1869Workaround.broadcastFix(INDArrayLayers.scala:71)
	at com.thoughtworks.deeplearning.plugins.INDArrayLayers$INDArrayLayer$$anonfun$binary$1$$anonfun$apply$23$$anonfun$com$thoughtworks

Vector(64.12373393653841, 64.10321271971497, 64.15078270616212, 64.16563069302275, 64.15465896366872, 64.1187209161055, 64.07121542779883, 64.0169910551045, 63.958908847815316, 63.89806231032316, 63.83592879173743, 63.773466675775204, 63.71086203976928, 63.64818259602306, 63.585463626825266, 63.522726950443506, 63.45998742176682, 63.397255790203076, 63.33454016678661, 63.271846862216336, 63.20918090523481, 63.14654637635879, 63.08394663942101, 63.02138450588179, 62.958862346800174, 62.896382184793026, 62.83394575774844, 62.771554568774, 62.70920992891483, 62.64691298697919, 62.584664754102036, 62.52246612551141, 62.46031789538583, 62.39822077297658, 62.336175392570574, 62.27418232447379, 62.212242079973166, 62.150355122048964, 62.08852186812499, 62.02674269689071, 61.96501795250977, 61.903347946765514, 61.841732964351884, 61.780173263680254, 61.718669081155795, 61.65722063295011, 61.59582811595261, 61.53449170938422, 61.473211577935146, 61.411987872801255, 61.35082073084857, 61.2897102

9524, 39.96331483370778, 39.92338049873641, 39.88348603822809, 39.8436314124871, 39.80381658269409, 39.76404150833733, 39.72430615055883, 39.68461046999937, 39.64495442716073, 39.60533798228743, 39.56576109661742, 39.526223730174074, 39.486725843737204, 39.44726739860309, 39.4078483547942, 39.368468673804934, 39.329128316152946, 39.28982724239669, 39.25056541384889, 39.21134279192267, 39.17215933705309, 39.133015010431286, 39.09390977281007, 39.05484358581811, 39.015816410703245, 38.97682820851332, 38.937878940336255, 38.898968567595546, 38.860097051755595, 38.82126435366122, 38.78247043539027, 38.74371525750648, 38.70499878204641, 38.66632097018787, 38.62768178398482, 38.58908118439326, 38.550519133424785, 38.511995592948495, 38.47351052421644, 38.43506388875598, 38.39665564902977, 38.35828576670249, 38.319954203238304, 38.28166092031921, 38.24340588097745, 38.20518904595495, 38.167010377525685, 38.12886983764262, 38.09076738877457, 38.052702992711446, 38.014676611878116, 37.976688208

879027317548697, 24.854163156408593, 24.829323844688194, 24.804509357299377, 24.77971966977584, 24.75495475725826, 24.73021459515047, 24.70549915810568, 24.680808422054984, 24.656142362237727, 24.631500953917946, 24.606884172921824, 24.582291993607555, 24.55772439179137, 24.533181343730682, 24.508662824096437, 24.484168808240714, 24.4596992721371, 24.435254190709326, 24.410833540039526, 24.386437295876423, 24.36206543405181, 24.337717929706113, 24.313394758720587, 24.28909589676193, 24.264821319461316, 24.240571002713178, 24.216344922495566, 24.1921430541545, 24.16796537359756, 24.143811856994617, 24.119682479883718, 24.09557721794614, 24.071496047544166, 24.047438944765446, 24.023405885184033, 23.999396844578243, 23.97541179910773, 23.951450724956448, 23.927513598570854, 23.903600395048517, 23.879711091182738, 23.85584566319307, 23.83200408684559, 23.808186338109756, 23.78439239369477, 23.760622229677292, 23.73687582239617, 23.71315314755763, 23.689454182205104, 23.665778902629018, 23

In [12]:
def jump[A](a: A)(implicit executionContext: scala.concurrent.ExecutionContext): Task[A] = {
    import scalaz._
    Task.async { handler: ((Throwable \/ A) => Unit) =>
      executionContext.execute {
        new Runnable {
          override def run(): Unit = handler(\/-(a))
        }
      }
    }
  }

defined [32mfunction[39m [36mjump[39m

In [13]:
def genIdx(v: INDArray): Int = Nd4j.getExecutioner().execAndReturn(new IMax(v)).getFinalResult()

def generate(seed: Char, n: Int): Task[String] = ((0 until n).foldLeftM((seed.toString, initH)) {
  (st: (String, INDArrayLayer), i: Int) =>
    st match {
      case (tot, hprev) => {
        val x = oneOfK(tot.last)
        charRNN(x, x, hprev) match {
          case (_, prob, hnext) =>
              prob.predict.flatMap { (probv: INDArray) =>
                jump {
                val nidx = genIdx(probv)
                val nc = ixToChar(nidx)
                (tot + nc.toString, hnext)
              }
            }
        }
      }
    }
}).map { (st: (String, INDArrayLayer)) =>
  st match {
    case (r, _) => r
  }
}

defined [32mfunction[39m [36mgenIdx[39m
defined [32mfunction[39m [36mgenerate[39m

In [15]:
generate('D', 32).unsafePerformSync

[36mres14[39m: [32mString[39m = [32m"DeepLearning.scalaarning.rniee.rn"[39m