# Model Source

This notebook experiments with re-implementing the model's source code inserting various event handlers for extension to n-task

In [1]:
%load_ext tensorboard

In [2]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Layer
from tensorflow.keras import Model

from collections import defaultdict
from enum import IntFlag, auto
import datetime
import math
import numpy as np
import random
import time
import os

from utils import idx_load

In [3]:
def hrr(length, normalized=True):
    """Create a new HRR vector using Tensorflow tensors"""
    length = int(length)      
    shp = int((length-1)/2)
    if normalized:    
        x = tf.random.uniform( shape = (shp,), minval = -np.pi, maxval = np.pi, dtype = tf.dtypes.float32, seed = 100, name = None )
        x = tf.cast(x, tf.complex64)
        if length % 2:
            x = tf.math.real( tf.signal.ifft( tf.concat([tf.ones(1, dtype="complex64"), tf.exp(1j*x), tf.exp(-1j*x[::-1])], axis=0)))

        else:  
            x = tf.math.real(tf.signal.ifft(tf.concat([tf.ones(1, dtype="complex64"),tf.exp(1j*x),tf.ones(1, dtype="complex64"),tf.exp(-1j*x[::-1])],axis=0)))
    else:        
        x = tf.random.normal( shape = (length,), mean=0.0, stddev=1.0/tf.sqrt(float(length)),dtype=tf.dtypes.float32,seed=100,name=None)
    return x


def hrrs(length, n=1, normalized=True):
    """Create n new HRR vectors using Tensorflow tensors"""
    return tf.stack([hrr(length, normalized) for x in range(n)], axis=0)


def circ_conv(x, y):
    """Calculate the circular convolution between two HRR vectors"""
    x = tf.cast(x, tf.complex64)
    y = tf.cast(y, tf.complex64)
    return tf.math.real(tf.signal.ifft(tf.signal.fft(x)*tf.signal.fft(y)))


def logmod(x):
    return np.sign(x)*np.log(abs(x) + 1)
    
    
def plot(title, labels, *frameGroups):
    fig, ax = plt.subplots()
    plotFrames(ax, title, labels, *frameGroups, xlabel="Epoch", ylabel="Value")
    ax.grid()
    plt.legend()
    
    
def plotFrames(ax, title, labels, *frameGroups, xlabel=None, ylabel=None):
    for i, frames in enumerate(frameGroups):
        keys = tuple(frames.keys() if type(frames) == dict else range(len(frames)))
        t = np.arange(keys[0], keys[-1] + 1, 1)
        ax.plot(t, list(frames.values()), label=(labels[i] if labels else None))
    ax.set(xlabel=xlabel, ylabel=ylabel, title=title)

In [4]:
class Verbosity(IntFlag):
    Progress = auto()
    Contexts = auto()

In [52]:
class AtrModel:
    
    def __init__(self, switch_threshold, add_threshold=0.0, max_contexts=0):
        
        self.switch_threshold = tf.Variable(switch_threshold, name="Switch_Threshold", trainable=False, dtype=tf.float32)
        self.add_threshold = tf.Variable(add_threshold or 0.0, name="Add_Threshold", trainable=False, dtype=tf.float32)
        self._max_contexts = max_contexts
        self._context_layer = None
        
        # Track the number of sequential switches so we can determine if no tasks fit the threshold
        self._num_seq_switches = tf.Variable(0, name="Sequential_Switches", trainable=False, dtype=tf.int64)
        
        # Indicate the epoch at which the last switch occurred
        self.epoch_switched = tf.Variable(-1, name="Last_Switch_Epoch", trainable=False, dtype=tf.int64)
        
        # If we tried to add another context after the max number of contexts was reached, we should warn the user
        # and stop checking whether or not a context should be added
        self._exceeded_context_limit = tf.Variable(False, name="Context_Limit_Exceeded", trainable=False, dtype=tf.bool)
        
        # Track the context-loss delta for the active context
        self.delta = tf.Variable(0.0, name="ATR_Delta", trainable=False, dtype=tf.float32)
        
        # Store the delta that triggered the initial context switch
        self.delta_switched = tf.Variable(0.0, name="ATR_Delta_Switched", trainable=False, dtype=tf.float32)
        
        # To be built...
        self.values = None
        self.values_initialized = None
        
        
    def set_context_layer(self, context_layer):
        self._context_layer = context_layer
    
    
    def build(self, num_contexts):
        # Determine the number of contexts to create.
        # Since we can't yet dynamically add contexts, we need
        # to create the list at its max size initially.
        num_contexts = max(num_contexts, self._max_contexts)
        
        # Create the list of ATR values to track
        self.values = tf.Variable(np.zeros(num_contexts), name="ATR_Values", trainable=False, dtype=tf.float32)
        
        # A second list is created to determine uninitialized ATR values
        self.values_initialized = tf.Variable(np.zeros(num_contexts), name="ATR_Values_Initialized", trainable=False, dtype=tf.bool)
        
    
    def add_context(self):
        # Add the new context to the context layer
        self._context_layer.add_context()
        
        
    def switch_contexts(self, context_loss, verbose):
        
        # If we have exhausted the context list, look for the one with the best fit
        if self._num_seq_switches >= self.num_contexts:
            best_fit = self.find_best_fit_context()
            
            # If no context really fits well and we can add more contexts, add a new one
            if self.max_num_contexts > 0 and not self._exceeded_context_limit and self.should_add_context(context_loss, best_fit):
                if self.num_contexts < self.max_num_contexts:
                    self.add_context()
                    self.hot_context = self.num_contexts - 1
                    if verbose & Verbosity.Contexts:
                        tf.print(f"\n[{self.context_layer.name}] Adding context {self.hot_context}")
                else:
                    self._exceeded_context_limit.assign(True)
                    tf.print(f"\n[{self.context_layer.name}] WARNING: Attempted to add context after context limit reached")
                
            # Use the best fit context
            else:
                if verbose & Verbosity.Contexts:
                    tf.print(f"\nUsing best-fit context {best_fit}")
                # Switch to the best-fitting context
                self.hot_context = best_fit
                
                # Before the ATR value is updated...
#                 self.on_before_update(context_loss)
                
                # Update the ATR value for the new context
                self.update_atr_value(self.context_losses[self.hot_context], switched=True)

        else:
            self.context_layer.next_context()
                
    
    def update_and_switch(self, epoch, context_loss, dynamic_switch, verbose):
        """
        Update the ATR.
        
        Returns result type
        """
        if dynamic_switch and self.should_switch(epoch, context_loss):
            
            # Before we switch...
            self.on_before_switch(epoch, context_loss)
            
            # Count the switches
            self._num_seq_switches.assign_add(1)
            self.epoch_switched.assign(epoch)
            
            # Switch contexts and return the result
            self.switch_contexts(context_loss, verbose)
            
            # Switched, so nothing was updated
            return False
        
        # Before the ATR value is updated...
        self.on_before_update(context_loss)
            
        self.update_atr_value(context_loss, switched=False)
            
        # Reset the switch count if we previously switched
        if self._num_seq_switches != 0:
            self._num_seq_switches.assign(0)
            if verbose & Verbosity.Contexts:
                tf.print(f"\n[{self.context_layer.name}] Switched context to {self.hot_context}")
            
        # Updated successfully
        return True
    
    
    def set_atr_value(self, context_loss):
        self.values.scatter_nd_update([[self.hot_context]], [context_loss])
        if not self.values_initialized[self.hot_context]:
            self.values_initialized.scatter_nd_update([[self.hot_context]], [True])

    # Event Handlers ------------------------------------------------------------------------------
    
    def on_before_switch(self, epoch, context_loss):
        if epoch != self.epoch_switched:
            delta = self.values[self.hot_context] - context_loss
            self.delta_switched.assign(delta)
            self.delta.assign(delta)
            
    def on_before_update(self, context_loss):
        if self.values_initialized[self.hot_context]:
            delta = self.values[self.hot_context] - context_loss
            self.delta.assign(delta)
            
    # Overridable ---------------------------------------------------------------------------------
    
    def context_loss_fn(self, context_delta):
        # Calculate Context Error
        # Keras MSE must have both args be arrs of floats, if one or both are arrs of ints, the output will be rounded to an int
        # This is how responsible the context layer was for the loss
        return tf.keras.losses.mean_squared_error(np.zeros(len(context_delta)), context_delta)
    
    def update_atr_value(self, context_loss, switched):
        """Update the ATR value"""
        # Update the ATR value
        self.set_atr_value(context_loss)
    
    def find_best_fit_context(self):
        """Locate the context index with the best fit"""
        return tf.argmax(tf.subtract(self.values, self.context_losses)[:self.num_contexts])
    
    def should_switch(self, epoch, context_loss):
        # If the ATR value has not been initialized yet, we don't need to switch
        if not self.values_initialized[self.hot_context]:
            return False
        # If the context loss exceeds the threshold
        delta = self.values[self.hot_context] - context_loss
        return delta < self.switch_threshold
    
    def should_add_context(self, context_loss, best_fit_context_idx):
        """
        Determine if a new context should be added
        Note: This is only checked after a switch has been determined
        """
        delta = self.values[self.hot_context] - self.context_losses[best_fit_context_idx]
        return delta < self.add_threshold
    
    def epoch_traces(self, epoch):
        """
        Return a dictionary of traces to plot
        """
        return {
            "ATR Traces": [
                trace(f"Context {i}", v) for i, v in enumerate(self.values.value())
                      if self.values_initialized[i] is not None
            ],
            "Delta Trace": [
                trace("Switch Threshold", self.switch_threshold.value(), '--', 'grey'), # Dark grey is lighter than grey...
                trace("Add Threshold", self.add_threshold.value(), '-.', 'grey', condition=self.max_num_contexts>0),
                trace("Context Delta", self.delta_switched.value(), '-', condition=self.epoch_switched==epoch),
                trace("Context Delta", self.delta.value(), '-')
            ]
        }
        
    # Properties ----------------------------------------------------------------------------------
        
    @property
    def context_losses(self):
        return self.context_layer.context_losses
        
    @property
    def num_contexts(self):
        if self.context_layer is None:
            return None
        return self.context_layer.num_contexts
        
    @property
    def max_num_contexts(self):
        return self._max_contexts
    
    @property
    def context_layer(self):
        return self._context_layer
    
    @property
    def hot_context(self):
        return self._context_layer.hot_context
    
    @hot_context.setter
    def hot_context(self, hot_context):
        self._context_layer.hot_context = hot_context

In [47]:
class AtrMovingAverage(AtrModel):
    def update_atr_value(self, context_loss, switched):
        if switched or not self.values_initialized[self.hot_context]:
            self.set_atr_value(context_loss)
        else:
            self.set_atr_value((self.values[self.hot_context] + context_loss) / 2.0)

