# Conditional GAN

## imports

In [None]:
import torch
import torchvision
import torch.nn as nn
import matplotlib.pyplot as plt

## hyperparams

In [None]:
PATH = 'data'
BATCHSIZE = 64
N_INPUTCHANNELS = 1
HEIGHT_WIDTH = 28       # height and width of mnist images
NDF = 4                 # handle for number of discriminator features
NGF = 32                # handle for number of generator features 
NZ = 100                # generator sample dimension
LR = 1e-3               # learning rate
EPOCHS = 20
N_CLASSES = 10

## load MNIST data

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

transform = torchvision.transforms.ToTensor()
train_data = torchvision.datasets.MNIST(root=PATH, train=True, transform=transform, download=True)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=BATCHSIZE, shuffle=True)

## Generator and Discriminator classes

In [None]:
class Generator(nn.Module):
    def __init__(self, nz, ngf, nic, nclasses):
        super().__init__()
        
        self.gen = nn.Sequential(
                                    # input_size: (BATCHSIZE, nz=100, 1, 1)
                                    nn.ConvTranspose2d(in_channels=nz+nclasses, out_channels=ngf*8, kernel_size=4,stride=1, padding=0), 
                                    nn.BatchNorm2d(ngf * 8), 
                                    nn.ReLU(), 
                                    # input_size: (BATCHSIZE, ngf*8, 4, 4)
                                    nn.ConvTranspose2d(in_channels=ngf*8, out_channels=ngf*4, kernel_size=3, stride=2, padding=1), 
                                    nn.BatchNorm2d(ngf * 4), 
                                    nn.ReLU(), 
                                    # input_size: (BATCHSIZE, ngf*4, 7, 7)
                                    nn.ConvTranspose2d(in_channels=ngf*4, out_channels=ngf, kernel_size=4, stride=2, padding=1), 
                                    nn.BatchNorm2d(ngf), 
                                    nn.ReLU(), 
                                    # input_size: (BATCHSIZE, ngf, 14, 14)
                                    nn.ConvTranspose2d(in_channels=ngf, out_channels=nic, kernel_size=4, stride=2, padding=1), 
                                    # output_size: (BATCHSIZE, nc, 28, 28)
                                    nn.Sigmoid() 
                                )

    def forward(self, z, c_one_hot): 
        x = torch.cat([z, c_one_hot], 1)
        x = self.gen(x.unsqueeze(-1).unsqueeze(-1))
        return x


class Discriminator(nn.Module):
    def __init__(self, nc, height_width, ndf, n_classes):
        super().__init__()
        
        self.height_width = height_width
        self.dis = nn.Sequential(   
                                    nn.Linear(height_width**2 + n_classes, height_width**2), 
                                    nn.Unflatten(1, (nc, height_width, height_width)), 
                                    nn.Conv2d(in_channels=nc, out_channels=ndf, kernel_size=4, stride=2, padding=1), 
                                    nn.ReLU(), 
                                    nn.Conv2d(in_channels=ndf, out_channels=ndf*4, kernel_size=4, stride=2, padding=1), 
                                    nn.BatchNorm2d(ndf*4), 
                                    nn.ReLU(), 
                                    nn.Conv2d(in_channels=ndf*4, out_channels=ndf*8, kernel_size=4, stride=2, padding=1), 
                                    nn.BatchNorm2d(ndf*8), 
                                    nn.ReLU(), 
                                    nn.Flatten(1,-1), 
                                    nn.Linear(ndf*8 * 3 * 3, 1), 
                                    nn.Sigmoid()
                                )
        
    def forward(self, image, c_one_hot): 
        x = image.view(-1, self.height_width**2)
        x = torch.cat([x, c_one_hot], 1)
        x = self.dis(x)
        return x

## initialize model

In [None]:
dis = Discriminator(N_INPUTCHANNELS, HEIGHT_WIDTH, NDF, N_CLASSES).to(device)
gen = Generator(NZ, NGF, N_INPUTCHANNELS, N_CLASSES).to(device)

criterion = nn.BCELoss()

