Skip to content

Commit

Permalink
[WGSL] Add support for built-in type aliases
Browse files Browse the repository at this point in the history
https://bugs.webkit.org/show_bug.cgi?id=260354
rdar://114033423

Reviewed by Dan Glastonbury.

Add support for the vecN(f|i|u) type aliases.

* Source/WebGPU/WGSL/Metal/MetalFunctionWriter.cpp:
(WGSL::Metal::FunctionDefinitionWriter::visit):
* Source/WebGPU/WGSL/TypeCheck.cpp:
(WGSL::TypeChecker::visit):
* Source/WebGPU/WGSL/TypeDeclarations.rb:
* Source/WebGPU/WGSL/generator/main.rb:
* Source/WebGPU/WGSL/tests/valid/aliases.wgsl: Added.

Canonical link: https://commits.webkit.org/267036@main
  • Loading branch information
tadeuzagallo committed Aug 18, 2023
1 parent 5326da2 commit a5dac91
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 2 deletions.
11 changes: 10 additions & 1 deletion Source/WebGPU/WGSL/Metal/MetalFunctionWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,16 @@ void FunctionDefinitionWriter::visit(const Type* type, AST::CallExpression& call
static constexpr std::pair<ComparableASCIILiteral, ASCIILiteral> baseTypesMappings[] {
{ "f32", "float"_s },
{ "i32", "int"_s },
{ "u32", "unsigned"_s }
{ "u32", "uint"_s },
{ "vec2f", "float2"_s },
{ "vec2i", "int2"_s },
{ "vec2u", "uint2"_s },
{ "vec3f", "float3"_s },
{ "vec3i", "int3"_s },
{ "vec3u", "uint3"_s },
{ "vec4f", "float4"_s },
{ "vec4i", "int4"_s },
{ "vec4u", "uint4"_s }
};
static constexpr SortedArrayMap baseTypes { baseTypesMappings };

Expand Down
17 changes: 17 additions & 0 deletions Source/WebGPU/WGSL/TypeCheck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,23 @@ void TypeChecker::visit(AST::CallExpression& call)
inferred(targetBinding->type);
return;
}

if (auto* vectorType = std::get_if<Types::Vector>(targetBinding->type)) {
typeArguments.append(vectorType->element);
switch (vectorType->size) {
case 2:
targetName = "vec2"_s;
break;
case 3:
targetName = "vec3"_s;
break;
case 4:
targetName = "vec4"_s;
break;
default:
RELEASE_ASSERT_NOT_REACHED();
}
}
}

if (targetBinding->kind == Binding::Value) {
Expand Down
13 changes: 13 additions & 0 deletions Source/WebGPU/WGSL/TypeDeclarations.rb
Original file line number Diff line number Diff line change
@@ -1,4 +1,17 @@
# FIXME: add all the missing type declarations here

suffixes = {
f: F32,
i: I32,
u: U32,
}

[2, 3, 4].each do |n|
suffixes.each do |suffix, type|
type_alias :"vec#{n}#{suffix}", Vector[type, n]
end
end

operator :+, {
[T < Number].(T, T) => T,

Expand Down
15 changes: 14 additions & 1 deletion Source/WebGPU/WGSL/generator/main.rb
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,11 @@ def to_s
"#{name.to_s}[#{arguments.map(&:to_s).join(", ")}]"
end

def concrete_type
"m_types.#{name}Type(#{arguments.map { |a| a.respond_to? :to_cpp and a.to_cpp or a}.join ", "})"
end
def to_cpp
"AbstractType { m_types.#{name}Type(#{arguments.map { |a| a.respond_to? :to_cpp and a.to_cpp or a}.join ", "}) }"
"AbstractType { #{concrete_type} }"
end
end

Expand Down Expand Up @@ -221,6 +224,7 @@ def call(*arguments)

module DSL
@context = binding()
@aliases = {}
@operators = {}
@TypeVariable = VariableKind.new(:TypeVariable)
@NumericVariable = VariableKind.new(:NumericVariable)
Expand All @@ -247,8 +251,17 @@ def self.operator(name, map)
end
end

def self.type_alias(name, type)
@aliases[name] = type
end

def self.to_cpp
out = []

@aliases.each do |name, type|
out << "introduceType(AST::Identifier::make(\"#{name}\"_s), #{type.concrete_type});"
end

@operators.each do |name, overloads|
out << "m_overloadedOperations.add(\"#{name}\"_s, Vector<OverloadCandidate>({"
overloads.each { |function| out << "#{function.to_cpp(name)}," }
Expand Down
24 changes: 24 additions & 0 deletions Source/WebGPU/WGSL/tests/valid/aliases.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: %metal-compile main

fn f1(x: vec2f) {}
fn f2(x: vec2i) {}
fn f3(x: vec2u) {}
fn f4(x: vec3f) {}
fn f5(x: vec3i) {}
fn f6(x: vec3u) {}
fn f7(x: vec4f) {}
fn f8(x: vec4i) {}
fn f9(x: vec4u) {}

@compute @workgroup_size(1)
fn main() {
_ = f1(vec2f(vec2(0u)));
_ = f2(vec2i(vec2(0f)));
_ = f3(vec2u(vec2(0f)));
_ = f4(vec3f(vec3(0u)));
_ = f5(vec3i(vec3(0f)));
_ = f6(vec3u(vec3(0f)));
_ = f7(vec4f(vec4(0u)));
_ = f8(vec4i(vec4(0f)));
_ = f9(vec4u(vec4(0f)));
}

0 comments on commit a5dac91

Please sign in to comment.