In [1]:
using Knet

include("../src/data.jl")
include("../src/model.jl")

train! (generic function with 1 method)

In [2]:
datadir = "../data/enwik8"
jld2dir = "../jld2/enwik8.jld2"

if !isfile(jld2dir)
    BATCHSIZE = 64
    println("Reading data from directory: $datadir")
    println("Setting batch size to $BATCHSIZE")
    vocab = Vocab("$datadir/train.txt")
    trainfile = TextReader("$datadir/train.txt", vocab)
    validfile = TextReader("$datadir/valid.txt", vocab)
    testfile = TextReader("$datadir/test.txt", vocab)
    dtrn = TextData(trainfile, batchsize=BATCHSIZE)
    ddev = TextData(validfile, batchsize=BATCHSIZE)
    dtst = TextData(testfile, batchsize=BATCHSIZE)
    println("Saving data from $jld2dir")
    Knet.save(jld2dir, "dtrn", dtrn, "dtst", dtst, "ddev", ddev)
else 
    println("Loading data from $jld2dir")
    (dtrn, dtst, ddev) = Knet.load(jld2dir, "dtrn", "dtst", "ddev")
    vocab = dtrn.src.vocab
end

Loading data from ../jld2/enwik8.jld2


Vocab(Dict("54" => 67,"101" => 4,"41" => 52,"65" => 38,"168" => 126,"159" => 175,"228" => 183,"190" => 117,"227" => 96,"88" => 104…), ["<s>", "<unk>", "32", "101", "116", "97", "105", "111", "110", "114"  …  "210", "239", "211", "198", "212", "240", "205", "220", "222", "200"], 2, 1, split)

In [3]:
@info "Initializing and Training Language Model"
epochs, em_size, hidden_size, layers = 10, 512, 512, 3
println("embedding size: ", em_size)
println("hidden size: ", hidden_size)
println("layers: ", layers)

println("Collecting training data...")
println("epochs: ", epochs)
ctrn = collect(dtrn)
trnx10 = collect(flatten(shuffle!(ctrn) for i in 1:epochs))
trnmini = ctrn[1:20]
dev = collect(ddev)

model = SimpleLSTMModel(em_size, hidden_size, vocab; layers=layers, dropout=0.2)

┌ Info: Initializing and Training Language Model
└ @ Main In[3]:1


embedding size: 512
hidden size: 512
layers: 3
Collecting training data...
epochs: 10


SimpleLSTMModel(Embed(P(KnetArray{Float32,2}(512,206))), LSTM(input=512,hidden=512,layers=3,dropout=0.2), Linear(P(KnetArray{Float32,2}(206,512))), 0.2, Vocab(Dict("54" => 67,"101" => 4,"41" => 52,"65" => 38,"168" => 126,"159" => 175,"228" => 183,"190" => 117,"227" => 96,"88" => 104…), ["<s>", "<unk>", "32", "101", "116", "97", "105", "111", "110", "114"  …  "210", "239", "211", "198", "212", "240", "205", "220", "222", "200"], 2, 1, split))

In [None]:
model = train!(model, trnx10, dev, trnmini)


┣                    ┫ [0.01%, 1/13740, 00:58/219:55:57, 57.62s/i] (trnloss = (4.9918785f0,), trnppl = (147.2127f0,), trnbpc = (7.201758370406836,), devloss = 4.987161f0, devppl = 146.51988f0, devbpc = 7.194952672946893)
┣                    ┫ [0.11%, 15/13740, 02:17/34:47:14, 5.65s/i] (trnloss = (3.5407028f0,), trnppl = (34.49115f0,), trnbpc = (5.108154399421969,), devloss = 3.5379148f0, devppl = 34.395123f0, devbpc = 5.10413206918359)
┣                    ┫ [0.20%, 28/13740, 03:35/29:15:56, 6.00s/i] (trnloss = (3.5248313f0,), trnppl = (33.948048f0,), trnbpc = (5.0852566292860955,), devloss = 3.523438f0, devppl = 33.90078f0, devbpc = 5.08324649606281)
┣                    ┫ [0.30%, 41/13740, 04:53/27:14:44, 6.00s/i] (trnloss = (3.5246823f0,), trnppl = (33.94299f0,), trnbpc = (5.085041650972516,), devloss = 3.5227232f0, devppl = 33.876556f0, devbpc = 5.082215288088237)
┣                    ┫ [0.40%, 55/13740, 06:13/25:52:09, 5.72s/i] (trnloss = (3.522422f0,), trnppl = (33.866356f0,), 

In [None]:
Knet.save("../jld2/baseline-512-3.jld2", "model", model)

In [None]:
testloss = loss(model, dtst)
(testloss=testloss, testppl=exp.(testloss), testbpc=(testloss ./ log(2)))

In [None]:
devloss = loss(model, ddev)
(devloss=devloss, devppl=exp.(devloss), devbpc=(devloss ./ log(2)))

In [None]:
s = generate(model, start="Syria is", maxlength=1024)