## TensorFlow2 training loop control using default *tf.fit(...)* function

### Task Descripton

Up to now custom training loop in Tensorflow2 requires writing two lops:
1. loop iterating through epochs 
2. loop iterating through batches 

Then all castom training precudere will have to be implemented in these double-loop block of code. It's neither elegant nor robust due to the missing advanced features of *tf.fit(...)*.

In [2]:
import tensorflow as tf

2023-01-04 09:50:45.340284: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-01-04 09:50:45.340323: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


In [25]:
class LoopControlableModel(tf.keras.Model):

    def __init__(self, *args, **kwargs):
        super(LoopControlableModel, self).__init__(*args, **kwargs)
        self.gate = tf.Variable(False, trainable=False) # gate control variable
    
    @tf.function
    def train_step(self, data):
        train_metrics = tf.cond(
            self.gate, 
            lambda: self.train_step_active(data),
            lambda: self.train_step_passive(data)
        )

        return train_metrics

    @tf.function
    def train_step_active(self, data):
        x, y = data
        with tf.GradientTape(watch_accessed_variables=True) as tape:
            logits = self(x, training=True)
            loss_value = self.compiled_loss(y, logits)
        grads = tape.gradient(loss_value, tape.watched_variables())
        self.optimizer.apply_gradients(zip(grads, tape.watched_variables()))
        return {**{m.name: m.result() for m in self.metrics}, **{"active": True, "passive": False}}

    @tf.function
    def train_step_passive(self, data):
        x, y = data
        with tf.GradientTape(watch_accessed_variables=True) as tape:
            logits = self(x, training=True)
            loss_value = self.compiled_loss(y, logits)
        grads = tape.gradient(loss_value, tape.watched_variables())
        self.optimizer.apply_gradients(zip(grads, tape.watched_variables()))
        return {**{m.name: m.result() for m in self.metrics}, **{"active": False, "passive": True}}
    

In [26]:
class LoopControlerCallback(tf.keras.callbacks.Callback):

    def __init__(self, gating_frequency: int, *args, **kwargs) -> None:
        super(LoopControlerCallback, self).__init__(*args, **kwargs)
        self.gating_frequency = gating_frequency

    def on_epoch_end(self, epoch, logs):
        """Control gating variable from the level of callback which can work on epoch/batch level."""
        # tf.variable.assign is different than tf.variable = <sth>. The second option is compiled to static
        # value in TF graph of computation as the result of @tf.function decorators in LoopControlableModel
        self.model.gate.assign(epoch % self.gating_frequency ==0)

In [27]:
class LoopControledModel(LoopControlableModel):

    def __init__(self, output_size, *args, **kwargs):
        super(LoopControledModel, self).__init__(*args, **kwargs)
        # define architecture of the model
        self.layer = tf.keras.layers.Dense(output_size)
        
    def call(self, inputs):
        return self.layer(inputs)

In [28]:
DATASET_SIZE, INPUT_SIZE, OUTPUT_SIZE = 1000, 2, 1
BATCH_SIZE = 64
data = tf.data.Dataset.from_tensor_slices(
    (tf.random.uniform((DATASET_SIZE, INPUT_SIZE)), tf.random.uniform((DATASET_SIZE, OUTPUT_SIZE)))
    ).batch(BATCH_SIZE)


In [29]:
model = LoopControledModel(OUTPUT_SIZE)

# compile model
model.compile(optimizer=
    tf.keras.optimizers.RMSprop(learning_rate=0.01),
    loss="mse", 
    metrics=["mae"])

# start training
history = model.fit(data, epochs = 10, verbose = 1,
    callbacks=[LoopControlerCallback(2)])

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
