Skip to content

Commit

Permalink
[WGSL] Add support for atomicCompareExchangeWeak
Browse files Browse the repository at this point in the history
https://bugs.webkit.org/show_bug.cgi?id=266912
rdar://120206827

Reviewed by Mike Wyrzykowski.

Add support for the last missing atomic function

* Source/WebGPU/WGSL/Constraints.cpp:
(WGSL::concretize):
* Source/WebGPU/WGSL/Metal/MetalFunctionWriter.cpp:
(WGSL::Metal::FunctionDefinitionWriter::emitNecessaryHelpers):
(WGSL::Metal::FunctionDefinitionWriter::visit):
* Source/WebGPU/WGSL/TypeCheck.cpp:
(WGSL::TypeChecker::visit):
* Source/WebGPU/WGSL/TypeDeclarations.rb:
* Source/WebGPU/WGSL/TypeStore.cpp:
(WGSL::TypeStore::atomicCompareExchangeResultType):
* Source/WebGPU/WGSL/TypeStore.h:
* Source/WebGPU/WGSL/Types.cpp:
(WGSL::Type::dump const):
(WGSL::conversionRank):
* Source/WebGPU/WGSL/Types.h:
* Source/WebGPU/WGSL/WGSLShaderModule.h:
(WGSL::ShaderModule::usesAtomicCompareExchange const):
(WGSL::ShaderModule::setUsesAtomicCompareExchange):
* Source/WebGPU/WGSL/generator/main.rb:
* Source/WebGPU/WGSL/tests/valid/overload.wgsl:

Canonical link: https://commits.webkit.org/272518@main
  • Loading branch information
tadeuzagallo committed Dec 28, 2023
1 parent bda8685 commit 12712ac
Show file tree
Hide file tree
Showing 11 changed files with 91 additions and 5 deletions.
3 changes: 3 additions & 0 deletions Source/WebGPU/WGSL/Constraints.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@ const Type* concretize(const Type* type, TypeStore& types)
auto* whole = concretize(primitiveStruct.values[PrimitiveStruct::ModfResult::whole], types);
return types.modfResultType(fract, whole);
}
case PrimitiveStruct::AtomicCompareExchangeResult::kind: {
return type;
}
}
},
[&](const Pointer&) -> const Type* {
Expand Down
20 changes: 20 additions & 0 deletions Source/WebGPU/WGSL/Metal/MetalFunctionWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,25 @@ void FunctionDefinitionWriter::emitNecessaryHelpers()
m_stringBuilder.append(m_indent, "}\n\n");
}

if (m_callGraph.ast().usesAtomicCompareExchange()) {
m_stringBuilder.append(m_indent, "template<typename T>\n");
m_stringBuilder.append(m_indent, "struct __atomic_compare_exchange_result {\n");
{
IndentationScope scope(m_indent);
m_stringBuilder.append(m_indent, "T old_value;\n");
m_stringBuilder.append(m_indent, "bool exchanged;\n");
}
m_stringBuilder.append(m_indent, "};\n\n");

m_stringBuilder.append(m_indent, "#define __wgslAtomicCompareExchangeWeak(atomic, compare, value) \\\n");
{
IndentationScope scope(m_indent);
m_stringBuilder.append(m_indent, "({ auto innerCompare = compare; \\\n");
m_stringBuilder.append(m_indent, "__atomic_compare_exchange_result<decltype(compare)> { innerCompare, atomic_compare_exchange_weak_explicit((atomic), &innerCompare, value, memory_order_relaxed, memory_order_relaxed) }; \\\n");
m_stringBuilder.append(m_indent, "})\n");
}
}

if (m_callGraph.ast().usesPackedStructs()) {
m_callGraph.ast().clearUsesPackedStructs();

Expand Down Expand Up @@ -1658,6 +1677,7 @@ void FunctionDefinitionWriter::visit(const Type* type, AST::CallExpression& call
}

static constexpr std::pair<ComparableASCIILiteral, ASCIILiteral> directMappings[] {
{ "atomicCompareExchangeWeak", "__wgslAtomicCompareExchangeWeak"_s },
{ "countLeadingZeros", "clz"_s },
{ "countOneBits", "popcount"_s },
{ "countTrailingZeros", "ctz"_s },
Expand Down
2 changes: 2 additions & 0 deletions Source/WebGPU/WGSL/TypeCheck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,8 @@ void TypeChecker::visit(AST::CallExpression& call)
m_shaderModule.setUsesFrexp();
else if (targetName == "modf"_s)
m_shaderModule.setUsesModf();
else if (targetName == "atomicCompareExchangeWeak"_s)
m_shaderModule.setUsesAtomicCompareExchange();
target.m_inferredType = result;
return;
}
Expand Down
6 changes: 4 additions & 2 deletions Source/WebGPU/WGSL/TypeDeclarations.rb
Original file line number Diff line number Diff line change
Expand Up @@ -1388,8 +1388,10 @@
}
end

