In [None]:
#| default_exp rest

In [None]:
#| hide
from nbdev.showdoc import *

# REST Client

In [None]:
#| exporti
import os
import requests

In [None]:
#| exporti
class BaseClient:
    def __init__(self, base_url=None):
        base_url = base_url if base_url else os.environ.get('TORCHSERVE_URL', 'http://localhost')
        self.base_url = base_url.rstrip('/')

    def _make_request(self, method, endpoint, json=None, params=None, files=None):
        """
        json: dict. The JSON body of the request.
        params: dict. The URL parameters of the request.
        files: [dict]. The files to upload.
        """
        url = f"{self.base_url}{endpoint}"
        response = requests.request(method, url, json=json, params=params, files=files)
        response.raise_for_status()
        return response.json()
    
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(base_url={self.base_url})"
    
    def _filter_none_values(self, data):
        return {key: value for key, value in data.items() if value is not None}

In [None]:
#| exporti
class ManagementClient(BaseClient):
    def __init__(self, base_url=None, port=8081):
        super().__init__(base_url)
        self.port = port
        self.base_url = f"{self.base_url}:{self.port}"

    def register_model(self, url, model_name=None, handler=None, runtime=None,
                       batch_size=1, max_batch_delay=100, initial_workers=0,
                       synchronous=False, response_timeout=120):
        data = {
            'url': url,
            'model_name': model_name,
            'handler': handler,
            'runtime': runtime,
            'batch_size': batch_size,
            'max_batch_delay': max_batch_delay,
            'initial_workers': initial_workers,
            'synchronous': synchronous,
            'response_timeout': response_timeout
        }
        data = self._filter_none_values(data)
        return self._make_request('POST', '/models', json=data)

    def scale_workers(self, model_name, version=None, min_worker=1, max_worker=None,
                      synchronous=False, timeout=-1):
        if version:
            endpoint = f"/models/{model_name}/{version}"
        else:
            endpoint = f"/models/{model_name}"

        params = {
            'min_worker': min_worker,
            'max_worker': max_worker if max_worker is not None else min_worker,
            'synchronous': synchronous,
            'timeout': timeout
        }

        return self._make_request('PUT', endpoint, json=params)
    
    def describe_model(self, model_name, version=None, customized=False):
        """
        Returns the model description.
        version :  is optional. if `all` return status of all version of a model. If not provided, the latest version will be returned.
        allowed
        """
        params = {}
        if customized:
            params['customized'] = customized
            
        if version:
            endpoint = f"/models/{model_name}/{version}"
        else:
            endpoint = f"/models/{model_name}"
        return self._make_request('GET', endpoint, params=params)
    
    def unregister_model(self, model_name, version=None):
        if version:
            endpoint = f"/models/{model_name}/{version}"
        else:
            endpoint = f"/models/{model_name}"

        return self._make_request('DELETE', endpoint)
    
    def list_models(self, limit=100, next_page_token=None):
        params = {
            'limit': limit,
            'next_page_token': next_page_token
        }
        params = self._filter_none_values(params)
        return self._make_request('GET', '/models', params=params)
    
    def api_description(self):
        return self._make_request('OPTIONS', '/')
    
    def set_default_version(self, model_name, version):
        endpoint = f"/models/{model_name}/{version}/set-default"
        return self._make_request('PUT', endpoint)

In [None]:
#| exporti
class InferenceClient(BaseClient):
    def __init__(self, base_url=None, port=8080):
        super().__init__(base_url)
        self.port = port
        self.base_url = f"{self.base_url}:{self.port}"

    def api_description(self):
        return self._make_request('OPTIONS', '/')
    
    def health_check(self):
        return self._make_request('GET', '/ping')
    
    def prediction(self, model_name, data, version=None):
        """
        data = [
            ('data', open('docs/images/dogs-before.jpg', 'rb')),
            ('data', open('docs/images/kitten_small.jpg', 'rb')),
        ]
        """
        if version:
            endpoint = f"/predictions/{model_name}/{version}"
        else:
            endpoint = f"/predictions/{model_name}"

        return self._make_request('POST', endpoint, files=data)
    
    def explaination(self, model_name, data):
        """
        data <string : bytes>
        """
        endpoint = f"/explanations/{model_name}"
        return self._make_request('POST', endpoint, files=data)

In [None]:
#| export
class TorchServeClientREST:
    def __init__(self, base_url=None, management_port=8081, inference_port=8080):
        self.management = ManagementClient(base_url, management_port)
        self.inference = InferenceClient(base_url, inference_port)

    def __repr__(self):
        url = self.management.base_url.rsplit(':', 1)[0]
        return f"TorchServeClientREST(base_url={url}, management_port={self.management.port}, inference_port={self.inference.port})"

To make calls to REST endpoint, simply initialize a `TorchServeClientREST` object as shown below:

In [None]:
#|eval: false
# Initialize the REST TorchServeClient object
ts_client = TorchServeClientREST()
ts_client

TorchServeClientREST(base_url=http://localhost, management_port=8081, inference_port=8080)

If you wish to customize the *base URL*, *management port*, or *inference port* of your TorchServe server, you can pass them as arguments during initialization:

In [None]:
#|eval: false
# Customize the base URL, management port, and inference port
ts_client = TorchServeClientREST(base_url='http://your-torchserve-server.com', 
                             management_port=8081, inference_port=8080)
ts_client

TorchServeClientREST(base_url=http://your-torchserve-server.com, management_port=8081, inference_port=8080)

Alternatively, if you don't provide a base URL during initialization, the client will check for the presence of `TORCHSERVE_URL` in the environment variables. If the variable is not found, it will gracefully fall back to using *localhost* as the default.

## Management APIs

With TorchServe Management APIs, you can effortlessly manage your models at runtime. Here's a quick rundown of the actions you can perform using our `TorchServeClient` SDK:

1. **Register a Model**: Easily register a model with TorchServe using the `ts_client.management.register_model()` method.

In [None]:
#|eval: false
ts_client.management.register_model('https://torchserve.pytorch.org/mar_files/squeezenet1_1.mar')

2. **Increase/Decrease Workers**: Scale the number of workers for a specific model with simplicity using `ts_client.management.scale_workers()`.

In [None]:
#|eval: false
ts_client.management.scale_workers('squeezenet1_1', min_worker=1, max_worker=2)

{'status': 'Processing worker updates...'}

3. **Model Status**: Curious about a model's status? Fetch all the details you need using `ts_client.management.describe_model()`.

In [None]:
#|eval: false
ts_client.management.describe_model('squeezenet1_1')

[{'modelName': 'squeezenet1_1',
  'modelVersion': '1.0',
  'modelUrl': 'https://torchserve.pytorch.org/mar_files/squeezenet1_1.mar',
  'runtime': 'python',
  'minWorkers': 1,
  'maxWorkers': 1,
  'batchSize': 1,
  'maxBatchDelay': 100,
  'loadedAtStartup': False,
  'workers': [{'id': '9001',
    'startTime': '2023-07-17T22:55:40.155Z',
    'status': 'UNLOADING',
    'memoryUsage': 0,
    'pid': -1,
    'gpu': False,
    'gpuUsage': 'N/A'}]}]

4. **List Registered Models**: Quickly fetch a list of all registered models using `ts_client.management.list_models()`.

In [None]:
#|eval: false
# List all models
ts_client.management.list_models()

{'models': [{'modelName': 'squeezenet1_1',
   'modelUrl': 'https://torchserve.pytorch.org/mar_files/squeezenet1_1.mar'}]}

5. **Set Default Model Version**: Ensure the desired version of a model is the default choice with the `ts_client.management.set_model_version()` method.

In [None]:
#|eval: false
ts_client.management.set_default_version('squeezenet1_1', '1.0')

{'status': 'Default vesion succsesfully updated for model "squeezenet1_1" to "1.0"'}

6. **Unregister a Model**: If you need to bid farewell to a model, use the `ts_client.management.unregister_model()` function to gracefully remove it from TorchServe.

In [None]:
#|eval: false
ts_client.management.unregister_model('squeezenet1_1')

{'status': 'Model "squeezenet1_1" unregistered'}

7. **API Description**: view a full list of Managment APIs.

In [None]:
#|eval: false
ts_client.management.api_description()


Remember, all these management APIs can be accessed conveniently under the namespace `ts_client.management`.

## Inference APIs

TorchServeClient allows you to interact with the Inference API, which listens on port 8080, enabling you to run inference on your samples effortlessly. Here are the available APIs under the `ts_client.inference` namespace:


1. **API Description**: Want to explore what APIs and options are available? Use `ts_client.inference.api_description()` to get a comprehensive list.


In [None]:
#|eval: false
ts_client.inference.api_description()


2. **Health Check API**: Ensure the health of the running server with the `ts_client.inference.health_check()` method.


In [None]:
#|eval: false
ts_client.inference.health_check()

{'status': 'Healthy'}


3. **Predictions API**: Get predictions from the served model using `ts_client.inference.predictions()`.


In [None]:
#|eval: false
ts_client.inference.prediction('squeezenet1_1', data={'data': open('/Users/ankursingh/Downloads/kitten_small.jpg', 'rb')})

{'lynx': 0.5455798506736755,
 'tabby': 0.2794159948825836,
 'Egyptian_cat': 0.10391879826784134,
 'tiger_cat': 0.06263326108455658,
 'leopard': 0.0050191376358270645}


4. **Explanations API**: Dive into the served model's explanations with ease using `ts_client.inference.explanations()`.


In [None]:
#|eval: false
ts_client.inference.explaination('squeezenet1_1', data={'data': open('/Users/ankursingh/Downloads/kitten_small.jpg', 'rb')})

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()