Skip to content

Commit

Permalink
[WGSL] shader,validation,expression,unary,address_of_and_indirection:…
Browse files Browse the repository at this point in the history
…* is failing

https://bugs.webkit.org/show_bug.cgi?id=270095
rdar://123636841

Reviewed by Mike Wyrzykowski.

Move the logic for address of and indirection from a type declaration into the
type checker in order to separate it from the arithmetic operators, and implement
the additional validation:
- don't allow taking the address of a vector component
- don't allow taking the address of textures and samplers

* Source/WebGPU/WGSL/TypeCheck.cpp:
(WGSL::TypeChecker::visit):
* Source/WebGPU/WGSL/TypeDeclarations.rb:
* Source/WebGPU/WGSL/TypeStore.cpp:
(WGSL::ReferenceKey::encode const):
(WGSL::TypeStore::referenceType):
* Source/WebGPU/WGSL/TypeStore.h:
* Source/WebGPU/WGSL/Types.h:

Canonical link: https://commits.webkit.org/275373@main
  • Loading branch information
tadeuzagallo committed Feb 27, 2024
1 parent 3211e58 commit 8529f01
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 19 deletions.
51 changes: 45 additions & 6 deletions Source/WebGPU/WGSL/TypeCheck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1002,7 +1002,7 @@ void TypeChecker::visit(AST::Expression&)

void TypeChecker::visit(AST::FieldAccessExpression& access)
{
const auto& accessImpl = [&](const Type* baseType, bool* canBeReference = nullptr) -> const Type* {
const auto& accessImpl = [&](const Type* baseType, bool* canBeReference = nullptr, bool* isVector = nullptr) -> const Type* {
if (isBottom(baseType))
return m_types.bottomType();

Expand Down Expand Up @@ -1037,6 +1037,8 @@ void TypeChecker::visit(AST::FieldAccessExpression& access)
if (std::holds_alternative<Types::Vector>(*baseType)) {
auto& vector = std::get<Types::Vector>(*baseType);
auto* result = vectorFieldAccess(vector, access);
if (isVector)
*isVector = true;
if (result && canBeReference)
*canBeReference = !std::holds_alternative<Types::Vector>(*result);
return result;
Expand All @@ -1049,9 +1051,10 @@ void TypeChecker::visit(AST::FieldAccessExpression& access)
auto* baseType = infer(access.base(), m_evaluation);
if (const auto* reference = std::get_if<Types::Reference>(baseType)) {
bool canBeReference = true;
if (const Type* result = accessImpl(reference->element, &canBeReference)) {
bool isVector = false;
if (const Type* result = accessImpl(reference->element, &canBeReference, &isVector)) {
if (canBeReference)
result = m_types.referenceType(reference->addressSpace, result, reference->accessMode);
result = m_types.referenceType(reference->addressSpace, result, reference->accessMode, isVector);
inferred(result);
}
return;
Expand Down Expand Up @@ -1086,7 +1089,7 @@ void TypeChecker::visit(AST::IndexAccessExpression& access)
access.setConstantValue(std::get<T>(*constantBase)[index]);
};

const auto& accessImpl = [&](const Type* base) -> const Type* {
const auto& accessImpl = [&](const Type* base, bool* isVector = nullptr) -> const Type* {
if (isBottom(base))
return m_types.bottomType();

Expand All @@ -1099,6 +1102,8 @@ void TypeChecker::visit(AST::IndexAccessExpression& access)
size = *constantSize;
constantAccess.operator()<ConstantArray>(size);
} else if (auto* vector = std::get_if<Types::Vector>(base)) {
if (isVector)
*isVector = true;
result = vector->element;
constantAccess.operator()<ConstantVector>(vector->size);
} else if (auto* matrix = std::get_if<Types::Matrix>(base)) {
Expand Down Expand Up @@ -1127,8 +1132,9 @@ void TypeChecker::visit(AST::IndexAccessExpression& access)
}

if (const auto* reference = std::get_if<Types::Reference>(base)) {
if (const Type* result = accessImpl(reference->element)) {
result = m_types.referenceType(reference->addressSpace, result, reference->accessMode);
bool isVector = false;
if (const Type* result = accessImpl(reference->element, &isVector)) {
result = m_types.referenceType(reference->addressSpace, result, reference->accessMode, isVector);
inferred(result);
}
return;
Expand Down Expand Up @@ -1549,6 +1555,39 @@ void TypeChecker::bitcast(AST::CallExpression& call, const Vector<const Type*>&

void TypeChecker::visit(AST::UnaryExpression& unary)
{
if (unary.operation() == AST::UnaryOperation::AddressOf) {
auto* type = infer(unary.expression(), Evaluation::Runtime);
auto* reference = std::get_if<Types::Reference>(type);
if (!reference) {
typeError(unary.span(), "cannot take address of expression");
return;
}

if (reference->addressSpace == AddressSpace::Handle) {
typeError(unary.span(), "cannot take the address of expression in handle address space");
return;
}

if (reference->isVectorComponent) {
typeError(unary.span(), "cannot take the address of a vector component");
return;
}

inferred(m_types.pointerType(reference->addressSpace, reference->element, reference->accessMode));
return;
}

if (unary.operation() == AST::UnaryOperation::Dereference) {
auto* type = infer(unary.expression(), Evaluation::Runtime);
auto* pointer = std::get_if<Types::Pointer>(type);
if (!pointer) {
typeError(unary.span(), "cannot dereference expression of type '", *type, "'");
return;
}

inferred(m_types.referenceType(pointer->addressSpace, pointer->element, pointer->accessMode));
return;
}
chooseOverload("operator", unary, toString(unary.operation()), ReferenceWrapperVector<AST::Expression, 1> { unary.expression() }, { });
}

Expand Down
8 changes: 0 additions & 8 deletions Source/WebGPU/WGSL/TypeDeclarations.rb
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,6 @@
must_use: true,
const: 'constantBitwiseAnd',

# unary
# FIXME: move this out of here
[AS, T, AM].(ref[AS, T, AM]) => ptr[AS, T, AM],

# binary
[].(bool, bool) => bool,
[N].(vec[N][bool], vec[N][bool]) => vec[N][bool],
Expand Down Expand Up @@ -98,10 +94,6 @@
must_use: true,
const: "constantMultiply",

# unary
# FIXME: move this out of here
[AS, T, AM].(ptr[AS, T, AM]) => ref[AS, T, AM],

# binary
[T < Number].(T, T) => T,

Expand Down
9 changes: 5 additions & 4 deletions Source/WebGPU/WGSL/TypeStore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,9 @@ struct ReferenceKey {
const Type* elementType;
AddressSpace addressSpace;
AccessMode accessMode;
bool isVectorComponent;

TypeCache::EncodedKey encode() const { return std::tuple(TypeCache::Reference, WTF::enumToUnderlyingType(addressSpace), WTF::enumToUnderlyingType(accessMode), 0, bitwise_cast<uintptr_t>(elementType)); }
TypeCache::EncodedKey encode() const { return std::tuple(TypeCache::Reference, WTF::enumToUnderlyingType(addressSpace), WTF::enumToUnderlyingType(accessMode), isVectorComponent, bitwise_cast<uintptr_t>(elementType)); }
};

struct PointerKey {
Expand Down Expand Up @@ -217,13 +218,13 @@ const Type* TypeStore::functionType(WTF::Vector<const Type*>&& parameters, const
return allocateType<Function>(WTFMove(parameters), result, mustUse);
}

const Type* TypeStore::referenceType(AddressSpace addressSpace, const Type* element, AccessMode accessMode)
const Type* TypeStore::referenceType(AddressSpace addressSpace, const Type* element, AccessMode accessMode, bool isVectorComponent)
{
ReferenceKey key { element, addressSpace, accessMode };
ReferenceKey key { element, addressSpace, accessMode, isVectorComponent };
const Type* type = m_cache.find(key);
if (type)
return type;
type = allocateType<Reference>(addressSpace, accessMode, element);
type = allocateType<Reference>(addressSpace, accessMode, element, isVectorComponent);
m_cache.insert(key, type);
return type;
}
Expand Down
2 changes: 1 addition & 1 deletion Source/WebGPU/WGSL/TypeStore.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class TypeStore {
const Type* textureType(Types::Texture::Kind, const Type*);
const Type* textureStorageType(Types::TextureStorage::Kind, TexelFormat, AccessMode);
const Type* functionType(Vector<const Type*>&&, const Type*, bool mustUse);
const Type* referenceType(AddressSpace, const Type*, AccessMode);
const Type* referenceType(AddressSpace, const Type*, AccessMode, bool isVectorComponent = false);
const Type* pointerType(AddressSpace, const Type*, AccessMode);
const Type* atomicType(const Type*);
const Type* typeConstructorType(ASCIILiteral, std::function<const Type*(AST::ElaboratedTypeExpression&)>&&);
Expand Down
1 change: 1 addition & 0 deletions Source/WebGPU/WGSL/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ struct Reference {
AddressSpace addressSpace;
AccessMode accessMode;
const Type* element;
bool isVectorComponent;
};

struct Pointer {
Expand Down

0 comments on commit 8529f01

Please sign in to comment.