spec: support MTP#6
Conversation
|
After the refactoring, all the state management of the draft context is perform outside of diff --git a/common/speculative.cpp b/common/speculative.cpp
index ef13edd34..95329b8a6 100644
--- a/common/speculative.cpp
+++ b/common/speculative.cpp
@@ -592,19 +592,6 @@ struct common_speculative_state_mtp : public common_speculative_impl {
auto & draft_tokens = *dp.result;
draft_tokens.clear();
- if (last_n_drafted[seq_id] > 0) {
- const int32_t n_to_drop = (int32_t) last_n_drafted[seq_id] - 1;
- if (n_to_drop > 0) {
- const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
- if (pos_max >= 0) {
- const llama_pos drop_from = pos_max - n_to_drop + 1;
- llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, drop_from, -1);
- }
- }
- last_n_drafted[seq_id] = 0;
- last_n_accepted[seq_id] = 0;
- }
-
// Effective draft length: min(global cap, per-sequence override).
int32_t n_max = std::max(1, params.n_max);
if (dp.n_max > 0) {
@@ -673,32 +660,9 @@ struct common_speculative_state_mtp : public common_speculative_impl {
cond_tok = best;
++pos;
}
-
- last_n_drafted[seq_id] = (uint16_t) draft_tokens.size();
}
void accept(llama_seq_id seq_id, uint16_t n_accepted) override {
- GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < last_n_drafted.size());
-
- auto * ctx_dft = this->params.ctx_dft;
-
- const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
- const int32_t n_drafted_last = (int32_t) last_n_drafted[seq_id];
-
- const int32_t n_to_drop = std::max(0, n_drafted_last - (int32_t) n_accepted - 1);
-
- if (pos_max < 0) {
- last_n_accepted[seq_id] = (int32_t) n_accepted;
- return;
- }
-
- if (n_to_drop > 0) {
- const llama_pos drop_from = pos_max - n_to_drop + 1;
- llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, drop_from, -1);
- }
-
- last_n_drafted [seq_id] = 0;
- last_n_accepted[seq_id] = (int32_t) n_accepted;
}
};
|
|
Give me ~1 hour an I'll open a PR here to simplify (wip: https://github.com/ggml-org/llama.cpp/tree/gg/spec-mtp-experiments) |
In the partial rollback implementation, the accepted batch is not re-evaluated with the draft context, correct? I think this will narrow the difference a bit, though not very sure by how much. Here are the
|
|
on my DGX spark (patched with adding a draft acceptance loop)
|
|
Another thing is |
|
Basically at low acceptance rates < 0.5, the speed difference is going to much larger. From anecdotal usage, using this PR I seem to even hit 9 toks/sec when doing real coding work, vs with partial rollback I never hit below 14 toks/sec even when acceptance is low. You can perhaps try and use it, I felt the difference is quite real. |
|
Did you use this branch or #7 ? |
|
I used this branch, just saw #7 |
|
Just tried #7 as well, Qwen3.6 27B - "wall_s_total": 100.33 Somehow acceptance rates are suspiciously high, maybe some accounting error For reference in |
With the |
|
You can observe the accepted drafts with |
|
|
Yes, I'm also not sure. On Mac it is always useful for some reason. On CUDA sometimes it helps sometimes not. In any case, it can be adjusted with the Regarding the partial rollback - it does bring a noticeable benefit on CUDA. But I still don't see a good way to support it cleanly. Among other drawbacks, the compute graph is also no longer static. The logic is not compatible with ngram speculative decoding because it uses long drafts of ~64 tokens which still need to be checkpointed. And for some reason that I still don't understand, it does not seem to help much on Mac. |
| // TODO: how to make it work with vision tokens? | ||
| if (batch_in.token == nullptr || batch_in.embd != nullptr) { | ||
| pending_pos[seq_id] = -1; | ||
| return true; | ||
| } |
There was a problem hiding this comment.
I'm not really sure what is the correct way to process the image embeddings with the MTP context. In any case, vision MTP seems to already work to good extent:
Here I ask it to OCR 100 random integers without speculative decoding and with MTP:
- Without spec decoding
- With MTP
With MTP it is ~2x faster which means the MTP context "knows" about the integers in some way. But at the same time, I'm pretty sure that the current way of processing is not 100% correct because inp->tokens tensor in the mtp graph is being used with stale data when the input batch has image embeddings and no tokens.
I think we will figure this out later - not super important atm.
I have removed the partial rollback changes and isolated changes for just qwen models. Things to work out
n_seq> 1note that partial rollback is extremely important for the speed-up here, for the MoE model there is actually a slowdown with MTP on this branch