Skip to content

Phambinh/circular vmm pool#774

Open
phambinhfin wants to merge 10 commits intomainfrom
phambinh/circular_vmm_pool
Open

Phambinh/circular vmm pool#774
phambinhfin wants to merge 10 commits intomainfrom
phambinh/circular_vmm_pool

Conversation

@phambinhfin
Copy link
Copy Markdown

@phambinhfin phambinhfin commented Apr 3, 2026

Circular VMM Pool for Update-Free Command Buffer Execution on ROCm

Problem

Enabling COLLECTIVES in XLA command buffers (HIP graphs) on ROCm causes severe performance degradation. Every iteration triggers a full NCCL stream capture cycle when buffer addresses change, making COLLECTIVES 5–15× slower than running without command buffers.

Solution

Pre-allocate physical memory with permanent virtual address (VA) mappings at startup. The HIP graph is recorded once with stable VA addresses and replayed forever. Per-iteration overhead is reduced to lightweight D2D memcpy for parameter data — no hipMemMap, no hipMemUnmap, no hipEventSynchronize.

Architecture:

Startup (once per executor):
  hipMemCreate          → physical memory chunk
  hipMemAddressReserve  → virtual address range
  hipMemMap             → permanent VA→physical mapping (never unmapped)
  hipMemSetAccess       → P2P peer access for all GPUs
  hipMallocSignalMemory → timeline counter for GPU→CPU signaling

Per iteration:
  __atomic_load_n(timeline)           ~0.001 us  (check if GPU finished previous iter)
  stream->MemcpyD2D(pool ← BFC)      ~2.9 us    (copy input data into pool VA)
  command_buffer->Submit(stream)      ~0 us      (replay pre-recorded HIP graph)
  stream->MemcpyD2D(BFC ← pool)      ~2.9 us    (copy results back to BFC)
  hipStreamWriteValue64(timeline++)   ~5.3 us    (GPU signals iteration complete)

GPU-to-CPU synchronization uses hipStreamWriteValue64 to write a monotonic timeline counter to signal memory. The CPU checks with a non-blocking atomic read (__atomic_load_n with acquire semantics, cost: 1 nanosecond) before reusing a slot. A 30-second timeout with sched_yield backoff prevents unbounded spin-wait.

Why D2D Copy Instead of Map/Unmap?

Measured on 8× MI300X (ROCm 7.2):

Operation MI300X Cost Type
D2D MemcpyD2D (1 MB) 2.9 us Stream async (HBM bandwidth)
hipMemMap + hipMemUnmap 20.7 us Driver calls (page table update)
hipMemSetAccess (8 GPUs) 14.4 us Per-peer driver call
hipEventSynchronize 11.9–29.5 us CPU-blocking
__atomic_load_n (timeline read) 0.001 us Non-blocking (1 ns)
hipStreamWriteValue64 5.3 us Stream enqueue

Full per-iteration cycle (8 GPUs, 1 MB buffers):

Approach Measured Cost Operations
NVIDIA Remap 25,141.9 us EventSync + Unmap + Map + SetAccess ×8
Circular VMM Pool 11.5 us AtomicRead + CopyIn + CopyOut + WriteValue64
Speedup 2,178×

D2D memcpy runs at HBM bandwidth (~3.35 TB/s on MI300X), making even large buffer copies cheaper than a single driver call. The NVIDIA remap cycle explodes with multiple GPUs due to driver lock contention.

(Benchmark source: vmm_benchmark.cpp, compiled with hipcc -O2)

Results

MaxText LLaMA 3 8B on 8× MI300X (FSDP-8, synthetic data, 10 steps)

F+C+CL+CC = FUSION,CUBLAS,CUBLASLT,CUSTOM_CALL

# Config s/step TFLOP/s/device vs Baseline
1 No command buffers (baseline) 0.66 174
2 BFC + F+C+CL+CC 0.64 180 same
3 BFC + F+C+CL+CC+COLL 3.70 31 5.6× slower
4 CircularVMM + F+C+CL+CC 0.64 179 same
5 CircularVMM + F+C+CL+CC+COLL 0.65 176 same speed

Loss convergence is identical across all configs (10.871 → 8.973).

Micro-benchmarks

Workload BFC+COLL CircularVMM+COLL Speedup
2× MI300X pmap (5 matmul+allreduce) 15,412 us 647 us 23.8×
8× MI300X transformer step 255.6 ms 11.1 ms 23.0×

Accuracy: All 5 numerical accuracy tests pass with zero bit-level difference across 20 iterations (tanh, pmap+allreduce, grad, varying inputs, large matmul on 8 GPUs).

Key Design Decisions

  • Pool ALL command buffer allocations (params + temps + live-outs), not just temps. Ensures all addresses the graph sees are stable VA addresses.
  • D2D memcpy for params — parameter data lives at BFC addresses but the graph uses pool VA addresses. Copy cost (~2.9 us for 1 MB) is negligible vs NCCL re-tracing cost (~15,000 us).
  • Copy live-out results back to BFC after execution so downstream consumers find outputs at expected addresses.
  • va_remapping=true in CommandBufferThunk — tells the thunk to record once and replay forever.
  • Thread-safe initializationstd::atomic<bool> with acquire/release for double-checked locking, std::atomic<uint64_t> for iteration counter.
  • Signal memory for timelinehipExtMallocWithFlags(hipMallocSignalMemory), freed with absl::Cleanup guard on error paths.
  • Spin-wait with timeout — 30s deadline + sched_yield() backoff in AcquireNextSlot.

Files

New (3):

  • xla/stream_executor/rocm/circular_vmm_pool.h/.cc — Pool implementation
  • xla/stream_executor/rocm/circular_vmm_pool_test.cc — Unit tests

Modified (10):

  • xla/xla.protoxla_gpu_enable_circular_vmm_pool (field 468), xla_gpu_circular_vmm_pool_slots (field 469)
  • xla/debug_options_flags.cc — Flag defaults and registration
  • xla/service/gpu/gpu_executable.h/.ccExecuteThunksWithCircularVmmPool, CircularPoolState
  • xla/service/gpu/thunk_emitter.cc — Enable va_remapping when circular pool is active
  • xla/backends/gpu/runtime/command_buffer_conversion_pass.cc — Same
  • xla/stream_executor/rocm/rocm_vmm_allocator.h/.cc — Remove unused param, align peer access
  • xla/stream_executor/rocm/BUILD, xla/service/gpu/BUILD — Targets and deps

Enable

XLA_FLAGS="--xla_gpu_enable_circular_vmm_pool=true \
           --xla_gpu_graph_min_graph_size=1 \
           --xla_gpu_enable_command_buffer=FUSION,CUBLAS,CUBLASLT,CUSTOM_CALL,COLLECTIVES"

Port CUDA's VMM allocator to ROCm/HIP using HIP VMM APIs (hipMemCreate,
hipMemAddressReserve, hipMemMap, hipMemSetAccess, hipStreamWriteValue64).
This enables per-GPU memory access control on AMD GPUs, replacing the
all-or-nothing hipDeviceMallocFinegrained approach that grants every GPU
access and disables L2 cache.

Four layers matching the CUDA structure:
- RocmRawMemoryAllocation: RAII wrapper for hipMemCreate/hipMemRelease
- RocmMemoryReservation: RAII wrapper for virtual address reservation,
  mapping, access control, and unmapping
- RocmVmmAllocator: simple all-in-one allocator (create + reserve + map
  + setAccess in a single Allocate call)
- RocmDeviceAddressVmmAllocator: advanced allocator with GPU timeline-
  based deferred deallocation via hipStreamWriteValue64 and signal memory

Key differences from the CUDA implementation:
- hipMemGenericAllocationHandle_t is a pointer type (not integer)
- hipStreamWriteValue64 requires signal memory (hipMallocSignalMemory)
  instead of pinned host memory (cuMemHostAlloc)
- hipDeviceptr_t is void* requiring PtrAdd() helper for offset arithmetic
- All wrap:: calls use nullptr/0ULL for proper template type deduction
Add four test files mirroring the CUDA VMM test structure:

- rocm_raw_memory_allocation_test: CreateAllocation, AddressReflectsHandle,
  SizeIsAtLeastRequested
- rocm_memory_reservation_test: CreateReservation, MapToWrongType,
  MapToSingleAllocation, ScopedMappingUnmapsOnDestruction,
  MapToMultipleAllocations, TwoReservationsDifferentAddresses
- rocm_vmm_allocator_test: AllocateAndFree, AllocateZeroBytes,
  MemcpyRoundTrip (parameterized with RdmaEnabled/RdmaDisabled)
- rocm_device_address_vmm_allocator_test: 13 single-GPU tests covering
  allocate/deallocate, memory read/write, stream accessors, deferred
  deallocation, VA reuse, destructor safety, and error handling;
  2 multi-GPU tests for cross-device allocation isolation
RocmEvent was missing the Synchronize() override, causing
ExecuteThunksWithVaRemapping to fail with "UNIMPLEMENTED: Not supported
for this Event." The VA remapping path calls unmap_event->Synchronize()
to wait for the GPU to finish before unmapping virtual addresses.

Implements hipEventSynchronize matching the CUDA CudaEvent::Synchronize
pattern.
…:SetAccess

SetAccess was only granting read/write access to the owning device,
causing GPU memory access faults when peer GPUs tried to read VMM
allocations during multi-GPU pmap/collective operations.

Now iterates over all devices and grants access to P2P-capable peers,
matching the pattern already used in RocmVmmAllocator (Layer 3).
Pre-allocate physical memory slots with permanent VA mappings at startup.
The GPU signals slot completion via hipStreamWriteValue64 to coherent host
memory; the CPU checks with a non-blocking memory read before reusing a
slot. After the first iteration, per-iteration overhead is just D2D memcpy
for params — no hipMemMap, no hipMemUnmap, no hipEventSynchronize.

Key components:
- CircularVmmPool: creates N physical chunks, each permanently mapped to
  its own VA range. AcquireNextSlot/ReleaseSlot for slot lifecycle.
- ExecuteThunksWithCircularVmmPool in gpu_executable.cc: pools all
  command buffer allocations, copies param data from BFC into pool before
  execution, copies live-out results back after execution.
- xla_gpu_enable_circular_vmm_pool flag (field 468) to opt in.
- va_remapping enabled for CommandBufferThunk when circular pool is active,
  so graphs are recorded once and replayed without re-recording.

Benchmark results on 8x MI300X (transformer training step):
- Baseline (no cmd buf): 114.5 ms/iter
- BFC + COLLECTIVES: 234.2 ms/iter (2x slower)
- Circular VMM + COLLECTIVES: 12.0 ms/iter (9.5x faster than baseline)

Benchmark on 2x MI300X (5 matmul+tanh + allreduce):
- BFC + COLLECTIVES: 19,074 us/iter
- Circular VMM + COLLECTIVES: 1,107 us/iter (17.2x faster)

All correctness tests pass (tanh, pmap+allreduce, grad, 10-layer model,
different inputs per iteration).
@phambinhfin phambinhfin added the claude-review Request a Claude AI code review for this PR label Apr 3, 2026
Comment on lines +1529 to +1531
absl::MutexLock lock(&circular_pool_mutex_);
pool_state = &circular_pools_[executor];
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: thread safety — pool_state accessed without lock

The mutex circular_pool_mutex_ is held only while looking up/inserting the CircularPoolState entry. After that, pool_state->pool, pool_state->iteration_count++, etc. are accessed without the lock.

If two threads call ExecuteThunksWithCircularVmmPool for the same executor concurrently (e.g., multi-threaded inference), they can race on iteration_count++ (non-atomic, no lock) — two iterations could use the same slot concurrently, corrupting data.

Consider either:

  1. Holding the mutex for the entire function body (but watch for deadlocks with GPU calls), or
  2. Making iteration_count an std::atomic<uint64_t> and adding a per-pool mutex for the initialization path.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved -- addressed in this revision. iteration_count is now std::atomic<uint64_t> with fetch_add(1), and pool initialization uses double-checked locking with a per-pool mutex.

Comment on lines +161 to +165
uint64_t required = iteration - num_slots_ + 1;
uint64_t completed = __atomic_load_n(timeline_, __ATOMIC_ACQUIRE);
while (completed < required) {
completed = __atomic_load_n(timeline_, __ATOMIC_ACQUIRE);
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: unbounded spin-wait with no backoff or timeout

AcquireNextSlot spin-waits in a tight loop reading the timeline counter. If the GPU stalls or the signal is never written (e.g., GPU hang), the CPU thread will spin forever consuming 100% of a core.

At minimum, add a timeout that returns an error status after a reasonable period (e.g., 30s). A sched_yield() or short std::this_thread::sleep_for in the loop body would also reduce CPU waste during normal waits.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved -- addressed in this revision. AcquireNextSlot now has a 30-second timeout with DeadlineExceededError, spin-wait with kMaxSpinIterations limit, and sched_yield() for backoff.

Comment on lines +206 to +218
LOG(INFO) << absl::StrFormat(
"CommandBufferThunk::Initialize: warmup_done=%d state=%d "
"va_remapping=%d requires_init=%d will_record=%d",
warmup, static_cast<int>(state), enable_command_buffer_va_remapping_,
commands_.requires_initialization(), will_record);

// Log the addresses that will be used for recording
if (will_record) {
for (auto idx : commands_.allocs_indices()) {
auto addr = execute_params.buffer_allocations->GetDeviceAddress(idx);
LOG(INFO) << absl::StrFormat(
" Initialize record addr[%d]: %p size=%d", idx, addr.opaque(),
addr.size());
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cleanup: LOG(INFO) in hot path should be VLOG

Multiple LOG(INFO) statements are added to Initialize and ExecuteOnStream that fire on every iteration. These will produce massive log output in production and add measurable overhead. The existing code already uses VLOG(2)/VLOG(3) for this purpose.

These should be converted to VLOG(2) or VLOG(3) before merging, or removed if they were only needed during development.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved -- addressed in this revision. The LOG(INFO) statements in Initialize and ExecuteOnStream have been converted to VLOG(3).

Comment on lines 1753 to 1757
LOG(INFO) << absl::StreamFormat(
"ExecuteThunks: cmd_buffer_allocs=%d circular_vmm_pool=%d "
"va_remapping=%d",
command_buffer_allocation_indexes_.size(), enable_circular_vmm_pool,
enable_command_buffer_va_remapping);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cleanup: LOG(INFO) in hot path — should be VLOG

This fires on every invocation of ExecuteThunks. The original code used XLA_VLOG_DEVICE(3, ...). Should be restored to VLOG(3) or XLA_VLOG_DEVICE to avoid flooding production logs.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved -- addressed in this revision. The ExecuteThunks logging now uses XLA_VLOG_DEVICE(3, ...) instead of LOG(INFO).

"Number of slots in the circular VMM pool (default 2)."));
} // NOLINT(readability/fn_size)

// Allocates flag_values and flag_objects; this function must not be called more
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: default value / help text mismatch

The actual default is set to 1 (line 507: set_xla_gpu_circular_vmm_pool_slots(1)), but the flag help text here says "(default 2)".

With 1 slot, AcquireNextSlot will spin-wait for GPU completion on every iteration (no overlap), negating the primary benefit of the circular pool. Either fix the help text to say "default 1" or change the default to 2.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Partially resolved. The flag help text in debug_options_flags.cc now correctly says "(default 1)", but the proto comment in xla.proto still says "(default 2)" -- see xla_gpu_circular_vmm_pool_slots field comment. Please update the proto comment to match the actual default of 1.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved -- fully addressed in this revision. Both the flag help text in debug_options_flags.cc and the proto comment in xla.proto now correctly say "(default 1)".

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved: both the flag help text (line 3034) and proto comment now correctly say "(default 1)", matching the actual default value.

