In [5]:

import torch
import torch.nn.functional as nn
import torch.autograd as autograd
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from torch.autograd import Variable
import tensorflow as tf
from torch.utils.data import Dataset
from torchvision import datasets
#import tensorflow_datasets as tfds
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from tensorflow.keras.utils import to_categorical

In [36]:

mb_size = 64
Z_dim = 100

h_dim = 128
c = 0
lr = 1e-3
X_dim = 28*28
y_dim = 1
training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)
train_dataloader = DataLoader(training_data, batch_size=mb_size, shuffle=True)

In [37]:

data = tf.keras.datasets.mnist.load_data()
(x_train, y_train), (x_test, y_test) = data
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
mb_size = 64
Z_dim = 100
test = next(train_dataset.batch(mb_size).as_numpy_iterator())
X_dim = 28*28
y_dim = 1
print(X_dim,y_dim)
h_dim = 128
c = 0
lr = 1e-3

784 1


In [38]:
def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / np.sqrt(in_dim / 2.)
    xavier = torch.randn(*size) * xavier_stddev
    xavier.requires_grad = True
    return (xavier)


""" ==================== GENERATOR ======================== """

Wzh = xavier_init(size=[Z_dim, h_dim])
bzh = (torch.zeros(h_dim))
bzh.requires_grad = True

Whx = xavier_init(size=[h_dim, X_dim])
bhx = (torch.zeros(X_dim))
bhx.requires_grad = True

def G(z):
    h = torch.relu(z @ Wzh + bzh.repeat(z.size(0), 1))
    X = torch.sigmoid(h @ Whx + bhx.repeat(h.size(0), 1))
    return X


""" ==================== DISCRIMINATOR ======================== """

Wxh = xavier_init(size=[X_dim, h_dim])
bxh = (torch.zeros(h_dim))
bxh.requires_grad = True

Why = xavier_init(size=[h_dim, y_dim])
bhy = (torch.zeros(y_dim))
bhy.requires_grad = True

def D(X):
    h = torch.relu(X @ Wxh + bxh.repeat(X.size(0), 1))
    y = torch.sigmoid(h @ Why + bhy.repeat(h.size(0), 1))
    return y


G_params = [Wzh, bzh, Whx, bhx]
D_params = [Wxh, bxh, Why, bhy]
params = G_params + D_params


""" ===================== TRAINING ======================== """


def reset_grad():
    for p in params:
        if p.grad is not None:
            data = p.grad.data
            p.grad = (data.new().resize_as_(data).zero_())


G_solver = optim.Adam(G_params, lr=1e-3)
D_solver = optim.Adam(D_params, lr=1e-3)

ones_label = (torch.ones(mb_size, y_dim))
zeros_label = (torch.zeros(mb_size, y_dim))

In [39]:

for it in range(100000):
    # Sample data
    z = (torch.randn(mb_size, Z_dim))
    X, _ = next(iter(train_dataloader))
    X = X.reshape(X.shape[0],X_dim)
    X = torch.as_tensor(X,dtype=torch.float)
    # Dicriminator forward-loss-backward-update
    G_sample = G(z)
    D_real = D(X)
    D_fake = D(G_sample)

    #D_loss_real = nn.binary_cross_entropy(D_real, ones_label)
    #D_loss_fake = nn.binary_cross_entropy(D_fake, zeros_label)
    #D_loss = D_loss_real + D_loss_fake

    D_loss = -(torch.mean(torch.log(D_real) + torch.log(1. - D_fake)))

    D_loss.backward()
    D_solver.step()

    # Housekeeping - reset gradient
    reset_grad()

    # Generator forward-loss-backward-update
    z = (torch.randn(mb_size, Z_dim))
    G_sample = G(z)
    #print(G_sample.shape)
    D_fake = D(G_sample)

    #G_loss = nn.binary_cross_entropy(D_fake, ones_label)

    G_loss = -(torch.mean(torch.log(D_fake)))

    G_loss.backward()
    G_solver.step()

    # Housekeeping - reset gradient
    reset_grad()

    # Print and plot every now and then
    if it % 1000 == 0:
        print('Iter-{}; D_loss: {}; G_loss: {}'.format(it, D_loss.data.numpy(), G_loss.data.numpy()))

        samples = G(z).data.numpy()[:16]

        fig = plt.figure(figsize=(4, 4))
        gs = gridspec.GridSpec(4, 4)
        gs.update(wspace=0.05, hspace=0.05)

        for i, sample in enumerate(samples):
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

        if not os.path.exists('out/'):
            os.makedirs('out/')

        plt.savefig('out/{}.png'.format(str(c).zfill(3)), bbox_inches='tight')
        c += 1
        plt.close(fig)

Iter-0; D_loss: 1.4279001951217651; G_loss: 2.307133197784424
Iter-1000; D_loss: 0.014277681708335876; G_loss: 9.329833984375
Iter-2000; D_loss: 0.005390492267906666; G_loss: 8.5023832321167
Iter-3000; D_loss: 0.06562675535678864; G_loss: 5.705738544464111
Iter-4000; D_loss: 0.06101506948471069; G_loss: 7.2230119705200195
Iter-5000; D_loss: 0.4056508541107178; G_loss: 4.784977912902832
Iter-6000; D_loss: 0.26737695932388306; G_loss: 4.271829605102539
Iter-7000; D_loss: 0.35483354330062866; G_loss: 3.3798208236694336
Iter-8000; D_loss: 0.45045316219329834; G_loss: 3.136461019515991
Iter-9000; D_loss: 0.6955567598342896; G_loss: 2.6970341205596924
Iter-10000; D_loss: 0.5434374213218689; G_loss: 2.5059731006622314
Iter-11000; D_loss: 0.7650926113128662; G_loss: 3.5989723205566406
Iter-12000; D_loss: 0.881645917892456; G_loss: 2.8203768730163574
Iter-13000; D_loss: 0.7671983242034912; G_loss: 2.800879955291748
Iter-14000; D_loss: 0.9357052445411682; G_loss: 2.273482322692871
Iter-15000; D_