# 2D Segmentation Task with Triton Inference Service

This guide will walk you through tailoring the FLaVor inference service for 2D segmentation toy task using Triton Inference Server.

## Prerequisites
Ensure you have the following dependencies installed:

```
python >= 3.10
```

## Implementation

### Setup Imports

In [1]:
import os
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple

import cv2
import numpy as np

from flavor.serve.apps import InferAPP
from flavor.serve.inference.data_models.api import (
    BaseAiCOCOImageInputDataModel,
    BaseAiCOCOImageOutputDataModel,
)
from flavor.serve.inference.data_models.functional import AiImage
from flavor.serve.inference.inference_models import (
    BaseAiCOCOImageInferenceModel,
    TritonInferenceModel,
    TritonInferenceModelSharedSystemMemory,
)
from flavor.serve.inference.strategies import AiCOCOSegmentationOutputStrategy

  from .autonotebook import tqdm as notebook_tqdm


## Retrieve Model Information on Triton Inference Server
Firstly initiate `TritonInferenceModel` with corresponding url and model name. Then, access `model_state` attribute and you can obtain information of the deployed model on triton inference server.

In [2]:
triton_model = TritonInferenceModel(
    triton_url="triton:8000", model_name="toyseg", model_version="",
)
triton_model.model_state

{'name': 'toyseg',
 'version': '1',
 'state': 'READY',
 'backend': 'onnxruntime',
 'max_batch_size': 4,
 'devices': ['KIND_GPU'],
 'inputs': [{'name': 'input', 'data_type': 'TYPE_FP32', 'dims': [3, -1, -1]}],
 'outputs': [{'name': 'logits',
   'data_type': 'TYPE_FP32',
   'dims': [-1, -1, -1]}]}

### Setup inference model

To perform inference with models on Triton Inference Server, you can specify the network names, versions, and the remote host URL by using the `TritonInferenceModel` class. For a more memory-efficient setup, particularly if you are hosting the Triton Inference Server on the same machine, you can use the `TritonInferenceModelSharedSystemMemory` class.

To dive in more about the implementation, we would create `SegmentationTritonInferenceModel` inheriting from `BaseAiCOCOImageInferenceModel`. There are few abstract methods that we must override such as `define_inference_network`, `set_categories`, `set_regressions`, `data_reader` and `output_formatter`. For the inference process related methods such as `preprocess`, `inference` and `postprocess`, we override them if necessary. Here in this example, the input to the model on triton inference server must depend on the model information. Referring to the above block, we know that the input should be a dictionary with key of `input` and corresponding value in type of `TYPE_FP32` and dimension of `[3, -1, -1]` where -1 means any number. 

Next, we need to implement submethods: `define_inference_network`, `set_categories` and `set_regressions`. These are defined in the `__init__()` constructor of the parent class `BaseAiCOCOImageInferenceModel`. `define_inference_network` defines your inference network and loads its pre-trained weight. Here, we call `TritonInferenceModel` or `TritonInferenceModelSharedSystemMemory` depending on the setting, and provide a few arguments to specify the model we want to use. `set_categories` and `set_regressions` define category and regression information. For example, a segmentation output would contain `c` channels. We need to show the exact meaning of each channel by specifying in `set_categories`. Refer to the following example for more detail.

Next, we implement other submethods that would be used in the `__call__` function of our inference model. See below workflow.

### `__call__` function workflow for the inference model
![__call__](images/call.png "inference workflow")

In [None]:
class SegmentationTritonInferenceModel(BaseAiCOCOImageInferenceModel):
    def __init__(
        self,
        triton_url: str = "triton:8000",
        model_name: str = "toyseg",
        model_version: str = "",
        is_shared_memory: bool = False,
    ):
        self.formatter = AiCOCOSegmentationOutputStrategy()

        self.triton_url = triton_url
        self.model_name = model_name
        self.model_version = model_version
        self.is_shared_memory = is_shared_memory
        super().__init__()

    def define_inference_network(self) -> Callable:
        if self.is_shared_memory:
            return TritonInferenceModelSharedSystemMemory(
                self.triton_url, self.model_name, self.model_version
            )
        else:
            return TritonInferenceModel(self.triton_url, self.model_name, self.model_version)

    def set_categories(self) -> List[Dict[str, Any]]:
        categories = [
            {"name": "Background", "display": False},
            {"name": "Foreground 1", "display": True},
            {"name": "Foreground 2", "display": True},
        ]
        return categories

    def set_regressions(self) -> None:
        return None

    def data_reader(self, files: Sequence[str], **kwargs) -> Tuple[np.ndarray, None, None]:
        img = cv2.imread(files[0])
        return img, None, None

    def preprocess(self, data: np.ndarray) -> np.ndarray:
        data = cv2.resize(data, (256, 256), interpolation=cv2.INTER_NEAREST)
        data = np.transpose(data, (2, 0, 1))  # h, w, c -> c, h, w
        data = np.expand_dims(data, axis=0)  # c, h, w -> 1, c, h, w
        return data

    def inference(self, x: np.ndarray) -> Dict[str, np.ndarray]:
        return self.network.forward({"input": x})

    def postprocess(
        self, out_dict: Dict[str, np.ndarray], metadata: Optional[Any] = None
    ) -> np.ndarray:
        out = out_dict["logits"][0]  # 1, c, h, w -> c, h, w
        onehot_out = np.zeros_like(out, dtype=np.int8)
        out = np.argmax(out, axis=0)
        for i in range(len(onehot_out)):
            onehot_out[i] = out == i
        return onehot_out

    def output_formatter(
        self,
        model_out: np.ndarray,
        images: Sequence[AiImage],
        categories: Sequence[Dict[str, Any]],
        **kwargs
    ) -> BaseAiCOCOImageOutputDataModel:

        output = self.formatter(model_out=model_out, images=images, categories=categories)
        return output

### Integration with InferAPP
We could integrate our defined inference model with FLaVor `InferAPP`, a FastAPI application. To initiate the application, users have to define `input_data_model` and `output_data_model` which are the standard input and output structure for the service. Then, provide `infer_function` as the main inference operation. After initiate the service, `/invocations` API end point would be available to process the inference request. We encourge users to implement a stand-alone python script based on this jupyter notebook tutorial.

In [None]:
# This block is only for jupyter notebook. You don't need this in stand-alone script.
import nest_asyncio
nest_asyncio.apply()

#### Initiate the service

In [None]:
app = InferAPP(
    infer_function=SegmentationTritonInferenceModel(triton_url="triton.user-hannchyun-chen:8000", model_name="toyseg"),
    input_data_model=BaseAiCOCOImageInputDataModel,
    output_data_model=BaseAiCOCOImageOutputDataModel,
)

In [None]:
app.run(port=int(os.getenv("PORT", 9111)))

### Send request
We can send request to the running server by `send_request.py` which opens the input files and the coresponding JSON file and would be sent via formdata. We expect to have response in AiCOCO format.
```bash
# pwd: examples/inference
python send_request.py -f test_data/seg/300.png -d test_data/seg/input_seg.json
```

## Setup Dockerfile
In order to interact with other services, we have to wrap the inference model into a docker container. Here's an example of the dockerfile.

```dockerfile
FROM nvidia/cuda:12.2.2-runtime-ubuntu20.04

RUN apt-get update \
    && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
        python3 \
        python3-pip \
    && ln -sf /usr/bin/python3 /usr/bin/python

RUN pip install https://github.com/ailabstw/FLaVor/archive/refs/heads/release/stable.zip -U && pip install flavor

WORKDIR /app

COPY your_script.py  /app/

CMD ["python", "your_script.py"]

```