# Simple GAN

In [12]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from matplotlib import pyplot as plt

We will apply a gan to match a Gaussian 0,1 (noise) to a Gaussian 10,1 (real data distribution) 

## Prep Data

In [13]:
def get_generated_data_sample(n):
    return torch.Tensor(np.random.normal(0, 1, (n,1)))

In [14]:
# Real data here follows a Gaussian Distribution of mean 10, std 1
def get_real_data_sample(n):
    return torch.Tensor(np.random.normal(10, 1, (n,1)))

In [15]:
get_generated_data_sample(1),get_real_data_sample(1)

(tensor([[ 0.4486]]), tensor([[ 11.6524]]))

## Define Generator and Discriminator Classes

In [16]:
class Generator(nn.Module):
    
    def __init__(self):
        super(Generator, self).__init__()

        self.lt1 = nn.Linear(1, 50)
        self.lt2 = nn.Linear(50, 1)

    def forward(self, x):
        
        x = F.relu(self.lt1(x))
        return self.lt2(x)
    

In [17]:
class Discriminator(nn.Module):
    
    def __init__(self):
        super(Discriminator, self).__init__()
        self.lt1 = nn.Linear(1, 50)
        self.lt2 = nn.Linear(50, 50)
        self.lt3 = nn.Linear(50, 1)

    def forward(self, x):
        x = F.relu(self.lt1(x))
        x = F.relu(self.lt2(x))
        return F.sigmoid(self.lt3(x))

In [18]:
G = Generator()
D = Discriminator()

In [19]:
loss = nn.BCELoss()  # Binary cross entropy: http://pytorch.org/docs/nn.html#bceloss

In [20]:
d_optimizer = optim.Adam(D.parameters(), lr=1e-3)
g_optimizer = optim.Adam(G.parameters(), lr=1e-3)

## Define Our Training Code

In [21]:
def train_discriminator(num_epochs = 0):
    global batch_size
    
    for d_index in range(num_epochs):

        # 1. Train D on real+fake
        D.zero_grad()

        #  A: Train D on real
        d_real_data = Variable(get_real_data_sample(batch_size))
        d_real_decision = D(d_real_data)
        d_real_error = loss(d_real_decision, Variable(torch.ones(batch_size,1)))  # ones = true
        d_real_error.backward() # compute/store gradients, but don't change params

        #  B: Train D on fake
        d_gen_input = Variable(get_generated_data_sample(batch_size))
        d_fake_data = G(d_gen_input).detach()  # detach to avoid training G on these labels
        d_fake_decision = D(d_fake_data)
        d_fake_error = loss(d_fake_decision, Variable(torch.zeros(batch_size,1)))  # zeros = fake
        d_fake_error.backward()

        d_optimizer.step()    # Only optimizes D's parameters; changes based on stored gradients from backward()

        return d_real_error+d_fake_error

In [22]:
def train_generator(num_epochs = 0):
    global batch_size
    
    for g_index in range(num_epochs):

        # 2. Train G on D's response
        G.zero_grad()

        gen_input = Variable(get_generated_data_sample(batch_size))
        g_fake_data = G(gen_input)
        d_fake_decision = D(g_fake_data)
        
        # here we want to fool, so we need to make the generator go towards producing 1s
        g_error = loss(d_fake_decision, Variable(torch.ones(batch_size,1)))  

        g_error.backward()
        g_optimizer.step()  # Only optimizes G's parameters

        return g_error

In [23]:
def train(num_epochs):
    for epoch in range(num_epochs):

        d_loss = train_discriminator(d_steps)
        g_loss = train_generator(g_steps)


        if epoch % print_interval == 0:
            print("{}: Total Loss: {} / D: {} / G: {}".format(epoch,(d_loss+g_loss).item(),d_loss.item(),g_loss.item()))

## TRAIN

In [33]:
num_epochs = 2000
batch_size = 64   # Minibatch size - cardinality of distributions
d_steps = 2  # 'k' steps in the original GAN paper. Can put the discriminator on higher training freq than generator
g_steps = 1
print_interval = 200

In [34]:
train(num_epochs)

0: Total Loss: 2.1040873527526855 / D: 1.3847991228103638 / G: 0.7192882299423218
200: Total Loss: 2.0813894271850586 / D: 1.3695675134658813 / G: 0.7118218541145325
400: Total Loss: 2.071267604827881 / D: 1.3736796379089355 / G: 0.6975880861282349
600: Total Loss: 2.092691421508789 / D: 1.386200189590454 / G: 0.7064912915229797
800: Total Loss: 2.0968101024627686 / D: 1.3837002515792847 / G: 0.7131098508834839
1000: Total Loss: 2.101151943206787 / D: 1.385023593902588 / G: 0.7161283493041992
1200: Total Loss: 2.0602939128875732 / D: 1.4204730987548828 / G: 0.6398208737373352
1400: Total Loss: 2.0187182426452637 / D: 1.0863401889801025 / G: 0.9323779940605164
1600: Total Loss: 1.8058009147644043 / D: 1.1107898950576782 / G: 0.6950110793113708
1800: Total Loss: 2.0861549377441406 / D: 1.545760154724121 / G: 0.5403948426246643


In [43]:
generated = G(torch.Tensor([0]));generated

tensor([ 12.2723])

In [39]:
confidence = D(torch.Tensor([0])); confidence

tensor([ 0.1158])