<a href="https://colab.research.google.com/github/Spinkk/Implementing-ANNs-with-Tensorflow/blob/main/w08/ANN_w08.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import time
import datetime

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
 
import tensorflow as tf
import tensorflow_datasets as tfds 

1. Data set

In [None]:
def preprocess_tfds(dataset, batch_size=32, buffer_size=1024, prefetch_factor=tf.data.experimental.AUTOTUNE, shuffle=True):
    '''
    Create an input pipeline from tf.dataset. 
    Adjusted to only take input as there are no labels for autoencoders.
    Does only do input pipeline optimization when desired (inputs are not None)

    :param dataset: tf.dataset to preprocess
    :param batch_size: int, default batch size is 32
    :param buffer_size: int, default is 1024
    :param prefetch_factor: int, default prefetch size is TF autotune
    :returns: preprocessed tf.dataset
    ''' 
    
    # only use batching if shuffle is set to False
    if not shuffle:
        dataset = dataset.batch(batch_size, drop_remainder=True)
    else:
        dataset = dataset.shuffle(buffer_size).batch(batch_size, drop_remainder=True)
    
    # casting of the images to float32 and expanding dim since there is no channel dim
    # dividing by 255 to min-max scale the input
    # drop the label is it is not needed
    dataset = dataset.map(lambda img, label: tf.cast(img, tf.float32)/255, 
                          num_parallel_calls=tf.data.experimental.AUTOTUNE)

    # prefetch the dataset using AUTOTUNE to automatically find the optimal number of batches to prefetch
    if not prefetch_factor is None:
        dataset = dataset.prefetch(prefetch_factor)
    return dataset


