In [1]:
using Knet

include("../src/data.jl")
include("../src/model.jl")

train! (generic function with 1 method)

In [5]:
datadir = "../data/enwik8"
BATCHSIZE, MAXLENGTH = 64, 256
@info  "Reading data from directory: $datadir ..."
println("Setting batch size to $BATCHSIZE and max word length to $MAXLENGTH")
vocab = Vocab("$datadir/train.txt", vocabsize=256)
trainfile = TextReader("$datadir/train.txt", vocab)
validfile = TextReader("$datadir/valid.txt", vocab)
testfile = TextReader("$datadir/test.txt", vocab)
dtrn = VocabData(trainfile, batchsize=BATCHSIZE, maxlength=MAXLENGTH, bucketwidth = 8)
ddev = VocabData(validfile, batchsize=BATCHSIZE, maxlength=MAXLENGTH, bucketwidth = 8)
dtst = VocabData(testfile, batchsize=BATCHSIZE, maxlength=MAXLENGTH, bucketwidth = 8)

Setting batch size to 64 and max word length to 256


┌ Info: Reading data from directory: ../data/enwik8 ...
└ @ Main In[5]:3


VocabData(TextReader("../data/enwik8/test.txt", Vocab(Dict("Z" => 83,"1" => 32,"ö" => 107,"r" => 9,"л" => 104,"=" => 35,"'" => 24,"’" => 112,"ร" => 188,"y" => 22…), ["<s>", "<unk>", "e", "t", "a", "i", "o", "n", "r", "s"  …  "ј", "£", "Х", "ū", "ิ", "ע", "ล", "イ", "ه", "С"], 2, 1, split)), 64, 256, false, 8, Array{Any,1}[[], [], [], [], [], [], [], [], [], []  …  [], [], [], [], [], [], [], [], [], []], arraybatch)

In [6]:
@info "Initializing and Training Language Model"
epochs, em_size, hidden_size, layers = 5, 256, 256, 2
println("embedding size: ", em_size)
println("hidden size: ", hidden_size)
println("layers: ", layers)

println("Collecting training data...")
println("epochs: ", epochs)
ctrn = collect(dtrn)
trnx10 = collect(flatten(shuffle!(ctrn) for i in 1:epochs))
trnmini = ctrn[1:20]
dev = collect(ddev)

model = SimpleLSTMModel(em_size, hidden_size, vocab; layers=layers, dropout=0.2)

embedding size: 256
hidden size: 256
layers: 2
Collecting training data...
epochs: 5


┌ Info: Initializing and Training Language Model
└ @ Main In[6]:1


SimpleLSTMModel(Embed(P(KnetArray{Float32,2}(256,256))), LSTM(input=256,hidden=256,layers=2,dropout=0.2), Linear(P(KnetArray{Float32,2}(256,256))), 0.2, Vocab(Dict("Z" => 83,"1" => 32,"ö" => 107,"r" => 9,"л" => 104,"=" => 35,"'" => 24,"’" => 112,"ร" => 188,"y" => 22…), ["<s>", "<unk>", "e", "t", "a", "i", "o", "n", "r", "s"  …  "ј", "£", "Х", "ū", "ิ", "ע", "ล", "イ", "ه", "С"], 2, 1, split))

In [7]:
model = train!(model, trnx10, dev, trnmini)


┣                    ┫ [0.00%, 1/57925, 00:22/351:28:50, 21.84s/i] (dev = 3376.349f0, tst = (110.52997f0,), mem = 9.540377f9)
┣                    ┫ [0.37%, 212/57925, 01:14/05:36:03, 4.06i/s] (dev = 2201.861f0, tst = (71.438934f0,), mem = 1.0082636f10)
┣▏                   ┫ [0.73%, 425/57925, 02:06/04:45:25, 4.11i/s] (dev = 1811.915f0, tst = (58.657215f0,), mem = 1.0520187f10)
┣▏                   ┫ [1.09%, 632/57925, 02:58/04:31:28, 3.98i/s] (dev = 1642.3308f0, tst = (53.137196f0,), mem = 1.1058893f10)
┣▎                   ┫ [1.45%, 841/57925, 03:50/04:23:53, 4.01i/s] (dev = 1510.3392f0, tst = (48.876534f0,), mem = 1.1766419f10)
┣▎                   ┫ [1.81%, 1049/57925, 04:42/04:19:30, 3.99i/s] (dev = 1416.7526f0, tst = (45.75126f0,), mem = 1.2507238f10)
┣▍                   ┫ [2.16%, 1252/57925, 05:34/04:17:30, 3.91i/s] (dev = 1360.6943f0, tst = (43.964146f0,), mem = 1.2611052f10)
┣▌                   ┫ [2.52%, 1457/57925, 06:26/04:15:54, 3.92i/s] (dev = 1326.9434f0, tst = (43.01

┣████▍               ┫ [22.34%, 12940/57925, 55:11/04:07:00, 3.94i/s] (dev = 977.43024f0, tst = (31.250051f0,), mem = 1.1795753f10)
┣████▌               ┫ [22.70%, 13147/57925, 56:03/04:06:56, 3.97i/s] (dev = 975.25446f0, tst = (31.125694f0,), mem = 1.2016986f10)
┣████▌               ┫ [23.05%, 13351/57925, 56:55/04:06:55, 3.92i/s] (dev = 974.62335f0, tst = (31.289991f0,), mem = 1.2558035f10)
┣████▋               ┫ [23.40%, 13554/57925, 57:47/04:06:57, 3.88i/s] (dev = 968.45966f0, tst = (31.149591f0,), mem = 1.2545411f10)
┣████▊               ┫ [23.75%, 13760/57925, 58:40/04:06:56, 3.92i/s] (dev = 967.8751f0, tst = (31.237371f0,), mem = 1.2555356f10)
┣████▊               ┫ [24.11%, 13964/57925, 59:32/04:06:58, 3.88i/s] (dev = 965.68506f0, tst = (31.208073f0,), mem = 1.254877f10)
┣████▉               ┫ [24.46%, 14169/57925, 01:00:25/04:06:58, 3.91i/s] (dev = 966.7179f0, tst = (31.072144f0,), mem = 1.2563624f10)
┣████▉               ┫ [24.80%, 14365/57925, 01:01:17/04:07:07, 3.75i/s] (de

┣████████▊           ┫ [44.01%, 25494/57925, 01:48:14/04:05:55, 3.95i/s] (dev = 911.92224f0, tst = (28.793701f0,), mem = 1.2024523f10)
┣████████▊           ┫ [44.37%, 25700/57925, 01:49:06/04:05:55, 3.92i/s] (dev = 909.74207f0, tst = (28.892494f0,), mem = 1.2330441f10)
┣████████▉           ┫ [44.72%, 25904/57925, 01:49:59/04:05:55, 3.90i/s] (dev = 913.1054f0, tst = (29.130175f0,), mem = 1.2549349f10)
┣█████████           ┫ [45.08%, 26112/57925, 01:50:51/04:05:54, 3.96i/s] (dev = 912.7295f0, tst = (29.118929f0,), mem = 1.2557468f10)
┣█████████           ┫ [45.43%, 26318/57925, 01:51:43/04:05:53, 3.96i/s] (dev = 911.5764f0, tst = (28.983528f0,), mem = 1.2557189f10)
┣█████████▏          ┫ [45.79%, 26523/57925, 01:52:35/04:05:54, 3.91i/s] (dev = 913.554f0, tst = (28.901955f0,), mem = 1.2557959f10)
┣█████████▏          ┫ [46.14%, 26729/57925, 01:53:28/04:05:54, 3.92i/s] (dev = 911.45514f0, tst = (28.949097f0,), mem = 1.2557758f10)
┣█████████▎          ┫ [46.50%, 26934/57925, 01:54:21/04:05:

┣█████████████▏      ┫ [65.69%, 38052/57925, 02:41:21/04:05:38, 4.03i/s] (dev = 884.83417f0, tst = (27.975986f0,), mem = 1.242496f10)
┣█████████████▏      ┫ [66.04%, 38254/57925, 02:42:14/04:05:39, 3.86i/s] (dev = 885.4116f0, tst = (28.106436f0,), mem = 1.2470835f10)
┣█████████████▎      ┫ [66.39%, 38459/57925, 02:43:06/04:05:39, 3.93i/s] (dev = 883.059f0, tst = (28.073349f0,), mem = 1.2519102f10)
┣█████████████▎      ┫ [66.75%, 38663/57925, 02:43:58/04:05:39, 3.91i/s] (dev = 885.5615f0, tst = (28.065344f0,), mem = 1.2561816f10)
┣█████████████▍      ┫ [67.10%, 38867/57925, 02:44:50/04:05:40, 3.91i/s] (dev = 884.645f0, tst = (28.06283f0,), mem = 1.2560898f10)
┣█████████████▍      ┫ [67.45%, 39073/57925, 02:45:43/04:05:40, 3.93i/s] (dev = 881.95166f0, tst = (28.124249f0,), mem = 1.2565352f10)
┣█████████████▌      ┫ [67.80%, 39276/57925, 02:46:35/04:05:41, 3.87i/s] (dev = 886.41046f0, tst = (28.07119f0,), mem = 1.2564603f10)
┣█████████████▌      ┫ [68.12%, 39460/57925, 02:47:28/04:05:50, 

┣█████████████████▍  ┫ [87.23%, 50526/57925, 03:34:28/04:05:52, 3.90i/s] (dev = 871.91425f0, tst = (27.482552f0,), mem = 1.2534876f10)
┣█████████████████▌  ┫ [87.58%, 50733/57925, 03:35:20/04:05:52, 3.95i/s] (dev = 870.4995f0, tst = (27.508503f0,), mem = 1.2544305f10)
┣█████████████████▌  ┫ [87.94%, 50938/57925, 03:36:12/04:05:52, 3.93i/s] (dev = 869.0328f0, tst = (27.442917f0,), mem = 1.2554258f10)
┣█████████████████▋  ┫ [88.29%, 51144/57925, 03:37:05/04:05:52, 3.95i/s] (dev = 870.6582f0, tst = (27.50968f0,), mem = 1.253889f10)
┣█████████████████▋  ┫ [88.65%, 51351/57925, 03:37:57/04:05:51, 3.95i/s] (dev = 868.8917f0, tst = (27.559631f0,), mem = 1.2570577f10)
┣█████████████████▊  ┫ [89.00%, 51554/57925, 03:38:50/04:05:52, 3.86i/s] (dev = 870.7847f0, tst = (27.50688f0,), mem = 1.256867f10)
┣█████████████████▊  ┫ [89.32%, 51738/57925, 03:39:42/04:05:59, 3.50i/s] (dev = 871.8919f0, tst = (27.349873f0,), mem = 1.2566233f10)
┣█████████████████▉  ┫ [89.68%, 51948/57925, 03:40:34/04:05:57, 4

SimpleLSTMModel(Embed(P(KnetArray{Float32,2}(256,256))), LSTM(input=256,hidden=256,layers=2,dropout=0.2), Linear(P(KnetArray{Float32,2}(256,256))), 0.2, Vocab(Dict("Z" => 83,"1" => 32,"ö" => 107,"r" => 9,"л" => 104,"=" => 35,"'" => 24,"’" => 112,"ร" => 188,"y" => 22…), ["<s>", "<unk>", "e", "t", "a", "i", "o", "n", "r", "s"  …  "ј", "£", "Х", "ū", "ิ", "ע", "ล", "イ", "ه", "С"], 2, 1, split))

In [8]:
Knet.save("baseline-256-2.jld2", "model", model)

In [62]:
generate(model, start="", del=" ", maxlength=1024)

"<s> | E v e r , f o r m & l t ; b r / & g t ; o c y p e r m a s o b a y ' ' , ' ' + 6 & l t ; s u p & g t ; י ' ' a n d & l t ; / b i g n a m & g t ; <s>"