# Generating New MNIST Digits with a GAN

This notebook provides you with a complete code example to generate MNIST digits with a GAN.

## Loading the MNIST Dataset with PyTorch

Implement the digit transformations ...

In [None]:
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5], inplace=True),
])

... import the MNIST digits ...

In [None]:
from torchvision.datasets import MNIST

trainset = MNIST(root="data", train=True, transform=transform, download=True)

... and plot some of the transformed MNIST digits.

In [None]:
import torch
import matplotlib.pyplot as plt

fig, axs = plt.subplots(1, 8, figsize=(15, 3))
for ax in axs.ravel():
    img, label = trainset[torch.randint(0, len(trainset), (1,)).squeeze()]
    ax.imshow(img.squeeze(), cmap="gray")
    ax.set_title(f"Label: {label}", fontsize=16)
    ax.axis("off")
plt.tight_layout()
plt.show()

## Defining the Generator and Discriminator

Determine the device to be used in the computations ...

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

... instantiating the generator ...

In [None]:
import deeplay as dl

latent_dim = 100

generator = dl.DCGANGenerator(
    latent_dim=latent_dim,
    features_dim=64,
    output_channels=1,
)
generator.build()
generator.to(device);

print(generator)

... and instantiate the discriminator.

In [None]:
discriminator = dl.DCGANDiscriminator(
    input_channels=1,
    features_dim=64,
)
discriminator.build()
discriminator.to(device);

print(discriminator)

## Training the GAN

Define the data loader ...

In [None]:
from torch.utils.data import DataLoader

batch_size=128

loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, 
                    num_workers=4)

... define the loss function ...

In [None]:
loss = torch.nn.BCELoss()

... define the optimizers ...

In [None]:
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002,
                               betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002,
                               betas=(0.5, 0.999))

... implement the adversarial training ...

In [None]:
import time
from datetime import timedelta

epochs = 20

num_batches = len(loader)
gen_losses_avg, disc_losses_avg = [], []
fixed_latent_vector = torch.randn(30, latent_dim, 1, 1).to(device)
for epoch in range(epochs):
    generator.train(), discriminator.train()
    
    print("\n" + f"Epoch {epoch + 1}/{epochs}" + "\n" + "-" * 10)
    start_time = time.time()
    running_gen_loss, running_disc_loss = 0.0, 0.0
    for batch_idx, (real_images, class_labels) in enumerate(loader, start=0):
        real_images = real_images.to(device)

        noise = torch.randn(batch_size, latent_dim, 1, 1).to(device)
        fake_images = generator(noise)

        # 1.  Discriminator training: minimize - log(D(x)) - log(1 - D(G(z))).
        real_output = discriminator(real_images).reshape(-1)
        fake_output = discriminator(fake_images).reshape(-1)

        real_loss = loss(real_output, torch.ones_like(real_output))
        fake_loss = loss(fake_output, torch.zeros_like(fake_output))

        discriminator_loss = (real_loss + fake_loss) / 2

        optimizer_D.zero_grad()
        discriminator_loss.backward(retain_graph=True)
        optimizer_D.step()

        # 2.  Generator training: minimize - log(D(G(z))).
        fake_output = discriminator(fake_images).reshape(-1)
        generator_loss = loss(fake_output, torch.ones_like(fake_output))

        optimizer_G.zero_grad()
        generator_loss.backward()
        optimizer_G.step()

        if batch_idx % 100 == 0:
            print(f"Batch {batch_idx + 1}/{num_batches}: " 
                  + f"Generator Loss: {generator_loss.item():.4f}, " 
                  + f"Discriminator Loss: {discriminator_loss.item():.4f}")

        running_gen_loss += generator_loss.item()
        running_disc_loss += discriminator_loss.item()

    gen_losses_avg.append(running_gen_loss / num_batches)
    disc_losses_avg.append(running_disc_loss / num_batches)
    end_time = time.time()

    print("-" * 10 + "\n"
          + f"Epoch {epoch+1}/{epochs}: " 
          + f"Generator Loss: {gen_losses_avg[-1]:.4f}, "
          + f"Discriminator Loss: {disc_losses_avg[-1]:.4f}, "
          + f"Time taken: {timedelta(seconds=end_time - start_time)}")
    
    generator.eval(), discriminator.eval()
    fake_images = generator(fixed_latent_vector).detach().cpu().numpy()
    fig, axs = plt.subplots(3, 10, figsize=(20, 6))
    for i, ax in enumerate(axs.ravel()):
        ax.imshow(fake_images[i][0], cmap="gray")
        ax.axis("off")
    plt.tight_layout()
    plt.show()
    plt.close(fig)

## Plotting the Training Losses

In [None]:
import numpy as np

plt.plot(np.arange(len(gen_losses_avg)), gen_losses_avg, "g--o",
         label="Generator Loss")
plt.plot(np.arange(len(disc_losses_avg)), disc_losses_avg, "r-o",
         label="Discriminator Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()