In [1]:
using Knet, Test, Base.Iterators, Printf, LinearAlgebra, CuArrays, Random, IterTools, StatsBase

struct Charset
    c2i::Dict{Any,Int}
    i2c::Vector{Any}
    eow::Int
end

function Charset(charset::String; eow="")
    i2c = [ eow; [ c for c in charset ]  ]
    print(i2c)
    c2i = Dict( c => i for (i, c) in enumerate(i2c))
    return Charset(c2i, i2c, c2i[eow])
end

struct TextReader
    file::String
    charset::Charset
end

function Base.iterate(r::TextReader, s=nothing)
    s === nothing && (s = open(r.file))
    eof(s) && return close(s)
    return [ get(r.charset.c2i, c, r.charset.eow) for c in readline(s)], s
end

Base.IteratorSize(::Type{TextReader}) = Base.SizeUnknown()
Base.IteratorEltype(::Type{TextReader}) = Base.HasEltype()
Base.eltype(::Type{TextReader}) = Vector{Int}

struct WordsData
    src::TextReader        
    batchsize::Int         
    maxlength::Int         
    batchmajor::Bool       
    bucketwidth::Int    
    buckets::Vector        
end

function WordsData(src::TextReader; batchsize = 128, maxlength = typemax(Int),
                batchmajor = false, bucketwidth = 2, numbuckets = min(128, maxlength ÷ bucketwidth))
    buckets = [ [] for i in 1:numbuckets ] # buckets[i] is an array of sentence pairs with similar length
    WordsData(src, batchsize, maxlength, batchmajor, bucketwidth, buckets)
end

Base.IteratorSize(::Type{WordsData}) = Base.SizeUnknown()
Base.IteratorEltype(::Type{WordsData}) = Base.HasEltype()
Base.eltype(::Type{WordsData}) = Array{Any,1}

function Base.iterate(d::WordsData, state=nothing)
    if state == 0 # When file is finished but buckets are partially full 
        for i in 1:length(d.buckets)
            if length(d.buckets[i]) > 0
                buc = d.buckets[i]
                d.buckets[i] = []
                return buc, state
            end
        end
        return nothing # Finish iteration
    end

    while true
        src_next = iterate(d.src, state)
        
        if src_next === nothing
            state = 0
            return iterate(d, state)
        end
        
        (src_word, src_state) = src_next
        state = src_state
        src_length = length(src_word)
        
        (src_length > d.maxlength) && continue

        i = Int(ceil(src_length / d.bucketwidth))
        i > length(d.buckets) && (i = length(d.buckets))

        push!(d.buckets[i], src_word)
        if length(d.buckets[i]) == d.batchsize
            buc = d.buckets[i]
            d.buckets[i] = []
            return buc, state
        end
    end
end

function arraybatch(d::WordsData, bucket)
    src_eow = d.src.charset.eow
    src_lengths = map(x -> length(x), bucket)
    max_length = max(src_lengths...)
    x = zeros(Int64, length(bucket), d.maxlength + 1) # default d.batchmajor is false

    for (i, v) in enumerate(bucket)
        to_be_added = fill(src_eow, d.maxlength - length(v))
        x[i,:] = [src_eow; v; to_be_added]
    end
    
    d.batchmajor && (x = x')
    return (x[:, 1:end-1], x[:, 2:end]) # to calculate nll on generators output directly
end

function readwordset(fname)
    words = []
    fi = open(fname)
    while !eof(fi)
        push!(words, readline(fi))
    end
    close(fi)
    words
end

function mask(a, pad)
    a = copy(a)
    for i in 1:size(a, 1)
        j = size(a,2)
        while a[i, j] == pad && j > 1
            if a[i, j - 1] == pad
                a[i, j] = 0
            end
            j -= 1
        end
    end
    return a
end

struct Embed; w; end

function Embed(shape...)
    Embed(param(shape...))
end

get_z(shape...) = KnetArray(randn(Float32, shape...))

# this function is similar to gumble softmax, it is used to soften the one-hot-vector of the real samples
# tau -> normalization factor; the bigger the softer
function soften(A; dims=1, tau=0.5, norm_factor=0.01) 
    A = (A .+ norm_factor) ./ tau
    softmax(A; dims=dims)
end

# per-word loss (in this case per-batch loss)
function loss(model, data; average=true)
    l = 0
    n = 0
    a = 0
    for (x, y) in data
        v = model(x, y; average=false)
        l += v[1]
        n += v[2]
        a += (v[1] / v[2])
    end
    average && return a
    return l, n
end

function (l::Embed)(x)
    dims = size(x)
    em = l.w * reshape(x, dims[1], dims[2] * dims[3]) # reshape for multiplication 
    em = reshape(em, size(em, 1), dims[2], dims[3]) # reshape to original size
end

struct Dense; w; b; f; p; end
(d::Dense)(x) = d.f.(d.w * mat(dropout(x,d.p)) .+ d.b) # mat reshapes 3-D tensor to 2-D matrix so we can use matmul
Dense(i::Int,o::Int,f=relu;pdrop=0) = Dense(param(o,i), param0(o), f, pdrop)

mutable struct DisModel
    charset::Charset
    embed::Embed
    rnn::RNN
    denselayers
end

# This discriminator uses separate weights for its embedding layer
function DisModel(charset, embeddingSize::Int, hidden, denselayers; layers=1, dropout=0)
    Em = Embed(embeddingSize, length(charset.c2i))
    rnn = RNN(embeddingSize, hidden; numLayers=layers, dropout=dropout)
    DisModel(charset, Em, rnn, denselayers)
end

function (c::DisModel)(x) # the input here is weights of the characters with shape (C, B, T)
    c.rnn.h, c.rnn.c = 0, 0
    em = c.embed(x)
    rnn_out = c.rnn(em)
    dims = size(rnn_out)
    rnn_out = reshape(rnn_out, :, dims[2] * dims[3] )
    for l in c.denselayers
        rnn_out = l(rnn_out)
    end
    reshape(rnn_out, :, dims[2], dims[3])
end

function (c::DisModel)(x, reward::Int; average=true)
    scores = softmax(c(x))
    scores = reshape(scores, :, size(scores, 2) * size(scores, 3))
    -log.(scores[1, :])
end

function (c::DisModel)(x, y; average=true)
    scores = reshape(c(x), :, size(y, 1) * size(y, 2))
    labels = reshape(y, size(y, 1) * size(y, 2))
    return nll(scores, y; average=average)
end

# concatenate z with embedding vectors, z -> (z_size, B), returns (E+z_size, B, T)
# this will be used to feed Z to generator at each timestep
function (l::Embed)(x, z)
    em = l.w[:, x]
    z_array = cat((z for i in 1:size(em, 3))...; dims=(3))
    cat(em, z_array; dims=(1))
end

# Generator model
struct GenModel
    embed::Embed
    rnn::RNN        
    dropout::Real
    charset::Charset
    projection::Embed
    disModel::DisModel
    maxlength::Int
    zsize::Int
end

function GenModel(esize::Int, zsize::Int, hidden::Int, charset::Charset, disModel::DisModel, maxlength::Int; layers=2, dropout=0)
    embed = Embed(esize, length(charset.i2c))
    rnn = RNN(zsize + esize, hidden; numLayers=layers, dropout=dropout)
    projection = Embed(hidden, length(charset.i2c))
    GenModel(embed, rnn, dropout, charset, projection, disModel, maxlength, zsize)
end

# This generator shares the projection layers weights of the discriminator for its projection layer
function GenModel(esize::Int, zsize::Int, charset::Charset, disModel::DisModel, maxlength::Int; layers=2, dropout=0)
    embed = Embed(esize, length(charset.i2c))
    rnn = RNN(zsize + esize, size(disModel.embed.w, 1); numLayers=layers, dropout=dropout)
    GenModel(embed, rnn, dropout, charset, disModel.embed, disModel, maxlength, zsize)
end

# Generator forward pass using Z and Teacher forcing for input
function (s::GenModel)(GenInput) # tuple (input, Z)
    (input, _), Z = GenInput
    s.rnn.h, s.rnn.c = 0, 0
    input = s.embed(input, Z)
    rnn_out = s.rnn(input)
    dims = size(rnn_out)
    output = s.projection.w' * reshape(rnn_out, dims[1], dims[2] * dims[3])
    scores = reshape(output, size(output, 1), dims[2], dims[3])
end

# Generator loss
function (s::GenModel)(GenInput, calculateloss::Int; average=true)
    # since the discriminator will output 2 for the fake data, 
    #    we train the generator to get 1 as output from the discriminator
    (_, output), Z = GenInput
    x = s(GenInput)
    dloss = s.disModel(softmax(x), 1)
    scores = reshape(x, :, size(output, 1) * size(output, 2))
    output = mask(reshape(output, size(output, 1) * size(output, 2)), s.charset.eow)
    glosses = [nll(scores[:, i], output[i:i]) * dloss[i] for i in 1:size(output, 1) ]
    average && return mean(glosses)
    return sum(glosses), length(glosses)
end

function generate(s::GenModel; start="", maxlength=30)
    s.rnn.h, s.rnn.c = 0, 0
    Z = get_z(s.zsize, 1, 1)
    chars = fill(s.charset.eow, 1)

    starting_index = 1
    for i in 1:length(start)
        push!(chars, s.charset.c2i[start[i]])
        charembed = s.embed(chars[i:i], Z)
        rnn_out = s.rnn(charembed)
        starting_index += 1
    end
    
    for i in starting_index:maxlength
        charembed = s.embed(chars[i:i], Z)
        rnn_out = s.rnn(charembed)
        dims = size(rnn_out)
        output = s.projection.w' * reshape(rnn_out, dims[1], dims[2] * dims[3])
        push!(chars, s.charset.c2i[ sample(s.charset.i2c, Weights(Array(softmax(reshape(output, length(s.charset.i2c)))))) ] )
#         push!(chars, argmax(output)[1])
        if chars[end] == s.charset.eow
            break
        end
    end
    
    join([ s.charset.i2c[i] for i in chars ], "")
end

struct Sampler
    wordsdata::WordsData
    charset::Charset
    genModel::GenModel
    maxBatchsize::Int
end

Base.IteratorSize(::Type{Sampler}) = Base.SizeUnknown()
Base.IteratorEltype(::Type{Sampler}) = Base.HasEltype()
Base.eltype(::Type{Sampler}) = Tuple{KnetArray{Float32,3},Array{Int64,2}}

function Base.iterate(s::Sampler, state=nothing)
    wdatastate = iterate(s.wordsdata, state)
    wdatastate === nothing && (return nothing)
    
    (bucket, state) = wdatastate
    bsize = length(bucket)
    src_eow = s.charset.eow
    src_lengths = map(x -> length(x), bucket)
    max_length = max(src_lengths...)
    gsize = bsize
    generated = softmax(s.genModel((arraybatch(s.wordsdata, bucket), get_z(s.genModel.zsize, gsize, 1))))

    to_be_cat = [generated, ]
    for (i, v) in enumerate(bucket)
        tindex = [i for i in 1:length(v)]
        pindex = [i for i in length(v)+1:s.wordsdata.maxlength]
        onehot = KnetArray(zeros(Float32, length(s.charset.c2i), 1, s.wordsdata.maxlength))
        onehot[v, :, tindex] .= 1
        onehot[s.charset.eow, :, pindex] .= 1
        onehot = soften(onehot) # soften one hot vectors elements value
        push!(to_be_cat, onehot)
    end
    x = cat(to_be_cat...;dims=2) # concatenate both generated and sampled words

    y = Array(ones(Int, gsize+bsize, s.wordsdata.maxlength)) # create labels 1 -> real, 2-> not-real
    y[1:gsize, :] = y[1:gsize, :] .+ 1
    
    ind = shuffle(1:gsize+bsize) # used to shuffle the batch
    x, y = x[:, ind, :], y[ind, :]
    return (x,y), state
end

function train!(model, parameters, trn, dev, tst; lr=0.001)
    bestmodel, bestloss = deepcopy(model), loss(model, dev)
    progress!(adam(model, trn; lr=lr, params=parameters), seconds=30) do y
        devloss = loss(model, dev)
        tstloss = loss(model, tst)
        if devloss < bestloss
            bestloss, bestmodel = devloss, deepcopy(model)
        end
        println(stderr)
        (dev=devloss, tst=tstloss, mem=Float32(CuArrays.usage[]))
    end
    return bestmodel
end

char_set = "ABCDEFGHIJKLMNOPRSTUVYZabcdefghijklmnoprstuvyzÇÖÜçöüĞğİıŞş"
tr_charset = Charset(char_set)
datadir = "turkish_word_set"
BATCHSIZE = 128
MAXLENGTH = 15
tr_dev = TextReader("$datadir/dev.tr", tr_charset)
tr_trn = TextReader("$datadir/train.tr", tr_charset)
dtrn = WordsData(tr_trn, batchsize=BATCHSIZE, maxlength=MAXLENGTH, bucketwidth = 1)
ddev = WordsData(tr_dev, batchsize=BATCHSIZE, maxlength=MAXLENGTH, bucketwidth = 1)

EMBEDDING_SIZE = 256
DHIDDEN_SIZE = 128
GDROPOUT = 0.1
DDROPOUT = 0.3

dismodel = DisModel(tr_charset, EMBEDDING_SIZE, DHIDDEN_SIZE,(
        Dense(DHIDDEN_SIZE, 2, identity),
        ); dropout=DDROPOUT)

GH_SIZE = 256
Z_SIZE = 128

genmodel = GenModel(EMBEDDING_SIZE, Z_SIZE, GH_SIZE, tr_charset, dismodel, MAXLENGTH; dropout=GDROPOUT, layers=2)
trnsampler = Sampler(dtrn, tr_charset, genmodel, BATCHSIZE * 2)
devsampler = Sampler(ddev, tr_charset, genmodel, BATCHSIZE * 2)

ctrn = collect(dtrn)
cdev = collect(ddev)
collecttrn = [ ((arraybatch(dtrn, i), get_z(Z_SIZE, size(i, 1), 1)), 1) for i in ctrn ]
collectdev = [ ((arraybatch(ddev, i), get_z(Z_SIZE, size(i, 1), 1)), 1) for i in cdev ]

function gmodel(batches)
    global genmodel
    global collecttrn
    global collectdev
    
    trnxbatches = shuffle!(collecttrn)[1:batches]
    devbatches = shuffle!(collectdev)
    trnmini = trnxbatches[1:5]

    genmodel = train!(genmodel, params(genmodel)[1:3], trnxbatches, devbatches, trnmini)
end

function dmodel(batches)
    global trnsampler
    global devsampler
    global dismodel
    
    ctrn = collect(trnsampler)
    ctrn = shuffle!(ctrn)[1:batches]
    trnmini = ctrn[1:5]
    dev = collect(devsampler)
    dismodel = train!(dismodel, params(dismodel), ctrn, dev, trnmini) 
end

@info "Started training..."
for k in 1:10
    println("Turn no:", k)
    println("Ex.Generated words: \n", join([ generate(genmodel; maxlength=MAXLENGTH) for i in 1:5 ],"\n"))
    dmodel(50)
    gmodel(400)
end

Knet.save("text-gan-model.jld2", "genmodel", genmodel)

Any["", 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'R', 'S', 'T', 'U', 'V', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'r', 's', 't', 'u', 'v', 'y', 'z', 'Ç', 'Ö', 'Ü', 'ç', 'ö', 'ü', 'Ğ', 'ğ', 'İ', 'ı', 'Ş', 'ş']

┌ Info: Started training...
└ @ Main In[1]:404


Turn no:1
Ex.Generated words: 
şJm
ÖıFiçDnuKğvozdS
OTMdüfPĞSPaSGJĞ
ıŞÖdÜeDEaüçJpEĞ
BÜAfŞUZçcCIjrrU



┣▍                   ┫ [2.00%, 1/50, 00:06/04:46, 5.72s/i] (dev = 131.41374f0, tst = 3.4245439f0, mem = 6.182463f9)
┣████████████████████┫ [100.00%, 50/50, 00:06/00:06, 8.04i/s] (dev = 7.8969607f0, tst = 0.20604204f0, mem = 7.9979955f9)

┣▏                   ┫ [1.00%, 4/400, 01:28/02:26:05, 16.95s/i] (dev = 12936.694f0, tst = 325.64725f0, mem = 8.067862f9)
┣▎                   ┫ [1.75%, 7/400, 02:19/02:12:08, 17.03s/i] (dev = 9760.602f0, tst = 241.76859f0, mem = 8.067862f9)
┣▌                   ┫ [2.50%, 10/400, 03:09/02:06:03, 16.78s/i] (dev = 8549.583f0, tst = 232.1902f0, mem = 7.7486853f9)
┣▋                   ┫ [3.25%, 13/400, 04:00/02:03:15, 17.09s/i] (dev = 8187.918f0, tst = 222.55331f0, mem = 7.217809f9)
┣▊                   ┫ [4.00%, 16/400, 04:51/02:01:09, 16.81s/i] (dev = 7893.323f0, tst = 211.15501f0, mem = 7.782547f8)
┣█                   ┫ [5.00%, 20/400, 05:52/01:57:20, 15.31s/i] (dev = 7839.6567f0, tst = 208.82922f0, mem = 7.8243104f8)
┣█▏                  ┫ [6.00%, 24/

┣████████████▌       ┫ [62.50%, 250/400, 01:03:26/01:41:29, 14.92s/i] (dev = 6045.0713f0, tst = 162.26588f0, mem = 7.6938784f8)
┣████████████▋       ┫ [63.50%, 254/400, 01:04:25/01:41:27, 14.87s/i] (dev = 6072.5293f0, tst = 163.15495f0, mem = 7.6938784f8)
┣████████████▉       ┫ [64.50%, 258/400, 01:05:25/01:41:26, 14.96s/i] (dev = 5983.3726f0, tst = 160.3439f0, mem = 7.6938784f8)
┣█████████████       ┫ [65.50%, 262/400, 01:06:25/01:41:24, 14.88s/i] (dev = 5928.483f0, tst = 158.52542f0, mem = 7.6949946f8)
┣█████████████▎      ┫ [66.50%, 266/400, 01:07:24/01:41:21, 14.85s/i] (dev = 5843.703f0, tst = 156.42667f0, mem = 7.6845037f8)
┣█████████████▌      ┫ [67.50%, 270/400, 01:08:24/01:41:20, 14.91s/i] (dev = 5827.4346f0, tst = 156.2443f0, mem = 7.6740186f8)
┣█████████████▋      ┫ [68.50%, 274/400, 01:09:23/01:41:18, 14.90s/i] (dev = 5748.543f0, tst = 153.95212f0, mem = 7.6740186f8)
┣█████████████▉      ┫ [69.50%, 278/400, 01:10:22/01:41:15, 14.75s/i] (dev = 5705.105f0, tst = 152.82034f0, m

Turn no:2
Ex.Generated words: 
ÖPyebük
a
HaDnu
Khüninizi
rrağar



┣▍                   ┫ [2.00%, 1/50, 00:00/00:10, 4.94i/s] (dev = 7.7236104f0, tst = 0.20174919f0, mem = 8.007759f9)
┣████████████████████┫ [100.00%, 50/50, 00:00/00:00, 108.75i/s] (dev = 3.696226f0, tst = 0.096399754f0, mem = 8.761974f9)

┣                    ┫ [0.25%, 1/400, 00:30/03:22:01, 30.30s/i] (dev = 5073.255f0, tst = 132.35794f0, mem = 8.761768f9)
┣▎                   ┫ [1.25%, 5/400, 01:30/01:59:59, 14.92s/i] (dev = 5063.355f0, tst = 131.46065f0, mem = 8.762299f9)
┣▍                   ┫ [2.25%, 9/400, 02:30/01:50:48, 14.90s/i] (dev = 5104.2246f0, tst = 131.98465f0, mem = 8.762299f9)
┣▋                   ┫ [3.25%, 13/400, 03:29/01:47:22, 14.94s/i] (dev = 5045.0776f0, tst = 130.94238f0, mem = 8.762452f9)
┣▊                   ┫ [4.25%, 17/400, 04:21/01:42:13, 12.82s/i] (dev = 5067.54f0, tst = 131.83981f0, mem = 8.510676f9)
┣█                   ┫ [5.25%, 21/400, 05:20/01:41:42, 14.92s/i] (dev = 5067.8486f0, tst = 131.86841f0, mem = 8.248651f8)
┣█▎                  ┫ [6.25%, 25/

┣████████████▊       ┫ [64.25%, 257/400, 01:03:45/01:39:13, 14.95s/i] (dev = 4800.7656f0, tst = 124.97799f0, mem = 7.3944403f8)
┣█████████████       ┫ [65.25%, 261/400, 01:04:45/01:39:13, 14.88s/i] (dev = 4776.4297f0, tst = 124.34866f0, mem = 7.3911635f8)
┣█████████████▎      ┫ [66.25%, 265/400, 01:05:44/01:39:14, 14.91s/i] (dev = 4773.262f0, tst = 124.406456f0, mem = 7.7339174f8)
┣█████████████▍      ┫ [67.25%, 269/400, 01:06:44/01:39:14, 14.89s/i] (dev = 4763.808f0, tst = 124.06004f0, mem = 7.7575354f8)
┣█████████████▋      ┫ [68.25%, 273/400, 01:07:44/01:39:14, 14.96s/i] (dev = 4756.3296f0, tst = 123.608246f0, mem = 7.7575354f8)
┣█████████████▊      ┫ [69.25%, 277/400, 01:08:43/01:39:14, 14.86s/i] (dev = 4734.5605f0, tst = 123.235275f0, mem = 7.7575354f8)
┣██████████████      ┫ [70.25%, 281/400, 01:09:40/01:39:11, 14.33s/i] (dev = 4770.098f0, tst = 124.45134f0, mem = 7.7575354f8)
┣██████████████▎     ┫ [71.25%, 285/400, 01:10:40/01:39:10, 14.79s/i] (dev = 4798.113f0, tst = 125.49953

Turn no:3
Ex.Generated words: 
Tafı
Balmlora
Vağımadı
memafıl
kitik



┣▍                   ┫ [2.00%, 1/50, 00:00/00:09, 5.83i/s] (dev = 3.6541483f0, tst = 0.095102f0, mem = 7.858611f9)
┣████████████████████┫ [100.00%, 50/50, 00:00/00:00, 118.13i/s] (dev = 2.2943914f0, tst = 0.05972531f0, mem = 8.821144f9)

┣                    ┫ [0.25%, 1/400, 00:30/03:19:35, 29.94s/i] (dev = 4627.3975f0, tst = 103.371956f0, mem = 8.821133f9)
┣▎                   ┫ [1.25%, 5/400, 01:29/01:59:10, 14.86s/i] (dev = 4648.254f0, tst = 102.407425f0, mem = 8.888824f9)
┣▍                   ┫ [2.25%, 9/400, 02:29/01:50:10, 14.84s/i] (dev = 4624.3706f0, tst = 102.11043f0, mem = 8.888824f9)
┣▋                   ┫ [3.25%, 13/400, 03:26/01:45:49, 14.40s/i] (dev = 4631.164f0, tst = 102.94819f0, mem = 8.888824f9)
┣▊                   ┫ [4.25%, 17/400, 04:26/01:44:11, 14.83s/i] (dev = 4632.516f0, tst = 102.36133f0, mem = 8.632533f9)
┣█                   ┫ [5.25%, 21/400, 05:25/01:43:12, 14.86s/i] (dev = 4629.2944f0, tst = 101.13604f0, mem = 8.693246f8)
┣█▎                  ┫ [6.25%, 25

┣████████████▊       ┫ [64.25%, 257/400, 01:03:26/01:38:44, 14.86s/i] (dev = 4479.2437f0, tst = 103.91239f0, mem = 8.942523f8)
┣█████████████       ┫ [65.25%, 261/400, 01:04:26/01:38:44, 14.81s/i] (dev = 4451.716f0, tst = 101.84515f0, mem = 8.7511546f8)
┣█████████████▎      ┫ [66.25%, 265/400, 01:05:25/01:38:44, 14.84s/i] (dev = 4445.393f0, tst = 101.19854f0, mem = 8.2585376f8)
┣█████████████▍      ┫ [67.25%, 269/400, 01:06:24/01:38:45, 14.83s/i] (dev = 4435.837f0, tst = 102.83292f0, mem = 7.7383706f8)
┣█████████████▋      ┫ [68.25%, 273/400, 01:07:24/01:38:45, 14.94s/i] (dev = 4431.1323f0, tst = 102.80873f0, mem = 7.546947f8)
┣█████████████▊      ┫ [69.25%, 277/400, 01:08:23/01:38:46, 14.85s/i] (dev = 4419.304f0, tst = 102.86017f0, mem = 8.018766f8)
┣██████████████      ┫ [70.25%, 281/400, 01:09:23/01:38:46, 14.94s/i] (dev = 4412.406f0, tst = 102.875916f0, mem = 8.087743f8)
┣██████████████▎     ┫ [71.25%, 285/400, 01:10:23/01:38:46, 14.85s/i] (dev = 4409.8286f0, tst = 103.03392f0, mem

Turn no:4
Ex.Generated words: 
Agılman
sartışlır
ildikeci
pesi
hüdeniyor



┣▍                   ┫ [2.00%, 1/50, 00:00/00:11, 4.63i/s] (dev = 2.2758296f0, tst = 0.05924733f0, mem = 8.25018f9)
┣████████████████████┫ [100.00%, 50/50, 00:00/00:00, 104.46i/s] (dev = 1.586433f0, tst = 0.04130251f0, mem = 8.882484f9)

┣                    ┫ [0.25%, 1/400, 00:30/03:19:30, 29.92s/i] (dev = 4343.2056f0, tst = 121.21285f0, mem = 8.882407f9)
┣▎                   ┫ [1.25%, 5/400, 01:29/01:59:12, 14.87s/i] (dev = 4336.5254f0, tst = 119.97962f0, mem = 8.882425f9)
┣▍                   ┫ [2.25%, 9/400, 02:29/01:50:17, 14.87s/i] (dev = 4352.5254f0, tst = 118.66106f0, mem = 8.882425f9)
┣▋                   ┫ [3.25%, 13/400, 03:28/01:46:49, 14.85s/i] (dev = 4366.7344f0, tst = 118.408966f0, mem = 8.882417f9)
┣▊                   ┫ [4.25%, 17/400, 04:27/01:44:51, 14.77s/i] (dev = 4348.2295f0, tst = 118.32151f0, mem = 8.523946f9)
┣█                   ┫ [5.25%, 21/400, 05:27/01:43:53, 14.97s/i] (dev = 4322.0957f0, tst = 118.348724f0, mem = 7.9845126f8)
┣█▎                  ┫ [6.25%

┣████████████▊       ┫ [64.25%, 257/400, 01:03:41/01:39:07, 14.89s/i] (dev = 4204.5137f0, tst = 116.24959f0, mem = 7.991335f8)
┣█████████████       ┫ [65.25%, 261/400, 01:04:41/01:39:08, 14.92s/i] (dev = 4197.1978f0, tst = 115.839966f0, mem = 7.995267f8)
┣█████████████▎      ┫ [66.25%, 265/400, 01:05:32/01:38:55, 12.75s/i] (dev = 4199.9014f0, tst = 114.73277f0, mem = 8.164804f8)
┣█████████████▍      ┫ [67.25%, 269/400, 01:06:31/01:38:55, 14.85s/i] (dev = 4267.663f0, tst = 115.23774f0, mem = 8.158389f8)
┣█████████████▋      ┫ [68.25%, 273/400, 01:07:31/01:38:56, 14.94s/i] (dev = 4254.289f0, tst = 114.671715f0, mem = 8.115109f8)
┣█████████████▊      ┫ [69.25%, 277/400, 01:08:31/01:38:56, 14.91s/i] (dev = 4200.9785f0, tst = 114.063995f0, mem = 8.089711f8)
┣██████████████      ┫ [70.25%, 281/400, 01:09:30/01:38:57, 14.94s/i] (dev = 4208.8047f0, tst = 115.255196f0, mem = 7.687917f8)
┣██████████████▎     ┫ [71.25%, 285/400, 01:10:30/01:38:57, 14.95s/i] (dev = 4210.321f0, tst = 115.518524f0, 

Turn no:5
Ex.Generated words: 
deden
Tar
GÜYLEM
vavidağınlarınd
zesir



┣▍                   ┫ [2.00%, 1/50, 00:00/00:10, 5.10i/s] (dev = 1.5757319f0, tst = 0.041045777f0, mem = 8.2868275f9)
┣████████████████████┫ [100.00%, 50/50, 00:00/00:00, 112.81i/s] (dev = 1.1504602f0, tst = 0.029966407f0, mem = 8.86127f9)

┣                    ┫ [0.25%, 1/400, 00:30/03:19:31, 29.93s/i] (dev = 4167.0645f0, tst = 117.425125f0, mem = 8.861272f9)
┣▎                   ┫ [1.25%, 5/400, 01:29/01:59:14, 14.88s/i] (dev = 4165.734f0, tst = 116.524536f0, mem = 8.861324f9)
┣▍                   ┫ [2.25%, 9/400, 02:29/01:50:28, 14.92s/i] (dev = 4172.3535f0, tst = 113.60322f0, mem = 8.861324f9)
┣▋                   ┫ [3.25%, 13/400, 03:29/01:47:09, 14.95s/i] (dev = 4191.423f0, tst = 113.2457f0, mem = 8.861447f9)
┣▊                   ┫ [4.25%, 17/400, 04:29/01:45:21, 14.93s/i] (dev = 4165.6113f0, tst = 112.75492f0, mem = 8.5453245f9)
┣█                   ┫ [5.25%, 21/400, 05:29/01:44:20, 15.00s/i] (dev = 4152.3115f0, tst = 112.992134f0, mem = 7.989858f8)
┣█▎                  ┫ [6.2

┣████████████▉       ┫ [64.50%, 258/400, 01:03:24/01:38:18, 14.99s/i] (dev = 4093.7075f0, tst = 111.328354f0, mem = 8.3530336f8)
┣█████████████       ┫ [65.50%, 262/400, 01:04:24/01:38:19, 14.95s/i] (dev = 4067.809f0, tst = 110.912346f0, mem = 8.353158f8)
┣█████████████▎      ┫ [66.50%, 266/400, 01:05:24/01:38:20, 14.89s/i] (dev = 4065.7034f0, tst = 111.076416f0, mem = 8.353158f8)
┣█████████████▌      ┫ [67.50%, 270/400, 01:06:23/01:38:21, 14.91s/i] (dev = 4057.19f0, tst = 111.2415f0, mem = 8.3412794f8)
┣█████████████▋      ┫ [68.50%, 274/400, 01:07:23/01:38:22, 14.90s/i] (dev = 4051.5059f0, tst = 112.17588f0, mem = 8.335299f8)
┣█████████████▉      ┫ [69.75%, 279/400, 01:08:22/01:38:02, 11.92s/i] (dev = 4051.446f0, tst = 112.4168f0, mem = 8.331993f8)
┣██████████████▏     ┫ [70.75%, 283/400, 01:09:22/01:38:03, 14.90s/i] (dev = 4050.0015f0, tst = 112.62209f0, mem = 8.1746336f8)
┣██████████████▎     ┫ [71.75%, 287/400, 01:10:22/01:38:04, 14.94s/i] (dev = 4074.88f0, tst = 114.11263f0, mem 

Turn no:6
Ex.Generated words: 
ulbildiğiniz
eliştiğ
talzak
Yakilene
duturlette



┣▍                   ┫ [2.00%, 1/50, 00:00/00:09, 5.64i/s] (dev = 1.1434118f0, tst = 0.029774528f0, mem = 7.6934815f9)
┣████████████████████┫ [100.00%, 50/50, 00:00/00:00, 111.79i/s] (dev = 0.8530328f0, tst = 0.022213204f0, mem = 8.881943f9)

┣                    ┫ [0.25%, 1/400, 00:30/03:19:25, 29.91s/i] (dev = 4017.1536f0, tst = 100.132904f0, mem = 8.881943f9)
┣▎                   ┫ [1.25%, 5/400, 01:30/01:59:47, 14.98s/i] (dev = 4032.1711f0, tst = 99.43105f0, mem = 8.881944f9)
┣▍                   ┫ [2.25%, 9/400, 02:29/01:50:41, 14.90s/i] (dev = 4032.6487f0, tst = 98.51256f0, mem = 8.881896f9)
┣▋                   ┫ [3.25%, 13/400, 03:29/01:47:13, 14.91s/i] (dev = 4091.6973f0, tst = 99.65001f0, mem = 8.881867f9)
┣▊                   ┫ [4.25%, 17/400, 04:28/01:45:13, 14.81s/i] (dev = 4094.1638f0, tst = 99.50047f0, mem = 8.541483f9)
┣█                   ┫ [5.25%, 21/400, 05:28/01:44:12, 14.98s/i] (dev = 4035.8088f0, tst = 97.96848f0, mem = 7.8809914f8)
┣█▎                  ┫ [6.25%,

┣████████████▊       ┫ [64.25%, 257/400, 01:03:30/01:38:51, 14.89s/i] (dev = 3966.4988f0, tst = 98.26141f0, mem = 8.2448525f8)
┣█████████████       ┫ [65.25%, 261/400, 01:04:30/01:38:51, 14.84s/i] (dev = 3934.7551f0, tst = 97.55963f0, mem = 8.2448525f8)
┣█████████████▎      ┫ [66.25%, 265/400, 01:05:29/01:38:51, 14.92s/i] (dev = 3933.2686f0, tst = 97.39904f0, mem = 8.246163f8)
┣█████████████▍      ┫ [67.25%, 269/400, 01:06:29/01:38:52, 14.97s/i] (dev = 3928.7366f0, tst = 97.204056f0, mem = 8.332566f8)
┣█████████████▋      ┫ [68.25%, 273/400, 01:07:29/01:38:53, 14.96s/i] (dev = 3923.136f0, tst = 97.062454f0, mem = 8.332566f8)
┣█████████████▊      ┫ [69.25%, 277/400, 01:08:29/01:38:53, 14.93s/i] (dev = 3923.5994f0, tst = 97.02804f0, mem = 8.33258f8)
┣██████████████      ┫ [70.25%, 281/400, 01:09:29/01:38:54, 14.90s/i] (dev = 3924.6892f0, tst = 96.95517f0, mem = 8.33258f8)
┣██████████████▎     ┫ [71.25%, 285/400, 01:10:28/01:38:54, 14.91s/i] (dev = 3921.1782f0, tst = 97.01894f0, mem = 8.3

Turn no:7
Ex.Generated words: 
nıfsul
Günelliyle
kırmalarını
soğluları
AGMİRÇARISA



┣▍                   ┫ [2.00%, 1/50, 00:00/00:09, 5.85i/s] (dev = 0.8480584f0, tst = 0.02208944f0, mem = 7.9679514f9)
┣████████████████████┫ [100.00%, 50/50, 00:00/00:00, 117.56i/s] (dev = 0.63978076f0, tst = 0.016664011f0, mem = 8.966383f9)

┣                    ┫ [0.25%, 1/400, 00:30/03:22:02, 30.30s/i] (dev = 3900.6162f0, tst = 97.99396f0, mem = 8.966128f9)

InterruptException: InterruptException:

In [2]:
Knet.save("text-gan-model.jld2", "genmodel", genmodel)

In [10]:
function generate(s::GenModel; start="", maxlength=30)
    s.rnn.h, s.rnn.c = 0, 0
    Z = get_z(s.zsize, 1, 1)
    chars = fill(s.charset.eow, 1)

    starting_index = 1
    for i in 1:length(start)
        push!(chars, s.charset.c2i[start[i]])
        charembed = s.embed(chars[i:i], Z)
        rnn_out = s.rnn(charembed)
        starting_index += 1
    end
    
    for i in starting_index:maxlength
        charembed = s.embed(chars[i:i], Z)
        rnn_out = s.rnn(charembed)
        dims = size(rnn_out)
        output = s.projection.w' * reshape(rnn_out, dims[1], dims[2] * dims[3])
        push!(chars, s.charset.c2i[ sample(s.charset.i2c, Weights(Array(softmax(reshape(output, length(s.charset.i2c)))))) ] )
#         push!(chars, argmax(output)[1])
        if chars[end] == s.charset.eow
            break
        end
    end
    
    join([ s.charset.i2c[i] for i in chars ], "")
end

generate (generic function with 1 method)