In [1]:
import torch
import torch.nn
import torch.nn.functional as nn
import torch.autograd as autograd
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from torch.autograd import Variable
from tensorflow.examples.tutorials.mnist import input_data


mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)
mb_size = 32
z_dim = 10
X_dim = mnist.train.images.shape[1]
y_dim = mnist.train.labels.shape[1]
h_dim = 128
cnt = 0
lr = 1e-4


G = torch.nn.Sequential(
    torch.nn.Linear(z_dim, h_dim),
    torch.nn.ReLU(),
    torch.nn.Linear(h_dim, X_dim),
    torch.nn.Sigmoid()
)


D = torch.nn.Sequential(
    torch.nn.Linear(X_dim, h_dim),
    torch.nn.ReLU(),
    torch.nn.Linear(h_dim, 1),
)


def reset_grad():
    G.zero_grad()
    D.zero_grad()


G_solver = optim.RMSprop(G.parameters(), lr=lr)
D_solver = optim.RMSprop(D.parameters(), lr=lr)


for it in range(1000000):
    for _ in range(5):
        # Sample data
        z = Variable(torch.randn(mb_size, z_dim))
        X, _ = mnist.train.next_batch(mb_size)
        X = Variable(torch.from_numpy(X))

        # Dicriminator forward-loss-backward-update
        G_sample = G(z)
        D_real = D(X)
        D_fake = D(G_sample)

        D_loss = -(torch.mean(D_real) - torch.mean(D_fake))

        D_loss.backward()
        D_solver.step()

        # Weight clipping
        for p in D.parameters():
            p.data.clamp_(-0.01, 0.01)

        # Housekeeping - reset gradient
        reset_grad()

    # Generator forward-loss-backward-update
    X, _ = mnist.train.next_batch(mb_size)
    X = Variable(torch.from_numpy(X))
    z = Variable(torch.randn(mb_size, z_dim))

    G_sample = G(z)
    D_fake = D(G_sample)

    G_loss = -torch.mean(D_fake)

    G_loss.backward()
    G_solver.step()

    # Housekeeping - reset gradient
    reset_grad()

    # Print and plot every now and then
    if it % 2000 == 0:
        print('Iter-{}; D_loss: {}; G_loss: {}'
              .format(it, D_loss.data.numpy(), G_loss.data.numpy()))

        samples = G(z).data.numpy()[:16]

        fig = plt.figure(figsize=(4, 4))
        gs = gridspec.GridSpec(4, 4)
        gs.update(wspace=0.05, hspace=0.05)

        for i, sample in enumerate(samples):
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

        if not os.path.exists('out/'):
            os.makedirs('out/')

        plt.savefig('out/{}.png'.format(str(cnt).zfill(3)), bbox_inches='tight')
        cnt += 1
        plt.close(fig)

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Instructions for updating:
Please write your own downloading logic.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting ../../MNIST_data\train-images-idx3-ubyte.gz
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting ../../MNIST_data\train-labels-idx1-ubyte.gz
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting ../../MNIST_data\t10k-images-idx3-ubyte.gz
Extracting ../../MNIST_data\t10k-labels-idx1-ubyte.gz
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Iter-0; D_loss: -0.15103307366371155; G_loss: 0.18560005724430084
Iter-2000; D_loss: -0.007170230150222778; G_loss: 0.0647735446691513
Iter-4000; D_loss: -0.0569344162940979; G_loss: -0.026285730302333832
Iter-6000; D_loss: -0.03135409206151962; G_loss: -0.028608782216906548
I

Iter-166000; D_loss: -0.007062479853630066; G_loss: -0.011770438402891159
Iter-168000; D_loss: -0.008834542706608772; G_loss: -0.021600160747766495
Iter-170000; D_loss: -0.0069620730355381966; G_loss: -0.009766893461346626
Iter-172000; D_loss: -0.005577122792601585; G_loss: -0.014154132455587387
Iter-174000; D_loss: -0.008548211306333542; G_loss: -0.02115953341126442
Iter-176000; D_loss: -0.00440484844148159; G_loss: -0.014142677187919617
Iter-178000; D_loss: -0.008657854050397873; G_loss: -0.016980934888124466
Iter-180000; D_loss: -0.006986429914832115; G_loss: -0.01571401208639145
Iter-182000; D_loss: -0.005954526364803314; G_loss: -0.014795857481658459
Iter-184000; D_loss: -0.007927127182483673; G_loss: -0.020439980551600456
Iter-186000; D_loss: -0.006919161416590214; G_loss: -0.01105736568570137
Iter-188000; D_loss: -0.00699373334646225; G_loss: -0.01377302035689354
Iter-190000; D_loss: -0.006946500390768051; G_loss: -0.011396193876862526
Iter-192000; D_loss: -0.005829913541674614;

Iter-388000; D_loss: -0.004220424219965935; G_loss: -0.011734329164028168
Iter-390000; D_loss: -0.005558038130402565; G_loss: -0.011463439092040062
Iter-392000; D_loss: -0.0051558250561356544; G_loss: -0.01378764770925045
Iter-394000; D_loss: -0.0030943695455789566; G_loss: -0.009213672019541264
Iter-396000; D_loss: -0.002217203378677368; G_loss: -0.021612118929624557
Iter-398000; D_loss: -0.008502266369760036; G_loss: -0.015876300632953644
Iter-400000; D_loss: -0.004081017337739468; G_loss: -0.014284889213740826
Iter-402000; D_loss: -0.003024960868060589; G_loss: -0.012510934844613075
Iter-404000; D_loss: -0.0020343028008937836; G_loss: -0.008617954328656197
Iter-406000; D_loss: -0.0021936800330877304; G_loss: -0.01839528977870941
Iter-408000; D_loss: -0.0031514307484030724; G_loss: -0.010939806699752808
Iter-410000; D_loss: -0.0061295367777347565; G_loss: -0.018334753811359406
Iter-412000; D_loss: -0.0073374006897211075; G_loss: -0.009762465953826904
Iter-414000; D_loss: -0.005367908

KeyboardInterrupt: 