# MNIST Classification with Keras on Databricks (Serverless CPU)

This notebook trains a simple Keras model on the MNIST digits dataset on Databricks Serverless CPU, with MLflow tracking. It attempts to load the dataset from a Unity Catalog volume, falling back to the built-in Keras MNIST loader if not found.

- Uses MLflow for experiment and run tracking, and logs the trained model
- Assumes MNIST is stored as an `mnist.npz` file in a Unity Catalog volume at `/Volumes/<catalog>/<schema>/<volume>/mnist.npz`
- Designed to be lightweight for CPU-only environments


In [0]:
CATALOG = dbutils.widgets.get("catalog")
SCHEMA = dbutils.widgets.get("schema")
EXPERIMENT_PATH = dbutils.widgets.get("exp_path")
VOLUME = dbutils.widgets.get("volume")

In [0]:
# Databricks and environment setup
# - Configure MLflow experiment location
# - Set TensorFlow log level for cleaner output
# - Import core libraries

import os
import time
from typing import Tuple

import mlflow
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# Reduce TF verbosity for CPU runs
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")  # 0=all,1=info,2=warning,3=error

# Reproducibility
SEED = 2025
keras.utils.set_random_seed(SEED)

mlflow.set_experiment(EXPERIMENT_PATH)

print(f"TF version: {tf.__version__}")
print(f"Keras version: {keras.__version__}")
print(f"MLflow experiment: {EXPERIMENT_PATH}")
print(f"MLflow tracking URI: {mlflow.get_tracking_uri()}")


In [0]:
# Data loading from Unity Catalog Volume with fallback to Keras loader
# - Expected file: /Volumes/<catalog>/<schema>/<volume>/mnist.npz
# - Provides clear errors if UC path is misconfigured

import numpy as np

# Configure the UC Volume path pieces
CATALOG = os.environ.get("CATALOG", "main")
SCHEMA = os.environ.get("SCHEMA", "default")
VOLUME = os.environ.get("VOLUME", "datasets")
UC_FILE = os.environ.get("UC_FILE", "mnist.npz")

UC_DATA_PATH = f"/Volumes/{CATALOG}/{SCHEMA}/{VOLUME}/{UC_FILE}"
print(f"Attempting to load MNIST from: {UC_DATA_PATH}")

def _load_from_uc_npz(path: str):
    if not os.path.exists(path):
        raise FileNotFoundError(f"Unity Catalog file not found at {path}")
    with np.load(path) as data:
        x_train = data["x_train"]
        y_train = data["y_train"]
        x_test = data["x_test"]
        y_test = data["y_test"]
    return (x_train, y_train), (x_test, y_test)


try:
    (x_train, y_train), (x_test, y_test) = _load_from_uc_npz(UC_DATA_PATH)
    source = "unity_catalog"
except Exception as e:
    print(f"UC load failed: {e}. Falling back to Keras loader...")
    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
    source = "keras_builtin"

print(f"Loaded MNIST from: {source}")
print(f"Train shape: {x_train.shape}, Test shape: {x_test.shape}")

# Normalize to [0,1] and reshape to vectors for MLP
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0

# Flatten 28x28 -> 784
x_train = x_train.reshape((-1, 28 * 28))
x_test = x_test.reshape((-1, 28 * 28))

num_classes = 10

# One-hot encoding labels for categorical crossentropy
y_train_categorical = keras.utils.to_categorical(y_train, num_classes)
y_test_categorical = keras.utils.to_categorical(y_test, num_classes)


In [0]:
# Build a simple MLP model for MNIST classification
# - Lightweight architecture suitable for CPU inference and training

HIDDEN_UNITS = 256
DROPOUT_RATE = 0.2
LEARNING_RATE = 1e-3

model = keras.Sequential([
    layers.Input(shape=(28 * 28,)),
    layers.Dense(HIDDEN_UNITS, activation="relu"),
    layers.Dropout(DROPOUT_RATE),
    layers.Dense(HIDDEN_UNITS // 2, activation="relu"),
    layers.Dropout(DROPOUT_RATE),
    layers.Dense(num_classes, activation="softmax"),
])

model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE),
    loss="categorical_crossentropy",
    metrics=["accuracy"],
)

model.summary()


In [0]:
# Train model with MLflow autologging
# - Tracks params, metrics, and artifacts automatically
# - Logs the Keras model to MLflow

mlflow.tensorflow.autolog()

BATCH_SIZE = 128
EPOCHS = 5  # Keep small for CPU demo
VALIDATION_SPLIT = 0.1

run_name = f"mnist-keras-mlp-cpu-{int(time.time())}"

with mlflow.start_run(run_name=run_name):
    # Optional manual params alongside autolog
    mlflow.log_params({
        "hidden_units": HIDDEN_UNITS,
        "dropout_rate": DROPOUT_RATE,
        "learning_rate": LEARNING_RATE,
        "batch_size": BATCH_SIZE,
        "epochs": EPOCHS,
        "validation_split": VALIDATION_SPLIT,
        "data_source": source,
        "seed": SEED,
    })

    history = model.fit(
        x_train,
        y_train_categorical,
        validation_split=VALIDATION_SPLIT,
        batch_size=BATCH_SIZE,
        epochs=EPOCHS,
        verbose=2,
    )

    # Explicitly log the model with a signature to ease serving
    input_signature = mlflow.models.signature.infer_signature(
        x_train[:100],
        model.predict(x_train[:100], verbose=0),
    )
    mlflow.keras.log_model(
        model,
        artifact_path="model",
        signature=input_signature,
        registered_model_name=None,  # set if you want to register in Model Registry
    )

print("Training completed.")


In [0]:
# Evaluate on the test set and log final metrics

with mlflow.start_run(run_name=f"mnist-eval-{int(time.time())}"):
    test_loss, test_acc = model.evaluate(x_test, y_test_categorical, verbose=0)
    mlflow.log_metrics({"test_loss": float(test_loss), "test_accuracy": float(test_acc)})

    # Log a small batch of predictions as an artifact for inspection
    sample_images = x_test[:16]
    sample_labels = y_test[:16]
    preds = model.predict(sample_images, verbose=0)
    pred_labels = np.argmax(preds, axis=1)

    # Create a compact CSV for quick checking
    import pandas as pd

    df_pred = pd.DataFrame({
        "label": sample_labels,
        "pred": pred_labels,
    })
    tmp_csv_path = "/tmp/mnist_sample_preds.csv"
    df_pred.to_csv(tmp_csv_path, index=False)
    mlflow.log_artifact(tmp_csv_path, artifact_path="eval")

print({"test_accuracy": float(test_acc), "test_loss": float(test_loss)})
