Skip to content

[WIP] Feature request: Implement GroupNormalization-21 #25060

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

Copilot
Copy link
Contributor

@Copilot Copilot AI commented Jun 14, 2025

Plan to implement ONNX GroupNormalization-21 support

After analyzing the codebase, I found that:

  1. GroupNormalization is currently only implemented as a contrib op (com.microsoft.GroupNorm)
  2. The standard ONNX GroupNormalization operator is NOT implemented in the CPU provider
  3. It's supported in CoreML provider, indicating it's a valid ONNX operation
  4. The existing test uses opset 18 but only tests CoreML provider

Implementation Checklist:

  • Research ONNX GroupNormalization specification details
  • Create CPU GroupNormalization kernel implementation
  • Register GroupNormalization operation in CPU provider for opsets 18-21
  • Create or adapt GroupNormalization implementation logic
  • Update tests to cover CPU provider and opset 21
  • Test and validate the implementation

Approach:

  • Follow existing LayerNorm/InstanceNorm patterns for CPU provider implementation
  • Leverage existing normalization helper code where possible
  • Implement minimal changes to add standard ONNX support alongside existing contrib op

Fixes #24538.


💬 Share your feedback on Copilot coding agent for the chance to win a $200 gift card! Click here to start the survey.

@justinchuby
Copy link
Contributor

This is the op spec for your reference:

GroupNormalization - 21
Version
name: [GroupNormalization (GitHub)](https://github.com/onnx/onnx/blob/main/docs/Operators.md#GroupNormalization)

domain: main

since_version: 21

function: True

support_level: SupportType.COMMON

shape inference: False

This version of the operator has been available since version 21.

Summary
A GroupNormalization function. Carries out group normalization as described in the paper https://arxiv.org/abs/1803.08494

This operator transforms input according to

y = scale * (x - mean) / sqrt(variance + epsilon) + bias,
where the mean and variance are computed per instance per group of channels, and scale and bias should be specified for each channel. The number of groups num_groups should be divisible by the number of channels so that there are an equal number of channels per group.

The overall computation has two stages: the first stage normalizes the elements to have zero mean and unit variance for each instance in each group, and the second stage scales and shifts the results of the first stage. The floating-point precision used in the first stage is determined by the stash_type attribute. For example, if stash_type is 1, the operator casts all input variables to 32-bit float, performs the computation, and finally casts the normalized results back to the original type of X. The second stage does not depend on stash_type.

When the number of groups is the same as the number of channels, this operator is equivalent to InstanceNormalization. When there is only one group, this operator is equivalent to LayerNormalization.

Attributes
epsilon - FLOAT (default is '1e-05'):

The epsilon value to use to avoid division by zero.

num_groups - INT (required) :

The number of groups of channels. It should be a divisor of the number of channels C.

stash_type - INT (default is '1'):

The floating-point precision used in stage one of the computation.

Inputs
X (heterogeneous) - T:

Input data tensor. Dimensions for image cases are (N x C x H x W), where N is the batch size, C is the number of channels, and H and W are the height and width of the data. Statistics are computed for every group of channels over C, H, and W. For non-image cases, the dimensions are in the form of (N x C x D1 x D2 ... Dn).

scale (heterogeneous) - T:

Scale tensor of shape (C).

bias (heterogeneous) - T:

Bias tensor of shape (C).

Outputs
Y (heterogeneous) - T:

The output tensor of the same shape as X.

Type Constraints
T in ( tensor(bfloat16), tensor(double), tensor(float), tensor(float16) ):

Constrain input and output types to float tensors.

[GroupNormalization - 18 vs 21](https://onnx.ai/onnx/operators/text_diff_GroupNormalization_18_21.html)

Copy link
Contributor Author

Copilot AI commented Jun 14, 2025

@justinchuby Unfortunately I hit an unexpected error while processing your comment. I've automatically reported this to GitHub.

You can ask me to try again later by mentioning me in a new comment.

If you want to contact GitHub about this error, please mention the following identifier so they can better serve you: 5e4b9df7-12c0-4e39-8582-7eec84f30bb3

Sorry for the inconvenience!

Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
@justinchuby
Copy link
Contributor

@copilot try again

Copy link
Contributor Author

Copilot AI commented Jun 14, 2025

@copilot try again

I've implemented the GroupNormalization-21 specification based on the details you provided. Key changes include:

  1. Proper stash_type implementation: When stash_type=1 (default), the first stage (mean/variance calculation) now uses float32 precision as specified, then casts back to original type for the second stage (scale/bias application).

  2. Added BFloat16 support: Extended type support to include tensor(bfloat16) as required by the specification.

  3. Fixed computation logic: Updated the implementation to exactly match the ONNX spec formula: y = scale * (x - mean) / sqrt(variance + epsilon) + bias

The implementation now properly handles the two-stage computation as described in the specification, where stage one precision is controlled by stash_type and stage two uses original type precision.

@justinchuby
Copy link
Contributor

@copilot do not try to build the project when you are iterating, because the build is going to take too long

@justinchuby
Copy link
Contributor

@copilot Also try to register the cuda kernels if you can

…t16 support

Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
Copilot finished work on behalf of justinchuby June 14, 2025 22:17
Copy link
Contributor Author

Copilot AI commented Jun 14, 2025

@copilot do not try to build the project when you are iterating, because the build is going to take too long

Understood, I'll avoid building during iteration to save time. I'll focus on the code implementation and skip build validation during development.

@copilot Also try to register the cuda kernels if you can

I've added CUDA kernel support for GroupNormalization. Created:

  • CUDA GroupNorm implementation (group_norm.h/cc)
  • CUDA kernel implementation (group_norm_impl.h/cu)
  • Registered kernels in CUDA execution provider for opsets 18-21
  • Added support for float, double, MLFloat16, and BFloat16 types

The implementation follows the same two-stage computation pattern as the CPU version with proper stash_type support.

Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
Copilot finished work on behalf of justinchuby June 14, 2025 22:24
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment on lines +12 to +16
// Opset 18-20 registrations (without stash_type)
#define REGISTER_ONNX_KERNEL_TYPED_VERSIONED(T) \
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(GroupNormalization, 18, 20, T, \
KernelDefBuilder() \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Opset 18-20 registrations (without stash_type)
#define REGISTER_ONNX_KERNEL_TYPED_VERSIONED(T) \
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(GroupNormalization, 18, 20, T, \
KernelDefBuilder() \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
// Opset 18-20 registrations (without stash_type)
#define REGISTER_ONNX_KERNEL_TYPED_VERSIONED(T) \
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(GroupNormalization, 18, 20, T, \
KernelDefBuilder() \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \

Comment on lines +19 to +24
// Opset 21+ registrations (with stash_type)
#define REGISTER_ONNX_KERNEL_TYPED_21(T) \
ONNX_CPU_OPERATOR_TYPED_KERNEL(GroupNormalization, 21, T, \
KernelDefBuilder() \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
GroupNorm);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Opset 21+ registrations (with stash_type)
#define REGISTER_ONNX_KERNEL_TYPED_21(T) \
ONNX_CPU_OPERATOR_TYPED_KERNEL(GroupNormalization, 21, T, \
KernelDefBuilder() \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
GroupNorm);
// Opset 21+ registrations (with stash_type)
#define REGISTER_ONNX_KERNEL_TYPED_21(T) \
ONNX_CPU_OPERATOR_TYPED_KERNEL(GroupNormalization, 21, T, \
KernelDefBuilder() \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
GroupNorm);

Comment on lines +38 to +40
ORT_ENFORCE(op_kernel_info.GetAttr("num_groups", &num_groups_).IsOK());

// stash_type is optional in opset 21, default to 1 (float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ORT_ENFORCE(op_kernel_info.GetAttr("num_groups", &num_groups_).IsOK());
// stash_type is optional in opset 21, default to 1 (float32)
ORT_ENFORCE(op_kernel_info.GetAttr("num_groups", &num_groups_).IsOK());
// stash_type is optional in opset 21, default to 1 (float32)

Comment on lines +54 to +56

template<typename T>
Status GroupNorm::ComputeImpl(OpKernelContext* context, const Tensor* X, const Tensor* scale, const Tensor* bias) const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
template<typename T>
Status GroupNorm::ComputeImpl(OpKernelContext* context, const Tensor* X, const Tensor* scale, const Tensor* bias) const {
template <typename T>
Status GroupNorm::ComputeImpl(OpKernelContext* context, const Tensor* X, const Tensor* scale, const Tensor* bias) const {

Comment on lines +59 to +66
const int64_t C = x_shape[1]; // channels

// Validate that channels are divisible by num_groups
ORT_RETURN_IF_NOT(C % num_groups_ == 0, "Number of channels must be divisible by num_groups");

const int64_t channels_per_group = C / num_groups_;

// Calculate spatial dimensions (H*W*... for everything after batch and channel dims)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const int64_t C = x_shape[1]; // channels
// Validate that channels are divisible by num_groups
ORT_RETURN_IF_NOT(C % num_groups_ == 0, "Number of channels must be divisible by num_groups");
const int64_t channels_per_group = C / num_groups_;
// Calculate spatial dimensions (H*W*... for everything after batch and channel dims)
const int64_t C = x_shape[1]; // channels
// Validate that channels are divisible by num_groups
ORT_RETURN_IF_NOT(C % num_groups_ == 0, "Number of channels must be divisible by num_groups");
const int64_t channels_per_group = C / num_groups_;
// Calculate spatial dimensions (H*W*... for everything after batch and channel dims)

Comment on lines +19 to +24
// Opset 21+ registrations (with stash_type)
#define REGISTER_CUDA_KERNEL_TYPED_21(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX(GroupNormalization, kOnnxDomain, 21, T, kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
GroupNorm<T, float>);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Opset 21+ registrations (with stash_type)
#define REGISTER_CUDA_KERNEL_TYPED_21(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX(GroupNormalization, kOnnxDomain, 21, T, kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
GroupNorm<T, float>);
// Opset 21+ registrations (with stash_type)
#define REGISTER_CUDA_KERNEL_TYPED_21(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX(GroupNormalization, kOnnxDomain, 21, T, kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
GroupNorm<T, float>);

Comment on lines +40 to +44
epsilon_ = tmp_epsilon;

ORT_ENFORCE(op_kernel_info.GetAttr("num_groups", &num_groups_).IsOK());

// stash_type is optional in opset 21, default to 1 (float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
epsilon_ = tmp_epsilon;
ORT_ENFORCE(op_kernel_info.GetAttr("num_groups", &num_groups_).IsOK());
// stash_type is optional in opset 21, default to 1 (float32)
epsilon_ = tmp_epsilon;
ORT_ENFORCE(op_kernel_info.GetAttr("num_groups", &num_groups_).IsOK());
// stash_type is optional in opset 21, default to 1 (float32)

Comment on lines +52 to +54
typedef typename ToCudaType<T>::MappedType CudaT;

// Inputs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
typedef typename ToCudaType<T>::MappedType CudaT;
// Inputs
typedef typename ToCudaType<T>::MappedType CudaT;
// Inputs

Comment on lines +65 to +70
const int64_t C = x_shape[1]; // channels

// Validate that channels are divisible by num_groups
ORT_RETURN_IF_NOT(C % num_groups_ == 0, "Number of channels must be divisible by num_groups");

// Calculate spatial dimensions (H*W*... for everything after batch and channel dims)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const int64_t C = x_shape[1]; // channels
// Validate that channels are divisible by num_groups
ORT_RETURN_IF_NOT(C % num_groups_ == 0, "Number of channels must be divisible by num_groups");
// Calculate spatial dimensions (H*W*... for everything after batch and channel dims)
const int64_t C = x_shape[1]; // channels
// Validate that channels are divisible by num_groups
ORT_RETURN_IF_NOT(C % num_groups_ == 0, "Number of channels must be divisible by num_groups");
// Calculate spatial dimensions (H*W*... for everything after batch and channel dims)

Comment on lines +74 to +76
}

Tensor* Y = ctx->Output(0, x_shape);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
}
Tensor* Y = ctx->Output(0, x_shape);
}
Tensor* Y = ctx->Output(0, x_shape);

@@ -0,0 +1,154 @@
// Copyright (c) Microsoft Corporation. All rights reserved.

Check warning

Code scanning / lintrunner

CLANGFORMAT/format Warning

See https://clang.llvm.org/docs/ClangFormat.html.
Run lintrunner -a to apply this patch.
@@ -0,0 +1,28 @@
// Copyright (c) Microsoft Corporation. All rights reserved.

Check warning

Code scanning / lintrunner

CLANGFORMAT/format Warning

See https://clang.llvm.org/docs/ClangFormat.html.
Run lintrunner -a to apply this patch.
@@ -0,0 +1,94 @@
// Copyright (c) Microsoft Corporation. All rights reserved.

Check warning

Code scanning / lintrunner

CLANGFORMAT/format Warning

See https://clang.llvm.org/docs/ClangFormat.html.
Run lintrunner -a to apply this patch.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Feature request: Implement GroupNormalization-21
2 participants