Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: [VRD-711] Add batch prediction method to client #3645

Merged
merged 33 commits into from Mar 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
8f0b66b
update url
hmacdonald-verta Mar 6, 2023
96f4c7a
Getting better, just need to fill out the TODOs now
hmacdonald-verta Mar 8, 2023
9a25cac
Finished adding the todos
hmacdonald-verta Mar 8, 2023
b27f0b4
remove unused import
hmacdonald-verta Mar 8, 2023
c3ff09b
Clean up!
hmacdonald-verta Mar 8, 2023
bc102bb
Add pandas as an optional requirement
hmacdonald-verta Mar 8, 2023
0be35d7
Update client/verta/verta/deployment/_deployedmodel.py
hmacdonald-verta Mar 8, 2023
12fd2b4
Make batch prediction url a property
hmacdonald-verta Mar 8, 2023
6e5b26f
Update client/verta/verta/deployment/_deployedmodel.py
hmacdonald-verta Mar 8, 2023
596a5fa
Update client/verta/verta/deployment/_deployedmodel.py
hmacdonald-verta Mar 8, 2023
93468a5
Update client/verta/verta/deployment/_deployedmodel.py
hmacdonald-verta Mar 8, 2023
37a4f5d
Update client/verta/verta/deployment/_deployedmodel.py
hmacdonald-verta Mar 8, 2023
83d114e
Update client/verta/verta/deployment/_deployedmodel.py
hmacdonald-verta Mar 8, 2023
a58e53a
Add one unit test!
hmacdonald-verta Mar 8, 2023
a856e44
handle indexes
hmacdonald-verta Mar 9, 2023
4d9f243
handle indexes correctly this time and fix tests so far
hmacdonald-verta Mar 9, 2023
93f4b8a
lots of fixes
hmacdonald-verta Mar 10, 2023
6d45ee2
remove axis thing
hmacdonald-verta Mar 10, 2023
e666ffa
clean up more
hmacdonald-verta Mar 10, 2023
6bed084
Fix doc string
hmacdonald-verta Mar 10, 2023
b59b61f
add nan test
hmacdonald-verta Mar 10, 2023
fa5abd8
Finish tidying up
hmacdonald-verta Mar 10, 2023
6378145
handle nans by converting to json ourselves
hmacdonald-verta Mar 11, 2023
eda87f4
Update client/verta/tests/unit_tests/test_deployed_model.py
hmacdonald-verta Mar 11, 2023
8868050
Update client/verta/tests/unit_tests/test_deployed_model.py
hmacdonald-verta Mar 11, 2023
7bc14ba
cleanup
hmacdonald-verta Mar 11, 2023
cdd5d12
more cleanup
hmacdonald-verta Mar 11, 2023
c88594e
EVEN more fixes
hmacdonald-verta Mar 11, 2023
bf823f6
EVEN more fixes
hmacdonald-verta Mar 11, 2023
a2e9dcb
Okay wow it works flawlessly
hmacdonald-verta Mar 13, 2023
f271d79
fix quote
hmacdonald-verta Mar 13, 2023
16ac5b5
Remove unnecessary comment
hmacdonald-verta Mar 13, 2023
71e1715
Remove unnecessary encoding
hmacdonald-verta Mar 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
100 changes: 100 additions & 0 deletions client/verta/tests/unit_tests/test_deployed_model.py
@@ -1,9 +1,12 @@
# -*- coding: utf-8 -*-
""" Unit tests for the verta.deployment.DeployedModel class. """
import json
import os
from typing import Any, Dict

import pytest
np = pytest.importorskip("numpy")
pd = pytest.importorskip("pandas")
from requests import Session, HTTPError
from requests.exceptions import RetryError
import responses
Expand All @@ -15,6 +18,7 @@
from verta._internal_utils import http_session

PREDICTION_URL: str = 'https://test.dev.verta.ai/api/v1/predict/test_path'
BATCH_PREDICTION_URL: str = 'https://test.dev.verta.ai/api/v1/batch-predict/test_path'
TOKEN: str = '12345678-xxxx-1a2b-3c4d-e5f6g7h8'
MOCK_RETRY: Retry = http_session.retry_config(
max_retries=http_session.DEFAULT_MAX_RETRIES,
Expand Down Expand Up @@ -379,3 +383,99 @@ def test_predict_400_error_message_missing(mocked_responses) -> None:
'400 Client Error: Bad Request for url: '
'https://test.dev.verta.ai/api/v1/predict/test_path at '
)


