Skip to content

Commit

Permalink
[WGSL] Add validation to integer division during code generation
Browse files Browse the repository at this point in the history
https://bugs.webkit.org/show_bug.cgi?id=264106
rdar://117865769

Reviewed by Mike Wyrzykowski.

According to the spec, the runtime behavior for integer division needs to handle
the following two corner cases:
- x / 0 = x
- INT_MIN / -1 = INT_MIN

The latter matches the default Metal behavior, but considering it's technically
undefined behavior, it seems safer to handle it explicitly.

* Source/WebGPU/WGSL/GlobalVariableRewriter.cpp:
(WGSL::RewriteGlobalVariables::getPacking):
* Source/WebGPU/WGSL/Metal/MetalCodeGenerator.cpp:
(WGSL::Metal::metalCodePrologue):
* Source/WebGPU/WGSL/Metal/MetalFunctionWriter.cpp:
(WGSL::Metal::FunctionDefinitionWriter::emitNecessaryHelpers):
(WGSL::Metal::FunctionDefinitionWriter::visit):
* Source/WebGPU/WGSL/TypeCheck.cpp:
(WGSL::TypeChecker::visit):
* Source/WebGPU/WGSL/WGSLShaderModule.h:
(WGSL::ShaderModule::usesDivision const):
(WGSL::ShaderModule::setUsesDivision):
(WGSL::ShaderModule::clearUsesDivision):
* Source/WebGPU/WGSL/tests/lit.cfg:
* Source/WebGPU/WGSL/tests/valid/division.wgsl: Added.

Canonical link: https://commits.webkit.org/270174@main
  • Loading branch information
tadeuzagallo committed Nov 3, 2023
1 parent b01c2d6 commit 3bb0644
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 1 deletion.
1 change: 1 addition & 0 deletions Source/WebGPU/WGSL/GlobalVariableRewriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,7 @@ auto RewriteGlobalVariables::getPacking(AST::CallExpression& call) -> Packing
);
strideExpression.m_inferredType = m_callGraph.ast().types().u32Type();

