# Deep Convolutional Generative Adversarial Networks (DCGAN)

In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as T

In [2]:
# device
assert torch.cuda.is_available()

In [3]:
# Training Parameters
LR = 2e-4
NUM_EPOCHS = 10
BATCH_SIZE = 64
LATENT_SIZE = 100
IMG_SIZE = 28*28
DEVICE = torch.device("cuda")

In [4]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.generator = nn.Sequential(
            nn.ConvTranspose2d(in_channels=LATENT_SIZE, out_channels=1024, kernel_size=4, bias=False),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=4, padding=1, stride=2, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, padding=1, stride=2, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, padding=1, stride=2, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(in_channels=128, out_channels=1, kernel_size=4, padding=1, stride=2, bias=True),
            nn.Tanh()
        )
    
    def forward(self, x):
        return self.generator(x)

In [5]:
# check generator
gen = Generator().to(DEVICE)
dummy_input = torch.randn(BATCH_SIZE, LATENT_SIZE, 1, 1, device=DEVICE)
gen(dummy_input).shape

torch.Size([64, 1, 64, 64])

In [6]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.discriminator = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=128, kernel_size=4, padding=1, stride=2, bias=True),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, padding=1, stride=2, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, padding=1, stride=2, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=4, padding=1, stride=2, bias=False),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_channels=1024, out_channels=1, kernel_size=4, padding=0, stride=1, bias=True),
            nn.Flatten()
        )
    
    def forward(self, x):
        return self.discriminator(x)

In [7]:
# check discriminator
dis = Discriminator()
dummy_input = torch.randn(BATCH_SIZE, 1, 64, 64)
dis(dummy_input).shape

torch.Size([64, 1])

In [8]:
# dataset and dataloader
transform = T.Compose([T.Resize(64), T.ToTensor(), T.Normalize(mean=(0.5 , ), std=(0.5,)) ])
dataset = datasets.MNIST(root='../datasets/', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=2)

In [9]:
generator = Generator().to(DEVICE)
discriminator = Discriminator().to(DEVICE)
gen_optim = optim.Adam(generator.parameters(), lr=LR, betas=(0.5, 0.999))
dis_optim = optim.Adam(discriminator.parameters(), lr=LR, betas=(0.5, 0.999))
criterion = nn.BCEWithLogitsLoss()
writer = SummaryWriter('runs/dcgan_mnist')

In [10]:
# this tensor is useful to track images that are created from a fixed latent variable
# that way we can observe if the generator starts to create better images
fixed_noise = torch.randn((BATCH_SIZE, LATENT_SIZE, 1, 1), device=DEVICE)

In [11]:
for epoch in range(NUM_EPOCHS):
    dis_loss_col = []
    gen_loss_col = []
    for batch_idx, (features, _) in enumerate(dataloader):
        real_images = features.to(DEVICE)
        
        # generate fake images from standar normal distributed latent vector
        latent_vector = torch.randn(BATCH_SIZE, LATENT_SIZE, 1, 1, device=DEVICE)
        fake_imgs = generator(latent_vector)
        
        # calculate logits for true and fake images
        fake_logits = discriminator(fake_imgs.detach())
        real_logits = discriminator(real_images)
        
        # calculate discriminator loss
        dis_real_loss = criterion(real_logits, torch.ones(BATCH_SIZE, 1, device=DEVICE))
        dis_fake_loss = criterion(fake_logits, torch.zeros(BATCH_SIZE, 1, device=DEVICE))
        dis_loss = dis_real_loss + dis_fake_loss
        
        # optimize the discriminator 
        dis_optim.zero_grad()
        dis_loss.backward()
        dis_optim.step()

        # calculate generator loss
        gen_loss = criterion(discriminator(fake_imgs), torch.ones(BATCH_SIZE, 1, device=DEVICE))
        
        # optimize the generator
        gen_optim.zero_grad()
        gen_loss.backward()
        gen_optim.step()
        
        dis_loss_col.append(dis_loss.cpu().item())
        gen_loss_col.append(gen_loss.cpu().item())

    dis_loss = sum(dis_loss_col) / len(dis_loss_col)
    gen_loss = sum(gen_loss_col) / len(gen_loss_col)

    print(f'Epoch {epoch+1}/{NUM_EPOCHS} || Discriminator Loss: {dis_loss:.4f} || Generator Loss: {gen_loss:.4f}')
    
    with torch.inference_mode():
        fake_imgs = generator(fixed_noise)
        grid = torchvision.utils.make_grid(fake_imgs, normalize=True)

        writer.add_image(
            "MNIST DCGAN Generated Images", grid, global_step=epoch+1
        )

Epoch 1/10 || Discriminator Loss: 0.3120 || Generator Loss: 6.8920
Epoch 2/10 || Discriminator Loss: 0.8425 || Generator Loss: 2.2131
Epoch 3/10 || Discriminator Loss: 0.5060 || Generator Loss: 3.5763
Epoch 4/10 || Discriminator Loss: 0.4448 || Generator Loss: 4.0151
Epoch 5/10 || Discriminator Loss: 0.3350 || Generator Loss: 4.4055
Epoch 6/10 || Discriminator Loss: 0.2965 || Generator Loss: 4.7895
Epoch 7/10 || Discriminator Loss: 0.3675 || Generator Loss: 4.5363
Epoch 8/10 || Discriminator Loss: 0.1050 || Generator Loss: 6.3591
Epoch 9/10 || Discriminator Loss: 0.0916 || Generator Loss: 9.0453
Epoch 10/10 || Discriminator Loss: 0.0519 || Generator Loss: 7.5223
