## Import required dependencies

In [0]:
import io
import base64
import pandas as pd
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import mlflow
from mlflow.deployments import get_deploy_client
from mlflow.models import infer_signature
from mlflow.pyfunc import PythonModel
from databricks.sdk import WorkspaceClient
from databricks.sdk.service import serving
from mlflow.tracking import MlflowClient
import time




## Add config

In [0]:
UC_CATALOG = "<CATALOG>"
UC_SCHEMA = "<SCHEMA>"
UC_MODEL_NAME = "resnet50_image_classifier_uc"
SERVING_ENDPOINT_NAME = "resnet50-image-serving-endpoint"
SCALE_TO_ZERO_ENABLED = True

## Create MLFlow Pyfunc wrapper to log the model with appropriate pre-processing 

In [0]:
class ResNet50ImageClassifier(PythonModel):

    def load_context(self, context):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1).to(self.device)
        self.model.eval()

        self.preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    def predict(self, context, model_input):
        if "image_data" not in model_input:
            raise ValueError("Input DataFrame must have 'image_data' column.")

        images = []
        for img_str in model_input["image_data"]:
            try:
                img_bytes = base64.b64decode(img_str)
                image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
                images.append(self.preprocess(image))
            except Exception as e:
                print(f"Skipping image due to processing error: {e}")

        if not images:
            return pd.DataFrame(columns=[f"class_{i}" for i in range(1000)])

        batch = torch.stack(images).to(self.device)
        with torch.no_grad():
            preds = self.model(batch)
        return pd.DataFrame(preds.cpu().numpy(), columns=[f"class_{i}" for i in range(1000)])

## Let's create a function to log the model with inference signature and register it to Unity Catalog

In [0]:
def log_and_register_model():
    mlflow.set_registry_uri("databricks-uc")
    uc_model_path = f"{UC_CATALOG}.{UC_SCHEMA}.{UC_MODEL_NAME}"

    dummy_img = Image.new('RGB', (224, 224), 'red')
    buffered = io.BytesIO()
    dummy_img.save(buffered, format="JPEG")
    dummy_b64 = base64.b64encode(buffered.getvalue()).decode()
    dummy_input_df = pd.DataFrame({"image_data": [dummy_b64]})

    with mlflow.start_run(run_name="ResNet50_UC_Logging"):
        model_info = mlflow.pyfunc.log_model(
            python_model=ResNet50ImageClassifier(),
            artifact_path="resnet50_model",
            registered_model_name=uc_model_path,
            signature=infer_signature(dummy_input_df, pd.DataFrame(columns=[f"class_{i}" for i in range(1000)])),
            input_example=dummy_input_df
        )

    client = MlflowClient()
    time.sleep(5)
    model_versions = client.search_model_versions(f"name='{uc_model_path}'")
    latest_version = max(model_versions, key=lambda x: int(x.version)).version

    return uc_model_path, latest_version

## Now we can deploy the UC registered model to model serving

In [0]:
def deploy_model(uc_model_path, model_version):
    client = get_deploy_client("databricks")
    
    served_model_name = f"{UC_MODEL_NAME.replace('_', '-')}-{model_version}"
    
    endpoint_config = {
        "served_entities": [
            {
                "name": served_model_name,  # Corrected to explicitly set name
                "entity_name": uc_model_path,
                "entity_version": str(model_version),
                "workload_size": "Small",
                "scale_to_zero_enabled": True
            }
        ],
        "traffic_config": {
            "routes": [
                {
                    "served_model_name": served_model_name,
                    "traffic_percentage": 100
                }
            ]
        }
    }

    try:
        client.get_endpoint(SERVING_ENDPOINT_NAME)
        print(f"Endpoint '{SERVING_ENDPOINT_NAME}' already exists, updating...")
        client.update_endpoint(endpoint=SERVING_ENDPOINT_NAME, config=endpoint_config)
        print(f"Endpoint '{SERVING_ENDPOINT_NAME}' updated successfully.")
    except Exception as e:
        if "RESOURCE_DOES_NOT_EXIST" in str(e) or "not found" in str(e).lower():
            print(f"Endpoint '{SERVING_ENDPOINT_NAME}' not found, creating a new one...")
            client.create_endpoint(name=SERVING_ENDPOINT_NAME, config=endpoint_config)
            print(f"Endpoint '{SERVING_ENDPOINT_NAME}' created successfully.")
        else:
            print(f"An unexpected error occurred: {e}")
            raise


## Execute the code

In [0]:
uc_model_path, model_version = log_and_register_model()
deploy_model(uc_model_path, model_version)


2025/07/30 20:08:01 INFO mlflow.pyfunc: Validating input example against model signature


Uploading artifacts:   0%|          | 0/11 [00:00<?, ?it/s]

Registered model 'users.aradhya_chouhan.resnet50_image_classifier_uc' already exists. Creating a new version of this model...


Uploading artifacts:   0%|          | 0/11 [00:00<?, ?it/s]

Created version '10' of model 'users.aradhya_chouhan.resnet50_image_classifier_uc'.


Endpoint 'resnet50-image-serving-endpoint' already exists, updating...


  client.update_endpoint(endpoint=SERVING_ENDPOINT_NAME, config=endpoint_config)


Endpoint 'resnet50-image-serving-endpoint' updated successfully.
