-
Notifications
You must be signed in to change notification settings - Fork 3.3k
[WIP] Feature request: Implement GroupNormalization-21 #25060
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This is the op spec for your reference:
|
@justinchuby Unfortunately I hit an unexpected error while processing your comment. I've automatically reported this to GitHub. You can ask me to try again later by mentioning me in a new comment. If you want to contact GitHub about this error, please mention the following identifier so they can better serve you: Sorry for the inconvenience! |
Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
@copilot try again |
I've implemented the GroupNormalization-21 specification based on the details you provided. Key changes include:
The implementation now properly handles the two-stage computation as described in the specification, where stage one precision is controlled by stash_type and stage two uses original type precision. |
@copilot do not try to build the project when you are iterating, because the build is going to take too long |
@copilot Also try to register the cuda kernels if you can |
…t16 support Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
Understood, I'll avoid building during iteration to save time. I'll focus on the code implementation and skip build validation during development.
I've added CUDA kernel support for GroupNormalization. Created:
The implementation follows the same two-stage computation pattern as the CPU version with proper stash_type support. |
Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
// Opset 18-20 registrations (without stash_type) | ||
#define REGISTER_ONNX_KERNEL_TYPED_VERSIONED(T) \ | ||
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(GroupNormalization, 18, 20, T, \ | ||
KernelDefBuilder() \ | ||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// Opset 18-20 registrations (without stash_type) | |
#define REGISTER_ONNX_KERNEL_TYPED_VERSIONED(T) \ | |
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(GroupNormalization, 18, 20, T, \ | |
KernelDefBuilder() \ | |
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \ | |
// Opset 18-20 registrations (without stash_type) | |
#define REGISTER_ONNX_KERNEL_TYPED_VERSIONED(T) \ | |
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(GroupNormalization, 18, 20, T, \ | |
KernelDefBuilder() \ | |
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \ |
// Opset 21+ registrations (with stash_type) | ||
#define REGISTER_ONNX_KERNEL_TYPED_21(T) \ | ||
ONNX_CPU_OPERATOR_TYPED_KERNEL(GroupNormalization, 21, T, \ | ||
KernelDefBuilder() \ | ||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \ | ||
GroupNorm); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// Opset 21+ registrations (with stash_type) | |
#define REGISTER_ONNX_KERNEL_TYPED_21(T) \ | |
ONNX_CPU_OPERATOR_TYPED_KERNEL(GroupNormalization, 21, T, \ | |
KernelDefBuilder() \ | |
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \ | |
GroupNorm); | |
// Opset 21+ registrations (with stash_type) | |
#define REGISTER_ONNX_KERNEL_TYPED_21(T) \ | |
ONNX_CPU_OPERATOR_TYPED_KERNEL(GroupNormalization, 21, T, \ | |
KernelDefBuilder() \ | |
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \ | |
GroupNorm); |
ORT_ENFORCE(op_kernel_info.GetAttr("num_groups", &num_groups_).IsOK()); | ||
|
||
// stash_type is optional in opset 21, default to 1 (float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ORT_ENFORCE(op_kernel_info.GetAttr("num_groups", &num_groups_).IsOK()); | |
// stash_type is optional in opset 21, default to 1 (float32) | |
ORT_ENFORCE(op_kernel_info.GetAttr("num_groups", &num_groups_).IsOK()); | |
// stash_type is optional in opset 21, default to 1 (float32) |
|
||
template<typename T> | ||
Status GroupNorm::ComputeImpl(OpKernelContext* context, const Tensor* X, const Tensor* scale, const Tensor* bias) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
template<typename T> | |
Status GroupNorm::ComputeImpl(OpKernelContext* context, const Tensor* X, const Tensor* scale, const Tensor* bias) const { | |
template <typename T> | |
Status GroupNorm::ComputeImpl(OpKernelContext* context, const Tensor* X, const Tensor* scale, const Tensor* bias) const { |
const int64_t C = x_shape[1]; // channels | ||
|
||
// Validate that channels are divisible by num_groups | ||
ORT_RETURN_IF_NOT(C % num_groups_ == 0, "Number of channels must be divisible by num_groups"); | ||
|
||
const int64_t channels_per_group = C / num_groups_; | ||
|
||
// Calculate spatial dimensions (H*W*... for everything after batch and channel dims) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const int64_t C = x_shape[1]; // channels | |
// Validate that channels are divisible by num_groups | |
ORT_RETURN_IF_NOT(C % num_groups_ == 0, "Number of channels must be divisible by num_groups"); | |
const int64_t channels_per_group = C / num_groups_; | |
// Calculate spatial dimensions (H*W*... for everything after batch and channel dims) | |
const int64_t C = x_shape[1]; // channels | |
// Validate that channels are divisible by num_groups | |
ORT_RETURN_IF_NOT(C % num_groups_ == 0, "Number of channels must be divisible by num_groups"); | |
const int64_t channels_per_group = C / num_groups_; | |
// Calculate spatial dimensions (H*W*... for everything after batch and channel dims) |
// Opset 21+ registrations (with stash_type) | ||
#define REGISTER_CUDA_KERNEL_TYPED_21(T) \ | ||
ONNX_OPERATOR_TYPED_KERNEL_EX(GroupNormalization, kOnnxDomain, 21, T, kCudaExecutionProvider, \ | ||
(*KernelDefBuilder::Create()) \ | ||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \ | ||
GroupNorm<T, float>); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// Opset 21+ registrations (with stash_type) | |
#define REGISTER_CUDA_KERNEL_TYPED_21(T) \ | |
ONNX_OPERATOR_TYPED_KERNEL_EX(GroupNormalization, kOnnxDomain, 21, T, kCudaExecutionProvider, \ | |
(*KernelDefBuilder::Create()) \ | |
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \ | |
GroupNorm<T, float>); | |
// Opset 21+ registrations (with stash_type) | |
#define REGISTER_CUDA_KERNEL_TYPED_21(T) \ | |
ONNX_OPERATOR_TYPED_KERNEL_EX(GroupNormalization, kOnnxDomain, 21, T, kCudaExecutionProvider, \ | |
(*KernelDefBuilder::Create()) \ | |
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \ | |
GroupNorm<T, float>); |
epsilon_ = tmp_epsilon; | ||
|
||
ORT_ENFORCE(op_kernel_info.GetAttr("num_groups", &num_groups_).IsOK()); | ||
|
||
// stash_type is optional in opset 21, default to 1 (float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
epsilon_ = tmp_epsilon; | |
ORT_ENFORCE(op_kernel_info.GetAttr("num_groups", &num_groups_).IsOK()); | |
// stash_type is optional in opset 21, default to 1 (float32) | |
epsilon_ = tmp_epsilon; | |
ORT_ENFORCE(op_kernel_info.GetAttr("num_groups", &num_groups_).IsOK()); | |
// stash_type is optional in opset 21, default to 1 (float32) |
typedef typename ToCudaType<T>::MappedType CudaT; | ||
|
||
// Inputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typedef typename ToCudaType<T>::MappedType CudaT; | |
// Inputs | |
typedef typename ToCudaType<T>::MappedType CudaT; | |
// Inputs |
const int64_t C = x_shape[1]; // channels | ||
|
||
// Validate that channels are divisible by num_groups | ||
ORT_RETURN_IF_NOT(C % num_groups_ == 0, "Number of channels must be divisible by num_groups"); | ||
|
||
// Calculate spatial dimensions (H*W*... for everything after batch and channel dims) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const int64_t C = x_shape[1]; // channels | |
// Validate that channels are divisible by num_groups | |
ORT_RETURN_IF_NOT(C % num_groups_ == 0, "Number of channels must be divisible by num_groups"); | |
// Calculate spatial dimensions (H*W*... for everything after batch and channel dims) | |
const int64_t C = x_shape[1]; // channels | |
// Validate that channels are divisible by num_groups | |
ORT_RETURN_IF_NOT(C % num_groups_ == 0, "Number of channels must be divisible by num_groups"); | |
// Calculate spatial dimensions (H*W*... for everything after batch and channel dims) |
} | ||
|
||
Tensor* Y = ctx->Output(0, x_shape); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
} | |
Tensor* Y = ctx->Output(0, x_shape); | |
} | |
Tensor* Y = ctx->Output(0, x_shape); |
@@ -0,0 +1,154 @@ | |||
// Copyright (c) Microsoft Corporation. All rights reserved. |
Check warning
Code scanning / lintrunner
CLANGFORMAT/format Warning
Run lintrunner -a to apply this patch.
@@ -0,0 +1,28 @@ | |||
// Copyright (c) Microsoft Corporation. All rights reserved. |
Check warning
Code scanning / lintrunner
CLANGFORMAT/format Warning
Run lintrunner -a to apply this patch.
@@ -0,0 +1,94 @@ | |||
// Copyright (c) Microsoft Corporation. All rights reserved. |
Check warning
Code scanning / lintrunner
CLANGFORMAT/format Warning
Run lintrunner -a to apply this patch.
Plan to implement ONNX GroupNormalization-21 support
After analyzing the codebase, I found that:
Implementation Checklist:
Approach:
Fixes #24538.
💬 Share your feedback on Copilot coding agent for the chance to win a $200 gift card! Click here to start the survey.