In [22]:
using Knet

include("../src/data.jl")
include("../src/model.jl")

train! (generic function with 1 method)

In [2]:
datadir = "../data/enwik8"
jld2dir = "../jld2/enwik8.jld2"

if !isfile(jld2dir)
    BATCHSIZE = 64
    println("Reading data from directory: $datadir")
    println("Setting batch size to $BATCHSIZE")
    vocab = Vocab("$datadir/train.txt")
    trainfile = TextReader("$datadir/train.txt", vocab)
    validfile = TextReader("$datadir/valid.txt", vocab)
    testfile = TextReader("$datadir/test.txt", vocab)
    dtrn = TextData(trainfile, batchsize=BATCHSIZE)
    ddev = TextData(validfile, batchsize=BATCHSIZE)
    dtst = TextData(testfile, batchsize=BATCHSIZE)
    println("Saving data from $jld2dir")
    Knet.save(jld2dir, "dtrn", dtrn, "dtst", dtst, "ddev", ddev)
else 
    println("Loading data from $jld2dir")
    (dtrn, dtst, ddev) = Knet.load(jld2dir, "dtrn", "dtst", "ddev")
    vocab = dtrn.src.vocab
end

Loading data from ../jld2/enwik8.jld2


Vocab(Dict("54" => 67,"101" => 4,"41" => 52,"65" => 38,"168" => 126,"159" => 175,"228" => 183,"190" => 117,"227" => 96,"88" => 104…), ["<s>", "<unk>", "32", "101", "116", "97", "105", "111", "110", "114"  …  "210", "239", "211", "198", "212", "240", "205", "220", "222", "200"], 2, 1, split)

In [None]:
@info "Initializing and Training Language Model"
epochs, em_size, hidden_size, layers = 3, 256, 256, 2
println("embedding size: ", em_size)
println("hidden size: ", hidden_size)
println("layers: ", layers)

println("Collecting training data...")
println("epochs: ", epochs)
ctrn = collect(dtrn)
trnx10 = collect(flatten(shuffle!(ctrn) for i in 1:epochs))
trnmini = ctrn[1:20]
dev = collect(ddev)

model = SimpleLSTMModel(em_size, hidden_size, vocab; layers=layers, dropout=0.2)

In [42]:
model = train!(model, trnx10, dev, trnmini)


┣                    ┫ [0.02%, 1/4122, 00:14/15:40:51, 13.70s/i] (trnloss = (1.6159786f0,), trnppl = (5.0328107f0,), trnbpc = (2.331364310375221,), devloss = 1.6195295f0, devppl = 5.0507135f0, devbpc = 2.3364871575964714)
┣▎                   ┫ [1.29%, 53/4122, 00:57/01:14:32, 1.19i/s] (trnloss = (1.6132574f0,), trnppl = (5.019134f0,), trnbpc = (2.3274384624039754,), devloss = 1.6172383f0, devppl = 5.0391545f0, devbpc = 2.333181651046888)
┣▌                   ┫ [2.57%, 106/4122, 01:41/01:05:23, 1.22i/s] (trnloss = (1.6119295f0,), trnppl = (5.0124736f0,), trnbpc = (2.3255227476560147,), devloss = 1.6151806f0, devppl = 5.028796f0, devbpc = 2.330213058510345)
┣▊                   ┫ [3.86%, 159/4122, 02:25/01:02:33, 1.21i/s] (trnloss = (1.6103592f0,), trnppl = (5.0046086f0,), trnbpc = (2.323257220196199,), devloss = 1.613553f0, devppl = 5.020618f0, devbpc = 2.3278649793781154)
┣█                   ┫ [5.12%, 211/4122, 03:09/01:01:26, 1.18i/s] (trnloss = (1.6081607f0,), trnppl = (4.993618f0

┣█████████▎          ┫ [46.80%, 1929/4122, 27:06/57:55, 1.17i/s] (trnloss = (1.5613579f0,), trnppl = (4.7652874f0,), trnbpc = (2.2525632356110887,), devloss = 1.5651946f0, devppl = 4.7836056f0, devbpc = 2.2580984972291085)
┣█████████▌          ┫ [48.08%, 1982/4122, 27:50/57:53, 1.21i/s] (trnloss = (1.5605316f0,), trnppl = (4.7613516f0,), trnbpc = (2.2513712238579586,), devloss = 1.5642446f0, devppl = 4.7790637f0, devbpc = 2.2567279674843825)
┣█████████▊          ┫ [49.37%, 2035/4122, 28:34/57:52, 1.21i/s] (trnloss = (1.5586568f0,), trnppl = (4.752434f0,), trnbpc = (2.2486664527078393,), devloss = 1.5621947f0, devppl = 4.769277f0, devbpc = 2.2537705538201456)
┣██████████          ┫ [50.58%, 2085/4122, 29:17/57:54, 1.16i/s] (trnloss = (1.5573332f0,), trnppl = (4.7461476f0,), trnbpc = (2.2467569293353096,), devloss = 1.5611525f0, devppl = 4.764309f0, devbpc = 2.252266909503652)
┣██████████▎         ┫ [51.87%, 2138/4122, 30:01/57:52, 1.21i/s] (trnloss = (1.5548562f0,), trnppl = (4.7344055f

┣██████████████████▋ ┫ [93.47%, 3853/4122, 53:58/57:44, 1.20i/s] (trnloss = (3.8139915f0,), trnppl = (45.331017f0,), trnbpc = (5.502426690316769,), devloss = 3.801562f0, devppl = 44.771065f0, devbpc = 5.484494747241907)
┣██████████████████▉ ┫ [94.74%, 3905/4122, 54:42/57:44, 1.20i/s] (trnloss = (3.7671504f0,), trnppl = (43.256626f0,), trnbpc = (5.434849203347944,), devloss = 3.7554173f0, devppl = 42.75206f0, devbpc = 5.417921982919422)
┣███████████████████▏┫ [95.95%, 3955/4122, 55:25/57:45, 1.15i/s] (trnloss = (3.7207808f0,), trnppl = (41.296627f0,), trnbpc = (5.3679520797459315,), devloss = 3.7079496f0, devppl = 40.770126f0, devbpc = 5.3494405551376625)
┣███████████████████▍┫ [97.21%, 4007/4122, 56:08/57:45, 1.20i/s] (trnloss = (3.681657f0,), trnppl = (39.712147f0,), trnbpc = (5.311508405628663,), devloss = 3.6696758f0, devppl = 39.23918f0, devbpc = 5.294223117321046)
┣███████████████████▋┫ [98.47%, 4059/4122, 56:51/57:44, 1.20i/s] (trnloss = (3.643354f0,), trnppl = (38.21981f0,), trn

SimpleLSTMModel(Embed(P(KnetArray{Float32,2}(256,206))), LSTM(input=256,hidden=256,layers=2,dropout=0.2), Linear(P(KnetArray{Float32,2}(206,256))), 0.2, Vocab(Dict("54" => 67,"101" => 4,"41" => 52,"65" => 38,"168" => 126,"159" => 175,"228" => 183,"190" => 117,"227" => 96,"88" => 104…), ["<s>", "<unk>", "32", "101", "116", "97", "105", "111", "110", "114"  …  "210", "239", "211", "198", "212", "240", "205", "220", "222", "200"], 2, 1, split))

In [43]:
Knet.save("jld2/baseline-256-2.jld2", "model", model)

In [44]:
testloss = loss(model, dtst)
(testloss=testloss, testppl=exp.(testloss), testbpc=(testloss ./ log(2)))

(testloss = 1.5113528f0, testppl = 4.532859f0, testbpc = 2.1804211571057137)

In [45]:
devloss = loss(model, ddev)
(devloss=devloss, devppl=exp.(devloss), devbpc=(devloss ./ log(2)))

(devloss = 1.5301247f0, devppl = 4.6187525f0, devbpc = 2.2075032651370803)

In [49]:
s = generate(model, start="Syria is", maxlength=1024)

"Syria is an intedrate past different [[Socalase]]. The Regure of The forespe a coroporfare development to Transtorashi enorys similar &quot;the evolutions of &quot;leaving.&quot;"