Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support output logprobs with turbomind backend. #1391

Merged
merged 20 commits into from
Apr 21, 2024
Merged

Conversation

irexyc
Copy link
Collaborator

@irexyc irexyc commented Apr 3, 2024

Motivation

Add logprobs output.

Openai has different logprobs structure of chat.completions and completions apis, however vllm use same structure of these two api. I think the logprobs structure of completions is more user-friendly, so I followed vllm to use this structure with these two apis.

Modification

  • pytorch / turbomind output with dataclass class
  • logprobs logits for turbomind backend

Use cases (Optional)

from openai import OpenAI
client = OpenAI(base_url='http://0.0.0.0:23333/v1', api_key='sk-l6bdprDovMBW6bRs22B05f5dBa3f417d8bC13e3d131f73Aa')

model_name = client.models.list().data[0].id
completion = client.chat.completions.create(
  model=model_name
  messages=[
    {"role": "user", "content": "讲一个笑话"}
  ],
  logprobs=True,
  top_logprobs=2,
  max_tokens=10,
  # stream=True
)

completion = client.completions.create(
  model=model_name,
  prompt="今天天气真好",
  logprobs=2,
  max_tokens=10,
  # stream=True
)
from lmdeploy import pipeline, GenerationConfig, PytorchEngineConfig

pipe = pipeline('/nvme/shared/vicuna-7b-v1.5/', backend_config=PytorchEngineConfig())
pipe('hello', gen_config=GenerationConfig(logprobs=10, top_k=40, max_new_tokens=10))

@lvhan028 lvhan028 added the enhancement New feature or request label Apr 3, 2024
@lvhan028
Copy link
Collaborator

lvhan028 commented Apr 4, 2024

build failed on windows platform

@lvhan028
Copy link
Collaborator

May merge latest main to resolve pr_ete_test worflow error

@lvhan028 lvhan028 requested a review from AllentDan April 17, 2024 03:22
Args:
status (ResponseType): the response type.
token_ids (List[int]): the output token ids.
num_token (int): the length of output token, for turbomind, num_token
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是可能会多出来一个token么?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stop word 的时候

output[-1].item() in gen_config.stop_words:
outputs = (status, output[:-1].tolist(), len_)

@@ -61,6 +61,8 @@ class ChatCompletionRequestQos(BaseModel):
messages: Union[str, List[Dict[str, str]]]
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有没有上界的限制?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

openai 上限是5。vllm没有限制,turbomind 受限于top_k的kernel,上限是1024 (or 1023)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那要校验参数的合法性。不要引起crash,hung等严重的问题

@lvhan028
Copy link
Collaborator

from lmdeploy import pipeline, GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig

pipe = pipeline('/workspace/140_models/InternLM/internlm2-chat-7b', backend_config=PytorchEngineConfig())
response = pipe('hello', gen_config=GenerationConfig(logprobs=10, top_k=1, max_new_tokens=10))
print(response)

pytorch engine should warn that "logprobs" hasn't been supported yet.

@lvhan028
Copy link
Collaborator

lvhan028 commented Apr 17, 2024

from lmdeploy import pipeline, GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig


pipe = pipeline('/workspace/140_models/InternLM/internlm2-chat-7b')
response = pipe('hello', gen_config=GenerationConfig(logprobs=10, top_k=1, max_new_tokens=10))
print(response)

The result is:

Response(text='你好!有什么我可以帮助你的吗?', generate_token_len=9, input_token_len=103, session_id=0, finish_reason='stop', token_ids=[77230, 60477, 69259, 74010, 68417, 68364, 61076, 60504], logprobs=[{77230: 0.0}, {60477: 0.0}, {69259: 0.0}, {74010: 0.0}, {68417: 0.0}, {68364: 0.0}, {61076: 0.0}, {60504: 0.0}])
  1. 为什么 prob的值都是 0.0?
  2. generate_token_len 和 token_ids的长度,logprobs 的长度不等。我感觉这会让用户很困惑,我们解释起来也麻烦。能不能在 turbomind.py 中处理好呢?
  3. 当把 top_k 设置为 2 时,结果看起来也不对。

@irexyc
Copy link
Collaborator Author

irexyc commented Apr 17, 2024

top_k 为1的话,只有一个候选词,概率是1,log一下就是0了。

logprobs 的长度应该跟token_ids的长度是一样的, 跟generate_token_len长度不一样应该是因为遇到stop word了。我记得这里是为了kv_cache的step吧

top_k 为 2的时候结果是什么?

@lvhan028
Copy link
Collaborator

从 pipeline 的层面来说,推理时的generate参数,行为,需要和 transformers 一致。
行为一致我们这么规定:

  1. 相同的 batch prompt,相同的generate config,如果 batch 中每个prompt通过transformers得到的结果一样,那么lmdeploy也应该一样。不约束和transformers的结果一模一样

@lvhan028
Copy link
Collaborator

top_k 为1的话,只有一个候选词,概率是1,log一下就是0了。

logprobs 的长度应该跟token_ids的长度是一样的, 跟generate_token_len长度不一样应该是因为遇到stop word了。我记得这里是为了kv_cache的step吧

top_k 为 2的时候结果是什么?

忘记要取 log 了,那应该没问题

@lvhan028
Copy link
Collaborator

lvhan028 commented Apr 17, 2024

建议增加 ut,测试 sampling kernel

src/turbomind/kernels/sampling_topp_kernels.cu Outdated Show resolved Hide resolved
async for res in generator:
logprobs = None
if request.logprobs and res.logprobs:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里会有 request.logprobs 有,但是 res.logprobs 无的情况吗?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有,pytorch backend的时候

@lvhan028
Copy link
Collaborator

分别在 pipeline.md, api_server.md 增加 example,介绍获取 logprobs的用法吧。

lmdeploy/turbomind/turbomind.py Outdated Show resolved Hide resolved
src/turbomind/kernels/sampling_topk_kernels.cu Outdated Show resolved Hide resolved
@lzhangzz
Copy link
Collaborator

We need to benchmark the performance impact of requesting logprobs.

@lvhan028
Copy link
Collaborator

We need to benchmark the performance impact of requesting logprobs.

internlm2-7b, rps 23.734

@lvhan028
Copy link
Collaborator

evaluation test pass

@irexyc
Copy link
Collaborator Author

irexyc commented Apr 20, 2024

with vocab = 92544
image

@lvhan028 lvhan028 merged commit b797a90 into InternLM:main Apr 21, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants