Skip to content

Commit

Permalink
feat: [VRD-711] Add batch prediction method to client (#3645)
Browse files Browse the repository at this point in the history
Create a batch_predict method that accepts and returns pandas.DataFrames. This method takes a dataframe, splits it into smaller dataframes of the provided batch_size to make predictions against the model, then reassembles the output to return to the user as one dataframe.
  • Loading branch information
hmacdonald-verta committed Mar 13, 2023
1 parent b10531a commit 2200a64
Show file tree
Hide file tree
Showing 2 changed files with 214 additions and 34 deletions.
100 changes: 100 additions & 0 deletions client/verta/tests/unit_tests/test_deployed_model.py
@@ -1,9 +1,12 @@
# -*- coding: utf-8 -*-
""" Unit tests for the verta.deployment.DeployedModel class. """
import json
import os
from typing import Any, Dict

import pytest
np = pytest.importorskip("numpy")
pd = pytest.importorskip("pandas")
from requests import Session, HTTPError
from requests.exceptions import RetryError
import responses
Expand All @@ -15,6 +18,7 @@
from verta._internal_utils import http_session

PREDICTION_URL: str = 'https://test.dev.verta.ai/api/v1/predict/test_path'
BATCH_PREDICTION_URL: str = 'https://test.dev.verta.ai/api/v1/batch-predict/test_path'
TOKEN: str = '12345678-xxxx-1a2b-3c4d-e5f6g7h8'
MOCK_RETRY: Retry = http_session.retry_config(
max_retries=http_session.DEFAULT_MAX_RETRIES,
Expand Down Expand Up @@ -379,3 +383,99 @@ def test_predict_400_error_message_missing(mocked_responses) -> None:
'400 Client Error: Bad Request for url: '
'https://test.dev.verta.ai/api/v1/predict/test_path at '
)


def test_batch_predict_with_one_batch_with_no_index(mocked_responses) -> None:
""" Call batch_predict with a single batch. """
expected_df = pd.DataFrame({"A": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], "B": [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]})
expected_df_body = json.dumps(expected_df.to_dict(orient="split"))
mocked_responses.post(
BATCH_PREDICTION_URL,
body=expected_df_body,
status=200,
)
creds = EmailCredentials.load_from_os_env()
dm = DeployedModel(
prediction_url=PREDICTION_URL,
creds=creds,
token=TOKEN,
)
# the input below is entirely irrelevant since it's smaller than the batch size
prediction_df = dm.batch_predict(pd.DataFrame({"hi": "bye"}, index=[1]), 10)
pd.testing.assert_frame_equal(expected_df, prediction_df)


def test_batch_predict_with_one_batch_with_index(mocked_responses) -> None:
""" Call batch_predict with a single batch, where the output has an index. """
expected_df = pd.DataFrame({"A": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], "B": [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]}, index=["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"])
expected_df_body = json.dumps(expected_df.to_dict(orient="split"))
mocked_responses.post(
BATCH_PREDICTION_URL,
body=expected_df_body,
status=200,
)
creds = EmailCredentials.load_from_os_env()
dm = DeployedModel(
prediction_url=PREDICTION_URL,
creds=creds,
token=TOKEN,
)
# the input below is entirely irrelevant since it's smaller than the batch size
prediction_df = dm.batch_predict(pd.DataFrame({"hi": "bye"}, index=[1]), 10)
pd.testing.assert_frame_equal(expected_df, prediction_df)


def test_batch_predict_with_five_batches_with_no_indexes(mocked_responses) -> None:
""" Since the input has 5 rows and we're providing a batch_size of 1, we expect 5 batches."""
expected_df_list = [pd.DataFrame({"A": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]}),
pd.DataFrame({"B": [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]}),
pd.DataFrame({"C": [21, 22, 23, 24, 25, 26, 27, 28, 29, 30]}),
pd.DataFrame({"D": [31, 32, 33, 34, 35, 36, 37, 38, 39, 40]}),
pd.DataFrame({"E": [41, 42, 43, 44, 45, 46, 47, 48, 49, 50]}),
]
for expected_df in expected_df_list:
mocked_responses.add(
responses.POST,
BATCH_PREDICTION_URL,
body=json.dumps(expected_df.to_dict(orient="split")),
status=200,
)
creds = EmailCredentials.load_from_os_env()
dm = DeployedModel(
prediction_url=PREDICTION_URL,
creds=creds,
token=TOKEN,
)
input_df = pd.DataFrame({"a": [1, 2, 3, 4, 5], "b": [11, 12, 13, 14, 15]})
prediction_df = dm.batch_predict(input_df, batch_size=1)
expected_df = pd.concat(expected_df_list)
pd.testing.assert_frame_equal(expected_df, prediction_df)


def test_batch_predict_with_batches_and_indexes(mocked_responses) -> None:
""" Since the input has 5 rows and we're providing a batch_size of 1, we expect 5 batches.
Include an example of an index.
"""
expected_df_list = [pd.DataFrame({"A": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]}, index=["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"]),
pd.DataFrame({"B": [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]}, index=["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"]),
pd.DataFrame({"C": [21, 22, 23, 24, 25, 26, 27, 28, 29, 30]}, index=["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"]),
pd.DataFrame({"D": [31, 32, 33, 34, 35, 36, 37, 38, 39, 40]}, index=["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"]),
pd.DataFrame({"E": [41, 42, 43, 44, 45, 46, 47, 48, 49, 50]}, index=["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"]),
]
for expected_df in expected_df_list:
mocked_responses.add(
responses.POST,
BATCH_PREDICTION_URL,
body=json.dumps(expected_df.to_dict(orient="split")),
status=200,
)
creds = EmailCredentials.load_from_os_env()
dm = DeployedModel(
prediction_url=PREDICTION_URL,
creds=creds,
token=TOKEN,
)
input_df = pd.DataFrame({"a": [1, 2, 3, 4, 5], "b": [11, 12, 13, 14, 15]}, index=["A", "B", "C", "D", "E"])
prediction_df = dm.batch_predict(input_df, 1)
expected_final_df = pd.concat(expected_df_list)
pd.testing.assert_frame_equal(expected_final_df, prediction_df)
148 changes: 114 additions & 34 deletions client/verta/verta/deployment/_deployedmodel.py
@@ -1,3 +1,5 @@


# -*- coding: utf-8 -*-

import gzip
Expand All @@ -9,11 +11,12 @@
from urllib.parse import urlparse
from verta import credentials

from .._internal_utils import _utils, http_session
from .._internal_utils import _utils, http_session, importer
from .._internal_utils._utils import Connection
from .._internal_utils.access_token import AccessToken
from ..external import six


# NOTE: DeployedModel's mechanism for making requests is independent from the
# rest of the client; Client's Connection deliberately instantiates a new
# Session for each request it makes otherwise it encounters de/serialization
Expand Down Expand Up @@ -69,7 +72,7 @@ def __init__(
prediction_url,
token=None,
creds=None
):
):
self.prediction_url: str = prediction_url
self._credentials: credentials.Credentials = creds or credentials.load_from_os_env()
self._access_token: str = token
Expand All @@ -83,11 +86,11 @@ def _init_session(self):
if self._credentials:
session.headers.update(
Connection.prefixed_headers_for_credentials(self._credentials)
)
)
if self._access_token:
session.headers.update(
AccessToken(self._access_token).headers()
)
)
self._session = session

def __repr__(self):
Expand All @@ -107,14 +110,18 @@ def prediction_url(self, value):
raise ValueError("not a valid prediction_url")
self._prediction_url = parsed.geturl()

@property
def _batch_prediction_url(self):
return self.prediction_url.replace("/predict/", "/batch-predict/")

# TODO: Implement dynamic compression via separate utility and call it from here
def _predict(
self,
x: Any,
compress: bool=False,
prediction_id: Optional[str]=None,
):
prediction_url,
compress: bool = False,
prediction_id: Optional[str] = None,
):
"""Make prediction, handling compression and error propagation."""
request_headers = dict()
if prediction_id:
Expand All @@ -130,20 +137,23 @@ def _predict(
gzstream.seek(0)

response = self._session.post(
self._prediction_url,
prediction_url,
headers=request_headers,
data=gzstream.read(),
)
)
else:
# when passing json=x, requests sets `allow_nan=False` by default (as of 2.26.0), which we don't want
# so we're going to dump ourselves
body = json.dumps(x, allow_nan=True)
response = self._session.post(
self.prediction_url,
prediction_url,
headers=request_headers,
json=x,
)
data=body,
)
if response.status_code in (
400,
502,
): # possibly error from the model back end
400,
502,
): # possibly error from the model back end
try:
data = _utils.body_to_json(response)
except ValueError: # not JSON response; 502 not from model back end
Expand All @@ -157,12 +167,10 @@ def _predict(
_utils.raise_for_http_error(response=response)
return response


def headers(self):
"""Returns a copy of the headers attached to prediction requests."""
return self._session.headers.copy()


def get_curl(self):
"""
Gets a valid cURL command.
Expand All @@ -181,16 +189,16 @@ def get_curl(self):

# TODO: Removed deprecated `always_retry_404` and `always_retry_429` params
def predict(
self,
x: List[Any],
compress=False,
max_retries: int = http_session.DEFAULT_MAX_RETRIES,
always_retry_404=False,
always_retry_429=False,
retry_status: Set[int] = http_session.DEFAULT_STATUS_FORCELIST,
backoff_factor: float = http_session.DEFAULT_BACKOFF_FACTOR,
prediction_id: str = None,
) -> Dict[str, Any]:
self,
x: List[Any],
compress=False,
max_retries: int = http_session.DEFAULT_MAX_RETRIES,
always_retry_404=False,
always_retry_429=False,
retry_status: Set[int] = http_session.DEFAULT_STATUS_FORCELIST,
backoff_factor: float = http_session.DEFAULT_BACKOFF_FACTOR,
prediction_id: str = None,
) -> Dict[str, Any]:
"""
Makes a prediction using input `x`.
Expand Down Expand Up @@ -260,10 +268,9 @@ def predict(
max_retries=max_retries,
retry_status=retry_status,
backoff_factor=backoff_factor,
)
)
return prediction_with_id[1]


