Skip to content

Commit

Permalink
[WGSL] Build call graph during staticCheck
Browse files Browse the repository at this point in the history
https://bugs.webkit.org/show_bug.cgi?id=270540
rdar://124095995

Reviewed by Mike Wyrzykowski.

Some validations that must happen during staticCheck/shader creation time require
knowing in which stages variables are used, and for that we need the call graph.
This patch just refactors the call graph creation into staticCheck, so that we
can use it in a subsequent patch to perform the validations.

* Source/WebGPU/WGSL/CallGraph.cpp:
(WGSL::CallGraphBuilder::CallGraphBuilder):
(WGSL::CallGraphBuilder::build):
(WGSL::CallGraphBuilder::initializeMappings):
(WGSL::buildCallGraph):
(WGSL::CallGraph::CallGraph): Deleted.
* Source/WebGPU/WGSL/CallGraph.h:
(WGSL::CallGraph::CallGraph):
(WGSL::CallGraph::ast const): Deleted.
* Source/WebGPU/WGSL/EntryPointRewriter.cpp:
(WGSL::EntryPointRewriter::EntryPointRewriter):
(WGSL::rewriteEntryPoints):
* Source/WebGPU/WGSL/EntryPointRewriter.h:
* Source/WebGPU/WGSL/GlobalVariableRewriter.cpp:
(WGSL::RewriteGlobalVariables::RewriteGlobalVariables):
(WGSL::RewriteGlobalVariables::run):
(WGSL::RewriteGlobalVariables::visitCallee):
(WGSL::RewriteGlobalVariables::visit):
(WGSL::RewriteGlobalVariables::pack):
(WGSL::RewriteGlobalVariables::getPacking):
(WGSL::RewriteGlobalVariables::collectGlobals):
(WGSL::RewriteGlobalVariables::bufferLengthType):
(WGSL::RewriteGlobalVariables::bufferLengthReferenceType):
(WGSL::RewriteGlobalVariables::packStructResource):
(WGSL::RewriteGlobalVariables::packArrayResource):
(WGSL::RewriteGlobalVariables::updateReference):
(WGSL::RewriteGlobalVariables::packStructType):
(WGSL::RewriteGlobalVariables::packArrayType):
(WGSL::RewriteGlobalVariables::insertParameter):
(WGSL::RewriteGlobalVariables::visitEntryPoint):
(WGSL::RewriteGlobalVariables::createArgumentBufferEntry):
(WGSL::RewriteGlobalVariables::finalizeArgumentBufferStruct):
(WGSL::RewriteGlobalVariables::insertStructs):
(WGSL::RewriteGlobalVariables::insertDynamicOffsetsBufferIfNeeded):
(WGSL::RewriteGlobalVariables::insertMaterializations):
(WGSL::RewriteGlobalVariables::insertLocalDefinitions):
(WGSL::RewriteGlobalVariables::initializeVariables):
(WGSL::RewriteGlobalVariables::insertWorkgroupBarrier):
(WGSL::RewriteGlobalVariables::findOrInsertLocalInvocationIndex):
(WGSL::RewriteGlobalVariables::storeInitialValue):
(WGSL::rewriteGlobalVariables):
* Source/WebGPU/WGSL/GlobalVariableRewriter.h:
* Source/WebGPU/WGSL/MangleNames.cpp:
(WGSL::NameManglerVisitor::NameManglerVisitor):
(WGSL::NameManglerVisitor::run):
(WGSL::NameManglerVisitor::visit):
(WGSL::NameManglerVisitor::introduceVariable):
(WGSL::NameManglerVisitor::readVariable const):
(WGSL::mangleNames):
* Source/WebGPU/WGSL/MangleNames.h:
* Source/WebGPU/WGSL/Metal/MetalCodeGenerator.cpp:
(WGSL::Metal::generateMetalCode):
* Source/WebGPU/WGSL/Metal/MetalCodeGenerator.h:
* Source/WebGPU/WGSL/Metal/MetalFunctionWriter.cpp:
(WGSL::Metal::FunctionDefinitionWriter::FunctionDefinitionWriter):
(WGSL::Metal::FunctionDefinitionWriter::write):
(WGSL::Metal::FunctionDefinitionWriter::emitNecessaryHelpers):
(WGSL::Metal::FunctionDefinitionWriter::visit):
(WGSL::Metal::emitMetalFunctions):
* Source/WebGPU/WGSL/Metal/MetalFunctionWriter.h:
* Source/WebGPU/WGSL/PointerRewriter.cpp:
(WGSL::PointerRewriter::PointerRewriter):
(WGSL::PointerRewriter::run):
(WGSL::PointerRewriter::rewrite):
(WGSL::PointerRewriter::visit):
(WGSL::rewritePointers):
* Source/WebGPU/WGSL/PointerRewriter.h:
* Source/WebGPU/WGSL/WGSL.cpp:
(WGSL::staticCheck):
(WGSL::prepareImpl):
(WGSL::generate):
* Source/WebGPU/WGSL/WGSL.h:
* Source/WebGPU/WGSL/WGSLShaderModule.h:
(WGSL::ShaderModule::callGraph const):
(WGSL::ShaderModule::setCallGraph):
* Source/WebGPU/WGSL/wgslc.cpp:
(runWGSL):
* Source/WebGPU/WebGPU/Pipeline.mm:
(WebGPU::createLibrary):
* Source/WebGPU/WebGPU/ShaderModule.mm:
(WebGPU::earlyCompileShaderModule):

Canonical link: https://commits.webkit.org/275826@main
  • Loading branch information
tadeuzagallo committed Mar 8, 2024
1 parent b46e324 commit 9af7ee6
Show file tree
Hide file tree
Showing 21 changed files with 298 additions and 301 deletions.
33 changes: 10 additions & 23 deletions Source/WebGPU/WGSL/CallGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,45 +33,37 @@

namespace WGSL {

CallGraph::CallGraph(ShaderModule& shaderModule)
: m_ast(shaderModule)
{
}

class CallGraphBuilder : public AST::Visitor {
public:
CallGraphBuilder(ShaderModule& shaderModule, const HashMap<String, std::optional<PipelineLayout>>& pipelineLayouts, HashMap<String, Reflection::EntryPointInformation>& entryPoints)
: m_callGraph(shaderModule)
, m_pipelineLayouts(pipelineLayouts)
, m_entryPoints(entryPoints)
CallGraphBuilder(ShaderModule& shaderModule)
: m_shaderModule(shaderModule)
{
}

CallGraph build();
void build();

void visit(AST::Function&) override;
void visit(AST::CallExpression&) override;

private:
void initializeMappings();

ShaderModule& m_shaderModule;
CallGraph m_callGraph;
const HashMap<String, std::optional<PipelineLayout>>& m_pipelineLayouts;
HashMap<String, Reflection::EntryPointInformation>& m_entryPoints;
HashMap<AST::Function*, unsigned> m_calleeBuildingMap;
Vector<CallGraph::Callee>* m_callees { nullptr };
Deque<AST::Function*> m_queue;
};

CallGraph CallGraphBuilder::build()
void CallGraphBuilder::build()
{
initializeMappings();
return m_callGraph;
m_shaderModule.setCallGraph(WTFMove(m_callGraph));
}

void CallGraphBuilder::initializeMappings()
{
for (auto& declaration : m_callGraph.m_ast.declarations()) {
for (auto& declaration : m_shaderModule.declarations()) {
auto* function = dynamicDowncast<AST::Function>(declaration);
if (!function)
continue;
Expand All @@ -82,15 +74,10 @@ void CallGraphBuilder::initializeMappings()
ASSERT_UNUSED(result, result.isNewEntry);
}

if (!m_pipelineLayouts.contains(name))
continue;

if (!function->stage())
continue;

auto addResult = m_entryPoints.add(function->name(), Reflection::EntryPointInformation { });
ASSERT(addResult.isNewEntry);
m_callGraph.m_entrypoints.append({ *function, *function->stage(), addResult.iterator->value });
m_callGraph.m_entrypoints.append({ *function, *function->stage(), function->name() });
m_queue.append(function);
}

Expand Down Expand Up @@ -132,9 +119,9 @@ void CallGraphBuilder::visit(AST::CallExpression& call)
m_callees->at(result.iterator->value).callSites.append(&call);
}

CallGraph buildCallGraph(ShaderModule& shaderModule, const HashMap<String, std::optional<PipelineLayout>>& pipelineLayouts, HashMap<String, Reflection::EntryPointInformation>& entryPoints)
void buildCallGraph(ShaderModule& shaderModule)
{
return CallGraphBuilder(shaderModule, pipelineLayouts, entryPoints).build();
CallGraphBuilder(shaderModule).build();
}

} // namespace WGSL
16 changes: 5 additions & 11 deletions Source/WebGPU/WGSL/CallGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,18 @@

#pragma once

// FIXME: move Stage out of StageAttribute so we don't need to include this
#include "ASTForward.h"
#include "ASTStageAttribute.h"
#include "WGSLEnums.h"
#include <wtf/HashMap.h>
#include <wtf/Vector.h>
#include <wtf/text/WTFString.h>

namespace WGSL {

class ShaderModule;
struct PipelineLayout;
struct PrepareResult;

namespace Reflection {
struct EntryPointInformation;
}

class CallGraph {
friend class CallGraphBuilder;

Expand All @@ -53,22 +49,20 @@ class CallGraph {
struct EntryPoint {
AST::Function& function;
ShaderStage stage;
Reflection::EntryPointInformation& information;
String originalName;
};

ShaderModule& ast() const { return m_ast; }
const Vector<EntryPoint>& entrypoints() const { return m_entrypoints; }
const Vector<Callee>& callees(AST::Function& function) const { return m_calleeMap.find(&function)->value; }

private:
CallGraph(ShaderModule&);
CallGraph() { }

ShaderModule& m_ast;
Vector<EntryPoint> m_entrypoints;
HashMap<String, AST::Function*> m_functionsByName;
HashMap<AST::Function*, Vector<Callee>> m_calleeMap;
};

CallGraph buildCallGraph(ShaderModule&, const HashMap<String, std::optional<PipelineLayout>>& pipelineLayouts, HashMap<String, Reflection::EntryPointInformation>&);
void buildCallGraph(ShaderModule&);

} // namespace WGSL
33 changes: 7 additions & 26 deletions Source/WebGPU/WGSL/EntryPointRewriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@

#include "AST.h"
#include "ASTVisitor.h"
#include "CallGraph.h"
#include "TypeStore.h"
#include "Types.h"
#include "WGSL.h"
Expand All @@ -38,7 +37,7 @@ namespace WGSL {

class EntryPointRewriter {
public:
EntryPointRewriter(ShaderModule&, const AST::Function&, ShaderStage, Reflection::EntryPointInformation&);
EntryPointRewriter(ShaderModule&, const AST::Function&, ShaderStage);

void rewrite();

Expand Down Expand Up @@ -75,34 +74,14 @@ class EntryPointRewriter {
const Type* m_structType;
String m_structTypeName;
String m_structParameterName;
Reflection::EntryPointInformation& m_information;
unsigned m_builtinID { 0 };
};

EntryPointRewriter::EntryPointRewriter(ShaderModule& shaderModule, const AST::Function& function, ShaderStage stage, Reflection::EntryPointInformation& information)
EntryPointRewriter::EntryPointRewriter(ShaderModule& shaderModule, const AST::Function& function, ShaderStage stage)
: m_stage(stage)
, m_shaderModule(shaderModule)
, m_function(function)
, m_information(information)
{
switch (m_stage) {
case ShaderStage::Compute: {
for (auto& attribute : function.attributes()) {
auto* workgroupSize = dynamicDowncast<AST::WorkgroupSizeAttribute>(attribute);
if (!workgroupSize)
continue;
m_information.typedEntryPoint = Reflection::Compute { &workgroupSize->x(), workgroupSize->maybeY(), workgroupSize->maybeZ() };
break;
}
break;
}
case ShaderStage::Vertex:
m_information.typedEntryPoint = Reflection::Vertex { false };
break;
case ShaderStage::Fragment:
m_information.typedEntryPoint = Reflection::Fragment { };
break;
}
}

void EntryPointRewriter::rewrite()
Expand Down Expand Up @@ -370,10 +349,12 @@ void EntryPointRewriter::appendBuiltins()
}
}

void rewriteEntryPoints(CallGraph& callGraph)
void rewriteEntryPoints(ShaderModule& shaderModule, const HashMap<String, std::optional<PipelineLayout>>& pipelineLayouts)
{
for (auto& entryPoint : callGraph.entrypoints()) {
EntryPointRewriter rewriter(callGraph.ast(), entryPoint.function, entryPoint.stage, entryPoint.information);
for (auto& entryPoint : shaderModule.callGraph().entrypoints()) {
if (!pipelineLayouts.contains(entryPoint.originalName))
continue;
EntryPointRewriter rewriter(shaderModule, entryPoint.function, entryPoint.stage);
rewriter.rewrite();
}
}
Expand Down
8 changes: 5 additions & 3 deletions Source/WebGPU/WGSL/EntryPointRewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@

#pragma once

#include <wtf/text/WTFString.h>

namespace WGSL {

class CallGraph;
struct PrepareResult;
class ShaderModule;
struct PipelineLayout;

void rewriteEntryPoints(CallGraph&);
void rewriteEntryPoints(ShaderModule&, const HashMap<String, std::optional<PipelineLayout>>&);

} // namespace WGSL
Loading

0 comments on commit 9af7ee6

Please sign in to comment.