## Introduction

<a href="https://colab.research.google.com/github/ntt123/pax/blob/main/examples/notebooks/DCGAN.ipynb" target="_top"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg" style="vertical-align:text-bottom"></a>

An example of training a DCGAN model on CelebA dataset.

This follows the Pytorch's tutorial at:
https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html

In [None]:
# uncomment to install PAX
# !pip3 install -q git+https://github.com/ntt123/pax#egg=pax3[test]
# !pip3 install -q tensorflow-datasets==4.0.1

In [None]:
import math
from functools import partial
from typing import NamedTuple

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import opax
import pax
import tensorflow as tf
import tensorflow_datasets as tfds
from tqdm.auto import tqdm
import os


class TrainRecord(NamedTuple):
    errD: jnp.ndarray
    errG: jnp.ndarray
    D_fake: jnp.ndarray
    D_real: jnp.ndarray
    D_G: jnp.ndarray

In [None]:
pax.seed_rng_key(42)
batch_size = 128
image_size = 64  # size of input image: 64x64
nz = 100  # size of z latent vector
nc = 3  # number of output channels
ngf = 64  # size of feature maps in generator
ndf = 64  # size of feature maps in discriminator
lr = 2e-4  # adam learning rate
beta1 = 0.5  # adam beta1
num_training_steps = 20_000  # CelebA has 202,599 images, which is 1582 min-batches. Therefore, we are training for 12.6 epochs.
log_freq = 500  # logging frequency: [steps] per [log record]
img_mean = 0.5
img_scale = 0.51  # avoid -1, 1

## Celeb-A dataset

In [None]:
### load celeb_a dataset

# This is a hack to use a custom link to celeb-a dataset in tensorflow-datasets.
# replace the ``tfds.image.CelebA._split_generators`` method by the following method
# which uses our custom links.

IMG_ALIGNED_DATA = (
    "https://drive.google.com/uc?export=download&"
    "id=1iQRFaGXRiPBd-flIm0u-u8Jy6CfJ_q6j"
)

EVAL_LIST = (
    "https://drive.google.com/uc?export=download&"
    "id=1ab9MDLOblszbKKXoDe8jumFsSkn6lIX1"
)
# Landmark coordinates: left_eye, right_eye etc.
LANDMARKS_DATA = (
    "https://drive.google.com/uc?export=download&"
    "id=1y8qfK-jaq1QWl9v_n_mBNIMu5-h3UXK4"
)

# Attributes in the image (Eyeglasses, Mustache etc).
ATTR_DATA = (
    "https://drive.google.com/uc?export=download&"
    "id=1BPfcVuIqrAsJAgG40-XGWU7g2wmmQU30"
)


def _split_generators(self, dl_manager):
    downloaded_dirs = dl_manager.download(
        {
            "img_align_celeba": IMG_ALIGNED_DATA,
            "list_eval_partition": EVAL_LIST,
            "list_attr_celeba": ATTR_DATA,
            "landmarks_celeba": LANDMARKS_DATA,
        }
    )

    # Load all images in memory (~1 GiB)
    # Use split to convert: `img_align_celeba/000005.jpg` -> `000005.jpg`
    all_images = {
        os.path.split(k)[-1]: img
        for k, img in dl_manager.iter_archive(downloaded_dirs["img_align_celeba"])
    }

    return [
        tfds.core.SplitGenerator(
            name=tfds.Split.TRAIN,
            gen_kwargs={
                "file_id": 0,
                "downloaded_dirs": downloaded_dirs,
                "downloaded_images": all_images,
            },
        ),
        tfds.core.SplitGenerator(
            name=tfds.Split.VALIDATION,
            gen_kwargs={
                "file_id": 1,
                "downloaded_dirs": downloaded_dirs,
                "downloaded_images": all_images,
            },
        ),
        tfds.core.SplitGenerator(
            name=tfds.Split.TEST,
            gen_kwargs={
                "file_id": 2,
                "downloaded_dirs": downloaded_dirs,
                "downloaded_images": all_images,
            },
        ),
    ]


tfds.image.CelebA._split_generators = _split_generators

ds = tfds.load("celeb_a")


def img_ops(x):
    img = tf.cast(x["image"], tf.float32) / 255.0
    img = tf.image.resize(img, (image_size * 2, image_size), preserve_aspect_ratio=True)
    img = tf.image.crop_to_bounding_box(img, 7, 0, 64, 64)
    img = (img - img_mean) / img_scale
    return img


dataset = ds["train"].concatenate(ds["validation"]).concatenate(ds["test"]).map(img_ops)

In [None]:
def make_image_grid(images, padding=2):
    """Place images in a square grid."""
    n = images.shape[0]
    size = int(math.sqrt(n))
    assert size * size == n, "expecting a square grid"
    img = images[0]

    H = img.shape[0] * size + padding * (size + 1)
    W = img.shape[1] * size + padding * (size + 1)
    out = np.zeros((H, W, img.shape[-1]), dtype=img.dtype)
    for i in range(n):
        x = i % size
        y = i // size
        xstart = x * (img.shape[0] + padding) + padding
        xend = xstart + img.shape[0]
        ystart = y * (img.shape[1] + padding) + padding
        yend = ystart + img.shape[1]
        out[xstart:xend, ystart:yend, :] = images[i]
    return out


def show_training_images():
    images = next(dataset.take(64).batch(64).as_numpy_iterator())
    plt.figure(figsize=(8, 8))
    plt.axis("off")
    plt.title("Training Images")
    _ = plt.imshow(make_image_grid(images * img_scale + img_mean))
    plt.show()


show_training_images()

### Weight Initialization

From the DCGAN paper, the authors specify that all model weights shall be randomly initialized from a Normal distribution with mean=0, stdev=0.02.


In [None]:
w_init = jax.nn.initializers.normal(0.02)

## Generator

In [None]:
class Generator(pax.Module):
    def __init__(self):
        super().__init__()
        conv_transpose = partial(pax.nn.Conv2DTranspose, with_bias=False, w_init=w_init)
        self.F = pax.nn.Sequential(
            conv_transpose(nz, ngf * 8, 4, 1, padding="VALID"),
            pax.nn.BatchNorm2D(ngf * 8, True, True, 0.9),
            jax.nn.relu,
            conv_transpose(ngf * 8, ngf * 4, 4, 2, padding="SAME"),
            pax.nn.BatchNorm2D(ngf * 4, True, True, 0.9),
            jax.nn.relu,
            conv_transpose(ngf * 4, ngf * 2, 4, 2, padding="SAME"),
            pax.nn.BatchNorm2D(ngf * 2, True, True, 0.9),
            jax.nn.relu,
            conv_transpose(ngf * 2, ngf * 1, 4, 2, padding="SAME"),
            pax.nn.BatchNorm2D(ngf * 1, True, True, 0.9),
            jax.nn.relu,
            conv_transpose(ngf, nc, 4, 2, padding="SAME"),
            jnp.tanh,
        )

    def __call__(self, x):
        return self.F(x)

## Discriminator

In [None]:
class Discriminator(pax.Module):
    def __init__(self):
        super().__init__()
        conv_config = {"with_bias": False, "w_init": w_init}
        leaky_relu = partial(jax.nn.leaky_relu, negative_slope=0.2)
        self.F = pax.nn.Sequential(
            pax.nn.Conv2D(nc, ndf, 4, 2, padding="SAME", **conv_config),
            leaky_relu,
            pax.nn.Conv2D(ndf * 1, ndf * 2, 4, 2, padding="SAME", **conv_config),
            pax.nn.BatchNorm2D(ndf * 2, True, True, 0.9),
            leaky_relu,
            pax.nn.Conv2D(ndf * 2, ndf * 4, 4, 2, padding="SAME", **conv_config),
            pax.nn.BatchNorm2D(ndf * 4, True, True, 0.9),
            leaky_relu,
            pax.nn.Conv2D(ndf * 4, ndf * 8, 4, 2, padding="SAME", **conv_config),
            pax.nn.BatchNorm2D(ndf * 8, True, True, 0.9),
            leaky_relu,
            pax.nn.Conv2D(ndf * 8, 1, 4, 1, padding="VALID", **conv_config),
        )

    def __call__(self, x):
        return self.F(x)

## Train model

In [None]:
def bce(x, target):
    """Binary Cross Entropy Loss."""
    llh = target * jax.nn.log_sigmoid(x) + (1.0 - target) * jax.nn.log_sigmoid(-x)
    return -llh


def ls(x, target):
    """Least Square Loss."""
    return 0.5 * jnp.square(x - target)


real_label_value = 1.0
fake_label_value = 0.0

# we use Least Squares GAN by default, this is
# different from Pytorch tutorial which uses BCE loss.
criterion = ls  # bce | ls


def discriminator_loss_fn(model: Discriminator, inputs):
    (real_data, _real_label), (fake_data, _fake_label) = inputs

    # doing real-samples forward pass and fake-samples forward pass separately.
    model, real_logit = pax.module_and_value(model)(real_data)
    real_loss = jnp.mean(criterion(real_logit, _real_label))

    model, fake_logit = pax.module_and_value(model)(fake_data)
    fake_loss = jnp.mean(criterion(fake_logit, _fake_label))

    D_real = jnp.mean(real_logit if criterion == ls else jax.nn.sigmoid(real_logit))
    D_fake = jnp.mean(fake_logit if criterion == ls else jax.nn.sigmoid(fake_logit))

    loss = real_loss + fake_loss
    return loss, ((loss, D_real, D_fake), model)


update_d_fn = pax.utils.build_update_fn(discriminator_loss_fn)


__Note__:
1. ``pax.utils.build_update_fn`` requires a standard signature for a loss function,
    
    ``params, model, inputs => LossFnOutput`` where ``LossFnOutput = loss, (auxiliary_info, model)``
2. The signature of the returned `update_fn` is:
    
    ``model, optimizer, inputs => updated_model, updated_optimizer, auxiliary_info``



In [None]:
def generator_loss_fn(netG, inputs):
    netD, optD, data, noise = inputs

    N = noise.shape[0]
    real_labels = jnp.ones((N, 1, 1, 1), dtype=noise.dtype) * real_label_value
    fake_labels = jnp.ones((N, 1, 1, 1), dtype=noise.dtype) * fake_label_value
    real = data
    netG, fake = pax.module_and_value(netG)(noise)
    D_inputs = ((real, real_labels), (fake, fake_labels))

    # Note: we are updating the discriminator inside the generator's loss funcion,
    # we can do it outside the loss function, however,
    # it will require an additional ``netG(noise)`` call.
    # Also, `stop_gradient` call is important here because
    # we don't want jax to differentiate through the `update_d_fn`.
    netD, optD, (errD, D_real, D_fake) = update_d_fn(netD, optD, D_inputs)
    netD = jax.lax.stop_gradient(netD)

    netD, logit = pax.module_and_value(netD)(fake)
    loss = criterion(logit, real_labels)
    loss = jnp.mean(loss)
    errG = loss
    D_G = jnp.mean(logit if criterion == ls else jax.nn.sigmoid(logit))

    train_record = TrainRecord(
        errD=errD, errG=errG, D_real=D_real, D_fake=D_fake, D_G=D_G
    )

    return loss, ((netD, optD, train_record), netG)


update_g_fn = pax.utils.build_update_fn(generator_loss_fn)


def update_fn(
    models_and_key,
    data: jnp.ndarray,
):
    netG, netD, optG, optD, rng_key = models_and_key
    rng_key, rng_key_1 = jax.random.split(rng_key)

    N, H, W, C = data.shape
    noise = jax.random.normal(rng_key_1, (N, 1, 1, nz), dtype=data.dtype)
    netG, optG, (netD, optD, train_record) = update_g_fn(
        netG, optG, (netD, optD, data, noise)
    )

    return (netG, netD, optG, optD, rng_key), train_record

In [None]:
## Source: https://github.com/deepmind/dm-haiku/blob/main/examples/imagenet/dataset.py#L163
def double_buffer(ds):
    batch = None
    for next_batch in ds:
        assert next_batch is not None
        next_batch = jax.device_put(next_batch)
        if batch is not None:
            yield batch
        batch = next_batch
    if batch is not None:
        yield batch


def train():
    netG = Generator()
    print(netG.summary())
    netD = Discriminator()
    print(netD.summary())

    rng = jax.random.PRNGKey(42)
    fixed_noise = jax.random.normal(rng, (64, 1, 1, nz))

    adam = opax.adam(lr, b1=beta1, b2=0.999)
    optD = adam(netD.parameters())
    optG = adam(netG.parameters())

    fast_update_fn = jax.jit(update_fn)

    fast_gen = jax.jit(lambda model, x: model.eval()(x))

    dataloader = (
        dataset.cache()
        .repeat()
        .shuffle(len(dataset) // 2)
        .batch(batch_size)
        .prefetch(tf.data.AUTOTUNE)
        .as_numpy_iterator()
    )

    rng_key = jax.random.PRNGKey(42)
    accum_train_record = TrainRecord(0.0, 0.0, 0.0, 0.0, 0.0)
    tr = tqdm(range(1, 1 + num_training_steps), desc="Training")
    data_iter = double_buffer(dataloader)

    D_losses, G_losses = [], []

    for i in tr:
        data = next(data_iter)
        (netG, netD, optG, optD, rng_key), train_record = fast_update_fn(
            (netG, netD, optG, optD, rng_key), data
        )

        accum_train_record = jax.tree_map(
            lambda x, y: x + y, accum_train_record, train_record
        )

        D_losses.append(train_record.errD)
        G_losses.append(train_record.errG)

        if i % log_freq == 0:
            avg: TrainRecord = jax.tree_map(lambda x: x / log_freq, accum_train_record)
            print(
                "[Step {:>4}]  errD {:.3f}  errG {:.3f}  D_real {:.3f}  D_fake {:.3f}  D_G {:.3f}".format(
                    i, avg.errD, avg.errG, avg.D_real, avg.D_fake, avg.D_G
                )
            )
            accum_train_record = TrainRecord(0.0, 0.0, 0.0, 0.0, 0.0)

        if i % (5 * log_freq) == 0:
            images = fast_gen(netG, fixed_noise)
            plt.figure(figsize=(8, 8))
            plt.axis("off")
            plt.title("Fake Images")
            images = jax.device_get(
                jnp.clip(images * img_scale + img_mean, a_min=0.0, a_max=1.0)
            )
            plt.imshow(make_image_grid(images))
            plt.show()

    plt.figure(figsize=(10, 5))
    plt.title("Generator and Discriminator Loss During Training")
    plt.plot(G_losses, label="G")
    plt.plot(D_losses, label="D")
    plt.xlabel("iterations")
    plt.ylabel("Loss")
    plt.ylim(0, 3)
    plt.legend()
    plt.show()

    return netG, netD, fixed_noise


netG, netD, fixed_noise = train()