Skip to content

Commit

Permalink
[WGSL] Validate functions missing a return statement
Browse files Browse the repository at this point in the history
https://bugs.webkit.org/show_bug.cgi?id=271482
rdar://124143123

Reviewed by Mike Wyrzykowski.

Implement the behavior analysis from the spec[1] and check if functions with a
return type are missing a return statement.

[1]: https://www.w3.org/TR/WGSL/#behaviors

* Source/WebGPU/WGSL/TypeCheck.cpp:
(WGSL::TypeChecker::visit):
(WGSL::TypeChecker::analyze):
(WGSL::TypeChecker::analyzeStatements):
* Source/WebGPU/WGSL/tests/invalid/function-call.wgsl:

Canonical link: https://commits.webkit.org/276630@main
  • Loading branch information
tadeuzagallo committed Mar 25, 2024
1 parent 7cb7a6e commit 2fa5e50
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 1 deletion.
125 changes: 125 additions & 0 deletions Source/WebGPU/WGSL/TypeCheck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "Types.h"
#include "WGSLShaderModule.h"
#include <wtf/DataLog.h>
#include <wtf/OptionSet.h>
#include <wtf/SetForScope.h>
#include <wtf/SortedArrayMap.h>

Expand Down Expand Up @@ -69,6 +70,14 @@ struct Binding {
std::optional<ConstantValue> constantValue;
};

enum class Behavior : uint8_t {
Return = 1 << 0,
Break = 1 << 1,
Continue = 1 << 2,
Next = 1 << 3,
};
using Behaviors = OptionSet<Behavior>;

static ASCIILiteral bindingKindToString(Binding::Kind kind)
{
switch (kind) {
Expand Down Expand Up @@ -211,6 +220,15 @@ class TypeChecker : public AST::ScopedVisitor<Binding> {
template<typename Node>
void setConstantValue(Node&, const Type*, const ConstantValue&);

Behaviors analyze(AST::Statement&);
Behaviors analyze(AST::CompoundStatement&);
Behaviors analyze(AST::ForStatement&);
Behaviors analyze(AST::IfStatement&);
Behaviors analyze(AST::LoopStatement&);
Behaviors analyze(AST::SwitchStatement&);
Behaviors analyze(AST::WhileStatement&);
Behaviors analyzeStatements(AST::Statement::List&);

ShaderModule& m_shaderModule;
const Type* m_inferredType { nullptr };
const Type* m_returnType { nullptr };
Expand Down Expand Up @@ -693,6 +711,10 @@ void TypeChecker::visit(AST::Function& function)
for (unsigned i = 0; i < parameters.size(); ++i)
introduceValue(function.parameters()[i].name(), parameters[i]);
Base::visit(function.body());

auto behaviours = analyze(function.body());
if (!behaviours.contains(Behavior::Return) && function.maybeReturnType())
typeError(InferBottom::No, function.span(), "missing return at end of function");
}

const Type* functionType = m_types.functionType(WTFMove(parameters), m_returnType, mustUse);
Expand Down Expand Up @@ -1961,6 +1983,109 @@ const Type* TypeChecker::infer(AST::Expression& expression, Evaluation evaluatio
return inferredType;
}

Behaviors TypeChecker::analyze(AST::Statement& statement)
{
switch (statement.kind()) {
case AST::NodeKind::AssignmentStatement:
case AST::NodeKind::BreakStatement:
case AST::NodeKind::CallStatement:
case AST::NodeKind::CompoundAssignmentStatement:
case AST::NodeKind::ConstAssertStatement:
case AST::NodeKind::DecrementIncrementStatement:
case AST::NodeKind::DiscardStatement:
case AST::NodeKind::PhonyAssignmentStatement:
case AST::NodeKind::StaticAssertStatement:
case AST::NodeKind::VariableStatement:
return Behavior::Next;
case AST::NodeKind::ReturnStatement:
return Behavior::Return;
case AST::NodeKind::ContinueStatement:
return Behavior::Continue;
case AST::NodeKind::CompoundStatement:
return analyze(uncheckedDowncast<AST::CompoundStatement>(statement));
case AST::NodeKind::ForStatement:
return analyze(uncheckedDowncast<AST::ForStatement>(statement));
case AST::NodeKind::IfStatement:
return analyze(uncheckedDowncast<AST::IfStatement>(statement));
case AST::NodeKind::LoopStatement:
return analyze(uncheckedDowncast<AST::LoopStatement>(statement));
case AST::NodeKind::SwitchStatement:
return analyze(uncheckedDowncast<AST::SwitchStatement>(statement));
case AST::NodeKind::WhileStatement:
return analyze(uncheckedDowncast<AST::WhileStatement>(statement));
default:
RELEASE_ASSERT_NOT_REACHED();
}
}

Behaviors TypeChecker::analyze(AST::CompoundStatement& statement)
{
return analyzeStatements(statement.statements());
}

Behaviors TypeChecker::analyze(AST::ForStatement& statement)
{
auto behaviors = Behaviors({ Behavior::Next, Behavior::Break, Behavior::Continue });
behaviors.add(analyze(statement.body()));
behaviors.remove({ Behavior::Break, Behavior::Continue });
return behaviors;
}

Behaviors TypeChecker::analyze(AST::IfStatement& statement)
{
auto behaviors = analyze(statement.trueBody());
if (auto* elseBody = statement.maybeFalseBody())
behaviors.add(analyze(*elseBody));
return behaviors;
}

Behaviors TypeChecker::analyze(AST::LoopStatement& statement)
{
auto behaviors = analyzeStatements(statement.body());
if (auto& continuing = statement.continuing()) {
behaviors.add(analyzeStatements(continuing->body));
if (auto* breakIf = continuing->breakIf)
behaviors.add({ Behavior::Break, Behavior:: Continue });
}
if (behaviors.contains(Behavior::Break))
behaviors.remove({ Behavior::Break, Behavior::Continue });
else
behaviors.remove({ Behavior::Next, Behavior::Continue });
return behaviors;
}

Behaviors TypeChecker::analyze(AST::SwitchStatement& statement)
{
auto behaviors = analyze(statement.defaultClause().body);
for (auto& clause : statement.clauses())
behaviors.add(analyze(clause.body));
if (behaviors.contains(Behavior::Break)) {
behaviors.remove(Behavior::Break);
behaviors.add(Behavior::Break);
}
return behaviors;
}

Behaviors TypeChecker::analyze(AST::WhileStatement& statement)
{
auto behaviors = Behaviors({ Behavior::Next, Behavior::Break });
behaviors.add(analyze(statement.body()));
behaviors.remove({ Behavior::Break, Behavior::Continue });
return behaviors;
}

Behaviors TypeChecker::analyzeStatements(AST::Statement::List& statements)
{
auto behaviors = Behaviors(Behavior::Next);
for (auto& statement : statements) {
behaviors.remove(Behavior::Next);
behaviors.add(analyze(statement));
if (!behaviors.contains(Behavior::Next))
break;
}
return behaviors;
}

const Type* TypeChecker::check(AST::Expression& expression, Constraint constraint, Evaluation evaluation)
{
auto* type = infer(expression, evaluation);
Expand Down
4 changes: 3 additions & 1 deletion Source/WebGPU/WGSL/tests/invalid/function-call.wgsl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// RUN: %not %wgslc | %check

fn f1(x: f32) -> f32 { return x; }
fn f1(x: f32) -> f32 {
// CHECK-L: missing return at end of function
}

fn f2() {
// CHECK-L: unresolved call target 'f0'
Expand Down

0 comments on commit 2fa5e50

Please sign in to comment.