In [2]:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import matplotlib.pyplot as plt
import datetime

%load_ext tensorboard

In [12]:
seq_len = 100
batch_size = 32

ds_train, ds_test = tfds.load('mnist', split=['train', 'test'], as_supervised=True)
def cumsum_dataset(ds, seq_len):
  #only get the targets, to keep this demonstration simple (and force students to understand the code if they are using it by rewriting it respectively)
  ds = ds.map(lambda x, t: ((tf.cast(x, tf.float32)/128.)-1, tf.cast(t, tf.int32)))
  # use window to create subsequences. This means ds is not a dataset of datasets, i.e. every single entry in the dataset is itself a small tf.data.Dataset object with seq_len many entries!
  ds = ds.window(seq_len)
  #make sure to check tf.data.Dataset.scan() to understand how this works!
  def alternating_scan_function(state, elem):
    #state is allways the sign to use!
    old_sign = state
    #just flip the sign for every element
    new_sign = old_sign*-1
    #elem is just the target of the element. We need to apply the appropriate sign to it!
    signed_target = elem*old_sign
    #we need to return a tuple for the scan function: The new state and the output element
    out_elem = signed_target
    new_state = new_sign
    return new_state, out_elem
  #we now want to apply this function via scanning, resulting in a dataset where the signs are alternating
  #remember we have a dataset, where each element is a sub dataset due to the windowing!
  ds = ds.map(lambda x, t: (x, t.scan(initial_state=1, scan_func=alternating_scan_function)))
  #now we need a scanning function which implements a cumulative sum, very similar to the cumsum used above
  def scan_cum_sum_function(state, elem):
    #state is the sum up the the current element, element is the new digit to add to it
    sum_including_this_elem = state+elem
    #both the element at this position and the returned state should just be sum up to this element, saved in sum_including_this_elem
    return sum_including_this_elem, sum_including_this_elem
  #again we want to apply this to the subdatasets via scan, with a starting state of 0 (sum before summing is zero...)
  ds = ds.map(lambda x, t: (x, t.scan(initial_state=0, scan_func=scan_cum_sum_function)))
  #finally we need to create a single element from everything in the subdataset
  ds = ds.map(lambda x, t: (x.batch(seq_len).get_single_element(), t.batch(seq_len).get_single_element()))
  return ds

ds_train = ds_train.apply(lambda dataset: cumsum_dataset(dataset, seq_len)).cache().batch(batch_size).prefetch(tf.data.AUTOTUNE)
ds_test = ds_test.apply(lambda dataset: cumsum_dataset(dataset, seq_len)).cache().batch(batch_size).prefetch(tf.data.AUTOTUNE)

