## TensorFlow2 training loop control using default *tf.fit(...)* function

### Task Descripton

Up to now custom training loop in Tensorflow2 requires writing two lops:
1. loop iterating through epochs 
2. loop iterating through batches 

Then all castom training precudere will have to be implemented in these double-loop block of code. It's neither elegant nor robust due to the missing advanced features of *tf.fit(...)*.

In [3]:
import tensorflow as tf
from typing import Dict, Any, List, Tuple
import functools

In [4]:
def bind(instance, func, as_name=None):
    """
    Bind the function *func* to *instance*, with either provided name *as_name*
    or the existing name of *func*. The provided *func* should accept the 
    instance as the first argument, i.e. "self".
    """
    if as_name is None:
        as_name = func.__name__
    bound_method = func.__get__(instance, instance.__class__)
    setattr(instance, as_name, bound_method)
    return bound_method


In [185]:
@tf.function
def gate_on(self, data):
    x, y = data
        
    print("----------!!!!!!----------")
    print(self.gate_on_vars)
    if not self.gate_on_vars:
        self.gate_on_vars=self.get_gate_on_vars()
    print(self.gate_on_vars())
    
    with tf.GradientTape(watch_accessed_variables=False) as tape:
        for g in self.gate_on_vars:
            tape.watch(g)
            
        logits = self(x, training=True)
        loss_value = loss(y, logits)
    grads = tape.gradient(loss_value, tape.watched_variables())
    self.optimizer.apply_gradients(zip([tf.clip_by_value(g, -0.2, 0.3) for g in grads], tape.watched_variables()))
    self.compiled_metrics.update_state(y, logits)
    return loss_value

