# Simple GAN

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from matplotlib import pyplot as plt

We will apply a gan to match a Gaussian 0,1 (noise) to a Gaussian 10,1 (real data distribution) 

## Prep Data

In [2]:
def get_generated_data_sample(n):
    return torch.Tensor(np.random.normal(0, 1, (n,1)))

In [3]:
# Real data here follows a Gaussian Distribution of mean 10, std 1
def get_real_data_sample(n):
    return torch.Tensor(np.random.normal(10, 1, (n,1)))

In [4]:
get_generated_data_sample(1),get_real_data_sample(1)

(tensor([[-1.7694]]), tensor([[ 11.5412]]))

## Define Generator and Discriminator Classes

In [5]:
class Generator(nn.Module):
    
    def __init__(self):
        super(Generator, self).__init__()

        self.lt1 = nn.Linear(1, 50)
        self.lt2 = nn.Linear(50, 1)

    def forward(self, x):
        
        x = F.relu(self.lt1(x))
        return self.lt2(x)
    

In [6]:
class Discriminator(nn.Module):
    
    def __init__(self):
        super(Discriminator, self).__init__()
        self.lt1 = nn.Linear(1, 50)
        self.lt2 = nn.Linear(50, 50)
        self.lt3 = nn.Linear(50, 1)

    def forward(self, x):
        x = F.relu(self.lt1(x))
        x = F.relu(self.lt2(x))
        return F.sigmoid(self.lt3(x))

In [9]:
G = Generator()
D = Discriminator()

In [11]:
loss = nn.BCELoss()  # Binary cross entropy: http://pytorch.org/docs/nn.html#bceloss

In [13]:
d_optimizer = optim.Adam(D.parameters(), lr=1e-3)
g_optimizer = optim.Adam(G.parameters(), lr=1e-3)

## Define Our Training Code

In [14]:
def train_discriminator(num_epochs = 0):
    global batch_size
    
    for d_index in range(num_epochs):

        # 1. Train D on real+fake
        D.zero_grad()

        #  A: Train D on real
        d_real_data = Variable(get_real_data_sample(batch_size))
        d_real_decision = D(d_real_data)
        d_real_error = loss(d_real_decision, Variable(torch.ones(batch_size,1)))  # ones = true
        d_real_error.backward() # compute/store gradients, but don't change params

        #  B: Train D on fake
        d_gen_input = Variable(get_generated_data_sample(batch_size))
        d_fake_data = G(d_gen_input).detach()  # detach to avoid training G on these labels
        d_fake_decision = D(d_fake_data)
        d_fake_error = loss(d_fake_decision, Variable(torch.zeros(batch_size,1)))  # zeros = fake
        d_fake_error.backward()

        d_optimizer.step()    # Only optimizes D's parameters; changes based on stored gradients from backward()

        return d_real_error+d_fake_error

In [19]:
def train_generator(num_epochs = 0):
    global batch_size
    
    for g_index in range(num_epochs):

        # 2. Train G on D's response
        G.zero_grad()

        gen_input = Variable(get_generated_data_sample(batch_size))
        g_fake_data = G(gen_input)
        d_fake_decision = D(g_fake_data)
        
        # here we want to fool, so we need to make the generator go towards producing 1s
        
        
        g_error = loss(d_fake_decision, Variable(torch.ones(batch_size,1))) 

        g_error.backward()
        g_optimizer.step()  # Only optimizes G's parameters

        return g_error

In [20]:
def train(num_epochs):
    for epoch in range(num_epochs):

        d_loss = train_discriminator(2) #train discriminator with a 2:1 ratio
        g_loss = train_generator(1)


        if epoch % print_interval == 0:
            print("{}: Total Loss: {} / D: {} / G: {}".format(epoch,(d_loss+g_loss).item(),d_loss.item(),g_loss.item()))

## TRAIN

In [21]:
num_epochs = 2000
batch_size = 64
d_steps = 2  # 'k' steps in the original GAN paper. Can put the discriminator on higher training freq than generator
g_steps = 1
print_interval = 200

In [22]:
train(num_epochs)

0: Total Loss: 1.9231290817260742 / D: 1.2399563789367676 / G: 0.6831727027893066
200: Total Loss: 2.7783594131469727 / D: 0.42315608263015747 / G: 2.35520339012146
400: Total Loss: 2.113056182861328 / D: 1.3129521608352661 / G: 0.8001041412353516
600: Total Loss: 2.11081862449646 / D: 1.3895102739334106 / G: 0.7213083505630493
800: Total Loss: 2.2005233764648438 / D: 1.3979523181915283 / G: 0.8025710582733154
1000: Total Loss: 2.0480828285217285 / D: 1.3868937492370605 / G: 0.661189079284668
1200: Total Loss: 2.110715866088867 / D: 1.3879331350326538 / G: 0.7227827310562134
1400: Total Loss: 2.048593521118164 / D: 1.388087272644043 / G: 0.6605063080787659
1600: Total Loss: 2.025691032409668 / D: 1.3892383575439453 / G: 0.6364527344703674
1800: Total Loss: 2.0215823650360107 / D: 1.3870400190353394 / G: 0.6345422863960266


In [23]:
generated = G(torch.Tensor([0]));generated

tensor([ 7.9089])

In [24]:
confidence = D(torch.Tensor([0])); confidence

tensor([ 0.1084])