Skip to content

Commit

Permalink
[WGSL] Use arena allocation for AST nodes
Browse files Browse the repository at this point in the history
https://bugs.webkit.org/show_bug.cgi?id=255917
rdar://108496324

Reviewed by Myles C. Maxfield.

The AST was originally implemented using only unique references/pointers, which
made it really difficult to transform the AST. Currently we have a mix of Ref(Ptr)
and unique references/pointers, which is unnecessary since all the nodes have the
same lifecycle. To simplify things, we switch to an arena allocator, inspired by
the JSC's ParserArena. I converted only a single leaf node, StructureMember, to
use the new AST builder as a proof of concept, but more nodes can be converted
incrementally later.

* Source/WebGPU/WGSL/AST/ASTBuilder.cpp: Added.
(WGSL::AST::Builder::arena):
(WGSL::AST::Builder::~Builder):
(WGSL::AST::Builder::allocateArena):
* Source/WebGPU/WGSL/AST/ASTBuilder.h: Added.
(WGSL::AST::Builder::construct):
(WGSL::AST::Builder::alignSize):
* Source/WebGPU/WGSL/AST/ASTStringDumper.cpp:
(WGSL::AST::StringDumper::visitPointerVector):
(WGSL::AST::StringDumper::visit):
* Source/WebGPU/WGSL/AST/ASTStringDumper.h:
* Source/WebGPU/WGSL/AST/ASTStructureMember.h:
* Source/WebGPU/WGSL/AST/ASTVisitor.cpp:
(WGSL::AST::Visitor::visit):
* Source/WebGPU/WGSL/EntryPointRewriter.cpp:
(WGSL::EntryPointRewriter::constructInputStruct):
(WGSL::EntryPointRewriter::visit):
* Source/WebGPU/WGSL/GlobalVariableRewriter.cpp:
(WGSL::RewriteGlobalVariables::insertStructs):
* Source/WebGPU/WGSL/MangleNames.cpp:
(WGSL::NameManglerVisitor::visit):
* Source/WebGPU/WGSL/Metal/MetalFunctionWriter.cpp:
(WGSL::Metal::FunctionDefinitionWriter::visit):
* Source/WebGPU/WGSL/Parser.cpp:
(WGSL::Parser<Lexer>::parseStructure):
(WGSL::Parser<Lexer>::parseStructureMember): Deleted.
* Source/WebGPU/WGSL/ParserPrivate.h:
(WGSL::Parser::Parser):
* Source/WebGPU/WGSL/TypeCheck.cpp:
(WGSL::TypeChecker::visitStructMembers):
* Source/WebGPU/WGSL/WGSLShaderModule.h:
(WGSL::ShaderModule::astBuilder):
* Source/WebGPU/WebGPU.xcodeproj/project.pbxproj:

Canonical link: https://commits.webkit.org/263452@main
  • Loading branch information
