In [32]:
# model.py

import torch.nn as nn

class Generator(nn.Module):

    def __init__(self, noise_channels=100, img_channels=3):
        super().__init__()
        self.conv_layers = nn.Sequential(

            # 1st fractional strided convolution layer (upsample from 1*1 -> 4*4)
            # Projection layer, to convert the z of 100 inputs to 1024 * 4 * 4 (noise_channels = z_dim)
            # Each input (z) will be actually reshaped to 100 * 1 * 1 (100 channels)
            # (to ensure from 1x1 -> 4x4, with stride = 2 and kernal = 4, we need padding = 0 now (for a x4 increase))
            self._block(in_channels=noise_channels, out_channels=1024, kernel_size=4, stride=2, padding=0),

            # 2nd fractional strided convolution layer (upsample from 4*4 -> 8*8)
            self._block(in_channels=1024, out_channels=512, kernel_size=4, stride=2, padding=1),

            # 3rd fractional strided convolution layer (upsample from 8*8 -> 16*16)
            self._block(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1),
            
            # 4th fractional strided convolution layer (upsample from 16*16 -> 32*32)
            self._block(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1),

            # Output fractional strided convolution layer (upsample from 32*32 -> 64*64)
            nn.ConvTranspose2d(in_channels=128, out_channels=img_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )
    
    def _block(self, in_channels, out_channels, kernel_size, stride, padding, batch_norm=True):

        return nn.Sequential(
            nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        ) if batch_norm else nn.Sequential(
            nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.ReLU(),
        )

    def forward(self, z):
        return self.conv_layers(z)


class Discriminator(nn.Module):

    def __init__(self, img_channels=3):
        super().__init__()

        self.conv_layers = nn.Sequential(
            
            # 1st fractional strided convolution layer (downsample from 64*64 -> 32*32)
            self._block(in_channels=img_channels, out_channels=128, kernel_size=4, stride=2, padding=1, batch_norm=False),

            # 2nd fractional strided convolution layer (downsample from 32*32 -> 16*16)
            self._block(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1),
            
            # 3rd fractional strided convolution layer (downsample from 16*16 -> 8*8)
            self._block(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1),

            # Output fractional strided convolution layer (downsample from 8*8 -> 4*4)
            self._block(in_channels=512, out_channels=1024, kernel_size=4, stride=2, padding=1),
            
            # Classifier
            # No fully connected layer for DCGAN, use another way (instead of nn.Flatten(), nn.Linear(in_features=1024*4*4, out_features=1))
            # Use another convolutional layer (to ensure from 4x4 to 1x1, with stride = 2 and kernal = 4, we need padding = 0 now (for a x4 reduction))
            nn.Conv2d(in_channels=1024, out_channels=1, kernel_size=4, stride=2, padding=0),
            nn.Sigmoid() # ensure prediction is within [0, 1]
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding, batch_norm=True):

        return nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(negative_slope=0.2)
        ) if batch_norm else nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.LeakyReLU(negative_slope=0.2)
        )
    
    def forward(self, x):
        return self.conv_layers(x)

In [33]:
# utils.py
import os
from PIL import Image
from matplotlib import pyplot as plt
import torch
import torchvision


def initialize_weights(model): # Weights are initialized from Normal Distribution with mean = 0; standard deviation = 0.02.
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)


def get_dataloader(img_dir, batch_size=64, img_channels=3, img_size=64, transforms=None):
    transforms = torchvision.transforms.Compose(
        [
            torchvision.transforms.Resize(size=(img_size, img_size)),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                [0.5 for _ in range(img_channels)],
                [0.5 for _ in range(img_channels)]
            )
        ]
    )

    dataset = torchvision.datasets.ImageFolder(root=img_dir, transform=transforms)

    # Create the dataloader
    NUM_WORKERS = os.cpu_count()

    dataloader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        num_workers=NUM_WORKERS,
        shuffle=True,
        pin_memory=True
    )

    return dataloader


def plot_images(images):
    plt.figure(figsize=(32, 32))
    plt.imshow(torch.cat([
        torch.cat([i for i in images.cpu()], dim=-1),
    ], dim=-2).permute(1, 2, 0).cpu())
    plt.show()


def save_images(images, path, **kwargs):
    grid = torchvision.utils.make_grid(images, **kwargs)
    ndarr = grid.permute(1, 2, 0).to('cpu').numpy()
    im = Image.fromarray(ndarr)
    im.save(path)


def setup_logging(run_name):
    os.makedirs("models", exist_ok=True)
    os.makedirs("results", exist_ok=True)
    os.makedirs(os.path.join("models", run_name), exist_ok=True)
    os.makedirs(os.path.join("results", run_name), exist_ok=True)

In [34]:
# train.py
import os
from pathlib import Path
import argparse
import logging
from tqdm import tqdm

import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S")

