Skip to content

Commit

Permalink
Added DCGAN and cGAN files.
Browse files Browse the repository at this point in the history
1. The original commits were messed up. This commit overwrites all previous commits.
2.The markdown files and the instructions to run the code is given.
  • Loading branch information
shreyas-kowshik committed Mar 26, 2019
1 parent 7d820e8 commit 81f6567
Show file tree
Hide file tree
Showing 9 changed files with 345 additions and 0 deletions.
Binary file added vision/mnist/DCGAN/GAN-1.jpg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
157 changes: 157 additions & 0 deletions vision/mnist/DCGAN/dcgan.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Get the imports done
using Flux, Flux.Data.MNIST
using Flux: @epochs, back!, testmode!, throttle
using Base.Iterators: partition
using Distributions: Uniform,Normal
using CUDAnative: tanh, log, exp
using CuArrays
using Images
using Statistics

# Define the hyperparameters
BATCH_SIZE = 128
NUM_EPOCHS = 15
noise_dim = 100
channels = 128
hidden_dim = 7 * 7 * channels
training_steps = 0
verbose_freq = 100
dis_lr = 0.0001f0 # Discriminator Learning Rate
gen_lr = 0.0001f0 # Generator Learning Rate

# Loading Data

# We use Flux's built in MNIST Loader
imgs = MNIST.images()

# Partition into batches of size 'BATCH_SIZE'
data = [reshape(float(hcat(vec.(imgs)...)),28,28,1,:) for imgs in partition(imgs, BATCH_SIZE)]

# Define out distribution for random sampling for the generator to sample noise from
dist = Normal(0.0,1.0) # Standard Normal noise is found to give better results

expand_dims(x,n::Int) = reshape(x,ones(Int64,n)...,size(x)...)
squeeze(x) = dropdims(x, dims = tuple(findall(size(x) .== 1)...))

# The Generator
generator = Chain(
Dense(noise_dim, 1024, leakyrelu),
x->expand_dims(x,1),
BatchNorm(1024),
x->squeeze(x),
Dense(1024, hidden_dim, leakyrelu),
x->expand_dims(x,1),
BatchNorm(hidden_dim),
x->squeeze(x),
x->reshape(x,7,7,channels,:),
ConvTranspose((4,4), channels=>64, relu; stride=(2,2), pad=(1,1)),
x->expand_dims(x,2),
BatchNorm(64),
x->squeeze(x),
ConvTranspose((4,4), 64=>1, tanh; stride=(2,2), pad=(1,1))
) |> gpu

# The Discriminator
discriminator = Chain(
Conv((3,3), 1=>32, leakyrelu;pad = 1),
x->meanpool(x, (2,2)),
Conv((3,3), 32=>64, leakyrelu;pad = 1),
x->meanpool(x, (2,2)),
x->reshape(x,7*7*64,:),
Dense(7*7*64, 1024, leakyrelu),
x->expand_dims(x,1),
BatchNorm(1024),
x->squeeze(x),
Dense(1024, 1,sigmoid)
) |> gpu

# <b>Define the optimizers</b>

opt_gen = ADAM(params(generator),gen_lr, β1 = 0.5)
opt_disc = ADAM(params(discriminator),dis_lr, β1 = 0.5)

# <b>Utility functions to zero out our model gradients</b>
function nullify_grad!(p)
if typeof(p) <: TrackedArray
p.grad .= 0.0f0
end
return p
end

function zero_grad!(model)
model = mapleaves(nullify_grad!, model)
end

# <b>Creating and Saving Utilities</b>

img(x) = Gray.(reshape((x+1)/2, 28, 28)) # For denormalizing the generated image

function sample()
noise = [rand(dist, noise_dim, 1) for i=1:9] # Sample 9 digits
noise = gpu.(noise) # Add to GPU

testmode!(generator)
fake_imgs = img.(map(x -> gpu(generator(x).data), noise)) # Generate a new image from random noise
testmode!(generator, false)

