<a href="https://colab.research.google.com/github/MehrdadDastouri/gan_mnist_image_generation/blob/main/gan_mnist_image_generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
from torch.utils.data import DataLoader
import os
import matplotlib.pyplot as plt

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create directory to save generated images
os.makedirs("gan_images", exist_ok=True)

# Hyperparameters
latent_dim = 100  # Dimension of the noise vector
image_size = 28  # Image size for MNIST
channels = 1  # Grayscale
batch_size = 128
epochs = 50
lr = 0.0002  # Learning rate
beta1 = 0.5  # Beta1 for Adam optimizer

# Define the Generator
class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super(Generator, self).__init__()
        self.img_shape = img_shape
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Linear(512, int(torch.prod(torch.tensor(img_shape)))),
            nn.Tanh(),  # Output values between -1 and 1
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *self.img_shape)
        return img

# Define the Discriminator
class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(int(torch.prod(torch.tensor(img_shape))), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),  # Output a probability
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

# Initialize the Generator and Discriminator
img_shape = (channels, image_size, image_size)
generator = Generator(latent_dim, img_shape).to(device)
discriminator = Discriminator(img_shape).to(device)

# Loss function
adversarial_loss = nn.BCELoss()

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))

# Data loader for MNIST
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),  # Normalize to [-1, 1]
])
dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Training loop
for epoch in range(epochs):
    for i, (imgs, _) in enumerate(dataloader):
        # Ground truths
        real = torch.ones((imgs.size(0), 1), device=device)
        fake = torch.zeros((imgs.size(0), 1), device=device)

        # Train Generator
        optimizer_G.zero_grad()
        z = torch.randn((imgs.size(0), latent_dim), device=device)  # Noise vector
        generated_imgs = generator(z)
        g_loss = adversarial_loss(discriminator(generated_imgs), real)  # Fool the discriminator
        g_loss.backward()
        optimizer_G.step()

        # Train Discriminator
        optimizer_D.zero_grad()
        real_loss = adversarial_loss(discriminator(imgs.to(device)), real)
        fake_loss = adversarial_loss(discriminator(generated_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        # Print progress
        if i % 100 == 0:
            print(f"Epoch [{epoch+1}/{epochs}] Batch {i}/{len(dataloader)} | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}")

    # Save generated images every epoch
    generator.eval()
    with torch.no_grad():
        z = torch.randn(64, latent_dim, device=device)
        generated_imgs = generator(z)
        grid = make_grid((generated_imgs + 1) / 2, nrow=8, normalize=False)  # Denormalize to [0, 1]
        save_image(grid, f"gan_images/epoch_{epoch+1}.png")
    generator.train()

# Visualize the final generated images
final_image = plt.imread(f"gan_images/epoch_{epochs}.png")
plt.figure(figsize=(8, 8))
plt.imshow(final_image)
plt.axis("off")
plt.title("Final Generated Images")
plt.show()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 37.5MB/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 1.21MB/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 10.3MB/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 3.40MB/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Epoch [1/50] Batch 0/469 | D Loss: 0.6910 | G Loss: 0.7103
Epoch [1/50] Batch 100/469 | D Loss: 0.4568 | G Loss: 0.8245
Epoch [1/50] Batch 200/469 | D Loss: 0.5567 | G Loss: 0.5913
Epoch [1/50] Batch 300/469 | D Loss: 0.5791 | G Loss: 1.0003
Epoch [1/50] Batch 400/469 | D Loss: 0.4403 | G Loss: 1.2282
Epoch [2/50] Batch 0/469 | D Loss: 0.4118 | G Loss: 1.0445
Epoch [2/50] Batch 100/469 | D Loss: 0.5474 | G Loss: 1.9859
Epoch [2/50] Batch 200/469 | D Loss: 0.3336 | G Loss: 2.0303
Epoch [2/50] Batch 300/469 | D Loss: 0.2386 | G Loss: 1.3432
Epoch [2/50] Batch 400/469 | D Loss: 0.3416 | G Loss: 1.0483
Epoch [3/50] Batch 0/469 | D Loss: 0.3885 | G Loss: 0.9699
Epoch [3/50] Batch 100/469 | D Loss: 0.4576 | G Loss: 0.7316
Epoch [3/50] Batch 200/469 | D Loss: 0.4284 | G Loss: 2.0507
Epoch [3/50] Batch 300/469 | D Loss: 0.4138 | G Loss: 0.8305
Epoch [3/50] Batch 400/469 | D Loss: 0.3769 | G Loss: 1.1491
Epoch [4/50] Bat