# Tensorflow + MLflow

In this guide, we will show how to train Tensorflow model and log the training using MLflow.

We will train a simple image classification model on handwritten digits of [mnist dataset](https://en.wikipedia.org/wiki/MNIST_database).

In [7]:
import mlflow
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow import keras

## Load the dataset

In [8]:
# Load the mnist dataset.
train_ds, test_ds = tfds.load(
    "mnist",
    split=["train", "test"],
    shuffle_files=True,
)

Preprocess the data with the following steps:
* Scale each pixel’s value to [0, 1).
* Batch the dataset.
* Use prefetch to speed up the training.

In [9]:
def preprocess_fn(data):
    image = tf.cast(data["image"], tf.float32) / 255
    label = data["label"]
    return (image, label)


train_ds = train_ds.map(preprocess_fn).batch(128).prefetch(tf.data.AUTOTUNE)
test_ds = test_ds.map(preprocess_fn).batch(128).prefetch(tf.data.AUTOTUNE)

## Define the model

In [10]:
input_shape = (28, 28, 1)
num_classes = 10

model = keras.Sequential(
    [
        keras.Input(shape=input_shape),
        keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
        keras.layers.MaxPooling2D(pool_size=(2, 2)),
        keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        keras.layers.MaxPooling2D(pool_size=(2, 2)),
        keras.layers.Flatten(),
        keras.layers.Dropout(0.5),
        keras.layers.Dense(num_classes, activation="softmax"),
    ]
)

In [11]:
model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(),
    optimizer=keras.optimizers.Adam(0.001),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

## Run local MLflow Tracking Server

Make sure you have MLflow installed then proceed with running the Tracking Server by executing this command:

```bash
mlflow server
```

## Logging with MLflow

### MLflow Auto Logging

In [12]:
mlflow.set_tracking_uri("http://127.0.0.1:5000")
mlflow.set_experiment("Tensorflow MNIST")

mlflow.tensorflow.autolog()

model.fit(x=train_ds, epochs=5)

2024/11/17 14:57:53 INFO mlflow.utils.autologging_utils: Created MLflow autologging run with ID '9699cde56ec140dbbcacb91334f64d50', which will track hyperparameters, performance metrics, model artifacts, and lineage information for the current tensorflow workflow
2024-11-17 14:57:54.911134: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 1/5
[1m469/469[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 16ms/step - loss: 0.6995 - sparse_categorical_accuracy: 0.7790



[1m469/469[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 16ms/step - loss: 0.6988 - sparse_categorical_accuracy: 0.7793
Epoch 2/5
[1m469/469[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - loss: 0.1172 - sparse_categorical_accuracy: 0.9643



[1m469/469[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 15ms/step - loss: 0.1171 - sparse_categorical_accuracy: 0.9643
Epoch 3/5
[1m469/469[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - loss: 0.0867 - sparse_categorical_accuracy: 0.9747



[1m469/469[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 15ms/step - loss: 0.0867 - sparse_categorical_accuracy: 0.9747
Epoch 4/5
[1m468/469[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 16ms/step - loss: 0.0722 - sparse_categorical_accuracy: 0.9781



[1m469/469[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 16ms/step - loss: 0.0722 - sparse_categorical_accuracy: 0.9781
Epoch 5/5
[1m468/469[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 16ms/step - loss: 0.0619 - sparse_categorical_accuracy: 0.9808



[1m469/469[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 16ms/step - loss: 0.0619 - sparse_categorical_accuracy: 0.9808


2024/11/17 14:58:36 INFO mlflow.tracking._tracking_service.client: 🏃 View run wise-mink-375 at: http://127.0.0.1:5000/#/experiments/369098936933184006/runs/9699cde56ec140dbbcacb91334f64d50.
2024/11/17 14:58:36 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://127.0.0.1:5000/#/experiments/369098936933184006.


<keras.src.callbacks.history.History at 0x1588a8b90>