# GAN
This file demonstrate how to use GAN Generator/Discriminator

In [29]:
import os, time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

from torch.autograd import Variable
%matplotlib inline

## HyperParameter

In [30]:
IMAGE_SIZE = 28
BATCH_SIZE = 128

## Generator/Discriminator

In [38]:
class Generator(nn.Module):
    
    def __init__(self, input_dim=64, image_size=28):
        super(Generator, self).__init__()
        self.input_dim = input_dim
        self.image_size = image_size
        self.output_dim = 1 * image_size * image_size  # same as origin image
        
        self.fc = nn.Sequential(
            nn.Linear(self.input_dim, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, self.output_dim),
            nn.BatchNorm1d(self.output_dim),
            nn.Tanh(), # ~[-1, 1]
        )
    
    def forward(self, input):
        x = self.fc(input)
        x = x.view(-1, self.output_dim)
        return x
    

In [39]:
class Discriminator(nn.Module):
    
    def __init__(self, image_size=28):
        super(Discriminator, self).__init__()
        self.input_dim = 1 * image_size * image_size
        
        self.fc = nn.Sequential(
            nn.Linear(self.input_dim, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )
    
    def forward(self, input):
        x = self.fc(input)
        return x

## Loading Data -> Mnist

In [40]:
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

def loading_data(input_size=28, batch_size=128):
    transform = transforms.Compose([
                    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
                    transforms.Grayscale(),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=(0.5,), std=(0.5,)) # [0, 1] -> [-1, 1]
    ])

    data_loader = DataLoader(
        datasets.MNIST("../data/mnist", train=True, download=True, transform=transform),
        batch_size=batch_size,
        shuffle=True,
    )
    
    return data_loader

## GAN Object

In [41]:
discriminator = Discriminator(image_size=IMAGE_SIZE)
generator = Generator(image_size=IMAGE_SIZE)

## Optimization

In [42]:
lossF = nn.BCELoss()
lrG = 0.0002
lrD = 0.0002
g_optimizer = torch.optim.Adam(generator.parameters(), lr=lrG)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lrD)

In [44]:
def train_discriminator(generator, discriminator, optimizer, real_data, batch_size, z_size):
    # Reshape real_data to vector
    real_data = real_data.view(batch_size, -1)
    # Rescale real_data to range -1 - 1
    real_data = scale(real_data)
    
    # Reset gradients and set model to training mode
    optimizer.zero_grad()
    discriminator.train()
    
    # Train on real data
    real_data_logits = discriminator.forward(real_data)
    loss_real = real_loss(real_data_logits, smooth=True)
    # Generate fake data
    z_vec = random_vector(batch_size, z_size)
    fake_data = generator.forward(z_vec)
    # Train on fake data
    fake_data_logits = discriminator.forward(fake_data)
    loss_fake = fake_loss(fake_data_logits)
    # Calculate total loss
    total_loss = loss_real + loss_fake
    total_loss.backward()
    optimizer.step()
    
    return total_loss

def train_generator(generator, discriminator, optimizer, batch_size, z_size):
    # Reset gradients and set model to training mode
    optimizer.zero_grad()
    generator.train()
    # Generate fake data
    z_vec = random_vector(batch_size, z_size)
    fake_data = generator.forward(z_vec)
    # Train generator with output of discriminator
    discriminator_logits = discriminator.forward(fake_data)
    # Reverse labels
    loss = real_loss(discriminator_logits)
    loss.backward()
    optimizer.step()
    return loss

## Loss function

In [45]:
def real_loss(predictions, smooth=False):
    batch_size = predictions.shape[0]
    labels = torch.ones(batch_size)
    # Smooth labels for discriminator to weaken learning
    if smooth:
        labels = labels * 0.9
    # We use the binary cross entropy loss | Model has a sigmoid function
    criterion = nn.BCELoss()
    # Move models to GPU if available
    if torch.cuda.is_available():
        labels = labels.cuda()
        criterion = criterion.cuda()
    loss = criterion(predictions.squeeze(), labels)
    return loss

def fake_loss(predictions):
    batch_size = predictions.shape[0]
    labels = torch.zeros(batch_size)
    criterion = nn.BCELoss()
    # Move models to GPU if available
    if torch.cuda.is_available():
        labels = labels.cuda()
        criterion = criterion.cuda()
    loss = criterion(predictions.squeeze(), labels)
    return loss

## Training

### Hyper Parameter

In [18]:
EPOCHES = 50
IMAGE_SIZE = 28
BATCH_SIZE = 128
SAMPLE_SIZE = 8 # Show numbers of image
PLOT_EVERY = 5  # plot every epoch
Z_INDEX = 100

sample_noise = Variable(torch.randn(BATCH_SIZE, Z_INDEX))

### Start Training

In [46]:
d_losses = []
g_losses = []

data_loader = loading_data(input_size=IMAGE_SIZE, batch_size=BATCH_SIZE)

for e in range(EPOCHES):
    for n, (images, _) in enumerate(data_loader):
        assert (images.shape[0] == BATCH_SIZE)
        
        d_loss = 
        

AssertionError: 

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