tadeuzagallo committed Apr 27, 2023
1 parent 3257310 commit af2e0ab
Show file tree
Hide file tree
Showing 13 changed files with 191 additions and 23 deletions.
62 changes: 62 additions & 0 deletions Source/WebGPU/WGSL/AST/ASTBuilder.cpp
@@ -0,0 +1,62 @@
/*
* Copyright (c) 2023 Apple Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY APPLE INC. ``AS IS'' AND ANY
* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR
* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
* OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/

#include "config.h"
#include "ASTBuilder.h"

namespace WGSL::AST {

DEFINE_ALLOCATOR_WITH_HEAP_IDENTIFIER(WGSLAST);

Builder::Builder(Builder&& other)
{
m_arena = std::exchange(other.m_arena, nullptr);
m_arenaEnd = std::exchange(other.m_arenaEnd, nullptr);
m_arenas = WTFMove(other.m_arenas);
m_nodes = WTFMove(other.m_nodes);
}

inline uint8_t* Builder::arena()
{
ASSERT(m_arenaEnd);
return m_arenaEnd - arenaSize;
}

Builder::~Builder()
{
size_t size = m_nodes.size();
for (size_t i = 0; i < size; ++i)
m_nodes[i]->~Node();
}

void Builder::allocateArena()
{
m_arenas.append(MallocPtr<uint8_t, WGSLASTMalloc>::malloc(arenaSize));
m_arena = m_arenas.last().get();
m_arenaEnd = m_arena + arenaSize;
ASSERT(arena() == m_arena);
}

} // namespace WGSL::AST
83 changes: 83 additions & 0 deletions Source/WebGPU/WGSL/AST/ASTBuilder.h
@@ -0,0 +1,83 @@
/*
* Copyright (c) 2023 Apple Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY APPLE INC. ``AS IS'' AND ANY
* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR
* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
* OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/

#pragma once

#include "ASTNode.h"
#include <wtf/MallocPtr.h>
#include <wtf/Noncopyable.h>
#include <wtf/Nonmovable.h>
#include <wtf/Vector.h>

#define WGSL_AST_BUILDER_NODE(Node) \
WTF_MAKE_NONCOPYABLE(Node); \
WTF_MAKE_NONMOVABLE(Node); \
friend class Builder; \

namespace WGSL::AST {

DECLARE_ALLOCATOR_WITH_HEAP_IDENTIFIER(WGSLAST);

class Builder {
WTF_MAKE_NONCOPYABLE(Builder);

public:
static constexpr size_t arenaSize = 0x4000;

Builder() = default;
Builder(Builder&&);
~Builder();

template<typename T, typename... Arguments, typename = std::enable_if_t<std::is_base_of_v<Node, T>>>
T& construct(Arguments&&... arguments)
{
constexpr size_t size = sizeof(T);
constexpr size_t alignedSize = alignSize(size);
static_assert(alignedSize <= arenaSize);
if (UNLIKELY(static_cast<size_t>(m_arenaEnd - m_arena) < alignedSize))
allocateArena();

auto* node = new (m_arena) T(std::forward<Arguments>(arguments)...);
m_arena += alignedSize;
m_nodes.append(node);
return *node;
}

private:
static constexpr size_t alignSize(size_t size)
{
return (size + sizeof(WTF::AllocAlignmentInteger) - 1) & ~(sizeof(WTF::AllocAlignmentInteger) - 1);
}

uint8_t* arena();
void allocateArena();

uint8_t* m_arena { nullptr };
uint8_t* m_arenaEnd { nullptr };
Vector<MallocPtr<uint8_t, WGSLASTMalloc>> m_arenas;
Vector<Node*> m_nodes;
};

} // namespace WGSL::AST
18 changes: 10 additions & 8 deletions Source/WebGPU/WGSL/AST/ASTStructureMember.h
Expand Up @@ -26,29 +26,31 @@
#pragma once

#include "ASTAttribute.h"
#include "ASTBuilder.h"
#include "ASTIdentifier.h"
#include "ASTTypeName.h"

namespace WGSL::AST {

class StructureMember final : public Node {
WTF_MAKE_FAST_ALLOCATED;
WGSL_AST_BUILDER_NODE(StructureMember);

public:
using List = UniqueRefVector<StructureMember>;
using List = Vector<std::reference_wrapper<StructureMember>>;

NodeKind kind() const final;
Identifier& name() { return m_name; }
TypeName& type() { return m_type; }
Attribute::List& attributes() { return m_attributes; }

private:
StructureMember(SourceSpan span, Identifier&& name, TypeName::Ref&& type, Attribute::List&& attributes)
: Node(span)
, m_name(WTFMove(name))
, m_attributes(WTFMove(attributes))
, m_type(WTFMove(type))
{ }

NodeKind kind() const final;
Identifier& name() { return m_name; }
TypeName& type() { return m_type; }
Attribute::List& attributes() { return m_attributes; }

private:
Identifier m_name;
Attribute::List m_attributes;
TypeName::Ref m_type;
Expand Down
4 changes: 2 additions & 2 deletions Source/WebGPU/WGSL/EntryPointRewriter.cpp
Expand Up @@ -159,7 +159,7 @@ void EntryPointRewriter::constructInputStruct()
// insert `var ${parameter.name()} = ${structName}.${parameter.name()}`
AST::StructureMember::List structMembers;
for (auto& parameter : m_parameters) {
structMembers.append(makeUniqueRef<AST::StructureMember>(
structMembers.append(m_shaderModule.astBuilder().construct<AST::StructureMember>(
SourceSpan::empty(),
WTFMove(parameter.name),
WTFMove(parameter.type),
Expand Down Expand Up @@ -253,7 +253,7 @@ void EntryPointRewriter::visit(Vector<String>& path, MemberOrParameter&& data)
)
));
path.append(data.name);
for (auto& member : structType->structure.members())
for (AST::StructureMember& member : structType->structure.members())
visit(path, MemberOrParameter { member.name(), member.type(), member.attributes() });
path.removeLast();
return;
Expand Down
2 changes: 1 addition & 1 deletion Source/WebGPU/WGSL/GlobalVariableRewriter.cpp
Expand Up @@ -312,7 +312,7 @@ void RewriteGlobalVariables::insertStructs(const UsedGlobals& usedGlobals)
AST::TypeName::Ref memberType = *global->declaration->maybeTypeName();
if (shouldBeReference)
memberType = adoptRef(*new AST::ReferenceTypeName(span, WTFMove(memberType)));
structMembers.append(makeUniqueRef<AST::StructureMember>(
structMembers.append(m_callGraph.ast().astBuilder().construct<AST::StructureMember>(
span,
AST::Identifier::make(global->declaration->name()),
WTFMove(memberType),
Expand Down
2 changes: 1 addition & 1 deletion Source/WebGPU/WGSL/MangleNames.cpp
Expand Up @@ -148,7 +148,7 @@ void NameManglerVisitor::visit(AST::Structure& structure)
introduceVariable(structure.name(), MangledName::Type);

NameMap fieldMap;
for (auto& member : structure.members()) {
for (AST::StructureMember& member : structure.members()) {
AST::Visitor::visit(member.type());
auto mangledName = makeMangledName(member.name(), MangledName::Field);
fieldMap.add(member.name(), mangledName);
Expand Down
2 changes: 1 addition & 1 deletion Source/WebGPU/WGSL/Metal/MetalFunctionWriter.cpp
Expand Up @@ -168,7 +168,7 @@ void FunctionDefinitionWriter::visit(AST::Structure& structDecl)
m_stringBuilder.append(m_indent, "struct ", structDecl.name(), " {\n");
{
IndentationScope scope(m_indent);
for (auto& member : structDecl.members()) {
for (AST::StructureMember& member : structDecl.members()) {
m_stringBuilder.append(m_indent);
visit(member.type());
m_stringBuilder.append(" ", member.name());
Expand Down
12 changes: 9 additions & 3 deletions Source/WebGPU/WGSL/Parser.cpp
Expand Up @@ -77,6 +77,12 @@ struct TemplateTypes<TT> {
return { WTFMove(astNodeResult) }; \
} while (false)

#define RETURN_ARENA_NODE(type, ...) \
do { \
AST::type& astNodeResult = m_builder.construct<AST::type>(CURRENT_SOURCE_SPAN() __VA_OPT__(,) __VA_ARGS__); /* NOLINT */ \
return { astNodeResult }; \
} while (false)

#define RETURN_NODE_REF(type, ...) \
return { adoptRef(*new AST::type(CURRENT_SOURCE_SPAN(), __VA_ARGS__)) };

Expand Down Expand Up @@ -549,7 +555,7 @@ Result<AST::Structure::Ref> Parser<Lexer>::parseStructure(AST::Attribute::List&&
AST::StructureMember::List members;
while (current().type != TokenType::BraceRight) {
PARSE(member, StructureMember);
members.append(makeUniqueRef<AST::StructureMember>(WTFMove(member)));
members.append(member);
if (current().type == TokenType::Comma)
consume();
else
Expand All @@ -562,7 +568,7 @@ Result<AST::Structure::Ref> Parser<Lexer>::parseStructure(AST::Attribute::List&&
}

template<typename Lexer>
Result<AST::StructureMember> Parser<Lexer>::parseStructureMember()
Result<std::reference_wrapper<AST::StructureMember>> Parser<Lexer>::parseStructureMember()
{
START_PARSE();

Expand All @@ -571,7 +577,7 @@ Result<AST::StructureMember> Parser<Lexer>::parseStructureMember()
CONSUME_TYPE(Colon);
PARSE(type, TypeName);

RETURN_NODE(StructureMember, WTFMove(name), WTFMove(type), WTFMove(attributes));
RETURN_ARENA_NODE(StructureMember, WTFMove(name), WTFMove(type), WTFMove(attributes));
}

template<typename Lexer>
Expand Down
6 changes: 5 additions & 1 deletion Source/WebGPU/WGSL/ParserPrivate.h
Expand Up @@ -26,6 +26,7 @@
#pragma once

#include "ASTAttribute.h"
#include "ASTBuilder.h"
#include "ASTExpression.h"
#include "ASTForward.h"
#include "ASTStatement.h"
Expand All @@ -34,6 +35,7 @@
#include "ASTVariable.h"
#include "CompilationMessage.h"
#include "Lexer.h"
#include "WGSLShaderModule.h"
#include <wtf/Ref.h>

namespace WGSL {
Expand All @@ -45,6 +47,7 @@ class Parser {
public:
Parser(ShaderModule& shaderModule, Lexer& lexer)
: m_shaderModule(shaderModule)
, m_builder(shaderModule.astBuilder())
, m_lexer(lexer)
, m_current(lexer.lex())
{
Expand All @@ -58,7 +61,7 @@ class Parser {
Result<AST::Attribute::List> parseAttributes();
Result<AST::Attribute::Ref> parseAttribute();
Result<AST::Structure::Ref> parseStructure(AST::Attribute::List&&);
Result<AST::StructureMember> parseStructureMember();
Result<std::reference_wrapper<AST::StructureMember>> parseStructureMember();
Result<AST::TypeName::Ref> parseTypeName();
Result<AST::TypeName::Ref> parseTypeNameAfterIdentifier(AST::Identifier&&, SourcePosition start);
Result<AST::TypeName::Ref> parseArrayType();
Expand Down Expand Up @@ -101,6 +104,7 @@ class Parser {
Token& current() { return m_current; }

ShaderModule& m_shaderModule;
AST::Builder& m_builder;
Lexer& m_lexer;
Token m_current;
};
Expand Down
2 changes: 1 addition & 1 deletion Source/WebGPU/WGSL/TypeCheck.cpp
Expand Up @@ -178,7 +178,7 @@ void TypeChecker::visitStructMembers(AST::Structure& structure)
ASSERT(std::holds_alternative<Types::Struct>(**type));

auto& structType = std::get<Types::Struct>(**type);
for (auto& member : structure.members()) {
for (AST::StructureMember& member : structure.members()) {
auto* memberType = resolve(member.type());
auto result = structType.fields.add(member.name().id(), memberType);
ASSERT_UNUSED(result, result.isNewEntry);
Expand Down
5 changes: 4 additions & 1 deletion Source/WebGPU/WGSL/WGSLShaderModule.h
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021 Apple Inc. All rights reserved.
* Copyright (c) 2021-2023 Apple Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
Expand All @@ -25,6 +25,7 @@

#pragma once

#include "ASTBuilder.h"
#include "ASTDirective.h"
#include "ASTFunction.h"
#include "ASTStructure.h"
Expand Down Expand Up @@ -56,6 +57,7 @@ class ShaderModule {
AST::Structure::List& structures() { return m_structures; }
AST::Variable::List& variables() { return m_variables; }
TypeStore& types() { return m_types; }
AST::Builder& astBuilder() { return m_astBuilder; }

template<typename T>
std::enable_if_t<std::is_base_of_v<AST::Node, T>, void> replace(T* current, T&& replacement)
Expand Down Expand Up @@ -136,6 +138,7 @@ class ShaderModule {
AST::Structure::List m_structures;
AST::Variable::List m_variables;
TypeStore m_types;
AST::Builder m_astBuilder;
Vector<std::function<void()>> m_replacements;
};

Expand Down
8 changes: 8 additions & 0 deletions Source/WebGPU/WebGPU.xcodeproj/project.pbxproj
Expand Up @@ -123,6 +123,8 @@
9776BE732992A236002D6D93 /* Overload.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 9776BE712992A236002D6D93 /* Overload.cpp */; };
9776BE742992A236002D6D93 /* Overload.h in Headers */ = {isa = PBXBuildFile; fileRef = 9776BE722992A236002D6D93 /* Overload.h */; };
9776BE7629957E12002D6D93 /* WGSLShaderModule.h in Headers */ = {isa = PBXBuildFile; fileRef = 9776BE7529957E12002D6D93 /* WGSLShaderModule.h */; };
97835C9329F7C9C600939EBA /* ASTBuilder.h in Headers */ = {isa = PBXBuildFile; fileRef = 97835C9229F7C9C600939EBA /* ASTBuilder.h */; };
97835C9529F7D85A00939EBA /* ASTBuilder.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 97835C9429F7D85A00939EBA /* ASTBuilder.cpp */; };
9789C31A297EA105009E9006 /* CallGraph.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 9789C318297EA105009E9006 /* CallGraph.cpp */; };
978A9125298A4E8400B37E5E /* MangleNames.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 978A9123298A4E8400B37E5E /* MangleNames.cpp */; };
978A9126298A4E8400B37E5E /* MangleNames.h in Headers */ = {isa = PBXBuildFile; fileRef = 978A9124298A4E8400B37E5E /* MangleNames.h */; };
Expand Down Expand Up @@ -363,6 +365,8 @@
9776BE712992A236002D6D93 /* Overload.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = Overload.cpp; sourceTree = "<group>"; };
9776BE722992A236002D6D93 /* Overload.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = Overload.h; sourceTree = "<group>"; };
9776BE7529957E12002D6D93 /* WGSLShaderModule.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = WGSLShaderModule.h; sourceTree = "<group>"; };
97835C9229F7C9C600939EBA /* ASTBuilder.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ASTBuilder.h; sourceTree = "<group>"; };
97835C9429F7D85A00939EBA /* ASTBuilder.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = ASTBuilder.cpp; sourceTree = "<group>"; };
9789C318297EA105009E9006 /* CallGraph.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = CallGraph.cpp; sourceTree = "<group>"; };
9789C319297EA105009E9006 /* CallGraph.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = CallGraph.h; sourceTree = "<group>"; };
978A9123298A4E8400B37E5E /* MangleNames.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = MangleNames.cpp; sourceTree = "<group>"; };
Expand Down Expand Up @@ -602,6 +606,8 @@
3A12AECB28FCFA9800C1B975 /* ASTBitcastExpression.h */,
3A12AED028FCFC5500C1B975 /* ASTBoolLiteral.h */,
3A12AE9F28FCE94B00C1B975 /* ASTBreakStatement.h */,
97835C9429F7D85A00939EBA /* ASTBuilder.cpp */,
97835C9229F7C9C600939EBA /* ASTBuilder.h */,
3A12AE9028FCE94A00C1B975 /* ASTBuiltinAttribute.h */,
33EA188527BC26DF00A1DD52 /* ASTCallExpression.h */,
3A12AEA328FCE94C00C1B975 /* ASTCompoundAssignmentStatement.h */,
Expand Down Expand Up @@ -707,6 +713,7 @@
3A12AECD28FCFA9800C1B975 /* ASTBitcastExpression.h in Headers */,
3A12AED628FCFC5600C1B975 /* ASTBoolLiteral.h in Headers */,
3A12AEB928FCE94C00C1B975 /* ASTBreakStatement.h in Headers */,
97835C9329F7C9C600939EBA /* ASTBuilder.h in Headers */,
3A12AEAA28FCE94C00C1B975 /* ASTBuiltinAttribute.h in Headers */,
33EA188627BC26DF00A1DD52 /* ASTCallExpression.h in Headers */,
3A12AEBD28FCE94C00C1B975 /* ASTCompoundAssignmentStatement.h in Headers */,
Expand Down Expand Up @@ -951,6 +958,7 @@
buildActionMask = 2147483647;
files = (
3AD0D23E2988F3AB0080D728 /* ASTBinaryExpression.cpp in Sources */,
97835C9529F7D85A00939EBA /* ASTBuilder.cpp in Sources */,
3A9D02A4298390CF00888A75 /* ASTStringDumper.cpp in Sources */,
3AD0D23B2988ED8F0080D728 /* ASTUnaryExpression.cpp in Sources */,
3A1337E728FBD56400F29B73 /* ASTVisitor.cpp in Sources */,
Expand Down

0 comments on commit af2e0ab

Please sign in to comment.