Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model grammar support via BNF #59

Closed
EricLBuehler opened this issue Apr 3, 2024 · 20 comments
Closed

Model grammar support via BNF #59

EricLBuehler opened this issue Apr 3, 2024 · 20 comments
Labels
new feature New feature or request processing Processing related to the model

Comments

@EricLBuehler
Copy link
Owner

EricLBuehler commented Apr 3, 2024

We will implement based on this.

The idea is as follows, given parsed BNF.

  1. While the model is calculating the logits, prepare the logit bias on a worker thread (from a pool).
  2. Run normal sampling first: if the returned token is valid grammar, avoid applying the logit bias
  3. During normal sampling, apply the logit bias on a worker thread (from a pool).
  4. If the normal sampling produced a token that would be invalid, rerun with the applied logit bias.
@EricLBuehler EricLBuehler added new feature New feature or request processing Processing related to the model labels Apr 3, 2024
@EricLBuehler EricLBuehler changed the title Model grammar support via GBNF Model grammar support via BNF Apr 3, 2024
@lucasavila00
Copy link
Contributor

lucasavila00 commented Apr 4, 2024

@EricLBuehler do you have further steps of it planned?

Looking at similar designs:

Outlines generates a FSM from regex (using https://github.com/MegaIng/interegular), transforms the FSM from text-based to token-based https://github.com/outlines-dev/outlines/blob/main/outlines/fsm/regex.py#L47 which is later used to influence the logits by vLLM: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/guided_logits_processors.py#L28

https://github.com/sgl-project/sglang also uses Outlines FSM but SGLang also implements a jump-forward optimization mechanism https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/constrained/jump_forward.py#L8 https://lmsys.org/blog/2024-02-05-compressed-fsm/

Llama.cpp seems to compute the state of the FSM as it walks it, doing the same logit operations https://github.com/ggerganov/llama.cpp/blob/master/llama.cpp#L12510 after building this data structure https://github.com/ggerganov/llama.cpp/blob/master/llama.cpp#L11954 after parsing the grammar

@lucasavila00
Copy link
Contributor

vLLM roadmap also shows they plan to integrate lm-format-encoder, that does not pre-build the FSM vllm-project/vllm#3713

And also AICI vllm-project/vllm#2888

AICI is a Rust library, I wonder if it wouldn't be simpler for mistral.rs to implement it, or implement it alongside other setups.

@EricLBuehler
Copy link
Owner Author

@lucasavila00, I plan on working on this next week. Would you be able to put together a skeleton PR? This looks like an exciting development.

@lucasavila00
Copy link
Contributor

Sure, I'll try to integrate AICI over the weekend.

@lucasavila00
Copy link
Contributor

lucasavila00 commented Apr 6, 2024

Implementing AICI has shown to be too complex for me to build quickly.

What I learned on #81 :

  • AICI is too young, there is no library published to crates.io, the protocol might change (eg: [RFC] drop pre/post callback, only leave mid microsoft/aici#68) and so on
  • AICI supports fast forwarding and it is complicated to implement. SGLang does fast-forwarding by appending text, re-tokenizing it and running a new request. SGLang has RadixAttention where it can cache the requests very granuarly, so for SGLang the only overhead is re-tokenization.
  • Configuring the side channel messages is trivial

To run the server on that PR use

./target/release/mistralrs-server --aicirt=PATH_TO_AICIRT --port 1234 --log target/output.txt mistral-gguf 

where PATH_TO_AICIRT is the binary from Github Releases https://github.com/microsoft/aici/releases/tag/v0.0.10

Use mistral-gguf because I hardcoded some settings instead of building the code that reads them from the config.json, tokenizer.json. These places were annotated with TODOs.

It only partially implemented the mid part of the protocol. pre and post have not been implemented.

@EricLBuehler
Copy link
Owner Author

Thank you for working on this! We will probably begin work on this early next week.

Like you said, AICI is pretty complicated, and we will probably not use it. I will look into the SGLang method you mentioned. If you have any ideas for implementing grammars, please let me know!

@lucasavila00
Copy link
Contributor

lucasavila00 commented Apr 6, 2024

I have this https://github.com/lucasavila00/LmScript project that supports both vLLM and SGLang. So I really only know about these.

Both use Outlines FSM.

The RadixAttention used by SGLang makes it faster to use prompts that generate many small pieces. For example, an ad-hoc structured XML generation https://github.com/lucasavila00/LmScript/blob/main/packages/client/src/backends/executor.ts#L78-L136

Also, SGLang has 2 different select-from-choices operations.

If one uses a regex with multiple options the backend will eagerly-per-token select the most probable choice. This might be unintuitive (eg: guidance-ai/guidance#564)

If one uses the native SGLang select operation then it calculates the probability for each of the options as whole: https://github.com/sgl-project/sglang/blob/ff99c38a0711ee82926840129db840a70e91f0d9/python/sglang/backend/runtime_endpoint.py#L191-L242

SGLang select is amazing regarding result quality, and it is implemented with just a few extra settings to the server, that tell the server which tokens it should return the logprobs of. Of course, this only works efficiently because the RadixAttention cache can re-use the computation of the common prefix.

I like SGLang a lot. The only issue is that it takes a long time to pre-compile the regexes, which are then saved to disk and available for re-use. I agree with vllm-project/vllm#3713 "For endpoint products running model as a service with customers supplying many different schemas, the cost might not be acceptable."

I would like if SGLang's design could compile regexes fast or reject them due to excessive complexity, that's the major flaw I see with it. Besides bad error messages when it receives invalid regexes and so on.

I did look into rust's regex-automata for the regex compilation https://github.com/rust-lang/regex/blob/master/regex-automata/src/dfa/mod.rs#L265-L270 but as the linked line says it is hard to make this compilation efficient.

@lucasavila00
Copy link
Contributor

Ah, SGLang doesn't do token healing. Only Guidance does https://towardsdatascience.com/the-art-of-prompt-design-prompt-boundaries-and-token-healing-3b2448b0be38

It would be amazing if SGLang's design also had token healing.

It does require a lot of changes to be supported though, one needs to remove parts of the initial prompt and then let the model re-generate the removed part.

If this is a desired feature, you might want to keep it in mind while building the other parts to make an eventual implementation easier.

@lucasavila00
Copy link
Contributor

Other interesting things I found:

Latest version of TGI https://github.com/huggingface/text-generation-inference also uses Outlines FSM, supports regex decoding

Llama.cpp has performance issues on its grammar implementation ggerganov/llama.cpp#4218 (comment)

Kalosm implemented regex-based decoding using the regex-automata DFA, as I mentioned above huggingface/candle#1945 (comment)

vLLM has cache prefix now vllm-project/vllm#2614 (Unfortunately I couldn't test it as my RTX2070 is too old for it, so I can't tell if this works as well as SGLang Radix Attention. I'm in the process of replacing the GPU and I should be able to compare the approaches in a week or so)

@EricLBuehler
Copy link
Owner Author

Thanks for the links. I really like the idea of using BNF, maybe it could be converted to a regex. After looking at the Kalosm implementation and this issue, I think we could implement this with the logit bias and just parse the BNF. Reading the llama.cpp issue, I think there are a few considerations we need to take into account when implementing:

  1. Use a polynomial bounded time parsing algorithm
  2. Run normal sampling and if that returns a token that does not match the grammar, only then spend the time computing the grammar logit bias

Potentially, we could use a regex DFA to handle the matching. What do you think about that plan?

Regarding prompt caching, I think that is workable, as we could just implement an eviction policy for KV caches. However, that is a separate topic, and perhaps you could raise a tracking issue for it?

@lucasavila00
Copy link
Contributor

lucasavila00 commented Apr 8, 2024

I really like the idea of using BNF

Outlines has an example of .lark files. Terminals are regexes and one can build the parser on top of the terminals, https://github.com/outlines-dev/outlines/blob/main/outlines/grammars/json.lark https://github.com/outlines-dev/outlines/blob/main/outlines/grammars/common.lark https://github.com/outlines-dev/outlines/blob/main/outlines/fsm/guide.py

Potentially, we could use a regex DFA to handle the matching. What do you think about that plan?

Generating the FSM for a "big" regex expression is slow. This SGLang example https://github.com/sgl-project/sglang?tab=readme-ov-file#json-decoding of a JSON object with ~10 fields takes a minute to compile.

A mixture of an "interpreter" for the BNF/gramamar, where each terminal is a compiled regex could work. We would cache the regexes FSM and build big JSON objects without a big compilation delay. This is how Outlines implements the .lark files.

Regarding prompt caching, I think that is workable, as we could just implement an eviction policy for KV caches. However, that is a separate topic, and perhaps you could raise a tracking issue for it?

Sure.

@EricLBuehler
Copy link
Owner Author

True, compiling a big BNF would be very costly, I like the idea of only compiling the terminals. I wonder if we can multithread the logit bias calculation with a thread pool, we can do something like:

let mut threadmap = HashMap::new();

And then accumulate them:

for handle in handles {
let res = handle.join().unwrap()?;
ws.extend(res);
}
}

The vocab size for a Mistral model is 32000, so iterating over that would be expensive in the sampling hotloop! We are currently ~3 ms/T slower than llama.cpp on an A10 + Mistral Q4_K_M GGUF, so I want to improve performance.

@lucasavila00
Copy link
Contributor

lucasavila00 commented Apr 8, 2024

On performance, I think the explanation of the AICI protocol is good https://github.com/microsoft/aici/blob/main/docs/aicirt-proto.md:

...
- the LLM schedules some of the non-supended sequences to compute logits for
- the LLM informs AICIrt about the scheduled sequences;...
- the LLM starts computing logits
- AICIrt sends the logit biases to the LLM
- LLM adds computed logits and biases and samples tokens
...

It should be done in a different thread, concurrently to the logit calculation, and it should not start calculation at sampling time, but as soon as the GPU starts the logit step.

According to AICI, if done this way, there is little cost: https://github.com/microsoft/aici?tab=readme-ov-file#performance

For example, computing allowed token set in the 32000-strong vocabulary of Llama model takes:

- about 2.0ms for Yacc grammar of the C programming language
- about 0.3ms for a regular expression
- about 0.2ms for a substring constraint, from 4kB string

Since logits calculation usually take longer than that, and the servers usually have many CPUs, there is no cost.

@EricLBuehler
Copy link
Owner Author

Ah, that is great: that way the logit biases are ready before the logits are! Even with this, I think it would still be best to sample w/o applying the logit biases first as an optimization, as that needs to be sequential and may iterate over most of the vocab size. In fact, while the initial sampling is occurring a worker thread can apply the logit bias.

If you want to open a PR laying the groundwork for some of this, such as the logit bias application and dual-sampling setup, please feel free! I am hoping to implement grammar support this week.

@lucasavila00
Copy link
Contributor

Awesome!

I'll have a busy week. I fear I'll only be available next weekend.

Once I have the time I'll look around for what I can help with.

@lucasavila00
Copy link
Contributor

lucasavila00 commented Apr 10, 2024

I was reading a bit more about DFA, FSM and so on thinking about how to implement this, and I stumbled upon the details of the AICI ABI:

https://github.com/microsoft/aici/tree/main/controllers/aici_abi

It works pretty much like Kalosm where it runs the regex for every logit. However, they do it on a token trie, so if tokens share a common prefix they won't require re-computation.

They provide the grammar with regex terminals: https://github.com/microsoft/aici/tree/main/controllers/aici_abi#lr1-grammars

Regex: https://github.com/microsoft/aici/tree/main/controllers/aici_abi#regular-expressions

And low level API like required in the sampler, that can be used for both the first optimistic check and the re-sampling too: https://github.com/microsoft/aici/tree/main/controllers/aici_abi#low-level-interface

It requires an instance of TokTrie that can be built like I did in the previous AICI MR https://github.com/EricLBuehler/mistral.rs/pull/81/files#diff-8e7ab085145c61f5613962a6b30db55a222fa6bf6e432e82df9ab90dbfb4627aR900-R903

And a stateful instance of the Recognizer per sequence, like this regex https://github.com/microsoft/aici/blob/main/controllers/aici_abi/src/rx.rs#L20 or grammar https://github.com/microsoft/aici/blob/main/controllers/aici_abi/src/cfg.rs

@EricLBuehler what are your thoughts?

To me it looks like we can do this by adding aici-abi as a dependency. (notice that the previous MR also added aici-runtime, which I'm not proposing to add anymore)

We could also start with copying the code and editing it. However after reading it for a bit I think using it as a dependency should be good enough...

@ealmloff
Copy link

I also published kalosm-sample as a separate library if you want to re-use the logic. It currently walks the dfa as it generates tokens so each individual token will take a constant amount of time to validate regardless of the generation length. If you are generating a series of different but similar sequences I would be happy to merge support for a prompt cache.

One technique kalosm uses that I haven't seen mentioned here is constraint accelerated batching. If you see that the next n tokens must be a specific string, you can load all of those tokens into the kv cache at once in a larger batch to accelerate generation

@EricLBuehler
Copy link
Owner Author

@lucasavila00, I think it would be good to do something similar to the token trie + kalosm-style regex matching.

@ealmloff, thanks for the link! If you would be able to merge support for a prompt cache, that would be much appreciated! I was wondering, in the kalosm internals, do you apply the Parser to every token in the vocab to make a logit bias, or do you do someting else?

@ealmloff
Copy link

ealmloff commented Apr 10, 2024

I was wondering, in the kalosm internals, do you apply the Parser to every token in the vocab to make a logit bias, or do you do something else?

Currently, yes I do that here. If you use top_k, you could only parse tokens until you get at least k tokens instead of every logit

@EricLBuehler
Copy link
Owner Author

Implemented in #103, thank you @lucasavila00 !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
new feature New feature or request processing Processing related to the model
Projects
None yet
Development

No branches or pull requests

3 participants