# Text-GAN Turkish word generation

In [6]:
using Knet, Test, Base.Iterators, Printf, LinearAlgebra, CuArrays, Random, IterTools

struct Charset
    c2i::Dict{Any,Int}
    i2c::Vector{Any}
    eow::Int
end

function Charset(charset::String; eow="")
    i2c = [ eow; [ c for c in charset ]  ]
    print(i2c)
    c2i = Dict( c => i for (i, c) in enumerate(i2c))
    return Charset(c2i, i2c, c2i[eow])
end

struct TextReader
    file::String
    charset::Charset
end

function Base.iterate(r::TextReader, s=nothing)
    s === nothing && (s = open(r.file))
    eof(s) && return close(s)
    return [ get(r.charset.c2i, c, r.charset.eow) for c in readline(s)], s
end

Base.IteratorSize(::Type{TextReader}) = Base.SizeUnknown()
Base.IteratorEltype(::Type{TextReader}) = Base.HasEltype()
Base.eltype(::Type{TextReader}) = Vector{Int}

struct WordsData
    src::TextReader        
    batchsize::Int         
    maxlength::Int         
    batchmajor::Bool       
    bucketwidth::Int    
    buckets::Vector        
    batchmaker::Function   
end

function WordsData(src::TextReader; batchmaker = arraybatch, batchsize = 128, maxlength = typemax(Int),
                batchmajor = false, bucketwidth = 2, numbuckets = min(128, maxlength ÷ bucketwidth))
    buckets = [ [] for i in 1:numbuckets ] # buckets[i] is an array of sentence pairs with similar length
    WordsData(src, batchsize, maxlength, batchmajor, bucketwidth, buckets, batchmaker)
end

Base.IteratorSize(::Type{WordsData}) = Base.SizeUnknown()
Base.IteratorEltype(::Type{WordsData}) = Base.HasEltype()
Base.eltype(::Type{WordsData}) = NTuple{2}

function Base.iterate(d::WordsData, state=nothing)
    if state == 0 # When file is finished but buckets are partially full 
        for i in 1:length(d.buckets)
            if length(d.buckets[i]) > 0
                batch = d.batchmaker(d, d.buckets[i])
                d.buckets[i] = []
                return batch, state
            end
        end
        return nothing # Finish iteration
    elseif state === nothing
        # Just to make sure
        for i in 1:length(d.buckets)
            d.buckets[i] = []
        end
        state = nothing
    end

    while true
        src_next = iterate(d.src, state)
        
        if src_next === nothing
            state = 0
            return iterate(d, state)
        end
        
        (src_word, src_state) = src_next
        state = src_state
        src_length = length(src_word)
        
        (src_length > d.maxlength) && continue

        i = Int(ceil(src_length / d.bucketwidth))
        i > length(d.buckets) && (i = length(d.buckets))

        push!(d.buckets[i], src_word)
        if length(d.buckets[i]) == d.batchsize
            batch = d.batchmaker(d, d.buckets[i])
            d.buckets[i] = []
            return batch, state
        end
    end
end

# function arraybatch(d::WordsData, bucket)
#     src_eow = d.src.charset.eow
#     src_lengths = map(x -> length(x), bucket)
#     max_length = max(src_lengths...)
#     x = zeros(Int64, length(bucket), max_length + 2) # default d.batchmajor is false

#     for (i, v) in enumerate(bucket)
#         to_be_added = fill(src_eow, max_length - length(v) + 1)
#         x[i,:] = [src_eow; v; to_be_added]
#     end

#     d.batchmajor && (x = x')
#     return (x[:, 1:end-1], x[:, 2:end])
# end

function readwordset(fname)
    words = []
    fi = open(fname)
    while !eof(fi)
        push!(words, readline(fi))
    end
    close(fi)
    words
end

readwordset (generic function with 1 method)

### G/D/S Common Parts

In [7]:
struct Embed; w; end

function Embed(shape...)
    Embed(param(shape...))
end

Embed

## Generator

In [8]:
get_z(shape...) = KnetArray(randn(Float32, shape...))


### Not used 
# concatenate z with embedding vectors, z -> (z_size, B), returns (E+z_size, B, T)
# this will be used to feed Z to generator at each timestep
# function (l::Embed)(x, z)
#     em = l.w[:, x]
#     z_array = cat((z for i in 1:size(em, 3))...; dims=(3))
#     cat(em, z_array; dims=(1))
# end

# Generator model
struct GModel
    projection::Embed
    rnn::RNN        
    dropout::Real
    charset::Charset 
end

function GModel(hidden::Int, charset::Charset; layers=2, dropout=0)
    rnn = RNN(1, hidden; numLayers=layers, dropout=dropout) # input size is 1
    projection = Embed(hidden, length(charset.i2c))
    GModel(projection, rnn, dropout, charset)
end

# Generator forward pass, here Z is our latent var -> (H, Tx, )
function (s::GModel)(timesteps, batchsize)
    s.rnn.h = get_z(s.rnn.hiddenSize, batchsize, s.rnn.numLayers) # according to get_z(H, B, layers)
    s.rnn.c = get_z(s.rnn.hiddenSize, batchsize, s.rnn.numLayers) # according to get_z(H, B, layers)
    rnn_out = s.rnn(KnetArray(ones(Float32, (1, batchsize, timesteps))))
    dims = size(rnn_out)
    output = s.projection.w' * dropout(reshape(rnn_out, dims[1], dims[2] * dims[3]), s.dropout)
    reshape(softmax(output), size(output, 1), dims[2], dims[3])
