-
Notifications
You must be signed in to change notification settings - Fork 830
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add V2 data plane support for alibi detect server
- Loading branch information
1 parent
4a960a0
commit 0e9eee5
Showing
17 changed files
with
4,090 additions
and
58 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
from http import HTTPStatus | ||
from typing import Dict, List | ||
|
||
import numpy as np | ||
import tornado | ||
from adserver.protocols.request_handler import ( | ||
RequestHandler, | ||
) # pylint: disable=no-name-in-module | ||
|
||
def _create_np_from_v2(data: list,ty: str, shape: list) -> np.array: | ||
npty = np.float | ||
if ty == "BOOL": | ||
npty = np.bool | ||
elif ty == "UINT8": | ||
npty = np.uint8 | ||
elif ty == "UINT16": | ||
npty = np.uint16 | ||
elif ty == "UINT32": | ||
npty = np.uint32 | ||
elif ty == "UINT64": | ||
npty = np.uint64 | ||
elif ty == "INT8": | ||
npty = np.int8 | ||
elif ty == "INT16": | ||
npty = np.int16 | ||
elif ty == "INT32": | ||
npty = np.int32 | ||
elif ty == "INT64": | ||
npty = np.int64 | ||
elif ty == "FP16": | ||
npty = np.float32 | ||
elif ty == "FP32": | ||
npty = np.float32 | ||
elif ty == "FP64": | ||
npty = np.float64 | ||
else: | ||
raise ValueError(f"V2 unknown type or type that can't be coerced {ty}") | ||
|
||
arr = np.array(data, dtype=npty) | ||
arr.shape = tuple(shape) | ||
return arr | ||
|
||
|
||
class V2RequestHandler(RequestHandler): | ||
def __init__(self, request: Dict): # pylint: disable=useless-super-delegation | ||
super().__init__(request) | ||
|
||
def validate(self): | ||
if not "inputs" in self.request: | ||
raise tornado.web.HTTPError( | ||
status_code=HTTPStatus.BAD_REQUEST, | ||
reason='Expected key "data" in request body', | ||
) | ||
# assumes single input | ||
inputs = self.request["inputs"][0] | ||
data_type = inputs["datatype"] | ||
|
||
if data_type == "BYTES": | ||
raise tornado.web.HTTPError( | ||
status_code=HTTPStatus.BAD_REQUEST, | ||
reason='v2 protocol BYTES data can not be presently handled"', | ||
) | ||
|
||
def extract_request(self) -> List: | ||
inputs = self.request["inputs"][0] | ||
data_type = inputs["datatype"] | ||
shape = inputs["shape"] | ||
data = inputs["data"] | ||
arr = _create_np_from_v2(data, data_type, shape) | ||
return arr.tolist() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.