In [1]:
import torch
from torch import nn
import torchvision
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from torchvision.transforms import Resize, ToTensor, Normalize
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as tt

import matplotlib.pyplot as plt
import numpy as np
#from vae_utils import get_vector_from_label, add_vector_to_images, morph_faces

print(torch.__version__)
print(torchvision.__version__)

device = torch.device("cuda:0" if torch.cuda.is_available() else "mps")

2.0.0
0.15.1


In [None]:
# setting hyperparameters
LEARNING_RATE = 1e-4
BATCH_SIZE = 64
IMAGE_SIZE = 64
CHANNELS_IMG = 3
NUM_CLASSES = 10
GEN_EMBEDDING = 100
Z_DIM = 100
NUM_EPOCHS = 5
FEATURES_DISC = 64
FEATURES_GEN = 64
CRITIC_ITERATIONS = 5
LAMBDA_GP = 10


#WEIGHT_CLIP = 0.01

In [2]:
import pandas as pd

df = pd.read_csv('/Users/parkermoesta/datasets/CelebA/list_attr_celeba.csv')
df.head()

Unnamed: 0,image_id,5_o_Clock_Shadow,Arched_Eyebrows,Attractive,Bags_Under_Eyes,Bald,Bangs,Big_Lips,Big_Nose,Black_Hair,...,Sideburns,Smiling,Straight_Hair,Wavy_Hair,Wearing_Earrings,Wearing_Hat,Wearing_Lipstick,Wearing_Necklace,Wearing_Necktie,Young
0,000001.jpg,-1,1,1,-1,-1,-1,-1,-1,-1,...,-1,1,1,-1,1,-1,1,-1,-1,1
1,000002.jpg,-1,-1,-1,1,-1,-1,-1,1,-1,...,-1,1,-1,-1,-1,-1,-1,-1,-1,1
2,000003.jpg,-1,-1,-1,-1,-1,-1,1,-1,-1,...,-1,-1,-1,1,-1,-1,-1,-1,-1,1
3,000004.jpg,-1,-1,1,-1,-1,-1,-1,-1,-1,...,-1,-1,1,-1,1,-1,1,1,-1,1
4,000005.jpg,-1,1,1,-1,-1,-1,1,-1,-1,...,-1,-1,-1,-1,-1,-1,1,-1,-1,1


In [3]:
# list of attributes
attributes = df.columns[1:]
attributes

Index(['5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes',
       'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair',
       'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin',
       'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones',
       'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard',
       'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline',
       'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair',
       'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick',
       'Wearing_Necklace', 'Wearing_Necktie', 'Young'],
      dtype='object')

In [None]:
transforms = tt.Compose(
    [
        tt.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        tt.ToTensor(),
        tt.Normalize([0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]),
    ]
)

#dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms, download=True)

dataset = datasets.ImageFolder(root="/Users/parkermoesta/datasets/CelebA/img_align_celeba/", transform=transforms)

dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
class CelebADataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, attributes, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.df = pd.read_csv(os.path.join(root_dir, 'list_attr_celeba.csv'))
        self.df = self.df.set_index('image_id')
        self.df = self.df[attributes]
        self.df = self.df.replace(-1, 0)  # replace -1 with 0
        self.filenames = self.df.index

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, 'img_align_celeba', self.filenames[idx])
        image = Image.open(img_name)
        attributes = self.df.iloc[idx].values.astype(np.float32)
        if self.transform:
            image = self.transform(image)
        return image, attributes

