Skip to content

Commit

Permalink
[WGSL] Add support for frexp and primitive structs
Browse files Browse the repository at this point in the history
https://bugs.webkit.org/show_bug.cgi?id=264336
rdar://118057393

Reviewed by Mike Wyrzykowski.

Initial support for operations that return built-in structs. For now, there's no
support for constant evaluation yet, since that requires constant structs, which
I will implement next.

* Source/WebGPU/WGSL/ConstantFunctions.h:
(WGSL::zeroValue):
* Source/WebGPU/WGSL/Constraints.cpp:
(WGSL::concretize):
* Source/WebGPU/WGSL/GlobalVariableRewriter.cpp:
(WGSL::bindingMemberForGlobal):
* Source/WebGPU/WGSL/Metal/MetalFunctionWriter.cpp:
(WGSL::Metal::FunctionDefinitionWriter::emitNecessaryHelpers):
(WGSL::Metal::FunctionDefinitionWriter::visit):
(WGSL::Metal::FunctionDefinitionWriter::serializeConstant):
* Source/WebGPU/WGSL/TypeCheck.cpp:
(WGSL::TypeChecker::visit):
(WGSL::TypeChecker::convertValue):
* Source/WebGPU/WGSL/TypeDeclarations.rb:
* Source/WebGPU/WGSL/TypeStore.cpp:
(WGSL::PrimitiveStructKey::encode const):
(WGSL::TypeStore::frexpResultType):
* Source/WebGPU/WGSL/TypeStore.h:
* Source/WebGPU/WGSL/Types.cpp:
(WGSL::Type::dump const):
(WGSL::conversionRank):
(WGSL::Type::size const):
(WGSL::Type::alignment const):
* Source/WebGPU/WGSL/Types.h:
* Source/WebGPU/WGSL/WGSLShaderModule.h:
(WGSL::ShaderModule::usesFrexp const):
(WGSL::ShaderModule::setUsesFrexp):
* Source/WebGPU/WGSL/generator/main.rb:
* Source/WebGPU/WGSL/tests/valid/overload.wgsl:

Canonical link: https://commits.webkit.org/270334@main
  • Loading branch information
tadeuzagallo committed Nov 7, 2023
1 parent 0335988 commit b056194
Show file tree
Hide file tree
Showing 13 changed files with 214 additions and 2 deletions.
5 changes: 5 additions & 0 deletions Source/WebGPU/WGSL/ConstantFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ static ConstantValue zeroValue(const Type* type)
// yet have ConstantStruct
RELEASE_ASSERT_NOT_REACHED();
},
[&](const Types::PrimitiveStruct&) -> ConstantValue {
// FIXME: this is valid and needs to be implemented, but we don't
// yet have ConstantStruct
RELEASE_ASSERT_NOT_REACHED();
},
[&](const Types::Matrix& matrix) -> ConstantValue {
ConstantMatrix result(matrix.columns, matrix.rows);
auto value = zeroValue(matrix.element);
Expand Down
9 changes: 9 additions & 0 deletions Source/WebGPU/WGSL/Constraints.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,15 @@ const Type* concretize(const Type* type, TypeStore& types)
[&](const Struct&) -> const Type* {
return type;
},
[&](const PrimitiveStruct& primitiveStruct) -> const Type* {
switch (primitiveStruct.kind) {
case PrimitiveStruct::FrexpResult::kind: {
auto* fract = concretize(primitiveStruct.values[PrimitiveStruct::FrexpResult::fract], types);
auto* exp = concretize(primitiveStruct.values[PrimitiveStruct::FrexpResult::exp], types);
return types.frexpResultType(fract, exp);
}
}
},
[&](const Pointer&) -> const Type* {
return type;
},
Expand Down
2 changes: 2 additions & 0 deletions Source/WebGPU/WGSL/GlobalVariableRewriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,8 @@ static BindGroupLayoutEntry::BindingMember bindingMemberForGlobal(auto& global)
.hasDynamicOffset = false,
.minBindingSize = 0
};
}, [&](const PrimitiveStruct&) -> BindGroupLayoutEntry::BindingMember {
RELEASE_ASSERT_NOT_REACHED();
}, [&](const Reference&) -> BindGroupLayoutEntry::BindingMember {
RELEASE_ASSERT_NOT_REACHED();
}, [&](const Pointer&) -> BindGroupLayoutEntry::BindingMember {
Expand Down
39 changes: 39 additions & 0 deletions Source/WebGPU/WGSL/Metal/MetalFunctionWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,29 @@ void FunctionDefinitionWriter::emitNecessaryHelpers()
m_stringBuilder.append(m_indent, "}\n\n");
}


