In [2]:
import torch
import torch.nn
import torch.nn.functional as nn
import torch.autograd as autograd
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from torch.autograd import Variable
from tensorflow.examples.tutorials.mnist import input_data


mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)
mb_size = 32
z_dim = 10
X_dim = mnist.train.images.shape[1]
y_dim = mnist.train.labels.shape[1]
h_dim = 128
cnt = 0
d_step = 3
lr = 1e-3
m = 5


G = torch.nn.Sequential(
    torch.nn.Linear(z_dim, h_dim),
    torch.nn.ReLU(),
    torch.nn.Linear(h_dim, X_dim),
    torch.nn.Sigmoid()
)

# D is an autoencoder
D_ = torch.nn.Sequential(
    torch.nn.Linear(X_dim, h_dim),
    torch.nn.ReLU(),
    torch.nn.Linear(h_dim, X_dim),
)


# Energy is the MSE of autoencoder
def D(X):
    X_recon = D_(X)
    return torch.mean(torch.sum((X - X_recon)**2, 1))


def reset_grad():
    G.zero_grad()
    D_.zero_grad()


G_solver = optim.Adam(G.parameters(), lr=lr)
D_solver = optim.Adam(D_.parameters(), lr=lr)


for it in range(1000000):
    # Sample data
    z = Variable(torch.randn(mb_size, z_dim))
    X, _ = mnist.train.next_batch(mb_size)
    X = Variable(torch.from_numpy(X))

    # Dicriminator
    G_sample = G(z)
    D_real = D(X)
    D_fake = D(G_sample)

    # EBGAN D loss. D_real and D_fake is energy, i.e. a number
    D_loss = D_real + nn.relu(m - D_fake)

    # Reuse D_fake for generator loss
    D_loss.backward()
    D_solver.step()
    reset_grad()

    # Generator
    G_sample = G(z)
    D_fake = D(G_sample)

    G_loss = D_fake

    G_loss.backward()
    G_solver.step()
    reset_grad()

    # Print and plot every now and then
    if it % 1000 == 0:
        print('Iter-{}; D_loss: {}; G_loss: {}'.format(it, D_loss.data.numpy(), G_loss.data.numpy()))

        samples = G(z).data.numpy()[:16]

        fig = plt.figure(figsize=(4, 4))
        gs = gridspec.GridSpec(4, 4)
        gs.update(wspace=0.05, hspace=0.05)

        for i, sample in enumerate(samples):
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

        if not os.path.exists('out/'):
            os.makedirs('out/')

        plt.savefig('out/{}.png'.format(str(cnt).zfill(3)), bbox_inches='tight')
        cnt += 1
        plt.close(fig)

Extracting ../../MNIST_data\train-images-idx3-ubyte.gz
Extracting ../../MNIST_data\train-labels-idx1-ubyte.gz
Extracting ../../MNIST_data\t10k-images-idx3-ubyte.gz
Extracting ../../MNIST_data\t10k-labels-idx1-ubyte.gz
Iter-0; D_loss: 87.01187133789062; G_loss: 200.72572326660156
Iter-1000; D_loss: 9.593932151794434; G_loss: 4.792764186859131
Iter-2000; D_loss: 9.009757041931152; G_loss: 6.650607109069824
Iter-3000; D_loss: 9.635393142700195; G_loss: 5.074335098266602
Iter-4000; D_loss: 8.87974739074707; G_loss: 6.062475681304932
Iter-5000; D_loss: 8.188961029052734; G_loss: 5.408221244812012
Iter-6000; D_loss: 8.152165412902832; G_loss: 4.6884636878967285
Iter-7000; D_loss: 8.791720390319824; G_loss: 5.422101974487305
Iter-8000; D_loss: 8.620965957641602; G_loss: 5.733882427215576
Iter-9000; D_loss: 8.284038543701172; G_loss: 5.141046047210693
Iter-10000; D_loss: 7.521677494049072; G_loss: 4.86089563369751
Iter-11000; D_loss: 7.971253395080566; G_loss: 5.814963340759277
Iter-12000; D_l

