New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrate triton with huggingface runtime #3601
base: master
Are you sure you want to change the base?
Conversation
[APPROVALNOTIFIER] This PR is NOT APPROVED This pull-request has been approved by: sivanantha321 The full list of commands accepted by this bot can be found here.
Needs approval from an approver in each of these files:
Approvers can indicate their approval by writing |
a9be41f
to
32182db
Compare
Signed-off-by: Sivanantham Chinnaiyan <sivanantham.chinnaiyan@ideas2it.com>
Signed-off-by: Sivanantham Chinnaiyan <sivanantham.chinnaiyan@ideas2it.com>
Signed-off-by: Sivanantham Chinnaiyan <sivanantham.chinnaiyan@ideas2it.com>
Signed-off-by: Sivanantham Chinnaiyan <sivanantham.chinnaiyan@ideas2it.com>
3075b51
to
f4a7c2c
Compare
Signed-off-by: Sivanantham Chinnaiyan <sivanantham.chinnaiyan@ideas2it.com>
f4a7c2c
to
632374e
Compare
Signed-off-by: Sivanantham Chinnaiyan <sivanantham.chinnaiyan@ideas2it.com>
COPY huggingfaceserver huggingfaceserver | ||
RUN cd huggingfaceserver && poetry install --no-interaction --no-cache | ||
RUN cd huggingfaceserver && poetry install --no-interaction --no-cache --extras "vllm" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to pin the version
@sivanantha321 let’s discuss the implementation, I think we can move the code to the base kserve level so transformer can also leverage this implementation not only huggingface models. |
@@ -68,6 +68,9 @@ def list_of_strings(arg): | |||
parser.add_argument( | |||
"--return_token_type_ids", action="store_true", help="Return token type ids" | |||
) | |||
parser.add_argument( | |||
"--enable_triton", action="store_true", help="Use triton as the runtime" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have introduced the new argument backend
for huggingfaceserver, so we can add triton
as another backend in addition to [vllm, huggingface]
def _triton_client(self): | ||
if self._triton_client_instance is None: | ||
url = f"grpc://localhost:{self._triton_config.grpc_port}" | ||
self._triton_client_instance = AsyncioModelClient( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how is the AsyncioModelClient
different from the triton's grpc python client ?
@@ -191,6 +238,84 @@ def load(self) -> bool: | |||
self.tokenizer.add_special_tokens({"pad_token": "[PAD]"}) | |||
logger.info(f"successfully loaded tokenizer for task: {self.task}") | |||
|
|||
if self.use_triton: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can consider adding this option to PredictorConfig
, so any KServe model is able to launch the triton server and load the model in the load
function when the option is enabled.
https://github.com/kserve/kserve/blob/master/python/kserve/kserve/model.py#L89
|
||
def infer_fn(requests: List[Request]): | ||
responses = [] | ||
for request in requests: | ||
input_tensors = {} | ||
for input_name, input_array in request.data.items(): | ||
input_tensors[input_name] = torch.tensor( | ||
input_array, device=self.device | ||
) | ||
|
||
if ( | ||
self.task == MLTask.text2text_generation.value | ||
or self.task == MLTask.text_generation | ||
): | ||
outputs = self.model.generate(**input_tensors) | ||
else: | ||
outputs = self.model(**input_tensors) | ||
responses.append({"outputs": outputs.numpy()}) | ||
return responses |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The inference should be done by triton not calling self.model, I guess infer_func
is for python backend?
inputs=inputs_config, | ||
outputs=outputs_config, | ||
strict=True, | ||
config=ModelConfig(batching=False, response_cache=True), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might also want to expose the batching configs
@@ -364,6 +508,10 @@ async def predict( | |||
if self.vllm_engine: | |||
raise InferenceError(VLLM_USE_GENERATE_ENDPOINT_ERROR) | |||
|
|||
if self.use_triton: | |||
res = await self._triton_client.infer_sample(**input_batch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is calling infer_sample
?
What this PR does / why we need it:
Which issue(s) this PR fixes (optional, in
fixes #<issue number>(, fixes #<issue_number>, ...)
format, will close the issue(s) when PR gets merged):Fixes #
Type of changes
Please delete options that are not relevant.
Feature/Issue validation/testing:
Please describe the tests that you ran to verify your changes and relevant result summary. Provide instructions so it can be reproduced.
Please also list any relevant details for your test configuration.
Test A
Test B
Logs
Special notes for your reviewer:
Checklist:
Release note: