# My first DCGAN 🌱
- DCGAN [paper](https://arxiv.org/abs/1511.06434)
<details>
<summary>
<font size="3" color="green">
<b>Gan Archtecture Scheme</b>
</font>
</summary>
<div>
<img src = "layers.png" width=800>
</div>

</details>


In [1]:
import torch
from torch import nn


class Discriminator(nn.Module):
    def __init__(self, img_channels, hidden_units) -> None:
        """Discriminator model for DCGAN architecture based
         on Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks paper.
         Paper: https://arxiv.org/abs/1511.06434

        Args:
            img_channels (int): Number of channels of the image (for example for RGB its 3).
            hidden_units (int): Number of neurons in hidden layer.
        """
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            # First discriminator block (bo batch norm)
            nn.Conv2d(img_channels, hidden_units, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            # Rest of the blocks...
            self._block(in_channels=hidden_units, out_channels=hidden_units * 2),
            self._block(in_channels=hidden_units * 2, out_channels=hidden_units * 4),
            self._block(in_channels=hidden_units * 4, out_channels=hidden_units * 8),
            nn.Conv2d(in_channels=hidden_units * 8, out_channels=1, kernel_size=4, stride=2, padding=0),
            nn.Sigmoid(), # Output range [0, 1]
        )

    def _block(self, in_channels, out_channels, kernel__size=4, stride=2, padding=1) -> torch.Tensor:
        """Creates a single discriminator block consisting of a convolution layer, batch normalization and Leaky ReLU.

        Args:
            in_channels (int): Number of input channels.
                out_channels (int): Number of output channels.
                kernel_size (int, optional): Size of convolution filter. Defaults to 4.
                stride (int, optional): Stride size. Defaults to 2.
                padding (int, optional): Amount of pixels added around the image. Defaults to 1.
        Returns:
            torch.Tensor: Returns a tensor with performed convolutions.
        """
        return nn.Sequential(
            nn.Conv2d(
                in_channels, out_channels, kernel__size, stride, padding, bias=False
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x: torch.Tensor):
        return self.disc(x)

In [2]:
class Generator(nn.Module):
    def __init__(self, z_dim: int, img_channels: int, hidden_units: int) -> None:
        """_summary_

        Args:
            z_dim (int): Dimension of the noise vector used for image generation.
            img_channels (int): Number of channels in the generated image.
            hidden_units (int): Number of neurons in the hidden layer.
        """
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            # First generator block
            self._generator_block(
                in_channels=z_dim,
                out_channels=hidden_units * 16,
                kernel_size=4,
                stride=1,
                padding=0,
            ),
            # Rest of the generator blocks...
            self._generator_block(
                in_channels=hidden_units * 16, out_channels=hidden_units * 8
            ),
            self._generator_block(
                in_channels=hidden_units * 8, out_channels=hidden_units * 4
            ),
            self._generator_block(
                in_channels=hidden_units * 4, out_channels=hidden_units * 2
            ),
            # Final block
            nn.ConvTranspose2d(
                in_channels=hidden_units * 2,
                out_channels=img_channels,
                kernel_size=4,
                stride=2,
                padding=1,
            ),
            nn.Tanh(),  # Output range: [-1, 1]
        )

    def _generator_block(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 4,
        stride: int = 2,
        padding: int = 1,
    ) -> torch.Tensor:
        """Creates a single generator block.

        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            kernel_size (int, optional): Size of convolution filter. Defaults to 4.
            stride (int, optional): Stride size. Defaults to 2.
            padding (int, optional): Amount of pixels added around the image. Defaults to 1.

        Returns:
            torch.Tensor: Returns a tensor with performed convolutions.
        """
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels, out_channels, kernel_size, stride, padding, bias=False
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.gen(x)
            

In [3]:
def initialize_weights(model: nn.Module) -> None:
    """Initializes weights using a normal distribution with mean 0 and std 0.02.

    Args:
        model (nn.Module): The generator or discriminator model.
    """
    for m in model.modules():
        # if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight, 0.0, 0.02)
            if isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.bias, 0)


def test_weights_initialization() -> None:
    """Tests the initialize_weights() function"""
    N, C, H, W = 8, 3, 64, 64
    z_dim = 100

    x = torch.randn(
        (N, C, H, W)
    )  # Simulate a random batch of images of shape N x C x H x W

    ### Test Discriminator
    disc = Discriminator(img_channels=C, hidden_units=8)
    initialize_weights(model=disc)

    # There should be outputet one prediction per image
    assert disc(x).shape == (N, 1, 1, 1), "Discriminators weights are not initialized correctly."

    ### Test Generator
    gen = Generator(z_dim=z_dim, img_channels=C, hidden_units=8)
    initialize_weights(model=gen)
    
    noise = torch.randn((N, z_dim, 1, 1))
    fake_image = gen(noise)
    
    assert fake_image.shape == (
        N,
        C,
        H,
        W,
    ), f"Generators weights are not initialized correctly. Instead of ({N}, {C}, {H}, {W}) they are {fake_image.shape}"

In [4]:
test_weights_initialization()