In [13]:
import sys
import pandas as pd
import torch
import torch.nn.functional as tf
from torch import nn
from torchvision.datasets import MNIST, FashionMNIST
from torchvision.utils import make_grid
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt

sys.path.append("..")
from models.utils import get_noise, combine_tensors, weights_init
from models.vis_utils import show_img_batch
from models.ConditionalGAN.mnist import (Generator as MnistGenerator,
                                         Discriminator as MnistDiscriminator)
# from models.ConditionalGAN.celeba import (Generator as CelebaGenerator, 
#                                           Discriminator as CelebaDiscriminator)
 
%matplotlib inline

In [14]:
device = "mps" if torch.mps.is_available() else "cpu"

In [19]:
mnist = MNIST(root="../datasets/", download=False, 
              transform=transforms.Compose([transforms.ToTensor(),
                                            transforms.Normalize((0.5,), (0.5,))]))
mnist_dl = DataLoader(mnist, batch_size=128, shuffle=True)

In [20]:
mnist_shape = (1, 28, 28)
n_classes = 10
hidden_dim = 64

criterion = nn.BCEWithLogitsLoss()
z_dim = 128
display_step = 5000
batch_size = 128
# A learning rate of 0.0002 works well on DCGAN
lr = 0.0002


beta_1 = 0.5 
beta_2 = 0.999

In [21]:
# Initialize the generator and discriminator
gen = MnistGenerator(z_dim=z_dim+n_classes).to(device)
disc = MnistDiscriminator(im_chan=1+n_classes).to(device)

# Initialize the optimizers for generator and discriminator
gen_opt = torch.optim.Adam(gen.parameters(),   lr=lr, betas=(beta_1, beta_2))
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr, betas=(beta_1, beta_2))


def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)
gen = gen.apply(weights_init)
disc = disc.apply(weights_init)

# fake = gen(noise_input)
# disc = MnistDiscriminator(1, 64).to(device)

In [22]:
real, labels = next(iter(mnist_dl))

In [28]:
one_hot_vec = tf.one_hot(labels.to(device), num_classes=10)
one_hot_img = one_hot_vec[:,:,None,None].expand(-1, -1, *(64,64,))

In [None]:
one_hot_img.shape

In [None]:
mnist_shape[1:]

In [None]:
n_epochs = 100
cur_step = 0
generator_loss = 0
discriminator_loss = 0

for epoch in range(n_epochs):

    for real, labels in tqdm(mnist_dl):
        cur_batch_size = len(real)
        real = real.to(device)
        
        one_hot_vec = tf.one_hot(labels.to(device), num_classes=10)
        one_hot_img = one_hot_vec[:,:,None,None].expand(-1, -1, 28, 28)

        ## Update discriminator ##
        disc_opt.zero_grad()
        
        # Get noise corresponding to the current batch_size
        noise = get_noise(cur_batch_size, z_dim, device=device)
        
        fake = gen(combine_tensors(noise, one_hot_vec))
        fake_comb = combine_tensors(fake, one_hot_img)
        disc_fake = disc(fake_comb.detach())
        disc_fake_loss = criterion(disc_fake, torch.zeros_like(disc_fake, device=device))


        real_comb = combine_tensors(real, one_hot_img)
        disc_real = disc(real_comb)
        disc_real_loss = criterion(disc_real, torch.ones_like(disc_real, device=device))
        disc_loss = (disc_real_loss + disc_fake_loss) / 2

        disc_loss.backward(retain_graph=True)
        disc_opt.step()

        discriminator_loss += disc_loss.item()
        
        ### Update generator ###
        gen_opt.zero_grad()
        noise = get_noise(cur_batch_size, z_dim, device=device)
                
        fake = gen(combine_tensors(noise, one_hot_vec))
        fake_comb = combine_tensors(fake, one_hot_img)
        gen_fake = disc(fake_comb)
        gen_fake_loss = criterion(gen_fake, torch.ones_like(gen_fake, device=device))
        gen_fake_loss.backward()
        gen_opt.step()        
        
        # Keep track of the average generator loss
        generator_loss += gen_fake_loss.item()

        ### Visualization code ###
        if cur_step % display_step == 0 and cur_step > 0:
            mean_gen_loss = generator_loss/display_step
            mean_disc_loss = discriminator_loss/display_step
            print(f"Epoch {epoch}, step {cur_step}: Generator loss: {mean_gen_loss:.2f}, discriminator loss: {mean_disc_loss:.2f}")
            generator_loss = 0
            discriminator_loss = 0
            show_img_batch((real + 1) / 2)
            show_img_batch((fake + 1) / 2)
        cur_step += 1


