<a href="https://colab.research.google.com/github/Benedictakel/DCGAN-for-Handwritten-Digit-Generation-MNIST-/blob/main/DCGAN_for_Handwritten_Digit_Generation_(MNIST).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch torchvision matplotlib imageio

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.utils as vutils
import matplotlib.pyplot as plt
import numpy as np
import os
import imageio
from torch.utils.data import DataLoader


In [None]:
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cpu


In [None]:
os.makedirs("generated", exist_ok=True)

In [None]:
transform = transforms.Compose([
    transforms.Resize(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [None]:
dataset = MNIST(root='./data', download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

100%|██████████| 9.91M/9.91M [00:00<00:00, 59.7MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.68MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 13.1MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 5.94MB/s]


In [None]:
class Generator(nn.Module):
    def __init__(self, nz=100, ngf=64, nc=1):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf*4, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf*4),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*2),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        return self.main(x)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, nc=1, ndf=64):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf*2, ndf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf*4, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.main(x)

In [None]:
nz = 100  # Latent vector size
G = Generator(nz).to(device)
D = Discriminator().to(device)

In [None]:
criterion = nn.BCELoss()
optimizerD = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))


In [None]:
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

In [22]:
epochs = 25
img_list = []

for epoch in range(epochs):
    for i, (real_images, _) in enumerate(dataloader):
        real_images = real_images.to(device)
        b_size = real_images.size(0)

        # Labels
        real_labels = torch.ones(b_size, device=device)
        fake_labels = torch.zeros(b_size, device=device)

        # Train Discriminator
        D.zero_grad()

        # Real images
        output = D(real_images).view(-1)
        loss_real = criterion(output, real_labels)

        # Fake images
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        fake_images = G(noise)
        output = D(fake_images.detach()).view(-1)
        loss_fake = criterion(output, fake_labels)

        # Backprop Discriminator
        d_loss = loss_real + loss_fake
        d_loss.backward()
        optimizerD.step()

        # Train Generator
        G.zero_grad()
        output = D(fake_images).view(-1)
        g_loss = criterion(output, real_labels)  # Trick D into thinking fakes are real
        g_loss.backward()
        optimizerG.step()

        if i % 100 == 0:
            print(f"Epoch [{epoch+1}/{epochs}] Step [{i}/{len(dataloader)}] D_loss: {d_loss.item():.4f} G_loss: {g_loss.item():.4f}")

    # Save image grid after each epoch
    with torch.no_grad():
        fake = G(fixed_noise).detach().cpu()
        grid = vutils.make_grid(fake, padding=2, normalize=True)
        img_path = f"generated/epoch_{epoch+1}.png"
        vutils.save_image(grid, img_path)
        img_list.append(img_path)


ValueError: Using a target size (torch.Size([128])) that is different to the input size (torch.Size([3200])) is deprecated. Please ensure they have the same size.

In [None]:
gif_path = "generated/dcgan_mnist.gif"
with imageio.get_writer(gif_path, mode='I') as writer:
    for filename in img_list:
        image = imageio.imread(filename)
        writer.append_data(image)
print("GIF saved at:", gif_path)


In [None]:
torch.save(G.state_dict(), "dcgan_mnist_generator.pth")
