## TensorFlow2 training loop control using default *tf.fit(...)* function

### Task Descripton

Up to now custom training loop in Tensorflow2 requires writing two lops:
1. loop iterating through epochs 
2. loop iterating through batches 

Then all castom training precudere will have to be implemented in these double-loop block of code. It's neither elegant nor robust due to the missing advanced features of *tf.fit(...)*.

In [24]:
import tensorflow as tf
from typing import Dict, Any, List, Tuple
import functools

In [3]:
def bind(instance, func, as_name=None):
    """
    Bind the function *func* to *instance*, with either provided name *as_name*
    or the existing name of *func*. The provided *func* should accept the 
    instance as the first argument, i.e. "self".
    """
    if as_name is None:
        as_name = func.__name__
    bound_method = func.__get__(instance, instance.__class__)
    setattr(instance, as_name, bound_method)
    return bound_method


In [35]:
class LoopControlerCallback(tf.keras.callbacks.Callback):
    def __init__(
        self, config: int, default_in_branch: Dict[str, Any] = None, verbose: bool = True, *args, **kwargs
    ) -> None:
        super(LoopControlerCallback, self).__init__(*args, **kwargs)
        self.default_in_branch: Dict[str, Any] = default_in_branch if default_in_branch else {
                "loss": None, 
                "clipping": None,
                "var": None,
                "exclud_var": None,
            }
        self.config: Dict[str, Any] = config
        self.verbose: bool = verbose


    def on_train_begin(self, logs=None):
        """Function called directely before training. It is executed on each call of model.fit with this callback.
            Inside the scope of this funciton we can access model on which this callback works: self.model.

        Args:
            logs (_type_, optional): _description_. Defaults to None.
        """

        # meta attributes for building conditions
        self.epochs: int = 0
        self.batches: int = 0
        self.batches_in_epoch: int = 0
        self.last_loss: int = 0.0
        self.history: List[Any] = []
        
        # if loss in default branch is None, then use compiled loss
        if not self.default_in_branch["loss"]:
            self.default_in_branch["loss"] = self.model.compiled_loss

        # extend config with validation step
        self.config = self._extend_config(self.config)
        # initiate control variables
        self._init_cv()
        # bind control variables to model
        self._bind_cv()
        # bind master train_step to model
        self._bind_master_step()
        # bind slave train_steps to model
        self._bin_slaves_steps()

    def _extend_config(self, config: Dict[str, Any]) -> None:
        """Extend and validate config file. Fill missing fields based on the default_in_branch.

        Args:
            config (Dict[str, Any]): Configuration to control model training
        """

        def validate_action_config(action_name: str, action_config: Dict[str, Any]) -> None:
            """Validate model training configuration.

            Args:
                action_name (str): name of the action slave train step
                action_config (Dict[str, Any]): configuration of the action slave train step

            Raises:
                ValueError: Missing controlable cond
                ValueError: Missing branch configuration for true/false after cond
            """

            Warning(f"------Validating Configuration for {action_name}------")
            if action_config == {}:
                Warning(
                    f"{action_name} has empty body. Condition and False or True branch must be implemented.\n It's ignored in furhter computations"
                )
            if (True not in action_config) and (False not in action_config):
                raise ValueError(
                    f"{action_name} has no False or True branch implemented"
                )
            if "cond" not in action_config:
                raise ValueError(f"{action_name} has no condition implemented.")

        pc = {}
        for action_name, action_config in config.items():
            validate_action_config(action_name, action_config)
            pc[action_name] = {"cond": action_config["cond"]}
            if True in action_config:
                pc[action_name][True] = {**self.default_in_branch, **action_config[True]}
            if False in action_config:
                pc[action_name][False] = {**self.default_in_branch, **action_config[False]}

        return pc

    def _get_cv_name(self, action_name: str) -> str:
        """Get control variable from action name

        Args:
            action_name (str): name of the action slave train step

        Returns:
            str: converted name of the action slave train step
        """
        return f"{action_name}_cv"

    def _get_cc_names(self, condition_name: str) -> Tuple[str, str]:
        return (f"{condition_name}_off", f"{condition_name}_on")

    def _init_cv(self) -> None:
        self.model.cv_names = []
        for cv_name in self.config.keys():
            setattr(self.model, cv_name, tf.Variable(False, trainable=False))
            self.model.cv_names.append(cv_name)

    def _bind_cv(self) -> None:
        self.c_conds = {}
        for action_name, action_config in self.config.items():
            name = self._get_cv_name(action_name)
            bind(self, action_config["cond"], name)
            self.c_conds[action_name] = name

    def _bind_master_step(self) -> None:
        lscope = locals()
        def _get_losses(config: Dict[str, Any]) -> str:
            def _substeps_condition_execution(name: str, conf: Dict[str, Any], on: bool) -> str:
                if on:
                    return f"self.{name}_on(data)" if True in conf else "0.0"
                else:
                    return f"self.{name}_off(data)" if False in conf else "0.0"

            return "{" + ",".join([
                        f"'loss_{an}' : tf.cond(self.{an}, lambda: {_substeps_condition_execution(an, ac, True)}, lambda: {_substeps_condition_execution(an, ac, False)})"
                        for an, ac in config.items()
                    ]) + "}"

        lscope = locals()
        function_body = """
@tf.function
def train_step(self, data):
    loss = {losses_config}
    metrics = {{m.name : m.result() for m in self.metrics}}
    
    control_states = {{
            c_name: tf.cond(
                getattr(self, c_name),
                lambda: tf.constant(True),
                lambda: tf.constant(False),
            )
            for c_name in self.cv_names
    }}
    
    return {{**loss, **metrics, **control_states}}
""".format(
            **{"losses_config": _get_losses(self.config)}
        )
        if self.verbose:
            print("-------------------MASTER STEP-------------------")
            print(function_body)
        exec(function_body, {**globals(), **lscope}, lscope)
        bind(self.model, lscope["train_step"])

    def _bin_slaves_steps(self) -> None:
        if self.verbose:
            print("-------------------SLAVE STEPS-------------------\n")
        for action_name, action_config in self.config.items():
            off_step_name, on_step_name = self._get_cc_names(action_name)
            if True in action_config:
                self._bind_slave_step(on_step_name, action_config[True])
            if False in action_config:
                self._bind_slave_step(off_step_name, action_config[False])

    def _bind_slave_step(self, fn_name: str, fn_config: Dict[str, Any]) -> None:
        
        lscope = {**locals(), **fn_config} # **{"loss": fn_config["loss"] if fn_config["loss"] else self.model.compiled_loss}}
        function_body = f"""
@tf.function
def {fn_name}(self, data):
   x, y = data
   with tf.GradientTape(watch_accessed_variables=True) as tape:
         logits = self(x, training=True)
         loss_value = loss(y, logits)
   grads = tape.gradient(loss_value, tape.watched_variables())
   self.optimizer.apply_gradients(zip(grads, tape.watched_variables()))
   self.compiled_metrics.update_state(y, logits)
   return loss_value
"""
        if self.verbose:
            print(f"-------------------{fn_name}-------------------")
            print(function_body)
        exec(function_body, {**globals(), **lscope}, lscope)
        bind(self.model, lscope[fn_name])

    def on_epoch_begin(self, epoch: int, logs) -> None:
        self.epochs += 1
        """Control gating variable from the level of callback which can work on epoch/batch level."""
        # tf.variable.assign is different than tf.variable = <sth>. The second option is compiled to static
        # value in TF graph of computation as the result of @tf.function decorators in LoopControlableModel
        for control_name, control_function_name in self.c_conds.items():
            getattr(self.model, control_name).assign(
                getattr(self, control_function_name)()
            )

In [36]:
DATASET_SIZE, INPUT_SIZE, OUTPUT_SIZE = 1000, 2, 1
BATCH_SIZE = 64
data = tf.data.Dataset.from_tensor_slices(
    (tf.random.uniform((DATASET_SIZE, INPUT_SIZE)), tf.random.uniform((DATASET_SIZE, OUTPUT_SIZE)))
    ).batch(BATCH_SIZE)


model = tf.keras.Sequential([tf.keras.layers.Dense(OUTPUT_SIZE)])
# compile model
model.compile(optimizer=
    tf.keras.optimizers.RMSprop(learning_rate=0.01),
    loss="mse", 
    metrics=["mae"])


def gate_config(self):
    return self.epochs % 2 == 1

def delay_config(self):
    return self.epochs > 5

config = {
    "gate": {
        "cond": gate_config,
        True: {
            "loss": tf.keras.losses.MeanAbsoluteError(), 
        },
        False: {}
        # False: {
        #     # "loss": tf.keras.losses.MeanAbsoluteError(), 
        # }
    },
    "delay": {
        "cond": delay_config,
        True: {
            "loss": tf.keras.losses.MeanSquaredError(), 
        },
    },

}

lcc = LoopControlerCallback(config)
# start training
history = model.fit(data, epochs = 10, verbose = 1,
    callbacks=[lcc])

