Skip to content

Commit

Permalink
[python] Adds content-type response for DeepSpeed and FasterTransform…
Browse files Browse the repository at this point in the history
…er handler (deepjavalibrary#797)
  • Loading branch information
frankfliu authored and KexinFeng committed Aug 16, 2023
1 parent 096a545 commit 855ea9f
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
2 changes: 2 additions & 0 deletions engines/python/setup/djl_python/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ def inference(self, inputs: Input):
generated_text = self.tokenizer.batch_decode(
output_tokens, skip_special_tokens=True)
outputs.add([{"generated_text": s} for s in generated_text])
outputs.add_property("content-type", "application/json")
return outputs

result = self.pipeline(input_data, **model_kwargs)
Expand All @@ -324,6 +325,7 @@ def inference(self, inputs: Input):
"generated_responses": result.generated_responses,
},
}
outputs.add_property("content-type", "application/json")

outputs.add(result)
except Exception as e:
Expand Down
4 changes: 3 additions & 1 deletion engines/python/setup/djl_python/fastertransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@ def inference(self, inputs: Input):
beam_width=beam_width,
**parameters)
result = [{"generated_text": s} for s in result]
outputs = Output().add(result)
outputs = Output()
outputs.add_property("content-type", "application/json")
outputs.add(result)
except Exception as e:
logging.exception("FasterTransformer inference failed")
outputs = Output().error((str(e)))
Expand Down
5 changes: 2 additions & 3 deletions engines/python/setup/djl_python/streaming_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,15 +164,14 @@ def _validate_inputs(model, inputs):
logging.warning(
f"Model config does not contain architectures field. Supported architectures: *{StreamingUtils.SUPPORTED_MODEL_ARCH_SUFFIXES}"
)
model_arch_supported = True
else:
model_arch_list = model.config.architectures
model_arch_supported = any(
model_arch.endswith(
StreamingUtils.SUPPORTED_MODEL_ARCH_SUFFIXES)
for model_arch in model_arch_list)
if not model_arch_supported:
assert False, f"model archs: {model_arch_list} is not in supported list: *{StreamingUtils.SUPPORTED_MODEL_ARCH_SUFFIXES}"
if not model_arch_supported:
assert False, f"model archs: {model_arch_list} is not in supported list: *{StreamingUtils.SUPPORTED_MODEL_ARCH_SUFFIXES}"
if isinstance(inputs, list):
assert len(inputs) >= 1, "[ERROR] empty input list"
else:
Expand Down

0 comments on commit 855ea9f

Please sign in to comment.