In [7]:
class Context(Layer):
    
    def __init__(self, contexts=1, atr_model=None, **kwargs):
        super(Context, self).__init__(**kwargs)
        
        # The ATR model handles the switching mechanisms
        self._atr_model = atr_model
        
        # Information Tracking
#         self._context_loss = tf.Variable([0.0, 0.0], name="Context_Losses", trainable=False, dtype=float) # Created in build step
        self._num_contexts = tf.Variable(contexts, name="Num_Contexts", trainable=False, dtype=tf.int64)
        self._hot_context = tf.Variable(0, name="Hot_Context", trainable=False, dtype=tf.int64)
        
        
    def build(self, input_shape):
        # Store the input shape since weights can be rebuilt later
        self._input_shape = int(input_shape[-1])
        
        # Build the ATR model
        self._atr_model.set_context_layer(self)
        self._atr_model.build(self.num_contexts)
        
        # The number of contexts to create in the kernel
        num_kernel_contexts = max(self.num_contexts, self.atr_model.max_num_contexts)
        
        # Create the HRR initializer. This will create the list of HRR vectors
        initializer = lambda shape, dtype=None: hrrs(self._input_shape, n=num_kernel_contexts)
        self.kernel = self.add_weight(name="context", shape=[num_kernel_contexts, self._input_shape], initializer=initializer, trainable=False)
        
        # Store the context losses for each context
        self._context_loss = tf.Variable(np.zeros(num_kernel_contexts), name="Context_Losses", trainable=False, dtype=float)
        
        #TEMP
        self._max_contexts = num_kernel_contexts
        
        
    def call(self, inputs):
        """
        Calculate the output for this layer.
        
        This layer convolves the input values with the context HRR vector
        to produce the output tensor.
        """
        # Fetch the hot context's HRR vector
        context_hrr = self.kernel[self.hot_context]
        
        # Return the resulting convolution between the inputs and the context HRR
        return circ_conv(inputs, context_hrr)
    
    
    def update_and_switch(self, epoch, dynamic_switch, verbose):
        """
        Update ATR values and switch contexts if necessary.
        Returns True if no context switch occurs; False otherwise
        """
        # If there is no ATR model, there's nothing to update
        if self._atr_model is None:
            return True
        
        # Update the ATR madel
        result = self._atr_model.update_and_switch(epoch, self.context_loss, dynamic_switch, verbose)
        
        # Clear the context loss when we're done
        self.clear_context_loss()
        
        # Did the ATR model update or switch?
        return result
        
    
    #TODO Context adding
    def add_context(self):
        # kernel_arr = self.kernel.value()
        # num_hrrs = max(0, self._num_contexts - len(kernel_arr))
        # initializer = lambda shape, dtype=None: np.append(kernel_arr[:self.num_contexts], hrrs(self._input_shape, n=num_hrrs), axis=0)
        # new_weights = self.add_weight(name="context", shape=[self.num_contexts, self._input_shape], initializer=initializer, trainable=False)
        # Create the weights for the layer.
        # The weights in this layer are generated HRR vectors, and are never updated.
        # self.kernel = new_weights
        
        if self._num_contexts < self._max_contexts:
            self._num_contexts.assign_add(1)
            return True
        return False
        
    
    def clear_context_loss(self):
        """Clear the context loss for the current epoch"""
        self._context_loss.scatter_nd_update([[self.hot_context]], [0.0])
    
    
    def add_context_loss(self, context_loss):
        """Accumulate context loss"""
        if self._atr_model is not None:
            context_loss = self._atr_model.context_loss_fn(context_loss)
        else:
            context_loss = tf.keras.losses.mean_squared_error(np.zeros(len(context_loss)), context_loss)
        self._context_loss.scatter_nd_add([[self.hot_context]], [context_loss])
        
        
    def next_context(self):
        """Switch to the next sequential context"""
        self.hot_context = (self.hot_context + 1) % self.num_contexts
        
    
    @property
    def atr_model(self):
        return self._atr_model
    
    @property
    def context_loss(self):
        return self._context_loss[self.hot_context]
    
    @property
    def context_losses(self):
        return self._context_loss
        
    @property
    def num_contexts(self):
        return self._num_contexts.value()
    
    @property
    def hot_context(self):
        """Get the active context index"""
        return self._hot_context.value()
    
    @hot_context.setter
    def hot_context(self, hot_context):
        if hot_context not in range(self.num_contexts):
            raise ValueError("`Provided context does not exist")
        self._hot_context.assign(hot_context)

## Extended Model

The model below serves as a new base model for NTask

In [11]:
from collections import deque
import copy

from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer as lso
from tensorflow.python.data.experimental.ops import distribute_options
from tensorflow.python.data.ops import dataset_ops

from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops

from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.profiler import traceme

from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import parameter_server_strategy
from tensorflow.python.keras import backend
from tensorflow.python.keras import callbacks as callbacks_module
from tensorflow.python.keras.utils import version_utils
from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.engine import data_adapter
from tensorflow.python.keras.engine import training
from tensorflow.python.util import nest

In [12]:
# Borrowed from https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/keras/engine/data_adapter.py
try:
    import pandas as pd  # pylint: disable=g-import-not-at-top
except ImportError:
    pd = None

