# Character-level text generation with Flux

In this tutorial we'll use Flux, a new machine learning library in Julia, to generate nonsense text in the style of a given source.

The model works at the character level, meaning that we'll hand it a sequence of characters like "The qui" and it will predict `c`, followed by `k`, etc, figuring out that the most likely next word is "quick" – then it can go on to predict "brown", "fox" and so on, letter by letter.

Initially, the text will come from a training set, like the works of Shakespeare. But we can also feed the model's own predictions back into itself, allowing it to "dream" new data.

## using Flux

In [1]:
using Flux

Flux works with simple functions.

In [2]:
@net f(x) = x .* x

We can wrap those functions so that they will run on a backend, like MXNet.

In [3]:
fmx = mxnet(f)

MX.Model(Capacitor(...), CPU0)

Notice that inputs get converted to Float32s, MXNet's native format.

In [4]:
fmx([1,2,3])

3-element Array{Float32,1}:
 1.0
 4.0
 9.0

We can also use MXNet to take gradients.

In [5]:
Flux.back!(fmx, [1,1,1],[1,2,3])

(Float32[2.0, 4.0, 6.0],)

Try modifying `f` to take, or return, multiple arguments. Are the derivatives correct?

## Basic Data Handling

ML models are good at handling exactly one kind of data – fixed-size lists of numbers. If we want to work with characters we have to turn them into that format.

We can do this by defining an "alphabet" and using a one-hot-encoding – a boolean for each character in the alphabet.

In [6]:
using Flux: onehot, onecold
alphabet = 'a':'z'
onehot('c', alphabet)'

1×26 RowVector{Int64,Array{Int64,1}}:
 0  0  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0

This allows us to feed characters into a numerical model. The output of the model will be similarly encoded as a vector, so we can decode it.

In [7]:
onecold(rand(26), alphabet)

'l': ASCII/Unicode U+006c (category Ll: Letter, lowercase)

Here's a simple linear transformation, of the kind you'll have seen earlier in the workshop. We then have a way to transform one letter into another.

In [8]:
W = randn(26,26)
b = randn(26)
@net f(x) = tanh.(W*x + b)

In [9]:
onecold(f(onehot('F', alphabet)), alphabet)

'l': ASCII/Unicode U+006c (category Ll: Letter, lowercase)

We will use the notation `f('a') -> 'b'` with the encoding implicit.

## Recurrent Models

A model like `f(x)` above – even a much more complex version – will always return the same output for the same input.

Sometimes we'd like this *not* to be true. For example, in our character-level model above, `f('t')` predicts the character that comes after `t`. This clearly varies depending on what's *before* `'t'` – "lat" and "bit" are probably followed by different letters, like "later" and "bitten".

So we want `f` to access some state from previous times it was called. We can do that with a neat syntax for indexing in time.

In [10]:
count = 0
@net function f(x)
    count = count{-1} + 1
    return x + count
end
fu = unroll1(f)

Stateful(Capacitor(...))

In [11]:
fu(0)

1

In [12]:
fu(0)

2

`count` behaves essentially the same as a global variable in this case. `f` is now able to store some aggregate information about all of the information we've seen (in this case just a count).

Unlike a global variable, Flux knows about `y` and can do some interesting things with it. For example, we can statically "unroll" `f` to take a *sequence* of inputs and outputs.

In [13]:
fu = unroll(f, 5)
fu((0,1,0,-1,0))

(1, 3, 3, 3, 5)

With this in mind, here's a basic RNN. It's essentially the same as the affine transform above but includes both the input `x` and the previous prediction `y`. The previous prediction therefore inputs the next one, which is what we wanted.

In [14]:
alphabet = ['a':'z'..., ' ']
N = length(alphabet)

Wxy = randn(N,N)
Wyy = randn(N,N)
b = randn(N)
y = randn(N)

@net function f(x)
    y = tanh.(Wxy*x + Wyy*y{-1} + b)
end

fu = unroll1(f)

Stateful(Capacitor(...))

In [15]:
onecold(fu(onehot('F', alphabet)), alphabet)

'd': ASCII/Unicode U+0064 (category Ll: Letter, lowercase)

If you try this a few times you should notice that it doesn't always give the same ouput for the same input. We can use this to generate text already, just by repeatedly predicting the next output from the last.

In [16]:
s = ['a']
for i = 1:50
    push!(s, onecold(fu(onehot(s[end], alphabet)), alphabet))
end
join(s)

"anavadnhvaxipfwhwednsxafiefykx hvejosrafwefnrypuxqk"

This is clearly a long way from Shakespeare, but even with an untrained network one can see that there's some structure; the ouput is not truly random biased towards certain patterns. It's this structure that will be exploited when we train on real data.

In [17]:
join(rand(alphabet, 50))

"rppumqbwhtjgdkunrebliymayolfmpjxacoixujfadlnbjeojx"

## Getting Data

In [18]:
using Flux.Batches: Batch, seqs, chunk

We can load any text file as input.

In [19]:
input = readstring("res/shakespeare_input.txt")
alphabet = unique(input)
N = length(alphabet)
first(input)

'F': ASCII/Unicode U+0046 (category Lu: Letter, uppercase)

We don't actually want to work with characters directly, but with sequences of encodings. So `encode` encodes each character and then groups them all together in `Seq`s of length 50.

In [20]:
encode(input) = seqs((onehot(ch, alphabet) for ch in input), 50)
first(encode(input))

50-element Flux.Batches.Seq{Array{Int64,1},Array{Int64,2}}:
 [1, 0, 0, 0, 0, 0, 0, 0, 0, 0  …  0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
 [0, 1, 0, 0, 0, 0, 0, 0, 0, 0  …  0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
 [0, 0, 1, 0, 0, 0, 0, 0, 0, 0  …  0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
 [0, 0, 0, 1, 0, 0, 0, 0, 0, 0  …  0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
 [0, 0, 0, 0, 1, 0, 0, 0, 0, 0  …  0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
 [0, 0, 0, 0, 0, 1, 0, 0, 0, 0  …  0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
 [0, 0, 0, 0, 0, 0, 1, 0, 0, 0  …  0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
 [0, 1, 0, 0, 0, 0, 0, 0, 0, 0  …  0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
 [0, 0, 0, 0, 1, 0, 0, 0, 0, 0  …  0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
 [0, 1, 0, 0, 0, 0, 0, 0, 0, 0  …  0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
 [0, 0, 0, 0, 0, 0, 0, 1, 0, 0  …  0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
 [0, 0, 0, 0, 0, 0, 0, 0, 1, 0  …  0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 1  …  0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
 ⋮                                                              
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0

We also need to batch a set of sequences together so that we can work on them all at once.

`Xs` are our inputs and `Ys` are our outputs. Notice that we construct them in exactly the same way; `Ys` is just `Xs` but offset by 1, since the model should predict the next character in the sequence.

In [21]:
Xs = (Batch(ss) for ss in zip(encode.(chunk(input, 50))...))
Ys = (Batch(ss) for ss in zip(encode.(chunk(input[2:end], 50))...))
Flux.rawbatch(first(Xs))

Stacktrace:
 [2] [1mcheckbounds_linear_indices[22m[22m at [1m./abstractarray.jl:423[22m[22m [inlined]
 [3] [1mcheckbounds_indices[22m[22m at [1m./abstractarray.jl:406[22m[22m [inlined]
 [4] [1mcheckbounds_indices[22m[22m at [1m./abstractarray.jl:389[22m[22m [inlined]
 [5] [1mcheckbounds[22m[22m at [1m./abstractarray.jl:342[22m[22m [inlined]
 [6] [1mcheckbounds[22m[22m at [1m./abstractarray.jl:362[22m[22m [inlined]
 [7] [1mmacro expansion[22m[22m at [1m./multidimensional.jl:487[22m[22m [inlined]
 [8] [1m_setindex![22m[22m at [1m./multidimensional.jl:484[22m[22m [inlined]
 [9] [1msetindex![22m[22m at [1m./abstractarray.jl:967[22m[22m [inlined]
 [10] [1mFlux.Batches.Storage{Flux.Batches.Seq{Array{Int64,1},Array{Int64,2}},Array{Int64,3}}[22m[22m[1m([22m[22m::NTuple{50,Array{Int64,2}}, ::Array{Int64,3}[1m)[22m[22m at [1m/Users/mike/.julia/v0.6/Flux/src/Batches/catmat.jl:16[22m[22m
 [11] [1mFlux.Batches.Storage{Flux.Batches.Seq{Arr

50×50×67 Array{Int64,3}:
[:, :, 1] =
 1  0  0  0  0  0  0  0  0  0  0  0  0  …  0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     1  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0  …  0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0  …  0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0 

## Training a Model

In [26]:
using Flux: unsqueeze

In order to make our RNN more reusable, let's use a [model template](http://mikeinnes.github.io/Flux.jl/stable/models/templates.html) to define it. This is essentially just a Julia type containing some parameters.

In [None]:
init(dims) = randn(dims)/100

@net type Recurrent
  Wxy; Wyy; by
  y
  function (x)
    y = tanh( x * Wxy .+ y{-1} * Wyy .+ by )
  end
end

Recurrent(in, out) =
  Recurrent(init((in, out)), init((out, out)), init((1, out)), init((1, out)))

We've made a couple of tweaks compared to earlier; we'll use the first dimension of the input data as a batch dimension, which means reordering the matmuls. Otherwise, this behaves the same as before.

In [None]:
f = unroll1(Recurrent(N,N))
onecold(f(rand(5,N)), alphabet)

Here's the model; we'll chain together a few recurrent layers at once.

In [None]:
model = Chain(
  Recurrent(N, 256),
  Recurrent(256, 256),
  Affine(256, N),
  softmax)

m = mxnet(unroll(model, 50))

m(first(Xs))

To help us understand what the model is doing, we'll print out the current loss on the 5th batch of data. `evalcb` is just a function we can call from inside the training process.

In [None]:
using Flux: logloss, tobatch
eval = tobatch.(first.(drop.((Xs, Ys), 5)))
evalcb = () -> @show logloss(m(eval[1]), eval[2])
evalcb()

Finally, the training!

In [None]:
@time Flux.train!(m, zip(Xs, Ys), η = 0.001, loss = logloss, cb = [evalcb])

Here's a sampling function. It's essentially the same loop we wrote above, but takes account of things like batching. Also, we use a weight sample rather than `onecold` to add some randomness to the output, and make things a bit more interesting.

In [24]:
using StatsBase: wsample
function sample(model, n, temp = 1)
  s = [rand(alphabet)]
  m = unroll1(model)
  for i = 1:n-1
    push!(s, wsample(alphabet, softmax(m(unsqueeze(onehot(s[end], alphabet)))./temp)[1,:]))
  end
  return string(s...)
end

sample (generic function with 2 methods)

In [None]:
sample(model[1:end-1], 100)

The `Recurrent` layer defined above is about the most naive possible, and for various reasons can struggle to store information. Here's the definition of an LSTM, which you can play around with in your model.

In [28]:
model = open(deserialize, "res/shakes.jls")
sample(model[1:end-1], 1000) |> println

Stacktrace:
 [1] [1mdepwarn[22m[22m[1m([22m[22m::String, ::Symbol[1m)[22m[22m at [1m./deprecated.jl:70[22m[22m
 [2] [1mtanh[22m[22m[1m([22m[22m::Array{Float64,2}[1m)[22m[22m at [1m./deprecated.jl:57[22m[22m
 [3] [1mmacro expansion[22m[22m at [1m/Users/mike/.julia/v0.6/DataFlow/src/interpreter.jl:135[22m[22m [inlined]
 [4] [1minterp[22m[22m[1m([22m[22m::DataFlow.Interpreter.Context{DataFlow.Interpreter.##1#2{DataFlow.Interpreter.#iconst,DataFlow.Interpreter.##1#2{DataFlow.Interpreter.#iline,DataFlow.Interpreter.##1#2{DataFlow.Interpreter.#ilambda,DataFlow.Interpreter.##1#2{DataFlow.Interpreter.#iargs,DataFlow.Interpreter.##1#2{DataFlow.Interpreter.#ituple,Flux.#interp}}}}}}, ::Function, ::Array{Float64,2}, ::Vararg{Any,N} where N[1m)[22m[22m at [1m/Users/mike/.julia/v0.6/Flux/src/compiler/interp.jl:19[22m[22m
 [5] [1mituple[22m[22m[1m([22m[22m::Function, ::DataFlow.Interpreter.Context{DataFlow.Interpreter.##1#2{DataFlow.Interpreter.#iconst

riving!
This is Wits bestowed upon his love:
Now save the Doath, would I have either,
The gentle state of joy out of deband will.
First I think in a threw rap'd his doing.

PROTERESTERETRA:
We'll be tresp'd! Come you.
That thee the worget to wably, to our commands,
To broach-moon seas, he under forthound here.
S and worthy armour to my blood,
And he wark'd unto their state shall be not pleases than
you; and they were a Rome unfolliel,
And to the oblece, rits challed with us;
The side at Chrisless ladious will I lin,
The trice of his humane, thus I give them unspucianter:
It was this bolded discretress of thy courtress
Do you double.

SISS
Song; fie out it! lest I live to helple-mine, to lault,
Night with a build hence continue my refeats.

FLIACATRO:
No sorn that you! but I?

WALBOT:
Most death; thoughess dear death; and thrite by the hand,
Newest-uphortunl'd me from the time and knees,
and as 'twixt my sister Titus avoid
Anon, within his dreading quick and love
With such a gift on my 

In [None]:
@net type MyLSTM
  Wxf; Wyf; bf
  Wxi; Wyi; bi
  Wxo; Wyo; bo
  Wxc; Wyc; bc
  y; state
  function (x)
    # Gates
    forget = σ( x * Wxf .+ y{-1} * Wyf .+ bf )
    input  = σ( x * Wxi .+ y{-1} * Wyi .+ bi )
    output = σ( x * Wxo .+ y{-1} * Wyo .+ bo )
    # State update and output
    state′ = tanh( x * Wxc .+ y{-1} * Wyc .+ bc )
    state  = forget .* state{-1} .+ input .* state′
    y = output .* tanh(state)
  end
end

MyLSTM(in, out) =
  LSTM(vcat([[init((in, out)), init((out, out)), init((1, out))] for _ = 1:4]...)...,
       zeros(Float32, (1, out)), zeros(Float32, (1, out)))