Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Condtional GAN and DCGAN tutorial #111

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Binary file added vision/mnist/DCGAN/GAN-1.jpg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
157 changes: 157 additions & 0 deletions vision/mnist/DCGAN/dcgan.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Get the imports done
using Flux, Flux.Data.MNIST
using Flux: @epochs, back!, testmode!, throttle
using Base.Iterators: partition
using Distributions: Uniform,Normal
using CUDAnative: tanh, log, exp
using CuArrays
using Images
using Statistics

# Define the hyperparameters
BATCH_SIZE = 128
NUM_EPOCHS = 15
noise_dim = 100
channels = 128
hidden_dim = 7 * 7 * channels
training_steps = 0
verbose_freq = 100
dis_lr = 0.0001f0 # Discriminator Learning Rate
gen_lr = 0.0001f0 # Generator Learning Rate

# Loading Data

# We use Flux's built in MNIST Loader
imgs = MNIST.images()

# Partition into batches of size 'BATCH_SIZE'
data = [reshape(float(hcat(vec.(imgs)...)),28,28,1,:) for imgs in partition(imgs, BATCH_SIZE)]

# Define out distribution for random sampling for the generator to sample noise from
dist = Normal(0.0,1.0) # Standard Normal noise is found to give better results

expand_dims(x,n::Int) = reshape(x,ones(Int64,n)...,size(x)...)
squeeze(x) = dropdims(x, dims = tuple(findall(size(x) .== 1)...))

# The Generator
generator = Chain(
Dense(noise_dim, 1024, leakyrelu),
x->expand_dims(x,1),
BatchNorm(1024),
x->squeeze(x),
Dense(1024, hidden_dim, leakyrelu),
x->expand_dims(x,1),
BatchNorm(hidden_dim),
x->squeeze(x),
x->reshape(x,7,7,channels,:),
ConvTranspose((4,4), channels=>64, relu; stride=(2,2), pad=(1,1)),
x->expand_dims(x,2),
BatchNorm(64),
x->squeeze(x),
ConvTranspose((4,4), 64=>1, tanh; stride=(2,2), pad=(1,1))
) |> gpu

# The Discriminator
discriminator = Chain(
Conv((3,3), 1=>32, leakyrelu;pad = 1),
x->meanpool(x, (2,2)),
Conv((3,3), 32=>64, leakyrelu;pad = 1),
x->meanpool(x, (2,2)),
x->reshape(x,7*7*64,:),
Dense(7*7*64, 1024, leakyrelu),
x->expand_dims(x,1),
BatchNorm(1024),
x->squeeze(x),
Dense(1024, 1,sigmoid)
) |> gpu

# Define the optimizers

opt_gen = ADAM(params(generator),gen_lr, β1 = 0.5)
opt_disc = ADAM(params(discriminator),dis_lr, β1 = 0.5)

# Utility functions to zero out our model gradients
function nullify_grad!(p)
if typeof(p) <: TrackedArray
p.grad .= 0.0f0
end
return p
end

function zero_grad!(model)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You shouldn't need these utilities; just use update! with Params and Grads, like Flux.train! does.

Also, you're still using HTML tags above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MikeInnes I used the update! without zeroing out the gradients of the models manually. However that model is not converging to the actual output even after repeated trials. However, manually zeroing out the gradients does. I don't know if it's a bug on my part. I was using this as a reference :
https://github.com/eriklindernoren/PyTorch-GAN/blob/1f130dfca726e14254e4fd78e5fb63f08931acd3/implementations/cgan/cgan.py#L161-L195

As pointed out on Slack, gradient used in update! should automatically zero out the gradients, but the results are not reflecting them...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The normal update! method will work if you use gradient rather than back!. back! should be avoided as it's effectively deprecated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MikeInnes Sorry for replying late.Made the requested changes.
Does it look good?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update! does zero out the gradient so no need to do it explicitly I suppose.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe update! only zeros out the gradient in the call but not all the gradient.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See this MWE below:

using Flux, Tracker

d1 = Dense(2, 1)
d2 = Dense(1, 1)
c = Chain(d1, d2)
p1 = params(d1)
p2 = params(d2)
pall = params(c)

x = rand(2, 10)

loss() = sum(c(x))

# Case 1
gradient(loss, pall).grads |> values |> println
# This zeros out all gradients

# Case 2
gradient(loss, p1).grads |> values |> println
# After this call, the gradient of p2 is not zeroed out
# Thus the call for gradient of p2 below will be affected
gradient(loss, p2).grads |> values |> println
# After this call, the gradient of p1 is not zeroed out

# Just zero out all gradients before the next experiment
Tracker.zero_grad!.(Tracker.grad.(p1))
Tracker.zero_grad!.(Tracker.grad.(p2))

# Case 3
gradient(loss, p1).grads |> values |> println
Tracker.zero_grad!.(Tracker.grad.(p2))  # just to avoid the situation in Case 1
gradient(loss, p2).grads |> values |> println

gives

Any[Float32[10.0] (tracked), Float32[1.8203094] (tracked), Float32[1.2446415 0.7490297] (tracked), Float32[-1.8717546] (tracked)]

