Skip to content

Commit

Permalink
[WebGPU] api,validation,shader_module,* is failing
Browse files Browse the repository at this point in the history
https://bugs.webkit.org/show_bug.cgi?id=267282
<radar://120726178>

Reviewed by Tadeu Zagallo.

Correct shader module validation and add passing expectations.

* LayoutTests/http/tests/webgpu/webgpu/api/validation/shader_module/entry_point-expected.txt:
* Source/WebCore/Modules/WebGPU/GPUProgrammableStage.h:
* Source/WebCore/Modules/WebGPU/Implementation/WebGPUDeviceImpl.cpp:
(WebCore::WebGPU::invalidEntryPointName):
(WebCore::WebGPU::convertToBacking):
* Source/WebCore/Modules/WebGPU/InternalAPI/WebGPUProgrammableStage.h:
* Source/WebGPU/WebGPU/ComputePipeline.mm:
(WebGPU::Device::createComputePipeline):
* Source/WebGPU/WebGPU/Pipeline.mm:
(WebGPU::createLibrary):
* Source/WebGPU/WebGPU/RenderPipeline.mm:
(WebGPU::Device::createRenderPipeline):
* Source/WebGPU/WebGPU/ShaderModule.h:
(WebGPU::ShaderModule::create):
* Source/WebGPU/WebGPU/ShaderModule.mm:
(WebGPU::earlyCompileShaderModule):
(WebGPU::handleShaderSuccessOrFailure):
(WebGPU::Device::createShaderModule):
(WebGPU::ShaderModule::ShaderModule):
(WebGPU::ShaderModule::fragmentInputsForEntryPoint const):
(WebGPU::ShaderModule::fragmentReturnTypeForEntryPoint const):
(WebGPU::ShaderModule::vertexReturnTypeForEntryPoint const):
(WebGPU::ShaderModule::stageInTypesForEntryPoint const):
(WebGPU::ShaderModule::entryPointInformation const):
(WebGPU::ShaderModule::transformedEntryPoint const):
* Source/WebKit/Shared/WebGPU/WebGPUProgrammableStage.h:
* Source/WebKit/Shared/WebGPU/WebGPUProgrammableStage.serialization.in:

Canonical link: https://commits.webkit.org/272865@main
  • Loading branch information
mwyrzykowski committed Jan 10, 2024
1 parent f568db0 commit 119ba1f
Show file tree
Hide file tree
Showing 11 changed files with 225 additions and 32 deletions.
Original file line number Diff line number Diff line change
@@ -1 +1,104 @@
(Populate me when we're ready to investigate this test)

PASS :compute:isAsync=true;shaderModuleEntryPoint="main";stageEntryPoint="main"
PASS :compute:isAsync=true;shaderModuleEntryPoint="main";stageEntryPoint=""
PASS :compute:isAsync=true;shaderModuleEntryPoint="main";stageEntryPoint="main%5Cu0000"
PASS :compute:isAsync=true;shaderModuleEntryPoint="main";stageEntryPoint="main%5Cu0000a"
PASS :compute:isAsync=true;shaderModuleEntryPoint="main";stageEntryPoint="mian"
PASS :compute:isAsync=true;shaderModuleEntryPoint="main";stageEntryPoint="main%20"
PASS :compute:isAsync=true;shaderModuleEntryPoint="main";stageEntryPoint="ma%20in"
PASS :compute:isAsync=true;shaderModuleEntryPoint="main";stageEntryPoint="main%5Cn"
PASS :compute:isAsync=true;shaderModuleEntryPoint="mian";stageEntryPoint="mian"
PASS :compute:isAsync=true;shaderModuleEntryPoint="mian";stageEntryPoint="main"
PASS :compute:isAsync=true;shaderModuleEntryPoint="mainmain";stageEntryPoint="mainmain"
PASS :compute:isAsync=true;shaderModuleEntryPoint="mainmain";stageEntryPoint="foo"
PASS :compute:isAsync=true;shaderModuleEntryPoint="main_t12V3";stageEntryPoint="main_t12V3"
PASS :compute:isAsync=true;shaderModuleEntryPoint="main_t12V3";stageEntryPoint="main_t12V5"
PASS :compute:isAsync=true;shaderModuleEntryPoint="main_t12V3";stageEntryPoint="_main_t12V3"
PASS :compute:isAsync=true;shaderModuleEntryPoint="s%C3%A9quen%C3%A7age";stageEntryPoint="s%C3%A9quen%C3%A7age"
PASS :compute:isAsync=true;shaderModuleEntryPoint="s%C3%A9quen%C3%A7age";stageEntryPoint="se%CC%81quen%C3%A7age"
PASS :compute:isAsync=false;shaderModuleEntryPoint="main";stageEntryPoint="main"
PASS :compute:isAsync=false;shaderModuleEntryPoint="main";stageEntryPoint=""
PASS :compute:isAsync=false;shaderModuleEntryPoint="main";stageEntryPoint="main%5Cu0000"
PASS :compute:isAsync=false;shaderModuleEntryPoint="main";stageEntryPoint="main%5Cu0000a"
PASS :compute:isAsync=false;shaderModuleEntryPoint="main";stageEntryPoint="mian"
PASS :compute:isAsync=false;shaderModuleEntryPoint="main";stageEntryPoint="main%20"
PASS :compute:isAsync=false;shaderModuleEntryPoint="main";stageEntryPoint="ma%20in"
PASS :compute:isAsync=false;shaderModuleEntryPoint="main";stageEntryPoint="main%5Cn"
PASS :compute:isAsync=false;shaderModuleEntryPoint="mian";stageEntryPoint="mian"
PASS :compute:isAsync=false;shaderModuleEntryPoint="mian";stageEntryPoint="main"
PASS :compute:isAsync=false;shaderModuleEntryPoint="mainmain";stageEntryPoint="mainmain"
PASS :compute:isAsync=false;shaderModuleEntryPoint="mainmain";stageEntryPoint="foo"
PASS :compute:isAsync=false;shaderModuleEntryPoint="main_t12V3";stageEntryPoint="main_t12V3"
PASS :compute:isAsync=false;shaderModuleEntryPoint="main_t12V3";stageEntryPoint="main_t12V5"
PASS :compute:isAsync=false;shaderModuleEntryPoint="main_t12V3";stageEntryPoint="_main_t12V3"
PASS :compute:isAsync=false;shaderModuleEntryPoint="s%C3%A9quen%C3%A7age";stageEntryPoint="s%C3%A9quen%C3%A7age"
PASS :compute:isAsync=false;shaderModuleEntryPoint="s%C3%A9quen%C3%A7age";stageEntryPoint="se%CC%81quen%C3%A7age"
PASS :vertex:isAsync=true;shaderModuleEntryPoint="main";stageEntryPoint="main"
PASS :vertex:isAsync=true;shaderModuleEntryPoint="main";stageEntryPoint=""
PASS :vertex:isAsync=true;shaderModuleEntryPoint="main";stageEntryPoint="main%5Cu0000"
PASS :vertex:isAsync=true;shaderModuleEntryPoint="main";stageEntryPoint="main%5Cu0000a"
PASS :vertex:isAsync=true;shaderModuleEntryPoint="main";stageEntryPoint="mian"
PASS :vertex:isAsync=true;shaderModuleEntryPoint="main";stageEntryPoint="main%20"
PASS :vertex:isAsync=true;shaderModuleEntryPoint="main";stageEntryPoint="ma%20in"
PASS :vertex:isAsync=true;shaderModuleEntryPoint="main";stageEntryPoint="main%5Cn"
PASS :vertex:isAsync=true;shaderModuleEntryPoint="mian";stageEntryPoint="mian"
PASS :vertex:isAsync=true;shaderModuleEntryPoint="mian";stageEntryPoint="main"
PASS :vertex:isAsync=true;shaderModuleEntryPoint="mainmain";stageEntryPoint="mainmain"
PASS :vertex:isAsync=true;shaderModuleEntryPoint="mainmain";stageEntryPoint="foo"
PASS :vertex:isAsync=true;shaderModuleEntryPoint="main_t12V3";stageEntryPoint="main_t12V3"
PASS :vertex:isAsync=true;shaderModuleEntryPoint="main_t12V3";stageEntryPoint="main_t12V5"
PASS :vertex:isAsync=true;shaderModuleEntryPoint="main_t12V3";stageEntryPoint="_main_t12V3"
PASS :vertex:isAsync=true;shaderModuleEntryPoint="s%C3%A9quen%C3%A7age";stageEntryPoint="s%C3%A9quen%C3%A7age"
PASS :vertex:isAsync=true;shaderModuleEntryPoint="s%C3%A9quen%C3%A7age";stageEntryPoint="se%CC%81quen%C3%A7age"
PASS :vertex:isAsync=false;shaderModuleEntryPoint="main";stageEntryPoint="main"
PASS :vertex:isAsync=false;shaderModuleEntryPoint="main";stageEntryPoint=""
PASS :vertex:isAsync=false;shaderModuleEntryPoint="main";stageEntryPoint="main%5Cu0000"
PASS :vertex:isAsync=false;shaderModuleEntryPoint="main";stageEntryPoint="main%5Cu0000a"
PASS :vertex:isAsync=false;shaderModuleEntryPoint="main";stageEntryPoint="mian"
PASS :vertex:isAsync=false;shaderModuleEntryPoint="main";stageEntryPoint="main%20"
PASS :vertex:isAsync=false;shaderModuleEntryPoint="main";stageEntryPoint="ma%20in"
PASS :vertex:isAsync=false;shaderModuleEntryPoint="main";stageEntryPoint="main%5Cn"
PASS :vertex:isAsync=false;shaderModuleEntryPoint="mian";stageEntryPoint="mian"
PASS :vertex:isAsync=false;shaderModuleEntryPoint="mian";stageEntryPoint="main"
PASS :vertex:isAsync=false;shaderModuleEntryPoint="mainmain";stageEntryPoint="mainmain"
PASS :vertex:isAsync=false;shaderModuleEntryPoint="mainmain";stageEntryPoint="foo"
PASS :vertex:isAsync=false;shaderModuleEntryPoint="main_t12V3";stageEntryPoint="main_t12V3"
PASS :vertex:isAsync=false;shaderModuleEntryPoint="main_t12V3";stageEntryPoint="main_t12V5"
PASS :vertex:isAsync=false;shaderModuleEntryPoint="main_t12V3";stageEntryPoint="_main_t12V3"
PASS :vertex:isAsync=false;shaderModuleEntryPoint="s%C3%A9quen%C3%A7age";stageEntryPoint="s%C3%A9quen%C3%A7age"
PASS :vertex:isAsync=false;shaderModuleEntryPoint="s%C3%A9quen%C3%A7age";stageEntryPoint="se%CC%81quen%C3%A7age"
PASS :fragment:isAsync=true;shaderModuleEntryPoint="main";stageEntryPoint="main"
PASS :fragment:isAsync=true;shaderModuleEntryPoint="main";stageEntryPoint=""
PASS :fragment:isAsync=true;shaderModuleEntryPoint="main";stageEntryPoint="main%5Cu0000"
PASS :fragment:isAsync=true;shaderModuleEntryPoint="main";stageEntryPoint="main%5Cu0000a"
PASS :fragment:isAsync=true;shaderModuleEntryPoint="main";stageEntryPoint="mian"
PASS :fragment:isAsync=true;shaderModuleEntryPoint="main";stageEntryPoint="main%20"
PASS :fragment:isAsync=true;shaderModuleEntryPoint="main";stageEntryPoint="ma%20in"
PASS :fragment:isAsync=true;shaderModuleEntryPoint="main";stageEntryPoint="main%5Cn"
PASS :fragment:isAsync=true;shaderModuleEntryPoint="mian";stageEntryPoint="mian"
PASS :fragment:isAsync=true;shaderModuleEntryPoint="mian";stageEntryPoint="main"
PASS :fragment:isAsync=true;shaderModuleEntryPoint="mainmain";stageEntryPoint="mainmain"
PASS :fragment:isAsync=true;shaderModuleEntryPoint="mainmain";stageEntryPoint="foo"
PASS :fragment:isAsync=true;shaderModuleEntryPoint="main_t12V3";stageEntryPoint="main_t12V3"
PASS :fragment:isAsync=true;shaderModuleEntryPoint="main_t12V3";stageEntryPoint="main_t12V5"
PASS :fragment:isAsync=true;shaderModuleEntryPoint="main_t12V3";stageEntryPoint="_main_t12V3"
PASS :fragment:isAsync=true;shaderModuleEntryPoint="s%C3%A9quen%C3%A7age";stageEntryPoint="s%C3%A9quen%C3%A7age"
PASS :fragment:isAsync=true;shaderModuleEntryPoint="s%C3%A9quen%C3%A7age";stageEntryPoint="se%CC%81quen%C3%A7age"
PASS :fragment:isAsync=false;shaderModuleEntryPoint="main";stageEntryPoint="main"
PASS :fragment:isAsync=false;shaderModuleEntryPoint="main";stageEntryPoint=""
PASS :fragment:isAsync=false;shaderModuleEntryPoint="main";stageEntryPoint="main%5Cu0000"
PASS :fragment:isAsync=false;shaderModuleEntryPoint="main";stageEntryPoint="main%5Cu0000a"
PASS :fragment:isAsync=false;shaderModuleEntryPoint="main";stageEntryPoint="mian"
PASS :fragment:isAsync=false;shaderModuleEntryPoint="main";stageEntryPoint="main%20"
PASS :fragment:isAsync=false;shaderModuleEntryPoint="main";stageEntryPoint="ma%20in"
PASS :fragment:isAsync=false;shaderModuleEntryPoint="main";stageEntryPoint="main%5Cn"
PASS :fragment:isAsync=false;shaderModuleEntryPoint="mian";stageEntryPoint="mian"
PASS :fragment:isAsync=false;shaderModuleEntryPoint="mian";stageEntryPoint="main"
PASS :fragment:isAsync=false;shaderModuleEntryPoint="mainmain";stageEntryPoint="mainmain"
PASS :fragment:isAsync=false;shaderModuleEntryPoint="mainmain";stageEntryPoint="foo"
PASS :fragment:isAsync=false;shaderModuleEntryPoint="main_t12V3";stageEntryPoint="main_t12V3"
PASS :fragment:isAsync=false;shaderModuleEntryPoint="main_t12V3";stageEntryPoint="main_t12V5"
PASS :fragment:isAsync=false;shaderModuleEntryPoint="main_t12V3";stageEntryPoint="_main_t12V3"
PASS :fragment:isAsync=false;shaderModuleEntryPoint="s%C3%A9quen%C3%A7age";stageEntryPoint="s%C3%A9quen%C3%A7age"
PASS :fragment:isAsync=false;shaderModuleEntryPoint="s%C3%A9quen%C3%A7age";stageEntryPoint="se%CC%81quen%C3%A7age"

2 changes: 1 addition & 1 deletion Source/WebCore/Modules/WebGPU/GPUProgrammableStage.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ struct GPUProgrammableStage {
}

GPUShaderModule* module { nullptr };
String entryPoint;
std::optional<String> entryPoint;
Vector<KeyValuePair<String, GPUPipelineConstantValue>> constants;
};

Expand Down
34 changes: 27 additions & 7 deletions Source/WebCore/Modules/WebGPU/Implementation/WebGPUDeviceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@

namespace WebCore::WebGPU {

static auto invalidEntryPointName()
{
return CString("");
}

DeviceImpl::DeviceImpl(WebGPUPtr<WGPUDevice>&& device, Ref<SupportedFeatures>&& features, Ref<SupportedLimits>&& limits, ConvertToBackingContext& convertToBackingContext)
: Device(WTFMove(features), WTFMove(limits))
, m_backing(device.copyRef())
Expand Down Expand Up @@ -318,7 +323,12 @@ static auto convertToBacking(const ComputePipelineDescriptor& descriptor, Conver
{
auto label = descriptor.label.utf8();

auto entryPoint = descriptor.compute.entryPoint.utf8();
std::optional<CString> entryPoint;
if (auto& descriptorEntryPoint = descriptor.compute.entryPoint) {
entryPoint = descriptorEntryPoint->utf8();
if (descriptorEntryPoint->length() != String::fromUTF8(entryPoint->data()).length())
entryPoint = invalidEntryPointName();
}

auto constantNames = descriptor.compute.constants.map([](const auto& constant) {
bool lengthsMatch = constant.key.length() == String::fromUTF8(constant.key.utf8().data()).length();
Expand All @@ -340,7 +350,7 @@ static auto convertToBacking(const ComputePipelineDescriptor& descriptor, Conver
descriptor.layout ? convertToBackingContext.convertToBacking(*descriptor.layout) : nullptr, {
nullptr,
convertToBackingContext.convertToBacking(descriptor.compute.module),
entryPoint.data(),
entryPoint ? entryPoint->data() : nullptr,
static_cast<uint32_t>(backingConstantEntries.size()),
backingConstantEntries.data(),
}
Expand All @@ -361,7 +371,12 @@ static auto convertToBacking(const RenderPipelineDescriptor& descriptor, Convert
{
auto label = descriptor.label.utf8();

auto vertexEntryPoint = descriptor.vertex.entryPoint.utf8();
std::optional<CString> vertexEntryPoint;
if (auto& descriptorEntryPoint = descriptor.vertex.entryPoint) {
vertexEntryPoint = descriptorEntryPoint->utf8();
if (descriptorEntryPoint->length() != String::fromUTF8(vertexEntryPoint->data()).length())
vertexEntryPoint = invalidEntryPointName();
}

auto vertexConstantNames = descriptor.vertex.constants.map([](const auto& constant) {
bool lengthsMatch = constant.key.length() == String::fromUTF8(constant.key.utf8().data()).length();
Expand Down Expand Up @@ -424,10 +439,15 @@ static auto convertToBacking(const RenderPipelineDescriptor& descriptor, Convert
.depthBiasClamp = descriptor.depthStencil ? descriptor.depthStencil->depthBiasClamp : 0,
};

auto fragmentEntryPoint = descriptor.fragment ? descriptor.fragment->entryPoint.utf8() : CString("");

std::optional<CString> fragmentEntryPoint;
Vector<CString> fragmentConstantNames;
if (descriptor.fragment) {
if (auto& descriptorEntryPoint = descriptor.fragment->entryPoint) {
fragmentEntryPoint = descriptorEntryPoint->utf8();
if (descriptorEntryPoint->length() != String::fromUTF8(descriptor.fragment->entryPoint->utf8().data()).length())
fragmentEntryPoint = invalidEntryPointName();
}

fragmentConstantNames = descriptor.fragment->constants.map([](const auto& constant) {
bool lengthsMatch = constant.key.length() == String::fromUTF8(constant.key.utf8().data()).length();
return lengthsMatch ? constant.key.utf8() : "";
Expand Down Expand Up @@ -489,7 +509,7 @@ static auto convertToBacking(const RenderPipelineDescriptor& descriptor, Convert
WGPUFragmentState fragmentState {
nullptr,
descriptor.fragment ? convertToBackingContext.convertToBacking(descriptor.fragment->module) : nullptr,
fragmentEntryPoint.data(),
fragmentEntryPoint ? fragmentEntryPoint->data() : nullptr,
static_cast<uint32_t>(fragmentConstantEntries.size()),
fragmentConstantEntries.data(),
static_cast<uint32_t>(colorTargets.size()),
Expand All @@ -510,7 +530,7 @@ static auto convertToBacking(const RenderPipelineDescriptor& descriptor, Convert
descriptor.layout ? convertToBackingContext.convertToBacking(*descriptor.layout) : nullptr, {
nullptr,
convertToBackingContext.convertToBacking(descriptor.vertex.module),
vertexEntryPoint.data(),
vertexEntryPoint ? vertexEntryPoint->data() : nullptr,
static_cast<uint32_t>(vertexConstantEntries.size()),
vertexConstantEntries.data(),
static_cast<uint32_t>(backingBuffers.size()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ using PipelineConstantValue = double; // May represent WGSL’s bool, f32, i32,

struct ProgrammableStage {
ShaderModule& module;
String entryPoint;
std::optional<String> entryPoint;
Vector<KeyValuePair<String, PipelineConstantValue>> constants;
};

Expand Down
3 changes: 1 addition & 2 deletions Source/WebGPU/WebGPU/ComputePipeline.mm
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ static MTLSize metalSize(auto workgroupSize, const HashMap<String, WGSL::Constan

PipelineLayout& pipelineLayout = WebGPU::fromAPI(descriptor.layout);
auto label = fromAPI(descriptor.label);
auto entryPoint = fromAPI(descriptor.compute.entryPoint);
auto libraryCreationResult = createLibrary(m_device, shaderModule, &pipelineLayout, entryPoint.length() ? entryPoint : shaderModule.defaultComputeEntryPoint(), label);
auto libraryCreationResult = createLibrary(m_device, shaderModule, &pipelineLayout, descriptor.compute.entryPoint ? fromAPI(descriptor.compute.entryPoint) : shaderModule.defaultComputeEntryPoint(), label);
if (!libraryCreationResult || &pipelineLayout.device() != this)
return returnInvalidComputePipeline(*this, isAsync);

Expand Down
5 changes: 4 additions & 1 deletion Source/WebGPU/WebGPU/Pipeline.mm
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@

namespace WebGPU {

std::optional<LibraryCreationResult> createLibrary(id<MTLDevice> device, const ShaderModule& shaderModule, const PipelineLayout* pipelineLayout, const String& entryPoint, NSString *label)
std::optional<LibraryCreationResult> createLibrary(id<MTLDevice> device, const ShaderModule& shaderModule, const PipelineLayout* pipelineLayout, const String& untransformedEntryPoint, NSString *label)
{
// FIXME: Remove below line when https://bugs.webkit.org/show_bug.cgi?id=266774 is completed
auto entryPoint = shaderModule.transformedEntryPoint(untransformedEntryPoint);
if (!entryPoint.length() || !shaderModule.isValid())
return std::nullopt;

Expand All @@ -58,6 +60,7 @@
auto iterator = prepareResult.entryPoints.find(entryPoint);
if (iterator == prepareResult.entryPoints.end())
return std::nullopt;

const auto& entryPointInformation = iterator->value;

return { { library, entryPointInformation } };
Expand Down
7 changes: 2 additions & 5 deletions Source/WebGPU/WebGPU/RenderPipeline.mm
Original file line number Diff line number Diff line change
Expand Up @@ -1255,8 +1255,7 @@ static uint32_t componentsForDataType(MTLDataType dataType)
if (&vertexModule.device() != this)
return returnInvalidRenderPipeline(*this, isAsync, "Vertex module was created with a different device"_s);

const auto& vertexFunctionName = fromAPI(descriptor.vertex.entryPoint);
const auto& vertexEntryPoint = vertexFunctionName.length() ? vertexFunctionName : vertexModule.defaultVertexEntryPoint();
const auto& vertexEntryPoint = descriptor.vertex.entryPoint ? fromAPI(descriptor.vertex.entryPoint) : vertexModule.defaultVertexEntryPoint();
auto libraryCreationResult = createLibrary(m_device, vertexModule, pipelineLayout, vertexEntryPoint, label);
if (!libraryCreationResult)
return returnInvalidRenderPipeline(*this, isAsync, "Vertex library failed creation"_s);
Expand Down Expand Up @@ -1298,9 +1297,7 @@ static uint32_t componentsForDataType(MTLDataType dataType)
RELEASE_ASSERT(fragmentShaderModule);
usesFragDepth = fragmentShaderModule->usesFragDepth();
usesSampleMask = fragmentShaderModule->usesSampleMask();
const auto& fragmentFunctionName = fromAPI(fragmentDescriptor.entryPoint);

const auto& fragmentEntryPoint = fragmentFunctionName.length() ? fragmentFunctionName : fragmentModule.defaultFragmentEntryPoint();
const auto& fragmentEntryPoint = fragmentDescriptor.entryPoint ? fromAPI(fragmentDescriptor.entryPoint) : fragmentModule.defaultFragmentEntryPoint();
auto libraryCreationResult = createLibrary(m_device, fragmentModule, pipelineLayout, fragmentEntryPoint, label);
if (!libraryCreationResult)
return returnInvalidRenderPipeline(*this, isAsync, "Fragment library could not be created"_s);
Expand Down
Loading

0 comments on commit 119ba1f

Please sign in to comment.