Skip to content

Commit

Permalink
[WebGPU] Use UnsignedWithZeroKeyHashTraits in m_cachedBindGroupLayout…
Browse files Browse the repository at this point in the history
…s instead of adding 1 to the key

https://bugs.webkit.org/show_bug.cgi?id=257068
<rdar://problem/109590315>

Reviewed by Dan Glastonbury.

This is exactly what UnsignedWithZeroKeyHashTraits is meant for.

* Source/WebGPU/WebGPU/ComputePipeline.h:
* Source/WebGPU/WebGPU/ComputePipeline.mm:
(WebGPU::ComputePipeline::getBindGroupLayout):
* Source/WebGPU/WebGPU/RenderPipeline.h:
* Source/WebGPU/WebGPU/RenderPipeline.mm:
(WebGPU::RenderPipeline::getBindGroupLayout):

Canonical link: https://commits.webkit.org/264304@main
  • Loading branch information
litherum committed May 21, 2023
1 parent 25578b6 commit 597aa39
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 6 deletions.
4 changes: 3 additions & 1 deletion Source/WebGPU/WebGPU/ComputePipeline.h
Expand Up @@ -28,6 +28,8 @@
#import "BindGroupLayout.h"

#import <wtf/FastMalloc.h>
#import <wtf/HashMap.h>
#import <wtf/HashTraits.h>
#import <wtf/Ref.h>
#import <wtf/RefCounted.h>

Expand Down Expand Up @@ -81,7 +83,7 @@ class ComputePipeline : public WGPUComputePipelineImpl, public RefCounted<Comput
#endif
const PipelineLayout *m_pipelineLayout { nullptr };
const Ref<Device> m_device;
HashMap<uint32_t, Ref<BindGroupLayout>> m_cachedBindGroupLayouts;
HashMap<uint32_t, Ref<BindGroupLayout>, WTF::DefaultHash<uint32_t>, WTF::UnsignedWithZeroKeyHashTraits<uint32_t>> m_cachedBindGroupLayouts;
const MTLSize m_threadsPerThreadgroup;
};

Expand Down
4 changes: 2 additions & 2 deletions Source/WebGPU/WebGPU/ComputePipeline.mm
Expand Up @@ -138,7 +138,7 @@ static MTLSize metalSize(auto workgroupSize)
if (m_pipelineLayout)
return const_cast<BindGroupLayout*>(&m_pipelineLayout->bindGroupLayout(groupIndex));

auto it = m_cachedBindGroupLayouts.find(groupIndex + 1);
auto it = m_cachedBindGroupLayouts.find(groupIndex);
if (it != m_cachedBindGroupLayouts.end())
return it->value.ptr();

Expand All @@ -159,7 +159,7 @@ static MTLSize metalSize(auto workgroupSize)
bindGroupLayoutDescriptor.entryCount = entries.size();
bindGroupLayoutDescriptor.entries = entries.size() ? &entries[0] : nullptr;
auto bindGroupLayout = m_device->createBindGroupLayout(bindGroupLayoutDescriptor);
m_cachedBindGroupLayouts.add(groupIndex + 1, bindGroupLayout);
m_cachedBindGroupLayouts.add(groupIndex, bindGroupLayout);

return bindGroupLayout.ptr();
#else
Expand Down
3 changes: 2 additions & 1 deletion Source/WebGPU/WebGPU/RenderPipeline.h
Expand Up @@ -27,6 +27,7 @@

#import <wtf/FastMalloc.h>
#import <wtf/HashMap.h>
#import <wtf/HashTraits.h>
#import <wtf/Ref.h>
#import <wtf/RefCounted.h>

Expand Down Expand Up @@ -77,7 +78,7 @@ class RenderPipeline : public WGPURenderPipelineImpl, public RefCounted<RenderPi
const id<MTLRenderPipelineState> m_renderPipelineState { nil };

const Ref<Device> m_device;
HashMap<uint32_t, Ref<BindGroupLayout>> m_cachedBindGroupLayouts;
HashMap<uint32_t, Ref<BindGroupLayout>, WTF::DefaultHash<uint32_t>, WTF::UnsignedWithZeroKeyHashTraits<uint32_t>> m_cachedBindGroupLayouts;
MTLPrimitiveType m_primitiveType;
std::optional<MTLIndexType> m_indexType;
MTLWinding m_frontFace;
Expand Down
4 changes: 2 additions & 2 deletions Source/WebGPU/WebGPU/RenderPipeline.mm
Expand Up @@ -526,7 +526,7 @@ static void populateStencilOperation(MTLStencilDescriptor *mtlStencil, const WGP
if (m_pipelineLayout)
return const_cast<BindGroupLayout*>(&m_pipelineLayout->bindGroupLayout(groupIndex));

auto it = m_cachedBindGroupLayouts.find(groupIndex + 1);
auto it = m_cachedBindGroupLayouts.find(groupIndex);
if (it != m_cachedBindGroupLayouts.end())
return it->value.ptr();

Expand Down Expand Up @@ -557,7 +557,7 @@ static void populateStencilOperation(MTLStencilDescriptor *mtlStencil, const WGP
bindGroupLayoutDescriptor.entryCount = entries.size();
bindGroupLayoutDescriptor.entries = entries.size() ? &entries[0] : nullptr;
auto bindGroupLayout = m_device->createBindGroupLayout(bindGroupLayoutDescriptor);
m_cachedBindGroupLayouts.add(groupIndex + 1, bindGroupLayout);
m_cachedBindGroupLayouts.add(groupIndex, bindGroupLayout);

return WebGPU::releaseToAPI(WTFMove(bindGroupLayout));
#else
Expand Down

0 comments on commit 597aa39

Please sign in to comment.