In [1]:
import torch
from torch import nn
import os


class Discriminator(nn.Module):
    def __init__(self, im_chan: int = 1, hidden_dim: int = 64):
        super(Discriminator, self).__init__()

        self.discriminator = nn.Sequential(
            self._discriminator_block(im_chan, hidden_dim),
            self._discriminator_block(hidden_dim, hidden_dim * 2),
            self._discriminator_block(hidden_dim * 2, 1, final_layer=True),
        )

    def _discriminator_block(
        self,
        input_channels: int,
        output_channels: int,
        kernel_size: int = 4,
        stride: int = 2,
        final_layer: int = False,
    ):
        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(negative_slope=0.2),
            )
        else:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
            )

    def forward(self, image):
        disc_pred = self.discriminator(image)
        return disc_pred.view(len(disc_pred), -1)


In [2]:
import torch
from torch import nn
import os


class Generator(nn.Module):
    def __init__(self, input_dim: int = 10, im_chan: int = 1, hidden_dim: int = 64):
        super(Generator, self).__init__()

        self.input_dim = input_dim
        self.generator = nn.Sequential(
            self._generator_block(input_dim, hidden_dim * 4),  # <- upsampling
            self._generator_block(
                hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1
            ),
            self._generator_block(hidden_dim * 2, hidden_dim),
            self._generator_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),
        )

    def _generator_block(
        self,
        input_channels: int,
        output_channels: int,
        kernel_size: int = 3,
        stride: int = 2,
        final_layer: int = False,
    ):
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(
                    input_channels, output_channels, kernel_size, stride
                ),
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace=True),
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(
                    input_channels, output_channels, kernel_size, stride
                ),
                nn.Tanh(),
            )

    def forward(self, noise):
        # 2D -> 4D, i.e., adding H & W dimensions for ConvTranspose2d op
        noise = noise.view(len(noise), self.input_dim, 1, 1)  # [B, C, H, W]
        return self.generator(noise)


def create_noise_vector(n_samples: int, input_dim: int, device: str = "cuda"):
    return torch.randn(n_samples, input_dim).to(device)


In [3]:
import torch
from torch import nn
from torchvision import transforms

MNIST_SHAPE = (1, 28, 28)
N_CLASSES = 10
N_EPOCHS = 200
Z_DIM = 64
DISPLAY_STEP = 100
BATCH_SIZE = 128
LR = 2e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CRITERION = nn.BCEWithLogitsLoss()
TRANSFORMS = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5), (0.5))]
)
DEBUG = False
LOG_PATH = "./logs"
BASE_DIR = "."
CHECKPOINT_DIR = "./checkpoints"

In [4]:
# Importing Modules
import sys
import os
import torch
from torch import nn
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import torch.nn.functional as F
import datetime
from pathlib import Path

torch.manual_seed(0)  # Set for our testing purposes, please do not change!


def plot_images_from_tensor(
    image_tensor, num_images=25, size=(1, 28, 28), nrow=5, show=True
):
    image_tensor = (image_tensor + 1) / 2  # [-1, 1] -> [0, 1] (normalizes image)

    # detach a tensor from the current computational graph and moving it to CPU.
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=nrow)

    # [C, H, W] -> [H, W, C] (format expected by matplotlib)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    if show:
        plt.show()


def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, mean=0.0, std=0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, mean=0.0, std=0.02)
        torch.nn.init.constant_(m.bias, val=0)


def ohe_vector_from_labels(label_tensor, n_classes):
    # takes 'label_tensor' tensor of shape (*), and converts it to 0/1s tensor of shape (*, n_classes)
    return F.one_hot(label_tensor, num_classes=n_classes)


"""
x = torch.tensor([4, 3, 2, 1, 0])
F.one_hot(x, num_classes=6)

# Expected result
# tensor([[0, 0, 0, 0, 1, 0],
#         [0, 0, 0, 1, 0, 0],
#         [0, 0, 1, 0, 0, 0],
#         [0, 1, 0, 0, 0, 0],
#         [1, 0, 0, 0, 0, 0]])
"""


def concat_vectors(x, y):
    # Generator in CGAN doesn't only take the noise vector 'z' but also the label vector 'y'.
    # Hence, the CONCATENATION, i.e., Generator Input  - Noise + Label Vector
    combined = torch.cat(tensors=(x.float(), y.float()), axis=1)
    return combined


""" 
Concatenation of Multiple Tensor with `torch.cat()`
RULE - To concatenate WITH torch.cat(), where the list of tensors are concatenated across the specified dimensions, requires 2 conditions to be satisfied
    1. All tensors need to have the same number of dimensions, and
    2. All dimensions EXCEPT the one that they are concatenated on, need to have the same size. 

Concatenation between (32,2,32) and (32,4,32) with concat(dim=1) yields (32,6,32). 
"""


def calculate_input_dim(z_dim, mnist_shape, n_classes):
    """
    DISCRIMINATOR -> Class information is appended as a channel or some other method,
    GENERATOR -> Class information is encoded by appending a one-hot vector to the noise to form a long vector input.

    z_dim = size of the noise vector - 64 or 128,
    mnist_shape = (1, 28, 28)
    n_classes = 10 [mnist digits]

    """
    generator_input_dim = z_dim + n_classes  # latent noise [z] + label vector [y]

    discriminator_image_channel = (
        mnist_shape[0] + n_classes
    )  # label information is appended as a channel.

    return generator_input_dim, discriminator_image_channel


