In [1]:
using Flux, CuArrays
using Flux: onehot, argmax, chunk, batchseq, throttle, crossentropy
using StatsBase: wsample
using Base.Iterators: partition



We'll load text data from `input.txt` and split it into characters, then turn it into the numeric form needed by the model.

The model will take a sequence of characters, like "the do", and try to produce the next character (e.g. 't' or 'g' would be likely here but not 'd'). The target output sequence $Y$ is therefore just the input sequence $X$ offset by one, e.g.

* $X$: `the dog`
* $Y$: `he dog_`

In [2]:
text = collect(readstring("julia.jl"))
alphabet = [unique(text)..., '_']
text = map(ch -> onehot(ch, alphabet), text)
stop = onehot('_', alphabet)

N = length(alphabet)
seqlen = 50
nbatch = 50

Xs = collect(partition(batchseq(chunk(text, nbatch), stop), seqlen))
Ys = collect(partition(batchseq(chunk(text[2:end], nbatch), stop), seqlen));

Our model will be a multi-layer LSTM, which takes a single character as input and produces a single character as output.

In [3]:
m = Chain(
  LSTM(N, 128),
  LSTM(128, 128),
  Dense(128, N),
  softmax)

m = cu(m)

function loss(xs, ys)
  l = sum(crossentropy.(m.(cu.(collect.(xs))), cu.(ys)))
  Flux.truncate!(m)
  return l
end

loss (generic function with 1 method)

The model accepts a one-hot-encoded character and returns a probability distribution over possible subsequent characters:

In [None]:
predict(x) = m(cu(collect(x)))
probabilities = predict(onehot('a', alphabet))

We can sample from this distribution to see what the model thinks comes after 'a'.

In [None]:
wsample(alphabet, probabilities.data)

If we feed the model's output back into itself, we can allow it to "dream" a sequence of characters.

In [None]:
function sample(m, alphabet, len; temp = 1)
  Flux.reset!(m)
  buf = IOBuffer()
  c = rand('a':'z')
  for i = 1:len
    write(buf, c)
    c = wsample(alphabet, m(cu(collect(onehot(c, alphabet)))).data)
  end
  return String(take!(buf))
end

sample(m, alphabet, 100) |> println

Right now it's more-or-less random because the model hasn't seen any data. Let's fix that.

We just need to call `Flux.train!` with an optimiser and the data we prepared. We set up a call back so that every 30 seconds, we get to see a sample of the model's output, which you should see learning a basic words and grammar fairly quickly.

In [None]:
opt = ADAM(params(m))
evalcb = function ()
  print_with_color(:blue, "Loss is $(loss(Xs[5], Ys[5]))\n")
  println(sample(deepcopy(m), alphabet, 500))
end
@time for i = 1:5 Flux.train!(loss, zip(Xs, Ys), opt, cb = throttle(evalcb, 10)) end

In [None]:
gc(); CuArrays.clearpool()

In [None]:
# open(io -> serialize(io, (alphabet, m)), "julia-90.jls", "w")
# (alphabet, m) = open(deserialize, "julia-0.90.jls")