# CIFAR-10 Image Classification with TensorFlow and MLflow

In [9]:
import mlflow
import mlflow.tensorflow
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import itertools
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.utils import to_categorical

In [10]:
# Load and preprocess CIFAR-10
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
y_train = y_train.flatten()
y_test = y_test.flatten()

x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

num_classes = 10
y_train_cat = to_categorical(y_train, num_classes)
y_test_cat = to_categorical(y_test, num_classes)

# Use a validation split from the training set
val_split = 0.2
num_val = int(len(x_train) * val_split)
x_val, y_val_cat = x_train[:num_val], y_train_cat[:num_val]
x_train_sub, y_train_sub = x_train[num_val:], y_train_cat[num_val:]
y_val = y_train[:num_val]

In [11]:
# Define a simple CNN model factory
def create_model(dropout_rate=0.5):
    model = models.Sequential([
        layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.Flatten(),
        layers.Dropout(dropout_rate),
        layers.Dense(64, activation='relu'),
        layers.Dense(num_classes, activation='softmax')
    ])
    return model

In [12]:
# Utility: plot and save confusion matrix
def plot_confusion_matrix(cm, classes, normalize=False,
                          title='Confusion matrix', cmap=plt.cm.Blues):
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
    plt.figure(figsize=(6, 6))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)
    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.0
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment='center',
                 color='white' if cm[i, j] > thresh else 'black')
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()

In [13]:
# Configure MLflow experiment
mlflow.set_experiment('cifar10_tensorflow_cnn')
mlflow.tensorflow.autolog(False)  # we'll log things manually for clarity

In [6]:
# Hyperparameter tuning loop (similar to scikit-learn CV loop)
dropout_rates = [0.3, 0.5]
epochs = 10
batch_size = 64

class_names = [str(i) for i in range(num_classes)]

for dropout_rate in dropout_rates:
    with mlflow.start_run():
        mlflow.log_param("dropout_rate", dropout_rate)
        mlflow.log_param("epochs", epochs)
        mlflow.log_param("batch_size", batch_size)

        model = create_model(dropout_rate=dropout_rate)
        model.compile(
            optimizer=tf.keras.optimizers.Adam(),
            loss="categorical_crossentropy",
            metrics=["accuracy"],
        )

        history = model.fit(
            x_train_sub,
            y_train_sub,
            validation_data=(x_val, y_val_cat),
            epochs=epochs,
            batch_size=batch_size,
            verbose=2,
        )

        # Log final validation metrics
        val_loss = history.history["val_loss"][-1]
        val_accuracy = history.history["val_accuracy"][-1]
        mlflow.log_metric("val_loss", float(val_loss))
        mlflow.log_metric("val_accuracy", float(val_accuracy))

        # Confusion matrix on validation set
        y_val_pred_prob = model.predict(x_val, verbose=0)
        y_val_pred = np.argmax(y_val_pred_prob, axis=1)
        cm = confusion_matrix(y_val, y_val_pred)

        plot_confusion_matrix(cm, classes=class_names,
                              title="Validation Confusion Matrix")
        cm_path = "confusion_matrix.png"
        plt.savefig(cm_path, bbox_inches="tight")
        plt.close()
        mlflow.log_artifact(cm_path)

        # Test metrics logged in the SAME run
        test_loss, test_accuracy = model.evaluate(x_test, y_test_cat, verbose=0)
        mlflow.log_metric("test_loss", float(test_loss))
        mlflow.log_metric("test_accuracy", float(test_accuracy))

        # Log the trained model
        mlflow.tensorflow.log_model(model, artifact_path="model")

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/10




625/625 - 11s - 18ms/step - accuracy: 0.3957 - loss: 1.6359 - val_accuracy: 0.5173 - val_loss: 1.3322
Epoch 2/10
Epoch 2/10




625/625 - 11s - 18ms/step - accuracy: 0.5349 - loss: 1.2987 - val_accuracy: 0.5902 - val_loss: 1.1776
Epoch 3/10
Epoch 3/10




625/625 - 8s - 13ms/step - accuracy: 0.5912 - loss: 1.1497 - val_accuracy: 0.6124 - val_loss: 1.1182
Epoch 4/10
Epoch 4/10




625/625 - 8s - 13ms/step - accuracy: 0.6278 - loss: 1.0581 - val_accuracy: 0.6402 - val_loss: 1.0270
Epoch 5/10
Epoch 5/10




625/625 - 8s - 13ms/step - accuracy: 0.6529 - loss: 0.9855 - val_accuracy: 0.6811 - val_loss: 0.9232
Epoch 6/10
Epoch 6/10
625/625 - 8s - 13ms/step - accuracy: 0.6758 - loss: 0.9286 - val_accuracy: 0.6652 - val_loss: 0.9536
Epoch 7/10
625/625 - 8s - 13ms/step - accuracy: 0.6758 - loss: 0.9286 - val_accuracy: 0.6652 - val_loss: 0.9536
Epoch 7/10




625/625 - 8s - 12ms/step - accuracy: 0.6910 - loss: 0.8792 - val_accuracy: 0.6930 - val_loss: 0.8923
Epoch 8/10
Epoch 8/10