-------------------MASTER STEP-------------------

@tf.function
def train_step(self, data):
    loss = {'loss_gate' : tf.cond(self.gate, lambda: self.gate_on(data), lambda: self.gate_off(data)),'loss_delay' : tf.cond(self.delay, lambda: self.delay_on(data), lambda: 0.0)}
    metrics = {m.name : m.result() for m in self.metrics}
    
    control_states = {
            c_name: tf.cond(
                getattr(self, c_name),
                lambda: tf.constant(True),
                lambda: tf.constant(False),
            )
            for c_name in self.cv_names
    }
    
    return {**loss, **metrics, **control_states}

-------------------SLAVE STEPS-------------------

-------------------gate_on-------------------

@tf.function
def gate_on(self, data):
   x, y = data
   with tf.GradientTape(watch_accessed_variables=True) as tape:
         logits = self(x, training=True)
         loss_value = loss(y, logits)
   grads = tape.gradient(loss_value, tape.watched_variables())
   self.optimiz

In [86]:
config = {
    "gate": {
        "cond": gate_config,
        True: {
            "loss": tf.keras.losses.MeanSquaredError(), 
        },
        False: {
            "loss": tf.keras.losses.MeanAbsoluteError(), 
        }
    },
    "delay": {
        "cond": delay_config,
        True: {
            "loss": tf.keras.losses.MeanSquaredError(), 
        },
    },
    "test": {
        "cond": delay_config,
        True: {}
    },
    # "test2": {
    #     "cond": delay_config,
    # },

}

def validate_action_config(action_name, action_config) -> bool:
        Warning(f"------Validating Configuration for {action_name}------")
        if action_config == {}:
            Warning(f"{action_name} has empty body. Condition and False or True branch must be implemented.\n It's ignored in furhter computations")
        if ((True not in action_config) and (True not in action_config)):
            raise ValueError(f"{action_name} has no False or True branch implemented")
        if "cond" not in action_config:
            raise ValueError(f"{action_name} has no condition implemented.")
        return True


            

def parse_config(config, default_in_branch = None):
    if not default_in_branch:
        default_in_branch = {
            "loss" : "self.compiled_loss",
            "clipping" : None, 
            "var": None,
            "exclud_var": None,
        }

    pc = {}
    for action_name, action_config in config.items():
        _ = validate_action_config(action_name, action_config)
        false_branch, true_branch = {}, {}
        true_branch = {**default_in_branch,**action_config[True]} if True in action_config else None
        false_branch = {**default_in_branch,**action_config[False]} if False in action_config else None
        pc[action_name] = {"cond": action_config["cond"], True: true_branch, False: false_branch}

    return pc 
            
parse_config(config)

