Glm4 mtp optimizations #4
                
     Draft
            
            
          
  Add this suggestion to a batch that can be applied as a single commit.
  This suggestion is invalid because no changes were made to the code.
  Suggestions cannot be applied while the pull request is closed.
  Suggestions cannot be applied while viewing a subset of changes.
  Only one suggestion per line can be applied in a batch.
  Add this suggestion to a batch that can be applied as a single commit.
  Applying suggestions on deleted lines is not supported.
  You must change the existing code in this line in order to create a valid suggestion.
  Outdated suggestions cannot be applied.
  This suggestion has been applied or marked resolved.
  Suggestions cannot be applied from pending reviews.
  Suggestions cannot be applied on multi-line comments.
  Suggestions cannot be applied while the pull request is queued to merge.
  Suggestion cannot be applied right now. Please check back later.
  
    
  
    
I've created this draft to share my findings on what to fix or improve to make MTP usable. Currently, MTP's output quality is good, but its performance is worse than not using it at all. Therefore, it's not enough to be on par with the baseline; we need to be faster.
My initial plan is to find areas for improvement. It's not necessary to implement everything at once, but some of these should be on our radar for the future. They are:
llama_context::decodecallsThere are likely more things to improve, but for now, I find these to be the most impactful. Below are my thoughts on each:
1) Graph Reuse: The baseline implementation always reuses the graph. The process is simple: it stores the graph, and in the next call to
llama_context::process_ubatch, it checks if the stored graph can be reused. If not, it's deleted and the new one is stored. This works well after the first token is generated, as subsequent graphs are identical. The main bottleneck isn't callingllama_model::build_graphconstantly, but ratherggml_backend_sched_alloc_graph, which has to allocate and compute resources for the backend.The first fix was simple: just store one graph. In this case, the main model's token generation graph, which is one of the most expensive, will always be reused. On my machine, this gave an uplift of 13.8% for small prompts.
Current state: Halted.
After that, I tried to store the graph for every operation, or at least the ones that didn't involve the KV cache. By applying
llm_graph_context::cbto certain layers, I could store and reuse the graph, and I was able to compile and test this using only the CPU backend. However, I was unable to get it working with the offload policy. In theory, thecbfunction should handle that, but something else seems to be preventing specifically the allocation and computation. Is it mixing the offload policies of the main model and the MTP? This needs a deeper investigation, and I lack the proper knowledge in this area, so I'm setting it aside for now.2)
decodecalls: MTP was successfully implemented insidedecode, but it uses the old logic where each operation requires an expensive function call. Here is a comparison of how many calls we make in different scenarios:LLM - Normal:
Draft Model:
MTP (Current Slow Implementation):
One way to make MTP more usable is to match the number of calls of a typical draft model. To do that, it's necessary to combine the KV cache update and the draft generation into a single call.
Current state: In progress.
I successfully merged the KV cache update with the draft generation. This required creating a custom batch and sinfo, and changing some logic regarding the embeddings and hidden states necessary for the MTP to work. The version in this branch works in terms of output, meaning it's not breaking quality. However, the draft acceptance rate has dropped to around 25%. I believe this happens because while the first step (KV update) works using the correct hidden state from the main model, the subsequent operation (draft) is using a new hidden state generated by the MTP itself during the update. I still need to confirm this theory and apply a fix to hopefully see the acceptance rate rise back to its previous level.
One last thing: this change will still require a separate warmup call on the first interaction, but this is less impactful than merging the update and draft steps. To merge the warmup step, it would be necessary to track the sinfo to know when the prompt processing has finished its last batch, and then insert a new slot for the draft token.
3) Multi-token drafts: We discussed this in another PR. The problem was that for each new draft token, the MTP's KV cache needed to be updated, which was painful to do before. Now that we are using the
decodefunction, it's more feasible. If the unified update/draft implementation works, we could simply increase the batch and sinfo size to make the model draft more tokens.These are some of my ideas. I'd appreciate any insights you might have on how to better handle some of these things, or even new ideas for improvements that I haven't spotted here.