In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import cv2
from keras.utils import to_categorical  # Only for categorical one hot encoding
from tensorflow.keras import layers
from sklearn.metrics import accuracy_score

In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# (x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
cy_train = np.array(to_categorical(y_train))
cy_test = np.array(to_categorical(y_test))

cx_train, cx_test = np.array((x_train.reshape(x_train.shape[0], 28, 28, 1) - 127.5)/127.5), np.array((x_test.reshape(x_test.shape[0], 28, 28, 1)-127.5)/127.5)
cx_train.shape

In [None]:
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=False, reduction=tf.keras.losses.Reduction.SUM)

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output) * 0.9, real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    # total_loss = tf.concat((fake_loss, real_loss), axis=0)
    total_loss = (fake_loss + real_loss) * 0.5
    return total_loss #* 0.5
    
def wasserstein_discriminator_loss(real_output, fake_output):
    total_loss = -tf.reduce_mean(real_output - fake_output)
    return total_loss

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output) * 0.9, fake_output)

def wasserstein_generator_loss(fake_output):
    total_loss = -tf.reduce_mean(fake_output)
    return total_loss

def generator_enc_loss(real, fake):
  # return tf.reduce_mean(tf.abs(real - fake))
  return tf.keras.losses.mean_absolute_error(real, fake)

In [None]:
def generator():
    model = tf.keras.Sequential()
    model.add(layers.Dense(512, use_bias=False, input_shape=(128,)))
    model.add(layers.Dense(7*7*256, use_bias=False))
    model.add(layers.BatchNormalization(momentum=0.8))
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((7, 7, 256)))

    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    model.add(layers.BatchNormalization(momentum=0.8))
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (7, 7), strides=(2, 2), padding='same', use_bias=False))
    model.add(layers.BatchNormalization(momentum=0.8))
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (7, 7), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))

    return model

def discriminator():
    inputLayer = layers.Input((28, 28, 1))
    x = tf.keras.layers.GaussianNoise(0.05)(inputLayer)
    x = layers.Conv2D(64, (7, 7), strides=(2, 2), padding='same')(x)
    x = layers.LeakyReLU()(x)
    x = layers.BatchNormalization(momentum=0.8)(x)
    # x = layers.Dropout(0.3)(x)

    x = layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same')(x)
    x = layers.LeakyReLU()(x)
    x = layers.BatchNormalization(momentum=0.8)(x)
    # x = layers.Dropout(0.3)(x)

    x = layers.Conv2D(128, (3, 3), strides=(2, 2), padding='same')(x)
    x = layers.LeakyReLU()(x)
    x = layers.BatchNormalization(momentum=0.8)(x)
    # x = layers.Dropout(0.3)(x)
    print(x.shape)
    x = layers.GlobalAveragePooling2D()(x)
    encOut = x
    x = layers.Dense(1, activation='sigmoid')(x)

    dis = tf.keras.models.Model(inputs=inputLayer, outputs=x)
    enc = tf.keras.models.Model(inputs=inputLayer, outputs=encOut)

    return dis, enc

In [None]:
gen = generator()
des, enc = discriminator()

generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

In [None]:
def trainDesGenEnc(gen, des, enc, real, batch_size):
  with tf.GradientTape() as disc_tape, tf.GradientTape() as gen_tape:
    noise = enc(real, training=False)
    fake = gen(noise, training=True)
    X = tf.concat((fake, real), axis=0)
    pred = des(X, training=True)

    fake_output = pred[:batch_size]
    real_output = pred[batch_size:]

    des_loss = discriminator_loss(real_output, fake_output)

    gradients_of_discriminator = disc_tape.gradient(des_loss, des.trainable_variables)
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, des.trainable_variables))

    gen_loss = generator_loss(fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, gen.trainable_variables)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, gen.trainable_variables))

# @tf.function
def trainGenEnc(gen, enc, real, batch_size):
  with tf.GradientTape() as disc_tape, tf.GradientTape() as gen_tape:
    real_enc = enc(real, training=False)
    enc_fake = gen(real_enc, training=True)

    gen_loss = generator_enc_loss(real, enc_fake) * 0.5
    gradients_of_generator = gen_tape.gradient(gen_loss, gen.trainable_variables)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, gen.trainable_variables))

# @tf.function
def trainDes(gen, des, real, batch_size):
  with tf.GradientTape() as disc_tape:
    noise = tf.random.normal([batch_size, 128])

    fake = gen(noise, training=False)

    X = tf.concat((fake, real), axis=0)
    
    pred = des(X, training=True)

    fake_output = pred[:batch_size]
    real_output = pred[batch_size:]

    des_loss = discriminator_loss(real_output, fake_output)
    gradients_of_discriminator = disc_tape.gradient(des_loss, des.trainable_variables)
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, des.trainable_variables))

# @tf.function
def trainDesGen(gen, des, real, batch_size):
  with tf.GradientTape() as disc_tape, tf.GradientTape() as gen_tape:
    noise = tf.random.normal([batch_size, 128])
    fake = gen(noise, training=True)
    X = tf.concat((fake, real), axis=0)
    pred = des(X, training=True)

    fake_output = pred[:batch_size]
    real_output = pred[batch_size:]

    des_loss = discriminator_loss(real_output, fake_output)
    gen_loss = generator_loss(fake_output)

    gradients_of_discriminator = disc_tape.gradient(des_loss, des.trainable_variables)
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, des.trainable_variables))

    gradients_of_generator = gen_tape.gradient(gen_loss, gen.trainable_variables)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, gen.trainable_variables))

# @tf.function
def trainGen(gen, des, batch_size):
  with tf.GradientTape() as gen_tape:
    noise = tf.random.normal([batch_size, 128])

    fake = gen(noise, training=True)
    fake_output = des(fake, training=False)

    gen_loss = generator_loss(fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, gen.trainable_variables)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, gen.trainable_variables))

def evalGan(gen, des, data, batches, batch_size):
  desAcc = 0
  genLoss = 0
  for i in range(batches):
    real = data[i]
    fake = gen.predict(tf.random.normal([batch_size, 128]))
    X = tf.concat((fake, real), axis=0)

    output = des.predict(X)

    real_output = output[batch_size:]
    fake_output = output[:batch_size]

    labels = tf.reshape(tf.concat((tf.zeros_like(fake_output), tf.ones_like(real_output)), axis=0), [-1])
    output = tf.reshape(output, [-1])
    
    acc = tf.keras.metrics.binary_accuracy(labels, output, threshold=0.5)
    desAcc += acc.numpy()
    # print(acc)
    genLoss += tf.reduce_sum(generator_loss(fake_output)).numpy()
  return desAcc / batches, genLoss / batches

def trainGan(realData, epochs=10, batch_size=5, loss='mse'):
  realData = realData.reshape(tuple([-1, batch_size] + list(realData.shape[1:])))
  print(realData.shape)
  noise = tf.random.normal([16, 128])
  results = []
  for epoch in range(epochs):
    realData = tf.random.shuffle(realData)
    for iter in range(len(realData)):
      real = realData[iter]
      real = tf.cast(real, tf.float32)

      trainDes(gen, des, real, batch_size)
      trainGen(gen, des, batch_size)
      # trainDesGen(gen, des, real, batch_size)
      # trainGenEnc(gen, enc, real, batch_size)
      # trainDesGenEnc(gen, des, enc, real, batch_size)

    fake = gen.predict(noise)
    print("Evaluating:")
    desAcc, genLoss = evalGan(gen, des, realData, 10, batch_size)
    results.append({'desAcc':desAcc, 'genLoss':genLoss})
    print("Epoch ", epoch, desAcc, genLoss)

    fig = plt.figure(figsize=(4,4))

    for i in range(fake.shape[0]):
      plt.subplot(4, 4, i+1)
      plt.imshow(fake[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
      plt.axis('off')
    plt.show()

In [None]:
trainGan(cx_train, epochs=100, batch_size=60)