# Lecture 5 A - Different ways to create Keras models.
There are three APIs for building models in Keras:
- The **Sequential model**, the most approachable API—it’s basically a Python list. As such, it’s limited to simple stacks of layers.
- The **Functional API**, which focuses on graph-like model architectures. It represents a nice mid-point between usability and flexibility, and as such, it’s the most commonly used model-building API.
- **Model subclassing**, a low-level option where we write everything ourself from scratch. This is ideal if we want full control over every little thing. However, we won’t get access to many built-in Keras features, and we will be more at risk of making mistakes.
<img src=1.png/>

In [None]:
from tensorflow import keras 
from tensorflow.keras import layers

## The Sequential model
is the simplest way to build a Keras model. 
```python
model = keras.Sequential([
    layers.Dense(64, activation="relu"),
    layers.Dense(10, activation="softmax")
])
```
It’s possible to build the same model incrementally via the add() method, which is equivalent of the `append()` method of a Python list:

In [None]:
model = keras.Sequential()
model.add(layers.Dense(64, activation="relu"))
model.add(layers.Dense(10, activation="softmax"))

Layers only get built when they are called for the first time. That’s because the shape of the layers' weights depends on the shape of their input. The  Sequential model does not have any weights until we actually call it on some data:

In [None]:
model.weights

We can also call `build()` method with an input shape to make model have weights:

In [None]:
model.build(input_shape=(None, 3))  
model.weights    

After the model is built, we can display its contents via the `summary()` method

In [None]:
model.summary()

We can give names to everything in Keras—every model, every layer:

In [None]:
model = keras.Sequential(name="Example_model")
model.add(layers.Dense(64, activation="relu", name="first_layer"))
model.add(layers.Dense(10, activation="softmax", name="last_layer"))
model.build((None, 3))
model.summary()

When building a Sequential model incrementally, it’s useful to be able to print a summary of what the current model looks like. But we can’t print a summary until the model is built. The solution is to declare the shape of the model’s inputs via the Input class:

In [None]:
model = keras.Sequential()
model.add(keras.Input(shape=(3,)))               
model.add(layers.Dense(64, activation="relu"))
model.summary()

In [None]:
model.add(layers.Dense(10, activation="softmax"))
model.summary()

## The Functional API
The Sequential model is easy to use, but its applicability is extremely limited: it can only express models with a single input and a single output, applying one layer after the other in a sequential fashion. In practice, it’s pretty common to encounter models with multiple inputs (say, an image and its metadata), multiple outputs (different things you want to predict about the data), or a nonlinear topology. In such cases, you’d build your model using the Functional API.

Functional API version of the previous model looks like:

In [None]:
inputs = keras.Input(shape=(3,), name="my_input")
features = layers.Dense(64, activation="relu")(inputs)
outputs = layers.Dense(10, activation="softmax")(features)
model = keras.Model(inputs=inputs, outputs=outputs)

The inputs object holds information about the shape and dtype of the data that the model will process

In [None]:
inputs.shape

In [None]:
inputs.dtype

Next, we creat a layer and called it on the input
```python
features = layers.Dense(64, activation="relu")(inputs)
```

In [None]:
features.shape

After obtaining the final outputs, we instantiate the model by specifying its inputs and outputs in the Model constructor:
```python
outputs = layers.Dense(10, activation="softmax")(features)
model = keras.Model(inputs=inputs, outputs=outputs)
```

The summary of the model is as follows:

In [None]:
model.summary()

The following model aplies possibilities of the functional API in broader range. 
Let us consider a system to rank customer support tickets by priority and route them to the appropriate department. Our model has three inputs:
- The title of the ticket (text input)
- The text body of the ticket (text input)
- Any tags added by the user (categorical input, assumed here to be one-hot encoded)

Our model has two outputs:
- The priority score of the ticket, a scalar between 0 and 1 (sigmoid output)
- The department that should handle the ticket (a softmax over the set of departments).

