In [None]:
%tensorflow_version 1.x

TensorFlow 1.x selected.


In [None]:
import tensorflow as tf
import torch
import torch.nn.functional as nn
import torch.autograd as autograd
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from torch.autograd import Variable
from tensorflow.examples.tutorials.mnist import input_data


mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)
mb_size = 64
Z_dim = 100
X_dim = mnist.train.images.shape[1]
y_dim = mnist.train.labels.shape[1]
h_dim = 128
c = 0
lr = 1e-3
num_iterations = 100000


def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / np.sqrt(in_dim / 2.)
    return Variable(torch.randn(*size) * xavier_stddev, requires_grad=True)


""" ==================== GENERATOR ======================== """

Wzh = xavier_init(size=[Z_dim, h_dim])
bzh = Variable(torch.zeros(h_dim), requires_grad=True)

Whx = xavier_init(size=[h_dim, X_dim])
bhx = Variable(torch.zeros(X_dim), requires_grad=True)


def G(z):
    h = nn.relu(z @ Wzh + bzh.repeat(z.size(0), 1))
    X = nn.sigmoid(h @ Whx + bhx.repeat(h.size(0), 1))
    return X


""" ==================== DISCRIMINATOR ======================== """

Wxh = xavier_init(size=[X_dim, h_dim])
bxh = Variable(torch.zeros(h_dim), requires_grad=True)

Why = xavier_init(size=[h_dim, 1])
bhy = Variable(torch.zeros(1), requires_grad=True)


def D(X):
    h = nn.relu(X @ Wxh + bxh.repeat(X.size(0), 1))
    y = nn.sigmoid(h @ Why + bhy.repeat(h.size(0), 1))
    return y


G_params = [Wzh, bzh, Whx, bhx]
D_params = [Wxh, bxh, Why, bhy]
params = G_params + D_params


""" ===================== TRAINING ======================== """


def reset_grad():
    for p in params:
        if p.grad is not None:
            data = p.grad.data
            p.grad = Variable(data.new().resize_as_(data).zero_())


G_solver = optim.Adam(G_params, lr=1e-3)
D_solver = optim.Adam(D_params, lr=1e-3)

ones_label = Variable(torch.ones(mb_size, 1))
zeros_label = Variable(torch.zeros(mb_size, 1))


for it in range(num_iterations):
    # Sample data
    z = Variable(torch.randn(mb_size, Z_dim))
    X, _ = mnist.train.next_batch(mb_size)
    X = Variable(torch.from_numpy(X))

    # Dicriminator forward-loss-backward-update
    G_sample = G(z)
    D_real = D(X)
    D_fake = D(G_sample)

    # loss for task 1.1
    D_loss = -torch.mean(torch.log(D_real) + torch.log(1. - D_fake))

    # loss for task 1.2
    #loss = torch.nn.BCEWithLogitsLoss()
    #D_loss_real = loss(D_real, torch.ones_like(D_real))#(torch.nn.BCEWithLogitsLoss(D_real, torch.ones_like(D_real)))
    #D_loss_fake = loss(D_fake, torch.zeros_like(D_fake))#(torch.nn.BCEWithLogitsLoss(D_fake, torch.zeros_like(D_fake)))
    #D_loss = D_loss_real + D_loss_fake


    D_loss.backward()
    D_solver.step()

    # Housekeeping - reset gradient
    reset_grad()

    # Generator forward-loss-backward-update
    z = Variable(torch.randn(mb_size, Z_dim))
    G_sample = G(z)
    D_fake = D(G_sample)

    # loss for task 1.1
    G_loss = -torch.mean(torch.log(D_fake))
    # loss for task 1.2
    #G_loss = loss(D_fake, torch.ones_like(D_fake))

    G_loss.backward()
    G_solver.step()

    # Housekeeping - reset gradient
    reset_grad()

    # Print and plot every now and then
    if it % 1000 == 0:
        print('Iter-{}; D_loss: {}; G_loss: {}'.format(it, D_loss.data.numpy(), G_loss.data.numpy()))

        samples = G(z).data.numpy()[:16]

        fig = plt.figure(figsize=(4, 4))
        gs = gridspec.GridSpec(4, 4)
        gs.update(wspace=0.05, hspace=0.05)

        for i, sample in enumerate(samples):
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

        if not os.path.exists('out/'):
            os.makedirs('out/')

        plt.savefig('out/{}.png'.format(str(c).zfill(3)), bbox_inches='tight')
        c += 1
        plt.close(fig)

Extracting ../../MNIST_data/train-images-idx3-ubyte.gz
Extracting ../../MNIST_data/train-labels-idx1-ubyte.gz
Extracting ../../MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ../../MNIST_data/t10k-labels-idx1-ubyte.gz




Iter-0; D_loss: 1.4635330438613892; G_loss: 2.1223740577697754
Iter-1000; D_loss: 0.007852480746805668; G_loss: 8.235454559326172
Iter-2000; D_loss: 0.0371272899210453; G_loss: 5.607226371765137
Iter-3000; D_loss: 0.0952954813838005; G_loss: 4.773682594299316
Iter-4000; D_loss: 0.16303715109825134; G_loss: 4.865339756011963
Iter-5000; D_loss: 0.15661019086837769; G_loss: 4.641195774078369
Iter-6000; D_loss: 0.39188510179519653; G_loss: 4.286818027496338
Iter-7000; D_loss: 0.4969128370285034; G_loss: 3.3988757133483887
Iter-8000; D_loss: 0.5213351249694824; G_loss: 3.262096881866455
Iter-9000; D_loss: 0.8866742253303528; G_loss: 2.488405227661133
Iter-10000; D_loss: 0.9080154299736023; G_loss: 3.0092363357543945
Iter-11000; D_loss: 0.6964693665504456; G_loss: 2.545623540878296
Iter-12000; D_loss: 0.7955878376960754; G_loss: 2.844303607940674
Iter-13000; D_loss: 0.8184623718261719; G_loss: 2.97110915184021
Iter-14000; D_loss: 0.5663131475448608; G_loss: 2.303577423095703
Iter-15000; D_lo

In [None]:
!zip -r /content/file.zip /content/out
from google.colab import files
files.download("/content/file.zip")

updating: content/out/ (stored 0%)
updating: content/out/092.png (deflated 8%)
updating: content/out/064.png (deflated 6%)
updating: content/out/052.png (deflated 7%)
updating: content/out/037.png (deflated 7%)
updating: content/out/067.png (deflated 6%)
updating: content/out/086.png (deflated 8%)
updating: content/out/077.png (deflated 6%)
updating: content/out/030.png (deflated 6%)
updating: content/out/074.png (deflated 7%)
updating: content/out/005.png (deflated 6%)
updating: content/out/016.png (deflated 7%)
updating: content/out/025.png (deflated 6%)
updating: content/out/000.png (deflated 8%)
updating: content/out/043.png (deflated 6%)
updating: content/out/072.png (deflated 6%)
updating: content/out/065.png (deflated 7%)
updating: content/out/015.png (deflated 6%)
updating: content/out/012.png (deflated 6%)
updating: content/out/093.png (deflated 7%)
updating: content/out/091.png (deflated 7%)
updating: content/out/087.png (deflated 8%)
updating: content/out/024.png (deflated 6

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>