Skip to content

Add dynamic bucket cache mode to improve peak and avg gpu buffer memory usage #25120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions onnxruntime/core/providers/webgpu/buffer_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,136 @@
std::vector<size_t> buckets_keys_;
};

class DynamicBucketCacheManager : public IBufferCacheManager {
public:
DynamicBucketCacheManager() {}

~DynamicBucketCacheManager() {
for (auto& pair : buckets_) {
for (auto& buffer : pair.second) {
wgpuBufferRelease(buffer);
}
}
}

void OnRunStart() override {
current_run_usage_.clear();
}

void OnRunEnd() override {
// Update memory patterns based on this session run.
for (const auto& usage : current_run_usage_) {
auto& pattern = memory_patterns_[usage.first];
pattern.request_size = usage.first;
pattern.frequency = usage.second;
}

// Adjust buckets based on the collected memory patterns every 2 runs.
// The reason for this is to allow the cache to adapt to the memory usage patterns
// of previous runs of last completed token generation session.
static size_t run_id = 0;
Copy link
Preview

Copilot AI Jun 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider whether the use of a static non-atomic run_id in OnRunEnd might lead to thread-safety issues if the provider is used in a multi-threaded context.

Copilot uses AI. Check for mistakes.

if ((run_id + 1) % 2 == 0) {
AdjustBuckets();
}
++run_id;
}

size_t CalculateBufferSize(size_t request_size) override {
size_t normalized_request_size = NormalizeBufferSize(request_size);

// Track usage for the current run
current_run_usage_[normalized_request_size]++;

// Check if we already have a bucket for this size. If not, create a new bucket so that it can cache buffers of
// this size in the current session run if the buffer is quickly released in the same session.
if (buckets_.find(normalized_request_size) == buckets_.end()) {
buckets_.emplace(normalized_request_size, std::vector<WGPUBuffer>());
buckets_keys_.push_back(normalized_request_size);
std::sort(buckets_keys_.begin(), buckets_keys_.end());
}

return normalized_request_size;
}

WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size) override {
auto it = buckets_.find(buffer_size);
if (it != buckets_.end() && !it->second.empty()) {
auto buffer = it->second.back();
it->second.pop_back();
return buffer;
}
return nullptr;
}

void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/) override {
// no-op
}

void ReleaseBuffer(WGPUBuffer buffer) override {
auto buffer_size = static_cast<size_t>(wgpuBufferGetSize(buffer));

auto it = buckets_.find(buffer_size);
if (it != buckets_.end()) {
it->second.emplace_back(buffer);
} else {
wgpuBufferRelease(buffer);
}
}

void OnRefresh() override {
// no-op
}

// Analyze memory patterns and adjust bucket sizes.
void AdjustBuckets() {
// Store old buckets to handle transitions.
auto old_buckets = std::move(buckets_);

// Clear and recreate buckets structure.
buckets_keys_.clear();
buckets_.clear();

// Create new buckets based on patterns.
for (const auto& pattern : memory_patterns_) {
// The request size here is already normalized, so we can use it directly as the bucket size key.
size_t bucket_size = pattern.second.request_size;
buckets_keys_.push_back(bucket_size);

// Initialize bucket vector.
auto& bucket = buckets_[bucket_size];

auto old_bucket_it = old_buckets.find(bucket_size);
if (old_bucket_it != old_buckets.end()) {
// Transfer buffers from old to new bucket.
bucket = std::move(old_bucket_it->second);

Check warning on line 353 in onnxruntime/core/providers/webgpu/buffer_manager.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/buffer_manager.cc:353: Add #include <utility> for move [build/include_what_you_use] [4]
old_bucket_it->second.clear();
old_buckets.erase(old_bucket_it);
}
}

// Sort bucket sizes.
std::sort(buckets_keys_.begin(), buckets_keys_.end());

Check warning on line 360 in onnxruntime/core/providers/webgpu/buffer_manager.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <algorithm> for sort [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/buffer_manager.cc:360: Add #include <algorithm> for sort [build/include_what_you_use] [4]

// Release any remaining buffers in old buckets that were not hit by the memoryusage patterns.
for (auto& pair : old_buckets) {
for (auto& buffer : pair.second) {
wgpuBufferRelease(buffer);
}
pair.second.clear();
}
old_buckets.clear();

// Clear patterns for next adjustment period.
memory_patterns_.clear();
}

private:
std::unordered_map<size_t, size_t> current_run_usage_; // Tracks usage in current session run.
std::unordered_map<size_t, MemoryUsagePattern> memory_patterns_; // Tracks patterns across session runs.
std::unordered_map<size_t, std::vector<WGPUBuffer>> buckets_;

Check warning on line 378 in onnxruntime/core/providers/webgpu/buffer_manager.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/buffer_manager.cc:378: Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4]
std::vector<size_t> buckets_keys_;

Check warning on line 379 in onnxruntime/core/providers/webgpu/buffer_manager.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/buffer_manager.cc:379: Add #include <vector> for vector<> [build/include_what_you_use] [4]
};

