Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sampling optimizations #152

Merged
merged 7 commits into from
Apr 16, 2024
Merged

Conversation

lucasavila00
Copy link
Contributor

@lucasavila00 lucasavila00 commented Apr 16, 2024

Makes sampling fully async, removing synchronizations.

Generation goes from 54t/s to 59t/s using ./target/profiling/mistralrs-server --prompt "Tell me 3 jokes." mistral-gguf

Same generation speed as Llama.cpp 😄

Copy link

github-actions bot commented Apr 16, 2024

Code Metrics Report
  ───────────────────────────────────────────────────────────────────────────────
Language                 Files     Lines   Blanks  Comments     Code Complexity
───────────────────────────────────────────────────────────────────────────────
Rust                        60     20036     1439       820    17777       1130
───────────────────────────────────────────────────────────────────────────────
Total                       60     20036     1439       820    17777       1130
───────────────────────────────────────────────────────────────────────────────
Estimated Cost to Develop 54,557
Estimated Schedule Effort 10.992939 months
Estimated People Required 4.481762
───────────────────────────────────────────────────────────────────────────────
Processed 677983 bytes, 0.678 megabytes (SI)
───────────────────────────────────────────────────────────────────────────────
  

// Sort by descending probability.
argsort_indices.sort_by(|&i, &j| probs[j].partial_cmp(&probs[i]).unwrap());
argsort_indices.sort_unstable_by(|&i, &j| probs[j].partial_cmp(&probs[i]).unwrap());
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sort is unstable (i.e., may reorder equal elements), in-place (i.e., does not allocate), and O(n * log(n)) worst-case.

The benchmarks here show the difference between stable and unstable. (I don't think adding the lib is required)

https://github.com/orlp/glidesort
image

This saves 500ms per token

for (token_id, logit) in logits.iter_mut().enumerate() {
let count = context.iter().filter(|x| **x as usize == token_id).count();
*logit = *logit
- count as f32 * frequency_penalty
- if count > 0 { 1. } else { 0. } * presence_penalty;
}
let logits_len = logits.len();
Tensor::from_vec(logits, logits_len, device)
Ok(logits)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this reads the context we can't pre-allocate the tensor.

Allocating the tensor here was triggering a synchronization that would cost 500ms

@EricLBuehler EricLBuehler merged commit 62d560f into EricLBuehler:master Apr 16, 2024
11 checks passed
@EricLBuehler
Copy link
Owner

Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants