-
Notifications
You must be signed in to change notification settings - Fork 14k
[DirectX] Improve error accumulation in root signature parsing #144465
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: users/joaosaffran/144577
Are you sure you want to change the base?
[DirectX] Improve error accumulation in root signature parsing #144465
Conversation
✅ With the latest revision this PR passed the C/C++ code formatter. |
e62419f
to
24f38bd
Compare
@llvm/pr-subscribers-backend-directx Author: None (joaosaffran) ChangesThis patch enhances error handling in the DirectX backend's root signature
Before this change, the parser would stop at the first error encountered. Now it Example of changes: bool HasError = false;
if (std::optional<uint32_t> Val = extractMdIntValue(RootFlagNode, 1))
RSD.Flags = *Val;
else
HasError = HasError || reportInvalidTypeError<ConstantInt>(
Ctx, "RootFlagNode", RootFlagNode, 1);
return HasError; Testing:
Patch is 25.66 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144465.diff 2 Files Affected:
diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
index 57d5ee8ac467c..eea46e714b756 100644
--- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp
+++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
@@ -141,14 +141,15 @@ static bool parseRootFlags(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
if (RootFlagNode->getNumOperands() != 2)
return reportError(Ctx, "Invalid format for RootFlag Element");
-
+ bool HasError = false;
if (std::optional<uint32_t> Val = extractMdIntValue(RootFlagNode, 1))
RSD.Flags = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "RootFlagNode",
- RootFlagNode, 1);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootFlagNode",
+ RootFlagNode, 1) ||
+ HasError;
- return false;
+ return HasError;
}
static bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
@@ -157,6 +158,7 @@ static bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
if (RootConstantNode->getNumOperands() != 5)
return reportError(Ctx, "Invalid format for RootConstants Element");
+ bool HasError = false;
dxbc::RTS0::v1::RootParameterHeader Header;
// The parameter offset doesn't matter here - we recalculate it during
// serialization Header.ParameterOffset = 0;
@@ -166,31 +168,35 @@ static bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1))
Header.ShaderVisibility = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode",
- RootConstantNode, 1);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode",
+ RootConstantNode, 1) ||
+ HasError;
dxbc::RTS0::v1::RootConstants Constants;
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2))
Constants.ShaderRegister = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode",
- RootConstantNode, 2);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode",
+ RootConstantNode, 2) ||
+ HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 3))
Constants.RegisterSpace = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode",
- RootConstantNode, 3);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode",
+ RootConstantNode, 3) ||
+ HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 4))
Constants.Num32BitValues = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode",
- RootConstantNode, 4);
-
- RSD.ParametersContainer.addParameter(Header, Constants);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode",
+ RootConstantNode, 4) ||
+ HasError;
+ if (!HasError)
+ RSD.ParametersContainer.addParameter(Header, Constants);
- return false;
+ return HasError;
}
static bool parseRootDescriptors(LLVMContext *Ctx,
@@ -205,6 +211,7 @@ static bool parseRootDescriptors(LLVMContext *Ctx,
if (RootDescriptorNode->getNumOperands() != 5)
return reportError(Ctx, "Invalid format for Root Descriptor Element");
+ bool HasError = false;
dxbc::RTS0::v1::RootParameterHeader Header;
switch (ElementKind) {
case RootSignatureElementKind::SRV:
@@ -224,36 +231,41 @@ static bool parseRootDescriptors(LLVMContext *Ctx,
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 1))
Header.ShaderVisibility = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode",
- RootDescriptorNode, 1);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode",
+ RootDescriptorNode, 1) ||
+ HasError;
dxbc::RTS0::v2::RootDescriptor Descriptor;
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2))
Descriptor.ShaderRegister = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode",
- RootDescriptorNode, 2);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode",
+ RootDescriptorNode, 2) ||
+ HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 3))
Descriptor.RegisterSpace = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode",
- RootDescriptorNode, 3);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode",
+ RootDescriptorNode, 3) ||
+ HasError;
if (RSD.Version == 1) {
- RSD.ParametersContainer.addParameter(Header, Descriptor);
- return false;
+ if (!HasError)
+ RSD.ParametersContainer.addParameter(Header, Descriptor);
+ return HasError;
}
assert(RSD.Version > 1);
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 4))
Descriptor.Flags = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode",
- RootDescriptorNode, 4);
-
- RSD.ParametersContainer.addParameter(Header, Descriptor);
- return false;
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode",
+ RootDescriptorNode, 4) ||
+ HasError;
+ if (!HasError)
+ RSD.ParametersContainer.addParameter(Header, Descriptor);
+ return HasError;
}
static bool parseDescriptorRange(LLVMContext *Ctx,
@@ -264,14 +276,16 @@ static bool parseDescriptorRange(LLVMContext *Ctx,
if (RangeDescriptorNode->getNumOperands() != 6)
return reportError(Ctx, "Invalid format for Descriptor Range");
+ bool HasError = false;
dxbc::RTS0::v2::DescriptorRange Range;
std::optional<StringRef> ElementText =
extractMdStringValue(RangeDescriptorNode, 0);
if (!ElementText.has_value())
- return reportInvalidTypeError<MDString>(Ctx, "RangeDescriptorNode",
- RangeDescriptorNode, 0);
+ HasError = reportInvalidTypeError<MDString>(Ctx, "RangeDescriptorNode",
+ RangeDescriptorNode, 0) ||
+ HasError;
Range.RangeType =
StringSwitch<uint32_t>(*ElementText)
@@ -283,40 +297,47 @@ static bool parseDescriptorRange(LLVMContext *Ctx,
.Default(-1u);
if (Range.RangeType == -1u)
- return reportError(Ctx, "Invalid Descriptor Range type: " + *ElementText);
+ HasError =
+ reportError(Ctx, "Invalid Descriptor Range type: " + *ElementText) ||
+ HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 1))
Range.NumDescriptors = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
- RangeDescriptorNode, 1);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
+ RangeDescriptorNode, 1) ||
+ HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 2))
Range.BaseShaderRegister = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
- RangeDescriptorNode, 2);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
+ RangeDescriptorNode, 2) ||
+ HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 3))
Range.RegisterSpace = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
- RangeDescriptorNode, 3);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
+ RangeDescriptorNode, 3) ||
+ HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 4))
Range.OffsetInDescriptorsFromTableStart = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
- RangeDescriptorNode, 4);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
+ RangeDescriptorNode, 4) ||
+ HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 5))
Range.Flags = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
- RangeDescriptorNode, 5);
-
- Table.Ranges.push_back(Range);
- return false;
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
+ RangeDescriptorNode, 5) ||
+ HasError;
+ if (!HasError)
+ Table.Ranges.push_back(Range);
+ return HasError;
}
static bool parseDescriptorTable(LLVMContext *Ctx,
@@ -325,13 +346,14 @@ static bool parseDescriptorTable(LLVMContext *Ctx,
const unsigned int NumOperands = DescriptorTableNode->getNumOperands();
if (NumOperands < 2)
return reportError(Ctx, "Invalid format for Descriptor Table");
-
+ bool HasError = false;
dxbc::RTS0::v1::RootParameterHeader Header;
if (std::optional<uint32_t> Val = extractMdIntValue(DescriptorTableNode, 1))
Header.ShaderVisibility = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "DescriptorTableNode",
- DescriptorTableNode, 1);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "DescriptorTableNode",
+ DescriptorTableNode, 1) ||
+ HasError;
mcdxbc::DescriptorTable Table;
Header.ParameterType =
@@ -340,15 +362,16 @@ static bool parseDescriptorTable(LLVMContext *Ctx,
for (unsigned int I = 2; I < NumOperands; I++) {
MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand(I));
if (Element == nullptr)
- return reportInvalidTypeError<MDNode>(Ctx, "DescriptorTableNode",
- DescriptorTableNode, I);
+ HasError = reportInvalidTypeError<MDNode>(Ctx, "DescriptorTableNode",
+ DescriptorTableNode, I) ||
+ HasError;
if (parseDescriptorRange(Ctx, RSD, Table, Element))
- return true;
+ HasError = true || HasError;
}
-
- RSD.ParametersContainer.addParameter(Header, Table);
- return false;
+ if (!HasError)
+ RSD.ParametersContainer.addParameter(Header, Table);
+ return HasError;
}
static bool parseStaticSampler(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
@@ -356,87 +379,101 @@ static bool parseStaticSampler(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
if (StaticSamplerNode->getNumOperands() != 14)
return reportError(Ctx, "Invalid format for Static Sampler");
+ bool HasError = false;
dxbc::RTS0::v1::StaticSampler Sampler;
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 1))
Sampler.Filter = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
- StaticSamplerNode, 1);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 1) ||
+ HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 2))
Sampler.AddressU = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
- StaticSamplerNode, 2);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 2) ||
+ HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 3))
Sampler.AddressV = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
- StaticSamplerNode, 3);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 3) ||
+ HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 4))
Sampler.AddressW = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
- StaticSamplerNode, 4);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 4) ||
+ HasError;
if (std::optional<APFloat> Val = extractMdFloatValue(StaticSamplerNode, 5))
Sampler.MipLODBias = Val->convertToFloat();
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
- StaticSamplerNode, 5);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 5) ||
+ HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 6))
Sampler.MaxAnisotropy = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
- StaticSamplerNode, 6);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 6) ||
+ HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 7))
Sampler.ComparisonFunc = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
- StaticSamplerNode, 7);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 7) ||
+ HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 8))
Sampler.BorderColor = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
- StaticSamplerNode, 8);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 8) ||
+ HasError;
if (std::optional<APFloat> Val = extractMdFloatValue(StaticSamplerNode, 9))
Sampler.MinLOD = Val->convertToFloat();
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
- StaticSamplerNode, 9);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 9) ||
+ HasError;
if (std::optional<APFloat> Val = extractMdFloatValue(StaticSamplerNode, 10))
Sampler.MaxLOD = Val->convertToFloat();
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
- StaticSamplerNode, 10);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 10) ||
+ HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 11))
Sampler.ShaderRegister = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
- StaticSamplerNode, 11);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 11) ||
+ HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 12))
Sampler.RegisterSpace = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
- StaticSamplerNode, 12);
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 12) ||
+ HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 13))
Sampler.ShaderVisibility = *Val;
else
- return reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
- StaticSamplerNode, 13);
-
- RSD.StaticSamplers.push_back(Sampler);
- return false;
+ HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
+ StaticSamplerNode, 13) ||
+ HasError;
+ if (!HasError)
+ RSD.StaticSamplers.push_back(Sampler);
+ return HasError;
}
static bool parseRootSignatureElement(LLVMContext *Ctx,
@@ -488,7 +525,7 @@ static bool parse(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
if (Element == nullptr)
return reportError(Ctx, "Missing Root Element Metadata Node.");
- HasError = HasError || parseRootSignatureElement(Ctx, RSD, Element);
+ HasError = parseRootSignatureElement(Ctx, RSD, Element) || HasError;
}
return HasError;
@@ -699,19 +736,20 @@ static bool verifyBorderColor(uint32_t BorderColor) {
static bool verifyLOD(float LOD) { return !std::isnan(LOD); }
static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) {
-
+ bool HasError = false;
if (!verifyVersion(RSD.Version)) {
- return reportValueError(Ctx, "Version", RSD.Version);
+ HasError = reportValueError(Ctx, "Version", RSD.Version) || HasError;
}
if (!verifyRootFlag(RSD.Flags)) {
- return reportValueError(Ctx, "RootFlags", RSD.Flags);
+ HasError = reportValueError(Ctx, "RootFlags", RSD.Flags) || HasError;
}
for (const mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer) {
if (!dxbc::isValidShaderVisibility(Info.Header.ShaderVisibility))
- return reportValueError(Ctx, "ShaderVisibility",
- Info.Header.ShaderVisibility);
+ HasError = reportValueError(Ctx, "ShaderVisibility",
+ Info.Header.ShaderVisibility) ||
+ HasError;
assert(dxbc::isValidParameterType(Info.Header.ParameterType) &&
"Invalid value for ParameterType");
@@ -724,15 +762,20 @@ static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) {
const dxbc::RTS0::v2::RootDescriptor &Descriptor =
RSD.ParametersContainer.getRootDescriptor(Info.Location);
if (!verifyRegisterValue(Descriptor.ShaderRegister))
- return reportValueError(Ctx, "ShaderRegister",
- Descriptor.ShaderRegister);
+ HasError = reportValueError(Ctx, "ShaderRegister",
+ Descriptor.ShaderRegister) ||
+ HasError;
if (!verifyRegisterSpace(Descriptor.RegisterSpace))
- return reportValueError(Ctx, "RegisterSpace", Descriptor.RegisterSpace);
+ HasError =
+ reportValueError(Ctx, "RegisterSpace", Descriptor.RegisterSpace) ||
+ HasError;
if (RSD.Version > 1)...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess the awkward part is that it might be nice to clean up all the current error tests that are spread across many files to just one for each param type.
Up to you, if we want to go about doing so.
This patch enhances error handling in the DirectX backend's root signature
parsing, specifically in DXILRootSignature.cpp. The changes include:
Modify error handling to accumulate errors:
Fix root flag parsing:
Before this change, the parser would stop at the first error encountered. Now it
continues validation, collecting all errors before returning. This provides a better
developer experience by showing all issues that need to be fixed at once.
Example of changes:
Testing: