-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add example deployment scripts (#553)
- Loading branch information
Showing
10 changed files
with
1,686 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Model Deployment with BentoML and Triton Inference Server | ||
|
||
1. Install the required dependencies with poetry. | ||
```bash | ||
poetry install --with deploy | ||
``` | ||
2. Serialize trained model and move it to the `model_repo` directory. Then create | ||
a `config.pbtxt` file for the model. | ||
|
||
**Example - torchxrayvision model** | ||
```python | ||
import torch | ||
import torchxrayvision as xrv | ||
|
||
|
||
model = xrv.models.ResNet(weights="resnet50-res512-all").eval().cuda() | ||
|
||
dummy_input = (-1024 - 1024) * torch.rand(1, 1, 512, 512) + 1024 | ||
dummy_input = dummy_input.cuda() | ||
|
||
torch.jit.trace(model, dummy_input).save("model_repo/resnet50_res512_all/1/model.pt") | ||
``` | ||
See `model_repo/resnet50_res512_all/config.pbtxt` for an example of a pytorch model configuration file. | ||
|
||
**Example - sklearn model** | ||
```python | ||
from skl2onnx import to_onnx | ||
|
||
|
||
onnx_model = to_onnx( | ||
<sklearn_model>, | ||
<input_data>, | ||
options={"zipmap": False}, | ||
) | ||
with open("model_repo/<model_name>/1/model.onnx", "wb") as f: | ||
f.write(onnx_model.SerializeToString()) | ||
``` | ||
See `model_repo/heart_failure_prediction/config.pbtxt` for an example of an ONNX model configuration file. | ||
3. Create a service with BentoML with a triton runner. See `service.py` for an example. | ||
4. Define a bentofile to specify which files to include in the bento. See `bentofile.yaml` for an example. | ||
5. Build a bento. | ||
```bash | ||
bentoml build --do-not-track | ||
``` | ||
6. Containerize the bento. | ||
```bash | ||
bentoml containerize -t model-service:alpha --enable-features=triton --do-not-track model-service:latest | ||
``` | ||
|
||
7. Run the container with docker. | ||
```bash | ||
docker run -d --gpus=1 --rm -p 3000:3000 model-service:alpha | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
service: service:svc | ||
include: | ||
- /model_repo | ||
- /*.py | ||
exclude: | ||
- /__pycache__ | ||
python: | ||
packages: | ||
- bentoml[triton] | ||
- torchxrayvision==1.2.1 | ||
docker: | ||
base_image: nvcr.io/nvidia/tritonserver:24.01-py3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
monitoring: | ||
enabled: true | ||
type: default | ||
options: | ||
output_dir: ./monitoring |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
backend: "pytorch" | ||
name: "densenet121_res224_all" | ||
max_batch_size: 64 | ||
dynamic_batching { | ||
max_queue_delay_microseconds: 100 | ||
} | ||
input { | ||
name: "INPUT__0" | ||
data_type: TYPE_FP32 | ||
dims: 1 | ||
dims: 224 | ||
dims: 224 | ||
} | ||
output { | ||
name: "OUTPUT__0" | ||
data_type: TYPE_FP32 | ||
dims: -1 | ||
dims: 18 | ||
} | ||
instance_group [ | ||
{ | ||
count: 1 | ||
kind: KIND_GPU | ||
gpus: [0] | ||
} | ||
] | ||
model_warmup [ | ||
{ | ||
name : "random sample" | ||
count: 1 | ||
batch_size: 1 | ||
inputs { | ||
key: "INPUT__0" | ||
value: { | ||
data_type: TYPE_FP32 | ||
dims: [1, 224, 224] | ||
random_data: true | ||
} | ||
} | ||
} | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
backend: "onnxruntime" | ||
name: "heart_failure_prediction" | ||
max_batch_size: 0 | ||
input { | ||
name: "X" | ||
data_type: TYPE_FP32 | ||
dims: [-1, 21] | ||
} | ||
output { | ||
name: "label" | ||
data_type: TYPE_INT64 | ||
dims: -1 | ||
} | ||
instance_group [ | ||
{ | ||
count: 1 | ||
kind: KIND_CPU | ||
} | ||
] | ||
optimization { execution_accelerators { | ||
cpu_execution_accelerator : [ { | ||
name : "openvino" | ||
}] | ||
}} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
backend: "pytorch" | ||
name: "resnet50_res512_all" | ||
max_batch_size: 32 | ||
dynamic_batching { | ||
max_queue_delay_microseconds: 100 | ||
} | ||
input { | ||
name: "INPUT__0" | ||
data_type: TYPE_FP32 | ||
dims: 1 | ||
dims: 512 | ||
dims: 512 | ||
} | ||
output { | ||
name: "OUTPUT__0" | ||
data_type: TYPE_FP32 | ||
dims: -1 | ||
dims: 18 | ||
} | ||
instance_group [ | ||
{ | ||
count: 1 | ||
kind: KIND_GPU | ||
gpus: [0] | ||
} | ||
] | ||
model_warmup [{ | ||
name : "random sample" | ||
batch_size: 1 | ||
inputs { | ||
key: "INPUT__0" | ||
value: { | ||
data_type: TYPE_FP32 | ||
dims: [1, 512, 512] | ||
random_data: true | ||
} | ||
} | ||
}] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
"""Model serving service with Triton Inference Server as backend.""" | ||
|
||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING, Any, Literal | ||
|
||
import bentoml | ||
import numpy as np | ||
import torchxrayvision as xrv | ||
from torchvision import transforms | ||
|
||
|
||
if TYPE_CHECKING: | ||
from PIL.Image import Image | ||
|
||
|
||
def get_transform(image_size: int) -> transforms.Compose: | ||
"""Get image transformation for model inference.""" | ||
return transforms.Compose( | ||
[ | ||
xrv.datasets.XRayCenterCrop(), | ||
xrv.datasets.XRayResizer(image_size), | ||
], | ||
) | ||
|
||
|
||
triton_runner = bentoml.triton.Runner( | ||
"triton_runner", | ||
"src/model_repo", | ||
tritonserver_type="http", | ||
cli_args=[ | ||
"--exit-on-error=true", # exits if any error occurs during initialization | ||
"--http-restricted-api=model-repository:access-key=admin", # restrict access to load/unload APIs | ||
"--model-control-mode=explicit", # enable explicit model loading/unloading | ||
"--load-model=resnet50_res512_all", | ||
], | ||
) | ||
svc = bentoml.Service("model-service", runners=[triton_runner]) | ||
|
||
|
||
@svc.api( # type: ignore | ||
input=bentoml.io.Multipart(im=bentoml.io.Image(), model_name=bentoml.io.Text()), | ||
output=bentoml.io.JSON(), | ||
) | ||
async def classify_xray(im: Image, model_name: str) -> dict[str, float]: | ||
"""Classify X-ray image using specified model.""" | ||
img = np.asarray(im) | ||
img = xrv.datasets.normalize( | ||
img, | ||
img.max(), | ||
reshape=True, # normalize image to [-1024, 1024] | ||
) | ||
|
||
model_repo_index = await triton_runner.get_model_repository_index() | ||
available_models = [model["name"] for model in model_repo_index] | ||
if model_name not in available_models: | ||
raise bentoml.exceptions.InvalidArgument( | ||
f"Expected model name to be one of {available_models}, but got {model_name}", | ||
) | ||
|
||
img_size = 224 | ||
if "resnet" in model_name: | ||
img_size = 512 | ||
|
||
img = get_transform(img_size)(img) | ||
|
||
if len(img.shape) == 3: | ||
img = img[None] # add batch dimension | ||
|
||
InferResult = await getattr(triton_runner, model_name).async_run(img) # noqa: N806 | ||
return dict( | ||
zip(xrv.datasets.default_pathologies, InferResult.as_numpy("OUTPUT__0")[0]), | ||
) | ||
|
||
|
||
@svc.api( # type: ignore | ||
input=bentoml.io.NumpyNdarray(dtype="float32", shape=(-1, 21)), | ||
output=bentoml.io.NumpyNdarray(dtype="int64", shape=(-1,)), | ||
) | ||
async def predict_heart_failure(X: np.ndarray) -> np.ndarray: # type: ignore | ||
"""Run inference on heart failure prediction model.""" | ||
InferResult = await triton_runner.heart_failure_prediction.async_run( # noqa: N806 | ||
X, | ||
) | ||
return InferResult.as_numpy("label") # type: ignore[no-any-return] | ||
|
||
|
||
# Triton Model management API | ||
@svc.api(input=bentoml.io.JSON(), output=bentoml.io.JSON()) # type: ignore | ||
async def model_config(input_model: dict[Literal["model_name"], str]) -> dict[str, Any]: | ||
"""Retrieve model configuration from Triton Inference Server.""" | ||
return await triton_runner.get_model_config(input_model["model_name"]) # type: ignore | ||
|
||
|
||
@svc.api(input=bentoml.io.Text(), output=bentoml.io.JSON()) # type: ignore | ||
async def unload_model(input_model: str, ctx: bentoml.Context) -> dict[str, str]: | ||
"""Unload a model from memory.""" | ||
await triton_runner.unload_model( | ||
input_model, | ||
headers=ctx.request.headers, | ||
) # noqa: E501 | ||
return {"unloaded": input_model} | ||
|
||
|
||
@svc.api(input=bentoml.io.Text(), output=bentoml.io.JSON()) # type: ignore | ||
async def load_model(input_model: str, ctx: bentoml.Context) -> dict[str, str]: | ||
"""Load a model into memory.""" | ||
await triton_runner.load_model(input_model, headers=ctx.request.headers) | ||
return {"loaded": input_model} | ||
|
||
|
||
@svc.api(input=bentoml.io.Text(), output=bentoml.io.JSON()) # type: ignore | ||
async def list_models(_: str, ctx: bentoml.Context) -> list[str]: | ||
"""Return a list of models available in the model repository.""" | ||
return await triton_runner.get_model_repository_index(headers=ctx.request.headers) # type: ignore |
Oops, something went wrong.