In [13]:
# Large amount of credit goes to:
# https://github.com/keras-team/keras-contrib/blob/master/examples/improved_wgan.py
# which I've used as a reference for this implementation

from functools import partial
import keras
import keras.backend as K
import matplotlib.pyplot as plt
import sys
import numpy as np
from tqdm import tqdm_notebook as tqdm

In [2]:
import tensorflow as tf

def ln_func(x):
    return tf.contrib.layers.layer_norm(x)

layer_norm = lambda : keras.layers.Lambda(ln_func)

In [18]:
img_rows = 28
img_cols = 28
channels = 1
img_shape = (img_rows, img_cols, channels)
latent_dim = 100

# Following parameter and optimizer set as recommended in paper
n_critic = 5
optimizer = keras.optimizers.Adam(lr=0.001, beta_1=.5)
lrelu = lambda x: keras.activations.relu(x, alpha=.2)

In [30]:
dim = 40
k = 4

# Build the generator and critic
model = keras.Sequential([
    keras.layers.Dense(dim*4 * 7 * 7, activation="relu", input_dim=latent_dim),
    keras.layers.BatchNormalization(),
    keras.layers.Reshape((7, 7, dim*4)),
    
    keras.layers.Conv2DTranspose(dim*2, kernel_size=k, padding="same", strides=2, activation="relu"),
    keras.layers.BatchNormalization(),
    
    keras.layers.Conv2DTranspose(dim, kernel_size=k, padding="same", strides=2, activation="relu"),
    keras.layers.BatchNormalization(),
    
    
    keras.layers.Conv2DTranspose(channels, kernel_size=k, padding="same", strides=1, activation="tanh"),
])

model.summary()

noise = keras.Input(shape=(latent_dim,))
img = model(noise)

generator = keras.Model(noise, img)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_7 (Dense)              (None, 7840)              791840    
_________________________________________________________________
batch_normalization_16 (Batc (None, 7840)              31360     
_________________________________________________________________
reshape_4 (Reshape)          (None, 7, 7, 160)         0         
_________________________________________________________________
conv2d_transpose_10 (Conv2DT (None, 14, 14, 80)        204880    
_________________________________________________________________
batch_normalization_17 (Batc (None, 14, 14, 80)        320       
_________________________________________________________________
conv2d_transpose_11 (Conv2DT (None, 28, 28, 40)        51240     
_________________________________________________________________
batch_normalization_18 (Batc (None, 28, 28, 40)        160       
__________

In [31]:
dim = 32
k = 4
DROP = 0.40

model = keras.Sequential([
    keras.layers.Conv2D(dim, kernel_size=k, strides=2, input_shape=img_shape, padding="same", activation=lrelu),
    keras.layers.BatchNormalization(),
    keras.layers.Dropout(DROP),
    
    keras.layers.Conv2D(dim*2, kernel_size=k, strides=2, padding="same", activation=lrelu),
    keras.layers.BatchNormalization(),
    keras.layers.Dropout(DROP),
    
    keras.layers.Conv2D(dim*4, kernel_size=k, strides=2, padding="same", activation=lrelu),
    keras.layers.BatchNormalization(),
    keras.layers.Dropout(DROP),
    
    keras.layers.Flatten(),
    keras.layers.Dense(1)
])

model.summary()

img = keras.Input(shape=img_shape)
validity = model(img)

critic =  keras.Model(img, validity)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_10 (Conv2D)           (None, 14, 14, 32)        544       
_________________________________________________________________
batch_normalization_19 (Batc (None, 14, 14, 32)        128       
_________________________________________________________________
dropout_10 (Dropout)         (None, 14, 14, 32)        0         
_________________________________________________________________
conv2d_11 (Conv2D)           (None, 7, 7, 64)          32832     
_________________________________________________________________
batch_normalization_20 (Batc (None, 7, 7, 64)          256       
_________________________________________________________________
dropout_11 (Dropout)         (None, 7, 7, 64)          0         
_________________________________________________________________
conv2d_12 (Conv2D)           (None, 4, 4, 128)         131200    
__________

In [32]:
### Relativistic average Standard GAN

# No sigmoid activation in last layer of generator because BCEWithLogitsLoss() already adds it

# BCE_stable = torch.nn.BCEWithLogitsLoss()

# # Discriminator loss
# errD = ((BCE_stable(y_pred - torch.mean(y_pred_fake), y) + BCE_stable(y_pred_fake - torch.mean(y_pred), y2))/2
# errD.backward()

# # Generator loss (You may want to resample again from real and fake data)
# errG = ((BCE_stable(y_pred - torch.mean(y_pred_fake), y2) + BCE_stable(y_pred_fake - torch.mean(y_pred), y))/2
# errG.backward()




In [33]:
def noop_loss(target, output):
    return output

In [34]:
def rasg_loss(inputs):
    
    
    y_pred_target = inputs[0]
    y_pred_off = inputs[1]
    
    zeros_bce = K.binary_crossentropy(y_pred_target - K.mean(y_pred_off), K.zeros_like(y_pred_target), from_logits=True)
    ones_bce = K.binary_crossentropy(y_pred_off - K.mean(y_pred_target), K.ones_like(y_pred_target), from_logits=True)
    
    return zeros_bce + ones_bce

def rasg_loss_layer():
    
    return keras.layers.Lambda(rasg_loss)

In [35]:
#-------------------------------
# Construct Computational Graph
#       for the Critic
#-------------------------------

# Freeze generator's layers while training critic
critic.trainable = True
generator.trainable = False

# Image input (real sample)
real_img = keras.Input(shape=img_shape)

# Noise input
z_disc = keras.Input(shape=(latent_dim,))
# Generate image based of noise (fake sample)
fake_img = generator(z_disc)

# Discriminator determines validity of the real and fake images
y_pred_fake = critic(fake_img)
y_pred = critic(real_img)

critic_loss = rasg_loss_layer()([y_pred, y_pred_fake])

critic_model = keras.Model(inputs=[real_img, z_disc],
                    outputs=[critic_loss])
critic_model.compile(loss=[noop_loss],
                                optimizer=optimizer)

In [36]:
#-------------------------------
# Construct Computational Graph
#         for Generator
#-------------------------------

# For the generator we freeze the critic's layers
critic.trainable = False
generator.trainable = True


gen_loss = rasg_loss_layer()([y_pred_fake, y_pred])

gen_model = keras.Model(inputs=[real_img, z_disc],
                    outputs=[gen_loss])
gen_model.compile(loss=[noop_loss],
                                optimizer=optimizer)

In [37]:
import shutil, os

n_critic = 1

def train(steps, batch_size, sample_interval=50):

    shutil.rmtree("images", ignore_errors=True)
    os.makedirs(f"images", exist_ok=True)
    
    # Load the dataset
    (X_train, _), (_, _) = keras.datasets.mnist.load_data()

    # Rescale -1 to 1
    X_train = (X_train.astype(np.float32) - 127.5) / 127.5
    X_train = np.expand_dims(X_train, axis=3)

    # Adversarial ground truths
    valid = -np.ones((batch_size, 1))
    fake =  np.ones((batch_size, 1))
    dummy = np.zeros((batch_size, 1)) # Dummy gt for gradient penalty
    for step in tqdm(range(1, steps+1)):

        # sample for Critic
        
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        imgs = X_train[idx]
        # Sample generator input
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        # Train the critic
        d_loss = critic_model.train_on_batch([imgs, noise],
                                                        valid)

        # ---------------------
        #  Train Generator
        # ---------------------
        
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        imgs = X_train[idx]
        # Sample generator input
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        g_loss = gen_model.train_on_batch([imgs, noise], valid)
        

        # If at save interval => save generated image samples
        if step % sample_interval == 0:
            print_string = f"step: {step}, D: {d_loss:g}, G: {g_loss:g}"
            print(print_string)
            sample_images(print_string, step)

def sample_images(print_string, step):
    # make a video with 
    # >ffmpeg -framerate 4 -pattern_type glob -i mnist_*.png -pix_fmt yuv420p output.mp4
    
    
    r, c = 5, 5
    noise = np.random.normal(0, 1, (r * c, latent_dim))
    gen_imgs = generator.predict(noise)

    # Rescale images 0 - 1
    gen_imgs = 0.5 * gen_imgs + 1

    fig, axs = plt.subplots(r, c, sharex=True, sharey=True, frameon=False, figsize=(5,5))
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray', aspect="auto")
            axs[i,j].axis('off')
            cnt += 1
    plt.tight_layout(h_pad=0, w_pad=0)
    plt.suptitle(print_string, backgroundcolor="white", fontsize=7)
    fig.savefig(f"images/mnist_{step:05d}.png", dpi=150)
    plt.close()

In [38]:
train(steps=10000, batch_size=32, sample_interval=50)

HBox(children=(IntProgress(value=0, max=10000), HTML(value='')))

step: 50, D: -22.9684, G: 76.11
step: 100, D: -647.664, G: 521.969
step: 150, D: -167.107, G: 323.056
step: 200, D: -272.203, G: 167.697
step: 250, D: -227.347, G: 130.947
step: 300, D: -105.172, G: 7.00066
step: 350, D: -221.548, G: 287.815
step: 400, D: -200.115, G: 51.8454
step: 450, D: -131.229, G: 131.447
step: 500, D: -151.489, G: 309.293
step: 550, D: -243.493, G: 262.96
step: 600, D: -202.014, G: 80.7457
step: 650, D: -364.984, G: 144.29
step: 700, D: -256.516, G: 337.985
step: 750, D: -300.863, G: 214.156
step: 800, D: -90.0517, G: 150.062
step: 850, D: -279.312, G: 254.462
step: 900, D: 162.799, G: 185.565
step: 950, D: -218.679, G: 501.021
step: 1000, D: -427.274, G: 560.599
step: 1050, D: -134.299, G: 269.479
step: 1100, D: -283.378, G: 238.182
step: 1150, D: -543.579, G: -35.8385
step: 1200, D: -258.428, G: 729.255
step: 1250, D: -497.878, G: 394.878
step: 1300, D: -330.567, G: 608.073
step: 1350, D: -408.889, G: 406.245
step: 1400, D: -17.0443, G: 465.391
step: 1450, D: 7