In [1]:
using Knet

include("../src/data.jl")
include("../src/model.jl")

train! (generic function with 1 method)

In [2]:
datadir = "../data/enwik8"
jld2dir = "../jld2/enwik8.jld2"
BATCHSIZE = 64

if !isfile(jld2dir)
    println("Reading data from directory: $datadir")
    println("Setting batch size to $BATCHSIZE")
    vocab = Vocab("$datadir/train.txt")
    trainfile = TextReader("$datadir/train.txt", vocab)
    validfile = TextReader("$datadir/valid.txt", vocab)
    testfile = TextReader("$datadir/test.txt", vocab)
    dtrn = TextData(trainfile, batchsize=BATCHSIZE)
    ddev = TextData(validfile, batchsize=BATCHSIZE)
    dtst = TextData(testfile, batchsize=BATCHSIZE)
    println("Saving data from $jld2dir")
    Knet.save(jld2dir, "dtrn", dtrn, "dtst", dtst, "ddev", ddev)
else 
    println("Loading data from $jld2dir")
    (dtrn, dtst, ddev) = Knet.load(jld2dir, "dtrn", "dtst", "ddev")
    vocab = dtrn.src.vocab
    if dtrn.batchsize != BATCHSIZE
        changeBatchSize(dtrn, BATCHSIZE)
        changeBatchSize(dtst, BATCHSIZE)
        changeBatchSize(ddev, BATCHSIZE)
    end
end

Loading data from ../jld2/enwik8.jld2


In [3]:
@info "Initializing and Training Language Model"
epochs, em_size, hidden_size, layers = 10, 512, 512, 3
println("embedding size: ", em_size)
println("hidden size: ", hidden_size)
println("layers: ", layers)

println("Collecting training data...")
println("epochs: ", epochs)
ctrn = collect(dtrn)
trnx10 = collect(flatten(shuffle!(ctrn) for i in 1:epochs))
trnmini = ctrn[1:20]
dev = collect(ddev)

model = SimpleLSTMModel(em_size, hidden_size, vocab; layers=layers, dropout=0.2)

┌ Info: Initializing and Training Language Model
└ @ Main In[3]:1


embedding size: 512
hidden size: 512
layers: 3
Collecting training data...
epochs: 10


SimpleLSTMModel(Embed(P(KnetArray{Float32,2}(512,206))), LSTM(input=512,hidden=512,layers=3,dropout=0.2), Linear(P(KnetArray{Float32,2}(206,512))), 0.2, Vocab(Dict("54" => 67,"101" => 4,"41" => 52,"65" => 38,"168" => 126,"159" => 175,"228" => 183,"190" => 117,"227" => 96,"88" => 104…), ["<s>", "<unk>", "32", "101", "116", "97", "105", "111", "110", "114"  …  "210", "239", "211", "198", "212", "240", "205", "220", "222", "200"], 2, 1, split))

In [None]:
model = train!(model, trnx10, dev, trnmini)


┣                    ┫ [0.01%, 1/13740, 00:58/219:42:09, 57.56s/i] (trnloss = (4.8057485f0,), trnppl = (122.21093f0,), trnbpc = (6.933229474863871,), devloss = 4.799313f0, devppl = 121.42698f0, devbpc = 6.923945163439693)
┣                    ┫ [0.11%, 15/13740, 02:17/34:51:34, 5.67s/i] (trnloss = (3.5621734f0,), trnppl = (35.239704f0,), trnbpc = (5.1391298507035685,), devloss = 3.5359335f0, devppl = 34.327045f0, devbpc = 5.10127371752625)
┣                    ┫ [0.20%, 28/13740, 03:37/29:30:55, 6.12s/i] (trnloss = (3.5479896f0,), trnppl = (34.7434f0,), trnbpc = (5.118667010938609,), devloss = 3.523046f0, devppl = 33.887493f0, devbpc = 5.082681017106774)
┣                    ┫ [0.30%, 41/13740, 04:56/27:30:31, 6.08s/i] (trnloss = (3.544656f0,), trnppl = (34.627773f0,), trnbpc = (5.113857688089882,), devloss = 3.519654f0, devppl = 33.772743f0, devbpc = 5.0777874227591235)
┣                    ┫ [0.39%, 54/13740, 06:14/26:26:00, 6.04s/i] (trnloss = (3.5421586f0,), trnppl = (34.5414f0,),

┣▋                   ┫ [3.46%, 475/13740, 48:13/23:14:40, 6.04s/i] (trnloss = (3.4204636f0,), trnppl = (30.58359f0,), trnbpc = (4.934685818455706,), devloss = 3.3990023f0, devppl = 29.93422f0, devbpc = 4.903723781820875)
┣▋                   ┫ [3.55%, 488/13740, 49:31/23:14:14, 6.02s/i] (trnloss = (3.4069054f0,), trnppl = (30.17173f0,), trnbpc = (4.915125543642476,), devloss = 3.385483f0, devppl = 29.532253f0, devbpc = 4.884219573351825)
┣▋                   ┫ [3.65%, 501/13740, 50:49/23:13:52, 6.03s/i] (trnloss = (3.3886914f0,), trnppl = (29.627157f0,), trnbpc = (4.888848314417147,), devloss = 3.3670402f0, devppl = 28.992586f0, devbpc = 4.85761213743684)
┣▋                   ┫ [3.74%, 514/13740, 52:08/23:13:44, 6.06s/i] (trnloss = (3.3785129f0,), trnppl = (29.327126f0,), trnbpc = (4.874163747755877,), devloss = 3.3575373f0, devppl = 28.718378f0, devbpc = 4.84390236844066)
┣▊                   ┫ [3.84%, 527/13740, 53:27/23:13:41, 6.08s/i] (trnloss = (3.3668869f0,), trnppl = (28.988142f

┣█▌                  ┫ [7.66%, 1052/13740, 01:45:59/23:04:16, 6.04s/i] (trnloss = (2.9310327f0,), trnppl = (18.746979f0,), trnbpc = (4.228586279836721,), devloss = 2.9222708f0, devppl = 18.583439f0, devbpc = 4.215945554998305)
┣█▌                  ┫ [7.76%, 1066/13740, 01:47:18/23:03:04, 5.65s/i] (trnloss = (2.9247718f0,), trnppl = (18.629974f0,), trnbpc = (4.2195537510134065,), devloss = 2.916136f0, devppl = 18.469782f0, devbpc = 4.207094983819602)
┣█▌                  ┫ [7.86%, 1080/13740, 01:48:38/23:02:05, 5.71s/i] (trnloss = (2.9129539f0,), trnppl = (18.411102f0,), trnbpc = (4.202504078937471,), devloss = 2.9038196f0, devppl = 18.243694f0, devbpc = 4.189326080297759)
┣█▌                  ┫ [7.96%, 1094/13740, 01:49:58/23:01:08, 5.71s/i] (trnloss = (2.9035847f0,), trnppl = (18.239412f0,), trnbpc = (4.18898727447556,), devloss = 2.8958213f0, devppl = 18.09836f0, devbpc = 4.17778707632077)
┣█▌                  ┫ [8.06%, 1107/13740, 01:51:17/23:01:21, 6.11s/i] (trnloss = (2.8985724f0,

┣██▏                 ┫ [11.19%, 1538/13740, 02:33:40/22:52:51, 5.61s/i] (trnloss = (2.740171f0,), trnppl = (15.489633f0,), trnbpc = (3.9532310489157094,), devloss = 2.7390118f0, devppl = 15.471688f0, devbpc = 3.9515586896187194)
┣██▎                 ┫ [11.29%, 1551/13740, 02:34:58/22:52:53, 6.02s/i] (trnloss = (2.7333798f0,), trnppl = (15.384797f0,), trnbpc = (3.943433541261357,), devloss = 2.7313597f0, devppl = 15.353749f0, devbpc = 3.9405191232598358)
┣██▎                 ┫ [11.39%, 1565/13740, 02:36:19/22:52:20, 5.72s/i] (trnloss = (2.7225819f0,), trnppl = (15.219566f0,), trnbpc = (3.9278553527462035,), devloss = 2.7220638f0, devppl = 15.211683f0, devbpc = 3.9271079161455535)
┣██▎                 ┫ [11.48%, 1578/13740, 02:37:37/22:52:25, 6.04s/i] (trnloss = (2.7216897f0,), trnppl = (15.205994f0,), trnbpc = (3.926568234587146,), devloss = 2.7221475f0, devppl = 15.212956f0, devbpc = 3.9272286479664595)
┣██▎                 ┫ [11.58%, 1591/13740, 02:38:55/22:52:28, 6.02s/i] (trnloss = 

┣██▉                 ┫ [14.64%, 2011/13740, 03:20:03/22:46:48, 5.73s/i] (trnloss = (2.6256804f0,), trnppl = (13.813971f0,), trnbpc = (3.788056159304654,), devloss = 2.6261218f0, devppl = 13.820068f0, devbpc = 3.788692839078148)
┣██▉                 ┫ [14.74%, 2025/13740, 03:21:23/22:46:28, 5.76s/i] (trnloss = (2.6003652f0,), trnppl = (13.468656f0,), trnbpc = (3.751533923567428,), devloss = 2.6019187f0, devppl = 13.489595f0, devbpc = 3.753775201473472)
┣██▉                 ┫ [14.84%, 2039/13740, 03:22:44/22:46:07, 5.75s/i] (trnloss = (2.5840354f0,), trnppl = (13.250502f0,), trnbpc = (3.727975052121637,), devloss = 2.5868287f0, devppl = 13.287566f0, devbpc = 3.732004949596654)

In [None]:
Knet.save("../jld2/baseline-512-3.jld2", "model", model)

In [None]:
testloss = loss(model, dtst)
(testloss=testloss, testppl=exp(testloss), testbpc=(testloss / log(2)))

In [None]:
devloss = loss(model, ddev)
(devloss=devloss, devppl=exp(devloss), devbpc=(devloss / log(2)))

In [None]:
s = generate(model, start="Syria is", maxlength=1024)