In [19]:
import torch
import torch.autograd as autograd
import torch.nn.functional as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
import numpy as np
from torch.autograd import Variable


import tensorflow as tf
from torchvision import datasets, transforms

In [6]:
# Check for GPU and set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# print(device)
# Load MNIST dataset

mnist_data = tf.keras.datasets.mnist
(X_train, y_train), (X_test, y_test) = mnist_data.load_data()

# Preprocess the data
X_train = X_train.reshape(-1, 28 * 28) / 255.0
X_test = X_test.reshape(-1, 28 * 28) / 255.0

# Convert class vectors to binary class matrices (one-hot encoding)
num_classes = 10 # There are 10 classes in MNIST
y_train_one_hot = tf.keras.utils.to_categorical(y_train, num_classes)
y_test_one_hot = tf.keras.utils.to_categorical(y_test, num_classes)

mb_size = 64
Z_dim = 100
X_dim = 28*28
y_dim = num_classes
h_dim = 128
cnt = 0
lr = 1e-3

In [20]:
def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / np.sqrt(in_dim / 2.)
    return Variable(torch.randn(*size) * xavier_stddev, requires_grad=True)

In [21]:
""" ==================== GENERATOR ======================== """

Wzh = xavier_init(size=[Z_dim + y_dim, h_dim])
bzh = Variable(torch.zeros(h_dim), requires_grad=True)

Whx = xavier_init(size=[h_dim, X_dim])
bhx = Variable(torch.zeros(X_dim), requires_grad=True)


def G(z, c):
    inputs = torch.cat([z, c], 1)
    h = nn.relu(inputs @ Wzh + bzh.repeat(inputs.size(0), 1))
    X = nn.sigmoid(h @ Whx + bhx.repeat(h.size(0), 1))
    return X


""" ==================== DISCRIMINATOR ======================== """

Wxh = xavier_init(size=[X_dim + y_dim, h_dim])
bxh = Variable(torch.zeros(h_dim), requires_grad=True)

Why = xavier_init(size=[h_dim, 1])
bhy = Variable(torch.zeros(1), requires_grad=True)


def D(X, c):
    inputs = torch.cat([X, c], 1)
    h = nn.relu(inputs @ Wxh + bxh.repeat(inputs.size(0), 1))
    y = nn.sigmoid(h @ Why + bhy.repeat(h.size(0), 1))
    return y


G_params = [Wzh, bzh, Whx, bhx]
D_params = [Wxh, bxh, Why, bhy]
params = G_params + D_params




In [23]:
""" ===================== TRAINING ======================== """


def reset_grad():
    for p in params:
        if p.grad is not None:
            data = p.grad.data
            p.grad = Variable(data.new().resize_as_(data).zero_())


G_solver = optim.Adam(G_params, lr=1e-3)
D_solver = optim.Adam(D_params, lr=1e-3)

ones_label = Variable(torch.ones(mb_size, 1))
zeros_label = Variable(torch.zeros(mb_size, 1))


for it in range(100000):
    # Sample data
    z = Variable(torch.randn(mb_size, Z_dim))
    idx = np.random.randint(0, X_train.shape[0], mb_size)
    X = Variable(torch.from_numpy(X_train[idx]).float())
    c = Variable(torch.from_numpy(y_train_one_hot[idx]).float())

    # Dicriminator forward-loss-backward-update
    G_sample = G(z, c)
    D_real = D(X, c)
    D_fake = D(G_sample, c)

    D_loss_real = nn.binary_cross_entropy(D_real, ones_label)
    D_loss_fake = nn.binary_cross_entropy(D_fake, zeros_label)
    D_loss = D_loss_real + D_loss_fake

    D_loss.backward()
    D_solver.step()

    # Housekeeping - reset gradient
    reset_grad()

    # Generator forward-loss-backward-update
    z = Variable(torch.randn(mb_size, Z_dim))
    G_sample = G(z, c)
    D_fake = D(G_sample, c)

    G_loss = nn.binary_cross_entropy(D_fake, ones_label)

    G_loss.backward()
    G_solver.step()

    # Housekeeping - reset gradient
    reset_grad()

    # Print and plot every now and then
    if it % 1000 == 0:
        print('Iter-{}; D_loss: {}; G_loss: {}'.format(it, D_loss.data.numpy(), G_loss.data.numpy()))

        c = np.zeros(shape=[mb_size, y_dim], dtype='float32')
        c[:, np.random.randint(0, 10)] = 1.
        c = Variable(torch.from_numpy(c))
        samples = G(z, c).data.numpy()[:16]

        fig = plt.figure(figsize=(4, 4))
        gs = gridspec.GridSpec(4, 4)
        gs.update(wspace=0.05, hspace=0.05)

        for i, sample in enumerate(samples):
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

        if not os.path.exists('out/'):
            os.makedirs('out/')

        plt.savefig('out/{}.png'.format(str(cnt).zfill(3)), bbox_inches='tight')
        cnt += 1
        plt.close(fig)

Iter-0; D_loss: 0.005447774659842253; G_loss: 6.082343578338623
Iter-1000; D_loss: 0.0003743061679415405; G_loss: 10.756275177001953
Iter-2000; D_loss: 0.0024052676744759083; G_loss: 9.610559463500977
Iter-3000; D_loss: 0.001010522129945457; G_loss: 9.340376853942871
Iter-4000; D_loss: 0.009231614880263805; G_loss: 10.579803466796875
Iter-5000; D_loss: 0.11352275311946869; G_loss: 8.274649620056152
Iter-6000; D_loss: 0.01767956092953682; G_loss: 7.710787773132324
Iter-7000; D_loss: 0.34597349166870117; G_loss: 5.000637054443359
Iter-8000; D_loss: 0.3477376699447632; G_loss: 4.466437339782715
Iter-9000; D_loss: 0.29114729166030884; G_loss: 4.2505083084106445
Iter-10000; D_loss: 0.7813863754272461; G_loss: 3.4069857597351074
Iter-11000; D_loss: 0.5685049295425415; G_loss: 3.1460330486297607
Iter-12000; D_loss: 0.8054193258285522; G_loss: 2.4995107650756836
Iter-13000; D_loss: 1.0170000791549683; G_loss: 2.715989351272583
Iter-14000; D_loss: 0.8165727853775024; G_loss: 2.09181809425354
It