## TensorFlow2 training loop control using default *tf.fit(...)* function

### Task Descripton

Up to now custom training loop in Tensorflow2 requires writing two lops:
1. loop iterating through epochs 
2. loop iterating through batches 

Then all castom training precudere will have to be implemented in these double-loop block of code. It's neither elegant nor robust due to the missing advanced features of *tf.fit(...)*.

In [3]:
import tensorflow as tf
from types import MethodType
import functools

2023-01-30 22:16:27.643002: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-01-30 22:16:27.643031: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


In [4]:
def bind(instance, func, as_name=None):
    """
    Bind the function *func* to *instance*, with either provided name *as_name*
    or the existing name of *func*. The provided *func* should accept the 
    instance as the first argument, i.e. "self".
    """
    if as_name is None:
        as_name = func.__name__
    bound_method = func.__get__(instance, instance.__class__)
    setattr(instance, as_name, bound_method)
    return bound_method


In [83]:
class LoopControlerCallback(tf.keras.callbacks.Callback):
    def __init__(self, config: int, *args, **kwargs) -> None:
        super(LoopControlerCallback, self).__init__(*args, **kwargs)
        self.config = config

    def on_train_begin(self, logs=None):

        # self.model.gate = tf.Variable(False, trainable=False) # gate control variable
        self.epochs = 0
        self.steps = 0

        @tf.function
        def train_step(self, data):
            train_metrics = tf.cond(
                self.gate, lambda: self.gate_on(data), lambda: self.gate_off(data)
            )

            final_results = {
                **train_metrics,
                **{
                    c_name: tf.cond(
                        getattr(self, c_name),
                        lambda: tf.constant(True),
                        lambda: tf.constant(False),
                    )
                    for c_name in self.control_variables_names
                },
            }

            return final_results

        @tf.function
        def gate_on(self, data):
            x, y = data
            with tf.GradientTape(watch_accessed_variables=True) as tape:
                logits = self(x, training=True)
                loss_value = self.compiled_loss(y, logits)
            grads = tape.gradient(loss_value, tape.watched_variables())
            self.optimizer.apply_gradients(zip(grads, tape.watched_variables()))
            return {m.name: m.result() for m in self.metrics}
            

        @tf.function
        def gate_off(self, data):
            x, y = data
            with tf.GradientTape(watch_accessed_variables=True) as tape:
                logits = self(x, training=True)
                loss_value = self.compiled_loss(y, logits)
            grads = tape.gradient(loss_value, tape.watched_variables())
            self.optimizer.apply_gradients(zip(grads, tape.watched_variables()))
            return {m.name: m.result() for m in self.metrics}

            
        bind(self.model, train_step)
        bind(self.model, gate_on)
        bind(self.model, gate_off)

        self.set_control_variables()
        self.set_conditions()

    def __transform_cond_name__(self, action_name):
        return f"{action_name}_variable_control"

    def set_control_variables(self):
        self.model.control_variables_names = []
        for control_variable_name in self.config.keys():
            setattr(
                self.model, control_variable_name, tf.Variable(False, trainable=False)
            )
            self.model.control_variables_names.append(control_variable_name)

    # def set_control_steps(self):
    #     for

    def set_conditions(self):
        self.all_control_cond = {}
        for action_name, action_config in self.config.items():
            name = self.__transform_cond_name__(action_name)
            print(f"Biding for {action_name} in {name}")
            bind(self, action_config["cond"], name)
            self.all_control_cond[action_name] = name

    def on_epoch_end(self, epoch, logs):
        self.epochs += 1
        """Control gating variable from the level of callback which can work on epoch/batch level."""
        # tf.variable.assign is different than tf.variable = <sth>. The second option is compiled to static
        # value in TF graph of computation as the result of @tf.function decorators in LoopControlableModel
        for control_name, control_function_name in self.all_control_cond.items():
            getattr(self.model, control_name).assign(
                getattr(self, control_function_name)()
            )

In [85]:
DATASET_SIZE, INPUT_SIZE, OUTPUT_SIZE = 1000, 2, 1
BATCH_SIZE = 64
data = tf.data.Dataset.from_tensor_slices(
    (tf.random.uniform((DATASET_SIZE, INPUT_SIZE)), tf.random.uniform((DATASET_SIZE, OUTPUT_SIZE)))
    ).batch(BATCH_SIZE)


model = tf.keras.Sequential([tf.keras.layers.Dense(OUTPUT_SIZE)])
model.todel = tf.Variable(True)
# compile model
model.compile(optimizer=
    tf.keras.optimizers.RMSprop(learning_rate=0.01),
    loss="mse", 
    metrics=["mae"])


def gate_config(self):
    return self.epochs % 2 == 0 

def delay_config(self):
    return self.epochs % 4 == 0 


config = {
    "gate": {
        "cond": gate_config,
        True: {
            "loss": tf.keras.losses.MeanSquaredError, 
        },
        False: {
            "loss": tf.keras.losses.MeanAbsoluteError, 
        }
    }, 
    "delay": {
        "cond": delay_config 
    }
}

# start training
history = model.fit(data, epochs = 10, verbose = 1,
    callbacks=[LoopControlerCallback(config)])

Biding for gate in gate_variable_control
Biding for delay in delay_variable_control
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