class LoopControlerCallback(tf.keras.callbacks.Callback):
    def __init__(
        self, config: int, default_in_branch: Dict[str, Any] = None, verbose: bool = True, *args, **kwargs
    ) -> None:
        super(LoopControlerCallback, self).__init__(*args, **kwargs)
        self.default_in_branch: Dict[str, Any] = default_in_branch if default_in_branch else {
                "loss": None, 
                "clipping": None,
                "variables": None,
                "exclud_var": None,
            }
        self.config: Dict[str, Any] = config
        self.verbose: bool = verbose

    def on_train_begin(self, logs=None):
        """Function called directely before training. It is executed on each call of model.fit with this callback.
            Inside the scope of this funciton we can access model on which this callback works: self.model.

        Args:
            logs (_type_, optional): _description_. Defaults to None.
        """

        # meta attributes for building conditions
        self.epochs: int = 0
        self.batches: int = 0
        self.batches_in_epoch: int = 0
        self.last_loss: int = 0.0
        self.history: List[Any] = []
        
        # if loss in default branch is None, then use compiled loss
        if not self.default_in_branch["loss"]:
            self.default_in_branch["loss"] = self.model.compiled_loss

        self.model.cv_names = []
        self.c_conds = {}

        # extend config with validation step
        self.config = self._extend_config(self.config)
        # initiate control variables
        self._init_cv(self.config)
        # bind control variables to model
        self._bind_cv(self.config)
        # bind master train_step to model
        self._bind_master_step(self.config)
        # bind placeholders for list of variables for slave steps
        self._bind_slaves_steps_dif_variables(self.config)
        # bind slave train_steps to model
        self._bin_slaves_steps()

        # test buingin vars
        setattr(self.model, "gate_on_vars", None)
        bind(self.model, dummy_vars, "get_gate_on_vars")
        bind(self.model, gate_on)

    def _extend_config(self, config: Dict[str, Any]) -> None:
        """Extend and validate config file. Fill missing fields based on the default_in_branch.

        Args:
            config (Dict[str, Any]): Configuration to control model training
        """

        def validate_action_config(action_name: str, action_config: Dict[str, Any]) -> None:
            """Validate model training configuration.

            Args:
                action_name (str): name of the action slave train step
                action_config (Dict[str, Any]): configuration of the action slave train step

            Raises:
                ValueError: Missing controlable cond
                ValueError: Missing branch configuration for true/false after cond
            """

            Warning(f"------Validating Configuration for {action_name}------")
            if action_config == {}:
                Warning(
                    f"{action_name} has empty body. Condition and False or True branch must be implemented.\n It's ignored in furhter computations"
                )
            if (True not in action_config) and (False not in action_config):
                raise ValueError(
                    f"{action_name} has no False or True branch implemented"
                )
            if "cond" not in action_config:
                raise ValueError(f"{action_name} has no condition implemented.")

        pc = {}
        for action_name, action_config in config.items():
            validate_action_config(action_name, action_config)
            pc[action_name] = {"cond": action_config["cond"]}
            if True in action_config:
                pc[action_name][True] = {**self.default_in_branch, **action_config[True]}
            if False in action_config:
                pc[action_name][False] = {**self.default_in_branch, **action_config[False]}

        return pc

    def _get_cv_name(self, action_name: str) -> str:
        """Get control variable from action name

        Args:
            action_name (str): name of the action slave train step

        Returns:
            str: converted name of the action slave train step
        """
        return f"{action_name}_cv"

    def _get_cc_names(self, condition_name: str) -> Tuple[str, str]:
        return (f"{condition_name}_off", f"{condition_name}_on")

    def _init_cv(self, config) -> None:
        for cv_name in config.keys():
            setattr(self.model, cv_name, tf.Variable(False, trainable=False))
            self.model.cv_names.append(cv_name)

    def _bind_cv(self, config) -> None:
        for action_name, action_config in config.items():
            name = self._get_cv_name(action_name)
            bind(self, action_config["cond"], name)
            self.c_conds[action_name] = name

    def _bind_master_step(self, config) -> None:
        lscope = locals()
        def _get_losses(config: Dict[str, Any]) -> str:
            def _substeps_condition_execution(name: str, conf: Dict[str, Any], on: bool) -> str:
                if on:
                    return f"self.{name}_on(data)" if True in conf else "0.0"
                else:
                    return f"self.{name}_off(data)" if False in conf else "0.0"

            return "{" + ",".join([
                        f"'loss_{an}' : tf.cond(self.{an}, lambda: {_substeps_condition_execution(an, ac, True)}, lambda: {_substeps_condition_execution(an, ac, False)})"
                        for an, ac in config.items()
                    ]) + "}"

        lscope = locals()
        function_body = """
@tf.function
def train_step(self, data):
    loss = {losses_config}
    metrics = {{m.name : m.result() for m in self.metrics}}
    
    control_states = {{
            c_name: tf.cond(
                getattr(self, c_name),
                lambda: tf.constant(True),
                lambda: tf.constant(False),
            )
            for c_name in self.cv_names
    }}
    
    return {{**loss, **metrics, **control_states}}
""".format(
            **{"losses_config": _get_losses(config)}
        )

        if self.verbose:
            print("-------------------MASTER STEP-------------------")
            print(function_body)
        
        exec(function_body, {**globals(), **lscope}, lscope)
        bind(self.model, lscope["train_step"])

    def _get_name_for_sleve_step_variables(self, name, on = None):
        if on == None:
            return f"{name}_vars"
        return f"{name}_{'on' if on else 'off'}_vars"

    def _bind_slaves_steps_dif_variables(self, config):
        for action_name, action_config in config.items():
            if action_config[True] and "variables" in action_config[True] and action_config[True]["variables"]:
                # we have to bind None because variables before trainign are not initialised
                setattr(self.model, self._get_name_for_sleve_step_variables(action_name, True), action_config[True]["variables"])
            
            if False in action_config and "variables" in action_config[False] and action_config[False]["variables"]:
                # we have to bind None because variables before trainign are not initialised
                setattr(self.model, self._get_name_for_sleve_step_variables(action_name, False), action_config[False]["variables"])

    def _bin_slaves_steps(self) -> None:
        if self.verbose:
            print("-------------------SLAVE STEPS-------------------\n")
        
        for action_name, action_config in self.config.items():
            off_step_name, on_step_name = self._get_cc_names(action_name)
            if True in action_config:
                self._bind_slave_step(on_step_name, action_config[True], True)
            if False in action_config:
                self._bind_slave_step(off_step_name, action_config[False], False)

    def _bind_slave_step(self, fn_name: str, fn_config: Dict[str, Any], true_branch: bool) -> None:
        
        lscope = {
            **locals(), 
            **fn_config
            }
        
        print(lscope)

        function_body = f"""
@tf.function
def {fn_name}(self, data):
    x, y = data
        """

        if fn_config["variables"]:
            vars_name = self._get_name_for_sleve_step_variables(fn_name)
            get_vars_fname=f"get_{fn_name}_vars"
            setattr(self.model, vars_name, None)
            bind(self.model, lscope["variables"], get_vars_fname)
            lscope[get_vars_fname] = lscope["variables"]
            function_body += f"""
    print("----------!!!----------")
    print(self.{vars_name})
    if not self.{vars_name}:
        self.{vars_name}=self.{get_vars_fname}()
    print(self.{vars_name}())
    
    with tf.GradientTape(watch_accessed_variables=False) as tape:
        for g in self.{vars_name}:
            tape.watch(g)
            """
        else:
            function_body += """
    with tf.GradientTape(watch_accessed_variables=True) as tape:
            """

        function_body += """
        logits = self(x, training=True)
        loss_value = loss(y, logits)
    grads = tape.gradient(loss_value, tape.watched_variables())
    self.optimizer.apply_gradients(zip({clipping_grads}, tape.watched_variables()))
    self.compiled_metrics.update_state(y, logits)
    return loss_value
""".format(**{
    "clipping_grads": "[tf.clip_by_value(g, {clip_low}, {clip_high}) for g in grads]".format(**{
        "clip_low": fn_config["clipping"][0], "clip_high": fn_config["clipping"][1]
        }) if fn_config["clipping"] else "grads",
})

        if self.verbose:
            print(f"-------------------{fn_name}-------------------")
            print(fn_config["variables"])
            print(function_body)
        
        if not fn_config["variables"]:
            exec(function_body, {**globals(), **lscope}, lscope)
            bind(self.model, lscope[fn_name])

    def on_epoch_begin(self, epoch: int, logs) -> None:
        self.epochs += 1
        """Control gating variable from the level of callback which can work on epoch/batch level."""
        # tf.variable.assign is different than tf.variable = <sth>. The second option is compiled to static
        # value in TF graph of computation as the result of @tf.function decorators in LoopControlableModel
        for control_name, control_function_name in self.c_conds.items():
            getattr(self.model, control_name).assign(
                getattr(self, control_function_name)()
            )
    def on_epoch_end(self, epoch, logs):
        print(self.model.variables)

In [186]:
DATASET_SIZE, INPUT_SIZE, OUTPUT_SIZE = 1000, 2, 1
BATCH_SIZE = 64
data = tf.data.Dataset.from_tensor_slices(
    (tf.random.uniform((DATASET_SIZE, INPUT_SIZE)), tf.random.uniform((DATASET_SIZE, OUTPUT_SIZE)))
    ).batch(BATCH_SIZE)

class TM(tf.keras.Model):
    def __init__(self) -> None:
        super().__init__()
        

model = tf.keras.Sequential([tf.keras.layers.Dense(OUTPUT_SIZE)])
# compile model
model.compile(optimizer=
    tf.keras.optimizers.RMSprop(learning_rate=0.01),
    loss="mse", 
    metrics=["mae"])

def get_gate_on_vars(self):
    return self.variables[:2]

bind(model, get_gate_on_vars)

def gate_config(self):
    return self.epochs % 2 == 1

def delay_config(self):
    return self.epochs > 5

def dummy_vars(self):
    return [self.variables[0]]

config = {
    "gate": {
        "cond": gate_config,
        True: {
            "loss": tf.keras.losses.MeanAbsoluteError(),
            "clipping": (-0.2, 0.3),
            # "variables": dummy_vars
        },
        False: {}
        # False: {
        #     # "loss": tf.keras.losses.MeanAbsoluteError(), 
        # }
    },
    "delay": {
        "cond": delay_config,
        True: {
            "loss": tf.keras.losses.MeanSquaredError(), 
        },
    },

}

lcc = LoopControlerCallback(config)
# start training
history = model.fit(data, epochs = 10, verbose = 1,
    callbacks=[lcc])

-------------------MASTER STEP-------------------

@tf.function
def train_step(self, data):
    loss = {'loss_gate' : tf.cond(self.gate, lambda: self.gate_on(data), lambda: self.gate_off(data)),'loss_delay' : tf.cond(self.delay, lambda: self.delay_on(data), lambda: 0.0)}
    metrics = {m.name : m.result() for m in self.metrics}
    
    control_states = {
            c_name: tf.cond(
                getattr(self, c_name),
                lambda: tf.constant(True),
                lambda: tf.constant(False),
            )
            for c_name in self.cv_names
    }
    
    return {**loss, **metrics, **control_states}

-------------------SLAVE STEPS-------------------

