In [None]:
import os
import time
from datetime import datetime
import pandas as pd
from PIL import Image
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as torch_utils
from torch.utils.data import Dataset
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

In [None]:
# custom weights initialization
# Reference (PyTorch Tutorials)
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

def save_images(images, output_dir):
    print('Saving Images...')
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        
    for idx, img in enumerate(images):
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S_%f')
        filename = f'image_{timestamp}_{idx}.png'
        filepath = os.path.join(output_dir, filename)
        plt.imsave(filepath, img)
    print("Done!")
    return True

In [None]:
# Discriminator
class Discriminator(nn.Module):
    def __init__(self, classes) -> None:
        super(Discriminator, self).__init__()
        self.classes = classes

        self.embedding = nn.Linear(classes, 1*64*64)

        conv_1 = self.conv_block(4, 64)
        conv_2 = self.conv_block(64, 128)
        conv_3 = self.conv_block(128, 256)
        conv_4 = self.conv_block(256, 512)

        self.classifier = nn.Sequential(
            conv_1,
            conv_2,
            conv_3,
            conv_4,
            nn.Conv2d(512, 1024, (5, 5), 2, 1),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),
            nn.Dropout(0.4),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, (5, 5), 2, 2),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True),
        )
    
    def forward(self, x, label):
        label_embedding = self.embedding(label).view(-1, 1, 64, 64)
        comb_latent_vector = torch.concat((x, label_embedding), dim=1)
        output = self.classifier(comb_latent_vector)
        return output

In [None]:
# Generator
class Generator(nn.Module):
    def __init__(self, classes):
        super(Generator, self).__init__()
        self.classes = classes

        self.embedding = nn.Linear(classes, 8*8)

        self.latent_vector = nn.Sequential(
            nn.Linear(100, 512*8*8),
            nn.LeakyReLU(0.2, inplace=True),
        )

        upsample_1 = self.upsample_block(513, 256, 1)
        upsample_2 = self.upsample_block(256, 128, 1)
        upsample_3 = self.upsample_block(128, 64, 1)

        self.conv_model = nn.Sequential(
            upsample_1,
            upsample_2,
            upsample_3,
            nn.Conv2d(64, 3, (1, 1), 1, 0),
            nn.Tanh()
        )
    
    def upsample_block(self, in_channels, out_channels, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, (4, 4), 2, padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )
    
    def forward(self, x, label):
        latent_vector = self.latent_vector(x).view(-1, 512, 8, 8)
        label_embedding = self.embedding(label).view(-1, 1, 8, 8)
        comb_latent_vector = torch.concat((latent_vector, label_embedding), dim = 1)
        output = self.conv_model(comb_latent_vector)
        return output

In [None]:
manualSeed = 123
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_size = 64
lr = 2e-3
beta1 = 0.5
batch_size = 64
noise_dim = 100
workers = 2
num_epochs = 10

In [None]:
# Monitor Progress
def train(dataloader, classes):
    progress = list()
    fixed_noise = torch.randn(classes*10, noise_dim, device=device)
    fixed_labels = []
    for i in range(classes):
        lab = [0 if j != i else 1 for j in range(classes)]
        lab = lab*10
        fixed_labels.append(lab)
    fixed_labels = torch.Tensor(fixed_labels).view(classes*10, classes).float().to(device)

    disc_net = Discriminator(classes)
    gen_net = Generator(classes)
    disc_net.to(device)
    gen_net.to(device)
    disc_net.apply(weights_init)
    gen_net.apply(weights_init)

    criterion = nn.BCELoss()

    disc_optimizer = optim.Adam(disc_net.parameters(), lr=lr, betas=(beta1, 0.999))
    gen_optimizer = optim.Adam(gen_net.parameters(), lr=lr, betas=(beta1, 0.999))


    # Training Loop

    # Lists to keep track of progress
    G_losses = []
    D_losses = []
    iters = 0
    
    disc_net.train()
    gen_net.train()
    print("Starting Training Loop...")
    for epoch in range(num_epochs):
        for i, data in enumerate(dataloader, 0):
            real_images = data[0].to(device)
            real_labels = data[1].to(device)
            num_images = real_images.size(0)
            
            real_target = torch.ones(num_images,).to(device)
            fake_target = torch.zeros(num_images,).to(device)
            
            # Training the discriminator
            # Train Discriminator on Real Images and Fake Images
            disc_net.zero_grad()

            real_output = disc_net(real_images, real_labels).view(-1)
            disc_err_real = criterion(real_output, real_target)
            
            # Conditional Noise
            noise = torch.randn(num_images, noise_dim, device=device)
            
            indices = torch.randint(0, classes, (num_images,))
            noise_labels = torch.zeros(num_images, classes, device=device)
            noise_labels[torch.arange(num_images), indices] = 1

            fake = gen_net(noise, noise_labels)

            fake_output = disc_net(fake.detach(), noise_labels).view(-1)
            disc_err_fake = criterion(fake_output, fake_target)

            disc_err = (disc_err_real + disc_err_fake)/2
            disc_err.backward()
            disc_optimizer.step()

            # Training the Generator
            # Steps:
            # 1. Get Discriminator Predictions on Fake Images
            # 2. Calculate loss
            gen_net.zero_grad()
            
            output = disc_net(fake, noise_labels).view(-1)

            gen_err = criterion(output, real_target)
            gen_err.backward()
            gen_optimizer.step()

            # Training Update
            if i % 50 == 0:
                print(
                    f"[{epoch}/{num_epochs}][{i}/{len(dataloader)}]\tLoss_D: {disc_err.item()}\tLoss_G: {gen_err.item()}"
                )

            # Tracking loss
            G_losses.append(gen_err.item())
            D_losses.append(disc_err.item())

            # Tracking Generator Progress
            if (iters % 10 == 0) or (
                (epoch == num_epochs - 1) and (i == len(dataloader) - 1)
            ):
                gen_net.eval()
                with torch.no_grad():
                    fake = gen_net(fixed_noise, fixed_labels).detach().cpu()
                progress.append(torch_utils.make_grid(fake, padding=2, nrow=10, normalize=True))
                gen_net.train()
            iters += 1
            
    return gen_net, G_losses, D_losses, progress

In [None]:
def eval(classes, model_path, num_images, output_dir):
    fixed_noise = torch.randn(classes*10, noise_dim, device=device)
    fixed_labels = []
    for i in range(classes):
        lab = [0 if j != i else 1 for j in range(classes)]
        lab = lab*10
        fixed_labels.append(lab)
    fixed_labels = torch.Tensor(fixed_labels).view(classes*10, classes).float().to(device)

    gen_net = Generator(classes)
    gen_net.to(device)
    gen_net.load_state_dict(torch.load(model_path))

    iters = 0
    
    gen_net.eval()
    images = []
    print("Starting Inference Loop...")
    for image in range(num_images):
        with torch.no_grad():
            fake = gen_net(fixed_noise, fixed_labels).detach().cpu()
        images.append(torch_utils.make_grid(fake, padding=2, nrow=10, normalize=True))
         
    return save_images(images, output_dir)

## Dataset 1: CelebA

In [None]:
class CelebADataset(Dataset):
    def __init__(self, image_dir, labels_csv, transform=None):
        """
        Args:
            image_dir (string): Directory with all the images.
            labels_csv (string): Path to the csv file with annotations.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.image_dir = image_dir
        self.labels = pd.read_csv(labels_csv)
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.image_dir, self.labels.iloc[idx, 0])
        image = Image.open(img_name)
        labels = self.labels.iloc[idx, 1:].to_numpy()
        labels[labels == -1] = 0
        labels = labels.astype('float32')

        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(labels)

In [None]:
data_dir = "/kaggle/input/celeba-dataset/img_align_celeba/img_align_celeba"
label_file = "/kaggle/input/celeba-dataset/list_attr_celeba.csv"
model_save_path = "/kaggle/working/cgan_celeba.pt"
animation_save_path = "/kaggle/working/cgan_celeba.mp4"
training_plot_save_path = "/kaggle/working/cgan_celeba.png"

celeba_dataset = CelebADataset(image_dir = data_dir, 
                               labels_csv = label_file,
                               transform = transforms.Compose(
                                [
                                    transforms.Resize(image_size),
                                    transforms.CenterCrop(image_size),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                ]
                            )
                            )

celeba_dataloader = data.DataLoader(
    celeba_dataset, batch_size=batch_size, shuffle=True, num_workers=workers
)

In [None]:
celeba_gen_net, celeba_G_losses, celeba_D_losses, celeba_progress = train(celeba_dataloader, 40)

In [None]:
# Save generator
torch.save(celeba_gen_net, model_save_path)

In [None]:
# Plot Training Graph
fig1 = plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(celeba_G_losses, label="G")
plt.plot(celeba_D_losses, label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.savefig(training_plot_save_path)
plt.show()

In [None]:
# Progress Animation
fig2 = plt.figure(figsize=(8, 8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)] for i in celeba_progress]
anim = animation.ArtistAnimation(fig2, ims, interval=1000, repeat_delay=1000, blit=True)
writervideo = animation.FFMpegWriter(fps=5)
anim.save(animation_save_path, writer=writervideo)
plt.close()

In [None]:
eval(40, '/kaggle/working/cgan_celeba.pt', 5, '/kaggle/working/celeba_eval')

## Dataset 2: Flower Dataset

In [None]:
def target_to_oh_flower(target):
    NUM_CLASS = 5
    one_hot = torch.eye(NUM_CLASS)[target]
    return one_hot

In [None]:
data_dir = "/kaggle/input/flower-classification-5-classes-roselilyetc/Flower Classification/Flower Classification/Training Data"
model_save_path = "/kaggle/working/cgan_flowers.pt"
animation_save_path = "/kaggle/working/cgan_flowers.mp4"
training_plot_save_path = "/kaggle/working/cgan_flowers.png"

floower_dataset = datasets.ImageFolder(
    root=data_dir,
    transform=transforms.Compose(
        [
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    ),
    target_transform = target_to_oh_flower
)

flower_dataloader = data.DataLoader(
    floower_dataset, batch_size=batch_size, shuffle=True, num_workers=workers
)

In [None]:
flower_gen_net, flower_G_losses, flower_D_losses, flower_progress = train(flower_dataloader, 5)

In [None]:
# Save generator
torch.save(flower_gen_net, model_save_path)

In [None]:
# Plot Training Graph
fig1 = plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(flower_G_losses, label="G")
plt.plot(flower_D_losses, label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.savefig(training_plot_save_path)
plt.show()

In [None]:
# Progress Animation
fig2 = plt.figure(figsize=(8, 8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)] for i in flower_progress]
anim = animation.ArtistAnimation(fig2, ims, interval=1000, repeat_delay=1000, blit=True)
writervideo = animation.FFMpegWriter(fps=5)
anim.save(animation_save_path, writer=writervideo)
plt.close()

In [None]:
eval(5, '/kaggle/working/cgan_flower.pt', 5, '/kaggle/working/flower_eval')

## Dataset 3: Shoe, Sandal, Boot

In [None]:
def target_to_oh_shoe(target):
    NUM_CLASS = 3
    one_hot = torch.eye(NUM_CLASS)[target]
    return one_hot

In [None]:
data_dir = "/kaggle/input/shoe-vs-sandal-vs-boot-dataset-15k-images/Shoe vs Sandal vs Boot Dataset"
model_save_path = "/kaggle/working/cgan_shoe.pt"
animation_save_path = "/kaggle/working/cgan_shoe.mp4"
training_plot_save_path = "/kaggle/working/cgan_shoe.png"

shoe_dataset = datasets.ImageFolder(
    root=data_dir,
    transform=transforms.Compose(
        [
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    ),
    target_transform = target_to_oh_shoe
)

shoe_dataloader = data.DataLoader(
    shoe_dataset, batch_size=batch_size, shuffle=True, num_workers=workers
)

In [None]:
shoe_gen_net, shoe_G_losses, shoe_D_losses, shoe_progress = train(shoe_dataloader, 3)

In [None]:
# Save generator
torch.save(shoe_gen_net, model_save_path)

In [None]:
# Plot Training Graph
fig1 = plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(shoe_G_losses, label="G")
plt.plot(shoe_D_losses, label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.savefig(training_plot_save_path)
plt.show()

In [None]:
# Progress Animation
fig2 = plt.figure(figsize=(8, 8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)] for i in shoe_progress]
anim = animation.ArtistAnimation(fig2, ims, interval=1000, repeat_delay=1000, blit=True)
writervideo = animation.FFMpegWriter(fps=5)
anim.save(animation_save_path, writer=writervideo)
plt.close()

In [None]:
eval(3, '/kaggle/working/cgan_shoe.pt', 5, '/kaggle/working/shoe_eval')