# Character-level language modelling
Based on [The Unreasonable Effectiveness of Recurrent Neural Networks](https://karpathy.github.io/2015/05/21/rnn-effectiveness/).

In [2]:
using Flux
using Flux: onehot, argmax, chunk, batchseq, throttle, crossentropy
using StatsBase: wsample
using Base.Iterators: partition

In [3]:
gpu(x) = x
# using CuArrays
# gpu(x) = cu(x)

gpu (generic function with 1 method)

We'll load text data from `input.txt` and split it into characters, then turn it into the numeric form needed by the model.

The model will take a sequence of characters, like "the do", and try to produce the next character (e.g. 't' or 'g' would be likely here but not 'd'). The target output sequence $Y$ is therefore just the input sequence $X$ offset by one, e.g.

* $X$: `the dog`
* $Y$: `he dog_`

In [4]:
text = collect(readstring("input.txt"))
alphabet = [unique(text)..., '_']
text = map(ch -> onehot(ch, alphabet), text)
stop = onehot('_', alphabet)

N = length(alphabet)
seqlen = 50
nbatch = 50

Xs = collect(partition(batchseq(chunk(text, nbatch), stop), seqlen))
Ys = collect(partition(batchseq(chunk(text[2:end], nbatch), stop), seqlen));

Our model will be a multi-layer LSTM, which takes a single character as input and produces a single character as output.

In [5]:
m = Chain(
  LSTM(N, 128),
  LSTM(128, 128),
  Dense(128, N),
  softmax)

m = gpu(m)

predict(x) = m(gpu(collect(x)))

function loss(xs, ys)
  l = sum(crossentropy.(predict.(xs), gpu.(ys)))
  Flux.truncate!(m)
  return l
end

loss (generic function with 1 method)

The model accepts a one-hot-encoded character and returns a probability distribution over possible subsequent characters:

In [6]:
probabilities = predict(onehot('a', alphabet))

Tracked 68-element CuArray{Float32,1}:
 0.0146758
 0.0147896
 0.0145877
 0.0147861
 0.0145921
 0.0147305
 0.0147246
 0.0146279
 0.0147461
 0.0146239
 0.014603 
 0.0148693
 0.0147986
 ⋮        
 0.0148373
 0.0147108
 0.0146613
 0.0146699
 0.0146876
 0.0148931
 0.0147482
 0.0146224
 0.0146833
 0.0145692
 0.0147434
 0.0146336

We can sample from this distribution to see what the model thinks comes after 'a'.

In [7]:
wsample(alphabet, probabilities.data)

'o': ASCII/Unicode U+006f (category Ll: Letter, lowercase)

If we feed the model's output back into itself, we can allow it to "dream" a sequence of characters.

In [8]:
function sample(m, alphabet, len; temp = 1)
  Flux.reset!(m)
  buf = IOBuffer()
  c = rand('a':'z')
  for i = 1:len
    write(buf, c)
    c = wsample(alphabet, m(gpu(collect(onehot(c, alphabet)))).data)
  end
  return String(take!(buf))
end

sample(m, alphabet, 100) |> println

mX?.nIj.oJGJ;xA;th_AJEDerfJ BcqLvfLr]OVfxrP:&DTcCcYcZb,s.ANDlESB3oTme]T_UMV?RdjgFOL]cVBdCtw:MwuFXZm]


Right now it's more-or-less random because the model hasn't seen any data. Let's fix that.

We just need to call `Flux.train!` with an optimiser and the data we prepared. We set up a call back so that every 30 seconds, we get to see a sample of the model's output, which you should see learning a basic words and grammar fairly quickly.

In [9]:
opt = ADAM(params(m))
evalcb = function ()
  print_with_color(:blue, "Loss is $(loss(Xs[5], Ys[5]))\n")
  println(sample(deepcopy(m), alphabet, 500))
end
Flux.train!(loss, zip(Xs, Ys), opt, cb = throttle(evalcb, 10))

[34mLoss is 210.36623f0 (tracked)
[39mwZ3ALBRkDb,icvZxuGM,pXj,qJKkkXmonFyL$XR;GIKb],iuDPPNxhuBEbqqk3kR]itgOzVs,
G33FcElp
Im
bhgMtGYYHSGVMmMeDjjXQTY&RhBzhvidIigxj?$ch,dq!Z!_igDD IuwxGAOreKAx[eg;Hip!k:Qqs!bfJ.VIaVk[rUKb!CdRh[jEMyQrDUgw:W.NoSHE]zYF$C3&SGH]
VYBQA_Wiy'fTqHSa!OYgr_CWKB!R-zOI$sP;H,yAJ]a?Chc[gcse]d_T&a,xGiHH]
OcFny'__o&juwcNx!IjFz?eP:Ai-B]KVrhIDzy'I,Jvv.!dwO&xAbO_s
_.,K
UhPx!aWEJ[KmpFlXI,QFzt_DLQbVIalbYG!&[kWjN,NwEg NC,j.vqmtJhtwPO?&PBkXltRGsFgPM;pmQ&DnpioGu?wC$X3Vvtu'qFqfkfb$TCMXxfWd-RwB:!]h_gLw;CCRg,.W?tGnchN-opotypujL-UK
[34mLoss is 133.7283f0 (tracked)
[39md [g'!,

TI:B:Sd inios.
Ts
d
rrntner Benys ohut may t borl.h
H:
LothBle eartitk pf, rbaoh nn gtraaclar;r hesse ol ly ltty
Kom fanne nl Nef elnlen retat whots chlowomos cotiit taas pfesghedw
nd anntg taunt boote so dtet sy,acit Poid mosthed!
IAOBPboA meatr engirs ados con sen with rn clerglaay soisarahe thun wharls latadutii, iw afleernghon
wof
.enm gig apercnlins wenf are ouod togos yhd.r

hho s,ton,.
iupfe lhe the t

LoadError: [91mInterruptException:[39m