-
Notifications
You must be signed in to change notification settings - Fork 14k
[NFC][RootSignature] Use llvm::EnumEntry
for serialization of Root Signature Elements
#144106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-hlsl Author: Finn Plummer (inbelic) Changes
Full diff: https://github.com/llvm/llvm-project/pull/144106.diff 1 Files Affected:
diff --git a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
index 765a3bcbed7e2..ab5ced523996a 100644
--- a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
+++ b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
@@ -15,112 +15,48 @@
#include "llvm/ADT/bit.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Metadata.h"
+#include "llvm/Support/ScopedPrinter.h"
namespace llvm {
namespace hlsl {
namespace rootsig {
-static raw_ostream &operator<<(raw_ostream &OS, const Register &Reg) {
- switch (Reg.ViewType) {
- case RegisterType::BReg:
- OS << "b";
- break;
- case RegisterType::TReg:
- OS << "t";
- break;
- case RegisterType::UReg:
- OS << "u";
- break;
- case RegisterType::SReg:
- OS << "s";
- break;
- }
- OS << Reg.Number;
- return OS;
+template <typename T>
+static StringRef getEnumName(const T Value, ArrayRef<EnumEntry<T>> Enums) {
+ for (const auto &EnumItem : Enums)
+ if (EnumItem.Value == Value)
+ return EnumItem.Name;
+ return "";
}
-static raw_ostream &operator<<(raw_ostream &OS,
- const ShaderVisibility &Visibility) {
- switch (Visibility) {
- case ShaderVisibility::All:
- OS << "All";
- break;
- case ShaderVisibility::Vertex:
- OS << "Vertex";
- break;
- case ShaderVisibility::Hull:
- OS << "Hull";
- break;
- case ShaderVisibility::Domain:
- OS << "Domain";
- break;
- case ShaderVisibility::Geometry:
- OS << "Geometry";
- break;
- case ShaderVisibility::Pixel:
- OS << "Pixel";
- break;
- case ShaderVisibility::Amplification:
- OS << "Amplification";
- break;
- case ShaderVisibility::Mesh:
- OS << "Mesh";
- break;
- }
-
- return OS;
-}
-
-static raw_ostream &operator<<(raw_ostream &OS, const ClauseType &Type) {
- switch (Type) {
- case ClauseType::CBuffer:
- OS << "CBV";
- break;
- case ClauseType::SRV:
- OS << "SRV";
- break;
- case ClauseType::UAV:
- OS << "UAV";
- break;
- case ClauseType::Sampler:
- OS << "Sampler";
- break;
- }
+template <typename T>
+static raw_ostream &printEnum(raw_ostream &OS, const T Value,
+ ArrayRef<EnumEntry<T>> Enums) {
+ OS << getEnumName(Value, Enums);
return OS;
}
-static raw_ostream &operator<<(raw_ostream &OS,
- const DescriptorRangeFlags &Flags) {
+template <typename T>
+static raw_ostream &printFlags(raw_ostream &OS, const T Value,
+ ArrayRef<EnumEntry<T>> Flags) {
bool FlagSet = false;
- unsigned Remaining = llvm::to_underlying(Flags);
+ unsigned Remaining = llvm::to_underlying(Value);
while (Remaining) {
unsigned Bit = 1u << llvm::countr_zero(Remaining);
if (Remaining & Bit) {
if (FlagSet)
OS << " | ";
- switch (static_cast<DescriptorRangeFlags>(Bit)) {
- case DescriptorRangeFlags::DescriptorsVolatile:
- OS << "DescriptorsVolatile";
- break;
- case DescriptorRangeFlags::DataVolatile:
- OS << "DataVolatile";
- break;
- case DescriptorRangeFlags::DataStaticWhileSetAtExecute:
- OS << "DataStaticWhileSetAtExecute";
- break;
- case DescriptorRangeFlags::DataStatic:
- OS << "DataStatic";
- break;
- case DescriptorRangeFlags::DescriptorsStaticKeepingBufferBoundsChecks:
- OS << "DescriptorsStaticKeepingBufferBoundsChecks";
- break;
- default:
+ bool Found = false;
+ for (const auto &FlagItem : Flags)
+ if (FlagItem.Value == T(Bit)) {
+ OS << FlagItem.Name;
+ Found = true;
+ break;
+ }
+ if (!Found)
OS << "invalid: " << Bit;
- break;
- }
-
FlagSet = true;
}
Remaining &= ~Bit;
@@ -128,6 +64,68 @@ static raw_ostream &operator<<(raw_ostream &OS,
if (!FlagSet)
OS << "None";
+ return OS;
+}
+
+static const EnumEntry<RegisterType> RegisterNames[] = {
+ {"b", RegisterType::BReg},
+ {"t", RegisterType::TReg},
+ {"u", RegisterType::UReg},
+ {"s", RegisterType::SReg},
+};
+
+static raw_ostream &operator<<(raw_ostream &OS, const Register &Reg) {
+ printEnum(OS, Reg.ViewType, ArrayRef(RegisterNames));
+ OS << Reg.Number;
+
+ return OS;
+}
+
+static const EnumEntry<ShaderVisibility> VisibilityNames[] = {
+ {"All", ShaderVisibility::All},
+ {"Vertex", ShaderVisibility::Vertex},
+ {"Hull", ShaderVisibility::Hull},
+ {"Domain", ShaderVisibility::Domain},
+ {"Geometry", ShaderVisibility::Geometry},
+ {"Pixel", ShaderVisibility::Pixel},
+ {"Amplification", ShaderVisibility::Amplification},
+ {"Mesh", ShaderVisibility::Mesh},
+};
+
+static raw_ostream &operator<<(raw_ostream &OS,
+ const ShaderVisibility &Visibility) {
+ printEnum(OS, Visibility, ArrayRef(VisibilityNames));
+
+ return OS;
+}
+
+static const EnumEntry<dxil::ResourceClass> ResourceClassNames[] = {
+ {"CBV", dxil::ResourceClass::CBuffer},
+ {"SRV", dxil::ResourceClass::SRV},
+ {"UAV", dxil::ResourceClass::UAV},
+ {"Sampler", dxil::ResourceClass::Sampler},
+};
+
+static raw_ostream &operator<<(raw_ostream &OS, const ClauseType &Type) {
+ printEnum(OS, dxil::ResourceClass(llvm::to_underlying(Type)),
+ ArrayRef(ResourceClassNames));
+
+ return OS;
+}
+
+static const EnumEntry<DescriptorRangeFlags> DescriptorRangeFlagNames[] = {
+ {"DescriptorsVolatile", DescriptorRangeFlags::DescriptorsVolatile},
+ {"DataVolatile", DescriptorRangeFlags::DataVolatile},
+ {"DataStaticWhileSetAtExecute",
+ DescriptorRangeFlags::DataStaticWhileSetAtExecute},
+ {"DataStatic", DescriptorRangeFlags::DataStatic},
+ {"DescriptorsStaticKeepingBufferBoundsChecks",
+ DescriptorRangeFlags::DescriptorsStaticKeepingBufferBoundsChecks},
+};
+
+static raw_ostream &operator<<(raw_ostream &OS,
+ const DescriptorRangeFlags &Flags) {
+ printFlags(OS, Flags, ArrayRef(DescriptorRangeFlagNames));
return OS;
}
@@ -236,12 +234,12 @@ MDNode *MetadataBuilder::BuildRootConstants(const RootConstants &Constants) {
MDNode *MetadataBuilder::BuildRootDescriptor(const RootDescriptor &Descriptor) {
IRBuilder<> Builder(Ctx);
- llvm::SmallString<7> Name;
- llvm::raw_svector_ostream OS(Name);
- OS << "Root" << ClauseType(llvm::to_underlying(Descriptor.Type));
-
+ StringRef TypeName =
+ getEnumName(dxil::ResourceClass(llvm::to_underlying(Descriptor.Type)),
+ ArrayRef(ResourceClassNames));
+ llvm::SmallString<7> Name({"Root", TypeName});
Metadata *Operands[] = {
- MDString::get(Ctx, OS.str()),
+ MDString::get(Ctx, Name),
ConstantAsMetadata::get(
Builder.getInt32(llvm::to_underlying(Descriptor.Visibility))),
ConstantAsMetadata::get(Builder.getInt32(Descriptor.Reg.Number)),
@@ -277,19 +275,19 @@ MDNode *MetadataBuilder::BuildDescriptorTable(const DescriptorTable &Table) {
MDNode *MetadataBuilder::BuildDescriptorTableClause(
const DescriptorTableClause &Clause) {
IRBuilder<> Builder(Ctx);
- std::string Name;
- llvm::raw_string_ostream OS(Name);
- OS << Clause.Type;
- return MDNode::get(
- Ctx, {
- MDString::get(Ctx, OS.str()),
- ConstantAsMetadata::get(Builder.getInt32(Clause.NumDescriptors)),
- ConstantAsMetadata::get(Builder.getInt32(Clause.Reg.Number)),
- ConstantAsMetadata::get(Builder.getInt32(Clause.Space)),
- ConstantAsMetadata::get(Builder.getInt32(Clause.Offset)),
- ConstantAsMetadata::get(
- Builder.getInt32(llvm::to_underlying(Clause.Flags))),
- });
+ StringRef Name =
+ getEnumName(dxil::ResourceClass(llvm::to_underlying(Clause.Type)),
+ ArrayRef(ResourceClassNames));
+ Metadata *Operands[] = {
+ MDString::get(Ctx, Name),
+ ConstantAsMetadata::get(Builder.getInt32(Clause.NumDescriptors)),
+ ConstantAsMetadata::get(Builder.getInt32(Clause.Reg.Number)),
+ ConstantAsMetadata::get(Builder.getInt32(Clause.Space)),
+ ConstantAsMetadata::get(Builder.getInt32(Clause.Offset)),
+ ConstantAsMetadata::get(
+ Builder.getInt32(llvm::to_underlying(Clause.Flags))),
+ };
+ return MDNode::get(Ctx, Operands);
}
MDNode *MetadataBuilder::BuildStaticSampler(const StaticSampler &Sampler) {
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for addressing my suggestions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Can't approve though.
It has pointed out here that we may be able to use
llvm::EnumEntry
so that we can re-use the printing logic across enumerations.Enables re-use of
printEnum
andprintFlags
methods via templatesAllows easy definition of
getEnumName
function for enum-to-string conversion, eliminating the need to use a string stream for constructing the Name SmallStringAlso, does a small fix-up of the operands for descriptor table clause to be consistent with other
Build*
methodsFor reference, the test-cases that must not change expected output.