In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np


 The Generator Network (G) 🎨<br>
The Generator takes a random noise vector (latent vector) as input and tries to transform it into something that resembles the real data (e.g., an image).

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super(Generator, self).__init__()
        self.img_shape = img_shape # (channels, height, width)

        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.LeakyReLU(0.2, inplace=True), # inplace=True modifies the input directly

            nn.Linear(128, 256),
            nn.BatchNorm1d(256), # BatchNorm after linear layer
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(512, int(np.prod(img_shape))), # np.prod calculates product of elements
            nn.Tanh() # To output values between -1 and 1
        )

    def forward(self, z):
        # z is the input noise vector (batch_size, latent_dim)
        img = self.model(z)
        # Reshape the output to the image shape
        img = img.view(img.size(0), *self.img_shape) # * unpacks the tuple
        return img

 The Discriminator Network (D) 🧐<br>
The Discriminator takes an image (either real or generated by G) as input and outputs a probability that the image is real. It's essentially a binary classifier.

In [None]:
class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super(Discriminator ,self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)) , 512),
            nn.LeakyReLU(0.2 , inplace=True),

            nn.Linear(512 , 256),
            nn.LeakyReLU(0.2 , inplace=True),

            nn.Linar(256 , 1),
            nn.Sigmoid() #probabilistic output (0 fake , real 1)
        )

    def forward(self, img):
        # img is the input image (batch_size, channels, height, width)
        img_flat = img.view(img.size(0) , -1) #flattening the image
        validity = self.model(img_flat)
        return validity