In [1]:
using Knet, Plots;
gr()
global atype = gpu()>=0 ? KnetArray{Float32} : Array{Float32}

Array{Float32,N} where N

In [2]:
function julia(x, y, width, height, c)
    # Scale the values of x and y to generate the fractal within our frame
    z = ((y/width)*2.7 - 1.3) + ((x/height)*4.5 - 2.5)im
    for i = 1:254
        z = z^2 + c
        if abs(z) >= 4
           return Float32(i)
        end
    end
    return Float32(255)
end


julia_set(height, width, c) = [julia(x, y, width, height, c) for x = 1:height, y = 1:width]

function get_training_data(n,size)
    data = []
    for i=1:n
        #c = 1.2e^(rand()*256*π*im)
        c = 1.2e^(rand()/10 + (n/14)*π*im)
        push!(data, julia_set(size, size, c))
    end
    cat(4, map(x -> cat(3,x),data)...)
end

get_training_data (generic function with 1 method)

In [4]:
function initmodel(hidden,input, output)
    𝗪 = [];
    x = input
    for h in [hidden... output]
        push!(𝗪, atype(xavier(h,x)), atype(zeros(h, 1))) #FC Layers weights and bias
        x = h
    end
    𝗪
end

leakyrelu(x;α=Float32(0.2)) = max(0,x) + α*min(0,x) # LeakyRelu activation
#A generic MLP forward prop function
function forward_prop(W,X;dropout_p=0.0)
    for i=1:2:length(W)
        X = W[i]*dropout(mat(X),dropout_p) .+ W[i+1] # mat(X) flattens X to an 
        i<length(W)-1 && (X = leakyrelu.(X))
    end
    sigm.(X)
end
# Forward prop for the discriminator and generator respectivly
D(w,x;dropout_p=0.0) = forward_prop(w,x;dropout_p=dropout_p)  #  Discriminator
G(w,z;dropout_p=0.0) = forward_prop(w,z;dropout_p=dropout_p)  #  Generator


global const 𝜀=Float32(1e-8) #  a small number prevent from getting NaN  in logs
𝑱d(𝗪d,x,Gz) = -mean(log.(D(𝗪d,x)+𝜀)+log.(1-D(𝗪d,Gz)+𝜀))/2 # Discriminator Loss
𝑱g(𝗪g, 𝗪d, z) = -mean(log.(D(𝗪d,G(𝗪g,z))+𝜀))             # Generator Loss

∇d  = grad(𝑱d) # Discriminator gradient
∇g  = grad(𝑱g) # Generator gradient



𝒩(input, batch) = atype(randn(Float32, input, batch))      # SampleNoise
function generate_and_save(𝗪,number,𝞗,gen;fldr="/mnt/data/other/fractals/8/")
    Gz = G(𝗪[1],𝒩(𝞗[:ginp],number)) #.> 0.5
    Gz = permutedims(reshape(Gz,(𝞗[:size],𝞗[:size],number)),(2,1,3))
    [png(heatmap(Gz[:,:,i], color=:ice), "$(fldr)$(gen)-$(i).png") for i=1:number]
end

#(if) train ? it updates model parameters : (else) it print losses
function train_model(𝗪, data, 𝞗, optim)
    gloss=dloss=counter=0.0;
    B = 𝞗[:batchsize]
    for generation=1:𝞗[:epochs]
        for n=1:32:(length(data[1,1,1,:]) - 33)
            x = data[:,:,:,n:n+31]
            counter+=2B
            Gz = G(𝗪[1],𝒩(𝞗[:ginp],B)) #Fake Images
            update!(𝗪[2], ∇d(𝗪[2],x,Gz), optim[2])
            z=𝒩(𝞗[:ginp],2B) #Sample z from Noise
            update!(𝗪[1], ∇g(𝗪[1], 𝗪[2], z), optim[1])
        end
        # Compute the total losses, log them and save some images
        log_model(𝗪, data, 𝞗, generation)
    end
end

function log_model(𝗪, data, 𝞗, generation)
    println("Running logging function for generation $(generation)")
    println("-----------------------------------------------------")
    gloss=dloss=counter=0.0;
    B = 𝞗[:batchsize]
    for n=1:32:(length(data[1,1,1,:]) - 33)
        x = data[:,:,:,n:n+31]
        dloss += 2B*𝑱d(𝗪[2],x,Gz)
        gloss += 2B*𝑱g(𝗪[1],𝗪[2],z)
    end
    
    println("Generator average loss: $(gloss/counter)")
    println("Disciminator average loss: $(dloss/counter)")
    generate_and_save(𝗪,6,𝞗,generation)
    println("Saved images for generation: $(generation)")
    println("------------------------------------------------\n")
end

function main(size, data)
    𝞗 = Dict(:batchsize=>32,:epochs=>80,:ginp=>256,:genh=>[256],:disch=>[1536],:optim=>Adam,:size=>size)
    println("Using hidden layer for discriminator: $(𝞗[:disch]) and hidden layers for generator: $(𝞗[:genh])")
    𝗪 = (𝗪g,𝗪d)   = initmodel(𝞗[:genh], 𝞗[:ginp], size^2), initmodel(𝞗[:disch], size^2, 1)
    𝚶 = (𝚶pg,𝚶pd)  = optimizers(𝗪g,𝞗[:optim];lr=0.003), optimizers(𝗪d,𝞗[:optim];lr=0.0002)
    train_model(𝗪, data, 𝞗, 𝚶) 
end


main (generic function with 1 method)

In [None]:
size = 80
println("Generating X !")
X = get_training_data(8000,size)
println("Done generating X !")
main(size, X)