In [21]:
import tensorflow as tf
import tensorflow.keras as keras 
 
class CustomModel(keras.Model):
    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        x, y = data

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)  # Forward pass
            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        # Update metrics (includes the metric that tracks the loss)
        self.compiled_metrics.update_state(y, y_pred)
        # Return a dict mapping metric names to current value
        
        return {m.name: m.result() for m in self.metrics}

In [22]:
import resource
def using(point=""):
    usage=resource.getrusage(resource.RUSAGE_SELF)
    return '''%s: usertime=%s systime=%s mem=%s mb
           '''%(point,usage[0],usage[1],
                usage[2]/1024.0 )

In [None]:
import numpy as np

# Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(100)(inputs)

model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))

for i in range(int(1e5)):

    model.fit(x, y, epochs=int(1e1), verbose=False)
    print(using("after"))

after: usertime=441.986728 systime=57.024734 mem=1430.1171875 mb
           
after: usertime=442.473599 systime=57.126383 mem=1430.1328125 mb
           
after: usertime=442.98091 systime=57.205181 mem=1430.1328125 mb
           
after: usertime=443.513417 systime=57.259892 mem=1430.1328125 mb
           
after: usertime=444.040051 systime=57.32454 mem=1430.1328125 mb
           
after: usertime=444.563174 systime=57.389004 mem=1430.265625 mb
           
after: usertime=445.102599 systime=57.437083 mem=1430.265625 mb
           
after: usertime=445.629428 systime=57.49727 mem=1430.265625 mb
           
after: usertime=446.143298 systime=57.58127 mem=1430.453125 mb
           
after: usertime=446.63779 systime=57.676899 mem=1430.453125 mb
           
after: usertime=447.137084 systime=57.763376 mem=1430.453125 mb
           
after: usertime=447.648496 systime=57.839668 mem=1430.453125 mb
           
after: usertime=448.165233 systime=57.913475 mem=1430.453125 mb
           
after: usert