if (m_callGraph.ast().usesFrexp()) {
m_stringBuilder.append(m_indent, "template<typename T, typename U>\n");
m_stringBuilder.append(m_indent, "struct __frexp_result {\n");
{
IndentationScope scope(m_indent);
m_stringBuilder.append(m_indent, "T fract;\n");
m_stringBuilder.append(m_indent, "U exp;\n");
}
m_stringBuilder.append(m_indent, "};\n\n");

m_stringBuilder.append(m_indent, "template<typename T, typename U = conditional_t<is_vector_v<T>, vec<int, vec_elements<T>::value ?: 2>, int>>\n");
m_stringBuilder.append(m_indent, "__frexp_result<T, U> __wgslFrexp(T value)\n");
m_stringBuilder.append(m_indent, "{\n");
{
IndentationScope scope(m_indent);
m_stringBuilder.append(m_indent, "__frexp_result<T, U> result;\n");
m_stringBuilder.append(m_indent, "result.fract = frexp(value, result.exp);\n");
m_stringBuilder.append(m_indent, "return result;\n");
}
m_stringBuilder.append(m_indent, "}\n\n");
}

if (m_callGraph.ast().usesPackedStructs()) {
m_callGraph.ast().clearUsesPackedStructs();

Expand Down Expand Up @@ -744,6 +767,17 @@ void FunctionDefinitionWriter::visit(const Type* type)
if (m_structRole.has_value() && *m_structRole == AST::StructureRole::PackedResource)
m_stringBuilder.append("::PackedType");
},
[&](const PrimitiveStruct& structure) {
m_stringBuilder.append(structure.name, "<");
bool first = true;
for (auto& value : structure.values) {
if (!first)
m_stringBuilder.append(", ");
first = false;
visit(value);
}
m_stringBuilder.append(">");
},
[&](const Texture& texture) {
const char* type;
const char* access = "sample";
Expand Down Expand Up @@ -1391,6 +1425,7 @@ void FunctionDefinitionWriter::visit(const Type* type, AST::CallExpression& call
{ "dpdy", "dfdy"_s },
{ "dpdyCoarse", "dfdy"_s },
{ "dpdyFine", "dfdy"_s },
{ "frexp", "__wgslFrexp"_s },
{ "fwidthCoarse", "fwidth"_s },
{ "fwidthFine", "fwidth"_s },
{ "inverseSqrt", "rsqrt"_s },
Expand Down Expand Up @@ -1870,6 +1905,10 @@ void FunctionDefinitionWriter::serializeConstant(const Type* type, ConstantValue
// Not supported yet
RELEASE_ASSERT_NOT_REACHED();
},
[&](const PrimitiveStruct&) {
// Not supported yet
RELEASE_ASSERT_NOT_REACHED();
},
[&](const Pointer&) {
RELEASE_ASSERT_NOT_REACHED();
},
Expand Down
16 changes: 16 additions & 0 deletions Source/WebGPU/WGSL/TypeCheck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,16 @@ void TypeChecker::visit(AST::FieldAccessExpression& access)
return it->value;
}

if (auto* primitiveStruct = std::get_if<Types::PrimitiveStruct>(baseType)) {
const auto& keys = Types::PrimitiveStruct::keys[primitiveStruct->kind];
auto* key = keys.tryGet(access.fieldName().id());
if (!key) {
typeError(access.span(), "struct '", *baseType, "' does not have a member called '", access.fieldName(), "'");
return nullptr;
}
return primitiveStruct->values[*key];
}

if (std::holds_alternative<Types::Vector>(*baseType)) {
auto& vector = std::get<Types::Vector>(*baseType);
auto* result = vectorFieldAccess(vector, access);
Expand Down Expand Up @@ -961,6 +971,8 @@ void TypeChecker::visit(AST::CallExpression& call)
// FIXME: this will go away once we track used intrinsics properly
if (targetName == "workgroupUniformLoad"_s)
m_shaderModule.setUsesWorkgroupUniformLoad();
else if (targetName == "frexp"_s)
m_shaderModule.setUsesFrexp();
target.m_inferredType = result;
return;
}
Expand Down Expand Up @@ -1635,6 +1647,10 @@ bool TypeChecker::convertValue(const SourceSpan& span, const Type* type, Constan
// FIXME: this should be supported
RELEASE_ASSERT_NOT_REACHED();
},
[&](const Types::PrimitiveStruct&) -> Conversion {
// FIXME: this should be supported
RELEASE_ASSERT_NOT_REACHED();
},
[&](const Types::Function&) -> Conversion {
RELEASE_ASSERT_NOT_REACHED();
},
Expand Down
16 changes: 15 additions & 1 deletion Source/WebGPU/WGSL/TypeDeclarations.rb
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,21 @@
must_use: true,
const: true,

# FIXME: this needs the special return types __frexp_result_*
[].(f32) => __frexp_result_f32,
# [].(f16) => __frexp_result_f16,
[].(abstract_float) => __frexp_result_abstract,

[].(vec2[f32]) => __frexp_result_vec2_f32,
# [].(vec2[f16]) => __frexp_result_vec2_f16,
[].(vec2[abstract_float]) => __frexp_result_vec2_abstract,

[].(vec3[f32]) => __frexp_result_vec3_f32,
# [].(vec3[f16]) => __frexp_result_vec3_f16,
[].(vec3[abstract_float]) => __frexp_result_vec3_abstract,

[].(vec4[f32]) => __frexp_result_vec4_f32,
# [].(vec4[f16]) => __frexp_result_vec4_f16,
[].(vec4[abstract_float]) => __frexp_result_vec4_abstract,
}

# 17.5.33
Expand Down
22 changes: 22 additions & 0 deletions Source/WebGPU/WGSL/TypeStore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ struct PointerKey {
TypeCache::EncodedKey encode() const { return std::tuple(TypeCache::Pointer, WTF::enumToUnderlyingType(addressSpace), WTF::enumToUnderlyingType(accessMode), 0, bitwise_cast<uintptr_t>(elementType)); }
};

struct PrimitiveStructKey {
unsigned kind;
const Type* elementType;

TypeCache::EncodedKey encode() const { return std::tuple(TypeCache::PrimitiveStruct, kind, 0, 0, bitwise_cast<uintptr_t>(elementType)); }
};

template<typename Key>
const Type* TypeCache::find(const Key& key) const
{
Expand Down Expand Up @@ -227,6 +234,21 @@ const Type* TypeStore::typeConstructorType(ASCIILiteral name, std::function<cons
return allocateType<TypeConstructor>(name, WTFMove(constructor));
}

const Type* TypeStore::frexpResultType(const Type* fract, const Type* exp)
{
PrimitiveStructKey key { PrimitiveStruct::FrexpResult::kind, fract };
const Type* type = m_cache.find(key);
if (type)
return type;

FixedVector<const Type*> values(2);
values[PrimitiveStruct::FrexpResult::fract] = fract;
values[PrimitiveStruct::FrexpResult::exp] = exp;
type = allocateType<PrimitiveStruct>("__frexp_result"_s, PrimitiveStruct::FrexpResult::kind, values);
m_cache.insert(key, type);
return type;
}

template<typename TypeKind, typename... Arguments>
const Type* TypeStore::allocateType(Arguments&&... arguments)
{
Expand Down
2 changes: 2 additions & 0 deletions Source/WebGPU/WGSL/TypeStore.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class TypeCache {
TextureStorage,
Reference,
Pointer,
PrimitiveStruct,
};

using EncodedKey = std::tuple<uint8_t, uint8_t, uint16_t, uint32_t, uintptr_t>;
Expand Down Expand Up @@ -101,6 +102,7 @@ class TypeStore {
const Type* pointerType(AddressSpace, const Type*, AccessMode);
const Type* atomicType(const Type*);
const Type* typeConstructorType(ASCIILiteral, std::function<const Type*(AST::ElaboratedTypeExpression&)>&&);
const Type* frexpResultType(const Type*, const Type*);

private:
template<typename TypeKind, typename... Arguments>
Expand Down
28 changes: 28 additions & 0 deletions Source/WebGPU/WGSL/Types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ void Type::dump(PrintStream& out) const
[&](const Struct& structure) {
out.print(structure.structure.name());
},
[&](const PrimitiveStruct& structure) {
out.print(structure.name, "<");
switch (structure.kind) {
case PrimitiveStruct::FrexpResult::kind:
out.print(*structure.values[PrimitiveStruct::FrexpResult::fract]);
break;
}
out.print(">");
},
[&](const Function& function) {
out.print("(");
bool first = true;
Expand Down Expand Up @@ -234,6 +243,19 @@ ConversionRank conversionRank(const Type* from, const Type* to)
return conversionRank(fromArray->element, toArray->element);
}

if (auto* fromPrimitiveStruct = std::get_if<PrimitiveStruct>(from)) {
auto* toPrimitiveStruct = std::get_if<PrimitiveStruct>(to);
if (!toPrimitiveStruct)
return std::nullopt;
auto kind = fromPrimitiveStruct->kind;
if (kind != toPrimitiveStruct->kind)
return std::nullopt;
switch (kind) {
case PrimitiveStruct::FrexpResult::kind:
return conversionRank(fromPrimitiveStruct->values[PrimitiveStruct::FrexpResult::fract], toPrimitiveStruct->values[PrimitiveStruct::FrexpResult::fract]);
}
}

// FIXME: add the abstract result conversion rules
return std::nullopt;
}
Expand Down Expand Up @@ -285,6 +307,9 @@ unsigned Type::size() const
[&](const Struct& structure) -> unsigned {
return structure.structure.size();
},
[&](const PrimitiveStruct&) -> unsigned {
RELEASE_ASSERT_NOT_REACHED();
},
[&](const Function&) -> unsigned {
RELEASE_ASSERT_NOT_REACHED();
},
Expand Down Expand Up @@ -355,6 +380,9 @@ unsigned Type::alignment() const
[&](const Struct& structure) -> unsigned {
return structure.structure.alignment();
},
[&](const PrimitiveStruct&) -> unsigned {
RELEASE_ASSERT_NOT_REACHED();
},
[&](const Function&) -> unsigned {
RELEASE_ASSERT_NOT_REACHED();
},
Expand Down
33 changes: 33 additions & 0 deletions Source/WebGPU/WGSL/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@

#include "ASTForward.h"
#include "WGSLEnums.h"
#include <wtf/FixedVector.h>
#include <wtf/HashMap.h>
#include <wtf/Markable.h>
#include <wtf/PrintStream.h>
#include <wtf/SortedArrayMap.h>
#include <wtf/text/WTFString.h>

namespace WGSL {
Expand Down Expand Up @@ -126,6 +128,35 @@ struct Struct {
HashMap<String, const Type*> fields { };
};

struct PrimitiveStruct {
private:
enum Kind : uint8_t {
FrexpResult,
};

public:
struct FrexpResult {
static constexpr Kind kind = Kind::FrexpResult;
static constexpr unsigned fract = 0;
static constexpr unsigned exp = 1;

static constexpr std::pair<ComparableASCIILiteral, unsigned> mapEntries[] {
{ "exp", exp },
{ "fract", fract },
};

static constexpr SortedArrayMap map { mapEntries };
};

static constexpr SortedArrayMap<std::pair<ComparableASCIILiteral, unsigned>[2]> keys[] {
FrexpResult::map,
};

String name;
Kind kind;
FixedVector<const Type*> values;
};

struct Function {
WTF::Vector<const Type*> parameters;
const Type* result;
Expand Down Expand Up @@ -163,6 +194,7 @@ struct Type : public std::variant<
Types::Matrix,
Types::Array,
Types::Struct,
Types::PrimitiveStruct,
Types::Function,
Types::Texture,
Types::TextureStorage,
Expand All @@ -179,6 +211,7 @@ struct Type : public std::variant<
Types::Matrix,
Types::Array,
Types::Struct,
Types::PrimitiveStruct,
Types::Function,
Types::Texture,
Types::TextureStorage,
Expand Down
4 changes: 4 additions & 0 deletions Source/WebGPU/WGSL/WGSLShaderModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ class ShaderModule {
bool usesDivision() const { return m_usesDivision; }
void setUsesDivision() { m_usesDivision = true; }

bool usesFrexp() const { return m_usesFrexp; }
void setUsesFrexp() { m_usesFrexp = true; }

template<typename T>
std::enable_if_t<std::is_base_of_v<AST::Node, T>, void> replace(T* current, T&& replacement)
{
Expand Down Expand Up @@ -221,6 +224,7 @@ class ShaderModule {
bool m_usesUnpackArray { false };
bool m_usesWorkgroupUniformLoad { false };
bool m_usesDivision { false };
bool m_usesFrexp { false };
Configuration m_configuration;
AST::Directive::List m_directives;
AST::Function::List m_functions;
Expand Down
19 changes: 19 additions & 0 deletions Source/WebGPU/WGSL/generator/main.rb
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,25 @@ def self.prologue
texture_storage_2d = texture_storage[TextureStorage2d]
texture_storage_2d_array = texture_storage[TextureStorage2dArray]
texture_storage_3d = texture_storage[TextureStorage3d]
# primitive structs
__frexp_result_abstract = Constructor.new(:frexpResult, [abstract_float, abstract_int])
# __frexp_result_f16 = Constructor.new(:frexpResult, [f16, i32])
__frexp_result_f32 = Constructor.new(:frexpResult, [f32, i32])
__frexp_result_vec2_abstract = Constructor.new(:frexpResult, [vec2[abstract_float], vec2[abstract_int]])
# __frexp_result_vec2_f16 = Constructor.new(:frexpResult, [vec2[f16], vec2[i32]])
__frexp_result_vec2_f32 = Constructor.new(:frexpResult, [vec2[f32], vec2[i32]])
__frexp_result_vec3_abstract = Constructor.new(:frexpResult, [vec3[abstract_float], vec3[abstract_int]])
# __frexp_result_vec3_f16 = Constructor.new(:frexpResult, [vec3[f16], vec3[i32]])
__frexp_result_vec3_f32 = Constructor.new(:frexpResult, [vec3[f32], vec3[i32]])
__frexp_result_vec4_abstract = Constructor.new(:frexpResult, [vec4[abstract_float], vec4[abstract_int]])
# __frexp_result_vec4_f16 = Constructor.new(:frexpResult, [vec4[f16], vec4[i32]])
__frexp_result_vec4_f32 = Constructor.new(:frexpResult, [vec4[f32], vec4[i32]])
EOS
end

Expand Down
Loading

0 comments on commit b056194

Please sign in to comment.