# Model Source

This notebook experiments with re-implementing the model's source code inserting various event handlers for extension to n-task

In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras import Model

import math
import numpy as np
import random
import time

from utils import idx_load

## Extended Model

The model below serves as a new base model for NTask

In [2]:
import copy

from tensorflow.python.data.experimental.ops import distribute_options
from tensorflow.python.data.ops import dataset_ops

from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops

from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.profiler import traceme

from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.keras import callbacks as callbacks_module
from tensorflow.python.keras.utils import version_utils
from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.engine import data_adapter
from tensorflow.python.keras.engine import training
from tensorflow.python.util import nest

In [3]:
# Borrowed from https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/keras/engine/data_adapter.py
try:
    import pandas as pd  # pylint: disable=g-import-not-at-top
except ImportError:
    pd = None

In [4]:
# Extended from https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/keras/engine/data_adapter.py
class WindowedDataHandler(data_adapter.DataHandler):
    """
    Enumerating over this data handler yields windows of the dataset.
    This is important for n-task because if a context switch occurs
    during an epoch, the data needs to be sent back through the network.
    """
    def enumerate_epochs(self):
        # Calculate the number of samples for the window size
        batch_size = self._adapter.batch_size()
        num_samples = self._inferred_steps*batch_size
        if self._adapter.has_partial_batch():
            num_samples -= batch_size - self._adapter.partial_batch_size()
        window_size = np.ceil(num_samples/min(batch_size, num_samples))
        print(window_size)
        # Split the dataset into windows
        data_iterator = iter(self._dataset.window(window_size))
        for epoch in range(self._initial_epoch, self._epochs):
            if self._insufficient_data:
                break
            if self._adapter.should_recreate_iterator():
                data_iterator = iter(self._dataset)
            yield epoch, data_iterator
            self._adapter.on_epoch_end()

In [23]:
# Extended from https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/keras/engine/training.py
class NTaskModelBase(Model):
#     def train_step(self, data):
#         data = data_adapter.expand_1d(data)
#         x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)

#         with backprop.GradientTape() as tape:
#             y_pred = self(x, training=True)
#             loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)

#         training._minimize(self.distribute_strategy, tape, self.optimizer, loss,
#               self.trainable_variables)

#         self.compiled_metrics.update_state(y, y_pred, sample_weight)
#         return {m.name: m.result() for m in self.metrics}
    
    
    @training.enable_multi_worker
    def fit(self,
          x=None,
          y=None,
          batch_size=None,
          epochs=1,
          verbose=1,
          callbacks=None,
          validation_split=0.,
          validation_data=None,
          shuffle=True,
          class_weight=None,
          sample_weight=None,
          initial_epoch=0,
          steps_per_epoch=None,
          validation_steps=None,
          validation_batch_size=None,
          validation_freq=1,
          max_queue_size=10,
          workers=1,
          use_multiprocessing=False):

        training._keras_api_gauge.get_cell('fit').set(True)
        # Legacy graph support is contained in `training_v1.Model`.
        version_utils.disallow_legacy_graph('Model', 'fit')
        self._assert_compile_was_called()
        self._check_call_args('fit')

        if validation_split:
            # Create the validation data using the training data. Only supported for
            # `Tensor` and `NumPy` input.
            (x, y, sample_weight), validation_data = (
            data_adapter.train_validation_split((x, y, sample_weight),
                                                validation_split=validation_split,
                                                shuffle=False))

        with self.distribute_strategy.scope(), training_utils.RespectCompiledTrainableState(self):
            # Creates a `tf.data.Dataset` and handles batch and epoch iteration.
            data_handler = WindowedDataHandler(
                x=x,
                y=y,
                sample_weight=sample_weight,
                batch_size=batch_size,
                steps_per_epoch=steps_per_epoch,
                initial_epoch=initial_epoch,
                epochs=epochs,
                shuffle=shuffle,
                class_weight=class_weight,
                max_queue_size=max_queue_size,
                workers=workers,
                use_multiprocessing=use_multiprocessing,
                model=self)

            # Container that configures and calls `tf.keras.Callback`s.
            if not isinstance(callbacks, callbacks_module.CallbackList):
                callbacks = callbacks_module.CallbackList(
                    callbacks,
                    add_history=True,
                    add_progbar=verbose != 0,
                    model=self,
                    verbose=verbose,
                    epochs=epochs,
                    steps=data_handler.inferred_steps)

            self.stop_training = False
            train_function = self.make_train_function()
            callbacks.on_train_begin()
            # Handle fault-tolerance for multi-worker.
            # TODO(omalleyt): Fix the ordering issues that mean this has to
            # happen after `callbacks.on_train_begin`.
            data_handler._initial_epoch = (self._maybe_load_initial_epoch_from_ckpt(initial_epoch))
            for epoch, window_iterator in data_handler.enumerate_epochs():
                self.reset_metrics()
                callbacks.on_epoch_begin(epoch)
                dataset = tf.data.Dataset.zip(next(window_iterator))
                switched = True
                while switched:
                    switched = False
                    iterator = iter(dataset)
                    with data_handler.catch_stop_iteration():
                        for step in data_handler.steps():
                            with traceme.TraceMe( 'TraceContext', graph_type='train', epoch_num=epoch, step_num=step, batch_size=batch_size):
                                callbacks.on_train_batch_begin(step)
                                tmp_logs = train_function(iterator)
                                # Catch OutOfRangeError for Datasets of unknown size.
                                # This blocks until the batch has finished executing.
                                # TODO(b/150292341): Allow multiple async steps here.
                                if not data_handler.inferred_steps:
                                    context.async_wait()
                                logs = tmp_logs  # No error, now safe to assign to logs.
                                callbacks.on_train_batch_end(step, logs)
                epoch_logs = copy.copy(logs)

                # Run validation.
                if validation_data and self._should_eval(epoch, validation_freq):
                    val_x, val_y, val_sample_weight = (
                        data_adapter.unpack_x_y_sample_weight(validation_data))
                    val_logs = self.evaluate(
                        x=val_x,
                        y=val_y,
                        sample_weight=val_sample_weight,
                        batch_size=validation_batch_size or batch_size,
                        steps=validation_steps,
                        callbacks=callbacks,
                        max_queue_size=max_queue_size,
                        workers=workers,
                        use_multiprocessing=use_multiprocessing,
                        return_dict=True)
                    val_logs = {'val_' + name: val for name, val in val_logs.items()}
                    epoch_logs.update(val_logs)

                callbacks.on_epoch_end(epoch, epoch_logs)
                if self.stop_training:
                    break

            callbacks.on_train_end()
            return self.history