625/625 - 10s - 16ms/step - accuracy: 0.7050 - loss: 0.8414 - val_accuracy: 0.7005 - val_loss: 0.8798
Epoch 9/10
Epoch 9/10




625/625 - 8s - 13ms/step - accuracy: 0.7165 - loss: 0.8045 - val_accuracy: 0.7091 - val_loss: 0.8351
Epoch 10/10
Epoch 10/10




625/625 - 7s - 12ms/step - accuracy: 0.7278 - loss: 0.7688 - val_accuracy: 0.7166 - val_loss: 0.8161


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/10




625/625 - 11s - 18ms/step - accuracy: 0.3675 - loss: 1.6961 - val_accuracy: 0.5064 - val_loss: 1.3495
Epoch 2/10
Epoch 2/10




625/625 - 10s - 16ms/step - accuracy: 0.5160 - loss: 1.3339 - val_accuracy: 0.5297 - val_loss: 1.2852
Epoch 3/10
Epoch 3/10




625/625 - 9s - 14ms/step - accuracy: 0.5680 - loss: 1.2123 - val_accuracy: 0.6160 - val_loss: 1.0845
Epoch 4/10
Epoch 4/10




625/625 - 8s - 14ms/step - accuracy: 0.5983 - loss: 1.1303 - val_accuracy: 0.6435 - val_loss: 1.0126
Epoch 5/10
Epoch 5/10




625/625 - 9s - 14ms/step - accuracy: 0.6225 - loss: 1.0653 - val_accuracy: 0.6512 - val_loss: 0.9860
Epoch 6/10
Epoch 6/10




625/625 - 10s - 15ms/step - accuracy: 0.6407 - loss: 1.0122 - val_accuracy: 0.6674 - val_loss: 0.9402
Epoch 7/10
Epoch 7/10




625/625 - 9s - 15ms/step - accuracy: 0.6562 - loss: 0.9686 - val_accuracy: 0.6679 - val_loss: 0.9355
Epoch 8/10
Epoch 8/10




625/625 - 10s - 16ms/step - accuracy: 0.6702 - loss: 0.9339 - val_accuracy: 0.6900 - val_loss: 0.8867
Epoch 9/10
Epoch 9/10




625/625 - 9s - 15ms/step - accuracy: 0.6831 - loss: 0.9034 - val_accuracy: 0.6952 - val_loss: 0.8622
Epoch 10/10
Epoch 10/10




625/625 - 9s - 14ms/step - accuracy: 0.6914 - loss: 0.8764 - val_accuracy: 0.6944 - val_loss: 0.8608




# Loading the model with the highest accuracy (0.71)

In [14]:
loaded_model = mlflow.tensorflow.load_model("runs:/618351e572f04986b23f25256fdb4fb5/model")

In [15]:
# Evaluate the loaded model on the test set
test_loss, test_accuracy = loaded_model.evaluate(x_test, y_test_cat, verbose=0)
print(f"Test loss: {test_loss:.4f}")
print(f"Test accuracy: {test_accuracy:.4f}")

Test loss: 0.8291
Test accuracy: 0.7114


## Local deployment and serving of the CIFAR-10 model

We serve the best TensorFlow model locally with MLflow, then query it via HTTP.

In [50]:
import json
import requests

num_examples = 8
x_batch = x_test[:num_examples]          # shape: (N, 32, 32, 3)
x_batch_list = x_batch.tolist()          # nested list for JSON

host = "127.0.0.1"
port = "1235"
url = f"http://{host}:{port}/invocations"

headers = {"Content-Type": "application/json"}

payload = json.dumps({"inputs": x_batch_list})

response = requests.post(url=url, headers=headers, data=payload)
print("Status code:", response.status_code)
print("Raw response:", response.text)

Status code: 200
Raw response: {"predictions": [[0.007945500314235687, 0.0005136856925673783, 0.062369707971811295, 0.6200320720672607, 0.011319908313453197, 0.25463297963142395, 0.02709686942398548, 0.003119528293609619, 0.005819974932819605, 0.007149782031774521], [0.0822228416800499, 0.061189915984869, 2.3025691916700453e-05, 2.3491569663747214e-05, 8.562147741031367e-06, 1.6617105984551017e-06, 6.0922357079107314e-06, 6.451009085139958e-06, 0.845633327960968, 0.010884624905884266], [0.14168456196784973, 0.06660491973161697, 0.005248929373919964, 0.013661198318004608, 0.005727095995098352, 0.0024737876374274492, 0.008335418067872524, 0.016216203570365906, 0.21332769095897675, 0.5267202258110046], [0.8414548635482788, 0.016507118940353394, 0.0017927006119862199, 0.0013356013223528862, 0.0018879343988373876, 7.168805313995108e-05, 0.000879703089594841, 4.183408236713149e-06, 0.13457100093364716, 0.0014952534111216664], [2.352754381718114e-05, 1.6026061757656862e-06, 0.0136080170050263

In [51]:
import numpy as np
import json

probs = json.loads(response.text)["predictions"]        # shape: (N, num_classes)
probs = np.array(probs)
preds = np.argmax(probs, axis=1)           # class indices
print("Predicted classes:", preds)
print("True labels:", y_test[:len(preds)])

Predicted classes: [3 8 9 0 4 6 1 6]
True labels: [3 8 8 0 6 6 1 6]
