## 2.1 Prepare the Dataset

In [12]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow import keras
from keras import datasets, layers, models
import matplotlib.pyplot as plt
import numpy as np
import datetime

In [13]:
# Load MNIST dataset
train_ds, test_ds = tfds.load('mnist', split=['train','test'], as_supervised=True)

def cumsum_dataset(ds, seq_len):
    #only get the targets, to keep this demonstration simple (and force students to understand the code if they are using it by rewriting it respectively)
    ds = ds.map(lambda x, t: tf.cast(t, dtype=tf.dtypes.int32))
    # use window to create subsequences. This means ds is not a dataset of datasets, i.e. every single entry in the dataset is itself a small tf.data.Dataset object with seq_len many entries!
    ds = ds.window(seq_len)
    #make sure to check tf.data.Dataset.scan() to understand how this works!
    def alternating_scan_function(state, elem):
        #state is allways the sign to use!
        old_sign = state
        #just flip the sign for every element
        new_sign = old_sign*-1
        #elem is just the target of the element. We need to apply the appropriate sign to it!
        signed_target = elem*old_sign
        #we need to return a tuple for the scan function: The new state and the output element
        out_elem = signed_target
        new_state = new_sign
        return new_state, out_elem
    #we now want to apply this function via scanning, resulting in a dataset where the signs are alternating
    #remember we have a dataset, where each element is a sub dataset due to the windowing!
    ds = ds.map(lambda sub_ds: sub_ds.scan(initial_state=1, scan_func=alternating_scan_function))
    #now we need a scanning function which implements a cumulative sum, very similar to the cumsum used above
    def scan_cum_sum_function(state, elem):
        #state is the sum up the the current element, element is the new digit to add to it
        sum_including_this_elem = state+elem
        #both the element at this position and the returned state should just be sum up to this element, saved in sum_including_this_elem
        return sum_including_this_elem, sum_including_this_elem
    #again we want to apply this to the subdatasets via scan, with a starting state of 0 (sum before summing is zero...)
    ds = ds.map(lambda sub_dataset: sub_dataset.scan(initial_state=0, scan_func=scan_cum_sum_function))
    #finally we need to create a single element from everything in the subdataset
    ds = ds.map(lambda sub_dataset: sub_dataset.batch(seq_len).get_single_element())
    return ds

In [14]:
def prepare_mnist_dataset(mnist_ds, seq_len, ds_type):
    # choose type of dataset
    if ds_type == 'Train':
        # flatten the images into vector
        mnist_ds = mnist_ds.map(lambda img, target: (tf.reshape(img, (-1,28,28,1)), target))
        # convert data from uint8 to float32
        mnist_ds = mnist_ds.map(lambda img, target: (tf.cast(img, tf.float32), target))
        # Sloppy input normalization, just bringing target values from range [0, 255] to [0, 1]
        mnist_ds = mnist_ds.map(lambda target, img: ((target/128.), img))
        mnist_ds.apply(lambda dataset: cumsum_dataset(dataset, seq_len)).take(10)
        # shuffle, batch, prefetch
        mnist_ds = mnist_ds.shuffle(1000)
        mnist_ds = mnist_ds.batch(32)
        mnist_ds = mnist_ds.prefetch(tf.data.AUTOTUNE)
        # return preprocessed dataset
        return mnist_ds
    elif ds_type == 'Test':
        # flatten the images into vector
        mnist_ds = mnist_ds.map(lambda img, target: (tf.reshape(img, (-1,28,28,1)), target))
        # convert data from uint8 to float32
        mnist_ds = mnist_ds.map(lambda img, target: (tf.cast(img, tf.float32), target))
        # Sloppy input normalization, just bringing target values from range [0, 255] to [0, 1]
        mnist_ds = mnist_ds.map(lambda target, img: ((target/128.), img))
        mnist_ds.apply(lambda dataset: cumsum_dataset(dataset, seq_len)).take(10)
        # batch, prefetch
        mnist_ds = mnist_ds.batch(32)
        mnist_ds = mnist_ds.prefetch(tf.data.AUTOTUNE)
        # return preprocessed dataset
        return mnist_ds


## 2.2 The CNN & LSTM Network
    The first part of your model should have a basic CNN structure. This part should extract vector representations from each MNIST image using Conv2D layers as well as (global) pooling or Flatten layers. A Conv2D layer can be called on a batch of sequences of images, where the time dimension is in the second axis. The time dimension will then be processed like a second batch dimension (search for ”extended batch shape” in the Conv2D layer documentation page for an example).
    While Conv2D layers accept our (batch, sequence-length, image) data structure with their extended batch size functionality, for the pooling layers to work correctly you will have to wrap them in TensorFlow’s TimeDistributed layers.
    Once you have encoded all images as vectors, the shape of the tensor should be (batch, sequence-length, features), which can be fed to a non-convolutional standard LSTM.
    
## 2.3 LSTM AbstractRNNCell layer
    For the LSTM, we want to not use the (optimized) keras implementation, but instead we want you to be able to implement the LSTM logic that is applied at each time-step yourselves. For this, we want to subclass the AbstractRNNCell layer and implement its methods and define the required properties. Those are state size, output size, and get initial state, which determines the initial hidden and cell state of the LSTM (usually tensors filled with zeros). 
    The LSTM-cell layer’s call method should take one (batch of) feature vector(s) as its input, along with the ”states”, a list containing the different state tensors of the LSTM cell (cell state and hidden state!).
    In the call method, the layer then uses these two states together with thevector representation of the current MNIST image in the sequence, and updates both the hidden state, and the cell state (in the way that it is done in an LSTM). The returns should be the output of the LSTM, to be used to compute the model output for this time-step (usually the hidden state), as well as a list containing the new states (e.g. [new hidden state, new cell state]).

## 2.4 Wrapping the LSTM-Cell layer with an RNN layer
    Since the LSTM cell only provides the computation for one time-step, you would need to write a wrapper layer around it, that applies it to every time-step in the sequence, aggregating the outputs and states in a for loop.
    In Tensorflow you should use the RNN layer for this, tf.keras.layers.RNN takes an instance of your LSTM cell as the first argument in its constructor. You also need to specify whether you want the RNN wrapper layer to return the output of your LSTM-cell for every time-step or only for the last step (with the argument return sequences=True). This is generally task-dependent, so think about what makes most sense in this case. For speed-ups (at the cost of memory usage) you can set the ”unroll” argument to True.
    The ”wrapper” RNN layer then takes the sequence of vector representations of the mnist images as its input (batch, seq len, feature dim).

## 2.5 Computing the model output
    With the output of your RNN-wrapped LSTM-Cell, you can now compute the model predictions. Again depending on the task, you need to think about what your predictions should be (generally have one prediction per target that is associated with your sequence). Dense layers also behave in the same way as Conv2D layers - when there is an additional time-dimension (batch, time, features), they apply the same computation for every time-index and for every batch-index. So you could (if the task demands it) use the same Dense layer to predict targets for all time-steps. You likely do not want to have a Dense layer for each time-step’s target prediction (potential for overfitting!).

