In [None]:
import os
import tensorflow as tf
import numpy as np
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Dense, Activation, Input
from tensorflow.keras.layers import Conv2D, Flatten, Dropout
from tensorflow.keras.layers import Lambda, Subtract, Add
from tensorflow.keras.layers import Reshape, Conv2DTranspose
from tensorflow.keras.optimizers import RMSprop, Adam, SGD
from tensorflow.keras.datasets import mnist
from tensorflow.keras import backend as K
from tensorflow.keras.models import load_model

import matplotlib.pyplot as plt
from PIL import Image
import datetime
!/opt/bin/nvidia-smi

In [None]:
from google.colab import drive
drive.mount('/content/drive/')

In [None]:
%cd "/content/drive/My Drive/cgen/"

In [None]:
original_class = 0
target_class = 8
alpha = 0.1
cgen_learning_rate =  0.005
beta_1 = 0.5
classifier_def = {'true_label' : 0.9,
                 'false_label' : 0.1  # label smoothing
                 }

models_dir = './models/'
if not os.path.exists(models_dir):
            os.makedirs(models_dir)
models_best_dir = './models/best/'
if not os.path.exists(models_best_dir):
            os.makedirs(models_best_dir)
imgs_save_dir = models_dir + 'results/'
if not os.path.exists(imgs_save_dir):
            os.makedirs(imgs_save_dir)

In [None]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()

In [None]:
# Getting sub set from Mnist with only zeros
index_original_class = np.where(y_train == original_class)
x_train_original_class = x_train[index_original_class[0]]
index_test_original_class = np.where(y_test == original_class)
x_test_original_class = x_test[index_test_original_class[0]]
index_test_target_class = np.where(y_test == target_class)
x_test_target_class = x_test[index_test_target_class[0]]
index_target_class = np.where(y_train == target_class)
x_train_target_class = x_train[index_target_class[0]]

x_train_original_class[1].shape
plt.imshow(x_train_original_class[13], cmap='gray')
plt.show()
print(x_test_target_class[10].shape)
plt.imshow(x_test_target_class[10], cmap='gray')
plt.show()
# reshape
image_size = x_train_original_class.shape[1]
x_train_original_class = np.reshape(x_train_original_class, [-1, image_size, image_size, 1])
x_train_original_class = x_train_original_class.astype('float32') / 255
x_train_target_class = np.reshape(x_train_target_class, [-1, image_size, image_size, 1])
x_train_target_class = x_train_target_class.astype('float32') / 255
x_test_original_class = np.reshape(x_test_original_class, [-1, image_size, image_size, 1])
x_test_original_class = x_test_original_class.astype('float32') / 255
x_test_target_class = np.reshape(x_test_target_class, [-1, image_size, image_size, 1])
x_test_target_class = x_test_target_class.astype('float32') / 255
x_train_classifier = np.reshape(x_train, [-1, image_size, image_size, 1])
x_train_classifier = x_train_classifier.astype('float32') / 255
x_test_classifier = np.reshape(x_test, [-1, image_size, image_size, 1])
x_test_classifier = x_test_classifier.astype('float32') / 255
print(x_train.shape)
print(x_train_original_class.shape)
print(x_train_target_class.shape)

y_train_classifier = np.zeros(y_train.shape)
y_index_target_class = np.where(y_train == target_class)
y_train_classifier[y_index_target_class[0]] = 1

In [None]:
# create a sampling layer
from tensorflow.keras.layers import Layer
class reparameterize(Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""
    def call(self, inputs):
        mean, logvar = inputs
        batch = tf.shape(mean)[0]
        dim = tf.shape(mean)[1]
        eps = tf.keras.backend.random_normal(shape=(batch, dim))
        return eps * tf.exp(logvar * .5) + mean

In [None]:
# Generator: Encoder + Decoder (VAE)
input_shape = (image_size, image_size, 1)
batch_size = 128
kernel_size = 3
latent_dim = 10

inputs = Input(shape=input_shape, name='encoder_input')
x = inputs
x = Conv2D(filters=32, kernel_size=kernel_size, strides=2, activation='relu', padding='same')(x)
x = Conv2D(filters=64, kernel_size=kernel_size, strides=2, activation='relu', padding='same')(x)
x = Flatten()(x)
mean = Dense(latent_dim, name="z_mean")(x)
logvar = Dense(latent_dim, name="z_log_var")(x)
z = reparameterize()([mean, logvar])

# Instantiate Encoder Model
encoder = Model(inputs, [z, mean, logvar], name='encoder')

In [None]:
latent_inputs = Input(shape=(latent_dim,), name='decoder_input')
x = Dense(units=7*7*32, activation=tf.nn.relu)(latent_inputs)
x = Reshape((7, 7, 32))(x)
x = Conv2DTranspose(filters=64, kernel_size=3, strides=2, padding='same', activation='relu')(x)
x = Conv2DTranspose(filters=32, kernel_size=3, strides=2, padding='same', activation='relu')(x)
x = Conv2DTranspose(filters=1, kernel_size=3, strides=1, padding='same')(x)
outputs = Activation('sigmoid', name='decoder_output')(x)
# Instantiate Decoder Model
decoder = Model(latent_inputs, outputs, name='decoder')

In [None]:
class VAE(Model):
    def __init__(self, encoder, latent_dim, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.cf_layer = Dense(units=latent_dim, kernel_initializer=tf.constant_initializer(np.eye(latent_dim)))
        self.decoder = decoder

    def train_step(self, data):
        if isinstance(data, tuple):
            data = data[0]
        with tf.GradientTape() as tape:
            z, mean, logvar = encoder(data)
            reconstruction = decoder(z)
            reconstruction_loss = tf.reduce_mean(
                tf.keras.losses.binary_crossentropy(data, reconstruction)
            )
            reconstruction_loss *= 28 * 28
            kl_loss = 1 + logvar - tf.square(mean) - tf.exp(logvar)
            kl_loss = tf.reduce_mean(kl_loss)
            kl_loss *= -0.5
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        return {
            "loss": total_loss,
            "reconstruction_loss": reconstruction_loss,
            "kl_loss": kl_loss,
        }

    def call(self, inputs):
      z, mean, logvar = encoder(inputs)
      return decoder(z)

    def cf(self, inputs):
      z, mean, logvar = encoder(inputs)
      return decoder(self.cf_layer(z))

    @tf.function
    def sample(self, eps=None):
      if eps is None:
        eps = tf.random.normal(shape=(100, latent_dim))
      return tf.sigmoid(self.decoder(eps))

vae = VAE(encoder, latent_dim, decoder)

In [None]:
vae.compile(optimizer=Adam())

x_vae_train = np.concatenate([x_train_original_class, x_train_target_class])
x_vae_test = np.concatenate([x_test_original_class, x_test_target_class])

vae.fit(x_vae_train,
        x_vae_train,
        validation_data=(x_vae_test, x_vae_test),
        epochs=30,
        batch_size=batch_size)

# vae.fit(x_train_original_class,
#         x_train_original_class,
#         validation_data=(x_test_original_class, x_test_original_class),
#         epochs=30,
#         batch_size=batch_size)

In [None]:
# vae.save_weights(models_dir + "vae/" + 'vae_13_trained_on_' + str(original_class) + str(target_class) + '.tf')
vae.load_weights(models_dir + "vae/" + 'vae_trained_on_' + str(original_class) + str(target_class) + '.tf')

In [None]:
def generate_and_save_images(model, test_sample):
    result = model.call(test_sample)
    fig = plt.figure(figsize=(4, 4))
    for i in range(result.shape[0]):
        plt.subplot(4, 4, i + 1)
        plt.imshow(result[i, :, :, 0], cmap='gray')
        plt.axis('off')
    plt.savefig('vae_MNIST_improved.pdf')
    plt.show()

x_vae_train = np.concatenate([x_train_original_class, x_train_target_class])
x_vae_test = np.concatenate([x_test_original_class, x_test_target_class])
test_img = []

test_dataset = (tf.data.Dataset.from_tensor_slices(x_test_original_class).batch(batch_size))

for test_batch in test_dataset.take(1):
    test_img = test_batch[0:16, :, :, :]
generate_and_save_images(vae, test_img)

In [None]:
# Classifier
input_shape = (image_size, image_size, 1)
batch_size = 128
kernel_size = 3
latent_dim = 5
layer_filters = [32, 64]

inputs = Input(shape=input_shape, name='classifier_input')
x = inputs
for filters in layer_filters:
    x = Conv2D(filters=filters,
               kernel_size=kernel_size,
               strides=2,
               activation='relu',
               padding='same')(x)
# Generate the latent vector
x = Flatten()(x)
x = Dense(128, activation='relu')(x)
x = Dense(1, name='classifier_output', activation='sigmoid')(x)


# Instantiate Encoder Model
# classifier = Model(inputs, [x, ll], name='classifier')
classifier = Model(inputs, x, name='classifier')
classifier.summary()

In [None]:
classifier.compile(loss='binary_crossentropy', optimizer='adam')
classifier.fit(x_train_classifier, y_train_classifier, epochs=6, batch_size=128)

In [None]:
classifier.save(models_dir + 'classifier_trained_on_' + str(target_class) + '.tf')
# classifier = load_model(models_dir + 'classifier_trained_on_' + str(target_class) + '.tf')

In [None]:
from tqdm import tqdm_notebook as tqdm
input_label = np.zeros(shape=(len(x_train_original_class),))
target_label = np.ones(shape=(len(x_train_target_class),))
train_data = np.concatenate([x_train_original_class, x_train_target_class])
train_label = np.concatenate([input_label, target_label])

train_dataset_cgen = (tf.data.Dataset.from_tensor_slices((x_train_original_class, input_label))
                     .shuffle(len(x_train_original_class)).batch(batch_size))

train_dataset = (tf.data.Dataset.from_tensor_slices((train_data, train_label))
                     .shuffle(len(train_data)).batch(batch_size))

def step(vae, classifier, optimizer, optimizer2, batch_size, alpha):
    vae.cf_layer.trainable = True
    total = 0
    vae_loss = 0
    clf_loss = 0
    disc = 0
    batch_num = 0
    for batch in tqdm(train_dataset_cgen):
        batch_num += 1
        with tf.GradientTape() as dec_tape:
            index_ori = np.where(K.eval(batch[1]) == 0)[0]
            # index_target = np.where(K.eval(batch[1]) == 1)

            train_img = tf.gather(batch[0], index_ori)
            counterfactual = vae.cf(train_img)
            # for d_g
            d_g = tf.reduce_mean(tf.keras.losses.binary_crossentropy(train_img, counterfactual))
            # for d_c
            logits = classifier(counterfactual)
            target_tensor = np.ones([len(train_img), 1])* 0.9
            # target_tensor[index_ori[0]] = target_tensor[index_ori[0]]* 0.9
            target_tensor = tf.convert_to_tensor(target_tensor, dtype=tf.float32)
            d_c = tf.reduce_mean(tf.keras.losses.binary_crossentropy(target_tensor, logits))

            total_loss = (1. - alpha)*d_g + alpha * d_c
            total += total_loss
            vae_loss += d_g
            clf_loss += d_c

            dec_gradients = dec_tape.gradient(total_loss, vae.cf_layer.trainable_weights)
            optimizer.apply_gradients(zip(dec_gradients, vae.cf_layer.trainable_weights))

    for batch in tqdm(train_dataset):
        with tf.GradientTape() as disc_tape:
            logits = classifier(batch[0])
            target_tensor = np.ones([len(batch[0]), 1])
            index_ori = np.where(K.eval(batch[1]) == 0)
            index_target = np.where(K.eval(batch[1]) == 1)
            target_tensor[index_ori[0]] = target_tensor[index_ori[0]]* 0.1
            target_tensor[index_target[0]] = target_tensor[index_target[0]]* 0.9
            target_tensor = tf.convert_to_tensor(target_tensor, dtype=tf.float32)

            disc_loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(target_tensor, logits))
            disc += disc_loss

        disc_gradients = disc_tape.gradient(disc_loss, classifier.trainable_weights)
        optimizer2.apply_gradients(zip(disc_gradients, classifier.trainable_weights))
    return [total/batch_num, vae_loss/batch_num, clf_loss/batch_num, disc/batch_num]

In [None]:
latent_dim = 13
vae = VAE(encoder, latent_dim, decoder)
vae.load_weights(models_dir + 'vae/' + 'vae_08_trained_on_' + str(original_class) + str(target_class) + '.tf')
classifier = load_model(models_dir + 'classifier_trained_on_' + str(target_class) + '.tf')
optimizer=Adam(lr=0.005, beta_1=beta_1)
optimizer2=Adam(lr=0.005)
batch_size = 128
total_loss = []
vae_loss = []
clf_loss = []
dis_loss = []

train_cycles = 100
for train_cycle in range(train_cycles):
    print("Cgen train cycle:", train_cycle, train_cycles)
    # cGen training
    cgen_loss = step(vae, classifier, optimizer, optimizer2, batch_size, 0.1)
    print('cgen_loss: {a}, vae_loss: {b}, clf_loss: {c}, disc_loss: {d}'.format(a=cgen_loss[0], b=cgen_loss[1], c=cgen_loss[2], d=cgen_loss[-1]))
 
    generate_and_save_cf(vae, test_img)
    a = K.eval(cgen_loss[0])
    b = K.eval(cgen_loss[1])
    c = K.eval(cgen_loss[2])
    d = K.eval(cgen_loss[3])
    total_loss.append(a)
    vae_loss.append(b)
    clf_loss.append(c)
    dis_loss.append(d)

In [None]:
#  vae.load_weights('cgen_08_0.9_10/'+'epoch_{}'.format(100))

def generate_and_save_cf(model, test_sample):
    result = model.cf(test_sample)
    fig = plt.figure(figsize=(4, 4))
    reconstruction_loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(result, test_sample))
    print(reconstruction_loss)
    for i in range(result.shape[0]):
        plt.subplot(4, 4, i + 1)
        plt.imshow(result[i, :, :, 0], cmap='gray')
        plt.axis('off')
    # plt.savefig("cgen_13_0.1_13.pdf")
    plt.show()
    print(classifier(result))

generate_and_save_cf(vae, test_img)