Skip to content

Commit 1e09d06

Browse files
inbelicsivan-shani
authored andcommitted
[HLSL][RootSignature] Add parsing of floats for StaticSampler (llvm#140181)
- defines in-memory representaiton of MipLODBias to allow for testing of a float parameter - defines `handleInt` and `handleFloat` to handle converting a token's `NumSpelling` into a valid float - plugs this into `parseFloatParam` to fill in the MipLODBias param The parsing of floats is required to match the behaviour of DXC. This behaviour is outlined as follows: - if the number is an integer then convert it using `_atoi64`, check for overflow and static_cast this to a float - if the number is a float then convert it using `strtod`, check for float overflow and static_cast this to a float, this will implicitly also check for double over/underflow and if the string is malformed then it will return an error This pr matches this behaviour by parsing as, uint/int accordingly and then casting, or, by using the correct APFloat semantics/rounding mode with `NumericLiteralParser`. - adds testing of error diagnostics and valid float param values to demonstrate functionality Part 2 of llvm#126574
1 parent f94d6a5 commit 1e09d06

File tree

7 files changed

+376
-6
lines changed

7 files changed

+376
-6
lines changed

clang/include/clang/Basic/DiagnosticParseKinds.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1856,7 +1856,11 @@ def err_hlsl_unexpected_end_of_params
18561856
: Error<"expected %0 to denote end of parameters, or, another valid parameter of %1">;
18571857
def err_hlsl_rootsig_repeat_param : Error<"specified the same parameter '%0' multiple times">;
18581858
def err_hlsl_rootsig_missing_param : Error<"did not specify mandatory parameter '%0'">;
1859-
def err_hlsl_number_literal_overflow : Error<"integer literal is too large to be represented as a 32-bit %select{signed |}0 integer type">;
1859+
def err_hlsl_number_literal_overflow : Error<
1860+
"%select{integer|float}0 literal is too large to be represented as a "
1861+
"%select{32-bit %select{signed|}1 integer|float}0 type">;
1862+
def err_hlsl_number_literal_underflow : Error<
1863+
"float literal has a magnitude that is too small to be represented as a float type">;
18601864
def err_hlsl_rootsig_non_zero_flag : Error<"flag value is neither a literal 0 nor a named value">;
18611865

18621866
} // end of Parser diagnostics

clang/include/clang/Lex/HLSLRootSignatureTokenKinds.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ KEYWORD(flags)
100100
KEYWORD(numDescriptors)
101101
KEYWORD(offset)
102102

103+
// StaticSampler Keywords:
104+
KEYWORD(mipLODBias)
105+
103106
// Unbounded Enum:
104107
UNBOUNDED_ENUM(unbounded, "unbounded")
105108

clang/include/clang/Parse/ParseHLSLRootSignature.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,14 @@ class RootSignatureParser {
111111

112112
struct ParsedStaticSamplerParams {
113113
std::optional<llvm::hlsl::rootsig::Register> Reg;
114+
std::optional<float> MipLODBias;
114115
};
115116
std::optional<ParsedStaticSamplerParams> parseStaticSamplerParams();
116117

117118
// Common parsing methods
118119
std::optional<uint32_t> parseUIntParam();
119120
std::optional<llvm::hlsl::rootsig::Register> parseRegister();
121+
std::optional<float> parseFloatParam();
120122

121123
/// Parsing methods of various enums
122124
std::optional<llvm::hlsl::rootsig::ShaderVisibility> parseShaderVisibility();
@@ -128,6 +130,19 @@ class RootSignatureParser {
128130
/// Use NumericLiteralParser to convert CurToken.NumSpelling into a unsigned
129131
/// 32-bit integer
130132
std::optional<uint32_t> handleUIntLiteral();
133+
/// Use NumericLiteralParser to convert CurToken.NumSpelling into a signed
134+
/// 32-bit integer
135+
std::optional<int32_t> handleIntLiteral(bool Negated);
136+
/// Use NumericLiteralParser to convert CurToken.NumSpelling into a float
137+
///
138+
/// This matches the behaviour of DXC, which is as follows:
139+
/// - convert the spelling with `strtod`
140+
/// - check for a float overflow
141+
/// - cast the double to a float
142+
/// The behaviour of `strtod` is replicated using:
143+
/// Semantics: llvm::APFloat::Semantics::S_IEEEdouble
144+
/// RoundingMode: llvm::RoundingMode::NearestTiesToEven
145+
std::optional<float> handleFloatLiteral(bool Negated);
131146

132147
/// Flags may specify the value of '0' to denote that there should be no
133148
/// flags set.

clang/lib/Parse/ParseHLSLRootSignature.cpp

Lines changed: 157 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,10 @@ std::optional<StaticSampler> RootSignatureParser::parseStaticSampler() {
376376

377377
Sampler.Reg = Params->Reg.value();
378378

379+
// Fill in optional values
380+
if (Params->MipLODBias.has_value())
381+
Sampler.MipLODBias = Params->MipLODBias.value();
382+
379383
if (consumeExpectedToken(TokenKind::pu_r_paren,
380384
diag::err_hlsl_unexpected_end_of_params,
381385
/*param of=*/TokenKind::kw_StaticSampler))
@@ -661,6 +665,23 @@ RootSignatureParser::parseStaticSamplerParams() {
661665
return std::nullopt;
662666
Params.Reg = Reg;
663667
}
668+
669+
// `mipLODBias` `=` NUMBER
670+
if (tryConsumeExpectedToken(TokenKind::kw_mipLODBias)) {
671+
if (Params.MipLODBias.has_value()) {
672+
getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
673+
<< CurToken.TokKind;
674+
return std::nullopt;
675+
}
676+
677+
if (consumeExpectedToken(TokenKind::pu_equal))
678+
return std::nullopt;
679+
680+
auto MipLODBias = parseFloatParam();
681+
if (!MipLODBias.has_value())
682+
return std::nullopt;
683+
Params.MipLODBias = MipLODBias;
684+
}
664685
} while (tryConsumeExpectedToken(TokenKind::pu_comma));
665686

666687
return Params;
@@ -709,6 +730,39 @@ std::optional<Register> RootSignatureParser::parseRegister() {
709730
return Reg;
710731
}
711732

733+
std::optional<float> RootSignatureParser::parseFloatParam() {
734+
assert(CurToken.TokKind == TokenKind::pu_equal &&
735+
"Expects to only be invoked starting at given keyword");
736+
// Consume sign modifier
737+
bool Signed =
738+
tryConsumeExpectedToken({TokenKind::pu_plus, TokenKind::pu_minus});
739+
bool Negated = Signed && CurToken.TokKind == TokenKind::pu_minus;
740+
741+
// DXC will treat a postive signed integer as unsigned
742+
if (!Negated && tryConsumeExpectedToken(TokenKind::int_literal)) {
743+
std::optional<uint32_t> UInt = handleUIntLiteral();
744+
if (!UInt.has_value())
745+
return std::nullopt;
746+
return float(UInt.value());
747+
}
748+
749+
if (Negated && tryConsumeExpectedToken(TokenKind::int_literal)) {
750+
std::optional<int32_t> Int = handleIntLiteral(Negated);
751+
if (!Int.has_value())
752+
return std::nullopt;
753+
return float(Int.value());
754+
}
755+
756+
if (tryConsumeExpectedToken(TokenKind::float_literal)) {
757+
std::optional<float> Float = handleFloatLiteral(Negated);
758+
if (!Float.has_value())
759+
return std::nullopt;
760+
return Float.value();
761+
}
762+
763+
return std::nullopt;
764+
}
765+
712766
std::optional<llvm::hlsl::rootsig::ShaderVisibility>
713767
RootSignatureParser::parseShaderVisibility() {
714768
assert(CurToken.TokKind == TokenKind::pu_equal &&
@@ -819,22 +873,121 @@ std::optional<uint32_t> RootSignatureParser::handleUIntLiteral() {
819873
PP.getSourceManager(), PP.getLangOpts(),
820874
PP.getTargetInfo(), PP.getDiagnostics());
821875
if (Literal.hadError)
822-
return true; // Error has already been reported so just return
876+
return std::nullopt; // Error has already been reported so just return
823877

824-
assert(Literal.isIntegerLiteral() && "IsNumberChar will only support digits");
878+
assert(Literal.isIntegerLiteral() &&
879+
"NumSpelling can only consist of digits");
825880

826-
llvm::APSInt Val = llvm::APSInt(32, false);
881+
llvm::APSInt Val(32, /*IsUnsigned=*/true);
827882
if (Literal.GetIntegerValue(Val)) {
828883
// Report that the value has overflowed
829884
PP.getDiagnostics().Report(CurToken.TokLoc,
830885
diag::err_hlsl_number_literal_overflow)
831-
<< 0 << CurToken.NumSpelling;
886+
<< /*integer type*/ 0 << /*is signed*/ 0;
832887
return std::nullopt;
833888
}
834889

835890
return Val.getExtValue();
836891
}
837892

893+
std::optional<int32_t> RootSignatureParser::handleIntLiteral(bool Negated) {
894+
// Parse the numeric value and do semantic checks on its specification
895+
clang::NumericLiteralParser Literal(CurToken.NumSpelling, CurToken.TokLoc,
896+
PP.getSourceManager(), PP.getLangOpts(),
897+
PP.getTargetInfo(), PP.getDiagnostics());
898+
if (Literal.hadError)
899+
return std::nullopt; // Error has already been reported so just return
900+
901+
assert(Literal.isIntegerLiteral() &&
902+
"NumSpelling can only consist of digits");
903+
904+
llvm::APSInt Val(32, /*IsUnsigned=*/true);
905+
// GetIntegerValue will overwrite Val from the parsed Literal and return
906+
// true if it overflows as a 32-bit unsigned int
907+
bool Overflowed = Literal.GetIntegerValue(Val);
908+
909+
// So we then need to check that it doesn't overflow as a 32-bit signed int:
910+
int64_t MaxNegativeMagnitude = -int64_t(std::numeric_limits<int32_t>::min());
911+
Overflowed |= (Negated && MaxNegativeMagnitude < Val.getExtValue());
912+
913+
int64_t MaxPositiveMagnitude = int64_t(std::numeric_limits<int32_t>::max());
914+
Overflowed |= (!Negated && MaxPositiveMagnitude < Val.getExtValue());
915+
916+
if (Overflowed) {
917+
// Report that the value has overflowed
918+
PP.getDiagnostics().Report(CurToken.TokLoc,
919+
diag::err_hlsl_number_literal_overflow)
920+
<< /*integer type*/ 0 << /*is signed*/ 1;
921+
return std::nullopt;
922+
}
923+
924+
if (Negated)
925+
Val = -Val;
926+
927+
return int32_t(Val.getExtValue());
928+
}
929+
930+
std::optional<float> RootSignatureParser::handleFloatLiteral(bool Negated) {
931+
// Parse the numeric value and do semantic checks on its specification
932+
clang::NumericLiteralParser Literal(CurToken.NumSpelling, CurToken.TokLoc,
933+
PP.getSourceManager(), PP.getLangOpts(),
934+
PP.getTargetInfo(), PP.getDiagnostics());
935+
if (Literal.hadError)
936+
return std::nullopt; // Error has already been reported so just return
937+
938+
assert(Literal.isFloatingLiteral() &&
939+
"NumSpelling consists only of [0-9.ef+-]. Any malformed NumSpelling "
940+
"will be caught and reported by NumericLiteralParser.");
941+
942+
// DXC used `strtod` to convert the token string to a float which corresponds
943+
// to:
944+
auto DXCSemantics = llvm::APFloat::Semantics::S_IEEEdouble;
945+
auto DXCRoundingMode = llvm::RoundingMode::NearestTiesToEven;
946+
947+
llvm::APFloat Val(llvm::APFloat::EnumToSemantics(DXCSemantics));
948+
llvm::APFloat::opStatus Status(Literal.GetFloatValue(Val, DXCRoundingMode));
949+
950+
// Note: we do not error when opStatus::opInexact by itself as this just
951+
// denotes that rounding occured but not that it is invalid
952+
assert(!(Status & llvm::APFloat::opStatus::opInvalidOp) &&
953+
"NumSpelling consists only of [0-9.ef+-]. Any malformed NumSpelling "
954+
"will be caught and reported by NumericLiteralParser.");
955+
956+
assert(!(Status & llvm::APFloat::opStatus::opDivByZero) &&
957+
"It is not possible for a division to be performed when "
958+
"constructing an APFloat from a string");
959+
960+
if (Status & llvm::APFloat::opStatus::opUnderflow) {
961+
// Report that the value has underflowed
962+
PP.getDiagnostics().Report(CurToken.TokLoc,
963+
diag::err_hlsl_number_literal_underflow);
964+
return std::nullopt;
965+
}
966+
967+
if (Status & llvm::APFloat::opStatus::opOverflow) {
968+
// Report that the value has overflowed
969+
PP.getDiagnostics().Report(CurToken.TokLoc,
970+
diag::err_hlsl_number_literal_overflow)
971+
<< /*float type*/ 1;
972+
return std::nullopt;
973+
}
974+
975+
if (Negated)
976+
Val = -Val;
977+
978+
double DoubleVal = Val.convertToDouble();
979+
double FloatMax = double(std::numeric_limits<float>::max());
980+
if (FloatMax < DoubleVal || DoubleVal < -FloatMax) {
981+
// Report that the value has overflowed
982+
PP.getDiagnostics().Report(CurToken.TokLoc,
983+
diag::err_hlsl_number_literal_overflow)
984+
<< /*float type*/ 1;
985+
return std::nullopt;
986+
}
987+
988+
return static_cast<float>(DoubleVal);
989+
}
990+
838991
bool RootSignatureParser::verifyZeroFlag() {
839992
assert(CurToken.TokKind == TokenKind::int_literal);
840993
auto X = handleUIntLiteral();

clang/unittests/Lex/LexHLSLRootSignatureTest.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ TEST_F(LexHLSLRootSignatureTest, ValidLexAllTokensTest) {
136136
space visibility flags
137137
numDescriptors offset
138138
139+
mipLODBias
140+
139141
unbounded
140142
DESCRIPTOR_RANGE_OFFSET_APPEND
141143

0 commit comments

Comments
 (0)