In [None]:
attributes = ['Attractive', 'Bags_Under_Eyes', 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair', 'Chubby', 'Double_Chin']
dataset = CelebADataset('/Users/parkermoesta/datasets/CelebA/', attributes, transform=transforms)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

## Additions to the original code:
1. Created an embedding layer for the Discriminator / Critic
2. Change Channels_img to channel_img + 1 (basically added another channel for the embedding layer / class label)
3. Concat the embedding layer with the image in both the generator and discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self, channel_img, features_d, num_classes, img_size):
        super(Discriminator, self).__init__()
        self.img_size = img_size
        self.disc = nn.Sequential(
            # Input N x channel_img x 64 x 64
            nn.Conv2d(
                channel_img +1, features_d, kernel_size=4, stride=2, padding=1
            ), # 32x32
            nn.LeakyReLU(0.2),
            self.d_block(features_d, features_d*2, 4, 2, 1), # out: (batch_size, features_d*2, 16, 16)
            self.d_block(features_d*2, features_d*4, 4, 2, 1), # out: (batch_size, features_d*4, 8, 8)
            self.d_block(features_d*4, features_d*8, 4, 2, 1), # out: (batch_size, features_d*8, 4, 4)
            nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0), # out: (batch_size, 1, 1, 1)
        )
        # create an embedding layer for the labels
        self.embed = nn.Embedding(num_classes, img_size*img_size)

    def d_block(self, in_channels, out_channels, kernel_size, stride, padding):
        
        return nn.Sequential(
            nn.Conv2d(
                in_channels, out_channels, kernel_size, stride, padding, bias=False,
            ),
            nn.InstanceNorm2d(out_channels, affine=True), # instance norm instead of batch norm, instancenorm is applied to each image individually.
            nn.LeakyReLU(0.2, inplace=True),
        )
    
    def forward(self, x, labels):
        embedding = self.embed(labels).view(labels.shape[0], 1, self.img_size, self.img_size) # (batch_size, 1, img_size, img_size)
        x = torch.cat([x, embedding], dim=1) # (batch_size, channel, img_size, img_size)
        return self.disc(x)

In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim, channels_img, features_g, num_classes, img_size, embed_size):
        super(Generator, self).__init__()
        self.img_size = img_size

        self.gen = nn.Sequential(
            # Input: N x z_dim x 1 x 1
            self.gen_block(z_dim + embed_size, features_g*16, 4, 1, 0), # N x f_g*16 x 4 x 4
            self.gen_block(features_g*16, features_g*8, 4, 2, 1), # N x f_g*8 x 8 x 8
            self.gen_block(features_g*8, features_g*4, 4, 2, 1), # N x f_g*4 x 16 x 16
            self.gen_block(features_g*4, features_g*2, 4, 2, 1), # N x f_g*2 x 32 x 32
            nn.ConvTranspose2d(
                features_g*2, channels_img, kernel_size=4, stride=2, padding=1
            ), # N x channels_img x 64 x 64
            nn.Tanh(), # [-1, 1]
        )
        self.embed = nn.Embedding(num_classes, embed_size)
    
    def gen_block(self, in_channels, out_channels, kernel_size, stride, padding):
        
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels, out_channels, kernel_size, stride, padding, bias=False,
            ),
            nn.BatchNorm2d(out_channels), # don't need to use bias as batchnorm's learnable parameters
            nn.ReLU(inplace=True),
        )
    
    def forward(self, x, labels):
        # Laten vector z: N x noise_dim x 1 x 1
        embedding = self.embed(labels).unsqueeze(2).unsqueeze(3) # (batch_size, embed_size, 1, 1)
        x = torch.cat([x, embedding], dim=1) # (batch_size, z_dim + embed_size, 1, 1)
        return self.gen(x)

In [None]:
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

In [None]:
def gradient_penalty(critic, labels,real, fake):
    BATCH_SIZE, C, H, W = real.shape # 64, 3, 64, 64
    epsilon = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device) # epsilon will be used to interpolate between real and fake images
    interpolated_images = real * epsilon + fake * (1 - epsilon) # interpolated_images will be used to calculate the gradient penalty

    mixed_scores = critic(interpolated_images, labels) # mixed_scores will be used to calculate the gradient penalty

    gradient = torch.autograd.grad( # gradient will be used to calculate the gradient penalty
        inputs=interpolated_images, # interpolated_images is the input to calculate the gradient
        outputs=mixed_scores, # mixed_scores is the output to calculate the gradient
        grad_outputs=torch.ones_like(mixed_scores), # torch.ones_like(mixed_scores) is the gradient of the output
        create_graph=True,
        retain_graph=True,
    )[0] # [0] to get the first element of the tuple returned by torch.autograd.grad

    gradient = gradient.view(gradient.shape[0], -1) # flatten the gradient tensor
    gradient_norm = gradient.norm(2, dim=1) # calculate the norm of the gradient tensor

    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

