Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for online drift detectors #1108

Merged
merged 10 commits into from
Apr 20, 2023
72 changes: 54 additions & 18 deletions runtimes/alibi-detect/mlserver_alibi_detect/runtime.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
import numpy as np

from pydantic.error_wrappers import ValidationError
from typing import Optional, List
from pydantic import BaseSettings
from typing import Optional, List, Union
from pydantic import BaseSettings, Field
from functools import cached_property

from alibi_detect.saving import load_detector
Expand Down Expand Up @@ -41,6 +42,12 @@ class Config:
inference runs for all of them).
"""

state_save_freq: Optional[int] = Field(100, gt=0)
"""
Save the detector state after every `state_save_freq` predictions.
Only applicable to detectors with a `save_state` method.
"""


class AlibiDetectRuntime(MLModel):
"""
Expand All @@ -58,15 +65,21 @@ def __init__(self, settings: ModelSettings):
super().__init__(settings)

async def load(self) -> bool:
model_uri = await get_model_uri(self._settings)
self._model_uri = await get_model_uri(self._settings)
adriangonz marked this conversation as resolved.
Show resolved Hide resolved
try:
self._model = load_detector(model_uri)
self._model = load_detector(self._model_uri)
# Update AlibiDetectSettings according to whether an online drift detector (i.e. has a save_state method)
# TODO - in future we may use self._model.meta['online'] here, but must reconsider outlier detectors first
if hasattr(self._model, 'save_state'):
self._ad_settings.batch_size = None
else:
self._ad_settings.state_save_freq = None
except (
ValueError,
FileNotFoundError,
EOFError,
NotImplementedError,
ValidationError,
ValueError,
FileNotFoundError,
EOFError,
NotImplementedError,
ValidationError,
) as e:
raise MLServerError(
f"Invalid configuration for model {self._settings.name}: {e}"
Expand Down Expand Up @@ -105,23 +118,37 @@ def _detect(self, payload: InferenceRequest) -> InferenceResponse:
input_data = self.decode_request(payload, default_codec=NumpyRequestCodec)
predict_kwargs = self._ad_settings.predict_parameters

try:
y = self._model.predict(np.array(input_data), **predict_kwargs)
return self._encode_response(y)
except (ValueError, IndexError) as e:
raise InferenceError(
f"Invalid predict parameters for model {self._settings.name}: {e}"
) from e
# If batch is configured, wrap X in a list so that it is not unpacked
X = np.array(input_data)
if self._ad_settings.batch_size:
X = [X]
Copy link
Contributor Author

@ascillitoe ascillitoe Apr 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrap X in a list (of length 1) so that we don't unpack into single instances when running in batch mode (i.e. offline drift detectors or outlier detectors). pred is then returned as a list of length 1.

This slightly circuitous logic is intended to avoid having two duplicate try-except blocks for the two (offline vs online) use cases...


# Run detector inference
pred = []
for x in X:
# Prediction
try:
pred.append(self._model.predict(x, **predict_kwargs))
except (ValueError, IndexError) as e:
raise InferenceError(
f"Invalid predict parameters for model {self._settings.name}: {e}"
) from e
# Save state if necessary
if self._ad_settings.state_save_freq and \
self._model.t % self._ad_settings.state_save_freq == 0 and self._model.t > 0:
self._model.save_state(os.path.join(self._model_uri, 'state'))
adriangonz marked this conversation as resolved.
Show resolved Hide resolved

return self._encode_response(self._postproc_pred(pred))

def _encode_response(self, y: dict) -> InferenceResponse:
outputs = []
for key in y["data"]:
outputs.append(
NumpyCodec.encode_output(name=key, payload=np.array([y["data"][key]]))
NumpyCodec.encode_output(name=key, payload=np.array(y["data"][key]))
adriangonz marked this conversation as resolved.
Show resolved Hide resolved
)

# Add headers
y["meta"]["headers"] = {
y["meta"]["headers"] = { # TODO - if we want to viz the sliding windows, where should we store window size?
"x-seldon-alibi-type": self.alibi_type,
"x-seldon-alibi-method": self.alibi_method,
}
Expand All @@ -132,6 +159,15 @@ def _encode_response(self, y: dict) -> InferenceResponse:
outputs=outputs,
)

@staticmethod
def _postproc_pred(pred: Union[dict, List[dict]]) -> dict:
data = {key: [] for key in pred[0]['data'].keys()}
for i, pred_i in enumerate(pred):
for key in data:
data[key].append(pred_i['data'][key])
y = {'data': data, 'meta': pred[0]['meta']}
return y

@cached_property
def alibi_method(self) -> str:
module: str = type(self._model).__module__
Expand Down
39 changes: 38 additions & 1 deletion runtimes/alibi-detect/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import tensorflow as tf

from tensorflow.keras.layers import Dense, InputLayer
from alibi_detect.cd import TabularDrift
from alibi_detect.cd import TabularDrift, CVMDriftOnline
from alibi_detect.od import OutlierVAE
from alibi_detect.saving import save_detector

Expand All @@ -18,6 +18,8 @@
tf.keras.backend.clear_session()

P_VAL_THRESHOLD = 0.05
ERT = 50
WINDOW_SIZES = [10]

TESTS_PATH = os.path.dirname(__file__)
TESTDATA_PATH = os.path.join(TESTS_PATH, "testdata")
Expand Down Expand Up @@ -125,6 +127,21 @@ def drift_detector_settings(
)


@pytest.fixture
def online_drift_detector_settings(
online_drift_detector_uri: str,
) -> ModelSettings:
return ModelSettings(
name="alibi-detect-model",
implementation=AlibiDetectRuntime,
parameters=ModelParameters(
uri=online_drift_detector_uri,
version="v1.2.3", # TODO - should this match version.py?
ascillitoe marked this conversation as resolved.
Show resolved Hide resolved
extra={"batch_size": 50, "state_save_freq": 10}, # spec batch_size to check that it is ignored
),
)


@pytest.fixture
def drift_detector_uri(tmp_path: str) -> str:
X_ref = np.array([[1, 2, 3]])
Expand All @@ -137,9 +154,29 @@ def drift_detector_uri(tmp_path: str) -> str:
return detector_uri


@pytest.fixture
def online_drift_detector_uri(tmp_path: str) -> str:
X_ref = np.ones((10, 3))

cd = CVMDriftOnline(X_ref, ert=ERT, window_sizes=WINDOW_SIZES)

detector_uri = os.path.join(tmp_path, "alibi-detector-artifacts")
save_detector(cd, detector_uri)

return detector_uri


@pytest.fixture
async def drift_detector(drift_detector_settings: ModelSettings) -> AlibiDetectRuntime:
model = AlibiDetectRuntime(drift_detector_settings)
model.ready = await model.load()

return model


@pytest.fixture
async def online_drift_detector(online_drift_detector_settings: ModelSettings) -> AlibiDetectRuntime:
model = AlibiDetectRuntime(online_drift_detector_settings)
model.ready = await model.load()

return model
76 changes: 72 additions & 4 deletions runtimes/alibi-detect/tests/test_drift_detector.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import numpy as np

from alibi_detect.cd import TabularDrift
from alibi_detect.cd import TabularDrift, CVMDriftOnline

from mlserver.types import InferenceRequest
from mlserver.codecs import NumpyRequestCodec
from mlserver.types import (
InferenceRequest,
Parameters,
RequestInput
)

from mlserver.codecs import NumpyCodec, NumpyRequestCodec

from mlserver_alibi_detect import AlibiDetectRuntime

from .conftest import P_VAL_THRESHOLD
from .conftest import P_VAL_THRESHOLD, ERT, WINDOW_SIZES


async def test_load_folder(
Expand All @@ -17,6 +22,13 @@ async def test_load_folder(
assert type(drift_detector._model) == TabularDrift


async def test_load_folder_online(
online_drift_detector: AlibiDetectRuntime,
):
assert online_drift_detector.ready
assert type(online_drift_detector._model) == CVMDriftOnline


async def test_predict(
drift_detector: AlibiDetectRuntime,
inference_request: InferenceRequest,
Expand Down Expand Up @@ -79,3 +91,59 @@ async def test_predict_batch_cleared(
# Batch should now be cleared (and started from scratch)
response = await drift_detector.predict(inference_request)
assert len(response.outputs) == 0


async def test_predict_online(
online_drift_detector: AlibiDetectRuntime,
inference_request: InferenceRequest,
):
# Test a request of length 1
response = await online_drift_detector.predict(inference_request)
assert len(response.outputs) == 7
assert response.outputs[0].name == "is_drift"
assert response.outputs[0].shape == [1, 1]
assert response.outputs[1].name == "distance"
assert response.outputs[2].name == "p_val"
assert response.outputs[3].name == "threshold"
assert response.outputs[4].name == "time"
assert response.outputs[4].data[0] == 1
assert response.outputs[5].name == "ert"
assert response.outputs[5].data[0] == ERT
assert response.outputs[6].name == "test_stat"
assert response.outputs[6].shape == [1, 1, 3]


async def test_predict_batch_online(
online_drift_detector: AlibiDetectRuntime
):
# Send a batch request, the drift detector should run on one instance at a time
batch_size = 50
data = np.random.normal(size=(batch_size, 3))
inference_request = InferenceRequest(
parameters=Parameters(content_type=NumpyRequestCodec.ContentType),
inputs=[
RequestInput(
name="predict",
shape=data.shape,
data=data.tolist(),
datatype="FP32",
)
],
)
response = await online_drift_detector.predict(inference_request)
assert len(response.outputs) == 7
assert response.outputs[0].name == "is_drift"
assert response.outputs[0].shape == [50, 1]
assert response.outputs[1].name == "distance"
assert response.outputs[2].name == "p_val"
assert response.outputs[3].name == "threshold"
assert response.outputs[4].name == "time"
assert response.outputs[4].data[-1] == 50
assert response.outputs[5].name == "ert"
assert response.outputs[5].data[0] == ERT
assert response.outputs[6].name == "test_stat"
assert response.outputs[6].shape == [50, 1, 3]
# Test stat should be NaN until the test window is filled
test_stats = NumpyCodec.decode_output(response.outputs[6])
assert np.isnan(test_stats[0]).all()
assert not np.isnan(test_stats[WINDOW_SIZES[0]]).all()
23 changes: 23 additions & 0 deletions runtimes/alibi-detect/tests/test_runtime.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import pytest

from mlserver.codecs import CodecError
Expand All @@ -16,3 +17,25 @@ async def test_multiple_inputs_error(

with pytest.raises(CodecError):
await outlier_detector.predict(inference_request)


async def test_saving_state(
online_drift_detector: AlibiDetectRuntime,
inference_request: InferenceRequest,
):
save_freq = online_drift_detector._ad_settings.state_save_freq
state_uri = os.path.join(online_drift_detector._model_uri, 'state')

# Check nothing written after (save_freq -1) requests
for _ in range(save_freq - 1): # type: ignore
await online_drift_detector.predict(inference_request)
assert not os.path.isdir(state_uri)

# Check state written after (save_freq) requests
await online_drift_detector.predict(inference_request)
assert os.path.isdir(state_uri)

# Check state properly loaded in new runtime
new_online_drift_detector = AlibiDetectRuntime(online_drift_detector.settings)
await new_online_drift_detector.load()
assert new_online_drift_detector._model.t == save_freq