In [1]:
#Only needs to be run once
#Currently getting error
#"WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv"
#Unsure how to setup new user
!pip install -q torch numpy matplotlib tensorflow
!pip install torchvision

[0m

In [2]:


import torch
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from torchvision import datasets, transforms
#from tensorflow.examples.tutorials.mnist import input_data

In [None]:
# Check for GPU and set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
# Load MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

# Create data loaders
train_loader = torch.utils.data.DataLoader(dataset=mnist_trainset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=mnist_testset, batch_size=64, shuffle=False)

mb_size = 64
Z_dim = 100
X_dim = mnist_trainset[0][0].shape[1] * mnist_trainset[0][0].shape[2]  # should be 28*28 = 784
h_dim = 128
c = 0
lr = 1e-3




In [4]:
def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / torch.sqrt(torch.tensor(in_dim / 2., device=device))
    return torch.nn.Parameter(torch.randn(*size, device=device) * xavier_stddev)

In [5]:
""" ==================== GENERATOR ======================== """

Wzh = xavier_init(size=[Z_dim, h_dim])
bzh = torch.nn.Parameter(torch.zeros(h_dim, device=device))
Whx = xavier_init(size=[h_dim, X_dim])
bhx = torch.nn.Parameter(torch.zeros(X_dim, device=device))

def G(z):
    h = F.relu(torch.mm(z, Wzh) + bzh)
    X = torch.sigmoid(torch.mm(h, Whx) + bhx)
    return X

In [6]:
""" ==================== DISCRIMINATOR ======================== """

Wxh = xavier_init(size=[X_dim, h_dim])
bxh = torch.nn.Parameter(torch.zeros(h_dim, device=device))
Why = xavier_init(size=[h_dim, 1])
bhy = torch.nn.Parameter(torch.zeros(1, device=device))

def D(X):
    h = F.relu(torch.mm(X.view(X.size(0), -1), Wxh) + bxh)
    y = torch.sigmoid(torch.mm(h, Why) + bhy)
    return y

G_params = [Wzh, bzh, Whx, bhx]
D_params = [Wxh, bxh, Why, bhy]
params = G_params + D_params


In [None]:
""" ===================== TRAINING ======================== """

def reset_grad():
    for p in params:
        if p.grad is not None:
            p.grad.data.zero_()

G_solver = optim.Adam([Wzh, bzh, Whx, bhx], lr=lr)
D_solver = optim.Adam([Wxh, bxh, Why, bhy], lr=lr)

for it in range(10000):
    for X, _ in train_loader:
        current_batch_size = X.size(0)

        z = torch.randn(current_batch_size, Z_dim, device=device)
        ones_label = torch.ones(current_batch_size, 1, device=device)
        zeros_label = torch.zeros(current_batch_size, 1, device=device)
        X = X.to(device)

        # Discriminator forward-loss-backward-update
        G_sample = G(z)
        G_sample = G_sample.to(device)
        D_real = D(X)
        D_real = D_real.to(device)
        D_fake = D(G_sample)
        D_fake = D_fake.to(device)

        D_loss_real = F.binary_cross_entropy(D_real, ones_label)
        D_loss_fake = F.binary_cross_entropy(D_fake, zeros_label)
        D_loss = D_loss_real + D_loss_fake
        D_loss = D_loss.to(device)

        D_solver.zero_grad()
        D_loss.backward()
        D_solver.step()

        # Generator forward-loss-backward-update
        z = torch.randn(current_batch_size, Z_dim, device=device)
        G_sample = G(z)
        G_sample = G_sample.to(device)
        D_fake = D(G_sample)
        D_fake = D_fake.to(device)

        G_loss = F.binary_cross_entropy(D_fake, ones_label)
        G_loss = G_loss.to(device)

        G_solver.zero_grad()
        G_loss.backward()
        G_solver.step()

    # Print and plot every now and then
    if it % 100 == 0:
        print('Iter-{}; D_loss: {}; G_loss: {}'.format(it, D_loss.item(), G_loss.item()))

        samples = G(z)
        samples = samples.to(device)
        samples = samples.detach().cpu().numpy()[:16]
        fig = plt.figure(figsize=(4, 4))
        gs = gridspec.GridSpec(4, 4)
        gs.update(wspace=0.05, hspace=0.05)

        for i, sample in enumerate(samples):
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

        if not os.path.exists('out/'):
            os.makedirs('out/')

        plt.savefig('out/{}.png'.format(str(c).zfill(3)), bbox_inches='tight')
        c += 1
        plt.close(fig)

Iter-0; D_loss: 0.008049117401242256; G_loss: 7.378349304199219
Iter-100; D_loss: 0.8193799257278442; G_loss: 2.022839069366455
Iter-200; D_loss: 0.7014672756195068; G_loss: 2.5554208755493164
Iter-300; D_loss: 0.40103816986083984; G_loss: 2.9425840377807617
Iter-400; D_loss: 0.39365285634994507; G_loss: 3.1927971839904785
Iter-500; D_loss: 0.2689024806022644; G_loss: 3.6506948471069336
Iter-600; D_loss: 3.5428504943847656; G_loss: 3.7231674194335938
Iter-700; D_loss: 0.1865185797214508; G_loss: 3.1402945518493652
Iter-800; D_loss: 0.2687574625015259; G_loss: 2.840707302093506
Iter-900; D_loss: 0.35830408334732056; G_loss: 3.3787498474121094
Iter-1000; D_loss: 0.10676257312297821; G_loss: 3.231506824493408
Iter-1100; D_loss: 0.15382477641105652; G_loss: 3.278045177459717
Iter-1200; D_loss: 0.1923561543226242; G_loss: 3.0635030269622803
Iter-1300; D_loss: 0.13251742720603943; G_loss: 3.693588972091675
Iter-1400; D_loss: 0.16421853005886078; G_loss: 3.7232770919799805
Iter-1500; D_loss: 