# 06: Inference Testing

**Purpose:** This notebook tests the MLflow model we logged in the previous script. We will test inference in two ways:

1.  **Batch Inference (Spark UDF):** Simulates a large-scale batch scoring job on a Spark DataFrame.
2.  **Real-time Inference (Pandas):** Simulates a single request, such as one coming from a REST API, using a Pandas DataFrame.

This script confirms that the logged artifact (model + image processor) is a complete, deployable pipeline.

### 1. Environment Setup

**Reason:** We must ensure the inference environment is identical to the training environment to prevent errors.

* `libaio-dev`: A system-level library required by `transformers`/`accelerate` on this cluster's OS (Ubuntu Noble). Without it, PyTorch fails to initialize.
* `transformers==4.49.0`: We pin this exact version to match what the model was trained with and what `mlflow==2.21.3` is compatible with.
* `uv`: This Python installer is now required by `mlflow` to manage virtual environments for UDFs.

In [0]:
%sh
# Update package lists and install the required AIO library
sudo apt-get update && sudo apt-get install -y libaio-dev

In [0]:
# Pin libraries to the versions used in training and logging
%pip install transformers==4.49.0
%pip install uv

# Restart the Python kernel for the new libraries to take effect
dbutils.library.restartPython()

In [0]:
import mlflow
import pandas as pd
import base64
import os
import torch
from datasets import Dataset, ClassLabel
from pyspark.sql.functions import col, udf
from pyspark.sql.types import StringType

# --- Configuration ---
# This is the MLflow Run ID from notebook 05 (Manual Logging)
RUN_ID = "6ffbb8dafe5e4395b1c93d2547ad160a" 
MODEL_URI = f"runs:/{RUN_ID}/model"

# This is the Delta table we created in notebook 04
DATASET_NAME = "train_dataset" 

# add the metadata to enable the image preview
img_meta = {
    "spark.contentAnnotation":'{"mimeType":"image/jpeg"}'
}

### 4. Define Preprocessing Functions

The logged MLflow model is a *pipeline* that includes the `image_processor`. This pipeline expects the raw image, not the `pixel_values` tensor.

We will define two functions to convert our raw image `bytes` into a `base64 string`, which the pipeline accepts.
1.  `bytes_to_base64_udf`: A Spark UDF for batch inference.
2.  `bytes_to_base64_pd`: A regular Python function for Pandas inference.

In [0]:
@udf(StringType())
def bytes_to_base64_udf(image_bytes):
    """
    Spark UDF: Converts raw image bytes into a Base64-encoded string.
    """
    return base64.b64encode(image_bytes).decode('utf-8')

def bytes_to_base64_pd(image_bytes):
    """
    Python function for Pandas: Converts raw image bytes into a Base64-encoded string.
    """
    return base64.b64encode(image_bytes).decode('utf-8')

### 5. Test 1: Batch Inference with Spark UDF

Here, we test scoring on a sample of the Spark DataFrame.

In [0]:
print("Loading Spark DataFrame...")
df = spark.read.table(DATASET_NAME)

# Get 5 'noise' samples for testing
sample_df = df.filter(col('label') == 'noise').limit(5)

# 1. Preprocess the data: Convert 'image' bytes to 'image_base64' string
inference_df = sample_df.withColumn("image_base64", bytes_to_base64_udf(col("image")))

# 2. Load the model as a Spark UDF
# We must specify result_type=StringType() because the model's 'task' 
# (image-classification) makes it output a JSON-like string, not a float.
print("Loading model as Spark UDF...")
loaded_model_udf = mlflow.pyfunc.spark_udf(
    spark, 
    model_uri=MODEL_URI, 
    result_type=StringType()
)

# 3. Run prediction
# We call the UDF *only* on the 'image_base64' column.
print("Running batch prediction on 5 samples...")
predictions_df = inference_df.withColumn(
    'predictions', 
    loaded_model_udf(col("image_base64"))
)

predictions_df.withColumn('image', col('image').alias('image', metadata = img_meta))
display(predictions_df.select("name", "image", "label", "predictions"))

### 6. Test 2: Real-time Inference with Pandas

Here, we test scoring using `pyfunc.load_model` and a Pandas DataFrame. This simulates how a REST API (like Databricks Model Serving) would use the model for a single request.

In [0]:
# 1. Load the model as a standard PyFunc object
print("Loading PyFunc model...")
pyfunc_model = mlflow.pyfunc.load_model(MODEL_URI)

# 2. Get 5 'surfing' (normal) samples and 5 'noise' (anomaly) samples
print("Preparing Pandas DataFrames...")
normal_pd_df = df.filter(col('label') == 'surfing').limit(5).toPandas()
anomaly_pd_df = df.filter(col('label') == 'noise').limit(5).toPandas()

# 3. Prepare the input data
# The pipeline expects a DataFrame where each row is one base64 string.
normal_input_data = pd.DataFrame(normal_pd_df['image'].apply(bytes_to_base64_pd))
anomaly_input_data = pd.DataFrame(anomaly_pd_df['image'].apply(bytes_to_base64_pd))

# 4. Run predictions
print("\n--- Predicting on NORMAL (surfing) samples ---")
normal_predictions = pyfunc_model.predict(normal_input_data)
display(normal_predictions)

print("\n--- Predicting on ANOMALY (noise) samples ---")
anomaly_predictions = pyfunc_model.predict(anomaly_input_data)
display(anomaly_predictions)