Skip to content

Commit

Permalink
[WGSL] Generate the correct code for arrayLength
Browse files Browse the repository at this point in the history
https://bugs.webkit.org/show_bug.cgi?id=262703
rdar://116524838

Reviewed by Mike Wyrzykowski.

As of 268884@main, the API now exposes the size of runtime-sized arrays at the
end of the argument buffer. This patch updates the compiler to match this argument
buffer layout and rewrite calls to arrayLength to read these sizes.

* Source/WebGPU/WGSL/GlobalVariableRewriter.cpp:
(WGSL::RewriteGlobalVariables::getPacking):
(WGSL::RewriteGlobalVariables::collectGlobals):
(WGSL::RewriteGlobalVariables::containsRuntimeArray):
(WGSL::RewriteGlobalVariables::packArrayType):
* Source/WebGPU/WGSL/Metal/MetalFunctionWriter.cpp:
(WGSL::Metal::emitArrayLength):
(WGSL::Metal::FunctionDefinitionWriter::visit):
* Source/WebGPU/WGSL/tests/valid/array-length.wgsl: Added.

Canonical link: https://commits.webkit.org/268995@main
  • Loading branch information
tadeuzagallo committed Oct 6, 2023
1 parent c9c17fb commit 8f80944
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 1 deletion.
3 changes: 3 additions & 0 deletions Source/WebGPU/WGSL/AST/ASTVariable.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "ASTVariableQualifier.h"

namespace WGSL {
class RewriteGlobalVariables;
class TypeChecker;
struct Type;

Expand All @@ -47,7 +48,9 @@ enum class VariableFlavor : uint8_t {

class Variable final : public Declaration {
WGSL_AST_BUILDER_NODE(Variable);
friend RewriteGlobalVariables;
friend TypeChecker;

public:
using Ref = std::reference_wrapper<Variable>;
using List = ReferenceWrapperVector<Variable>;
Expand Down
88 changes: 87 additions & 1 deletion Source/WebGPU/WGSL/GlobalVariableRewriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class RewriteGlobalVariables : public AST::Visitor {
void packResource(AST::Variable&);
void packArrayResource(AST::Variable&, const Types::Array*);
void packStructResource(AST::Variable&, const Types::Struct*);
bool containsRuntimeArray(const Type*);
const Type* packType(const Type*);
const Type* packStructType(const Types::Struct*);
const Type* packArrayType(const Types::Array*);
Expand Down Expand Up @@ -422,6 +423,37 @@ auto RewriteGlobalVariables::getPacking(AST::UnaryExpression& expression) -> Pac

auto RewriteGlobalVariables::getPacking(AST::CallExpression& call) -> Packing
{
if (is<AST::IdentifierExpression>(call.target())) {
auto& target = downcast<AST::IdentifierExpression>(call.target());
if (target.identifier() == "arrayLength"_s) {
ASSERT(call.arguments().size() == 1);
const auto& getBase = [&](auto&& getBase, AST::Expression& expression) -> AST::Expression& {
if (is<AST::IdentityExpression>(expression))
return getBase(getBase, downcast<AST::IdentityExpression>(expression).expression());
if (is<AST::UnaryExpression>(expression))
return getBase(getBase, downcast<AST::UnaryExpression>(expression).expression());
if (is<AST::FieldAccessExpression>(expression))
return getBase(getBase, downcast<AST::FieldAccessExpression>(expression).base());
if (is<AST::IdentifierExpression>(expression))
return expression;
RELEASE_ASSERT_NOT_REACHED();
};
auto& base = getBase(getBase, call.arguments()[0]);
ASSERT(is<AST::IdentifierExpression>(base));
auto& identifier = downcast<AST::IdentifierExpression>(base).identifier();
ASSERT(m_globals.contains(identifier));
auto lengthName = makeString("__", identifier, "_ArrayLength");
auto& length = m_callGraph.ast().astBuilder().construct<AST::IdentifierExpression>(
SourceSpan::empty(),
AST::Identifier::make(lengthName)
);
length.m_inferredType = m_callGraph.ast().types().u32Type();
m_callGraph.ast().replace(call, length);
visit(base); // we also need to mark the array as read
return getPacking(length);
}
}

for (auto& argument : call.arguments())
pack(Packing::Unpacked, argument);
return Packing::Unpacked;
Expand All @@ -447,6 +479,7 @@ auto RewriteGlobalVariables::packingForType(const Type* type) -> Packing
void RewriteGlobalVariables::collectGlobals()
{
auto& globalVars = m_callGraph.ast().variables();
Vector<std::tuple<AST::Variable*, unsigned>> bufferLengths;
for (auto& globalVar : globalVars) {
std::optional<unsigned> group;
std::optional<unsigned> binding;
Expand Down Expand Up @@ -479,6 +512,45 @@ void RewriteGlobalVariables::collectGlobals()
auto result = m_groupBindingMap.add(resource->group, Vector<std::pair<unsigned, String>>());
result.iterator->value.append({ resource->binding, globalVar.name() });
packResource(globalVar);

if (containsRuntimeArray(globalVar.maybeReferenceType()->inferredType()))
bufferLengths.append({ &globalVar, *group });
}
}

if (!bufferLengths.isEmpty()) {
auto& type = m_callGraph.ast().astBuilder().construct<AST::IdentifierExpression>(SourceSpan::empty(), AST::Identifier::make("u32"_s));
type.m_inferredType = m_callGraph.ast().types().u32Type();
auto& referenceType = m_callGraph.ast().astBuilder().construct<AST::ReferenceTypeExpression>(
SourceSpan::empty(),
type
);
referenceType.m_inferredType = m_callGraph.ast().types().referenceType(AddressSpace::Handle, m_callGraph.ast().types().u32Type(), AccessMode::Read);

for (const auto& [variable, group] : bufferLengths) {
auto name = AST::Identifier::make(makeString("__", variable->name(), "_ArrayLength"));
auto& lengthVariable = m_callGraph.ast().astBuilder().construct<AST::Variable>(
SourceSpan::empty(),
AST::VariableFlavor::Var,
AST::Identifier::make(name),
&type,
nullptr
);
lengthVariable.m_referenceType = &referenceType;

auto it = m_groupBindingMap.find(group);
ASSERT(it != m_groupBindingMap.end());

auto binding = it->value.last().first + 1;
auto result = m_globals.add(name, Global {
{ {
group,
binding,
} },
&lengthVariable
});
ASSERT_UNUSED(result, result.isNewEntry);
it->value.append({ binding, name });
}
}
}
Expand Down Expand Up @@ -561,6 +633,17 @@ void RewriteGlobalVariables::updateReference(AST::Variable& global, AST::Express
m_callGraph.ast().replace(reference, packedTypeReference);
}

bool RewriteGlobalVariables::containsRuntimeArray(const Type* type)
{
if (auto* referenceType = std::get_if<Types::Reference>(type))
return containsRuntimeArray(referenceType->element);
if (auto* structType = std::get_if<Types::Struct>(type))
return containsRuntimeArray(structType->structure.members().last().type().inferredType());
if (auto* arrayType = std::get_if<Types::Array>(type))
return !arrayType->size.has_value();
return false;
}

const Type* RewriteGlobalVariables::packType(const Type* type)
{
if (auto* structType = std::get_if<Types::Struct>(type))
Expand Down Expand Up @@ -615,9 +698,12 @@ const Type* RewriteGlobalVariables::packArrayType(const Types::Array* arrayType)
if (!structType)
return nullptr;

const Type* packedStructType = packStructType(structType);
if (!packedStructType)
return nullptr;

m_callGraph.ast().setUsesUnpackArray();
m_callGraph.ast().setUsesPackArray();
const Type* packedStructType = packStructType(structType);
return m_callGraph.ast().types().arrayType(packedStructType, arrayType->size);
}

Expand Down
6 changes: 6 additions & 0 deletions Source/WebGPU/WGSL/Metal/MetalFunctionWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1267,6 +1267,11 @@ static void emitAtomicExchange(FunctionDefinitionWriter* writer, AST::CallExpres
atomicFunction("atomic_exchange_explicit", writer, call);
}

static void emitArrayLength(FunctionDefinitionWriter* writer, AST::CallExpression& call)
{
writer->visit(call.arguments()[0]);
writer->stringBuilder().append(".size()");
}

void FunctionDefinitionWriter::visit(const Type* type, AST::CallExpression& call)
{
Expand Down Expand Up @@ -1299,6 +1304,7 @@ void FunctionDefinitionWriter::visit(const Type* type, AST::CallExpression& call

if (is<AST::IdentifierExpression>(call.target())) {
static constexpr std::pair<ComparableASCIILiteral, void(*)(FunctionDefinitionWriter*, AST::CallExpression&)> builtinMappings[] {
{ "arrayLength", emitArrayLength },
{ "atomicAdd", emitAtomicAdd },
{ "atomicExchange", emitAtomicExchange },
{ "atomicLoad", emitAtomicLoad },
Expand Down
39 changes: 39 additions & 0 deletions Source/WebGPU/WGSL/tests/valid/array-length.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// RUN: %metal-compile main

struct S {
x: array<f32>,
}

struct Packed {
x: vec3<f32>
}


struct T {
x: array<Packed>,
}

@group(0) @binding(0) var<storage, read_write> x: array<f32>;
@group(0) @binding(1) var<storage, read_write> y: S;
@group(0) @binding(2) var<storage, read_write> z: T;

fn f() -> u32 {
let x1 = arrayLength(&x);
let y1 = arrayLength(&y.x);
let z1 = arrayLength(&z.x);

let xptr = &x;
let yptr = &y.x;
let zptr = &z.x;

let x2 = arrayLength(xptr);
let y2 = arrayLength(yptr);
let z2 = arrayLength(zptr);

return x1 + y1 + z1 + x2 + y2 + z2;
}

@compute @workgroup_size(1, 1, 1)
fn main() {
let x = f();
}

0 comments on commit 8f80944

Please sign in to comment.