In [2]:
import time
import datetime

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
 
import tensorflow as tf
import tensorflow_datasets as tfds 

### Helper function for timing

In [None]:
class Timer():
    """
    A small class to measure time during training.
    """
    def __init__(self):
        self._start_time = None

    def start(self):
        """
        Start a new timer
        """
        if self._start_time is not None:
            print(f"Timer is running. Use .stop() to stop it")
            return None

        self._start_time = time.perf_counter()

    def stop(self):
        """
        Stop the timer, and report the elapsed time
        """
        if self._start_time is None:
            print(f"Timer is not running. Use .start() to start it")
            return 0
    
        elapsed_time = time.perf_counter() - self._start_time
        self._start_time = None
        return elapsed_time  


# 1. Task
For every timestep, two query digits are given. The network decides which of these two digits are most commonly presented in the sequence until the current timestep.

In [91]:
def data_pair_gen(max_len=100):
    while True:
        length = np.random.randint(1, max_len+1)
        x = np.random.randint(1, 10, length, dtype=np.uint8)
        context = np.random.randint(1, 10, 2, dtype=np.uint8)
        y = np.array([np.count_nonzero(x == value) for value in context], dtype=np.uint8)
        
        yield tf.constant(x), tf.constant(context), tf.constant(y)
        
# for tf==2.4.0
# tf.data.Dataset.from_generator(generator=data_pair_gen,
#                                output_signature=(tf.TensorSpec(shape=(None,), dtype=tf.uint32),
#                                                  tf.TensorSpec(shape=(2,), dtype=tf.uint32),
#                                                  tf.TensorSpec(shape=(2,), dtype=tf.uint32))
#                               )

# for tf==2.2.0
ds = tf.data.Dataset.from_generator(generator=data_pair_gen,
                                    output_types=(tf.uint8, tf.uint8, tf.uint8),
                                    output_shapes=((None,), (2,), (2,))
                                    )

print(list(ds.take(1)))

[(<tf.Tensor: shape=(79,), dtype=uint8, numpy=
array([5, 2, 9, 3, 9, 4, 6, 8, 1, 6, 3, 4, 3, 5, 9, 7, 5, 6, 9, 8, 4, 2,
       4, 8, 2, 4, 1, 2, 4, 3, 9, 3, 7, 8, 5, 4, 4, 5, 1, 6, 8, 8, 5, 9,
       2, 1, 7, 3, 2, 2, 6, 1, 1, 9, 6, 8, 6, 5, 5, 5, 8, 4, 5, 8, 4, 6,
       6, 1, 5, 6, 5, 4, 3, 7, 8, 4, 1, 1, 9], dtype=uint8)>, <tf.Tensor: shape=(2,), dtype=uint8, numpy=array([1, 7], dtype=uint8)>, <tf.Tensor: shape=(2,), dtype=uint8, numpy=array([9, 4], dtype=uint8)>)]


# 2. Model
To unroll the network, consider appending the network multiple times next to each other and feeding input at different locations. First do it using for loops, than change to graph mode.

In [None]:
class LSTM_cell (tf.keras.layers.Layer):
    def __init__(self, hidden_dim = 1):
        super(LSTM_cell, self).__init__()
        self.h = hidden_dim  # dimension of cell state and hidden state
        # TODO: init hidden_state and cell_state


    def build(self, input_shape):
        # forget gate
        self.w_f = self.add_weight(shape=(self.h,  # dim (h, d+h) with d = input_shape
                                          self.h + input_shape),
                                   initializer=tf.random_normal_initializer(),
                                   trainable=True)
        self.b_f = self.add_weight(shape=(self.h,),  # (h,1)
                                   # bias of forget gate is initially 1
                                   initializer=tf.keras.initializers.Constant(value=1.0),
                                   trainable=True)
        # input gate
        self.w_i = self.add_weight(shape=(self.h, self.h + input_shape),
                                   initializer=tf.random_normal_initializer(),
                                   trainable=True)
        self.b_i = self.add_weight(shape=(self.h,),
                                   initializer=tf.random_normal_initializer(),
                                   trainable=True)
        # candidate layer
        self.w_c = self.add_weight(shape=(self.h, self.h + input_shape),
                                   initializer=tf.random_normal_initializer(),
                                   trainable=True)
        self.b_c = self.add_weight(shape=(self.h,), 
                                   initializer=tf.random_normal_initializer(),
                                   trainable=True)
        # output gate
        self.w_o = self.add_weight(shape=(self.h, self.h + input_shape),
                                   initializer=tf.random_normal_initializer(),
                                   trainable=True)
        self.b_o = self.add_weight(shape=(self.h,), 
                                   initializer=tf.random_normal_initializer(),
                                   trainable=True)
        

    def call(self, input, (hidden_state, cell_state)):
        # [h_{t-1}, x_t] to get dim: (d+h,1) where 1 is a single time slice
        # TODO: axis might be wrong?
        concat_input = tf.keras.layers.Concatenate(axis=0)([hidden_state, input])
        
        # function to compute ouput of forget, input, output gates
        # e.g. f_t = sigmoid( w_f @ [h_t-1, x_t] + b_f )
        gate_output = lambda w,b: tf.keras.activations.sigmoid(tf.linalg.matmul(w, concat_input) + b)
        
        # forget gate 
        f_t = gate_output(self.w_f, self.b_f)
        # input gate
        i_t = gate_output(self.w_i, self.b_i)
        # candidates for new cell states, use tanh instead of sigmoid
        c_tilde_t = tf.linalg.matmul(self.w_c, concat_input) + self.b_c
        c_tilde_t = tf.keras.activations.tanh(c_tilde_t)
        # update cell states: C_t = f_t * C_t-1 + i_t * C_tilde_t
        self.cell_state = tf.math.multiply(f_t, self.cell_state) + tf.math.multiply(i_t, c_tilde_t)
        # output gate
        o_t = gate_output(self.w_o, self.b_o)
        # h_t = o_t * tanh(C_t)
        self.hidden_state = tf.math.multiply(o_t, tf.keras.activations.tanh(self.cell_state))
                                        
        return self.hidden_state

In [None]:
class LSTM_net (tf.keras.Model):
    '''
    Build a LSTM net with a single recurrent node
    '''
    def __init__(self, hidden_dim=1):
        super(LSTM_net, self).__init__()
        # readin layer dim is subject to change depending on data structure
        self.readin = tf.keras.layers.Dense(100, activation='relu', input_shape=(3,))
        self.recurrent = LSTM_cell(hidden_dim)
        # logistic classification
        self.readout = tf.keras.layers.Dense(1, actiation='sigmoid')

    
    def call(self, x):
        x = self.readin(x)
        x = self.recurrent(x)
        x = self.readout(x)
        return x

    # TODO: unroll the network