In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from tensorflow.examples.tutorials.mnist import input_data

In [7]:
mnist = input_data.read_data_sets('./MNIST_data', one_hot=True)

Extracting ./MNIST_data/train-images-idx3-ubyte.gz
Extracting ./MNIST_data/train-labels-idx1-ubyte.gz
Extracting ./MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ./MNIST_data/t10k-labels-idx1-ubyte.gz


In [8]:
mb_size = 32
z_dim = 5
X_dim = mnist.train.images.shape[1]
y_dim = mnist.train.labels.shape[1]
h_dim = 128
cnt = 0
lr = 1e-3

In [13]:
# Encoder
Q = nn.Sequential(
    nn.Linear(X_dim, h_dim),
    nn.ReLU(),
    nn.Linear(h_dim, z_dim)
)

# Decoder
P = nn.Sequential(
    nn.Linear(z_dim, h_dim),
    nn.ReLU(),
    nn.Linear(h_dim, X_dim),
    nn.Sigmoid()
)

# Descriminator
D = nn.Sequential(
    nn.Linear(z_dim, h_dim),
    nn.ReLU(),
    nn.Linear(h_dim, 1),
    nn.Sigmoid()
)

In [14]:
def reset_grad():
    Q.zero_grad()
    P.zero_grad()
    D.zero_grad()

In [15]:
def sample_X(size, include_y=False):
    X, y = mnist.train.next_batch(size)
    X = torch.from_numpy(X)
    
    if include_y:
        y = np.argmax(y, axis=1).astype(np.int)
        y = torch.from_numpy(y)
        return X, y

    return X

In [16]:
Q_optimizer = optim.Adam(Q.parameters(), lr=lr)
P_optimizer = optim.Adam(P.parameters(), lr=lr)
D_optimizer = optim.Adam(D.parameters(), lr=lr)

In [24]:
for it in range(1000000):
    X = sample_X(mb_size)
    
    # reconstruction phase
    z_sample = Q(X)
    X_sample = P(z_sample)    
    recon_loss = F.binary_cross_entropy(X_sample, X)
    
    recon_loss.backward()
    P_optimizer.step()
    Q_optimizer.step()
    reset_grad()

    # regularization phase
    z_real = torch.randn(mb_size, z_dim)
    z_fake = Q(X)
    
    D_real = D(z_real)
    D_fake = D(z_fake)
    
    D_loss = - torch.mean(torch.log(D_real) + torch.log(1 - D_fake))
    D_loss.backward()
    D_optimizer.step()
    reset_grad()
    
    # generator
    z_fake = Q(X)
    D_fake = D(z_fake)
    G_loss = -torch.mean(torch.log(D_fake))
    G_loss.backward()
    Q_optimizer.step()
    reset_grad()
    
    if it % 1000 == 0:
        print('Iter-{}; D_loss: {:.4}; G_loss: {:.4}; recon_loss: {:.4}'.format(it, D_loss.item(), G_loss.item(), recon_loss.item()))
        samples = P(z_real).detach().numpy()[:16]
        
        fig = plt.figure(figsize=(4, 4))
        gs = gridspec.GridSpec(4, 4)
        gs.update(wspace=0.05, hspace=0.05)
        
        for i, sample in enumerate(samples):
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
        
        os.makedirs('aae', exist_ok=True)
        plt.savefig('aae/{}.png'.format(str(cnt).zfill(3)), bbox_inches='tight')
        cnt += 1
        plt.close(fig)

Iter-0; D_loss: 1.476; G_loss: 0.5378; recon_loss: 0.6613
Iter-1000; D_loss: 1.555; G_loss: 0.6024; recon_loss: 0.2733
Iter-2000; D_loss: 1.542; G_loss: 0.5959; recon_loss: 0.2738
Iter-3000; D_loss: 1.488; G_loss: 0.613; recon_loss: 0.2406
Iter-4000; D_loss: 1.364; G_loss: 0.6926; recon_loss: 0.2085
Iter-5000; D_loss: 1.382; G_loss: 0.6965; recon_loss: 0.1609
Iter-6000; D_loss: 1.382; G_loss: 0.6987; recon_loss: 0.1823
Iter-7000; D_loss: 1.386; G_loss: 0.6965; recon_loss: 0.1695
Iter-8000; D_loss: 1.381; G_loss: 0.6829; recon_loss: 0.1506
Iter-9000; D_loss: 1.384; G_loss: 0.7114; recon_loss: 0.1906
Iter-10000; D_loss: 1.38; G_loss: 0.6991; recon_loss: 0.1721
Iter-11000; D_loss: 1.389; G_loss: 0.6854; recon_loss: 0.1496
Iter-12000; D_loss: 1.394; G_loss: 0.684; recon_loss: 0.157
Iter-13000; D_loss: 1.39; G_loss: 0.6832; recon_loss: 0.166
Iter-14000; D_loss: 1.393; G_loss: 0.6841; recon_loss: 0.174
Iter-15000; D_loss: 1.391; G_loss: 0.6974; recon_loss: 0.147
Iter-16000; D_loss: 1.406; G_

KeyboardInterrupt: 