Any[Float32[1.8203094] (tracked), Float32[1.2446415 0.7490297] (tracked)]
Any[Float32[20.0] (tracked), Float32[-3.7435093] (tracked)]

Any[Float32[1.8203094] (tracked), Float32[1.2446415 0.7490297] (tracked)]
Any[Float32[10.0] (tracked), Float32[-1.8717546] (tracked)]

See how Case 2 is different from Case 1 and Case 3.

Copy link
Contributor

@matsueushi matsueushi Dec 4, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I created a DCGAN model with Tracker backend (Flux v0.9.0), it didn't converge without zeroing out gradients after training the discriminator. https://github.com/matsueushi/fluxjl-gan/blob/flux0.9.0/mnist-dcgan.jl

However, with Zygote backend (Flux v0.10.0),

using Flux

d1 = Dense(2, 1)
d2 = Dense(1, 1)
c = Chain(d1, d2)
p1 = params(d1)
p2 = params(d2)
pall = params(c)

x = rand(2, 10)

loss() = sum(c(x))

@info "Case1"
gradient(loss, pall).grads |> values |> println

@info "Case2"
gradient(loss, p1).grads |> values |> println
gradient(loss, p2).grads |> values |> println

gives expected results

[ Info: Case1
Any[Float32[9.715003], Float32[3.7775292 4.41461], Float32[10.0], Float32[-5.8413825]]
[ Info: Case2
Any[Float32[9.715003], Float32[3.7775292 4.41461]]
Any[Float32[10.0], Float32[-5.8413825]]

and I didn't have to zero out gradients. https://github.com/matsueushi/fluxjl-gan/blob/e60684b6c8ecc601eb6784ae393eae9a3a3ba57a/mnist-dcgan.jl

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is expected. Only a tracker based AD needs the zero-out part. AD based on Zygote doesn't have this side effect.

model = mapleaves(nullify_grad!, model)
end

# Creating and Saving Utilities

img(x) = Gray.(reshape((x+1)/2, 28, 28)) # For denormalizing the generated image

function sample()
noise = [rand(dist, noise_dim, 1) for i=1:9] # Sample 9 digits
noise = gpu.(noise) # Add to GPU

testmode!(generator)
fake_imgs = img.(map(x -> gpu(generator(x).data), noise)) # Generate a new image from random noise
testmode!(generator, false)

img_grid = vcat([hcat(imgs...) for imgs in partition(fake_imgs, 3)]...) # Create grid for saving
end

cd(@__DIR__)


# We use the Binary Cross Entropy Loss
function bce(ŷ, y)
mean(-y.*log.(ŷ) - (1 .- y .+ 1f-10).*log.(1 .- ŷ .+ 1f-10))
end

function train(x)
global training_steps
println("TRAINING")
z = rand(dist, noise_dim, BATCH_SIZE) |> gpu
inp = 2x .- 1 |> gpu # Normalize images to [-1,1]

zero_grad!(discriminator)

D_real = discriminator(inp) # D(x)
real_labels = ones(size(D_real)) |> gpu


D_real_loss = bce(D_real,real_labels)

fake_x = generator(z) # G(z)
D_fake = discriminator(fake_x) # D(G(z))
fake_labels = zeros(size(D_fake)) |> gpu

D_fake_loss = bce(D_fake,fake_labels)

D_loss = D_real_loss + D_fake_loss
Flux.back!(D_loss)
opt_disc() # Optimize the discriminator

zero_grad!(generator)

fake_x = generator(z) # G(z)
D_fake = discriminator(fake_x) # D(G(z))
real_labels = ones(size(D_fake)) |> gpu

G_loss = bce(D_fake,real_labels)

Flux.back!(G_loss)
opt_gen() # Optimise the generator

if training_steps % verbose_freq == 0
println("D Loss: $(D_loss.data) | G loss: $(G_loss.data)")
end

training_steps += 1
end

for e = 1:NUM_EPOCHS
for imgs in data
train(imgs)
end
println("Epoch $e over.")
end

save("sample_dcgan.png", sample())
24 changes: 24 additions & 0 deletions vision/mnist/DCGAN/dcgan.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Generative Adversarial Network Tutorial

Generative Adversarial Nets (GAN), are generative models used to infer a complicated probability distribution.

We have two networks competing against each other - The Generator and the discriminator.

![GAN](GAN-1.jpg)

The first net generates data from randomly sampled noise, and the second net tries to tell the difference between the real data and the fake data generated by the first net.

The formulation per se involves the following min-max objective :

![gan_loss](gan.png)

At equilibrium, the discriminator will output a probability of 0.5 for each generated image.

## Run the script

```
julia dcgan.jl
```

*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).*

Binary file added vision/mnist/DCGAN/frames.pdf
Binary file not shown.
Binary file added vision/mnist/DCGAN/gan.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added vision/mnist/cGAN/cGAN.jpg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
148 changes: 148 additions & 0 deletions vision/mnist/cGAN/cgan.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Get the imports done
using Flux, Flux.Data.MNIST,Flux
using Flux: @epochs, back!, testmode!, throttle
using Base.Iterators: partition,flatten
using Flux: onehot,onehotbatch
using Distributions: Normal
using Statistics
using Images

# Define the hyperparameters
NUM_EPOCHS = 5000
BATCH_SIZE = 100
NOISE_DIM = 100
gen_lr = 0.0001f0 # Generator learning rate
dis_lr = 0.0001f0 # discriminator learning rate
training_steps = 0
verbose_freq = 2

# Loading Data
@info("Loading data set")
train_labels = MNIST.labels()[1:100] |> gpu
train_imgs = MNIST.images()[1:100] |> gpu

# Bundle images together with labels and group into minibatches
function make_minibatch(X, Y, idxs)
X_batch = Array{Float32}(undef, 784, length(idxs))
for i in 1:length(idxs)
X_batch[:, i] = Float32.(reshape(X[idxs[i]],784))
end
Y_batch = onehotbatch(Y[idxs], 0:9)
return vcat(X_batch, Y_batch)
end

mb_idxs = partition(1:length(train_imgs), BATCH_SIZE)
train_set = [make_minibatch(train_imgs, train_labels, i) for i in mb_idxs]

# Define out distribution for random sampling for the generator to sample noise from
dist = Normal(0.0,1.0) # Standard Normal noise is found to give better results

# The Generator
generator = Chain(Dense(NOISE_DIM + 10,1200,leakyrelu),
Dense(1200,1000,leakyrelu),
Dense(1000,784,tanh)
) |> gpu

# The Discriminator
discriminator = Chain(Dense(794,512,leakyrelu),
Dense(512,128,leakyrelu),
Dense(128,1,sigmoid)
) |> gpu

# Define the optimizers
opt_gen = ADAM(params(generator),gen_lr, β1 = 0.5)
opt_disc = ADAM(params(discriminator),dis_lr, β1 = 0.5)

# Utility functions to zero out our model gradients
function nullify_grad!(p)
if typeof(p) <: TrackedArray
p.grad .= 0.0f0
end
return p
end

function zero_grad!(model)
model = mapleaves(nullify_grad!, model)
end

# Creating and Saving Utilities

img(x) = Gray.(reshape((x.+1)/2, 28, 28, 1)) # For denormalizing the generated image

function sample()
num_samples = 9 # Number of digits to sample
fake_labels = zeros(10,num_samples)
for i in 1:num_samples
fake_labels[rand(1:9),i] = 1
end

noise = [vcat(rand(dist, NOISE_DIM, 1),fake_labels[:,i]) for i=1:num_samples] # Sample 9 digits
noise = gpu.(noise) # Add to GPU

testmode!(generator)
fake_imgs = img.(map(x -> gpu(generator(x).data), noise)) # Generate a new image from random noise
testmode!(generator, false)

img_grid = fake_imgs[1]
end

cd(@__DIR__)

# We use the Binary Cross Entropy Loss
function bce(ŷ, y)
mean(-y.*log.(ŷ .+ 1f-10) - (1 .- y .+ 1f-10).*log.(1 .- ŷ .+ 1f-10))
end

function train(x)
global training_steps

z = rand(dist,NOISE_DIM, BATCH_SIZE) |> gpu
inp = 2x .- 1 |> gpu # Normalize images to [-1,1]
inp[end-9:end,:] = x[end-9:end,:] # The labels should not be modified

labels = Float32.(x[end-9:end,:]) |> gpu # y
zero_grad!(discriminator)
zero_grad!(generator)

D_real = discriminator(inp) # D(x|y)
real_labels = ones(size(D_real)) |> gpu

D_real_loss = bce(D_real,real_labels)

fake_x = generator(vcat(z,labels)) # G(z|y)
D_fake = discriminator(vcat(fake_x,labels)) # D(G(z|y))
fake_labels = zeros(size(D_fake)) |> gpu

D_fake_loss = bce(D_fake,fake_labels)

D_loss = D_real_loss + D_fake_loss
Flux.back!(D_loss)
opt_disc() # Optimize the discriminator

zero_grad!(discriminator)
zero_grad!(generator)

fake_x = generator(vcat(z,labels)) # G(z|y)
D_fake = discriminator(vcat(fake_x,labels)) # D(G(z|y))
real_labels = ones(size(D_fake)) |> gpu

G_loss = bce(D_fake,real_labels)
Flux.back!(G_loss)
opt_gen() # Optimise the generator

if training_steps % verbose_freq == 0
println("D Loss: $(D_loss.data) | G loss: $(G_loss.data)")
end

println(training_steps)
training_steps += 1
end

for e = 1:NUM_EPOCHS
for data in train_set
train(data)
end
println("Epoch $e over.")
end

save("sample_cgan.png", sample())
16 changes: 16 additions & 0 deletions vision/mnist/cGAN/cgan.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Conditional Generative Adversarial Network Tutorial

A cGAN is a GAN wherein both the generator and the discriminator are fed prior labels along with the image and random noise.

It models the conditional probabilities conditioned on the labels.

![cGAN](cGAN.jpg)

## Run the script

```
julia cgan.jl
```

*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).*

Binary file added vision/mnist/cGAN/sample_cgan.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.