Notebook Overview
- Start Execution
- Install and Import Libraries
- Configure Settings
- Verify Assets
- Other sections...

## Start Execution

In [None]:
import logging
import time

# Configure logger
logger: logging.Logger = logging.getLogger("register_model_logger")
logger.setLevel(logging.INFO)
logger.propagate = False  # Prevent duplicate logs from parent loggers

# Set formatter
formatter: logging.Formatter = logging.Formatter(
    fmt="%(asctime)s - %(levelname)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S"
)

# Configure and attach stream handler
stream_handler: logging.StreamHandler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)

In [None]:
start_time = time.time()  

logger.info("Notebook execution started.")

## Install and Import Libraries

In [None]:
from PIL import Image
import base64
from io import BytesIO
import warnings                         
import logging  


import numpy as np
import pandas as pd  

import matplotlib.pyplot as plt


from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPool2D, Flatten, Dense
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.datasets import mnist

# ------------------------ MLflow Integration ------------------------
import mlflow
from mlflow import MlflowClient
from mlflow.models.signature import ModelSignature
from mlflow.types.schema import Schema, ColSpec

## Configure Settings

In [None]:
# ------------------------- MLflow Experiment Configuration -------------------------
EXPERIMENT_NAME = 'MNIST with TensorFlow'
RUN_NAME = "MNIST_Run"
MODEL_NAME = "MNIST_Model"
MODEL_PATH = "model_keras_mnist.keras"

In [None]:
# Suppress Python warnings
warnings.filterwarnings("ignore")

## Logging Model to MLflow

In [None]:
def base64_to_numpy(base64_string):
    """
        Convert a base64 to a numpy array.
    """
    if "," in base64_string:
        base64_string = base64_string.split(",")[1]
    image_data = base64.b64decode(base64_string)

    image = Image.open(BytesIO(image_data))

    if image.mode != 'L':
        image = image.convert('L')
    image = image.resize((28, 28))

    array = np.array(image) / 255
    
    array = array.reshape(1, 28, 28, 1)

    return array

In [None]:
# Load and preprocess MNIST data
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(-1, 28, 28, 1).astype("float32") / 255.0
x_test = x_test.reshape(-1, 28, 28, 1).astype("float32") / 255.0
y_cat_train = to_categorical(y_train, 10)
y_cat_test = to_categorical(y_test, 10)

# Build a simple CNN model
model = Sequential([
    Conv2D(32, kernel_size=(4, 4), activation='relu', input_shape=(28, 28, 1)),
    MaxPool2D(pool_size=(2, 2)),
    Flatten(),
    Dense(128, activation='relu'),
    Dense(10, activation='softmax')
])

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train, y_cat_train, epochs=4, validation_data=(x_test, y_cat_test))

In [None]:
class MNISTModel(mlflow.pyfunc.PythonModel):
    def init(self, keras_model):
        """
        Load keras model.
        """
        try:
            # Load the model
            self.model = keras_model

        except Exception as e:
            logger.error(f"Error loading context: {str(e)}")
            raise

    def predict(self, context, model_input, params = None):
        """
        Computes the predicted digit, by converting the base64 to a numpy array.
        """
        try:
            if isinstance(model_input, pd.DataFrame):
                image_input = model_input.iloc[0, 0]
            elif isinstance(model_input, list):
                image_input = model_input[0]
            else:
                image_input = str(model_input)
                
            base64_string = base64_to_numpy(image_input)

            predictions = self.model.predict(base64_string)

            predicted_classes = np.argmax(predictions, axis = 1)
            
            return predicted_classes
        
        except Exception as e:
            logger.error(f"Error performing prediction: {str(e)}")
            raise
    
    @classmethod
    def log_model(cls, model_name):
        """
        Logs the model to MLflow with appropriate artifacts and schema.
        """
        try:
            # Define input and output schema
            input_schema = Schema([
                ColSpec("string", name = "digit"),
                ])
            output_schema = Schema([
                ColSpec("long", name = "prediction"),
            ])
            
            # Define model signature
            signature = ModelSignature(inputs=input_schema, outputs=output_schema)
            
            # Log the model in MLflow
            
            mlflow.pyfunc.log_model(
                    artifact_path=model_name,
                    python_model=MNISTModel(model),
                    signature=signature
                    )

        except Exception as e:
            logger.error(f"Error logging model: {str(e)}")
            raise   

In [None]:
logger.info(f'Starting the experiment: {EXPERIMENT_NAME}')

mlflow.set_tracking_uri('/phoenix/mlflow')
# Set the MLflow experiment name
mlflow.set_experiment(experiment_name=EXPERIMENT_NAME)

# Start an MLflow run
with mlflow.start_run(run_name=RUN_NAME) as run:
    # Print the artifact URI for reference
    logging.info(f"Run's Artifact URI: {run.info.artifact_uri}")
    test_loss, test_accuracy = model.evaluate(x_test,y_cat_test, verbose = 0)
    mlflow.log_metrics({"test_accuracy": test_accuracy})
    mlflow.log_metrics({"test_lost": test_loss})

    # Log the model to MLflow
    MNISTModel.log_model(model_name=MODEL_NAME)

    # Register the logged model in MLflow Model Registry
    mlflow.register_model(
        model_uri=f"runs:/{run.info.run_id}/{MODEL_NAME}", 
        name=MODEL_NAME
    )

logger.info(f'Registered the model: {MODEL_NAME}')

## Fetching the Latest Model Version from MLflow

In [None]:
# Initialize the MLflow client
client = MlflowClient()

# Retrieve the latest version of the model
model_metadata = client.get_latest_versions(MODEL_NAME, stages=["None"])
latest_model_version = model_metadata[0].version  # Extract the latest model version

# Fetch model information, including its signature
model_info = mlflow.models.get_model_info(f"models:/{MODEL_NAME}/{latest_model_version}")

# Print the latest model version and its signature
logger.info(f"Latest Model Version: {latest_model_version}")
logger.info(f"Model Signature: {model_info.signature}")

## Loading the Model and Running Inference

In [None]:
model = mlflow.pyfunc.load_model(model_uri=f"models:/{MODEL_NAME}/{latest_model_version}")

# Base64 example
base = "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcABwBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+prW0uL66itbSCSe4lbbHFEpZmPoAOtaWt+FtZ8Ox28mqWghS4LCNlmSQblxuU7GO1huGQcHnpWPRXoOiWF/pfhiwh0K2ln8R+JfMWNoh89vaK2w7f7pdg2Wzwq9sk1X+IY03SItJ8JabMLk6QsjX1wpOJLuQr5gHsuxQP/rVw1Fen+EfFmueF/AN3qkup3C2yMbPR7QkBXmbJkk9SsYOcfd3MK8yd3lkaSRizsSzMxyST1JptFXLnVb280+ysJ5y9rYhxbx4AEe9tzdBzk9z7elU6K//2Q=="
numpy_image = base64_to_numpy(base)
# Image of the base64 example
plt.imshow(numpy_image.squeeze(), cmap= 'gray') 

base_input = pd.DataFrame({"digit": [base]})
# Prediction of base64
predictions = model.predict(base_input)

logger.info(f"Predicted class: {predictions}")

In [None]:
end_time: float = time.time()
elapsed_time: float = end_time - start_time
elapsed_minutes: int = int(elapsed_time // 60)
elapsed_seconds: float = elapsed_time % 60

logger.info(f"⏱️ Total execution time: {elapsed_minutes}m {elapsed_seconds:.2f}s")
logger.info("✅ Notebook execution completed successfully.")

Built with ❤️ using [**Z by HP AI Studio**](https://zdocs.datascience.hp.com/docs/aistudio/overview).