In [13]:
def _minimize(strategy, tape, optimizer, loss, trainable_variables):
    """Minimizes loss for one step by updating `trainable_variables`.
    This is roughly equivalent to
    ```python
    gradients = tape.gradient(loss, trainable_variables)
    self.optimizer.apply_gradients(zip(gradients, trainable_variables))
    ```
    However, this function also applies gradient clipping and loss scaling if the
    optimizer is a LossScaleOptimizer.
    Args:
      strategy: `tf.distribute.Strategy`.
      tape: A gradient tape. The loss must have been computed under this tape.
      optimizer: The optimizer used to minimize the loss.
      loss: The loss tensor.
      trainable_variables: The variables that will be updated in order to minimize
        the loss.
    Return:
      gradients
    """

    with tape:
        if isinstance(optimizer, lso.LossScaleOptimizer):
            loss = optimizer.get_scaled_loss(loss)

    gradients = tape.gradient(loss, trainable_variables)

    # Whether to aggregate gradients outside of optimizer. This requires support
    # of the optimizer and doesn't work with ParameterServerStrategy and
    # CentralStroageStrategy.
    aggregate_grads_outside_optimizer = (
        optimizer._HAS_AGGREGATE_GRAD and  # pylint: disable=protected-access
        not isinstance(strategy.extended,
                       parameter_server_strategy.ParameterServerStrategyExtended))

    if aggregate_grads_outside_optimizer:
        # We aggregate gradients before unscaling them, in case a subclass of
        # LossScaleOptimizer all-reduces in fp16. All-reducing in fp16 can only be
        # done on scaled gradients, not unscaled gradients, for numeric stability.
        gradients = optimizer._aggregate_gradients(zip(gradients,  # pylint: disable=protected-access
                                                       trainable_variables))
    if isinstance(optimizer, lso.LossScaleOptimizer):
        gradients = optimizer.get_unscaled_gradients(gradients)
    gradients = optimizer._clip_gradients(gradients)  # pylint: disable=protected-access
    if trainable_variables:
        if aggregate_grads_outside_optimizer:
            optimizer.apply_gradients(
                zip(gradients, trainable_variables),
                experimental_aggregate_gradients=False)
        else:
            optimizer.apply_gradients(zip(gradients, trainable_variables))
    return gradients

In [14]:
# Extended from https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/keras/engine/data_adapter.py
class WindowedDataHandler(data_adapter.DataHandler):
    """
    Enumerating over this data handler yields windows of the dataset.
    This is important for n-task because if a context switch occurs
    during an epoch the data needs to be sent back through the network.
    """
    def calc_window_size(self):
        batch_size = self._adapter.batch_size()
        num_samples = self._inferred_steps*batch_size
        if self._adapter.has_partial_batch():
            num_samples -= batch_size - self._adapter.partial_batch_size()
        return np.ceil(num_samples/min(batch_size, num_samples))
    
    def enumerate_epochs(self):
        data_iterator = iter(self._dataset.window(self.calc_window_size()))
        for epoch in range(self._initial_epoch, self._epochs):
            if self._insufficient_data:
                break
            if self._adapter.should_recreate_iterator():
                data_iterator = iter(self._dataset.window(self.calc_window_size()))
            yield epoch, data_iterator
            self._adapter.on_epoch_end()