In [None]:
vocabulary_size = 10000
num_tags = 100
num_departments = 4

title = keras.Input(shape=(vocabulary_size,), name="title")
text_body = keras.Input(shape=(vocabulary_size,), name="text_body")
tags = keras.Input(shape=(num_tags,), name="tags")

features = layers.Concatenate()([title, text_body, tags])
features = layers.Dense(64, activation="relu")(features)

priority = layers.Dense(1, activation="sigmoid", name="priority")(features)
department = layers.Dense(
    num_departments, activation="softmax", name="department")(features)

model = keras.Model(inputs=[title, text_body, tags], outputs=[priority, department])

We can inspect and reuse individual layers in the model. The model.layers model property provides the list of layers that make up the model, and for each layer we can query layer.input and layer.output.

In [None]:
model.layers

For instance we want to get input and output of layer[3]:

In [None]:
model.layers[3].input

In [None]:
model.layers[3].output

This enables us to do feature extraction, creating models that reuse intermediate features from another model.

## Model subclassing
is the most advanced model-building pattern. 
Subclassing Model involves:
- In the `__init__()` method, define the layers the model will use.
- In the `call()` method, define the forward pass of the model, reusing the layers previously created.
- Instantiate subclass, and call it on data to create its weights.

The Model subclassing workflow is the most flexible way to build a model. It enables us to build models that cannot be expressed as directed acyclic graphs of layers — imagine, for instance, a model where the `call()` method uses layers inside a for loop, or calls them recursively.

We will reimplement the customer support ticket management model using a Model subclass:

In [None]:
class CustomerTicketModel(keras.Model):

    def __init__(self, num_departments):
        super().__init__() # Call the super() constructor
        # Define sublayers in the constructor:
        self.concat_layer = layers.Concatenate()
        self.mixing_layer = layers.Dense(64, activation="relu")
        self.priority_scorer = layers.Dense(1, activation="sigmoid")
        self.department_classifier = layers.Dense(num_departments, activation="softmax")

    def call(self, inputs): # Define the forward pass in the call() method
        title = inputs["title"]
        text_body = inputs["text_body"]
        tags = inputs["tags"]

        features = self.concat_layer([title, text_body, tags])
        features = self.mixing_layer(features)
        priority = self.priority_scorer(features)
        department = self.department_classifier(features)
        return priority, department

Once we’ve defined the model, we can instantiate it:

In [None]:
import numpy as np

num_samples = 1280

title_data = np.random.randint(0, 2, size=(num_samples, vocabulary_size))
text_body_data = np.random.randint(0, 2, size=(num_samples, vocabulary_size))
tags_data = np.random.randint(0, 2, size=(num_samples, num_tags))

priority, department = model(
    {"title": title_data, "text_body": text_body_data, "tags": tags_data})

We can compile and train a Model subclass just like a Sequential or Functional model:

In [None]:
priority_data = np.random.random(size=(num_samples, 1))
department_data = np.random.randint(0, 2, size=(num_samples, num_departments))

model.compile(optimizer="rmsprop",
              #  The structure of what we pass as the loss and metrics arguments must match exactly 
              # what gets returned by call()—here, a list of two elements:
              loss=["mean_squared_error", "categorical_crossentropy"],
              metrics=[["mean_absolute_error"], ["accuracy"]])
model.fit({"title": title_data,
           "text_body": text_body_data,
           "tags": tags_data},
          [priority_data, department_data],
          epochs=1)
model.evaluate(#The structure of the input data must match exactly what is expected by the call() method—
               # here, a dict with keys title, text_body, and tags:
                {"title": title_data,
                "text_body": text_body_data,
                "tags": tags_data},
                # The structure of the target data must match exactly what is returned by the call() method—
                # here, a list of two elements:
               [priority_data, department_data])
priority_preds, department_preds = model.predict({"title": title_data,
                                                  "text_body": text_body_data,
                                                  "tags": tags_data})

