Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 155 additions & 74 deletions cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -586,6 +586,10 @@ class CacheSender::Impl
// not be removed from mCancelledRequests. This should be handled by timeout.
auto it = mReadyResponses.find(mCurrentRequest.value());
TLLM_CHECK(it != mReadyResponses.end());
auto cancelledException
= TLLM_REQUEST_EXCEPTION(reqId, tensorrt_llm::common::RequestErrorCode::kNETWORK_ERROR,
"Context KV cache transfer cancelled after ready-signal for request %zu", reqId);
it->second.mPromise.set_exception(std::make_exception_ptr(cancelledException));
{
std::scoped_lock lkResp(mSenderMutex);
mReadyResponses.erase(it);
Expand Down Expand Up @@ -847,103 +851,149 @@ class CacheReceiver::Impl

auto* agentConnectionManager = dynamic_cast<executor::kv_cache::AgentConnectionManager*>(mManager);
std::vector<std::optional<size_t>> cacheBufferIds;
// Track every recv-buffer slot we reserve here so the RAII guard
// below can release them again if anything later in this function
// throws. Without this, an exception between assignBufferIndexForRecv()
// and the eventual freeBufferIndexForRecv() call inside unformat() /
// receiveSync() leaks the slot, and because mRecvBufferCount defaults
// to 1 the next assignBufferIndexForRecv() then blocks forever in the
// unbounded cv.wait inside BaseTransBufferManager::assignBufferIndex
// (signature #6 of NVBug 6104831).
std::vector<std::pair<BaseTransBufferManager*, std::optional<size_t>>> assignedRecvBuffers;
if (agentConnectionManager)
{
for (auto& cacheTransBufferManager : agentConnectionManager->getCacheTransBufferManagers())
{
cacheBufferIds.push_back(cacheTransBufferManager->assignBufferIndexForRecv());
auto reservedId = cacheTransBufferManager->assignBufferIndexForRecv();
cacheBufferIds.push_back(reservedId);
assignedRecvBuffers.emplace_back(cacheTransBufferManager, reservedId);
}
TLLM_CHECK(!cacheBufferIds.empty());
}
auto freeAssignedRecvBuffers = [&assignedRecvBuffers, &llmRequest]() noexcept
{
for (auto& [mgr, id] : assignedRecvBuffers)
{
if (mgr != nullptr && id.has_value())
{
try
{
mgr->freeBufferIndexForRecv(id);
}
catch (std::exception const& freeExc)
{
TLLM_LOG_ERROR("Failed to free recv buffer index for request %zu during cleanup: %s",
llmRequest.mRequestId, freeExc.what());
}
}
}
assignedRecvBuffers.clear();
};

auto allCounterparts
= mCacheTransferLayer.computeCounterparts(mSelfState.getCommState().value().getSelfIdx(), contextState);
try
{
auto allCounterparts
= mCacheTransferLayer.computeCounterparts(mSelfState.getCommState().value().getSelfIdx(), contextState);

auto kvCounterParts = mCacheTransferLayer.getKvFormatter()->getCounterparts(
mCacheTransferLayer.getCacheState(), mSelfState.getCommState().value().getSelfIdx(), destCacheState);
auto kvCounterParts = mCacheTransferLayer.getKvFormatter()->getCounterparts(
mCacheTransferLayer.getCacheState(), mSelfState.getCommState().value().getSelfIdx(), destCacheState);

bool hasRnn = mCacheTransferLayer.getCacheState().hasRnnConfig() && destCacheState.hasRnnConfig();
bool hasRnn = mCacheTransferLayer.getCacheState().hasRnnConfig() && destCacheState.hasRnnConfig();

std::vector<SizeType32> rnnCounterParts;
if (hasRnn)
{
rnnCounterParts = executor::kv_cache::targetIRanksForRnn(
destCacheState, mCacheTransferLayer.getCacheState(), mSelfState.getCommState().value().getSelfIdx())
.mIRanks;
}
std::vector<SizeType32> rnnCounterParts;
if (hasRnn)
{
rnnCounterParts = executor::kv_cache::targetIRanksForRnn(
destCacheState, mCacheTransferLayer.getCacheState(), mSelfState.getCommState().value().getSelfIdx())
.mIRanks;
}

auto connections = mManager->getConnections(commState);
std::vector<executor::kv_cache::Connection const*> allConnections;
for (auto index : allCounterparts)
{
auto const* connection = connections.at(index);
allConnections.emplace_back(connection);
}
auto connections = mManager->getConnections(commState);
std::vector<executor::kv_cache::Connection const*> allConnections;
for (auto index : allCounterparts)
{
auto const* connection = connections.at(index);
allConnections.emplace_back(connection);
}

for (size_t ci = 0; ci < allCounterparts.size(); ci++)
{
auto rank = allCounterparts[ci];
auto const* connection = connections.at(rank);
for (size_t ci = 0; ci < allCounterparts.size(); ci++)
{
auto rank = allCounterparts[ci];
auto const* connection = connections.at(rank);

bool isKvCounterpart
= std::find(kvCounterParts.begin(), kvCounterParts.end(), rank) != kvCounterParts.end();
bool isRnnCounterpart
= hasRnn && std::find(rnnCounterParts.begin(), rnnCounterParts.end(), rank) != rnnCounterParts.end();
bool isKvCounterpart
= std::find(kvCounterParts.begin(), kvCounterParts.end(), rank) != kvCounterParts.end();
bool isRnnCounterpart = hasRnn
&& std::find(rnnCounterParts.begin(), rnnCounterParts.end(), rank) != rnnCounterParts.end();

if (agentConnectionManager)
{
auto idsForRank = cacheBufferIds;
auto const& managers = agentConnectionManager->getCacheTransBufferManagers();
for (size_t i = 0; i < idsForRank.size(); i++)
if (agentConnectionManager)
{
auto kind = managers[i]->getBufferKind();
bool include = (kind != BufferKind::kRNN) ? isKvCounterpart : isRnnCounterpart;
if (!include)
auto idsForRank = cacheBufferIds;
auto const& managers = agentConnectionManager->getCacheTransBufferManagers();
for (size_t i = 0; i < idsForRank.size(); i++)
{
idsForRank[i] = std::nullopt;
auto kind = managers[i]->getBufferKind();
bool include = (kind != BufferKind::kRNN) ? isKvCounterpart : isRnnCounterpart;
if (!include)
{
idsForRank[i] = std::nullopt;
}
}
}

int validConnectionIdx = 0;
if (isKvCounterpart)
{
auto kvCpIdx
= std::find(kvCounterParts.begin(), kvCounterParts.end(), rank) - kvCounterParts.begin();
auto [pickUpIdx, localRankIdx] = mCacheTransferLayer.getKvFormatter()->pickRecvConnections(
allCounterparts.size(), mSelfState.getCacheState().value(),
mSelfState.getCommState().value().getSelfIdx(), destCacheState, allCounterparts);
validConnectionIdx
= std::find(localRankIdx.begin(), localRankIdx.end(), kvCpIdx) - localRankIdx.begin();
int validConnectionIdx = 0;
if (isKvCounterpart)
{
auto kvCpIdx
= std::find(kvCounterParts.begin(), kvCounterParts.end(), rank) - kvCounterParts.begin();
auto [pickUpIdx, localRankIdx] = mCacheTransferLayer.getKvFormatter()->pickRecvConnections(
allCounterparts.size(), mSelfState.getCacheState().value(),
mSelfState.getCommState().value().getSelfIdx(), destCacheState, allCounterparts);
validConnectionIdx
= std::find(localRankIdx.begin(), localRankIdx.end(), kvCpIdx) - localRankIdx.begin();
}
else if (isRnnCounterpart)
{
auto rnnTargetInfo = executor::kv_cache::targetIRanksForRnn(destCacheState,
mCacheTransferLayer.getCacheState(), mSelfState.getCommState().value().getSelfIdx());
auto rnnCpIdx
= std::find(rnnCounterParts.begin(), rnnCounterParts.end(), rank) - rnnCounterParts.begin();
auto [pickUpIdx, localRankIdx]
= cache_formatter_utils::pickRecvConnections(rnnCounterParts.size(),
mCacheTransferLayer.getCacheState(), mSelfState.getCommState().value().getSelfIdx(),
destCacheState, rnnCounterParts, rnnTargetInfo);
validConnectionIdx
= std::find(localRankIdx.begin(), localRankIdx.end(), rnnCpIdx) - localRankIdx.begin();
}

auto* agentConnection = dynamic_cast<executor::kv_cache::AgentConnection const*>(connection);
TLLM_CHECK(agentConnection != nullptr);

const_cast<executor::kv_cache::AgentConnection*>(agentConnection)
->sendRequestAndBufferInfo(requestInfo, idsForRank, validConnectionIdx);
}
else if (isRnnCounterpart)
else
{
auto rnnTargetInfo = executor::kv_cache::targetIRanksForRnn(destCacheState,
mCacheTransferLayer.getCacheState(), mSelfState.getCommState().value().getSelfIdx());
auto rnnCpIdx
= std::find(rnnCounterParts.begin(), rnnCounterParts.end(), rank) - rnnCounterParts.begin();
auto [pickUpIdx, localRankIdx] = cache_formatter_utils::pickRecvConnections(rnnCounterParts.size(),
mCacheTransferLayer.getCacheState(), mSelfState.getCommState().value().getSelfIdx(),
destCacheState, rnnCounterParts, rnnTargetInfo);
validConnectionIdx
= std::find(localRankIdx.begin(), localRankIdx.end(), rnnCpIdx) - localRankIdx.begin();
sendRequestInfo(connection, requestInfo);
}

auto* agentConnection = dynamic_cast<executor::kv_cache::AgentConnection const*>(connection);
TLLM_CHECK(agentConnection != nullptr);

const_cast<executor::kv_cache::AgentConnection*>(agentConnection)
->sendRequestAndBufferInfo(requestInfo, idsForRank, validConnectionIdx);
}
else
{
sendRequestInfo(connection, requestInfo);
}
auto const& resource = getReceiveCacheResource(llmRequest);
// Buffer indices are now owned by the agent connections
// (mCacheBufferIds) and will be freed by unformat() during
// receiveSync() on the success path, or by the !isReady early
// return in requestSync() on the cancelled-after-ready path.
// Hand ownership off and clear the local reservation list so
// the catch below does not double-free on the success path.
assignedRecvBuffers.clear();
return TransferSession(std::move(allConnections), DataContext{tagFromRequestId(requestId), mTerminate},
std::move(allCounterparts), mSelfState, contextState, resource->mBufferManager,
requestInfo.getIndexFromEnd(), requestInfo.getLastBlockKey(), &llmRequest,
!common::getEnvKVCacheTimeOutputPath().empty());
}
catch (...)
{
freeAssignedRecvBuffers();
throw;
}
auto const& resource = getReceiveCacheResource(llmRequest);
return TransferSession(std::move(allConnections), DataContext{tagFromRequestId(requestId), mTerminate},
std::move(allCounterparts), mSelfState, contextState, resource->mBufferManager,
requestInfo.getIndexFromEnd(), requestInfo.getLastBlockKey(), &llmRequest,
!common::getEnvKVCacheTimeOutputPath().empty());
}

std::unique_ptr<ReceiveCacheResource> const& getReceiveCacheResource(LlmRequest const& llmRequest)
Expand Down Expand Up @@ -1062,6 +1112,37 @@ class CacheReceiver::Impl
// Reuse the error state for the cancelled request.
llmRequest.setState(LlmRequestState::kDISAGG_TRANS_ERROR);
llmRequest.setKvCacheTransferEnd(std::chrono::steady_clock::now());
// Mirror what unformat() does on the success path: explicitly free
// any pre-assigned recv buffer indices so a cancelled-after-ready
// request does not leak the recv buffer slot. Without this, the
// next assignBufferIndexForRecv() call blocks forever in the
// unbounded cv.wait inside BaseTransBufferManager::assignBufferIndex
// (signature #6 of NVBug 6104831). The Layer A guard in
// sendRequestInfo() already covers exception paths; this branch
// covers the structured early return that fires whenever the
// sender-side cancellation path (signature #1 fix) sends
// is_ready=false.
auto* agentConnectionManager = dynamic_cast<executor::kv_cache::AgentConnectionManager*>(mManager);
if (agentConnectionManager != nullptr)
{
for (auto const* connection : session.getConnections())
{
auto const* agentConnection = dynamic_cast<executor::kv_cache::AgentConnection const*>(connection);
if (agentConnection == nullptr)
{
continue;
}
for (auto& mgr : agentConnectionManager->getCacheTransBufferManagers())
{
auto cacheBufferId
= agentConnection->getPreAssignedBufferId(static_cast<uint8_t>(mgr->getBufferKind()));
if (cacheBufferId.has_value())
{
mgr->freeBufferIndexForRecv(cacheBufferId);
}
}
}
}
return;
}
receiveSync(session);
Expand Down
Loading
Loading