In [None]:
class BasicCNN(tf.keras.Model):
    def __init__(self):
        super().__init__()

        self.metrics_list = [tf.keras.metrics.Mean(name="loss"),
                            tf.keras.metrics.CategoricalAccuracy(name="acc")]
        
        self.optimizer = tf.keras.optimizers.Adam()
        
        self.loss_function = tf.keras.losses.CategoricalCrossentropy()

        self.layer_list = [

                           tf.keras.layers.BatchNormalization(),
                           tf.keras.layers.Conv2D(filters=24, kernel_size=3, padding='same', activation='relu'),
                           tf.keras.layers.Dropout(0.1),
                           tf.keras.layers.Conv2D(filters=24, kernel_size=3, padding='same', activation='relu'),
                           tf.keras.layers.Dropout(0.1),
                           tf.keras.layers.Conv2D(filters=24, kernel_size=3, padding='same', activation='relu'),
                           tf.keras.layers.Dropout(0.1),
                           tf.keras.layers.Conv2D(filters=24, kernel_size=3, padding='same', activation='relu'),
                           tf.keras.layers.Dropout(0.1),
                           tf.keras.layers.Conv2D(filters=24, kernel_size=3, padding='same', activation='relu'),
                           tf.keras.layers.Dropout(0.1),
                           tf.keras.layers.Conv2D(filters=24, kernel_size=3, padding='same', activation='relu'),
                           tf.keras.layers.Dropout(0.1),
                           tf.keras.layers.MaxPooling2D(pool_size=2, strides=2),
                           
                           tf.keras.layers.BatchNormalization(),
                           tf.keras.layers.Conv2D(filters=48, kernel_size=3, padding='same', activation='relu'),
                           tf.keras.layers.Dropout(0.1),
                           tf.keras.layers.Conv2D(filters=48, kernel_size=3, padding='same', activation='relu'),
                           tf.keras.layers.Dropout(0.1),
                           tf.keras.layers.Conv2D(filters=48, kernel_size=3, padding='same', activation='relu'),
                           tf.keras.layers.Dropout(0.1),
                           tf.keras.layers.Conv2D(filters=48, kernel_size=3, padding='same', activation='relu'),
                           tf.keras.layers.Dropout(0.1),
                           tf.keras.layers.MaxPooling2D(pool_size=2, strides=2),

                           tf.keras.layers.BatchNormalization(),
                           tf.keras.layers.Conv2D(filters=96, kernel_size=3, padding='same', activation='relu'),
                           tf.keras.layers.Dropout(0.1),
                           tf.keras.layers.Conv2D(filters=96, kernel_size=3, padding='same', activation='relu'),
                           tf.keras.layers.Dropout(0.1),
                           tf.keras.layers.GlobalAveragePooling2D(),
                           
                           tf.keras.layers.Dense(10, activation='softmax')]
                           

    @tf.function
    def call(self, x):
        for item in self.layer_list:
            x = item(x)
        return x
     
    @property
    def metrics(self):
        return self.metrics_list
        # return a list with all metrics in the model

    # 4. reset all metrics objects
    def reset_metrics(self):
        for metric in self.metrics:
            metric.reset_states()

    # 5. train step method
    @tf.function
    def train_step(self, data):
        img, target = data
        
        with tf.GradientTape() as tape:
            output = self(img, training=True)
            loss = self.loss_function(target, output)
            
        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

        # update loss metric
        self.metrics[0].update_state(loss)
        
        # for all metrics except loss, update states (accuracy etc.)
        for metric in self.metrics[1:]:
            metric.update_state(target, output)

        # Return a dictionary mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}
    
    @tf.function
    def test_step(self, data):
        img, target = data

        output = self(img, training=False)
        loss = self.loss_function(target, output)

        self.metrics[0].update_state(loss)
        # for accuracy metrics:
        for metric in self.metrics[1:]:
            metric.update_state(target, output)

        return {m.name: m.result() for m in self.metrics}

In [None]:
# Define where to save the log
config_name= "batchnormblock+dropout0.1layer+avg"
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

train_log_path = f"logs/{config_name}/{current_time}/train"
val_log_path = f"logs/{config_name}/{current_time}/val"

# log writer for training metrics
train_summary_writer = tf.summary.create_file_writer(train_log_path)

# log writer for validation metrics
val_summary_writer = tf.summary.create_file_writer(val_log_path)

def training_loop(model, train_ds, val_ds, epochs, train_summary_writer, val_summary_writer):
    for epoch in range(epochs):
        print(f"Epoch {epoch}:")
        
        # Training:
        
        for data in tqdm.tqdm(train_ds, position=0, leave=True):
            metrics = model.train_step(data)
            
            # logging the validation metrics to the log file which is used by tensorboard
            with train_summary_writer.as_default():
                for metric in model.metrics:
                    tf.summary.scalar(f"{metric.name}", metric.result(), step=epoch)

        # print the metrics
        print([f"{key}: {value.numpy()}" for (key, value) in metrics.items()])

        # reset all metrics (requires a reset_metrics method in the model)
        model.reset_metrics()    
        
        # Validation:
        for data in val_ds:
            metrics = model.test_step(data)
        
            # logging the validation metrics to the log file which is used by tensorboard
            with val_summary_writer.as_default():
                for metric in model.metrics:
                    tf.summary.scalar(f"{metric.name}", metric.result(), step=epoch)
                    
        print([f"val_{key}: {value.numpy()}" for (key, value) in metrics.items()])

        # reset all metrics
        model.reset_metrics()
        print("\n")

In [None]:
%tensorboard --logdir logs/