# Model Architecture Rewrite

This notebook attempts to rewrite the NTaskModel class to utilize overridable methods to retain all Keras model properties and features.

**Note** Due to what seems like some sort of bug in Keras, this rewrite is currently not possible

In [335]:
from tensorflow.python.keras.engine import data_adapter
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras import Model
import tensorflow as tf

import numpy as np
import random

In [503]:
class NTaskModel(Model):
    
    def compile(self, *args, **kwargs):
        super(NTaskModel, self).compile(*args, **kwargs)
        self.context_layers = []
        self.my_test_count = 0
        
        
    def _calc_context_loss(self, context_layer_idx, gradients):
        """
        IMPORTANT: 
        1) Assumes no use of activation function on Ntask layer
        2) Assumes that the layer following the Ntask layer:
            a) Is a Dense layer
            b) Is using bias 
               — ex: Dense(20, ... , use_bias=True) 
               — note Keras Dense layer uses bias by default if no value is given for use_bias param
        3) Assumes index of the next layer's gradient is known within the gradients list returned from gradient tape in a tape.gradient call
        4) If the above points aren't met, things will break and it may be hard to locate the bugs
        """
        # From the delta rule in neural network math
        delta_at_next_layer = gradients[context_layer_idx + 1]
        transpose_of_weights_at_next_layer = tf.transpose(self.layers[context_layer_idx + 1].get_weights()[0])
        
        # Calculate delta at n-task layer
        context_delta = np.dot(delta_at_next_layer, transpose_of_weights_at_next_layer).astype(np.float)
        
        # Calculate Context Error
        # Keras MSE must have both args be arrs of floats, if one or both are arrs of ints, the output will be rounded to an int
        # This is how responsible the context layer was for the loss
        return tf.keras.losses.mean_squared_error(np.zeros(len(context_delta)), context_delta)
    
        
    def _forward_pass(self, x, y, sample_weight=None):
        """
        Performs a forward pass with active switching mechanism
        x: input data
        y: expected output (required for switching mechanisms)
        """
        
        # Perform forward pass and calculate loss
        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)
            loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
            
        # Extract the gradients for the loss of the current sample
        gradients = tape.gradient(loss, self.trainable_variables)

        for context_layer_idx in self.context_layers:
            self.layers[context_layer_idx].context_loss += self._calc_context_loss(context_layer_idx, gradients)
        
        # Return the calculated gradients
        return y_pred, gradients
    
    
    def train_step(self, batch):
    
        # Unpack the data
        data = data_adapter.expand_1d(batch)
        x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(batch)
    
        # Perform a forward pass and calculate gradients
        y_pred, gradients = self._forward_pass(x, y, sample_weight)
        
        # Apply the gradients to the model
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        
        # Update the metrics
        self.compiled_metrics.update_state(y, y_pred, sample_weight)
        
        self.my_test_count += 1
        tf.print("Train step", self.my_test_count)
        
        return {metric.name: metric.result() for metric in self.metrics}
    
    
    def fit(self, *args, **kwargs):
        # Method 1: Modify batch size here to size of dataset, divide in train step
        return super(NTaskModel, self).fit(*args, **kwargs)
        

In [504]:
inp = Input((2,))
x = Dense(40, activation="relu")(inp)
x = Dense(1, activation="sigmoid")(x)
model = NTaskModel(inputs=inp, outputs=x)

In [505]:
model.compile(optimizer=tf.keras.optimizers.Adam(1e-4), loss=tf.keras.losses.BinaryCrossentropy(), metrics=["accuracy"])

In [506]:
x_train = np.array([[1, 1], [-1, 1], [2, 2], [3, 3]]) # 4 inputs
y_train = np.array([[1], [0], [2], [3]])              # 4 labels

In [507]:
model.fit(x_train, y_train, batch_size=1, shuffle=False, verbose=0)

Train step 2
Train step 2
Train step 2
Train step 2


<tensorflow.python.keras.callbacks.History at 0x7fdf04f210d0>

In [502]:
model.my_test_count

2

In [413]:
model.layers[2].get_weights()

[array([[-0.3652719 ],
        [-0.14661784],
        [ 0.1964958 ],
        [-0.05448712],
        [-0.0798344 ],
        [ 0.07384975],
        [ 0.23355207],
        [ 0.16030094],
        [ 0.24935973],
        [-0.27872664],
        [ 0.38130707],
        [ 0.17336681],
        [-0.10716698],
        [-0.14784165],
        [-0.22933787],
        [-0.0591571 ],
        [ 0.03553462],
        [-0.20349382],
        [ 0.07389298],
        [-0.30909204],
        [ 0.30697566],
        [-0.03409614],
        [-0.06402037],
        [ 0.07586429],
        [ 0.04679683],
        [ 0.08286428],
        [ 0.31718922],
        [-0.1294789 ],
        [-0.3269272 ],
        [ 0.03682972],
        [-0.24866512],
        [ 0.19342953],
        [ 0.35329303],
        [ 0.37803873],
        [-0.16188507],
        [-0.02978953],
        [ 0.01155861],
        [ 0.23300216],
        [-0.34957296],
        [-0.25856268]], dtype=float32),
 array([0.00021756], dtype=float32)]