## label codes for later analysis/visualization of encoded dataset
label_code = ['T-Shirt', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot']

# load the entire dataset from tfds (you can also get fashion_mnist from keras)
train_ds, test_ds = tfds.load('fashion_mnist', 
                              split=['train', 'test'], 
                              as_supervised=True, 
                              shuffle_files=False)


train_ds = preprocess_tfds(train_ds)
test_ds = preprocess_tfds(test_ds)

# 2. Model

In [None]:
class Discriminator(tf.keras.Model):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.slayers = [
                                           tf.keras.layers.Conv2D(filters=32,
                                                                  kernel_size=3,
                                                                  strides=(2,2),
                                                                  input_shape=(None,28,28,1)),
                                           tf.keras.layers.BatchNormalization(),
                                           tf.keras.layers.Activation('relu'),
                                           tf.keras.layers.Conv2D(filters=64,
                                                                  kernel_size=3,
                                                                  strides=(2,2)),
                                           tf.keras.layers.BatchNormalization(),
                                           tf.keras.layers.Activation('relu'),
                                           tf.keras.layers.Flatten(),
                                           # binary decision of fake/real data
                                           tf.keras.layers.Dense(1, activation='sigmoid')]

    def call(self, x, training=False):
        for layer in self.slayers:
            try:  # training argument only for BN layer
                x = layer(x, training) 
            except:
                x = layer(x)
        return x


class Generator(tf.keras.Model):
    def __init__(self, latent_dim=100, restore_shape=(7,7,64)):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim
        self.slayers = [tf.keras.layers.Dense(units=int(tf.math.reduce_prod(restore_shape)),
                                             input_shape=(latent_dim,)),
                       tf.keras.layers.BatchNormalization(),
                       tf.keras.layers.Activation('relu'),
                       # reshape to 3 dim with depth dim again
                       tf.keras.layers.Reshape(target_shape=restore_shape),
                       # (2,2) strided transposed conv to upsample        
                       tf.keras.layers.Conv2DTranspose(filters=32,
                                                       kernel_size=(3,3),
                                                       strides=(2,2),
                                                       padding='same'),
                       tf.keras.layers.BatchNormalization(),
                       tf.keras.layers.Activation('relu'),
                       # restore image by convolution with image size
                       tf.keras.layers.Conv2DTranspose(filters=1,
                                                       kernel_size=(3,3),
                                                       strides=(2,2),
                                                       padding='same'),
                       tf.keras.layers.BatchNormalization(),
                       # use sigmoid to get values between 0 and 1
                       tf.keras.layers.Activation('sigmoid')]

    def call(self, x, training=False):
        for layer in self.slayers:
            try:  # training argument only for BN layer
                x = layer(x, training) 
            except:
                x = layer(x)
        return x


In [None]:
exdis = Discriminator()
for img in train_ds.take(1):
    res = exdis(img)

exgen = Generator()
tf.shape(exgen(res))

<tf.Tensor: shape=(4,), dtype=int32, numpy=array([32, 28, 28,  1], dtype=int32)>

# 3. Training

In [None]:
class Timer():
    """
    A small class to measure time during training.
    """
    def __init__(self):
        self._start_time = None

    def start(self):
        """
        Start a new timer
        """
        if self._start_time is not None:
            print(f"Timer is running. Use .stop() to stop it")
            return None

        self._start_time = time.perf_counter()

    def stop(self):
        """
        Stop the timer, and report the elapsed time
        """
        if self._start_time is None:
            print(f"Timer is not running. Use .start() to start it")
            return 0
    
        elapsed_time = time.perf_counter() - self._start_time
        self._start_time = None
        return elapsed_time



In [None]:
# Hyperparameters
epochs = 10
learning_rate = 0.0001
latent_dim = 100

tf.keras.backend.clear_session() #clear session from previous models
timer = Timer() # Instantiate the timer

# Instiante models
generator = Generator()
discriminator = Discriminator()
generator.build((None,latent_dim))
discriminator.build((None,28,28,1))
generator.summary()
discriminator.summary()

# Instantiate optimizer
adam = tf.keras.optimizers.Adam(learning_rate) 

Model: "generator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense (Dense)                multiple                  316736    
_________________________________________________________________
batch_normalization (BatchNo multiple                  12544     
_________________________________________________________________
activation (Activation)      multiple                  0         
_________________________________________________________________
reshape (Reshape)            multiple                  0         
_________________________________________________________________
conv2d_transpose (Conv2DTran multiple                  18464     
_________________________________________________________________
batch_normalization_1 (Batch multiple                  128       
_________________________________________________________________
activation_1 (Activation)    multiple                  0 

In [None]:
def training_step_GAN(data, model_gen, model_disc, optimizer, metric_gen, metric_disc):
    bce = tf.keras.losses.BinaryCrossentropy()
    
    for batch_true in data:
        # Generate fake images
        noise_in_z = tf.random.normal([tf.shape(batch_true)[0], model_gen.latent_dim])
        batch_gen = model_gen(noise_in_z, training=True)

        # Discriminator decides whether provided images are fake or not 
        fake_pred = model_disc(batch_gen, training=True)
        true_pred = model_disc(batch_true, training=True)
        
        # Compute loss/gradient for discriminator
        with tf.GradientTape() as tape:          
            # discrminator should ideally assign 1 to true img and 0 to fake
            loss_disc = bce(tf.zeros_like(fake_pred), fake_pred) + \
             bce(tf.ones_like(true_pred), true_pred) 
            # TODO: gradient isn't computed correctly although traininable variables are there
            g_disc = tape.gradient(loss_disc, model_disc.trainable_variables)
            print('g_disc', g_disc[0])
        
        # Compute loss/gradient for generator
        with tf.GradientTape() as tape:
            # generator should fool disc. that fake images are real (label 1)
            loss_gen = bce(tf.ones_like(fake_pred), fake_pred)     
            g_gen = tape.gradient(loss_gen, model_gen.trainable_variables)
        
        # Gradient descent
        # TODO: "No gradients provided for any variable:"
        optimizer.apply_gradients(zip(g_disc, model_disc.trainable_variables))
        optimizer.apply_gradients(zip(g_gen, model_gen.trainable_variables))
        
        # Save mean loss values
        metric_gen.update_state(loss_gen)
        metric_disc.update_state(loss_disc)


def evaluation_step_GAN(data, model_gen, model_disc, num_image=1):
    '''
    Plot generated image next to real image side by side
    '''    

    fig, ax = plt.subplots(nrows=num_image, ncols=2, figsize=(5*num_image, 10))
     

    for batch_img_true in data:
        if tf.shape(batch_img_true)[0] < num_image:
            print('NUM_IMAGE SHOULD BIG UNDER BATCH SIZE')
            pass

        # iterate over num_image img
        for i,img_true in enumerate(batch_img_true[:num_image]): 
            noise_in_z = tf.random.normal([model_gen.latent_dim])
            img_gen = model_gen(noise_in_z)
            ax[i,0].imshow(img_true, cmap='gray')
            ax[i,1].plot(img_gen, cmap='gray')
        break  # just take the first batch 


In [None]:
# take mean over different data points in the training loop
train_loss_gen = tf.keras.metrics.Mean('generator')
train_loss_disc = tf.keras.metrics.Mean('discriminator')

# initialize the logger for Tensorboard visualization
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
train_log_dir = 'logs/gradient_tape/' + current_time + '/train_ResNet'      # defining the log dir
test_log_dir = 'logs/gradient_tape/' + current_time + '/test_ResNet'        # defining the log dir
train_summary_writer = tf.summary.create_file_writer(train_log_dir)  # training logger
test_summary_writer = tf.summary.create_file_writer(test_log_dir)    # test logger

# Initialize lists for later visualization.
losses_disc = []
losses_gen = []
times = []

In [None]:
for epoch in range(epochs):
    print(f'\n[EPOCH] ____________________{epoch}____________________')
    # Training
    timer.start()
    training_step_GAN(train_ds, generator, discriminator, adam, train_loss_disc, train_loss_gen)
    # logging our metrics to a file which is used by tensorboard
    with train_summary_writer.as_default():     
        tf.summary.scalar('discriminator', train_loss_disc.result(), step=epoch)
        tf.summary.scalar('generator', train_loss_gen.result(), step=epoch)
    # append in epoch-wise history list
    train_losses.append(train_loss_disc)
    train_accuracies.append(train_loss_gen)
    # reset metrics for next epoch use
    train_loss_disc.reset_states()
    train_loss_gen.reset_states()
    # time and print progress
    elapsed_time = timer.stop()
    times.append(elapsed_time)
    print(f'[{epoch}] - Finished Epoch in {elapsed_time:0.2f} seconds - train_loss: {train_loss:0.4f}, train_acc: {train_acc:0.4f}')
    
    # Test and visualize
    timer.start()
    evaluation_step_GAN(test_ds, generator, discriminator)
    elapsed_time = timer.stop()
    times.append(elapsed_time)
    plt.show()
    print(f'\n[{epoch}] - Finished evaluation')
  
    # Print progress everywhile
    if epoch%3 == 0:
        print(f'\n[INFO] - Total time elapsed: {np.sum(times)/60:0.4f} min. Total time remaining: {(np.sum(times)/(epoch+1))*(epochs-epoch-1)/60:0.4f} min.')
print(f'[INFO] - Total run time: {np.sum(times)/60:0.4f} min.')


[EPOCH] ____________________0____________________
g_disc None


ValueError: ignored