# Simple GAN

In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from matplotlib import pyplot as plt

We will apply a gan to match a Gaussian 0,1 (noise) to a Gaussian 10,1 (real data distribution) 

## Prep Data

define a function to generate noise

In [3]:
def get_generated_data_sample(n):
    return torch.Tensor(np.random.normal(0, 1, (n,1)))

define a function to retreive real data

In [4]:
# Real data here follows a Gaussian Distribution of mean 10, std 1
def get_real_data_sample(n):
    return torch.Tensor(np.random.normal(10, 1, (n,1)))

In [15]:
get_generated_data_sample(1)

tensor([[-0.6251]])

In [21]:
get_real_data_sample(1)

tensor([[ 11.4471]])

## Define Generator and Discriminator Classes

In [32]:
class Generator(nn.Module):
    
    def __init__(self):
        super(Generator, self).__init__()

        self.lt1 = nn.Linear(1, 50)
        self.lt2 = nn.Linear(50, 1)

    def forward(self, x):
        
        x = F.relu(self.lt1(x))
        return self.lt2(x)
    

In [33]:
class Discriminator(nn.Module):
    
    def __init__(self):
        super(Discriminator, self).__init__()
        self.lt1 = nn.Linear(1, 50)
        self.lt2 = nn.Linear(50, 1)

    def forward(self, x):
        x = F.relu(self.lt1(x))
        return F.sigmoid(self.lt2(x))

In [34]:
G = Generator()
D = Discriminator()

In [35]:
loss = nn.BCELoss()  # Binary cross entropy: http://pytorch.org/docs/nn.html#bceloss

In [36]:
d_optimizer = optim.Adam(D.parameters(), lr=1e-3)
g_optimizer = optim.Adam(G.parameters(), lr=1e-3)

## Define Our Training Code

In [37]:
def train_discriminator(num_epochs = 0):
    global batch_size
    
    for d_index in range(num_epochs):

        # 1. Train Discriminator on real+fake
        D.zero_grad()

        #  A: Train D on real
        d_real_data = Variable(get_real_data_sample(batch_size))
        d_real_decision = D(d_real_data)
        d_real_error = loss(d_real_decision, Variable(torch.ones(batch_size,1)))  # ones = true
        d_real_error.backward() # compute/store gradients, but don't change params

        #  B: Train D on fake
        d_gen_input = Variable(get_generated_data_sample(batch_size))
        d_fake_data = G(d_gen_input).detach()  # detach to avoid training G on these labels
        d_fake_decision = D(d_fake_data)
        d_fake_error = loss(d_fake_decision, Variable(torch.zeros(batch_size,1)))  # zeros = fake
        d_fake_error.backward()

        d_optimizer.step()    # Only optimizes D's parameters; changes based on stored gradients from backward()

        return d_real_error+d_fake_error

In [38]:
def train_generator(num_epochs = 0):
    global batch_size
    
    for g_index in range(num_epochs):

        # 2. Train Generator on D's response to fake data
        G.zero_grad()

        gen_input = Variable(get_generated_data_sample(batch_size))
        g_fake_data = G(gen_input)
        d_fake_decision = D(g_fake_data)
        
        # here we want to fool, so we need to make the generator go towards producing 1s
        
        
        g_error = loss(d_fake_decision, Variable(torch.ones(batch_size,1))) 

        g_error.backward()
        g_optimizer.step()  # Only optimizes G's parameters

        return g_error

In [39]:
def train(num_epochs):
    for epoch in range(num_epochs):

        d_loss = train_discriminator(2) #train discriminator with a 2:1 ratio
        g_loss = train_generator(1)


        if epoch % print_interval == 0:
            print("{}: Total Loss: {} / D: {} / G: {}".format(epoch,(d_loss+g_loss).item(),d_loss.item(),g_loss.item()))

## TRAIN

In [40]:
num_epochs = 4000
batch_size = 64
d_steps = 2  # 'k' steps in the original GAN paper. Can put the discriminator on higher training freq than generator
g_steps = 1
print_interval = 200

In [46]:
train(num_epochs)

0: Total Loss: 1.941429615020752 / D: 1.3334829807281494 / G: 0.6079465746879578
200: Total Loss: 2.1247782707214355 / D: 1.3338627815246582 / G: 0.7909154295921326
400: Total Loss: 2.084935188293457 / D: 1.3824303150177002 / G: 0.7025049924850464
600: Total Loss: 2.0343194007873535 / D: 1.432154655456543 / G: 0.6021648049354553
800: Total Loss: 2.19767427444458 / D: 1.4093291759490967 / G: 0.7883450984954834
1000: Total Loss: 1.9563608169555664 / D: 1.3526643514633179 / G: 0.6036964654922485
1200: Total Loss: 2.1178717613220215 / D: 1.3607864379882812 / G: 0.757085382938385
1400: Total Loss: 2.151245355606079 / D: 1.4218370914459229 / G: 0.729408323764801
1600: Total Loss: 1.9929895401000977 / D: 1.3996505737304688 / G: 0.5933390259742737
1800: Total Loss: 2.101590156555176 / D: 1.3573167324066162 / G: 0.7442735433578491
2000: Total Loss: 2.146493911743164 / D: 1.4014475345611572 / G: 0.7450463771820068
2200: Total Loss: 1.9973479509353638 / D: 1.3951616287231445 / G: 0.60218632221221

In [47]:
z = 0

In [48]:
generated = G(torch.Tensor([z]));generated

tensor([ 9.9681])

In [49]:
confidence = D(G(torch.Tensor([z]))); confidence

tensor([ 0.4938])