Notebook following this blogpost: <br>
https://medium.com/@devnag/generative-adversarial-networks-gans-in-50-lines-of-code-pytorch-e81b79659e3f

In [1]:
import random
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torchvision
import torchvision.transforms as transforms
from torch.autograd import Variable

## Defining helper functions

In [2]:
# QUESTION: what does this function do?
def decorate_with_diffs(data, exponent):
    mean = torch.mean(data.data, 1, keepdim=True)
    mean_broadcast = torch.mul(torch.ones(data.size()), mean.tolist()[0][0])
    diffs = torch.pow(data - Variable(mean_broadcast), exponent)
    return torch.cat([data, diffs], 1)

In [3]:
# returns a list from a variable object
def extract(v):
    return v.data.storage().tolist()

# returns mu,variance for a given array (not variable!)
# works with extract, see example below
def stats(d):
    return [np.mean(d), np.std(d)]

In [4]:
x = Variable(torch.Tensor(torch.randn(2,2)))
print("Our tensor:", x, sep ="\n")
print(extract(x))
stats(extract(x))

Our tensor:
Variable containing:
 0.0271 -0.8832
 0.7581 -0.2634
[torch.FloatTensor of size 2x2]

[0.027085527777671814, -0.883240818977356, 0.7581320405006409, -0.2634291648864746]


[-0.090363103896379471, 0.58998837354775158]

## Parameters

In [5]:
# Data params
data_mean = 4
data_stddev = 1.25

# Model params
g_input_size = 1     # Random noise dimension coming into generator, per output vector
g_hidden_size = 50   # Generator complexity
g_output_size = 1    # size of generated output vector
d_input_size = 100   # Minibatch size - cardinality of distributions
d_hidden_size = 50   # Discriminator complexity
d_output_size = 1    # Single dimension for 'real' vs. 'fake'
minibatch_size = d_input_size

d_learning_rate = 2e-4  # 2e-4
g_learning_rate = 2e-4
optim_betas = (0.9, 0.999)
num_epochs = 30000
print_interval = 200
d_steps = 1  # 'k' steps in the original GAN paper. Can put the discriminator on higher training freq than generator
g_steps = 1

In [6]:
# Uncomment only one of these
(name, preprocess, d_input_func) = ("Raw data", lambda data: data, lambda x: x)
#(name, preprocess, d_input_func) = ("Data and variances", lambda data: decorate_with_diffs(data, 2.0), lambda x: x * 2)

print("Using data [%s]" % (name))

Using data [Raw data]


In [7]:
#returns gaussian function with mean mu and var sigma
def get_distribution_sampler(mu, sigma):
    return lambda n: torch.Tensor(np.random.normal(mu, sigma, (1, n)))  # Gaussian


#returns a funciton capable of generation random numbers
def get_generator_input_sampler():
    return lambda m, n: torch.rand(m, n) # Uniform-dist data into generator, _NOT_ Gaussian

### Models

In [8]:
class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.elu(self.map1(x))
        x = F.sigmoid(self.map2(x))
        return self.map3(x) # no activations!

In [9]:
class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Discriminator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.elu(self.map1(x))
        x = F.elu(self.map2(x))
        return F.sigmoid(self.map3(x)) 

In [10]:
# create samplers for networks
d_sampler = get_distribution_sampler(data_mean, data_stddev)
gi_sampler = get_generator_input_sampler()

# create networks, see params higher
G = Generator(input_size=g_input_size, hidden_size=g_hidden_size, output_size=g_output_size)
D = Discriminator(input_size=d_input_func(d_input_size), hidden_size=d_hidden_size, output_size=d_output_size)

#define optimizier for both networks
criterion = nn.BCELoss()  # Binary cross entropy: http://pytorch.org/docs/nn.html#bceloss
d_optimizer = optim.Adam(D.parameters(), lr=d_learning_rate, betas=optim_betas)
g_optimizer = optim.Adam(G.parameters(), lr=g_learning_rate, betas=optim_betas)

In [11]:
for epoch in range(num_epochs):
    
    for d_index in range(d_steps):
        # 1. Train D on real+fake
        D.zero_grad()

        #  1A: Train D on real
        d_real_data = Variable(d_sampler(d_input_size))
        d_real_decision = D(preprocess(d_real_data))
        d_real_error = criterion(d_real_decision, Variable(torch.ones(1)))  # ones = true
        d_real_error.backward() # compute/store gradients, but don't change params

        #  1B: Train D on fake
        d_gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
        d_fake_data = G(d_gen_input).detach()  # detach to avoid training G on these labels
        d_fake_decision = D(preprocess(d_fake_data.t()))
        d_fake_error = criterion(d_fake_decision, Variable(torch.zeros(1)))  # zeros = fake
        d_fake_error.backward()
        d_optimizer.step()     # Only optimizes D's parameters; changes based on stored gradients from backward()

        
    for g_index in range(g_steps):
        # 2. Train G on D's response (but DO NOT train D on these labels)
        G.zero_grad()

        gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
        g_fake_data = G(gen_input)
        dg_fake_decision = D(preprocess(g_fake_data.t()))
        g_error = criterion(dg_fake_decision, Variable(torch.ones(1)))  # we want to fool, so pretend it's all genuine

        g_error.backward()
        g_optimizer.step()  # Only optimizes G's parameters

    if epoch % print_interval == 0:
        print("%s: D: %s/%s G: %s \n(Real: %s, Fake: %s) " % (epoch,
                                                            extract(d_real_error)[0],
                                                            extract(d_fake_error)[0],
                                                            extract(g_error)[0],
                                                            stats(extract(d_real_data)),
                                                            stats(extract(d_fake_data))))

0: D: 1.2101160287857056/0.6485859155654907 G: 0.7444975972175598 
(Real: [4.0160312485694885, 1.2764931997490707], Fake: [-0.2030932642519474, 0.001100026122854983]) 


  "Please ensure they have the same size.".format(target.size(), input.size()))


200: D: 0.17556071281433105/0.5551384091377258 G: 0.8623489141464233 
(Real: [3.9404023426771162, 1.2416657755338121], Fake: [1.3736900115013122, 0.070557733591906452]) 
400: D: 0.8309756517410278/0.8656553626060486 G: 0.5659179091453552 
(Real: [4.0358939003944396, 1.2840827517909839], Fake: [3.665249991416931, 0.11772644355718874]) 
600: D: 0.8060558438301086/0.8244041800498962 G: 0.5962673425674438 
(Real: [4.1958333897590636, 1.3143381074060152], Fake: [5.277332954406738, 0.10932349391769729]) 
800: D: 0.9875590205192566/0.4799656271934509 G: 1.0010446310043335 
(Real: [3.873040654063225, 1.2243600347548818], Fake: [5.9661702203750613, 0.10097086839537393]) 
1000: D: 0.7064688801765442/0.40801751613616943 G: 1.1727526187896729 
(Real: [4.0906891226768494, 1.1153377969224756], Fake: [5.6164121770858761, 0.10323083813304613]) 
1200: D: 0.40916234254837036/0.6069835424423218 G: 0.8213219046592712 
(Real: [4.0312618210911753, 1.3107701121734781], Fake: [4.241436047554016, 0.11454982961

10000: D: 0.7772295475006104/0.7340304851531982 G: 0.5596606135368347 
(Real: [4.0489641654491422, 1.2296461622883443], Fake: [3.046853095293045, 1.2732438684621783]) 
10200: D: 0.8809556365013123/0.4738127887248993 G: 0.8162707090377808 
(Real: [4.0264144527912142, 1.0627024098836459], Fake: [4.772911756038666, 1.128546453457679]) 
10400: D: 0.6570205092430115/0.7955340147018433 G: 0.5795419216156006 
(Real: [3.9849056173861026, 1.355779374991402], Fake: [3.2306370759010314, 1.2181721279144688]) 
10600: D: 0.7312461733818054/0.6392481327056885 G: 0.7741319537162781 
(Real: [4.0229061871767042, 1.285692880132558], Fake: [4.808013472557068, 1.0787925434220746]) 
10800: D: 0.7360827922821045/0.71989905834198 G: 0.6758870482444763 
(Real: [4.2712983047962192, 1.2936568086199165], Fake: [3.655003753900528, 1.2057611452181296]) 
11000: D: 0.7715901136398315/0.6487645506858826 G: 0.9155619144439697 
(Real: [3.8807445156574247, 1.2592725705443231], Fake: [4.4458279657363891, 1.117475393645885

19800: D: 0.835269570350647/0.5151991248130798 G: 0.6350280046463013 
(Real: [3.8301703524589539, 1.2480546073649941], Fake: [4.3116613805294035, 1.1643201445779585]) 
20000: D: 0.6759535074234009/0.7452389597892761 G: 0.6394375562667847 
(Real: [3.9457636463642118, 1.3692178061722438], Fake: [3.3674484515190124, 1.3482518133976642]) 
20200: D: 0.7020516991615295/0.6686930060386658 G: 0.690633237361908 
(Real: [4.0230103302001954, 1.2289659482187014], Fake: [3.9353762686252596, 1.2018887976381365]) 
20400: D: 0.737435519695282/0.6922870874404907 G: 0.8274880051612854 
(Real: [3.8425655543804167, 1.0942578549386885], Fake: [4.3888791835308076, 1.1774804242792212]) 
20600: D: 0.6452497243881226/0.6285393834114075 G: 0.7303392291069031 
(Real: [3.9005243074893952, 1.3000646815722825], Fake: [4.2294338858127594, 1.2225968897719497]) 
20800: D: 0.6966481804847717/0.6665687561035156 G: 0.7061999440193176 
(Real: [3.9519964307546616, 1.2246550064894532], Fake: [3.6354426729679106, 1.223393232

29600: D: 0.6855615377426147/0.8035569190979004 G: 0.651859700679779 
(Real: [4.1214955931901933, 1.2953032428029383], Fake: [4.0347378504276277, 1.1667436116449288]) 
29800: D: 0.6599339842796326/0.740149974822998 G: 0.6223009824752808 
(Real: [3.8921004545688631, 1.2139082858719474], Fake: [4.2401549530029294, 1.234683474612206]) 


In [13]:
#TODO: visualize learning using matplotlib?