def train(args):
    # Setup
    setup_logging(args.run_name)
    device = args.device
    dataloader = get_dataloader(args.dataset_path, args.batch_size, args.img_channels, args.img_size, args.transforms)
    generator = Generator(noise_channels=args.z_dim, img_channels=args.img_channels).to(device)
    discriminator = Discriminator(img_channels=args.img_channels).to(device)
    criterion = nn.BCELoss()
    optimizer_D = optim.Adam(params=discriminator.parameters(), lr=args.lr, betas=(args.b1, 0.999)) # b2 kept as default
    optimizer_G = optim.Adam(params=generator.parameters(), lr=args.lr, betas=(args.b1, 0.999))  # b2 kept as default
    fixed_noise = torch.randn(args.batch_size, args.z_dim, 1, 1).to(device)
    logger = SummaryWriter(os.path.join("runs", args.run_name))
    l = len(dataloader)

    initialize_weights(generator)
    initialize_weights(discriminator)
    step = 0

    for epoch in tqdm(range(args.epochs)):
        logging.info(f"Starting epoch {epoch}:")
        pbar = tqdm(dataloader)

        for batch_idx, (x, _) in enumerate(pbar):
            x = x.to(device)
            noise = torch.randn(size = (x.shape[0], args.z_dim, 1, 1)).to(device)

            g_z = generator(noise) # G(z)
            d_x = discriminator(x).reshape(-1) # D(x), reshape from 1*1*1 to 1
            d_g_z = discriminator(g_z).reshape(-1) # D(G(z)), reshape from 1*1*1 to 1

            ### Train the Discriminator: Min -(log(D(x)) + log(1-D(G(Z)))) <---> Max log(D(x)) + log(1-D(G(Z)))
            discriminator.train()
            generator.eval()

            loss_real_D = criterion(d_x, torch.ones_like(d_x)) # -log(D(X))
            loss_fake_D = criterion(d_g_z, torch.zeros_like(d_g_z)) # -log(1-D(G(z)))
            loss_D = (loss_fake_D + loss_real_D )/ 2 # -(log(D(x)) + log(1-D(G(Z))))

            optimizer_D.zero_grad()

            loss_D.backward(retain_graph=True)

            optimizer_D.step()

            ### Train the Generator: Min -log(D(G(z)) <---> Max log(D(G(z))) <---> Min log(1-D(G(z)))
            generator.train()
            discriminator.eval()

            d_g_z_next = discriminator(g_z).reshape(-1) # after training the disc, new D(G(z)), reshape from 1*1*1 to 1
            loss_G = criterion(d_g_z_next, torch.ones_like(d_g_z_next)) # -log(D(G(z)))

            optimizer_G.zero_grad()

            loss_G.backward()

            optimizer_G.step()

            # Logs
            pbar.set_description(f"Epoch [{epoch} / {args.epochs}]")
            pbar.set_postfix(loss_disc = loss_D.item(), loss_gen = loss_G.item())
            scalars = {
                "loss_disc_real": loss_real_D.item(),
                "loss_disc_fake": loss_fake_D.item(),
                "loss_disc": loss_D.item(),
                "loss_gen": loss_G.item()
            }
            logger.add_scalars("Losses", scalars, global_step=epoch * l + batch_idx)

            # Evaluation
            if batch_idx % 50 == 0:
                scalars = {
                    "loss_disc_real": loss_real_D.item(),
                    "loss_disc_fake": loss_fake_D.item(),
                    "loss_disc": loss_D.item(),
                    "loss_gen": loss_G.item()
                }
                logger.add_scalars("Losses", scalars, global_step=epoch * l + batch_idx)
                with torch.no_grad():
                        generator.eval()
                        fake = generator(fixed_noise)
                        # take out (up to) 32 examples
                        img_grid_real = torchvision.utils.make_grid(x[:32], normalize=True)
                        img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)

                        logger.add_image("Real", img_grid_real, global_step=epoch * l + batch_idx)
                        logger.add_image("Fake", img_grid_fake, global_step=epoch * l + batch_idx)

        # save models' checkpoint after each epoch
        torch.save(generator.state_dict(), os.path.join("models", args.run_name, f"gen_ckpt.pt"))
        torch.save(discriminator.state_dict(), os.path.join("models", args.run_name, f"disc_ckpt.pt"))
        torch.save(optimizer_G.state_dict(), os.path.join("models", args.run_name, f"gen_optim.pt"))
        torch.save(optimizer_D.state_dict(), os.path.join("models", args.run_name, f"disc_optim.pt"))        
    

def launch(): # argsparse not usable in notebook, see bottom args as class implementation for launch
    parser = argparse.ArgumentParser()
    args = parser.parse_args()
    # setup
    args.device = "cuda" if torch.cuda.is_available() else "cpu"
    args.dataset_path = Path().cwd().parent / "data" / "celeb_A" # Path().cwd().parent / "data" / "celeb_A"
    args.run_name = "DCGAN"
    args.epochs = 10
    # Hyperparameters following the DCGAN paper
    args.transforms=None # follow default transforms
    args.batch_size = 128
    args.img_size = 64
    args.img_channels = 3 # can be changed wrt to images (althought DCGAN paper, input channels of images is to be 3)
    args.z_dim = 100
    args.lr = 2e-4
    args.b1 = 0.5
    train(args)

In [35]:
def launch():
    class args:
        # setup
        device = "cuda" if torch.cuda.is_available() else "cpu"
        dataset_path = Path().cwd()/ "data" / "celeb_A" # Path().cwd().parent / "data" / "celeb_A"
        run_name = "DCGAN"
        epochs = 10
        # Hyperparameters following the DCGAN paper
        transforms=None # follow default transforms
        batch_size = 128
        img_size = 64
        img_channels = 3 # can be changed wrt to images (althought DCGAN paper, input channels of images is to be 3)
        z_dim = 100
        lr = 2e-4
        b1 = 0.5
    train(args)

In [36]:
launch()

06:26:44 - INFO: Starting epoch 0:
Epoch [0 / 10]:  68%|██████▊   | 1082/1583 [30:50<14:16,  1.71s/it, loss_disc=0.693, loss_gen=0.693]  


KeyboardInterrupt: 