# Implementation of Deep Convolutional GAN

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/KTFish/DCGAN/blob/main/dcgan.ipynb)

Deep Convolutional Generative Adversarial Network architecture implementation. Today the usage of convolution in GANs is a standard and according to [Generative Deep Learning. Teaching Machines to Paint, Write, Compose, and Play](https://helion.pl/ksiazki/generative-deep-learning-teaching-machines-to-paint-write-compose-and-play-david-foster,e_16sj.htm#format/e) usually when we say GAN we really think about a DCGAN.

### Some resources related to this topic

- DCGAN [paper](https://arxiv.org/abs/1511.06434).
- [Papers with Code](https://paperswithcode.com/method/dcgan) explanation.


In [None]:
# TODO: Make sure the Colab Button works

In [None]:
import torch
from torch import nn
import torch.optim as optim
import torchvision
from torchvision.datasets import MNIST
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from typing import Tuple
import matplotlib.pyplot as plt
torch.manual_seed(42) 
# Import helper functions
# from scripts.utils import get_random_noise

In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim: int, img_channels: int, hidden_units: int) -> None:
        """Generator model.

        Args:
            z_dim (int): Dimension of the noise vector used for image generation.
            img_channels (int): Number of channels in the generated image.
            hidden_units (int): Number of neurons in the hidden layer.
        """
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            self._generator_block(
                in_channels=z_dim,
                out_channels=hidden_units * 4,
                kernel_size=4,
                stride=1,
                padding=0,
            ),  # First generator block
            # Rest of the generator blocks...
            self._generator_block(
                in_channels=hidden_units * 4, out_channels=hidden_units * 2
            ),
            self._generator_block(
                in_channels=hidden_units * 2, out_channels=hidden_units
            ),
            self._generator_final_block(hidden_units, img_channels),  # Final block
        )
        self.z_dim = z_dim

    def _generator_block(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 4,
        stride: int = 2,
        padding: int = 1,
    ) -> torch.Tensor:
        """Creates a single generator block.

        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            kernel_size (int, optional): Size of convolution filter. Defaults to 4.
            stride (int, optional): Stride size. Defaults to 2.
            padding (int, optional): Amount of pixels added around the image. Defaults to 1.

        Returns:
            torch.Tensor: Returns a tensor with performed convolutions.
        """
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels, out_channels, kernel_size, stride, padding, bias=False
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def _generator_final_block(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 4,
        stride: int = 2,
        padding: int = 1,
    ) -> torch.Tensor:
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
            ),
            nn.Tanh(),
        )  # Output range: [-1, 1]

    def forward(self, noise: torch.Tensor) -> torch.Tensor:
        x = noise.view(len(noise), self.z_dim, 1, 1)
        return self.gen(x)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, img_channels, hidden_units) -> None:
        """Discriminator model for DCGAN architecture based
         on Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks paper.
         Paper: https://arxiv.org/abs/1511.06434

        Args:
            img_channels (int): Number of channels of the image (for example for RGB its 3).
            hidden_units (int): Number of neurons in hidden layer.
        """
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            self._block(img_channels, hidden_units, kernel_size=4, stride=2, padding=1),
            self._block(hidden_units, hidden_units * 2),
            nn.Conv2d(
                in_channels=hidden_units * 2,
                out_channels=1,
                kernel_size=4,
                stride=2,
                padding=0,
            ),
            nn.Sigmoid(),  # Output range [0, 1]
        )

    def _block(
        self, in_channels, out_channels, kernel_size=4, stride=2, padding=1
    ) -> torch.Tensor:
        """Creates a single discriminator block consisting of a convolution layer, batch normalization and Leaky ReLU.

        Args:
            in_channels (int): Number of input channels.
                out_channels (int): Number of output channels.
                kernel_size (int, optional): Size of convolution filter. Defaults to 4.
                stride (int, optional): Stride size. Defaults to 2.
                padding (int, optional): Amount of pixels added around the image. Defaults to 1.
        Returns:
            torch.Tensor: Returns a tensor with performed convolutions.
        """
        return nn.Sequential(
            nn.Conv2d(
                in_channels, out_channels, kernel_size, stride, padding, bias=False
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, image: torch.Tensor) -> torch.Tensor:
        """Foward pass implementation for the discriminator.

        Args:
            image (torch.Tensor): Flattened tensor representing an image.

        Returns:
            torch.Tensor: _description_
        """
        predicitons = self.disc(image)
        return predicitons.view(len(predicitons), -1) # Reshape to one dimension

In [None]:
def get_random_noise(n_samples: int, z_dim: int, device: str = "cpu") -> torch.Tensor:
    """Returns a noise vector z (used by the generator to create an image).

    Args:
        n_samples (int): Number of samples that will be generated from that vector. Usually set to batch size.
        z_dim (int): Dimension of the noise vector.
        device (str, optional): Device on which the vector will be stored. Defaults to 'cpu'.

    Returns:
        torch.Tensor: Noise.
    """
    return torch.randn(n_samples, z_dim, device=device)


In [None]:
hidden_units = 16
criterion = nn.BCEWithLogitsLoss()
z_dim = 64
display_step = 100
batch_size = 128
lr = 0.0002
betas=(0.5,0.999)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, ), (0.05,))
])

dataloader = DataLoader(
    dataset=MNIST('dataset', download=True, transform=transform),
    shuffle=True,
    batch_size=batch_size
)

gen = Generator(z_dim, img_channels=1, hidden_units=hidden_units).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=betas)

disc = Discriminator(img_channels=1, hidden_units=hidden_units).to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr, betas=betas)


In [None]:

def initialize_weights(model: nn.Module) -> None:
    """Initializes weights using a normal distribution with mean 0 and std 0.02.

    Args:
        model (nn.Module): The generator or discriminator model.
    """
    for m in model.modules():
        # if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight, 0.0, 0.02)
            if isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.bias, 0)

In [None]:
def test_weights_initialization() -> None:
    """Tests the initialize_weights() function"""
    N, C, H, W = 8, 3, 64, 64
    z_dim = 100

    x = torch.randn(
        (N, C, H, W)
    )  # Simulate a random batch of images of shape N x C x H x W

    ### Test Discriminator
    disc = Discriminator(img_channels=C, hidden_units=8)
    initialize_weights(model=disc)

    # There should be outputet one prediction per image
    assert disc(x).shape == (N, 1, 1, 1), "Discriminators weights are not initialized correctly."

    ### Test Generator
    gen = Generator(z_dim=z_dim, img_channels=C, hidden_units=8)
    initialize_weights(model=gen)
    
    noise = torch.randn((N, z_dim, 1, 1))
    fake_image = gen(noise)
    
    assert fake_image.shape == (
        N,
        C,
        H,
        W,
    ), f"Generators weights are not initialized correctly. Instead of ({N}, {C}, {H}, {W}) they are {fake_image.shape}"

In [None]:

def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in an uniform grid.
    '''
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

In [None]:
n_epochs = 50
cur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0
for epoch in range(n_epochs):
    # Dataloader returns the batches
    for real, _ in dataloader:
        cur_batch_size = len(real)
        real = real.to(device)

        ## Update discriminator ##
        disc_opt.zero_grad()
        fake_noise = get_random_noise(cur_batch_size, z_dim, device=device)
        fake = gen(fake_noise)
        disc_fake_pred = disc(fake.detach())
        disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
        disc_real_pred = disc(real)
        disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))
        disc_loss = (disc_fake_loss + disc_real_loss) / 2

        # Keep track of the average discriminator loss
        mean_discriminator_loss += disc_loss.item() / display_step
        # Update gradients
        disc_loss.backward(retain_graph=True)
        # Update optimizer
        disc_opt.step()

        ## Update generator ##
        gen_opt.zero_grad()
        fake_noise_2 = get_random_noise(cur_batch_size, z_dim, device=device)
        fake_2 = gen(fake_noise_2)
        disc_fake_pred = disc(fake_2)
        gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
        gen_loss.backward()
        gen_opt.step()

        # Keep track of the average generator loss
        mean_generator_loss += gen_loss.item() / display_step

        ## Visualization code ##
        if cur_step % display_step == 0 and cur_step > 0:
            print(f"Epoch {epoch}, step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}")
            show_tensor_images(fake)
            show_tensor_images(real)
            mean_generator_loss = 0
            mean_discriminator_loss = 0
        cur_step += 1