In [None]:
!pip install -q torch numpy matplotlib tensorflow

In [2]:
import torch
import torch.autograd as autograd
import torch.nn.functional as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
import numpy as np
from torch.autograd import Variable
import tensorflow as tf
#from torchvision import datasets, transforms
#from tensorflow.examples.tutorials.mnist import input_data

In [3]:
# Check for GPU and set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
# Load MNIST dataset

mnist_data = tf.keras.datasets.mnist
(X_train,_),(X_test,_) = mnist_data.load_data()

train_loader = X_train.reshape(-1,28*28) / 255.0
test_loader = X_test.reshape(-1,28*28) / 255.0

mb_size = 64
Z_dim = 100
X_dim = 28*28
h_dim = 128
c = 0
lr = 1e-3




cpu
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [4]:
def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / torch.sqrt(torch.tensor(in_dim / 2., device=device))
    return torch.nn.Parameter(torch.randn(*size, device=device) * xavier_stddev)

In [8]:
""" ==================== GENERATOR ======================== """

Wzh = xavier_init(size=[Z_dim, h_dim])
bzh = Variable(torch.zeros(h_dim), requires_grad=True)
Whx = xavier_init(size=[h_dim, X_dim])
bhx = Variable(torch.zeros(X_dim), requires_grad=True)

def G(z):
    h = nn.relu(z @ Wzh + bzh.repeat(z.size(0), 1))
    X = nn.sigmoid(h @ Whx + bhx.repeat(h.size(0), 1))
    return X

In [9]:
""" ==================== DISCRIMINATOR ======================== """

Wxh = xavier_init(size=[X_dim, h_dim])
bxh = Variable(torch.zeros(h_dim), requires_grad=True)
Why = xavier_init(size=[h_dim, 1])
bhy = Variable(torch.zeros(1), requires_grad=True)

def D(X):
    h = nn.relu(X @ Wxh + bxh.repeat(X.size(0), 1))
    y = nn.sigmoid(h @ Why + bhy.repeat(h.size(0), 1))
    return y


G_params = [Wzh, bzh, Whx, bhx]
D_params = [Wxh, bxh, Why, bhy]
params = G_params + D_params



In [14]:
""" ===================== 2.2 TRAINING logistic loss 20K/100K======================== """
def reset_grad():
    for p in params:
        if p.grad is not None:
            data = p.grad.data
            p.grad = Variable(data.new().resize_as_(data).zero_())


G_solver = optim.Adam(G_params, lr=1e-3)
D_solver = optim.Adam(D_params, lr=1e-3)

ones_label = Variable(torch.ones(mb_size, 1))
zeros_label = Variable(torch.zeros(mb_size, 1))


for it in range(100000):
    # Sample data
    z = Variable(torch.randn(mb_size, Z_dim))
    idx = np.random.randint(0, train_loader.shape[0], mb_size)
    X = Variable(torch.from_numpy(train_loader[idx]).float())

    # Dicriminator forward-loss-backward-update
    G_sample = G(z)
    D_real = D(X)
    D_fake = D(G_sample)

    D_loss_real = nn.binary_cross_entropy(D_real, ones_label)
    D_loss_fake = nn.binary_cross_entropy(D_fake, zeros_label)
    D_loss = D_loss_real + D_loss_fake

    D_loss.backward()
    D_solver.step()

    # Housekeeping - reset gradient
    reset_grad()

    # Generator forward-loss-backward-update
    z = Variable(torch.randn(mb_size, Z_dim))
    G_sample = G(z)
    D_fake = D(G_sample)

    G_loss = nn.binary_cross_entropy(D_fake, ones_label)

    G_loss.backward()
    G_solver.step()

    # Housekeeping - reset gradient
    reset_grad()

    # Print and plot every now and then
    if it % 1000 == 0:
        print('Iter-{}; D_loss: {}; G_loss: {}'.format(it, D_loss.data.numpy(), G_loss.data.numpy()))

        samples = G(z).data.numpy()[:16]

        fig = plt.figure(figsize=(4, 4))
        gs = gridspec.GridSpec(4, 4)
        gs.update(wspace=0.05, hspace=0.05)

        for i, sample in enumerate(samples):
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

        if not os.path.exists('out/'):
            os.makedirs('out/')

        plt.savefig('out/{}.png'.format(str(c).zfill(3)), bbox_inches='tight')
        c += 1
        plt.close(fig)

Iter-0; D_loss: 0.614266574382782; G_loss: 2.2302069664001465
Iter-1000; D_loss: 0.7152239680290222; G_loss: 2.1338486671447754
Iter-2000; D_loss: 0.5606213808059692; G_loss: 2.1985132694244385
Iter-3000; D_loss: 0.6654432415962219; G_loss: 2.538466691970825
Iter-4000; D_loss: 0.39272552728652954; G_loss: 2.7692525386810303
Iter-5000; D_loss: 0.5662136673927307; G_loss: 2.278355121612549
Iter-6000; D_loss: 0.42721202969551086; G_loss: 2.560878038406372
Iter-7000; D_loss: 0.5624514818191528; G_loss: 2.361584424972534
Iter-8000; D_loss: 0.47946757078170776; G_loss: 2.465975046157837
Iter-9000; D_loss: 0.6758779883384705; G_loss: 2.2267396450042725
Iter-10000; D_loss: 0.6571389436721802; G_loss: 2.432741165161133
Iter-11000; D_loss: 0.5256248712539673; G_loss: 2.5864264965057373
Iter-12000; D_loss: 0.48997095227241516; G_loss: 2.630333185195923
Iter-13000; D_loss: 0.446507066488266; G_loss: 2.7108314037323
Iter-14000; D_loss: 0.4788537323474884; G_loss: 2.591122627258301
Iter-15000; D_los