# MLflow Segment Anything

Segment Anything MLflow model logging example.

## Setup environment

Conda: Python 3.12.

### Conda

Basically only Python is needed. Additionally install MLflow with Conda and Boto3 for S3 bucket access.

```shell
conda create -n mlflow-x -c conda-forge python==3.12 mlflow==2.17.0 boto3 python-dotenv
conda activate mlflow-x
```

Additionally install JupyterLab for experiments:
```shell
conda install jupyterlab nb_conda
```

### Pip

PyTorch CPU version (GPU version is not so tricky):

```shell
pip install -f https://download.pytorch.org/whl/torch torch==2.5.0+cpu
pip install -f https://download.pytorch.org/whl/torchvision torchvision==0.20.0+cpu
```

Segment Anything:

```shell
pip install git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588
```

OpenCV:

```shell
pip install opencv-python-headless
```

### Jupyter Lab

This notebook can be successfully launched in Jupyter Lab.

> NOTE: Jupyter cells may fail on timeout, run with increased limits, e. g. `jupyter lab --ServerApp.rate_limit_window=1440.0`

### Variables in .env

Python dotenv is used to set environment variables from `.env` that must reside in the notebook's working directory.

> Nobody sets environment variables in notebooks, right 😉

Mandatory variables (assume MLflow tracking server is a separate host):

```shell
MLFLOW_TRACKING_URI=http://host-or-ip:5000
```

Optional variables (assume MLflow artifact storage is configured as a separate S3 storage):
```shell
MLFLOW_S3_ENDPOINT_URL="https://storage.host.name"
AWS_ACCESS_KEY_ID="s3-key-id"
AWS_SECRET_ACCESS_KEY="s3-key-secret"
```


In [None]:
import pickle

from os import environ
from traceback import format_exc

from dotenv import load_dotenv


load_dotenv()  # take environment variables from .env (not overwrite existing)

# MLflow tracking host address: empty string defaults to ./mlruns directory
MLFLOW_TRACKING_URI = environ.get('MLFLOW_TRACKING_URI', '')  # or 'http://localhost:5000'
# assert MLFLOW_TRACKING_URI, f"Provide MLFLOW_TRACKING_URI environment variable or set above!"

PATH_IMAGE = 'test.jpg'

## Log Segment Anything model

In [None]:
import base64
import json
import logging

from datetime import datetime
from io import BytesIO
from os import getcwd, makedirs
from os import path as osp
from typing import Any, Dict, List, Tuple
from urllib.request import urlretrieve

import cv2 as cv
import mlflow
import numpy as np

from mlflow.models import ModelSignature, infer_signature
from mlflow.pyfunc import PythonModel, log_model
from mlflow.types.schema import ColSpec, DataType, Schema, TensorSpec
from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
from torchvision import transforms as T


logs = logging.getLogger("mlflow")
# logs.setLevel(logging.DEBUG)


# Class to wrap Segment Anything Model
class MLflowWrapperSAM(PythonModel):
    # shape_input_sam: Tuple[int, int] = (800, 800)  # FIXME: set real SAM imput shape
    min_mask_region_area: int = 100000

    def __init__(self, model, shape_input=None, mask_min=None):
        """
        Initializes the SAM wrapper.

        :param model: Pre-loaded SAM model instance.
        """
        self.model = model
        # self.shape_input_sam = shape_input or self.shape_input_sam
        self.min_mask_region_area = mask_min or self.min_mask_region_area

        # Set up SAM Generator and Predictor
        self.generator_sam = SamAutomaticMaskGenerator(self.model, min_mask_region_area=self.min_mask_region_area)
        self.predictor_sam = SamPredictor(self.model)

    def predict(self, context, model_input, params=None):
        # Read input data
        images_base64 = model_input['image']
        modes = model_input.get('mode', ['generator'] * len(images_base64))
        pos_points = model_input.get('pos_points', [None] * len(images_base64))
        neg_points = model_input.get('neg_points', [None] * len(images_base64))
        bboxes = model_input.get('bboxes', [None] * len(images_base64))

        # Debug messages
        try:
            logs.debug(
                "model input = %s\n%s", str(model_input.keys()), str({k: type(v) for k, v in model_input.items()})
            )
        except Exception as ex:
            logs.error("%s", str(ex))
        try:
            logs.debug("image size = %s", str([len(i) for i in images_base64]))
        except Exception as ex:
            logs.error("%s", str(ex))
        try:
            logs.debug("mode = %s", str([m for m in modes]))
        except Exception as ex:
            logs.error("%s", str(ex))
        try:
            logs.debug("pos_points = %s", str([pp for pp in pos_points]))
        except Exception as ex:
            logs.error("%s", str(ex))
        try:
            logs.debug("neg_points = %s", str([pn for pn in neg_points]))
        except Exception as ex:
            logs.error("%s", str(ex))
        try:
            logs.debug("bboxes = %s", str([bb for bb in bboxes]))
        except Exception as ex:
            logs.error("%s", str(ex))

        result = []

        for mode, image_base64, points_p, points_n, boxes in zip(modes, images_base64, pos_points, neg_points, bboxes):
            predictions = []

            # Decode image and convert to required format for generator
            image_data = base64.b64decode(image_base64)
            image_array = np.frombuffer(image_data, dtype=np.uint8)
            image = cv.imdecode(image_array, cv.IMREAD_COLOR)
            input_image = cv.cvtColor(image, cv.COLOR_BGR2RGB)

            # Debug messages
            logs.debug("image data is %s", str(type(image_data)))
            logs.debug("image array is %s", str(type(image_array)))
            logs.debug("image is %s", str(type(image)))
            try:
                logs.debug("image size = %s", str(image.size))
            except Exception as ex:
                logs.error("%s", str(ex))

            # Check mode and apply appropriate function
            if mode == 'generator':
                # Run the generator mode
                predictions = self.run_generator_mode(input_image)
            elif mode == 'predictor':
                # Run the predictor mode with points and bounding boxes.
                # Assume that pos_points and neg_points are of type:
                # np.array[[x, y], [x, y], ...]
                input_points = np.array(points_p)  # if isinstance(points_p, (list, tuple)) else None
                logs.debug(
                    "input positive points are %s",
                    str(input_points.shape if input_points is not None else type(input_points)),
                )
                if input_points is not None and input_points.shape[-1] == 2:
                    input_points = input_points.reshape(-1, 2)
                    input_labels = np.array(
                        [1] * len(input_points)
                    )  # if isinstance(input_points, (list, tuple)) else None
                    logs.debug(
                        "input positive labels are %s",
                        str(input_labels.shape if input_labels is not None else type(input_labels)),
                    )
                    points_n = np.array(points_n)  # if isinstance(points_n, (list, tuple)) else None
                    logs.debug(
                        "input negative points are %s", str(points_n.shape if points_n is not None else type(points_n))
                    )
                    if points_n is not None and points_n.shape[-1] == 2:
                        points_n = points_n.reshape(-1, 2)
                        input_points = np.concatenate([input_points, points_n], axis=0)
                        input_labels = np.concatenate([input_labels, np.array([0] * len(points_n))], axis=0)
                        logs.debug(
                            "input positive and negative labels are %s",
                            str(input_labels.shape if input_labels is not None else type(input_labels)),
                        )
                else:
                    logs.debug("no valid points found (predictor mode)! Skip...")
                    input_points = None
                    input_labels = None

                boxes = np.array(boxes)
                logs.debug("boxes are %s", str(boxes.shape if boxes is not None else type(boxes)))
                predictions = self.run_predictor_mode(input_image, input_points, input_labels, boxes)
                logs.debug("predictions are %s of size %d", str(type(predictions)), len(predictions))
            else:
                raise ValueError("Invalid mode specified. Use 'generator' or 'predictor'.")

            # Postprocess predictions
            for prediction in predictions:
                # Project predictions to original image
                size_source = input_image.shape[1::-1]
                size_target = prediction['crop_box'][2:4]
                ratio = size_target[0] / size_source[0], size_target[1] / size_source[1]
                # Project mask and its area
                mask = np.array(prediction['segmentation'], dtype=np.uint8) * np.uint8(255)
                mask = cv.resize(mask, size_source, interpolation=cv.INTER_NEAREST)
                prediction['segmentation'] = mask  # push back resized mask to predictions
                prediction['area'] = prediction['area'] * ratio[0] * ratio[1]  # update mask area
                # Extra: replace mask with contours
                polygon, hierarchy = cv.findContours(mask, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)
                prediction['polygon'] = [p.tolist() for p in polygon]  # create polygon entry in predictions
                logs.debug("polygon = %s", str(polygon))
                # prediction['segmentation'] = None  # delete raster masks from output
                del prediction['segmentation']
                # Project bbox
                bbox = prediction['bbox']
                bbox[0] = (bbox[0] + prediction['crop_box'][0]) / ratio[0]
                bbox[1] = (bbox[1] + prediction['crop_box'][1]) / ratio[1]
                bbox[2] /= ratio[0]
                bbox[3] /= ratio[1]
                prediction['bbox'] = bbox  # push bask resized bboxes to predictions
                # Project points
                logs.debug("point coords = %s", str(prediction['point_coords']))
                for point in prediction['point_coords']:
                    # TODO: point relative to prediction['crop_box'] upper left point
                    point[0] /= ratio[0]
                    point[1] /= ratio[1]
                prediction['logits'] = prediction.get('logits', None)  # default value if no exists
            result.append(predictions)
        return result

    def run_generator_mode(self, input_image: np.ndarray) -> List[Dict[str, Any]]:
        logs.debug("(generator mode): input image shape = %s", str(input_image.shape))
        masks = self.generator_sam.generate(input_image)
        logs.debug("masks are %s", str(type(masks)))
        try:
            logs.debug("masks = %s", str([type(m) for m in masks]))
        except Exception as ex:
            logs.error("%s", str(ex))
        return masks  # return masks directly

    def run_predictor_mode(self, input_image, points=None, labels=None, bboxes=None):
        logs.debug("(predictor mode): input image shape = %s", str(input_image.shape))
        self.predictor_sam.set_image(input_image)

        # Predict based on points and/or bounding boxes
        masks = []
        if points is not None:
            if labels is None:
                point_labels = np.array([1] * len(points))
            else:
                point_labels = labels
            m, s, l = self.predictor_sam.predict(point_coords=points, point_labels=point_labels, multimask_output=False)
            masks.extend(self.unify_predictor_output(m, s, l, points))
        if bboxes is not None:
            for bbox in bboxes:
                m, s, l = self.predictor_sam.predict(box=bbox)
                # Surrogate points from the center of the bbox
                points = int((bbox[0] + bbox[2]) / 2), int((bbox[1] + bbox[3]) / 2)
                masks.extend(self.unify_predictor_output(m, s, l, [points]))

        return masks  # return masks in predictor mode

    def unify_predictor_output(self, masks, scores, logits=None, point_coords=None):
        """
        Converts the output of SamPredictor.predict to match the output format of SamAutomaticMaskGenerator.generate.

        Args:
            masks (np.ndarray): Array of predicted masks with shape (num_masks, height, width).
            scores (np.ndarray): Array of confidence scores for each mask.
            logits (np.ndarray, optional): Array of logits (may be used as inputs to SamPredictor.predict method).
            point_coords (list(list(float)), optional): List of input point coordinates.

        Returns:
            list[dict]: Unified output format as a list of dictionaries matching SamAutomaticMaskGenerator.generate.
        """
        unified_output = []

        for i, mask in enumerate(masks):
            # Calculate bounding box (bbox) for each mask
            y_indices, x_indices = mask.nonzero()
            if y_indices.size > 0 and x_indices.size > 0:
                x_min, x_max = x_indices.min(), x_indices.max()
                y_min, y_max = y_indices.min(), y_indices.max()
                bbox = [int(x_min), int(y_min), int(x_max - x_min), int(y_max - y_min)]
            else:
                bbox = [0, 0, 0, 0]
            h, w = mask.shape[-2:]
            crop_box = [0, 0, int(w), int(h)]

            # Calculate area (number of non-zero pixels in the mask)
            area = int(mask.sum())

            # Add a placeholder for stability_score (can be computed more precisely if logits are available)
            stability_score = float(scores[i])  # placeholder; use score as an approximate stability score
            predicted_iou = 0.95  # TODO: approximate
            try:
                lrm = logits[i]
            except Exception as ex:
                logs.warning("%s", str(ex))
                lrm = None
            point_coords_output = np.array(point_coords if point_coords is not None else [[0, 0]]).tolist()
            logs.debug("unification points = %s", str(point_coords_output))

            # Append to output in unified format
            unified_output.append(
                {
                    'segmentation': mask,  # binary mask for the object
                    'bbox': bbox,  # bounding box in [x_min, y_min, width, height]
                    'area': area,  # area of the mask in pixels
                    'predicted_iou': predicted_iou,  # approximated with score
                    'point_coords': point_coords_output,  # input points used for mask generation
                    'stability_score': stability_score,  # placeholder stability score
                    'crop_box': crop_box,  # crop box in [x_min, y_min, x_max, y_max]
                    'logits': lrm,  # SamPredictor low-res masks used by CVAT
                }
            )

        return unified_output


# Define the required environment packages
conda_env_sam = {
    'name': 'base',
    'channels': ['conda-forge'],
    'dependencies': [
        'python=3.12',
        'boto3',  # this may be required to download from S3
        {
            'pip': [
                'opencv-python-headless',
                '-f https://download.pytorch.org/whl/torch torch==2.5.0+cpu',
                '-f https://download.pytorch.org/whl/torchvision torchvision==0.20.0+cpu',
                'git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588',
            ]
        },
    ],
}

mlflow.set_tracking_uri(f"{MLFLOW_TRACKING_URI}")
mlflow.set_experiment("Segment-Anything")


# Define model URLs
MODELS = {
    'vit_h': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth',
    'vit_l': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
    'vit_b': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth',
}


def get_model_weights(model_name, download_path='models'):
    # Ensure the download directory exists
    makedirs(download_path, exist_ok=True)

    # Define the full path for the weights
    weights_path = osp.join(download_path, osp.basename(MODELS[model_name]))

    # Check if the file exists; if not, download it
    if not osp.isfile(weights_path):
        print(f"Downloading {model_name} model weights...")
        urlretrieve(MODELS[model_name], weights_path)

    return weights_path


PREFIX_MODELS = osp.realpath(osp.join(getcwd(), 'models'))
makedirs(PREFIX_MODELS, exist_ok=True)

assert osp.isdir(PREFIX_MODELS), f"check {PREFIX_MODELS} exists!"

NAME_PROJECT = 'MLflow'
NAME_TASK = 'Segment-Anything'
NAME_SUBTASK = 'SAM'

models_registered = {}

# Load and encode image to base64
with open(PATH_IMAGE, 'rb') as image_file:
    image_base64 = base64.b64encode(image_file.read()).decode('utf-8')

# Define the input schema
input_schema = Schema.from_json(
    '[{"type": "string", "name": "mode", "required": false}, {"type": '
    '"string", "name": "image", "required": true}, {"type": "array", '
    '"items": {"type": "array", "items": {"type": "long"}}, "name": '
    '"bboxes", "required": false}, {"type": "array", "items": {"type": '
    '"array", "items": {"type": "long"}}, "name": "pos_points", '
    '"required": false}, {"type": "array", "items": {"type": "array", '
    '"items": {"type": "long"}}, "name": "neg_points", "required": '
    'false}]'
)

# Define the output schema
output_schema = Schema.from_json(
    '[{"type": "array", "items": {"type": "object", "properties": '
    '{"area": {"type": "double", "required": true}, "bbox": {"type": '
    '"array", "items": {"type": "double"}, "required": true}, '
    '"crop_box": {"type": "array", "items": {"type": "long"}, '
    '"required": true}, "logits": {"type": "array", "items": {"type": '
    '"array", "items": {"type": "float"}}, "required": false}, '
    '"point_coords": {"type": "array", "items": {"type": "array", '
    '"items": {"type": "double"}}, "required": true}, "polygon": '
    '{"type": "array", "items": {"type": "array", "items": {"type": '
    '"array", "items": {"type": "array", "items": {"type": "long"}}}}, '
    '"required": true}, "predicted_iou": {"type": "double", '
    '"required": true}, "stability_score": {"type": "double", '
    '"required": true}}}, "required": true}]'
)

signature = ModelSignature(inputs=input_schema, outputs=output_schema)
# signature = None

# Prepare the model input example
# data = {'mode': 'generator', 'image': image_base64}
data = {
    'mode': 'predictor',
    'image': image_base64,
    'bboxes': [[200, 200, 600, 600]],
    'pos_points': [[380, 380], [420, 420]],
    'neg_points': [[180, 180], [620, 620]],
}

for ms in ('b', 'l', 'h'):
    # Iterate over model sizes
    NAME_VERSION = f"{ms}+cpu"
    NAME_MODEL_SAM = f"{NAME_PROJECT}-{NAME_TASK}-{NAME_SUBTASK}-{NAME_VERSION}"

    PATH_ARTIFACT = 'segment-anything-sam'

    TIMESTAMP = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")

    model_name = f"vit_{ms}"
    path_model_weights = get_model_weights(model_name, PREFIX_MODELS)  # download weights

    # Instantiate the model with weights and wrap it for MLflow
    model = sam_model_registry[model_name](checkpoint=path_model_weights)
    wrapped_model = MLflowWrapperSAM(model)

    # raise KeyboardInterrupt
    with mlflow.start_run(run_name=f"{NAME_MODEL_SAM}-{TIMESTAMP}"):
        log_model(
            artifact_path=PATH_ARTIFACT,
            python_model=wrapped_model,
            conda_env=conda_env_sam,
            signature=signature,
            input_example=data,
        )
        mlflow.set_tags({'project': NAME_PROJECT, 'task': NAME_TASK, 'model': NAME_SUBTASK, 'version': NAME_VERSION})
        run_id = mlflow.active_run().info.run_id
    model_uri = f"runs:/{run_id}/{PATH_ARTIFACT}"
    models_registered[model_name] = mlflow.register_model(model_uri, NAME_MODEL_SAM)

### Check logged models

In [None]:
models_registered

In [None]:
NAME_MODEL_SAM

In [None]:
from mlflow.pyfunc import load_model


model = load_model(f"models:/{NAME_MODEL_SAM}/latest")

print(model.metadata.signature)

In [None]:
import json

json.loads(model.metadata.to_json())

### Inferred signatures

As Python `object`:
```
inputs: 
  ['mode': string (optional), 'image': string (required), 'bboxes': Array(Array(long)) (optional), 'pos_points': Array(Array(long)) (optional), 'neg_points': Array(Array(long)) (optional)]
outputs: 
  [Array({area: double (required), bbox: Array(double) (required), crop_box: Array(long) (required), logits: Array(Array(float)) (optional), point_coords: Array(Array(double)) (required), polygon: Array(Array(Array(Array(long)))) (required), predicted_iou: double (required), stability_score: double (required)}) (required)]
params: 
  None
```

As Python `dict`:
```python
{'inputs': '[{"type": "string", "name": "mode", "required": false}, {"type": '
           '"string", "name": "image", "required": true}, {"type": "array", '
           '"items": {"type": "array", "items": {"type": "long"}}, "name": '
           '"bboxes", "required": false}, {"type": "array", "items": {"type": '
           '"array", "items": {"type": "long"}}, "name": "pos_points", '
           '"required": false}, {"type": "array", "items": {"type": "array", '
           '"items": {"type": "long"}}, "name": "neg_points", "required": '
           'false}]',
 'outputs': '[{"type": "array", "items": {"type": "object", "properties": '
            '{"area": {"type": "double", "required": true}, "bbox": {"type": '
            '"array", "items": {"type": "double"}, "required": true}, '
            '"crop_box": {"type": "array", "items": {"type": "long"}, '
            '"required": true}, "logits": {"type": "array", "items": {"type": '
            '"array", "items": {"type": "float"}}, "required": true}, '
            '"point_coords": {"type": "array", "items": {"type": "array", '
            '"items": {"type": "double"}}, "required": true}, "polygon": '
            '{"type": "array", "items": {"type": "array", "items": {"type": '
            '"array", "items": {"type": "array", "items": {"type": "long"}}}}, '
            '"required": true}, "predicted_iou": {"type": "double", '
            '"required": true}, "stability_score": {"type": "double", '
            '"required": true}}}, "required": true}]',
 'params': None}
```

In [None]:
# from pprint import pprint


# pprint(model.metadata.signature.to_dict())

### Stop here before continue with MLflow serving

In [None]:
raise KeyboardInterrupt

## Predict with MLflow serving

MLflow serving must be started manually on any suitable machine. Let's follow some key assumptions:

1. MLflow serving is going to be deployed on localhost
2. MLflow tracking server is a separate host (not localhost)
3. MLflow artifact storage is an S3-compatible storage (bucket)
4. Variables `MLFLOW_TRACKING_URI`, `MLFLOW_S3_ENDPOINT_URL`, `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` are set in `.env` file

**Shell script example** (make sure there only variables `VARIABLE=value`):

```shell
#!/usr/bin/env bash

# Check if .env file exists
if [ ! -f .env ]; then
    echo "Error: .env file not found!"
    exit 1
fi

# Load environment variables from .env file
while IFS= read -r line; do
    # Skip empty lines and comments
    [[ -z "$line" || "$line" =~ ^# ]] && continue
    # Export the variable
    export "$line"
done < .env

# Default model name
MODEL="MLflow-Segment-Anything-SAM-b+cpu/latest"
# First argument is a model name
MODEL=${1:-$MODEL}
# Second argument is a port number
PORT=${2:-10055}

# Env manager set to `local` implies that MLflow will not try
# to create any enviroment, but use local active environment (conda, pyenv, venv, etc)
mlflow models serve -m "models:/$MODEL" -p "$PORT" --env-manager local --timeout 1440
```

Assume script name is `start-mlflow`:

```shell
chmod +x start-mlflow

./start-mlflow
```

In [None]:
if 'NAME_MODEL_SAM' not in locals():
    NAME_MODEL_SAM = 'MLflow-Segment-Anything-SAM-b+cpu'

## Automatic Mask Generator

Predict using SAM Automatic Mask Generation mode.

### Input data

In [None]:
import base64
import json
import requests
from PIL import Image
from io import BytesIO

# PATH_IMAGE = 'test.jpg'

# Load and encode image to base64
with open(PATH_IMAGE, 'rb') as image_file:
    image_base64 = base64.b64encode(image_file.read()).decode('utf-8')

# Prepare the payload
data = {'mode': 'generator', 'image': image_base64}
payload = {'instances': [data]}

# Send the request to the MLflow serving endpoint
url = 'http://localhost:10055/invocations'
headers = {'Content-Type': 'application/json'}
response = requests.post(url, headers=headers, data=json.dumps(payload))

### Visualize results

This is an example how to interpret MLflow SAM serving results (generator mode).

In [None]:
from random import randint

import cv2 as cv
import numpy as np

from matplotlib import pyplot as plt


# Check if response is successful
if response.status_code == 200:
    # Decode response
    response_data = response.json()
    if isinstance(response_data, dict) and 'predictions' in response_data:
        for predictions in response_data['predictions']:
            # List response_data['predictions'] is a batch of images from payload['instances']
            image = plt.imread(PATH_IMAGE).copy()
            for prediction in predictions:
                color = (randint(15, 240), randint(15, 240), randint(15, 240))
                polygon = prediction['polygon']
                image = cv.drawContours(image, tuple(map(np.array, polygon)), -1, color, 2, cv.LINE_AA)
                for point in prediction['point_coords']:
                    image = cv.circle(image, tuple(map(round, point)), 3, (255, 255, 255), 1, cv.LINE_AA)
                bbox = prediction['bbox']
                p1 = round(bbox[0]), round(bbox[1])
                p2 = round(bbox[0] + bbox[2]), round(bbox[1] + bbox[3])
                image = cv.rectangle(image, p1, p2, (5, 5, 5), 1, cv.LINE_AA)
            plt.figure(figsize=(1280 / 72, 960 / 72), dpi=72)
            plt.imshow(image)
            plt.axis('off')
            plt.show()
    else:
        print(f"Unexpected response format: {type(response_data)}")
else:
    print(f"Request failed with status code {response.status_code}: {response.text}")

In [None]:
try:
    print(response_data['predictions'][0][0].keys())
except:
    ...

## Predictor

Predict using SAM Prediction mode.

### Input data

In [None]:
import base64
import json
import requests
from PIL import Image
from io import BytesIO

# PATH_IMAGE = 'test.jpg'

# Load and encode image to base64
with open(PATH_IMAGE, 'rb') as image_file:
    image_base64 = base64.b64encode(image_file.read()).decode('utf-8')

# Prepare the payload
data = {
    'mode': 'predictor',
    'image': image_base64,
    'bboxes': [[200, 200, 600, 600]],
    'pos_points': [[380, 380], [420, 420]],
    'neg_points': [[180, 180], [620, 620]],
}
payload = {'instances': [data]}

# Send the request to the MLflow serving endpoint
url = 'http://localhost:10055/invocations'
headers = {'Content-Type': 'application/json'}
response = requests.post(url, headers=headers, data=json.dumps(payload))

### Visualize results

This is an example how to interpret MLflow SAM serving results (generator mode).

In [None]:
from random import randint

import cv2 as cv
import numpy as np

from matplotlib import pyplot as plt


# Check if response is successful
if response.status_code == 200:
    # Decode response
    response_data = response.json()
    if isinstance(response_data, dict) and 'predictions' in response_data:
        for predictions in response_data['predictions']:
            # List response_data['predictions'] is a batch of images from payload['instances']
            image = plt.imread(PATH_IMAGE).copy()
            for prediction in predictions:
                color = (randint(15, 240), randint(15, 240), randint(15, 240))
                polygon = prediction['polygon']
                image = cv.drawContours(image, tuple(map(np.array, polygon)), -1, color, 2, cv.LINE_AA)
                for point in prediction['point_coords']:
                    image = cv.circle(image, tuple(map(round, point)), 3, (255, 255, 255), 1, cv.LINE_AA)
                bbox = prediction['bbox']
                p1 = round(bbox[0]), round(bbox[1])
                p2 = round(bbox[0] + bbox[2]), round(bbox[1] + bbox[3])
                image = cv.rectangle(image, p1, p2, (5, 5, 5), 1, cv.LINE_AA)
            plt.figure(figsize=(1280 / 72, 960 / 72), dpi=72)
            plt.imshow(image)
            plt.axis('off')
            plt.show()
    else:
        print(f"Unexpected response format: {type(response_data)}")
else:
    print(f"Request failed with status code {response.status_code}: {response.text}")

In [None]:
try:
    print(response_data['predictions'][0][0].keys())
except:
    ...

# DEBUG