Skip to content

Commit

Permalink
[WGSL] Fix failing wgslc tests
Browse files Browse the repository at this point in the history
https://bugs.webkit.org/show_bug.cgi?id=274479
rdar://128485179

Reviewed by Mike Wyrzykowski.

There were 2 issues after 278991@main:
- we emitted definitions for `__unpack(PackedVec3)` even when we didn't emit
  the PackedVec3 struct
- in a previous patch we tried skipping some explicit pack/unpack calls for
  vec3/packed_vec3, but that doesn't work for PackedVec3, so we can no longer
  skip these calls.

* Source/WebGPU/WGSL/AST/ASTForward.h:
* Source/WebGPU/WGSL/AST/ASTParameter.h:
* Source/WebGPU/WGSL/AST/ASTVariable.h:
* Source/WebGPU/WGSL/GlobalVariableRewriter.cpp:
(WGSL::RewriteGlobalVariables::visitCallee):
(WGSL::RewriteGlobalVariables::visit):
(WGSL::RewriteGlobalVariables::pack):
(WGSL::RewriteGlobalVariables::packStructResource):
(WGSL::RewriteGlobalVariables::packArrayResource):
(WGSL::RewriteGlobalVariables::insertMaterializations):
* Source/WebGPU/WGSL/Metal/MetalFunctionWriter.cpp:
(WGSL::Metal::FunctionDefinitionWriter::emitNecessaryHelpers):
(WGSL::Metal::FunctionDefinitionWriter::visit):
(WGSL::Metal::FunctionDefinitionWriter::shouldPackType const):
(WGSL::Metal::FunctionDefinitionWriter::emitPackedVector):
(WGSL::Metal::FunctionDefinitionWriter::serializeVariable):
* Source/WebGPU/WGSL/WGSLShaderModule.h:
(WGSL::ShaderModule::usesPackVector const):
(WGSL::ShaderModule::setUsesPackVector):
(WGSL::ShaderModule::clearUsesPackVector):
(WGSL::ShaderModule::usesUnpackVector const):
(WGSL::ShaderModule::setUsesUnpackVector):
(WGSL::ShaderModule::clearUsesUnpackVector):
* Source/WebGPU/WGSL/tests/valid/packing.wgsl:

Canonical link: https://commits.webkit.org/279141@main
  • Loading branch information
tadeuzagallo committed May 22, 2024
1 parent 625de07 commit 8cb6094
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 71 deletions.
1 change: 1 addition & 0 deletions Source/WebGPU/WGSL/AST/ASTForward.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,5 +113,6 @@ enum class ParameterRole : uint8_t;
enum class StructureRole : uint8_t;
enum class UnaryOperation : uint8_t;
enum class VariableFlavor : uint8_t;
enum class VariableRole : uint8_t;

} // namespace WGSL::AST
1 change: 1 addition & 0 deletions Source/WebGPU/WGSL/AST/ASTParameter.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ enum class ParameterRole : uint8_t {
UserDefined,
StageIn,
BindGroup,
PackedResource,
};

class Parameter final : public Node {
Expand Down
13 changes: 12 additions & 1 deletion Source/WebGPU/WGSL/AST/ASTVariable.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ enum class VariableFlavor : uint8_t {
Var,
};

enum class VariableRole : uint8_t {
UserDefined,
PackedResource,
};

class Variable final : public Declaration {
WGSL_AST_BUILDER_NODE(Variable);
friend AttributeValidator;
Expand All @@ -60,6 +65,10 @@ class Variable final : public Declaration {
NodeKind kind() const override;
VariableFlavor flavor() const { return m_flavor; };
VariableFlavor& flavor() { return m_flavor; };

VariableRole role() const { return m_role; }
VariableRole& role() { return m_role; }

Identifier& name() override { return m_name; }
Identifier& originalName() { return m_originalName; }
Attribute::List& attributes() { return m_attributes; }
Expand All @@ -85,7 +94,7 @@ class Variable final : public Declaration {
: Variable(span, flavor, WTFMove(name), { }, type, initializer, { })
{ }

Variable(SourceSpan span, VariableFlavor flavor, Identifier&& name, VariableQualifier::Ptr qualifier, Expression::Ptr type, Expression::Ptr initializer, Attribute::List&& attributes)
Variable(SourceSpan span, VariableFlavor flavor, Identifier&& name, VariableQualifier::Ptr qualifier, Expression::Ptr type, Expression::Ptr initializer, Attribute::List&& attributes, VariableRole role = VariableRole::UserDefined)
: Declaration(span)
, m_name(WTFMove(name))
, m_originalName(m_name)
Expand All @@ -94,6 +103,7 @@ class Variable final : public Declaration {
, m_type(type)
, m_initializer(initializer)
, m_flavor(flavor)
, m_role(role)
{
ASSERT(m_type || m_initializer);
}
Expand All @@ -107,6 +117,7 @@ class Variable final : public Declaration {
Expression::Ptr m_type;
Expression::Ptr m_initializer;
VariableFlavor m_flavor;
VariableRole m_role;
Expression::Ptr m_referenceType { nullptr };

// Computed properties
Expand Down
56 changes: 28 additions & 28 deletions Source/WebGPU/WGSL/GlobalVariableRewriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ class RewriteGlobalVariables : public AST::Visitor {
AST::Expression* m_bufferLengthReferenceType { nullptr };
AST::Function* m_currentFunction { nullptr };
HashMap<std::pair<unsigned, unsigned>, unsigned> m_globalsUsingDynamicOffset;
HashSet<AST::Expression*> m_doNotUnpack;
};

std::optional<Error> RewriteGlobalVariables::run()
Expand Down Expand Up @@ -212,12 +213,13 @@ void RewriteGlobalVariables::visitCallee(const CallGraph::Callee& callee)
}
ASSERT(type);

auto parameterRole = global.declaration->role() == AST::VariableRole::PackedResource ? AST::ParameterRole::PackedResource : AST::ParameterRole::UserDefined;
m_shaderModule.append(callee.target->parameters(), m_shaderModule.astBuilder().construct<AST::Parameter>(
SourceSpan::empty(),
AST::Identifier::make(read),
*type,
AST::Attribute::List { },
AST::ParameterRole::UserDefined
parameterRole
));
}

Expand Down Expand Up @@ -250,6 +252,7 @@ void RewriteGlobalVariables::visitCallee(const CallGraph::Callee& callee)
);
global.m_inferredType = it->value.declaration->storeType();
m_shaderModule.append(call->arguments(), global);
m_doNotUnpack.add(&global);
}
}

Expand Down Expand Up @@ -355,15 +358,13 @@ void RewriteGlobalVariables::visit(AST::AssignmentStatement& statement)
ASSERT(lhsPacking != Packing::Either);
if (lhsPacking == Packing::PackedVec3)
lhsPacking = Packing::Either;
else
lhsPacking = static_cast<Packing>(lhsPacking | Packing::Vec3);
pack(lhsPacking, statement.rhs());
}

void RewriteGlobalVariables::visit(AST::VariableStatement& statement)
{
if (auto* initializer = statement.variable().maybeInitializer())
pack(static_cast<Packing>(Packing::Unpacked | Packing::Vec3), *initializer);
pack(static_cast<Packing>(Packing::Unpacked), *initializer);
}

void RewriteGlobalVariables::visit(AST::PhonyAssignmentStatement& statement)
Expand All @@ -378,6 +379,9 @@ void RewriteGlobalVariables::visit(AST::Expression& expression)

