Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CudaSPIRV] Allow using integral non-type template parameters as attribute args #131546

Merged

Conversation

alexander-shaposhnikov
Copy link
Collaborator

Allow using integral non-type template parameters as attribute arguments of
reqd_work_group_size and work_group_size_hint.

Test plan:
ninja check-all

@llvmbot llvmbot added clang Clang issues not falling into any other category clang-tools-extra backend:AMDGPU clang-tidy clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:codegen IR generation bugs: mangling, exceptions, etc. labels Mar 17, 2025
@llvmbot
Copy link
Member

llvmbot commented Mar 17, 2025

@llvm/pr-subscribers-clang-tidy
@llvm/pr-subscribers-backend-amdgpu

@llvm/pr-subscribers-clang

Author: Alexander Shaposhnikov (alexander-shaposhnikov)

Changes

Allow using integral non-type template parameters as attribute arguments of
reqd_work_group_size and work_group_size_hint.

Test plan:
ninja check-all


Full diff: https://github.com/llvm/llvm-project/pull/131546.diff

9 Files Affected:

  • (modified) clang-tools-extra/clang-tidy/altera/SingleWorkItemBarrierCheck.cpp (+6-2)
  • (modified) clang/include/clang/Basic/Attr.td (+2-5)
  • (modified) clang/lib/CodeGen/CodeGenFunction.cpp (+12-6)
  • (modified) clang/lib/CodeGen/Targets/AMDGPU.cpp (+11-3)
  • (modified) clang/lib/CodeGen/Targets/TCE.cpp (+9-9)
  • (modified) clang/lib/Sema/SemaDeclAttr.cpp (+70-9)
  • (modified) clang/lib/Sema/SemaTemplateInstantiateDecl.cpp (+32)
  • (added) clang/test/SemaCUDA/spirv-attrs-diag.cu (+38)
  • (modified) clang/test/SemaCUDA/spirv-attrs.cu (+16)
diff --git a/clang-tools-extra/clang-tidy/altera/SingleWorkItemBarrierCheck.cpp b/clang-tools-extra/clang-tidy/altera/SingleWorkItemBarrierCheck.cpp
index df21c425ea956..c5da66a1f28b6 100644
--- a/clang-tools-extra/clang-tidy/altera/SingleWorkItemBarrierCheck.cpp
+++ b/clang-tools-extra/clang-tidy/altera/SingleWorkItemBarrierCheck.cpp
@@ -54,8 +54,12 @@ void SingleWorkItemBarrierCheck::check(const MatchFinder::MatchResult &Result) {
     bool IsNDRange = false;
     if (MatchedDecl->hasAttr<ReqdWorkGroupSizeAttr>()) {
       const auto *Attribute = MatchedDecl->getAttr<ReqdWorkGroupSizeAttr>();
-      if (Attribute->getXDim() > 1 || Attribute->getYDim() > 1 ||
-          Attribute->getZDim() > 1)
+      auto Eval = [&](Expr *E) {
+        return E->EvaluateKnownConstInt(MatchedDecl->getASTContext())
+            .getExtValue();
+      };
+      if (Eval(Attribute->getXDim()) > 1 || Eval(Attribute->getYDim()) > 1 ||
+          Eval(Attribute->getZDim()) > 1)
         IsNDRange = true;
     }
     if (IsNDRange) // No warning if kernel is treated as an NDRange.
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index 4d34346460561..cceb4085d523d 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -3044,8 +3044,7 @@ def NoDeref : TypeAttr {
 def ReqdWorkGroupSize : InheritableAttr {
   // Does not have a [[]] spelling because it is an OpenCL-related attribute.
   let Spellings = [GNU<"reqd_work_group_size">];
-  let Args = [UnsignedArgument<"XDim">, UnsignedArgument<"YDim">,
-              UnsignedArgument<"ZDim">];
+  let Args = [ExprArgument<"XDim">, ExprArgument<"YDim">, ExprArgument<"ZDim">];
   let Subjects = SubjectList<[Function], ErrorDiag>;
   let Documentation = [Undocumented];
 }
@@ -3053,9 +3052,7 @@ def ReqdWorkGroupSize : InheritableAttr {
 def WorkGroupSizeHint :  InheritableAttr {
   // Does not have a [[]] spelling because it is an OpenCL-related attribute.
   let Spellings = [GNU<"work_group_size_hint">];
-  let Args = [UnsignedArgument<"XDim">,
-              UnsignedArgument<"YDim">,
-              UnsignedArgument<"ZDim">];
+  let Args = [ExprArgument<"XDim">, ExprArgument<"YDim">, ExprArgument<"ZDim">];
   let Subjects = SubjectList<[Function], ErrorDiag>;
   let Documentation = [Undocumented];
 }
diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp
index 447192bc7f60c..de1e04a7982fa 100644
--- a/clang/lib/CodeGen/CodeGenFunction.cpp
+++ b/clang/lib/CodeGen/CodeGenFunction.cpp
@@ -649,18 +649,24 @@ void CodeGenFunction::EmitKernelMetadata(const FunctionDecl *FD,
   }
 
   if (const WorkGroupSizeHintAttr *A = FD->getAttr<WorkGroupSizeHintAttr>()) {
+    auto Eval = [&](Expr *E) {
+      return E->EvaluateKnownConstInt(FD->getASTContext()).getExtValue();
+    };
     llvm::Metadata *AttrMDArgs[] = {
-        llvm::ConstantAsMetadata::get(Builder.getInt32(A->getXDim())),
-        llvm::ConstantAsMetadata::get(Builder.getInt32(A->getYDim())),
-        llvm::ConstantAsMetadata::get(Builder.getInt32(A->getZDim()))};
+        llvm::ConstantAsMetadata::get(Builder.getInt32(Eval(A->getXDim()))),
+        llvm::ConstantAsMetadata::get(Builder.getInt32(Eval(A->getYDim()))),
+        llvm::ConstantAsMetadata::get(Builder.getInt32(Eval(A->getZDim())))};
     Fn->setMetadata("work_group_size_hint", llvm::MDNode::get(Context, AttrMDArgs));
   }
 
   if (const ReqdWorkGroupSizeAttr *A = FD->getAttr<ReqdWorkGroupSizeAttr>()) {
+    auto Eval = [&](Expr *E) {
+      return E->EvaluateKnownConstInt(FD->getASTContext()).getExtValue();
+    };
     llvm::Metadata *AttrMDArgs[] = {
-        llvm::ConstantAsMetadata::get(Builder.getInt32(A->getXDim())),
-        llvm::ConstantAsMetadata::get(Builder.getInt32(A->getYDim())),
-        llvm::ConstantAsMetadata::get(Builder.getInt32(A->getZDim()))};
+        llvm::ConstantAsMetadata::get(Builder.getInt32(Eval(A->getXDim()))),
+        llvm::ConstantAsMetadata::get(Builder.getInt32(Eval(A->getYDim()))),
+        llvm::ConstantAsMetadata::get(Builder.getInt32(Eval(A->getZDim())))};
     Fn->setMetadata("reqd_work_group_size", llvm::MDNode::get(Context, AttrMDArgs));
   }
 
diff --git a/clang/lib/CodeGen/Targets/AMDGPU.cpp b/clang/lib/CodeGen/Targets/AMDGPU.cpp
index 7d84310ba0bf5..a14f599330644 100644
--- a/clang/lib/CodeGen/Targets/AMDGPU.cpp
+++ b/clang/lib/CodeGen/Targets/AMDGPU.cpp
@@ -753,12 +753,20 @@ void CodeGenModule::handleAMDGPUFlatWorkGroupSizeAttr(
     int32_t *MaxThreadsVal) {
   unsigned Min = 0;
   unsigned Max = 0;
+  auto Eval = [&](Expr *E) {
+    return E->EvaluateKnownConstInt(getContext()).getExtValue();
+  };
   if (FlatWGS) {
-    Min = FlatWGS->getMin()->EvaluateKnownConstInt(getContext()).getExtValue();
-    Max = FlatWGS->getMax()->EvaluateKnownConstInt(getContext()).getExtValue();
+    Min = Eval(
+        FlatWGS
+            ->getMin()); //->EvaluateKnownConstInt(getContext()).getExtValue();
+    Max = Eval(
+        FlatWGS
+            ->getMax()); //->EvaluateKnownConstInt(getContext()).getExtValue();
   }
   if (ReqdWGS && Min == 0 && Max == 0)
-    Min = Max = ReqdWGS->getXDim() * ReqdWGS->getYDim() * ReqdWGS->getZDim();
+    Min = Max = Eval(ReqdWGS->getXDim()) * Eval(ReqdWGS->getYDim()) *
+                Eval(ReqdWGS->getZDim());
 
   if (Min != 0) {
     assert(Min <= Max && "Min must be less than or equal Max");
diff --git a/clang/lib/CodeGen/Targets/TCE.cpp b/clang/lib/CodeGen/Targets/TCE.cpp
index d7178b4b8a949..bfce262ea965c 100644
--- a/clang/lib/CodeGen/Targets/TCE.cpp
+++ b/clang/lib/CodeGen/Targets/TCE.cpp
@@ -53,15 +53,15 @@ void TCETargetCodeGenInfo::setTargetAttributes(
         SmallVector<llvm::Metadata *, 5> Operands;
         Operands.push_back(llvm::ConstantAsMetadata::get(F));
 
-        Operands.push_back(
-            llvm::ConstantAsMetadata::get(llvm::Constant::getIntegerValue(
-                M.Int32Ty, llvm::APInt(32, Attr->getXDim()))));
-        Operands.push_back(
-            llvm::ConstantAsMetadata::get(llvm::Constant::getIntegerValue(
-                M.Int32Ty, llvm::APInt(32, Attr->getYDim()))));
-        Operands.push_back(
-            llvm::ConstantAsMetadata::get(llvm::Constant::getIntegerValue(
-                M.Int32Ty, llvm::APInt(32, Attr->getZDim()))));
+        auto Eval = [&](Expr *E) {
+          return E->EvaluateKnownConstInt(FD->getASTContext());
+        };
+        Operands.push_back(llvm::ConstantAsMetadata::get(
+            llvm::Constant::getIntegerValue(M.Int32Ty, Eval(Attr->getXDim()))));
+        Operands.push_back(llvm::ConstantAsMetadata::get(
+            llvm::Constant::getIntegerValue(M.Int32Ty, Eval(Attr->getYDim()))));
+        Operands.push_back(llvm::ConstantAsMetadata::get(
+            llvm::Constant::getIntegerValue(M.Int32Ty, Eval(Attr->getZDim()))));
 
         // Add a boolean constant operand for "required" (true) or "hint"
         // (false) for implementing the work_group_size_hint attr later.
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index 769f5316ed934..0a8a3e1c49414 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -2914,21 +2914,70 @@ static void handleWeakImportAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
   D->addAttr(::new (S.Context) WeakImportAttr(S.Context, AL));
 }
 
+// Checks whether an argument of launch_bounds-like attribute is
+// acceptable, performs implicit conversion to Rvalue, and returns
+// non-nullptr Expr result on success. Otherwise, it returns nullptr
+// and may output an error.
+template <class Attribute>
+static Expr *makeAttributeArgExpr(Sema &S, Expr *E, const Attribute &Attr,
+                                  const unsigned Idx) {
+  if (S.DiagnoseUnexpandedParameterPack(E))
+    return nullptr;
+
+  // Accept template arguments for now as they depend on something else.
+  // We'll get to check them when they eventually get instantiated.
+  if (E->isValueDependent())
+    return E;
+
+  std::optional<llvm::APSInt> I = llvm::APSInt(64);
+  if (!(I = E->getIntegerConstantExpr(S.Context))) {
+    S.Diag(E->getExprLoc(), diag::err_attribute_argument_n_type)
+        << &Attr << Idx << AANT_ArgumentIntegerConstant << E->getSourceRange();
+    return nullptr;
+  }
+  // Make sure we can fit it in 32 bits.
+  if (!I->isIntN(32)) {
+    S.Diag(E->getExprLoc(), diag::err_ice_too_large)
+        << toString(*I, 10, false) << 32 << /* Unsigned */ 1;
+    return nullptr;
+  }
+  if (*I < 0)
+    S.Diag(E->getExprLoc(), diag::err_attribute_requires_positive_integer)
+        << &Attr << /*non-negative*/ 1 << E->getSourceRange();
+
+  // We may need to perform implicit conversion of the argument.
+  InitializedEntity Entity = InitializedEntity::InitializeParameter(
+      S.Context, S.Context.getConstType(S.Context.IntTy), /*consume*/ false);
+  ExprResult ValArg = S.PerformCopyInitialization(Entity, SourceLocation(), E);
+  assert(!ValArg.isInvalid() &&
+         "Unexpected PerformCopyInitialization() failure.");
+
+  return ValArg.getAs<Expr>();
+}
+
 // Handles reqd_work_group_size and work_group_size_hint.
 template <typename WorkGroupAttr>
 static void handleWorkGroupSize(Sema &S, Decl *D, const ParsedAttr &AL) {
-  uint32_t WGSize[3];
+  Expr *WGSize[3];
   for (unsigned i = 0; i < 3; ++i) {
-    const Expr *E = AL.getArgAsExpr(i);
-    if (!S.checkUInt32Argument(AL, E, WGSize[i], i,
-                               /*StrictlyUnsigned=*/true))
+    if (Expr *E = makeAttributeArgExpr(S, AL.getArgAsExpr(i), AL, i))
+      WGSize[i] = E;
+    else
       return;
   }
 
-  if (!llvm::all_of(WGSize, [](uint32_t Size) { return Size == 0; })) {
+  auto IsZero = [&](Expr *E) {
+    if (E->isValueDependent())
+      return false;
+    std::optional<llvm::APSInt> I = E->getIntegerConstantExpr(S.Context);
+    assert(I && "Non-integer constant expr");
+    return I->isZero();
+  };
+
+  if (!llvm::all_of(WGSize, IsZero)) {
     for (unsigned i = 0; i < 3; ++i) {
       const Expr *E = AL.getArgAsExpr(i);
-      if (WGSize[i] == 0) {
+      if (IsZero(WGSize[i])) {
         S.Diag(AL.getLoc(), diag::err_attribute_argument_is_zero)
             << AL << E->getSourceRange();
         return;
@@ -2936,10 +2985,22 @@ static void handleWorkGroupSize(Sema &S, Decl *D, const ParsedAttr &AL) {
     }
   }
 
+  auto Equal = [&](Expr *LHS, Expr *RHS) {
+    if (LHS->isValueDependent() || RHS->isValueDependent())
+      return true;
+    std::optional<llvm::APSInt> L = LHS->getIntegerConstantExpr(S.Context);
+    assert(L && "Non-integer constant expr");
+    std::optional<llvm::APSInt> R = RHS->getIntegerConstantExpr(S.Context);
+    assert(L && "Non-integer constant expr");
+    return L == R;
+  };
+
   WorkGroupAttr *Existing = D->getAttr<WorkGroupAttr>();
-  if (Existing && !(Existing->getXDim() == WGSize[0] &&
-                    Existing->getYDim() == WGSize[1] &&
-                    Existing->getZDim() == WGSize[2]))
+  if (Existing &&
+      !llvm::equal(std::initializer_list<Expr *>{Existing->getXDim(),
+                                                 Existing->getYDim(),
+                                                 Existing->getZDim()},
+                   WGSize, Equal))
     S.Diag(AL.getLoc(), diag::warn_duplicate_attribute) << AL;
 
   D->addAttr(::new (S.Context)
diff --git a/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp b/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
index 170c0b0d39f86..d5330430a3b73 100644
--- a/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
+++ b/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
@@ -572,6 +572,32 @@ static void instantiateDependentAMDGPUFlatWorkGroupSizeAttr(
   S.AMDGPU().addAMDGPUFlatWorkGroupSizeAttr(New, Attr, MinExpr, MaxExpr);
 }
 
+static void instantiateDependentReqdWorkGroupSizeAttr(
+    Sema &S, const MultiLevelTemplateArgumentList &TemplateArgs,
+    const ReqdWorkGroupSizeAttr &Attr, Decl *New) {
+  // Both min and max expression are constant expressions.
+  EnterExpressionEvaluationContext Unevaluated(
+      S, Sema::ExpressionEvaluationContext::ConstantEvaluated);
+
+  ExprResult Result = S.SubstExpr(Attr.getXDim(), TemplateArgs);
+  if (Result.isInvalid())
+    return;
+  Expr *X = Result.getAs<Expr>();
+
+  Result = S.SubstExpr(Attr.getYDim(), TemplateArgs);
+  if (Result.isInvalid())
+    return;
+  Expr *Y = Result.getAs<Expr>();
+
+  Result = S.SubstExpr(Attr.getZDim(), TemplateArgs);
+  if (Result.isInvalid())
+    return;
+  Expr *Z = Result.getAs<Expr>();
+
+  ASTContext &Context = S.getASTContext();
+  New->addAttr(::new (Context) ReqdWorkGroupSizeAttr(Context, Attr, X, Y, Z));
+}
+
 ExplicitSpecifier Sema::instantiateExplicitSpecifier(
     const MultiLevelTemplateArgumentList &TemplateArgs, ExplicitSpecifier ES) {
   if (!ES.getExpr())
@@ -812,6 +838,12 @@ void Sema::InstantiateAttrs(const MultiLevelTemplateArgumentList &TemplateArgs,
       continue;
     }
 
+    if (const auto *ReqdWorkGroupSize =
+            dyn_cast<ReqdWorkGroupSizeAttr>(TmplAttr)) {
+      instantiateDependentReqdWorkGroupSizeAttr(*this, TemplateArgs,
+                                                *ReqdWorkGroupSize, New);
+    }
+
     if (const auto *AMDGPUFlatWorkGroupSize =
             dyn_cast<AMDGPUFlatWorkGroupSizeAttr>(TmplAttr)) {
       instantiateDependentAMDGPUFlatWorkGroupSizeAttr(
diff --git a/clang/test/SemaCUDA/spirv-attrs-diag.cu b/clang/test/SemaCUDA/spirv-attrs-diag.cu
new file mode 100644
index 0000000000000..8eae45623f5c5
--- /dev/null
+++ b/clang/test/SemaCUDA/spirv-attrs-diag.cu
@@ -0,0 +1,38 @@
+// RUN: %clang_cc1 -triple spirv64 -aux-triple x86_64-unknown-linux-gnu \
+// RUN:   -fcuda-is-device -verify -fsyntax-only %s
+
+#include "Inputs/cuda.h"
+
+__attribute__((reqd_work_group_size(0x100000000, 1, 1))) // expected-error {{integer constant expression evaluates to value 4294967296 that cannot be represented in a 32-bit unsigned integer type}}
+__global__ void TestTooBigArg1(void);
+
+__attribute__((work_group_size_hint(0x100000000, 1, 1))) // expected-error {{integer constant expression evaluates to value 4294967296 that cannot be represented in a 32-bit unsigned integer type}}
+__global__ void TestTooBigArg2(void);
+
+template <int... Args>
+__attribute__((reqd_work_group_size(Args))) // expected-error {{expression contains unexpanded parameter pack 'Args'}}
+__global__ void TestTemplateVariadicArgs1(void) {}
+
+template <int... Args>
+__attribute__((work_group_size_hint(Args))) // expected-error {{expression contains unexpanded parameter pack 'Args'}}
+__global__ void TestTemplateVariadicArgs2(void) {}
+
+template <class a> // expected-note {{declared here}}
+__attribute__((reqd_work_group_size(a, 1, 1))) // expected-error {{'a' does not refer to a value}}
+__global__ void TestTemplateArgClass1(void) {}
+
+template <class a> // expected-note {{declared here}}
+__attribute__((work_group_size_hint(a, 1, 1))) // expected-error {{'a' does not refer to a value}}
+__global__ void TestTemplateArgClass2(void) {}
+
+constexpr int A = 512;
+
+__attribute__((reqd_work_group_size(A, A, A)))
+__global__ void TestConstIntArg1(void) {}
+
+__attribute__((work_group_size_hint(A, A, A)))
+__global__ void TestConstIntArg2(void) {}
+
+
+
+
diff --git a/clang/test/SemaCUDA/spirv-attrs.cu b/clang/test/SemaCUDA/spirv-attrs.cu
index 6539421423ee1..fd944e665f169 100644
--- a/clang/test/SemaCUDA/spirv-attrs.cu
+++ b/clang/test/SemaCUDA/spirv-attrs.cu
@@ -8,11 +8,27 @@
 __attribute__((reqd_work_group_size(128, 1, 1)))
 __global__ void reqd_work_group_size_128_1_1() {}
 
+template <unsigned a, unsigned b, unsigned c>
+__attribute__((reqd_work_group_size(a, b, c)))
+__global__ void reqd_work_group_size_a_b_c() {}
+
+template <>
+__global__ void reqd_work_group_size_a_b_c<128,1,1>(void);
+
 __attribute__((work_group_size_hint(2, 2, 2)))
 __global__ void work_group_size_hint_2_2_2() {}
 
+template <unsigned a, unsigned b, unsigned c>
+__attribute__((work_group_size_hint(a, b, c)))
+__global__ void work_group_size_hint_a_b_c() {}
+
+template <>
+__global__ void work_group_size_hint_a_b_c<128,1,1>(void);
+
 __attribute__((vec_type_hint(int)))
 __global__ void vec_type_hint_int() {}
 
 __attribute__((intel_reqd_sub_group_size(64)))
 __global__ void intel_reqd_sub_group_size_64() {}
+
+

@llvmbot
Copy link
Member

llvmbot commented Mar 17, 2025

@llvm/pr-subscribers-clang-tools-extra

Author: Alexander Shaposhnikov (alexander-shaposhnikov)

Changes

Allow using integral non-type template parameters as attribute arguments of
reqd_work_group_size and work_group_size_hint.

Test plan:
ninja check-all


Full diff: https://github.com/llvm/llvm-project/pull/131546.diff

9 Files Affected:

  • (modified) clang-tools-extra/clang-tidy/altera/SingleWorkItemBarrierCheck.cpp (+6-2)
  • (modified) clang/include/clang/Basic/Attr.td (+2-5)
  • (modified) clang/lib/CodeGen/CodeGenFunction.cpp (+12-6)
  • (modified) clang/lib/CodeGen/Targets/AMDGPU.cpp (+11-3)
  • (modified) clang/lib/CodeGen/Targets/TCE.cpp (+9-9)
  • (modified) clang/lib/Sema/SemaDeclAttr.cpp (+70-9)
  • (modified) clang/lib/Sema/SemaTemplateInstantiateDecl.cpp (+32)
  • (added) clang/test/SemaCUDA/spirv-attrs-diag.cu (+38)
  • (modified) clang/test/SemaCUDA/spirv-attrs.cu (+16)
diff --git a/clang-tools-extra/clang-tidy/altera/SingleWorkItemBarrierCheck.cpp b/clang-tools-extra/clang-tidy/altera/SingleWorkItemBarrierCheck.cpp
index df21c425ea956..c5da66a1f28b6 100644
--- a/clang-tools-extra/clang-tidy/altera/SingleWorkItemBarrierCheck.cpp
+++ b/clang-tools-extra/clang-tidy/altera/SingleWorkItemBarrierCheck.cpp
@@ -54,8 +54,12 @@ void SingleWorkItemBarrierCheck::check(const MatchFinder::MatchResult &Result) {
     bool IsNDRange = false;
     if (MatchedDecl->hasAttr<ReqdWorkGroupSizeAttr>()) {
       const auto *Attribute = MatchedDecl->getAttr<ReqdWorkGroupSizeAttr>();
-      if (Attribute->getXDim() > 1 || Attribute->getYDim() > 1 ||
-          Attribute->getZDim() > 1)
+      auto Eval = [&](Expr *E) {
+        return E->EvaluateKnownConstInt(MatchedDecl->getASTContext())
+            .getExtValue();
+      };
+      if (Eval(Attribute->getXDim()) > 1 || Eval(Attribute->getYDim()) > 1 ||
+          Eval(Attribute->getZDim()) > 1)
         IsNDRange = true;
     }
     if (IsNDRange) // No warning if kernel is treated as an NDRange.
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index 4d34346460561..cceb4085d523d 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -3044,8 +3044,7 @@ def NoDeref : TypeAttr {
 def ReqdWorkGroupSize : InheritableAttr {
   // Does not have a [[]] spelling because it is an OpenCL-related attribute.
   let Spellings = [GNU<"reqd_work_group_size">];
-  let Args = [UnsignedArgument<"XDim">, UnsignedArgument<"YDim">,
-              UnsignedArgument<"ZDim">];
+  let Args = [ExprArgument<"XDim">, ExprArgument<"YDim">, ExprArgument<"ZDim">];
   let Subjects = SubjectList<[Function], ErrorDiag>;
   let Documentation = [Undocumented];
 }
@@ -3053,9 +3052,7 @@ def ReqdWorkGroupSize : InheritableAttr {
 def WorkGroupSizeHint :  InheritableAttr {
   // Does not have a [[]] spelling because it is an OpenCL-related attribute.
   let Spellings = [GNU<"work_group_size_hint">];
-  let Args = [UnsignedArgument<"XDim">,
-              UnsignedArgument<"YDim">,
-              UnsignedArgument<"ZDim">];
+  let Args = [ExprArgument<"XDim">, ExprArgument<"YDim">, ExprArgument<"ZDim">];
   let Subjects = SubjectList<[Function], ErrorDiag>;
   let Documentation = [Undocumented];
 }
diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp
index 447192bc7f60c..de1e04a7982fa 100644
--- a/clang/lib/CodeGen/CodeGenFunction.cpp
+++ b/clang/lib/CodeGen/CodeGenFunction.cpp
@@ -649,18 +649,24 @@ void CodeGenFunction::EmitKernelMetadata(const FunctionDecl *FD,
   }
 
   if (const WorkGroupSizeHintAttr *A = FD->getAttr<WorkGroupSizeHintAttr>()) {
+    auto Eval = [&](Expr *E) {
+      return E->EvaluateKnownConstInt(FD->getASTContext()).getExtValue();
+    };
     llvm::Metadata *AttrMDArgs[] = {
-        llvm::ConstantAsMetadata::get(Builder.getInt32(A->getXDim())),
-        llvm::ConstantAsMetadata::get(Builder.getInt32(A->getYDim())),
-        llvm::ConstantAsMetadata::get(Builder.getInt32(A->getZDim()))};
+        llvm::ConstantAsMetadata::get(Builder.getInt32(Eval(A->getXDim()))),
+        llvm::ConstantAsMetadata::get(Builder.getInt32(Eval(A->getYDim()))),
+        llvm::ConstantAsMetadata::get(Builder.getInt32(Eval(A->getZDim())))};
     Fn->setMetadata("work_group_size_hint", llvm::MDNode::get(Context, AttrMDArgs));
   }
 
   if (const ReqdWorkGroupSizeAttr *A = FD->getAttr<ReqdWorkGroupSizeAttr>()) {
+    auto Eval = [&](Expr *E) {
+      return E->EvaluateKnownConstInt(FD->getASTContext()).getExtValue();
+    };
     llvm::Metadata *AttrMDArgs[] = {
-        llvm::ConstantAsMetadata::get(Builder.getInt32(A->getXDim())),
-        llvm::ConstantAsMetadata::get(Builder.getInt32(A->getYDim())),
-        llvm::ConstantAsMetadata::get(Builder.getInt32(A->getZDim()))};
+        llvm::ConstantAsMetadata::get(Builder.getInt32(Eval(A->getXDim()))),
+        llvm::ConstantAsMetadata::get(Builder.getInt32(Eval(A->getYDim()))),
+        llvm::ConstantAsMetadata::get(Builder.getInt32(Eval(A->getZDim())))};
     Fn->setMetadata("reqd_work_group_size", llvm::MDNode::get(Context, AttrMDArgs));
   }
 
diff --git a/clang/lib/CodeGen/Targets/AMDGPU.cpp b/clang/lib/CodeGen/Targets/AMDGPU.cpp
index 7d84310ba0bf5..a14f599330644 100644
--- a/clang/lib/CodeGen/Targets/AMDGPU.cpp
+++ b/clang/lib/CodeGen/Targets/AMDGPU.cpp
@@ -753,12 +753,20 @@ void CodeGenModule::handleAMDGPUFlatWorkGroupSizeAttr(
     int32_t *MaxThreadsVal) {
   unsigned Min = 0;
   unsigned Max = 0;
+  auto Eval = [&](Expr *E) {
+    return E->EvaluateKnownConstInt(getContext()).getExtValue();
+  };
   if (FlatWGS) {
-    Min = FlatWGS->getMin()->EvaluateKnownConstInt(getContext()).getExtValue();
-    Max = FlatWGS->getMax()->EvaluateKnownConstInt(getContext()).getExtValue();
+    Min = Eval(
+        FlatWGS
+            ->getMin()); //->EvaluateKnownConstInt(getContext()).getExtValue();
+    Max = Eval(
+        FlatWGS
+            ->getMax()); //->EvaluateKnownConstInt(getContext()).getExtValue();
   }
   if (ReqdWGS && Min == 0 && Max == 0)
-    Min = Max = ReqdWGS->getXDim() * ReqdWGS->getYDim() * ReqdWGS->getZDim();
+    Min = Max = Eval(ReqdWGS->getXDim()) * Eval(ReqdWGS->getYDim()) *
+                Eval(ReqdWGS->getZDim());
 
   if (Min != 0) {
     assert(Min <= Max && "Min must be less than or equal Max");
diff --git a/clang/lib/CodeGen/Targets/TCE.cpp b/clang/lib/CodeGen/Targets/TCE.cpp
index d7178b4b8a949..bfce262ea965c 100644
--- a/clang/lib/CodeGen/Targets/TCE.cpp
+++ b/clang/lib/CodeGen/Targets/TCE.cpp
@@ -53,15 +53,15 @@ void TCETargetCodeGenInfo::setTargetAttributes(
         SmallVector<llvm::Metadata *, 5> Operands;
         Operands.push_back(llvm::ConstantAsMetadata::get(F));
 
-        Operands.push_back(
-            llvm::ConstantAsMetadata::get(llvm::Constant::getIntegerValue(
-                M.Int32Ty, llvm::APInt(32, Attr->getXDim()))));
-        Operands.push_back(
-            llvm::ConstantAsMetadata::get(llvm::Constant::getIntegerValue(
-                M.Int32Ty, llvm::APInt(32, Attr->getYDim()))));
-        Operands.push_back(
-            llvm::ConstantAsMetadata::get(llvm::Constant::getIntegerValue(
-                M.Int32Ty, llvm::APInt(32, Attr->getZDim()))));
+        auto Eval = [&](Expr *E) {
+          return E->EvaluateKnownConstInt(FD->getASTContext());
+        };
+        Operands.push_back(llvm::ConstantAsMetadata::get(
+            llvm::Constant::getIntegerValue(M.Int32Ty, Eval(Attr->getXDim()))));
+        Operands.push_back(llvm::ConstantAsMetadata::get(
+            llvm::Constant::getIntegerValue(M.Int32Ty, Eval(Attr->getYDim()))));
+        Operands.push_back(llvm::ConstantAsMetadata::get(
+            llvm::Constant::getIntegerValue(M.Int32Ty, Eval(Attr->getZDim()))));
 
         // Add a boolean constant operand for "required" (true) or "hint"
         // (false) for implementing the work_group_size_hint attr later.
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index 769f5316ed934..0a8a3e1c49414 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -2914,21 +2914,70 @@ static void handleWeakImportAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
   D->addAttr(::new (S.Context) WeakImportAttr(S.Context, AL));
 }
 
+// Checks whether an argument of launch_bounds-like attribute is
+// acceptable, performs implicit conversion to Rvalue, and returns
+// non-nullptr Expr result on success. Otherwise, it returns nullptr
+// and may output an error.
+template <class Attribute>
+static Expr *makeAttributeArgExpr(Sema &S, Expr *E, const Attribute &Attr,
+                                  const unsigned Idx) {
+  if (S.DiagnoseUnexpandedParameterPack(E))
+    return nullptr;
+
+  // Accept template arguments for now as they depend on something else.
+  // We'll get to check them when they eventually get instantiated.
+  if (E->isValueDependent())
+    return E;
+
+  std::optional<llvm::APSInt> I = llvm::APSInt(64);
+  if (!(I = E->getIntegerConstantExpr(S.Context))) {
+    S.Diag(E->getExprLoc(), diag::err_attribute_argument_n_type)
+        << &Attr << Idx << AANT_ArgumentIntegerConstant << E->getSourceRange();
+    return nullptr;
+  }
+  // Make sure we can fit it in 32 bits.
+  if (!I->isIntN(32)) {
+    S.Diag(E->getExprLoc(), diag::err_ice_too_large)
+        << toString(*I, 10, false) << 32 << /* Unsigned */ 1;
+    return nullptr;
+  }
+  if (*I < 0)
+    S.Diag(E->getExprLoc(), diag::err_attribute_requires_positive_integer)
+        << &Attr << /*non-negative*/ 1 << E->getSourceRange();
+
+  // We may need to perform implicit conversion of the argument.
+  InitializedEntity Entity = InitializedEntity::InitializeParameter(
+      S.Context, S.Context.getConstType(S.Context.IntTy), /*consume*/ false);
+  ExprResult ValArg = S.PerformCopyInitialization(Entity, SourceLocation(), E);
+  assert(!ValArg.isInvalid() &&
+         "Unexpected PerformCopyInitialization() failure.");
+
+  return ValArg.getAs<Expr>();
+}
+
 // Handles reqd_work_group_size and work_group_size_hint.
 template <typename WorkGroupAttr>
 static void handleWorkGroupSize(Sema &S, Decl *D, const ParsedAttr &AL) {
-  uint32_t WGSize[3];
+  Expr *WGSize[3];
   for (unsigned i = 0; i < 3; ++i) {
-    const Expr *E = AL.getArgAsExpr(i);
-    if (!S.checkUInt32Argument(AL, E, WGSize[i], i,
-                               /*StrictlyUnsigned=*/true))
+    if (Expr *E = makeAttributeArgExpr(S, AL.getArgAsExpr(i), AL, i))
+      WGSize[i] = E;
+    else
       return;
   }
 
-  if (!llvm::all_of(WGSize, [](uint32_t Size) { return Size == 0; })) {
+  auto IsZero = [&](Expr *E) {
+    if (E->isValueDependent())
+      return false;
+    std::optional<llvm::APSInt> I = E->getIntegerConstantExpr(S.Context);
+    assert(I && "Non-integer constant expr");
+    return I->isZero();
+  };
+
+  if (!llvm::all_of(WGSize, IsZero)) {
     for (unsigned i = 0; i < 3; ++i) {
       const Expr *E = AL.getArgAsExpr(i);
-      if (WGSize[i] == 0) {
+      if (IsZero(WGSize[i])) {
         S.Diag(AL.getLoc(), diag::err_attribute_argument_is_zero)
             << AL << E->getSourceRange();
         return;
@@ -2936,10 +2985,22 @@ static void handleWorkGroupSize(Sema &S, Decl *D, const ParsedAttr &AL) {
     }
   }
 
+  auto Equal = [&](Expr *LHS, Expr *RHS) {
+    if (LHS->isValueDependent() || RHS->isValueDependent())
+      return true;
+    std::optional<llvm::APSInt> L = LHS->getIntegerConstantExpr(S.Context);
+    assert(L && "Non-integer constant expr");
+    std::optional<llvm::APSInt> R = RHS->getIntegerConstantExpr(S.Context);
+    assert(L && "Non-integer constant expr");
+    return L == R;
+  };
+
   WorkGroupAttr *Existing = D->getAttr<WorkGroupAttr>();
-  if (Existing && !(Existing->getXDim() == WGSize[0] &&
-                    Existing->getYDim() == WGSize[1] &&
-                    Existing->getZDim() == WGSize[2]))
+  if (Existing &&
+      !llvm::equal(std::initializer_list<Expr *>{Existing->getXDim(),
+                                                 Existing->getYDim(),
+                                                 Existing->getZDim()},
+                   WGSize, Equal))
     S.Diag(AL.getLoc(), diag::warn_duplicate_attribute) << AL;
 
   D->addAttr(::new (S.Context)
diff --git a/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp b/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
index 170c0b0d39f86..d5330430a3b73 100644
--- a/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
+++ b/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
@@ -572,6 +572,32 @@ static void instantiateDependentAMDGPUFlatWorkGroupSizeAttr(
   S.AMDGPU().addAMDGPUFlatWorkGroupSizeAttr(New, Attr, MinExpr, MaxExpr);
 }
 
+static void instantiateDependentReqdWorkGroupSizeAttr(
+    Sema &S, const MultiLevelTemplateArgumentList &TemplateArgs,
+    const ReqdWorkGroupSizeAttr &Attr, Decl *New) {
+  // Both min and max expression are constant expressions.
+  EnterExpressionEvaluationContext Unevaluated(
+      S, Sema::ExpressionEvaluationContext::ConstantEvaluated);
+
+  ExprResult Result = S.SubstExpr(Attr.getXDim(), TemplateArgs);
+  if (Result.isInvalid())
+    return;
+  Expr *X = Result.getAs<Expr>();
+
+  Result = S.SubstExpr(Attr.getYDim(), TemplateArgs);
+  if (Result.isInvalid())
+    return;
+  Expr *Y = Result.getAs<Expr>();
+
+  Result = S.SubstExpr(Attr.getZDim(), TemplateArgs);
+  if (Result.isInvalid())
+    return;
+  Expr *Z = Result.getAs<Expr>();
+
+  ASTContext &Context = S.getASTContext();
+  New->addAttr(::new (Context) ReqdWorkGroupSizeAttr(Context, Attr, X, Y, Z));
+}
+
 ExplicitSpecifier Sema::instantiateExplicitSpecifier(
     const MultiLevelTemplateArgumentList &TemplateArgs, ExplicitSpecifier ES) {
   if (!ES.getExpr())
@@ -812,6 +838,12 @@ void Sema::InstantiateAttrs(const MultiLevelTemplateArgumentList &TemplateArgs,
       continue;
     }
 
+    if (const auto *ReqdWorkGroupSize =
+            dyn_cast<ReqdWorkGroupSizeAttr>(TmplAttr)) {
+      instantiateDependentReqdWorkGroupSizeAttr(*this, TemplateArgs,
+                                                *ReqdWorkGroupSize, New);
+    }
+
     if (const auto *AMDGPUFlatWorkGroupSize =
             dyn_cast<AMDGPUFlatWorkGroupSizeAttr>(TmplAttr)) {
       instantiateDependentAMDGPUFlatWorkGroupSizeAttr(
diff --git a/clang/test/SemaCUDA/spirv-attrs-diag.cu b/clang/test/SemaCUDA/spirv-attrs-diag.cu
new file mode 100644
index 0000000000000..8eae45623f5c5
--- /dev/null
+++ b/clang/test/SemaCUDA/spirv-attrs-diag.cu
@@ -0,0 +1,38 @@
+// RUN: %clang_cc1 -triple spirv64 -aux-triple x86_64-unknown-linux-gnu \
+// RUN:   -fcuda-is-device -verify -fsyntax-only %s
+
+#include "Inputs/cuda.h"
+
+__attribute__((reqd_work_group_size(0x100000000, 1, 1))) // expected-error {{integer constant expression evaluates to value 4294967296 that cannot be represented in a 32-bit unsigned integer type}}
+__global__ void TestTooBigArg1(void);
+
+__attribute__((work_group_size_hint(0x100000000, 1, 1))) // expected-error {{integer constant expression evaluates to value 4294967296 that cannot be represented in a 32-bit unsigned integer type}}
+__global__ void TestTooBigArg2(void);
+
+template <int... Args>
+__attribute__((reqd_work_group_size(Args))) // expected-error {{expression contains unexpanded parameter pack 'Args'}}
+__global__ void TestTemplateVariadicArgs1(void) {}
+
+template <int... Args>
+__attribute__((work_group_size_hint(Args))) // expected-error {{expression contains unexpanded parameter pack 'Args'}}
+__global__ void TestTemplateVariadicArgs2(void) {}
+
+template <class a> // expected-note {{declared here}}
+__attribute__((reqd_work_group_size(a, 1, 1))) // expected-error {{'a' does not refer to a value}}
+__global__ void TestTemplateArgClass1(void) {}
+
+template <class a> // expected-note {{declared here}}
+__attribute__((work_group_size_hint(a, 1, 1))) // expected-error {{'a' does not refer to a value}}
+__global__ void TestTemplateArgClass2(void) {}
+
+constexpr int A = 512;
+
+__attribute__((reqd_work_group_size(A, A, A)))
+__global__ void TestConstIntArg1(void) {}
+
+__attribute__((work_group_size_hint(A, A, A)))
+__global__ void TestConstIntArg2(void) {}
+
+
+
+
diff --git a/clang/test/SemaCUDA/spirv-attrs.cu b/clang/test/SemaCUDA/spirv-attrs.cu
index 6539421423ee1..fd944e665f169 100644
--- a/clang/test/SemaCUDA/spirv-attrs.cu
+++ b/clang/test/SemaCUDA/spirv-attrs.cu
@@ -8,11 +8,27 @@
 __attribute__((reqd_work_group_size(128, 1, 1)))
 __global__ void reqd_work_group_size_128_1_1() {}
 
+template <unsigned a, unsigned b, unsigned c>
+__attribute__((reqd_work_group_size(a, b, c)))
+__global__ void reqd_work_group_size_a_b_c() {}
+
+template <>
+__global__ void reqd_work_group_size_a_b_c<128,1,1>(void);
+
 __attribute__((work_group_size_hint(2, 2, 2)))
 __global__ void work_group_size_hint_2_2_2() {}
 
+template <unsigned a, unsigned b, unsigned c>
+__attribute__((work_group_size_hint(a, b, c)))
+__global__ void work_group_size_hint_a_b_c() {}
+
+template <>
+__global__ void work_group_size_hint_a_b_c<128,1,1>(void);
+
 __attribute__((vec_type_hint(int)))
 __global__ void vec_type_hint_int() {}
 
 __attribute__((intel_reqd_sub_group_size(64)))
 __global__ void intel_reqd_sub_group_size_64() {}
+
+

@llvmbot
Copy link
Member

llvmbot commented Mar 17, 2025

@llvm/pr-subscribers-clang-codegen

Author: Alexander Shaposhnikov (alexander-shaposhnikov)

Changes

Allow using integral non-type template parameters as attribute arguments of
reqd_work_group_size and work_group_size_hint.

Test plan:
ninja check-all


Full diff: https://github.com/llvm/llvm-project/pull/131546.diff

9 Files Affected:

  • (modified) clang-tools-extra/clang-tidy/altera/SingleWorkItemBarrierCheck.cpp (+6-2)
  • (modified) clang/include/clang/Basic/Attr.td (+2-5)
  • (modified) clang/lib/CodeGen/CodeGenFunction.cpp (+12-6)
  • (modified) clang/lib/CodeGen/Targets/AMDGPU.cpp (+11-3)
  • (modified) clang/lib/CodeGen/Targets/TCE.cpp (+9-9)
  • (modified) clang/lib/Sema/SemaDeclAttr.cpp (+70-9)
  • (modified) clang/lib/Sema/SemaTemplateInstantiateDecl.cpp (+32)
  • (added) clang/test/SemaCUDA/spirv-attrs-diag.cu (+38)
  • (modified) clang/test/SemaCUDA/spirv-attrs.cu (+16)
diff --git a/clang-tools-extra/clang-tidy/altera/SingleWorkItemBarrierCheck.cpp b/clang-tools-extra/clang-tidy/altera/SingleWorkItemBarrierCheck.cpp
index df21c425ea956..c5da66a1f28b6 100644
--- a/clang-tools-extra/clang-tidy/altera/SingleWorkItemBarrierCheck.cpp
+++ b/clang-tools-extra/clang-tidy/altera/SingleWorkItemBarrierCheck.cpp
@@ -54,8 +54,12 @@ void SingleWorkItemBarrierCheck::check(const MatchFinder::MatchResult &Result) {
     bool IsNDRange = false;
     if (MatchedDecl->hasAttr<ReqdWorkGroupSizeAttr>()) {
       const auto *Attribute = MatchedDecl->getAttr<ReqdWorkGroupSizeAttr>();
-      if (Attribute->getXDim() > 1 || Attribute->getYDim() > 1 ||
-          Attribute->getZDim() > 1)
+      auto Eval = [&](Expr *E) {
+        return E->EvaluateKnownConstInt(MatchedDecl->getASTContext())
+            .getExtValue();
+      };
+      if (Eval(Attribute->getXDim()) > 1 || Eval(Attribute->getYDim()) > 1 ||
+          Eval(Attribute->getZDim()) > 1)
         IsNDRange = true;
     }
     if (IsNDRange) // No warning if kernel is treated as an NDRange.
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index 4d34346460561..cceb4085d523d 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -3044,8 +3044,7 @@ def NoDeref : TypeAttr {
 def ReqdWorkGroupSize : InheritableAttr {
   // Does not have a [[]] spelling because it is an OpenCL-related attribute.
   let Spellings = [GNU<"reqd_work_group_size">];
-  let Args = [UnsignedArgument<"XDim">, UnsignedArgument<"YDim">,
-              UnsignedArgument<"ZDim">];
+  let Args = [ExprArgument<"XDim">, ExprArgument<"YDim">, ExprArgument<"ZDim">];
   let Subjects = SubjectList<[Function], ErrorDiag>;
   let Documentation = [Undocumented];
 }
@@ -3053,9 +3052,7 @@ def ReqdWorkGroupSize : InheritableAttr {
 def WorkGroupSizeHint :  InheritableAttr {
   // Does not have a [[]] spelling because it is an OpenCL-related attribute.
   let Spellings = [GNU<"work_group_size_hint">];
-  let Args = [UnsignedArgument<"XDim">,
-              UnsignedArgument<"YDim">,
-              UnsignedArgument<"ZDim">];
+  let Args = [ExprArgument<"XDim">, ExprArgument<"YDim">, ExprArgument<"ZDim">];
   let Subjects = SubjectList<[Function], ErrorDiag>;
   let Documentation = [Undocumented];
 }
diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp
index 447192bc7f60c..de1e04a7982fa 100644
--- a/clang/lib/CodeGen/CodeGenFunction.cpp
+++ b/clang/lib/CodeGen/CodeGenFunction.cpp
@@ -649,18 +649,24 @@ void CodeGenFunction::EmitKernelMetadata(const FunctionDecl *FD,
   }
 
   if (const WorkGroupSizeHintAttr *A = FD->getAttr<WorkGroupSizeHintAttr>()) {
+    auto Eval = [&](Expr *E) {
+      return E->EvaluateKnownConstInt(FD->getASTContext()).getExtValue();
+    };
     llvm::Metadata *AttrMDArgs[] = {
-        llvm::ConstantAsMetadata::get(Builder.getInt32(A->getXDim())),
-        llvm::ConstantAsMetadata::get(Builder.getInt32(A->getYDim())),
-        llvm::ConstantAsMetadata::get(Builder.getInt32(A->getZDim()))};
+        llvm::ConstantAsMetadata::get(Builder.getInt32(Eval(A->getXDim()))),
+        llvm::ConstantAsMetadata::get(Builder.getInt32(Eval(A->getYDim()))),
+        llvm::ConstantAsMetadata::get(Builder.getInt32(Eval(A->getZDim())))};
     Fn->setMetadata("work_group_size_hint", llvm::MDNode::get(Context, AttrMDArgs));
   }
 
   if (const ReqdWorkGroupSizeAttr *A = FD->getAttr<ReqdWorkGroupSizeAttr>()) {
+    auto Eval = [&](Expr *E) {
+      return E->EvaluateKnownConstInt(FD->getASTContext()).getExtValue();
+    };
     llvm::Metadata *AttrMDArgs[] = {
-        llvm::ConstantAsMetadata::get(Builder.getInt32(A->getXDim())),
-        llvm::ConstantAsMetadata::get(Builder.getInt32(A->getYDim())),
-        llvm::ConstantAsMetadata::get(Builder.getInt32(A->getZDim()))};
+        llvm::ConstantAsMetadata::get(Builder.getInt32(Eval(A->getXDim()))),
+        llvm::ConstantAsMetadata::get(Builder.getInt32(Eval(A->getYDim()))),
+        llvm::ConstantAsMetadata::get(Builder.getInt32(Eval(A->getZDim())))};
     Fn->setMetadata("reqd_work_group_size", llvm::MDNode::get(Context, AttrMDArgs));
   }
 
diff --git a/clang/lib/CodeGen/Targets/AMDGPU.cpp b/clang/lib/CodeGen/Targets/AMDGPU.cpp
index 7d84310ba0bf5..a14f599330644 100644
--- a/clang/lib/CodeGen/Targets/AMDGPU.cpp
+++ b/clang/lib/CodeGen/Targets/AMDGPU.cpp
@@ -753,12 +753,20 @@ void CodeGenModule::handleAMDGPUFlatWorkGroupSizeAttr(
     int32_t *MaxThreadsVal) {
   unsigned Min = 0;
   unsigned Max = 0;
+  auto Eval = [&](Expr *E) {
+    return E->EvaluateKnownConstInt(getContext()).getExtValue();
+  };
   if (FlatWGS) {
-    Min = FlatWGS->getMin()->EvaluateKnownConstInt(getContext()).getExtValue();
-    Max = FlatWGS->getMax()->EvaluateKnownConstInt(getContext()).getExtValue();
+    Min = Eval(
+        FlatWGS
+            ->getMin()); //->EvaluateKnownConstInt(getContext()).getExtValue();
+    Max = Eval(
+        FlatWGS
+            ->getMax()); //->EvaluateKnownConstInt(getContext()).getExtValue();
   }
   if (ReqdWGS && Min == 0 && Max == 0)
-    Min = Max = ReqdWGS->getXDim() * ReqdWGS->getYDim() * ReqdWGS->getZDim();
+    Min = Max = Eval(ReqdWGS->getXDim()) * Eval(ReqdWGS->getYDim()) *
+                Eval(ReqdWGS->getZDim());
 
   if (Min != 0) {
     assert(Min <= Max && "Min must be less than or equal Max");
diff --git a/clang/lib/CodeGen/Targets/TCE.cpp b/clang/lib/CodeGen/Targets/TCE.cpp
index d7178b4b8a949..bfce262ea965c 100644
--- a/clang/lib/CodeGen/Targets/TCE.cpp
+++ b/clang/lib/CodeGen/Targets/TCE.cpp
@@ -53,15 +53,15 @@ void TCETargetCodeGenInfo::setTargetAttributes(
         SmallVector<llvm::Metadata *, 5> Operands;
         Operands.push_back(llvm::ConstantAsMetadata::get(F));
 
-        Operands.push_back(
-            llvm::ConstantAsMetadata::get(llvm::Constant::getIntegerValue(
-                M.Int32Ty, llvm::APInt(32, Attr->getXDim()))));
-        Operands.push_back(
-            llvm::ConstantAsMetadata::get(llvm::Constant::getIntegerValue(
-                M.Int32Ty, llvm::APInt(32, Attr->getYDim()))));
-        Operands.push_back(
-            llvm::ConstantAsMetadata::get(llvm::Constant::getIntegerValue(
-                M.Int32Ty, llvm::APInt(32, Attr->getZDim()))));
+        auto Eval = [&](Expr *E) {
+          return E->EvaluateKnownConstInt(FD->getASTContext());
+        };
+        Operands.push_back(llvm::ConstantAsMetadata::get(
+            llvm::Constant::getIntegerValue(M.Int32Ty, Eval(Attr->getXDim()))));
+        Operands.push_back(llvm::ConstantAsMetadata::get(
+            llvm::Constant::getIntegerValue(M.Int32Ty, Eval(Attr->getYDim()))));
+        Operands.push_back(llvm::ConstantAsMetadata::get(
+            llvm::Constant::getIntegerValue(M.Int32Ty, Eval(Attr->getZDim()))));
 
         // Add a boolean constant operand for "required" (true) or "hint"
         // (false) for implementing the work_group_size_hint attr later.
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index 769f5316ed934..0a8a3e1c49414 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -2914,21 +2914,70 @@ static void handleWeakImportAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
   D->addAttr(::new (S.Context) WeakImportAttr(S.Context, AL));
 }
 
+// Checks whether an argument of launch_bounds-like attribute is
+// acceptable, performs implicit conversion to Rvalue, and returns
+// non-nullptr Expr result on success. Otherwise, it returns nullptr
+// and may output an error.
+template <class Attribute>
+static Expr *makeAttributeArgExpr(Sema &S, Expr *E, const Attribute &Attr,
+                                  const unsigned Idx) {
+  if (S.DiagnoseUnexpandedParameterPack(E))
+    return nullptr;
+
+  // Accept template arguments for now as they depend on something else.
+  // We'll get to check them when they eventually get instantiated.
+  if (E->isValueDependent())
+    return E;
+
+  std::optional<llvm::APSInt> I = llvm::APSInt(64);
+  if (!(I = E->getIntegerConstantExpr(S.Context))) {
+    S.Diag(E->getExprLoc(), diag::err_attribute_argument_n_type)
+        << &Attr << Idx << AANT_ArgumentIntegerConstant << E->getSourceRange();
+    return nullptr;
+  }
+  // Make sure we can fit it in 32 bits.
+  if (!I->isIntN(32)) {
+    S.Diag(E->getExprLoc(), diag::err_ice_too_large)
+        << toString(*I, 10, false) << 32 << /* Unsigned */ 1;
+    return nullptr;
+  }
+  if (*I < 0)
+    S.Diag(E->getExprLoc(), diag::err_attribute_requires_positive_integer)
+        << &Attr << /*non-negative*/ 1 << E->getSourceRange();
+
+  // We may need to perform implicit conversion of the argument.
+  InitializedEntity Entity = InitializedEntity::InitializeParameter(
+      S.Context, S.Context.getConstType(S.Context.IntTy), /*consume*/ false);
+  ExprResult ValArg = S.PerformCopyInitialization(Entity, SourceLocation(), E);
+  assert(!ValArg.isInvalid() &&
+         "Unexpected PerformCopyInitialization() failure.");
+
+  return ValArg.getAs<Expr>();
+}
+
 // Handles reqd_work_group_size and work_group_size_hint.
 template <typename WorkGroupAttr>
 static void handleWorkGroupSize(Sema &S, Decl *D, const ParsedAttr &AL) {
-  uint32_t WGSize[3];
+  Expr *WGSize[3];
   for (unsigned i = 0; i < 3; ++i) {
-    const Expr *E = AL.getArgAsExpr(i);
-    if (!S.checkUInt32Argument(AL, E, WGSize[i], i,
-                               /*StrictlyUnsigned=*/true))
+    if (Expr *E = makeAttributeArgExpr(S, AL.getArgAsExpr(i), AL, i))
+      WGSize[i] = E;
+    else
       return;
   }
 
-  if (!llvm::all_of(WGSize, [](uint32_t Size) { return Size == 0; })) {
+  auto IsZero = [&](Expr *E) {
+    if (E->isValueDependent())
+      return false;
+    std::optional<llvm::APSInt> I = E->getIntegerConstantExpr(S.Context);
+    assert(I && "Non-integer constant expr");
+    return I->isZero();
+  };
+
+  if (!llvm::all_of(WGSize, IsZero)) {
     for (unsigned i = 0; i < 3; ++i) {
       const Expr *E = AL.getArgAsExpr(i);
-      if (WGSize[i] == 0) {
+      if (IsZero(WGSize[i])) {
         S.Diag(AL.getLoc(), diag::err_attribute_argument_is_zero)
             << AL << E->getSourceRange();
         return;
@@ -2936,10 +2985,22 @@ static void handleWorkGroupSize(Sema &S, Decl *D, const ParsedAttr &AL) {
     }
   }
 
+  auto Equal = [&](Expr *LHS, Expr *RHS) {
+    if (LHS->isValueDependent() || RHS->isValueDependent())
+      return true;
+    std::optional<llvm::APSInt> L = LHS->getIntegerConstantExpr(S.Context);
+    assert(L && "Non-integer constant expr");
+    std::optional<llvm::APSInt> R = RHS->getIntegerConstantExpr(S.Context);
+    assert(L && "Non-integer constant expr");
+    return L == R;
+  };
+
   WorkGroupAttr *Existing = D->getAttr<WorkGroupAttr>();
-  if (Existing && !(Existing->getXDim() == WGSize[0] &&
-                    Existing->getYDim() == WGSize[1] &&
-                    Existing->getZDim() == WGSize[2]))
+  if (Existing &&
+      !llvm::equal(std::initializer_list<Expr *>{Existing->getXDim(),
+                                                 Existing->getYDim(),
+                                                 Existing->getZDim()},
+                   WGSize, Equal))
     S.Diag(AL.getLoc(), diag::warn_duplicate_attribute) << AL;
 
   D->addAttr(::new (S.Context)
diff --git a/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp b/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
index 170c0b0d39f86..d5330430a3b73 100644
--- a/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
+++ b/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
@@ -572,6 +572,32 @@ static void instantiateDependentAMDGPUFlatWorkGroupSizeAttr(
   S.AMDGPU().addAMDGPUFlatWorkGroupSizeAttr(New, Attr, MinExpr, MaxExpr);
 }
 
+static void instantiateDependentReqdWorkGroupSizeAttr(
+    Sema &S, const MultiLevelTemplateArgumentList &TemplateArgs,
+    const ReqdWorkGroupSizeAttr &Attr, Decl *New) {
+  // Both min and max expression are constant expressions.
+  EnterExpressionEvaluationContext Unevaluated(
+      S, Sema::ExpressionEvaluationContext::ConstantEvaluated);
+
+  ExprResult Result = S.SubstExpr(Attr.getXDim(), TemplateArgs);
+  if (Result.isInvalid())
+    return;
+  Expr *X = Result.getAs<Expr>();
+
+  Result = S.SubstExpr(Attr.getYDim(), TemplateArgs);
+  if (Result.isInvalid())
+    return;
+  Expr *Y = Result.getAs<Expr>();
+
+  Result = S.SubstExpr(Attr.getZDim(), TemplateArgs);
+  if (Result.isInvalid())
+    return;
+  Expr *Z = Result.getAs<Expr>();
+
+  ASTContext &Context = S.getASTContext();
+  New->addAttr(::new (Context) ReqdWorkGroupSizeAttr(Context, Attr, X, Y, Z));
+}
+
 ExplicitSpecifier Sema::instantiateExplicitSpecifier(
     const MultiLevelTemplateArgumentList &TemplateArgs, ExplicitSpecifier ES) {
   if (!ES.getExpr())
@@ -812,6 +838,12 @@ void Sema::InstantiateAttrs(const MultiLevelTemplateArgumentList &TemplateArgs,
       continue;
     }
 
+    if (const auto *ReqdWorkGroupSize =
+            dyn_cast<ReqdWorkGroupSizeAttr>(TmplAttr)) {
+      instantiateDependentReqdWorkGroupSizeAttr(*this, TemplateArgs,
+                                                *ReqdWorkGroupSize, New);
+    }
+
     if (const auto *AMDGPUFlatWorkGroupSize =
             dyn_cast<AMDGPUFlatWorkGroupSizeAttr>(TmplAttr)) {
       instantiateDependentAMDGPUFlatWorkGroupSizeAttr(
diff --git a/clang/test/SemaCUDA/spirv-attrs-diag.cu b/clang/test/SemaCUDA/spirv-attrs-diag.cu
new file mode 100644
index 0000000000000..8eae45623f5c5
--- /dev/null
+++ b/clang/test/SemaCUDA/spirv-attrs-diag.cu
@@ -0,0 +1,38 @@
+// RUN: %clang_cc1 -triple spirv64 -aux-triple x86_64-unknown-linux-gnu \
+// RUN:   -fcuda-is-device -verify -fsyntax-only %s
+
+#include "Inputs/cuda.h"
+
+__attribute__((reqd_work_group_size(0x100000000, 1, 1))) // expected-error {{integer constant expression evaluates to value 4294967296 that cannot be represented in a 32-bit unsigned integer type}}
+__global__ void TestTooBigArg1(void);
+
+__attribute__((work_group_size_hint(0x100000000, 1, 1))) // expected-error {{integer constant expression evaluates to value 4294967296 that cannot be represented in a 32-bit unsigned integer type}}
+__global__ void TestTooBigArg2(void);
+
+template <int... Args>
+__attribute__((reqd_work_group_size(Args))) // expected-error {{expression contains unexpanded parameter pack 'Args'}}
+__global__ void TestTemplateVariadicArgs1(void) {}
+
+template <int... Args>
+__attribute__((work_group_size_hint(Args))) // expected-error {{expression contains unexpanded parameter pack 'Args'}}
+__global__ void TestTemplateVariadicArgs2(void) {}
+
+template <class a> // expected-note {{declared here}}
+__attribute__((reqd_work_group_size(a, 1, 1))) // expected-error {{'a' does not refer to a value}}
+__global__ void TestTemplateArgClass1(void) {}
+
+template <class a> // expected-note {{declared here}}
+__attribute__((work_group_size_hint(a, 1, 1))) // expected-error {{'a' does not refer to a value}}
+__global__ void TestTemplateArgClass2(void) {}
+
+constexpr int A = 512;
+
+__attribute__((reqd_work_group_size(A, A, A)))
+__global__ void TestConstIntArg1(void) {}
+
+__attribute__((work_group_size_hint(A, A, A)))
+__global__ void TestConstIntArg2(void) {}
+
+
+
+
diff --git a/clang/test/SemaCUDA/spirv-attrs.cu b/clang/test/SemaCUDA/spirv-attrs.cu
index 6539421423ee1..fd944e665f169 100644
--- a/clang/test/SemaCUDA/spirv-attrs.cu
+++ b/clang/test/SemaCUDA/spirv-attrs.cu
@@ -8,11 +8,27 @@
 __attribute__((reqd_work_group_size(128, 1, 1)))
 __global__ void reqd_work_group_size_128_1_1() {}
 
+template <unsigned a, unsigned b, unsigned c>
+__attribute__((reqd_work_group_size(a, b, c)))
+__global__ void reqd_work_group_size_a_b_c() {}
+
+template <>
+__global__ void reqd_work_group_size_a_b_c<128,1,1>(void);
+
 __attribute__((work_group_size_hint(2, 2, 2)))
 __global__ void work_group_size_hint_2_2_2() {}
 
+template <unsigned a, unsigned b, unsigned c>
+__attribute__((work_group_size_hint(a, b, c)))
+__global__ void work_group_size_hint_a_b_c() {}
+
+template <>
+__global__ void work_group_size_hint_a_b_c<128,1,1>(void);
+
 __attribute__((vec_type_hint(int)))
 __global__ void vec_type_hint_int() {}
 
 __attribute__((intel_reqd_sub_group_size(64)))
 __global__ void intel_reqd_sub_group_size_64() {}
+
+

Copy link

github-actions bot commented Mar 17, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@ShangwuYao ShangwuYao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty neat!!

@alexander-shaposhnikov alexander-shaposhnikov force-pushed the template_attr branch 2 times, most recently from b5aefb3 to 6f00315 Compare March 17, 2025 21:26
@arsenm arsenm requested review from Artem-B and yxsamliu March 18, 2025 01:29
Copy link
Contributor

@AlexVlx AlexVlx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Copy link
Collaborator

@yxsamliu yxsamliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks

@alexander-shaposhnikov alexander-shaposhnikov merged commit 297f0b3 into llvm:main Mar 19, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AMDGPU clang:codegen IR generation bugs: mangling, exceptions, etc. clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category clang-tidy clang-tools-extra
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants