# Deploying and Using MedImageParse model for Inference using Batch Endpoints
This example illustrates how to deploy MedImageParse, a state-of-the-art segmentation model tailored for biomedical imaging. For this Notebook, we use Python 3.10, AzureML v2.

### Task
The primary task is semantic segmentation, where the goal is to identify and label specific regions within an image based on their semantic meaning using a submitted image and a text prompt.
 
### Model
MedImageParse is powered by a transformer-based architecture, fine-tuned for segmentation tasks on extensive biomedical image datasets. It is designed to excel in handling complex segmentation challenges across diverse imaging modalities. 

### Inference data
For this demonstration, we will use histopathology images stained with HE (Hematoxylin and Eosin) and focus on cell phenotyping, segmenting and identifying different types of cells in the tissue sample.

### Outline
1. Setup pre-requisites
2. Pick a model to deploy
3. Deploy the model to an batch endpoint
4. Test the endpoint
5. Clean up resources - delete the endpoint


## 1. Setup pre-requisites
* Install [Azure ML Client library for Python](https://learn.microsoft.com/en-us/python/api/overview/azure/ai-ml-readme?view=azure-python)
* Connect to AzureML Workspace and authenticate.

In [None]:
from azure.ai.ml import MLClient, Input
from azure.ai.ml.entities import (
    BatchEndpoint,
    ModelBatchDeployment,
    ModelBatchDeploymentSettings,
    Model,
    AmlCompute,
    Data,
    BatchRetrySettings,
    CodeConfiguration,
    Environment,
)
from azure.ai.ml.constants import AssetTypes, BatchDeploymentOutputAction
from azure.identity import DefaultAzureCredential
import pandas as pd

credential = DefaultAzureCredential()
ml_workspace = MLClient.from_config(credential)
print("Workspace:", ml_workspace)
ml_registry = MLClient(credential, registry_name="azureml")
print("Registry:", ml_registry)

## 2. Pick a model to deploy

In this example, we use the `MedImageParse` model. If you have opened this notebook for a different model, replace the model name accordingly.

In [None]:
model = ml_registry.models.get(name="MedImageParse", label="latest")
model

## 3. Deploy the model to an online endpoint for real time inference
Online endpoints give a durable REST API that can be used to integrate with applications that need to use the model.

The steps below show how to deploy an endpoint programmatically. You can skip the steps in this section if you just want to test an existing endpoint. 

### Create compute cluster

In [None]:
compute_name = "mip-batch-cluster"
if not any(filter(lambda m: m.name == compute_name, ml_workspace.compute.list())):
    compute_cluster = AmlCompute(
        name=compute_name,
        description="GPU cluster compute for MedImageParse inference",
        min_instances=0,
        max_instances=1,
        size="Standard_NC6s_v3",
    )
    ml_workspace.compute.begin_create_or_update(compute_cluster).result()

### Create batch endpoint

In [None]:
import random
import string

endpoint_prefix = "mip-batch"
endpoint_list = list(
    filter(
        lambda m: m.name.startswith(endpoint_prefix),
        ml_workspace.batch_endpoints.list(),
    )
)

if endpoint_list:
    endpoint = endpoint_list and endpoint_list[0]
    print("Found existing endpoint:", endpoint.name)
else:
    # Creating a unique endpoint name by including a random suffix
    allowed_chars = string.ascii_lowercase + string.digits
    endpoint_suffix = "".join(random.choice(allowed_chars) for x in range(5))
    endpoint_name = f"{endpoint_prefix}-{endpoint_suffix}"
    endpoint = BatchEndpoint(
        name=endpoint_name,
        description="A batch endpoint for scoring images from MedImageParse.",
        tags={"type": "medimageparse"},
    )
    ml_workspace.begin_create_or_update(endpoint).result()
    print(f"Created new endpoint: {endpoint_name}")

### Deploy MedImageParse to batch endpoint

- **max_concurrency_per_instance**: Determines the number of worker process to spawn. Each worker process loads the model into GPU. We want to use multiple worker process to maximize GPU utilization, but not exceed available GPU memory.
- **retry_settings**: Timeout may need to be adjusted based on batch size. Larger batch size requires longer timeout; otherwise, worker process may end prematurely.

In [None]:
deployment = ModelBatchDeployment(
    name="mip-dpl",
    description="A deployment for model MedImageParse",
    endpoint_name=endpoint.name,
    model=model,
    compute=compute_name,
    settings=ModelBatchDeploymentSettings(
        max_concurrency_per_instance=4,
        mini_batch_size=1,
        instance_count=1,
        output_action=BatchDeploymentOutputAction.APPEND_ROW,
        output_file_name="predictions.csv",
        retry_settings=BatchRetrySettings(max_retries=3, timeout=300),
        logging_level="info",
    ),
)
ml_workspace.begin_create_or_update(deployment).result()

In [None]:
ml_workspace.batch_endpoints.get_logs()

In [None]:
endpoint = ml_workspace.batch_endpoints.get(endpoint.name)
endpoint.defaults.deployment_name = deployment.name
ml_workspace.batch_endpoints.begin_create_or_update(endpoint).result()
print(f"The default deployment is {endpoint.defaults.deployment_name}")

## 4 Test the endpoint - base64 encoded image and text

### Load test dataset
Download the test dataset using command `azcopy copy --recursive https://azuremlexampledata.blob.core.windows.net/data/healthcare-ai/ /home/azureuser/data/`

In [None]:
import glob

root_dir = "/home/azureuser/data/healthcare-ai/medimageinsight-examparameter/pngs"

png_files = glob.glob(f"{root_dir}/**/*.png", recursive=True)
print(f"Found {len(png_files)} PNG files")

### Create the input CSV file

#### Zero-Padding Batch Filenames

>In the example below only one batch is created (batch_input_001.csv). If you need to create more batches please make sure to use the zero-padding batch index.

For more detailed example please refer to the MedImageInsight notebook `mi2-deploy-batch-endpoint.ipynb`.   
In that notebook, the function `write_to_csv()` will automatically create batch files with  **zero-padded numeric suffixes** (e.g., `batch_input_001.csv`, `batch_input_002.csv`, ..., `batch_input_010.csv`). 
It's essential to use that index for enumerating your batches. 

This ensures that files are **sorted in the correct numerical order**, rather than lexicographic string order. E.g., without padding, `batch10` would appear **before** `batch2` or `batch3` when sorting, which can lead to confusing or incorrect alignment between batch input files and batch output results. Zero-padding helps maintain predictable ordering and avoids mismatches during downstream processing or aggregation.

In [None]:
import base64
import os


def read_base64_image(image_path):
    with open(image_path, "rb") as f:
        return base64.b64encode(f.read()).decode("utf-8")


data = []
for f in png_files:
    base64_image = read_base64_image(f)
    data.append([base64_image, "abnormality"])

csv_path = os.path.join(os.getcwd(), "batch_input_001.csv")
df_input = pd.DataFrame(data, columns=["image", "text"])
df_input.to_csv(csv_path)

### Load the test dataset into AzureML

In [None]:
dataset_name = "mip-png-dataset"

png_dataset = Data(
    path=csv_path,
    type=AssetTypes.URI_FILE,
    description="An unlabeled dataset for heart classification",
    name=dataset_name,
)

ml_workspace.data.create_or_update(png_dataset)

### Verify the test dataset is uploaded successfully

In [None]:
ml_workspace.data.get(name=dataset_name, label="latest")

### Submit a job to the batch endpoint

In [None]:
input = Input(type=AssetTypes.URI_FILE, path=png_dataset.path)
input

In [None]:
job = ml_workspace.batch_endpoints.invoke(endpoint_name=endpoint.name, input=input)

In [None]:
# Monitor job progress
ml_workspace.jobs.stream(job.name)

### Download the job output
MedImageParse embeddings can be found in file `named-outputs/score/predictions.csv`

In [None]:
scoring_job = list(ml_workspace.jobs.list(parent_job_name=job.name))[0]
scoring_job

In [None]:
ml_workspace.jobs.download(
    name=scoring_job.name, download_path=".", output_name="score"
)

### Load job result

In [None]:
pred_csv_path = os.path.join("named-outputs", "score", "predictions.csv")
df_result = pd.read_csv(pred_csv_path, header=None)
print("df_result.shape:", df_result.shape)
print(df_result.iloc[0])  # print first row

### Display job result

In [None]:
import base64
import json
import matplotlib.pyplot as plt
import numpy as np


def parse_image(json_encoded):
    """Decode an image pixel data array in JSON.
    Return image pixel data as an array.
    """
    # Parse the JSON string
    array_metadata = json.loads(json_encoded)
    # Extract Base64 string, shape, and dtype
    base64_encoded = array_metadata["data"]
    shape = tuple(array_metadata["shape"])
    dtype = np.dtype(array_metadata["dtype"])
    # Decode Base64 to byte string
    array_bytes = base64.b64decode(base64_encoded)
    # Convert byte string back to NumPy array and reshape
    array = np.frombuffer(array_bytes, dtype=dtype).reshape(shape)
    return array


def parse_labels(s):
    return json.loads(s.replace("'", '"'))


def convert_to_rgba(image_np):
    # Convert the image to 4 channels by adding an alpha channel
    alpha_channel = (
        np.ones((image_np.shape[0], image_np.shape[1], 1), dtype=image_np.dtype) * 255
    )
    image_rgba_np = np.concatenate((image_np, alpha_channel), axis=2)
    return image_rgba_np


def plot_segmentation_masks(original_image, segmentation_masks, labels):
    """Plot a list of segmentation mask over an image."""
    fig, ax = plt.subplots(1, len(segmentation_masks) + 1, figsize=(10, 5))
    ax[0].imshow(original_image)
    ax[0].set_title("Original Image")

    for i, mask in enumerate(segmentation_masks):
        ax[i + 1].imshow(original_image)
        ax[i + 1].set_title(labels[i])
        mask_temp = original_image.copy()
        mask_temp[mask > 128] = [255, 0, 0, 255]
        mask_temp[mask <= 128] = [0, 0, 0, 0]
        ax[i + 1].imshow(mask_temp, alpha=0.9)
    plt.show()

In [None]:
for index in range(len(df_input)):
    orig_image = convert_to_rgba(plt.imread(png_files[index]))
    result = df_result.iloc[index]

    image_features = parse_image(result.iloc[1])
    labels = parse_labels(result.iloc[2].replace("'", '"'))

    # # Plot feature over image
    print(f"Image {index}")
    plot_segmentation_masks(orig_image, image_features, labels)

## 5. Clean up resources - delete the batch endpoint

In [None]:
ml_workspace.batch_endpoints.begin_delete(endpoint_name).result()