In [None]:
!pip install torch torchvision matplotlib numpy

Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch)
  Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)
Collecting nvidia-curand-cu12==10.3.2.106 (from torch)
  Using cached nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)
Collectin

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
import os
import numpy as np
import matplotlib.pyplot as plt

Generator

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)

Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

Training Setup

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create the generator and discriminator
netG = Generator().to(device)
netD = Discriminator().to(device)

# Initialize the weights
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

netG.apply(weights_init)
netD.apply(weights_init)

# Loss and optimizers
criterion = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))

num_epochs = 50
batch_size = 64
latent_vector_size = 100

# Data transformation for CIFAR-10
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load CIFAR-10 dataset
dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Function to convert RGB images to grayscale
def rgb_to_grayscale(batch):
    return batch.mean(dim=1, keepdim=True)

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:03<00:00, 46563957.70it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data


Training Loop

In [None]:
import os

# Create the output directory if it doesn't exist
os.makedirs('output', exist_ok=True)


In [None]:
for epoch in range(num_epochs):
    for i, data in enumerate(dataloader, 0):
        # Update Discriminator
        netD.zero_grad()
        real_images, _ = data
        real_images = real_images.to(device)

        # Convert to grayscale
        grayscale_images = rgb_to_grayscale(real_images)

        batch_size = real_images.size(0)
        labels = torch.full((batch_size,), 1, dtype=torch.float, device=device)

        output = netD(real_images).view(-1)
        errD_real = criterion(output, labels)
        errD_real.backward()

        noise = torch.randn(batch_size, latent_vector_size, 1, 1, device=device)
        fake_images = netG(noise)
        labels.fill_(0)
        output = netD(fake_images.detach()).view(-1)
        errD_fake = criterion(output, labels)
        errD_fake.backward()
        optimizerD.step()

        # Update Generator
        netG.zero_grad()
        labels.fill_(1)
        output = netD(fake_images).view(-1)
        errG = criterion(output, labels)
        errG.backward()
        optimizerG.step()

        if i % 100 == 0:
            print(f'[{epoch}/{num_epochs}][{i}/{len(dataloader)}] Loss_D: {errD_real.item() + errD_fake.item()} Loss_G: {errG.item()}')

    # Save fake images every epoch
    save_image(fake_images, f'output/fake_images_epoch_{epoch}.png', normalize=True)


[0/50][0/782] Loss_D: 0.6486326523590833 Loss_G: 8.115062713623047
[0/50][100/782] Loss_D: 0.14724233746528625 Loss_G: 4.501054763793945
[0/50][200/782] Loss_D: 0.24880269169807434 Loss_G: 4.623174667358398
[0/50][300/782] Loss_D: 1.2991057224571705 Loss_G: 2.461794137954712
[0/50][400/782] Loss_D: 0.7386796027421951 Loss_G: 2.32002854347229
[0/50][500/782] Loss_D: 0.7644049376249313 Loss_G: 3.7227911949157715
[0/50][600/782] Loss_D: 1.1464663445949554 Loss_G: 5.486776351928711
[0/50][700/782] Loss_D: 0.9613341689109802 Loss_G: 1.838554859161377
[1/50][0/782] Loss_D: 0.4510817676782608 Loss_G: 2.6202545166015625
[1/50][100/782] Loss_D: 0.2887461483478546 Loss_G: 5.126065254211426
[1/50][200/782] Loss_D: 0.47905173897743225 Loss_G: 3.712486505508423
[1/50][300/782] Loss_D: 0.873783016577363 Loss_G: 5.537789344787598
[1/50][400/782] Loss_D: 0.6930228769779205 Loss_G: 1.9276821613311768
[1/50][500/782] Loss_D: 0.5091284960508347 Loss_G: 2.315997838973999
[1/50][600/782] Loss_D: 0.72202861

Testing the Generator

In [None]:
with torch.no_grad():
    noise = torch.randn(64, latent_vector_size, 1, 1, device=device)
    fake_images = netG(noise).detach().cpu()
    save_image(fake_images, 'output/test_fake_images.png', normalize=True)
