<a href="https://colab.research.google.com/github/aimldlnlp/RKK303-Computer-Vision-Final-Project/blob/main/GAN_with_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import Required Libraries


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as dsets
import torchvision.transforms as transforms

# Define the Generator and Discriminator Models

In [2]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 28*28),
            nn.Tanh()  # Output should be between -1 and 1
        )

    def forward(self, z):
        return self.model(z).view(-1, 1, 28, 28)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()  # Output between 0 and 1
        )

    def forward(self, img):
        return self.model(img)

# Initialize Models, Loss Function, and Optimizers

In [3]:
generator = Generator()
discriminator = Discriminator()

criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002)

# Training Loop

In [4]:
# Load the MNIST dataset
mnist = dsets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
data_loader = torch.utils.data.DataLoader(mnist, batch_size=64, shuffle=True)

# Training the GAN
num_epochs = 50
for epoch in range(num_epochs):
    for real_images, _ in data_loader:
        batch_size = real_images.size(0)

        # Train Discriminator
        z = torch.randn(batch_size, 100)  # Random noise
        fake_images = generator(z)
        real_labels = torch.ones(batch_size, 1)  # Labels for real images
        fake_labels = torch.zeros(batch_size, 1)  # Labels for fake images

        optimizer_D.zero_grad()
        outputs = discriminator(real_images)
        d_loss_real = criterion(outputs, real_labels)
        outputs = discriminator(fake_images.detach())
        d_loss_fake = criterion(outputs, fake_labels)
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizer_D.step()

        # Train Generator
        optimizer_G.zero_grad()
        outputs = discriminator(fake_images)
        g_loss = criterion(outputs, real_labels)
        g_loss.backward()
        optimizer_G.step()

    print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}')

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 39.1MB/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 1.24MB/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 10.8MB/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 2.54MB/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Epoch [1/50], d_loss: 3.5201, g_loss: 1.3312
Epoch [2/50], d_loss: 0.9882, g_loss: 0.8839
Epoch [3/50], d_loss: 0.9332, g_loss: 3.3154
Epoch [4/50], d_loss: 0.7635, g_loss: 1.1851
Epoch [5/50], d_loss: 0.5982, g_loss: 1.6919
Epoch [6/50], d_loss: 0.6452, g_loss: 1.5113
Epoch [7/50], d_loss: 1.1522, g_loss: 1.0088
Epoch [8/50], d_loss: 1.8872, g_loss: 1.2322
Epoch [9/50], d_loss: 0.6560, g_loss: 1.3315
Epoch [10/50], d_loss: 1.0278, g_loss: 2.0553
Epoch [11/50], d_loss: 1.2514, g_loss: 2.0441
Epoch [12/50], d_loss: 0.3823, g_loss: 1.6383
Epoch [13/50], d_loss: 0.7317, g_loss: 1.2457
Epoch [14/50], d_loss: 0.9453, g_loss: 1.2251
Epoch [15/50], d_loss: 1.1688, g_loss: 2.6061
Epoch [16/50], d_loss: 1.0321, g_loss: 1.2930
Epoch [17/50], d_loss: 1.6943, g_loss: 0.9679
Epoch [18/50], d_loss: 1.8136, g_loss: 0.8511
Epoch [19/50], d_loss: 0.5116, g_loss: 2.2687
Epoch [20/50], d_loss: 1.7262, g_loss: 1.2037
Epoch [21/50],