<a href="https://colab.research.google.com/github/a-mhamdi/jlai/blob/main/Codes/Julia/Part-3/gan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GENERATIVE ADVERSARIAL NETWORK
---

In [None]:
versioninfo() # -> v"1.11.2"

In [None]:
pkgs = """[deps]
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ImageInTerminal = "d8c32880-2388-543b-8c61-d9f865259254"
ImageShow = "4e3cecfd-b093-5904-9786-8bbb286a6a31"
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
"""

open("Project.toml", "w") do file
    write(file, pkgs)
end

In [None]:
_ = begin
  import Pkg;
  Pkg.activate(".");
  Pkg.instantiate();
end

In [None]:
Pkg.status()

In [None]:
using Flux
using CUDA

In [None]:
using Images: Gray
using ProgressMeter: @showprogress

**Generator:** noise vector -> synthetic sample

In [None]:
function generator(; latent_dim=16, img_shape=(28,28,1,1))
    return Chain(
        Dense(latent_dim, 128, relu),
        Dense(128, 256, relu),
        Dense(256, prod(img_shape), tanh),
        x -> reshape(x, img_shape)
    )
end

In [None]:
gen = generator() |> gpu

**Discriminator:** sample -> score indicating the probability that the sample is real.

In [None]:
function discriminator(; img_shape=(28,28,1,1))
    return Chain(
        x -> reshape(x, :, size(x, 4)),
        Dense(prod(img_shape), 256, relu),
        Dense(256, 128, relu),
        Dense(128, 1)
    )
end

Loss function

In [None]:
bce_loss(y_true, y_pred) = Flux.logitbinarycrossentropy(y_pred, y_true)

In [None]:
disc = discriminator() |> gpu

Training function

In [None]:
function train_gan(gen, disc; n_epochs=16, latent_dim=16)

  gen_st = Flux.setup(Adam(.001), gen)
  disc_st = Flux.setup(Adam(.0001), disc)

  @showprogress for epoch in 1:n_epochs

    ## Train the discriminator `disc`
    noise = CUDA.randn(Float32, latent_dim, 1)
    fake_imgs = gen(noise) # pass the noise through the generator to get a synthetic sample
    real_imgs = rand(Float32, size(fake_imgs)...)

    disc_loss(m) = bce_loss(ones(Float32, 1, 1), m(real_imgs)) +
                    bce_loss(zeros(Float32, 1, 1), m(fake_imgs))
    grads_d = gradient(disc_loss, disc)
    Flux.update!(disc_st, disc, grads_d[1])

    ## Train the generator `gen`
    noise = CUDA.randn(Float32, latent_dim, 1)
    gen_loss(m) = bce_loss( ones(Float32, 1, 1), disc(m(noise)))
    grads_g = gradient(gen_loss, gen)
    Flux.update!(gen_st, gen, grads_g[1])

    #println("Epoch $(epoch): Discriminator loss = $(disc_loss), Generator loss = $(gen_loss)")
    #sleep(.1)
  end
end

Train the GAN

In [None]:
train_gan(gen, disc)

Generate and plot some images

In [None]:
latent_dim = 16
noise = randn(Float32, latent_dim, 16)
generated_images = [ gen(noise[:, i]) for i in 1:16 ];

In [None]:
using CairoMakie

In [None]:
#=
plot_images = [ plot(Gray.(generated_images[i])[:,:,1,1]) for i in 1:16 ]
titles = reshape([string(i) for i in 1:16], 1, :);
=#

In [None]:
#=
plot(
    plot_images...,
    layout = (4, 4),
    title = titles, titleloc=:right, titlefont=font(8),
    size = (800, 800)
)
=#