In [24]:
class NTaskModel(NTaskModelBase):
    pass

___

# Test

In [6]:
def test(ModelClass, x_train, y_train, seed=5, **kwargs):
    # Set the random seed for all used libraries
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    
    # Create the model
    inp = Input(x_train[0].shape)
    x = Dense(128, activation="relu")(inp)
    x = Dense(1, activation="sigmoid")(x)
    model = ModelClass(inputs=inp, outputs=x)
    
    # Compile the model
    model.compile(
        loss=tf.keras.losses.BinaryCrossentropy(),
        optimizer=tf.keras.optimizers.SGD(1e-4)
    )
    
    # Train the model
    model.fit(x_train, y_train, **kwargs)
    
    # Calculate and display the accuracy
    result = (np.round(model(x_train)).astype(int).flatten() == y_train.flatten()).sum()
    print(f"{result}/{len(y_train)}; Accuracy: {100*result/len(y_train):.2f}%")

### Dataset

In [7]:
# Training images
training_images = idx_load("../datasets/mnist/train-images.idx3-ubyte")
training_images.shape

(60000, 28, 28)

In [8]:
# Training labels
training_labels = idx_load("../datasets/mnist/train-labels.idx1-ubyte")
training_labels.shape

(60000,)

In [9]:
# Normalize the datasets
training_images = training_images.reshape(len(training_images), 28*28) / 255.0

In [10]:
logic_gate_labels = np.array([
    [[0], [1], [1], [0]], # XOR
    [[1], [0], [0], [1]], # XNOR
    [[0], [0], [0], [1]], # AND
    [[0], [1], [1], [1]], # OR
    [[1], [0], [0], [0]], # NOR
    [[1], [1], [1], [0]], # NAND
    [[1], [0], [1], [0]], # Custom 1
    [[0], [1], [0], [1]]  # Custom 2
])

logic_gate_inputs = np.array([[-1, -1], [-1, 1], [1, -1], [1, 1]])

___

### Evaluation

In [11]:
# MNIST number is even
x_train = training_images
y_train = np.array([int(i % 2 == 0) for i in training_labels])

In [12]:
# Verify on the first 10 the dataset seems correct...
print(training_labels[:10])
print(y_train[:10])

[5 0 4 1 9 2 1 3 1 4]
[0 1 1 0 0 1 0 0 0 1]


In [20]:
s = time.time()
test(Model, x_train, y_train, epochs=10, batch_size=64, verbose=1)
print(time.time() - s)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
48394/60000; Accuracy: 80.66%
13.526658296585083


In [21]:
s = time.time()
test(ExtendedModel, x_train, y_train, epochs=10, batch_size=64, verbose=1)
print(time.time() - s)

938.0
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
48394/60000; Accuracy: 80.66%
13.328534841537476


In [664]:
test(Model, logic_gate_inputs, logic_gate_labels[0], epochs=1, batch_size=3, verbose=1)

2/4; Accuracy: 50.00%


In [663]:
test(ExtendedModel, logic_gate_inputs, logic_gate_labels[0], epochs=1, batch_size=3, verbose=1)

2.0
2/4; Accuracy: 50.00%


In [503]:
epochs = g_dataset.window(4)

In [507]:
for win_x, win_y in epochs:


<_VariantDataset shapes: (1, 2), types: tf.int64> <_VariantDataset shapes: (1, 1), types: tf.int64>
tf.Tensor([[-1  1]], shape=(1, 2), dtype=int64)
tf.Tensor([[-1 -1]], shape=(1, 2), dtype=int64)
tf.Tensor([[1 1]], shape=(1, 2), dtype=int64)
tf.Tensor([[ 1 -1]], shape=(1, 2), dtype=int64)


In [519]:
gen = (i for i in range(10))

In [520]:
iterator = lambda: next(gen)

In [521]:
next(iterator)

TypeError: 'function' object is not an iterator

In [523]:
iterator()

1

In [525]:
g_dataset

<_OptionsDataset shapes: ((1, 2), (1, 1)), types: (tf.int64, tf.int64)>

In [527]:
g_dataset.size

AttributeError: '_OptionsDataset' object has no attribute 'size'

In [640]:
a = iter([1, 2, 3])
b = iter([4, 5, 6])
c = next(zip(a, b))

In [641]:
c

(1, 4)