## TensorFlow2 training loop control using default *tf.fit(...)* function

### Task Descripton

Up to now custom training loop in Tensorflow2 requires writing two lops:
1. loop iterating through epochs 
2. loop iterating through batches 

Then all castom training precudere will have to be implemented in these double-loop block of code. It's neither elegant nor robust due to the missing advanced features of *tf.fit(...)*.

In [92]:
import tensorflow as tf
from types import MethodType
import functools

In [93]:
def bind(instance, func, as_name=None):
    """
    Bind the function *func* to *instance*, with either provided name *as_name*
    or the existing name of *func*. The provided *func* should accept the 
    instance as the first argument, i.e. "self".
    """
    if as_name is None:
        as_name = func.__name__
    bound_method = func.__get__(instance, instance.__class__)
    setattr(instance, as_name, bound_method)
    return bound_method


In [174]:
class LoopControlerCallback(tf.keras.callbacks.Callback):
    def __init__(self, config: int, *args, **kwargs) -> None:
        super(LoopControlerCallback, self).__init__(*args, **kwargs)
        self.config = config

    def on_train_begin(self, logs=None):

        # meta attributes for building conditions
        self.epochs = 0
        self.batches = 0
        self.batches_in_epoch = 0
        self.last_loss = 0.0
        self.history = []

        @tf.function
        def train_step(self, data):
            loss = {"loss" : tf.cond(
                self.gate, lambda: self.gate_on(data), lambda: self.gate_off(data)
            )}

            metrics = {m.name : m.result() for m in self.metrics}
            
            control_states = {
                    c_name: tf.cond(
                        getattr(self, c_name),
                        lambda: tf.constant(True),
                        lambda: tf.constant(False),
                    )
                    for c_name in self.cv_names
            }
            
            return {**loss, **metrics, **control_states}
            
        bind(self.model, train_step)
        
        self._init_cv()
        self._bind_cv()

        self._bind_train_branch("gate_on", tf.keras.losses.MeanAbsoluteError())
        self._bind_train_branch("gate_off", tf.keras.losses.MeanSquaredError())



    def _get_cv_name(self, action_name):
        return f"{action_name}_variable_control"
    
    def _get_cc_names(self, condition_name):
        return (f"{condition_name}_off", f"{condition_name}_on")

    def _init_cv(self):
        self.model.cv_names = []
        for control_variable_name in self.config.keys():
            setattr(
                self.model, control_variable_name, tf.Variable(False, trainable=False)
            )
            self.model.cv_names.append(control_variable_name)

    def _bind_cv(self):
        self.c_conds = {}
        for action_name, action_config in self.config.items():
            name = self._get_cv_name(action_name)
            bind(self, action_config["cond"], name)
            self.c_conds[action_name] = name

    def _bind_train_branch(self, fn_name, loss):
        lscope = locals()
        function_body = f"""
@tf.function
def {fn_name}(self, data):
   x, y = data
   with tf.GradientTape(watch_accessed_variables=True) as tape:
         logits = self(x, training=True)
         loss_value = loss(y, logits)
   grads = tape.gradient(loss_value, tape.watched_variables())
   self.optimizer.apply_gradients(zip(grads, tape.watched_variables()))
   self.compiled_metrics.update_state(y, logits)
   return loss_value
"""
        exec(function_body, {**globals(), **lscope}, lscope)
        bind(self.model, lscope[fn_name])
        

    def on_epoch_end(self, epoch, logs):
        self.epochs += 1
        """Control gating variable from the level of callback which can work on epoch/batch level."""
        # tf.variable.assign is different than tf.variable = <sth>. The second option is compiled to static
        # value in TF graph of computation as the result of @tf.function decorators in LoopControlableModel
        for control_name, control_function_name in self.c_conds.items():
            getattr(self.model, control_name).assign(
                getattr(self, control_function_name)()
            )

    def on_batch_end(self, batch, logs=None):
        self.batches += 1
        


In [175]:
DATASET_SIZE, INPUT_SIZE, OUTPUT_SIZE = 1000, 2, 1
BATCH_SIZE = 64
data = tf.data.Dataset.from_tensor_slices(
    (tf.random.uniform((DATASET_SIZE, INPUT_SIZE)), tf.random.uniform((DATASET_SIZE, OUTPUT_SIZE)))
    ).batch(BATCH_SIZE)


model = tf.keras.Sequential([tf.keras.layers.Dense(OUTPUT_SIZE)])
model.todel = tf.Variable(True)
# compile model
model.compile(optimizer=
    tf.keras.optimizers.RMSprop(learning_rate=0.01),
    loss="mse", 
    metrics=["mae"])


def gate_config(self):
    return self.epochs % 2 == 0 

def delay_config(self):
    return self.epochs % 4 == 0 


config = {
    "gate": {
        "cond": gate_config,
        True: {
            "loss": tf.keras.losses.MeanSquaredError(), 
        },
        False: {
            "loss": tf.keras.losses.MeanAbsoluteError(), 
        }
    }
}

lcc = LoopControlerCallback(config)
# start training
history = model.fit(data, epochs = 6, verbose = 1,
    callbacks=[lcc])

Epoch 1/6
Epoch 2/6
Epoch 3/6
Epoch 4/6
Epoch 5/6
Epoch 6/6


In [176]:
import time
class B:
    def __init__(self) -> None:
        self.x = 2
    
    def boo(self):
        
        bb = tf.keras.losses.MeanAbsoluteError()
        adict = locals()
        # print(adict)
        function_body = f"""
def train(self):
    print(bb([1,2,3], [3,4,5]))
    return self.x
"""
        exec(function_body,{**globals(), **adict}, adict)
        # print(adict)
        # time.sleep(0.1)
        bind(self, adict["train"])
        print(self.train())
        # bind(self, test3)
        # bind(self, test)
        return True

b = B()
b.boo()
# self.train_substep()

tf.Tensor(2, shape=(), dtype=int32)
2


True