Functional and subclassed models are substantially different in nature. A Functional model is an explicit data structure—a graph of layers, which we can view, inspect, and modify. A subclassed model is a piece of bytecode—a Python class with a `call()` method that contains raw code. This is the source of the subclassing workflow’s flexibility but it introduces new limitations:
- the way layers are connected to each other is hidden inside the body of the call() method, we cannot access that information;
- calling `summary()` will not display layer connectivity;
- we cannot plot the model topology via plot_model();
- given a subclassed model, we cannot access layers to do feature extraction.

## Mixing and matching different components
Choosing one of these patterns—the Sequential model, the Functional API, or Model subclassing—does not lock us out of the others. All models in the Keras API can smoothly interoperate with each other.

For instance, we can use a subclassed layer or model in a Functional model.

In [None]:
class Classifier(keras.Model):

    def __init__(self, num_classes=2):
        super().__init__()
        if num_classes == 2:
            num_units = 1
            activation = "sigmoid"
        else:
            num_units = num_classes
            activation = "softmax"
        self.dense = layers.Dense(num_units, activation=activation)

    def call(self, inputs):
        return self.dense(inputs)

inputs = keras.Input(shape=(3,))
features = layers.Dense(64, activation="relu")(inputs)
outputs = Classifier(num_classes=10)(features)
model = keras.Model(inputs=inputs, outputs=outputs)
model.summary()

Inversely, we can use a Functional model as part of a subclassed layer or model:

In [None]:
inputs = keras.Input(shape=(64,))
outputs = layers.Dense(1, activation="sigmoid")(inputs)
binary_classifier = keras.Model(inputs=inputs, outputs=outputs)

class MyModel(keras.Model):

    def __init__(self, num_classes=2):
        super().__init__()
        self.dense = layers.Dense(64, activation="relu")
        self.classifier = binary_classifier

    def call(self, inputs):
        features = self.dense(inputs)
        return self.classifier(features)

model = MyModel()

## What workflow should we choose for building Keras models?
In general, the Functional API provides us with a pretty good trade-off between ease of use and flexibility. It also gives us direct access to layer connectivity, which is very powerful for use cases such as model plotting or feature extraction. If we can use the Functional API — that is, if our model can be expressed as a directed acyclic graph of layers — it is recommended using it over model subclassing.

In general, using Functional models that include subclassed layers provides the best of both worlds: high development flexibility while retaining the advantages of the Functional API.

# Lecture 5B - using training and evaluation loops
The principle of progressive disclosure of complexity also applies to model training. Keras provides us with different workflows for training models. They can be as simple as calling fit() on our data, or as advanced as writing a new training algorithm from scratch.

The standard workflow encompasses: `compile()`, `fit()`, `evaluate()`, and `predict()`:

In [None]:
from tensorflow.keras.datasets import mnist

def get_mnist_model():
    inputs = keras.Input(shape=(28 * 28,))
    features = layers.Dense(512, activation="relu")(inputs)
    features = layers.Dropout(0.5)(features)
    outputs = layers.Dense(10, activation="softmax")(features)
    model = keras.Model(inputs, outputs)
    return model

(images, labels), (test_images, test_labels) = mnist.load_data()
images = images.reshape((60000, 28 * 28)).astype("float32") / 255
test_images = test_images.reshape((10000, 28 * 28)).astype("float32") / 255
train_images, val_images = images[10000:], images[:10000]
train_labels, val_labels = labels[10000:], labels[:10000]

