## TensorFlow2 training loop control using default *tf.fit(...)* function

### Task Descripton

Up to now custom training loop in Tensorflow2 requires writing two lops:
1. loop iterating through epochs 
2. loop iterating through batches 

Then all castom training precudere will have to be implemented in these double-loop block of code. It's neither elegant nor robust due to the missing advanced features of *tf.fit(...)*.

In [1]:
import tensorflow as tf
from types import MethodType
import functools

2023-02-01 15:31:27.532975: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-02-01 15:31:27.533025: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


In [2]:
def bind(instance, func, as_name=None):
    """
    Bind the function *func* to *instance*, with either provided name *as_name*
    or the existing name of *func*. The provided *func* should accept the 
    instance as the first argument, i.e. "self".
    """
    if as_name is None:
        as_name = func.__name__
    bound_method = func.__get__(instance, instance.__class__)
    setattr(instance, as_name, bound_method)
    return bound_method


In [113]:
class LoopControlerCallback(tf.keras.callbacks.Callback):
    def __init__(self, config: int, *args, **kwargs) -> None:
        super(LoopControlerCallback, self).__init__(*args, **kwargs)
        self.config = config

    def on_train_begin(self, logs=None):

        # meta attributes for building conditions
        self.epochs = 0
        self.batches = 0
        self.batches_in_epoch = 0
        self.last_loss = 0.0
        self.history = []
        
        self._init_cv()
        self._bind_cv()
        self._bin_slaves_steps("gate")
        self._bind_master_step("gate")
        # self._bind_slave_step("gate_on", tf.keras.losses.MeanAbsoluteError())
        # self._bind_slave_step("gate_off", tf.keras.losses.MeanSquaredError())

    def _get_cv_name(self, action_name):
        return f"{action_name}_cv"
    
    def _get_cc_names(self, condition_name):
        return (f"{condition_name}_off", f"{condition_name}_on")

    def _init_cv(self):
        self.model.cv_names = []
        for cv_name in self.config.keys():
            setattr(
                self.model, cv_name, tf.Variable(False, trainable=False)
            )
            self.model.cv_names.append(cv_name)

    def _bind_cv(self):
        self.c_conds = {}
        for action_name, action_config in self.config.items():
            name = self._get_cv_name(action_name)
            bind(self, action_config["cond"], name)
            self.c_conds[action_name] = name

   

    def _bind_master_step(self, name):
        lscope = locals()
        function_body = """
@tf.function
def train_step(self, data):
    loss = {{"loss_{name}" : tf.cond(self.{name}, lambda: self.{name}_on(data), lambda: self.{name}_off(data))}}

    metrics = {{m.name : m.result() for m in self.metrics}}
    
    control_states = {{
            c_name: tf.cond(
                getattr(self, c_name),
                lambda: tf.constant(True),
                lambda: tf.constant(False),
            )
            for c_name in self.cv_names
    }}
    
    return {{**loss, **metrics, **control_states}}
""".format(**{"name": name})

        exec(function_body, {**globals(), **lscope}, lscope)
        bind(self.model, lscope["train_step"])
        
    def _bin_slaves_steps(self, name):
        @tf.function
        def dymmy_empty_step(self, data):
            return tf.constant(-1.0)
            
        on_step_name, off_step_name = self._get_cc_names(name)
        if True in self.config[name]:
            self._bind_slave_step(on_step_name, self.config[name][True]["loss"])
        else:
            bind(self.model, dymmy_empty_step, on_step_name)
        if False in self.config[name]:
            self._bind_slave_step(off_step_name, self.config[name][False]["loss"])
        else:
            bind(self.model, dymmy_empty_step, off_step_name)

    def _bind_slave_step(self, fn_name, loss):
        lscope = locals()
        function_body = f"""
@tf.function
def {fn_name}(self, data):
   x, y = data
   with tf.GradientTape(watch_accessed_variables=True) as tape:
         logits = self(x, training=True)
         loss_value = loss(y, logits)
   grads = tape.gradient(loss_value, tape.watched_variables())
   self.optimizer.apply_gradients(zip(grads, tape.watched_variables()))
   self.compiled_metrics.update_state(y, logits)
   return loss_value
"""
        exec(function_body, {**globals(), **lscope}, lscope)
        bind(self.model, lscope[fn_name])
        

    def on_epoch_end(self, epoch, logs):
        self.epochs += 1
        """Control gating variable from the level of callback which can work on epoch/batch level."""
        # tf.variable.assign is different than tf.variable = <sth>. The second option is compiled to static
        # value in TF graph of computation as the result of @tf.function decorators in LoopControlableModel
        for control_name, control_function_name in self.c_conds.items():
            getattr(self.model, control_name).assign(
                getattr(self, control_function_name)()
            )
        


In [115]:
DATASET_SIZE, INPUT_SIZE, OUTPUT_SIZE = 1000, 2, 1
BATCH_SIZE = 64
data = tf.data.Dataset.from_tensor_slices(
    (tf.random.uniform((DATASET_SIZE, INPUT_SIZE)), tf.random.uniform((DATASET_SIZE, OUTPUT_SIZE)))
    ).batch(BATCH_SIZE)


model = tf.keras.Sequential([tf.keras.layers.Dense(OUTPUT_SIZE)])
# compile model
model.compile(optimizer=
    tf.keras.optimizers.RMSprop(learning_rate=0.01),
    loss="mse", 
    metrics=["mae"])


def gate_config(self):
    return self.epochs % 2 == 0

config = {
    "default_step": True,
    "gate": {
        "cond": gate_config,
        True: {
            "loss": tf.keras.losses.MeanSquaredError(), 
        },
        False: {
            "loss": tf.keras.losses.MeanAbsoluteError(), 
        }
    }
}

lcc = LoopControlerCallback(config)
# start training
history = model.fit(data, epochs = 6, verbose = 1,
    callbacks=[lcc])

Epoch 1/6
Epoch 2/6
Epoch 3/6
Epoch 4/6
Epoch 5/6
Epoch 6/6


In [5]:
import time
class B:
    def __init__(self) -> None:
        self.x = 2
    
    def boo(self):
        
        bb = tf.keras.losses.MeanAbsoluteError()
        adict = locals()
        # print(adict)
        function_body = f"""
def train(self):
    print(bb([1,2,3], [3,4,5]))
    return self.x
"""
        exec(function_body,{**globals(), **adict}, adict)
        # print(adict)
        # time.sleep(0.1)
        bind(self, adict["train"])
        print(self.train())
        # bind(self, test3)
        # bind(self, test)
        return True

b = B()
b.boo()
# self.train_substep()

tf.Tensor(2, shape=(), dtype=int32)
2


True