# FIXME: Implement atomicCompareExchangeWeak (which depends on the result struct that is not currently supported)
# fn atomicCompareExchangeWeak(atomic_ptr: ptr<AS, atomic<T>, read_write>, cmp: T, v: T) -> __atomic_compare_exchange_result<T>
function :atomicCompareExchangeWeak, {
[AS].(ptr[AS, atomic[i32], read_write], i32, i32) => __atomic_compare_exchange_result_i32,
[AS].(ptr[AS, atomic[u32], read_write], u32, u32) => __atomic_compare_exchange_result_u32,
}

# 16.9. Data Packing Built-in Functions (https://www.w3.org/TR/WGSL/#pack-builtin-functions)
# FIXME: implement
Expand Down
18 changes: 18 additions & 0 deletions Source/WebGPU/WGSL/TypeStore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,24 @@ const Type* TypeStore::modfResultType(const Type* fract, const Type* whole)
return type;
}

const Type* TypeStore::atomicCompareExchangeResultType(const Type* type)
{
const auto& load = [&](const Type*& member) {
if (member)
return member;
FixedVector<const Type*> values(2);
values[PrimitiveStruct::AtomicCompareExchangeResult::oldValue] = type;
values[PrimitiveStruct::AtomicCompareExchangeResult::exchanged] = boolType();
member = allocateType<PrimitiveStruct>("__atomic_compare_exchange_result"_s, PrimitiveStruct::ModfResult::kind, values);
return member;
};

if (type == m_i32)
return load(m_atomicCompareExchangeResultI32);
ASSERT(type == m_u32);
return load(m_atomicCompareExchangeResultU32);
}

template<typename TypeKind, typename... Arguments>
const Type* TypeStore::allocateType(Arguments&&... arguments)
{
Expand Down
3 changes: 3 additions & 0 deletions Source/WebGPU/WGSL/TypeStore.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class TypeStore {
const Type* typeConstructorType(ASCIILiteral, std::function<const Type*(AST::ElaboratedTypeExpression&)>&&);
const Type* frexpResultType(const Type*, const Type*);
const Type* modfResultType(const Type*, const Type*);
const Type* atomicCompareExchangeResultType(const Type*);

private:
template<typename TypeKind, typename... Arguments>
Expand Down Expand Up @@ -135,6 +136,8 @@ class TypeStore {
const Type* m_textureDepthMultisampled2d;
const Type* m_atomicI32;
const Type* m_atomicU32;
const Type* m_atomicCompareExchangeResultI32;
const Type* m_atomicCompareExchangeResultU32;
};

} // namespace WGSL
6 changes: 5 additions & 1 deletion Source/WebGPU/WGSL/Types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ void Type::dump(PrintStream& out) const
case PrimitiveStruct::ModfResult::kind:
out.print(*structure.values[PrimitiveStruct::ModfResult::fract]);
break;
case PrimitiveStruct::AtomicCompareExchangeResult::kind:
out.print(*structure.values[PrimitiveStruct::AtomicCompareExchangeResult::oldValue]);
break;
}
out.print(">");
},
Expand Down Expand Up @@ -260,10 +263,11 @@ ConversionRank conversionRank(const Type* from, const Type* to)
return conversionRank(fromPrimitiveStruct->values[PrimitiveStruct::FrexpResult::fract], toPrimitiveStruct->values[PrimitiveStruct::FrexpResult::fract]);
case PrimitiveStruct::ModfResult::kind:
return conversionRank(fromPrimitiveStruct->values[PrimitiveStruct::ModfResult::fract], toPrimitiveStruct->values[PrimitiveStruct::ModfResult::fract]);
case PrimitiveStruct::AtomicCompareExchangeResult::kind:
return std::nullopt;
}
}

// FIXME: add the abstract result conversion rules
return std::nullopt;
}

Expand Down
15 changes: 15 additions & 0 deletions Source/WebGPU/WGSL/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ struct PrimitiveStruct {
enum Kind : uint8_t {
FrexpResult,
ModfResult,
AtomicCompareExchangeResult,
};

public:
Expand Down Expand Up @@ -163,9 +164,23 @@ struct PrimitiveStruct {
static constexpr SortedArrayMap map { mapEntries };
};

struct AtomicCompareExchangeResult {
static constexpr Kind kind = Kind::AtomicCompareExchangeResult;
static constexpr unsigned oldValue = 0;
static constexpr unsigned exchanged = 1;

static constexpr std::pair<ComparableASCIILiteral, unsigned> mapEntries[] {
{ "exchanged", exchanged },
{ "old_value", oldValue },
};

static constexpr SortedArrayMap map { mapEntries };
};

static constexpr SortedArrayMap<std::pair<ComparableASCIILiteral, unsigned>[2]> keys[] {
FrexpResult::map,
ModfResult::map,
AtomicCompareExchangeResult::map,
};

String name;
Expand Down
4 changes: 4 additions & 0 deletions Source/WebGPU/WGSL/WGSLShaderModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ class ShaderModule {
bool usesModf() const { return m_usesModf; }
void setUsesModf() { m_usesModf = true; }

bool usesAtomicCompareExchange() const { return m_usesAtomicCompareExchange; }
void setUsesAtomicCompareExchange() { m_usesAtomicCompareExchange = true; }

template<typename T>
std::enable_if_t<std::is_base_of_v<AST::Node, T>, void> replace(T* current, T&& replacement)
{
Expand Down Expand Up @@ -244,6 +247,7 @@ class ShaderModule {
bool m_usesModulo { false };
bool m_usesFrexp { false };
bool m_usesModf { false };
bool m_usesAtomicCompareExchange { false };
OptionSet<Extension> m_enabledExtensions;
OptionSet<LanguageFeature> m_requiredFeatures;
Configuration m_configuration;
Expand Down
3 changes: 3 additions & 0 deletions Source/WebGPU/WGSL/generator/main.rb
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,9 @@ def self.prologue
__modf_result_vec4_abstract = Constructor.new(:modfResult, [vec4[abstract_float], vec4[abstract_float]])
__modf_result_vec4_f16 = Constructor.new(:modfResult, [vec4[f16], vec4[f16]])
__modf_result_vec4_f32 = Constructor.new(:modfResult, [vec4[f32], vec4[f32]])
__atomic_compare_exchange_result_i32 = Constructor.new(:atomicCompareExchangeResult, [i32])
__atomic_compare_exchange_result_u32 = Constructor.new(:atomicCompareExchangeResult, [u32])
EOS
end

Expand Down
16 changes: 14 additions & 2 deletions Source/WebGPU/WGSL/tests/valid/overload.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -3516,6 +3516,7 @@ fn testTextureStore()

// 16.8. Atomic Built-in Functions (https://www.w3.org/TR/WGSL/#atomic-builtin-functions)
var<workgroup> x: atomic<i32>;
@group(8) @binding(0) var<storage, read_write> y: atomic<i32>;

// RUN: %metal-compile testAtomicFunctions
@compute @workgroup_size(1)
Expand All @@ -3531,13 +3532,15 @@ fn testAtomicLoad()
{
// [AS, T].(ptr[AS, atomic[T], read_write]) => T,
_ = atomicLoad(&x);
_ = atomicLoad(&y);
}

// 16.8.2
fn testAtomicStore()
{
/*[AS, T].(ptr[AS, atomic[T], read_write], T) => void,*/
atomicStore(&x, 42);
atomicStore(&y, 42);
}

// 16.8.3. Atomic Read-modify-write (this spec entry contains several functions)
Expand All @@ -3552,10 +3555,19 @@ fn testAtomicReadWriteModify()
_ = atomicOr(&x, 42);
_ = atomicXor(&x, 42);
_ = atomicExchange(&x, 42);
_ = atomicCompareExchangeWeak(&x, 42, 13);

_ = atomicAdd(&y, 42);
_ = atomicSub(&y, 42);
_ = atomicMax(&y, 42);
_ = atomicMin(&y, 42);
_ = atomicAnd(&y, 42);
_ = atomicOr(&y, 42);
_ = atomicXor(&y, 42);
_ = atomicExchange(&y, 42);
_ = atomicCompareExchangeWeak(&y, 42, 13);
}

// FIXME: Implement atomicCompareExchangeWeak (which depends on the result struct that is not currently supported)

// 16.9. Data Packing Built-in Functions (https://www.w3.org/TR/WGSL/#pack-builtin-functions)
// FIXME: implement

Expand Down

0 comments on commit 12712ac

Please sign in to comment.