# Text-GAN Turkish word generation

In [2]:
using Knet, Test, Base.Iterators, Printf, LinearAlgebra, CuArrays, Random, IterTools, StatsBase

struct Charset
    c2i::Dict{Any,Int}
    i2c::Vector{Any}
    eow::Int
end

function Charset(charset::String; eow="")
    i2c = [ eow; [ c for c in charset ]  ]
    print(i2c)
    c2i = Dict( c => i for (i, c) in enumerate(i2c))
    return Charset(c2i, i2c, c2i[eow])
end

struct TextReader
    file::String
    charset::Charset
end

function Base.iterate(r::TextReader, s=nothing)
    s === nothing && (s = open(r.file))
    eof(s) && return close(s)
    return [ get(r.charset.c2i, c, r.charset.eow) for c in readline(s)], s
end

Base.IteratorSize(::Type{TextReader}) = Base.SizeUnknown()
Base.IteratorEltype(::Type{TextReader}) = Base.HasEltype()
Base.eltype(::Type{TextReader}) = Vector{Int}

struct WordsData
    src::TextReader        
    batchsize::Int         
    maxlength::Int         
    batchmajor::Bool       
    bucketwidth::Int    
    buckets::Vector        
    batchmaker::Function   
end

function WordsData(src::TextReader; batchmaker = arraybatch, batchsize = 128, maxlength = typemax(Int),
                batchmajor = false, bucketwidth = 2, numbuckets = min(128, maxlength ÷ bucketwidth))
    buckets = [ [] for i in 1:numbuckets ] # buckets[i] is an array of sentence pairs with similar length
    WordsData(src, batchsize, maxlength, batchmajor, bucketwidth, buckets, batchmaker)
end

Base.IteratorSize(::Type{WordsData}) = Base.SizeUnknown()
Base.IteratorEltype(::Type{WordsData}) = Base.HasEltype()
Base.eltype(::Type{WordsData}) = NTuple{2}

function Base.iterate(d::WordsData, state=nothing)
    if state == 0 # When file is finished but buckets are partially full 
        for i in 1:length(d.buckets)
            if length(d.buckets[i]) > 0
                batch = d.batchmaker(d, d.buckets[i])
                d.buckets[i] = []
                return batch, state
            end
        end
        return nothing # Finish iteration
    elseif state === nothing
        # Just to make sure
        for i in 1:length(d.buckets)
            d.buckets[i] = []
        end
        state = nothing
    end

    while true
        src_next = iterate(d.src, state)
        
        if src_next === nothing
            state = 0
            return iterate(d, state)
        end
        
        (src_word, src_state) = src_next
        state = src_state
        src_length = length(src_word)
        
        (src_length > d.maxlength) && continue

        i = Int(ceil(src_length / d.bucketwidth))
        i > length(d.buckets) && (i = length(d.buckets))

        push!(d.buckets[i], src_word)
        if length(d.buckets[i]) == d.batchsize
            batch = d.batchmaker(d, d.buckets[i])
            d.buckets[i] = []
            return batch, state
        end
    end
end

function arraybatch(d::WordsData, bucket)
    src_eow = d.src.charset.eow
    src_lengths = map(x -> length(x), bucket)
    max_length = max(src_lengths...)
    x = zeros(Int64, length(bucket), max_length + 2) # default d.batchmajor is false

    for (i, v) in enumerate(bucket)
        to_be_added = fill(src_eow, max_length - length(v) + 1)
        x[i,:] = [src_eow; v; to_be_added]
    end

    d.batchmajor && (x = x')
    return (x[:, 1:end-1], x[:, 2:end])
end

function readwordset(fname)
    words = []
    fi = open(fname)
    while !eof(fi)
        push!(words, readline(fi))
    end
    close(fi)
    words
end

readwordset (generic function with 1 method)

### G/D Common Parts

In [23]:
struct Embed; w; end

function Embed(charsetsize::Int, embedsize::Int)
    Embed(param(embedsize, charsetsize))
end

### Not used 
# concatenate z with embedding vectors, z -> (z_size, B), returns (E+z_size, B, T)
# this will be used to feed Z to generator at each timestep
# function (l::Embed)(x, z)
#     em = l.w[:, x]
#     z_array = cat((z for i in 1:size(em, 3))...; dims=(3))
#     cat(em, z_array; dims=(1))
# end

# This one to be used by DModel
# (E, B, T) -> (T, E, 1, B)
function (l::Embed)(x)
    em=permutedims(l.w[:, x], [3, 1, 2])
    ds=size(em)
    em=reshape(em, ds[1], ds[2], 1, ds[3])
end

In [83]:
get_z(shape...) = KnetArray(randn(Float32, shape...))
rnn = RNN(1, 32; bidirectional=false, numLayers=2)
rnn.c, rnn.h = get_z(32, 1, 2), get_z(32, 1, 2)
ou = rnn(KnetArray(ones(Float32, (1, 4, 6))))
# rnn.h

AssertionError: AssertionError: r.h == nothing || (r.h == 0 || vec(value(r.h)) isa WTYPE && (ndims(r.h) <= 3 && (size(r.h, 1), size(r.h, 2), size(r.h, 3)) == HSIZE))

### Generator Parts

In [None]:
get_z(shape...) = KnetArray(randn(Float32, shape...))

# Generator model
struct GModel
    projection::Embed
    rnn::RNN        
    dropout::Real
    charset::Charset 
end

function GModel(hidden::Int, charset::Charset; layers=2, dropout=0)
    rnn = RNN(1, hidden; bidirectional=false, numLayers=layers, dropout=dropout)
    projection = Embed(length(charset.i2c), hidden)
    LModel(projection, rnn, dropout, charset)
end

# Generator forward pass, here Z is our latent var -> (H, Tx, )
function (s::GModel)(z; average=true)
    s.rnn.h, s.rnn.c = z[], z[]
    srcembed = s.srcembed(src)
    rnn_out = s.rnn(srcembed)
    dims = size(rnn_out)
    output = s.projection(dropout(reshape(rnn_out, dims[1], dims[2] * dims[3]), s.dropout))
    scores = reshape(output, size(output, 1), dims[2], dims[3])
    nll(scores, mask(tgt, s.srccharset.eow); dims=1, average=average)
end

In [10]:


# Generating words using the LM with sampling
function generate(s::LModel; start="", maxlength=30)
    s.rnn.h, s.rnn.c = 0, 0
    chars = fill(s.srccharset.eow, 1)
    
    starting_index = 1
    for i in 1:length(start)
        push!(chars, s.srccharset.c2i[start[i]])
        charembed = s.srcembed(chars[i:i])
        rnn_out = s.rnn(charembed)
        starting_index += 1
    end
    
    for i in starting_index:maxlength
        charembed = s.srcembed(chars[i:i])
        rnn_out = s.rnn(charembed)
        output = model.projection(dropout(rnn_out, model.dropout))
        push!(chars, s.srccharset.c2i[ sample(s.srccharset.i2c, Weights(Array(softmax(reshape(output, length(s.srccharset.i2c)))))) ] )
        
        if chars[end] == s.srccharset.eow
            break
        end
    end
    
    join([ s.srccharset.i2c[i] for i in chars ], "")
end

1×50 KnetArray{Float64,2}:
 -0.0747458  -1.13555  1.09597  -1.56335  …  -1.62976  -1.24693  0.66351