{'gate': {'cond': <function __main__.gate_config(self)>,
  True: {'loss': <keras.losses.MeanSquaredError at 0x7f4c4d08da00>,
   'clipping': None,
   'var': None,
   'exclud_var': None},
  False: {'loss': <keras.losses.MeanAbsoluteError at 0x7f4c4d08d490>,
   'clipping': None,
   'var': None,
   'exclud_var': None}},
 'delay': {'cond': <function __main__.delay_config(self)>,
  True: {'loss': <keras.losses.MeanSquaredError at 0x7f4c4d08ddf0>,
   'clipping': None,
   'var': None,
   'exclud_var': None},
  False: None},
 'test': {'cond': <function __main__.delay_config(self)>,
  True: {'loss': 'self.compiled_loss',
   'clipping': None,
   'var': None,
   'exclud_var': None},
  False: None}}

In [62]:
False not in {False}

False

In [81]:
from inspect import signature

def foo(x: int = 1, y: int = 2) -> bool:
    return x > y
sig = signature(foo)
sig.return_annotation

bool

In [None]:
import time
class B:
    def __init__(self) -> None:
        self.x = 2
    
    def boo(self):
        
        bb = tf.keras.losses.MeanAbsoluteError()
        adict = locals()
        # print(adict)
        function_body = f"""
def train(self):
    print(bb([1,2,3], [3,4,5]))
    return self.x
"""
        exec(function_body,{**globals(), **adict}, adict)
        # print(adict)
        # time.sleep(0.1)
        bind(self, adict["train"])
        print(self.train())
        # bind(self, test3)
        # bind(self, test)
        return True

b = B()
b.boo()
# self.train_substep()

tf.Tensor(2, shape=(), dtype=int32)
2


True

In [24]:
config = {
    "gate": {
        "cond": gate_config,
        True: {
            "loss": tf.keras.losses.MeanSquaredError(), 
        },
        False: {
            "loss": tf.keras.losses.MeanAbsoluteError(), 
        }
    },
    "delay": {
        "cond": gate_config,
        True: {
            "loss": tf.keras.losses.MeanSquaredError(), 
        },
        
    }
    
}

def _get_losses(config):
    def tmp(name, conf, on):
        if on:
            return f'self.{name}_on(data)' if True in conf else '0.0'
        else:
            return f'self.{name}_off(data)' if False in conf else '0.0'

    return "{"+",".join([f"loss_{an} : tf.cond(self.{an}, lambda: {tmp(an, ac, True)}, lambda: {tmp(an, ac, False)}" for an, ac in config.items()])+"}"

name = "gate"
function_body = """
@tf.function
def train_step(self, data):
    loss = {losses_config}
    metrics = {{m.name : m.result() for m in self.metrics}}
    
    control_states = {{
            c_name: tf.cond(
                getattr(self, c_name),
                lambda: tf.constant(True),
                lambda: tf.constant(False),
            )
            for c_name in self.cv_names
    }}
    
    return {{**loss, **metrics, **control_states}}
""".format(**{"losses_config": _get_losses(config)})
function_body

'\n@tf.function\ndef train_step(self, data):\n    loss = {loss_gate : tf.cond(self.gate, lambda: self.gate_on(data), lambda: self.gate_off(data),loss_delay : tf.cond(self.delay, lambda: self.delay_on(data), lambda: 0.0}\n    metrics = {m.name : m.result() for m in self.metrics}\n    \n    control_states = {\n            c_name: tf.cond(\n                getattr(self, c_name),\n                lambda: tf.constant(True),\n                lambda: tf.constant(False),\n            )\n            for c_name in self.cv_names\n    }\n    \n    return {**loss, **metrics, **control_states}\n'

In [25]:
_get_losses(config)

'{loss_gate : tf.cond(self.gate, lambda: self.gate_on(data), lambda: self.gate_off(data),loss_delay : tf.cond(self.delay, lambda: self.delay_on(data), lambda: 0.0}'

In [None]:

@tf.function
def train_step(self, data):
    loss = {loss_gate : tf.cond(self.gate, lambda: self.gate_on(data), lambda: self.gate_off(data))}
    metrics = {m.name : m.result() for m in self.metrics}
    
    control_states = {
            c_name: tf.cond(
                getattr(self, c_name),
                lambda: tf.constant(True),
                lambda: tf.constant(False),
            )
            for c_name in self.cv_names
    }
    
    return {**loss, **metrics, **control_states}