{'self': <__main__.LoopControlerCallback object at 0x7ff457953d90>, 'fn_name': 'gate_on', 'fn_config': {'loss': <keras.losses.MeanAbsoluteError object at 0x7ff478638e20>, 'clipping': (-0.2, 0.3), 'variables': None, 'exclud_var': None}, 'true_branch': True, 'loss': <keras.losses.MeanAbsoluteError object at 0x7ff478638e20

ValueError: in user code:

    File "/home/filip/workspace/tf/venv/lib/python3.8/site-packages/keras/engine/training.py", line 1051, in train_function  *
        return step_function(self, iterator)
    File "/tmp/ipykernel_7363/1499073282.py", line 8, in gate_on  *
        self.gate_on_vars=self.get_gate_on_vars()
    File "/tmp/ipykernel_7363/3443802096.py", line 31, in dummy_vars  **
        return [self.variables[0]]
    File "/home/filip/workspace/tf/venv/lib/python3.8/site-packages/keras/engine/base_layer.py", line 2066, in variables
        return self.weights
    File "/home/filip/workspace/tf/venv/lib/python3.8/site-packages/keras/engine/training.py", line 2829, in weights
        return self._dedup_weights(self._undeduplicated_weights)
    File "/home/filip/workspace/tf/venv/lib/python3.8/site-packages/keras/engine/training.py", line 2834, in _undeduplicated_weights
        self._assert_weights_created()
    File "/home/filip/workspace/tf/venv/lib/python3.8/site-packages/keras/engine/sequential.py", line 472, in _assert_weights_created
        super(functional.Functional, self)._assert_weights_created()  # pylint: disable=bad-super-call
    File "/home/filip/workspace/tf/venv/lib/python3.8/site-packages/keras/engine/training.py", line 3027, in _assert_weights_created
        raise ValueError(f'Weights for model {self.name} have not yet been '

    ValueError: Weights for model sequential_79 have not yet been created. Weights are created when the Model is first called on inputs or `build()` is called with an `input_shape`.


In [187]:
model.variables

ValueError: Weights for model sequential_79 have not yet been created. Weights are created when the Model is first called on inputs or `build()` is called with an `input_shape`.

In [142]:
model.get_gate_on_vars()

ValueError: Weights for model sequential_59 have not yet been created. Weights are created when the Model is first called on inputs or `build()` is called with an `input_shape`.

In [113]:
model.get_gate_on_vars()

ValueError: Weights for model sequential_45 have not yet been created. Weights are created when the Model is first called on inputs or `build()` is called with an `input_shape`.

In [111]:
model.__dir__()

['_self_setattr_tracking',
 '_is_model_for_instrumentation',
 '_instrumented_keras_api',
 '_instrumented_keras_layer_class',
 '_instrumented_keras_model_class',
 '_trainable',
 '_stateful',
 'built',
 '_input_spec',
 '_build_input_shape',
 '_saved_model_inputs_spec',
 '_saved_model_arg_spec',
 '_supports_masking',
 '_name',
 '_activity_regularizer',
 '_trainable_weights',
 '_non_trainable_weights',
 '_updates',
 '_thread_local',
 '_callable_losses',
 '_losses',
 '_metrics',
 '_metrics_lock',
 '_dtype_policy',
 '_compute_dtype_object',
 '_autocast',
 '_self_tracked_trackables',
 '_inbound_nodes_value',
 '_outbound_nodes_value',
 '_expects_training_arg',
 '_default_training_arg',
 '_expects_mask_arg',
 '_dynamic',
 '_initial_weights',
 '_auto_track_sub_layers',
 '_preserve_input_structure_in_config',
 '_name_scope_on_declaration',
 '_captured_weight_regularizer',
 '_is_graph_network',
 'inputs',
 'outputs',
 'input_names',
 'output_names',
 '_compute_output_and_mask_jointly',
 '_distribu

In [None]:
        
        lscope = {
            **locals(), 
            **fn_config
            }
        
        vars_name = self._get_name_for_sleve_step_variables(fn_name)
        function_body = f"""
@tf.function
def {fn_name}(self, data):
    x, y = data        
        """

        if fn_config["variables"]:
            lscope["action_vars"] = getattr(self.model, vars_name)
            lscope["get_slave_vars"] = lscope["variables"]
            function_body += f"""
    if not self.{vars_name}:
        self.{vars_name}=get_slave_vars()
    with tf.GradientTape(watch_accessed_variables=False) as tape:
        for g in self.{vars_name}:
            tape.watch(g)
            """
        else:
            function_body += """
    with tf.GradientTape(watch_accessed_variables=True) as tape:
            """

        function_body += """
        logits = self(x, training=True)
        loss_value = loss(y, logits)
    grads = tape.gradient(loss_value, tape.watched_variables())
    self.optimizer.apply_gradients(zip({clipping_grads}, tape.watched_variables()))
    self.compiled_metrics.update_state(y, logits)
    return loss_value
""".format(**{
    "clipping_grads": "[tf.clip_by_value(g, {clip_low}, {clip_high}) for g in grads]".format(**{
        "clip_low": fn_config["clipping"][0], "clip_high": fn_config["clipping"][1]
        }) if fn_config["clipping"] else "grads",
})


In [None]:
        function_body = """
@tf.function
def {fn_name}(self, data):
    {variables}
    x, y = data
    with tf.GradientTape(watch_accessed_variables=True) as tape:
        logits = self(x, training=True)
        loss_value = loss(y, logits)
    grads = tape.gradient(loss_value, tape.watched_variables())
    self.optimizer.apply_gradients(zip({clipping_grads}, tape.watched_variables()))
    self.compiled_metrics.update_state(y, logits)
    return loss_value
""".format(**{
    "fn_name": fn_name,
    "clipping_grads": "[tf.clip_by_value(g, {clip_low}, {clip_high}) for g in grads]".format(**{"clip_low": fn_config["clipping"][0], "clip_high": fn_config["clipping"][1]}) if fn_config["clipping"] else "grads",
    "variables": "if not self.{var_names}:\n        self.{var_names}=variables()\n".format(**{
        "variables" : lscope["variables"], "var_names" : vars_name
        }) if lscope["variables"] else "",
    "test" : self._get_name_for_sleve_step_variables(fn_name)
})

In [32]:
class LoopControlerCallback(tf.keras.callbacks.Callback):
    def __init__(
        self, config: int, default_in_branch: Dict[str, Any] = None, verbose: bool = True, *args, **kwargs
    ) -> None:
        super(LoopControlerCallback, self).__init__(*args, **kwargs)
        self.default_in_branch: Dict[str, Any] = default_in_branch if default_in_branch else {
                "loss": None, 
                "clipping": None,
                "variables": None,
                "exclud_var": None,
            }
        self.config: Dict[str, Any] = config
        self.verbose: bool = verbose


    def on_train_begin(self, logs=None):
        """Function called directely before training. It is executed on each call of model.fit with this callback.
            Inside the scope of this funciton we can access model on which this callback works: self.model.

        Args:
            logs (_type_, optional): _description_. Defaults to None.
        """

        # meta attributes for building conditions
        self.epochs: int = 0
        self.batches: int = 0
        self.batches_in_epoch: int = 0
        self.last_loss: int = 0.0
        self.history: List[Any] = []

        # control variables list of names 
        self.model.cv_names = []
        # 
        self.c_conds = {}
        
        # if loss in default branch is None, then use compiled loss
        if not self.default_in_branch["loss"]:
            self.default_in_branch["loss"] = self.model.compiled_loss
        
        # extend config with validation step
        self.config = self._extend_actions_config(self.config)
        # bind control variables to the model
        self._bind_slave_variables(self.config)
        # bind slave control confition to the callback
        self._bind_slave_conditions(self.config)
        # bind master train_step to model
        self._bind_master_step(self.config)
        # bind placeholders for list of variables for slave steps
        self._bind_slaves_steps_dif_variables(self.config)
        # bind slave train_steps to model
        self._bin_slaves_steps()

    def _get_slave_variable_name(self, action_name: str) -> str:
        """Get control variable from action name

        Args:
            action_name (str): name of the action slave train step

        Returns:
            str: converted name of the action slave train step
        """
        return f"{action_name}_slave_var"

    def _get_slave_condition_names(self, condition_name: str) -> Tuple[str, str]:
        return (f"{condition_name}_off", f"{condition_name}_on")

    def _bind_slave_variables(self, config) -> None:
        for cv_name in config.keys():
            name = self._get_slave_variable_name(cv_name)
            setattr(self.model, name, tf.Variable(False, trainable=False))
            self.model.cv_names.append(name)

    def _get_slave_condition_name(self, action_name: str) -> str:
        """Get control variable from action name

        Args:
            action_name (str): name of the action slave train step

        Returns:
            str: converted name of the action slave train step
        """
        return f"{action_name}_slave_cond" 

    def _bind_slave_conditions(self, config) -> None:
        for action_name, action_config in config.items():
            name = self._get_slave_condition_name(action_name)
            bind(self, action_config["cond"], name)
            self.c_conds[self._get_slave_variable_name(action_name)] = name

    def _bind_master_step(self, config) -> None:
        lscope = locals()
        def _get_losses(config: Dict[str, Any]) -> str:
            def _substep_exec(name: str, conf: Dict[str, Any], on: bool) -> str:
                if on:
                    return f"self.{name}_on(data)" if True in conf else "0.0"
                else:
                    return f"self.{name}_off(data)" if False in conf else "0.0"

            return "{" + ",".join([
                        f"'loss_{an}' : tf.cond(self.{self._get_slave_variable_name(an)}, lambda: {_substep_exec(an, ac, True)}, lambda: {_substep_exec(an, ac, False)})"
                        for an, ac in config.items()
                    ]) + "}"

        lscope = locals()
        function_body = """
@tf.function
def train_step(self, data):
    loss = {losses_config}
    metrics = {{m.name : m.result() for m in self.metrics}}
    
    control_states = {{
            c_name: tf.cond(
                getattr(self, c_name),
                lambda: tf.constant(True),
                lambda: tf.constant(False),
            )
            for c_name in self.cv_names
    }}
    
    return {{**loss, **metrics, **control_states}}
""".format(
            **{"losses_config": _get_losses(config)}
        )
        if self.verbose:
            print("-------------------MASTER STEP-------------------")
            print(function_body)
        exec(function_body, {**globals(), **lscope}, lscope)
        bind(self.model, lscope["train_step"])

    def _get_name_for_sleve_step_variables(self, name, on = None):
        if on == None:
            return f"{name}_vars"
        return f"{name}_{'on' if on else 'off'}_vars"

    def _bind_slaves_steps_dif_variables(self, config):
        for action_name, action_config in config.items():
            if action_config[True] and "variables" in action_config[True] and action_config[True]["variables"]:
                # we have to bind None because variables before trainign are not initialised
                setattr(self.model, self._get_name_for_sleve_step_variables(action_name, True), action_config[True]["variables"])


            if False in action_config and "variables" in action_config[False] and action_config[False]["variables"]:
                # we have to bind None because variables before trainign are not initialised
                setattr(self.model, self._get_name_for_sleve_step_variables(action_name, False), action_config[False]["variables"])

    def _bin_slaves_steps(self) -> None:
        if self.verbose:
            print("-------------------SLAVE STEPS-------------------\n")
        for action_name, action_config in self.config.items():
            off_step_name, on_step_name = self._get_slave_condition_names(action_name)
            if True in action_config:
                self._bind_slave_step(on_step_name, action_config[True])
            if False in action_config:
                self._bind_slave_step(off_step_name, action_config[False])

    def _bind_slave_step(self, fn_name: str, fn_config: Dict[str, Any]) -> None:
        lscope = {
            **locals(), 
            **fn_config,
            **{"vars_control_name": self._get_name_for_sleve_step_variables(fn_name)}
            }
        # print(getattr(self.model, vars_control_name)())
        function_body = """
@tf.function
def {fn_name}(self, data):
    x, y = data
    with tf.GradientTape(watch_accessed_variables=True) as tape:
        logits = self(x, training=True)
        loss_value = loss(y, logits)
    grads = tape.gradient(loss_value, tape.watched_variables())
    self.optimizer.apply_gradients(zip({clipping_grads}, tape.watched_variables()))
    self.compiled_metrics.update_state(y, logits)
""".format(**{
    "fn_name": fn_name,
    "clipping_grads": "[tf.clip_by_value(g, {clip_low}, {clip_high}) for g in grads]".format(**{"clip_low": fn_config["clipping"][0], "clip_high": fn_config["clipping"][1]}) if fn_config["clipping"] else "grads",
    "variables": "if not getattr(self.model, vars_control_name):\n      getattr(self.model, vars_control_name)=variables()\n" if fn_config["variables"] else "",
    "test" : self._get_name_for_sleve_step_variables(fn_name)

})
        if self.verbose:
            print(f"-------------------{fn_name}-------------------")
            print(fn_config["variables"])
            print(function_body)
            
        exec(function_body, {**globals(), **lscope}, lscope)
        bind(self.model, lscope[fn_name])

    def _extend_actions_config(self, config: Dict[str, Any]) -> None:       
        """Extend and validate config file. Fill missing fields based on the default_in_branch.

        Args:
            config (Dict[str, Any]): Configuration to control model training
        """

        def _validate_action_config(action_name: str, action_config: Dict[str, Any]) -> None:
            """Validate model training configuration.

            Args:
                action_name (str): name of the action slave train step
                action_config (Dict[str, Any]): configuration of the action slave train step

            Raises:
                ValueError: Missing controlable cond
                ValueError: Missing branch configuration for true/false after cond
            """

            Warning(f"------Validating Configuration for {action_name}------")
            if action_config == {}:
                Warning(
                    f"{action_name} has empty body. Condition and False or True branch must be implemented.\n It's ignored in furhter computations"
                )
            if (True not in action_config) and (False not in action_config):
                raise ValueError(
                    f"{action_name} has no False or True branch implemented"
                )
            if "cond" not in action_config:
                raise ValueError(f"{action_name} has no condition implemented.")

        extended_config = {}
        for action_name, action_config in config.items():
            _validate_action_config(action_name, action_config)
            extended_config[action_name] = {"cond": action_config["cond"]}
            if True in action_config:
                extended_config[action_name][True] = {**self.default_in_branch, **action_config[True]}
            if False in action_config:
                extended_config[action_name][False] = {**self.default_in_branch, **action_config[False]}

        return extended_config

    def on_epoch_begin(self, epoch: int, logs) -> None:
        self.epochs += 1
        """Control gating variable from the level of callback which can work on epoch/batch level."""
        # tf.variable.assign is different than tf.variable = <sth>. The second option is compiled to static
        # value in TF graph of computation as the result of @tf.function decorators in LoopControlableModel
        # print(self.c_conds)
        print(self.c_conds)
        for control_variable_name, control_function_name in self.c_conds.items():
            print(getattr(self.model, control_variable_name))
            print(getattr(self, control_function_name)())

            getattr(self.model, control_variable_name).assign(
                getattr(self, control_function_name)()
            )

In [33]:
DATASET_SIZE, INPUT_SIZE, OUTPUT_SIZE = 1000, 2, 1
BATCH_SIZE = 64
data = tf.data.Dataset.from_tensor_slices(
    (tf.random.uniform((DATASET_SIZE, INPUT_SIZE)), tf.random.uniform((DATASET_SIZE, OUTPUT_SIZE)))
    ).batch(BATCH_SIZE)


model = tf.keras.Sequential([tf.keras.layers.Dense(OUTPUT_SIZE)])
# compile model
model.compile(optimizer=
    tf.keras.optimizers.RMSprop(learning_rate=0.01),
    loss="mse", 
    metrics=["mae"])


def gate_config(self):
    return self.epochs % 2 == 1

def delay_config(self):
    return self.epochs > 5

def dummy_vars(self,):
    return [self.variables[0]]

config = {
    "gate": {
        "cond": gate_config,
        True: {
            "loss": tf.keras.losses.MeanAbsoluteError(),
            "clipping": (-0.2, 0.3),
            "variables": dummy_vars
        },
        False: {}
        # False: {
        #     # "loss": tf.keras.losses.MeanAbsoluteError(), 
        # }
    },
    "delay": {
        "cond": delay_config,
        True: {
            "loss": tf.keras.losses.MeanSquaredError(), 
        },
    },

}

lcc = LoopControlerCallback(config)
# start training
history = model.fit(data, epochs = 10, verbose = 1,
    callbacks=[lcc])

-------------------MASTER STEP-------------------

@tf.function
def train_step(self, data):
    loss = {'loss_gate' : tf.cond(self.gate_slave_var, lambda: self.gate_on(data), lambda: self.gate_off(data)),'loss_delay' : tf.cond(self.delay_slave_var, lambda: self.delay_on(data), lambda: 0.0)}
    metrics = {m.name : m.result() for m in self.metrics}
    
    control_states = {
            c_name: tf.cond(
                getattr(self, c_name),
                lambda: tf.constant(True),
                lambda: tf.constant(False),
            )
            for c_name in self.cv_names
    }
    
    return {**loss, **metrics, **control_states}

-------------------SLAVE STEPS-------------------

-------------------gate_on-------------------
<function dummy_vars at 0x7ff4986234c0>

@tf.function
def gate_on(self, data):
    x, y = data
    with tf.GradientTape(watch_accessed_variables=True) as tape:
        logits = self(x, training=True)
        loss_value = loss(y, logits)
    grads = tape.g

TypeError: in user code:

    File "/home/filip/workspace/tf/venv/lib/python3.8/site-packages/keras/engine/training.py", line 1051, in train_function  *
        return step_function(self, iterator)
    File "/home/filip/workspace/tf/venv/lib/python3.8/site-packages/keras/engine/training.py", line 1040, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/home/filip/workspace/tf/venv/lib/python3.8/site-packages/keras/engine/training.py", line 1030, in run_step  **
        outputs = model.train_step(data)
    File "<string>", line 4, in train_step
        

    TypeError: true_fn and false_fn arguments to tf.cond must have the same number, type, and overall structure of return values.
    
    true_fn output: Tensor("cond_1/Identity_1:0", shape=(), dtype=bool)
    false_fn output: Tensor("cond_1/Identity:0", shape=(), dtype=float32)
    
    Error details:
    Tensor("cond_1/Identity_1:0", shape=(), dtype=bool) and Tensor("cond_1/Identity:0", shape=(), dtype=float32) have different types


In [31]:
model.gate

<keras.engine.sequential.Sequential at 0x7ff498651b50>

In [86]:
config = {
    "gate": {
        "cond": gate_config,
        True: {
            "loss": tf.keras.losses.MeanSquaredError(), 
        },
        False: {
            "loss": tf.keras.losses.MeanAbsoluteError(), 
        }
    },
    "delay": {
        "cond": delay_config,
        True: {
            "loss": tf.keras.losses.MeanSquaredError(), 
        },
    },
    "test": {
        "cond": delay_config,
        True: {}
    },
    # "test2": {
    #     "cond": delay_config,
    # },

}

def validate_action_config(action_name, action_config) -> bool:
        Warning(f"------Validating Configuration for {action_name}------")
        if action_config == {}:
            Warning(f"{action_name} has empty body. Condition and False or True branch must be implemented.\n It's ignored in furhter computations")
        if ((True not in action_config) and (True not in action_config)):
            raise ValueError(f"{action_name} has no False or True branch implemented")
        if "cond" not in action_config:
            raise ValueError(f"{action_name} has no condition implemented.")
        return True


            

def parse_config(config, default_in_branch = None):
    if not default_in_branch:
        default_in_branch = {
            "loss" : "self.compiled_loss",
            "clipping" : None, 
            "var": None,
            "exclud_var": None,
        }

    pc = {}
    for action_name, action_config in config.items():
        _ = validate_action_config(action_name, action_config)
        false_branch, true_branch = {}, {}
        true_branch = {**default_in_branch,**action_config[True]} if True in action_config else None
        false_branch = {**default_in_branch,**action_config[False]} if False in action_config else None
        pc[action_name] = {"cond": action_config["cond"], True: true_branch, False: false_branch}

    return pc 
            
parse_config(config)

{'gate': {'cond': <function __main__.gate_config(self)>,
  True: {'loss': <keras.losses.MeanSquaredError at 0x7f4c4d08da00>,
   'clipping': None,
   'var': None,
   'exclud_var': None},
  False: {'loss': <keras.losses.MeanAbsoluteError at 0x7f4c4d08d490>,
   'clipping': None,
   'var': None,
   'exclud_var': None}},
 'delay': {'cond': <function __main__.delay_config(self)>,
  True: {'loss': <keras.losses.MeanSquaredError at 0x7f4c4d08ddf0>,
   'clipping': None,
   'var': None,
   'exclud_var': None},
  False: None},
 'test': {'cond': <function __main__.delay_config(self)>,
  True: {'loss': 'self.compiled_loss',
   'clipping': None,
   'var': None,
   'exclud_var': None},
  False: None}}

In [62]:
False not in {False}

False

In [81]:
from inspect import signature

def foo(x: int = 1, y: int = 2) -> bool:
    return x > y
sig = signature(foo)
sig.return_annotation

bool

In [None]:
@tf.function
def gate_on(self, data):
    if not self.gate_var:
	getattr(self, self._get_name_for_sleve_step_variables(gate_on))=variables

    x, y = data
    with tf.GradientTape(watch_accessed_variables=True) as tape:
        logits = self(x, training=True)
        loss_value = loss(y, logits)
    grads = tape.gradient(loss_value, tape.watched_variables())
    self.optimizer.apply_gradients(zip([tf.clip_by_value(g, -0.2, 0.3) for g in grads], tape.watched_variables()))
    self.compiled_metrics.update_state(y, logits)
    return loss_value

In [None]:
import time
class B:
    def __init__(self) -> None:
        self.x = 2
    
    def boo(self):
        
        bb = tf.keras.losses.MeanAbsoluteError()
        adict = locals()
        # print(adict)
        function_body = f"""
def train(self):
    print(bb([1,2,3], [3,4,5]))
    return self.x
"""
        exec(function_body,{**globals(), **adict}, adict)
        # print(adict)
        # time.sleep(0.1)
        bind(self, adict["train"])
        print(self.train())
        # bind(self, test3)
        # bind(self, test)
        return True

b = B()
b.boo()
# self.train_substep()

tf.Tensor(2, shape=(), dtype=int32)
2


True