In [None]:
#| default_exp grpc

In [None]:
#| hide
from nbdev.showdoc import *

# gRPC Client

In [None]:
#| export
import os
import grpc
import torchserve_client.proto.inference_pb2 as inference_pb2
import torchserve_client.proto.inference_pb2_grpc as inference_pb2_grpc
import torchserve_client.proto.management_pb2 as management_pb2
import torchserve_client.proto.management_pb2_grpc as management_pb2_grpc

In [None]:
#| exporti
class BaseClient:
    def __init__(self, base_url=None):
        base_url = base_url if base_url else os.environ.get('TORCHSERVE_URL', 'http://localhost')
        self.base_url = base_url.split('//')[1]
    
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(base_url={self.base_url})"
    
    def _filter_none_values(self, data):
        return {key: value for key, value in data.items() if value is not None}

In [None]:
#| exporti
class ManagementClient(BaseClient):
    def __init__(self, base_url=None, port=7071):
        super().__init__(base_url)
        self.port = port
        self.base_url = f"{self.base_url}:{self.port}"
        self.channel = grpc.insecure_channel(self.base_url)
        self.stub = management_pb2_grpc.ManagementAPIsServiceStub(self.channel)

    def describe_model(self, model_name, model_version=None, customized=None):
        method_args = {key: value for key, value in locals().items() if key != "self"}
        method_args = self._filter_none_values(method_args)
        request = management_pb2.DescribeModelRequest(model_name=model_name)
        return self.stub.DescribeModel(request)

    def list_models(self, limit=None, next_page_token=None):
        method_args = {key: value for key, value in locals().items() if key != "self"}
        method_args = self._filter_none_values(method_args)
        request = management_pb2.ListModelsRequest(**method_args)
        return self.stub.ListModels(request)

    def register_model(
        self,
        batch_size=None,
        handler=None,
        initial_workers=None,
        max_batch_delay=None,
        model_name=None,
        response_timeout=None,
        runtime=None,
        synchronous=None,
        url=None,
        s3_sse_kms=None,
    ):
        method_args = {key: value for key, value in locals().items() if key != "self"}
        method_args = self._filter_none_values(method_args)
        request = management_pb2.RegisterModelRequest(**method_args)
        return self.stub.RegisterModel(request)
    
    def scale_worker(
        self,
        model_name,
        min_worker=None,
        max_worker=None,
        model_version=None,
        number_gpu=None,
        synchronous=None,
        timeout=None,
    ):
        method_args = {key: value for key, value in locals().items() if key != "self"}
        method_args = self._filter_none_values(method_args)
        request = management_pb2.ScaleWorkerRequest(**method_args)
        return self.stub.ScaleWorker(request)

    def set_default(self, model_name, model_version):
        method_args = {key: value for key, value in locals().items() if key != "self"}
        method_args = self._filter_none_values(method_args)
        request = management_pb2.SetDefaultRequest(**method_args)
        return self.stub.SetDefault(request)

    def unregister_model(self, model_name, model_version=None):
        method_args = {key: value for key, value in locals().items() if key != "self"}
        method_args = self._filter_none_values(method_args)
        request = management_pb2.UnregisterModelRequest(**method_args)
        return self.stub.UnregisterModel(request)

In [None]:
#| exporti
class InferenceClient(BaseClient):
    def __init__(self, base_url=None, port=7070):
        super().__init__(base_url)
        self.port = port
        self.base_url = f"{self.base_url}:{self.port}"
        self.channel = grpc.insecure_channel(self.base_url)
        self.stub = inference_pb2_grpc.InferenceAPIsServiceStub(self.channel)

    def ping(self):
        return self.stub.Ping()

    def predictions(self, model_name, input_data, model_version=None):
        """
        input_data = {"data": data}
        """
        method_args = {key: value for key, value in locals().items() if key != "self"}
        method_args = self._filter_none_values(method_args)
        request = inference_pb2.PredictionsRequest(**method_args)     
        return self.stub.Predictions(request)

    def stream_predictions(self, model_name, input_data, model_version=None):
        """
        input_data = {"data": data}
        """
        method_args = {key: value for key, value in locals().items() if key != "self"}
        method_args = self._filter_none_values(method_args)
        request = inference_pb2.PredictionsRequest(**method_args)
        return self.stub.StreamPredictions(request)  

In [None]:
#| export
class TorchServeClientGRPC:
    def __init__(self, base_url=None, management_port=7071, inference_port=7070):
        self.management = ManagementClient(base_url, management_port)
        self.inference = InferenceClient(base_url, inference_port)

    def __repr__(self):
        url = self.management.base_url.rsplit(':', 1)[0]
        return f"TorchServeClientGRPC(base_url={url}, management_port={self.management.port}, inference_port={self.inference.port})"

To create a gRPC client, simply create a `TorchServeClientGRPC` object

In [None]:
#|eval: false
# Initialize the gRPC TorchServeClient object
ts_client = TorchServeClientGRPC()
ts_client

TorchServeClientGRPC(base_url=localhost, management_port=7071, inference_port=7070)

To customize base URL and default ports, pass them as arguments during initialization

In [None]:
#|eval: false
# Initialize the gRPC TorchServeClient object
ts_client = TorchServeClientGRPC(base_url='http://your-torchserve-server.com', 
                             management_port=7071, inference_port=7070)
ts_client

TorchServeClientGRPC(base_url=your-torchserve-server.com, management_port=7071, inference_port=7070)

## Management APIs

Here is the list of all the supported gRPC management endpoints:

- `describe_model`: Provide detailed information about the default version of a model

    **Arguments**:

    - `model_name` (str, required): Name of the model to describe

    - `model_version` (str, optional): Version of the model to describe

    - `customized` (bool, optional): Customized metadata

    **Usage**:

    ```python
    response = ts_client.management.describe_model(model_name="mnist")
    response.msg
    ```


- `list_models`: List all registered models in TorchServe

    **Arguments**:

    - `limit` (int, optional): Maximum number of items to return (default: 100).

    - `next_page_token` (int, optional): Token to retrieve the next set of results

    **Usage**:
    
    ```python
    response = ts_client.management.list_models()
    response.msg
    ```

- `register_model` : Register a new model to TorchServe

    **Arguments**:
    
    - `batch_size` (int, optional): Inference batch size (default: 1).
    
    - `handler` (str, optional): Inference handler entry-point.
    
    - `initial_workers` (int, optional): Number of initial workers (default: 0).
    
    - `max_batch_delay` (int, optional): Maximum delay for batch aggregation (default: 100).
    
    - `model_name` (str, optional): Name of the model.
    
    - `response_timeout` (int, optional): Maximum time for model response (default: 120 seconds).
    
    - `runtime` (str, optional): Runtime for model custom service code.
    
    - `synchronous` (bool, optional): Synchronous worker creation (default: False).
    
    - `url` (str, required): Model archive download URL.
    
    - `s3_sse_kms` (bool, optional): S3 SSE KMS enabled (default: False).

    **Usage**:

    ```python
    response = ts_client.management.register_model()
    response.msg
    ```    

- `scale_worker`: Configure the number of workers for a model. This is an asynchronous call by default

    **Arguments**:

    - `model_name` (str, required): Name of the model to scale workers.
    
    - `model_version` (str, optional): Model version.
    
    - `max_worker` (int, optional): Maximum number of worker processes.
    
    - `min_worker` (int, optional): Minimum number of worker processes.
    
    - `number_gpu` (int, optional): Number of GPU worker processes to create.
    
    - `synchronous` (bool, optional): Synchronous call (default: False).
    
    - `timeout` (int, optional): Wait time for worker completion (0: terminate immediately, -1: wait infinitely).

    **Usage**:

    ```python
    response = ts_client.management.scale_worker()
    response.msg
    ```    


- `set_default`: Set default version of a model

    **Arguments**:

    - `model_name` (str, required): Name of the model for which the default version should be updated
    
    - `model_version` (str, required): Version of the model to set as the default version

    **Usage**:

    ```python
    response = ts_client.management.set_default()
    response.msg
    ```


- `unregister_model`: Unregister a particular version of a model from TorchServe. This call is asynchronous by default.
    
    **Arguments**:
    
    - `model_name` (str, required): Name of the model to unregister.
    
    - `model_version` (str, optional): Version of the model to unregister. If none, then default version of the model will be unregistered.

    **Usage**:

    ```python
    response = ts_client.management.unregister_model()
    response.msg
    ```    

Check [`management.proto`](https://github.com/pytorch/serve/blob/master/frontend/server/src/main/resources/proto/management.proto) file to better understand the arguments of each method.

## Inference APIs

Here is a list gRPC inference endpoints:

- `ping`: Check Health Status

    **Usage**:

    ```python
    response = ts_client.inference.ping()
    response.health
    ```

- `predictions`: Get predictions
    
    **Arguments**:

    - `model_name` (str, required): Name of the model.

    - `model_version` (str, optional): Version of the model. If not provided, default version will be used.

    - `input` (Dict[str, bytes], required): Input data for model prediction

    **Usage**:
    
    ```python
    response = ts_client.inference.predictions(model_name="mnist", input={"data": data})
    response.prediction.decode("utf-8")
    ```

- `steam_predictions`: Get steaming predictions

    **Arguments**:
    
    - `model_name` (str, required): Name of the model.
    
    - `model_version` (str, optional): Version of the model. If not provided, default version will be used.
    
    - `input` (Dict[str, bytes], required): Input data for model prediction

    **Usage**:

    ```python
    response = ts_client.inference.stream_predictions(model_name="mnist", input={"data": data})
    response.prediction.decode("utf-8")
    ```

Again, for more detail about gRPC request and response objects, refer [`inference.proto`](https://github.com/pytorch/serve/blob/master/frontend/server/src/main/resources/proto/inference.proto).

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()