instructions for downloading emnist dataset (535 MB)

```
mkdir -p ~/.data/emnist
wget -P ~/.data/emnist http://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip
unzip -o -d ~/.data/emnist -j ~/.data/emnist/gzip.zip
rm -f ~/.data/emnist/gzip.zip
```

In [1]:
%pylab inline

from mnist import MNIST
import keras
import keras.backend as K
import tensorflow as tf
from tqdm import tqdm as old_tqdm
from tqdm import tqdm_notebook as tqdm

Populating the interactive namespace from numpy and matplotlib


  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [7]:
mndata = MNIST('/home/ubuntu/.data/emnist', return_type="numpy")

mndata.select_emnist('bymerge')

x, y = mndata.load_training()
# x = np.array(x)
# y = np.array(y)

In [22]:
x = x.reshape(-1,28,28,1)

In [24]:
RANDOM_SEED = 1
IMG_SIZE = 32 # should be power of two and square
INFO_DIM = 10
CHANNELS = 1
N_CRITIC = 1
N_GEN = 1
LR = 0.0002
BATCH_SIZE = 32
N_EPOCHS = 30
BETAS = (0.5, 0.999)
Z_SIZE = 72
LOSS = "RaGAN" # one of RGAN, RaGAN, LSGAN, RaLSGAN, HingeGAN, RaHingeGAN

In [25]:
random.seed(RANDOM_SEED)
numpy.random.seed(RANDOM_SEED)
tf.set_random_seed(RANDOM_SEED)

img_shape = [IMG_SIZE, IMG_SIZE, CHANNELS]

# Following parameter and optimizer set as recommended in paper
n_critic = 1
n_gen = 1
lr = 0.0002
optimizer = keras.optimizers.Adam(lr=LR, beta_1=BETAS[0], beta_2=BETAS[1])
tnorm = keras.initializers.truncated_normal(stddev=.02)
rnorm = keras.initializers.random_normal(stddev=.02)
batchnorm = lambda : keras.layers.BatchNormalization(
    moving_variance_initializer=keras.initializers.random_normal(
        mean=1, stddev=.02))
lrelu = lambda x: keras.activations.relu(x, alpha=.2)

start_pow = np.log2(IMG_SIZE) - 3


def ln_func(x):
    return tf.contrib.layers.layer_norm(x)
layer_norm = lambda : keras.layers.Lambda(ln_func)

In [26]:
##################
#
# Build Generator
#
#################


dim = 64
k = 4
DROP = 0.00

noise = keras.Input(shape=(Z_SIZE, ))

noise_block = keras.layers.Dense(
    int(dim * (2**start_pow)) * 4 * 4,
    activation="relu",
    kernel_initializer=rnorm)(noise)
noise_block = batchnorm()(noise_block)
noise_block = keras.layers.Reshape((4, 4, int(dim * (2**start_pow))))(noise_block)
noise_block = keras.layers.Dropout(DROP)(noise_block)

conv_block = noise_block

n_blocks = int(start_pow + 1)
curr_dim = int((2**(start_pow-1)) * dim)

for i in range(n_blocks):
    
    if curr_dim >= dim:
        conv_block = keras.layers.Conv2DTranspose(
            curr_dim,
            kernel_size=k,
            strides=2,
            padding="same",
            activation="relu",
            kernel_initializer=rnorm,
            use_bias=False)(conv_block)
        curr_dim = curr_dim // 2
        conv_block = batchnorm()(conv_block)
        conv_block = keras.layers.Dropout(DROP)(conv_block)
        
    else:
        conv_block = keras.layers.Conv2DTranspose(
            CHANNELS,
            kernel_size=k,
            strides=2,
            padding="same",
            activation="tanh",
            kernel_initializer=rnorm,
            use_bias=False)(conv_block)
        
    

img = conv_block