end

function generate(s::GModel, maxlength, batchsize)
    out = s(maxlength, batchsize)
    words = []
    for i in 1:batchsize
        push!(words, join([s.charset.i2c[x[1]] for x in argmax(out[:, i, :]; dims=1)], ""))
    end
    words
end

generate (generic function with 1 method)

## Word Sampler

### TODO: word sampler will be used to train discriminator. 
this sampler should take B, T (batchsize, timestep) as parameters
returns (X, Y) tuple 
where X is tensor of size (C, B, T)
and Y is array of size B
B consists of real words and generated words
C charset size where each value is weight of this char
in the case of generated words the generator already gives C, B, T
for real words we need to convert words to C, T arrays
where every character can be represented by one hot vector or by Gumble-Max (which is normalized one hot vector)

In [None]:
struct Sampler
    
end

function Base.iterate(s::Sampler, state=nothing)
    
end

## Discriminator

In [4]:
# This one to be used by DModel, takes weights of characters and reduce the embedding for each character
# this approach to avoid sampling or argmaxing over rnn's output
# (C, B, T) -> (T, E, 1, B)
function (l::Embed)(x)
    dims = size(out)
    em = l.w * reshape(x, dims[1], dims[2] * dims[3]) # reshape for multiplication 
    em = reshape(em, size(em, 1), dims[2], dims[3]) # reshape to original size
    em = permutedims(em, [3, 1, 2])  # permute for CONV
    em = reshape(em, dims[3], size(em, 2), 1, dims[2]) # Add one dim for CONV
end

struct Conv; w; b; f; p; end
(c::Conv)(x) = (co=conv4(c.w, dropout(x,c.p)); c.f.(pool((co .+ c.b); window=(size(co, 1), size(co, 2)))))
Conv(w1::Int,w2::Int,cx::Int,cy::Int,f=relu;pdrop=0) = Conv(param(w1,w2,cx,cy), param0(1,1,cy,1), f, pdrop)

struct Dense; w; b; f; p; end
(d::Dense)(x) = d.f.(d.w * mat(dropout(x,d.p)) .+ d.b) # mat reshapes 4-D tensor to 2-D matrix so we can use matmul
Dense(i::Int,o::Int,f=relu;pdrop=0) = Dense(param(o,i), param0(o), f, pdrop)

# Perform convolution then, global-max pooling and concatenate the output and feed it to sequential dense layer 
mutable struct DisModel
    charset::Charset
    embed::Embed
    filters
    dense_layers
end

# This discriminator uses separate weights for its embedding layer
function DisModel(charset, embeddingSize::Int, filters, denselayers)
    Em = Embed(embed_size, length(tr_charset.c2i))
    DisModel(charset, Em, filters, denselayers)
end

# This discriminator shares the projection layers weights of the generator for its embedding layer
function DisModel(charset, embeddingLayer::Embed, filters, denselayers)
    DisModel(charset, embeddingLayer, filters, denselayers)
end

function (c::DisModel)(x) # the input here is weights of the characters with shape (C, B, T)
    em = c.embed(x)
    filters_out = []
    for f in c.filters
        push!(filters_out, f(em))
    end
    out = cat(filters_out...;dims=3)
    for l in c.dense_layers
        out = l(out)
    end
    out
end

(c::DisModel)(x,y; average=true) = nll(c(x), y; average=average)

# per-word loss (in this case per-batch loss)
function loss(model, data; average=true)
    l = 0
    n = 0
    a = 0
    for (x, y) in data
        v = model(x, y; average=false)
        l += v[1]
        n += v[2]
        a += (v[1] / v[2])
    end
    average && return a
    return l, n
end

In [5]:
char_set = "ABCDEFGHIJKLMNOPRSTUVYZabcdefghijklmnoprstuvyzÇÖÜçöüĞğİıŞş"
tr_charset = Charset(char_set)

embedding_size = 64
gmodel = GModel(embedding_size, tr_charset; dropout=0.2)
out = gmodel(30, 16) # (T, B) -> (C, B, T)

# generate(gmodel, 30, 10)
filter_no = 20
dismodel = DisModel(tr_charset, gmodel.projection, (
        Conv(2,embedding_size,1,filter_no; pdrop=0.2),
        Conv(3,embedding_size,1,filter_no; pdrop=0.2),
        Conv(4,embedding_size,1,filter_no; pdrop=0.2),
        ),(
        Dense(60,64,pdrop=0.3),
        Dense(64,2,sigm,pdrop=0.3)
        ))


dismodel(gmodel(30, 16)) 

Any["", 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'R', 'S', 'T', 'U', 'V', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'r', 's', 't', 'u', 'v', 'y', 'z', 'Ç', 'Ö', 'Ü', 'ç', 'ö', 'ü', 'Ğ', 'ğ', 'İ', 'ı', 'Ş', 'ş']

2×16 KnetArray{Float32,2}:
 0.499999  0.499999  0.5       0.499999  …  0.5  0.499999  0.499999  0.5
 0.499999  0.499999  0.499999  0.499999     0.5  0.499999  0.5       0.5