def init_setting():
    timestr = str(datetime.datetime.now().strftime("%Y-%m%d_%H%M"))
    experiment_dir = Path(LOG_PATH)
    experiment_dir.mkdir(exist_ok=True)  # directory for saving experimental results
    experiment_dir = experiment_dir.joinpath(timestr)
    experiment_dir.mkdir(exist_ok=True)  # root directory of each experiment

    checkpoint_dir = Path(CHECKPOINT_DIR)
    checkpoint_dir.mkdir(exist_ok=True)
    checkpoint_dir = checkpoint_dir.joinpath(timestr)
    checkpoint_dir.mkdir(exist_ok=True)  # root directory of each experiment

    # returns several directory paths
    return experiment_dir, checkpoint_dir, timestr


def save_checkpoint(**kwargs):
    pass


In [5]:
import os
import torch
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

if __name__ == "__main__":
    exp_path, checkpoint_dir, timestr = init_setting()

    dataloader = DataLoader(
        dataset=MNIST(root="data", download=False, transform=TRANSFORMS),
        batch_size=BATCH_SIZE,
        shuffle=True,
    )

    gen_input_dim, disc_input_chan = calculate_input_dim(
        z_dim=Z_DIM, mnist_shape=MNIST_SHAPE, n_classes=N_CLASSES
    )  # (74, 11)

    gen = Generator(input_dim=gen_input_dim).to(DEVICE)
    gen_opt = torch.optim.Adam(params=gen.parameters(), lr=LR)
    disc = Discriminator(im_chan=disc_input_chan).to(DEVICE)
    disc_opt = torch.optim.Adam(params=disc.parameters(), lr=LR)

    gen, disc = gen.apply(weights_init), disc.apply(weights_init)

    generator_losses, discriminator_losses = [], []
    cur_step = 0


    for epoch in range(N_EPOCHS):
        for idx, (images, labels) in enumerate(dataloader):
            cur_batch_size = len(images)  # 128
            images = images.to(DEVICE)

            """
            Create OHE vectors from labels (ground truth), i.e., 
            - labels[0] = 8
            - one_hot_labels[0] = [0, 0, 0, 0, 0, 0, 0, 0, 1, 0] 
            """

            one_hot_labels = ohe_vector_from_labels(
                label_tensor=labels, n_classes=N_CLASSES
            ).to(DEVICE)  # [128, 10]
            image_one_hot_labels = one_hot_labels[..., None, None]  # [128, 10, 1, 1]

            # [128, 10, 1, 1] -> # [128, 10, 28, 28]
            image_one_hot_labels = image_one_hot_labels.repeat(
                1, 1, MNIST_SHAPE[1], MNIST_SHAPE[2]
            )  # how many times to repeat each dim

            ### Train Discriminator
            disc_opt.zero_grad()
            fake_noise = create_noise_vector(
                n_samples=cur_batch_size, input_dim=Z_DIM, device=DEVICE
            )  # [128, 64]

            """
            IMPORTANT:
            * For Generator, labels are appened to the end of the noise vectors.
            * For Discriminator, labels are appended to the channel dimension.
            """

            # z(noise) - [128, 64] + y(true_labels) - [128, 10]
            noise_and_labels = concat_vectors(fake_noise, one_hot_labels)  # [128, 74]

            # noise_and_labels dims get expanded automatically during generator's forward pass
            fake = gen(noise_and_labels)  # CONDITIONED FAKE IMAGES / [128, 1, 28, 28]

            # [128, 1, 28, 28] + [128, 10, 28, 28] = [128, 11, 28, 28] (both)
            fake_image_and_labels = concat_vectors(fake, image_one_hot_labels)
            real_image_and_labels = concat_vectors(images, image_one_hot_labels)

            # Getting the discriminator's predictions
            disc_fake_pred = disc(fake_image_and_labels.detach())  # [128, 1]
            disc_real_pred = disc(real_image_and_labels)  # [128, 1]

            # Calculating the Loss
            disc_fake_loss = CRITERION(disc_fake_pred, torch.zeros_like(disc_fake_pred))
            disc_real_loss = CRITERION(disc_real_pred, torch.ones_like(disc_real_pred))
            disc_loss = (disc_fake_loss + disc_real_loss) / 2

            # Backpropagate & Update Weights
            disc_loss.backward(retain_graph=True)
            disc_opt.step()

            # Keep track of average discriminator losses.
            discriminator_losses += [disc_loss.item()]

            ### Train Generator
            gen_opt.zero_grad()

            # [128, 1, 28, 28] + [128, 10, 28, 28] -> [128, 11, 28, 28]
            fake_image_and_labels = concat_vectors(fake, image_one_hot_labels)
            disc_fake_pred = disc(fake_image_and_labels)  # [128, 1]

            gen_loss = CRITERION(disc_fake_pred, torch.ones_like(disc_fake_pred))
            gen_loss.backward()
            gen_opt.step()

            # Keep track of average generator losses.
            generator_losses += [gen_loss.item()]

            if idx % DISPLAY_STEP == 0 and idx > 0:
                # Calculate Generator Mean Loss for the latest display steps (i.e., last 50 steps)
                gen_mean = sum(generator_losses[-DISPLAY_STEP:]) / DISPLAY_STEP
                disc_mean = sum(discriminator_losses[-DISPLAY_STEP:]) / DISPLAY_STEP
                print(
                    f"Epoch {epoch}: | Step: {idx} | Gen Loss: {gen_mean} | Disc Loss: {disc_mean}"
                )

                plot_images_from_tensor(fake)
                plot_images_from_tensor(images)
            cur_step += 1

        checkpoint = {
            "epoch": epoch,
            "gen_state_dict": gen.state_dict(),
            "disc_state_dict": disc.state_dict(),
            "gen_optimizer": gen_opt.state_dict(),
            "disc_optimizer": disc_opt.state_dict(),
        }  # save state dictionary
        torch.save(checkpoint, f"{checkpoint_dir}/model.pth")


KeyboardInterrupt: 