Packing RewriteGlobalVariables::pack(Packing expectedPacking, AST::Expression& expression)
{
if (m_doNotUnpack.contains(&expression))
return expectedPacking;

const auto& visitAndReplace = [&](auto& expression) -> Packing {
auto packing = getPacking(expression);
if (expectedPacking & packing)
Expand All @@ -387,9 +391,11 @@ Packing RewriteGlobalVariables::pack(Packing expectedPacking, AST::Expression& e
if (auto* referenceType = std::get_if<Types::Reference>(type))
type = referenceType->element;
ASCIILiteral operation;
if (std::holds_alternative<Types::Struct>(*type))
if (std::holds_alternative<Types::Struct>(*type)) {
if (!type->isConstructible())
return packing;
operation = packing & Packing::Packed ? "__unpack"_s : "__pack"_s;
else if (std::holds_alternative<Types::Array>(*type)) {
} else if (std::holds_alternative<Types::Array>(*type)) {
// array of vec3 can be implicitly converted
if (packing & Packing::Vec3)
m_shaderModule.setUsesPackedVec3();
Expand All @@ -401,26 +407,12 @@ Packing RewriteGlobalVariables::pack(Packing expectedPacking, AST::Expression& e
m_shaderModule.setUsesPackArray();
}
} else {
ASSERT(std::holds_alternative<Types::Vector>(*type));
auto& vector = std::get<Types::Vector>(*type);
ASSERT(std::holds_alternative<Types::Primitive>(*vector.element));
switch (std::get<Types::Primitive>(*vector.element).kind) {
case Types::Primitive::AbstractInt:
case Types::Primitive::I32:
operation = packing & Packing::Packed ? "int3"_s : "packed_int3"_s;
break;
case Types::Primitive::U32:
operation = packing & Packing::Packed ? "uint3"_s : "packed_uint3"_s;
break;
case Types::Primitive::AbstractFloat:
case Types::Primitive::F32:
operation = packing & Packing::Packed ? "float3"_s : "packed_float3"_s;
break;
case Types::Primitive::F16:
operation = packing & Packing::Packed ? "half3"_s : "packed_half3"_s;
break;
default:
RELEASE_ASSERT_NOT_REACHED();
if (packing & Packing::Packed) {
operation = "__unpack"_s;
m_shaderModule.setUsesUnpackVector();
} else {
operation = "__pack"_s;
m_shaderModule.setUsesPackVector();
}
}
RELEASE_ASSERT(!operation.isNull());
Expand Down Expand Up @@ -762,13 +754,19 @@ void RewriteGlobalVariables::packStructResource(AST::Variable& global, const Typ
auto& namedTypeName = downcast<AST::IdentifierExpression>(*global.maybeTypeName());
m_shaderModule.replace(namedTypeName, packedType);
updateReference(global, packedType);
m_shaderModule.replace(&global.role(), AST::VariableRole::PackedResource);
}

void RewriteGlobalVariables::packArrayResource(AST::Variable& global, const Types::Array* arrayType)
{
const Type* packedArrayType = packArrayType(arrayType);
if (!packedArrayType)
if (!packedArrayType) {
if (arrayType->element->packing() & Packing::Vec3) {
m_shaderModule.setUsesPackedVec3();
m_shaderModule.replace(&global.role(), AST::VariableRole::PackedResource);
}
return;
}

const Type* packedStructType = std::get<Types::Array>(*packedArrayType).element;
auto& packedType = m_shaderModule.astBuilder().construct<AST::IdentifierExpression>(
Expand All @@ -787,6 +785,7 @@ void RewriteGlobalVariables::packArrayResource(AST::Variable& global, const Type

m_shaderModule.replace(arrayTypeName, packedArrayTypeName);
updateReference(global, packedArrayTypeName);
m_shaderModule.replace(&global.role(), AST::VariableRole::PackedResource);
}

void RewriteGlobalVariables::updateReference(AST::Variable& global, AST::Expression& packedType)
Expand Down Expand Up @@ -1806,7 +1805,8 @@ void RewriteGlobalVariables::insertMaterializations(AST::Function& function, con
nullptr,
global->declaration->maybeReferenceType(),
initializer,
AST::Attribute::List { }
AST::Attribute::List { },
AST::VariableRole::PackedResource
);

auto& variableStatement = m_shaderModule.astBuilder().construct<AST::VariableStatement>(SourceSpan::empty(), variable);
Expand Down
Loading

0 comments on commit 8cb6094

Please sign in to comment.