img_grid = vcat([hcat(imgs...) for imgs in partition(fake_imgs, 3)]...) # Create grid for saving
end

cd(@__DIR__)


# We use the <b>Binary Cross Entropy Loss</b>
function bce(ŷ, y)
mean(-y.*log.(ŷ) - (1 .- y .+ 1f-10).*log.(1 .-.+ 1f-10))
end

function train(x)
global training_steps
println("TRAINING")
z = rand(dist, noise_dim, BATCH_SIZE) |> gpu
inp = 2x .- 1 |> gpu # Normalize images to [-1,1]

zero_grad!(discriminator)

D_real = discriminator(inp) # D(x)
real_labels = ones(size(D_real)) |> gpu


D_real_loss = bce(D_real,real_labels)

fake_x = generator(z) # G(z)
D_fake = discriminator(fake_x) # D(G(z))
fake_labels = zeros(size(D_fake)) |> gpu

D_fake_loss = bce(D_fake,fake_labels)

D_loss = D_real_loss + D_fake_loss
Flux.back!(D_loss)
opt_disc() # Optimize the discriminator

zero_grad!(generator)

fake_x = generator(z) # G(z)
D_fake = discriminator(fake_x) # D(G(z))
real_labels = ones(size(D_fake)) |> gpu

G_loss = bce(D_fake,real_labels)

Flux.back!(G_loss)
opt_gen() # Optimise the generator

if training_steps % verbose_freq == 0
println("D Loss: $(D_loss.data) | G loss: $(G_loss.data)")
end

training_steps += 1
end

for e = 1:NUM_EPOCHS
for imgs in data
train(imgs)
end
println("Epoch $e over.")
end

save("sample_dcgan.png", sample())
24 changes: 24 additions & 0 deletions vision/mnist/DCGAN/dcgan.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Generative Adversarial Network Tutorial

Generative Adversarial Nets (GAN), are generative models used to infer a complicated probability distribution.

We have two networks competing against each other - The Generator and the discriminator.

![GAN](GAN-1.jpg)

The first net generates data from randomly sampled noise, and the second net tries to tell the difference between the real data and the fake data generated by the first net.

The formulation per se involves the following min-max objective :

![gan_loss](gan.png)

At equilibrium, the discriminator will output a probability of 0.5 for each generated image.

## Run the script

```
julia dcgan.jl
```

*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).*

Binary file added vision/mnist/DCGAN/frames.pdf
Binary file not shown.
Binary file added vision/mnist/DCGAN/gan.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added vision/mnist/cGAN/cGAN.jpg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
148 changes: 148 additions & 0 deletions vision/mnist/cGAN/cgan.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Get the imports done
using Flux, Flux.Data.MNIST,Flux
using Flux: @epochs, back!, testmode!, throttle
using Base.Iterators: partition,flatten
using Flux: onehot,onehotbatch
using Distributions: Normal
using Statistics
using Images

# Define the hyperparameters
NUM_EPOCHS = 5000
BATCH_SIZE = 100
NOISE_DIM = 100
gen_lr = 0.0001f0 # Generator learning rate
dis_lr = 0.0001f0 # discriminator learning rate
training_steps = 0
verbose_freq = 2

# Loading Data
@info("Loading data set")
train_labels = MNIST.labels()[1:100] |> gpu
train_imgs = MNIST.images()[1:100] |> gpu

# Bundle images together with labels and group into minibatches
function make_minibatch(X, Y, idxs)
X_batch = Array{Float32}(undef, 784, length(idxs))
for i in 1:length(idxs)
X_batch[:, i] = Float32.(reshape(X[idxs[i]],784))
end
Y_batch = onehotbatch(Y[idxs], 0:9)
return vcat(X_batch, Y_batch)
end

mb_idxs = partition(1:length(train_imgs), BATCH_SIZE)
train_set = [make_minibatch(train_imgs, train_labels, i) for i in mb_idxs]

# Define out distribution for random sampling for the generator to sample noise from
dist = Normal(0.0,1.0) # Standard Normal noise is found to give better results

