In [1]:
#preprocess
import tensorflow as tf
import tensorflow_datasets as tfds
def new_target_fnc(ds, sequence_len):
  """
  Creates list of new targets by alternately adding and subtracting
  The first digit is added, the second subtracted, the third added, etc
  Parameters
  ----------
  ds : TensorFlowdataset
  original mnist dataset containg images and targets as a tuple.
  sequence_len : int
  indicates at which point the sum has to reset for the new sequence
  Returns
  -------
  l : list
  list containing the new targets
  """
  l = list()
  for i, elem in enumerate(ds):
    if (i % sequence_len) == 0:
      l.append(int(elem[1]))
    else:
      if (i % 2) == 0:
        l.append(int(l[i-1] + elem[1]))
      else:
        l.append(int(l[i-1] - elem[1]))
  return l


def prepare_mnist_data(mnist, batch_size, slice_size):
    """preprocessing the mnist data set by creating slices and updating the targets

    Args:
      mnist (tf.data.dataset): dataset to prepare
      batch_size (int): batch_size
      slice_size (int): length of sequence

    Returns:
      prepared dataset in shape(batch, sequencelength, features) with features in shape(imag, target)
    """
    #####Step 1#######
    # JUST CHANGE IMAGES NOT TARGETS
    # convert data from uint8 to float32
    data = mnist.map(lambda img, target: (tf.cast(img, tf.float32), target))
    # sloppy input normalization, just bringing image values from range [0, 255] to [-1, 1]
    data = data.map(lambda img, target: ((img/128.)-1., target))

    # get new, sliced targets
    new_targets = new_target_fnc(data,slice_size)

    # convert list to dataset
    new_targets = tf.data.Dataset.from_tensor_slices(new_targets)

    # put MNIST and new targets together
    prepared = tf.data.Dataset.zip((data,new_targets))

    # exchange old targets with new ones
    prepared = prepared.map(lambda img, target: (img[0], target))

    # slice into sequence
    prepared = prepared.batch(slice_size)

    ####Step 3####  
    # shuffle
    prepared = prepared.shuffle(2000)
    # cache data
    prepared = prepared.cache()
    # batch the dataset
    prepared = prepared.batch(batch_size)
    # prefetch
    prepared = prepared.prefetch(tf.data.AUTOTUNE)

    return prepared

In [4]:
class LSTMCell(tf.keras.layers.AbstractRNNCell):

    def __init__(self, units, batch_size) :
        """constructor function for an LSTM cell

        Args:
          units (int): number of units 
          batch_size (int): batch size

        Returns:
          prepared dataset in shape(batch, sequencelength, features) with features in shape(image, target)
        """
        #subclass the AbstractRNNCell layer        
        super().__init__()
        self.units = units
        # forget gate
        self.fg_layer = tf.keras.layers.Dense(
            units,
            activation='sigmoid',
            bias_initializer='ones' 
        )
        # input gate
        self.ig_layer = self.fg_W = tf.keras.layers.Dense(
            units,
            activation='sigmoid'
        )
        # output gate
        self.og_layer = self.fg_W = tf.keras.layers.Dense(
            units,
            activation='sigmoid'
        )
        # cell
        self.cell_layer = self.fg_W = tf.keras.layers.Dense(
            units,
            activation='tanh'
        )

    @property
    def state_size(self):
        return [tf.TensorShape(self.units),tf.TensorShape(self.units)]
    @property
    def output_size(self):
        return tf.TensorShape(self.units)
    @property
    def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
        return (tf.zeros((self.batch_size, self.cell.units)),
                tf.zeros((self.batch_size, self.cell.units)))

        
    def call(self, x, states):
        """
          Compute forward step
        """
        prev_hidden_state, prev_cell_state = states
        # gate inputs
        xh = tf.concat([x, prev_hidden_state], axis=1)
        # forget gate output
        ffilter = self.fg_layer(xh)
        # input gate output
        ifilter = self.ig_layer(xh)
        # cell state candidates
        cs_cand = self.cell_layer(xh)
        # update cell state
        cell_state = tf.math.multiply(ffilter, prev_cell_state) +\
                     tf.math.multiply(ifilter, cs_cand)
        # output gate output
        ofilter = self.og_layer(xh)
        # new hidden state
        hidden_state = tf.math.multiply(ofilter, tf.nn.tanh(cell_state))

        #returns output of the LSTM and a list containing the new states
        return hidden_state,[hidden_state, cell_state]

In [5]:
#model
from keras.layers import Dense

class MyModel(tf.keras.Model):

    # 1. constructor
    def __init__(self):
        super().__init__()

        # basic CNN structure
        self.convlayer1 = tf.keras.layers.Conv2D(filters=24, kernel_size=3, padding='same', activation='relu') 
        self.convlayer2 = tf.keras.layers.Conv2D(filters=24, kernel_size=3, padding='same', activation='relu')
        self.pooling = tf.keras.layers.MaxPooling2D(pool_size=2, strides=2)

        self.convlayer3 = tf.keras.layers.Conv2D(filters=48, kernel_size=3, padding='same', activation='relu')
        self.convlayer4 = tf.keras.layers.Conv2D(filters=48, kernel_size=3, padding='same', activation='relu')
        self.global_pool = tf.keras.layers.GlobalAvgPool2D()

        #LSTM structure
        self.lstm_cell = LSTMCell(units=32, batch_size = 32) # unit size missing
        self.lstm_layer = tf.keras.layers.RNN(self.lstm_cell, return_sequences=True, unroll=False)

        self.out = tf.keras.layers.Dense(1)

    @tf.function  
    def call(self, x):
        """
          Compute forward step
        """
        x = self.convlayer1(x)
        x = self.convlayer2(x)
        x = tf.keras.layers.TimeDistributed(self.pooling)(x)
        x = self.convlayer3(x)
        x = self.convlayer4(x)
        x = tf.keras.layers.TimeDistributed(self.global_pool)(x)
        x = self.lstm_layer(x) 
        x = self.out(x)
        return x


In [None]:
# main
import tensorflow_datasets as tfds
import tensorflow as tf
import datetime
%load_ext tensorboard

# get mnist from tensorflow_dataets
train_ds, test_ds = tfds.load('mnist', split=['train', 'test'], as_supervised=True)

#preprocess the data and make sequences
train_ds = prepare_mnist_data(train_ds, 32, 4)
test_ds = prepare_mnist_data(test_ds, 32, 4)


Downloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to ~/tensorflow_datasets/mnist/3.0.1...


Dl Completed...:   0%|          | 0/5 [00:00<?, ? file/s]

Dataset mnist downloaded and prepared to ~/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.


In [None]:
model = MyModel()
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)
loss = tf.keras.losses.MeanSquaredError()
#metrics = tf.keras.metrics.MeanSquaredError()

#set tensorboard
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
logging_callback = tf.keras.callbacks.TensorBoard(log_dir=f"./logs/LSTM/{current_time}")
%tensorboard --logdir="logs/LSTM"

model.compile(optimizer, loss) #, metrics)
history = model.fit(train_ds,
          validation_data = test_ds, 
          initial_epoch=0,
          epochs=15, 
          callbacks=[logging_callback]) 