m_callGraph.ast().setUsesDivision();
auto& elementCount = m_callGraph.ast().astBuilder().construct<AST::BinaryExpression>(
SourceSpan::empty(),
length,
Expand Down
1 change: 1 addition & 0 deletions Source/WebGPU/WGSL/Metal/MetalCodeGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ static StringView metalCodePrologue()
{
return StringView {
"#include <metal_stdlib>\n"
"#include <metal_types>\n"
"\n"
"using namespace metal;\n"
"\n"_s
Expand Down
47 changes: 47 additions & 0 deletions Source/WebGPU/WGSL/Metal/MetalFunctionWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,24 @@ void FunctionDefinitionWriter::emitNecessaryHelpers()
m_stringBuilder.append(m_indent, "}\n\n");
}

if (m_callGraph.ast().usesDivision()) {
m_callGraph.ast().clearUsesDivision();
m_stringBuilder.append(m_indent, "template<typename T, typename U, typename V = conditional_t<is_scalar_v<U>, T, U>>\n");
m_stringBuilder.append(m_indent, "V __wgslDiv(T lhs, U rhs)\n");
m_stringBuilder.append(m_indent, "{\n");
{
IndentationScope scope(m_indent);
m_stringBuilder.append(m_indent, "auto predicate = V(rhs) == V(0);\n");
m_stringBuilder.append(m_indent, "if constexpr (is_signed_v<U>)\n");
{
IndentationScope scope(m_indent);
m_stringBuilder.append(m_indent, "predicate = predicate || (V(lhs) == V(numeric_limits<T>::lowest()) && V(rhs) == V(-1));\n");
}
m_stringBuilder.append(m_indent, "return lhs / select(V(rhs), V(1), predicate);\n");
}
m_stringBuilder.append(m_indent, "}\n\n");
}

if (m_callGraph.ast().usesPackedStructs()) {
m_callGraph.ast().clearUsesPackedStructs();

Expand Down Expand Up @@ -1415,6 +1433,20 @@ void FunctionDefinitionWriter::visit(AST::BinaryExpression& binary)
}
}

if (binary.operation() == AST::BinaryOperation::Divide) {
auto* resultType = binary.inferredType();
if (auto* vectorType = std::get_if<Types::Vector>(resultType))
resultType = vectorType->element;
if (satisfies(resultType, Constraints::Integer)) {
m_stringBuilder.append("__wgslDiv(");
visit(binary.leftExpression());
m_stringBuilder.append(", ");
visit(binary.rightExpression());
m_stringBuilder.append(")");
return;
}
}

m_stringBuilder.append("(");
visit(binary.leftExpression());
switch (binary.operation()) {
Expand Down Expand Up @@ -1572,6 +1604,21 @@ void FunctionDefinitionWriter::visit(AST::CallStatement& statement)

void FunctionDefinitionWriter::visit(AST::CompoundAssignmentStatement& statement)
{
if (statement.operation() == AST::BinaryOperation::Divide) {
auto* rightType = statement.rightExpression().inferredType();
if (auto* vectorType = std::get_if<Types::Vector>(rightType))
rightType = vectorType->element;
if (satisfies(rightType, Constraints::Integer)) {
visit(statement.leftExpression());
m_stringBuilder.append(" = __wgslDiv(");
visit(statement.leftExpression());
m_stringBuilder.append(", ");
visit(statement.rightExpression());
m_stringBuilder.append(")");
return;
}
}

visit(statement.leftExpression());
m_stringBuilder.append(" ", toASCIILiteral(statement.operation()), "= ");
visit(statement.rightExpression());
Expand Down
4 changes: 4 additions & 0 deletions Source/WebGPU/WGSL/TypeCheck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,8 @@ void TypeChecker::visit(AST::CompoundAssignmentStatement& statement)
{
// FIXME: Implement type checking - infer is called to avoid ASSERT in
// TypeChecker::visit(AST::Expression&)
if (statement.operation() == AST::BinaryOperation::Divide)
m_shaderModule.setUsesDivision();
infer(statement.leftExpression());
infer(statement.rightExpression());
}
Expand Down Expand Up @@ -813,6 +815,8 @@ void TypeChecker::visit(AST::IndexAccessExpression& access)

void TypeChecker::visit(AST::BinaryExpression& binary)
{
if (binary.operation() == AST::BinaryOperation::Divide)
m_shaderModule.setUsesDivision();
chooseOverload("operator", binary, toString(binary.operation()), ReferenceWrapperVector<AST::Expression, 2> { binary.leftExpression(), binary.rightExpression() }, { });
}

Expand Down
5 changes: 5 additions & 0 deletions Source/WebGPU/WGSL/WGSLShaderModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ class ShaderModule {
void setUsesWorkgroupUniformLoad() { m_usesWorkgroupUniformLoad = true; }
void clearUsesWorkgroupUniformLoad() { m_usesWorkgroupUniformLoad = false; }

bool usesDivision() const { return m_usesDivision; }
void setUsesDivision() { m_usesDivision = true; }
void clearUsesDivision() { m_usesDivision = false; }

template<typename T>
std::enable_if_t<std::is_base_of_v<AST::Node, T>, void> replace(T* current, T&& replacement)
{
Expand Down Expand Up @@ -218,6 +222,7 @@ class ShaderModule {
bool m_usesPackedStructs { false };
bool m_usesUnpackArray { false };
bool m_usesWorkgroupUniformLoad { false };
bool m_usesDivision { false };
Configuration m_configuration;
AST::Directive::List m_directives;
AST::Function::List m_functions;
Expand Down
3 changes: 2 additions & 1 deletion Source/WebGPU/WGSL/tests/lit.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,15 @@ config.environment['DYLD_FRAMEWORK_PATH'] = port._build_path()
ignored_warnings = [
'-Wno-unused-variable',
'-Wno-missing-braces',
'-Wno-c++17-extensions',
]

config.substitutions.append(('%check', '{}/bin/OutputCheck --comment=".*//" %s'.format(site.getuserbase())))
config.substitutions.append(('%wgslc', '{} %s _ 2>&1'.format(wgslc)))
config.substitutions.append(('%not', 'eval !'))
config.substitutions.append(('%metal-compile', (
"function metal_compile() {"
" set -e -o pipefail;"
" set -e -o pipefail;"
f" {wgslc} --dump-generated-code '%s' \"$1\" > '%t.metal';"
f" xcrun -sdk macosx metal -Werror {' '.join(ignored_warnings)} -c '%t.metal' -o /dev/null;"
"};"
Expand Down
75 changes: 75 additions & 0 deletions Source/WebGPU/WGSL/tests/valid/division.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// RUN: %metal-compile main

fn testI32()
{
var x: vec2<i32>;
var y: vec2<i32>;

let a = x / y;
let b = x / y[0];
let c = x[0] / y;
let d = x[0] / y[0];
}

fn testU32()
{
var x: vec2<u32>;
var y: vec2<u32>;

let a = x / y;
let b = x / y[0];
let c = x[0] / y;
let d = x[0] / y[0];
}

fn testF32()
{
var x: vec2<f32>;
var y: vec2<f32>;

let a = x / y;
let b = x / y[0];
let c = x[0] / y;
let d = x[0] / y[0];
}

fn testI32Compound()
{
var x: vec2<i32>;
var y: vec2<i32>;

x /= y;
x /= y[0];
x[0] /= y[0];
}

fn testU32Compound()
{
var x: vec2<u32>;
var y: vec2<u32>;

x /= y;
x /= y[0];
x[0] /= y[0];
}

fn testF32Compound()
{
var x: vec2<f32>;
var y: vec2<f32>;

x /= y;
x /= y[0];
x[0] /= y[0];
}

@compute @workgroup_size(1)
fn main() {
testI32();
testU32();
testF32();

testI32Compound();
testU32Compound();
testF32Compound();
}
2 changes: 2 additions & 0 deletions Tools/TestWebKitAPI/Tests/WGSL/MetalGenerationTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ fn main() -> @location(0) vec4<f32> {

EXPECT_TRUE(mslSource.has_value());
EXPECT_EQ(*mslSource, R"(#include <metal_stdlib>
#include <metal_types>
using namespace metal;
Expand All @@ -77,6 +78,7 @@ fn main(@builtin(position) position : vec4<f32>,

EXPECT_TRUE(mslSource.has_value());
EXPECT_EQ(*mslSource, R"(#include <metal_stdlib>
#include <metal_types>
using namespace metal;
Expand Down

0 comments on commit 3bb0644

Please sign in to comment.