# Wasserstein GAN and conditional GAN (2 in 1)

This lab borrows a lot from the previous topic  notebook (Vanilla GAN). Again, we are going to train a GAN that mimics the images from the MNIST dataset. Again, most of the code is written for you already, and your task will be to fill in some of the missing parts.

This time, though, we are going to change the GAN objective to turn our "vanilla" GAN into a **Wasserstein GAN** (WGAN-GP, [arXiv:1704.00028](https://arxiv.org/abs/1704.00028)). Another change is that this time we will train a **conditional GAN**: we will condition our generative model on the labels of the digits. In other words, we'll train our GAN to generate digits of a given sort.

As always, let's start by importing all the required libraries:

In [None]:
import torch
from torchvision.datasets import MNIST
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
import multiprocessing
from IPython.display import clear_output

In [None]:
# PyTorch Lightning
try:
    import pytorch_lightning as pl
except ModuleNotFoundError: # Google Colab does not have PyTorch Lightning
                            # installed by default.
                            # Hence, we do it here if necessary
    !pip install pytorch-lightning==1.3.4
    import pytorch_lightning as pl

And download the MNIST dataset:

In [None]:
transform = transforms.ToTensor()
ds_train = MNIST(".", train=True, download=True, transform=transform)
ds_test = MNIST(".", train=False, download=True, transform=transform)

Here's a helper function to glue together a bunch of images and plot them in a matrix:

In [None]:
def join_images(images):
    # We expect `images` to be a 2D array of 2D images represented
    # as uint8 numbers (from 0 to 255)
    images = np.array(images)
    assert images.ndim == 4, "This function expects the input " \
                             "to be a 2D matrix of 2D images"

    # Let's pad our images with the white color to separate the
    # nearby images once we glue them together
    joined_image = np.pad(
        images, ((0, 0), (0, 0), (0, 1), (0, 1)), constant_values=255
    )

    # Here we transpose and reshape our 4D array to look at it as if
    # it was a single 2D image with all the sub-images placed in a
    # rectangular grid
    joined_image = np.transpose(joined_image, (0, 2, 1, 3))
    joined_image = joined_image.reshape(
        joined_image.shape[0] * joined_image.shape[1],
        joined_image.shape[2] * joined_image.shape[3],
    )

    # Finally, we pad the top and left sides of our resulting image with
    # the white color
    joined_image = np.pad(
        joined_image, ((1, 0), (1, 0)), constant_values=255
    )

    return joined_image

# A simple function to quickly plot a matrix of images
def plot_images(images):
    plt.imshow(join_images(images), cmap="Greys")
    plt.axis("off")

Let's test our functions:

In [None]:
plot_images(
    ds_train.data.numpy()[:20].reshape(4, 5, 28, 28)
)

OK, we've reached the interesting part! Similarly to the previous notebook, we are going to code all the logic of our GAN training within a single lightning module. Before we do, here's a quick reminder on how WGAN-GP works.

Contrarily to the "vanilla" GAN, where the discriminator minimizes the binary cross-entropy loss for classifying its inputs into "real" and "generated" ones, the WGAN discriminator (or *critic*, as it's called in the original paper) solves the optimization task from the dual form of the Wasserstein distance between the real and generated distributions:
$$W(p_{\text{real}},p_{\text{gen},\theta})=\sup_{\Vert D\Vert_L\leq1}\left[\underset{x\sim p_{\text{data}}}{\mathbb{E}}D(x) - \underset{x\sim p_{\text{gen},\theta}}{\mathbb{E}}D(x)\right],$$
where $\Vert D\Vert_L\leq1$ means that our discriminator is Lipschitz-continuous with the Lipschitz constant of 1. The latter property will be enforced by the gradient penalty term:
$$\lambda\cdot\underset{\hat x\,\sim\,p_{\hat x}}{\mathbb{E}}\left[\left(\left\Vert\nabla_{\hat x}D(\hat x)\right\Vert-1\right)^2\right],$$
$$\hat x=\alpha\cdot x+(1-\alpha)\cdot y,$$
$$x\sim p_{\text{data}},~~y\sim p_{\text{gen},\theta},~~\alpha\sim\text{Uniform}(0,1).$$

All the expectations in the formulas above are estimated on samples, i.e. the actuall loss functions are calculated as follows:
$$L_D=L_D^{\text{main}}+\lambda\cdot\text{GP},$$
$$L_D^{\text{main}}=\frac{1}{\text{batch_size}}\sum_i\left[D(G(z_i))-D(x_i)\right],$$
$$\text{GP}=\frac{1}{\text{batch_size}}\sum_i\left[\left(\left\Vert\nabla_{\hat x_i}D(\hat x_i)\right\Vert-1\right)^2\right],$$
$$\hat x_i\,=\,\alpha_i\ x_i\,+\,(1-\alpha_i) G(z_i),$$
$$L_G=\frac{1}{\text{batch_size}}\sum_i\left[-D(G(z_i))\right].$$
Here, $x_i$, $z_i$ and $\alpha_i$ are samples from the real data, the latent and the Uniform(0, 1) distributions, respectively, and $\Vert\cdot\Vert$ denotes the usual L2 norm of the gradient vector. As always, $L_D$ and $L_G$ are minimized in turns, with respect to the parameters of the discriminator and generator, respectively.

Finally, since we want to train a *conditional* GAN, we need to add the conditional labels to both generator and discriminator networks at each call, i.e.:
$$G(z_i)\to G(z_i,c_i),$$
$$D(x_i)\to D(x_i,c_i),$$
where $c_i$ is the label that corresponds to the object $x_i$ or to the object generated by the latent code $z_i$.

Ok, let's code the model we've just described:

In [None]:
class GAN(pl.LightningModule):
    def __init__(
        self,
        generator, # the generator network (torch.Module)
        discriminator, # the discriminator network (torch.Module)
        latent_size, # the dimensionality of z
        num_disc_steps=1, # number of discriminator updates per single generator update
        initial_lr=0.001, # initial learning rate
        lr_decay_rate=0.95, # a factor, by which the learning rates will be multiplied after each epoch
        gp_factor=10, # gradient penalty constant
        **kwargs
    ):
        super().__init__(**kwargs)

        self.generator = generator
        self.discriminator = discriminator
        self.latent_size = latent_size
        self.num_disc_steps = num_disc_steps
        self.initial_lr = initial_lr
        self.lr_decay_rate = lr_decay_rate
        self.gp_factor = gp_factor

        # We'll use this counter to switch between generator and discriminator steps
        self._gan_step_counter = 0

        # Important: This property activates manual optimization in lightning.
        # Since GAN training steps are quite non-standard, we opt for manual
        # optimization that we'll implement on our own.
        self.automatic_optimization = False

    # A function to sample a batch of latent vectors z
    def generate_z(self, N):
        return torch.randn(N, self.latent_size).to(self.device)

    # A function to sample a batch of generated objects G(z, c)
    def generate(self, one_hot_labels):
        return self.generator(
            self.generate_z(len(one_hot_labels)),
            one_hot_labels,
        )

    # We'll use the function below to calculate both generator's and
    # discriminator's losses
    def _shared_losses_calculation(self, real_img_batch, one_hot_labels):
        """Calculate the loss value on a given batch"""

        batch_size = len(real_img_batch)

        # generate a batch of fakes:
        fake_img_batch = self.generate(one_hot_labels)

        # calculate the discriminator output on real and fake batches:
        d_real = self.discriminator(real_img_batch, one_hot_labels)
        d_fake = self.discriminator(fake_img_batch, one_hot_labels)

        # Here we calculate the losses using the functions:
        # `wgan_discriminator_loss_main`, `gradient_penalty` and
        # `wgan_generator_loss` that you will implement later.
        d_loss = (
            wgan_discriminator_loss_main(d_real, d_fake)
            + gradient_penalty(
                batch_real=real_img_batch,
                batch_fake=fake_img_batch,
                one_hot_labels=one_hot_labels,
                discriminator=self.discriminator,
            ) * self.gp_factor
        )
        g_loss = wgan_generator_loss(d_fake)

        return d_loss, g_loss

    # This function will be iteratively called by lightning, automatically.
    # Here we make our optimization steps.
    def training_step(self, batch, batch_idx):
        # extract the objects, ignore the MNIST labels (digit indicies):
        batch, labels = batch
        one_hot_labels = torch.nn.functional.one_hot(labels, 10)

        # get the optimizers (see the `configure_optimizers` method below):
        d_opt, g_opt = self.optimizers()

        # calculate both losses:
        d_loss, g_loss = (
            self._shared_losses_calculation(batch, one_hot_labels)
        )

        # Choose, which update step to make:
        if self._gan_step_counter < self.num_disc_steps:
            # Making a discriminator step
            self._gan_step_counter += 1
            d_opt.zero_grad()
            self.manual_backward(d_loss) # https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#manual-backward
            d_opt.step()
        else:
            # Making a generator step
            self._gan_step_counter = 0
            g_opt.zero_grad()
            self.manual_backward(g_loss) # https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#manual-backward
            g_opt.step()

        # Logging our losses
        self.log("train_loss_discriminator", d_loss)
        self.log("train_loss_generator", g_loss)

    # This function will be automatically called by lightning
    # at each training epoch end. Inside, we are going to schedule our
    # learning rates (see the `configure_optimizers` method below).
    def training_epoch_end(self, outputs):
        for sch in self.lr_schedulers():
            sch.step()


    # Here, we configure our optimizers for the discriminator and generator,
    # along with learning rate schedulers.
    def configure_optimizers(self):
        d_opt = torch.optim.RMSprop(self.discriminator.parameters(), lr=self.initial_lr)
        g_opt = torch.optim.RMSprop(self.generator.parameters(), lr=self.initial_lr)

        # A learning rate scheduler is an object that changes the learning rate
        # during training. Below, we create ExponentialLR instances to
        # exponentially decay the learning rates (multiply the learning rate by
        # a factor of 0.95 after each epoch).
        d_scheduler = torch.optim.lr_scheduler.ExponentialLR(d_opt, self.lr_decay_rate)
        g_scheduler = torch.optim.lr_scheduler.ExponentialLR(g_opt, self.lr_decay_rate)

        return [
            {"optimizer" : d_opt, "lr_scheduler" : d_scheduler},
            {"optimizer" : g_opt, "lr_scheduler" : g_scheduler},
        ]


# Task (6 points)
Fill the blanks below to implement the functions:
1. `wgan_discriminator_loss_main` (2 points),
2. `wgan_generator_loss` (2 points),
3. `gradient_penalty` (2 points).

For the reference, see the definitions of $L_D^{\text{main}}$, $L_G$ and $\text{GP}$.

In [None]:
def wgan_discriminator_loss_main(d_real, d_fake):
    """
    Calculates loss for the discriminator (not including the gradient penalty term).
    Input parameters:
      d_real -- discriminator outputs for a batch of real objects and their
                respective labels
      d_fake -- discriminator outputs for a batch of fake objects and their
                respective labels
    """
    raise NotImplementedError # <= remove this
    # <YOUR CODE>
    # return <YOUR CODE>


def wgan_generator_loss(d_fake):
    """
    Calculates loss for the generator.
    Input parameters:
      d_fake -- discriminator outputs for a batch of fake objects and their
                respective labels
    """
    raise NotImplementedError # <= remove this
    # <YOUR CODE>
    # return <YOUR CODE>


def gradient_penalty(batch_real, batch_fake, one_hot_labels, discriminator):
    """
    Calculates the gradient penalty term.
    Input parameters:
      batch_real -- a batch of real objects
      batch_fake -- a batch of fake objects (should be of same size as
                    `batch_real`)
      one_hot_labels -- a batch of one-hot labels, corresponding to the objects
                        in `batch_real` and `batch_fake`
      discriminator -- the discriminator network (torch.Module)
    """
    # First, let's calculate the x_hat from the GP formula, i.e. the linear
    # interpolates between the real and fake objects.
    
    # We'll sample the interpolation coefficients `alpha` from a uniform
    # distribution. We need to be careful, though, to get a single `alpha` value
    # for each image in the batch (i.e., we want to have the same alpha sample
    # for all the pixels of the same image, but different alpha samples for
    # different images). Since our images are of the shape (B, C, H, W),
    # this can be acheved by sampling `alpha` of the shape (B, 1, 1, 1).
    batch_size = len(batch_real)
    alpha = torch.empty(
        batch_size, 1, 1, 1, dtype=batch_real.dtype, device=batch_real.device
    )
    alpha.uniform_(0, 1)
    interpolates = alpha * batch_real + (1 - alpha) * batch_fake

    # Now, let's calculate the D(x_hat):
    d_output = discriminator(interpolates, one_hot_labels)

    # Now we can calculate the gradients. We'll use the `torch.autograd.grad`
    # function to do that. It's important to pass `create_graph=True` to it,
    # such that we can then backpropagate through the result when optimizing
    # our discriminator. Check out the docs:
    #   https://pytorch.org/docs/stable/generated/torch.autograd.grad.html
    
    # One of the important arguments here is the one called `grad_outputs`.
    # It can typically be ignored when calculating the gradient of a scalar,
    # but now we are calculating the gradient of a vector (`d_output` is a batch
    # of the discriminator outputs). Hence, this argument is mandatory.
    
    # A gradient of a vector with respect to another vector is a matrix
    # (Jacobian). The `torch.autograd.grad` function is implemented to return
    # `matmul(Jacobian, grad_outputs)` product, i.e. the matrix-vector product
    # of the Jacobian matrix and the `grad_outputs` vector.

    # Think, what vector should be passed to the `grad_outputs` argument to
    # get the correct result, and fill the gap below:

    # grad_outputs = <YOUR CODE>
    raise NotImplementedError # <= remove this

    # And now we calulate the gradient. The resulting gradient has the same
    # shape as our images, so we reshape it to (B, C * H * W) to treat it as a
    # vector later.
    grads = torch.autograd.grad(
        [d_output],
        [interpolates],
        create_graph=True,
        grad_outputs=grad_outputs,
    )[0].view(batch_size, -1)

    # OK, now that you have the `grads` vector (dD(x_hat)/dx_hat), use it to
    # calculate the gradient penalty value.
    # HINT: use `torch.linalg.norm` to calculate the norm. Calculating it
    # manually by squaring, summing and then taking the sqrt may result in a NaN
    # for a derivative if the input vector has 0 length (due to the 0 / 0
    # ambiguity).

    # gradient_penalty_value = <YOUR CODE>
    raise NotImplementedError # <= remove this

    return gradient_penalty_value

You may use the cell below for self-control (a correct solution should not trigger any of the assertions).

In [None]:
def test_solution():
    dummy_d_real = torch.tensor(
        [[-1.3130], [ 0.2434], [ 0.9388], [-0.3919], [-1.1801], [-0.9238], [-0.3502], [-0.7253], [ 0.2569], [ 1.4235]]
    )
    dummy_d_fake = torch.tensor(
        [[-0.0455], [-1.0809], [ 0.7711], [-0.4530], [-0.1560], [ 1.3168], [ 0.8033], [-0.0238], [-0.5631], [-0.0213]]
    )
    dummy_batch_real = torch.tensor(
        [[[[ 0.6031, -1.3677, -0.6469,  1.3742], [-0.5418,  1.9298,  0.7391,  1.9177], [-1.4565, -1.2624, -0.5140,  0.9718], [-0.3452, -0.8339,  0.6958,  1.6030]]],
         [[[ 0.2721, -0.1995, -1.0205,  0.5386], [-0.4417, -1.2073, -0.4647, -0.3979], [ 0.0356, -0.6618, -0.4126, -0.0598], [ 1.1912,  2.4061, -0.0862, -0.3650]]],
         [[[ 1.6891, -0.3839,  0.1758, -2.6211], [ 0.8489,  1.3923,  1.7145,  1.2908], [-1.3113,  1.4497, -1.4642,  1.1934], [ 0.5669, -0.9430,  0.7101, -1.6577]]],
         [[[ 0.3040,  0.0942, -0.5878, -0.4074], [ 0.9896,  0.0596,  0.7263, -0.9611], [ 1.3899, -0.7076, -1.7282, -0.4437], [ 0.4054, -0.1272,  0.9978,  0.3708]]],
         [[[ 0.5256,  0.7467,  0.3299, -0.3802], [-0.6723, -0.9191,  0.3278,  1.1507], [-0.0032,  0.7865, -0.2089, -0.7193], [ 0.3707, -2.0075, -1.1139,  0.1478]]]]
    )
    dummy_batch_fake = torch.tensor(
        [[[[-2.1661, -0.2350,  0.8560, -0.2495], [-0.4252,  1.8806, -1.4996, -0.8637], [-0.9359,  1.2126, -0.3936, -0.4491], [ 0.5070,  2.5815, -0.6126, -1.8609]]],
         [[[-1.4590, -0.4350, -0.5206,  0.1118], [-0.6213,  1.0411,  0.1881, -1.2227], [-1.3277,  0.4584,  0.3738, -1.3614], [ 0.9316, -1.9789,  0.0483, -0.6561]]],
         [[[ 1.8826,  1.0135, -1.2973, -1.8827], [-0.8091, -0.7376,  0.7225,  0.7375], [-0.8654,  1.5320,  0.8099,  0.2867], [-0.7880, -0.3300,  0.0338, -1.9983]]],
         [[[-0.3476,  0.2015, -1.0666, -0.5320], [-0.9068,  1.1607,  0.7148,  0.1309], [ 1.5926,  1.7366, -0.0851,  0.8583], [ 1.0615, -1.7650, -0.2727,  0.2254]]],
         [[[ 1.0907, -0.2203,  0.8887,  0.1553], [-0.4685, -0.9549, -2.0325, -0.6273], [ 1.0975,  0.2569,  0.1859,  0.9267], [ 0.9180, -0.6636,  0.4972,  1.5740]]]]
    )
    dummy_batch_real.requires_grad = True
    dummy_labels = torch.nn.functional.one_hot(torch.arange(5), 10)
    def dummy_discriminator(x, labels):
        dummy_multiplier_1 = torch.tensor(
            [[[[-1.0278, -1.0076, -0.9431,  0.8558],
               [ 1.4569,  1.1453,  0.0753,  0.3377],
               [ 0.9183,  1.7088, -0.1621,  0.1573],
               [ 0.6179, -0.9540,  0.2306, -0.2688]]]]
        )
        dummy_multiplier_2 = torch.tensor(
            [[-0.4931,  0.8827, -0.1299,  0.3261, -0.2990, -0.0691, -1.5011,  0.2234, 0.3917,  1.4503]]
        )
        return (
            (x * dummy_multiplier_1).sum(axis=(2, 3))
            + (labels * dummy_multiplier_2).sum(axis=1, keepdims=True)
        )

    d_loss_main = wgan_discriminator_loss_main(dummy_d_real, dummy_d_fake)
    g_loss = wgan_generator_loss(dummy_d_fake)
    gp_value = gradient_penalty(
        dummy_batch_real, dummy_batch_fake, dummy_labels, dummy_discriminator
    )

    d_loss_main_reference = torch.tensor(0.25693002343177795)
    g_loss_reference = torch.tensor(-0.054760001599788666)
    gp_value_reference = torch.tensor(6.413931846618652)

    assert d_loss_main.numel() == 1, "`wgan_discriminator_loss_main` should " \
        "return a single number. Did you forget to average the result?"
    assert torch.isclose(
        d_loss_main, d_loss_main_reference
    ).item(), "Failed test for `wgan_discriminator_loss_main`: expected " \
        f"{d_loss_main_reference.item()}, got {d_loss_main.item()}"
    assert g_loss.numel() == 1, "`wgan_generator_loss` should " \
        "return a single number. Did you forget to average the result?"
    assert torch.isclose(
        g_loss, g_loss_reference
    ).item(), "Failed test for `wgan_generator_loss`: expected " \
        f"{g_loss_reference.item()}, got {g_loss.item()}"
    assert gp_value.numel() == 1, "`gradient_penalty` should " \
        "return a single number. Did you forget to average the result?"
    assert torch.isclose(
        gp_value, gp_value_reference
    ).item(), "Failed test for `gradient_penalty`: expected " \
        f"{gp_value_reference.item()}, got {gp_value.item()}"

    print("All tests passed!")

test_solution()

The code cell below defines the generator and discriminator architectures. Since we work with images, we decide to follow the regular deep convolutional network architecture for the discriminator. For the generator, we interleave convolutions and transposed convolutions with residual connections to upsample the images.

Compared to the previous notebook, we reduced the numbers of channels to make our networks more lightweight. This is mainly due to the gradient penalty term, which makes everything run considerably slower.

Another modification is the conditional labels. In the generator, we just concatenate them with the latent code $z$. For the discriminator, we decide to make all the convolutional part not conditioned and only concatenate the condition to the dense representation afterwards. Note that we use the one-hot representation of the labels (therefore they are held in 10-dimensional vectors).

In [None]:
# This will be our upsampling block with a residual connection:
# we'll upsample the input in two ways and then add the results together.
# The two upsampling ways are:
#   1) with a transposed convolution with a trainable kernel
#   2) with a fixed nearest-neighbor interpolation upsampling
class UpsampleWithRes(torch.nn.Module):
    def __init__(self, upconv, activation, factor, **kwargs):
        super().__init__(**kwargs)

        self.path_a = torch.nn.Sequential(
            upconv, activation()
        )
        self.factor = factor

    def forward(self, x):
        x_a = self.path_a(x)
        x_b = torch.nn.functional.interpolate(x, scale_factor=self.factor)
        return x_a + x_b

# Define the generator architecture
class ConvGenerator(torch.nn.Module):
    def __init__(self, activation, latent_size, labels_size=10, **kwargs):
        super().__init__(**kwargs)

        self.fc = torch.nn.Sequential(
            torch.nn.Linear(latent_size + labels_size, 256),
            activation(),
        )
        self.backbone = torch.nn.Sequential( # 8x8
            torch.nn.Conv2d(
                in_channels=4, out_channels=64, kernel_size=3, padding=1,
            ), # -> 8x8
            activation(),
            UpsampleWithRes(
                upconv=torch.nn.ConvTranspose2d(
                    in_channels=64, out_channels=64, kernel_size=4, stride=4, padding=0,
                ),
                activation=activation,
                factor=4,
            ), # -> 32x32
            torch.nn.Conv2d(
                in_channels=64, out_channels=32, kernel_size=3, padding=0,
            ), # -> 30x30
            activation(),
            torch.nn.Conv2d(
                in_channels=32, out_channels=16, kernel_size=3, padding=0,
            ), # -> 28x28
            activation(),
            torch.nn.Conv2d(
                in_channels=16, out_channels=1, kernel_size=1, padding=0,
            ),
        )

    def forward(self, x, labels):
        return self.backbone(
            self.fc(
                torch.cat([x, labels], axis=1)
            ).view(-1, 4, 8, 8)
        )


# Define the discriminator architecture
class ConvDiscriminator(torch.nn.Module):
    def __init__(self, activation, labels_size=10, **kwargs):
        super().__init__(**kwargs)

        self.backbone = torch.nn.Sequential( # 28x28
            torch.nn.Conv2d(
                in_channels=1, out_channels=16, kernel_size=3, padding=1,
            ), # -> 28x28
            activation(),
            torch.nn.Conv2d(
                in_channels=16, out_channels=32, kernel_size=4, stride=4, padding=0,
            ), # -> 7x7
            activation(),
            torch.nn.Conv2d(
                in_channels=32, out_channels=32, kernel_size=3, padding=0,
            ), # -> 5x5
            activation(),
            torch.nn.Conv2d(
                in_channels=32, out_channels=32, kernel_size=3, padding=0,
            ), # -> 3x3
            activation(),
        )
        self.fc = torch.nn.Sequential(
            torch.nn.Linear(9 * 32 + labels_size, 32),
            activation(),
            torch.nn.Linear(32, 1),
        )

    def forward(self, x, labels):
        conv_features = self.backbone(x).view(-1, 9 * 32)
        return self.fc(
            torch.cat([conv_features, labels], axis=1)
        )

Here, we'll create a callback to plot some generator images at the end of each training epoch. We'll store the generated images in tensorboard.

In [None]:
class PlotDigitsCallback(pl.Callback):
    def on_train_epoch_end(self, trainer, module):
        tensorboard = module.logger.experiment
        image = join_images(
            (
                module.generate(
                    torch.nn.functional.one_hot(
                        torch.arange(30) % 10, 10
                    ).to(module.device)
                ).detach().cpu().numpy().clip(0, 1).reshape(
                    5, 6, 28, 28
                ) * 255
            ).astype("uint8")
        )
        tensorboard.add_image(
            "generated_images",
            image,
            trainer.global_step, dataformats="HW"
        )

Finally, let's create our networks, the GAN module and a lightning trainer object:

In [None]:
LATENT_SIZE = 128
model = GAN(
    generator=ConvGenerator(
        activation=torch.nn.ELU, latent_size=LATENT_SIZE
    ),
    discriminator=ConvDiscriminator(
        activation=torch.nn.ELU,
    ),
    latent_size=LATENT_SIZE,
    num_disc_steps=3,
    initial_lr=0.001,
)

train_loader = torch.utils.data.DataLoader(
    ds_train,
    batch_size=100,
    shuffle=True,
    num_workers=multiprocessing.cpu_count(),
)

trainer = pl.Trainer(
    gpus=1,
    max_epochs=20,
    log_every_n_steps=5,
    flush_logs_every_n_steps=10,
    callbacks=[
        PlotDigitsCallback(),
        pl.callbacks.LearningRateMonitor(),
    ],
)

Initialize tensorboard to monitor progress.

In [None]:
# Start tensorboard.
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

Finally, we may train our model. While it's training, keep an eye on the tensorboard interface above. You can monitor the losses in the "SCALARS" tab, but you should also be able to look at the images produced by the generator in the "IMAGES" tab.

In [None]:
trainer.fit(model, train_loader)

Let's plot some images that our generator produces:

In [None]:
plot_images(
    (
        model.generate(
            torch.nn.functional.one_hot(
                torch.arange(100) % 10, 10
            ).to(model.device)
        ).detach().cpu().numpy().clip(0, 1).reshape(10, 10, 28, 28) * 255
    ).astype("uint8")
)