Skip to content

Commit

Permalink
Adopt dynamicDowncast<> in some WebGPU code
Browse files Browse the repository at this point in the history
https://bugs.webkit.org/show_bug.cgi?id=270409

Reviewed by Tadeu Zagallo.

For security & performance.

* Source/WebGPU/WGSL/Metal/MetalFunctionWriter.cpp:
(WGSL::Metal::FunctionDefinitionWriter::write):
(WGSL::Metal::FunctionDefinitionWriter::visit):
(WGSL::Metal::fragDepthIdentifierForFunction):
* Source/WebGPU/WGSL/PointerRewriter.cpp:
(WGSL::PointerRewriter::visit):
* Source/WebGPU/WGSL/TypeCheck.cpp:
(WGSL::TypeChecker::visit):
(WGSL::TypeChecker::chooseOverload):
(WGSL::TypeChecker::resolve):
(WGSL::TypeChecker::texelFormat):
(WGSL::TypeChecker::accessMode):
(WGSL::TypeChecker::addressSpace):
* Source/WebGPU/WGSL/WGSL.cpp:
(WGSL::evaluate):
* Source/WebGPU/WebGPU/ShaderModule.mm:
(WebGPU::ShaderModule::parseFragmentReturnType):
(WebGPU::ShaderModule::ShaderModule):

Canonical link: https://commits.webkit.org/275614@main
  • Loading branch information
annevk committed Mar 4, 2024
1 parent 094154f commit e827b50
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 67 deletions.
42 changes: 17 additions & 25 deletions Source/WebGPU/WGSL/Metal/MetalFunctionWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,15 +164,15 @@ void FunctionDefinitionWriter::write()
emitNecessaryHelpers();

for (auto& declaration : m_callGraph.ast().declarations()) {
if (is<AST::Structure>(declaration))
visit(downcast<AST::Structure>(declaration));
else if (is<AST::Variable>(declaration))
visitGlobal(downcast<AST::Variable>(declaration));
if (auto* structure = dynamicDowncast<AST::Structure>(declaration))
visit(*structure);
else if (auto* variable = dynamicDowncast<AST::Variable>(declaration))
visitGlobal(*variable);
}

for (auto& declaration : m_callGraph.ast().declarations()) {
if (is<AST::Structure>(declaration))
generatePackingHelpers(downcast<AST::Structure>(declaration));
if (auto* structure = dynamicDowncast<AST::Structure>(declaration))
generatePackingHelpers(*structure);
}

for (auto& entryPoint : m_callGraph.entrypoints())
Expand Down Expand Up @@ -1108,17 +1108,12 @@ void FunctionDefinitionWriter::visit(const Type* type, AST::Expression& expressi
return;
}

switch (expression.kind()) {
case AST::NodeKind::CallExpression:
visit(type, downcast<AST::CallExpression>(expression));
break;
case AST::NodeKind::IdentityExpression:
visit(type, downcast<AST::IdentityExpression>(expression).expression());
break;
default:
if (auto* call = dynamicDowncast<AST::CallExpression>(expression))
visit(type, *call);
else if (auto* identity = dynamicDowncast<AST::IdentityExpression>(expression))
visit(type, identity->expression());
else
AST::Visitor::visit(expression);
break;
}
}

static void visitArguments(FunctionDefinitionWriter* writer, AST::CallExpression& call, unsigned startOffset = 0)
Expand Down Expand Up @@ -1792,9 +1787,8 @@ static void emitUnpack4xU8(FunctionDefinitionWriter* writer, AST::CallExpression

void FunctionDefinitionWriter::visit(const Type* type, AST::CallExpression& call)
{
if (is<AST::ElaboratedTypeExpression>(call.target())) {
auto& base = downcast<AST::ElaboratedTypeExpression>(call.target()).base();
if (base == "bitcast"_s) {
if (auto* target = dynamicDowncast<AST::ElaboratedTypeExpression>(call.target())) {
if (target->base() == "bitcast"_s) {
emitBitcast(this, call);
return;
}
Expand Down Expand Up @@ -1825,7 +1819,7 @@ void FunctionDefinitionWriter::visit(const Type* type, AST::CallExpression& call
return;
}

if (is<AST::IdentifierExpression>(call.target())) {
if (auto* target = dynamicDowncast<AST::IdentifierExpression>(call.target())) {
static constexpr std::pair<ComparableASCIILiteral, void(*)(FunctionDefinitionWriter*, AST::CallExpression&)> builtinMappings[] {
{ "__dynamicOffset", emitDynamicOffset },
{ "arrayLength", emitArrayLength },
Expand Down Expand Up @@ -1873,7 +1867,7 @@ void FunctionDefinitionWriter::visit(const Type* type, AST::CallExpression& call
{ "workgroupUniformLoad", emitWorkgroupUniformLoad },
};
static constexpr SortedArrayMap builtins { builtinMappings };
const auto& targetName = downcast<AST::IdentifierExpression>(call.target()).identifier().id();
const auto& targetName = target->identifier().id();
if (auto mappedBuiltin = builtins.get(targetName)) {
mappedBuiltin(this, call);
return;
Expand Down Expand Up @@ -2253,10 +2247,8 @@ static std::optional<std::pair<String, String>> fragDepthIdentifierForFunction(A
if (member.builtin() == WGSL::Builtin::FragDepth)
return std::make_pair(returnStruct->structure.name(), member.name());
for (auto& attribute : member.attributes()) {
if (attribute.kind() != AST::NodeKind::BuiltinAttribute)
continue;
auto& builtinAttribute = downcast<AST::BuiltinAttribute>(attribute);
if (builtinAttribute.builtin() == WGSL::Builtin::FragDepth)
auto* builtinAttribute = dynamicDowncast<AST::BuiltinAttribute>(attribute);
if (builtinAttribute && builtinAttribute->builtin() == WGSL::Builtin::FragDepth)
return std::make_pair(returnStruct->structure.name(), member.name());
}
}
Expand Down
8 changes: 3 additions & 5 deletions Source/WebGPU/WGSL/PointerRewriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,12 @@ void PointerRewriter::visit(AST::UnaryExpression& unary)
AST::Expression* nested = &unary.expression();
while (is<AST::IdentityExpression>(*nested))
nested = &downcast<AST::IdentityExpression>(*nested).expression();
if (!is<AST::UnaryExpression>(*nested))
return;

auto& nestedUnary = downcast<AST::UnaryExpression>(*nested);
if (nestedUnary.operation() != AST::UnaryOperation::AddressOf)
auto* nestedUnary = dynamicDowncast<AST::UnaryExpression>(*nested);
if (!nestedUnary || nestedUnary->operation() != AST::UnaryOperation::AddressOf)
return;

m_callGraph.ast().replace(unary, nestedUnary.expression());
m_callGraph.ast().replace(unary, nestedUnary->expression());
}

void rewritePointers(CallGraph& callGraph)
Expand Down
34 changes: 15 additions & 19 deletions Source/WebGPU/WGSL/TypeCheck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1370,35 +1370,34 @@ void TypeChecker::visit(AST::CallExpression& call)
return;
}

if (is<AST::ArrayTypeExpression>(target)) {
AST::ArrayTypeExpression& array = downcast<AST::ArrayTypeExpression>(target);
if (auto* array = dynamicDowncast<AST::ArrayTypeExpression>(target)) {
const Type* elementType = nullptr;
unsigned elementCount;

if (array.maybeElementType()) {
if (!array.maybeElementCount()) {
if (array->maybeElementType()) {
if (!array->maybeElementCount()) {
typeError(call.span(), "cannot construct a runtime-sized array");
return;
}
elementType = resolve(*array.maybeElementType());
auto* elementCountType = infer(*array.maybeElementCount(), m_evaluation);
elementType = resolve(*array->maybeElementType());
auto* elementCountType = infer(*array->maybeElementCount(), m_evaluation);

if (isBottom(elementType) || isBottom(elementCountType)) {
inferred(m_types.bottomType());
return;
}

if (!unify(m_types.i32Type(), elementCountType) && !unify(m_types.u32Type(), elementCountType)) {
typeError(array.span(), "array count must be an i32 or u32 value, found '", *elementCountType, "'");
typeError(array->span(), "array count must be an i32 or u32 value, found '", *elementCountType, "'");
return;
}

if (!elementType->isConstructible()) {
typeError(array.span(), "'", *elementType, "' cannot be used as an element type of an array");
typeError(array->span(), "'", *elementType, "' cannot be used as an element type of an array");
return;
}

auto constantValue = array.maybeElementCount()->constantValue();
auto constantValue = array->maybeElementCount()->constantValue();
if (!constantValue) {
typeError(call.span(), "array must have constant size in order to be constructed");
return;
Expand Down Expand Up @@ -1431,7 +1430,7 @@ void TypeChecker::visit(AST::CallExpression& call)
argument.m_inferredType = elementType;
}
} else {
ASSERT(!array.maybeElementCount());
ASSERT(!array->maybeElementCount());
elementCount = call.arguments().size();
if (!elementCount) {
typeError(call.span(), "cannot infer array element type from constructor");
Expand All @@ -1446,7 +1445,7 @@ void TypeChecker::visit(AST::CallExpression& call)
elementType = argumentType;

if (!elementType->isConstructible()) {
typeError(array.span(), "'", *elementType, "' cannot be used as an element type of an array");
typeError(array->span(), "'", *elementType, "' cannot be used as an element type of an array");
return;
}

Expand Down Expand Up @@ -1881,9 +1880,9 @@ const Type* TypeChecker::chooseOverload(const char* kind, AST::Expression& expre
callArguments[i].m_inferredType = overload->parameters[i];
inferred(overload->result);

if (it->value.kind == OverloadedDeclaration::Constructor && is<AST::CallExpression>(expression)) {
auto& call = downcast<AST::CallExpression>(expression);
call.m_isConstructor = true;
if (it->value.kind == OverloadedDeclaration::Constructor) {
if (auto* call = dynamicDowncast<AST::CallExpression>(expression))
call->m_isConstructor = true;
}

unsigned argumentCount = callArguments.size();
Expand Down Expand Up @@ -1975,8 +1974,8 @@ const Type* TypeChecker::check(AST::Expression& expression, Constraint constrain
const Type* TypeChecker::resolve(AST::Expression& type)
{
ASSERT(!m_inferredType);
if (is<AST::IdentifierExpression>(type))
inferred(lookupType(downcast<AST::IdentifierExpression>(type).identifier()));
if (auto* identifierExpression = dynamicDowncast<AST::IdentifierExpression>(type))
inferred(lookupType(identifierExpression->identifier()));
else
Base::visit(type);
ASSERT(m_inferredType);
Expand Down Expand Up @@ -2320,7 +2319,6 @@ std::optional<TexelFormat> TypeChecker::texelFormat(AST::Expression& expression)
return std::nullopt;
}

ASSERT(is<AST::IdentifierExpression>(expression));
auto& formatName = downcast<AST::IdentifierExpression>(expression).identifier();

auto* format = parseTexelFormat(formatName.id());
Expand All @@ -2339,7 +2337,6 @@ std::optional<AccessMode> TypeChecker::accessMode(AST::Expression& expression)
return std::nullopt;
}

ASSERT(is<AST::IdentifierExpression>(expression));
auto& accessName = downcast<AST::IdentifierExpression>(expression).identifier();

auto* accessMode = parseAccessMode(accessName.id());
Expand All @@ -2358,7 +2355,6 @@ std::optional<AddressSpace> TypeChecker::addressSpace(AST::Expression& expressio
return std::nullopt;
}

ASSERT(is<AST::IdentifierExpression>(expression));
auto& addressSpaceName = downcast<AST::IdentifierExpression>(expression).identifier();

auto* addressSpace = parseAddressSpace(addressSpaceName.id());
Expand Down
1 change: 0 additions & 1 deletion Source/WebGPU/WGSL/WGSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ ConstantValue evaluate(const AST::Expression& expression, const HashMap<String,
{
if (auto constantValue = expression.constantValue())
return *constantValue;
ASSERT(is<const AST::IdentifierExpression>(expression));
auto constantValue = constants.get(downcast<const AST::IdentifierExpression>(expression).identifier());
const_cast<AST::Expression&>(expression).setConstantValue(constantValue);
return constantValue;
Expand Down
32 changes: 15 additions & 17 deletions Source/WebGPU/WebGPU/ShaderModule.mm
Original file line number Diff line number Diff line change
Expand Up @@ -464,10 +464,10 @@ static WGPUVertexFormat vertexFormatTypeForStructMember(const WGSL::Type* type)
populateOutputState(entryPoint, *member.builtin());

for (auto& attribute : member.attributes()) {
if (attribute.kind() != WGSL::AST::NodeKind::BuiltinAttribute)
auto* builtinAttribute = dynamicDowncast<WGSL::AST::BuiltinAttribute>(attribute);
if (!builtinAttribute)
continue;
auto& builtinAttribute = downcast<WGSL::AST::BuiltinAttribute>(attribute);
populateOutputState(entryPoint, builtinAttribute.builtin());
populateOutputState(entryPoint, builtinAttribute->builtin());
}

if (!member.location() || member.builtin())
Expand Down Expand Up @@ -671,45 +671,43 @@ static void populateStageInMap(const WGSL::Type& type, ShaderModule::VertexStage
if (std::holds_alternative<WGSL::SuccessfulCheck>(m_checkResult)) {
auto& check = std::get<WGSL::SuccessfulCheck>(m_checkResult);
for (auto& declaration : check.ast->declarations()) {
if (!is<WGSL::AST::Function>(declaration))
auto* function = dynamicDowncast<WGSL::AST::Function>(declaration);
if (!function || !function->stage())
continue;
auto& function = downcast<WGSL::AST::Function>(declaration);
if (!function.stage())
continue;
switch (*function.stage()) {
switch (*function->stage()) {
case WGSL::ShaderStage::Vertex: {
m_stageInTypesForEntryPoint.add(function.name(), parseStageIn(function));
if (auto expression = function.maybeReturnType()) {
m_stageInTypesForEntryPoint.add(function->name(), parseStageIn(*function));
if (auto expression = function->maybeReturnType()) {
if (auto* inferredType = expression->inferredType())
m_vertexReturnTypeForEntryPoint.add(function.name(), parseVertexReturnType(*inferredType));
m_vertexReturnTypeForEntryPoint.add(function->name(), parseVertexReturnType(*inferredType));
}
if (!allowVertexDefault || m_defaultVertexEntryPoint.length()) {
allowVertexDefault = false;
m_defaultVertexEntryPoint = emptyString();
continue;
}
m_defaultVertexEntryPoint = function.name();
m_defaultVertexEntryPoint = function->name();
} break;
case WGSL::ShaderStage::Fragment: {
m_fragmentInputsForEntryPoint.add(function.name(), parseFragmentInputs(function));
if (auto expression = function.maybeReturnType()) {
m_fragmentInputsForEntryPoint.add(function->name(), parseFragmentInputs(*function));
if (auto expression = function->maybeReturnType()) {
if (auto* inferredType = expression->inferredType())
m_fragmentReturnTypeForEntryPoint.add(function.name(), parseFragmentReturnType(*inferredType, function.name()));
m_fragmentReturnTypeForEntryPoint.add(function->name(), parseFragmentReturnType(*inferredType, function->name()));
}
if (!allowFragmentDefault || m_defaultFragmentEntryPoint.length()) {
allowFragmentDefault = false;
m_defaultFragmentEntryPoint = emptyString();
continue;
}
m_defaultFragmentEntryPoint = function.name();
m_defaultFragmentEntryPoint = function->name();
} break;
case WGSL::ShaderStage::Compute: {
if (!allowComputeDefault || m_defaultComputeEntryPoint.length()) {
allowComputeDefault = false;
m_defaultComputeEntryPoint = emptyString();
continue;
}
m_defaultComputeEntryPoint = function.name();
m_defaultComputeEntryPoint = function->name();
} break;
default:
ASSERT_NOT_REACHED();
Expand Down

0 comments on commit e827b50

Please sign in to comment.