In [1]:
using Statistics  # mean()
using Random  # randperm()
using Dates

include("./Functional.jl")
include("./Layer.jl")
include("./Optimizer.jl")
include("./MNIST.jl")
;

In [2]:
Random.seed!(2019)
;

### Load MNIST

In [3]:
_x_train = MNIST.images(:train)
_y_train = MNIST.labels(:train)

x_train = convert(Array{Float64, 2}, hcat([vec(Float64.(x)) for x in _x_train]...))
x_train = x_train .* 2.0 .- 1.0
y_train = Functional.onehot(Float64, _y_train, 0:9)
@show size(x_train), size(y_train)
;

(size(x_train), size(y_train)) = ((784, 60000), (10, 60000))


### Define DataLoader

In [4]:
function dataloader(x, y, ;batch_size=1, shuffle=false)
    function producer(c::Channel, x, y, batch_size, shuffle)
        data_size = size(x, 2)
        if shuffle
            randidx = randperm(data_size)
            x = x[:, randidx]
            y = y[:, randidx]
        end
        i = 1
        while i < data_size-batch_size
            put!(c, (x[:, i:i+batch_size-1], y[:, i:i+batch_size-1]))
            i += batch_size
        end
        put!(c, (x[:, i:end], y[:, i:end]))
    end

    ch = Channel((ch_arg) -> producer(ch_arg, x, y, batch_size,  shuffle))
    return ch
end
;

In [5]:
mutable struct Generator{T}
    a1lyr::Layer.AffineLayer{T}
    leakyrelu1lyr::Layer.ReluLayer
#     bn1::Layer.BatchNormalization
a2lyr::Layer.AffineLayer{T}
    leakyrelu2lyr::Layer.ReluLayer
#     bn2::Layer.BatchNormalization
    a3lyr::Layer.AffineLayer{T}
    tanhlyr::Layer.TanhLayer
    params
end

function (::Type{Generator{T}})(input_size::Int, hidden_size::Int, hidden_size2::Int, output_size::Int; weight_init_std::Float64=0.1) where T
    W1 = weight_init_std .* randn(T, hidden_size, input_size)
    b1 = zeros(T, hidden_size)
    W2 = weight_init_std .* randn(T, hidden_size2, hidden_size)
    b2 = zeros(T, hidden_size2)
    W3 = weight_init_std .* randn(T, output_size, hidden_size2)
    b3 = zeros(T, output_size)
#     gamma1 = ones(hidden_size)
#     beta1 = zeros(hidden_size)
#     gamma2 = ones(hidden_size2)
#     beta2 = zeros(hidden_size2)
    
    a1lyr = Layer.AffineLayer(W1, b1)
    leakyrelu1lyr = Layer.ReluLayer()
#     bn1 = Layer.BatchNormalization(gamma1, beta1)
    a2lyr = Layer.AffineLayer(W2, b2)
    leakyrelu2lyr = Layer.ReluLayer()
#     bn2 = Layer.BatchNormalization(gamma2, beta2)
    a3lyr = Layer.AffineLayer(W3, b3)
    tanhlyr = Layer.TanhLayer()
    params = [a1lyr.W, a1lyr.b, a2lyr.W, a2lyr.b, a3lyr.W, a3lyr.b] #, bn1.gamma, bn1.beta, bn2.gamma, bn2.beta]
    Generator(a1lyr, leakyrelu1lyr, a2lyr, leakyrelu2lyr, a3lyr, tanhlyr, params)
end

function setparams(net::Generator, params)
    net.a1lyr.W = params[1]
    net.a1lyr.b = params[2]
    net.a2lyr.W = params[3]
    net.a2lyr.b = params[4]
    net.a3lyr.W = params[5]
    net.a3lyr.b = params[6]
#     net.bn1.gamma = params[7]
#     net.bn1.beta = params[8]
#     net.bn2.gamma = params[9]
#     net.bn2.beta = params[10]
end

function forward(net::Generator, x)
    x = Layer.forward(net.a1lyr, x)
#     x = Layer.forward(net.bn1, x)
    x = Layer.forward(net.leakyrelu1lyr, x)
    
    x = Layer.forward(net.a2lyr, x)
#     x = Layer.forward(net.bn2, x)
    x = Layer.forward(net.leakyrelu2lyr, x)
    
    x = Layer.forward(net.a3lyr, x)
    output = Layer.forward(net.tanhlyr, x)
    return output
end

function backward(net::Generator, y)
    y = Layer.backward(net.tanhlyr, y)
    y = Layer.backward(net.a3lyr, y)
    
    y = Layer.backward(net.leakyrelu2lyr, y)
#     y = Layer.backward(net.bn2, y)
    y = Layer.backward(net.a2lyr, y)
    
    y = Layer.backward(net.leakyrelu1lyr, y)
