In [1]:
using Knet

include("../src/data.jl")
include("../src/model.jl")

train! (generic function with 1 method)

In [2]:
datadir = "../data/enwik8"
jld2dir = "../jld2/enwik8.jld2"

if !isfile(jld2dir)
    BATCHSIZE = 64
    println("Reading data from directory: $datadir")
    println("Setting batch size to $BATCHSIZE")
    vocab = Vocab("$datadir/train.txt")
    trainfile = TextReader("$datadir/train.txt", vocab)
    validfile = TextReader("$datadir/valid.txt", vocab)
    testfile = TextReader("$datadir/test.txt", vocab)
    dtrn = TextData(trainfile, batchsize=BATCHSIZE)
    ddev = TextData(validfile, batchsize=BATCHSIZE)
    dtst = TextData(testfile, batchsize=BATCHSIZE)
    println("Saving data from $jld2dir")
    Knet.save(jld2dir, "dtrn", dtrn, "dtst", dtst, "ddev", ddev)
else 
    println("Loading data from $jld2dir")
    (dtrn, dtst, ddev) = Knet.load(jld2dir, "dtrn", "dtst", "ddev")
    vocab = dtrn.src.vocab
end

Loading data from ../jld2/enwik8.jld2


Vocab(Dict("54" => 67,"101" => 4,"41" => 52,"65" => 38,"168" => 126,"159" => 175,"228" => 183,"190" => 117,"227" => 96,"88" => 104…), ["<s>", "<unk>", "32", "101", "116", "97", "105", "111", "110", "114"  …  "210", "239", "211", "198", "212", "240", "205", "220", "222", "200"], 2, 1, split)

In [None]:
@info "Initializing and Training Language Model"
epochs, em_size, hidden_size, layers = 5, 1024, 1024, 2
println("embedding size: ", em_size)
println("hidden size: ", hidden_size)
println("layers: ", layers)

println("Collecting training data...")
println("epochs: ", epochs)
ctrn = collect(dtrn)
trnx10 = collect(flatten(shuffle!(ctrn) for i in 1:epochs))
trnmini = ctrn[1:20]
dev = collect(ddev);

model = SimpleLSTMModel(em_size, hidden_size, vocab; layers=layers, dropout=0.2)

In [None]:
model.rnn.h, model.rnn.c = 0, 0
model = train!(model, trnx10, dev, trnmini)

In [16]:
testloss = loss(model, dtst)
(testloss=testloss, testppl=exp.(testloss), testbpc=(testloss ./ log(2)))

(testloss = 1.0396284f0, testppl = 2.8281658f0, testbpc = 1.4998667175673344)

In [17]:
devloss = loss(model, ddev)
(devloss=devloss, devppl=exp.(devloss), devbpc=(devloss ./ log(2)))

(devloss = 1.0334321f0, devppl = 2.810696f0, devbpc = 1.4909274033407873)

In [27]:
print(generate(model, start="United Nations ", maxlength=1024))

United Nations higher organisations (as an researcher which shows all the non-sudden bishops, and not national broken next topics); unjustly violated both the Jews and the Family resistants terms allowing them to conspiracy high-executive following thy national [[Social Democratic Purpose (Special Education Technology)|income]] - indirect to acknowledge regarding the theoretical advisory influence: