Skip to content
Draft
44 changes: 34 additions & 10 deletions cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include <torch/custom_class.h>
#include <torch/python.h>
#include <type_traits>
#include <unordered_set>
#include <vector>

using SizeType32 = tensorrt_llm::runtime::SizeType32;
Expand Down Expand Up @@ -204,13 +205,17 @@ class BaseCacheTransceiver
{
public:
virtual ~BaseCacheTransceiver() = default;
virtual void respondAndSendAsync(LlmRequest* llmRequest) = 0;
// These methods take std::shared_ptr<LlmRequest> so the transceiver and
// its async workers can hold a strong reference for the duration of the
// transfer. See the comment on CacheTransceiver::mSenderFutures for the
// lifetime invariant (kept in one place to avoid drift).
virtual void respondAndSendAsync(std::shared_ptr<LlmRequest> llmRequest) = 0;
virtual void respondAndSendLayerWise(
RequestVector const& requests, std::shared_ptr<ContextProgress> const& progress)
= 0;

virtual void requestAndReceiveSync(LlmRequest* llmRequest) = 0;
virtual void requestAndReceiveAsync(LlmRequest* llmRequest) = 0;
virtual void requestAndReceiveSync(std::shared_ptr<LlmRequest> llmRequest) = 0;
virtual void requestAndReceiveAsync(std::shared_ptr<LlmRequest> llmRequest) = 0;

/// Check all requests transferring context, and return the requests that have completed or encountered an error.
virtual RequestStatuses checkContextTransferStatus(
Expand All @@ -221,7 +226,7 @@ class BaseCacheTransceiver

[[nodiscard]] virtual bool checkGenTransferComplete() const = 0;

virtual bool cancelRequest(LlmRequest* llmRequest) = 0;
virtual bool cancelRequest(std::shared_ptr<LlmRequest> llmRequest) = 0;
};

class CacheTransceiver : public BaseCacheTransceiver
Expand Down Expand Up @@ -252,13 +257,13 @@ class CacheTransceiver : public BaseCacheTransceiver

virtual ~CacheTransceiver();

void respondAndSendAsync(LlmRequest* llmRequest) override;
void respondAndSendAsync(std::shared_ptr<LlmRequest> llmRequest) override;

void respondAndSendLayerWise(
RequestVector const& requests, std::shared_ptr<ContextProgress> const& progress) override;

void requestAndReceiveSync(LlmRequest* llmRequest) override;
void requestAndReceiveAsync(LlmRequest* llmRequest) override;
void requestAndReceiveSync(std::shared_ptr<LlmRequest> llmRequest) override;
void requestAndReceiveAsync(std::shared_ptr<LlmRequest> llmRequest) override;

RequestStatuses checkContextTransferStatus(
std::optional<int> const& atLeastRequestNum = std::nullopt, bool markComplete = false) override;
Expand All @@ -267,7 +272,7 @@ class CacheTransceiver : public BaseCacheTransceiver

[[nodiscard]] bool checkGenTransferComplete() const override;

virtual bool cancelRequest(LlmRequest* llmRequest) override;
virtual bool cancelRequest(std::shared_ptr<LlmRequest> llmRequest) override;

private:
void initializeCommState();
Expand All @@ -276,8 +281,27 @@ class CacheTransceiver : public BaseCacheTransceiver

std::unique_ptr<CacheSender> mCacheSender;
std::unique_ptr<CacheReceiver> mCacheReceiver;
std::vector<std::pair<LlmRequest*, std::future<void>>> mSenderFutures;
std::vector<std::pair<LlmRequest*, std::future<void>>> mRequesterFutures;
// Store shared_ptr rather than raw LlmRequest* so the futures map holds a
// strong reference for the duration of the transfer. Otherwise Python's
// _terminate_request can drop its pybind shared_ptr while the C++ side's
// raw pointer is still dereferenced by checkGenTransferStatus /
// checkContextTransferStatus (the UAF forensically confirmed via
// MALLOC_PERTURB_=85 producing mRequestId=0x5555555555555555).
//
// Eviction policy is asymmetric:
// - mRequesterFutures (gen side): on timeout, keep the entry tracked
// via mTimedOutRequesterIds until the worker future resolves. A
// timeout/cancel is not a quiescence proof on the recv side, so the
// advertised receive buffers may still be written to until the worker
// unwinds. See checkGenTransferStatus.
// - mSenderFutures (ctx side): erased immediately on completion,
// exception, or timeout. Sender zombies empirically unwind on peer
// teardown (decode-pod restart), and CacheSender::cancelRequest is
// only required to clear bookkeeping for telemetry / re-enqueue
// paths. See checkContextTransferStatus.
std::vector<std::pair<std::shared_ptr<LlmRequest>, std::future<void>>> mSenderFutures;
std::vector<std::pair<std::shared_ptr<LlmRequest>, std::future<void>>> mRequesterFutures;
std::unordered_set<LlmRequest::RequestIdType> mTimedOutRequesterIds;
mpi::MpiComm const* mMpiWorldComm{nullptr};

std::shared_ptr<CacheTransceiverComm> mGroupComm;
Expand Down
9 changes: 9 additions & 0 deletions cpp/include/tensorrt_llm/executor/transferAgent.h
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,15 @@ class TransferStatus
virtual ~TransferStatus() = default;
[[nodiscard]] virtual bool isCompleted() const = 0;
virtual TransferState wait(int64_t timeout_ms = -1) const = 0;

/// Release the backend transfer request. If the request is still active,
/// backends may attempt to cancel it. A true return only means the backend
/// accepted release of the transfer handle; callers must still treat remote
/// memory quiescence as backend-specific.
[[nodiscard]] virtual bool release()
{
return false;
}
};

struct BaseAgentConfig
Expand Down
189 changes: 181 additions & 8 deletions cpp/tensorrt_llm/batch_manager/baseTransBuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,83 @@
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/common/opUtils.h"

#include <exception>
#include <mutex>

namespace tensorrt_llm::batch_manager
{

namespace
{

char const* bufferKindName(BufferKind kind)
{
switch (kind)
{
case BufferKind::kKV: return "kv";
case BufferKind::kKV_INDEXER: return "kv_indexer";
case BufferKind::kRNN: return "rnn";
}
return "unknown";
}

} // namespace

void BufferIndexHolder::release() noexcept
{
// Happy-path release: frees the slot and disarms the holder in one
// noexcept call. Used in place of an older detach() + explicit
// freeBufferIndex*() sequence so a throw between the two calls cannot
// leave the holder in a partially-released state.
if (!mHeld || mMgr == nullptr)
{
return;
}
try
{
if (mIsRecv)
{
mMgr->freeBufferIndexForRecv(mIndex);
}
else
{
mMgr->freeBufferIndexForSend(mIndex);
}
}
catch (...)
{
// Swallow; the destructor must be noexcept and any exit path that
// failed to release explicitly relies on this fallback to free the
// slot.
}
mHeld = false;
}

void BufferIndexHolder::poison() noexcept
{
if (!mHeld || mMgr == nullptr)
{
return;
}
try
{
if (mIsRecv)
{
mMgr->poisonBufferIndexForRecv(mIndex);
}
else
{
mMgr->poisonBufferIndexForSend(mIndex);
}
}
catch (...)
{
// poisonBufferIndex is noexcept; keep this as belt-and-suspenders so
// fail-closed cleanup cannot throw from an exception path.
}
mHeld = false;
}

BaseTransBufferManager::BaseTransBufferManager(
size_t transferBufferSize, nvinfer1::DataType dataType, std::optional<size_t> maxNumTokens)
: mDataType{dataType}
Expand Down Expand Up @@ -54,26 +126,40 @@ BaseTransBufferManager::BaseTransBufferManager(
allocateBuffer();
}

std::optional<int> BaseTransBufferManager::assignBufferIndexForSend()
std::optional<int> BaseTransBufferManager::assignBufferIndexForSend(
std::atomic<bool> const* perRequestCancel, int64_t waitSliceMs, std::optional<uint64_t> requestIdForLog)
{
return assignBufferIndex(mConcurrenceSendResource, mSendBufferCount, mOnlyUseDynamicBuffer);
return assignBufferIndex(mConcurrenceSendResource, mSendBufferCount, mOnlyUseDynamicBuffer, perRequestCancel,
waitSliceMs, requestIdForLog);
}

void BaseTransBufferManager::freeBufferIndexForSend(std::optional<int> bufferId)
{
freeBufferIndex(mConcurrenceSendResource, bufferId, mSendBufferCount, mOnlyUseDynamicBuffer);
}

std::optional<int> BaseTransBufferManager::assignBufferIndexForRecv()
void BaseTransBufferManager::poisonBufferIndexForSend(std::optional<int> bufferId) noexcept
{
poisonBufferIndex(mConcurrenceSendResource, bufferId, mSendBufferCount, mOnlyUseDynamicBuffer, "send");
}

std::optional<int> BaseTransBufferManager::assignBufferIndexForRecv(
std::atomic<bool> const* perRequestCancel, int64_t waitSliceMs, std::optional<uint64_t> requestIdForLog)
{
return assignBufferIndex(mConcurrenceRecvResource, mRecvBufferCount, mOnlyUseDynamicBuffer);
return assignBufferIndex(mConcurrenceRecvResource, mRecvBufferCount, mOnlyUseDynamicBuffer, perRequestCancel,
waitSliceMs, requestIdForLog);
}

void BaseTransBufferManager::freeBufferIndexForRecv(std::optional<int> bufferId)
{
freeBufferIndex(mConcurrenceRecvResource, bufferId, mRecvBufferCount, mOnlyUseDynamicBuffer);
}

void BaseTransBufferManager::poisonBufferIndexForRecv(std::optional<int> bufferId) noexcept
{
poisonBufferIndex(mConcurrenceRecvResource, bufferId, mRecvBufferCount, mOnlyUseDynamicBuffer, "recv");
}

std::tuple<std::vector<runtime::ITensor::SharedPtr>, size_t, bool> BaseTransBufferManager::getOrAllocateSendBuffers(
std::optional<int> bufferId, int targetNum, std::vector<size_t> const& requestedNumberOfElements,
runtime::BufferManager const& bufferManagerToUse)
Expand Down Expand Up @@ -225,16 +311,46 @@ void BaseTransBufferManager::allocateBuffer()
}
}

std::optional<int> BaseTransBufferManager::assignBufferIndex(
ConcurrenceResource& resource, size_t bufferCount, bool onlyUseDynamicBuffer)
std::optional<int> BaseTransBufferManager::assignBufferIndex(ConcurrenceResource& resource, size_t bufferCount,
bool onlyUseDynamicBuffer, std::atomic<bool> const* perRequestCancel, int64_t waitSliceMs,
std::optional<uint64_t> requestIdForLog)
{
if (onlyUseDynamicBuffer)
{
TLLM_CHECK_WITH_INFO(!resource.mPoisoned.load(std::memory_order_relaxed),
"Cannot assign dynamic cache transfer buffer kind=%s because a previous transfer left dynamic transfer "
"memory poisoned. The process must restart before these memory ranges can be safely reused.",
bufferKindName(getBufferKind()));
return std::nullopt;
}
// Bounded wait_for loop so a cancel fired on this request while parked
// here can interrupt the wait via the per-request cancel atomic, and so
// mTerminate (flipped between slices) keeps the drain worker responsive
// to shutdown.
std::unique_lock lk(resource.mBuffersMutex);
resource.mBuffersCV.wait(
lk, [&resource, bufferCount]() { return static_cast<size_t>(resource.mConcurrence) < bufferCount; });
auto const predicate = [&resource, bufferCount]()
{
return resource.mPoisoned.load(std::memory_order_relaxed)
|| static_cast<size_t>(resource.mConcurrence) < bufferCount;
};
if (!predicate())
{
auto const slice = std::chrono::milliseconds{waitSliceMs};
while (!predicate())
{
resource.mBuffersCV.wait_for(lk, slice);
if (perRequestCancel != nullptr && perRequestCancel->load(std::memory_order_relaxed))
{
auto const reqIdStr
= requestIdForLog.has_value() ? std::to_string(requestIdForLog.value()) : std::string{"?"};
TLLM_THROW("assignBufferIndex cancelled via perRequestCancel (reqId=%s)", reqIdStr.c_str());
}
}
}
TLLM_CHECK_WITH_INFO(!resource.mPoisoned.load(std::memory_order_relaxed),
"Cannot assign cache transfer buffer kind=%s because a previous transfer left the buffer pool poisoned. "
"The process must restart before these memory ranges can be safely reused.",
bufferKindName(getBufferKind()));
int bufferId = -1;
for (size_t i = 0; i < bufferCount; i++)
{
Expand Down Expand Up @@ -264,13 +380,70 @@ void BaseTransBufferManager::freeBufferIndex(
TLLM_CHECK(static_cast<size_t>(bufferId.value()) < bufferCount);
{
std::scoped_lock lk(resource.mBuffersMutex);
if (resource.mBufferIndexFlag[bufferId.value()] == 2)
{
TLLM_LOG_ERROR("Refusing to free poisoned cache transfer buffer kind=%s index=%d",
bufferKindName(getBufferKind()), bufferId.value());
return;
}
resource.mBufferIndexFlag[bufferId.value()] = 0;
}
resource.mConcurrence--;
resource.mBuffersCV.notify_one();
}
}

void BaseTransBufferManager::poisonBufferIndex(ConcurrenceResource& resource, std::optional<int> bufferId,
size_t bufferCount, bool onlyUseDynamicBuffer, char const* direction) noexcept
{
resource.mPoisoned.store(true, std::memory_order_relaxed);

if (onlyUseDynamicBuffer)
{
TLLM_LOG_ERROR(
"Poisoned dynamic %s cache transfer buffer kind=%s. Dynamic transfer memory cannot be safely reused; "
"the process must restart.",
direction, bufferKindName(getBufferKind()));
resource.mBuffersCV.notify_all();
return;
}

if (!bufferId.has_value())
{
TLLM_LOG_ERROR("Poisoned unknown %s cache transfer buffer kind=%s. The process must restart.", direction,
bufferKindName(getBufferKind()));
resource.mBuffersCV.notify_all();
return;
}

try
{
TLLM_CHECK(static_cast<size_t>(bufferId.value()) < bufferCount);
{
std::scoped_lock lk(resource.mBuffersMutex);
if (resource.mBufferIndexFlag[bufferId.value()] == 1)
{
resource.mBufferIndexFlag[bufferId.value()] = 2;
}
}
TLLM_LOG_ERROR(
"Poisoned %s cache transfer buffer kind=%s index=%d. The slot will not be returned to the pool because "
"transport quiescence is unknown; restart the process before serving more KV transfers.",
direction, bufferKindName(getBufferKind()), bufferId.value());
}
catch (std::exception const& e)
{
TLLM_LOG_ERROR("Exception while poisoning %s cache transfer buffer kind=%s index=%d: %s", direction,
bufferKindName(getBufferKind()), bufferId.value_or(-1), e.what());
}
catch (...)
{
TLLM_LOG_ERROR("Unknown exception while poisoning %s cache transfer buffer kind=%s index=%d", direction,
bufferKindName(getBufferKind()), bufferId.value_or(-1));
}
resource.mBuffersCV.notify_all();
}

size_t BaseTransBufferManager::getRecvBufferCount()
{
return mRecvBufferCount;
Expand Down
Loading
Loading