In [43]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, transforms
from torchmetrics.image.fid import FrechetInceptionDistance
from tqdm import tqdm

Constants

In [44]:
data_dir = './data'
img_size = 64
z_dim = 100 # Size of the noise vector
ngf = 64 # Size of feature maps in generator
ndf = 64 # Size of feature maps in discriminator

num_epochs = 10
batch_size = 32
learning_rate = 0.0002

ckpt_every = 300 # Save checkpoint every 300 iterations
visualize_every = 300 # Visualize every 300 iterations
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cuda


In [3]:
transform = transforms.Compose([
    transforms.Resize(img_size),
    transforms.CenterCrop(img_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_set = datasets.CelebA(data_dir, split='train', download=True, transform=transform)
test_set = datasets.CelebA(data_dir, split='test', download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


Networks

In [60]:
class Generator(nn.Module):
  def __init__(self, z_dim=100, ngf=ngf, channels=3):
    super(Generator, self).__init__()

    self.model = nn.Sequential(
      # Input: z_dim x 1 x 1 -> Output: ngf*8 x 4 x 4
      nn.ConvTranspose2d(z_dim, ngf*8, kernel_size=4, stride=1, padding=0, bias=False),
      nn.BatchNorm2d(ngf*8),
      nn.ReLU(True),
      # Input: ngf*8 x 4 x 4 -> Output: ngf*4 x 8 x 8
      nn.ConvTranspose2d(ngf*8, ngf*4, kernel_size=4, stride=2, padding=1, bias=False),
      nn.BatchNorm2d(ngf*4),
      nn.ReLU(True),
      # Input: ngf*4 x 8 x 8 -> Output: ngf*2 x 16 x 16
      nn.ConvTranspose2d(ngf*4, ngf*2, kernel_size=4, stride=2, padding=1, bias=False),
      nn.BatchNorm2d(ngf*2),
      nn.ReLU(True),
      # Input: ngf*2 x 16 x 16 -> Output: ngf x 32 x 32
      nn.ConvTranspose2d(ngf*2, ngf, kernel_size=4, stride=2, padding=1, bias=False),
      nn.BatchNorm2d(ngf),
      nn.ReLU(True),
      # Input: ngf x 32 x 32 -> Output: channels x 64 x 64
      nn.ConvTranspose2d(ngf, channels, kernel_size=4, stride=2, padding=1, bias=False),
      nn.Tanh()
    )

  def forward(self, x):
    return self.model(x)

In [54]:
class Discriminator(nn.Module):
  def __init__(self, ndf=ndf, channels=3):
    super(Discriminator, self).__init__()

    self.model = nn.Sequential(
      # Input size: (channels) x 64 x 64 -> Output size: (ndf) x 32 x 32
      nn.Conv2d(channels, ndf, kernel_size=4, stride=2, padding=1, bias=False),
      nn.LeakyReLU(0.2, inplace=True),
      # Input size: (ndf) x 32 x 32 -> Output size: (ndf*2) x 16 x 16
      nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1, bias=False),
      nn.BatchNorm2d(ndf*2),
      nn.LeakyReLU(0.2, inplace=True),
      # Input size (ndf*2) x 16 x 16 -> Output size: (ndf*4) x 8 x 8 
      nn.Conv2d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1, bias=False),
      nn.BatchNorm2d(ndf*4),
      nn.LeakyReLU(0.2, inplace=True),
      # Input size (ndf*4) x 8 x 8 -> Output size: (ndf*8) x 4 x 4
      nn.Conv2d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1, bias=False),
      nn.BatchNorm2d(ndf*8),
      nn.LeakyReLU(0.2, inplace=True),
      # Input size: (ndf*8) x 4 x 4 -> Output size: 1 x 1 x 1
      nn.Conv2d(ndf*8, 1, kernel_size=4, stride=1, padding=0, bias=False),
    )

  def forward(self, x):
    return self.model(x)

In [52]:
d = Discriminator().to(device)
img = torch.randn(32, 3, 64, 64).to(device)
print(d(img).shape)

torch.Size([32, 1, 1, 1])


In [24]:
def init_weights(m):
  if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d:
    nn.init.normal_(m.weight, 0.0, 0.02)
  if type(m) == nn.BatchNorm2d:
    nn.init.normal_(m.weight, 1.0, 0.02)
    nn.init.constant_(m.bias, 0)

Helper functions

In [57]:
def visualize_batch(fixed_noise, netG, epoch):
  with torch.no_grad():
    fake = netG(fixed_noise)
    grid = torchvision.utils.make_grid(fake, nrow=8, normalize=True)
    torchvision.utils.save_image(grid, f'./samples/{epoch}.png', normalize=True)

In [26]:
def save_checkpoint(netG, netD, epoch):
  torch.save(netG.state_dict(), f'./checkpoint/netG_{epoch}.pth')
  torch.save(netD.state_dict(), f'./checkpoint/netD_{epoch}.pth')

Training

In [58]:
def train():
  # networks
  netG = Generator(z_dim).to(device)
  netG.apply(init_weights)
  netD = Discriminator().to(device)
  netD.apply(init_weights)

  # optimizers
  optimG = torch.optim.Adam(netG.parameters(), lr=learning_rate, betas=(0.5, 0.999))
  optimD = torch.optim.Adam(netD.parameters(), lr=learning_rate, betas=(0.5, 0.999))

  # loss function
  criterion = nn.BCEWithLogitsLoss()

  # fixed noise for visualization
  fixed_noise = torch.randn(64, z_dim, 1, 1, device=device)

  # training loop
  losses_G = []
  losses_D = []

  for epoch in range(num_epochs):
    print(f"Epoch {epoch}/{num_epochs}: ")
    pbar = tqdm(enumerate(train_loader), total=len(train_loader))
    for i, (real, _) in pbar:
      running_loss_G = 0.0
      running_loss_D = 0.0

      real = real.to(device)
      batch_size = real.size(0)

      # train discriminator
      netD.zero_grad()
      label = torch.full((batch_size,), 1, device=device, dtype=torch.float32)
      output = netD(real).view(-1)
      lossD_real = criterion(output, label)
      lossD_real.backward()

      noise = torch.randn(batch_size, z_dim, 1, 1, device=device)
      fake = netG(noise)
      label = torch.full((batch_size,), 0, device=device, dtype=torch.float32)
      output = netD(fake.detach()).view(-1)
      lossD_fake = criterion(output, label)
      lossD_fake.backward()

      lossD = lossD_real + lossD_fake
      optimD.step()

      running_loss_D += lossD.item()

      # train generator
      netG.zero_grad()
      label = torch.full((batch_size,), 1, device=device, dtype=torch.float32)
      output = netD(fake).view(-1)
      lossG = criterion(output, label)
      lossG.backward()
      optimG.step()

      running_loss_G += lossG.item()

      pbar.set_description_str(f'LossD: {lossD.item():.4f} LossG: {lossG.item():.4f}')
      if i % visualize_every == 0:
        visualize_batch(fixed_noise, netG, f"epoch_{epoch}_iter_{i}")
      
      if i % ckpt_every == 0:
        save_checkpoint(netG, netD, f"epoch_{epoch}_iter_{i}")

    losses_G.append(running_loss_G / len(train_loader))
    losses_D.append(running_loss_D / len(train_loader))

  return losses_G, losses_D

In [59]:
train()

Epoch 0/10: 


LossD: 0.0649 LossG: 7.2936:   0%|          | 17/5087 [00:05<27:43,  3.05it/s] 


KeyboardInterrupt: 

Inference

In [37]:
def inference(ckpt_path, num_samples=64):
  netG = Generator(z_dim).to(device)
  netG.load_state_dict(torch.load(ckpt_path))
  netG.eval()

  with torch.no_grad():
    noise = torch.randn(num_samples, z_dim, 1, 1, device=device)
    fake = netG(noise)
    grid = torchvision.utils.make_grid(fake, nrow=8, normalize=True)
    torchvision.utils.save_image(grid, f'./inference/inference.png', normalize=True)

In [41]:
inference('./checkpoint/netG_epoch_0_iter_1200.pth', 32)

Evaluation

In [18]:
def calculate_fid(ckpt_path):
  netG = Generator(z_dim).to(device)
  netG.load_state_dict(torch.load(ckpt_path))
  netG.eval()

  fid = FrechetInceptionDistance(feature=64, normalize=True).to(device)

  with torch.no_grad():
    for real, _ in tqdm(test_loader):
      real = real.to(device)
      noise = torch.randn(real.size(0), z_dim, 1, 1, device=device)
      fake = netG(noise)
      
      fid.update(real, real=True)
      fid.update(fake, real=False)

  print(f'FID: {fid.compute():.4f}')

In [19]:
calculate_fid('./checkpoint/netG_epoch_0_iter_1200.pth')

100%|██████████| 624/624 [00:36<00:00, 17.23it/s]


FID: 12.9965
