Skip to content

Commit

Permalink
[WGSL] Add support for address-of and indirection operators
Browse files Browse the repository at this point in the history
https://bugs.webkit.org/show_bug.cgi?id=261833
<rdar://problem/115795658>

Reviewed by Dan Glastonbury.

Add the type declarations for address-of (&x, spec[1]) and indirection (*x, spec[2])
operators. This required introducing AbstractReference and AbstractPointer to the
overload resolution algorithm and exposing them to the ruby DSL, as well as exposing
address space and access mode, which was done by generalizing NumericVariables to
use the underlying enum type.

[1]: https://www.w3.org/TR/WGSL/#address-of-expr
[2]: https://www.w3.org/TR/WGSL/#indirection-expr

* Source/WebGPU/WGSL/Constraints.cpp:
(WGSL::satisfies):
(WGSL::satisfyOrPromote):
* Source/WebGPU/WGSL/Metal/MetalFunctionWriter.cpp:
(WGSL::Metal::FunctionDefinitionWriter::visit):
* Source/WebGPU/WGSL/Overload.cpp:
(WGSL::OverloadResolver::OverloadResolver):
(WGSL::OverloadResolver::materialize const):
(WGSL::OverloadResolver::considerCandidate):
(WGSL::OverloadResolver::calculateRank):
(WGSL::OverloadResolver::unify):
(WGSL::OverloadResolver::assign):
(WGSL::OverloadResolver::resolve const):
(WGSL::resolveOverloads):
(WTF::printInternal):
* Source/WebGPU/WGSL/Overload.h:
* Source/WebGPU/WGSL/TypeDeclarations.rb:
* Source/WebGPU/WGSL/generator/main.rb:
* Source/WebGPU/WGSL/tests/valid/overload.wgsl:
* Source/WebGPU/WGSL/tests/valid/pointers.wgsl:

Canonical link: https://commits.webkit.org/268245@main
  • Loading branch information
