In [1]:
]activate .; instantiate

[32m[1m  Updating[22m[39m registry at `~/.julia/registries/General`
[32m[1m  Updating[22m[39m git-repo `https://github.com/JuliaRegistries/General.git`
[?25l[2K[?25h

# Character-level language modelling
Based on [The Unreasonable Effectiveness of Recurrent Neural Networks](https://karpathy.github.io/2015/05/21/rnn-effectiveness/).

In [2]:
using Flux
using Flux: onehot, chunk, batchseq, throttle, crossentropy
using StatsBase: wsample
using Base.Iterators: partition

We'll load text data from `input.txt` and split it into characters, then turn it into the numeric form needed by the model.

The model will take a sequence of characters, like "the do", and try to produce the next character (e.g. 't' or 'g' would be likely here but not 'd'). The target output sequence $Y$ is therefore just the input sequence $X$ offset by one, e.g.

* $X$: `the dog`
* $Y$: `he dog_`

In [3]:
text = collect(String(read("data/input.txt")))
alphabet = [unique(text)..., '_']
text = map(ch -> onehot(ch, alphabet), text)
stop = onehot('_', alphabet)

N = length(alphabet)
seqlen = 50
nbatch = 50

Xs = collect(partition(batchseq(chunk(text, nbatch), stop), seqlen))
Ys = collect(partition(batchseq(chunk(text[2:end], nbatch), stop), seqlen));

Our model will be a multi-layer LSTM, which takes a single character as input and produces a single character as output.

![LSTM](http://colah.github.io/posts/2015-08-Understanding-LSTMs/img/LSTM3-chain.png)

In [4]:
m = Chain(
  LSTM(N, 128),
  LSTM(128, 128),
  Dense(128, N),
  softmax)

m = gpu(m)

predict(x) = m(gpu(collect(x)))

function loss(xs, ys)
  l = sum(crossentropy.(predict.(xs), gpu.(ys)))
  Flux.truncate!(m)
  return l
end

loss (generic function with 1 method)

The model accepts a one-hot-encoded character and returns a probability distribution over possible subsequent characters:

In [5]:
probabilities = predict(onehot('a', alphabet))

Tracked 68-element Array{Float64,1}:
 0.015663191146045526
 0.01506423856977615 
 0.014417515936224213
 0.012959562362094256
 0.014522172320585771
 0.015258467516875238
 0.014133224451479335
 0.014298158369095083
 0.013341111115256989
 0.015560867922805106
 0.01463665306681998 
 0.016419129551026052
 0.014800611476192431
 ⋮                   
 0.013888908085837367
 0.01460300137738062 
 0.015564474217974323
 0.013552285030210371
 0.014826963044521604
 0.0150584998660292  
 0.014499189910601496
 0.015349244305896603
 0.014196267943016844
 0.01300247739467761 
 0.013322897900529816
 0.013556416439301026

We can sample from this distribution to see what the model thinks comes after 'a'.

In [6]:
wsample(alphabet, probabilities.data)

'G': ASCII/Unicode U+0047 (category Lu: Letter, uppercase)

If we feed the model's output back into itself, we can allow it to "dream" a sequence of characters.

In [7]:
function sample(m, alphabet, len; temp = 1)
  Flux.reset!(m)
  buf = IOBuffer()
  c = rand('a':'z')
  for i = 1:len
    write(buf, c)
    c = wsample(alphabet, m(gpu(collect(onehot(c, alphabet)))).data)
  end
  return String(take!(buf))
end

sample(m, alphabet, 100) |> println

nU;H$sE]MYW.QorWjM?tsceWWRUeLdndrIV[aQ'[TTwH?ysmTcBjkT[hb[rnfNR[S&pYsGNjnu?tHsiUb3U.i_kLXpfyiNfK!zDj


Right now it's more-or-less random because the model hasn't seen any data. Let's fix that.

We just need to call `Flux.train!` with an optimiser and the data we prepared. We set up a call back so that every 30 seconds, we get to see a sample of the model's output, which you should see learning a basic words and grammar fairly quickly.

In [8]:
opt = ADAM(params(m))
evalcb = function ()
  printstyled("Loss is $(loss(Xs[5], Ys[5]))\n", color=:blue)
  println(sample(deepcopy(m), alphabet, 500))
end
Flux.train!(loss, zip(Xs, Ys), opt, cb = throttle(evalcb, 10))

[34mLoss is 210.15407230396016 (tracked)[39m
e:,YWxf$QJRsGCfn!otpmrdR]XRsWxh;x!JKaiO]T!:,3Lzr
pe3QMgj$YmW3[fGtrrqDHHXGtbRf&h-rJb $::pihh! o.A;&Gk':q:CB
sDp:HQo!D_]$Kl!sh;c cycDR,gbgRej!mm$[ukW-z
A.H;R[tC
&VfJy.PqK$E. bPlXW,l,3Hv!hO[3RdOIE[AyLdYSiQYD3Foacl.y;OUlfnTEi?Fuin]qT. 3dH,buOHE.SS.bI[-Qmm:qPWh:NZ?[-,Uw''q:-3ZWtLQjBOnzR-J-uiyiJ&?k3g;NFuWuF:zg3x_3qUpitjKFyYh,ltkUByl-k &p3A$Ez-_TEPX$elOonSvQl.BaE!WfO.]iZBHV;oF!T]
?sVApPttqokAE d$:dgxr?Gw,]pj:heg FfXq;.B Y!acGXJCsQ_JgRlAswi$BKhdg$B:;CKdIV['A_SkZR'SSJB:]dkUbQPg3e[-q,flfCOyVpWkn?Gp3b'jj!:
[34mLoss is 175.9910525135106 (tracked)[39m
gUnJUrHcAOIlOl.ypaorPdaiINnvWl rf e  ryj NrJ :$,$v NnsdweG  ZhrivKse 
riQJvdxr l idiN dl tl rut uyK thro 
d  Kmcoyin
    ]ooma Wed ar l lrtbsrecehik

onuro Wd:
 Yn
i &oeSetehyIlmi e  eieh
etooRee' iPie  e'N iee_
t Ir   h  ,n  Cr? wfeXe,- s [Eo yev,eOclDrty encn Wo R
 owlhm,nefPnnerE h
toGGz Tlde iJ Aet,tX
po albHt t e kiwfkekiK JKNCnfetlt  de ieuq  P
demfos!n
xo
h Q t,oN oe s N ete  e]huoe Lt i  htmorer


[34mLoss is 164.32141758982243 (tracked)[39m
s$urlnQpkg;Itgol sahn ;dacn eTCttaBw' KtNuuasemicywors  hIp'medpme weehv  'af wmnm aydr,ru
.ACra!lnHaoreS m gk okgnh enHld
CIehh
oOCr
glmesgh mtot ntYot,TcogbyTntmyG 
rhalMl e,nOag et:huu iba rSs atwDm l eqhurrn h, msy r i oswihel Eyrhsas-ponooewwToo:teaosahetoO .eaat,o
  aEanoaE u re,RAhmN she:engeaeyswrsaKCsmeh-In Utwe ooio eL redBnieora lmyta rws eAh !vr CnlkEmytutecrgnneo
eD
MmsIa'h edtyyrei ctaoagncnoltipzI ceghR t iPptl'oheahh
rtyue.SIof  w-ioUd
ocahLo bobP
Gtitn  eh
Sae;
:it :,
euTsoisndt
[34mLoss is 162.26987431702915 (tracked)[39m
eVy_!'.djcItakb i dwe d,efFsLtR:hmCplo etdsTeu   emui t ou : hrn Oba Susf sr [r.c eawegre tm ar Di:htheinor r erNnsnHb ttfd h El weacs' do Oiee fipwrlhyum e weo ;phhmnub  iAe
s
nDi  ihroO  t lsdidsm  hemnr
nat
mdJ wo
n Zoilsd elieir  n: lun 
ooeee lekhdhd  na no  t :liroryhhnto ts   b oCoo
Amtushrreerae, ieFr aionw dRoGo.se 
PCu
oe I-wo i
 O
enI 'olkplyn eut hoI thhde  amur t eoWra  , 
skyre
engffavan

InterruptException: InterruptException: