In [1]:
# Set display width, load packages, import symbols
ENV["COLUMNS"]=72
using Distributions
using Interpolations
using Knet: Knet, dir, accuracy, progress, sgd, load143, save, gc, Param, KnetArray, Data, minibatch, nll, relu, training, dropout,sigm # param, param0, xavier_uniform
using Knet
using Images
using Plots
using LinearAlgebra
using IterTools: ncycle, takenth
using MLDatasets
using Base.Iterators: flatten
import CUDA # functional
using ImageTransformations
using Statistics
using Memento
using NPZ
# using Interpolations
atype=(CUDA.functional() ? KnetArray{Float32} : Array{Float32})

KnetArray{Float32, N} where N

In [2]:
const F = Float32
params = Knet.params

logger = Memento.config!("info"; fmt="[{date} | {level} | {name}]: {msg}");

include("PlotUtility.jl")
include("ImageUtility.jl")
include("TrainUtility.jl")
include("LayerUtility.jl")
include("LossUtility.jl")

using .PlotUtility
using .ImageUtility
using .TrainUtility
using .LayerUtility
using .LossUtility

In [3]:
########################### CHANGE THIS LINE FOR DATASET PARAMETER ##############################
dataset_name = "fashion"
exp_number = 1
########################### CHANGE THIS LINE FOR RESULT FOLDER NAME #############################
notebook_name = "Implicit_GON_Fashion" * "_" * dataset_name * string(exp_number)

if !isdir("Results")
   mkdir("Results") 
end
if  !isdir(joinpath("Results", notebook_name))
    mkdir(joinpath("Results", notebook_name))
end

if  !isdir(joinpath("Results", notebook_name, "Saved_Models"))
    mkdir(joinpath("Results", notebook_name, "Saved_Models"))
end

if  !isdir(joinpath("Results", notebook_name, "Images"))
    mkdir(joinpath("Results", notebook_name, "Images"))
end

if  !isdir(joinpath("Results", notebook_name, "Logger"))
    mkdir(joinpath("Results", notebook_name, "Logger"))
end

push!(logger, DefaultHandler(joinpath("Results", notebook_name, "Logger", "logger.log"),DefaultFormatter("[{date} | {level} | {name}]: {msg}")));

use_saved_data = false
nc = nothing

if dataset_name == "mnist"
    nc = 1
    if use_saved_data
        xtrn = npzread("Data/MNIST_Train_Images.npy")
        xtrn = permutedims(xtrn, (3,4,2,1))
        xtst = npzread("Data/MNIST_Test_Images.npy")
        xtst = permutedims(xtst, (3,4,2,1))

    else

        xtrn,_ = MNIST.traindata()
        xtst,_ = MNIST.testdata()
        xtrn = Array{Float64, 3}(xtrn)
        xtst = Array{Float64, 3}(xtst)

        xtrn = resize_MNIST(xtrn, 1)
        xtst = resize_MNIST(xtst, 1)
        
    end
    
elseif dataset_name == "fashion"
    nc = 1
    if use_saved_data

        xtrn = npzread("Data/Fashion_MNIST_Train_Images.npy")
        xtrn = permutedims(xtrn, (3,4,2,1))
        xtst = npzread("Data/Fashion_MNIST_Test_Images.npy")
        xtst = permutedims(xtst, (3,4,2,1))
        
    else
        
        xtrn,_ = FashionMNIST.traindata()
        xtst,_ = FashionMNIST.testdata()
        xtrn = Array{Float64, 3}(xtrn)
        xtst = Array{Float64, 3}(xtst)

        xtrn = resize_MNIST(xtrn, 1)
        xtst = resize_MNIST(xtst, 1)

    end
    
elseif dataset_name == "cifar"
    nc = 3
    xtrn,_= CIFAR10.traindata()
    xtst,_ = CIFAR10.testdata()
    xtrn = Array{Float64, 4}(xtrn)
    xtst = Array{Float64, 4}(xtst)
#     println("No implemented yet")
end

batch_size = 64

dtrn = minibatch(xtrn, batch_size; xsize = (28*28, nc,:), xtype = atype, shuffle = true)
dtst = minibatch(xtst, batch_size; xsize = (28*28, nc,:), xtype = atype);