tadeuzagallo committed Sep 21, 2023
1 parent a76cf09 commit 3d505c7
Show file tree
Hide file tree
Showing 8 changed files with 176 additions and 42 deletions.
6 changes: 6 additions & 0 deletions Source/WebGPU/WGSL/Constraints.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ namespace WGSL {

bool satisfies(const Type* type, Constraint constraint)
{
if (constraint == Constraints::None)
return true;

auto* primitive = std::get_if<Types::Primitive>(type);
if (!primitive) {
if (auto* reference = std::get_if<Types::Reference>(type))
Expand Down Expand Up @@ -67,6 +70,9 @@ bool satisfies(const Type* type, Constraint constraint)

const Type* satisfyOrPromote(const Type* type, Constraint constraint, const TypeStore& types)
{
if (constraint == Constraints::None)
return type;

auto* primitive = std::get_if<Types::Primitive>(type);
if (!primitive) {
if (auto* reference = std::get_if<Types::Reference>(type))
Expand Down
17 changes: 7 additions & 10 deletions Source/WebGPU/WGSL/Metal/MetalFunctionWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,7 @@ void FunctionDefinitionWriter::visit(const Type* type)
const char* addressSpace = nullptr;
switch (pointer.addressSpace) {
case AddressSpace::Function:
case AddressSpace::Private:
addressSpace = "thread";
break;
case AddressSpace::Workgroup:
Expand All @@ -737,16 +738,12 @@ void FunctionDefinitionWriter::visit(const Type* type)
addressSpace = "device";
break;
case AddressSpace::Handle:
case AddressSpace::Private:
break;
}
if (!addressSpace) {
visit(pointer.element);
return;
RELEASE_ASSERT_NOT_REACHED();
}
if (pointer.accessMode == AccessMode::Read)
m_stringBuilder.append("const ");
m_stringBuilder.append(addressSpace, " ");
if (addressSpace)
m_stringBuilder.append(addressSpace, " ");
visit(pointer.element);
m_stringBuilder.append("*");
},
Expand Down Expand Up @@ -1025,11 +1022,11 @@ void FunctionDefinitionWriter::visit(AST::UnaryExpression& unary)
case AST::UnaryOperation::Not:
m_stringBuilder.append("!");
break;

case AST::UnaryOperation::AddressOf:
m_stringBuilder.append("&");
break;
case AST::UnaryOperation::Dereference:
// FIXME: Implement these
RELEASE_ASSERT_NOT_REACHED();
m_stringBuilder.append("*");
break;
}
visit(unary.expression());
Expand Down
100 changes: 83 additions & 17 deletions Source/WebGPU/WGSL/Overload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,16 @@ class OverloadResolver {
const Type* materialize(const AbstractScalarType&) const;

bool unify(const AbstractValue&, unsigned);
void assign(NumericVariable, unsigned);
std::optional<unsigned> resolve(NumericVariable) const;
void assign(ValueVariable, unsigned);
std::optional<unsigned> resolve(ValueVariable) const;
unsigned materialize(const AbstractValue&) const;

TypeStore& m_types;
const Vector<OverloadCandidate>& m_candidates;
const Vector<const Type*>& m_valueArguments;
const Vector<const Type*>& m_typeArguments;
FixedVector<const Type*> m_typeSubstitutions;
FixedVector<std::optional<unsigned>> m_numericSubstitutions;
FixedVector<std::optional<unsigned>> m_valueSubstitutions;
};

OverloadResolver::OverloadResolver(TypeStore& types, const Vector<OverloadCandidate>& candidates, const Vector<const Type*>& valueArguments, const Vector<const Type*>& typeArguments, unsigned numberOfTypeSubstitutions, unsigned numberOfValueSubstitutions)
Expand All @@ -95,7 +95,7 @@ OverloadResolver::OverloadResolver(TypeStore& types, const Vector<OverloadCandid
, m_valueArguments(valueArguments)
, m_typeArguments(typeArguments)
, m_typeSubstitutions(numberOfTypeSubstitutions)
, m_numericSubstitutions(numberOfValueSubstitutions)
, m_valueSubstitutions(numberOfValueSubstitutions)
{
}

Expand Down Expand Up @@ -172,6 +172,22 @@ const Type* OverloadResolver::materialize(const AbstractType& abstractType) cons
if (auto* element = materialize(texture.element))
return m_types.textureType(element, texture.kind);
return nullptr;
},
[&](const AbstractReference& reference) -> const Type* {
if (auto* element = materialize(reference.element)) {
auto addressSpace = materialize(reference.addressSpace);
auto accessMode = materialize(reference.accessMode);
return m_types.referenceType(static_cast<AddressSpace>(addressSpace), element, static_cast<AccessMode>(accessMode));
}
return nullptr;
},
[&](const AbstractPointer& pointer) -> const Type* {
if (auto* element = materialize(pointer.element)) {
auto addressSpace = materialize(pointer.addressSpace);
auto accessMode = materialize(pointer.accessMode);
return m_types.pointerType(static_cast<AddressSpace>(addressSpace), element, static_cast<AccessMode>(accessMode));
}
return nullptr;
});
}

Expand All @@ -197,7 +213,7 @@ unsigned OverloadResolver::materialize(const AbstractValue& abstractValue) const
[&](unsigned value) -> unsigned {
return value;
},
[&](NumericVariable variable) -> unsigned {
[&](ValueVariable variable) -> unsigned {
std::optional<unsigned> resolvedValue = resolve(variable);
ASSERT(resolvedValue.has_value());
return *resolvedValue;
Expand All @@ -221,7 +237,7 @@ std::optional<ViableOverload> OverloadResolver::considerCandidate(const Overload
return std::nullopt;

m_typeSubstitutions.fill(nullptr);
m_numericSubstitutions.fill(std::nullopt);
m_valueSubstitutions.fill(std::nullopt);

for (unsigned i = 0; i < m_typeArguments.size(); ++i) {
if (!assign(candidate.typeVariables[i], m_typeArguments[i]))
Expand Down Expand Up @@ -283,6 +299,16 @@ ConversionRank OverloadResolver::calculateRank(const AbstractType& parameter, co
return conversionRank(argumentType, resolvedType);
}

if (auto* referenceParameter = std::get_if<AbstractReference>(&parameter)) {
auto& referenceArgument = std::get<Types::Reference>(*argumentType);
return calculateRank(referenceParameter->element, referenceArgument.element);
}

if (auto* pointerParameter = std::get_if<AbstractPointer>(&parameter)) {
auto& pointerArgument = std::get<Types::Pointer>(*argumentType);
return calculateRank(pointerParameter->element, pointerArgument.element);
}

if (auto* reference = std::get_if<Types::Reference>(argumentType)) {
ASSERT(reference->accessMode != AccessMode::Write);
return calculateRank(parameter, reference->element);
Expand Down Expand Up @@ -368,6 +394,28 @@ bool OverloadResolver::unify(const AbstractType& parameter, const Type* argument
if (auto* variable = std::get_if<TypeVariable>(&parameter))
return unify(variable, argumentType);

if (auto* referenceParameter = std::get_if<AbstractReference>(&parameter)) {
auto* referenceArgument = std::get_if<Types::Reference>(argumentType);
if (!referenceArgument)
return false;
if (!unify(referenceParameter->addressSpace, WTF::enumToUnderlyingType(referenceArgument->addressSpace)))
return false;
if (!unify(referenceParameter->accessMode, WTF::enumToUnderlyingType(referenceArgument->accessMode)))
return false;
return unify(referenceParameter->element, referenceArgument->element);
}

if (auto* pointerParameter = std::get_if<AbstractPointer>(&parameter)) {
auto* pointerArgument = std::get_if<Types::Pointer>(argumentType);
if (!pointerArgument)
return false;
if (!unify(pointerParameter->addressSpace, WTF::enumToUnderlyingType(pointerArgument->addressSpace)))
return false;
if (!unify(pointerParameter->accessMode, WTF::enumToUnderlyingType(pointerArgument->accessMode)))
return false;
return unify(pointerParameter->element, pointerArgument->element);
}

if (auto* reference = std::get_if<Types::Reference>(argumentType)) {
if (reference->accessMode == AccessMode::Write)
return false;
Expand Down Expand Up @@ -422,7 +470,7 @@ bool OverloadResolver::unify(const AbstractValue& parameter, unsigned argumentVa
if (auto* parameterValue = std::get_if<unsigned>(&parameter))
return *parameterValue == argumentValue;

auto variable = std::get<NumericVariable>(parameter);
auto variable = std::get<ValueVariable>(parameter);
auto resolvedValue = resolve(variable);
if (!resolvedValue.has_value()) {
assign(variable, argumentValue);
Expand All @@ -442,20 +490,20 @@ bool OverloadResolver::assign(TypeVariable variable, const Type* type)
return true;
}

void OverloadResolver::assign(NumericVariable variable, unsigned value)
void OverloadResolver::assign(ValueVariable variable, unsigned value)
{
logLn("assign ", variable, " => ", value);
m_numericSubstitutions[variable.id] = { value };
m_valueSubstitutions[variable.id] = { value };
}

const Type* OverloadResolver::resolve(TypeVariable variable) const
{
return m_typeSubstitutions[variable.id];
}

std::optional<unsigned> OverloadResolver::resolve(NumericVariable variable) const
std::optional<unsigned> OverloadResolver::resolve(ValueVariable variable) const
{
return m_numericSubstitutions[variable.id];
return m_valueSubstitutions[variable.id];
}

ConversionRank OverloadResolver::conversionRank(const Type* from, const Type* to) const
Expand All @@ -472,7 +520,7 @@ std::optional<SelectedOverload> resolveOverloads(TypeStore& types, const Vector<
unsigned numberOfValueSubstitutions = 0;
for (const auto& candidate : candidates) {
numberOfTypeSubstitutions = std::max(numberOfTypeSubstitutions, static_cast<unsigned>(candidate.typeVariables.size()));
numberOfValueSubstitutions = std::max(numberOfValueSubstitutions, static_cast<unsigned>(candidate.numericVariables.size()));
numberOfValueSubstitutions = std::max(numberOfValueSubstitutions, static_cast<unsigned>(candidate.valueVariables.size()));
}
OverloadResolver resolver(types, candidates, valueArguments, typeArguments, numberOfTypeSubstitutions, numberOfValueSubstitutions);
return resolver.resolve();
Expand All @@ -482,7 +530,7 @@ std::optional<SelectedOverload> resolveOverloads(TypeStore& types, const Vector<

namespace WTF {

void printInternal(PrintStream& out, const WGSL::NumericVariable& variable)
void printInternal(PrintStream& out, const WGSL::ValueVariable& variable)
{
out.print("val", variable.id);
}
Expand All @@ -493,7 +541,7 @@ void printInternal(PrintStream& out, const WGSL::AbstractValue& value)
[&](unsigned value) {
out.print(value);
},
[&](WGSL::NumericVariable variable) {
[&](WGSL::ValueVariable variable) {
printInternal(out, variable);
});
}
Expand Down Expand Up @@ -533,6 +581,24 @@ void printInternal(PrintStream& out, const WGSL::AbstractType& type)
out.print("<");
printInternal(out, texture.element);
out.print(">");
},
[&](const WGSL::AbstractReference& reference) {
out.print("ref<");
printInternal(out, reference.addressSpace);
out.print(", ");
printInternal(out, reference.element);
out.print(", ");
printInternal(out, reference.accessMode);
out.print(">");
},
[&](const WGSL::AbstractPointer& pointer) {
out.print("ptr<");
printInternal(out, pointer.addressSpace);
out.print(", ");
printInternal(out, pointer.element);
out.print(", ");
printInternal(out, pointer.accessMode);
out.print(">");
});
}

Expand All @@ -549,7 +615,7 @@ void printInternal(PrintStream& out, const WGSL::AbstractScalarType& type)

void printInternal(PrintStream& out, const WGSL::OverloadCandidate& candidate)
{
if (candidate.typeVariables.size() || candidate.numericVariables.size()) {
if (candidate.typeVariables.size() || candidate.valueVariables.size()) {
bool first = true;
out.print("<");
for (auto& typeVariable : candidate.typeVariables) {
Expand All @@ -558,11 +624,11 @@ void printInternal(PrintStream& out, const WGSL::OverloadCandidate& candidate)
first = false;
printInternal(out, typeVariable);
}
for (auto& numericVariable : candidate.numericVariables) {
for (auto& valueVariable : candidate.valueVariables) {
if (!first)
out.print(", ");
first = false;
printInternal(out, numericVariable);
printInternal(out, valueVariable);
}
out.print(">");
}
Expand Down
25 changes: 20 additions & 5 deletions Source/WebGPU/WGSL/Overload.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,10 @@

namespace WGSL {

struct NumericVariable {
struct ValueVariable {
unsigned id;
};

using AbstractValue = std::variant<NumericVariable, unsigned>;
using AbstractValue = std::variant<ValueVariable, unsigned>;

struct TypeVariable {
unsigned id;
Expand All @@ -44,10 +43,14 @@ struct TypeVariable {
struct AbstractVector;
struct AbstractMatrix;
struct AbstractTexture;
struct AbstractReference;
struct AbstractPointer;
using AbstractType = std::variant<
AbstractVector,
AbstractMatrix,
AbstractTexture,
AbstractReference,
AbstractPointer,
TypeVariable,
const Type*
>;
Expand All @@ -73,9 +76,21 @@ struct AbstractTexture {
Types::Texture::Kind kind;
};

struct AbstractReference {
AbstractValue addressSpace;
AbstractScalarType element;
AbstractValue accessMode;
};

struct AbstractPointer {
AbstractValue addressSpace;
AbstractScalarType element;
AbstractValue accessMode;
};

struct OverloadCandidate {
Vector<TypeVariable, 1> typeVariables;
Vector<NumericVariable, 2> numericVariables;
Vector<ValueVariable, 2> valueVariables;
Vector<AbstractType, 2> parameters;
AbstractType result;
};
Expand All @@ -90,7 +105,7 @@ std::optional<SelectedOverload> resolveOverloads(TypeStore&, const Vector<Overlo
} // namespace WGSL

namespace WTF {
void printInternal(PrintStream&, const WGSL::NumericVariable&);
void printInternal(PrintStream&, const WGSL::ValueVariable&);
void printInternal(PrintStream&, const WGSL::AbstractValue&);
void printInternal(PrintStream&, const WGSL::TypeVariable&);
void printInternal(PrintStream&, const WGSL::AbstractType&);
Expand Down
8 changes: 8 additions & 0 deletions Source/WebGPU/WGSL/TypeDeclarations.rb
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@
}

operator :*, {
# unary
[AS, T, AM].(Ptr[AS, T, AM]) => Ref[AS, T, AM],

# binary
[T < Number].(T, T) => T,

# vector scaling
Expand Down Expand Up @@ -111,6 +115,10 @@
}

operator :'&', {
# unary
[AS, T, AM].(Ref[AS, T, AM]) => Ptr[AS, T, AM],

# binary
[].(Bool, Bool) => Bool,
[N].(Vector[Bool, N], Vector[Bool, N]) => Vector[Bool, N],
}
Expand Down
Loading

0 comments on commit 3d505c7

Please sign in to comment.