std::unique_ptr<IBufferCacheManager> CreateBufferCacheManager(BufferCacheMode cache_mode) {
switch (cache_mode) {
case BufferCacheMode::Disabled:
Expand All @@ -259,6 +389,8 @@
return std::make_unique<SimpleCacheManager>();
case BufferCacheMode::Bucket:
return std::make_unique<BucketCacheManager>();
case BufferCacheMode::DynamicBucket:
return std::make_unique<DynamicBucketCacheManager>();
default:
ORT_NOT_IMPLEMENTED("Unsupported buffer cache mode");
}
Expand All @@ -278,6 +410,9 @@
case BufferCacheMode::Bucket:
os << "Bucket";
break;
case BufferCacheMode::DynamicBucket:
os << "DynamicBucket";
break;
default:
os << "Unknown(" << static_cast<int>(mode) << ")";
}
Expand Down
36 changes: 33 additions & 3 deletions onnxruntime/core/providers/webgpu/buffer_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,21 @@
Disabled,
LazyRelease,
Simple,
Bucket
Bucket,
DynamicBucket
};
std::ostream& operator<<(std::ostream& os, BufferCacheMode mode);

//
// IBufferCacheManager is an interface for buffer cache management.
//
// By implementing this interface, we can have different buffer cache management strategies.
// Currently, we have 3 strategies:
// Currently, we have 4 strategies:
// - Disabled: no cache. always allocate a new buffer and release it immediately after use.
// - LazyRelease: no cache. the difference from Disabled is that it delays the release of buffers until the next refresh.
// - Simple: a simple cache that always keeps buffers. when a buffer is requested, it tries to find a buffer in the cache.
// - Bucket: a cache that keeps buffers in different buckets based on the buffer size, with a maximum number of buffers in each bucket.
//
// - DynamicBucket: a variation bucket cache that dynamically adjusts bucket sizes based on usage patterns in real-time requests and previous sessions.
class IBufferCacheManager {
public:
virtual ~IBufferCacheManager() = default;
Expand All @@ -50,6 +51,12 @@

// when a stream refresh is requested
virtual void OnRefresh() = 0;

// Track start of session run
virtual void OnRunStart() {}

// Track end of session run and update memory patterns
virtual void OnRunEnd() {}
};

//
Expand All @@ -69,6 +76,20 @@
void Download(WGPUBuffer src, void* dst, size_t size);
void RefreshPendingBuffers();

void OnRunStart() {
if (storage_cache_) storage_cache_->OnRunStart();
if (uniform_cache_) uniform_cache_->OnRunStart();
if (query_resolve_cache_) query_resolve_cache_->OnRunStart();
if (default_cache_) default_cache_->OnRunStart();
}

void OnRunEnd() {
if (storage_cache_) storage_cache_->OnRunEnd();
if (uniform_cache_) uniform_cache_->OnRunEnd();
if (query_resolve_cache_) query_resolve_cache_->OnRunEnd();
if (default_cache_) default_cache_->OnRunEnd();
}

private:
IBufferCacheManager& GetCacheManager(wgpu::BufferUsage usage) const;
IBufferCacheManager& GetCacheManager(WGPUBuffer buffer) const;
Expand All @@ -88,5 +109,14 @@
BufferManagerFactory() {}
};

// Structure to track memory usage patterns
struct MemoryUsagePattern {
size_t request_size;
size_t frequency;

MemoryUsagePattern(size_t size = 0, size_t freq = 0)

Check warning on line 117 in onnxruntime/core/providers/webgpu/buffer_manager.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Constructors callable with one argument should be marked explicit. [runtime/explicit] [4] Raw Output: onnxruntime/core/providers/webgpu/buffer_manager.h:117: Constructors callable with one argument should be marked explicit. [runtime/explicit] [4]
: request_size(size), frequency(freq) {}
};

} // namespace webgpu
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,9 @@ Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_
context_.StartProfiling();
}

// Start tracking memory usage for this session run.
context_.BufferManager().OnRunStart();

if (IsGraphCaptureEnabled() && IsGraphCaptureAllowed() && !IsGraphCaptured(0)) {
ORT_NOT_IMPLEMENTED("graph capture not implemented");
}
Expand All @@ -903,6 +906,9 @@ Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxrunti
context_.CollectProfilingData(profiler_->Events());
}

// Update memory patterns from this session run.
context_.BufferManager().OnRunEnd();

context_.OnRunEnd();

if (context_.ValidationMode() >= ValidationMode::Basic) {
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ std::shared_ptr<IExecutionProviderFactory> WebGpuProviderFactoryCreator::Create(
return webgpu::BufferCacheMode::Simple;
} else if (buffer_cache_mode_str == kBufferCacheMode_Bucket) {
return webgpu::BufferCacheMode::Bucket;
} else if (buffer_cache_mode_str == kBufferCacheMode_DynamicBucket) {
return webgpu::BufferCacheMode::DynamicBucket;
} else {
ORT_THROW("Invalid buffer cache mode: ", config_entry_str);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ constexpr const char* kBufferCacheMode_Disabled = "disabled";
constexpr const char* kBufferCacheMode_LazyRelease = "lazyRelease";
constexpr const char* kBufferCacheMode_Simple = "simple";
constexpr const char* kBufferCacheMode_Bucket = "bucket";
constexpr const char* kBufferCacheMode_DynamicBucket = "dynamicBucket";

constexpr const char* kValidationMode_Disabled = "disabled";
constexpr const char* kValidationMode_wgpuOnly = "wgpuOnly";
Expand Down
Loading