In [1]:
import sys, torch
from torch import nn
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt

sys.path.append("..")
from models.vis_utils import show_img_batch
from models.discriminator_utils import DiscConvBlock
from models.generator_utils import GenConvTransposeBlock
from models.utils import get_noise, weights_init
from models.DCGAN.mnist import (Generator as MnistGenerator, 
                                Discriminator as MnistDiscriminator)
from models.DCGAN.celeba import (Generator as CelebaGenerator, 
                                 Discriminator as CelebaDiscriminator)
from models.DCGAN.train import train
%matplotlib inline

device = "mps" if torch.backends.mps.is_available() else "cpu"

## MNIST dataset
We create a DCGAN model that generates handwritten digits based on the famous [MNIST](https://yann.lecun.com/exdb/mnist/) dataset. The flow is roughly broken down into four steps:
1. Define parameters for the model, the optimizer and the training loop.
2. Create an `iterator`, which is a [pytorch dataloader](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html), to hold the training dataset.
3. Initialize the Generator and the Discriminator, where we follow the architecture proposed in the the [DCGAN paper](https://arxiv.org/abs/1511.06434).
4. Train the two models, and visualize intermediate steps.

In [2]:
# Training params
criterion = nn.BCEWithLogitsLoss()
display_step = 2500
batch_size = 128

# Model params
n_samples = 100
z_dim = 128
hidden_dim = 64

# Optimizer params
lr = 0.0002
beta_1 = 0.5
beta_2 = 0.999

In [3]:
mnist_transforms = transforms.Compose([transforms.ToTensor(), 
                                       transforms.Normalize((0.5,), (0.5,))])
mnist_dt = MNIST(root="../datasets/", download=False, transform=mnist_transforms)
mnist_dl = DataLoader(mnist_dt, batch_size=batch_size, shuffle=True)

In [4]:
device = "mps" if torch.backends.mps.is_available() else "cpu"

# Initialize the generator and discriminator
gen  = MnistGenerator(z_dim=z_dim  , hidden_dim=hidden_dim).to(device)
disc = MnistDiscriminator(im_chan=1, hidden_dim=hidden_dim).to(device)

# Initialize the optimizers for generator and discriminator
gen_opt  = torch.optim.Adam(gen.parameters() , lr=lr, betas=(beta_1, beta_2))
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr, betas=(beta_1, beta_2))

gen = gen.apply(weights_init)
disc = disc.apply(weights_init)

noise_input = get_noise(n_samples, z_dim, device)
fake = gen(noise_input)
output = disc(fake)

In [None]:
n_epochs = 100
display_step = 5000
save_step = 10000
cur_step = 0
generator_loss = 0
discriminator_loss = 0

for epoch in range(n_epochs):

    for real, _ in tqdm(mnist_dl):
        cur_batch_size = len(real)
        real = real.to(device)

        ## Update discriminator ##
        disc_opt.zero_grad()
        
        # Get noise corresponding to the current batch_size
        noise = get_noise(cur_batch_size, z_dim, device=device)
                
        fake = gen(noise)
        disc_fake = disc(fake.detach())
        disc_fake_loss = criterion(disc_fake, torch.zeros_like(disc_fake, device=device))
        disc_real = disc(real)
        disc_real_loss = criterion(disc_real, torch.ones_like(disc_fake, device=device))
        disc_loss = (disc_real_loss + disc_fake_loss) / 2

        disc_loss.backward(retain_graph=True)
        disc_opt.step()

        discriminator_loss += disc_loss.item()
        
        ### Update generator ###
        gen_opt.zero_grad()
        noise = get_noise(cur_batch_size, z_dim, device=device)
                
        fake = gen(noise)
        gen_fake = disc(fake)
        gen_fake_loss = criterion(gen_fake, torch.ones_like(gen_fake, device=device))
        gen_fake_loss.backward()
        gen_opt.step()        
        
        # Keep track of the average generator loss
        generator_loss += gen_fake_loss.item()

        ### Visualization code ###
        if cur_step % display_step == 0 and cur_step > 0:
            mean_gen_loss = generator_loss/display_step
            mean_disc_loss = discriminator_loss/display_step
            print(f"Epoch {epoch}, step {cur_step}: Generator loss: {mean_gen_loss:.2f}, discriminator loss: {mean_disc_loss:.2f}")
            generator_loss = 0
            discriminator_loss = 0
            show_img_batch((real + 1) / 2, save_path=f"./figures/mnist_epoch_{epoch}_step_{cur_step}_real.jpg")
            show_img_batch((fake + 1) / 2, save_path=f"./figures/mnist_epoch_{epoch}_step_{cur_step}_fake.jpg")
        # if cur_step % save_step == 0and cur_step > 0:
        #     torch.save({
        #         'generator': gen.state_dict(),
        #         'discriminator': disc.state_dict(),
        #         'gen_opt': gen_opt.state_dict(),
        #         'disc_opt': disc_opt.state_dict(),
        #         'generator_loss': mean_gen_loss,
        #         'discriminator_loss': mean_disc_loss
        #     }, f"./assets/DCGAN/mnist_epoch_{epoch}_step_{cur_step}.pth")
            # print(f"Models, optimizers, and mean losses saved at step {cur_step}")
        cur_step += 1


### Celeba Dataset

[CelebFaces Attributes Dataset (CelebA)](https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) is a large-scale face attributes dataset with more than **200K** celebrity images, each with **40** attribute annotations. The images in this dataset cover large pose variations and background clutter. CelebA has large diversities, large quantities, and rich annotations, including

- **10,177** number of **identities**
- **202,599** number of **face images**, and
- **5 landmark locations**, **40 binary attributes** annotations per image


In this section, we focus on generating faces 

In [7]:
from torchvision.datasets import ImageFolder

In [None]:
# celeba = ImageFolder("../Data/celeba", celeba_transform)
celeba_transform = transforms.Compose([transforms.ToTensor(),
                                         transforms.CenterCrop((178, 178)),
                                         transforms.Resize((64, 64))])
celeba = ImageFolder("../datasets/celeba", transform=celeba_transform)
celeba_dl = DataLoader(celeba, batch_size=128, shuffle=True)
img_batch, labels = next(iter(celeba_dl))
# show_img_batch(img_batch, size=(3,178,178))
show_img_batch(img_batch, size=(3,64,64))

In [15]:
# Training params
criterion = nn.BCEWithLogitsLoss()
display_step = 2500
batch_size = 128

# Model params
im_chan = 3
z_dim = 128
size = (3,64,64)

# Optimizer params
lr = 0.0002
beta_1 = 0.5 
beta_2 = 0.999

In [16]:
# Initialize the generator and discriminator
gen = CelebaGenerator(z_dim=z_dim, im_chan=im_chan).to(device)
disc = CelebaDiscriminator(im_chan=im_chan).to(device)

# Initialize the optimizers for generator and discriminator
gen_opt = torch.optim.Adam(gen.parameters(),   lr=lr, betas=(beta_1, beta_2))
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr, betas=(beta_1, beta_2))

gen = gen.apply(weights_init)
disc = disc.apply(weights_init)

In [None]:
n_epochs = 100
cur_step = 0
generator_loss = 0
discriminator_loss = 0

for epoch in range(n_epochs):

    for real, _ in tqdm(celeba_dl):
        cur_batch_size = len(real)
        real = real.to(device)

        ## Update discriminator ##
        disc_opt.zero_grad()
        
        # Get noise corresponding to the current batch_size
        noise = get_noise(cur_batch_size, z_dim, device=device)
                
        fake = gen(noise)
        disc_fake = disc(fake.detach())
        disc_fake_loss = criterion(disc_fake, torch.zeros_like(disc_fake, device=device))
        disc_real = disc(real)
        disc_real_loss = criterion(disc_real, torch.ones_like(disc_fake, device=device))
        disc_loss = (disc_real_loss + disc_fake_loss) / 2

        disc_loss.backward(retain_graph=True)
        disc_opt.step()

        discriminator_loss += disc_loss.item()
        
        ### Update generator ###
        gen_opt.zero_grad()
        noise = get_noise(cur_batch_size, z_dim, device=device)
                
        fake = gen(noise)
        gen_fake = disc(fake)
        gen_fake_loss = criterion(gen_fake, torch.ones_like(gen_fake, device=device))
        gen_fake_loss.backward()
        gen_opt.step()        
        
        # Keep track of the average generator loss
        generator_loss += gen_fake_loss.item()

        ### Visualization code ###
        if cur_step % display_step == 0 and cur_step > 0:
            mean_gen_loss = generator_loss/display_step
            mean_disc_loss = discriminator_loss/display_step
            print(f"Epoch {epoch}, step {cur_step}: Generator loss: {mean_gen_loss:.2f}, discriminator loss: {mean_disc_loss:.2f}")
            generator_loss = 0
            discriminator_loss = 0
            show_img_batch(real, size=size)
            show_img_batch(fake, size=size)
        cur_step += 1
