In [None]:
!pip install -q torch numpy matplotlib tensorflow

In [None]:
import torch
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from torchvision import datasets, transforms
#from tensorflow.examples.tutorials.mnist import input_data

In [None]:
# Check for GPU and set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

# Create data loaders
train_loader = torch.utils.data.DataLoader(dataset=mnist_trainset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=mnist_testset, batch_size=64, shuffle=False)

mb_size = 64
Z_dim = 100
X_dim = mnist_trainset[0][0].shape[1] * mnist_trainset[0][0].shape[2]  # should be 28*28 = 784
h_dim = 128
c = 0
lr = 1e-3




Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:04<00:00, 2262225.84it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 339826.84it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:01<00:00, 1263309.42it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 8282838.59it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



In [None]:
def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / torch.sqrt(torch.tensor(in_dim / 2., device=device))
    return torch.nn.Parameter(torch.randn(*size, device=device) * xavier_stddev)

In [None]:
""" ==================== GENERATOR ======================== """

Wzh = xavier_init(size=[Z_dim, h_dim])
bzh = torch.nn.Parameter(torch.zeros(h_dim, device=device))
Whx = xavier_init(size=[h_dim, X_dim])
bhx = torch.nn.Parameter(torch.zeros(X_dim, device=device))

def G(z):
    h = F.relu(torch.mm(z, Wzh) + bzh)
    X = torch.sigmoid(torch.mm(h, Whx) + bhx)
    return X

In [None]:
""" ==================== DISCRIMINATOR ======================== """

Wxh = xavier_init(size=[X_dim, h_dim])
bxh = torch.nn.Parameter(torch.zeros(h_dim, device=device))
Why = xavier_init(size=[h_dim, 1])
bhy = torch.nn.Parameter(torch.zeros(1, device=device))

def D(X):
    h = F.relu(torch.mm(X.view(X.size(0), -1), Wxh) + bxh)
    y = torch.sigmoid(torch.mm(h, Why) + bhy)
    return y

G_params = [Wzh, bzh, Whx, bhx]
D_params = [Wxh, bxh, Why, bhy]
params = G_params + D_params


In [None]:
""" ===================== TRAINING ======================== """

def reset_grad():
    for p in params:
        if p.grad is not None:
            p.grad.data.zero_()

G_solver = optim.Adam([Wzh, bzh, Whx, bhx], lr=lr)
D_solver = optim.Adam([Wxh, bxh, Why, bhy], lr=lr)

for it in range(10000):
    for X, _ in train_loader:
        current_batch_size = X.size(0)

        z = torch.randn(current_batch_size, Z_dim, device=device)
        ones_label = torch.ones(current_batch_size, 1, device=device)
        zeros_label = torch.zeros(current_batch_size, 1, device=device)
        X = X.to(device)

        # Discriminator forward-loss-backward-update
        G_sample = G(z)
        D_real = D(X)
        D_fake = D(G_sample)

        D_loss_real = F.binary_cross_entropy(D_real, ones_label)
        D_loss_fake = F.binary_cross_entropy(D_fake, zeros_label)
        D_loss = D_loss_real + D_loss_fake

        D_solver.zero_grad()
        D_loss.backward()
        D_solver.step()

        # Generator forward-loss-backward-update
        z = torch.randn(current_batch_size, Z_dim, device=device)
        G_sample = G(z)
        D_fake = D(G_sample)

        G_loss = F.binary_cross_entropy(D_fake, ones_label)

        G_solver.zero_grad()
        G_loss.backward()
        G_solver.step()

    # Print and plot every now and then
    if it % 100 == 0:
        print('Iter-{}; D_loss: {}; G_loss: {}'.format(it, D_loss.item(), G_loss.item()))

        samples = G(z).detach().cpu().numpy()[:16]

        fig = plt.figure(figsize=(4, 4))
        gs = gridspec.GridSpec(4, 4)
        gs.update(wspace=0.05, hspace=0.05)

        for i, sample in enumerate(samples):
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

        if not os.path.exists('out/'):
            os.makedirs('out/')

        plt.savefig('out/{}.png'.format(str(c).zfill(3)), bbox_inches='tight')
        c += 1
        plt.close(fig)

Iter-0; D_loss: 0.0033271959982812405; G_loss: 8.435750007629395
Iter-100; D_loss: 0.6213933229446411; G_loss: 2.1821322441101074
Iter-200; D_loss: 0.6876852512359619; G_loss: 2.3392324447631836
Iter-300; D_loss: 0.49598371982574463; G_loss: 2.7474308013916016
Iter-400; D_loss: 0.23141580820083618; G_loss: 3.424255847930908
Iter-500; D_loss: 0.31467732787132263; G_loss: 2.4547605514526367
Iter-600; D_loss: 0.8529813289642334; G_loss: 3.2991485595703125
