Skip to content

[NFC][RootSignature] Use llvm::EnumEntry for serialization of Root Signature Elements #144106

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 16, 2025

Conversation

inbelic
Copy link
Contributor

@inbelic inbelic commented Jun 13, 2025

It has pointed out here that we may be able to use llvm::EnumEntry so that we can re-use the printing logic across enumerations.

  • Enables re-use of printEnum and printFlags methods via templates

  • Allows easy definition of getEnumName function for enum-to-string conversion, eliminating the need to use a string stream for constructing the Name SmallString

  • Also, does a small fix-up of the operands for descriptor table clause to be consistent with other Build* methods

For reference, the test-cases that must not change expected output.

@llvmbot llvmbot added the HLSL HLSL Language Support label Jun 13, 2025
@llvmbot
Copy link
Member

llvmbot commented Jun 13, 2025

@llvm/pr-subscribers-hlsl

Author: Finn Plummer (inbelic)

Changes
  • Enables re-use of printEnum and printFlags methods via templates

  • Allows easy definition of getEnumName function for enum-to-string conversion, eliminating the need to use a string stream for constructing the Name SmallString

  • Also, does a small fix-up of the operands for descriptor table clause to be consistent with other Build* methods


Full diff: https://github.com/llvm/llvm-project/pull/144106.diff

1 Files Affected:

  • (modified) llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp (+103-105)
diff --git a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
index 765a3bcbed7e2..ab5ced523996a 100644
--- a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
+++ b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
@@ -15,112 +15,48 @@
 #include "llvm/ADT/bit.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Metadata.h"
+#include "llvm/Support/ScopedPrinter.h"
 
 namespace llvm {
 namespace hlsl {
 namespace rootsig {
 
-static raw_ostream &operator<<(raw_ostream &OS, const Register &Reg) {
-  switch (Reg.ViewType) {
-  case RegisterType::BReg:
-    OS << "b";
-    break;
-  case RegisterType::TReg:
-    OS << "t";
-    break;
-  case RegisterType::UReg:
-    OS << "u";
-    break;
-  case RegisterType::SReg:
-    OS << "s";
-    break;
-  }
-  OS << Reg.Number;
-  return OS;
+template <typename T>
+static StringRef getEnumName(const T Value, ArrayRef<EnumEntry<T>> Enums) {
+  for (const auto &EnumItem : Enums)
+    if (EnumItem.Value == Value)
+      return EnumItem.Name;
+  return "";
 }
 
-static raw_ostream &operator<<(raw_ostream &OS,
-                               const ShaderVisibility &Visibility) {
-  switch (Visibility) {
-  case ShaderVisibility::All:
-    OS << "All";
-    break;
-  case ShaderVisibility::Vertex:
-    OS << "Vertex";
-    break;
-  case ShaderVisibility::Hull:
-    OS << "Hull";
-    break;
-  case ShaderVisibility::Domain:
-    OS << "Domain";
-    break;
-  case ShaderVisibility::Geometry:
-    OS << "Geometry";
-    break;
-  case ShaderVisibility::Pixel:
-    OS << "Pixel";
-    break;
-  case ShaderVisibility::Amplification:
-    OS << "Amplification";
-    break;
-  case ShaderVisibility::Mesh:
-    OS << "Mesh";
-    break;
-  }
-
-  return OS;
-}
-
-static raw_ostream &operator<<(raw_ostream &OS, const ClauseType &Type) {
-  switch (Type) {
-  case ClauseType::CBuffer:
-    OS << "CBV";
-    break;
-  case ClauseType::SRV:
-    OS << "SRV";
-    break;
-  case ClauseType::UAV:
-    OS << "UAV";
-    break;
-  case ClauseType::Sampler:
-    OS << "Sampler";
-    break;
-  }
+template <typename T>
+static raw_ostream &printEnum(raw_ostream &OS, const T Value,
+                              ArrayRef<EnumEntry<T>> Enums) {
+  OS << getEnumName(Value, Enums);
 
   return OS;
 }
 
-static raw_ostream &operator<<(raw_ostream &OS,
-                               const DescriptorRangeFlags &Flags) {
+template <typename T>
+static raw_ostream &printFlags(raw_ostream &OS, const T Value,
+                               ArrayRef<EnumEntry<T>> Flags) {
   bool FlagSet = false;
-  unsigned Remaining = llvm::to_underlying(Flags);
+  unsigned Remaining = llvm::to_underlying(Value);
   while (Remaining) {
     unsigned Bit = 1u << llvm::countr_zero(Remaining);
     if (Remaining & Bit) {
       if (FlagSet)
         OS << " | ";
 
-      switch (static_cast<DescriptorRangeFlags>(Bit)) {
-      case DescriptorRangeFlags::DescriptorsVolatile:
-        OS << "DescriptorsVolatile";
-        break;
-      case DescriptorRangeFlags::DataVolatile:
-        OS << "DataVolatile";
-        break;
-      case DescriptorRangeFlags::DataStaticWhileSetAtExecute:
-        OS << "DataStaticWhileSetAtExecute";
-        break;
-      case DescriptorRangeFlags::DataStatic:
-        OS << "DataStatic";
-        break;
-      case DescriptorRangeFlags::DescriptorsStaticKeepingBufferBoundsChecks:
-        OS << "DescriptorsStaticKeepingBufferBoundsChecks";
-        break;
-      default:
+      bool Found = false;
+      for (const auto &FlagItem : Flags)
+        if (FlagItem.Value == T(Bit)) {
+          OS << FlagItem.Name;
+          Found = true;
+          break;
+        }
+      if (!Found)
         OS << "invalid: " << Bit;
-        break;
-      }
-
       FlagSet = true;
     }
     Remaining &= ~Bit;
@@ -128,6 +64,68 @@ static raw_ostream &operator<<(raw_ostream &OS,
 
   if (!FlagSet)
     OS << "None";
+  return OS;
+}
+
+static const EnumEntry<RegisterType> RegisterNames[] = {
+    {"b", RegisterType::BReg},
+    {"t", RegisterType::TReg},
+    {"u", RegisterType::UReg},
+    {"s", RegisterType::SReg},
+};
+
+static raw_ostream &operator<<(raw_ostream &OS, const Register &Reg) {
+  printEnum(OS, Reg.ViewType, ArrayRef(RegisterNames));
+  OS << Reg.Number;
+
+  return OS;
+}
+
+static const EnumEntry<ShaderVisibility> VisibilityNames[] = {
+    {"All", ShaderVisibility::All},
+    {"Vertex", ShaderVisibility::Vertex},
+    {"Hull", ShaderVisibility::Hull},
+    {"Domain", ShaderVisibility::Domain},
+    {"Geometry", ShaderVisibility::Geometry},
+    {"Pixel", ShaderVisibility::Pixel},
+    {"Amplification", ShaderVisibility::Amplification},
+    {"Mesh", ShaderVisibility::Mesh},
+};
+
+static raw_ostream &operator<<(raw_ostream &OS,
+                               const ShaderVisibility &Visibility) {
+  printEnum(OS, Visibility, ArrayRef(VisibilityNames));
+
+  return OS;
+}
+
+static const EnumEntry<dxil::ResourceClass> ResourceClassNames[] = {
+    {"CBV", dxil::ResourceClass::CBuffer},
+    {"SRV", dxil::ResourceClass::SRV},
+    {"UAV", dxil::ResourceClass::UAV},
+    {"Sampler", dxil::ResourceClass::Sampler},
+};
+
+static raw_ostream &operator<<(raw_ostream &OS, const ClauseType &Type) {
+  printEnum(OS, dxil::ResourceClass(llvm::to_underlying(Type)),
+            ArrayRef(ResourceClassNames));
+
+  return OS;
+}
+
+static const EnumEntry<DescriptorRangeFlags> DescriptorRangeFlagNames[] = {
+    {"DescriptorsVolatile", DescriptorRangeFlags::DescriptorsVolatile},
+    {"DataVolatile", DescriptorRangeFlags::DataVolatile},
+    {"DataStaticWhileSetAtExecute",
+     DescriptorRangeFlags::DataStaticWhileSetAtExecute},
+    {"DataStatic", DescriptorRangeFlags::DataStatic},
+    {"DescriptorsStaticKeepingBufferBoundsChecks",
+     DescriptorRangeFlags::DescriptorsStaticKeepingBufferBoundsChecks},
+};
+
+static raw_ostream &operator<<(raw_ostream &OS,
+                               const DescriptorRangeFlags &Flags) {
+  printFlags(OS, Flags, ArrayRef(DescriptorRangeFlagNames));
 
   return OS;
 }
@@ -236,12 +234,12 @@ MDNode *MetadataBuilder::BuildRootConstants(const RootConstants &Constants) {
 
 MDNode *MetadataBuilder::BuildRootDescriptor(const RootDescriptor &Descriptor) {
   IRBuilder<> Builder(Ctx);
-  llvm::SmallString<7> Name;
-  llvm::raw_svector_ostream OS(Name);
-  OS << "Root" << ClauseType(llvm::to_underlying(Descriptor.Type));
-
+  StringRef TypeName =
+      getEnumName(dxil::ResourceClass(llvm::to_underlying(Descriptor.Type)),
+                  ArrayRef(ResourceClassNames));
+  llvm::SmallString<7> Name({"Root", TypeName});
   Metadata *Operands[] = {
-      MDString::get(Ctx, OS.str()),
+      MDString::get(Ctx, Name),
       ConstantAsMetadata::get(
           Builder.getInt32(llvm::to_underlying(Descriptor.Visibility))),
       ConstantAsMetadata::get(Builder.getInt32(Descriptor.Reg.Number)),
@@ -277,19 +275,19 @@ MDNode *MetadataBuilder::BuildDescriptorTable(const DescriptorTable &Table) {
 MDNode *MetadataBuilder::BuildDescriptorTableClause(
     const DescriptorTableClause &Clause) {
   IRBuilder<> Builder(Ctx);
-  std::string Name;
-  llvm::raw_string_ostream OS(Name);
-  OS << Clause.Type;
-  return MDNode::get(
-      Ctx, {
-               MDString::get(Ctx, OS.str()),
-               ConstantAsMetadata::get(Builder.getInt32(Clause.NumDescriptors)),
-               ConstantAsMetadata::get(Builder.getInt32(Clause.Reg.Number)),
-               ConstantAsMetadata::get(Builder.getInt32(Clause.Space)),
-               ConstantAsMetadata::get(Builder.getInt32(Clause.Offset)),
-               ConstantAsMetadata::get(
-                   Builder.getInt32(llvm::to_underlying(Clause.Flags))),
-           });
+  StringRef Name =
+      getEnumName(dxil::ResourceClass(llvm::to_underlying(Clause.Type)),
+                  ArrayRef(ResourceClassNames));
+  Metadata *Operands[] = {
+      MDString::get(Ctx, Name),
+      ConstantAsMetadata::get(Builder.getInt32(Clause.NumDescriptors)),
+      ConstantAsMetadata::get(Builder.getInt32(Clause.Reg.Number)),
+      ConstantAsMetadata::get(Builder.getInt32(Clause.Space)),
+      ConstantAsMetadata::get(Builder.getInt32(Clause.Offset)),
+      ConstantAsMetadata::get(
+          Builder.getInt32(llvm::to_underlying(Clause.Flags))),
+  };
+  return MDNode::get(Ctx, Operands);
 }
 
 MDNode *MetadataBuilder::BuildStaticSampler(const StaticSampler &Sampler) {

Copy link
Contributor

@joaosaffran joaosaffran left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing my suggestions

Copy link
Contributor

@alsepkow alsepkow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Can't approve though.

@inbelic inbelic merged commit 63b80dd into llvm:main Jun 16, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
HLSL HLSL Language Support
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants