Skip to content

Commit

Permalink
[WGSL] Add support for constant structs
Browse files Browse the repository at this point in the history
https://bugs.webkit.org/show_bug.cgi?id=269710
rdar://123238708

Reviewed by Mike Wyrzykowski.

We already had ConstantStruct, but it was only being used for primitive structs
(like the one returned by frexp), so we just needed to create the constant values
and add the lookup logic to FieldAccessExpression.

* Source/WebGPU/WGSL/TypeCheck.cpp:
(WGSL::TypeChecker::visit):
(WGSL::TypeChecker::convertValueImpl):

Canonical link: https://commits.webkit.org/275043@main
  • Loading branch information
tadeuzagallo committed Feb 20, 2024
1 parent 4d69396 commit 3cfd9dd
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 9 deletions.
9 changes: 5 additions & 4 deletions Source/WebGPU/WGSL/ConstantFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,11 @@ static ConstantValue zeroValue(const Type* type)
result.elements[i] = value;
return result;
},
[&](const Types::Struct&) -> ConstantValue {
// FIXME: this is valid and needs to be implemented, but we don't
// yet have ConstantStruct
RELEASE_ASSERT_NOT_REACHED();
[&](const Types::Struct& structType) -> ConstantValue {
HashMap<String, ConstantValue> constantFields;
for (auto& [key, type] : structType.fields)
constantFields.set(key, zeroValue(type));
return ConstantStruct { WTFMove(constantFields) };
},
[&](const Types::PrimitiveStruct&) -> ConstantValue {
// Primitive structs can't be zero initialized
Expand Down
33 changes: 28 additions & 5 deletions Source/WebGPU/WGSL/TypeCheck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1001,6 +1001,10 @@ void TypeChecker::visit(AST::FieldAccessExpression& access)
typeError(access.span(), "struct '", *baseType, "' does not have a member called '", access.fieldName(), "'");
return nullptr;
}
if (auto constant = access.base().constantValue()) {
auto& constantStruct = std::get<ConstantStruct>(*constant);
access.setConstantValue(constantStruct.fields.get(access.fieldName().id()));
}
return it->value;
}

Expand Down Expand Up @@ -1205,6 +1209,8 @@ void TypeChecker::visit(AST::CallExpression& call)
return;
}

HashMap<String, ConstantValue> constantFields;
bool isConstant = true;
for (unsigned i = 0; i < numberOfArguments; ++i) {
auto& argument = call.arguments()[i];
auto& member = structType->structure.members()[i];
Expand All @@ -1216,8 +1222,19 @@ void TypeChecker::visit(AST::CallExpression& call)
}
argument.m_inferredType = fieldType;
auto& value = argument.m_constantValue;
if (value.has_value())
convertValue(argument.span(), argument.inferredType(), value);
if (value.has_value()) {
if (convertValue(argument.span(), argument.inferredType(), value))
constantFields.set(member.name(), *value);
else
isConstant = false;
}
}
if (isConstant) {

if (numberOfArguments)
setConstantValue(call, targetBinding->type, ConstantStruct { WTFMove(constantFields) });
else
setConstantValue(call, targetBinding->type, zeroValue(targetBinding->type));
}
inferred(targetBinding->type);
return;
Expand Down Expand Up @@ -2077,9 +2094,15 @@ bool TypeChecker::convertValueImpl(const SourceSpan& span, const Type* type, Con
}
return true;
},
[&](const Types::Struct&) -> bool {
// FIXME: this should be supported
RELEASE_ASSERT_NOT_REACHED();
[&](const Types::Struct& structType) -> bool {
auto& constantStruct = std::get<ConstantStruct>(value);
for (auto& [key, type] : structType.fields) {
auto it = constantStruct.fields.find(key);
RELEASE_ASSERT(it != constantStruct.fields.end());
if (!convertValueImpl(span, type, it->value))
return false;
}
return true;
},
[&](const Types::PrimitiveStruct& primitiveStruct) -> bool {
auto& constantStruct = std::get<ConstantStruct>(value);
Expand Down

0 comments on commit 3cfd9dd

Please sign in to comment.