In [15]:
# Extended from https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/keras/engine/training.py
class NTaskModelBase(Model):
    """
    This abstract model integrates the raw mechanisms and handlers into
    Tensorflow Keras' model class. These mechanisms can be implemented by
    inheriting from this class.
    """
    
    def __init__(self, *args, **kwargs):
        super(NTaskModelBase, self).__init__(*args, **kwargs)
        self.accumulate_gradients = False
        self.accumulated_gradients = None
        
        
    def compile(self, *args, accumulate_gradients=False, **kwargs):
        super(NTaskModelBase, self).compile(*args, **kwargs)
        
        # TODO
        if accumulate_gradients:
            self.accumulate_gradients = True
        
    
    def train_step(self, data):
        data = data_adapter.expand_1d(data)
        x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)

        with backprop.GradientTape() as tape:
            y_pred = self(x, training=True)
            loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
            
        gradients = _minimize(self.distribute_strategy, tape, self.optimizer, loss,
              self.trainable_variables)
        
        # Add context loss to layers
        self.add_context_loss(gradients)

        self.compiled_metrics.update_state(y, y_pred, sample_weight)
        return {m.name: m.result() for m in self.metrics}
    
    
    @training.enable_multi_worker
    def fit(self,
            x=None,
            y=None,
            batch_size=None,
            epochs=1,
            verbose=1,
            dynamic_switch=True,
            callbacks=None,
            validation_split=0.,
            validation_data=None,
            shuffle=True,
            class_weight=None,
            sample_weight=None,
            initial_epoch=0,
            steps_per_epoch=None,
            validation_steps=None,
            validation_batch_size=None,
            validation_freq=1,
            max_queue_size=10,
            workers=1,
            use_multiprocessing=False):

        training._keras_api_gauge.get_cell('fit').set(True)
        # Legacy graph support is contained in `training_v1.Model`.
        version_utils.disallow_legacy_graph('Model', 'fit')
        self._assert_compile_was_called()
        self._check_call_args('fit')

        if validation_split:
            # Create the validation data using the training data. Only supported for
            # `Tensor` and `NumPy` input.
            (x, y, sample_weight), validation_data = (
            data_adapter.train_validation_split((x, y, sample_weight),
                                                validation_split=validation_split,
                                                shuffle=False))

        with self.distribute_strategy.scope(), training_utils.RespectCompiledTrainableState(self):
            # Creates a `tf.data.Dataset` and handles batch and epoch iteration.
            data_handler = WindowedDataHandler(
                x=x,
                y=y,
                sample_weight=sample_weight,
                batch_size=batch_size,
                steps_per_epoch=steps_per_epoch,
                initial_epoch=initial_epoch,
                epochs=epochs,
                shuffle=shuffle,
                class_weight=class_weight,
                max_queue_size=max_queue_size,
                workers=workers,
                use_multiprocessing=use_multiprocessing,
                model=self)

            # Container that configures and calls `tf.keras.Callback`s.
            if not isinstance(callbacks, callbacks_module.CallbackList):
                callbacks = callbacks_module.CallbackList(
                    callbacks,
                    add_history=True,
                    add_progbar=bool(verbose & Verbosity.Progress),
                    model=self,
                    verbose=verbose,
                    epochs=epochs,
                    steps=data_handler.inferred_steps)

            self.stop_training = False
            train_function = self.make_train_function()
            callbacks.on_train_begin()
            # Handle fault-tolerance for multi-worker.
            # TODO(omalleyt): Fix the ordering issues that mean this has to
            # happen after `callbacks.on_train_begin`.
            data_handler._initial_epoch = (self._maybe_load_initial_epoch_from_ckpt(initial_epoch))
            for epoch, window_iterator in data_handler.enumerate_epochs():
                self.reset_metrics()
                callbacks.on_epoch_begin(epoch)
                dataset = tf.data.Dataset.zip(next(window_iterator))
                switched = True
                weights = backend.batch_get_value(self.trainable_variables)
                while switched:
                    self.initialize_epoch(epoch)
                    iterator = iter(dataset)
                    with data_handler.catch_stop_iteration():
                        for step in data_handler.steps():
                            with traceme.TraceMe( 'TraceContext', graph_type='train', epoch_num=epoch, step_num=step, batch_size=batch_size):
                                callbacks.on_train_batch_begin(step)
                                tmp_logs = train_function(iterator)
                                # Catch OutOfRangeError for Datasets of unknown size.
                                # This blocks until the batch has finished executing.
                                # TODO(b/150292341): Allow multiple async steps here.
                                if not data_handler.inferred_steps:
                                    context.async_wait()
                                logs = tmp_logs  # No error, now safe to assign to logs.
                                callbacks.on_train_batch_end(step, logs)
                                
                        switched = not self.update_and_switch(epoch, dynamic_switch, verbose)
                        # If a switch occurred, we need to restore the weights
                        if switched:
                            backend.batch_set_value(zip(self.trainable_variables, weights))
                            self.reset_metrics()
                    
                epoch_logs = copy.copy(logs)
                
                if self.accumulate_gradients:
                    self.optimizer.apply_gradients(zip(self.accumulated_gradients, self.trainable_variables))

                # Run validation.
                if validation_data and self._should_eval(epoch, validation_freq):
                    val_x, val_y, val_sample_weight = (
                        data_adapter.unpack_x_y_sample_weight(validation_data))
                    val_logs = self.evaluate(
                        x=val_x,
                        y=val_y,
                        sample_weight=val_sample_weight,
                        batch_size=validation_batch_size or batch_size,
                        steps=validation_steps,
                        callbacks=callbacks,
                        max_queue_size=max_queue_size,
                        workers=workers,
                        use_multiprocessing=use_multiprocessing,
                        return_dict=True)
                    val_logs = {'val_' + name: val for name, val in val_logs.items()}
                    epoch_logs.update(val_logs)

                callbacks.on_epoch_end(epoch, epoch_logs)
                if self.stop_training:
                    break

            callbacks.on_train_end()
            return self.history
        
    def add_context_loss(self, gradients):
        """Calculate and add context loss to context layers"""
        pass
        
        
    def initialize_epoch(self, epoch):
        """Reset context loss in context layers"""
        pass
        
        
    def update_and_switch(self, epoch, dynamic_switch=True, verbose=0):
        """
        Update the context layers
        
        Args:
            dynamic_switch [bool]: Enable/disable dynamic switching mechanisms
        Return:
            [bool]: Indicate if no switches occurred
        """
        pass

In [16]:
class NTaskModel(NTaskModelBase):
    def __init__(self, *args, **kwargs):
        super(NTaskModel, self).__init__(*args, **kwargs)
        self.ctx_layers = [i for i, layer in enumerate(self.layers) if isinstance(layer, Context)]
        
        # We need to map the context layer to their gradient indices
        self.ctx_gradient_map = {}
        index = 0
        for i, layer in enumerate(self.layers):
            if isinstance(layer, Context):
                self.ctx_gradient_map[i] = index + 1 # The bias gradient
            index += len(layer.trainable_variables)
    
    
    def _calc_context_loss(self, ctx_layer_idx, gradients):
        """
        IMPORTANT: 
        1) Assumes no use of activation function on Ntask layer
        2) Assumes that the layer following the Ntask layer:
            a) Is a Dense layer
            b) Is using bias
               — ex: Dense(20, ... , use_bias=True) 
               — note Keras Dense layer uses bias by default if no value is given for use_bias param
        3) Assumes index of the next layer's gradient is known within the gradients list returned from gradient tape in a tape.gradient call
        4) If the above points aren't met, things will break and it may be hard to locate the bugs
        """
        # From the delta rule in neural network math        
        index = self.ctx_gradient_map[ctx_layer_idx]
        delta_at_next_layer = gradients[index]
        transpose_of_weights_at_next_layer = tf.transpose(self.layers[ctx_layer_idx + 1].weights[0])
        
        # Calculate delta at n-task layer
        context_delta = tf.tensordot(delta_at_next_layer, transpose_of_weights_at_next_layer, 1)
        return context_delta
    
    
    def initialize_epoch(self, epoch):
        # Clear context loss (probably going to use a new mechanism here)
#         for i in self.ctx_layers:
#             self.layers[i].clear_context_loss()
        pass
            
    
    def add_context_loss(self, gradients):
        for i in self.ctx_layers:
            self.layers[i].add_context_loss(self._calc_context_loss(i, gradients))
    
    
    def update_and_switch(self, epoch, dynamic_switch, verbose):
        updated = True
        for i in reversed(self.ctx_layers):
            layer = self.layers[i]
            updated &= layer.update_and_switch(epoch, dynamic_switch=dynamic_switch, verbose=verbose)
        return updated
    
    
    def get_contexts(self):
        return [self.layers[layer].hot_context for layer in self.ctx_layers]
    

    def set_contexts(self, contexts):
        for i, layer in enumerate(self.ctx_layers):
            self.layers[layer].hot_context = contexts[i]

___

In [48]:
class AtrLoggerTensorBoard(tf.keras.callbacks.BaseLogger):
    """
    Log ATR models via TensorBoard
    """
    
    def __init__(self, logdir, *args, **kwargs):
        super(AtrLoggerTensorBoard, self).__init__(*args, **kwargs)
        self.logdir = logdir
        self.writers = {}
        
    def set_model(self, model):
        super(AtrLoggerTensorBoard, self).set_model(model)
        self.writers = {self.model.layers[i]: [] for i in self.model.ctx_layers}
        
    def on_epoch_end(self, epoch, logs=None):
        """Create the correct number of writers for the task if necessary"""
        for layer, writers in self.writers.items():
            for i in range(len(writers), layer.num_contexts):
                writers.append(tf.summary.create_file_writer(os.path.join(self.logdir, f"{layer.name}_Atr_{i}")))
            plot_tag = f"{layer.name}_Atr_Trace"
            for i, writer in enumerate(writers):
                with writer.as_default():
                    value = layer.atr_model.values[i]
                    if value is not None:
                        tf.summary.scalar(plot_tag, data=value, step=epoch)

In [53]:
class AtrLogger(tf.keras.callbacks.BaseLogger):
    """
    Log ATR models via matplotlib
    """
    
    def __init__(self, *args, **kwargs):
        super(AtrLogger, self).__init__(*args, **kwargs)
        self.plots = None # layer_name -> { plot_name -> { trace_name -> { data } } }
        self.model = None
        
    def plot(self, vertical=True, figsize=(30, 6)):
        for layer, plots in self.plots.items():
            dim = (len(plots), 1) if vertical else (1, len(plots))
            fig, axs = plt.subplots(*dim, figsize=figsize, sharey=False)
            for i, (plot_name, traces) in enumerate(plots.items()):
                for label, trace in traces.items():
                    axs[i].plot(
                        trace["x"], trace["y"],
                        label=label,
                        color=trace["color"],
                        linestyle=trace["style"])
                axs[i].set_title(plot_name)
                axs[i].set_xlabel("Epoch")
                axs[i].set_ylabel("Value")
                axs[i].grid(True)
                axs[i].legend()
            fig.suptitle(layer.name)
    
    def set_model(self, model):
        if not isinstance(model, NTaskModel):
            return
        if self.model != model:
            self.plots = {}
            for i in model.ctx_layers:
                layer = model.layers[i]
                if layer.atr_model is not None:
                    self.plots[layer] = defaultdict(lambda: defaultdict(lambda: {
                        "x": [],
                        "y": [],
                        "style": None,
                        "color": None
                    }))
        super(AtrLogger, self).set_model(model)
        
    def on_epoch_end(self, epoch, logs=None):
        for layer, plots in self.plots.items():
            for plot_name, traces in layer.atr_model.epoch_traces(epoch).items():
                for trace_data in traces:
                    if trace_data is not None:
                        (label, value, style, color) = trace_data
                        trace = plots[plot_name][label]
                        if len(trace["x"]) == 0:
                            trace["style"] = style
                            trace["color"] = color
                        trace["x"].append(epoch)
                        trace["y"].append(value)

In [54]:
def trace(label, value, style='-', color=None, condition=True):
    if not condition:
        return None
    return (label, value, style, color)

___

## Utility Functions

In [18]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)