Comment on lines +1536 to +1546
absl::btree_set<BufferAllocation::Index> pool_indexes;
absl::btree_set<BufferAllocation::Index> copy_indexes;
if (buffer_assignment_) {
for (BufferAllocation::Index idx : command_buffer_allocation_indexes_) {
const auto& alloc = buffer_assignment_->GetAllocation(idx);
if (alloc.is_constant() || alloc.size() == 0) continue;
pool_indexes.insert(idx);
if (alloc.is_entry_computation_parameter() || alloc.maybe_live_out()) {
copy_indexes.insert(idx);
}
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Performance: pool_indexes and copy_indexes recomputed every iteration

These sets are derived from buffer_assignment_ and command_buffer_allocation_indexes_, which don't change between iterations. They should be computed once during pool initialization and cached in CircularPoolState, rather than rebuilt on every call.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved -- addressed in this revision. pool_indexes and copy_indexes are now cached in CircularPoolState with an indexes_cached flag and double-checked locking, computed once at init.

Comment on lines +444 to +445
// Track last-seen BFC addresses to skip redundant D2D memcpy.
absl::flat_hash_map<BufferAllocation::Index, void*> last_param_addrs;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused field: last_param_addrs is never read or written

This map is declared with a comment about skipping redundant D2D memcpy, but it's never used anywhere in the implementation. Either remove it or implement the optimization it's intended for.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved -- addressed in this revision. The last_param_addrs field has been removed.

Thunk::ExecutableSource executable_source, bool block_host_until_done);

struct CircularPoolState {
std::shared_ptr<void> pool;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: std::shared_ptr<void> for type erasure is fragile

This works correctly because unique_ptr<CircularVmmPool> converting to shared_ptr<void> captures the right deleter. However, it sacrifices type safety and requires static_cast<CircularVmmPool*> at every use site.

Since the header already has a #if TENSORFLOW_USE_ROCM include for CircularVmmPool in the .cc file, consider forward-declaring CircularVmmPool here and using std::unique_ptr<se::gpu::CircularVmmPool> directly (with a custom deleter or forward-declared destructor). This eliminates the type-erased casts.

Comment on lines +1641 to +1652
const auto& alloc = buffer_assignment_->GetAllocation(i);
if (alloc.maybe_live_out()) {
auto bfc_addr = buffer_allocations.GetDeviceAddress(i);
if (!bfc_addr.is_null() && bfc_addr.size() > 0) {
se::DeviceAddressBase bfc_dst(bfc_addr.opaque(), bfc_addr.size());
se::DeviceAddressBase pool_src(pool_addr.opaque(), bfc_addr.size());
TF_RETURN_IF_ERROR(run_options->stream()->MemcpyD2D(
&bfc_dst, pool_src, bfc_addr.size()));
}
}
}
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential bug: live-out copy-back not synchronized when block_host_until_done is true

If block_host_until_done is true, ExecuteThunksImpl synchronizes the stream before returning. These D2D copy-back operations are enqueued after that sync, meaning the caller may read stale data at the BFC address.

Consider adding a stream synchronization after the copy-back when block_host_until_done is true:

if (block_host_until_done) {
  TF_RETURN_IF_ERROR(run_options->stream()->BlockHostUntilDone());
}

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved -- addressed in this revision. The copy-back D2D memcpy calls now happen before BlockHostUntilDone, so the sync correctly covers both the thunk execution and the copy-back.

Comment on lines +1589 to +1591
LOG(INFO) << absl::StrFormat(
"CircularVmmPool iter=%d slot=%d/%d: %d pool addrs",
iteration, slot_idx, pool->num_slots(), slot_addresses.size());
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cleanup: LOG(INFO) per-iteration — should be VLOG

This logs on every iteration. Should be VLOG(3) or similar to avoid flooding production logs.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved -- addressed in this revision. Per-iteration logging in ExecuteThunksWithCircularVmmPool now uses VLOG(3). The remaining LOG(INFO) calls are one-time pool initialization messages.

// via hipStreamWriteValue64 after each iteration; the CPU reads it to
// determine when a slot is safe to reuse.
volatile uint64_t* timeline_;
void* timeline_host_ptr_; // Raw pointer for hipHostFree.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: misleading comment

Comment says "Raw pointer for hipHostFree" but the destructor actually calls hipFree, not hipHostFree. The comment should say "Raw pointer for hipFree".

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved -- addressed in this revision. Comment now correctly reads 'Raw pointer for hipFree (signal memory).'.

Comment on lines +71 to +72
size_t granularity = 0;
TF_RETURN_IF_ERROR(ToStatus(wrap::hipMemGetAllocationGranularity(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: unused parameter is_rdma_supported

GetVmmAllocationProperties takes is_rdma_supported but never uses it. Either remove the parameter or implement the RDMA support path.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved -- addressed in this revision. GetVmmAllocationProperties no longer takes is_rdma_supported, and the constructor documents the parameter as unused with a comment explaining why.

@claude
Copy link
Copy Markdown

claude bot commented Apr 3, 2026

Review Summary

This PR adds a circular VMM pool for ROCm to eliminate per-iteration hipMemMap/hipMemUnmap overhead in command buffer execution. The design — pre-allocating N physical slots with permanent VA mappings and GPU timeline signaling — is sound and addresses a real performance bottleneck.

Key issues to address before merge:

  • Thread safety bugpool_state fields (including iteration_count++) are accessed outside the mutex, risking data races in concurrent execution scenarios.
  • Unbounded spin-waitAcquireNextSlot has no timeout or backoff; a GPU hang will spin a CPU core forever.
  • Live-out copy-back not synced — when block_host_until_done is true, D2D copies enqueued after ExecuteThunksImpl are not synchronized, so callers may read stale output.
  • Default slots mismatch — actual default is 1, help text says 2. With 1 slot there is no pipelining benefit.
  • LOG(INFO) in hot paths — multiple LOG(INFO) calls fire every iteration across command_buffer_thunk.cc and gpu_executable.cc; should be VLOG(2)/VLOG(3).

See inline comments for details and additional suggestions.

@github-actions github-actions bot removed the claude-review Request a Claude AI code review for this PR label Apr 3, 2026
phambinhfin added a commit that referenced this pull request Apr 3, 2026
Fixes from Claude code review on PR #774:

1. Convert LOG(INFO) to VLOG(3) in hot paths (command_buffer_thunk.cc,
   gpu_executable.cc) to avoid flooding production logs
2. Fix help text: default pool slots is 1, not 2
3. Thread safety: use std::atomic<uint64_t> for iteration_count and
   add per-pool mutex for initialization
4. Cache pool_indexes/copy_indexes in CircularPoolState (computed once
   at init, reused every iteration instead of recomputed)
5. Add stream sync after live-out copy-back when block_host_until_done
   is true, ensuring caller reads correct data at BFC address
6. Remove unused last_param_addrs field from CircularPoolState
- Convert LOG(INFO) to VLOG(3) in hot paths to avoid flooding logs
- Fix help text: default pool slots is 1, not 2
- Thread safety: use std::atomic for iteration_count, add per-pool
  mutex for initialization
- Cache pool_indexes/copy_indexes in CircularPoolState instead of
  recomputing every iteration
- Add stream sync after live-out copy-back when block_host_until_done
- Remove unused last_param_addrs field
@phambinhfin phambinhfin force-pushed the phambinh/circular_vmm_pool branch from d67610c to af977d0 Compare April 3, 2026 12:26
@phambinhfin phambinhfin added the claude-review Request a Claude AI code review for this PR label Apr 3, 2026
- Add timeout (30s) and sched_yield to AcquireNextSlot spin-wait to
  prevent unbounded CPU burn if GPU stalls
- Remove unused is_rdma_supported parameter from RocmVmmAllocator
  (hipMemAllocationProp has no RDMA flag on ROCm)
- Fix misleading comment: timeline pointer freed with hipFree, not
  hipHostFree (allocated via hipExtMallocWithFlags signal memory)
@phambinhfin phambinhfin added claude-review Request a Claude AI code review for this PR and removed claude-review Request a Claude AI code review for this PR labels Apr 3, 2026
}

// Cache buffer classification on first call (doesn't change between iters).
if (!pool_state->indexes_cached) {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: data race in double-checked locking on indexes_cached

indexes_cached is a plain bool (declared at gpu_executable.h:449), but it is read here outside the lock. Another thread may be writing true to it (at line 1548) under the lock. Reading a non-atomic variable concurrently with a write is undefined behavior in the C++ memory model.

Either:

  1. Change indexes_cached to std::atomic<bool> and use load(std::memory_order_acquire) / store(true, std::memory_order_release) for the double-checked locking pattern, or
  2. Remove the outer check and always acquire the lock (the lock is uncontended after the first call, so the overhead is negligible).

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved -- addressed in this revision. indexes_cached replaced by std::atomic initialized with proper acquire/release memory ordering for safe double-checked locking.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved: indexes_cached has been replaced by std::atomic<bool> initialized with proper acquire/release ordering, fixing the data race.

Comment on lines +1555 to +1584
if (pool_state->pool == nullptr) {
int num_slots = has_module()
? module_config().debug_options().xla_gpu_circular_vmm_pool_slots()
: 1;

if (pool_indexes.empty()) {
return ExecuteThunksImpl(
has_module() ? &module_config().debug_options() : nullptr,
module_name_, unique_id, *thunk_executor_, executable_source,
run_options, buffer_allocations, block_host_until_done,
execution_stream_ids_, collective_memory_cache_);
}

std::vector<uint64_t> buffer_sizes;
buffer_sizes.reserve(pool_indexes.size());
for (BufferAllocation::Index idx : pool_indexes) {
buffer_sizes.push_back(buffer_allocations.GetDeviceAddress(idx).size());
}

TF_ASSIGN_OR_RETURN(
auto pool,
se::gpu::CircularVmmPool::Create(executor, buffer_sizes, num_slots));

LOG(INFO) << absl::StrFormat(
"CircularVmmPool: created %d slots for module %s on device %d "
"(%d command buffer allocations)",
num_slots, module_name_, executor->device_ordinal(),
command_buffer_allocation_indexes_.size());

pool_state->pool = std::move(pool);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: pool initialization race -- pool == nullptr checked without lock

If two threads call ExecuteThunksWithCircularVmmPool concurrently for the same executor before the pool is initialized, both will see pool_state->pool == nullptr and both will create a pool. The second std::move(pool) assignment overwrites the first, and any iteration already in-flight on the first pool will use stale state.

This should use the same double-checked locking pattern as indexes_cached (once that is fixed to use atomics), or simply hold pool_state->mu during initialization:

if (pool_state->pool == nullptr) {
  absl::MutexLock lock(&pool_state->mu);
  if (pool_state->pool == nullptr) {
    // ... create pool ...
  }
}

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved -- addressed in this revision. Pool initialization is now guarded by atomic initialized + per-pool mutex double-checked locking. The pool creation, index computation, and initialized flag are all protected within the same critical section.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved: pool initialization is now inside an atomic double-checked locking pattern with a per-pool mutex, eliminating the race.

xla/xla.proto Outdated
// map/unmap overhead entirely after startup.
optional bool xla_gpu_enable_circular_vmm_pool = 468;

// Number of slots in the circular VMM pool (default 2). Higher values allow
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: proto comment says "default 2" but actual default is 1

The default set in debug_options_flags.cc is set_xla_gpu_circular_vmm_pool_slots(1). This comment should say "(default 1)" to match. The flag help text was already corrected.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved -- addressed in this revision. Proto comment now correctly says "(default 1)".

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved: proto comment now correctly says "(default 1)", matching the actual default.

Comment on lines +137 to +143
RocmMemoryReservation::~RocmMemoryReservation() {
if (ptr_ == nullptr) {
return;
}
std::unique_ptr<ActivateContext> activation = executor_->Activate();
auto unmap_status =
ToStatus(wrap::hipMemUnmap(ptr_, size_), "Error unmapping ROCm memory");
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: potential double-unmap when used with ScopedMapping

The destructor calls hipMemUnmap(ptr_, size_) on the full reservation range. However, if a ScopedMapping was created via MapTo() and is destroyed before the reservation (which is the case in CircularVmmPool::Slot due to C++ reverse-declaration-order destruction), ScopedMapping::~ScopedMapping() will have already called UnMap on the mapped sub-range.

This means the same range gets unmapped twice: once by the ScopedMapping destructor, then again here. HIP likely tolerates this (the LOG(ERROR) will fire silently), but it could mask real mapping bugs.

Consider either:

  • Tracking whether sub-ranges are already unmapped, or
  • Documenting that the reservation destructor is the sole owner of unmap responsibility and ScopedMapping should be released (moved-from) before the reservation is destroyed.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved -- addressed in this revision. The CircularVmmPool destructor now explicitly releases ScopedMappings before clearing slots, ensuring only one unmap occurs per range.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved: the destructor now explicitly releases ScopedMappings before clearing slots, avoiding the double-unmap scenario.

Comment on lines +103 to +117
int device_count = 0;
TF_RETURN_IF_ERROR(ToStatus(wrap::hipGetDeviceCount(&device_count)));
for (int peer = 0; peer < device_count; peer++) {
hipMemAccessDesc access_desc = GetVmmAccessDescriptor(peer);
auto access_status =
ToStatus(wrap::hipMemSetAccess(ptr, padded_size, &access_desc, 1));
if (!access_status.ok()) {
if (peer == executor->device_ordinal()) {
return access_status;
}
VLOG(3) << "Could not set VMM access for peer device " << peer << ": "
<< access_status;
}
}

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: inconsistent peer access pattern vs. RocmMemoryReservation::SetAccess

This loop unconditionally tries hipMemSetAccess for every device and silently logs failures for non-owning peers. In contrast, RocmMemoryReservation::SetAccess (in rocm_memory_reservation.cc:111-120) first checks CanEnablePeerAccessTo() and skips devices that don't support P2P access.

Both code paths serve the same purpose — setting VMM access permissions for peer devices. It would be more robust and consistent to use the same pattern (checking CanEnablePeerAccessTo() first) here as well, especially since unconditionally calling hipMemSetAccess on unsupported peers could produce spurious log noise on multi-GPU systems without full P2P connectivity.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved -- addressed in this revision. rocm_vmm_allocator.cc now checks CanEnablePeerAccessTo() before calling hipMemSetAccess, matching the pattern in RocmMemoryReservation::SetAccess.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved: peer access is now checked via CanEnablePeerAccessTo() before calling hipMemSetAccess, consistent with the pattern in RocmMemoryReservation::SetAccess.

@claude
Copy link
Copy Markdown

claude bot commented Apr 3, 2026

Re-review Summary

Good progress — the majority of issues from the initial review have been addressed. Of the 12 original findings, 10 are fully resolved and 1 is partially resolved. Two new issues were identified in this revision, plus two minor nits on the new VMM infrastructure code.

Remaining issues (from thread updates + new findings)

Bugs (2 new):

  • indexes_cached double-checked locking data race — plain bool read outside lock needs std::atomic<bool> (gpu_executable.cc)
  • Pool initialization race — pool == nullptr check needs mutex protection (gpu_executable.cc)

Partially resolved (1):

  • Proto comment still says "default 2" but actual default is 1 (xla.proto)

Nits (2 new):

  • Potential double-unmap in RocmMemoryReservation destructor when used with ScopedMapping (rocm_memory_reservation.cc)
  • Inconsistent peer access pattern between VmmAllocate and RocmMemoryReservation::SetAccess (rocm_vmm_allocator.cc)

Still open from initial review (1, no developer reply):

  • std::shared_ptr<void> type erasure for pool storage (gpu_executable.h)

@github-actions github-actions bot removed the claude-review Request a Claude AI code review for this PR label Apr 3, 2026
- Fix data race: use std::atomic<bool> with acquire/release for
  double-checked locking on pool initialization
- Fix pool initialization race: merge index caching and pool creation
  into single mutex-protected block, checked via atomic initialized flag
- Fix proto comment: default slots is 1, not 2
- Fix double-unmap: release ScopedMappings before reservation destructor
  in CircularVmmPool to avoid unmapping same range twice
- Align peer access pattern in rocm_vmm_allocator.cc with
  RocmMemoryReservation::SetAccess (check CanEnablePeerAccessTo first)
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: missing BUILD dependency @com_google_absl//absl/time

circular_vmm_pool.cc includes absl/time/clock.h and absl/time/time.h (for absl::Now(), absl::Seconds, absl::Time in AcquireNextSlot), but this BUILD target does not declare @com_google_absl//absl/time as a dependency. This will cause a build failure when strict dependency checking is enabled (Bazel --strict_deps / layering check).

Suggested change
"@com_google_absl//absl/types:span",
"@com_google_absl//absl/types:span",
"@com_google_absl//absl/time",

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved ✓ — addressed in this revision.

std::vector<Slot> slots;
slots.reserve(num_slots);
for (int i = 0; i < num_slots; ++i) {
TF_ASSIGN_OR_RETURN(auto physical,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: signal memory leaked on error path

timeline_host_ptr is allocated via hipExtMallocWithFlags at line 115 but is only freed inside ~CircularVmmPool. If any of the TF_ASSIGN_OR_RETURN calls in this loop fail (e.g., RocmRawMemoryAllocation::Create or RocmMemoryReservation::Create), the function returns early without constructing a CircularVmmPool object, so the destructor never runs and timeline_host_ptr is leaked.

Consider wrapping timeline_host_ptr in a cleanup guard (e.g., absl::Cleanup) that calls hipFree and is dismissed/released only after the CircularVmmPool is successfully constructed:

auto timeline_cleanup = absl::MakeCleanup([&]() {
  wrap::hipFree(timeline_host_ptr);
});
// ... slot creation loop ...
std::move(timeline_cleanup).Cancel();  // dismiss before returning
return absl::WrapUnique(new CircularVmmPool(...));

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved ✓ — addressed in this revision.

@claude
Copy link
Copy Markdown

claude bot commented Apr 3, 2026

Re-review Summary

Good progress — 11 of 15 prior findings have been addressed. 2 new issues found in this cycle (posted inline). 4 prior findings remain open without developer reply (see existing inline threads).

New findings (2):

  • Missing @com_google_absl//absl/time BUILD dep for circular_vmm_pool
  • Signal memory (timeline_host_ptr) leaked on error path during slot creation

Still open from prior review (4):

  • shared_ptr<void> type erasure in CircularPoolState
  • Data race on indexes_cached (plain bool in double-checked locking)
  • Pool initialization race (pool==nullptr checked without lock)
  • Potential double-unmap in RocmMemoryReservation destructor with live ScopedMapping

🤖 Generated with Claude Code

@phambinhfin phambinhfin added the claude-review Request a Claude AI code review for this PR label Apr 3, 2026
@claude
Copy link
Copy Markdown

claude bot commented Apr 3, 2026

Re-review Summary

6 previously flagged issues are now resolved:

  • Default value / help text mismatch (proto + flags) — fixed
  • Data race on indexes_cached — replaced with std::atomic<bool> with acquire/release
  • Pool initialization race — now uses atomic DCL with per-pool mutex
  • Proto comment default mismatch — corrected
  • Potential double-unmap in reservation destructor — ScopedMappings released explicitly before slots
  • Inconsistent peer access pattern — now checks CanEnablePeerAccessTo() first

3 prior issues still open (no new comments posted — see existing inline threads):

  • std::shared_ptr<void> type erasure for pool storage (nit)
  • Missing @com_google_absl//absl/time BUILD dependency
  • Signal memory leak on error path in slot creation

No new issues found in this revision. Good progress on addressing the prior review feedback.

@github-actions github-actions bot removed the claude-review Request a Claude AI code review for this PR label Apr 3, 2026
- Add absl::Cleanup guard for timeline signal memory in
  CircularVmmPool::Create so hipExtMallocWithFlags allocation is freed
  if slot creation fails partway through
- Add missing BUILD deps: @com_google_absl//absl/time and
  @com_google_absl//absl/cleanup for circular_vmm_pool target
- Replace std::shared_ptr<void> with typed
  std::shared_ptr<CircularVmmPool> using forward declaration,
  eliminating unsafe static_cast at use sites
@phambinhfin phambinhfin added the claude-review Request a Claude AI code review for this PR label Apr 3, 2026
Comment on lines +450 to +454
std::shared_ptr<stream_executor::gpu::CircularVmmPool> pool
ABSL_GUARDED_BY(mu);
std::atomic<uint64_t> iteration_count{0};
absl::btree_set<BufferAllocation::Index> pool_indexes ABSL_GUARDED_BY(mu);
absl::btree_set<BufferAllocation::Index> copy_indexes ABSL_GUARDED_BY(mu);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: ABSL_GUARDED_BY(mu) annotations are inaccurate after initialization

pool, pool_indexes, and copy_indexes are annotated ABSL_GUARDED_BY(mu), but after the one-time initialization they are read without holding mu in ExecuteThunksWithCircularVmmPool (lines 1582-1585, 1593 of gpu_executable.cc). This is safe because the atomic double-checked locking on initialized guarantees the fields are immutable post-init, but the annotations will produce false positives with Clang's thread safety analysis (-Wthread-safety).

Consider either:

  1. Removing the ABSL_GUARDED_BY(mu) annotations and adding a comment that these fields are write-once under mu and immutable thereafter (reads protected by the initialized acquire fence), or
  2. Keeping the lock held for reads as well (unnecessary overhead for the hot path).

Comment on lines +1666 to +1672
if (block_host_until_done) {
TF_RETURN_IF_ERROR(run_options->stream()->BlockHostUntilDone());
}

// GPU signals slot completion so the CPU knows when this slot is safe to
// reuse (non-blocking write via hipStreamWriteValue64).
TF_RETURN_IF_ERROR(pool->ReleaseSlot(run_options->stream(), iteration));
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Medium] ReleaseSlot (which enqueues hipStreamWriteValue64) is called after BlockHostUntilDone. This means the GPU-side timeline signal write is not covered by the host synchronization — if the caller destroys the pool (or the process exits) immediately after this function returns with block_host_until_done=true, the signal write may not have executed yet.

Consider swapping the order so ReleaseSlot is enqueued before the sync:

Suggested change
if (block_host_until_done) {
TF_RETURN_IF_ERROR(run_options->stream()->BlockHostUntilDone());
}
// GPU signals slot completion so the CPU knows when this slot is safe to
// reuse (non-blocking write via hipStreamWriteValue64).
TF_RETURN_IF_ERROR(pool->ReleaseSlot(run_options->stream(), iteration));
// GPU signals slot completion so the CPU knows when this slot is safe to
// reuse (non-blocking write via hipStreamWriteValue64).
TF_RETURN_IF_ERROR(pool->ReleaseSlot(run_options->stream(), iteration));
if (block_host_until_done) {
TF_RETURN_IF_ERROR(run_options->stream()->BlockHostUntilDone());
}

This ensures the timeline signal is included in the synchronization fence.

Comment on lines +450 to +454
std::shared_ptr<stream_executor::gpu::CircularVmmPool> pool
ABSL_GUARDED_BY(mu);
std::atomic<uint64_t> iteration_count{0};
absl::btree_set<BufferAllocation::Index> pool_indexes ABSL_GUARDED_BY(mu);
absl::btree_set<BufferAllocation::Index> copy_indexes ABSL_GUARDED_BY(mu);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Low] pool, pool_indexes, and copy_indexes are annotated ABSL_GUARDED_BY(mu), but after initialization they are read without holding mu in the hot path (lines 1582-1593 of gpu_executable.cc). This is functionally safe because these are write-once fields protected by the initialized atomic, but the annotations are inconsistent and will produce false positives if Clang's thread safety analysis is enabled.

Consider either:

  1. Removing ABSL_GUARDED_BY(mu) (since the fields are effectively immutable post-init and the atomic<bool> initialized provides the barrier), or
  2. Adding ABSL_NO_THREAD_SAFETY_ANALYSIS to the hot-path reader function.

Comment on lines +105 to +119
int device_count = 0;
TF_RETURN_IF_ERROR(ToStatus(wrap::hipGetDeviceCount(&device_count)));
for (int peer = 0; peer < device_count; peer++) {
if (peer != executor->device_ordinal()) {
auto peer_executor_or =
const_cast<Platform*>(executor->GetPlatform())
->ExecutorForDevice(peer);
if (!peer_executor_or.ok() ||
!executor->CanEnablePeerAccessTo(peer_executor_or.value())) {
continue;
}
}
hipMemAccessDesc access_desc = GetVmmAccessDescriptor(peer);
TF_RETURN_IF_ERROR(ToStatus(
wrap::hipMemSetAccess(ptr, padded_size, &access_desc, 1),
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Medium] If hipMemSetAccess fails (line 118-119) or hipGetDeviceCount fails (line 106), TF_RETURN_IF_ERROR returns immediately without cleaning up the already-mapped memory. At this point hipMemMap has succeeded (line 90), but the error path never calls hipMemUnmap, hipMemAddressFree, or hipMemRelease — leaking both the VA mapping and the physical allocation handle.

Consider using an absl::Cleanup guard after the successful hipMemMap to ensure proper cleanup on all subsequent error paths:

auto map_cleanup = absl::MakeCleanup([&]() {
  wrap::hipMemUnmap(ptr, padded_size);
  wrap::hipMemAddressFree(ptr, padded_size);
  wrap::hipMemRelease(handle);
});
// ... hipMemSetAccess calls ...
std::move(map_cleanup).Cancel();
return std::make_tuple(ptr, padded_size, handle);


auto* pool = pool_state->pool.get();
uint64_t iteration = pool_state->iteration_count.fetch_add(1);
int slot_idx = iteration % pool->num_slots();
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Nit] slot_idx is only used in the VLOG(3) on line 1601. In builds where VLOG(3) compiles out, this may trigger an unused-variable warning. Consider moving this computation inside the VLOG statement or guarding it with VLOG_IS_ON(3).

@claude
Copy link
Copy Markdown

claude bot commented Apr 3, 2026

Re-review Summary (Round 3)

All 19 previous inline findings have been addressed and resolved — nice work on the fixes across the board (signal memory leak, thread safety, logging levels, annotation mismatches, unused fields, etc.).

This round identified 4 new findings (2 medium, 1 low, 1 nit), posted inline:

  • ReleaseSlot after BlockHostUntilDone ordering — the timeline signal write is not covered by the host sync fence (medium)
  • VmmAllocate error path leakhipMemSetAccess failure leaks the already-mapped VA and physical handle (medium)
  • ABSL_GUARDED_BY annotation mismatchpool/pool_indexes/copy_indexes are read without lock post-init (low)
  • Unused slot_idx variable — only used inside VLOG(3), may warn in release builds (nit)

Overall the circular VMM pool design is solid — pre-allocated slots with permanent VA mappings and GPU timeline signaling is a clean approach to eliminating per-iteration hipMemMap/hipMemUnmap overhead.

🤖 Generated with Claude Code

@github-actions github-actions bot removed the claude-review Request a Claude AI code review for this PR label Apr 3, 2026
@phambinhfin phambinhfin added the claude-review Request a Claude AI code review for this PR label Apr 3, 2026
@github-actions github-actions bot removed the claude-review Request a Claude AI code review for this PR label Apr 3, 2026
@phambinhfin phambinhfin added the claude-review Request a Claude AI code review for this PR label Apr 3, 2026
pool_state->initialized.load(std::memory_order_acquire)) {
auto cache_it = pool_state->alias_cache.find(arg_idx);
if (cache_it != pool_state->alias_cache.end()) {
return cache_it->second;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Medium] alias_cache always returns slot-0 alias VA addresses regardless of which slot the current iteration will actually use. With the default num_slots=1 this is correct (slot 0 is the only slot, so alias and graph VAs share the same physical memory). However, with num_slots>1, there is a potential correctness issue:

  1. GenerateBufferAllocations calls BufferForAllocation, which returns slot-0 alias addresses for all pooled live-out buffers (via this alias_cache lookup).
  2. The live-out result buffers in the ExecutionOutput are set from these slot-0 alias addresses (line 1283: buffer_allocations.GetDeviceAddress(output_info.allocation_index)).
  3. ExecuteThunksWithCircularVmmPool then runs the graph on slot X (where X = iteration % N), writing results to slot-X's graph VA (different physical memory from slot 0).
  4. Downstream consumers read from the slot-0 alias VA in the ExecutionOutput, which points to slot-0's physical memory -- not slot-X's.

Result: with N>1, live-out data from iterations that don't land on slot 0 would be read from the wrong physical memory.

Since the default is N=1 and the PR description recommends N=1, this is safe in practice today. But if someone sets xla_gpu_circular_vmm_pool_slots>1, it could produce silent data corruption. Consider either:

  • Documenting that N>1 is not yet supported, or
  • Updating the alias cache per-iteration (or making ExecuteThunksWithCircularVmmPool patch the result buffers after execution).

Comment on lines +230 to +256
absl::StatusOr<std::vector<DeviceAddressBase>>
CircularVmmPool::AcquireNextSlotAliases(uint64_t iteration) {
int slot_idx = iteration % num_slots_;

if (iteration >= static_cast<uint64_t>(num_slots_)) {
uint64_t required = iteration - num_slots_ + 1;
uint64_t completed = __atomic_load_n(timeline_, __ATOMIC_ACQUIRE);
constexpr int kMaxSpinIterations = 1000;
constexpr auto kTimeout = absl::Seconds(30);
auto deadline = absl::Now() + kTimeout;
int spin_count = 0;
while (completed < required) {
if (++spin_count > kMaxSpinIterations) {
if (absl::Now() > deadline) {
return absl::DeadlineExceededError(absl::StrFormat(
"CircularVmmPool: timed out waiting for alias slot %d "
"(required=%d, completed=%d)",
iteration % num_slots_, required, completed));
}
sched_yield();
spin_count = 0;
}
completed = __atomic_load_n(timeline_, __ATOMIC_ACQUIRE);
}
}

return slots_[slot_idx].alias_addresses;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Nit] AcquireNextSlotAliases duplicates the entire spin-wait logic from AcquireNextSlot (lines 199-228). Consider extracting the spin-wait into a private helper, e.g. WaitForSlotAvailable(uint64_t iteration), so both methods call it before returning their respective address vectors. This avoids maintaining the same timeout/yield logic in two places.

@claude
Copy link
Copy Markdown

claude bot commented Apr 3, 2026

Re-review Summary

Resolved (2): Missing @com_google_absl//absl/time BUILD dep and signal memory leak (now guarded by absl::MakeCleanup) — both fixed since last review.

Still open from prior review (5): ReleaseSlot ordering after BlockHostUntilDone, GUARDED_BY annotation inconsistencies, hipMemSetAccess error-path leak, unused slot_idx variable warning — these remain unaddressed.

New findings (2): alias_cache correctness issue with N>1 slots (medium severity), and AcquireNextSlotAliases spin-wait code duplication (nit). Posted inline.

🤖 Generated with Claude Code

@github-actions github-actions bot removed the claude-review Request a Claude AI code review for this PR label Apr 3, 2026
@phambinhfin phambinhfin force-pushed the phambinh/circular_vmm_pool branch from 341069b to bea785e Compare April 3, 2026 22:32
… pool

The circular VMM pool was forcing enable_va_remapping=true in
CommandBufferThunk, which skips the warmup/initialization path
that NCCL collectives require. This caused a segfault at step 3
on MaxText LLaMA 3 8B with 8x MI300X FSDP-8.

The fix: don't set enable_va_remapping when only the circular pool
is enabled. The pool provides permanently stable VA addresses, so
CommandBufferThunk naturally detects no address changes and replays
the graph without re-recording. This allows the normal warmup path
to run, properly initializing NCCL collectives.
@phambinhfin phambinhfin force-pushed the phambinh/circular_vmm_pool branch from bea785e to 33a92a8 Compare April 3, 2026 22:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant