# Conditional GANs

<div>
   <span style="font-size: large;"> "GANS in Action"</span>, Jakub Langr, Vladimir Bok, Manning Publications 2019.  
</div>

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

In [None]:
import itertools
from IPython.display import clear_output, display

In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
%matplotlib inline

In [None]:
import torch as t
import torch.nn as tnn
import torch.nn.functional as F

import torchvision

In [None]:
import sys
sys.path.append("../../src")

In [None]:
import GAN.C_DCGAN_MNIST as cgan
import utils

In [None]:
if t.cuda.is_available():
    if t.cuda.device_count()>1:
        device = t.device('cuda:1')
    else:
        device = t.device('cuda')   
else:
    device = t.device('cpu')

In [None]:
t.cuda.is_available()

In [None]:
print(device)

## MNIST 

In [None]:
dl_train = t.utils.data.DataLoader(
    torchvision.datasets.MNIST('./data/mnist', train=True, download=True))

dl_test  = t.utils.data.DataLoader(
    torchvision.datasets.MNIST('./data/mnist', train=False, download=True))

In [None]:
mnist_train_data   = dl_train.dataset.data.to(dtype=t.float32)
mnist_train_labels = dl_train.dataset.targets

In [None]:
mnist_test_data   = dl_test.dataset.data.to(dtype=t.float32)
mnist_test_labels = dl_test.dataset.targets

In [None]:
mnist_data = np.concatenate((mnist_train_data, mnist_test_data))
mnist_labels = np.concatenate((mnist_train_labels, mnist_test_labels))

In [None]:
index=49
plt.imshow(mnist_data[index], cmap='Greys')
plt.text(22,3,'%d' % (mnist_labels[index],), fontsize=32);

In [None]:
n_data = 50000

In [None]:
mnist_data_rescaled = 2.0*mnist_data.astype('float32')/255.0 - 1.0
mnist_data_t = t.from_numpy(mnist_data_rescaled[0:n_data]).reshape(-1,1,28,28)

In [None]:
mnist_labels_t = t.from_numpy(mnist_labels[0:n_data] )

mnist_one_hot_labels_t = t.zeros(len(mnist_labels_t),10)
mnist_one_hot_labels_t.scatter_(1,mnist_labels_t.reshape(-1,1),1);

In [None]:
mnist_data_t = mnist_data_t.to(device)
mnist_one_hot_labels_t = mnist_one_hot_labels_t.to(device)

I am taking the whole data set. But if you run on the CPU you can consider using a smaller portion e.g 
```
mnist_data_t = t.from_numpy(mnist_data_flatened[0:20000])
```

## Models

### Discriminator

<img src="GANs_in_Action/discriminator.jpg" style="margin: auto;padding: 25px;">

<img src="GANs_in_Action/discriminator_embedding.jpg" style="margin: auto; padding:25px;">

In [None]:
discriminator = cgan.Discriminator()
discriminator.to(device)

In [None]:
mnist_data_t[0:1].shape

In [None]:
o = discriminator(mnist_data_t[0:1], mnist_one_hot_labels_t[0:1])
o.shape

### Generator

<img src="GANs_in_Action/generator.jpg" style="margin: auto; padding: 25px;">

<img src="GANs_in_Action/generator_embedding.jpg" style="margin: auto; padding: 25px;">

In [None]:
z_dim=100

In [None]:
generator = cgan.Generator(z_dim)
generator.to(device)

In [None]:
g = generator(t.rand((2,z_dim), device=device), mnist_one_hot_labels_t[0:2] )

In [None]:
g.shape

Move data and models to the device (CUDA). 

In [None]:
with t.no_grad():
    generator.eval()
    in_t = t.empty(1,z_dim, device=device).uniform_(-1,1)
    out_t = generator(in_t, mnist_one_hot_labels_t[0:1]);
    plt.imshow(out_t.data.cpu().numpy().squeeze(), cmap='Greys')

Binary Cross Entropy loss with labels $l_i$ is defined as:
$$bce(\{ p_i \},\{ l_ i\}) = \frac{1}{n}\sum_{i=0}^{n-1} \left(l_i \log p_i + (1-l_i) \log(1-p_i)\right)  $$

In [None]:
bce = t.nn.BCELoss()

In [None]:
t.cuda.empty_cache()

In [None]:
mnist_data_t.shape

In [None]:
mnist_one_hot_labels_t.shape

In [None]:
plt.imshow(mnist_data_t[0].cpu().numpy().squeeze() )

