Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebGPU] Support tier1 level argument buffers #8275

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Source/WTF/wtf/PlatformHave.h
Original file line number Diff line number Diff line change
Expand Up @@ -1145,7 +1145,6 @@
#endif

#if HAVE(WEBGPU_IMPLEMENTATION) && ((PLATFORM(MAC) && __MAC_OS_X_VERSION_MIN_REQUIRED >= 130000) || (PLATFORM(IOS) && __IPHONE_OS_VERSION_MIN_REQUIRED >= 160000))
#define HAVE_TIER2_ARGUMENT_BUFFERS 1
#define HAVE_METAL_BUFFER_BINDING_REFLECTION 1
#endif

Expand Down
172 changes: 44 additions & 128 deletions Source/WebGPU/WebGPU/BindGroup.mm
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#import "Device.h"
#import "Sampler.h"
#import "TextureView.h"
#import <wtf/EnumeratedArray.h>

namespace WebGPU {

Expand All @@ -50,127 +51,41 @@ static bool textureViewIsPresent(const WGPUBindGroupEntry& entry)
return entry.textureView;
}

#if HAVE(TIER2_ARGUMENT_BUFFERS)
static auto sizeOfEntries(const WGPUBindGroupDescriptor& descriptor, BindGroupLayout* bindGroupLayout)
static MTLRenderStages metalRenderStage(ShaderStage shaderStage)
{
uint32_t sizes[] = { 0, 0, 0 };
for (uint32_t entryIndex = 0; entryIndex < descriptor.entryCount; ++entryIndex) {
const WGPUBindGroupEntry& entry = descriptor.entries[entryIndex];

auto stages = bindGroupLayout ? bindGroupLayout->stagesForBinding(entry.binding) : (WGPUShaderStage_Vertex | WGPUShaderStage_Fragment | WGPUShaderStage_Compute);
constexpr WGPUShaderStage shaderStage[] = { WGPUShaderStage_Vertex, WGPUShaderStage_Fragment, WGPUShaderStage_Compute };
bool bufferIsPresent = WebGPU::bufferIsPresent(entry);
bool samplerIsPresent = WebGPU::samplerIsPresent(entry);
bool textureViewIsPresent = WebGPU::textureViewIsPresent(entry);

for (size_t currentStage = 0; currentStage < std::size(shaderStage); ++currentStage) {
WGPUShaderStage renderStage = shaderStage[currentStage];
if (!(stages & renderStage))
continue;

if (bufferIsPresent)
sizes[currentStage] += sizeof(float*);
else if (samplerIsPresent || textureViewIsPresent)
sizes[currentStage] += sizeof(MTLResourceID);
}
switch (shaderStage) {
case ShaderStage::Vertex:
return MTLRenderStageVertex;
case ShaderStage::Fragment:
return MTLRenderStageFragment;
case ShaderStage::Compute:
return (MTLRenderStages)0;
}

return std::array<uint32_t, 3>( { sizes[0], sizes[1], sizes[2] } );
}
#endif

template <typename T>
using ShaderStageArray = EnumeratedArray<ShaderStage, T, ShaderStage::Compute>;

Ref<BindGroup> Device::createBindGroup(const WGPUBindGroupDescriptor& descriptor)
{
if (descriptor.nextInChain)
return BindGroup::createInvalid(*this);

constexpr ShaderStage stages[] = { ShaderStage::Vertex, ShaderStage::Fragment, ShaderStage::Compute };
constexpr size_t stageCount = std::size(stages);
ShaderStageArray<NSUInteger> bindingIndexForStage = std::array<NSUInteger, stageCount>();
const auto& bindGroupLayout = WebGPU::fromAPI(descriptor.layout);
Vector<BindableResource> resources;
#if HAVE(TIER2_ARGUMENT_BUFFERS)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, so much red!

if ([m_device argumentBuffersSupport] != MTLArgumentBuffersTier1) {
uint32_t entryCount = descriptor.entryCount;
auto bindGroupLayout = descriptor.layout ? &WebGPU::fromAPI(descriptor.layout) : nullptr;
auto bufferSizes = sizeOfEntries(descriptor, bindGroupLayout);
constexpr WGPUShaderStage shaderStage[] = { WGPUShaderStage_Vertex, WGPUShaderStage_Fragment, WGPUShaderStage_Compute };
constexpr auto shaderStageLength = std::size(shaderStage);
id<MTLBuffer> argumentBuffer[shaderStageLength];
for (size_t i = 0; i < shaderStageLength; ++i)
argumentBuffer[i] = bufferSizes[i] ? safeCreateBuffer(bufferSizes[i], MTLStorageModeShared) : nil;

char* argumentBufferContents[] = { (char*)argumentBuffer[0].contents, (char*)argumentBuffer[1].contents, (char*)argumentBuffer[2].contents };
size_t bindingOffsetForStage[] = { 0, 0, 0 };
for (uint32_t i = 0; i < entryCount; ++i) {
const WGPUBindGroupEntry& entry = descriptor.entries[i];

if (entry.nextInChain)
return BindGroup::createInvalid(*this);

bool bufferIsPresent = WebGPU::bufferIsPresent(entry);
bool samplerIsPresent = WebGPU::samplerIsPresent(entry);
bool textureViewIsPresent = WebGPU::textureViewIsPresent(entry);
if (bufferIsPresent + samplerIsPresent + textureViewIsPresent != 1)
return BindGroup::createInvalid(*this);

auto stages = bindGroupLayout ? bindGroupLayout->stagesForBinding(entry.binding) : (WGPUShaderStage_Vertex | WGPUShaderStage_Fragment | WGPUShaderStage_Compute);
for (size_t currentStage = 0; currentStage < shaderStageLength; ++currentStage) {
WGPUShaderStage renderStage = shaderStage[currentStage];
if (!(stages & renderStage))
continue;

auto& offset = bindingOffsetForStage[currentStage];
ASSERT(argumentBufferContents[currentStage]);

if (bufferIsPresent) {
id<MTLBuffer> buffer = WebGPU::fromAPI(entry.buffer).buffer();
if (entry.offset > buffer.length)
return BindGroup::createInvalid(*this);

ASSERT(sizeof(float*) == sizeof(buffer.gpuAddress));
*(float**)(argumentBufferContents[currentStage] + offset) = (float*)((char*)buffer.gpuAddress + entry.offset);
offset += sizeof(MTLResourceID);
} else if (samplerIsPresent) {
id<MTLSamplerState> sampler = WebGPU::fromAPI(entry.sampler).samplerState();
ASSERT(sizeof(float*) == sizeof(sampler.gpuResourceID));
*(MTLResourceID*)(argumentBufferContents[currentStage] + offset) = sampler.gpuResourceID;
offset += sizeof(MTLResourceID);
} else if (textureViewIsPresent) {
id<MTLTexture> texture = WebGPU::fromAPI(entry.textureView).texture();
ASSERT(sizeof(float*) == sizeof(texture.gpuResourceID));
*(MTLResourceID*)(argumentBufferContents[currentStage] + offset) = texture.gpuResourceID;
resources.append({ texture, MTLResourceUsageRead, renderStage });
offset += sizeof(MTLResourceID);
}
}
}

return BindGroup::create(argumentBuffer[0], argumentBuffer[1], argumentBuffer[2], WTFMove(resources), *this);
ShaderStageArray<id<MTLArgumentEncoder>> argumentEncoder = std::array<id<MTLArgumentEncoder>, stageCount>({ bindGroupLayout.vertexArgumentEncoder(), bindGroupLayout.fragmentArgumentEncoder(), bindGroupLayout.computeArgumentEncoder() });
ShaderStageArray<id<MTLBuffer>> argumentBuffer;
for (ShaderStage stage : stages) {
auto encodedLength = bindGroupLayout.encodedLength(stage);
argumentBuffer[stage] = encodedLength ? safeCreateBuffer(encodedLength, MTLStorageModeShared) : nil;
[argumentEncoder[stage] setArgumentBuffer:argumentBuffer[stage] offset:0];
}
#endif // HAVE(TIER2_ARGUMENT_BUFFERS)

// FIXME: Validate this according to the spec.

const BindGroupLayout& bindGroupLayout = WebGPU::fromAPI(descriptor.layout);

// FIXME(PERFORMANCE): Don't allocate 3 new buffers for every bind group.
// In fact, don't even allocate a single new buffer for every bind group.
id<MTLBuffer> vertexArgumentBuffer = safeCreateBuffer(bindGroupLayout.encodedLength(), MTLStorageModeShared);
id<MTLBuffer> fragmentArgumentBuffer = safeCreateBuffer(bindGroupLayout.encodedLength(), MTLStorageModeShared);
id<MTLBuffer> computeArgumentBuffer = safeCreateBuffer(bindGroupLayout.encodedLength(), MTLStorageModeShared);
if (!vertexArgumentBuffer || !fragmentArgumentBuffer || !computeArgumentBuffer)
return BindGroup::createInvalid(*this);

auto label = fromAPI(descriptor.label);
vertexArgumentBuffer.label = label;
fragmentArgumentBuffer.label = label;
computeArgumentBuffer.label = label;

id<MTLArgumentEncoder> vertexArgumentEncoder = bindGroupLayout.vertexArgumentEncoder();
id<MTLArgumentEncoder> fragmentArgumentEncoder = bindGroupLayout.fragmentArgumentEncoder();
id<MTLArgumentEncoder> computeArgumentEncoder = bindGroupLayout.computeArgumentEncoder();
[vertexArgumentEncoder setArgumentBuffer:vertexArgumentBuffer offset:0];
[fragmentArgumentEncoder setArgumentBuffer:fragmentArgumentBuffer offset:0];
[computeArgumentEncoder setArgumentBuffer:computeArgumentBuffer offset:0];

for (uint32_t i = 0; i < descriptor.entryCount; ++i) {
for (uint32_t i = 0, entryCount = descriptor.entryCount; i < entryCount; ++i) {
const WGPUBindGroupEntry& entry = descriptor.entries[i];

if (entry.nextInChain)
Expand All @@ -179,31 +94,32 @@ static auto sizeOfEntries(const WGPUBindGroupDescriptor& descriptor, BindGroupLa
bool bufferIsPresent = WebGPU::bufferIsPresent(entry);
bool samplerIsPresent = WebGPU::samplerIsPresent(entry);
bool textureViewIsPresent = WebGPU::textureViewIsPresent(entry);
if (static_cast<int>(bufferIsPresent) + static_cast<int>(samplerIsPresent) + static_cast<int>(textureViewIsPresent) != 1)
if (bufferIsPresent + samplerIsPresent + textureViewIsPresent != 1)
return BindGroup::createInvalid(*this);

if (bufferIsPresent) {
id<MTLBuffer> buffer = WebGPU::fromAPI(entry.buffer).buffer();
[vertexArgumentEncoder setBuffer:buffer offset:static_cast<NSUInteger>(entry.offset) atIndex:entry.binding];
[fragmentArgumentEncoder setBuffer:buffer offset:static_cast<NSUInteger>(entry.offset) atIndex:entry.binding];
[computeArgumentEncoder setBuffer:buffer offset:static_cast<NSUInteger>(entry.offset) atIndex:entry.binding];
} else if (samplerIsPresent) {
id<MTLSamplerState> sampler = WebGPU::fromAPI(entry.sampler).samplerState();
[vertexArgumentEncoder setSamplerState:sampler atIndex:entry.binding];
[fragmentArgumentEncoder setSamplerState:sampler atIndex:entry.binding];
[computeArgumentEncoder setSamplerState:sampler atIndex:entry.binding];
} else if (textureViewIsPresent) {
id<MTLTexture> texture = WebGPU::fromAPI(entry.textureView).texture();
[vertexArgumentEncoder setTexture:texture atIndex:entry.binding];
[fragmentArgumentEncoder setTexture:texture atIndex:entry.binding];
[computeArgumentEncoder setTexture:texture atIndex:entry.binding];
} else {
ASSERT_NOT_REACHED();
return BindGroup::createInvalid(*this);
for (ShaderStage stage : stages) {
if (!bindGroupLayout.bindingContainsStage(entry.binding, stage))
continue;

auto& index = bindingIndexForStage[stage];
if (bufferIsPresent) {
id<MTLBuffer> buffer = WebGPU::fromAPI(entry.buffer).buffer();
if (entry.offset > buffer.length)
return BindGroup::createInvalid(*this);

[argumentEncoder[stage] setBuffer:buffer offset:entry.offset atIndex:index++];
} else if (samplerIsPresent) {
id<MTLSamplerState> sampler = WebGPU::fromAPI(entry.sampler).samplerState();
[argumentEncoder[stage] setSamplerState:sampler atIndex:index++];
} else if (textureViewIsPresent) {
id<MTLTexture> texture = WebGPU::fromAPI(entry.textureView).texture();
[argumentEncoder[stage] setTexture:texture atIndex:index++];
resources.append({ texture, MTLResourceUsageRead, metalRenderStage(stage) });
}
}
}

return BindGroup::create(vertexArgumentBuffer, fragmentArgumentBuffer, computeArgumentBuffer, WTFMove(resources), *this);
return BindGroup::create(argumentBuffer[ShaderStage::Vertex], argumentBuffer[ShaderStage::Fragment], argumentBuffer[ShaderStage::Compute], WTFMove(resources), *this);
}

BindGroup::BindGroup(id<MTLBuffer> vertexArgumentBuffer, id<MTLBuffer> fragmentArgumentBuffer, id<MTLBuffer> computeArgumentBuffer, Vector<BindableResource>&& resources, Device& device)
Expand Down
27 changes: 14 additions & 13 deletions Source/WebGPU/WebGPU/BindGroupLayout.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,21 @@ struct WGPUBindGroupLayoutImpl {

namespace WebGPU {

enum class ShaderStage {
Vertex = 0,
Fragment = 1,
Compute = 2
};

class Device;

// https://gpuweb.github.io/gpuweb/#gpubindgrouplayout
class BindGroupLayout : public WGPUBindGroupLayoutImpl, public RefCounted<BindGroupLayout> {
WTF_MAKE_FAST_ALLOCATED;
public:
static Ref<BindGroupLayout> create(id<MTLArgumentEncoder> vertexArgumentEncoder, id<MTLArgumentEncoder> fragmentArgumentEncoder, id<MTLArgumentEncoder> computeArgumentEncoder)
{
return adoptRef(*new BindGroupLayout(vertexArgumentEncoder, fragmentArgumentEncoder, computeArgumentEncoder));
}
static Ref<BindGroupLayout> create(HashMap<uint32_t, WGPUShaderStageFlags>&& stageMapTable)
static Ref<BindGroupLayout> create(HashMap<uint32_t, WGPUShaderStageFlags>&& stageMapTable, id<MTLArgumentEncoder> vertexArgumentEncoder, id<MTLArgumentEncoder> fragmentArgumentEncoder, id<MTLArgumentEncoder> computeArgumentEncoder)
{
return adoptRef(*new BindGroupLayout(WTFMove(stageMapTable)));
return adoptRef(*new BindGroupLayout(WTFMove(stageMapTable), vertexArgumentEncoder, fragmentArgumentEncoder, computeArgumentEncoder));
}
static Ref<BindGroupLayout> createInvalid(Device&)
{
Expand All @@ -58,26 +60,25 @@ class BindGroupLayout : public WGPUBindGroupLayoutImpl, public RefCounted<BindGr

void setLabel(String&&);

bool isValid() const { return m_shaderStageForBinding.size() || m_vertexArgumentEncoder || m_fragmentArgumentEncoder || m_computeArgumentEncoder; }
bool isValid() const { return m_shaderStageForBinding.size(); }

NSUInteger encodedLength() const;
NSUInteger encodedLength(ShaderStage) const;

id<MTLArgumentEncoder> vertexArgumentEncoder() const { return m_vertexArgumentEncoder; }
id<MTLArgumentEncoder> fragmentArgumentEncoder() const { return m_fragmentArgumentEncoder; }
id<MTLArgumentEncoder> computeArgumentEncoder() const { return m_computeArgumentEncoder; }

uint32_t stagesForBinding(uint32_t binding) const;
bool bindingContainsStage(uint32_t bindingIndex, ShaderStage renderStage) const;

private:
BindGroupLayout(id<MTLArgumentEncoder> vertexArgumentEncoder, id<MTLArgumentEncoder> fragmentArgumentEncoder, id<MTLArgumentEncoder> computeArgumentEncoder);
BindGroupLayout(HashMap<uint32_t, WGPUShaderStageFlags>&&);
BindGroupLayout(HashMap<uint32_t, WGPUShaderStageFlags>&&, id<MTLArgumentEncoder>, id<MTLArgumentEncoder>, id<MTLArgumentEncoder>);
BindGroupLayout();

const HashMap<uint32_t, WGPUShaderStageFlags> m_shaderStageForBinding;

const id<MTLArgumentEncoder> m_vertexArgumentEncoder { nil };
const id<MTLArgumentEncoder> m_fragmentArgumentEncoder { nil };
const id<MTLArgumentEncoder> m_computeArgumentEncoder { nil };

const HashMap<uint32_t, WGPUShaderStageFlags> m_shaderStageForBinding;
};

} // namespace WebGPU
Loading