In [60]:
def train(model,
          x_train,
          y_train_list,
          cycles=1,
          epochs=1,
          task_shuffle=True,
          initial_task_shuffle=False,
          explicit_contexts=None,
          assert_contexts=False,
          **kwargs):
    """
    Train an NTask model on a dataset containing samples from multiple contexts.
    
    Args:
      model            : The NTask model instance
      x_train          : The input data
      y_train_list     : A list of tuples containing the y_train value and context IDs.
                         e.g. (y_train, 0, 2) # (y_train, context_layer_0 = 0, context_layer_1 = 2)
      cycles           : The number of iterations over the entire y_train_list
      epochs           : The number of epochs to perform on each y_train element within y_train_list
      task_shuffle     : Shuffle the y_train_list on each cycle (does not shuffle data within the task)
      initial_shuffle  : Shuffle the y_train_list on the first epoch. If set to false, context order
                         within the context layers is guaranteed to remain in the same order as provided
      explicit_contexts: A list of context mappings for each task. If none (either entirely, or
                         contains None as an item), the current task will be dynamically switched.
      **kwargs         : All other keyword arguments are passed to `model.fit`
    """
    
    # Validate provided explicit contexts
    if explicit_contexts is None:
        explicit_contexts = [None]*len(y_train_list)
    elif len(explicit_contexts) != len(y_train_list):
        raise ValueError(f"Supplied number of explicit contexts ({len(explicit_contexts)}) does not match the number of tasks ({len(y_train_list)}).")
    else:
        for i, task in enumerate(explicit_contexts):
            if task is not None and len(task) != len(model.context_layers):
                raise ValueError(f"Provided explicit contexts for task {i} does not match the number of context layers.")
    
    # Create an index map of the tasks
    indices = np.arange(len(y_train_list))
    
    # Shuffle if necessary
    if initial_task_shuffle:
        np.random.shuffle(indices)
    
    # Map the shuffled tasks in the order they are passed to the `fit` method
    task_map = indices.copy()
    
    # Track the layer contexts for each task for later evaluation
    context_map = [None]*len(indices)
    
    for cycle in range(cycles):
        for i, task in enumerate(indices):
            y_train = y_train_list[task]
            
            # Calculate the initial epoch to start training on
            initial_epoch = cycle*len(y_train_list)*epochs + i*epochs
            end_epoch = initial_epoch + epochs
            
            # Set contexts explicitly if necessary
            if explicit_contexts[task] is not None:
                model.set_contexts(explicit_contexts[task])
                
            model.fit(x_train, y_train, epochs=end_epoch, initial_epoch=initial_epoch, dynamic_switch=(explicit_contexts[task] is None), **kwargs)
            
            # Update the task map
            context_map[task] = model.get_contexts()
                
        if task_shuffle:
            np.random.shuffle(indices)
            
    return task_map, context_map

In [59]:
def evaluate(model,
             x, y_list,
             task_map,
             context_map,
             display_predictions=True,
             return_dict=True,
             **kwargs):
    results = []
#     reverse_lookup = np.zeros(len(task_map))
#     for i, task in enumerate(task_map):
#         reverse_lookup[task] = i
    for i, task in enumerate(task_map):
        y = y_list[i]
        contexts = context_map[i]
        
        model.set_contexts(contexts)
        if display_predictions:
            tf.print(model.predict(x))
        results.append(model.evaluate(x, y, return_dict=return_dict, **kwargs))
    return results

# Test

In [20]:
def test(ModelClass, init_args, compile_args, x_train, y_train, seed=5, **kwargs):
    # Set the random seed for all used libraries
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    
    # Create the model
    inp = Input(x_train[0].shape)
    x = Dense(128, activation="relu")(inp)
    x = Dense(1, activation="sigmoid")(x)
    model = ModelClass(inputs=inp, outputs=x, **init_args)
    
    # Compile the model
    model.compile(
        loss=tf.keras.losses.BinaryCrossentropy(),
        optimizer=tf.keras.optimizers.SGD(1e-4),
        **compile_args
    )
    
    # Train the model
    model.fit(x_train, y_train, **kwargs)
    
    # Calculate and display the accuracy
    result = (np.round(model(x_train)).astype(int).flatten() == y_train.flatten()).sum()
    print(f"{result}/{len(y_train)}; Accuracy: {100*result/len(y_train):.2f}%")

In [21]:
def test_context(ModelClass, init_args, compile_args, x_train, y_train_list, cycles=1, seed=5, epochs=1, **kwargs):
    # Set the random seed for all used libraries
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    
    # Create the model
    inp = Input(x_train[0].shape)
    x = Dense(128, activation="relu", use_bias=True)(inp)
    x = Context(2, AtrMovingAverage(max_contexts=2, switch_threshold=-0.02))(x)
#     x = Context()(x)
    x = Dense(1, activation="sigmoid", use_bias=True)(x)
    model = ModelClass(inputs=inp, outputs=x, **init_args)
    
    # Compile the model
    model.compile(
        loss=tf.keras.losses.BinaryCrossentropy(),
        optimizer=tf.keras.optimizers.SGD(1e-1),
        **compile_args
    )
    
#     logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
#     callbacks = [
#         tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1),
#         AtrLoggerTensorBoard(logdir)
#     ]
    
    callbacks = [
        AtrLogger()
    ]
    
    model.summary()
    
    # Train the model
    for cycle in range(cycles):
        for context, y_train in enumerate(y_train_list):
            initial_epoch = cycle*len(y_train_list)*epochs + context*epochs
#             model.set_contexts([context])
            model.fit(x_train, y_train, callbacks=callbacks, initial_epoch=initial_epoch, epochs=initial_epoch + epochs, **kwargs)
    
    for context in range(len(y_train_list)):
        model.set_contexts([context])
        tf.print(model.predict(x_train))
    
    # Calculate and display the accuracy
    result = (np.round(model(x_train)).astype(int).flatten() == y_train.flatten()).sum()
    print(f"{result}/{len(y_train)}; Accuracy: {100*result/len(y_train):.2f}%")
    
    return model

