In [1]:
using Flux
using Flux: onehot, argmax, chunk, batchseq, throttle, crossentropy
using FluxJS
using StatsBase: wsample
using Base.Iterators: partition
using BSON: @save, @load

In [2]:
text = collect(readstring("hp.txt"))
alphabet = [unique(text)..., '_']
text = map(ch -> onehot(ch, alphabet), text)
stop = onehot('_', alphabet)
N = length(alphabet)
seqlen = 25
nbatch = 25

Xs = collect(partition(batchseq(chunk(text, nbatch), stop), seqlen))
Ys = collect(partition(batchseq(chunk(text[2:end], nbatch), stop), seqlen));

In [3]:
m = Chain(
  LSTM(N, 128),
  LSTM(128, 128),
  Dense(128, N),
  softmax)

function loss(xs, ys)
  l = sum(crossentropy.(m.(xs), ys))
  Flux.truncate!(m)
  return l
end

opt = ADAM(params(m), 0.01)
function evalcb() 
    @show loss(Xs[5], Ys[5])
    @save "model-checkpoint.bson" m
end

evalcb (generic function with 1 method)

In [4]:
Flux.train!(loss, zip(Xs, Ys), opt,
            cb = throttle(evalcb, 30))

loss(Xs[5], Ys[5]) = 110.34525112983819 (tracked)
loss(Xs[5], Ys[5]) = 70.46271348589207 (tracked)
loss(Xs[5], Ys[5]) = 58.37482945039203 (tracked)
loss(Xs[5], Ys[5]) = 54.50622371245402 (tracked)
loss(Xs[5], Ys[5]) = 53.03594862772823 (tracked)
loss(Xs[5], Ys[5]) = 51.023813475246484 (tracked)
loss(Xs[5], Ys[5]) = 49.59801286672625 (tracked)
loss(Xs[5], Ys[5]) = 47.309709431965985 (tracked)
loss(Xs[5], Ys[5]) = 46.87448627502052 (tracked)
loss(Xs[5], Ys[5]) = 46.11419481178562 (tracked)
loss(Xs[5], Ys[5]) = 45.80061502708199 (tracked)
loss(Xs[5], Ys[5]) = 46.65096380609584 (tracked)
loss(Xs[5], Ys[5]) = 45.9052636756591 (tracked)
loss(Xs[5], Ys[5]) = 43.76509137019723 (tracked)
loss(Xs[5], Ys[5]) = 43.19664619876508 (tracked)
loss(Xs[5], Ys[5]) = 45.05510449492588 (tracked)
loss(Xs[5], Ys[5]) = 45.88615437690778 (tracked)
loss(Xs[5], Ys[5]) = 44.9665657644777 (tracked)
loss(Xs[5], Ys[5]) = 43.66211622265231 (tracked)
loss(Xs[5], Ys[5]) = 42.294778595800615 (tracked)
loss(Xs[5], Ys[5])

In [5]:
function sample(m, alphabet, len; temp = 1)
  Flux.reset!(m)
  buf = IOBuffer()
  c = rand(alphabet)
  for i = 1:len
    write(buf, c)
    c = wsample(alphabet, m(onehot(c, alphabet)).data)
  end
  return String(take!(buf))
end

sample (generic function with 1 method)

In [6]:
sample(m, alphabet, 10000) |> println

\Ly!" said Ron nervous find Slughorn! Magnose fingers and halfway table broom agod darkcared, "You will think't then.s!" said Fudge�s face swarmed away in them shamphy not frightly. "I'll done to preparant moment, Ron?" said Fred, graps of room.
 "Hermione?"
 They were mild Ron was soft on before he and the room new. Doly into his hour, do. It could high great rock-than there was growing from the offediously. "What's caught wand in it in
a lurbe inside as you fither words to know Harry, She Wooks...rush�ley are or kneen," said Riney jinxes Curse.
 Yell contienish Fawcorded his eyessuagus on Kreacher abstandiss in dodmother of Krum.
 "What is wand says �� for time...VELTh HOGORMON" Sounne
ball of they just long silver. It's only a binner, even this chaken then; even burrowy shower - changuies burst off with rading his around to already stared was supposed to recognizine keeped in her surfectantly, cletching the prickle of the forcht. And Xeatered and all of Haved Ron. "Over, I would be 

In [8]:
length(alphabet)

94

In [29]:
first(Xs)[1]

94×25 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}:
  true  false  false  false  false  …  false  false  false  false  false
 false  false  false  false  false     false   true  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false   true  false     false  false  false  false  false
 false  false  false  false  false  …  false  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false     false  false   true  false  false
 false  false  false  false  false     false  false  false  false   true
 false  false  false  false  false     false  false  false  false  false
 false   true   true  false  false  …   true  false  false  false  false
 false  false  false  false  false     false  false  false  false  false
 false  false  false  false  false     false  false  false  false  fals

In [22]:
@load "model-checkpoint.bson" m

In [7]:
m(zeros(94))

Tracked 94-element Array{Float64,1}:
 0.00031122 
 0.000706177
 0.00344699 
 0.00177044 
 0.566277   
 4.0275e-5  
 0.00217819 
 0.0316184  
 0.0015179  
 0.0290025  
 0.0104681  
 0.00663156 
 9.02475e-5 
 ⋮          
 6.87722e-7 
 3.61225e-6 
 8.53602e-7 
 7.53192e-6 
 7.26536e-8 
 2.54951e-6 
 2.75955e-6 
 7.84611e-8 
 2.47369e-7 
 4.25817e-7 
 7.59581e-7 
 1.30766e-7 

In [8]:
@code_js m(zeros(94))

LoadError: [91mUnsupported type Tuple{TrackedArray{…,Array{Float64,1}},TrackedArray{…,Array{Float64,1}}}[39m