Skip to content

[wip] refactor: add more features to vec_t #1142

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

yzh119
Copy link
Collaborator

@yzh119 yzh119 commented Jun 12, 2025

📌 Description

Per recent communication operators #1096 #1108 #1131 #1134 indicates we need to add more features to vec_t class to make them more accessible:

  1. access elements with the granularity of 2 (because CUDA supports vector instructions for bf16, etc).
  2. binary operations on vec_t (+/-/*/max/etc)
  3. determine whether the vector include element satisfy some condition: vec_t<T, N>::any(cond)
  4. fp4 support
  5. rich vector type casting such as int to float, or float to int

This PR implements these functionalities.

🔍 Related Issues

#1096 #1108 #1131 #1134

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request starts the process of enhancing the vec_t class to provide more flexible and accessible features, addressing points raised in recent issues. The current changes introduce the ability to access sub-vectors with specified granularity via a new lane method and include some necessary type corrections for internal vector data representation and casting.

Highlights

  • Vector Granularity Access: I've added a new lane<size_t G>(size_t lane_idx) method to the vec_t class template and its specializations. This allows accessing a sub-vector of size G (where G must divide the total vector size) at a specific index lane_idx. This is intended to support operations on smaller vector units, like pairs of BF16/FP16 elements.
  • Type Corrections: I've corrected the internal data type for some vec_t specializations (e.g., vec_t<__nv_fp8_e4m3, 8>, vec_t<half, 4>) from uint2 to int2. Corresponding pointer casts in load, store, memcpy, and the fast_dequant_f8f16x4 helper function have also been updated from uint2* to int2*.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configureGemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new lane method to vec_t for accessing sub-vectors and changes some internal data types from uint2 to int2. The primary concern is a recurring pointer arithmetic error in the implementation of the lane method across nearly all its specializations. The uint2 to int2 changes appear to be for consistency and are generally acceptable. It's recommended to correct the lane method's pointer logic and add specific tests for this new functionality.

Comment on lines +981 to +985
template <size_t G>
FLASHINFER_INLINE __device__ vec_t<__nv_fp8_e5m2, G> lane(size_t lane_idx) const {
static_assert(vec_size % G == 0, "G must divide vec_size");
return *reinterpret_cast<const vec_t<__nv_fp8_e5m2, G>*>(&data[lane_idx / vec_size] +
G * (lane_idx % vec_size));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The pointer arithmetic for the lane method in this generic vec_t<__nv_fp8_e5m2, vec_size> is incorrect. Pointer arithmetic is on int4* instead of const __nv_fp8_e5m2*.

Corrected form:
return *reinterpret_cast<const vec_t<__nv_fp8_e5m2, G>*>(reinterpret_cast<const __nv_fp8_e5m2*>(data) + G * lane_idx);

    return *reinterpret_cast<const vec_t<__nv_fp8_e5m2, G>*>(reinterpret_cast<const __nv_fp8_e5m2*>(data) + G * lane_idx);

Comment on lines +1325 to +1327
FLASHINFER_INLINE __device__ vec_t<nv_bfloat16, G> lane(size_t lane_idx) const {
static_assert(2 % G == 0, "G must divide vec_size");
return *reinterpret_cast<const vec_t<nv_bfloat16, G>*>(&data + G * lane_idx);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The pointer arithmetic in the lane method has the same issue. &data is nv_bfloat162*. Arithmetic should be based on const nv_bfloat16*.

Corrected form:
return *reinterpret_cast<const vec_t<nv_bfloat16, G>*>(reinterpret_cast<const nv_bfloat16*>(&data) + G * lane_idx);

    return *reinterpret_cast<const vec_t<nv_bfloat16, G>*>(reinterpret_cast<const nv_bfloat16*>(&data) + G * lane_idx);

Comment on lines +522 to +524
FLASHINFER_INLINE __device__ vec_t<__nv_fp8_e4m3, G> lane(size_t lane_idx) const {
static_assert(2 % G == 0, "G must divide vec_size");
return *reinterpret_cast<const vec_t<__nv_fp8_e4m3, G>*>(&data + G * lane_idx);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The pointer arithmetic in the lane method appears incorrect. &data is of type __nv_fp8x2_e4m3*. The calculation should be based on the size of __nv_fp8_e4m3.

A corrected approach:
return *reinterpret_cast<const vec_t<__nv_fp8_e4m3, G>*>(reinterpret_cast<const __nv_fp8_e4m3*>(&data) + G * lane_idx);

    return *reinterpret_cast<const vec_t<__nv_fp8_e4m3, G>*>(reinterpret_cast<const __nv_fp8_e4m3*>(&data) + G * lane_idx);

Comment on lines +573 to +575
FLASHINFER_INLINE __device__ vec_t<__nv_fp8_e4m3, G> lane(size_t lane_idx) const {
static_assert(4 % G == 0, "G must divide vec_size");
return *reinterpret_cast<const vec_t<__nv_fp8_e4m3, G>*>(&data + G * lane_idx);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The pointer arithmetic in the lane method has the same issue as noted for vec_t<__nv_fp8_e4m3, 2>. The base pointer for arithmetic should be cast to const __nv_fp8_e4m3*.

Corrected form:
return *reinterpret_cast<const vec_t<__nv_fp8_e4m3, G>*>(reinterpret_cast<const __nv_fp8_e4m3*>(&data) + G * lane_idx);

    return *reinterpret_cast<const vec_t<__nv_fp8_e4m3, G>*>(reinterpret_cast<const __nv_fp8_e4m3*>(&data) + G * lane_idx);

Comment on lines +1589 to +1592
FLASHINFER_INLINE __device__ vec_t<float, G> lane(size_t lane_idx) const {
static_assert(vec_size % G == 0, "G must divide vec_size");
return *reinterpret_cast<const vec_t<float, G>*>(&data[lane_idx / vec_size] +
G * (lane_idx % vec_size));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The pointer arithmetic for the lane method in this generic vec_t<float, vec_size> is incorrect. Pointer arithmetic is on float4* instead of const float*.

Corrected form:
return *reinterpret_cast<const vec_t<float, G>*>(reinterpret_cast<const float*>(data) + G * lane_idx);

    return *reinterpret_cast<const vec_t<float, G>*>(reinterpret_cast<const float*>(data) + G * lane_idx);

@@ -872,12 +915,17 @@ FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::memcpy(__nv_fp8_e5m2* dst,

template <>
struct vec_t<__nv_fp8_e5m2, 8> {
uint2 data;
int2 data;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Changing data member from uint2 to int2. Consistent with other similar specializations.

Comment on lines +1070 to +1073
template <size_t G>
FLASHINFER_INLINE __device__ vec_t<half, G> lane(size_t lane_idx) const {
static_assert(1 % G == 0, "G must divide vec_size");
return *reinterpret_cast<const vec_t<half, G>*>(&data + G * lane_idx);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For vec_t<half, 1>, this lane implementation is correct.

Comment on lines +1507 to +1509
FLASHINFER_INLINE __device__ vec_t<float, G> lane(size_t lane_idx) const {
static_assert(1 % G == 0, "G must divide vec_size");
return *reinterpret_cast<const vec_t<float, G>*>(&data + G * lane_idx);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For vec_t<float, 1>, this lane implementation is correct.

Comment on lines +775 to +777
FLASHINFER_INLINE __device__ vec_t<__nv_fp8_e5m2, G> lane(size_t lane_idx) const {
static_assert(1 % G == 0, "G must divide vec_size");
return *reinterpret_cast<const vec_t<__nv_fp8_e5m2, G>*>(&data + G * lane_idx);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For vec_t<_, 1>, this lane implementation is correct.

@@ -599,12 +616,17 @@ FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::memcpy(__nv_fp8_e4m3* dst,

template <>
struct vec_t<__nv_fp8_e4m3, 8> {
uint2 data;
int2 data;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Changing data member from uint2 to int2. This should be functionally equivalent as long as bit patterns are correctly handled and reinterpreted.

yzh119 pushed a commit that referenced this pull request Jul 2, 2025
<!-- .github/pull_request_template.md -->

## 📌 Description

We enabled and updated the fused quantization in all-reduce/moe
all-reduce.

## 🔍 Related Issues

Depends on #1142 

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant