Skip to content

Commit

Permalink
[WGSL] Add validation to integer modulo
Browse files Browse the repository at this point in the history
https://bugs.webkit.org/show_bug.cgi?id=264603
rdar://118239748

Reviewed by Mike Wyrzykowski.

Similar to division, we need to check for modulo with zero or
INT_MIN and -1.

* Source/WebGPU/WGSL/ConstantFunctions.h:
(WGSL::BINARY_OPERATION):
* Source/WebGPU/WGSL/Metal/MetalFunctionWriter.cpp:
(WGSL::Metal::FunctionDefinitionWriter::emitNecessaryHelpers):
(WGSL::Metal::FunctionDefinitionWriter::visit):
* Source/WebGPU/WGSL/TypeCheck.cpp:
(WGSL::TypeChecker::visit):
* Source/WebGPU/WGSL/WGSLShaderModule.h:
(WGSL::ShaderModule::usesModulo const):
(WGSL::ShaderModule::setUsesModulo):
* Source/WebGPU/WGSL/tests/invalid/modulo.wgsl: Added.

Canonical link: https://commits.webkit.org/270636@main
  • Loading branch information
tadeuzagallo committed Nov 13, 2023
1 parent b92ebbd commit e99d13b
Show file tree
Hide file tree
Showing 5 changed files with 280 additions and 14 deletions.
15 changes: 11 additions & 4 deletions Source/WebGPU/WGSL/ConstantFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -547,11 +547,18 @@ BINARY_OPERATION(Divide, Number, [&]<typename T>(T left, T right) -> ConstantRes
return { { left / right } };
});

BINARY_OPERATION(Modulo, Number, [&](auto left, auto right) {
BINARY_OPERATION(Modulo, Number, [&]<typename T>(T left, T right) -> ConstantResult {
if constexpr (std::is_floating_point_v<decltype(left)>)
return fmod(left, right);
else
return left % right;
return { { fmod(left, right) } };
else {
if (!right)
return makeUnexpected("invalid modulo by zero"_s);
if constexpr (std::is_signed_v<T>) {
if (left == std::numeric_limits<T>::lowest() && right == -1)
return makeUnexpected("invalid modulo overflow"_s);
}
return { { left % right } };
}
});

// Comparison Operations
Expand Down
37 changes: 33 additions & 4 deletions Source/WebGPU/WGSL/Metal/MetalFunctionWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,23 @@ void FunctionDefinitionWriter::emitNecessaryHelpers()
m_stringBuilder.append(m_indent, "}\n\n");
}

if (m_callGraph.ast().usesModulo()) {
m_stringBuilder.append(m_indent, "template<typename T, typename U, typename V = conditional_t<is_scalar_v<U>, T, U>>\n");
m_stringBuilder.append(m_indent, "V __wgslMod(T lhs, U rhs)\n");
m_stringBuilder.append(m_indent, "{\n");
{
IndentationScope scope(m_indent);
m_stringBuilder.append(m_indent, "auto predicate = V(rhs) == V(0);\n");
m_stringBuilder.append(m_indent, "if constexpr (is_signed_v<U>)\n");
{
IndentationScope scope(m_indent);
m_stringBuilder.append(m_indent, "predicate = predicate || (V(lhs) == V(numeric_limits<T>::lowest()) && V(rhs) == V(-1));\n");
}
m_stringBuilder.append(m_indent, "return select(lhs % V(rhs), V(0), predicate);\n");
}
m_stringBuilder.append(m_indent, "}\n\n");
}


if (m_callGraph.ast().usesFrexp()) {
m_stringBuilder.append(m_indent, "template<typename T, typename U>\n");
Expand Down Expand Up @@ -1666,12 +1683,18 @@ void FunctionDefinitionWriter::visit(AST::BinaryExpression& binary)
}
}

if (binary.operation() == AST::BinaryOperation::Divide) {
const char* helperFunction = nullptr;
if (binary.operation() == AST::BinaryOperation::Divide)
helperFunction = "__wgslDiv";
else if (binary.operation() == AST::BinaryOperation::Modulo)
helperFunction = "__wgslMod";

if (helperFunction) {
auto* resultType = binary.inferredType();
if (auto* vectorType = std::get_if<Types::Vector>(resultType))
resultType = vectorType->element;
if (satisfies(resultType, Constraints::Integer)) {
m_stringBuilder.append("__wgslDiv(");
m_stringBuilder.append(helperFunction, "(");
visit(binary.leftExpression());
m_stringBuilder.append(", ");
visit(binary.rightExpression());
Expand Down Expand Up @@ -1837,13 +1860,19 @@ void FunctionDefinitionWriter::visit(AST::CallStatement& statement)

void FunctionDefinitionWriter::visit(AST::CompoundAssignmentStatement& statement)
{
if (statement.operation() == AST::BinaryOperation::Divide) {
const char* helperFunction = nullptr;
if (statement.operation() == AST::BinaryOperation::Divide)
helperFunction = "__wgslDiv";
else if (statement.operation() == AST::BinaryOperation::Modulo)
helperFunction = "__wgslMod";

if (helperFunction) {
auto* rightType = statement.rightExpression().inferredType();
if (auto* vectorType = std::get_if<Types::Vector>(rightType))
rightType = vectorType->element;
if (satisfies(rightType, Constraints::Integer)) {
visit(statement.leftExpression());
m_stringBuilder.append(" = __wgslDiv(");
m_stringBuilder.append(" = ", helperFunction, "(");
visit(statement.leftExpression());
m_stringBuilder.append(", ");
visit(statement.rightExpression());
Expand Down
28 changes: 22 additions & 6 deletions Source/WebGPU/WGSL/TypeCheck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -576,15 +576,23 @@ void TypeChecker::visit(AST::CompoundAssignmentStatement& statement)
infer(statement.leftExpression());
infer(statement.rightExpression());

if (statement.operation() == AST::BinaryOperation::Divide) {
const char* operationName = nullptr;
if (statement.operation() == AST::BinaryOperation::Divide)
operationName = "division";
else if (statement.operation() == AST::BinaryOperation::Modulo)
operationName = "modulo";
if (operationName) {
auto* rightType = statement.rightExpression().inferredType();
if (auto* vectorType = std::get_if<Types::Vector>(rightType))
rightType = vectorType->element;
if (satisfies(rightType, Constraints::Integer)) {
m_shaderModule.setUsesDivision();
if (statement.operation() == AST::BinaryOperation::Divide)
m_shaderModule.setUsesDivision();
else
m_shaderModule.setUsesModulo();
auto rightValue = statement.rightExpression().constantValue();
if (rightValue && containsZero(*rightValue, statement.rightExpression().inferredType()))
typeError(InferBottom::No, statement.span(), "invalid division by zero");
typeError(InferBottom::No, statement.span(), "invalid ", operationName, " by zero");
}
}
}
Expand Down Expand Up @@ -858,16 +866,24 @@ void TypeChecker::visit(AST::BinaryExpression& binary)
{
chooseOverload("operator", binary, toString(binary.operation()), ReferenceWrapperVector<AST::Expression, 2> { binary.leftExpression(), binary.rightExpression() }, { });

if (binary.operation() == AST::BinaryOperation::Divide) {
const char* operationName = nullptr;
if (binary.operation() == AST::BinaryOperation::Divide)
operationName = "division";
else if (binary.operation() == AST::BinaryOperation::Modulo)
operationName = "modulo";
if (operationName) {
auto* rightType = binary.rightExpression().inferredType();
if (auto* vectorType = std::get_if<Types::Vector>(rightType))
rightType = vectorType->element;
if (satisfies(rightType, Constraints::Integer)) {
m_shaderModule.setUsesDivision();
if (binary.operation() == AST::BinaryOperation::Divide)
m_shaderModule.setUsesDivision();
else
m_shaderModule.setUsesModulo();
auto leftValue = binary.leftExpression().constantValue();
auto rightValue = binary.rightExpression().constantValue();
if (!leftValue && rightValue && containsZero(*rightValue, binary.rightExpression().inferredType()))
typeError(InferBottom::No, binary.span(), "invalid division by zero");
typeError(InferBottom::No, binary.span(), "invalid ", operationName, " by zero");
}
}
}
Expand Down
4 changes: 4 additions & 0 deletions Source/WebGPU/WGSL/WGSLShaderModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ class ShaderModule {
bool usesDivision() const { return m_usesDivision; }
void setUsesDivision() { m_usesDivision = true; }

bool usesModulo() const { return m_usesModulo; }
void setUsesModulo() { m_usesModulo = true; }

bool usesFrexp() const { return m_usesFrexp; }
void setUsesFrexp() { m_usesFrexp = true; }

Expand Down Expand Up @@ -224,6 +227,7 @@ class ShaderModule {
bool m_usesUnpackArray { false };
bool m_usesWorkgroupUniformLoad { false };
bool m_usesDivision { false };
bool m_usesModulo { false };
bool m_usesFrexp { false };
Configuration m_configuration;
AST::Directive::List m_directives;
Expand Down
210 changes: 210 additions & 0 deletions Source/WebGPU/WGSL/tests/invalid/modulo.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
// RUN: %not %wgslc | %check

fn testAbstractInt()
{
// CHECK-L: modulo by zero
_ = 42 % 0;
// CHECK-L: modulo by zero
_ = 42 % vec2(1, 0);
// CHECK-L: modulo by zero
_ = vec2(42) % 0;
// CHECK-L: modulo by zero
_ = vec2(42) % vec2(1, 0);

let x = 42;
// CHECK-L: modulo by zero
_ = x % 0;
// CHECK-L: modulo by zero
_ = x % vec2(1, 0);
// CHECK-L: modulo by zero
_ = vec2(x) % 0;
// CHECK-L: modulo by zero
_ = vec2(x) % vec2(1, 0);

// CHECK-NOT-L: invalid modulo overflow
_ = (-2147483647 - 1) % -1;
// CHECK-NOT-L: invalid modulo overflow
_ = vec2(-2147483647 - 1, 1) % vec2(-1, 1);

// CHECK-L: invalid modulo overflow
_ = (-9223372036854775807 - 1) % -1;
// CHECK-L: invalid modulo overflow
_ = vec2(-9223372036854775807 - 1, 1) % vec2(-1, 1);
}

fn testI32()
{
// CHECK-L: modulo by zero
_ = 42i % 0i;
// CHECK-L: modulo by zero
_ = 42i % vec2(1i, 0i);
// CHECK-L: modulo by zero
_ = vec2(42i) % 0i;
// CHECK-L: modulo by zero
_ = vec2(42i) % vec2(1i, 0i);

let x = 42i;
// CHECK-L: modulo by zero
_ = x % 0i;
// CHECK-L: modulo by zero
_ = x % vec2(1i, 0i);
// CHECK-L: modulo by zero
_ = vec2(x) % 0i;
// CHECK-L: modulo by zero
_ = vec2(x) % vec2(1i, 0i);

// CHECK-L: invalid modulo overflow
_ = (-2147483647i - 1i) % -1i;
// CHECK-L: invalid modulo overflow
_ = vec2(-2147483647i - 1i, 1i) % vec2(-1i, 1i);
}

fn testU32()
{
// CHECK-L: modulo by zero
_ = 42u % 0u;
// CHECK-L: modulo by zero
_ = 42u % vec2(1u, 0u);
// CHECK-L: modulo by zero
_ = vec2(42u) % 0u;
// CHECK-L: modulo by zero
_ = vec2(42u) % vec2(1u, 0u);

let x = 42u;
// CHECK-L: modulo by zero
_ = x % 0u;
// CHECK-L: modulo by zero
_ = x % vec2(1u, 0u);
// CHECK-L: modulo by zero
_ = vec2(x) % 0u;
// CHECK-L: modulo by zero
_ = vec2(x) % vec2(1u, 0u);
}

fn testAbstractFloat()
{
// CHECK-L: value NaN cannot be represented as '<AbstractFloat>'
_ = 42.0 % 0.0;
// CHECK-L: value NaN cannot be represented as '<AbstractFloat>'
_ = 42.0 % vec2(1.0, 0.0);
// CHECK-L: value NaN cannot be represented as '<AbstractFloat>'
_ = vec2(42.0) % 0.0;
// CHECK-L: value NaN cannot be represented as '<AbstractFloat>'
_ = vec2(42.0) % vec2(1.0, 0.0);

// CHECK-NOT-L: modulo by zero
let x = 42.0;
// CHECK-NOT-L: modulo by zero
_ = x % 0.0;
// CHECK-NOT-L: modulo by zero
_ = x % vec2(1.0, 0.0);
// CHECK-NOT-L: modulo by zero
_ = vec2(x) % 0.0;
// CHECK-NOT-L: modulo by zero
_ = vec2(x) % vec2(1.0, 0.0);

// CHECK-NOT-L: invalid modulo overflow
_ = (-2147483647.0 - 1.0) % -1.0;
// CHECK-NOT-L: invalid modulo overflow
_ = vec2(-2147483647.0 - 1.0, 1.0) % vec2(-1.0, 1.0);

// CHECK-NOT-L: invalid modulo overflow
_ = (-9223372036854775807.0 - 1.0) % -1.0;
// CHECK-NOT-L: invalid modulo overflow
_ = vec2(-9223372036854775807.0 - 1.0, 1.0) % vec2(-1.0, 1.0);

// CHECK-NOT-L: invalid modulo overflow
_ = -340282346638528859811704183484516925440.0 - 1.0 % -1.0;
// CHECK-NOT-L: invalid modulo overflow
_ = vec2(-340282346638528859811704183484516925440.0 - 1.0, 1.0) % vec2(-1.0, 1.0);
}

fn testF32()
{
// CHECK-L: value NaN cannot be represented as 'f32'
_ = 42f % 0f;
// CHECK-L: value NaN cannot be represented as 'f32'
_ = 42f % vec2(1f, 0f);
// CHECK-L: value NaN cannot be represented as 'f32'
_ = vec2(42f) % 0f;
// CHECK-L: value NaN cannot be represented as 'f32'
_ = vec2(42f) % vec2(1f, 0f);

let x = 42f;
// CHECK-NOT-L: modulo by zero
_ = x % 0f;
// CHECK-NOT-L: modulo by zero
_ = x % vec2(1f, 0f);
// CHECK-NOT-L: modulo by zero
_ = vec2(x) % 0f;
// CHECK-NOT-L: modulo by zero
_ = vec2(x) % vec2(1f, 0f);

// CHECK-NOT-L: invalid modulo overflow
_ = (-2147483647f - 1f) % -1f;
// CHECK-NOT-L: invalid modulo overflow
_ = vec2(-2147483647f - 1f, 1f) % vec2(-1f, 1f);

// CHECK-NOT-L: invalid modulo overflow
_ = (-9223372036854775807.f - 1.f) % -1.f;
// CHECK-NOT-L: invalid modulo overflow
_ = vec2(-9223372036854775807.f - 1.f, 1.f) % vec2(-1.f, 1.f);

// CHECK-NOT-L: invalid modulo overflow
_ = (-340282346638528859811704183484516925439.f - 1f) % -1f;
// CHECK-NOT-L: invalid modulo overflow
_ = vec2(-340282346638528859811704183484516925439.f - 1f, 1f) % vec2(-1f, 1f);
}

fn testI32Compound()
{
var y: vec2<i32>;
// CHECK-L: modulo by zero
y %= 0;
// CHECK-L: modulo by zero
y %= vec2(1, 0);
// CHECK-L: modulo by zero
y[0] %= 0;
// CHECK-L: modulo by zero
y[0] %= vec2(1, 0);
}

fn testU32Compound()
{
var y: vec2<u32>;
// CHECK-L: modulo by zero
y %= 0;
// CHECK-L: modulo by zero
y %= vec2(1, 0);
// CHECK-L: modulo by zero
y[0] %= 0;
// CHECK-L: modulo by zero
y[0] %= vec2(1, 0);
}

fn testF32Compound()
{
// FIXME: these depende on proper typing for compound statements
// var y: vec2<f32>;
// skip-CHECK-NOT-L: modulo by zero
// y %= 0;
// skip-CHECK-NOT-L: modulo by zero
// y %= vec2(1, 0);
// skip-CHECK-NOT-L: modulo by zero
// y[0] %= 0;
// skip-CHECK-NOT-L: modulo by zero
// y[0] %= vec2(1, 0);
}

@compute @workgroup_size(1)
fn main() {
testAbstractInt();
testI32();
testU32();
testAbstractFloat();
testF32();

testI32Compound();
testU32Compound();
testF32Compound();
}

0 comments on commit e99d13b

Please sign in to comment.