Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate triton with huggingface runtime #3601

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from

Conversation

sivanantha321
Copy link
Member

What this PR does / why we need it:

Which issue(s) this PR fixes (optional, in fixes #<issue number>(, fixes #<issue_number>, ...) format, will close the issue(s) when PR gets merged):
Fixes #

Type of changes
Please delete options that are not relevant.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Feature/Issue validation/testing:

Please describe the tests that you ran to verify your changes and relevant result summary. Provide instructions so it can be reproduced.
Please also list any relevant details for your test configuration.

  • Test A

  • Test B

  • Logs

Special notes for your reviewer:

  1. Please confirm that if this PR changes any image versions, then that's the sole change this PR makes.

Checklist:

  • Have you added unit/e2e tests that prove your fix is effective or that this feature works?
  • Has code been commented, particularly in hard-to-understand areas?
  • Have you made corresponding changes to the documentation?

Release note:

Integrate triton with huggingface runtime

Copy link

oss-prow-bot bot commented Apr 15, 2024

[APPROVALNOTIFIER] This PR is NOT APPROVED

This pull-request has been approved by: sivanantha321
Once this PR has been reviewed and has the lgtm label, please assign njhill for approval by writing /assign @njhill in a comment. For more information see:The Kubernetes Code Review Process.

The full list of commands accepted by this bot can be found here.

Needs approval from an approver in each of these files:

Approvers can indicate their approval by writing /approve in a comment
Approvers can cancel approval by writing /approve cancel in a comment

@sivanantha321 sivanantha321 force-pushed the llm-pytriton-support branch 2 times, most recently from a9be41f to 32182db Compare April 15, 2024 18:39
Signed-off-by: Sivanantham Chinnaiyan <sivanantham.chinnaiyan@ideas2it.com>
Signed-off-by: Sivanantham Chinnaiyan <sivanantham.chinnaiyan@ideas2it.com>
Signed-off-by: Sivanantham Chinnaiyan <sivanantham.chinnaiyan@ideas2it.com>
Signed-off-by: Sivanantham Chinnaiyan <sivanantham.chinnaiyan@ideas2it.com>
@sivanantha321 sivanantha321 force-pushed the llm-pytriton-support branch 2 times, most recently from 3075b51 to f4a7c2c Compare April 22, 2024 08:43
Signed-off-by: Sivanantham Chinnaiyan <sivanantham.chinnaiyan@ideas2it.com>
Signed-off-by: Sivanantham Chinnaiyan <sivanantham.chinnaiyan@ideas2it.com>
COPY huggingfaceserver huggingfaceserver
RUN cd huggingfaceserver && poetry install --no-interaction --no-cache
RUN cd huggingfaceserver && poetry install --no-interaction --no-cache --extras "vllm"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to pin the version

@yuzisun
Copy link
Member

yuzisun commented Apr 28, 2024

@sivanantha321 let’s discuss the implementation, I think we can move the code to the base kserve level so transformer can also leverage this implementation not only huggingface models.

@@ -68,6 +68,9 @@ def list_of_strings(arg):
parser.add_argument(
"--return_token_type_ids", action="store_true", help="Return token type ids"
)
parser.add_argument(
"--enable_triton", action="store_true", help="Use triton as the runtime"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have introduced the new argument backend for huggingfaceserver, so we can add triton as another backend in addition to [vllm, huggingface]

def _triton_client(self):
if self._triton_client_instance is None:
url = f"grpc://localhost:{self._triton_config.grpc_port}"
self._triton_client_instance = AsyncioModelClient(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how is the AsyncioModelClient different from the triton's grpc python client ?

@@ -191,6 +238,84 @@ def load(self) -> bool:
self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
logger.info(f"successfully loaded tokenizer for task: {self.task}")

if self.use_triton:
Copy link
Member

@yuzisun yuzisun Apr 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can consider adding this option to PredictorConfig, so any KServe model is able to launch the triton server and load the model in the load function when the option is enabled.
https://github.com/kserve/kserve/blob/master/python/kserve/kserve/model.py#L89

Comment on lines +289 to +307

def infer_fn(requests: List[Request]):
responses = []
for request in requests:
input_tensors = {}
for input_name, input_array in request.data.items():
input_tensors[input_name] = torch.tensor(
input_array, device=self.device
)

if (
self.task == MLTask.text2text_generation.value
or self.task == MLTask.text_generation
):
outputs = self.model.generate(**input_tensors)
else:
outputs = self.model(**input_tensors)
responses.append({"outputs": outputs.numpy()})
return responses
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The inference should be done by triton not calling self.model, I guess infer_func is for python backend?

inputs=inputs_config,
outputs=outputs_config,
strict=True,
config=ModelConfig(batching=False, response_cache=True),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might also want to expose the batching configs

@@ -364,6 +508,10 @@ async def predict(
if self.vllm_engine:
raise InferenceError(VLLM_USE_GENERATE_ENDPOINT_ERROR)

if self.use_triton:
res = await self._triton_client.infer_sample(**input_batch)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is calling infer_sample ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants