### Conditional Deep Convolutional Generative Adversarial Network
view https://arxiv.org/pdf/1605.05396.pdf

In [315]:
import torch
from torch import nn
import torchvision.datasets
import numpy as np
import matplotlib.pyplot as plt

In [316]:
batch_size = 16
%matplotlib inline
plt.rcParams['image.cmap'] = 'gray'

In [317]:
transform = torchvision.transforms.ToTensor()
mnist_train = torchvision.datasets.MNIST('./MNIST_data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size)
mnist_test = torchvision.datasets.MNIST('./MNIST_data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size)

In [318]:
class Flatten(nn.Module):
    def forward(self, input):
        flattened = input.view(input.shape[0], -1)
        return flattened

In [319]:
class Unflatten(nn.Module):
    def __init__(self, C=128, H=7, W=7):
        super(Unflatten, self).__init__()
        self.C = C
        self.H = H
        self.W = W
        
    def forward(self, input):
        unflattened = input.view(-1, self.C, self.H, self.W)
        return unflattened

In [320]:
class Add_Conditional(nn.Module):
    def __init__(self, conditional=torch.zeros(16)):
        super(Add_Conditional, self).__init__()
        self.conditional= conditional
        
    def forward(self, input):
        print("INPUT SHAPE", input.shape)
        conditional_array = np.expand_dims(np.array(self.conditional), axis=1)
        conditional_tensor = torch.FloatTensor(conditional_array)
        print("conditional_tensor", self.conditional.shape)
        concatenated = torch.cat((conditional_tensor, input), dim=1)
        return concatenated

In [321]:
def generate_nosie(batch_size, dim=96):
    noise = torch.rand(batch_size, dim) * 2 - 1
    return noise

In [322]:
def CNN(conditional):
    print("Conditional CNN", conditional.shape)
    model = nn.Sequential(
        nn.Conv2d(1, 32, [5,5], stride=[1,1]),
        nn.LeakyReLU(negative_slope=.01),
        nn.MaxPool2d([2,2], stride=[2,2]),
        nn.Conv2d(32, 64, [5,5], stride=[1,1]),
        nn.LeakyReLU(negative_slope=.01),
        nn.MaxPool2d([2,2], stride=[2,2]),
        Flatten(),
        Add_Conditional(conditional=conditional),
        Add_Conditional(conditional),
        nn.Linear((4*4*64 + 1), (4*4*64)), 
        nn.LeakyReLU(negative_slope=.01),
        nn.Linear((4*4*64), 1)
    )
    return model

In [323]:
def generator(conditional, noise_dim=96):
    print("Conditional", conditional.shape)
    model = nn.Sequential(
        nn.Linear(noise_dim, 1024),
        nn.ReLU(),
        nn.BatchNorm1d(1024),
        Add_Conditional(conditional=conditional),
        #1025 b/c added one from conditional
        nn.Linear(1025, (7*7*128)),
        nn.ReLU(),
        nn.BatchNorm1d(7*7*128),
        Unflatten(C=128, H=7, W=7),
        nn.ConvTranspose2d(128, 64, [4,4], stride=[2,2], padding=1),
        nn.ReLU(),
        nn.BatchNorm2d(64),
        nn.ConvTranspose2d(64, 1, [4,4], stride=[2,2], padding=1),
        nn.Tanh(),
        Flatten()
    )
    return model

In [324]:
def create_optimizer(model, lr=.01, betas=None):
    if betas == None:
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas)
    return optimizer

In [325]:
def discriminator_loss(scores_real, scores_fake):
    true_labels = torch.ones_like(scores_real)
    valid_loss = torch.mean((scores_real - true_labels) ** 2) * .5
    invalid_loss = torch.mean(scores_fake ** 2) * .5
    loss = valid_loss + invalid_loss
    return loss

In [326]:
def generator_loss(scores_fake):
    true_labels = torch.ones_like(scores_fake)
    loss = torch.mean((scores_fake - true_labels) ** 2) * .5
    return loss

In [327]:
def show_image(images):
#     for image in images:
    images_np = images.detach().numpy().squeeze()
#     side_length = np.sqrt(images.shape[1])
#     print("side length", side_length, images.shape)
#     assert(side_length % 1 == 0), "images not square shape"
#     image_show = images[0]
#     image_unflattened = np.reshape(image_show, (int(side_length), int(side_length)))
#     plt.imshow(image_unflattened)
    plt.imshow(images_np[0])
    plt.show()

In [328]:
def train_gan(generator, discriminator, image_loader, epochs, num_train_batches=-1):
    generator_optimizer = create_optimizer(generator, lr=1e-3, betas=(.5, .999))
    discriminator_optimizer = create_optimizer(discriminator, lr=1e-3, betas=(.5, .999))
    iters = 0
    for epoch in range(epochs):
        for i, (examples, labels) in enumerate(image_loader):
            if i == num_train_batches:
                break
            generator_optimizer.zero_grad()
            discriminator_optimizer.zero_grad()
            z = generate_nosie(batch_size)
            images_fake = generator(z, labels)
            images_fake_unflattened = images_fake.view(images_fake.shape[0], 1, 28, 28)
            scores_fake = discriminator(images_fake_unflattened)
            
            ##TODO, fix scores_fake 10 class problem
            
            g_cost = generator_loss(scores_fake)
            g_cost.backward(retain_graph=True)
            generator_optimizer.step()

            scores_real = discriminator(examples, labels)
            d_cost = discriminator_loss(scores_real, scores_fake)
            d_cost.backward()
            discriminator_optimizer.step()
            iters += 1
            if iters % 100  == 0:
                print("Iteration:", iters)
                print("Discriminator Cost", d_cost)
                print("Generator Cost", g_cost)
                show_image(images_fake_unflattened)
                

    return generator, discriminator

In [330]:
filler_conditonal = torch.zeros(16)
generator = generator(filler_conditonal)
discriminator = CNN(filler_conditonal)
image_loader = train_loader
epochs = 5
num_train_batches = 100

Conditional torch.Size([16])
Conditional CNN torch.Size([16])


In [331]:
train_gan(generator, discriminator, image_loader, epochs, num_train_batches=num_train_batches)

TypeError: forward() takes 2 positional arguments but 3 were given