In [66]:
torch.save({'generator': gen.state_dict(), 
            'discriminator': disc.state_dict(), 
            'gen_opt': gen_opt.state_dict(), 
            'disc_opt': disc_opt.state_dict()
            }, 
            'cgan_models.pth')

In [None]:
models = torch.load('cgan_models.pth')
gen_loaded = MnistGenerator(z_dim=z_dim+n_classes).to(device)
disc_loaded = MnistDiscriminator(im_chan=1+n_classes).to(device)
gen_loaded.load_state_dict(models['generator'])
disc_loaded.load_state_dict(models['discriminator']) 

In [None]:
noise = get_noise(32, z_dim, device)
labels_ = torch.randint(0, 10, (32, ), device=device)
one_hot_vec = tf.one_hot(labels_, num_classes=10)
gen_input = combine_tensors(noise, one_hot_vec)
show_img_batch(gen_loaded(gen_input))

### Celeba Dataset

[CelebFaces Attributes Dataset (CelebA)](https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) is a large-scale face attributes dataset with more than **200K** celebrity images, each with **40** attribute annotations. The images in this dataset cover large pose variations and background clutter. CelebA has large diversities, large quantities, and rich annotations, including

- **10,177** number of **identities**
- **202,599** number of **face images**, and
- **5 landmark locations**, **40 binary attributes** annotations per image


In this section, we focus on generating faces 

In [4]:
from torchvision.datasets import ImageFolder

In [None]:
# celeba = ImageFolder("../Data/celeba", celeba_transform)
celeba_transform = transforms.Compose([transforms.ToTensor(),
                                         transforms.CenterCrop((178, 178)),
                                         transforms.Resize((64, 64))])
celeba = ImageFolder("../datasets/celeba", transform=celeba_transform)
celeba_dl = DataLoader(celeba, batch_size=128, shuffle=False)
img_batch, labels = next(iter(celeba_dl))
# show_img_batch(img_batch, size=(3,178,178))
show_img_batch(img_batch, size=(3,64,64))


df = pd.read_csv('../datasets/celeba_align/list_attr_celeba.csv', delim_whitespace=True, index_col=0)
df = df.replace(-1, 0)

In [6]:

from torch.utils.data import Dataset

# Custom dataset class to combine images and attributes
class CelebADataset(Dataset):
    def __init__(self, image_folder, attributes_df, transform=None):
        self.image_folder = image_folder
        self.attributes_df = attributes_df
        self.transform = transform

    def __len__(self):
        return len(self.attributes_df)

    def __getitem__(self, idx):
        img, _ = self.image_folder[idx]  # Get image from ImageFolder
        attributes = self.attributes_df.iloc[idx].values.astype(float)  # Get attributes
        
        if self.transform: img = self.transform(img)

        return img, attributes
    
    def __getitem__(self, idx):
        img, _ = self.image_folder[idx]  # Get image from ImageFolder
        attributes = self.attributes_df.iloc[idx].values.astype(float)  # Get attributes
         
        if self.transform: img = self.transform(img)

        return img, attributes
    
# Create the CelebA dataset
celeba_dataset = CelebADataset(celeba, df)
celeba_dl = DataLoader(celeba_dataset, batch_size=128, shuffle=False)

# Example usage
img_batch, labels = next(iter(celeba_dl))

In [7]:
# Training params
criterion = nn.BCEWithLogitsLoss()
display_step = 2500
batch_size = 128
num_classes = 40

# Model params
im_chan = 3
z_dim = 256
size = (3,64,64)

# Optimizer params
lr = 0.0002
beta_1 = 0.5 
beta_2 = 0.999

In [9]:
from models.discriminator_utils import DiscConvBlock
from models.generator_utils import GenConvTransposeBlock

In [10]:
class CelebaGenerator(nn.Module):

    def __init__(self, z_dim=256, im_chan=3, hidden_dim=64):
        super().__init__()
        self.z_dim = z_dim
        self.gen = nn.Sequential(
            GenConvTransposeBlock(z_dim,          hidden_dim * 8, kernel_size=4, stride=1, padding = 0),
            GenConvTransposeBlock(hidden_dim * 8, hidden_dim * 8, kernel_size=4, stride=2, padding = 1),
            GenConvTransposeBlock(hidden_dim * 8, hidden_dim * 4, kernel_size=4, stride=2, padding = 1),
            GenConvTransposeBlock(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=2, padding = 1),
            GenConvTransposeBlock(hidden_dim * 2, im_chan,        kernel_size=4, stride=2, padding = 1, final=True)
        )
    def unsqueeze_noise(self, x):
        return x.view(len(x), self.z_dim, 1, 1)
    
    def forward(self, x):
        x = self.unsqueeze_noise(x)
        return self.gen(x)
    

class CelebaDiscriminator(nn.Module):

    def __init__(self, im_chan=3, hidden_dim=64):
        super().__init__()
        self.disc = nn.Sequential(
            DiscConvBlock(im_chan,        hidden_dim * 1, kernel_size=4, stride=2, padding=1),
            DiscConvBlock(hidden_dim * 1, hidden_dim * 2, kernel_size=4, stride=2, padding=1),
            DiscConvBlock(hidden_dim * 2, hidden_dim * 2, kernel_size=4, stride=2, padding=1), 
            DiscConvBlock(hidden_dim * 2, hidden_dim * 4, kernel_size=4, stride=2, padding=1),
            DiscConvBlock(hidden_dim * 4, 1, kernel_size=4, stride=1, final=True)
        )

    def forward(self, x):
        return self.disc(x).view(-1, 1)

In [11]:
# Initialize the generator and discriminator
gen = CelebaGenerator(z_dim=z_dim+num_classes, im_chan=im_chan).to(device)
disc = CelebaDiscriminator(im_chan=im_chan+num_classes).to(device)

# Initialize the optimizers for generator and discriminator
gen_opt = torch.optim.Adam(gen.parameters(),   lr=lr, betas=(beta_1, beta_2))
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr, betas=(beta_1, beta_2))

gen = gen.apply(weights_init)
disc = disc.apply(weights_init)

In [None]:
n_epochs = 50
cur_step = 0
generator_loss = 0
discriminator_loss = 0

for epoch in range(n_epochs):

    for real, labels in tqdm(celeba_dl):
        cur_batch_size = len(real)
        real = real.to(device)
        
        one_hot_vec = labels.float().to(device)
        one_hot_img = one_hot_vec[:,:,None,None].expand(-1, -1, 64, 64)

        ## Update discriminator ##
        disc_opt.zero_grad()
        
        # Get noise corresponding to the current batch_size
        noise = get_noise(cur_batch_size, z_dim, device=device)
        
        fake = gen(combine_tensors(noise, one_hot_vec))
        fake_comb = combine_tensors(fake, one_hot_img)
        disc_fake = disc(fake_comb.detach())
        disc_fake_loss = criterion(disc_fake, torch.zeros_like(disc_fake, device=device))


        real_comb = combine_tensors(real, one_hot_img)
        disc_real = disc(real_comb)
        disc_real_loss = criterion(disc_real, torch.ones_like(disc_real, device=device))
        disc_loss = (disc_real_loss + disc_fake_loss) / 2

        disc_loss.backward(retain_graph=True)
        disc_opt.step()

        discriminator_loss += disc_loss.item()
        
        ### Update generator ###
        gen_opt.zero_grad()
        noise = get_noise(cur_batch_size, z_dim, device=device)
                
        fake = gen(combine_tensors(noise, one_hot_vec))
        fake_comb = combine_tensors(fake, one_hot_img)
        gen_fake = disc(fake_comb)
        gen_fake_loss = criterion(gen_fake, torch.ones_like(gen_fake, device=device))
        gen_fake_loss.backward()
        gen_opt.step()        
        
        # Keep track of the average generator loss
        generator_loss += gen_fake_loss.item()

        ### Visualization code ###
        if cur_step % display_step == 0 and cur_step > 0:
            mean_gen_loss = generator_loss/display_step
            mean_disc_loss = discriminator_loss/display_step
            print(f"Epoch {epoch}, step {cur_step}: Generator loss: {mean_gen_loss:.2f}, discriminator loss: {mean_disc_loss:.2f}")
            generator_loss = 0
            discriminator_loss = 0
            show_img_batch(real, size=(3, 64, 64))
            show_img_batch(fake, size=(3, 64, 64))
        cur_step += 1
