# Get Started with MLflow + Tensorflow

In this guide, we will show how to train your model with Tensorflow and log your training using MLflow.

We will use [Databricks Community Edition](https://community.cloud.databricks.com/) as our tracking server, which has built-in support for MLflow. Databricks CE is the free version of Databricks platform, if you haven't, please register an account via [link](https://www.databricks.com/try-databricks).

You can run code in this guide from cloud-based notebooks like Databricks notebook or Google Colab, or run it on your local machine.

## Install dependencies

Let's install the `mlflow` package.

```
%pip install -q mlflow
```

Then let's import the packages.

In [3]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow import keras

## Load the dataset

We will do a simple image classification on handwritten digits with [mnist dataset](https://en.wikipedia.org/wiki/MNIST_database).

Let's load the dataset using `tensorflow_datasets` (`tfds`), which returns datasets in the format of `tf.data.Dataset`.

In [4]:
# Load the mnist dataset.
train_ds, test_ds = tfds.load(
    "mnist",
    split=["train", "test"],
    shuffle_files=True,
)

2024-11-08 16:11:51.928475: W tensorflow/core/platform/cloud/google_auth_provider.cc:184] All attempts to get a Google authentication bearer token failed, returning an empty token. Retrieving token from files failed with "NOT_FOUND: Could not locate the credentials file.". Retrieving token from GCE failed with "FAILED_PRECONDITION: Error executing an HTTP request: libcurl code 6 meaning 'Couldn't resolve host name', error details: Could not resolve host: metadata".


[1mDownloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /home/hedredo/tensorflow_datasets/mnist/3.0.1...[0m


  from .autonotebook import tqdm as notebook_tqdm
Dl Completed...: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5/5 [00:01<00:00,  3.21 file/s]


[1mDataset mnist downloaded and prepared to /home/hedredo/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.[0m


2024-11-08 16:11:54.940386: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-11-08 16:11:54.941290: I tensorflow/c/logging.cc:34] DirectML: creating device on adapter 0 (AMD Radeon RX 6700 XT)
Dropped Escape call with ulEscapeCode : 0x03007703
2024-11-08 16:11:55.287774: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2024-11-08 16:11:55.287809: W tensorflow/core/common_runtime/pluggable_device/pluggable_device_bfc_allocator.cc:28] Overriding allow_growth setting because force_memory_growth was requested by the device.
2024-11-08 16:11:55.287833: I tensorflow/core/c

Let's preprocess our data with the following steps:
- Scale each pixel's value to `[0, 1)`.
- Batch the dataset.
- Use `prefetch` to speed up the training.

In [5]:
def preprocess_fn(data):
    image = tf.cast(data["image"], tf.float32) / 255
    label = data["label"]
    return (image, label)


train_ds = train_ds.map(preprocess_fn).batch(128).prefetch(tf.data.AUTOTUNE)
test_ds = test_ds.map(preprocess_fn).batch(128).prefetch(tf.data.AUTOTUNE)

## Define the Model

Let's define a convolutional neural network as our classifier. We can use `keras.Sequential` to stack up the layers.

In [6]:
input_shape = (28, 28, 1)
num_classes = 10

model = keras.Sequential(
    [
        keras.Input(shape=input_shape),
        keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
        keras.layers.MaxPooling2D(pool_size=(2, 2)),
        keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        keras.layers.MaxPooling2D(pool_size=(2, 2)),
        keras.layers.Flatten(),
        keras.layers.Dropout(0.5),
        keras.layers.Dense(num_classes, activation="softmax"),
    ]
)

Set training-related configs, optimizers, loss function, metrics.

In [7]:
model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(),
    optimizer=keras.optimizers.Adam(0.001),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

## Set up tracking/visualization tool

In this tutorial, we will use Databricks CE as MLflow tracking server. For other options such as using your local MLflow server, please read the [Tracking Server Overview](https://mlflow.org/docs/latest/getting-started/tracking-server-overview/index.html).

If you have not, please register an account of [Databricks community edition](https://www.databricks.com/try-databricks#account). It should take no longer than 1min to register. Databricks CE (community edition) is a free platform for users to try out Databricks features. For this guide, we need the ML experiment dashboard for us to track our training progress.




After successfully registering an account on Databricks CE, let's connnect MLflow to Databricks CE. You will need to enter following information:
- **Databricks Host**: https://community.cloud.databricks.com/
- **Username**: your signed up email
- **Password**: your password

In [8]:
import mlflow

mlflow.set_tracking_uri("http://localhost:5000")

Now this colab is connected to the hosted tracking server. Let's configure MLflow metadata. Two things to set up:
- `mlflow.set_tracking_uri`: always use "databricks".
- `mlflow.set_experiment`: pick up a name you like, start with `/`.

## Logging with MLflow

There are two ways you can log to MLflow from your Tensorflow pipeline:
- MLflow auto logging.
- Use a callback.

Auto logging is simple to configure, but gives you less control. Using a callback is more flexible. Let's see how each way is done.

### MLflow Auto Logging

All you need to do is to call `mlflow.tensorflow.autolog()` before kicking off the training, then the backend will automatically log the metrics into the server you configured earlier. In our case, Databricks CE.

In [9]:
mlflow.create_experiment("/mlflow-tf-keras-mnist")

'324626298537831039'

In [10]:
# Choose any name that you like.
mlflow.set_experiment("/mlflow-tf-keras-mnist")

mlflow.tensorflow.autolog()

model.fit(x=train_ds, epochs=3)

2024/11/08 16:14:59 INFO mlflow.utils.autologging_utils: Created MLflow autologging run with ID 'a86127ba9ea1491f92c454df7b9eaf43', which will track hyperparameters, performance metrics, model artifacts, and lineage information for the current tensorflow workflow


Epoch 1/3


2024-11-08 16:15:00.990501: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
2024-11-08 16:15:01.037767: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2024-11-08 16:15:01.037838: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 24700 MB memory) -> physical PluggableDevice (device: 0, name: DML, pci bus id: <undefined>)
2024-11-08 16:15:01.040417: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2024-11-08 16:15:01.040442: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_f





Epoch 2/3



Epoch 3/3







INFO:tensorflow:Assets written to: /tmp/tmp939zlfb4/model/data/model/assets


INFO:tensorflow:Assets written to: /tmp/tmp939zlfb4/model/data/model/assets
2024/11/08 16:15:14 INFO mlflow.tracking._tracking_service.client: üèÉ View run languid-panda-426 at: http://localhost:5000/#/experiments/324626298537831039/runs/a86127ba9ea1491f92c454df7b9eaf43.
2024/11/08 16:15:14 INFO mlflow.tracking._tracking_service.client: üß™ View experiment at: http://localhost:5000/#/experiments/324626298537831039.


<keras.callbacks.History at 0x7f1ea00fbbb0>

While your training is ongoing, you can find this training in your dashboard. Log in to your [Databricks CE](https://community.cloud.databricks.com/) account, and click on top left to select machine learning in the drop down list. Then click on the experiment icon. See the screenshot below:
![landing page](https://i.imgur.com/eQgnAcI.png)

After clicking the `Experiment` button, it will bring you to the experiment page, where you can find your runs. Clicking on the most recent experiment and run, you can find your metrics there, similar to:
![experiment page](https://i.imgur.com/uuHLttD.png)


You can click on metrics to see the chart.

Let's evaluate the training result.

In [11]:
score = model.evaluate(test_ds)

print(f"Test loss: {score[0]:.4f}")
print(f"Test accuracy: {score[1]: .2f}")



2024-11-08 16:16:11.637497: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
2024-11-08 16:16:11.658134: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2024-11-08 16:16:11.658192: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 24700 MB memory) -> physical PluggableDevice (device: 0, name: DML, pci bus id: <undefined>)
2024-11-08 16:16:11.659953: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2024-11-08 16:16:11.659998: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_f

Test loss: 0.0456
Test accuracy:  0.99


### Log with MLflow Callback

Auto logging is powerful and convenient, but if you are looking for a more native way as Tensorflow pipelines, you can use `mlflow.tensorflow.MllflowCallback` inside `model.fit()`, it will log:
- Your model configuration, layers, hyperparameters and so on.
- The training stats, including losses and metrics configured with `model.compile()`.

In [12]:
from mlflow.tensorflow import MlflowCallback

# Turn off autologging.
mlflow.tensorflow.autolog(disable=True)

with mlflow.start_run() as run:
    model.fit(
        x=train_ds,
        epochs=2,
        callbacks=[MlflowCallback(run)],
    )

Epoch 1/2
Epoch 2/2


2024/11/08 16:16:58 INFO mlflow.tracking._tracking_service.client: üèÉ View run bemused-bee-341 at: http://localhost:5000/#/experiments/324626298537831039/runs/f622f62a1f0a4f76b6d120be49086195.
2024/11/08 16:16:58 INFO mlflow.tracking._tracking_service.client: üß™ View experiment at: http://localhost:5000/#/experiments/324626298537831039.


Going to the Databricks CE experiment view, you will see a similar dashboard as before.

### Customize the MLflow Callback

If you want to add extra logging logic, you can customize the MLflow callback. You can either subclass from `keras.callbacks.Callback` and write everything from scratch or subclass from `mlflow.tensorflow.MllflowCallback` to add you custom logging logic.

Let's look at an example that we want to replace the loss with its log value to log to MLflow.

In [13]:
import math


# Create our own callback by subclassing `MlflowCallback`.
class MlflowCustomCallback(MlflowCallback):
    def on_epoch_end(self, epoch, logs=None):
        if not self.log_every_epoch:
            return
        loss = logs["loss"]
        logs["log_loss"] = math.log(loss)
        del logs["loss"]
        self.metrics_logger.record_metrics(logs, epoch)

Train the model with the new callback.

In [14]:
with mlflow.start_run() as run:
    run_id = run.info.run_id
    model.fit(
        x=train_ds,
        epochs=2,
        callbacks=[MlflowCustomCallback(run)],
    )

Epoch 1/2



Epoch 2/2





2024/11/08 16:18:40 INFO mlflow.tracking._tracking_service.client: üèÉ View run skittish-cod-577 at: http://localhost:5000/#/experiments/324626298537831039/runs/943ad3a6054e45b2a60b8e88ef4d8bad.
2024/11/08 16:18:40 INFO mlflow.tracking._tracking_service.client: üß™ View experiment at: http://localhost:5000/#/experiments/324626298537831039.


Going to your Databricks CE page, you should find the `log_loss` is replacing the `loss` metric, similar to what is shown in the screenshot below.

![log loss screenshot](https://i.imgur.com/dncAwaP.png)

## Wrap up

Now you have learned the basic integration between MLflow and Tensorflow. There are a few things not covered by this quickstart, e.g., saving TF model to MLflow and loading it back. For a detailed guide, please refer to our main guide for integration between MLflow and Tensorflow.