In [2]:
using Flux
using Flux: onehot, argmax, chunk, batchseq, throttle, crossentropy
using StatsBase: wsample
using Base.Iterators: partition

We'll load text data from `input.txt` and split it into characters, then turn it into the numeric form needed by the model.

The model will take a sequence of characters, like "the do", and try to produce the next character (e.g. 't' or 'g' would be likely here but not 'd'). The target output sequence $Y$ is therefore just the input sequence $X$ offset by one, e.g.

* $X$: `the dog`
* $Y$: `he dog_`

In [2]:
text = collect(readstring("input.txt"))
alphabet = [unique(text)..., '_']
text = map(ch -> onehot(ch, alphabet), text)
stop = onehot('_', alphabet)

N = length(alphabet)
seqlen = 50
nbatch = 50

Xs = collect(partition(batchseq(chunk(text, nbatch), stop), seqlen))
Ys = collect(partition(batchseq(chunk(text[2:end], nbatch), stop), seqlen));

Our model will be a multi-layer LSTM, which takes a single character as input and produces a single character as output.

In [25]:
m = Chain(
  LSTM(N, 128),
  Dense(128, N),
  softmax)

function loss(xs, ys)
  l = sum(crossentropy.(m.(xs), ys))
  Flux.truncate!(m)
  return l
end

loss (generic function with 1 method)

The model accepts a one-hot-encoded character and returns a probability distribution over possible subsequent characters:

In [6]:
probabilities = m(onehot('a', alphabet)).data

68-element Array{Float64,1}:
 0.0151041
 0.0143515
 0.0134988
 0.0139868
 0.0121982
 0.0133474
 0.0154417
 0.0140419
 0.0140034
 0.0142141
 0.0145562
 0.0136849
 0.0144724
 ⋮        
 0.0157853
 0.0152509
 0.0155455
 0.0163665
 0.0156836
 0.0152913
 0.0145699
 0.0139844
 0.0156272
 0.0151008
 0.0146325
 0.0151086

We can sample from this distribution to see what the model thinks comes after 'a'.

In [7]:
wsample(alphabet, probabilities)

't': ASCII/Unicode U+0074 (category Ll: Letter, lowercase)

If we feed the model's output back into itself, we can allow it to "dream" a sequence of characters.

In [7]:
function sample(m, alphabet, len; temp = 1)
  Flux.reset!(m)
  buf = IOBuffer()
  c = rand(alphabet)
  for i = 1:len
    write(buf, c)
    c = wsample(alphabet, m(onehot(c, alphabet)).data)
  end
  return String(take!(buf))
end

sample(m, alphabet, 100) |> println

sample (generic function with 1 method)

Right now it's more-or-less random because the model hasn't seen any data. Let's fix that.

We just need to call `Flux.train!` with an optimiser and the data we prepared. We set up a call back so that every 30 seconds, we get to see a sample of the model's output, which you should see learning a basic words and grammar fairly quickly.

In [26]:
opt = ADAM(params(m), 0.01)
evalcb = function ()
  print_with_color(:blue, "Loss is $(loss(Xs[5], Ys[5]).data[])\n")
  println(sample(deepcopy(m), alphabet, 100))
end
Flux.train!(loss, zip(Xs, Ys), opt, cb = throttle(evalcb, 30))

[34mLoss is 202.79004518047788
[39m_ZOWk!$GLljWEkUAxs
tfp?Vc-IEKK
[oAnZ
_hXEoB3,OcuSy&:DSfpzX]-cs_NnaeUWW
Rn?UlJ RpDY:pE!tMJZPt k'ew,Uf
[34mLoss is 119.96210054291271
[39mWTo nou losere'y in the Torey Bit lild, 'isdr, ton sy ut.

 co her.
Thopr his, Asllingemleut

I Buer
[34mLoss is 109.20718742345751
[39m  souzk seet, I way ma poneditles Good Pancente, ftar, Craod dading: Ina the blaotp ceaveek, thuer g
[34mLoss is 102.65545199945302
[39m:
I lond tide, feertwith beftince hatl ment whou not never:
En't
And spow. 'Dat then.

Frrest
Caosio
[34mLoss is 99.49612308562551
[39m;
BY Lord:
I we hem seove is sin I comnatirn
Fandshning is hil unenf-'Gind!
Heal, feairs leave she b
[34mLoss is 97.98811768094829
[39meS
OTHaun:
A welter tolferanty Ore issile tought all and have that Quaszan, of hoor;
And we hans mon
[34mLoss is 97.16236004296526
[39mUw.

DUISSOM:
There obeugatolesecre would, and Wool me there's yes quearied:
She ling-brince on now 
[34mLoss is 92.539100060141
[39m:
L

LoadError: [91mInterruptException:[39m

Bake for 2-3 hours at 180° for best results. Here's one I made earlier.

In [4]:
m_shakes, alpha_shakes = open(deserialize, "shakespeare-0.75.jls")

(Chain(Recur(LSTMCell(68, 128)), Recur(LSTMCell(128, 128)), Dense(128, 68), NNlib.softmax), ['F', 'i', 'r', 's', 't', ' ', 'C', 'z', 'e', 'n'  …  'K', 'Q', '&', 'Z', 'X', '3', '$', '[', ']', '_'])

In [10]:
sample(m_shakes, alpha_shakes, 1000) |> println

Bonelo!
I'll not hear us hand Humphrea, atten.

Clown:
Hark you, I chaste them, that he
blood with alls; to answer'd in our honour;
For not the great old mother were a fellow;
Kercite in the Lord of Trod with her countathat judgment;
It fear me in it to whom high down than your ring.

OLIVIA:
Not in the lovers in stomached till doth ashaw'd, I
doubt such common flower knows them less
'Lucew Philo's to gods, hath I your wisdom, step:
For I do: 'tween away.

DROMIO OF LORTESTEL:
What wat so,'
With Romans is I told the sweet shallows?

PAULINE:
Your heart: there quickly cold to uttermid! for so fits,
I suffer braggard thy hour and gates.

EDWARD:
You know that warly actage:
And 'Orther show.

Hostess:
Nothing I am confess, when we have ancient
Do the sweet?

Second Petone.

REGAN:
When you served them a scarf,
The world is house is dismerding to a wall: I make you,
At past of travered to an exile.

ULYSSES:
Deep and boot and favour cannot we perchance
As to our schole.

TITUS ANDRONICUS:
