In [None]:
!pip install wget mlflow -q

In [None]:
import os
import uuid
import random
import numpy as np
import matplotlib.pyplot as plt

import wget
import mlflow
import zipfile
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.utils as vutils
import torchvision.transforms as transforms

import tqdm
from torch import Tensor
from torch.utils.data import DataLoader, Subset
from torchvision.utils import make_grid
import torch.utils.data

: 

In [None]:
manualSeed = 999
random.seed(manualSeed)
torch.manual_seed(manualSeed)

os.environ['MLFLOW_TRACKING_USERNAME'] = 'MarceloRosendo'
os.environ['MLFLOW_TRACKING_PASSWORD'] = '8424ba4e319e56a3077b43c49bec078fcdf9ca5a'
mlflow.set_tracking_uri('.mlflow')

In [None]:
number_of_gpus = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
learning_rate = 2e-4
batch_size = 128
image_size = 64
image_channels = 3
z_dimension = 100
num_epochs = 30
discriminator_features = 64
generator_features = 64

In [None]:
%%time
data_root = "datasets"
base_url = "https://graal.ift.ulaval.ca/public/celeba/"

file_list = [
    "img_align_celeba.zip",
    "list_attr_celeba.txt",
    "identity_CelebA.txt",
    "list_bbox_celeba.txt",
    "list_landmarks_align_celeba.txt",
    "list_eval_partition.txt",
]

dataset_folder = f"{data_root}/celeba"
os.makedirs(dataset_folder, exist_ok=True)

for file in file_list:
    url = f"{base_url}/{file}"
    if not os.path.exists(f"{dataset_folder}/{file}"):
        wget.download(url, f"{dataset_folder}/{file}")

with zipfile.ZipFile(f"{dataset_folder}/img_align_celeba.zip", "r") as ziphandler:
    ziphandler.extractall(dataset_folder)

In [None]:
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5),
                         std=(0.5, 0.5, 0.5))
])

dataset = datasets.ImageFolder(data_root, transform=transform)
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)

In [None]:
len(dataset)

In [None]:
dataset = Subset(dataset, np.arange(1, 20000))
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)

In [None]:
len(dataloader)

In [None]:
real_batch = next(iter(dataloader))
plt.figure(figsize=(7, 7))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64],
                                         padding=2,
                                         normalize=True).cpu(),
                        axes=(1, 2, 0)))

In [None]:
class Generator(nn.Module):
    def __init__(self,
                 z_dimension: int,
                 image_channels: int,
                 generator_features: int) -> None:
        """
        Initializes the Generator class.

        Args:
            z_dimension (int): The dimension of the input noise vector.
            image_channels (int): The number of channels in the output image.
            generator_features (int): The number of feature maps in the generator model.

        Returns:
            None
        """
        super(Generator, self).__init__()
        self.generator_sequential_model = nn.Sequential(
            self._conv_generator_block(in_channels=z_dimension,
                                       out_channels=generator_features * 16,
                                       kernel_size=4,
                                       stride=1,
                                       padding=0),
            self._conv_generator_block(in_channels=generator_features * 16,
                                       out_channels=generator_features * 8,
                                       kernel_size=4,
                                       stride=2,
                                       padding=1),
            self._conv_generator_block(in_channels=generator_features * 8,
                                       out_channels=generator_features * 4,
                                       kernel_size=4,
                                       stride=2,
                                       padding=1),
            self._conv_generator_block(in_channels=generator_features * 4,
                                       out_channels=generator_features * 2,
                                       kernel_size=4,
                                       stride=2,
                                       padding=1),
            nn.ConvTranspose2d(
                in_channels=generator_features * 2,
                out_channels=image_channels,
                kernel_size=4,
                stride=2,
                padding=1
            ),
            nn.Tanh()
        )

    @staticmethod
    def _conv_generator_block(in_channels: int,
                              out_channels: int,
                              kernel_size: int,
                              stride: int,
                              padding: int) -> nn.Sequential:
        """
        Generate a convolutional generator block.

        Args:
            in_channels (int): The number of input channels.
            out_channels (int): The number of output channels.
            kernel_size (int): The size of the convolutional kernel.
            stride (int): The stride of the convolution.
            padding (int): The padding added to the input.

        Returns:
            nn.Sequential: A sequential module containing the convolutional generator block.
        """
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                bias=False
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def forward(self,
                x: Tensor) -> Tensor:
        """
        Forward pass through the model.

        Parameters:
            x (Tensor): The input tensor.

        Returns:
            Tensor: The output tensor.
        """
        return self.generator_sequential_model(x)

In [None]:
class Discriminator(nn.Module):
    def __init__(self,
                 image_channels: int,
                 discriminator_features: int) -> None:
        """
        Initializes a Discriminator object.

        Args:
            image_channels (int): The number of channels in the input image.
            discriminator_features (int): The number of features in the discriminator model.

        Returns:
            None
        """
        super(Discriminator, self).__init__()
        self.discriminator_sequential_model = nn.Sequential(
            nn.Conv2d(in_channels=image_channels,
                      out_channels=discriminator_features,
                      kernel_size=4,
                      stride=2,
                      padding=1),
            nn.LeakyReLU(0.2),
            self._conv_discriminator_block(in_channels=discriminator_features,
                                           out_channels=discriminator_features * 2,
                                           kernel_size=4,
                                           stride=2,
                                           padding=1),
            self._conv_discriminator_block(in_channels=discriminator_features * 2,
                                           out_channels=discriminator_features * 4,
                                           kernel_size=4,
                                           stride=2,
                                           padding=1),
            self._conv_discriminator_block(in_channels=discriminator_features * 4,
                                           out_channels=discriminator_features * 8,
                                           kernel_size=4, stride=2,
                                           padding=1),
            nn.Conv2d(in_channels=discriminator_features * 8,
                      out_channels=1,
                      kernel_size=4,
                      stride=2,
                      padding=0),
            nn.Sigmoid()
        )

    @staticmethod
    def _conv_discriminator_block(in_channels: int,
                                  out_channels: int,
                                  kernel_size: int,
                                  stride: int,
                                  padding: int) -> nn.Sequential:
        """
        Create a discriminator block consisting of a convolutional layer, batch normalization,
        and a leaky ReLU activation function.

        Args:
            in_channels (int): The number of input channels.
            out_channels (int): The number of output channels.
            kernel_size (int): The size of the convolutional kernel.
            stride (int): The stride value for the convolution.
            padding (int): The padding value for the convolution.

        Returns:
            nn.Sequential: The discriminator block.
        """
        return nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                bias=False
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )

    def forward(self,
                x: Tensor) -> Tensor:
        """
        Forward pass through the model.

        Args:
            x (Tensor): The input tensor.

        Returns:
            Tensor: The output tensor.
        """
        return self.discriminator_sequential_model(x)

In [None]:
def initialize_weights(model: nn.Module) -> None:
    """
    Initializes the weights of a neural network model.

    Args:
        model (nn.Module): The neural network model to initialize the weights for.

    Returns:
        None: This function does not return anything.
    """
    class_name = model.__class__.__name__
    if 'Conv' in class_name:
        nn.init.normal_(model.weight.data,
                        mean=0.0,
                        std=0.02)
    elif 'BatchNorm' in class_name:
        nn.init.normal_(model.weight.data,
                        mean=1.0,
                        std=0.02)
        nn.init.constant_(model.bias.data,
                          val=0)

def show_tensor_images(image_tensor: Tensor,
                       num_images: int = 32) -> None:
    """
    Displays a grid of tensor images.

    Args:
        image_tensor (Tensor): The input tensor containing the images.
        num_images (int, optional): The number of images to display. Default is 32.

    Returns:
        None
    """
    image_tensor = (image_tensor + 1) / 2
    image_detached = image_tensor.detach().cpu()
    image_grid = make_grid(image_detached[:num_images], nrow=4)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

def calculate_generator_loss(fake_image: Tensor) -> Tensor:
    """
    Calculate the generator loss.

    Args:
        fake_image (Tensor): The fake image generated by the generator.

    Returns:
        Tensor: The generator loss.

    """
    output = discriminator(fake_image).reshape(-1)
    return loss(output, torch.ones_like(output))


def calculate_discriminator_loss(real_image: Tensor,
                                 fake_image: Tensor) -> Tensor:
    """
    Calculates the discriminator loss for a batch of real and fake images.

    Parameters:
        real_image (torch.Tensor): A tensor representing the batch of real images.
        fake_image (torch.Tensor): A tensor representing the batch of fake images.

    Returns:
        torch.Tensor: A tensor representing the discriminator loss for the given batch of images.
    """
    disc_real = discriminator(real_image).reshape(-1)
    disc_fake = discriminator(fake_image.detach()).reshape(-1)
    loss_disc_real = loss(disc_real, torch.ones_like(disc_real))
    loss_disc_fake = loss(disc_fake, torch.zeros_like(disc_fake))
    return (loss_disc_real + loss_disc_fake) / 2


def show_intermediate_results() -> None:
    """
    Display intermediate results during training.

    This function prints the current epoch and batch information, as well as the discriminator and generator loss. It
    also generates and displays a grid of fake images.

    Parameters:
    - None

    Returns:
    - None
    """
    print(f" Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(dataloader)} \
            Loss D: {discriminator_loss:.4f}, loss G: {generator_loss:.4f}")
    with torch.no_grad():
        fixed_fake_image = generator(fixed_noise)
        img_grid_fake = torchvision.utils.make_grid(fixed_fake_image[:32], normalize=True)
        show_tensor_images(img_grid_fake)


def plot_loss_convergence(mean_discriminator_loss: list,
                          mean_generator_loss: list) -> None:
    """
    Plot the convergence of the discriminator and generator loss.

    Parameters:
    - mean_discriminator_loss (list): A list of the mean discriminator loss values.
    - mean_generator_loss (list): A list of the mean generator loss values.

    Returns:
    - None
    """
    plt.plot(mean_discriminator_loss, label='Discriminator')
    plt.plot(mean_generator_loss, label='Generator')
    plt.show()

In [None]:
generator = Generator(z_dimension,
                      image_channels,
                      generator_features).to(device)
discriminator = Discriminator(image_channels,
                              discriminator_features).to(device)

initialize_weights(generator)
initialize_weights(discriminator)

In [None]:
generator_optimizer = optim.Adam(generator.parameters(),
                                 lr=learning_rate,
                                 betas=(0.5, 0.999))
discriminator_optimizer = optim.Adam(discriminator.parameters(),
                                     lr=learning_rate,
                                     betas=(0.5, 0.999))
loss = nn.BCELoss()
fixed_noise = torch.randn(32, z_dimension, 1, 1).to(device)
mean_discriminator_loss = []
mean_generator_loss = []
discriminator_loss = None
generator_loss = None

In [None]:
%%time
generator.train()
discriminator.train()

with mlflow.start_run(run_name=f'DCGAN-{uuid.uuid4()}'):
    mlflow.log_params({
        "generator_optimizer": generator_optimizer.__class__.__name__,
        "discriminator_optimizer": discriminator_optimizer.__class__.__name__,
        "loss": loss.__class__.__name__,
        "batch_size": batch_size,
        "learning_rate": learning_rate,
        "z_dimension": z_dimension,
        "image_channels": image_channels,
        "discriminator_features": discriminator_features,
        "generator_features": generator_features,
        "number_of_gpus": number_of_gpus,
        "image_size": image_size,
        "num_epochs": num_epochs,
        "device": device,
        "seed":manualSeed
    })

    for epoch in tqdm.tqdm(range(num_epochs)):
        for batch_idx, (real_image, _) in enumerate(dataloader):
            real_image = real_image.to(device)
            noise = torch.randn((batch_size, z_dimension, 1, 1)).to(device)
            fake_image = generator(noise)

            discriminator_loss = calculate_discriminator_loss(real_image, fake_image)
            discriminator_optimizer.zero_grad()
            discriminator_loss.backward()
            discriminator_optimizer.step()

            generator_loss = calculate_generator_loss(fake_image)
            generator_optimizer.zero_grad()
            generator_loss.backward()
            generator_optimizer.step()

            if batch_idx % 50 == 0:
                show_intermediate_results()

            mean_discriminator_loss.append(discriminator_loss.item())
            mean_generator_loss.append(generator_loss.item())

            mlflow.log_metrics(metrics={"discriminator_loss": discriminator_loss.item()},
                               step=batch_idx)
            mlflow.log_metrics(metrics={"generator_loss": generator_loss.item()},
                               step=batch_idx)

    mlflow.pytorch.log_model(pytorch_model=generator,
                             artifact_path="generator_model")
    mlflow.pytorch.log_model(pytorch_model=discriminator,
                             artifact_path="discriminator_model")

plot_loss_convergence(mean_discriminator_loss, mean_generator_loss)