# Conditional Generative Adverserial Network

In [None]:
import os
import pickle

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import datasets, transforms

In [None]:
# make folder for saving
if not os.path.isdir('output'):
    os.mkdir('output')
if not os.path.isdir('output/val_results'):
    os.mkdir('output/val_results')

In [None]:
# training parameters
use_gpu = False
lr = 0.0002
batch_size = 128
train_epochs = 50

if use_gpu:
    assert torch.cuda.is_available(), 'ERROR: You have no GPU (CUDA device), turn `use_cuda` as False'

### Prepare Data
* Training data pair `(x, y)` through `MNIST` data loader
* Fixed noise `z` and label `y` for validation phase, totally 100 samples: 10 samples per number in 0~9.
    * `fixed_z`: shape=`(100, 100)` 
    * `fixed_y`: shape=`(100, 10)` (one-hot encoded)

#### one-hot encode

Example, if we want to encode number from 0 to 9 (10 numbers)
```yaml
0 -> (1, 0, 0, 0, 0, 0, 0, 0, 0, 0)
1 -> (0, 1, 0, 0, 0, 0, 0, 0, 0, 0)
2 -> (0, 0, 1, 0, 0, 0, 0, 0, 0, 0)
...
9 -> (0, 0, 0, 0, 0, 0, 0, 0, 0, 1)
```

In [None]:
''' Training data '''
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
data_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True, transform=transform),
    batch_size=batch_size,
    shuffle=True)


''' Validation data '''
def one_hot_encode(tensor, output):
    return output.scatter_(dim=1, index=tensor.long(), value=1)

fixed_z = torch.rand(10 * 10, 100)
fixed_y = torch.cat((torch.zeros(10, 1).fill_(i) for i in range(10)), dim=0)
fixed_y = one_hot_encode(fixed_y, output=torch.zeros(10 * 10, 10))

fixed_z_var = Variable(fixed_z, volatile=True)
fixed_y_var = Variable(fixed_y, volatile=True)

if use_gpu:
    fixed_z_var.cuda()
    fixed_y_var.cuda()

## Prepare model
- Baseline models
- Your advanced models

In [None]:
class BaselineGenerator(nn.Module):

    def __init__(self):
        super().__init__()

        self.fc1_z = nn.Linear(100, 1024)
        self.fc1_z_bn = nn.BatchNorm1d(1024)
        self.fc1_y = nn.Linear(10, 1024)
        self.fc1_y_bn = nn.BatchNorm1d(1024)
        self.fc2 = nn.Linear(2048, 784)

        self.apply(weights_init)

    def forward(self, z, y):
        z = F.relu(self.fc1_z_bn(self.fc1_z(z)))
        y = F.relu(self.fc1_y_bn(self.fc1_y(y)))

        h = torch.cat([z, y], dim=1)
        h = F.tanh(self.fc2(h))

        return h


class BaselineDiscriminator(nn.Module):

    def __init__(self):
        super().__init__()

        self.fc1_x = nn.Linear(784, 1024)
        self.fc1_y = nn.Linear(10, 1024)
        self.fc2 = nn.Linear(2048, 1)

        self.apply(weights_init)

    def forward(self, x, y):
        x = F.leaky_relu(self.fc1_x(x), 0.2)
        y = F.leaky_relu(self.fc1_y(y), 0.2)

        h = torch.cat([x, y], dim=1)
        h = F.sigmoid(self.fc2(h))

        return h


class AdvancedGenerator(nn.Module):
    pass


class AdvancedGenerator(nn.Module):
    pass


def weights_init(m):
    classname = m.__class__.__name__
    if isinstance(m, nn.Linear):
        m.weight.data.normal_(0.0, 0.02)
        m.bias.data.zero_()
    elif classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [None]:
def validate_result(sample_images, epoch, result_folder):
    fig, ax = plt.subplots(10, 10, figsize=(5, 5))
    for i in range(10):
        for j in range(10):
            k = i * 10 + j
            ax[i, j].get_xaxis().set_visible(False)
            ax[i, j].get_yaxis().set_visible(False)
            ax[i, j].cla()
            ax[i, j].imshow(sample_images[k].cpu().data.view(28, 28).numpy(), cmap='gray')
    fig.text(0.5, 0.04, 'Epoch {0}'.format(epoch), ha='center')
    plt.savefig(os.path.join(result_folder, 'epoch_{0}.png'.format(epoch)))

## Train the cGAN
1. Define and prepare the generator `G` and discriminator `D` models
2. Define loss function
3. Build optimizers for updating model weights in training

### In each training epoch
* Visualize the performance (output) of generator with `fixed_z`
* Record the losses for later curve plot

In [None]:
# network
G = # what generator...?
D = # what discriminator...?
if use_gpu:
    G.cuda()
    D.cuda()
print('==> Model ready!')

# loss function: Binary Cross Entropy loss
BCE_loss = nn.BCELoss()

# optimizer: Adam
G_optimizer = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
D_optimizer = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
print('==> Ready for training!')

print('==> Training start!')

for epoch in range(train_epochs):
    G.train()

    for x, y in data_loader:
        batch_size = x.size(0)

        ''' (Data) prepare ground truth label for Discriminator
        '''
        real_label = Variable(torch.ones(batch_size, 1))
        fake_label = Variable(torch.zeros(batch_size, 1))
        if use_gpu:
            real_label.cuda(), fake_label.cuda()

        ''' (Train) discriminator D
        '''
        D_optimizer.zero_grad()

        '''real case:
            (Data) prepare ground truth data pair x, y
            `x` in shape (batch_size, 28 * 28)
            `y` in shape (batch_size, 10)
        '''
        x_var = Variable(x.view(-1, 28 * 28))
        y_var = Variable(torch.zeros(batch_size, 10).scatter_(1, y.view(batch_size, 1), 1))
        if use_gpu:
            x_var.cuda(), y_var.cuda()
        # what should D do here? and the loss?
        # ...

        '''fake case:
            (Data) prepare fake random data pair z, y
            `z` in shape (batch_size, 100)
            `y` in shape (batch_size, 10)
        '''
        z_var = Variable(torch.rand((batch_size, 100)))
        y_var = Variable(torch.zeros(batch_size, 10).scatter_(1, (torch.rand(batch_size, 1) * 10).long(), 1))
        if use_gpu:
            z_var.cuda(), y_var.cuda()

        # in fake case, what should G and D do? and the loss?
        # ...

        # total loss in D-step and loss backward to the weights

        D_optimizer.step()


        ''' (Train) generator G
        '''
        G_optimizer.zero_grad()

        '''generator case:
            (Data) prepare fake random data pair z, y
            `z` in shape (batch_size, 100)
            `y` in shape (batch_size, 10)
        '''
        z_var = Variable(torch.rand((batch_size, 100)))
        y_var = Variable(torch.zeros(batch_size, 10).scatter_(1, (torch.rand(batch_size, 1) * 10).long(), 1))
        if use_gpu:
            z_var.cuda(), y_var.cuda()

        # in G-step, what should G and D do? and the loss?
        # ...

        # loss in G-step and loss backward to the weights

        G_optimizer.step()

    G.eval()
    sample_images = # generate some results from generator with fixed input
    validate_result(sample_images, epoch + 1, result_folder='output/val_results')

    print('[{0}/{1}] loss_d: {2:.3f}, loss_g: {3:.3f}'.format((epoch + 1), train_epochs, loss_d, loss_g))

print('==> Training finish!')

## Save the weights from models
Make it convinient for testing instead of training the networks again

In [None]:
torch.save(, 'output/generator_weight.pth')
torch.save(, 'output/discriminator_weight.pth')
print('==> Models saved')

## Show the training loss curve

## Demo
Show the results with trained generator model (2 images for each class)