<a href="https://colab.research.google.com/github/MatchLab-Imperial/deep-learning-course/blob/master/07_VAE_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Variational Autoencoders

<a href="https://ibb.co/GMMVqft"><img src="https://i.ibb.co/X55zqf3/vae.jpg" alt="vae" border="0"></a>

Image taken from [here](http://kvfrans.com/variational-autoencoders-explained/)

In the autoencoder tutorial we showed how to learn a meaningful representation of the data by using an autoencoder. In an autoencoder, the input image was transformed into a vector which encoded the information from the image in a lower dimensionality space. Then, we decoded that vector to get a reconstruction of the input image. However, the model was focused on encoding existing data for representation learning or similar purposes. To tackle the generation of new data, we will use a Variational Autoencoder (VAE) approach. The image shows an overview of the VAE method.

Parts of the code are taken from [here](https://tiao.io/post/tutorial-on-variational-autoencoders-with-a-concise-keras-implementation/), which contains a more in-depth explanation.


Before starting to define the different parts of the VAE, let's import the needed modules for this tutorial.

In [None]:
# Imports
import datetime
import sys
from IPython.display import HTML, display, clear_output
from PIL import Image as pil_image
import ipywidgets as widgets
from matplotlib import animation
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d
import numpy as np
from scipy import stats
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
from tqdm.notebook import tqdm
import matplotlib.image as mpimg

%matplotlib inline
np.random.seed(123)  # for reproducibility
torch.manual_seed(123)

Now we load MNIST, which will be our toy dataset for this example.

In [None]:
original_dim = 784
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

Now let's do a quick recap of the Variational AutoEncoder (VAE) theory. As stated in the lecture, we want to find the $\hat{\theta}$ that approximates $P_\theta(x)$ by doing:

$$
\hat{\theta} = \text{argmax}_\theta \sum_{i=1}^n \mathbb{E}_{Q_\phi(z|x_i)}[\log(P_\theta(x_i|z))] - \text{KL}(Q_\phi(z|x_i) || P_\theta(z))
$$
Se will minimize the negative of that term in order to maximize it, where
we will train our model using a stochastic approach by sampling mini-batches from the dataset. First, we define the two losses we will use. The loss `nll` is the first term of the equation, whereas the `KLDivergenceLayer` is the second term of the loss. The `KLDivergenceLayer`will be used to compute an extra loss in the middle of the model via the `self.add_loss` function, but it does not change its inputs.

In [None]:
# Negative log likelihood loss (Bernoulli)
def nll(y_true, y_pred):
    return F.binary_cross_entropy(y_pred, y_true, reduction='none').sum(dim=-1)

# KL Divergence Layer
class KLDivergenceLayer(nn.Module):
    def forward(self, mu, log_var):
        kl_batch = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp()).sum(dim=1)
        self.kl_loss = kl_batch.mean()
        return mu, log_var

In the last block, we just defined the two losses we will use. Here we build the whole VAE model. First, we build the encoder. The goal of the encoder $\phi$ is to approximate $P_\theta(z|x_i)$ via $Q_\phi(z|x_i)$.  We assume that $P(z)$ is a Normal distribution with zero mean and unit variance. We also assume that $Q(z|x)={N}(\mu_x, \sigma_x)$ so the encoder tries to recover the parameters $\mu_x, \sigma_x$ for the different $x$ (which are the input images), i.e. the encoder outputs for each latent dimension a mean and standard deviation. The code for this encoder is the following:

```python
# Define latent dimension
        self.latent_dim = latent_dim
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=5, stride=2, padding=2),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
            nn.LeakyReLU(0.2),
            nn.Flatten()
        )

        # Compute μ and log(σ²)
        self.z_mu = nn.Linear(128*7*7, latent_dim)
        self.z_log_var = nn.Linear(128*7*7, latent_dim)
        # To Apply KL divergence layer
        self.kl_layer = KLDivergenceLayer()
```


As we mentioned, the encoder will output the parameters $\mu_x, \sigma_x$. Now we will sample from the normal distribution defined by those parameters to pass it to the decoder. However, we now face one of the problems of implementing a VAE: we want to optimize both decoder and encoder at the same time to i) encourage good reconstruction and ii) to make $z$ follow a normal distribution. What is the problem here? Using a standard sampling method, i.e by directly sampling using the mean and standard deviation output by the encoder, we cannot train it in an end-to-end manner as sampling is not a differentiable operation.

We need a trick to solve this issue. We can use one of the properties of the Normal probability distribution, which is that sampling $\mathcal{N}(\mu_\psi, \sigma_\psi)$ is the same as $\mu_\psi + \sigma_\psi\mathcal{N}(0, 1)$. Now, the sampling process is just a factor multiplying by the prediction $\sigma_\phi$ of the encoder, meaning we can propagate the gradients from the output back to the encoder. This is called the reparametrisation trick. We use this trick in the following block.


```python
def reparameterize(self, mu, log_var):
        ##### Reparametrisation trick
        # Log_var to sigma
        std = torch.exp(0.5 * log_var)
        # Define a separate noise input (which must be provided during training)
        eps = torch.randn_like(std)
        # Multiply sigma with the external noise and add mu to obtain the latent vector
        return mu + eps * std
```

Now we define the decoder, which will take the sampling from $\mathcal{N}(\mu_\psi, \sigma_\psi)$ as input and output an image.


```python
# Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128*7*7),
            nn.LeakyReLU(0.2),
            nn.Unflatten(1, (128, 7, 7)),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, kernel_size=5, padding=2),
            nn.LeakyReLU(0.2),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(64, 1, kernel_size=5, padding=2),
            nn.Sigmoid(),
            nn.Flatten()
        )
```

We have defined both the encoder and the decoder, and we are ready to train the model. We now build the model, which will have two inputs: the image `x`  for the encoder; and the sample `eps` (which refers to the sample from $\mathcal{N}(0, 1)$ we mentioned before) for the decoder, which will use $\mu_x, \sigma_x$ via the reparametrisation trick. The code below will put together all the previous parts as a single class and compile.

In [None]:
# VAE Model
class VAE(nn.Module):
    def __init__(self, latent_dim=2):
        super(VAE, self).__init__()
        # Define latent dimension
        self.latent_dim = latent_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=5, stride=2, padding=2),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
            nn.LeakyReLU(0.2),
            nn.Flatten()
        )

        # Compute μ and log(σ²)
        self.z_mu = nn.Linear(128*7*7, latent_dim)
        self.z_log_var = nn.Linear(128*7*7, latent_dim)
        # To Apply KL divergence layer
        self.kl_layer = KLDivergenceLayer()

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128*7*7),
            nn.LeakyReLU(0.2),
            nn.Unflatten(1, (128, 7, 7)),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, kernel_size=5, padding=2),
            nn.LeakyReLU(0.2),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(64, 1, kernel_size=5, padding=2),
            nn.Sigmoid(),
            nn.Flatten()
        )

    def encode(self, x):
        h = self.encoder(x)
        return self.z_mu(h), self.z_log_var(h)

    def reparameterize(self, mu, log_var):
        ##### Reparametrisation trick
        # Log_var to sigma
        std = torch.exp(0.5 * log_var)
        # Define a separate noise input (which must be provided during training)
        eps = torch.randn_like(std)
        # Multiply sigma with the external noise and add mu to obtain the latent vector
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):

        mu, log_var = self.encode(x)
        mu, log_var = self.kl_layer(mu, log_var)
        z = self.reparameterize(mu, log_var)
        return self.decode(z), mu, log_var, self.kl_layer.kl_loss

