Skip to content

Commit

Permalink
feature: support json serialize Conversation outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
pepesi committed Aug 24, 2022
1 parent cc720fa commit 553ec92
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 4 deletions.
22 changes: 20 additions & 2 deletions runtimes/huggingface/mlserver_huggingface/common.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import os
import io
import base64
import json
from typing import Optional, Dict
from distutils.util import strtobool
from PIL import Image

import numpy as np
from pydantic import BaseSettings
from mlserver.errors import MLServerError

from transformers.pipelines import pipeline
from transformers.pipelines import pipeline, Conversation
from transformers.pipelines.base import Pipeline
from transformers.models.auto.tokenization_auto import AutoTokenizer

Expand Down Expand Up @@ -128,8 +131,23 @@ def load_pipeline_from_settings(hf_settings: HuggingFaceSettings) -> Pipeline:
return pp


class NumpyEncoder(json.JSONEncoder):
class CommonJSONEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, (np.float_, np.float16, np.float32, np.float64)):
return float(str(obj))
if isinstance(obj, (np.int_, np.int8, np.int16, np.int32, np.int64)):
return int(obj)
if isinstance(obj, Image.Image):
buf = io.BytesIO()
obj.save(buf, format="png")
return base64.b64encode(buf.getvalue()).decode()
if isinstance(obj, Conversation):
return {
'uuid': str(obj.uuid),
'past_user_inputs': obj.past_user_inputs,
'generated_responses': obj.generated_responses,
'new_user_input': obj.new_user_input
}
return json.JSONEncoder.default(self, obj)
4 changes: 2 additions & 2 deletions runtimes/huggingface/mlserver_huggingface/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
parse_parameters_from_env,
InvalidTranformerInitialisation,
load_pipeline_from_settings,
NumpyEncoder,
CommonJSONEncoder
)
from mlserver_huggingface.codecs import MultiStringRequestCodec
from transformers.pipelines import SUPPORTED_TASKS
Expand Down Expand Up @@ -84,7 +84,7 @@ async def predict(self, payload: InferenceRequest) -> InferenceResponse:
prediction = self._model(*args, **kwargs)

# TODO: Convert hf output to v2 protocol, for now we use to_json
str_out = json.dumps(prediction, cls=NumpyEncoder)
str_out = json.dumps(prediction, cls=CommonJSONEncoder)
prediction_encoded = StringCodec.encode_output(payload=[str_out], name="output")

return InferenceResponse(
Expand Down
46 changes: 46 additions & 0 deletions runtimes/huggingface/tests/test_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import pytest
import json

import numpy as np
import uuid
from transformers.pipelines import Conversation
from mlserver_huggingface.common import CommonJSONEncoder


@pytest.mark.parametrize(
"output, expected",
[
(
{
"f_": np.float_(1),
"f16": np.float_(1),
"f32": np.float_(1),
"f64": np.float_(1),
"i_": np.int_(1.0),
"i0": np.int0(1.0),
"i8": np.int8(1.0),
"i16": np.int16(1.0),
"i32": np.int32(1.0),
"i64": np.int64(1.0),
},
'{"f_": 1.0, "f16": 1.0, "f32": 1.0, "f64": 1.0, "i_": '
+ '1, "i0": 1, "i8": 1, "i16": 1, "i32": 1, "i64": 1}',
),
(
{
"ints": np.full(1, 1, dtype=np.int8),
"floats": np.full(1, 1, dtype=np.float32),
},
'{"ints": [1], "floats": [1.0]}',
),
(np.full((2, 2, 2), 1, dtype=np.int8), "[[[1, 1], [1, 1]], [[1, 1], [1, 1]]]"),
(Conversation(
text="hello",
conversation_id=uuid.UUID('712dcbad-a042-4d9d-ab4d-84f20d6d9e7e'),
generated_responses=["hello!"]
), '{"uuid": "712dcbad-a042-4d9d-ab4d-84f20d6d9e7e", "past_user_inputs": [],'
+ ' "generated_responses": ["hello!"], "new_user_input": "hello"}')
],
)
def test_json_encoder(output, expected):
assert json.dumps(output, cls=CommonJSONEncoder) == expected

0 comments on commit 553ec92

Please sign in to comment.