In [1]:
using Pkg
Pkg.activate("/home/asafaya19/develop-knet/Project.toml")

[32m[1mActivating[22m[39m environment at `~/develop-knet/Project.toml`


In [2]:
using Knet

include("../src/data.jl")
include("../src/xmodel.jl")
include("../src/train.jl")

train! (generic function with 1 method)

In [3]:
datadir = "../data/enwik8"
jld2dir = "../jld2/enwik8.jld2"

if !isfile(jld2dir)
    BATCHSIZE = 64
    println("Reading data from directory: $datadir")
    println("Setting batch size to $BATCHSIZE")
    vocab = Vocab("$datadir/train.txt")
    trainfile = TextReader("$datadir/train.txt", vocab)
    validfile = TextReader("$datadir/valid.txt", vocab)
    testfile = TextReader("$datadir/test.txt", vocab)
    dtrn = TextData(trainfile, batchsize=BATCHSIZE)
    ddev = TextData(validfile, batchsize=BATCHSIZE)
    dtst = TextData(testfile, batchsize=BATCHSIZE)
    println("Saving data from $jld2dir")
    Knet.save(jld2dir, "dtrn", dtrn, "dtst", dtst, "ddev", ddev)
else 
    println("Loading data from $jld2dir")
    (dtrn, dtst, ddev) = Knet.load(jld2dir, "dtrn", "dtst", "ddev")
    vocab = dtrn.src.vocab
end

Loading data from ../jld2/enwik8.jld2


Vocab(Dict("54" => 67,"101" => 4,"41" => 52,"65" => 38,"168" => 126,"159" => 175,"228" => 183,"190" => 117,"227" => 96,"88" => 104…), ["<s>", "<unk>", "32", "101", "116", "97", "105", "111", "110", "114"  …  "210", "239", "211", "198", "212", "240", "205", "220", "222", "200"], 2, 1, split)

In [4]:
@info "Initializing and Training Language Model"
epochs, em_size, hidden_size, layers = 20, 256, 256, 1
println("embedding size: ", em_size)
println("hidden size: ", hidden_size)
println("layers: ", layers)

println("Collecting training data...")
println("epochs: ", epochs)
ctrn = collect(dtrn)
trnx10 = collect(flatten(shuffle!(ctrn) for i in 1:epochs))
trnmini = ctrn[1:20]
dev = collect(ddev);

model = XModel(em_size, hidden_size, vocab; layers=layers, dropout=0.2)

┌ Info: Initializing and Training Language Model
└ @ Main In[4]:1


embedding size: 256
hidden size: 256
layers: 1
Collecting training data...
epochs: 20


XModel(Embed(P(KnetArray{Float32,2}(256,206))), LSTM(input=256,hidden=256,dropout=0.2), Boom(Linear(P(KnetArray{Float32,2}(1024,256)), P(KnetArray{Float32,1}(1024))), Linear(P(KnetArray{Float32,2}(256,1024)), P(KnetArray{Float32,1}(256))), 0.1, false, NNlib.gelu), Linear(P(KnetArray{Float32,2}(206,256)), P(KnetArray{Float32,1}(206))), 0.2, Vocab(Dict("54" => 67,"101" => 4,"41" => 52,"65" => 38,"168" => 126,"159" => 175,"228" => 183,"190" => 117,"227" => 96,"88" => 104…), ["<s>", "<unk>", "32", "101", "116", "97", "105", "111", "110", "114"  …  "210", "239", "211", "198", "212", "240", "205", "220", "222", "200"], 2, 1, split))

In [None]:
model.rnn.h, model.rnn.c = 0, 0
model = train!(model, trnx10, dev, trnmini)


┣                    ┫ [0.00%, 1/27480, 00:08/64:22:30, 8.43s/i] (trnloss = (1.2666105f0,), trnppl = (3.5488036f0,), trnbpc = (1.8273326916997772,), devloss = 1.2705009f0, devppl = 3.5626366f0, devbpc = 1.832945345510685)
┣                    ┫ [0.59%, 162/27480, 00:46/02:09:36, 4.30i/s] (trnloss = (1.2673681f0,), trnppl = (3.551493f0,), trnbpc = (1.8284256414460103,), devloss = 1.2729763f0, devppl = 3.5714664f0, devbpc = 1.8365165652558513)
┣▏                   ┫ [1.15%, 317/27480, 01:23/02:00:15, 4.14i/s] (trnloss = (1.267088f0,), trnppl = (3.5504987f0,), trnbpc = (1.8280216541991336,), devloss = 1.2734703f0, devppl = 3.5732312f0, devbpc = 1.8372292613610266)
┣▎                   ┫ [1.72%, 473/27480, 02:01/01:56:48, 4.17i/s] (trnloss = (1.2689669f0,), trnppl = (3.5571759f0,), trnbpc = (1.8307322727593824,), devloss = 1.2743429f0, devppl = 3.5763507f0, devbpc = 1.8384881743653425)
┣▍                   ┫ [2.28%, 626/27480, 02:38/01:55:37, 4.09i/s] (trnloss = (1.2681619f0,), trnppl = (

┣████                ┫ [20.47%, 5626/27480, 22:38/01:50:34, 4.12i/s] (trnloss = (1.2629907f0,), trnppl = (3.5359807f0,), trnbpc = (1.822110438506328,), devloss = 1.2661588f0, devppl = 3.547201f0, devbpc = 1.826681049435658)
┣████▏               ┫ [21.04%, 5783/27480, 23:15/01:50:31, 4.21i/s] (trnloss = (1.2603744f0,), trnppl = (3.5267417f0,), trnbpc = (1.8183359352678419,), devloss = 1.2648045f0, devppl = 3.5424001f0, devbpc = 1.8247271545392056)
┣████▎               ┫ [21.62%, 5940/27480, 23:53/01:50:28, 4.22i/s] (trnloss = (1.261209f0,), trnppl = (3.5296865f0,), trnbpc = (1.8195399858065324,), devloss = 1.2675376f0, devppl = 3.5520952f0, devbpc = 1.8286702007755373)
┣████▍               ┫ [22.20%, 6100/27480, 24:30/01:50:23, 4.26i/s] (trnloss = (1.2607262f0,), trnppl = (3.5279827f0,), trnbpc = (1.8188434560705382,), devloss = 1.2667795f0, devppl = 3.5494034f0, devbpc = 1.8275765630987006)
┣████▌               ┫ [22.76%, 6254/27480, 25:08/01:50:25, 4.11i/s] (trnloss = (1.2611245f0,), 

In [9]:
testloss = loss(model, dtst)
(testloss=testloss, testppl=exp.(testloss), testbpc=(testloss ./ log(2)))

(testloss = 1.2621013f0, testppl = 3.5328372f0, testbpc = 1.8208272759482407)

In [10]:
devloss = loss(model, ddev)
(devloss=devloss, devppl=exp.(devloss), devbpc=(devloss ./ log(2)))

(devloss = 1.270441f0, devppl = 3.5624235f0, devbpc = 1.8328590102199518)

In [27]:
print(generate(model, start="United Nations ", maxlength=1024))

United Nations higher organisations (as an researcher which shows all the non-sudden bishops, and not national broken next topics); unjustly violated both the Jews and the Family resistants terms allowing them to conspiracy high-executive following thy national [[Social Democratic Purpose (Special Education Technology)|income]] - indirect to acknowledge regarding the theoretical advisory influence: