Skip to content

Commit

Permalink
[WGSL] shader,execution,expression,call,builtin,bitcast:* is failing
Browse files Browse the repository at this point in the history
https://bugs.webkit.org/show_bug.cgi?id=267333
rdar://120783284

Reviewed by Mike Wyrzykowski.

The constant implementation of bitcast reused the `convertValue` helper, which used
static_cast to convert between types. Allow passing different casts to convertValue
and use bitwise_cast instead.

* Source/WebGPU/WGSL/ConstantFunctions.h:
(WGSL::StaticCast::cast):
(WGSL::BitwiseCast::cast):
(WGSL::convertValue):
(WGSL::CONSTANT_FUNCTION):

Canonical link: https://commits.webkit.org/272893@main
  • Loading branch information
tadeuzagallo committed Jan 11, 2024
1 parent 074ad42 commit 42abab7
Showing 1 changed file with 50 additions and 28 deletions.
78 changes: 50 additions & 28 deletions Source/WebGPU/WGSL/ConstantFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,45 +266,66 @@ static ConstantResult constantTernaryOperation(const FixedVector<ConstantValue>&
}, arguments[0], arguments[1], arguments[2]);
}

template<typename DestinationType>
template<typename U>
struct StaticCast {
template<typename T>
static U cast(T t) { return static_cast<U>(t); }
};

template<typename U>
struct BitwiseCast {
template<typename T>
static U cast(T t) { return bitwise_cast<U>(t); }
};

template<typename DestinationType, template <typename U> typename Cast = StaticCast>
static ConstantValue convertValue(ConstantValue value)
{
if (auto* boolean = std::get_if<bool>(&value))
return static_cast<DestinationType>(*boolean);
if (auto* i32 = std::get_if<int32_t>(&value))
return static_cast<DestinationType>(*i32);
if (auto* u32 = std::get_if<uint32_t>(&value))
return static_cast<DestinationType>(*u32);
if (auto* abstractInt = std::get_if<int64_t>(&value))
return static_cast<DestinationType>(*abstractInt);
if (auto* f32 = std::get_if<float>(&value))
return static_cast<DestinationType>(*f32);
if (auto* f16 = std::get_if<half>(&value))
return static_cast<DestinationType>(*f16);
if (auto* abstractFloat = std::get_if<double>(&value))
return static_cast<DestinationType>(*abstractFloat);
if constexpr (std::is_same_v<Cast<DestinationType>, StaticCast<DestinationType>> || sizeof(DestinationType) == 4) {
if (auto* i32 = std::get_if<int32_t>(&value))
return Cast<DestinationType>::cast(*i32);
if (auto* u32 = std::get_if<uint32_t>(&value))
return Cast<DestinationType>::cast(*u32);
if (auto* f32 = std::get_if<float>(&value))
return Cast<DestinationType>::cast(*f32);
}

if constexpr (std::is_same_v<Cast<DestinationType>, StaticCast<DestinationType>> || sizeof(DestinationType) == 2) {
if (auto* f16 = std::get_if<half>(&value))
return Cast<DestinationType>::cast(*f16);
}

if constexpr (std::is_same_v<Cast<DestinationType>, StaticCast<DestinationType>>) {
if (auto* boolean = std::get_if<bool>(&value))
return Cast<DestinationType>::cast(*boolean);
if (auto* abstractInt = std::get_if<int64_t>(&value))
return Cast<DestinationType>::cast(*abstractInt);
if (auto* abstractFloat = std::get_if<double>(&value))
return Cast<DestinationType>::cast(*abstractFloat);
}
RELEASE_ASSERT_NOT_REACHED();
}

template<template <typename U> typename Cast = StaticCast>
static ConstantValue convertValue(const Type* targetType, ConstantValue value)
{
ASSERT(std::holds_alternative<Types::Primitive>(*targetType));
auto& primitive = std::get<Types::Primitive>(*targetType);
switch (primitive.kind) {
case Types::Primitive::AbstractInt:
return convertValue<int64_t>(value);
return convertValue<int64_t, Cast>(value);
case Types::Primitive::I32:
return convertValue<int32_t>(value);
return convertValue<int32_t, Cast>(value);
case Types::Primitive::U32:
return convertValue<uint32_t>(value);
return convertValue<uint32_t, Cast>(value);
case Types::Primitive::AbstractFloat:
return convertValue<double>(value);
return convertValue<double, Cast>(value);
case Types::Primitive::F32:
return convertValue<float>(value);
return convertValue<float, Cast>(value);
case Types::Primitive::F16:
return convertValue<half>(value);
return convertValue<half, Cast>(value);
case Types::Primitive::Bool:
return convertValue<bool>(value);
return convertValue<bool, Cast>(value);
default:
RELEASE_ASSERT_NOT_REACHED();
}
Expand Down Expand Up @@ -1604,22 +1625,23 @@ CONSTANT_FUNCTION(Bitcast)
value = 0;
}

result.elements[offset] = bitwise_cast<half>(static_cast<uint16_t>(value));
result.elements[offset + 1] = bitwise_cast<half>(static_cast<uint16_t>(value >> 16));
auto parts = bitwise_cast<std::array<half, 2>>(value);
result.elements[offset] = parts[0];
result.elements[offset + 1] = parts[1];
return std::nullopt;
};

const auto& join = [&](const Type* type, const ConstantVector& vector, unsigned offset) -> ConstantValue {
uint32_t value = 0;
value |= bitwise_cast<uint16_t>(std::get<half>(vector.elements[offset]));
value |= static_cast<uint32_t>(bitwise_cast<uint16_t>(std::get<half>(vector.elements[offset + 1]))) << 16;
return convertValue(type, value);
return convertValue<BitwiseCast>(type, value);
};

const auto& vectorVector = [&](const Types::Vector& dst, const ConstantVector& src) -> ConstantResult {
if (dst.size == src.elements.size()) {
return scalarOrVector([&](auto& value) {
return convertValue(dst.element, value);
return convertValue<BitwiseCast>(dst.element, value);
}, src);
}

Expand Down Expand Up @@ -1663,9 +1685,9 @@ CONSTANT_FUNCTION(Bitcast)
auto result = convertInteger<int32_t>(*abstractInt);
if (!result.has_value())
return makeUnexpected(makeString("value ", String::number(*abstractInt), " cannot be represented as 'i32'"));
return { convertValue(resultType, *result) };
return { convertValue<BitwiseCast>(resultType, *result) };
}
return { convertValue(resultType, argument) };
return { convertValue<BitwiseCast>(resultType, argument) };
}

// Type checker helpers
Expand Down

0 comments on commit 42abab7

Please sign in to comment.