def test_batch_predict_with_one_batch_with_no_index(mocked_responses) -> None:
""" Call batch_predict with a single batch. """
expected_df = pd.DataFrame({"A": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], "B": [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]})
expected_df_body = json.dumps(expected_df.to_dict(orient="split"))
mocked_responses.post(
BATCH_PREDICTION_URL,
body=expected_df_body,
status=200,
)
creds = EmailCredentials.load_from_os_env()
dm = DeployedModel(
prediction_url=PREDICTION_URL,
creds=creds,
token=TOKEN,
)
# the input below is entirely irrelevant since it's smaller than the batch size
prediction_df = dm.batch_predict(pd.DataFrame({"hi": "bye"}, index=[1]), 10)
pd.testing.assert_frame_equal(expected_df, prediction_df)


def test_batch_predict_with_one_batch_with_index(mocked_responses) -> None:
""" Call batch_predict with a single batch, where the output has an index. """
expected_df = pd.DataFrame({"A": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], "B": [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]}, index=["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"])
expected_df_body = json.dumps(expected_df.to_dict(orient="split"))
mocked_responses.post(
BATCH_PREDICTION_URL,
body=expected_df_body,
status=200,
)
creds = EmailCredentials.load_from_os_env()
dm = DeployedModel(
prediction_url=PREDICTION_URL,
creds=creds,
token=TOKEN,
)
# the input below is entirely irrelevant since it's smaller than the batch size
prediction_df = dm.batch_predict(pd.DataFrame({"hi": "bye"}, index=[1]), 10)
pd.testing.assert_frame_equal(expected_df, prediction_df)


def test_batch_predict_with_five_batches_with_no_indexes(mocked_responses) -> None:
""" Since the input has 5 rows and we're providing a batch_size of 1, we expect 5 batches."""
expected_df_list = [pd.DataFrame({"A": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]}),
pd.DataFrame({"B": [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]}),
pd.DataFrame({"C": [21, 22, 23, 24, 25, 26, 27, 28, 29, 30]}),
pd.DataFrame({"D": [31, 32, 33, 34, 35, 36, 37, 38, 39, 40]}),
pd.DataFrame({"E": [41, 42, 43, 44, 45, 46, 47, 48, 49, 50]}),
]
for expected_df in expected_df_list:
mocked_responses.add(
responses.POST,
BATCH_PREDICTION_URL,
body=json.dumps(expected_df.to_dict(orient="split")),
status=200,
)
creds = EmailCredentials.load_from_os_env()
dm = DeployedModel(
prediction_url=PREDICTION_URL,
creds=creds,
token=TOKEN,
)
input_df = pd.DataFrame({"a": [1, 2, 3, 4, 5], "b": [11, 12, 13, 14, 15]})
prediction_df = dm.batch_predict(input_df, batch_size=1)
expected_df = pd.concat(expected_df_list)
pd.testing.assert_frame_equal(expected_df, prediction_df)


def test_batch_predict_with_batches_and_indexes(mocked_responses) -> None:
""" Since the input has 5 rows and we're providing a batch_size of 1, we expect 5 batches.
Include an example of an index.
"""
expected_df_list = [pd.DataFrame({"A": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]}, index=["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"]),
pd.DataFrame({"B": [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]}, index=["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"]),
pd.DataFrame({"C": [21, 22, 23, 24, 25, 26, 27, 28, 29, 30]}, index=["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"]),
pd.DataFrame({"D": [31, 32, 33, 34, 35, 36, 37, 38, 39, 40]}, index=["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"]),
pd.DataFrame({"E": [41, 42, 43, 44, 45, 46, 47, 48, 49, 50]}, index=["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"]),
]
for expected_df in expected_df_list:
mocked_responses.add(
responses.POST,
BATCH_PREDICTION_URL,
body=json.dumps(expected_df.to_dict(orient="split")),
status=200,
)
creds = EmailCredentials.load_from_os_env()
dm = DeployedModel(
prediction_url=PREDICTION_URL,
creds=creds,
token=TOKEN,
)
input_df = pd.DataFrame({"a": [1, 2, 3, 4, 5], "b": [11, 12, 13, 14, 15]}, index=["A", "B", "C", "D", "E"])
prediction_df = dm.batch_predict(input_df, 1)
expected_final_df = pd.concat(expected_df_list)
pd.testing.assert_frame_equal(expected_final_df, prediction_df)
148 changes: 114 additions & 34 deletions client/verta/verta/deployment/_deployedmodel.py
@@ -1,3 +1,5 @@


# -*- coding: utf-8 -*-

import gzip
Expand All @@ -9,11 +11,12 @@
from urllib.parse import urlparse
from verta import credentials

from .._internal_utils import _utils, http_session
from .._internal_utils import _utils, http_session, importer
from .._internal_utils._utils import Connection
from .._internal_utils.access_token import AccessToken
from ..external import six


# NOTE: DeployedModel's mechanism for making requests is independent from the
# rest of the client; Client's Connection deliberately instantiates a new
# Session for each request it makes otherwise it encounters de/serialization
Expand Down Expand Up @@ -69,7 +72,7 @@ def __init__(
prediction_url,
token=None,
creds=None
):
):
self.prediction_url: str = prediction_url
self._credentials: credentials.Credentials = creds or credentials.load_from_os_env()
self._access_token: str = token
Expand All @@ -83,11 +86,11 @@ def _init_session(self):
if self._credentials:
session.headers.update(
Connection.prefixed_headers_for_credentials(self._credentials)
)
)
if self._access_token:
session.headers.update(
AccessToken(self._access_token).headers()
)
)
self._session = session

def __repr__(self):
Expand All @@ -107,14 +110,18 @@ def prediction_url(self, value):
raise ValueError("not a valid prediction_url")
self._prediction_url = parsed.geturl()

@property
def _batch_prediction_url(self):
return self.prediction_url.replace("/predict/", "/batch-predict/")

# TODO: Implement dynamic compression via separate utility and call it from here
def _predict(
self,
x: Any,
compress: bool=False,
prediction_id: Optional[str]=None,
):
prediction_url,
compress: bool = False,
prediction_id: Optional[str] = None,
):
"""Make prediction, handling compression and error propagation."""
request_headers = dict()
if prediction_id:
Expand All @@ -130,20 +137,23 @@ def _predict(
gzstream.seek(0)

response = self._session.post(
self._prediction_url,
prediction_url,
headers=request_headers,
data=gzstream.read(),
)
)
else:
# when passing json=x, requests sets `allow_nan=False` by default (as of 2.26.0), which we don't want
# so we're going to dump ourselves
body = json.dumps(x, allow_nan=True)
response = self._session.post(
self.prediction_url,
prediction_url,
headers=request_headers,
json=x,
)
data=body,
)
if response.status_code in (
400,
502,
): # possibly error from the model back end
400,
502,
): # possibly error from the model back end
try:
data = _utils.body_to_json(response)
except ValueError: # not JSON response; 502 not from model back end
Expand All @@ -157,12 +167,10 @@ def _predict(
_utils.raise_for_http_error(response=response)
return response


def headers(self):
"""Returns a copy of the headers attached to prediction requests."""
return self._session.headers.copy()


def get_curl(self):
"""
Gets a valid cURL command.
Expand All @@ -181,16 +189,16 @@ def get_curl(self):

# TODO: Removed deprecated `always_retry_404` and `always_retry_429` params
def predict(
self,
x: List[Any],
compress=False,
max_retries: int = http_session.DEFAULT_MAX_RETRIES,
always_retry_404=False,
always_retry_429=False,
retry_status: Set[int] = http_session.DEFAULT_STATUS_FORCELIST,
backoff_factor: float = http_session.DEFAULT_BACKOFF_FACTOR,
prediction_id: str = None,
) -> Dict[str, Any]:
self,
x: List[Any],
compress=False,
max_retries: int = http_session.DEFAULT_MAX_RETRIES,
always_retry_404=False,
always_retry_429=False,
retry_status: Set[int] = http_session.DEFAULT_STATUS_FORCELIST,
backoff_factor: float = http_session.DEFAULT_BACKOFF_FACTOR,
prediction_id: str = None,
ewagner-verta marked this conversation as resolved.
Show resolved Hide resolved
) -> Dict[str, Any]:
"""
Makes a prediction using input `x`.

Expand Down Expand Up @@ -260,10 +268,9 @@ def predict(
max_retries=max_retries,
retry_status=retry_status,
backoff_factor=backoff_factor,
)
)
return prediction_with_id[1]


def predict_with_id(
self,
x: List[Any],
Expand All @@ -272,7 +279,7 @@ def predict_with_id(
retry_status: Set[int] = http_session.DEFAULT_STATUS_FORCELIST,
backoff_factor: float = http_session.DEFAULT_BACKOFF_FACTOR,
prediction_id: str = None,
) -> Tuple[str, List[Any]]:
) -> Tuple[str, List[Any]]:
"""
Makes a prediction using input `x` the same as `predict`, but returns a tuple including the ID of the
prediction request along with the prediction results.
Expand Down Expand Up @@ -323,12 +330,85 @@ def predict_with_id(
max_retries=max_retries,
status_forcelist=retry_status,
backoff_factor=backoff_factor,
)
)

response = self._predict(x, compress, prediction_id)
response = self._predict(x, self.prediction_url, compress, prediction_id)
id = response.headers.get('verta-request-id', '')
return (id, _utils.body_to_json(response))

def batch_predict(
self,
df,
hmacdonald-verta marked this conversation as resolved.
Show resolved Hide resolved
batch_size: int = 100,
compress: bool = False,
max_retries: int = http_session.DEFAULT_MAX_RETRIES,
retry_status: Set[int] = http_session.DEFAULT_STATUS_FORCELIST,
backoff_factor: float = http_session.DEFAULT_BACKOFF_FACTOR,
prediction_id: str = None,
):
"""
Makes a prediction using input `df` of type pandas.DataFrame.

Parameters
----------
df : pd.DataFrame
A batch of inputs for the model. The dataframe must have an index (note that most pandas dataframes are
created with an automatically-generated index).
compress : bool, default False
Whether to compress the request body.
batch_size : int, default 100
The number of rows to send in each request.
max_retries : int, default 13
Maximum number of retries on status codes listed in ``retry_status``.
retry_status : set, default {404, 429, 500, 503, 504}
Set of status codes, as integers, for which retry attempts should be made. Overwrites default value.
Expand the set to include more. For example, to add status code 409 to the existing set, use:
``retry_status={404, 429, 500, 503, 504, 409}``
backoff_factor : float, default 0.3
A backoff factor to apply between retry attempts. Uses standard urllib3 sleep pattern:
``{backoff factor} * (2 ** ({number of total retries} - 1))`` with a maximum sleep time between requests of
120 seconds.
prediction_id: str, default None
A custom string to use as the ID for the prediction request. Defaults to a randomly generated UUID.

Returns
ewagner-verta marked this conversation as resolved.
Show resolved Hide resolved
-------
prediction : pd.DataFrame
Output returned by the deployed model for input `df`.

Raises
------
RuntimeError
If the deployed model encounters an error while running the prediction.
requests.HTTPError
If the server encounters an error while handing the HTTP request.

"""

pd = importer.maybe_dependency("pandas")
if pd is None:
raise ImportError("pandas is not installed; try `pip install pandas`")

# Set the retry config if it differs from current config.
self._session = http_session.set_retry_config(
self._session,
max_retries=max_retries,
status_forcelist=retry_status,
backoff_factor=backoff_factor,
)

# Split into batches
out_df_list = []
for i in range(0, len(df), batch_size):
batch = df.iloc[i:i+batch_size]
serialized_batch = batch.to_dict(orient="split")
# Predict with one batch at a time
response = self._predict(serialized_batch, self._batch_prediction_url, compress, prediction_id)
json_response = _utils.body_to_json(response)
out_df = pd.DataFrame(data=json_response['data'], index=json_response['index'], columns=json_response['columns'])
out_df_list.append(out_df)
# Reassemble output and return to user
return pd.concat(out_df_list)

# TODO: Remove this method after release of 0.22.0
@classmethod
Expand All @@ -344,12 +424,12 @@ def from_url(cls, url, token=None, creds=None):
"This method is deprecated and will be removed in an upcoming version;"
" Drop \".from_url\" and call the DeployedModel class directly with the same parameters.",
category=FutureWarning,
)
)
return cls(
prediction_url=url,
token=token,
creds=creds,
)
)


def prediction_input_unpack(func):
Expand Down