In [1]:
using Knet

include("../src/data.jl")
include("../src/model.jl")

train! (generic function with 1 method)

In [2]:
datadir = "../data/enwik8"
jld2dir = "../jld2/enwik8.jld2"

if !isfile(jld2dir)
    BATCHSIZE = 64
    println("Reading data from directory: $datadir")
    println("Setting batch size to $BATCHSIZE")
    vocab = Vocab("$datadir/train.txt")
    trainfile = TextReader("$datadir/train.txt", vocab)
    validfile = TextReader("$datadir/valid.txt", vocab)
    testfile = TextReader("$datadir/test.txt", vocab)
    dtrn = TextData(trainfile, batchsize=BATCHSIZE)
    ddev = TextData(validfile, batchsize=BATCHSIZE)
    dtst = TextData(testfile, batchsize=BATCHSIZE)
    println("Saving data from $jld2dir")
    Knet.save(jld2dir, "dtrn", dtrn, "dtst", dtst, "ddev", ddev)
else 
    println("Loading data from $jld2dir")
    (dtrn, dtst, ddev) = Knet.load(jld2dir, "dtrn", "dtst", "ddev")
    vocab = dtrn.src.vocab
end

Loading data from ../jld2/enwik8.jld2


Vocab(Dict("54" => 67,"101" => 4,"41" => 52,"65" => 38,"168" => 126,"159" => 175,"228" => 183,"190" => 117,"227" => 96,"88" => 104…), ["<s>", "<unk>", "32", "101", "116", "97", "105", "111", "110", "114"  …  "210", "239", "211", "198", "212", "240", "205", "220", "222", "200"], 2, 1, split)

In [12]:
dtrn.bptt = 1024
dtst.bptt = 1024
ddev.bptt = 1024

1024

In [13]:
@info "Initializing and Training Language Model"
epochs, em_size, hidden_size, layers = 5, 1024, 1024, 2
println("embedding size: ", em_size)
println("hidden size: ", hidden_size)
println("layers: ", layers)

println("Collecting training data...")
println("epochs: ", epochs)
ctrn = collect(dtrn)
trnx10 = collect(flatten(shuffle!(ctrn) for i in 1:epochs))
trnmini = ctrn[1:20]
dev = collect(ddev);

model = SimpleLSTMModel(em_size, hidden_size, vocab; layers=layers, dropout=0.2)

embedding size: 1024
hidden size: 1024
layers: 2
Collecting training data...
epochs: 5


┌ Info: Initializing and Training Language Model
└ @ Main In[13]:1


