-
Notifications
You must be signed in to change notification settings - Fork 375
[wip] refactor: add more features to vec_t
#1142
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request starts the process of enhancing the vec_t
class to provide more flexible and accessible features, addressing points raised in recent issues. The current changes introduce the ability to access sub-vectors with specified granularity via a new lane
method and include some necessary type corrections for internal vector data representation and casting.
Highlights
- Vector Granularity Access: I've added a new
lane<size_t G>(size_t lane_idx)
method to thevec_t
class template and its specializations. This allows accessing a sub-vector of sizeG
(whereG
must divide the total vector size) at a specific indexlane_idx
. This is intended to support operations on smaller vector units, like pairs of BF16/FP16 elements. - Type Corrections: I've corrected the internal data type for some
vec_t
specializations (e.g.,vec_t<__nv_fp8_e4m3, 8>
,vec_t<half, 4>
) fromuint2
toint2
. Corresponding pointer casts inload
,store
,memcpy
, and thefast_dequant_f8f16x4
helper function have also been updated fromuint2*
toint2*
.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configureGemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new lane
method to vec_t
for accessing sub-vectors and changes some internal data types from uint2
to int2
. The primary concern is a recurring pointer arithmetic error in the implementation of the lane
method across nearly all its specializations. The uint2
to int2
changes appear to be for consistency and are generally acceptable. It's recommended to correct the lane
method's pointer logic and add specific tests for this new functionality.
template <size_t G> | ||
FLASHINFER_INLINE __device__ vec_t<__nv_fp8_e5m2, G> lane(size_t lane_idx) const { | ||
static_assert(vec_size % G == 0, "G must divide vec_size"); | ||
return *reinterpret_cast<const vec_t<__nv_fp8_e5m2, G>*>(&data[lane_idx / vec_size] + | ||
G * (lane_idx % vec_size)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The pointer arithmetic for the lane
method in this generic vec_t<__nv_fp8_e5m2, vec_size>
is incorrect. Pointer arithmetic is on int4*
instead of const __nv_fp8_e5m2*
.
Corrected form:
return *reinterpret_cast<const vec_t<__nv_fp8_e5m2, G>*>(reinterpret_cast<const __nv_fp8_e5m2*>(data) + G * lane_idx);
return *reinterpret_cast<const vec_t<__nv_fp8_e5m2, G>*>(reinterpret_cast<const __nv_fp8_e5m2*>(data) + G * lane_idx);
FLASHINFER_INLINE __device__ vec_t<nv_bfloat16, G> lane(size_t lane_idx) const { | ||
static_assert(2 % G == 0, "G must divide vec_size"); | ||
return *reinterpret_cast<const vec_t<nv_bfloat16, G>*>(&data + G * lane_idx); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The pointer arithmetic in the lane
method has the same issue. &data
is nv_bfloat162*
. Arithmetic should be based on const nv_bfloat16*
.
Corrected form:
return *reinterpret_cast<const vec_t<nv_bfloat16, G>*>(reinterpret_cast<const nv_bfloat16*>(&data) + G * lane_idx);
return *reinterpret_cast<const vec_t<nv_bfloat16, G>*>(reinterpret_cast<const nv_bfloat16*>(&data) + G * lane_idx);
FLASHINFER_INLINE __device__ vec_t<__nv_fp8_e4m3, G> lane(size_t lane_idx) const { | ||
static_assert(2 % G == 0, "G must divide vec_size"); | ||
return *reinterpret_cast<const vec_t<__nv_fp8_e4m3, G>*>(&data + G * lane_idx); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The pointer arithmetic in the lane
method appears incorrect. &data
is of type __nv_fp8x2_e4m3*
. The calculation should be based on the size of __nv_fp8_e4m3
.
A corrected approach:
return *reinterpret_cast<const vec_t<__nv_fp8_e4m3, G>*>(reinterpret_cast<const __nv_fp8_e4m3*>(&data) + G * lane_idx);
return *reinterpret_cast<const vec_t<__nv_fp8_e4m3, G>*>(reinterpret_cast<const __nv_fp8_e4m3*>(&data) + G * lane_idx);
FLASHINFER_INLINE __device__ vec_t<__nv_fp8_e4m3, G> lane(size_t lane_idx) const { | ||
static_assert(4 % G == 0, "G must divide vec_size"); | ||
return *reinterpret_cast<const vec_t<__nv_fp8_e4m3, G>*>(&data + G * lane_idx); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The pointer arithmetic in the lane
method has the same issue as noted for vec_t<__nv_fp8_e4m3, 2>
. The base pointer for arithmetic should be cast to const __nv_fp8_e4m3*
.
Corrected form:
return *reinterpret_cast<const vec_t<__nv_fp8_e4m3, G>*>(reinterpret_cast<const __nv_fp8_e4m3*>(&data) + G * lane_idx);
return *reinterpret_cast<const vec_t<__nv_fp8_e4m3, G>*>(reinterpret_cast<const __nv_fp8_e4m3*>(&data) + G * lane_idx);
FLASHINFER_INLINE __device__ vec_t<float, G> lane(size_t lane_idx) const { | ||
static_assert(vec_size % G == 0, "G must divide vec_size"); | ||
return *reinterpret_cast<const vec_t<float, G>*>(&data[lane_idx / vec_size] + | ||
G * (lane_idx % vec_size)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The pointer arithmetic for the lane
method in this generic vec_t<float, vec_size>
is incorrect. Pointer arithmetic is on float4*
instead of const float*
.
Corrected form:
return *reinterpret_cast<const vec_t<float, G>*>(reinterpret_cast<const float*>(data) + G * lane_idx);
return *reinterpret_cast<const vec_t<float, G>*>(reinterpret_cast<const float*>(data) + G * lane_idx);
@@ -872,12 +915,17 @@ FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::memcpy(__nv_fp8_e5m2* dst, | |||
|
|||
template <> | |||
struct vec_t<__nv_fp8_e5m2, 8> { | |||
uint2 data; | |||
int2 data; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
template <size_t G> | ||
FLASHINFER_INLINE __device__ vec_t<half, G> lane(size_t lane_idx) const { | ||
static_assert(1 % G == 0, "G must divide vec_size"); | ||
return *reinterpret_cast<const vec_t<half, G>*>(&data + G * lane_idx); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FLASHINFER_INLINE __device__ vec_t<float, G> lane(size_t lane_idx) const { | ||
static_assert(1 % G == 0, "G must divide vec_size"); | ||
return *reinterpret_cast<const vec_t<float, G>*>(&data + G * lane_idx); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FLASHINFER_INLINE __device__ vec_t<__nv_fp8_e5m2, G> lane(size_t lane_idx) const { | ||
static_assert(1 % G == 0, "G must divide vec_size"); | ||
return *reinterpret_cast<const vec_t<__nv_fp8_e5m2, G>*>(&data + G * lane_idx); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -599,12 +616,17 @@ FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::memcpy(__nv_fp8_e4m3* dst, | |||
|
|||
template <> | |||
struct vec_t<__nv_fp8_e4m3, 8> { | |||
uint2 data; | |||
int2 data; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
<!-- .github/pull_request_template.md --> ## 📌 Description We enabled and updated the fused quantization in all-reduce/moe all-reduce. ## 🔍 Related Issues Depends on #1142 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
📌 Description
Per recent communication operators #1096 #1108 #1131 #1134 indicates we need to add more features to
vec_t
class to make them more accessible:vec_t
(+/-/*/max/etc)vec_t<T, N>::any(cond)
This PR implements these functionalities.
🔍 Related Issues
#1096 #1108 #1131 #1134
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commit
by runningpip install pre-commit
(or used your preferred method).pre-commit install
.pre-commit run --all-files
and fixed any reported issues.🧪 Tests
unittest
, etc.).Reviewer Notes