-
Notifications
You must be signed in to change notification settings - Fork 183
/
lightgbm.py
34 lines (24 loc) · 1 KB
/
lightgbm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import lightgbm as lgb
from mlserver import types
from mlserver.model import MLModel
from mlserver.utils import get_model_uri
from mlserver.codecs import NumpyCodec, NumpyRequestCodec
WELLKNOWN_MODEL_FILENAMES = ["model.bst"]
class LightGBMModel(MLModel):
"""
Implementationof the MLModel interface to load and serve `lightgbm` models.
"""
async def load(self) -> bool:
model_uri = await get_model_uri(
self._settings, wellknown_filenames=WELLKNOWN_MODEL_FILENAMES
)
self._model = lgb.Booster(model_file=model_uri)
return True
async def predict(self, payload: types.InferenceRequest) -> types.InferenceResponse:
decoded = self.decode_request(payload, default_codec=NumpyRequestCodec)
prediction = self._model.predict(decoded)
return types.InferenceResponse(
model_name=self.name,
model_version=self.version,
outputs=[NumpyCodec.encode_output(name="predict", payload=prediction)],
)