## TensorFlow2 training loop control using default *tf.fit(...)* function

### Task Descripton

Up to now custom training loop in Tensorflow2 requires writing two lops:
1. loop iterating through epochs 
2. loop iterating through batches 

Then all castom training precudere will have to be implemented in these double-loop block of code. It's neither elegant nor robust due to the missing advanced features of *tf.fit(...)*.

In [114]:
import tensorflow as tf
from types import MethodType
import functools

In [115]:
# class LoopControlableModel(tf.keras.Model):

#     def __init__(self, *args, **kwargs):
#         super(LoopControlableModel, self).__init__(*args, **kwargs)
#         self.gate = tf.Variable(False, trainable=False) # gate control variable
    
#     @tf.function
#     def train_step(self, data):
#         train_metrics = tf.cond(
#             self.gate, 
#             lambda: self.train_step_active(data),
#             lambda: self.train_step_passive(data)
#         )

#         return train_metrics
    

In [116]:
def assign_to_instance(instance):
    def decorator_assign_to_instance(func):
        @functools.wraps(func)
        def wrapper_assign_to_instance(*args, **kwargs):
            """consider funciton scope as now its functools.partial
            functools.update_wrapper(tf.function, functools.partial(func, *args_, **kwargs))
            """
            args = (instance,) + args
            setattr(instance, func.__name__, functools.partial(func, *args, **kwargs))
            return
        return wrapper_assign_to_instance
    return decorator_assign_to_instance

In [125]:
class LoopControlerCallback(tf.keras.callbacks.Callback):

    def __init__(self, gating_frequency: int, *args, **kwargs) -> None:
        super(LoopControlerCallback, self).__init__(*args, **kwargs)
        self.gating_frequency = gating_frequency

    def on_train_begin(self, logs = None):
        
        self.model.gate = tf.Variable(False, trainable=False) # gate control variable

        @assign_to_instance(self.model)
        @tf.function
        def train_step(self, data):
            train_metrics = tf.cond(
                self.gate, 
                lambda: self.train_step_active(data),
                lambda: self.train_step_passive(data)
            )

            return train_metrics 

        @assign_to_instance(self.model)
        @tf.function
        def train_step_active(self, data):
            x, y = data
            with tf.GradientTape(watch_accessed_variables=True) as tape:
                logits = self(x, training=True)
                loss_value = self.compiled_loss(y, logits)
            grads = tape.gradient(loss_value, tape.watched_variables())
            self.optimizer.apply_gradients(zip(grads, tape.watched_variables()))
            return {**{m.name: m.result() for m in self.metrics}, **{"active": True, "passive": False}}

        @assign_to_instance(self.model)
        @tf.function
        def train_step_passive(self, data):
            x, y = data
            with tf.GradientTape(watch_accessed_variables=True) as tape:
                logits = self(x, training=True)
                loss_value = self.compiled_loss(y, logits)
            grads = tape.gradient(loss_value, tape.watched_variables())
            self.optimizer.apply_gradients(zip(grads, tape.watched_variables()))
            return {**{m.name: m.result() for m in self.metrics}, **{"active": False, "passive": True}}
            
        train_step()
        train_step_active()
        train_step_passive()


    def on_epoch_end(self, epoch, logs):
        """Control gating variable from the level of callback which can work on epoch/batch level."""
        # tf.variable.assign is different than tf.variable = <sth>. The second option is compiled to static
        # value in TF graph of computation as the result of @tf.function decorators in LoopControlableModel
        self.model.gate.assign(epoch % self.gating_frequency == 0)

In [130]:
DATASET_SIZE, INPUT_SIZE, OUTPUT_SIZE = 1000, 2, 1
BATCH_SIZE = 64
data = tf.data.Dataset.from_tensor_slices(
    (tf.random.uniform((DATASET_SIZE, INPUT_SIZE)), tf.random.uniform((DATASET_SIZE, OUTPUT_SIZE)))
    ).batch(BATCH_SIZE)


model = tf.keras.Sequential([tf.keras.layers.Dense(OUTPUT_SIZE)])

# compile model
model.compile(optimizer=
    tf.keras.optimizers.RMSprop(learning_rate=0.01),
    loss="mse", 
    metrics=["mae"])

# start training
history = model.fit(data, epochs = 4, verbose = 1,
    callbacks=[LoopControlerCallback(2)])

Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4