In [None]:
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN,NUM_CLASSES, IMAGE_SIZE, GEN_EMBEDDING).to(device)
critic = Discriminator(CHANNELS_IMG, FEATURES_DISC, NUM_CLASSES, IMAGE_SIZE).to(device)
# need to initialize weights on generator and discriminator
initialize_weights(gen)
initialize_weights(critic)

In [None]:
# Optimizers 
#opt_gen = torch.optim.RMSprop(gen.parameters(), lr=LEARNING_RATE)
#opt_critic = torch.optim.RMSprop(critic.parameters(), lr=LEARNING_RATE)

opt_gen = torch.optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
opt_critic = torch.optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))

# loss function
#criterion = nn.BCELoss()

In [None]:
# setting fixed noise for visualization
fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")

step = 0 # for printing to tensorboard

gen.train()
critic.train()

for epoch in range(NUM_EPOCHS):
    for batch_idx, (real, labels) in enumerate(dataloader):
        real = real.to(device)
        cur_batch_size = real.shape[0]
        labels = labels.to(device)

        for _ in range(CRITIC_ITERATIONS):
            noise = torch.randn((BATCH_SIZE, Z_DIM, 1, 1)).to(device)
            fake = gen(noise)
            critic_real = critic(real, labels).reshape(-1) # flatten
            critic_fake = critic(fake, labels).reshape(-1) # flatten
            gp = gradient_penalty(critic,labels, real, fake)
            loss_critic = (
                -(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP * gp)
            critic.zero_grad()
            loss_critic.backward(retain_graph=True) # retain_graph=True to prevent error
            opt_critic.step()


        ## Train Generator: min -E[critic(gen_fake)]
        output = critic(fake, labels).reshape(-1)
        loss_gen = -torch.mean(output) # we want to minimize the loss
        loss_gen.backward()
        opt_gen.step()
        # print losses occasionally and print to tensorboard
        if batch_idx % 100 == 0:
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \
                  Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = gen(noise, labels)
                # take out (up to) 32 examples
                img_grid_real = torchvision.utils.make_grid(
                    real[:32], normalize=True
                )
                img_grid_fake = torchvision.utils.make_grid(
                    fake[:32], normalize=True
                )

                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)
            step += 1

In [None]:
# setting fixed noise for visualization
fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")

step = 0 # for printing to tensorboard

gen.train()
critic.train()

for epoch in range(NUM_EPOCHS):
    for batch_idx, (real, labels) in enumerate(dataloader):
        real = real.to(device)
        cur_batch_size = real.shape[0]
        labels = labels.to(device)

        for _ in range(CRITIC_ITERATIONS):
            noise = torch.randn((BATCH_SIZE, Z_DIM, 1, 1)).to(device)
            fake = gen(noise)
            critic_real = critic(real, labels).reshape(-1) # flatten
            critic_fake = critic(fake, labels).reshape(-1) # flatten
            gp = gradient_penalty(critic,labels, real, fake)
            loss_critic = (
                -(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP * gp)
            critic.zero_grad()
            loss_critic.backward(retain_graph=True) # retain_graph=True to prevent error
            opt_critic.step()


        ## Train Generator: min -E[critic(gen_fake)]
        # Train Generator
        noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(device)
        fake = gen(noise, labels)
        output = critic(fake, labels).reshape(-1)
        loss_gen = -torch.mean(output) # we want to minimize the loss
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # print losses occasionally and print to tensorboard
        if batch_idx % 100 == 0:
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \
                Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                # Generate new noise and labels for logging
                noise = torch.randn(32, Z_DIM, 1, 1).to(device)
                labels = torch.randint(0, NUM_CLASSES, (32,)).to(device)
                fake = gen(noise, labels)
                # take out (up to) 32 examples
                img_grid_real = torchvision.utils.make_grid(
                    real[:32], normalize=True
                )
                img_grid_fake = torchvision.utils.make_grid(
                    fake[:32], normalize=True
                )

                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)
            step += 1


In [None]:
# Save the models
torch.save(gen.state_dict(), 'generator.pth')
torch.save(critic.state_dict(), 'critic.pth')