[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/sascha-senko/TensorflowCourse/blob/HSinger04/ANNwTFHW9.ipynb)

## Global TODO: Only for Hermann

* Add the 2 digits as inputs to each datum
* Add labels to each datum

In [115]:
import matplotlib.pyplot as plt
import numpy as np
import sys
import random
%load_ext tensorboard
%tensorflow_version 2.x
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, Flatten, Dense, Conv2DTranspose, \
 Reshape, MaxPooling2D, Dropout, BatchNormalization, UpSampling2D, ReLU, \
 ELU, Layer
from tensorflow import debugging as debug
import tensorflow_probability as tfp
from functools import partial

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


## Define some constants for dataset

In [116]:
# arbitrarily set. Feel free to change these
DATA_SIZE = 10000
SEQ_SIZE = 25
SHUFFLE_SIZE = DATA_SIZE
PREFETCH_SIZE = tf.data.experimental.AUTOTUNE
BATCH_SIZE = 32

## Define Dataset

In [117]:
def digit_sequence(data_size, size=SEQ_SIZE):
    num = 0
    while num < data_size:
        # yield the two context digits, the sequence and the label
        seq = np.random.randint(10, size=size)
        seq = tf.one_hot(seq, 10, axis=-1)   

        yield seq 
        num += 1

x_train = tf.data.Dataset.from_generator(digit_sequence, args=[DATA_SIZE], output_signature=tf.TensorSpec((25,10)))

## Prepare data

In [118]:
def get_context_and_labels(seq_and_context):
    seq, context = seq_and_context
    # get all unique digits of sequence
    digits, _ = tf.unique(seq)
    # context digits
    context = tf.random.shuffle(digits)[:2]

    # now get the labels
    labels = tf.TensorArray(dtype=tf.float32, size=SEQ_SIZE)

    # counts how much more often the first context digit was observed over the second
    first_vs_second_occurance = 0

    for i in tf.range(SEQ_SIZE):
        digit = seq[i]

        if digit == context[0]:
            first_vs_second_occurance += 1
        elif digit == context[1]:
            first_vs_second_occurance -= 1

        if first_vs_second_occurance >= 0:        
            labels.write(i, 0)
        else:
            labels.write(i, 1)    

    # gather returns Tensor from TensorArray
    return tf.ragged.constant(context, labels.gather(tf.range(SEQ_SIZE)))

def get_context(seq):
    # get all unique digits of sequence
    digits, _ = tf.unique(seq)
    # context digits
    context = tf.random.shuffle(digits)[:2]

    return context

def get_labels(seq_and_context):
    seq = seq_and_context[:25]
    context = seq_and_context[25:]

    # now get the labels
    labels = tf.TensorArray(dtype=tf.int16, size=SEQ_SIZE)

    # counts how much more often the first context digit was observed over the second
    first_vs_second_occurance = 0

    for i in tf.range(SEQ_SIZE):
        digit = seq[i]
        if digit == context[0]:
            first_vs_second_occurance += 1
        elif digit == context[1]:
            first_vs_second_occurance -= 1

        if first_vs_second_occurance >= 0:        
            labels.write(i, 0)
        else:
            labels.write(i, 1)    

    # gather returns Tensor from TensorArray
    return labels.gather(tf.range(SEQ_SIZE))

def data_pipeline(data):
    """ 
    helper function for data pipeline - does all the things we need 
    
    No shuffling needed, as generator creates random sequence everytime
    """

    data = data.batch(BATCH_SIZE)
    return data

x_train = data_pipeline(x_train)

## LSTM Cell

In [119]:
class LSTM_Cell(tf.keras.layers.Layer):
    def __init__(self, hidden_size):
        super(LSTM_Cell, self).__init__()

        self.hidden_size = hidden_size
        #gates
        self.input_gate = Dense(hidden_size, activation="sigmoid")
        # setting forget bias to one initially is important, 
        # probably because the very first hidden and cell state that gets fed in 
        # call is just a dummy zero vector and doesn't provide any information  
        self.forget_gate = Dense(hidden_size, bias_initializer='ones', activation="sigmoid")
        self.output_gate = Dense(hidden_size, activation="sigmoid")
        self.cell_state_candidates = Dense(hidden_size, activation="tanh")

    #@tf.function
    def call(self, input, hidden_state, cell_state):
        # x is 1-D
        # TODO: axis -1 or 1?
        print("hidden_state's shape")
        print(hidden_state.shape)
        print("input's shape")
        print(input.shape)
        concat_input = tf.concat([hidden_state, input], axis=-1) 
        new_cell_state = cell_state * self.forget_gate(concat_input) 
        new_cell_state += self.input_gate(concat_input) * self.cell_state_candidates(concat_input)
        new_hidden_state = tf.keras.activations.tanh(cell_state) * self.output_gate(concat_input) # new hidden state is also output
        return new_cell_state, new_hidden_state      

## LSTM

In [120]:
class LSTM(Model):
    def __init__(self, hidden_size):
        super(LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.LSTM_Cell = LSTM_Cell(hidden_size)
        # < 0.5 for first context digit occuring more often, >= 0.5 for second
        self.Classifier = Dense(1, activation='sigmoid')
        
    # TODO: Let's see if call works, as I am initializing a state  
    # if I end up using tf.Variable, make sure it's untrainable
    #@tf.function
    def call(self, x):
        results = tf.TensorArray(tf.float32, size=SEQ_SIZE)

        hidden_state = tf.zeros((BATCH_SIZE, self.hidden_size))
        cell_state = tf.zeros((BATCH_SIZE, self.hidden_size))
        
        # TODO: check if it's zero everytime right here with debug
        for index in range(SEQ_SIZE):
            digit = x[:,index,:]
            # TODO: check if compatible with tf.function
            print("digit's shape")
            print(digit.shape)
            cell_state, hidden_state = self.LSTM_Cell(digit, hidden_state, cell_state)
            results.write(index, hidden_state)

        output = self.Classifier(tf.transpose(results.stack(), perm=[1,0,2]))

        return output

In [121]:
model = LSTM(10)
for input in x_train:
    print(input.shape)
    print(model(input).shape)
    break

(32, 25, 10)
digit's shape
(32, 10)
hidden_state's shape
(32, 10)
input's shape
(32, 10)
Object was never used (type <class 'tensorflow.python.ops.tensor_array_ops.TensorArray'>):
<tensorflow.python.ops.tensor_array_ops.TensorArray object at 0x7f17ae24cd30>
If you want to mark it as used call its "mark_used()" method.
It was originally created here:
  File "/usr/local/lib/python3.6/dist-packages/IPython/core/interactiveshell.py", line 2882, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)  File "<ipython-input-121-23cc79fdadeb>", line 4, in <module>
    print(model(input).shape)  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py", line 1012, in __call__
    outputs = call_fn(inputs, *args, **kwargs)  File "<ipython-input-120-ed3c92a9a684>", line 25, in call
    results.write(index, hidden_state)  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/tf_should_use.py", line 249, in wrapped
    error_in_function=error

## Define some constants

In [None]:
NUM_EPOCHS = 10
LEARNING_RATE = 0.0001   
OPTIMIZER = tf.keras.optimizers.Adam(LEARNING_RATE)
BCE = tf.keras.losses.BinaryCrossentropy() 
NUM_BATCHES = (int(x_train.cardinality()))
HIDDEN_SIZE = 10

import datetime
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
lstm_log_dir = 'logs/gradient_tape/' + current_time + '/lstm'
train_writer = tf.summary.create_file_writer(lstm_log_dir)

In [None]:
@tf.function
def train_step(model, inputs, labels, optimizer):
    # loss_object and optimizer_object are instances of respective tensorflow classes
    with tf.GradientTape() as tape:
        prediction = model(inputs)
        loss = BCE(labels, prediction)
        gradients = tape.gradient(loss, model.trainable_variables)

    # update weights  
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    prediction = tf.cast(tf.math.round(prediction), labels.dtype)
    correct = tf.math.equal(prediction, correct)

    return loss, correct

## Helper functions

In [None]:
# TODO: Comment about BPTT
# TODO: Regression or classification problem?
# TODO: args
# TODO: weight update still incorrect. Remember that we have an output for each part
@tf.function
def train_step(model, inputs, labels, optimizer):
    # loss_object and optimizer_object are instances of respective tensorflow classes
    with tf.GradientTape() as tape:
        prediction = model(inputs)
        loss = BCE(labels, prediction)
        gradients = tape.gradient(loss, model.trainable_variables)

    # update weights  
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    prediction = tf.cast(tf.math.round(prediction), labels.dtype)
    correct = tf.math.equal(prediction, correct)

    return loss, correct

# TODO: epoch needs to be tf.function compatible
#@tf.function
def one_epoch(model, optimizer, loss_tracker, accuracy_tracker, train_data, epoch):
    # reset statistics
    loss_tracker.reset_states()
    accuracy_tracker.reset_states()

    # TODO: not just input, but also other things
    for inputs in train_data:

        #context = tf.map_fn(get_context_and_labels, inputs, fn_output_signature=tf.TensorSpec((25, ), tf.float32))
        # TODO: add labels back
        context = tf.map_fn(get_context, inputs)
        # tried to create labels_inp more cleanly separated (e.g. tuple),
        # but due to the mysterious nature of map_fn, 
        # I had to fall back on this ugly fix
        labels_inp = tf.concat((inputs, context), 1)
        labels = tf.map_fn(get_labels, labels_inp)
        # TODO: remove
        print(inputs[0])
        print(context[0])
        print(labels[0])
        break
        
        # TODO: dunno if recursive tf.function worked so well
        loss, accuracy = train_step(model, inputs, labels, optimizer)

        loss_tracker.update_state(loss)
        accuracy_tracker.update_state(accuracy)

    # Write statistics into summary
    with train_writer.as_default():
        tf.summary.scalar('loss', loss_tracker.result(), step=epoch)
        tf.summary.scalar('accuracy', accuracy_tracker.result(), step=epoch)

## Train

In [None]:
# Clear any logs from previous runs
%rm -rf ./logs/

# remove all active models for memory purposes
tf.keras.backend.clear_session()

model = LSTM(HIDDEN_SIZE)

loss_tracker = tf.keras.metrics.Mean()
accuracy_tracker = tf.keras.metrics.Mean()

for epoch in range(NUM_EPOCHS):
    print('Epoch: ' + str(epoch+1))
    one_epoch(model, OPTIMIZER, loss_tracker, accuracy_tracker, x_train, epoch)

In [None]:
# Open tensorboard
%tensorboard --logdir logs/gradient_tape