## TensorFlow2 training loop control using default *tf.fit(...)* function

### Task Descripton

Up to now custom training loop in Tensorflow2 requires writing two lops:
1. loop iterating through epochs 
2. loop iterating through batches 

Then all castom training precudere will have to be implemented in these double-loop block of code. It's neither elegant nor robust due to the missing advanced features of *tf.fit(...)*.

In [1]:
import tensorflow as tf
from typing import Dict, Any, List, Tuple
import functools

2023-02-20 08:16:13.607055: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-02-20 08:16:13.607085: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


In [43]:
def bind(instance, func, as_name=None):
    """
    Bind the function *func* to *instance*, with either provided name *as_name*
    or the existing name of *func*. The provided *func* should accept the 
    instance as the first argument, i.e. "self".
    """
    if as_name is None:
        as_name = func.__name__
    bound_method = func.__get__(instance, instance.__class__)
    setattr(instance, as_name, bound_method)
    return bound_method

In [44]:
DATASET_SIZE, INPUT_SIZE, OUTPUT_SIZE = 1000, 2, 1
BATCH_SIZE = 64
data = tf.data.Dataset.from_tensor_slices(
    (tf.random.uniform((DATASET_SIZE, INPUT_SIZE)), tf.random.uniform((DATASET_SIZE, OUTPUT_SIZE)))
    ).batch(BATCH_SIZE)

In [65]:
class LoopControlerCallback(tf.keras.callbacks.Callback):
    def __init__(
        self, config: int, default_in_branch: Dict[str, Any] = None, verbose: bool = True, *args, **kwargs
    ) -> None:
        super(LoopControlerCallback, self).__init__(*args, **kwargs)
        self.default_in_branch: Dict[str, Any] = default_in_branch if default_in_branch else {
                "loss": None, 
                "clipping": None,
                "variables": None,
                "exclud_var": None,
            }
        self.config: Dict[str, Any] = config
        self.verbose: bool = verbose

    def on_train_begin(self, logs=None):
        """Function called directely before training. It is executed on each call of model.fit with this callback.
            Inside the scope of this funciton we can access model on which this callback works: self.model.

        Args:
            logs (_type_, optional): _description_. Defaults to None.
        """

        # meta attributes for building conditions
        self.epochs: int = 0
        self.batches: int = 0
        self.batches_in_epoch: int = 0
        self.last_loss: int = 0.0
        self.history: List[Any] = []
        
        # if loss in default branch is None, then use compiled loss
        if not self.default_in_branch["loss"]:
            self.default_in_branch["loss"] = self.model.compiled_loss

        # extend config with validation step
        self.config = self._extend_config(self.config)

        # bind control variables and control conditions
        self._bind_controlers(self.config)
        
        # bind master train_step to model
        self._bind_master_step(self.config)

        # bind slave train_steps to model
        self._bind_slaves_steps(self.config)
        

    def _bind_master_step(self, config: Dict[str, Any]) -> None:
        lscope = locals()
        def _get_losses(config: Dict[str, Any]) -> str:
            def _substeps_condition_execution(name: str, config: Dict[str, Any], on: bool) -> str:
                if on:
                    return f"self.{name}_on(data)" if True in config else "0.0"
                else:
                    return f"self.{name}_off(data)" if False in config else "0.0"
                    
            return "{" + ",".join([
                        f"'loss_{an}' : tf.cond(self.control_variables['{an}'], lambda: {_substeps_condition_execution(an, ac, True)}, lambda: {_substeps_condition_execution(an, ac, False)})"
                        for an, ac in config.items()
                    ]) + "}"

        lscope = locals()
        function_body = """
@tf.function
def train_step(self, data):
    loss = {losses_config}
    metrics = {{m.name : m.result() for m in self.metrics}}
    
    control_states = {{
            control_name: tf.cond(
                control_value,
                lambda: tf.constant(True),
                lambda: tf.constant(False),
            )
            for control_name, control_value in self.control_variables.items()
    }}
    
    return {{**loss, **metrics, **control_states}}
""".format(
            **{"losses_config": _get_losses(config)}
        )

        if self.verbose:
            print("-------------------MASTER STEP-------------------")
            print(function_body)
        
        exec(function_body, {**globals(), **lscope}, lscope)
        bind(self.model, lscope["train_step"])

    def _bind_slaves_steps(self, config) -> None:
        
        if self.verbose:
            print("-------------------SLAVE STEPS-------------------\n")

        for action_name, action_config in config.items():
            if True in action_config:
                self._bind_slave_step(action_name, action_config[True], True)
            if False in action_config:
                self._bind_slave_step(action_name, action_config[False], False)

    def _bind_slave_step(self, action_name: str, fn_config: Dict[str, Any], branch: bool) -> None:
        
        lscope = {
            **locals(), 
            **fn_config
            }
        fn_name = self._get_actoin_step_name(action_name, branch)

        function_body = f"""
@tf.function
def {fn_name}(self, data):
    x, y = data
        """

        if fn_config["variables"]:
            get_vars_fn = "_get_" + fn_name + "_variables"
            bind(self.model, fn_config["variables"], get_vars_fn)
            function_body += f"""
    with tf.GradientTape(watch_accessed_variables=False) as tape:
        for g in self.{get_vars_fn}():
            tape.watch(g)
            """
        else:
            function_body += """
    with tf.GradientTape(watch_accessed_variables=True) as tape:
            """

        function_body += """
        logits = self(x, training=True)
        loss_value = loss(y, logits)
    grads = tape.gradient(loss_value, tape.watched_variables())
    self.optimizer.apply_gradients(zip({clipping_grads}, tape.watched_variables()))
    self.compiled_metrics.update_state(y, logits)
    return loss_value
""".format(**{
    "clipping_grads": "[tf.clip_by_value(g, {clip_low}, {clip_high}) for g in grads]".format(**{
        "clip_low": fn_config["clipping"][0], "clip_high": fn_config["clipping"][1]
        }) if fn_config["clipping"] else "grads",
})

        if self.verbose:
            print(f"-------------------{fn_name}-------------------")
            print(function_body)
        
        exec(function_body, {**globals(), **lscope}, lscope)
        bind(self.model, lscope[fn_name])

    def on_epoch_begin(self, epoch: int, logs) -> None:
        self.epochs += 1
        """Control gating variable from the level of callback which can work on epoch/batch level."""
        # tf.variable.assign is different than tf.variable = <sth>. The second option is compiled to static
        # value in TF graph of computation as the result of @tf.function decorators in LoopControlableModel
        for action_name, _ in self.config.items():
            self.model.control_variables[action_name].assign(
                getattr(self, self.control_conditions[action_name])()
            )

    def _get_actoin_step_name(self, action_name: str, branch: bool) -> str:
        return f"{action_name}_on" if branch else f"{action_name}_off"

    def _bind_controlers(self, config) -> None:
        self.model.control_variables = {}
        self.control_conditions = {}
        for action_name, action_config in config.items():
            self.model.control_variables[action_name] = tf.Variable(False, trainable=False)
            condition_function_name = action_name + "_condition"
            bind(self, action_config["cond"], condition_function_name)
            self.control_conditions[action_name] = condition_function_name

    def _extend_config(self, config: Dict[str, Any]) -> None:
        """Extend and validate config file. Fill missing fields based on the default_in_branch.

        Args:
            config (Dict[str, Any]): Configuration to control model training
        """

        def validate_action_config(action_name: str, action_config: Dict[str, Any]) -> None:
            """Validate model training configuration.

            Args:
                action_name (str): name of the action slave train step
                action_config (Dict[str, Any]): configuration of the action slave train step

            Raises:
                ValueError: Missing controlable cond
                ValueError: Missing branch configuration for true/false after cond
            """

            Warning(f"------Validating Configuration for {action_name}------")
            if action_config == {}:
                Warning(
                    f"{action_name} has empty body. Condition and False or True branch must be implemented.\n It's ignored in furhter computations"
                )
            if (True not in action_config) and (False not in action_config):
                raise ValueError(
                    f"{action_name} has no False or True branch implemented"
                )
            if "cond" not in action_config:
                raise ValueError(f"{action_name} has no condition implemented.")

        pc = {}
        for action_name, action_config in config.items():
            validate_action_config(action_name, action_config)
            pc[action_name] = {"cond": action_config["cond"]}
            if True in action_config:
                pc[action_name][True] = {**self.default_in_branch, **action_config[True]}
            if False in action_config:
                pc[action_name][False] = {**self.default_in_branch, **action_config[False]}

        return pc

In [66]:

model = tf.keras.Sequential([
    tf.keras.layers.InputLayer(INPUT_SIZE),
    tf.keras.layers.Dense(OUTPUT_SIZE),
    tf.keras.layers.Dense(OUTPUT_SIZE)
    ])
# compile model
model.compile(optimizer=
    tf.keras.optimizers.RMSprop(learning_rate=0.01),
    loss="mse", 
    metrics=["mae"])

def gate_config(self):
    return self.epochs % 2 == 1

def delay_config(self):
    return self.epochs > 5

def dummy_vars(self):
    return [self.variables[0]]

config = {
    "delay": {
        "cond": delay_config,
        True: {
            "loss": tf.keras.losses.MeanSquaredError(), 
        },
    },
    "gate": {
        "cond": gate_config,
        True: {
            "loss": tf.keras.losses.MeanAbsoluteError(),
            "clipping": (-0.2, 0.3),
            "variables": dummy_vars
        },
        False: {
            "loss": tf.keras.losses.MeanAbsoluteError()
        }
    },

}

lcc = LoopControlerCallback(config, verbose=1)
# start training
history = model.fit(data, epochs = 10, verbose = 1,
    callbacks=[lcc])

-------------------MASTER STEP-------------------

@tf.function
def train_step(self, data):
    loss = {'loss_delay' : tf.cond(self.control_variables['delay'], lambda: self.delay_on(data), lambda: 0.0),'loss_gate' : tf.cond(self.control_variables['gate'], lambda: self.gate_on(data), lambda: self.gate_off(data))}
    metrics = {m.name : m.result() for m in self.metrics}
    
    control_states = {
            control_name: tf.cond(
                control_value,
                lambda: tf.constant(True),
                lambda: tf.constant(False),
            )
            for control_name, control_value in self.control_variables.items()
    }
    
    return {**loss, **metrics, **control_states}

-------------------SLAVE STEPS-------------------

-------------------delay_on-------------------

@tf.function
def delay_on(self, data):
    x, y = data
        
    with tf.GradientTape(watch_accessed_variables=True) as tape:
            
        logits = self(x, training=True)
        loss_