In [4]:
function SIREN_Layer_Weight_Init(i, o; w0 = 30, is_first = false, bias::Bool = true, return_param = false)
    if is_first
       k = 1/i 
    else
        k = sqrt(6/i)/w0
    end
    w = rand(Uniform(-k,k), o, i)
    if bias 
        k_ = sqrt(1/i)
        bias = rand(Uniform(-k_, k_), o, 1)
        if return_param
            return Param(w), Param(b)
        else
            return w, bias
        end
    end
    if return_param
       return Param(w) 
    else
        return w
    end
end

SIREN_Layer_Weight_Init (generic function with 1 method)

In [5]:
gon_shape = [34, 256, 256, 256, 256, 1]

function weights(gon_shape, w0)
    theta = []  # Empty list initialization of weights
    w,b = SIREN_Layer_Weight_Init(gon_shape[1], gon_shape[2]; is_first =true, w0 = w0)
    push!(theta, w)
    push!(theta, b)
    
    w, b = SIREN_Layer_Weight_Init(gon_shape[2], gon_shape[3]; w0 = w0)
    push!(theta, w)
    push!(theta, b)
    
    w, b = SIREN_Layer_Weight_Init(gon_shape[3], gon_shape[4]; w0 = w0)
    push!(theta, w)
    push!(theta, b)
    
    w, b = SIREN_Layer_Weight_Init(gon_shape[4], gon_shape[5]; w0 = w0)
    push!(theta, w)
    push!(theta, b)
    
    w, b = SIREN_Layer_Weight_Init(gon_shape[5], gon_shape[6]; w0 = w0)
    push!(theta, w)
    push!(theta, b)
    
    theta = map(a->convert(atype,a), theta)
    return Param.(theta)
end

weights (generic function with 1 method)