# The Generator
generator = Chain(Dense(NOISE_DIM + 10,1200,leakyrelu),
Dense(1200,1000,leakyrelu),
Dense(1000,784,tanh)
) |> gpu

# The Discriminator
discriminator = Chain(Dense(794,512,leakyrelu),
Dense(512,128,leakyrelu),
Dense(128,1,sigmoid)
) |> gpu

# <b>Define the optimizers</b>
opt_gen = ADAM(params(generator),gen_lr, β1 = 0.5)
opt_disc = ADAM(params(discriminator),dis_lr, β1 = 0.5)

# <b>Utility functions to zero out our model gradients</b>
function nullify_grad!(p)
if typeof(p) <: TrackedArray
p.grad .= 0.0f0
end
return p
end

function zero_grad!(model)
model = mapleaves(nullify_grad!, model)
end

# <b>Creating and Saving Utilities</b>

img(x) = Gray.(reshape((x.+1)/2, 28, 28, 1)) # For denormalizing the generated image

function sample()
num_samples = 9 # Number of digits to sample
fake_labels = zeros(10,num_samples)
for i in 1:num_samples
fake_labels[rand(1:9),i] = 1
end

noise = [vcat(rand(dist, NOISE_DIM, 1),fake_labels[:,i]) for i=1:num_samples] # Sample 9 digits
noise = gpu.(noise) # Add to GPU

testmode!(generator)
fake_imgs = img.(map(x -> gpu(generator(x).data), noise)) # Generate a new image from random noise
testmode!(generator, false)

img_grid = fake_imgs[1]
end

cd(@__DIR__)

# We use the <b>Binary Cross Entropy Loss</b>
function bce(ŷ, y)
mean(-y.*log.(ŷ .+ 1f-10) - (1 .- y .+ 1f-10).*log.(1 .-.+ 1f-10))
end

function train(x)
global training_steps

z = rand(dist,NOISE_DIM, BATCH_SIZE) |> gpu
inp = 2x .- 1 |> gpu # Normalize images to [-1,1]
inp[end-9:end,:] = x[end-9:end,:] # The labels should not be modified

labels = Float32.(x[end-9:end,:]) |> gpu # y
zero_grad!(discriminator)
zero_grad!(generator)

D_real = discriminator(inp) # D(x|y)
real_labels = ones(size(D_real)) |> gpu

D_real_loss = bce(D_real,real_labels)

fake_x = generator(vcat(z,labels)) # G(z|y)
D_fake = discriminator(vcat(fake_x,labels)) # D(G(z|y))
fake_labels = zeros(size(D_fake)) |> gpu

D_fake_loss = bce(D_fake,fake_labels)

D_loss = D_real_loss + D_fake_loss
Flux.back!(D_loss)
opt_disc() # Optimize the discriminator

zero_grad!(discriminator)
zero_grad!(generator)

fake_x = generator(vcat(z,labels)) # G(z|y)
D_fake = discriminator(vcat(fake_x,labels)) # D(G(z|y))
real_labels = ones(size(D_fake)) |> gpu

G_loss = bce(D_fake,real_labels)
Flux.back!(G_loss)
opt_gen() # Optimise the generator

if training_steps % verbose_freq == 0
println("D Loss: $(D_loss.data) | G loss: $(G_loss.data)")
end

println(training_steps)
training_steps += 1
end

for e = 1:NUM_EPOCHS
for data in train_set
train(data)
end
println("Epoch $e over.")
end

save("sample_cgan.png", sample())
16 changes: 16 additions & 0 deletions vision/mnist/cGAN/cgan.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Conditional Generative Adversarial Network Tutorial

A cGAN is a GAN wherein both the generator and the discriminator are fed prior labels along with the image and random noise.

It models the conditional probabilities conditioned on the labels.

![cGAN](cGAN.jpg)

## Run the script

```
julia cgan.jl
```

*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).*

Binary file added vision/mnist/cGAN/sample_cgan.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 81f6567

Please sign in to comment.