## TensorFlow2 training loop control using default *tf.fit(...)* function

### Task Descripton

Up to now custom training loop in Tensorflow2 requires writing two lops:
1. loop iterating through epochs 
2. loop iterating through batches 

Then all castom training precudere will have to be implemented in these double-loop block of code. It's neither elegant nor robust due to the missing advanced features of *tf.fit(...)*.

In [2]:
import tensorflow as tf
from types import MethodType


In [3]:
class LCModel(tf.keras.Model):
    def __init__(self, variables, *args, **kwargs):
        super(LCModel, self).__init__(*args, **kwargs)
        for name, value in variables.items():
            setattr(self, name, tf.Variable(value, trainable=False))
    

# lcm = LCModel({"gating": {
#     "init_value": False,
#     "loss_fn": tf.keras.losses.MeanAbsoluteError, 
#     "freq_epoch" : 2,
#     "delay_epoch" : 10,
#     "freq_step" : 2,
#     "delay_step" : 10,
#     "cond" : functool.partials()
#     "var" : 
#     "clipping" : [0.1, 0.2]
# }})


In [4]:
class LoopControlableModel(tf.keras.Model):

    def __init__(self, *args, **kwargs):
        super(LoopControlableModel, self).__init__(*args, **kwargs)
        self.gate = tf.Variable(False, trainable=False) # gate control variable
    
    @tf.function
    def train_step(self, data):
        train_metrics = tf.cond(
            self.gate, 
            lambda: self.train_step_active(data),
            lambda: self.train_step_passive(data)
        )

        return train_metrics

    @tf.function
    def train_step_active(self, data):
        x, y = data
        with tf.GradientTape(watch_accessed_variables=True) as tape:
            logits = self(x, training=True)
            loss_value = self.compiled_loss(y, logits)
        grads = tape.gradient(loss_value, tape.watched_variables())
        self.optimizer.apply_gradients(zip(grads, tape.watched_variables()))
        return {**{m.name: m.result() for m in self.metrics}, **{"active": True, "passive": False}}

    @tf.function
    def train_step_passive(self, data):
        x, y = data
        with tf.GradientTape(watch_accessed_variables=True) as tape:
            logits = self(x, training=True)
            loss_value = self.compiled_loss(y, logits)
        grads = tape.gradient(loss_value, tape.watched_variables())
        self.optimizer.apply_gradients(zip(grads, tape.watched_variables()))
        return {**{m.name: m.result() for m in self.metrics}, **{"active": False, "passive": True}}
    

In [5]:
class LoopControlerCallback(tf.keras.callbacks.Callback):

    def __init__(self, gating_frequency: int, *args, **kwargs) -> None:
        super(LoopControlerCallback, self).__init__(*args, **kwargs)
        self.gating_frequency = gating_frequency

    def on_train_begin(self, logs = None):
        tf.print(self.model.to_json())


    def on_epoch_end(self, epoch, logs):
        """Control gating variable from the level of callback which can work on epoch/batch level."""
        # tf.variable.assign is different than tf.variable = <sth>. The second option is compiled to static
        # value in TF graph of computation as the result of @tf.function decorators in LoopControlableModel
        self.model.gate.assign(epoch % self.gating_frequency ==0)

In [6]:
class LoopControledModel(LoopControlableModel):

    def __init__(self, output_size, *args, **kwargs):
        super(LoopControledModel, self).__init__(*args, **kwargs)
        # define architecture of the model
        self.layer = tf.keras.layers.Dense(output_size)
        
    def call(self, inputs):
        return self.layer(inputs)

In [7]:
DATASET_SIZE, INPUT_SIZE, OUTPUT_SIZE = 1000, 2, 1
BATCH_SIZE = 64
data = tf.data.Dataset.from_tensor_slices(
    (tf.random.uniform((DATASET_SIZE, INPUT_SIZE)), tf.random.uniform((DATASET_SIZE, OUTPUT_SIZE)))
    ).batch(BATCH_SIZE)


2023-01-27 14:13:56.849956: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2023-01-27 14:13:56.850022: W tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303)
2023-01-27 14:13:56.850074: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (filip-HP-ProBook-440-G3): /proc/driver/nvidia/version does not exist
2023-01-27 14:13:56.853177: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [14]:
model = LoopControledModel(OUTPUT_SIZE)

# compile model
model.compile(optimizer=
    tf.keras.optimizers.RMSprop(learning_rate=0.01),
    loss="mse", 
    metrics=["mae"])

# start training
history = model.fit(data, epochs = 10, verbose = 1,
    callbacks=[LoopControlerCallback(2)])

{"class_name": "LoopControledModel", "config": {}, "keras_version": "2.9.0", "backend": "tensorflow"}
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [23]:
import functools

class A:
    def __init__(self) -> None:
        self.x = 1
        self.y = 2

a = A()

def assign_to_instance(instance):
    def decorator_assign_to_instance(func):
        @functools.wraps(func)
        def wrapper_assign_to_instance(*args, **kwargs):
            args = (instance,) + args
            setattr(instance, func.__name__, functools.partial(func, *args, **kwargs))
            return 
        return wrapper_assign_to_instance
    return decorator_assign_to_instance

@assign_to_instance(a)
def train(self, a, b):
    print(self.x)
    print(self.y)
    print(f"a: {a}")
    print(f"b: {b}")

train()
a.train(4, 5)


1
2
a: 4
b: 5
