# Implementation of Conditional GANs
Reference: https://arxiv.org/pdf/1411.1784.pdf

In [None]:
# Run the comment below only when using Google Colab
# !pip install torch torchvision

In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F

In [2]:
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
from CGAN_MNIST import NetD, NetG

In [3]:
import numpy as np
import datetime
import os, sys
import gzip
import pickle

In [4]:
from matplotlib.pyplot import imshow, imsave
%matplotlib inline

In [5]:
MODEL_NAME = 'ConditionalGAN'
DEVICE = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")

In [6]:
def to_onehot(x, num_classes=10):
    assert isinstance(x, int) or isinstance(x, (torch.LongTensor, torch.cuda.LongTensor))
    if isinstance(x, int):
        c = torch.zeros(1, num_classes).long()
        c[0][x] = 1
    else:
        x = x.cpu()
        c = torch.LongTensor(x.size(0), num_classes)
        c.zero_()
        c.scatter_(1, x, 1) # dim, index, src value
    return c

In [7]:
def get_sample_image(G, n_noise=100):
    """
        save sample 100 images
    """
    img = np.zeros([280, 280])
    for j in range(10):
        c = torch.zeros([10, 10]).to(DEVICE)
        c[:, j] = 1
        z = torch.randn(10, n_noise).to(DEVICE)
        y_hat = G(torch.cat((z,c),dim=1)).view(10, 28, 28)
        result = y_hat.cpu().data.numpy()
        img[j*28:(j+1)*28] = np.concatenate([x for x in result], axis=-1)
    return img

In [8]:
class Discriminator(nn.Module):
    """
        Simple Discriminator w/ MLP
    """
    def __init__(self, input_size=784, condition_size=10, num_classes=1):
        super(Discriminator, self).__init__()
        self.layer = nn.Sequential(
            nn.Linear(input_size+condition_size, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, num_classes),
            nn.Sigmoid(),
        )
    
    def forward(self, x, c):        
        x, c = x.view(x.size(0), -1), c.view(c.size(0), -1).float()
        v = torch.cat((x, c), 1) # v: [input, label] concatenated vector
        y_ = self.layer(v)
        return y_

In [9]:
class Generator(nn.Module):
    """
        Simple Generator w/ MLP
    """
    def __init__(self, input_size=100, condition_size=10, num_classes=784):
        super(Generator, self).__init__()
        self.layer = nn.Sequential(
            nn.Linear(input_size+condition_size, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, num_classes),
            nn.Tanh()
        )
        
    def forward(self, x, c):
        x, c = x.view(x.size(0), -1), c.view(c.size(0), -1).float()
        v = torch.cat((x, c), 1) # v: [input, label] concatenated vector
        y_ = self.layer(v)
        y_ = y_.view(x.size(0), 1, 28, 28)
        return y_

In [10]:
D = NetD().to(DEVICE)
G = NetG().to(DEVICE)

Linear(in_features=794, out_features=512, bias=True)
Linear(in_features=512, out_features=512, bias=True)
Linear(in_features=512, out_features=512, bias=True)
Linear(in_features=512, out_features=1024, bias=True)
Linear(in_features=1024, out_features=1, bias=True)
Linear(in_features=110, out_features=128, bias=True)
Linear(in_features=128, out_features=256, bias=True)
Linear(in_features=256, out_features=512, bias=True)
Linear(in_features=512, out_features=1024, bias=True)
Linear(in_features=1024, out_features=784, bias=True)


In [11]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize(mean=(0.5, 0.5, 0.5),
                                std=(0.5, 0.5, 0.5))]
)

In [12]:
DATA_PATH = '../data/MNIST/mnist.pkl.gz'
with gzip.open(DATA_PATH, 'rb') as mnist:
    ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(mnist, encoding='latin-1')

In [13]:
x_train, y_train, x_valid, y_valid = map(torch.detach,
                                                        map(torch.Tensor,
                                                           (x_train, y_train, x_valid, y_valid)))

In [14]:
x_train = (x_train-0.5)/0.5
x_valid = (x_valid-0.5)/0.5
ds_train = TensorDataset(x_train, y_train)
ds_valid = TensorDataset(x_valid, y_valid)

