<a href="https://colab.research.google.com/github/Aiden-Ross-Dsouza/Generative-Models/blob/main/Generative%20Adversarial%20Networks/notebooks/DCGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install Packages

In [None]:
pip install pytorch-fid



# Import Liraries

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import matplotlib.pyplot as plt
import os
import subprocess

# Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            # Input: N X channels_img X 64 X 64
            nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1), # 32 X 32
            nn.LeakyReLU(0.2),
            self.block(features_d, features_d*2, 4, 2, 1), # 16 X 16
            self.block(features_d*2, features_d*4, 4, 2, 1), # 8 X 8
            self.block(features_d*4, features_d*8, 4, 2, 1), # 4 X 4
            nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0), # 1 X 1
            nn.Sigmoid(),
        )

    def block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
                             nn.BatchNorm2d(out_channels),
                             nn.LeakyReLU(0.2))

    def forward(self, x):
        return self.disc(x)

# Generator

In [None]:
class Generator(nn.Module):
  def __init__(self, z_dim, channels_img, features_g):
    super(Generator, self).__init__()
    self.gen = nn.Sequential(
        # Input: N X z_dim X 1 X 1
        self.block(z_dim, features_g*16, 4, 1, 0), #N X f_g*16 X 4 X 4
        self.block(features_g*16, features_g*8, 4, 2, 1), # 8 X 8
        self.block(features_g*8, features_g*4, 4, 2, 1), # 16 X 16
        self.block(features_g*4, features_g*2, 4, 2, 1), # 32 X 32
        nn.ConvTranspose2d(features_g*2, channels_img, kernel_size=4, stride=2, padding=1), # 64 X 64
        nn.Tanh(),
    )

  def block(self, in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
    )

  def forward(self, x):
    return self.gen(x)

# Initialize weights

In [None]:
def initialize_weights(model):
  for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
      nn.init.normal_(m.weight.data, 0.0, 0.02)

# Test

In [None]:
def test():
  N, in_channels, H, W = 8, 3, 64, 64
  z_dim = 100
  x = torch.randn((N, in_channels, H, W))
  disc = Discriminator(in_channels, 8)
  initialize_weights(disc)
  assert disc(x).shape == (N, 1, 1, 1)
  gen = Generator(z_dim, in_channels, 8)
  initialize_weights(gen)
  z = torch.randn((N, z_dim, 1, 1))
  assert gen(z).shape == (N, in_channels, H, W)
  print("Success")

# Visualize Generated Images

In [None]:
def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    image_tensor = (image_tensor + 1) / 2  # De-normalize the images to [0, 1] range
    image_unflat = image_tensor.detach().cpu().view(-1, *size)
    image_grid = torchvision.utils.make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

# Save Images

In [None]:
os.makedirs('generated_images', exist_ok=True)
os.makedirs('real_images', exist_ok=True)

def save_images(images, directory, prefix, num_images=25):
    images = (images + 1) / 2  # De-normalize the images to [0, 1] range
    image_grid = torchvision.utils.make_grid(images[:num_images], nrow=5)
    image_grid = image_grid.permute(1, 2, 0).cpu().numpy() * 255
    image_filename = os.path.join(directory, f"{prefix}.png")
    plt.imsave(image_filename, image_grid.astype(np.uint8))

# Train

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lr = 2e-4
batch_size = 128
image_size = 64
channels_img = 3
z_dim = 100
num_epochs = 5
features_disc = 64
features_gen = 64

image_transforms = transforms.Compose(
    [
     transforms.Resize(image_size),
     transforms.ToTensor(),
     transforms.Normalize([0.5 for _ in range(channels_img)], [0.5 for _ in range(channels_img)]),
    ]
)

dataset = datasets.CIFAR10(root = "dataset/", train=True, transform=image_transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
gen = Generator(z_dim, channels_img, features_gen).to(device)
disc = Discriminator(channels_img, features_disc).to(device)
initialize_weights(gen)
initialize_weights(disc)

opt_gen = optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=lr, betas=(0.5, 0.999))
criterion = nn.BCELoss()

fixed_noise = torch.randn(32, z_dim, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0

gen.train()
disc.train()

gen_losses = []
disc_losses = []
iterations = []

for epoch in range(num_epochs):
  for batch_idx, (real, _) in enumerate(loader):
    real = real.to(device)
    noise = torch.randn((batch_size, z_dim, 1, 1)).to(device)
    fake = gen(noise)

    # Train Discriminator max log(D(x)) + log(1-d(G(z)))
    disc_real = disc(real).reshape(-1)
    loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
    disc_fake = disc(fake).reshape(-1)
    loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
    loss_disc = (loss_disc_real + loss_disc_fake) / 2
    disc.zero_grad()
    loss_disc.backward(retain_graph=True)
    opt_disc.step()

    #Train Generator min log(1-D(G(z))) <-> max log(D(G(Z)))
    output = disc(fake).reshape(-1)
    loss_gen = criterion(output, torch.ones_like(output))
    gen.zero_grad()
    loss_gen.backward()
    opt_gen.step()

    gen_losses.append(loss_gen.item())
    disc_losses.append(loss_disc.item())
    iterations.append(epoch * len(loader) + batch_idx)

    #Print losses occassionally
    if batch_idx % 100 == 0:
      print(
          f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
          Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"
      )

      with torch.no_grad():
        fake = gen(fixed_noise)
        img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
        img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)

        writer_real.add_image("Real", img_grid_real, global_step=step)
        writer_fake.add_image("Fake", img_grid_fake, global_step=step)

        save_images(real.cpu(), 'real_images', f'real_epoch_{epoch}_batch_{batch_idx}')
        save_images(fake.cpu(), 'generated_images', f'fake_epoch_{epoch}_batch_{batch_idx}')

        show_tensor_images(fake, num_images=25, size=(channels_img, image_size, image_size))

      step += 1

Files already downloaded and verified


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


# Calculate Frechet Inception Distance

In [None]:
def compute_fid(real_images_dir, fake_images_dir):
    # Call the FID computation from the command line
    result = subprocess.run([
        'python', '-m', 'pytorch_fid', real_images_dir, fake_images_dir
    ], capture_output=True, text=True)

    # Parse and print FID score
    fid_score_line = [line for line in result.stdout.split('\n') if 'FID' in line]
    if fid_score_line:
        fid_score = fid_score_line[0].split()[-1]
        print(f"FID Score: {fid_score}")

# Call this function after training
compute_fid('real_images', 'generated_images')


# Plot Losses

In [None]:
plt.figure(figsize=(12, 6))
plt.plot(iterations, gen_losses, label='Generator Loss')
plt.plot(iterations, disc_losses, label='Discriminator Loss')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.title('Generator and Discriminator Losses')
plt.legend()
plt.show()