def predict_with_id(
self,
x: List[Any],
Expand All @@ -272,7 +279,7 @@ def predict_with_id(
retry_status: Set[int] = http_session.DEFAULT_STATUS_FORCELIST,
backoff_factor: float = http_session.DEFAULT_BACKOFF_FACTOR,
prediction_id: str = None,
) -> Tuple[str, List[Any]]:
) -> Tuple[str, List[Any]]:
"""
Makes a prediction using input `x` the same as `predict`, but returns a tuple including the ID of the
prediction request along with the prediction results.
Expand Down Expand Up @@ -323,12 +330,85 @@ def predict_with_id(
max_retries=max_retries,
status_forcelist=retry_status,
backoff_factor=backoff_factor,
)
)

response = self._predict(x, compress, prediction_id)
response = self._predict(x, self.prediction_url, compress, prediction_id)
id = response.headers.get('verta-request-id', '')
return (id, _utils.body_to_json(response))

def batch_predict(
self,
df,
batch_size: int = 100,
compress: bool = False,
max_retries: int = http_session.DEFAULT_MAX_RETRIES,
retry_status: Set[int] = http_session.DEFAULT_STATUS_FORCELIST,
backoff_factor: float = http_session.DEFAULT_BACKOFF_FACTOR,
prediction_id: str = None,
):
"""
Makes a prediction using input `df` of type pandas.DataFrame.
Parameters
----------
df : pd.DataFrame
A batch of inputs for the model. The dataframe must have an index (note that most pandas dataframes are
created with an automatically-generated index).
compress : bool, default False
Whether to compress the request body.
batch_size : int, default 100
The number of rows to send in each request.
max_retries : int, default 13
Maximum number of retries on status codes listed in ``retry_status``.
retry_status : set, default {404, 429, 500, 503, 504}
Set of status codes, as integers, for which retry attempts should be made. Overwrites default value.
Expand the set to include more. For example, to add status code 409 to the existing set, use:
``retry_status={404, 429, 500, 503, 504, 409}``
backoff_factor : float, default 0.3
A backoff factor to apply between retry attempts. Uses standard urllib3 sleep pattern:
``{backoff factor} * (2 ** ({number of total retries} - 1))`` with a maximum sleep time between requests of
120 seconds.
prediction_id: str, default None
A custom string to use as the ID for the prediction request. Defaults to a randomly generated UUID.
Returns
-------
prediction : pd.DataFrame
Output returned by the deployed model for input `df`.
Raises
------
RuntimeError
If the deployed model encounters an error while running the prediction.
requests.HTTPError
If the server encounters an error while handing the HTTP request.
"""

pd = importer.maybe_dependency("pandas")
if pd is None:
raise ImportError("pandas is not installed; try `pip install pandas`")

# Set the retry config if it differs from current config.
self._session = http_session.set_retry_config(
self._session,
max_retries=max_retries,
status_forcelist=retry_status,
backoff_factor=backoff_factor,
)

# Split into batches
out_df_list = []
for i in range(0, len(df), batch_size):
batch = df.iloc[i:i+batch_size]
serialized_batch = batch.to_dict(orient="split")
# Predict with one batch at a time
response = self._predict(serialized_batch, self._batch_prediction_url, compress, prediction_id)
json_response = _utils.body_to_json(response)
out_df = pd.DataFrame(data=json_response['data'], index=json_response['index'], columns=json_response['columns'])
out_df_list.append(out_df)
# Reassemble output and return to user
return pd.concat(out_df_list)

# TODO: Remove this method after release of 0.22.0
@classmethod
Expand All @@ -344,12 +424,12 @@ def from_url(cls, url, token=None, creds=None):
"This method is deprecated and will be removed in an upcoming version;"
" Drop \".from_url\" and call the DeployedModel class directly with the same parameters.",
category=FutureWarning,
)
)
return cls(
prediction_url=url,
token=token,
creds=creds,
)
)


def prediction_input_unpack(func):
Expand Down

0 comments on commit 2200a64

Please sign in to comment.