model = get_mnist_model()
model.compile(optimizer="rmsprop",
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy"])
model.fit(train_images, train_labels,
          epochs=3,
          validation_data=(val_images, val_labels))
test_metrics = model.evaluate(test_images, test_labels)
predictions = model.predict(test_images)

There are, however, a couple of ways we can customize this simple workflow:
- Creating our own custom metrics.
- Passing callbacks to the `fit()` method to schedule actions to be taken at specific points during training.

## Creating our own custom metrics
Metrics are key to measuring the performance of our model — in particular, to measuring the difference between its performance on the training data and its performance on the test data. Commonly used metrics for classification and regression are already part of the built-in `keras.metrics` module. But if we’re doing anything out of the ordinary, we will need to be able to write our own metrics. A Keras metric is a subclass of the `keras.metrics.Metric` class. Like layers, a metric has an internal state stored in TensorFlow variables. Unlike layers, these variables aren’t updated via backpropagation, so we have to write the state-update logic ourself, which happens in the `update_state()` method.

Below we implement an example of a custom metric that measures the root mean squared error (RMSE):

In [None]:
import tensorflow as tf

class RootMeanSquaredError(keras.metrics.Metric): # Subclass the Metric class

    # Define the state variables in the constructor. Like for layers, we have access to the add_weight() method
    def __init__(self, name="rmse", **kwargs):
        super().__init__(name=name, **kwargs)
        self.mse_sum = self.add_weight(name="mse_sum", initializer="zeros")
        self.total_samples = self.add_weight(name="total_samples", initializer="zeros", dtype="int32")       
    
    # Implement the state update logic in update_state(). The y_true argument is the targets (or labels) for one batch, 
    # while y_pred represents the corresponding predictions from the model. 
    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.one_hot(y_true, depth=tf.shape(y_pred)[1]) # To match our MNIST model, we expect categorical predictions and integer labels.
        mse = tf.reduce_sum(tf.square(y_true - y_pred))
        self.mse_sum.assign_add(mse)
        num_samples = tf.shape(y_pred)[0]
        self.total_samples.assign_add(num_samples)
    
    # Return the current value of the metric
    def result(self):
        return tf.sqrt(self.mse_sum / tf.cast(self.total_samples, tf.float32))

    # Reset the metric state without having to reinstantiate it
    def reset_state(self):
        self.mse_sum.assign(0.)
        self.total_samples.assign(0)

Our custom metrics can be used like built-in ones:

In [None]:
model = get_mnist_model()
model.compile(optimizer="rmsprop",
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy", RootMeanSquaredError()])
model.fit(train_images, train_labels,
          epochs=3,
          validation_data=(val_images, val_labels))

In [None]:
test_metrics = model.evaluate(test_images, test_labels)
test_metrics

### Using callbacks
A **callback** is an object (a class instance implementing specific methods) that is passed to the model in the call to `fit()` and that is called by the model at various points during training. It has access to all the available data about the state of the model and its performance, and it can take action: interrupt training, save a model, load a different weight set, or otherwise alter the state of the model.

We can use callbacks in the following ways:
- Model checkpointing—Saving the current state of the model at different points during training.
- Early stopping—Interrupting training when the validation loss is no longer improving.
- Dynamically adjusting the value of certain parameters during training—Such as the learning rate of the optimizer.
- Logging training and validation metrics during training, or visualizing the representations learned by the model as they’re updated.

The `keras.callbacks` module includes a number of built-in callbacks:
- `keras.callbacks.ModelCheckpoint`
- `keras.callbacks.EarlyStopping`
- `keras.callbacks.LearningRateScheduler`
- `keras.callbacks.ReduceLROnPlateau`
- `keras.callbacks.CSVLogger`

When we’re training a model, we can’t tell how many epochs will be needed to get to an optimal validation loss. We have adopted the strategy of training for enough epochs that we begin overfitting, using the first run to figure out the proper number of epochs to train for, and then finally launching a new training run from scratch using this optimal number. However, this approach is wasteful. A much better way to handle this is to stop training when we measure that the validation loss is no longer improving. This can be achieved using the *EarlyStopping* callback. The *EarlyStopping* callback interrupts training once a target metric being monitored has stopped improving for a fixed number of epochs. This callback is typically used in combination with *ModelCheckpoint*, which lets us continually save the model during training.

In [None]:
# Callbacks are passed to the model via the callbacks argument in fit(), which takes a list of callbacks. 
callbacks_list = [
    keras.callbacks.EarlyStopping(# Interrupts training when improvement stops
        monitor="val_accuracy",# Monitors the model’s validation accuracy
        patience=2,# Interrupts training when accuracy has stopped improving for 2 epochs
    ),
    keras.callbacks.ModelCheckpoint(# Saves the current weights after every epoch
        filepath="checkpoint_path.keras",# Path to the destination model file
        monitor="val_loss",# val_loss is monitored
        save_best_only=True,# Only the best model seen during training will be kept
    )
]
model = get_mnist_model()
model.compile(optimizer="rmsprop",
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy"])
model.fit(train_images, train_labels,
          epochs=15,
          callbacks=callbacks_list,
          validation_data=(val_images, val_labels))

We can easily save models manually after training:
```python
model.save('my_checkpoint_path')
```

To reload the model we’ve saved, we use

In [None]:
model = keras.models.load_model("checkpoint_path.keras")
print(model.evaluate(val_images, val_labels))
model.evaluate(test_images, test_labels)

### Writing our own callbacks
If we need to take a specific action during training that isn’t covered by one of the built-in callbacks, we can write our own callback. Callbacks are implemented by subclassing the `keras.callbacks.Callback` class. We can then implement any number of the following transparently named methods, which are called at various points during training:
```python
on_epoch_begin(epoch, logs) # Called at the start of every epoch     
on_epoch_end(epoch, logs)   # Called at the end of every epoch     
on_batch_begin(batch, logs) # Called before processing each batch     
on_batch_end(batch, logs)   # Called after processing each batch        
on_train_begin(logs)        # Called at the start of training     
on_train_end(logs)          # Called at the end of training  
```

All these methods are called with a logs argument, which is a dictionary containing information about the previous batch, epoch, or training run — training and validation metrics, and so on.

The following callback saves a list of per-batch loss values during training and saves a graph of these values at the end of each epoch.

In [None]:
from matplotlib import pyplot as plt

class LossHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs):
        self.per_batch_losses = []

    def on_batch_end(self, batch, logs):
        self.per_batch_losses.append(logs.get("loss"))

    def on_epoch_end(self, epoch, logs):
        plt.clf()
        plt.plot(range(len(self.per_batch_losses)), self.per_batch_losses,
                 label="Training loss for each batch")
        plt.xlabel(f"Batch (epoch {epoch})")
        plt.ylabel("Loss")
        plt.legend()
        plt.savefig(f"Plots/plot_at_epoch_{epoch}")
        self.per_batch_losses = []

In [None]:
model = get_mnist_model()
model.compile(optimizer="rmsprop",
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy"])
model.fit(train_images, train_labels,
          epochs=10,
          callbacks=[LossHistory()],
          validation_data=(val_images, val_labels))

### Monitoring and visualization with TensorBoard
To do good research or develop good models, we need rich, frequent feedback about what’s going on inside our models during our experiments. That’s the point of running experiments: to get information about how well a model performs. **TensorBoard** (www.tensorflow.org/tensorboard) is a browser-based application that we can run locally. It’s the best way to monitor everything that goes on inside our model during training. With TensorBoard, we can:
- Visually monitor metrics during training
- Visualize our model architecture
- Visualize histograms of activations and gradients
- Explore embeddings in 3D

The easiest way to use TensorBoard with a Keras model and the `fit()` method is to use the `keras.callbacks.TensorBoard` callback:

In [None]:
model = get_mnist_model()
model.compile(optimizer="rmsprop",
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy"])

tensorboard = keras.callbacks.TensorBoard(
    #log_dir="/full_path_to_your_log_dir",
    log_dir="D:/Dokumenty/log_dir",
)
model.fit(train_images, train_labels,
          epochs=5,
          validation_data=(val_images, val_labels),
          callbacks=[tensorboard])

Once the model starts running, it will write logs at the target location. We can run an embedded TensorBoard instance as part of your notebook, using the following commands:

In [None]:
%load_ext tensorboard
%tensorboard --logdir "D:/Dokumenty/log_dir"

## Writing our own training and evaluation loops
The built-in `fit()` workflow is solely focused on supervised learning: a setup where there are known targets associated with input data, and where we compute our loss as a function of these targets and the model’s predictions. There are other setups where no explicit targets are present, such as generative learning, self-supervised learning (where targets are obtained from the inputs), and reinforcement learning (where learning is driven by occasional “rewards,”). In such situations the built-in `fit()` is not enough, and we will need to write our own custom training logic. 

The contents of a typical training loop look like this:
- Run the forward pass (compute the model’s output) inside a gradient tape to obtain a loss value for the current batch of data.
- Retrieve the gradients of the loss with regard to the model’s weights.
- Update the model’s weights so as to lower the loss value on the current batch of data.

To reimplement `fit()` from scratch we need to take into consideration the following facts:
1. Some Keras layers, such as the Dropout layer, have different behaviors during training and during inference (when we use them to generate predictions). Such layers expose a training Boolean argument in their `call()` method. Calling `dropout(inputs, training=True)` will drop some activation entries, while calling `dropout(inputs, training=False)` does nothing. So we need to pass training =True when we call a Keras model during the forward pass. Our forward pass thus becomes `predictions = model(inputs, training=True)`.
2. When we retrieve the gradients of the weights of our model, we should not use `tape.gradients(loss, model.weights)`, but rather `tape .gradients(loss, model.trainable_weights)` since layers and models own two kinds of weights:
- Trainable weights — These are meant to be updated via backpropagation to minimize the loss of the model.
- Non-trainable weights — These are meant to be updated during the forward pass by the layers that own them.

Taking into account these two facts, a supervised-learning training step ends up looking like this:
```python
def train_step(inputs, targets):
    with tf.GradientTape() as tape:
        predictions = model(inputs, training=True)
        loss = loss_fn(targets, predictions)
    gradients = tape.gradients(loss, model.trainable_weights)
    optimizer.apply_gradients(zip(model.trainable_weights, gradients))
```

3. In a low-level training loop, we can leverage Keras metrics. As we know the metrics API consits of: simply call `update_state(y_true, y_pred)` for each batch of targets and predictions, and then use `result()` to query the current metric value. For instance:

In [None]:
values = range(10)
mean_tracker = keras.metrics.Mean() 
for value in values:
    mean_tracker.update_state(value) 
mean_tracker.result()

The following training step function combines the forward pass, backward pass, and metrics tracking into a fit()-like function  that takes a batch of data and targets and returns the logs that would get displayed by the `fit()` progress bar:

In [None]:
model = get_mnist_model()

loss_fn = keras.losses.SparseCategoricalCrossentropy()
optimizer = keras.optimizers.RMSprop()
metrics = [keras.metrics.SparseCategoricalAccuracy()]
loss_tracking_metric = keras.metrics.Mean()

@tf.function
def train_step(inputs, targets):
    with tf.GradientTape() as tape:
        predictions = model(inputs, training=True)
        loss = loss_fn(targets, predictions)
    gradients = tape.gradient(loss, model.trainable_weights)
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))

    logs = {}
    for metric in metrics:
        metric.update_state(targets, predictions)
        logs[metric.name] = metric.result()

    loss_tracking_metric.update_state(loss)
    logs["loss"] = loss_tracking_metric.result()
    return logs

4. We also need to use `metric.reset_state()` when we want to reset the current results (at the start of a training epoch or at the start of evaluation). 

In [None]:
def reset_metrics():
    for metric in metrics:
        metric.reset_state()
    loss_tracking_metric.reset_state()

Now we can lay out our complete training loop. We use a `tf.data.Dataset` object to turn our NumPy data into an iterator that iterates over the data in batches of size 32.

In [None]:
training_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
training_dataset = training_dataset.batch(32)
epochs = 2
for epoch in range(epochs):
    reset_metrics()
    for inputs_batch, targets_batch in training_dataset:
        logs = train_step(inputs_batch, targets_batch)
    print(f"Results at the end of epoch {epoch}")
    for key, value in logs.items():
        print(f"...{key}: {value:.4f}")

The evaluation loop is a simple `for` loop that repeatedly calls a `test_step()` function, which processes a single batch of data. The `test_step()` function is a subset of the logic of `train_step()` since it omits the code that deals with updating the weights of the model:

In [None]:
import time

st = time.time()
def test_step(inputs, targets):
    predictions = model(inputs, training=False)
    loss = loss_fn(targets, predictions)

    logs = {}
    for metric in metrics:
        metric.update_state(targets, predictions)
        logs["val_" + metric.name] = metric.result()

    loss_tracking_metric.update_state(loss)
    logs["val_loss"] = loss_tracking_metric.result()
    return logs

val_dataset = tf.data.Dataset.from_tensor_slices((val_images, val_labels))
val_dataset = val_dataset.batch(32)
reset_metrics()
for inputs_batch, targets_batch in val_dataset:
    logs = test_step(inputs_batch, targets_batch)
print("Evaluation results:")
for key, value in logs.items():
    print(f"...{key}: {value:.4f}")
print("Time:",time.time()-st)

Our custom loops are running significantly slower than the built-in `fit()` and `evaluate()`, despite implementing essentially the same logic.  TensorFlow code is executed line by line, eagerly, much like regular Python code. It’s more performant to compile our TensorFlow code into a computation graph that can be globally optimized. The syntax to do this is very simple: we nned to add a `@tf.function` to any function we want to compile before executing:

In [None]:
st = time.time()
@tf.function
def test_step(inputs, targets):
    predictions = model(inputs, training=False)
    loss = loss_fn(targets, predictions)

    logs = {}
    for metric in metrics:
        metric.update_state(targets, predictions)
        logs["val_" + metric.name] = metric.result()

    loss_tracking_metric.update_state(loss)
    logs["val_loss"] = loss_tracking_metric.result()
    return logs

val_dataset = tf.data.Dataset.from_tensor_slices((val_images, val_labels))
val_dataset = val_dataset.batch(32)
reset_metrics()
for inputs_batch, targets_batch in val_dataset:
    logs = test_step(inputs_batch, targets_batch)
print("Evaluation results:")
for key, value in logs.items():
    print(f"...{key}: {value:.4f}")
print('Time:',time.time()-st)

### Leveraging fit() with a custom training loop
is a middle ground between `fit()` and a training loop written from scratch. We can provide a custom training step function and let the framework do the rest by overriding the `train_step()` method of the `Model` class. The `train_step()` is the function that is called by `fit()` for every batch of data. 

In the following example:
- We create a new class that subclasses `keras.Model`.
- We override the method `train_step(self, data)`. It returns a dictionary mapping metric names (including the loss) to their current values.
- We implement a `metrics` property that tracks the model’s Metric instances. This enables the model to automatically call `reset_state()` on the model’s metrics at the start of each epoch and at the start of a call to `evaluate()`, so we don’t have to do it by hand.

In [None]:
loss_fn = keras.losses.SparseCategoricalCrossentropy()
loss_tracker = keras.metrics.Mean(name="loss")

class CustomModel(keras.Model):
    def train_step(self, data):
        inputs, targets = data
        with tf.GradientTape() as tape:
            predictions = self(inputs, training=True)
            loss = loss_fn(targets, predictions)
        gradients = tape.gradient(loss, model.trainable_weights)
        optimizer.apply_gradients(zip(gradients, model.trainable_weights))

        loss_tracker.update_state(loss)
        return {"loss": loss_tracker.result()}

    @property
    def metrics(self):
        return [loss_tracker]