<a href="https://colab.research.google.com/github/Firojpaudel/GenAI-Chronicles/blob/main/GANs/GAN_With_Pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [12]:
import os
import shutil
import numpy as np
import torch
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

# Set training parameters directly here
n_epochs = 200  # number of epochs of training
batch_size = 64  # size of the batches
lr = 0.0009  # adam: learning rate
b1 = 0.5  # adam: decay of first order momentum of gradient
b2 = 0.999  # adam: decay of first order momentum of gradient
n_cpu = 64  # number of cpu threads to use during batch generation
latent_dim = 100  # dimensionality of the latent space
img_size = 28  # size of each image dimension
channels = 1  # number of image channels
sample_interval = 400  # interval between image samples

# Image shape (channel, height, width)
img_shape = (channels, img_size, img_size)

# Check if CUDA is available
cuda = torch.cuda.is_available()

# Create the Generator class
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img

# Define Discriminator class
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

# Loss function
adversarial_loss = torch.nn.BCELoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

# Configure data loader
dataloader = DataLoader(
    datasets.MNIST(
        "./data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose([transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
    ),
    batch_size=batch_size,
    shuffle=True
)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

# Clear the previous images directory if it exists
if os.path.exists('images'):
    shutil.rmtree('images')  # This will delete the 'images' directory and all its contents
os.makedirs("images", exist_ok=True)  # Recreate the empty directory

# Define how often to print and save images (e.g., every 10 batches)
print_interval = 100
save_interval = 2000

# Training loop
for epoch in range(n_epochs):
    for i, (imgs, _) in enumerate(dataloader):

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        # Train Generator
        optimizer_G.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim))))

        # Generate a batch of images
        gen_imgs = generator(z)

        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)

        g_loss.backward()
        optimizer_G.step()

        # Train Discriminator
        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        # Print the progress every `print_interval` batches
        if i % print_interval == 0:
            print(f"[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(dataloader)}] [D loss: {d_loss.item():.6f}] [G loss: {g_loss.item():.6f}]")

        # Save and show images periodically
        if i % save_interval == 0:
            # Save the generated images
            save_image(gen_imgs.data[:25], f"images/{epoch}_{i}.png", nrow=5, normalize=True)

            # # Display the generated images directly in Colab
            # plt.figure(figsize=(5, 5))
            # plt.imshow(np.transpose(gen_imgs[0].cpu().detach().numpy(), (1, 2, 0)), cmap='gray')
            # plt.axis('off')
            # plt.show()


[Epoch 0/200] [Batch 0/938] [D loss: 0.681468] [G loss: 0.691147]
[Epoch 0/200] [Batch 100/938] [D loss: 0.741282] [G loss: 0.493891]
[Epoch 0/200] [Batch 200/938] [D loss: 0.542678] [G loss: 1.081067]
[Epoch 0/200] [Batch 300/938] [D loss: 0.528900] [G loss: 1.483625]
[Epoch 0/200] [Batch 400/938] [D loss: 0.652110] [G loss: 0.462227]
[Epoch 0/200] [Batch 500/938] [D loss: 0.554159] [G loss: 0.960890]
[Epoch 0/200] [Batch 600/938] [D loss: 0.513295] [G loss: 0.887534]
[Epoch 0/200] [Batch 700/938] [D loss: 0.556031] [G loss: 0.996662]
[Epoch 0/200] [Batch 800/938] [D loss: 0.577245] [G loss: 0.670647]
[Epoch 0/200] [Batch 900/938] [D loss: 0.518372] [G loss: 0.900598]
[Epoch 1/200] [Batch 0/938] [D loss: 0.462437] [G loss: 1.023750]
[Epoch 1/200] [Batch 100/938] [D loss: 0.496170] [G loss: 2.405663]
[Epoch 1/200] [Batch 200/938] [D loss: 0.409335] [G loss: 1.717518]
[Epoch 1/200] [Batch 300/938] [D loss: 0.441603] [G loss: 0.868198]
[Epoch 1/200] [Batch 400/938] [D loss: 0.727007] [G 

KeyboardInterrupt: 