-
Notifications
You must be signed in to change notification settings - Fork 170
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sampling optimizations #152
Conversation
Code Metrics Report─────────────────────────────────────────────────────────────────────────────── Language Files Lines Blanks Comments Code Complexity ─────────────────────────────────────────────────────────────────────────────── Rust 60 20036 1439 820 17777 1130 ─────────────────────────────────────────────────────────────────────────────── Total 60 20036 1439 820 17777 1130 ─────────────────────────────────────────────────────────────────────────────── Estimated Cost to Develop 54,557 Estimated Schedule Effort 10.992939 months Estimated People Required 4.481762 ─────────────────────────────────────────────────────────────────────────────── Processed 677983 bytes, 0.678 megabytes (SI) ─────────────────────────────────────────────────────────────────────────────── |
// Sort by descending probability. | ||
argsort_indices.sort_by(|&i, &j| probs[j].partial_cmp(&probs[i]).unwrap()); | ||
argsort_indices.sort_unstable_by(|&i, &j| probs[j].partial_cmp(&probs[i]).unwrap()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This sort is unstable (i.e., may reorder equal elements), in-place (i.e., does not allocate), and O(n * log(n)) worst-case.
The benchmarks here show the difference between stable and unstable. (I don't think adding the lib is required)
https://github.com/orlp/glidesort
This saves 500ms per token
for (token_id, logit) in logits.iter_mut().enumerate() { | ||
let count = context.iter().filter(|x| **x as usize == token_id).count(); | ||
*logit = *logit | ||
- count as f32 * frequency_penalty | ||
- if count > 0 { 1. } else { 0. } * presence_penalty; | ||
} | ||
let logits_len = logits.len(); | ||
Tensor::from_vec(logits, logits_len, device) | ||
Ok(logits) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this reads the context we can't pre-allocate the tensor.
Allocating the tensor here was triggering a synchronization that would cost 500ms
Thank you! |
Makes sampling fully async, removing synchronizations.
Generation goes from 54t/s to 59t/s using
./target/profiling/mistralrs-server --prompt "Tell me 3 jokes." mistral-gguf
Same generation speed as Llama.cpp 😄