In [None]:
mnist_one_hot_labels_t[0]

In [None]:
dataset = t.utils.data.TensorDataset( mnist_data_t, mnist_one_hot_labels_t)

In [None]:
if device.type == 'cpu':
    mini_batch_size=64
else:
    mini_batch_size=2048
dataloader = t.utils.data.DataLoader(dataset, batch_size=mini_batch_size, shuffle=True)

In [None]:
with t.no_grad():
    discriminator.eval()
    generator.eval()
    d_loss = 0
    g_loss = 0
    n_batches = 0
    for d,lbl in dataloader:
        d_loss+= cgan.discriminator_loss(discriminator, generator,  d, lbl, z_dim, device)
        g_loss+= cgan.generator_loss( discriminator, generator,len(d), lbl, z_dim, device) 
        n_batches+=1
print(d_loss.item()/n_batches, g_loss.item()/n_batches)

In [None]:
d_optimizer = t.optim.Adam(discriminator.parameters(), lr=0.0002,betas=(0.5, 0.999))
g_optimizer = t.optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))

In [None]:
np.arange(10)

In [None]:
import time
start = time.time()

cols = 5
rows = 2 

fixed_noise = t.empty(cols*rows,z_dim, device=device).uniform_(-1,1)
digits = t.zeros((10,10))
digits.scatter_(1,t.arange(10).reshape(-1,1),1)
digits = digits.to(device)

k_discriminator = 2
k_generator = 1


n_epochs = 200
plt.ioff()
for epoch in range(1,n_epochs+1):
    discriminator.train()
    generator.train()
    for d,lbl in dataloader:
        d_loss = cgan.train_discriminator(discriminator, generator, d_optimizer, d, lbl, z_dim, k_discriminator, device)
        g_loss = cgan.train_generator(discriminator, generator, g_optimizer, len(d), lbl, z_dim, k_generator, device)        
      
    if epoch % 1 == 0:
        clear_output(wait=True)
        ellapsed, remaining = utils.estimate(start, n_epochs, epoch)
        print('%5d %6.2f %6.2f %6.2fs %6.2fs\n' % (epoch, d_loss.item(), g_loss.item(), ellapsed, remaining))
       
    if epoch % 5 == 0:
        cgan.gen_and_save(rows, cols, generator, fixed_noise, digits, f"img/img_{epoch:03d}.png")
            

In [None]:
imgs = generator(fixed_noise, digits)

In [None]:
cgan.gen_and_save(rows, cols, generator, fixed_noise, digits, "a.png")

In [None]:
#save the generator
t.save(generator.state_dict(),"gan.pt")

In [None]:
cgan.gen_and_save(rows, cols, generator, fixed_noise, digits)

In [None]:
#Display original data
cols = 8
rows = 4 
fig, ax = plt.subplots(rows, cols, figsize=(1.5*cols,1.5*rows))
for i, j in itertools.product(range(rows), range(cols) ):
    ax[i,j].get_xaxis().set_visible(False)
    ax[i,j].get_yaxis().set_visible(False)
for i, j in itertools.product(range(rows), range(cols)):  
    ax[i,j].imshow(mnist_data[np.random.randint(0,len(mnist_data))], cmap='Greys')    

A demonstration of how to load generator.  

In [None]:
generator_loaded = cgan.Generator(z_dim)

In [None]:
generator_loaded.load_state_dict(t.load('gan.pt'))
generator_loaded.eval()

In [None]:
generator_loaded= generator_loaded.to(device)

In [None]:
d100 = t.stack([digits for i in range(10)],1).reshape(-1,10)

In [None]:
fake_imgs = generator_loaded(2*t.rand(100,100, device=device)-1, d100)

In [None]:
utils.display_img_grid(10,10,fake_imgs.cpu().data.numpy().reshape(-1,28,28))

In [None]:
real_imgs = mnist_data[np.concatenate([np.random.choice(np.where(mnist_labels==i)[0], 10) for i in range(10)])]/255.0

In [None]:
utils.display_img_grid(10,10,real_imgs)

## Emebedding

In [None]:
class Embedding(tnn.Module):
    def __init__(self, in_features, out_features, s=0.01):
        super(Embedding, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weights = tnn.Parameter(t.empty(in_features, out_features).uniform_(-s, s))
        self.register_parameter("weights", self.weights)

    def forward(self,x):
        return self.weights[x]

In [None]:
embedding = Embedding(10,20)

In [None]:
embedding(t.LongTensor([0,3,0])).shape