In [None]:
# Training setup
latent_dim = 2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VAE(latent_dim).to(device)
optimizer = optim.Adam(model.parameters())

Now we train the model for some epochs to see if we can model our data. After the training process we will use this trained VAE to generate new data.

In [None]:
# Training the model
epochs = 20
batch_size = 50

# Early stopping implementation
best_loss = float('inf')
patience = 3
no_improve_epochs = 0

for epoch in range(1, epochs + 1):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()

        recon_batch, mu, log_var, kl_loss = model(data)
        nll_loss = nll(data.view(-1, 784), recon_batch).mean()
        loss = nll_loss + kl_loss

        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for data, _ in test_loader:
            data = data.to(device)
            recon_batch, mu, log_var, kl_loss = model(data)
            nll_loss = nll(data.view(-1, 784), recon_batch).mean()
            val_loss += (nll_loss + kl_loss).item()

    val_loss /= len(test_loader.dataset)
    print(f'Epoch: {epoch}, Train Loss: {train_loss / len(train_loader.dataset):.4f}, Val Loss: {val_loss:.4f}')

    # Early stopping check
    if val_loss < best_loss:
        best_loss = val_loss
        no_improve_epochs = 0
        torch.save(model.state_dict(), 'best_vae_model.pth')
    else:
        no_improve_epochs += 1
        if no_improve_epochs >= patience:
            print(f'Early stopping after {epoch} epochs')
            model.load_state_dict(torch.load('best_vae_model.pth'))
            break

Assuming we input into the decoder $\mu_x$ obtained after encoding the image with the encoder, we can retrieve the MSE in the test set using the following piece of code.

In [None]:
# Calculate MSE on test set
model.eval()
total_mse = 0
with torch.no_grad():
    for data, _ in test_loader:
        data = data.to(device)
        recon_batch, mu, _, _ = model(data)
        mse = F.mse_loss(recon_batch, data.view(-1, 784), reduction='sum')
        total_mse += mse.item()

print(f"The MSE is: {total_mse / len(test_loader.dataset):.4f}")

We trained an encoder to model $Q(z|x)$ which outputs two parameters per latent dimension for each image $x$, which are  $\mu_x, \sigma_x$. We now plot the distribution of the $\mu_x$ when encoding the different images from the test set.

In [None]:
# Plot latent space distribution
model.eval()
latent_vectors = []
labels = []

with torch.no_grad():
    for data, target in test_loader:
        data = data.to(device)
        mu, _ = model.encode(data)
        latent_vectors.append(mu.cpu())
        labels.append(target.cpu())

z_test = torch.cat(latent_vectors, dim=0).numpy()
y_test = torch.cat(labels, dim=0).numpy()

# Display a 2D plot of the digit classes in the latent space
fig, ax = plt.subplots(figsize=(6,6))
cm = plt.get_cmap('gist_rainbow')
ax.set_prop_cycle(color=[cm(1.*i/(10)) for i in range(10)])
for l in range(10):
    # Only select indices for corresponding label
    ind = y_test == l
    ax.scatter(z_test[ind, 0], z_test[ind, 1], label=str(l), s=10)
ax.legend()
plt.title("Latent distribution")
plt.show()

The distribution of the encoded means shows how they are clustered by class too, as in the autoencoder case. However, in this case, the distribution also follows a kind of circular distribution around the centre due to the Kullback Leibler divergence term in the loss. We made the $\mu_x$ output by the encoder to be close to zero, and $\sigma_x$ to be close to 1.

Now let's start with the generation of data, which is the main reason we trained this model.

In [None]:
# Display a 2D manifold of the digits
n = 15  # figure with 15x15 digits
digit_size = 28

# Linearly spaced coordinates transformed through inverse CDF of Gaussian
z1 = stats.norm.ppf(np.linspace(0.01, 0.99, n))
z2 = stats.norm.ppf(np.linspace(0.01, 0.99, n))
z_grid = np.dstack(np.meshgrid(z1, z2))

# Generate images from grid
with torch.no_grad():
    z_tensor = torch.FloatTensor(z_grid.reshape(n*n, latent_dim)).to(device)
    x_pred = model.decoder(z_tensor).cpu().numpy()

x_pred_grid = x_pred.reshape(n, n, digit_size, digit_size)

plt.figure(figsize=(10, 10))
plt.imshow(np.block(list(map(list, x_pred_grid))), cmap='gray')
plt.axis('off')
plt.show()

You can see in the image how there is a smooth transition between the different generated numbers.

Now let's generate a nice animation for the latent variable, where we show the point we used in $z$ to generate the data and the corresponding image generated.

In [None]:
## We create a 2d array
# Number of points to use, increase it for smoother animation
# Using more points makes the function slower
n_points = 50
# theta from 0 to 2pi
theta = np.linspace(0, 2*np.pi, n_points)
# radius of the circle (change it depending on your reprentation space plot)
r = 1
# compute x and y (you can add an offset depending on your latent space)
offset_x = 0
offset_y = 0
x = r*np.cos(theta) + offset_x
y = r*np.sin(theta) + offset_y
latent = np.stack([x, y], -1)

## We now plot as before the 2d scatter with the images from the test set
## and the corresponding label
# Get test latent vectors
model.eval()
with torch.no_grad():
    z_test = []
    for data, _ in test_loader:
        mu, _ = model.encode(data.to(device))
        z_test.append(mu.cpu())
    z_test = torch.cat(z_test).numpy()

# Create figure
fig, ax = plt.subplots(1, 2, figsize=(12, 6))

# First subplot: latent space with test points
ax[0].set_prop_cycle(color=[plt.cm.gist_rainbow(1.*i/10) for i in range(10)])
for l in range(10):
    ind = y_test == l
    ax[0].scatter(z_test[ind, 0], z_test[ind, 1], label=str(l), s=10)
ax[0].legend()
scat = ax[0].scatter(latent[0,0], latent[0,1], s=200, c='k')

# Second subplot: generated image
with torch.no_grad():
    latent_tensor = torch.FloatTensor(latent).to(device)
    latent_im = model.decoder(latent_tensor).cpu().numpy()

im = ax[1].imshow(latent_im[0].reshape(28, 28), cmap='gray')
ax[1].axis('off')

# Animation update function
def updatefig(i):
    scat.set_offsets(latent[i])
    im.set_array(latent_im[i].reshape(28, 28))
    return im,

# Create animation
anim = animation.FuncAnimation(fig, updatefig, frames=n_points, interval=200, blit=True)
plt.close()
HTML(anim.to_html5_video())

# Generative Adversarial Networks



![](http://2018.igem.org/wiki/images/4/48/T--Vilnius-Lithuania-OG--introduction1.png)
[Image source](https://2018.igem.org/Team:Vilnius-Lithuania-OG/Gan_Introduction)

Generative Adversarial Networks have been shown to improve the generation of data compared to approaches such as VAE.

As you learnt in the lecture, we have two networks playing what is called a min max game between them. The Generator, $G$, tries to generate data that looks similar to real data, whereas the discriminator $D$ tries to distinguish between real and fake data.

We first define $G$ here, and in this case we will use a convolutional network. Notice that we use what is called LeakyReLU for the activation functions, which have been shown to work well when using GANs. $G$ will take a vector of noise (in this case of dimensionality 10) sampled from $\mathcal{N}(0, 1)$ and output a generated image. We use `tanh` as the last activation because the data will be normalized to be between $[-1, 1]$.

In [None]:
# Build Generator model
randomDim = 5

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # Input is random noise vector of size randomDim
            nn.Linear(randomDim, 128 * 7 * 7),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Unflatten(1, (128, 7, 7)),  # Reshape to (128, 7, 7)

            # Upsample to (128, 14, 14)
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, kernel_size=5, padding=2),
            nn.LeakyReLU(0.2, inplace=True),

            # Upsample to (64, 28, 28)
            nn.Upsample(scale_factor=2),
            nn.Conv2d(64, 1, kernel_size=5, padding=2),
            nn.Tanh()  # Output in range [-1, 1]
        )

    def forward(self, input):
        return self.main(input)

# Create generator and optimizer
generator = Generator()
optimizer_gen = optim.Adam(generator.parameters(), lr=2e-4, betas=(0.5, 0.999))

# Print model summary
print(generator)

Here, we define the discriminator $D$. The discriminator is trained with the `categorical_crossentropy` loss, but you could also use a binary loss, as its job is to discriminate between real and fake data.

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # Input: 1x28x28 image
            nn.Conv2d(1, 64, kernel_size=5, stride=2, padding=2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.4),

            # 64x14x14
            nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.4),

            # 128x7x7
            nn.Flatten(),
            nn.Linear(128 * 7 * 7, 1),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

# Create discriminator and optimizer
discriminator = Discriminator()
optimizer_disc = optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))

# Print model summary
print(discriminator)

We defined the combined network, which combines the generator and discriminator in a network. The discriminator, however, will not be updated when using this combined network. We will explain later why, but basically we will use the `discriminator` object (not the combined network) whenever we need to update the discriminator.

In [None]:
class GAN(nn.Module):
    def __init__(self, generator, discriminator):
        super(GAN, self).__init__()
        self.generator = generator
        self.discriminator = discriminator

        # Freeze discriminator parameters
        for param in self.discriminator.parameters():
            param.requires_grad = False

    def forward(self, z):
        """Forward pass through generator then discriminator"""
        generated_images = self.generator(z)
        return self.discriminator(generated_images)

# Create combined GAN model
gan = GAN(generator, discriminator)
optimizer_gan = optim.Adam(gan.generator.parameters(), lr=2e-4, betas=(0.5, 0.999))

# Print model summary
print(gan)

Now, let's define some helper functions that will be used to plot the loss and some generated images during training.

In [None]:
def plot_loss(losses):
    """Plot discriminator and generator loss during training"""
    plt.figure(figsize=(10, 5))
    plt.plot(losses["d"], label='Discriminator Loss')
    plt.plot(losses["g"], label='Generator Loss')
    plt.title("Training Losses")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig('./loss.png')
    plt.close()

def plot_gen(n_ex=16, dim=(4,4), figsize=(10,10)):
    """Generate and plot images from the generator"""
    # Generate random noise
    noise = torch.randn(n_ex, randomDim, device=device)

    # Generate images
    with torch.no_grad():
        generated_images = generator(noise).cpu().numpy()

    # Reshape and plot
    plt.figure(figsize=figsize)
    for i in range(n_ex):
        plt.subplot(dim[0], dim[1], i+1)
        img = generated_images[i].reshape(28, 28)  # Reshape to 28x28
        plt.imshow(img, cmap='gray')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig('./images.png')
    plt.close()

We will not use the regular fit method to train the model as in the usual approach. The reason for this is that the min max game that the GANs do, in practice is implemented in each of the batches by using two different training steps.
* **First step - Training the discriminator:** In this step only the discriminator is trained.  The generator will output some fake images using noise as input. Then, we will give the discriminator these generated images and some images sampled from the real dataset and train it to distinguish between the two of them.

* **Second step - Updating the generator**: We use the generator to output fake images again, and the discriminator will try to guess if these newly generated images are real or fake. However, the aim in this second step is not to update the discriminator, only the generator. This is why in the combined network we made the discriminator not to be trainable. The discriminator is used to pass information (in the form of gradients in this case) to update the generator. Hence, the generator will try to change its weights to make the discriminator think the data comes from the real distribution.

In [None]:
# Set up output widgets
progress_out = widgets.Output()
plot_out = widgets.Output()
display(progress_out)
display(plot_out)

# Initialize variables
losses = {"d": [], "g": []}
total_iter = 0
saved_images = []
saved_iterations = []

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Move models to device
generator = generator.to(device)
discriminator = discriminator.to(device)
gan = GAN(generator, discriminator).to(device)

# Loss function
criterion = nn.BCELoss()

def update_plots(iteration):
    global saved_images, saved_iterations
    with plot_out:
        plot_out.clear_output(wait=True)
        plot_gen()
        # Read the saved image file and append the image data
        img_data = mpimg.imread('images.png')
        saved_images.append(img_data)
        saved_iterations.append(iteration)

        # Display the last saved image
        fig, ax = plt.subplots(1, 1, figsize=(5, 5))
        ax.imshow(img_data)
        ax.axis('off')
        ax.set_title(f'Iteration: {iteration}')
        plt.tight_layout()
        display(fig)
        plt.close(fig)


def train_epoch(epoch_number, plt_frq=200):
    global total_iter

    with progress_out:
        # Initialize progress bar outside the loop
        pbar = tqdm(total=len(train_loader), desc=f"Epoch {epoch_number}", leave=True)

        for i, (real_images, _) in enumerate(train_loader):
            total_iter += 1
            batch_size = real_images.size(0)

            # Move to device
            real_images = real_images.to(device)

            # Create labels
            real_labels = torch.ones(batch_size, 1, device=device)
            fake_labels = torch.zeros(batch_size, 1, device=device)

            # ---------------------
            #  Train Discriminator
            # ---------------------
            # Enable discriminator gradients
            for param in discriminator.parameters():
                param.requires_grad = True
            discriminator.zero_grad()

            # Train with real images
            outputs_real = discriminator(real_images)
            d_loss_real = criterion(outputs_real, real_labels)

            # Train with fake images
            noise = torch.randn(batch_size, randomDim, device=device)
            fake_images = generator(noise).detach()
            outputs_fake = discriminator(fake_images)
            d_loss_fake = criterion(outputs_fake, fake_labels)

            # Total discriminator loss
            d_loss = d_loss_real + d_loss_fake
            d_loss.backward()
            optimizer_disc.step()
            losses["d"].append(d_loss.item())

            # -----------------
            #  Train Generator
            # -----------------
            # Disable discriminator gradients
            for param in discriminator.parameters():
                param.requires_grad = False
            generator.zero_grad()

            # Generate fake images
            noise = torch.randn(batch_size, randomDim, device=device)
            outputs = gan(noise)

            # Generator tries to fool discriminator
            g_loss = criterion(outputs, real_labels)
            g_loss.backward()
            optimizer_gan.step()
            losses["g"].append(g_loss.item())

            # Update progress bar
            pbar.set_postfix({
                'd_loss': f"{d_loss.item():.4f}",
                'g_loss': f"{g_loss.item():.4f}"
            })
            pbar.update(1)

            if i % plt_frq == 0:
                update_plots(total_iter)

        # Close progress bar at end of epoch
        pbar.close()

Let's start the training process. We will use MNIST again. A standard practice for GANs is to normalize images to the range $[-1, 1]$, which we do. We train it for 10 epochs.

In [None]:
# Training loop
n_epoch = 10
for e in range(n_epoch):
    train_epoch(e+1)

In [None]:
# Plot final losses
plot_loss(losses)
fig, ax = plt.subplots(1, 1, figsize=(7, 7))
img = mpimg.imread('loss.png')
ax.imshow(img)
ax.axis('off')
plt.tight_layout()

In [None]:
# Create animation
fig_anim, ax_anim = plt.subplots(figsize=(5, 5))
# Use the first saved image data directly
im_anim = ax_anim.imshow(saved_images[0])
ax_anim.axis('off')
title_text = ax_anim.set_title(f'Iteration: {saved_iterations[0]}')

def update_frame(i):
    # Use the saved image data for updating the frame
    im_anim.set_array(saved_images[i])
    title_text.set_text(f'Iteration: {saved_iterations[i]}')
    return im_anim, title_text

anim = animation.FuncAnimation(fig_anim, update_frame, frames=len(saved_images),
                              interval=500, blit=False, repeat=False)

# Display animation
HTML(anim.to_html5_video())

The plots show how the training loss for both the generator and the discriminator does not change much. Usually we have a model which we optimize to reduce some metric/loss we pass it. However, in this case we have two models which are 'competing' against each other, so their losses are approximately stable. Now let's plot some results, as in the VAE case, we plot the results when sampling from 2 of the dimensions of the input noise $z$. However, as in this case we used more than 2 dimensions for $z$, the results vary due to choosing at random the 2 dimensions where to sample from.

In [None]:
# display a 2D manifold of the digits
n = 15  # figure with 15x15 digits
digit_size = 28

# linearly spaced coordinates on the unit square were transformed
# through the inverse CDF (ppf) of the Gaussian to produce values
# of the latent variables z, since the prior of the latent space
# is Gaussian
z1 = stats.norm.ppf(np.linspace(0.01, 0.99, n))
z2 = stats.norm.ppf(np.linspace(0.01, 0.99, n))
z_grid = np.dstack(np.meshgrid(z1, z2))
z_grid = z_grid.astype(np.float32)

# Create random values for the remaining dimensions
z_fill = np.random.normal(0, 1, size=[randomDim-2])
z_fill = np.expand_dims(z_fill, 0)
z_fill = np.expand_dims(z_fill, 0)
z_fill = np.repeat(z_fill, n, 0)
z_fill = np.repeat(z_fill, n, 1)

# Combine the grid coordinates with random values for other dimensions
z_grid = np.concatenate([z_grid, z_fill], -1)
np.random.shuffle(z_grid[:,:,-1])  # Shuffle one dimension for more variation

# Generate images from the grid points
with torch.no_grad():
    # Convert numpy array to PyTorch tensor and move to device
    z_tensor = torch.from_numpy(z_grid.reshape(n*n, randomDim)).float().to(device)
    # Generate images
    x_pred = generator(z_tensor).cpu().numpy()

# Reshape the predictions into a grid
x_pred_grid = np.reshape(x_pred, (n, n, digit_size, digit_size))

# Display the grid of generated digits
plt.figure(figsize=(10, 10))
plt.imshow(np.block(list(map(list, x_pred_grid))), cmap='gray')
plt.axis('off')
plt.show()

**Training instability**

GANs are quite difficult to train as there are several factors that can hurt their performance. Some of the most frequent issues are:

* [Mode collapsing](https://www.youtube.com/watch?v=ktxhiKhWoEE): The generator is not capable of creating diverse images, it only generates a limited set of images.
* Discriminator loss decreases quickly: in some cases the discriminator may be too powerful, so it quickly learns at the beginning which images are fake and which real, leading to small gradients passed to the generator.
* Hyper parameter sensitivity
* Non-convergence

There are some tricks that have been shown to improve convergence, some of them are shared in this [repo](https://github.com/soumith/ganhacks).



# Inception Score


The evaluation of a generative model quantitatively is a problem that is being research. The metric should evaluate the coverage, i.e. how diverse the generated data is, and sample quality, which is related to the visual quality of the sample. A widely used metric is the Inception Score, which is defined as:
$$
\text{IS}(G)= \exp(\mathbb{E}_{x\sim p_a}\ D_{KL}(p(y|x)||p(y)))
$$
Let's explain the term a little bit. We aim to have a high IS, which means that we want a high $D_{KL}$ between $p(y|x)$ and $p(y)$. Ideally, what we want is $p(y|x)$ to be "peaky" and $p(y)$ to be uniform. $p(y|x)$ is the predicted probability vector $y$ given a generated sample $x$ by a model, in this case an Inception model (hence, the name of the score). If the sample $x$ is of high quality, the model should classify it with confidence in one of the available classes, hence the "peaky" $p(y|x)$. The term $p(y)$ is related to the diversity of the data, and is the marginal probability computed as $\int_z p(y|x=G(z))dz$, hence we average the predicted probabilites of the Inception model for the given samples. If the data is diverse, we will see approximately obtain a flat $p(y)$.

In [None]:
# Train model
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np

# Define the classifier architecture matching the Keras model
class MNISTClassifier(nn.Module):
    def __init__(self):
        super(MNISTClassifier, self).__init__()
        self.model = nn.Sequential(
            # Conv2d_1
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),

            # Conv2d_2
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),

            # Conv2d_3
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),

            # Flatten and dense
            nn.Flatten(),
            nn.Dropout(0.5),
            nn.Linear(128*7*7, 10),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        return self.model(x)

# Initialize model
model = MNISTClassifier()
print(model)

# Training setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

# Data loading
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

# Training function
def train(epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

# Test function
def test():
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} '
          f'({accuracy:.0f}%)\n')
    return accuracy

# Training loop
best_accuracy = 0
for epoch in range(1, 11):  # 10 epochs
    train(epoch)
    current_accuracy = test()

    # Save best model
    if current_accuracy > best_accuracy:
        best_accuracy = current_accuracy
        torch.save(model.state_dict(), 'mnist_classifier_best.pth')
        print("Saved new best model")

# Save final model
torch.save(model.state_dict(), 'mnist_classifier_final.pth')
print("Training complete. Models saved.")

In [None]:
# Load saved model from file
model.load_state_dict(torch.load('mnist_classifier_best.pth'))

In [None]:
def inception_score(x, batch_size=32, denorm_im=True, device='cuda'):
    # Load the trained model
    model = MNISTClassifier().to(device)
    model.load_state_dict(torch.load('mnist_classifier_best.pth'))
    model.eval()

    n_batches = (x.shape[0] + batch_size - 1) // batch_size
    preds = []

    with torch.no_grad():
        for i in range(n_batches):
            batch = x[i*batch_size:(i+1)*batch_size]
            batch = torch.from_numpy(batch).float().to(device)
            if batch.ndim == 3:
                batch = batch.unsqueeze(1)
            elif batch.shape[-1] == 1:
                batch = batch.permute(0, 3, 1, 2)
            if denorm_im:
                batch = (batch + 1) / 2  # Scale from [-1,1] to [0,1]
            logits = model(batch)
            preds.append(logits.cpu().numpy())

    preds = np.concatenate(preds, axis=0)

    # Calculate p(y)
    p_y = np.mean(preds, axis=0)
    # Calculate KL divergence
    # p(y|x)log(P(y|x)/P(y))
    kl_div = preds * (np.log(preds + 1e-16) - np.log(p_y + 1e-16))  # Added epsilon for numerical stability
    # KL(x) = Σ_y p(y|x)log(P(y|x)/P(y))
    kl_div = np.sum(kl_div, axis=1)
    # Calculate mean KL divergence
    avg_kl_div = np.mean(kl_div)
    # Calculate inception score
    is_score = np.exp(avg_kl_div)
    return is_score

def image_inception_score(generator, n_ex=10000, dim_random=10, input_noise=None, denorm_im=True, device='cuda'):
    generator.eval()
    if input_noise is None:
        input_noise = torch.randn(n_ex, dim_random, device=device)
    else:
        input_noise = torch.from_numpy(input_noise).float().to(device)
    # Generate images
    with torch.no_grad():
        x_pred = generator(input_noise)
        # Convert to numpy and reshape if needed
        x_pred = x_pred.cpu().numpy()
        if len(x_pred.shape) == 2:  # Flattened images
            x_pred = x_pred.reshape(n_ex, 28, 28, 1)
        elif x_pred.shape[1] == 1:  # NCHW format
            x_pred = x_pred.transpose(0, 2, 3, 1)  # Convert to NHWC
    return inception_score(x_pred, batch_size=128, denorm_im=denorm_im, device=device)

Now let's check the Inception Score for the GAN model we have trained before.

In [None]:
# Compute Inception Score for the trained GAN generator
is_score = image_inception_score(generator,
                               dim_random=randomDim,
                               denorm_im=True)

print(f"Inception Score: {is_score:.4f}")

Due to the way we defined the generator for the GAN, images are generated in the range $[-1, 1]$ as we use a `tanh` activation function. That is why we use the `denorm_im` variable to map them to the range $[0, 1]$.

The decoder of the VAE can also be evaluated using the same `image_inception_score` function. A VAE uses a prior $\mathcal{N}(0,1)$ as an extra constraint in the loss, hence if we sample from $\mathcal{N}(0,1)$ as done in `image_inception_score`, the decoder should be able to generate new samples. When using the function for the VAE, use `image_inception_score(..., denorm_im=0)` as the decoder in the VAE already uses the range $[0,1]$.

# Conditional Generative Adversarial Networks (cGANs)

A Conditional GAN (cGAN) is a variant of classical Generative Adversarial Networks, where the task is conditioned on some extra information instead of random noise. They were introduced in 2014, [link to the cGANs paper](https://arxiv.org/abs/1411.1784), and set a new state of the art in many image tasks.

In this setting, cGANs expect an input image instead of a random noise vector. A simple example of image translation is the drawing to real object transformation. To do so, imagine that we have two different paired domains, images of the real objects (domain A) and drawing of objects (domain B), and we want to go from one to the other. In classical GANs, we used these two domains for training the architectures, but here we can also condition the input to one of the images in the domain B, and use the relationship between the domains to guide the network to the desired result:

<a href="https://imgbb.com/"><img src="https://i.ibb.co/rF2w5Tt/Screenshot-from-2019-02-25-15-47-18.png" alt="Screenshot-from-2019-02-25-15-47-18" border="0"></a>

As in regular GANs, the discriminator learns to classify between fake (synthesized by the generator) and real {drawing, photo} tuples. Hence, the generator will try to fool the discriminator with better and better synthesised images. In contrast to GANS, as we feed the drawing directly to the generator, the information that the network can use to seek convergence is much wider.

## Image-to-image Translation with cGANs


[Image-to-image Translation with Conditional Adversarial Networks](https://arxiv.org/abs/1611.07004) is a paper presented at the Computer Vision and Pattern Recognition (CVPR) conference in 2017. In their paper, the authors introduced Pix2pix, a conditional GAN, to transform images from one domain to another. The authors in the paper investigated conditional adversarial networks as a general-purpose solution to image-to-image translation problems, see some examples from the paper of different Image-to-Image tasks:

>

![texto alternativo](https://i.ibb.co/Mf08787/examples.jpg)

In this tutorial, we will explain and implement Pix2pix model for colouring images step by step.

## Colouring Black and White Images
The tutorial aims to build the Pix2pix code for colouring images. We will use CIFAR10 dataset to perform the experiment, however, you could try to use any other dataset. For instance, check [this repo](https://github.com/kvfrans/deepcolor) which shows how to colour manga-style images.

The idea is quite intuitive, first of all, we will transform images into black and white (BW) and train a generator that will transform them back to RGB images. As in classical GANs, we use the discriminator to differentiate between *Fake* and *Real* images.

Thus, we have domain A (colour) and domain B (B&W).

## Dataset Generator

When training a cGAN, or any regular GAN architecture, we need to train independently the generator and discriminator networks. As seen, the training is a min-max game between the two nets.

Firstly, we aim to train the generator to be able to fool the discriminator, while in the next iteration, we will train the discriminator to be able to distinguish between the generated (fake) images and the real ones.

We are going to switch the training between networks in every step. Instead of training per a whole epoch as has been seen in the previous tutorials, we will generate batches and train first the discriminator and then switch to train the generator.

Let's define our class *DataLoader()*:

In [None]:
class DataLoader():
    def __init__(self, dataset_name, img_res=(32, 32)):
        self.dataset_name = dataset_name
        self.img_res = img_res
        self.load_dataset()
        self.convert_to_bw = None  # Will be set later
        self._start_epoch()

    def _start_epoch(self):
        # Shuffle training data at start of each epoch
        self.indices = np.random.permutation(len(self.im_A_train))

    def load_dataset(self):
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # [-1, 1] range
        ])

        if self.dataset_name == 'CIFAR10':
            train_set = datasets.CIFAR10('./data', train=True, download=True, transform=transform)
            test_set = datasets.CIFAR10('./data', train=False, download=True, transform=transform)
        elif self.dataset_name == 'CIFAR100':
            train_set = datasets.CIFAR100('./data', train=True, download=True, transform=transform)
            test_set = datasets.CIFAR100('./data', train=False, download=True, transform=transform)
        else:
            raise Exception('Please select a valid dataset')

        # Store as numpy arrays for compatibility with original interface
        self.im_A_train = train_set.data
        self.im_A_test = test_set.data
        self.y_train = np.array(train_set.targets)
        self.y_test = np.array(test_set.targets)

        # Convert to [0,1] range to match original implementation
        self.im_A_train = self.im_A_train.astype('float32') / 255.0
        self.im_A_test = self.im_A_test.astype('float32') / 255.0

    def get_dataset_shape(self, is_training=True):
        if is_training:
            return self.im_A_train.shape
        else:
            return self.im_A_test.shape

    def get_num_batches(self, batch_size):
        return int(np.ceil(len(self.im_A_train) / batch_size))

    def set_image_transformations(self, convert_to_bw):
        self.convert_to_bw = convert_to_bw

    def load_batch(self, batch_size=1, is_training=True):
        if is_training:
            data = self.im_A_train
            indices = self.indices
        else:
            data = self.im_A_test
            indices = np.arange(len(data))

        num_batches = self.get_num_batches(batch_size) if is_training else int(np.ceil(len(data) / batch_size))

        for i in range(num_batches):
            start_idx = i * batch_size
            end_idx = (i + 1) * batch_size
            batch_indices = indices[start_idx:end_idx]

            batch = data[batch_indices]
            batch = torch.from_numpy(batch).permute(0, 3, 1, 2)  # NHWC to NCHW

            if self.convert_to_bw:
                bw_batch = self.convert_to_bw(batch)
                yield batch, bw_batch
            else:
                yield batch, None

    def get_random_batch(self, batch_size=1, is_training=True):
        if is_training:
            data = self.im_A_train
            indices = np.random.choice(len(data), batch_size, replace=False)
        else:
            data = self.im_A_test
            indices = np.random.choice(len(data), batch_size, replace=False)

        batch = data[indices]
        batch = torch.from_numpy(batch).permute(0, 3, 1, 2)  # NHWC to NCHW

        if self.convert_to_bw:
            bw_batch = self.convert_to_bw(batch)
            return batch, bw_batch
        else:
            return batch, None

class BWConverter:
    def predict(self, batch):
        if isinstance(batch, torch.Tensor):
            # Convert RGB to grayscale using standard weights
            return 0.2989 * batch[:,0] + 0.5870 * batch[:,1] + 0.1140 * batch[:,2]
        else:
            return np.dot(batch[...,:3], [0.2989, 0.5870, 0.1140])

*DataLoader* class needs only the name of the dataset to load. Now, it can load CIFAR10 and CIFAR100. Let's see what some of *DataLoader*'s methods do:


*   *_start_epoch()* is a private method. It shuffles the dataset every time that *load_batch()* is called.
*   *load_dataset()* loads the dataset into *DataLoader* object.
*  *load_batch()* creates a python generator. In each iteration, it will return a batch from the shuffled dataset.

Let's now define the dataset class.

In [None]:
# Load the data, shuffled and split between train and test sets
dataset_loader = DataLoader(dataset_name='CIFAR10')
dataset_loader.set_image_transformations(BWConverter())

training_shape = dataset_loader.get_dataset_shape()
test_shape = dataset_loader.get_dataset_shape(is_training=False)

print('Shape of Training Images:', training_shape)
print('Shape of Test Images:', test_shape)

## Defining the Domains

As explained above, we will go from B&W to RGB colour images. To get the B&W images, we need to define *transform_bw*, a method that will transform images from **domain A** (colour) to images in **domain B** (B&W). We can do that with a custom transform class as seen below.

In [None]:
# Convert to grayscale color space
class BWTransform(nn.Module):
    def forward(self, x):
        # Input shape: (batch, 3, H, W) in [-1,1] range
        # Using standard RGB to grayscale conversion weights
        grayscale = 0.2989 * x[:,0] + 0.5870 * x[:,1] + 0.1140 * x[:,2]
        return grayscale.unsqueeze(1)  # Add channel dim back

transform_bw = BWTransform()

# Set predefined transformation in DataLoader Class
dataset_loader.set_image_transformations(transform_bw)

We can visualise images from both domains before starting to code our cGAN. Rerun the following code to check different examples:

In [None]:
# Load random batch from dataset
random_batch = dataset_loader.get_random_batch(batch_size=9)
color_images, bw_images = random_batch

# Convert tensors to numpy and permute dimensions for visualization
color_images = color_images.permute(0, 2, 3, 1).cpu().numpy()  # NCHW -> NHWC
bw_images = bw_images.permute(0, 2, 3, 1).cpu().numpy()  # NCHW -> NHWC

# Denormalize from [-1,1] to [0,1] range
color_images = (color_images + 1) / 2
bw_images = (bw_images + 1) / 2

# Repeat grayscale channel for visualization
bw_images_rgb = np.repeat(bw_images, 3, axis=3)

# Plot comparison
N = 3
fig, axes = plt.subplots(N, N, figsize=(10, 10))
plt.suptitle('Domain A (Color) VS Domain B (B&W)', fontsize=18)
for row in range(N):
    for col in range(N):
        idx = row * N + col

        # Concatenate color and BW images horizontally
        combined = np.concatenate((color_images[idx], bw_images_rgb[idx]), axis=1)

        axes[row, col].imshow(np.clip(combined, 0, 1))
        axes[row, col].axis('off')

plt.tight_layout()
plt.show()

## Generator & Discriminator Models

For many image translation problems, there is a great deal of low-level information shared between the input and output. Thus, it is desirable to shuttle this information directly across the net. For example, in the case of image colourization, the input and output share the location of prominent edges. To give the generator a means to circumvent the bottleneck for information like this, we use an architecture with skip connections, following the general shape of the UNet introduced in previous tutorials.

The generator introduced in this tutorial is a simpler version of the one in the paper since we are using low-resolution images (32x32).

In Pix2pix, the authors found that mixing the GAN objective with a more traditional loss, such as L1 loss, improved the final performance. The discriminator’s job remains unchanged, but the generator is tasked to not only fool the discriminator but also to be near the ground-truth output in an L1 sense.

We can define our cGAN generator model.

In [None]:
class Generator(nn.Module):
    def __init__(self, im_shape=(32, 32)):
        super(Generator, self).__init__()

        # Encoder
        self.conv1 = nn.Conv2d(1, 64, 3, padding=1)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
        self.pool3 = nn.MaxPool2d(2)
        self.conv4 = nn.Conv2d(256, 512, 3, padding=1)
        self.drop4 = nn.Dropout2d(0.5)
        self.pool4 = nn.MaxPool2d(2)
        self.conv5 = nn.Conv2d(512, 1024, 3, padding=1)
        self.drop5 = nn.Dropout2d(0.5)

        # Decoder with skip connections
        self.up6 = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(1024, 512, 3, padding=1)
        )
        self.conv6 = nn.Conv2d(1024, 512, 3, padding=1)  # Accounts for concatenation

        self.up7 = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(512, 256, 3, padding=1)
        )
        self.conv7 = nn.Conv2d(512, 256, 3, padding=1)

        self.up8 = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 128, 3, padding=1)
        )
        self.conv8 = nn.Conv2d(256, 128, 3, padding=1)

        self.up9 = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, padding=1)
        )
        self.conv9 = nn.Conv2d(128, 64, 3, padding=1)

        self.conv10 = nn.Conv2d(64, 3, 3, padding=1)

    def forward(self, x):
        # Encoder
        conv1 = F.relu(self.conv1(x))
        pool1 = self.pool1(conv1)
        conv2 = F.relu(self.conv2(pool1))
        pool2 = self.pool2(conv2)
        conv3 = F.relu(self.conv3(pool2))
        pool3 = self.pool3(conv3)
        conv4 = F.relu(self.conv4(pool3))
        drop4 = self.drop4(conv4)
        pool4 = self.pool4(drop4)
        conv5 = F.relu(self.conv5(pool4))
        drop5 = self.drop5(conv5)

        # Decoder with skip connections
        up6 = self.up6(drop5)
        merge6 = torch.cat([drop4, up6], dim=1)
        conv6 = F.relu(self.conv6(merge6))

        up7 = self.up7(conv6)
        merge7 = torch.cat([conv3, up7], dim=1)
        conv7 = F.relu(self.conv7(merge7))

        up8 = self.up8(conv7)
        merge8 = torch.cat([conv2, up8], dim=1)
        conv8 = F.relu(self.conv8(merge8))

        up9 = self.up9(conv8)
        merge9 = torch.cat([conv1, up9], dim=1)
        conv9 = F.relu(self.conv9(merge9))

        conv10 = self.conv10(conv9)

        return torch.tanh(conv10)  # Output in [-1, 1] range

generator = Generator()

Next, we follow the discriminator model definition from the paper:

In [None]:
class Discriminator(nn.Module):
    def __init__(self, im_shape=(32, 32)):
        super(Discriminator, self).__init__()
        self.im_shape = im_shape

        # Define the model architecture
        self.conv1 = nn.Sequential(
            nn.Conv2d(4, 64, kernel_size=5, stride=2, padding=2),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(128, momentum=0.8),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(256, momentum=0.8),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.fc = nn.Sequential(
            nn.Linear(256 * (im_shape[0]//8) * (im_shape[1]//8), 1024),
            nn.Linear(1024, 1)
        )

    def forward(self, img_A, img_B):
        # Concatenate inputs along channel dimension
        x = torch.cat((img_A, img_B), dim=1)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

discriminator = Discriminator()

## Training Loop

We arrive at the tricky part of the cGAN tutorial, where we need to define the training loop of our model. Do not focus much on this part of the code, this is only provided in case you want to further train the architectures.

First, we need to define the optimisers, inputs, outputs, and losses for each one of the networks. Let's go line by line:

In [None]:
# Define optimizers
optimizer_g = optim.Adam(generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=2e-5, betas=(0.5, 0.999))

# Loss functions
criterion_mse = nn.MSELoss()  # For adversarial loss
criterion_l1 = nn.L1Loss()    # For pixel-wise loss

# Input size
im_shape = (32, 32)

# Move models to device (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = generator.to(device)
discriminator = discriminator.to(device)

We have created the models and compiled them. Besides, before starting the training loop, we are going to define two auxiliary functions that will allow us to print images during training to have an idea of how well the generator is colouring the images:

In [None]:
def showColoredIms(imB, fake_imA, real_imA):
    # Convert tensors to numpy and denormalize
    imB = imB[0].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5
    # Detach fake_imA before converting to numpy
    fake_imA = fake_imA[0].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5
    real_imA = real_imA[0].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5

    plt.figure(figsize=(15,5))
    plt.subplot(131)
    plt.imshow(imB[:,:,0], cmap='gray')
    plt.title('Domain B', fontsize=20)
    plt.axis('off')

    plt.subplot(132)
    plt.imshow(fake_imA)
    plt.title('Fake A', fontsize=20)
    plt.axis('off')

    plt.subplot(133)
    plt.imshow(real_imA)
    plt.title('Real A', fontsize=20)
    plt.axis('off')
    plt.show()

def showColored_two_models_Ims(imB, fake_imA_MAE, fake_imA_cGAN, real_imA):
    # Convert tensors to numpy and denormalize
    imB = imB[0].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5
    # Detach fake_imA_MAE and fake_imA_cGAN before converting to numpy
    fake_imA_MAE = fake_imA_MAE[0].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5
    fake_imA_cGAN = fake_imA_cGAN[0].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5
    real_imA = real_imA[0].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5

    plt.figure(figsize=(20,5))
    plt.subplot(141)
    plt.imshow(imB[:,:,0], cmap='gray')
    plt.title('BW', fontsize=20)
    plt.axis('off')

    plt.subplot(142)
    plt.imshow(fake_imA_MAE)
    plt.title('MAE', fontsize=20)
    plt.axis('off')

    plt.subplot(143)
    plt.imshow(fake_imA_cGAN)
    plt.title('cGAN', fontsize=20)
    plt.axis('off')

    plt.subplot(144)
    plt.imshow(real_imA)
    plt.title('Real', fontsize=20)
    plt.axis('off')
    plt.show()

In the next section, we provide trained weights so you could resume or skip the training. These weights were obtained by training the net for 20 epochs. Note that if you want to further train the networks, GAN training is characterized for being really long.

Let's see the training loop.

In [None]:
num_epochs = 1
batch_size = 128
n_batches = dataset_loader.get_num_batches(batch_size)

# Adversarial loss ground truths
valid = torch.ones(batch_size, 1, device=device)
fake = torch.zeros(batch_size, 1, device=device)

for epoch in range(num_epochs):
    start_time = datetime.datetime.now()

    g_avg_loss = []
    d_avg_loss = []
    d_avg_acc = []

    for batch_i, (imgs_A, imgs_B) in enumerate(dataset_loader.load_batch(batch_size)):
        # Get current batch size (last batch might be smaller)
        current_batch_size = imgs_A.size(0)

        # Move data to device
        imgs_A = imgs_A.to(device)
        imgs_B = imgs_B.to(device)

        # Create properly sized target tensors
        valid = torch.ones(current_batch_size, 1, device=device)
        fake = torch.zeros(current_batch_size, 1, device=device)

        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_d.zero_grad()

        # Real images - pass both images separately
        real_validity = discriminator(imgs_A, imgs_B)
        d_loss_real = criterion_mse(real_validity, valid)

        # Fake images
        fake_A = generator(imgs_B)
        fake_validity = discriminator(fake_A.detach(), imgs_B)
        d_loss_fake = criterion_mse(fake_validity, fake)

        # Total discriminator loss
        d_loss = (d_loss_real + d_loss_fake) / 2
        d_loss.backward()
        optimizer_d.step()

        # -----------------
        #  Train Generator
        # -----------------
        optimizer_g.zero_grad()

        # Generate fake images
        fake_A = generator(imgs_B)

        # Adversarial loss
        validity = discriminator(fake_A, imgs_B)
        g_loss_adv = criterion_mse(validity, valid)

        # Pixel-wise loss
        g_loss_l1 = criterion_l1(fake_A, imgs_A)

        # Total generator loss
        g_loss = g_loss_adv + 100 * g_loss_l1
        g_loss.backward()
        optimizer_g.step()

        # Calculate discriminator accuracy
        real_acc = torch.mean((real_validity > 0.5).float())
        fake_acc = torch.mean((fake_validity < 0.5).float())
        d_acc = (real_acc + fake_acc) / 2

        # Record metrics
        d_avg_loss.append(d_loss.item())
        d_avg_acc.append(d_acc.item())
        g_avg_loss.append(g_loss.item())

        elapsed_time = datetime.datetime.now() - start_time
        remaining_time = (elapsed_time/(batch_i+1)) * (n_batches-batch_i-1)

        if batch_i % 50 == 0:
            showColoredIms(imgs_B, fake_A, imgs_A)

        if batch_i % 10 == 0:
            print(f"[Epoch {epoch}/{num_epochs}] [Batch {batch_i}/{n_batches}] "
                  f"[D loss: {np.mean(d_avg_loss):.4f}, acc: {100*np.mean(d_avg_acc):.0f}%] "
                  f"[G loss: {np.mean(g_avg_loss):.4f}] elapsed: {elapsed_time} remaining: {remaining_time}")


# Save the models after training
torch.save(generator.state_dict(), 'generator.pth')
torch.save(discriminator.state_dict(), 'discriminator.pth')

If desired, you could download the model you just trained. Those weights can be used for resuming future trainings.

In [None]:
from google.colab import files

# Download the generator weights
files.download('generator.pth')

# Download the discriminator weights
files.download('discriminator.pth')

# Optionally: Save a complete checkpoint including optimizer states
checkpoint = {
    'generator_state_dict': generator.state_dict(),
    'discriminator_state_dict': discriminator.state_dict(),
    'optimizer_g_state_dict': optimizer_g.state_dict(),
    'optimizer_d_state_dict': optimizer_d.state_dict(),
    'epoch': num_epochs,
    'losses': {
        'g_avg_loss': g_avg_loss,
        'd_avg_loss': d_avg_loss,
        'd_avg_acc': d_avg_acc
    }
}

torch.save(checkpoint, 'cgan_checkpoint.pth')
files.download('cgan_checkpoint.pth')

## Colouring Test Images

We are ready to visualise how the network colours the test images. If you have not to train your model, you can optionally load the provided trained model.

In [None]:
# Load model from saved file
generator.load_state_dict(torch.load('generator.pth'))

Let's see the results, you can rerun the following code for visualizing more examples:

In [None]:
# Put generator in evaluation mode
generator.eval()

with torch.no_grad():  # Disable gradient calculation for inference
    for i in range(3):
        # Get random batch from test set
        im_A_real, im_B_test = dataset_loader.get_random_batch(batch_size=1, is_training=False)

        # Move to device and generate fake image
        im_B_test = im_B_test.to(device)
        im_A_fake = generator(im_B_test)  # Forward pass instead of predict()

        # Move back to CPU for visualization
        im_A_real = im_A_real.cpu()
        im_B_test = im_B_test.cpu()
        im_A_fake = im_A_fake.cpu()

        # Show results
        showColoredIms(im_B_test, im_A_fake, im_A_real)

# Put generator back in training mode if continuing training
generator.train()

The colouring may not be perfect, but the network does a good job at guessing plausible colours. We have shown a simplified version of the Pix2pix method. Training GANs is a hard task, and there are many elements that could be added to improve results. For instance, [this repo](https://github.com/soumith/ganhacks), like many others, shows common tricks to improve GAN performance.

# Coursework



## Task 1: MNIST generation using VAE and GAN

**Report**
* Train the given VAE model in the tutorial also using the same hyper parameters (batch size, optimizer, number of epochs, etc...) but increasing the latent dimensionality to 10. Compute the MSE on the reconstructed `x_test` images and Inception Score (IS). Now train the same model without the KL divergence loss, and compute again the MSE and IS score. Report the results in a table and discuss them. For the IS you can use in this case `image_inception_score(decoder, dim_random=10, denorm_im=False)`.

* Train the GAN given in the tutorial with increased dimensionality of the initial random sample to 10 for 10 epochs and report its IS (use the same table as in the VAE case). Discuss the difference in the obtained IS for VAE and GAN and link it to the qualitative results. For the IS use in this case `image_inception_score(generator, dim_random=10, denorm_im=True)`.



## Task 2: Quantitative VS Qualitative Results

In this task, we will observe the difference between two trained models for colouring images. One is the model trained during the tutorial, which uses a cGAN approach to predict the RGB pixel-wise values of a B&W image. The other one is a simple UNet autoencoder trained with a Mean Absolute Error (MAE) loss, which is trained to predict directly the RBG image without any GAN based learning strategy. We refer to the first and second models as cGAN and MAE models, respectively. For this task, 20 epochs trained weights for the cGAN and MAE models are provided. If desired, the code to train the MAE model can be found below:

In [None]:
# Build the generator
generator_mae = Generator(im_shape).to(device)

# Define optimizer and loss
optimizer_mae = torch.optim.Adam(generator_mae.parameters(), lr=2e-4, betas=(0.5, 0.999))
criterion_mae = nn.L1Loss()  # Mean Absolute Error

num_epochs = 1
batch_size = 128
n_batches = dataset_loader.get_num_batches(batch_size)

for epoch in range(num_epochs):
    start_time = datetime.datetime.now()
    g_avg_loss = []

    for batch_i, (imgs_A, imgs_B) in enumerate(dataset_loader.load_batch(batch_size)):
        # Move data to device
        imgs_A = imgs_A.to(device)
        imgs_B = imgs_B.to(device)

        # -----------------
        #  Train Generator (MAE)
        # -----------------
        optimizer_mae.zero_grad()

        # Generate fake images
        fake_A = generator_mae(imgs_B)

        # Calculate MAE loss
        g_loss = criterion_mae(fake_A, imgs_A)

        # Backward pass and optimize
        g_loss.backward()
        optimizer_mae.step()

        g_avg_loss.append(g_loss.item())

        elapsed_time = datetime.datetime.now() - start_time
        remaining_time = (elapsed_time/(batch_i+1)) * (n_batches-batch_i-1)

        # Plot examples
        if batch_i % 50 == 0:
            with torch.no_grad():
                fake_A = generator_mae(imgs_B)
                showColoredIms(imgs_B, fake_A, imgs_A)

        # Print progress
        if batch_i % 10 == 0:
            print(f"[Epoch {epoch}/{num_epochs}] [Batch {batch_i}/{n_batches}] "
                  f"[G loss: {np.mean(g_avg_loss):.4f}] "
                  f"elapsed: {elapsed_time} remaining: {remaining_time}")

    # Save model
    torch.save({
        'generator_state_dict': generator_mae.state_dict(),
        'optimizer_state_dict': optimizer_mae.state_dict(),
        'epoch': epoch,
        'loss': g_avg_loss
    }, 'generator_mae.pth')

Instead of training the models, we can directly load their pre-trained weights by running:

In [None]:
# Pre-load previously trained cGAN generator
generator_cGAN = Generator(im_shape).to(device)
generator_cGAN.load_state_dict(torch.load('generator.pth'))

We have loaded both models, and we are ready to compare them. In this task, you are asked to analyse the difference between the quantitative versus the qualitative results. To do so, we provided two pieces of code. The first one will compute the MAE metric for both models in the test dataset. As we know, this metric is widely used on image generation tasks, such as image upsampling, image reconstruction, image translation, and so on.

In [None]:
# Evaluation for MAE-trained generator
g_mae_avg_mae = []
generator_mae.eval()  # Set to evaluation mode

with torch.no_grad():  # Disable gradient calculation
    for batch_i, (imgs_A, imgs_B) in enumerate(dataset_loader.load_batch(128, is_training=False)):
        imgs_A = imgs_A.to(device)
        imgs_B = imgs_B.to(device)

        fake_A = generator_mae(imgs_B)
        mae = F.l1_loss(fake_A, imgs_A).item()  # Calculate MAE
        g_mae_avg_mae.append(mae)

print("MAE (Trained MAE): {:.4f}".format(np.mean(g_mae_avg_mae)))

# Evaluation for cGAN-trained generator
g_cgan_avg_mae = []
generator_cGAN.eval()  # Set to evaluation mode

with torch.no_grad():
    for batch_i, (imgs_A, imgs_B) in enumerate(dataset_loader.load_batch(128, is_training=False)):
        imgs_A = imgs_A.to(device)
        imgs_B = imgs_B.to(device)

        fake_A = generator_cGAN(imgs_B)
        mae = F.l1_loss(fake_A, imgs_A).item()  # Calculate MAE
        g_cgan_avg_mae.append(mae)

print("MAE (Trained cGAN): {:.4f}".format(np.mean(g_cgan_avg_mae)))

The next piece of code will show coloured examples for both networks, so you can check them visually and discuss which model is better. First, we need to create an iterator object to go through the test dataset:

In [None]:
iterator = iter(dataset_loader.load_batch(1, is_training=False))

Run multiple examples so that you have a clear idea of how both methods differ.

In [None]:
# Load test example
[imgs_A, imgs_B] = next(iterator)

if not isinstance(imgs_A, torch.Tensor):
    imgs_A = torch.from_numpy(imgs_A)
    imgs_B = torch.from_numpy(imgs_B)

imgs_A = imgs_A.to(device)
imgs_B = imgs_B.to(device)

# Generate predictions for both models
with torch.no_grad():
    fake_A_cGAN = generator_cGAN(imgs_B).cpu()
    fake_A_MAE = generator_mae(imgs_B).cpu()

# Plot all images
showColored_two_models_Ims(imgs_B.cpu(), fake_A_MAE, fake_A_cGAN, imgs_A.cpu())

We showed that both models obtain a similar MAE value. If we would only take into account the quantitative metric, as done in many scientific articles, we would say that the MAE model is better. However, in addition to the quantitative results, we need to analyse visually the results produced by the two networks to declare which is the best model.

**Report**


*   Run the previous code to analyse several coloured images for both models. Based on previous results and linked to GAN theory, discuss from the numerical and visual perspective if both models are similar, or whether there is a better one. You can provide in the report visual examples together with their MAE values to support your arguments. The figure of this task can be included in the Appendix. Discussion still needs to go into the main text.