In [21]:
using Flux
using Flux: onehot, argmax, chunk, batchseq, throttle, crossentropy
using FluxJS
using StatsBase: wsample
using Base.Iterators: partition
using BSON: @save, @load

In [3]:
text = collect(readstring("hp.txt"))
alphabet = [unique(text)..., '_']
text = map(ch -> onehot(ch, alphabet), text)
stop = onehot('_', alphabet)
N = length(alphabet)
seqlen = 50
nbatch = 25

Xs = collect(partition(batchseq(chunk(text, nbatch), stop), seqlen))
Ys = collect(partition(batchseq(chunk(text[2:end], nbatch), stop), seqlen));

In [4]:
m = Chain(
  LSTM(N, 128),
  LSTM(128, 128),
  Dense(128, N),
  softmax)

function loss(xs, ys)
  l = sum(crossentropy.(m.(xs), ys))
  Flux.truncate!(m)
  return l
end

opt = ADAM(params(m), 0.01)
function evalcb() 
    @show loss(Xs[5], Ys[5])
    @save "model-checkpoint.bson" m
end

evalcb (generic function with 1 method)

In [5]:
Flux.train!(loss, zip(Xs, Ys), opt,
            cb = throttle(evalcb, 30))

loss(Xs[5], Ys[5]) = 222.3757202685765 (tracked)
loss(Xs[5], Ys[5]) = 157.21890909493126 (tracked)
loss(Xs[5], Ys[5]) = 133.8549126696383 (tracked)
loss(Xs[5], Ys[5]) = 122.90663398381166 (tracked)
loss(Xs[5], Ys[5]) = 114.47700039986553 (tracked)
loss(Xs[5], Ys[5]) = 108.84221941193296 (tracked)
loss(Xs[5], Ys[5]) = 103.49640418315056 (tracked)
loss(Xs[5], Ys[5]) = 101.6835895707162 (tracked)
loss(Xs[5], Ys[5]) = 98.85960913039885 (tracked)
loss(Xs[5], Ys[5]) = 96.9075684127779 (tracked)
loss(Xs[5], Ys[5]) = 93.2355921660098 (tracked)
loss(Xs[5], Ys[5]) = 92.01270261481325 (tracked)
loss(Xs[5], Ys[5]) = 90.57310722525109 (tracked)
loss(Xs[5], Ys[5]) = 90.61986117375184 (tracked)
loss(Xs[5], Ys[5]) = 91.63529131454425 (tracked)
loss(Xs[5], Ys[5]) = 88.16451110433556 (tracked)
loss(Xs[5], Ys[5]) = 86.72675765174564 (tracked)
loss(Xs[5], Ys[5]) = 86.06704838193778 (tracked)
loss(Xs[5], Ys[5]) = 86.41018876768626 (tracked)
loss(Xs[5], Ys[5]) = 85.57010398651063 (tracked)
loss(Xs[5], Ys[5]

In [6]:
function sample(m, alphabet, len; temp = 1)
  Flux.reset!(m)
  buf = IOBuffer()
  c = rand(alphabet)
  for i = 1:len
    write(buf, c)
    c = wsample(alphabet, m(onehot(c, alphabet)).data)
  end
  return String(take!(buf))
end

sample (generic function with 1 method)

In [7]:
sample(m, alphabet, 10000) |> println

'll do, sort of Wordt for red, when it abundering, becoming and importance for eye-parently. Students," said Ron, under corridor contentation and Hermione was viised to press stretched shining-blinked the disagress, frew to room jurk upon a greatle's enough was endedently had been people bade through them dispotton aidered over the brotthed and left. Hermione, dropped and family was
 frew in a high door, talking with Harry bly in mouds, still was degiturated. It thruel and left it assowth jordflouched. "Purresses?" Harry and them thicked gingress of blow were got his fingering -well, remembered vowprestly Umbridge was noise right gill to sprounded green stronged.
 "I know. . . ." Roiring dropped up, "Motter-Ome listent ?"
They sit mean, fomility whereen "whill we'll good it. . . .
 Perhaps the entoronce - I did you will place's means where's help nothing doing, she�s asleeply best come so only thirts of that winged around he? Hod during to checkure that!" snarled Harry had been demento

In [8]:
length(alphabet)

94

In [23]:
@code_js m(zeros(94))

LoadError: [91mUnsupported type Tuple{TrackedArray{…,Array{Float64,1}},TrackedArray{…,Array{Float64,1}}}[39m

In [22]:
@load "model-checkpoint.bson" m

In [27]:
m(zeros(94))

Tracked 94×25 Array{Float64,2}:
 2.90854e-5   0.000118049  …  0.000154269  8.2384e-6   7.9182e-5 
 0.682277     0.703952        0.0143805    0.108223    0.573244  
 0.00858001   0.0497          0.00681526   0.00204965  0.0210669 
 0.00553047   0.000614798     0.000309075  0.00459829  0.00732779
 0.0138385    0.00160243      0.00237333   0.0164756   0.00122341
 2.09908e-5   4.31597e-5   …  6.42426e-5   1.4833e-5   1.20529e-5
 0.0469966    0.0868281       0.00463252   0.00294652  0.0212775 
 0.0148687    0.0122372       0.771559     0.013287    0.0133423 
 0.0304297    0.00875302      0.00244044   0.0111998   0.0264861 
 0.00531429   0.00457857      0.0387241    0.0195691   0.0109529 
 0.000389545  0.00157066   …  0.000393897  0.0268009   0.00646937
 0.00472734   0.00437733      0.0144871    0.0101109   0.0295822 
 1.59775e-5   4.97304e-5      4.68176e-5   1.47801e-5  6.7845e-5 
 ⋮                         ⋱                                     
 7.82424e-6   4.48123e-6      9.83668e-7   9