In [6]:
function get_mgrid(sidelen)
    iterator = (range(-1,stop=1,length = sidelen))
    return Array{Float64}(hcat([[i,j] for i = iterator, j = iterator]...)');
end

function batched_linear(theta, x_in; atype = KnetArray{Float32})
#     """
#     multiply a weight matrix of size (O, I) with a batch of matrices 
#     of size (I, W, B) to have an output of size (O, W, B), 
#     where B is the batch size.
    
#     size(theta) = (O, I)
#     size(x_in) = (O, W, B)
#     """
    o = size(theta,1)
    w = size(x_in, 2)
    b = size(x_in, 3)
    x_in_reshaped = reshape(x_in, size(x_in,1), w*b)
    out = reshape(theta * x_in_reshaped, size(theta,1), w, b)
    return out
end

batched_linear (generic function with 1 method)

In [7]:
function model_forw(theta, z, c; w0 = 30)
   
    z_ = copy(z)
    z_ = permutedims(reshape(z_,64,1,1,num_latent),(4,3,2,1))
    # The following line is the same for  :  hcat([z for _ = 1:size(c,2)]...)
    # However it is more efficient while taking second order derivative of the loss.
    # one_conv_weight is defined globally as convolution weights of all ones
    z_rep = permutedims(conv4(one_conv_weight, z_)[:,1,:,:], (3,2,1))
    z_in = cat(c, z_rep, dims = 3)
    z_in = (permutedims(z_in, (3,2,1)))
    z = batched_linear(theta[1], z_in) .+ theta[2]
    z = sin.(w0 .* z)
    
    z = batched_linear(theta[3], z) .+ theta[4]
    z = sin.(w0 .* z)
    
    z = batched_linear(theta[5], z) .+ theta[6]
    z = sin.(w0 .* z)
    
    z = batched_linear(theta[7], z) .+ theta[8]
    z = sin.(w0 .* z)
    
    z = batched_linear(theta[9], z) .+ theta[10]
#     z = sin.(30 .* z)
    z = permutedims(z, (2,1,3))
end

function loss(theta, z, x)
    x_hat = model_forw(theta, z, c)
    L = mean((x_hat- x).^2)
#     L = mean(sum((x_hat - x).^2, dims = 1))
end

function loss_train(theta, x; batch_size = 64)
    z = Param(atype(zeros(batch_size, 1, num_latent)))
    derivative_origin = @diff loss(theta, z, x)
    dz = grad(derivative_origin, z)
    z = -dz
    x_hat = model_forw(theta, z, c)
    L = mean((x_hat- x).^2)
#     L = mean(sum((x_hat - x).^2, dims = 1))
    return L
end

function loss_train(theta, d::Data)
     total_loss = 0
    n_instance = 0
    for x in d
        batch_size_ = size(x,3)
       total_loss += loss_train(theta, x; batch_size = batch_size_) * batch_size_
        n_instance += batch_size_
    end
    total_loss / n_instance
end

loss_train (generic function with 2 methods)

In [8]:
function plot_reconstructed_images2(im_ori, im_rec, n_instances = 10, max_instance = 64, plot_size = (900,300))
    k = rand(1:max_instance, n_instances)
    ori_plot_list = reshape(im_ori[:,:,:,k[1]], (28, 28))
    recon_plot_list = reshape(im_rec[:,:,:,k[1]], (28, 28))
    for j in k[2:end]
        ori_plot_list = hcat(ori_plot_list, reshape(im_ori[:,:,:,j], (28, 28)))
        recon_plot_list = hcat(recon_plot_list, reshape(im_rec[:,:,:,j], (28, 28)))
    end
    p1 = plot(Matrix{Gray{Float32}}(ori_plot_list), title = "Original Images", size = (20,200),font =  "Courier", xtick = false, ytick = false)
    p2 = plot(Matrix{Gray{Float32}}(recon_plot_list), title = "Reconstructed Images", font = "Courier", xtick = false, ytick = false)
    plot(p1, p2, layout = (2,1), size = (900,300))
end

plot_reconstructed_images2 (generic function with 4 methods)

In [9]:
one_conv_weight = atype(ones(1,1,1,784))

batch_size = 64
x = first(dtrn)
mgrid = get_mgrid(28)
c = atype(permutedims(repeat(mgrid,1,1,batch_size),(3,1,2)));
c_copy = copy(c)
c_copy[:,:,1] = c[:,:,2]
c_copy[:,:,2] = c[:,:,1]
c = c_copy
num_latent = 32

# define model weights
theta = weights(gon_shape, 30);
# Define Learning Rate and Number of Epochs
lr = 2*1e-4
n_epochs = 500
# Specify the optimizer for each param
for p in params(theta)
    p.opt =  Knet.Adam(lr = lr, beta1 = 0.9, beta2 = 0.999)
end

In [10]:
# Initialize Empty Lists for both training and test losses
trn_loss_list = Float64[]
tst_loss_list = Float64[]


# RECORD INITIAL LOSS VALUES
epoch_loss_trn_ = loss_train(theta, dtrn)
epoch_loss_tst_ = loss_train(theta, dtst)

push!(trn_loss_list, epoch_loss_trn_)
push!(tst_loss_list, epoch_loss_tst_)

info(logger, ("Now training of Implicit-GON is starting. We provide the parameters as the following"))
info(logger, "Dataset = $dataset_name")
info(logger,"num_latent = $num_latent")
info(logger, "lr = $lr")
info(logger, "n_epochs = $n_epochs")

info(logger, ("Epoch : 0"))
info(logger, ("Train Loss : $epoch_loss_trn_"))
info(logger, ("Test Loss : $epoch_loss_tst_"))

[32m[2021-12-20 04:46:31 | info | root]: Now training of Implicit-GON is starting. We provide the parameters as the following[39m
[32m[2021-12-20 04:46:31 | info | root]: Dataset = fashion[39m
[32m[2021-12-20 04:46:31 | info | root]: num_latent = 32[39m
[32m[2021-12-20 04:46:31 | info | root]: lr = 0.0002[39m
[32m[2021-12-20 04:46:31 | info | root]: n_epochs = 500[39m
[32m[2021-12-20 04:46:31 | info | root]: Epoch : 0[39m
[32m[2021-12-20 04:46:32 | info | root]: Train Loss : 0.20152164[39m
[32m[2021-12-20 04:46:32 | info | root]: Test Loss : 0.20163457[39m


In [None]:
########################################## CHANGE THE FOLLOWING LINES FOR CHECKPOINT ITERATION NUMBERS ############################
# Define the step number of model save checkpoint
model_save_checkpoint = 1
logger_checkpoint = 1
image_rec_checkpoint = 1

x_ = first(dtst)
for epoch in progress(1:n_epochs)
    for (i,x) in enumerate(dtrn)
        
        derivative_model = @diff loss_train(theta, x)
        
#         if (i%100) == 0
#            println(value(derivative_model)) 
#         end
        for p in theta
            dp = grad(derivative_model, p)
            update!(value(p), dp, p.opt)
        end
    
    end
    
    epoch_loss_trn = loss_train(theta, dtrn)
    epoch_loss_tst = loss_train(theta, dtst)
    push!(trn_loss_list, epoch_loss_trn)
    push!(tst_loss_list, epoch_loss_tst)
    
    # Print losses to the logger file
    if epoch % logger_checkpoint == 0
        info(logger,"Epoch : $epoch")
        info(logger,"Train Loss : $epoch_loss_trn")
        info(logger,"Test Loss : $epoch_loss_tst")
    end
    
    if ((epoch - 1) % image_rec_checkpoint == 0) || (epoch == n_epochs)
        
        z = Param(atype(zeros(batch_size, 1, num_latent)))
        derivative_origin = @diff loss(theta, z, x_)
        dz = grad(derivative_origin, z)
        z = -dz
        x_hat = model_forw(theta, z, c)
        x_hat_ = Array{Float32}(reshape(x_hat, 28,28,1,64))
        x__ = Array{Float32}(reshape(x_, 28,28,1,64))

        (plot_reconstructed_images2(x__, x_hat_, 10, 64, (900,300)))
        fig_name = "Reconstructed_Imgs_ID" * string(1000 + epoch) 
        savefig(joinpath("Results", notebook_name, "Images", fig_name))
        
    end
    
    # Save model at some steps
    if (epoch % model_save_checkpoint == 0) || (epoch == n_epochs)
        
        model_id = 1000 + epoch
        model_name = joinpath("Results", notebook_name, "Saved_Models","Model_Base$model_id.jld2")
        w = Dict(:decoder => theta)
        Knet.save(model_name,"model",w) 
        ### TO LOAD THE MODEL WEIGHTS, USE THE FOLLOWING
        # w = Knet.load(model_name,"model",) # Ex: model_name = "Results/Conv_AutoEncoder_Baseline_MNIST/Saved_Models/Model_Base1500.jld2"
        # theta = w[:decoder]
#         Knet.save(joinpath("Results", notebook_name,"trn_loss_list.jld2"),"trn_loss_list",trn_loss_list) 
#         Knet.save(joinpath("Results", notebook_name,"tst_loss_list.jld2"),"tst_loss_list",tst_loss_list) 
        
    end
    Knet.save(joinpath("Results", notebook_name,"trn_loss_list.jld2"),"trn_loss_list",trn_loss_list) 
    Knet.save(joinpath("Results", notebook_name,"tst_loss_list.jld2"),"tst_loss_list",tst_loss_list) 
    
end

plot_loss_convergence(trn_loss_list[2:end], tst_loss_list[2:end]; title = "Train & Test Loss w.r.t. Epochs")
fig_name = "Train_and_test_loss"
savefig(joinpath("Results", notebook_name, fig_name))

plot_loss_convergence(trn_rec_loss_list[2:end], tst_rec_loss_list[2:end]; title = "Train & Test Reconstruction Loss w.r.t. Epochs")
fig_name = "Train_and_test_reconstruction_loss"
savefig(joinpath("Results", notebook_name, fig_name))


┣                    ┫ [0.20%, 1/500, 00:00/00:01, 572.95i/s] 

0.06555851
0.0650994
0.06393664
0.04932634
0.042390026
0.03834834
0.036927365
0.0339294
0.036424104
[32m[2021-12-20 04:50:23 | info | root]: Epoch : 1[39m
[32m[2021-12-20 04:50:23 | info | root]: Train Loss : 0.034436733[39m
[32m[2021-12-20 04:50:23 | info | root]: Test Loss : 0.03444634[39m


┣                    ┫ [0.40%, 2/500, 04:13/17:34:38, 253.11s/i] 

0.028655665
0.028749403
0.028376698
0.033038314
0.026327541
0.028540615
0.028248327
