Skip to content

Commit

Permalink
[WGSL] shader,validation,shader_io,builtins:* is failing
Browse files Browse the repository at this point in the history
https://bugs.webkit.org/show_bug.cgi?id=271254
rdar://125022496

Reviewed by Mike Wyrzykowski.

Validate builtin attributes according to the spec[1].

[1]: https://www.w3.org/TR/WGSL/#builtin-inputs-outputs

* LayoutTests/http/tests/webgpu/webgpu/shader/validation/shader_io/builtins-expected.txt:
* Source/WebGPU/WGSL/AttributeValidator.cpp:
(WGSL::AttributeValidator::validateIO):
(WGSL::AttributeValidator::validateBuiltinIO):
(WGSL::AttributeValidator::validateLocationIO):
(WGSL::AttributeValidator::validateStructIO):
(WGSL::validateIO):
* Source/WebGPU/WGSL/AttributeValidator.h:
* Source/WebGPU/WGSL/TypeCheck.cpp:
(WGSL::TypeChecker::visit):
* Source/WebGPU/WGSL/WGSL.cpp:
(WGSL::staticCheck):
* Source/WebGPU/WGSL/WGSLEnums.cpp:
* Source/WebGPU/WGSL/WGSLEnums.h:

Canonical link: https://commits.webkit.org/276454@main
  • Loading branch information
tadeuzagallo committed Mar 21, 2024
1 parent 28a06fc commit ec673c8
Show file tree
Hide file tree
Showing 7 changed files with 525 additions and 12 deletions.

Large diffs are not rendered by default.

165 changes: 165 additions & 0 deletions Source/WebGPU/WGSL/AttributeValidator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,17 @@

namespace WGSL {

enum class Direction : uint8_t {
Input,
Output,
};

class AttributeValidator : public AST::Visitor {
public:
AttributeValidator(ShaderModule&);

std::optional<FailedCheck> validate();
std::optional<FailedCheck> validateIO();

void visit(AST::Function&) override;
void visit(AST::Parameter&) override;
Expand All @@ -54,6 +60,12 @@ class AttributeValidator : public AST::Visitor {
void validateInterpolation(const SourceSpan&, const std::optional<AST::Interpolation>&, const std::optional<unsigned>&);
void validateInvariant(const SourceSpan&, const std::optional<Builtin>&, bool);

using Builtins = HashSet<Builtin, WTF::IntHash<Builtin>, WTF::StrongEnumHashTraits<Builtin>>;
using Locations = HashSet<uint32_t, DefaultHash<uint32_t>, WTF::UnsignedWithZeroKeyHashTraits<uint32_t>>;
void validateBuiltinIO(const SourceSpan&, const Type*, ShaderStage, Builtin, Direction, Builtins&);
void validateLocationIO(const SourceSpan&, const Type*, ShaderStage, Direction, Locations&);
void validateStructIO(ShaderStage, const Types::Struct&, Direction, Builtins&, Locations&);

template<typename T>
void update(const SourceSpan&, std::optional<T>&, const T&);
void set(const SourceSpan&, bool&);
Expand Down Expand Up @@ -440,6 +452,7 @@ void AttributeValidator::validateInvariant(const SourceSpan& span, const std::op
error(span, "@invariant is only allowed on declarations that have a @builtin(position) attribute");
}


template<typename T>
void AttributeValidator::update(const SourceSpan& span, std::optional<T>& destination, const T& source)
{
Expand All @@ -463,9 +476,161 @@ void AttributeValidator::error(const SourceSpan& span, Arguments&&... arguments)
m_errors.append({ makeString(std::forward<Arguments>(arguments)...), span });
}

std::optional<FailedCheck> AttributeValidator::validateIO()
{
for (auto& entryPoint : m_shaderModule.callGraph().entrypoints()) {
auto& function = entryPoint.function;
Builtins builtins;
Locations locations;
for (auto& parameter : function.parameters()) {
const auto& span = parameter.span();
const auto* type = parameter.typeName().inferredType();

if (auto builtin = parameter.builtin()) {
validateBuiltinIO(span, type, entryPoint.stage, *builtin, Direction::Input, builtins);
continue;
}

if (parameter.location()) {
validateLocationIO(span, type, entryPoint.stage, Direction::Input, locations);
continue;
}

if (auto* structType = std::get_if<Types::Struct>(type)) {
validateStructIO(entryPoint.stage, *structType, Direction::Input, builtins, locations);
continue;
}

error(span, "missing entry point IO attribute on parameter");
}

if (!function.maybeReturnType()) {
if (entryPoint.stage == ShaderStage::Vertex)
error(function.span(), "a vertex shader must include the 'position' builtin in its return type");
continue;
}

builtins.clear();
locations.clear();
const auto& span = function.maybeReturnType()->span();
const auto* type = function.maybeReturnType()->inferredType();

if (auto builtin = function.returnTypeBuiltin())
validateBuiltinIO(span, type, entryPoint.stage, *builtin, Direction::Output, builtins);
else if (function.returnTypeLocation())
validateLocationIO(span, type, entryPoint.stage, Direction::Output, locations);
else if (auto* structType = std::get_if<Types::Struct>(type))
validateStructIO(entryPoint.stage, *structType, Direction::Output, builtins, locations);
else {
error(span, "missing entry point IO attribute on return type");
continue;
}

if (entryPoint.stage == ShaderStage::Vertex && !builtins.contains(Builtin::Position))
error(span, "a vertex shader must include the 'position' builtin in its return type");
}

if (m_errors.isEmpty())
return std::nullopt;
return FailedCheck { WTFMove(m_errors), { } };
}

void AttributeValidator::validateBuiltinIO(const SourceSpan& span, const Type* type, ShaderStage stage, Builtin builtin, Direction direction, Builtins& builtins)
{


#define TYPE_CHECK(__type) \
type != m_shaderModule.types().__type##Type(), *m_shaderModule.types().__type##Type()

#define VEC_CHECK(__count, __elementType) \
auto* vector = std::get_if<Types::Vector>(type); !vector || vector->size != __count || vector->element != m_shaderModule.types().__elementType##Type(), "vec" #__count "<" #__elementType ">"

#define CASE_(__case, __typeCheck, __type) \
case Builtin::__case: \
if (__typeCheck) { \
error(span, "store type of @builtin(", toString(Builtin::__case), ") must be '", __type, "'"); \
return; \
} \

#define CASE(__case, __typeCheck, __stage, __direction) \
CASE_(__case, __typeCheck); \
if (stage != ShaderStage::__stage || direction != Direction::__direction) { \
error(span, "@builtin(", toString(Builtin::__case), ") cannot be used for ", toString(stage), " shader ", direction == Direction::Input ? "input" : "output"); \
return; \
} \
break;

#define CASE2(__case, __typeCheck, __stage1, __direction1, __stage2, __direction2) \
CASE_(__case, __typeCheck); \
if ((stage != ShaderStage::__stage1 || direction != Direction::__direction1) && (stage != ShaderStage::__stage2 || direction != Direction::__direction2)) { \
error(span, "@builtin(", toString(Builtin::__case), ") cannot be used for ", toString(stage), " shader ", direction == Direction::Input ? "input" : "output"); \
return; \
} \
break;

switch (builtin) {
CASE(FragDepth, TYPE_CHECK(f32), Fragment, Output)
CASE(FrontFacing, TYPE_CHECK(bool), Fragment, Input)
CASE(GlobalInvocationId, VEC_CHECK(3, u32), Compute, Input)
CASE(InstanceIndex, TYPE_CHECK(u32), Vertex, Input)
CASE(LocalInvocationId, VEC_CHECK(3, u32), Compute, Input)
CASE(LocalInvocationIndex, TYPE_CHECK(u32), Compute, Input)
CASE(NumWorkgroups, VEC_CHECK(3, u32), Compute, Input)
CASE(SampleIndex, TYPE_CHECK(u32), Fragment, Input)
CASE(VertexIndex, TYPE_CHECK(u32), Vertex, Input)
CASE(WorkgroupId, VEC_CHECK(3, u32), Compute, Input)
CASE2(SampleMask, TYPE_CHECK(u32), Fragment, Input, Fragment, Output)
CASE2(Position, VEC_CHECK(4, f32), Vertex, Output, Fragment, Input)
}

auto result = builtins.add(builtin);
if (!result.isNewEntry)
error(span, "@builtin(", toString(builtin), ") appears multiple times as pipeline input");
}

void AttributeValidator::validateLocationIO(const SourceSpan& span, const Type* type, ShaderStage stage, Direction direction, Locations& locations)
{
// FIXME: implement this
UNUSED_PARAM(span);
UNUSED_PARAM(type);
UNUSED_PARAM(stage);
UNUSED_PARAM(direction);
UNUSED_PARAM(locations);
}

void AttributeValidator::validateStructIO(ShaderStage stage, const Types::Struct& structType, Direction direction, Builtins& builtins, Locations& locations)
{
for (auto& member : structType.structure.members()) {
const auto& span = member.span();
const auto* type = member.type().inferredType();

if (auto builtin = member.builtin()) {
validateBuiltinIO(span, type, stage, *builtin, direction, builtins);
continue;
}

if (member.location()) {
validateLocationIO(span, type, stage, direction, locations);
continue;
}

if (auto* structType = std::get_if<Types::Struct>(member.type().inferredType())) {
error(span, "nested structures cannot be used for entry point IO");
continue;
}

error(span, "missing entry point IO attribute");
}
}

std::optional<FailedCheck> validateAttributes(ShaderModule& shaderModule)
{
return AttributeValidator(shaderModule).validate();
}

std::optional<FailedCheck> validateIO(ShaderModule& shaderModule)
{
return AttributeValidator(shaderModule).validateIO();
}

} // namespace WGSL
1 change: 1 addition & 0 deletions Source/WebGPU/WGSL/AttributeValidator.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,6 @@ namespace WGSL {
class ShaderModule;

std::optional<FailedCheck> validateAttributes(ShaderModule&);
std::optional<FailedCheck> validateIO(ShaderModule&);

} // namespace WGSL
4 changes: 2 additions & 2 deletions Source/WebGPU/WGSL/TypeCheck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1151,7 +1151,7 @@ void TypeChecker::visit(AST::IndexAccessExpression& access)

void TypeChecker::visit(AST::BinaryExpression& binary)
{
chooseOverload("operator", binary, toString(binary.operation()), ReferenceWrapperVector<AST::Expression, 2> { binary.leftExpression(), binary.rightExpression() }, { });
chooseOverload("operator", binary, toASCIILiteral(binary.operation()), ReferenceWrapperVector<AST::Expression, 2> { binary.leftExpression(), binary.rightExpression() }, { });

const char* operationName = nullptr;
if (binary.operation() == AST::BinaryOperation::Divide)
Expand Down Expand Up @@ -1600,7 +1600,7 @@ void TypeChecker::visit(AST::UnaryExpression& unary)
inferred(m_types.referenceType(pointer->addressSpace, pointer->element, pointer->accessMode));
return;
}
chooseOverload("operator", unary, toString(unary.operation()), ReferenceWrapperVector<AST::Expression, 1> { unary.expression() }, { });
chooseOverload("operator", unary, toASCIILiteral(unary.operation()), ReferenceWrapperVector<AST::Expression, 1> { unary.expression() }, { });
}

// Literal Expressions
Expand Down
1 change: 1 addition & 0 deletions Source/WebGPU/WGSL/WGSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ std::variant<SuccessfulCheck, FailedCheck> staticCheck(const String& wgsl, const
CHECK_PASS(typeCheck, shaderModule);
CHECK_PASS(validateAttributes, shaderModule);
RUN_PASS(buildCallGraph, shaderModule);
CHECK_PASS(validateIO, shaderModule);

Vector<Warning> warnings { };
return std::variant<SuccessfulCheck, FailedCheck>(std::in_place_type<SuccessfulCheck>, WTFMove(warnings), WTFMove(shaderModule));
Expand Down
26 changes: 17 additions & 9 deletions Source/WebGPU/WGSL/WGSLEnums.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,28 @@ namespace WGSL {
#define CONTINUATION(args...) args RPAREN
#define EXPAND(x) x

#define ENUM_DEFINE_PRINT_INTERNAL_CASE_(__type, __name, __string, ...) \
#define ENUM_DEFINE_TO_STRING_CASE_(__type, __name, __string, ...) \
case __type::__name: \
out.print(#__string); \
return #__string##_s; \
break;

#define ENUM_DEFINE_PRINT_INTERNAL_CASE(__name) \
ENUM_DEFINE_PRINT_INTERNAL_CASE_ LPAREN __name, CONTINUATION
#define ENUM_DEFINE_TO_STRING_CASE(__name) \
ENUM_DEFINE_TO_STRING_CASE_ LPAREN __name, CONTINUATION

#define ENUM_DEFINE_PRINT_INTERNAL(__name) \
void printInternal(PrintStream& out, __name __value) \
#define ENUM_DEFINE_TO_STRING(__name) \
ASCIILiteral toString(__name __value) \
{ \
switch (__value) { \
EXPAND(ENUM_##__name(ENUM_DEFINE_PRINT_INTERNAL_CASE LPAREN __name RPAREN)) \
EXPAND(ENUM_##__name(ENUM_DEFINE_TO_STRING_CASE LPAREN __name RPAREN)) \
} \
}

#define ENUM_DEFINE_PRINT_INTERNAL(__name) \
void printInternal(PrintStream& out, __name __value) \
{ \
out.print(toString(__value)); \
}

#define ENUM_DEFINE_PARSE_ENTRY_(__type, __name, __string, ...) \
{ #__string, __type::__name },

Expand All @@ -70,6 +76,7 @@ namespace WGSL {
}

#define ENUM_DEFINE(__name) \
ENUM_DEFINE_TO_STRING(__name) \
ENUM_DEFINE_PRINT_INTERNAL(__name) \
ENUM_DEFINE_PARSE(__name)

Expand All @@ -86,8 +93,9 @@ ENUM_DEFINE(LanguageFeature);

#undef ENUM_DEFINE
#undef ENUM_DEFINE_PRINT_INTERNAL
#undef ENUM_DEFINE_PRINT_INTERNAL_CASE
#undef ENUM_DEFINE_PRINT_INTERNAL_CASE_
#undef ENUM_DEFINE_TO_STRING
#undef ENUM_DEFINE_TO_STRING_CASE
#undef ENUM_DEFINE_TO_STRING_CASE_
#undef EXPAND
#undef CONTINUATION
#undef RPAREN
Expand Down
5 changes: 5 additions & 0 deletions Source/WebGPU/WGSL/WGSLEnums.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#pragma once

namespace WTF {
class ASCIILiteral;
class PrintStream;
class String;
}
Expand Down Expand Up @@ -114,6 +115,9 @@ namespace WGSL {
#define ENUM_DECLARE_PRINT_INTERNAL(__name) \
void printInternal(WTF::PrintStream& out, __name)

#define ENUM_DECLARE_TO_STRING(__name) \
WTF::ASCIILiteral toString(__name)

#define ENUM_DECLARE_PARSE(__name) \
const __name* parse##__name(const WTF::String&)

Expand All @@ -122,6 +126,7 @@ namespace WGSL {
ENUM_##__name(ENUM_DECLARE_VALUE) \
}; \
ENUM_DECLARE_PRINT_INTERNAL(__name); \
ENUM_DECLARE_TO_STRING(__name); \
ENUM_DECLARE_PARSE(__name);

ENUM_DECLARE(AddressSpace);
Expand Down

0 comments on commit ec673c8

Please sign in to comment.