In [15]:
class RNNCell(tf.keras.layers.AbstractRNNCell):
    def __init__(self, recurrent_units_1, recurrent_units_2):
        super().__init__()

        self.recurrent_units_1 = recurrent_units_1
        self.recurrent_units_2 = recurrent_units_2

        self.linear_1 = tf.keras.layers.TimeDistributed(
            tf.keras.layers.Dense(recurrent_units_1))
        self.linear_2 = tf.keras.layers.TimeDistributed(
            tf.keras.layers.Dense(recurrent_units_2))
    
        # First recurrent layer in the RNN
        self.recurrent_layer_1 = tf.keras.layers.Conv2D(filters=recurrent_units_1,
                                    kernel_size=3,
                                    padding = 'same',
                                    kernel_initializer=tf.keras.initializers.Orthogonal(
                                        gain=1.0, seed=None),
                                    activation=tf.nn.tanh)
        
        # layer normalization for trainability
        self.batch_norm_1 = tf.keras.layers.BatchNormalization()

        # Second recurrent layer in the RNN
        self.recurrent_layer_2 = tf.keras.layers.Conv2D(filters=recurrent_units_2,
                                    kernel_size=3,
                                    padding = 'same',
                                    kernel_initializer=tf.keras.initializers.Orthogonal(
                                        gain=1.0, seed=None),
                                    activation=tf.nn.tanh)

        # layer normalization for trainability
        self.batch_norm_2 = tf.keras.layers.BatchNormalization()

    @property
    def state_size(self):
        return [tf.TensorShape([self.recurrent_units_1]), 
                tf.TensorShape([self.recurrent_units_2])]

    @property
    def output_size(self):
        return [tf.TensorShape([self.recurrent_units_2])]

    def initial_state(self, inputs=None, batch_size=None, dtype=None):
        return [tf.zeros([self.recurrent_units_1]),
                tf.zeros([self.recurrent_units_2])]
    
    @tf.function
    def __call__(self, inputs, states):
        # Unpack the states
        state_layer_1 = states[0]
        state_layer_2 = states[1]

        # Linearly project input
        x = self.linear_1(inputs) + state_layer_1
        # Apply first reccurent layer
        new_state_layer_1 = self.recurrent_layer_1(x)
        # Apply first layer's layer norm
        x = self.batch_norm_1(new_state_layer_1)
        # linearly project output of layer norm
        x = self.linear_2(x) + state_layer_2
        # Apply second reccurent layer
        new_state_layer_2 = self.recurrent_layer_2(x)
        # Apply second layer's layer norm
        x = self.batch_norm_2(new_state_layer_2)

        # Return output and the list of new states of the layers
        return x, [new_state_layer_1, new_state_layer_2] 

    def get_config(self):
        return {'recurrent_units_1': self.recurrent_units_1,
                'recurrent_units_2': self.recurrent_units_2}

In [16]:
class RNNModel(tf.keras.Model):
    def __init__(self, input_n, output_n):
        super().__init__()

        self.rnn_cell = RNNCell(input_n, output_n)
        # Return_sequences collects and returns the output of the rnn_cell for all time-steps
        # Unroll unrolls the network for speed (at the cost of memory)
        self.rnn_layer = tf.keras.layers.RNN(self.rnn_cell, return_sequences=False, unroll=True)

        self.avg_global_pool = tf.keras.layers.TimeDistributed(
            tf.keras.layers.GlobalAvgPool2D())

        self.flatten = tf.keras.layers.TimeDistributed(
            tf.keras.layers.Flatten())

        self.output_layer = tf.keras.layers.TimeDistributed(
            tf.keras.layers.Dense(units=10,activation=tf.nn.sigmoid))

        self.metrics_list = [
            tf.keras.metrics.Mean(name='loss'),
            tf.keras.metrics.CategoricalAccuracy(name='accuracy')
        ]

    @tf.function
    def __call__(self, sequence, training=False):
        x = self.rnn_layer(sequence)
        x = self.avg_global_pool(sequence)
        x = self.flatten(sequence)
        return self.output_layer(x)

    @property
    def metrics(self):
        return self.metrics_list

    def reset_metrics(self):
        for metric in self.metrics:
            metric.reset_state()

    def train_step(self, data):
        sequence, label = data
        with tf.GradientTape() as tape:
            output = self(sequence, training=True)
            loss = self.compiled_loss(label, output, regularization_losses=self.losses)
        gradients = tape.gradient(loss, self.trainable_variables)
        
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        
        self.metrics[0].update_state(loss)
        self.metrics[1].update_state(label, output)
        
        return {m.name : m.result() for m in self.metrics}

    def test_step(self, data):
        sequence, label = data
        output = self(sequence, training=False)
        loss = self.compiled_loss(label, output, regularization_losses=self.losses)
                
        self.metrics[0].update_state(loss)
        self.metrics[1].update_state(label, output)
        
        return {m.name : m.result() for m in self.metrics}    

## 2.6 Training
    We still highly recommend to write another custom training-loop for this homework. However, if you want to follow the code presented in this week, feel free to already use the model.compile and model.fit methods for the first time. Make sure to track your experiments properly, save configs (e.g. hyperparameters) of your settings, save logs (e.g. with Tensorboard) and checkpoint your model’s weights (or even the complete model). To visualize your results you can modify your training loop to also write metrics to lists, or rely on the default history callback that model.fit uses.

In [17]:
def visualization(history):
    # Plotting the loss data
    plt.plot(history.history['loss'], label='training')
    plt.plot(history.history['val_loss'], label='validation')
    plt.legend(loc='best')
    plt.xlabel('Epoch')
    plt.ylabel('Categorical Crossentropy loss')

    plt.show()

In [18]:
SEQUENCE_LENGTH = 10

In [19]:
train_dataset = prepare_mnist_dataset(train_ds, SEQUENCE_LENGTH, 'Train')
# Сheck the contents of the dataset
for img, label in train_dataset:
    print(img.shape, label.shape)
    break

(32, 1, 28, 28, 1) (32,)


In [20]:
val_dataset = prepare_mnist_dataset(train_ds, SEQUENCE_LENGTH, 'Test')
for img, label in val_dataset:
    print(img.shape, label.shape)
    break

(32, 1, 28, 28, 1) (32,)


In [21]:
# input_shape = (28,28,1)
input_units = 24
output_units = 48

model = RNNModel(input_units, output_units)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True)

In [22]:
# Compile the model
model.compile(optimizer=optimizer,
                loss=loss) 


In [23]:
EXPERIMENT_NAME = 'RNN_cumsum'
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
logging_callback = tf.keras.callbacks.TensorBoard(log_dir=f'./logs/{EXPERIMENT_NAME}/{current_time}')

# Train the model with fit function
history = model.fit(train_dataset,
                        validation_data=val_dataset,
                        epochs = 20,
                        callbacks = logging_callback)
    
# Save model
model.save(filepath='./saved_models/')

# Visualize the data
visualization(history)

Epoch 1/20


ValueError: in user code:

    File "C:\Anaconda3\lib\site-packages\keras\engine\training.py", line 1160, in train_function  *
        return step_function(self, iterator)
    File "C:\Anaconda3\lib\site-packages\keras\engine\training.py", line 1146, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "C:\Anaconda3\lib\site-packages\keras\engine\training.py", line 1135, in run_step  **
        outputs = model.train_step(data)
    File "C:\Users\п\AppData\Local\Temp\ipykernel_4352\4165944845.py", line 43, in train_step
        loss = self.compiled_loss(label, output, regularization_losses=self.losses)
    File "C:\Anaconda3\lib\site-packages\keras\engine\compile_utils.py", line 265, in __call__
        loss_value = loss_obj(y_t, y_p, sample_weight=sw)
    File "C:\Anaconda3\lib\site-packages\keras\losses.py", line 152, in __call__
        losses = call_fn(y_true, y_pred)
    File "C:\Anaconda3\lib\site-packages\keras\losses.py", line 272, in call  **
        return ag_fn(y_true, y_pred, **self._fn_kwargs)
    File "C:\Anaconda3\lib\site-packages\keras\losses.py", line 1990, in categorical_crossentropy
        return backend.categorical_crossentropy(
    File "C:\Anaconda3\lib\site-packages\keras\backend.py", line 5529, in categorical_crossentropy
        target.shape.assert_is_compatible_with(output.shape)

    ValueError: Shapes (None,) and (None, 1, 10) are incompatible
