<a href="https://colab.research.google.com/github/Shubham-Sahoo/GAN-Basics/blob/main/GAN_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [191]:
import torch
from torchvision import datasets, transforms
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os

In [177]:
# Define a transformation to apply to the images (e.g., convert to tensor)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # scale to [-1, 1]
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of test samples: {len(test_dataset)}")

Number of training samples: 60000
Number of test samples: 10000


In [180]:
for x in train_dataset:
    print(x[0].shape)
    break

torch.Size([1, 28, 28])


In [194]:
class Generator(nn.Module):
    def __init__(self, latent_dim, img_channels=1, feature_maps=64):
        super().__init__()

        self.latent_linear = nn.Linear(latent_dim, feature_maps * 7*7)
        self.unflatten = nn.Unflatten(1, (feature_maps, 7, 7))
        self.conv_up1 = nn.ConvTranspose2d(feature_maps, feature_maps // 2, kernel_size=4, stride=2, padding=1)
        self.conv_up2 = nn.ConvTranspose2d(feature_maps // 2, img_channels, kernel_size=4, stride=2, padding=1)

        self.relu1 = nn.LeakyReLU(0.2, inplace=True)
        self.relu2 = nn.LeakyReLU(0.2, inplace=True)
        self.tanh1 = nn.Tanh()

        self.batchnorm1 = nn.BatchNorm2d(feature_maps // 2)


    def forward(self, z):

        out = self.latent_linear(z)
        out = self.relu1(out)
        out = self.unflatten(out)

        out = self.conv_up1(out)
        out = self.relu2(out)
        out = self.batchnorm1(out)

        out = self.conv_up2(out)
        out = self.tanh1(out)

        return out




In [198]:
G = Generator(100, 1, 64)
x = torch.tensor(np.ones((1,100)), dtype=torch.float)
G(x).shape

torch.Size([1, 1, 28, 28])

In [216]:
class Discriminator(nn.Module):
    def __init__(self, img_channels=1, feature_maps=64):

        super().__init__()

        self.model = nn.Sequential(
            nn.Conv2d(img_channels, feature_maps, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(feature_maps, feature_maps*2, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Flatten(),
            nn.Linear(feature_maps*2*7*7, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.model(x)


In [218]:
D = Discriminator(1, 64)
x = torch.tensor(np.ones((3,1,28,28)), dtype=torch.float)
D(x).shape

torch.Size([3, 1])

In [190]:
def plot_generated_images(generator, noise, epoch, out_dir="gan_outputs"):
    generator.eval()
    with torch.no_grad():
        fake_images = generator(noise).cpu()

    # For 1D vectors (e.g., Gaussian): Plot as line chart
    if fake_images.dim() == 2:  # Shape: [batch_size, features]
        plt.figure(figsize=(8, 4))
        for i in range(min(8, fake_images.size(0))):
            plt.plot(fake_images[i].numpy(), label=f"Sample {i}")
        plt.title(f"Generated Samples at Epoch {epoch}")
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.savefig(f"{out_dir}/epoch_{epoch}_lines.png")
        plt.close()

    # For image data (e.g., MNIST or CIFAR): Show grid
    elif fake_images.dim() == 4:  # Shape: [B, C, H, W]
        from torchvision.utils import make_grid
        grid = make_grid(fake_images, nrow=4, normalize=True, value_range=(-1, 1))
        plt.figure(figsize=(6, 6))
        plt.imshow(grid.permute(1, 2, 0))
        plt.title(f"Epoch {epoch} - Generated Images")
        plt.axis('off')
        plt.tight_layout()
        plt.savefig(f"{out_dir}/epoch_{epoch}_images.png")
        plt.close()

In [221]:
def train_gan(train_data, latent_dim: int = 100, hidden_dim: int = 128, learning_rate: float = 0.001, epochs: int = 500, batch_size: int = 128, seed: int = 42):

    torch.manual_seed(seed)

    fixed_noise = torch.randn(16, latent_dim).to('cuda')

    os.makedirs("gan_outputs", exist_ok=True)

    G = Generator(latent_dim, 1, 64).to(device='cuda')
    D = Discriminator(1, 64).to(device='cuda')


    """
    Optimizer setup
    """
    optimizer_gen = optim.Adam(G.parameters(), lr=0.001, betas=(0.9, 0.999))
    optimizer_dsc = optim.Adam(D.model.parameters(), lr=0.001, betas=(0.9, 0.999))



    loss_func = nn.BCELoss()
    gen_loss_class = nn.BCELoss()


    epoch = 0

    gen_loss_up = []
    dsc_loss_up = []

    train_data_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)

    for epoch in range(epochs):

        for data in train_data_loader:
            x_data, x_label = data
            # print(x_data.shape)

            """
            Discriminator pass

            """


            img_real = x_data.to(device='cuda')
            z_fake_latent = torch.randn(batch_size, latent_dim).to(device='cuda')

            img_fake = G(z_fake_latent)

            # print(img_fake.shape)

            """
            Optimizer step discriminator
            """

            optimizer_dsc.zero_grad()

            dsc_out_real = D(img_real)
            dsc_labels_real = torch.ones(batch_size, 1).to(device='cuda')

            dsc_out_fake = D(img_fake.detach())
            dsc_labels_fake = torch.zeros(batch_size, 1).to(device='cuda')

            loss_dsc_real = loss_func(dsc_out_real, dsc_labels_real)
            loss_dsc_fake = loss_func(dsc_out_fake, dsc_labels_fake)

            loss_dsc = (loss_dsc_real + loss_dsc_fake)

            dsc_loss_up.append(loss_dsc.item())

            loss_dsc.backward()
            optimizer_dsc.step()

            """
            Optimizer step generator
            """

            optimizer_gen.zero_grad()

            gen_out = D(img_fake)

            gen_loss = gen_loss_class(gen_out, dsc_labels_real)

            gen_loss_up.append(gen_loss.item())

            gen_loss.backward()
            optimizer_gen.step()

            # print(loss_dsc_real, loss_dsc_fake, gen_loss)

        # Logging
        if (epoch + 1) % 1 == 0:
            print(f"Epoch [{epoch+1}/{epochs}] | D Loss: {loss_dsc.item():.4f} | G Loss: {gen_loss.item():.4f}")

            plot_generated_images(G, fixed_noise, epoch + 1)
            print(f"D(real): {torch.sigmoid(dsc_out_real).mean().item():.4f}, D(fake): {torch.sigmoid(dsc_out_fake).mean().item():.4f}")


    return G.forward, dsc_loss_up, gen_loss_up

In [None]:
gen_forward, dsc_loss_up,gen_loss_up  = train_gan(train_dataset, learning_rate=0.1, epochs=100, seed=42)
# z = torch.randn(50, 10)
# x_gen = gen_forward(z)
# print((round(x_gen.mean().item(), 4), round(x_gen.std().item(), 4)))

Epoch [1/100] | D Loss: 0.0110 | G Loss: 8.1977
D(real): 0.7296, D(fake): 0.5007
Epoch [2/100] | D Loss: 0.2027 | G Loss: 7.0988
D(real): 0.7103, D(fake): 0.5043
Epoch [3/100] | D Loss: 0.2426 | G Loss: 7.3263
D(real): 0.7274, D(fake): 0.5249
Epoch [4/100] | D Loss: 0.3444 | G Loss: 4.7825
D(real): 0.7126, D(fake): 0.5233
Epoch [5/100] | D Loss: 0.3917 | G Loss: 4.1064
D(real): 0.6905, D(fake): 0.5089
Epoch [6/100] | D Loss: 0.5394 | G Loss: 2.6250
D(real): 0.6898, D(fake): 0.5271
Epoch [7/100] | D Loss: 0.4719 | G Loss: 3.1037
D(real): 0.6982, D(fake): 0.5297
Epoch [8/100] | D Loss: 0.3968 | G Loss: 3.4179
D(real): 0.7116, D(fake): 0.5455
Epoch [9/100] | D Loss: 0.5576 | G Loss: 3.8071
D(real): 0.7110, D(fake): 0.5636
Epoch [10/100] | D Loss: 0.4766 | G Loss: 2.8863
D(real): 0.6899, D(fake): 0.5247
Epoch [11/100] | D Loss: 0.3972 | G Loss: 4.4867
D(real): 0.7106, D(fake): 0.5410
Epoch [12/100] | D Loss: 0.4398 | G Loss: 2.8734
D(real): 0.7080, D(fake): 0.5411
Epoch [13/100] | D Loss: 