optimizer_dis = torch.optim.Adam(dis.parameters(), lr=LR)
optimizer_gen = torch.optim.Adam(gen.parameters(), lr=LR)

## training loop

In [None]:
gen_losses = []
dis_losses = []

for epoch in range(EPOCHS): 

    for i, data in enumerate(train_loader): 
        real_im, labels = data
        real_im, labels = real_im.to(device), labels.to(device)
        labels_one_hot = torch.nn.functional.one_hot(labels, num_classes=10)
        
        # ---------------------
        #  Train Discriminator
        # ---------------------
        
        optimizer_dis.zero_grad()
        
        with torch.no_grad(): 
            z = torch.randn(len(real_im), NZ).to(device)                        # batch-size, latent-dimension
            z_labels = torch.randint(0, 10, (len(real_im),)).to(device)         # random labels
            z_labels_one_hot = torch.nn.functional.one_hot(z_labels, num_classes=10)
            fake_im = gen(z, z_labels_one_hot)
        
        disc_real = dis(real_im, labels_one_hot)
        disc_fake = dis(fake_im, z_labels_one_hot)
        
        # 1 being target/label for real images, 0 target/label for generated images (for the discriminator)
        real_loss = criterion(disc_real, torch.ones_like(disc_real))
        fake_loss = criterion(disc_fake, torch.zeros_like(disc_fake))
        loss_dis = (real_loss + fake_loss) / 2
        
        loss_dis.backward() 
        optimizer_dis.step()
        
        # -----------------
        #  Train Generator
        # -----------------

        optimizer_gen.zero_grad()
        
        # generate images
        z = torch.randn(len(real_im), NZ).to(device)                        # batch-size, latent-dimension
        z_labels = torch.randint(0, 10, (len(real_im),)).to(device)         # random labels
        z_labels_one_hot = torch.nn.functional.one_hot(z_labels, num_classes=10)
        fake_im = gen(z, z_labels_one_hot)
        
        # judge generated images with discriminator
        output = dis(fake_im, z_labels_one_hot)
        
        # 1 being target/label for real image, 0 target/label for generated image (here, for generator training, generated images are labeled as real)
        loss_gen = criterion(output, torch.ones_like(output))

        loss_gen.backward()
        optimizer_gen.step()
        
        # record stats
        avg_pred_real = disc_real.mean().item()
        avg_pred_gen1 = disc_fake.mean().item()
        avg_pred_gen2 = output.mean().item()

        # print stats 
        if i % 50 == 0:
            print(f'[{epoch+1}/{EPOCHS}] [{i}/{len(train_loader)}] \nLoss D: {loss_dis.item()}, Loss G: {loss_gen.item()}, Mean D(x): {avg_pred_real}, Mean D(G(z)):{avg_pred_gen1} / {avg_pred_gen2}')

        gen_losses.append(loss_gen.item())
        dis_losses.append(loss_dis.item())

## check results

In [None]:
# plot generator and discriminator loss
plt.figure(figsize=(10,5))
plt.title('Generator and Discriminator Loss')
plt.plot(gen_losses, label='Generator')
plt.plot(dis_losses, label='Discriminator')
plt.xlabel('iterations')
plt.ylabel('loss')
plt.legend()
plt.show()

# sample from a standard gaussian
z = torch.randn(10, NZ).to(device) 
# generate classes
c = torch.eye(10, 10).to(device)
# generate some images
gen_img = gen(z, c) 

# show generated images
for i in range(10): 
  plt.subplot(2,5, i + 1) 
  plt.axis('off')
  plt.title(f'{i}')
  plt.imshow(gen_img[i].squeeze().detach().cpu().numpy(), cmap='gray_r')

plt.show()

# check discriminator
out_test = dis(gen_img, c)
print(f'Discriminator tested on generated images, MEAN prediction: {out_test.mean().item():.2f}')

real_batch, real_label = next(iter(train_loader))
real_batch, real_label = real_batch.to(device), real_label.to(device)

real_label_one_hot = torch.nn.functional.one_hot(real_label, num_classes=10)
real_test = dis(real_batch, real_label_one_hot)
print(f'Discriminator tested on real images, MEAN prediction:  {real_test.mean().item():.2f}')