77-element Array{Tuple{Array{Int16,2},Array{Int16,2}},1}:
 ([4 10 … 3 8; 6 9 … 66 4; … ; 5 13 … 6 9; 36 36 … 10 21], [10 9 … 8 22; 9 14 … 4 29; … ; 13 4 … 9 14; 36 36 … 21 4])   
 ([22 3 … 15 4; 29 6 … 11 13; … ; 14 12 … 3 5; 4 9 … 11 3], [3 57 … 4 3; 6 12 … 13 7; … ; 12 4 … 5 4; 9 15 … 3 5])      
 ([3 26 … 17 17; 7 9 … 20 5; … ; 4 51 … 10 15; 5 8 … 3 5], [26 6 … 17 3; 9 21 … 5 81; … ; 51 5 … 15 7; 8 3 … 5 13])     
 ([3 26 … 4 9; 81 66 … 7 16; … ; 7 6 … 5 8; 13 4 … 7 9], [26 4 … 9 5; 66 8 … 16 19; … ; 6 12 … 8 19; 4 10 … 9 3])       
 ([5 16 … 19 6; 19 17 … 22 10; … ; 19 3 … 4 3; 3 5 … 4 3], [16 6 … 6 76; 17 17 … 10 8; … ; 3 18 … 3 7; 5 13 … 3 18])    
 ([76 8 … 5 23; 8 19 … 10 8; … ; 7 9 … 3 26; 18 18 … 6 12], [8 10 … 23 32; 19 3 … 8 20; … ; 9 3 … 26 8; 18 11 … 12 3])  
 ([32 39 … 10 8; 20 4 … 15 4; … ; 8 10 … 7 15; 3 5 … 6 9], [39 8 … 8 20; 4 81 … 4 9; … ; 10 35 … 15 13; 5 13 … 9 21])   
 ([20 4 … 10 6; 9 5 … 6 24; … ; 13 3 … 17 3; 21 4 … 6 5], [4 3 … 6 24; 5 3 … 24 7; … ; 3 22 … 3

In [None]:
model.rnn.h, model.rnn.c = 0, 0
model = train!(model, trnx10, dev, trnmini)


┣                    ┫ [0.01%, 1/6870, 00:24/46:07:30, 24.17s/i] (trnloss = (1.123224f0,), trnppl = (3.0747514f0,), trnbpc = (1.6204697234675298,), devloss = 1.1463436f0, devppl = 3.1466663f0, devbpc = 1.6538242107585137)
┣                    ┫ [0.61%, 42/6870, 01:18/03:32:24, 1.31s/i] (trnloss = (1.1155245f0,), trnppl = (3.0511682f0,), trnbpc = (1.6093617080136027,), devloss = 1.1381273f0, devppl = 3.1209185f0, devbpc = 1.6419706505130964)
┣▏                   ┫ [1.22%, 84/6870, 02:12/02:59:24, 1.28s/i] (trnloss = (1.1135364f0,), trnppl = (3.0451078f0,), trnbpc = (1.6064933813625129,), devloss = 1.1352942f0, devppl = 3.112089f0, devbpc = 1.637883310832691)
┣▎                   ┫ [1.73%, 119/6870, 03:05/02:58:05, 1.53s/i] (trnloss = (1.1109736f0,), trnppl = (3.037314f0,), trnbpc = (1.6027960983342642,), devloss = 1.135086f0, devppl = 3.1114414f0, devbpc = 1.6375830291242843)
┣▍                   ┫ [2.26%, 155/6870, 03:59/02:56:25, 1.49s/i] (trnloss = (1.1111416f0,), trnppl = (3.037824

┣███▊                ┫ [19.02%, 1307/6870, 32:34/02:51:10, 1.49s/i] (trnloss = (1.089672f0,), trnppl = (2.9732985f0,), trnbpc = (1.572064346468938,), devloss = 1.1178998f0, devppl = 3.058424f0, devbpc = 1.612788462332044)
┣███▉                ┫ [19.56%, 1344/6870, 33:27/02:51:00, 1.45s/i] (trnloss = (1.0891196f0,), trnppl = (2.9716566f0,), trnbpc = (1.5712673788648397,), devloss = 1.1165205f0, devppl = 3.0542088f0, devbpc = 1.6107986230615614)
┣████                ┫ [20.09%, 1380/6870, 34:21/02:50:59, 1.49s/i] (trnloss = (1.0891672f0,), trnppl = (2.9717982f0,), trnbpc = (1.5713361719251848,), devloss = 1.1174508f0, devppl = 3.0570512f0, devbpc = 1.6121407756688948)
┣████                ┫ [20.61%, 1416/6870, 35:14/02:50:55, 1.47s/i] (trnloss = (1.0874133f0,), trnppl = (2.9665904f0,), trnbpc = (1.5688057911830404,), devloss = 1.1156461f0, devppl = 3.0515392f0, devbpc = 1.6095371303174828)
┣████▏               ┫ [21.14%, 1452/6870, 36:07/02:50:55, 1.49s/i] (trnloss = (1.087939f0,), trnppl

┣███████▌            ┫ [37.90%, 2604/6870, 01:04:43/02:50:45, 1.48s/i] (trnloss = (1.0694133f0,), trnppl = (2.9136696f0,), trnbpc = (1.5428372708160105,), devloss = 1.1008003f0, devppl = 3.0065713f0, devbpc = 1.588119098909632)
┣███████▋            ┫ [38.43%, 2640/6870, 01:05:37/02:50:45, 1.50s/i] (trnloss = (1.0683601f0,), trnppl = (2.9106026f0,), trnbpc = (1.5413178040956377,), devloss = 1.1002121f0, devppl = 3.0048032f0, devbpc = 1.5872705365102748)
┣███████▊            ┫ [38.95%, 2676/6870, 01:06:31/02:50:46, 1.50s/i] (trnloss = (1.0659176f0,), trnppl = (2.903502f0,), trnbpc = (1.5377940515621096,), devloss = 1.099919f0, devppl = 3.0039225f0, devbpc = 1.5868476311718032)
┣███████▉            ┫ [39.48%, 2712/6870, 01:07:25/02:50:46, 1.49s/i] (trnloss = (1.0661724f0,), trnppl = (2.9042418f0,), trnbpc = (1.5381615784870033,), devloss = 1.0991608f0, devppl = 3.001646f0, devbpc = 1.5857538215123157)
┣████████            ┫ [40.00%, 2748/6870, 01:08:18/02:50:46, 1.49s/i] (trnloss = (1.066

In [10]:
testloss = loss(model, dtst)
(testloss=testloss, testppl=exp.(testloss), testbpc=(testloss ./ log(2)))

(testloss = 1.1416304f0, testppl = 3.1318705f0, testbpc = 1.6470245326913509)

In [11]:
devloss = loss(model, ddev)
(devloss=devloss, devppl=exp.(devloss), devbpc=(devloss ./ log(2)))

(devloss = 1.1472075f0, devppl = 3.149386f0, devbpc = 1.6550705690293166)

In [22]:
s = generate(model, start="Ali Safaya is ", maxlength=1024)

"Ali Safaya is a term in popularization as it is quite firmly developed in all of its own drugs."

In [10]:
s = generate(model, start="Syrian Arab Republic is", maxlength=1024)

"Syrian Arab Republic is divorced once there were France by the [[History of the Pacific]].  The French failed permanently defined to Ishtar Bulgaria during the Capital War well dependent on full health claims of peoples every year by the [[West Bank and India]] to separate [[King of Torus]] and the possibleslites of [[User:Peranum|Jueva, Orthodox]], with a record of [[Walker End]], where the early saves sower the outbreak.  Due to a desirable creek, they were also ''Location'' led by San Jose (out of 1947)."