In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np


 The Generator Network (G) 🎨<br>
The Generator takes a random noise vector (latent vector) as input and tries to transform it into something that resembles the real data (e.g., an image).

In [12]:
class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super(Generator, self).__init__()
        self.img_shape = img_shape # (channels, height, width)

        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.LeakyReLU(0.2, inplace=True), # inplace=True modifies the input directly

            nn.Linear(128, 256),
            nn.BatchNorm1d(256), # BatchNorm after linear layer
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(512, int(np.prod(img_shape))), # np.prod calculates product of elements
            nn.Tanh() # To output values between -1 and 1
        )

    def forward(self, z):
        # z is the input noise vector (batch_size, latent_dim)
        img = self.model(z)
        # Reshape the output to the image shape
        img = img.view(img.size(0), *self.img_shape) # * unpacks the tuple
        return img

 The Discriminator Network (D) 🧐<br>
The Discriminator takes an image (either real or generated by G) as input and outputs a probability that the image is real. It's essentially a binary classifier.

In [13]:
class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super(Discriminator ,self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)) , 512),
            nn.LeakyReLU(0.2 , inplace=True),

            nn.Linear(512 , 256),
            nn.LeakyReLU(0.2 , inplace=True),

            nn.Linear(256 , 1),
            nn.Sigmoid() #probabilistic output (0 fake , real 1)
        )

    def forward(self, img):
        # img is the input image (batch_size, channels, height, width)
        img_flat = img.view(img.size(0) , -1) #flattening the image
        validity = self.model(img_flat)
        return validity

loss

In [14]:
# For a Discriminator with a Sigmoid output layer
adversarial_loss = nn.BCELoss()

# If your Discriminator outputs logits (no Sigmoid at the end)
# adversarial_loss = nn.BCEWithLogitsLoss()

optimizers

In [15]:
# hyperparaneters 
lr = 0.0002
b1 = 0.5 #adam optimizer for beta1
b2 = 0.999 #adam optimizer for beta2



In [16]:
# Initialize Generator and Discriminator
# latent_dim, channels, img_size would be defined based on your dataset
# e.g., for MNIST: latent_dim=100, channels=1, img_size=28
# img_shape = (channels, img_size, img_size)

In [17]:
# Example configuration
latent_dim = 100
channels = 1      # For grayscale images like MNIST
img_size = 28
img_shape = (channels, img_size, img_size)

In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#initialize generator
generator = Generator(latent_dim, img_shape).to(device)
discriminator = Discriminator(img_shape).to(device)

#optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas = (b1, b2))
optimizer_D = optim.Adam(discriminator.parameters(), lr = lr, betas = (b1, b2))



In [19]:
#using MNIST dataset
transform = transforms.Compose([
    transforms.Resize(img_size),
    transforms.ToTensor(), #converting PIL img or numpy.ndarray to tensor
    transforms.Normalize([0.5], [0.5]) #normalizes to [-1 , 1] for a single channel
     # For 3-channel images (e.g. CIFAR10): transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])


])

batch_size = 64


In [None]:
# MNIST example
dataset = datasets.MNIST(root="./data/mnist", train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
num_epochs = 200

for epoch in range(num_epochs):
    for i, (real_imgs, _) in enumerate(dataloader): # _ are labels, not needed for basic gan

        #moving the data to configured device (cpu or gpu)
        real_imgs = real_imgs.to(device)

        #adversial ground truths
        real_labels = torch.ones(real_imgs.size(0) , 1).to(device) 
        fake_labels = torch.zeros(real_imgs.size(0) , 1).to(device)

        #training the discriminator
        optimizer_D.zero_grad() #clearing the old gradients

        #loss for real imgaes
        real_outputs = discriminator(real_imgs)
        d_loss_real = adversarial_loss(real_outputs , real_labels)

        #generate fake images
        z = torch.randn(real_imgs.size(0) , latent_dim).to(device)
        fake_imgs = generator(z)

        #loss for fake images
        #detaching the fake_imgs to prevent gradients from flowing back to G during D training
        fake_outputs = discriminator(fake_imgs.detach())
        d_loss_fake = adversarial_loss(fake_outputs, fake_labels)


        #total discriminator loss
        d_loss = (d_loss_real + d_loss_fake) / 2
        d_loss.backward() #compute gradients
        optimizer_D.step()


        #training generator
        optimizer_G.zero_grad() #clear old gradients


                # We want the discriminator to think the fake images are real
        # So we use real_labels (all 1s) as the target for the generator's output
        # No need to generate new fake_imgs, can reuse from D training if not detached,
        # but common practice is to generate fresh ones or re-evaluate D on them.
        # For simplicity and clarity, let's re-evaluate.



        fake_outputs_for_G = discriminator(fake_imgs)
        g_loss = adversarial_loss(fake_outputs_for_G, real_labels) 

        g_loss.backward()
        optimizer_G.step() #upgrade G's weights


        # ---------------------
        #  Log Progress & Save Images
        # ---------------------
        if (i + 1) % 200 == 0: # Log every 200 batches
            print(
                f"[Epoch {epoch+1}/{num_epochs}] [Batch {i+1}/{len(dataloader)}] "
                f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]"
            )

    # At the end of each epoch (or every few epochs), save some generated images
    if (epoch + 1) % 10 == 0:
        with torch.no_grad(): # No need to track gradients here
            # Generate a fixed set of noise vectors to see G's progress over time
            fixed_noise = torch.randn(25, latent_dim).to(device) # Generate 25 images
            generated_images = generator(fixed_noise).cpu() # Move to CPU for visualization

            # Rescale images from [-1, 1] to [0, 1] for display/saving
            generated_images = 0.5 * generated_images + 0.5

            grid = torchvision.utils.make_grid(generated_images, nrow=5, normalize=False)
            plt.figure(figsize=(8,8))
            plt.imshow(grid.permute(1, 2, 0))
            plt.title(f'Epoch {epoch+1}')
            plt.axis('off')
            plt.show()
            # You can also save the image grid:
            # torchvision.utils.save_image(generated_images, f"gan_images/epoch_{epoch+1}.png", nrow=5, normalize=True)
            # Make sure the directory "gan_images" exists or create it.
