In [1]:
using Knet

include("../src/data.jl")
include("../src/model.jl")
include("../src/train.jl")

└ @ CuArrays /kuacc/users/asafaya19/.julia/packages/CuArrays/A6GUx/src/CuArrays.jl:122


In [None]:
datadir = "../data/enwik8"
jld2dir = "../jld2/enwik8.jld2"
BATCHSIZE = 16

if !isfile(jld2dir)
    println("Reading data from directory: $datadir")
    println("Setting batch size to $BATCHSIZE")
    vocab = Vocab("$datadir/train.txt")
    trainfile = TextReader("$datadir/train.txt", vocab)
    validfile = TextReader("$datadir/valid.txt", vocab)
    testfile = TextReader("$datadir/test.txt", vocab)
    dtrn = TextData(trainfile, batchsize=BATCHSIZE)
    ddev = TextData(validfile, batchsize=BATCHSIZE)
    dtst = TextData(testfile, batchsize=BATCHSIZE)
    println("Saving data from $jld2dir")
    Knet.save(jld2dir, "dtrn", dtrn, "dtst", dtst, "ddev", ddev)
else 
    println("Loading data from $jld2dir")
    (dtrn, dtst, ddev) = Knet.load(jld2dir, "dtrn", "dtst", "ddev")
    vocab = dtrn.src.vocab
    if dtrn.batchsize != BATCHSIZE
        changebatchsize!(dtrn, BATCHSIZE)
        changebatchsize!(ddev, BATCHSIZE)
        changebatchsize!(dtst, BATCHSIZE)
    end;
end

In [3]:
@info "Initializing and Training Language Model"
epochs, em_size, hidden_size, layers = 2, 1024, 4096, 4
println("embedding size: ", em_size)
println("hidden size: ", hidden_size)
println("layers: ", layers)

println("Collecting training data...")
println("epochs: ", epochs)
ctrn = collect(dtst) # -> dtrn
trn = collect(flatten(ctrn for i in 1:epochs))
trnmini = ctrn[1:20]
dev = collect(ddev);

model = SHARNN(em_size, hidden_size, vocab, layers);
nothing

┌ Info: Initializing and Training Language Model
└ @ Main In[3]:1


embedding size: 1024
hidden size: 4096
layers: 4
Collecting training data...
epochs: 2


In [38]:
function initopt!(model; lr=0.001)
    for par in params(model)
        par.opt = Adam(; lr=lr)
    end
end

In [None]:
@info "Starting training, total iteration no: $(length(trn))"
# model = train!(model, length(ctrn), trn, dev, trnmini)
initopt!(model, length(ctrn))
model = train!(model, trn, dev; report_iter=200)

┌ Info: Starting training, total iteration no: 1222
└ @ Main In[4]:1


Total iterations = 1222
21:40:51  ->  Dev set scores : (loss = 9.002581f0, ppl = 8124.022f0, bpc = 12.98797844842655)
21:41:19  ->  20 iteration: Training set scores : (loss = 6.38045f0, ppl = 590.1931f0, bpc = 9.205043244533986)
21:41:38  ->  40 iteration: Training set scores : (loss = 4.2261457f0, ppl = 68.45289f0, bpc = 6.097039507409844)
21:41:57  ->  60 iteration: Training set scores : (loss = 3.7057438f0, ppl = 40.680294f0, bpc = 5.346258188166097)
21:42:16  ->  80 iteration: Training set scores : (loss = 3.5031104f0, ppl = 33.218616f0, bpc = 5.053920014437685)
21:42:34  ->  100 iteration: Training set scores : (loss = 3.474502f0, ppl = 32.28175f0, bpc = 5.012646929953024)
21:42:53  ->  120 iteration: Training set scores : (loss = 3.4782512f0, ppl = 32.403008f0, bpc = 5.01805578432266)
21:43:11  ->  140 iteration: Training set scores : (loss = 3.5953724f0, ppl = 36.429264f0, bpc = 5.187025987072955)
21:43:30  ->  160 iteration: Training set scores : (loss = 3.60546f0, ppl = 36.79

In [5]:
@info "Finished training, Starting evaluation ..."
# trnloss = loss(model, dtrn);
# println("Training set scores:       ", report_lm(trnloss))
devloss = loss(model, ddev);
println("Development set scores:    ", report_lm(devloss))
testloss = loss(model, dtst);
println("Test set scores:           ", report_lm(testloss))

# @info "Generate text using the trained model"
# print(generate(model, start="United Nations ", maxlength=1024))

@info "Saving the model as model_x.jld2"
Knet.save("sharnn_first.jld2", "model", model);

┌ Info: Finished training, Starting evaluation ...
└ @ Main In[5]:1


Training set scores:       (loss = 2.3402822f0, ppl = 10.384167f0, bpc = 3.376313526769909)
Development set scores:    (loss = 2.3432722f0, ppl = 10.415262f0, bpc = 3.38062719561885)
Test set scores:           (loss = 2.3392775f0, ppl = 10.373739f0, bpc = 3.374864056988437)


┌ Info: Saving the model as model_x.jld2
└ @ Main In[5]:12
