## LSTM recurrent neural network for generating text

### Loading the data

First we load the required packages:

In [1]:
using Flux
using Flux: onehot, chunk, batchseq, throttle, crossentropy
using StatsBase: wsample
using Base.Iterators: partition

Now we read in the data:

In [3]:
# isfile("input.txt") ||
#   download("https://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt",
#            "input.txt")

text = collect(String(read("data/input_catullus.txt")))
alphabet = [unique(text)..., '_']
text = map(ch -> onehot(ch, alphabet), text)
stop = onehot('_', alphabet);

In [4]:
N = length(alphabet)
seqlen = 50
nbatch = 50

Xs = collect(partition(batchseq(chunk(text, nbatch), stop), seqlen));
Ys = collect(partition(batchseq(chunk(text[2:end], nbatch), stop), seqlen));

In [5]:
?LSTM

search: [0m[1mL[22m[0m[1mS[22m[0m[1mT[22m[0m[1mM[22m [0m[1ml[22mog[0m[1ms[22mof[0m[1mt[22m[0m[1mm[22max partia[0m[1ml[22m[0m[1ms[22mor[0m[1mt[22mper[0m[1mm[22m partia[0m[1ml[22m[0m[1ms[22mor[0m[1mt[22mper[0m[1mm[22m! [0m[1ml[22m[0m[1ms[22m[0m[1mt[22mat [0m[1ml[22m[0m[1ms[22m[0m[1mt[22mrip [0m[1ml[22ma[0m[1ms[22m[0m[1mt[22m



```
LSTM(in::Integer, out::Integer)
```

Long Short Term Memory recurrent layer. Behaves like an RNN but generally exhibits a longer memory span over sequences.

See [this article](http://colah.github.io/posts/2015-08-Understanding-LSTMs/) for a good overview of the internals.


In [6]:
# Define our model.  
# We will use two LSTMs
# followed by a final Dense layer that
# feeds into a softmax probability output.
m = Chain(
  LSTM(N, 128),
  LSTM(128, 128),
  Dense(128, N),
  softmax)

Chain(Recur(LSTMCell(59, 128)), Recur(LSTMCell(128, 128)), Dense(128, 59), NNlib.softmax)

In [7]:
# `loss()` calculates the crossentropy loss 
function loss(xs, ys)
  los = sum(crossentropy.(m.(xs), ys))
  Flux.truncate!(m)
  return los
end

loss (generic function with 1 method)

In [8]:
opt = ADAM(0.01)
tx, ty = (Xs[5], Ys[5])
evalcb = () -> @show loss(tx, ty)

#7 (generic function with 1 method)

In [9]:
epochs = 50
for i = 1:epochs
    Flux.train!(loss, params(m), zip(Xs, Ys), opt,
            cb = throttle(evalcb, 30))
end

loss(tx, ty) = 178.79956f0 (tracked)
loss(tx, ty) = 151.91824f0 (tracked)
loss(tx, ty) = 146.89218f0 (tracked)
loss(tx, ty) = 125.76952f0 (tracked)
loss(tx, ty) = 118.99043f0 (tracked)
loss(tx, ty) = 115.355934f0 (tracked)
loss(tx, ty) = 111.89171f0 (tracked)
loss(tx, ty) = 108.702995f0 (tracked)
loss(tx, ty) = 105.75389f0 (tracked)
loss(tx, ty) = 106.171295f0 (tracked)
loss(tx, ty) = 103.921455f0 (tracked)
loss(tx, ty) = 102.30315f0 (tracked)
loss(tx, ty) = 100.050385f0 (tracked)
loss(tx, ty) = 98.23508f0 (tracked)
loss(tx, ty) = 97.25767f0 (tracked)
loss(tx, ty) = 95.35851f0 (tracked)
loss(tx, ty) = 94.9022f0 (tracked)
loss(tx, ty) = 93.50886f0 (tracked)
loss(tx, ty) = 92.64974f0 (tracked)
loss(tx, ty) = 91.35759f0 (tracked)
loss(tx, ty) = 91.194f0 (tracked)
loss(tx, ty) = 89.62077f0 (tracked)
loss(tx, ty) = 89.14076f0 (tracked)
loss(tx, ty) = 89.46552f0 (tracked)
loss(tx, ty) = 88.94996f0 (tracked)
loss(tx, ty) = 88.996f0 (tracked)
loss(tx, ty) = 88.48419f0 (tracked)
loss(tx, ty) = 

In [14]:
epochs = 50
for i = 1:epochs
    Flux.train!(loss, params(m), zip(Xs, Ys), opt,
            cb = throttle(evalcb, 30))
end

loss(tx, ty) = 80.77767f0 (tracked)
loss(tx, ty) = 79.79824f0 (tracked)
loss(tx, ty) = 79.40668f0 (tracked)
loss(tx, ty) = 81.20881f0 (tracked)
loss(tx, ty) = 81.04387f0 (tracked)
loss(tx, ty) = 80.61255f0 (tracked)
loss(tx, ty) = 82.20054f0 (tracked)
loss(tx, ty) = 80.65729f0 (tracked)
loss(tx, ty) = 77.743484f0 (tracked)
loss(tx, ty) = 79.36326f0 (tracked)
loss(tx, ty) = 81.159485f0 (tracked)
loss(tx, ty) = 79.03786f0 (tracked)
loss(tx, ty) = 78.556076f0 (tracked)
loss(tx, ty) = 80.71069f0 (tracked)
loss(tx, ty) = 83.903915f0 (tracked)
loss(tx, ty) = 79.579834f0 (tracked)
loss(tx, ty) = 79.765495f0 (tracked)
loss(tx, ty) = 79.7674f0 (tracked)
loss(tx, ty) = 79.66036f0 (tracked)
loss(tx, ty) = 78.80348f0 (tracked)
loss(tx, ty) = 78.31745f0 (tracked)
loss(tx, ty) = 80.05818f0 (tracked)
loss(tx, ty) = 79.1711f0 (tracked)
loss(tx, ty) = 77.2754f0 (tracked)
loss(tx, ty) = 77.679085f0 (tracked)
loss(tx, ty) = 76.46557f0 (tracked)
loss(tx, ty) = 75.03615f0 (tracked)
loss(tx, ty) = 73.65215f

In [15]:
# Sampling

function sample(m, alphabet, len; temp = 1)
  Flux.reset!(m)
  buf = IOBuffer()
  c = rand(alphabet)
  for i = 1:len
    write(buf, c)
    c = wsample(alphabet, m(onehot(c, alphabet)).data)
  end
  return String(take!(buf))
end

sample (generic function with 1 method)

In [24]:
sample(m, alphabet, 10000) |> println


Mirmi, mala domum fre iuvencunt raudites,
tum Zeaceo quo tenens vosquens in manus
capillos habere novo?
forte qui furdetu mivatum curas quo meos amorem.
qui. nar tali cum matre quam me oratur esse patruae
auxit ipsa dedpulis pueris et mihi: lansiculos.
harramque incita, iam non est nostribus,
iti quae tuisus est omnibus ver et intesto prodes,
nutque remigula Remuncis mente omnibus.
astia gaudiia venenie nobis mentem confituo est.
namque culus uni, quandos socpersum maturnae, mei ne vargiescum.

LXIV. ad Gaium Brelio facsatis

Gellus iunctum in aequor est inceptam
Procesia letunissime filique,
famulam, nito ciduli, quam bei redire tota domona;
nullum ore, induemus auructis prusui promi,
Zrybhi, cum laudantum incidationes essem
currite copia capit! multidum
morte furtum credita periculus puero,
dulcrantes, codicillos insibia.
nam salvent singuis: quicumor aequens pedes,
pulcere infectis puercate, cum oracdalum,
cum te mihi, dilactum illuc, lacrabit, ut iam mens et ignes
admer reditum pu