In [15]:
batch_size = 1024
condition_size = 10

In [16]:
data_loader = DataLoader(dataset=ds_train, batch_size=batch_size, shuffle=True, drop_last=True,num_workers=8)

In [17]:
criterion = nn.BCELoss()
D_opt = torch.optim.Adam(D.parameters(), lr=0.0001, betas=(0.5, 0.999))
G_opt = torch.optim.Adam(G.parameters(), lr=0.0001, betas=(0.5, 0.999))

In [18]:
max_epoch = 300 # need more than 100 epochs for training generator
n_critic = 1 # for training more k steps about Discriminator
n_noise = 100

In [19]:
D_labels = torch.ones([batch_size, 1]).to(DEVICE) # Discriminator Label to real
D_fakes = torch.zeros([batch_size, 1]).to(DEVICE) # Discriminator Label to fake

In [20]:
if not os.path.exists('samples'):
    os.makedirs('samples')

In [21]:
for epoch in range(0, max_epoch):
    for idx, (images, labels) in enumerate(data_loader):
        # Training Discriminator
        x = images.to(DEVICE)
        y = labels.view(batch_size, 1)
        y = to_onehot(y.to(torch.long)).to(torch.float).to(DEVICE)
        x_outputs = D(torch.cat((x,y), dim=1 ))
        D_x_loss = criterion(x_outputs, D_labels)

        z = torch.randn(batch_size, n_noise).to(DEVICE)
        z_outputs = D(torch.cat((G(torch.cat((z, y), dim=1)), y), dim=1))#fake as true?
        D_z_loss = criterion(z_outputs, D_fakes)
        D_loss = D_x_loss + D_z_loss
        
        D.zero_grad()
        D_loss.backward()
        D_opt.step()
        
        if idx % n_critic == 0:
            # Training Generator
            z = torch.randn(batch_size, n_noise).to(DEVICE)
            z_outputs = D(torch.cat((G(torch.cat((z, y), dim=1)), y), dim=1))
            G_loss = criterion(z_outputs, D_labels)

            G.zero_grad()
            G_loss.backward()
            G_opt.step()
        
#         if step % 500 == 0:
        
    if (epoch+1) % 10 == 0:
        G.eval()
        print('Epoch: {}/{}, D Loss: {}, G Loss: {}'.format(epoch+1, max_epoch, D_loss.item(), G_loss.item()))
        img = get_sample_image(G, n_noise)
        imsave('samples/{}_epoch{}.jpg'.format(MODEL_NAME, str(epoch+1).zfill(3)), img, cmap='gray')
        G.train()

Traceback (most recent call last):
  File "/usr/local/python3/lib/python3.6/multiprocessing/queues.py", line 240, in _feed
    send_bytes(obj)
  File "/usr/local/python3/lib/python3.6/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/usr/local/python3/lib/python3.6/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/usr/local/python3/lib/python3.6/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
Traceback (most recent call last):
BrokenPipeError: [Errno 32] Broken pipe
Traceback (most recent call last):
  File "/usr/local/python3/lib/python3.6/multiprocessing/queues.py", line 240, in _feed
    send_bytes(obj)
  File "/usr/local/python3/lib/python3.6/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/usr/local/python3/lib/python3.6/multiprocessing/connection.py", line 404, in _send_bytes
   

RuntimeError: DataLoader worker (pid 1511) is killed by signal: Bus error. 

## Sample

In [None]:
# generation to image
G.eval()
imshow(get_sample_image(G, n_noise), cmap='gray')

In [None]:
def save_checkpoint(state, file_name='checkpoint.pth.tar'):
    torch.save(state, file_name)

In [None]:
# Saving params.
save_checkpoint({'epoch': epoch + 1, 'state_dict':D.state_dict(), 'optimizer' : D_opt.state_dict()}, 'D_c.pth.tar')
save_checkpoint({'epoch': epoch + 1, 'state_dict':G.state_dict(), 'optimizer' : G_opt.state_dict()}, 'G_c.pth.tar')

In [None]:
a = torch.zeros((128,784))
b = torch.ones((128,10))
torch.cat((a,b), 1)

In [None]:
torch.cat((x, conditions), dim=1))