# DCGAN

[![arXiv](https://img.shields.io/badge/arXiv-1511.06434-b31b1b?logo=arxiv)](https://arxiv.org/abs/1511.06434)
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adamelliotfields/gan/blob/main/dcgan/mnist.ipynb)
[![Render nbviewer](https://img.shields.io/badge/render-nbviewer-f37726)](https://nbviewer.org/github/adamelliotfields/gan/blob/main/dcgan/mnist.ipynb)
[![W&B](https://img.shields.io/badge/Weights_&_Biases-FFCC33?logo=WeightsAndBiases&logoColor=black)](https://wandb.ai/adamelliotfields/gan-mnist)

Based on the official Keras [tutorial](https://keras.io/examples/generative/dcgan_overriding_train_step).

In [None]:
%load_ext tensorboard
%matplotlib inline

In [None]:
# @title Config
SEED = 42  # @param {type:"integer"}
VERBOSE = 1  # @param {type:"integer"}
EPOCHS = 100  # @param {type:"integer"}
LATENT_DIM = 100  # @param {type:"integer"}
BATCH_SIZE = 256  # @param {type:"integer"}
JIT_COMPILE = True  # @param {type:"boolean"}
DTYPE_POLICY = "mixed_float16"  # @param ["float32", "mixed_float16"] {type:"string"}

# lower learning rate and beta_1 for Adam optimizer
# beta_1 (m_t in the paper) is the first moment estimate
LEARNING_RATE = 0.0002  # @param {type:"number"}
BETA_1 = 0.5  # @param {type:"number"}

# https://wandb.ai/{ENTITY}/{PROJECT}
WANDB_ENTITY = "adamelliotfields"  # @param {type:"string"}
WANDB_PROJECT = "gan-mnist"  # @param {type:"string"}

In [None]:
# @title Environment
import os
import subprocess

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["KERAS_BACKEND"] = "tensorflow"

try:
    from google.colab import userdata

    try:
        # only log to W&B if there's a key
        os.environ["WANDB_API_KEY"] = userdata.get("WANDB_API_KEY")
        os.environ["WANDB_DISABLE_GIT"] = "true"
    except (userdata.NotebookAccessError, userdata.SecretNotFoundError):
        pass

    GOOGLE_DRIVE_DIR = "/content/drive/MyDrive"
    MODEL_SAVE_DIR = os.path.join(GOOGLE_DRIVE_DIR, "keras", "models")
    TENSORBOARD_LOG_DIR = os.path.join(GOOGLE_DRIVE_DIR, "tensorboard")
    os.environ["TFDS_DATA_DIR"] = os.path.join(GOOGLE_DRIVE_DIR, "tensorflow_datasets")
    subprocess.run(["pip", "install", "-qU", "keras", "wandb"])
except ImportError:
    MODEL_SAVE_DIR = "./"
    TENSORBOARD_LOG_DIR = "./logs"

In [None]:
# @title Imports
import math
import wandb
import tensorflow as tf
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds

from datetime import datetime
from wandb.integration.keras import WandbMetricsLogger

from keras import (
    Input,
    Model,
    Sequential,
    backend,
    callbacks,
    config,
    layers,
    losses,
    metrics,
    ops,
    optimizers,
    random,
)

In [None]:
# @title Functions
def generate_predictions(generator, n=4):
    seed = random.normal((n, LATENT_DIM))
    predictions = generator(seed, training=False)
    return predictions


def plot_predictions(predictions, figsize=(2, 2)):
    fig = plt.figure(figsize=figsize)
    for i in range(predictions.shape[0]):
        plt.subplot(figsize[0], figsize[1], i + 1)
        plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap="gray")
        plt.axis("off")
    plt.show()
    return fig

In [None]:
# @title Data
(mnist_train, mnist_test), mnist_info = tfds.load(
    "mnist",
    with_info=True,
    as_supervised=True,
    split=["train", "test"],
)
X_train = (
    mnist_train.concatenate(mnist_test)
    .map(
        # normalize to -1,1 and remove label
        lambda x, _: (ops.cast(x, "float32") - 127.5) / 127.5
    )
    .shuffle(70000)
    .batch(BATCH_SIZE)
    .prefetch(tf.data.AUTOTUNE)
)

In [None]:
# @title DCGAN
class DCGAN(Model):
    def __init__(self):
        super().__init__()
        self.seed_generator = random.SeedGenerator(SEED)
        self.G_loss_metric = metrics.Mean(name="g_loss")
        self.D_loss_metric = metrics.Mean(name="d_loss")
        self.loss = losses.BinaryCrossentropy(from_logits=True)

        # we need to train the generator and discriminator together
        # we'll handle computing and applying gradients manually, so we won't need to compile them
        self.generator = Sequential(
            [
                # use BN-ReLU in generator; dropout in discriminator
                # normally, bias terms are added to the outputs of the layer to offset
                # batch normalization centers the data, which cancels out any offset
                # therefore, we don't need bias terms in layers with BN
                Input(shape=(LATENT_DIM,)),
                layers.Dense(7 * 7 * 256, use_bias=False),
                layers.BatchNormalization(),
                layers.LeakyReLU(),
                layers.Reshape((7, 7, 256)),
                # use transpose (deconvolution) instead of upsample + convolution
                # upsampling uses a fixed technique like bilinear interpolation that isn't learnable
                # so transposed convolution is like a learnable upsampling
                layers.Conv2DTranspose(
                    128,
                    strides=1,
                    kernel_size=5,
                    padding="same",
                    use_bias=False,
                ),
                layers.BatchNormalization(),
                layers.LeakyReLU(),
                layers.Conv2DTranspose(
                    64,
                    strides=2,
                    kernel_size=5,
                    padding="same",
                    use_bias=False,
                ),
                layers.BatchNormalization(),
                layers.LeakyReLU(),
                # tanh activation because we normalized the discriminator's training images to -1,1
                # if we normalized to 0,1 we would use sigmoid here
                # (-1,1 is recommended)
                layers.Conv2DTranspose(
                    1,
                    strides=2,
                    kernel_size=5,
                    padding="same",
                    activation="tanh",
                ),
            ]
        )
        self.discriminator = Sequential(
            [
                Input(shape=(28, 28, 1)),
                layers.Conv2D(64, strides=2, kernel_size=5, padding="same"),
                layers.LeakyReLU(),
                layers.Dropout(0.3),
                layers.Conv2D(128, strides=2, kernel_size=5, padding="same"),
                layers.LeakyReLU(),
                layers.Dropout(0.3),
                layers.Flatten(),
                layers.Dense(1),  # no activation so use from_logits
            ]
        )

    # anything that can change during recompilation should be in compile
    def compile(self, jit_compile="auto", learning_rate=1e-3, beta_1=0.9):
        super().compile(jit_compile=jit_compile)
        self.G_optimizer = optimizers.Adam(learning_rate=learning_rate, beta_1=beta_1)
        self.D_optimizer = optimizers.Adam(learning_rate=learning_rate, beta_1=beta_1)

    @property
    def metrics(self):
        return [self.G_loss_metric, self.D_loss_metric]

    def G_loss(self, fake_output):
        return self.loss(ops.ones_like(fake_output), fake_output)

    def D_loss(self, real_output, fake_output):
        real_loss = self.loss(ops.ones_like(real_output), real_output)
        fake_loss = self.loss(ops.zeros_like(fake_output), fake_output)
        return real_loss + fake_loss

    def train_step(self, images):
        # 256 vectors of 100 random floats each
        noise = random.normal((BATCH_SIZE, LATENT_DIM), seed=self.seed_generator)

        # use a single context to avoid duplicate code
        with tf.GradientTape() as G_tape, tf.GradientTape() as D_tape:
            generated_images = self.generator(noise, training=True)
            real_output = self.discriminator(images, training=True)
            fake_output = self.discriminator(generated_images, training=True)
            G_loss_metric = self.G_loss(fake_output)
            D_loss_metric = self.D_loss(real_output, fake_output)

        G_grads = G_tape.gradient(G_loss_metric, self.generator.trainable_variables)
        D_grads = D_tape.gradient(D_loss_metric, self.discriminator.trainable_variables)

        self.G_optimizer.apply_gradients(zip(G_grads, self.generator.trainable_variables))
        self.D_optimizer.apply_gradients(zip(D_grads, self.discriminator.trainable_variables))

        self.G_loss_metric.update_state(G_loss_metric)
        self.D_loss_metric.update_state(D_loss_metric)
        return {"g_loss": self.G_loss_metric.result(), "d_loss": self.D_loss_metric.result()}

In [None]:
# @title Compile
backend.clear_session()
config.set_dtype_policy(DTYPE_POLICY)
dcgan = DCGAN()
dcgan.compile(
    beta_1=BETA_1,
    jit_compile=JIT_COMPILE,
    learning_rate=LEARNING_RATE,
)

In [None]:
# @title Train
tensorboard_cb = callbacks.TensorBoard(
    log_dir=os.path.join(
        TENSORBOARD_LOG_DIR,
        WANDB_PROJECT,
        datetime.now().strftime("%Y%m%d%H%M%S"),
    )
)
checkpoint_cb = callbacks.ModelCheckpoint(
    os.path.join(MODEL_SAVE_DIR, "dcgan-mnist.weights.h5"),
    save_weights_only=True,
    save_freq=math.ceil(70000 / BATCH_SIZE) * EPOCHS,  # only save after training
)

if not os.environ.get("WANDB_API_KEY"):
    dcgan.fit(
        X_train,
        epochs=EPOCHS,
        verbose=VERBOSE,
        callbacks=[tensorboard_cb, checkpoint_cb],
    )
    predictions = generate_predictions(dcgan.generator, n=25)
    fig = plot_predictions(predictions, figsize=(5, 5))
else:
    with wandb.init(
        group="dcgan",
        job_type="train",
        entity=WANDB_ENTITY,
        project=WANDB_PROJECT,
        sync_tensorboard=False,
        config={
            "model": "DCGAN",
            "epochs": EPOCHS,
            "beta_1": BETA_1,
            "batch_size": BATCH_SIZE,
            "latent_dim": LATENT_DIM,
            "dtype_policy": DTYPE_POLICY,
            "learning_rate": LEARNING_RATE,
        },
    ) as run:
        dcgan.fit(
            X_train,
            epochs=EPOCHS,
            verbose=VERBOSE,
            callbacks=[tensorboard_cb, checkpoint_cb, WandbMetricsLogger()],
        )
        predictions = generate_predictions(dcgan.generator, n=25)
        fig = plot_predictions(predictions, figsize=(5, 5))
        fig.savefig("generations.png")
        run.log({"Generations (DCGAN)": wandb.Image("generations.png")})