generator = keras.Model(noise, img)
generator.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_7 (InputLayer)         (None, 72)                0         
_________________________________________________________________
dense_3 (Dense)              (None, 4096)              299008    
_________________________________________________________________
batch_normalization_16 (Batc (None, 4096)              16384     
_________________________________________________________________
reshape_3 (Reshape)          (None, 4, 4, 256)         0         
_________________________________________________________________
dropout_16 (Dropout)         (None, 4, 4, 256)         0         
_________________________________________________________________
conv2d_transpose_7 (Conv2DTr (None, 8, 8, 128)         524288    
_________________________________________________________________
batch_normalization_17 (Batc (None, 8, 8, 128)         512       
__________

In [32]:
##################
#
# Build Critic
#
#################

dim = 64
k = 4
DROP = 0.0

img = keras.Input(shape=img_shape)

conv_block = img

# conv_block = keras.layers.GaussianNoise(2/255)(conv_block)

for i in range(n_blocks):
    conv_block = keras.layers.Conv2D(
        dim * (2**i),
        kernel_size=k,
        strides=2,
        padding="same",
        activation=lrelu,
        kernel_initializer=tnorm,
        use_bias=False,
    )(conv_block)
    conv_block = batchnorm()(conv_block)
#     conv_block = keras.layers.BatchNormalization()(conv_block)
    conv_block = keras.layers.Dropout(DROP)(conv_block)

val_block = keras.layers.Conv2D(
    1,
    kernel_size=4,
    strides=1,
    padding="valid",
    kernel_initializer=tnorm,
    use_bias=False,
)(conv_block)

val_block = keras.layers.Flatten()(val_block)

info_block = keras.layers.Conv2D(
    INFO_DIM,
    kernel_size=4,
    strides=1,
    padding="valid",
    kernel_initializer=tnorm,
    use_bias=False,
)(conv_block)

info_block = keras.layers.Flatten()(info_block)

critic = keras.Model(img, [info_block, val_block])

critic.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_10 (InputLayer)           (None, 32, 32, 1)    0                                            
__________________________________________________________________________________________________
conv2d_23 (Conv2D)              (None, 16, 16, 64)   1024        input_10[0][0]                   
__________________________________________________________________________________________________
batch_normalization_25 (BatchNo (None, 16, 16, 64)   256         conv2d_23[0][0]                  
__________________________________________________________________________________________________
dropout_25 (Dropout)            (None, 16, 16, 64)   0           batch_normalization_25[0][0]     
__________________________________________________________________________________________________
conv2d_24 

In [30]:
def mutual_info_loss(self, c, c_given_x):
    """The mutual information metric we aim to minimize"""
    eps = 1e-8
    conditional_entropy = K.mean(- K.sum(K.log(c_given_x + eps) * c, axis=1))
    entropy = K.mean(- K.sum(K.log(c + eps) * c, axis=1))

    return conditional_entropy + entropy

In [33]:
def noop_loss(target, output):
    return output

def BCE(logits, labels):
    return tf.nn.sigmoid_cross_entropy_with_logits(logits =logits, labels=labels)

### Relativistic Standard GAN

# No sigmoid activation in last layer of generator because BCEWithLogitsLoss() already adds it

# BCE_stable = torch.nn.BCEWithLogitsLoss()

# # Discriminator loss
# errD = BCE_stable(y_pred - y_pred_fake, y)
# errD.backward()

# # Generator loss (You may want to resample again from real and fake data)
# errG = BCE_stable(y_pred_fake - y_pred, y)
# errG.backward()

def rgan_loss(inputs):
    
    y_pred = inputs[0]
    y_pred_fake = inputs[1]
    y = K.ones_like(y_pred)
    
    return BCE(y_pred - y_pred_fake, y)


### Relativistic average Standard GAN

# No sigmoid activation in last layer of generator because BCEWithLogitsLoss() already adds it

# BCE_stable = torch.nn.BCEWithLogitsLoss()

# # Discriminator loss
# errD = ((BCE_stable(y_pred - torch.mean(y_pred_fake), y) + BCE_stable(y_pred_fake - torch.mean(y_pred), y2))/2
# errD.backward()

# # Generator loss (You may want to resample again from real and fake data)
# errG = ((BCE_stable(y_pred - torch.mean(y_pred_fake), y2) + BCE_stable(y_pred_fake - torch.mean(y_pred), y))/2
# errG.backward()

def ragan_loss(inputs):
    
    y_pred = inputs[0]
    y_pred_fake = inputs[1]
    y = K.ones_like(y_pred)
    y2 = K.zeros_like(y_pred)
        
    first_term = BCE(y_pred - K.mean(y_pred_fake), y)
    second_term = BCE(y_pred_fake - K.mean(y_pred), y2)
        
    return (first_term + second_term)/2

### Relativistic average LSGAN

# No activation in generator

# Discriminator loss
# errD = (torch.mean((y_pred - torch.mean(y_pred_fake) - y) ** 2) + torch.mean((y_pred_fake - torch.mean(y_pred) + y) ** 2))/2
# errD.backward()

# # Generator loss (You may want to resample again from real and fake data)
# errG = (torch.mean((y_pred - torch.mean(y_pred_fake) + y) ** 2) + torch.mean((y_pred_fake - torch.mean(y_pred) - y) ** 2))/2
# errG.backward()

def ralsgan_loss(inputs):
    
    y_pred = inputs[0]
    y_pred_fake = inputs[1]
    
    y = K.ones_like(y_pred)
    
    first_term = K.mean((y_pred - K.mean(y_pred_fake) - y)**2)
    second_term = K.mean((y_pred_fake - K.mean(y_pred) +y)**2)
    
    return (first_term + second_term)/2

### Relativistic average HingeGAN

# No activation in generator

# # Discriminator loss
# errD = (torch.mean(torch.nn.ReLU()(1.0 - (y_pred - torch.mean(y_pred_fake)))) + torch.mean(torch.nn.ReLU()(1.0 + (y_pred_fake - torch.mean(y_pred)))))/2
# errD.backward()
 
# # Generator loss  (You may want to resample again from real and fake data)
# errG = (torch.mean(torch.nn.ReLU()(1.0 + (y_pred - torch.mean(y_pred_fake)))) + torch.mean(torch.nn.ReLU()(1.0 - (y_pred_fake - torch.mean(y_pred)))))/2
# errG.backward()

def rahingegan_loss(inputs):
    
    y_pred = inputs[0]
    y_pred_fake = inputs[1]
    
    y = K.ones_like(y_pred)
    
    first_term = K.relu(1.0 - (y_pred - K.mean(y_pred_fake)))
    second_term = K.relu(1.0 + (y_pred_fake - K.mean(y_pred)))
    
    return (first_term + second_term)/2

gan_loss = {
    "RGAN": rgan_loss, 
    "RaGAN": ragan_loss, 
    "RaLSGAN": ralsgan_loss, 
    "RaHingeGAN": rahingegan_loss
}[LOSS]

In [35]:
#-------------------------------
#       Set up tensors for 
#       computational graph
#-------------------------------

# Image input (real sample)
real_img = keras.Input(shape=img_shape)

# Noise input
z_noise = keras.Input(shape=(Z_SIZE, ))
# Generate image based of noise (fake sample)
fake_img = generator(z_noise)

# Discriminator determines validity of the real and fake images
info_fake, val_fake = critic(fake_img)
info, val = critic(real_img)