In [22]:
def test_context_dynamic(ModelClass, init_args, compile_args, x_train, y_train_list, cycles=1, seed=5, epochs=1, **kwargs):
    # Set the random seed for all used libraries
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    
    # Create the model
    inp = Input(x_train[0].shape)
    x = Dense(128, activation="relu", use_bias=True)(inp)
    x = Context(1, AtrMovingAverage(max_contexts=len(y_train_list)+1, switch_threshold=-0.02, add_threshold=-0.04))(x)
    x = Dense(1, activation="sigmoid", use_bias=True)(x)
    model = ModelClass(inputs=inp, outputs=x, **init_args)
    
    # Compile the model
    model.compile(
        loss=tf.keras.losses.BinaryCrossentropy(),
        optimizer=tf.keras.optimizers.SGD(1e-1),
        **compile_args
    )
    
    logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
    callbacks = [
        tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1),
        AtrLogger(logdir)
    ]
    
    # Train the model
    for cycle in range(cycles):
        for context, y_train in enumerate(y_train_list):
            initial_epoch = cycle*len(y_train_list)*epochs + context*epochs
#             model.set_contexts([context])
            model.fit(x_train, y_train, callbacks=callbacks, initial_epoch=initial_epoch, epochs=initial_epoch + epochs, **kwargs)
    
    for context in range(len(y_train_list)):
        model.set_contexts([context])
        tf.print(model.predict(x_train))
    
    # Calculate and display the accuracy
    result = (np.round(model(x_train)).astype(int).flatten() == y_train.flatten()).sum()
    print(f"{result}/{len(y_train)}; Accuracy: {100*result/len(y_train):.2f}%")
    
    return model

### Dataset

In [23]:
# Training images
training_images = idx_load("../datasets/mnist/train-images.idx3-ubyte")
training_images.shape

(60000, 28, 28)

In [24]:
# Training labels
training_labels = idx_load("../datasets/mnist/train-labels.idx1-ubyte")
training_labels.shape

(60000,)

In [25]:
# Normalize the datasets
training_images = training_images.reshape(len(training_images), 28*28) / 255.0

In [26]:
logic_gate_labels = np.array([
    [[0], [1], [1], [0]], # XOR
    [[1], [0], [0], [1]], # XNOR
    [[0], [0], [0], [1]], # AND
    [[0], [1], [1], [1]], # OR
    [[1], [0], [0], [0]], # NOR
    [[1], [1], [1], [0]], # NAND
    [[1], [0], [1], [0]], # Custom 1
    [[0], [1], [0], [1]]  # Custom 2
])

logic_gate_inputs = np.array([[-1, -1], [-1, 1], [1, -1], [1, 1]])

___

### Model Evaluation

In [27]:
# MNIST number is even
x_train = training_images
y_train = np.array([int(i % 2 == 0) for i in training_labels])

In [28]:
# Verify on the first 10 the dataset seems correct...
print(training_labels[:10])
print(y_train[:10])

[5 0 4 1 9 2 1 3 1 4]
[0 1 1 0 0 1 0 0 0 1]


In [29]:
%time test(Model, {}, {"metrics": ["accuracy"]}, x_train, y_train, epochs=10, batch_size=64, verbose=1)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
48394/60000; Accuracy: 80.66%
CPU times: user 32.4 s, sys: 18 s, total: 50.4 s
Wall time: 14.5 s


In [30]:
%time test(NTaskModel, {}, {"metrics": ["accuracy"]}, x_train, y_train, epochs=10, batch_size=64, verbose=1)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
48394/60000; Accuracy: 80.66%
CPU times: user 31.9 s, sys: 17.2 s, total: 49.1 s
Wall time: 14.9 s


In [243]:
%time test(NTaskModel, {}, {}, logic_gate_inputs, logic_gate_labels[0], epochs=500, batch_size=1, verbose=0)

2/4; Accuracy: 50.00%
CPU times: user 3.06 s, sys: 406 ms, total: 3.47 s
Wall time: 2.65 s


In [79]:
%time test(NTaskModel, {}, {}, logic_gate_inputs, logic_gate_labels[0], epochs=1, batch_size=4, verbose=1)

DictWrapper({})
2/4; Accuracy: 50.00%
CPU times: user 250 ms, sys: 15.6 ms, total: 266 ms
Wall time: 243 ms


In [22]:
model.set_contexts([1])

In [24]:
model.evaluate(logic_gate_inputs, logic_gate_labels[1])



[0.0006013047532178462, 1.0]

In [8]:
%time model = test_context(NTaskModel, {}, {"metrics": [tf.keras.metrics.BinaryAccuracy()]}, logic_gate_inputs, logic_gate_labels[:2], cycles=3, epochs=10, batch_size=1, verbose=1)

NameError: name 'test_context' is not defined

___

In [None]:
labels = np.array([
    [[0], [1], [1], [0]], # XOR
    [[1], [0], [0], [1]], # XNOR
    [[0], [0], [0], [1]], # AND
    [[0], [1], [1], [1]], # OR
    [[1], [0], [0], [0]], # NOR
    [[1], [1], [1], [0]], # NAND
    [[1], [0], [1], [0]], # Custom 1
    [[0], [1], [0], [1]]  # Custom 2
])

x_train = np.array([[-1, -1], [-1, 1], [1, -1], [1, 1]])

In [51]:
logic_gate_labels[0]

array([[0],
       [1],
       [1],
       [0]])

In [10]:
!kill 4469

In [618]:
%tensorboard --logdir logs