#     y = Layer.backward(net.bn1, y)
    y = Layer.backward(net.a1lyr, y)
    return [net.a1lyr.dW, net.a1lyr.db, net.a2lyr.dW, net.a2lyr.db, net.a3lyr.dW, net.a3lyr.db] #, net.bn1.dgamma, net.bn1.dbeta, net.bn2.dgamma, net.bn2.dbeta]
end
;

In [6]:
mutable struct Discriminator{T}
    a1lyr::Layer.AffineLayer{T}
    leakyrelu1lyr::Layer.LeakyReluLayer
#     bn1::Layer.BatchNormalization
    a2lyr::Layer.AffineLayer{T}
    leakyrelu2lyr::Layer.LeakyReluLayer
#     bn2::Layer.BatchNormalization
    a3lyr::Layer.AffineLayer{T}
    sigmoidlyr::Layer.SigmoidLayer
    criterionlyr::Layer.BCELossLayer
    params
end

function (::Type{Discriminator{T}})(input_size::Int, hidden_size::Int, hidden_size2::Int, output_size::Int; weight_init_std::Float64=0.1) where T
    W1 = weight_init_std .* randn(T, hidden_size, input_size)
    b1 = zeros(T, hidden_size)
    W2 = weight_init_std .* randn(T, hidden_size2, hidden_size)
    b2 = zeros(T, hidden_size2)
    W3 = weight_init_std .* randn(T, output_size, hidden_size2)
    b3 = zeros(T, output_size)
#     gamma1 = ones(hidden_size)
#     beta1 = zeros(hidden_size)
#     gamma2 = ones(hidden_size2)
#     beta2 = zeros(hidden_size2)
    a1lyr = Layer.AffineLayer(W1, b1)
    leakyrelu1lyr = Layer.LeakyReluLayer()
#     bn1 = Layer.BatchNormalization(gamma1, beta1)
    a2lyr = Layer.AffineLayer(W2, b2)
    leakyrelu2lyr = Layer.LeakyReluLayer()
#     bn2 = Layer.BatchNormalization(gamma2, beta2)
    a3lyr = Layer.AffineLayer(W3, b3)
    sigmoidlyr = Layer.SigmoidLayer()
    criterionlyr = Layer.BCELossLayer()
    params = [a1lyr.W, a1lyr.b, a2lyr.W, a2lyr.b, a3lyr.W, a3lyr.b] #, bn2.gamma, bn2.beta]
#     Discriminator(a1lyr, leakyrelu1lyr, bn1, a2lyr, leakyrelu2lyr, bn2, a3lyr, sigmoidlyr, criterionlyr, params)
        Discriminator(a1lyr, leakyrelu1lyr, a2lyr, leakyrelu2lyr, a3lyr, sigmoidlyr, criterionlyr, params)
end

function setparams(net::Discriminator, params)
    net.a1lyr.W = params[1]
    net.a1lyr.b = params[2]
    net.a2lyr.W = params[3]
    net.a2lyr.b = params[4]
    net.a3lyr.W = params[5]
    net.a3lyr.b = params[6]
#     net.bn1.gamma = params[7]
#     net.bn1.beta = params[8]
#     net.bn2.gamma = params[7]
#     net.bn2.beta = params[8]
end

function forward(net::Discriminator, x)
    x = Layer.forward(net.a1lyr, x)
#     x = Layer.forward(net.bn1, x)
    x = Layer.forward(net.leakyrelu1lyr, x)
    
    x = Layer.forward(net.a2lyr, x)
#     x = Layer.forward(net.bn2, x)
    x = Layer.forward(net.leakyrelu2lyr, x)
    
    x = Layer.forward(net.a3lyr, x)
    output = Layer.forward(net.sigmoidlyr, x)
    return output
end

function backward(net::Discriminator, y)
    dout = one(typeof(y))
    y = Layer.backward(net.criterionlyr, dout)
    y = Layer.backward(net.sigmoidlyr, y)
    y = Layer.backward(net.a3lyr, y)
    
    y = Layer.backward(net.leakyrelu2lyr, y)
#     y = Layer.backward(net.bn2, y)
    y = Layer.backward(net.a2lyr, y)
    
    y = Layer.backward(net.leakyrelu1lyr, y)
#     y = Layer.backward(net.bn1, y)
    y = Layer.backward(net.a1lyr, y)
    return [net.a1lyr.dW, net.a1lyr.db, net.a2lyr.dW, net.a2lyr.db, net.a3lyr.dW, net.a3lyr.db] #, net.bn2.dgamma, net.bn2.dbeta]
end

function _backward(net::Discriminator, y)
    dout = one(typeof(y))
    y = Layer.backward(net.criterionlyr, dout)
    y = Layer.backward(net.sigmoidlyr, y)
    y = Layer.backward(net.a3lyr, y)
    
    y = Layer.backward(net.leakyrelu2lyr, y)
#     y = Layer.backward(net.bn2, y)
    y = Layer.backward(net.a2lyr, y)
    
    y = Layer.backward(net.leakyrelu1lyr, y)
#     y = Layer.backward(net.bn1, y)
    y = Layer.backward(net.a1lyr, y)
    return y
end

function criterion(net::Discriminator{T}, y::AbstractArray{T}, t::AbstractArray{T}) where T
    Layer.forward(net.criterionlyr, y, t)
end
;

### Hyper parameter

In [7]:
const epochs = 100
const batch_size = 100
const learning_rate = Float64(1e-4)
const train_size = size(x_train, 2) # => 60000
const iter_per_epoch = Int32(max(train_size / batch_size, 1))
const noise_size = 100
const image_size = 28

const fixed_noise = randn(noise_size, 1)

generator = Generator{Float64}(100, 256, 512, 784)
discriminator = Discriminator{Float64}(784, 256, 128, 1)

gen_optimizer = Optimizer.Adam(generator, learning_rate)
dis_optimizer = Optimizer.SGD(discriminator, learning_rate)
;

### Train

In [8]:
using Images
function save_checkpoints(tensor, epoch, iter)
    tensor = tensor ./ 2 .+ 0.5
    tensor_2d = reshape(tensor, image_size, image_size)
    img = colorview(Gray, tensor_2d)
    save("./checkpoints/$(epoch)_$(iter).png", img)
end

save_checkpoints (generic function with 1 method)

In [9]:
for epoch = 1:epochs
    iter = 0
    for (x_batch, _) in dataloader(x_train, y_train, batch_size=batch_size, shuffle=true)
        iter += 1
        start_time = now()
        # train Real
        label = ones(1, batch_size)
        d_x = output = forward(discriminator, x_batch)
        dr_loss = criterion(discriminator, output, label)
        grads = backward(discriminator, dr_loss)
        Optimizer.step(dis_optimizer, grads)
        
        # train fake
        noise = randn(noise_size, batch_size)
        fake = forward(generator, noise)
        
        label = zeros(1, batch_size)
        d_g = output = forward(discriminator, fake)
        df_loss = criterion(discriminator, output, label)
        grads = backward(discriminator, df_loss)
        Optimizer.step(dis_optimizer, grads)
        
        # train Generator
        label = ones(1, batch_size)
        output = forward(discriminator, fake)
        g_loss = criterion(discriminator, output, label)
        y = _backward(discriminator, g_loss)
        grads = backward(generator, y)
        Optimizer.step(gen_optimizer, grads)
        
        d_loss = dr_loss + df_loss
        if iter % 100 == 0
            fake = forward(generator, fixed_noise)
            save_checkpoints(fake, epoch, iter)
            print("D(x): $(mean(d_x)) ")
            println("D(G(z)): $(mean(d_g))")
            println("epoch: [$(epoch)/$(epochs)] [$(iter)/$(iter_per_epoch)] g_loss: $(g_loss) d_loss: $(d_loss)")
        end
    end
    fake = forward(generator, fixed_noise)
    save_checkpoints(fake, epoch, iter)
end
;

D(x): 0.8824135704692057 D(G(z)): 0.2429632104182938
epoch: [1/100] [100/600] g_loss: 2.533616349343741 d_loss: 0.4636022567352608
D(x): 0.8627827765336902 D(G(z)): 0.24957329361759858
epoch: [1/100] [200/600] g_loss: 3.938064001499252 d_loss: 0.4967634206493586
D(x): 0.43601374470115933 D(G(z)): 0.5990226011741506
epoch: [1/100] [300/600] g_loss: 3.811994491304884 d_loss: 2.211653063718924
D(x): 0.5312517867595674 D(G(z)): 0.5424678820743705
epoch: [1/100] [400/600] g_loss: 1.3161797773439434 d_loss: 1.5085668085977244
D(x): 0.5910517694737645 D(G(z)): 0.4616719026019027
epoch: [1/100] [500/600] g_loss: 1.277946368844637 d_loss: 1.2037066060961714
D(x): 0.6077778744987012 D(G(z)): 0.4664198265155244
epoch: [1/100] [600/600] g_loss: 1.2582683128536536 d_loss: 1.1806307559618658
D(x): 0.647272573547383 D(G(z)): 0.382528134986142
epoch: [2/100] [100/600] g_loss: 1.4844130048726394 d_loss: 0.9943576464330449
D(x): 0.7384513657120836 D(G(z)): 0.3290283388408
epoch: [2/100] [200/600] g_loss

InterruptException: InterruptException:

### Test

In [12]:
for i in 1:10
    noise = randn(noise_size, 1)
    fake = forward(generator, noise)
    save_checkpoints(fake, "test", i)
end