Iter-123000; D_loss: 7.792487621307373; G_loss: 4.829008102416992
Iter-124000; D_loss: 7.800978183746338; G_loss: 5.7943644523620605
Iter-125000; D_loss: 8.84516429901123; G_loss: 5.818871021270752
Iter-126000; D_loss: 7.659824371337891; G_loss: 5.272951602935791
Iter-127000; D_loss: 8.15914535522461; G_loss: 5.27876615524292
Iter-128000; D_loss: 8.045501708984375; G_loss: 5.038029670715332
Iter-129000; D_loss: 7.89641809463501; G_loss: 4.36964750289917
Iter-130000; D_loss: 7.641613483428955; G_loss: 4.992758750915527
Iter-131000; D_loss: 7.728170871734619; G_loss: 5.320935249328613
Iter-132000; D_loss: 8.092809677124023; G_loss: 4.959162712097168
Iter-133000; D_loss: 7.117520332336426; G_loss: 5.339357852935791
Iter-134000; D_loss: 8.956867218017578; G_loss: 4.69159460067749
Iter-135000; D_loss: 8.063344955444336; G_loss: 5.167698383331299
Iter-136000; D_loss: 6.673223495483398; G_loss: 4.495510578155518
Iter-137000; D_loss: 7.207250595092773; G_loss: 5.5955810546875
Iter-138000; D_lo

Iter-248000; D_loss: 8.528234481811523; G_loss: 5.4738311767578125
Iter-249000; D_loss: 7.396669387817383; G_loss: 4.876051902770996
Iter-250000; D_loss: 6.875120639801025; G_loss: 4.931817054748535
Iter-251000; D_loss: 7.316567897796631; G_loss: 5.09701681137085
Iter-252000; D_loss: 8.660661697387695; G_loss: 5.339004993438721
Iter-253000; D_loss: 7.099618434906006; G_loss: 5.79254674911499
Iter-254000; D_loss: 7.975097179412842; G_loss: 4.763156890869141
Iter-255000; D_loss: 7.571772575378418; G_loss: 5.274129867553711
Iter-256000; D_loss: 8.11617660522461; G_loss: 4.8407816886901855
Iter-257000; D_loss: 7.786869525909424; G_loss: 5.0662336349487305
Iter-258000; D_loss: 7.949109077453613; G_loss: 4.218548774719238
Iter-259000; D_loss: 7.517273426055908; G_loss: 4.935765743255615
Iter-260000; D_loss: 7.841581344604492; G_loss: 5.0828447341918945
Iter-261000; D_loss: 7.562607765197754; G_loss: 4.837146282196045
Iter-262000; D_loss: 7.683197498321533; G_loss: 4.8740715980529785
Iter-263

Iter-373000; D_loss: 7.77850866317749; G_loss: 4.09946870803833
Iter-374000; D_loss: 7.667840480804443; G_loss: 4.73340368270874
Iter-375000; D_loss: 7.894171237945557; G_loss: 5.207218647003174
Iter-376000; D_loss: 7.654257774353027; G_loss: 5.080494403839111
Iter-377000; D_loss: 7.968121528625488; G_loss: 7.493031978607178
Iter-378000; D_loss: 7.731637001037598; G_loss: 4.977372169494629
Iter-379000; D_loss: 8.009050369262695; G_loss: 5.303049087524414
Iter-380000; D_loss: 7.6902947425842285; G_loss: 5.015714645385742
Iter-381000; D_loss: 9.086913108825684; G_loss: 4.407630443572998
Iter-382000; D_loss: 6.747997760772705; G_loss: 6.220824241638184
Iter-383000; D_loss: 7.788740634918213; G_loss: 5.405364990234375
Iter-384000; D_loss: 7.8143768310546875; G_loss: 5.38282585144043
Iter-385000; D_loss: 7.511729717254639; G_loss: 4.7615647315979
Iter-386000; D_loss: 7.90072774887085; G_loss: 5.769272804260254
Iter-387000; D_loss: 7.969110012054443; G_loss: 5.984823226928711
Iter-388000; D_

Iter-498000; D_loss: 7.298853397369385; G_loss: 4.716330051422119
Iter-499000; D_loss: 8.146215438842773; G_loss: 4.551661491394043
Iter-500000; D_loss: 5.87896728515625; G_loss: 5.499879360198975
Iter-501000; D_loss: 7.321591377258301; G_loss: 5.700683116912842
Iter-502000; D_loss: 8.454724311828613; G_loss: 6.387269973754883


KeyboardInterrupt: 