Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
efa1996
implement variable window attention by breaking the block manager int…
netanel-haber Mar 24, 2025
bfdcb5a
revert isCyclic to be true if the min attention window is reached, no…
netanel-haber Mar 25, 2025
8b87594
add explanatory comment to mCyclicThreshold
netanel-haber Mar 25, 2025
8b8a5cb
load correct gemma config
netanel-haber Mar 25, 2025
1b68b16
don't shadow inputLength in addSequence - it should remain the functi…
netanel-haber Mar 25, 2025
8c7e8e4
fix KVCacheManagerVariableWindowAttentionWithReuseTest for multiple w…
netanel-haber Mar 25, 2025
ebd93c6
if TYPE_CHECKING
netanel-haber Mar 26, 2025
d79f5ad
Merge branch 'main' into feature/allocate_minimal_blocks_per_window_size
netanel-haber Mar 26, 2025
34481ed
set temp_attention_window_inputs to None explicitly
netanel-haber Mar 26, 2025
8001958
set temp_attention_window_inputs to None explicitly
netanel-haber Mar 26, 2025
c9de9c1
pass dtype as well
netanel-haber Mar 26, 2025
5ca0da1
Merge branch 'main' into feature/allocate_minimal_blocks_per_window_size
netanel-haber Mar 31, 2025
9001a08
test_gemma variable sliding window attention
netanel-haber Mar 31, 2025
e70b719
Merge branch 'main' into feature/allocate_minimal_blocks_per_window_size
netanel-haber Apr 1, 2025
a9393b6
allot a fraction of primary/secondaryBlocks to different window size …
netanel-haber Apr 2, 2025
72e8bd7
Merge branch 'main' into feature/allocate_minimal_blocks_per_window_size
netanel-haber Apr 2, 2025
a5c527f
remove || mEnableBlockReuse which erroneously triggers beamsearch cod…
netanel-haber Apr 8, 2025
71d38be
turn off request delaying for MaxUtil
netanel-haber Apr 8, 2025
a5ddf64
make comments better
netanel-haber Apr 8, 2025
cdd6ffa
Merge branch 'main' into feature/allocate_minimal_blocks_per_window_size
netanel-haber Apr 8, 2025
05f244e
windowSizesTotalSum using std::accumulate
netanel-haber Apr 8, 2025
957e8c9
Merge branch 'main' into feature/allocate_minimal_blocks_per_window_size
netanel-haber Apr 9, 2025
d340977
fix error handling of forwardAsync - forwardAsync catch-all catch cle…
netanel-haber Apr 9, 2025
6497c83
fix comments
netanel-haber Apr 9, 2025
1221a36
Merge branch 'main' into feature/allocate_minimal_blocks_per_window_size
netanel-haber Apr 9, 2025
0fffa0b
remove assert that kills disagg tests, since it isn't necessary
netanel-haber Apr 10, 2025
987573d
Merge branch 'main' into feature/allocate_minimal_blocks_per_window_size
netanel-haber Apr 10, 2025
e547f89
Merge branch 'main' into feature/allocate_minimal_blocks_per_window_size
netanel-haber Apr 10, 2025
260bed6
Merge branch 'main' into feature/allocate_minimal_blocks_per_window_size
netanel-haber Apr 10, 2025
4be6dd9
fix corrupted expression: 'isNewTask && (peftCacheManager ?' -> '(isN…
netanel-haber Apr 10, 2025
49b3d27
add Gemma3 to SUPPORTED_HF_ARCHITECTURES
netanel-haber Apr 13, 2025
4847fc5
Merge branch 'main' into feature/allocate_minimal_blocks_per_window_size
netanel-haber Apr 13, 2025
88ddfec
support Gemma3
netanel-haber Apr 14, 2025
70a0ef5
Merge commit '2fb1d65d4301a5fd28e70ae00bfb6d5af01afa4b' into feature/…
netanel-haber Apr 14, 2025
e532522
finally fix test_gemma - always spread at least {} into generate_summ…
netanel-haber Apr 14, 2025
522d4b0
finally fix test_gemma - always spread at least {} into generate_summ…
netanel-haber Apr 14, 2025
b207b7e
fix kvfactor field for deepseek
netanel-haber Apr 14, 2025
3c15112
fix comment
netanel-haber Apr 14, 2025
65be79e
fix gemma-3 entries in testlist to include vswa
netanel-haber Apr 14, 2025
bb4095e
only quantize gemma2 VSWA
netanel-haber Apr 14, 2025
dfddbe5
fix test_gemma
netanel-haber Apr 14, 2025
cf86e35
fix test_gemma
netanel-haber Apr 14, 2025
4947552
in sendRequestInfo, fromOldAllocatedBlockIds->fromOldAllocatedBlockId…
netanel-haber Apr 16, 2025
9a60e91
fix: disable KV cache reuse if using attention sink (#3021)
Funatiq Apr 15, 2025
9f6be5d
Merge branch 'main' into feature/allocate_minimal_blocks_per_window_size
Funatiq Apr 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions cpp/include/tensorrt_llm/batch_manager/capacityScheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,6 @@ class MaxUtilizationScheduler : public BaseCapacityScheduler
RequestList const& activeRequests) const;

private:
/// @return {fitsKvCache, fitsPeft}
std::pair<bool, bool> trySchedulingRequestMaxUtilization(kv_cache_manager::BaseKVCacheManager const& kvCacheManager,
OptionalRef<BasePeftCacheManager const> peftCacheManager, std::shared_ptr<LlmRequest> const& req,
RequestVector& scheduledRequests, SizeType32& numScheduledBlocks, SizeType32& numScheduledPeftPages,
std::unordered_set<uint64_t>& seenTaskIds) const;

SizeType32 mMaxNumRequests;
/// @brief Boolean that indicates if multiple micro batches might be in flight
bool mManyMicroBatches;
Expand Down
773 changes: 600 additions & 173 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Large diffs are not rendered by default.

56 changes: 43 additions & 13 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,28 @@ class BlockRange
{
};

BlockRange(BaseKVCacheManager const& cacheManager, LlmRequest::RequestIdType requestId, SizeType32 beam,
SizeType32 poolIdx = 0)
: mManager(&cacheManager)
, mPool(cacheManager.getBlockManager().getPrimaryPool(poolIdx))
, mBlockIds(cacheManager.getSequence(requestId).getCacheBlockIds().at(beam))
static BlockRange fromOldAllocatedBlockIds(BaseKVCacheManager const& cacheManager,
LlmRequest::RequestIdType requestId, SizeType32 beam = kFIRST_AND_ONLY_BEAM)
{
assert(kFIRST_AND_ONLY_BEAM == beam);
auto const windowSize = firstWindowSize(cacheManager);
auto const blockIds = cacheManager.getSequence(requestId).getCacheBlockIds(windowSize).at(kFIRST_AND_ONLY_BEAM);
return BlockRange(cacheManager, blockIds, requestId);
}

BlockRange(BaseKVCacheManager const& cacheManager, std::vector<SizeType32> blockIds, SizeType32 poolIdx = 0)
: mManager(&cacheManager)
, mPool(cacheManager.getBlockManager().getPrimaryPool(poolIdx))
, mBlockIds(std::move(blockIds))
static BlockRange fromNewlyAllocatedBlockIds(
BaseKVCacheManager const& cacheManager, LlmRequest::RequestIdType requestId)
{
auto const windowSize = firstWindowSize(cacheManager);
auto const blockIds = cacheManager.getNewlyAllocatedBlockIds(requestId, windowSize);
return BlockRange(cacheManager, blockIds, requestId);
}

BlockRange(runtime::ITensor::SharedPtr pool, std::vector<SizeType32> const& blockIds)
BlockRange(runtime::ITensor::SharedPtr pool, std::vector<SizeType32> const& blockIds) // Only used in tests
: mManager{nullptr}
, mPool{std::move(pool)}
, mWindowSize{0}
, mRequestId{0}
, mBlockIds{blockIds}
{
TLLM_CHECK(mPool);
Expand Down Expand Up @@ -84,25 +88,51 @@ class BlockRange
auto& blockManager = mManager->getBlockManager();
for (auto id : mBlockIds)
{
blockHashes.emplace_back(blockManager.getBlockById(id)->getHash());
blockHashes.emplace_back(blockManager.getBlockById(id, mWindowSize)->getHash());
}
return blockHashes;
}

void updatePoolIdx(SizeType32 poolIdx)
{
if (mManager)
TLLM_CHECK(mManager);
mPool = mManager->getBlockManager().getPrimaryPool(poolIdx);
auto const newWindowSize = mManager->getBlockManager().getPoolWindowSize(poolIdx);
if (newWindowSize != mWindowSize)
{
mPool = mManager->getBlockManager().getPrimaryPool(poolIdx);
mWindowSize = newWindowSize;
mBlockIds = mManager->getSequence(mRequestId).getCacheBlockIds(mWindowSize).at(kFIRST_AND_ONLY_BEAM);
}
}

friend class BlockIterator;

private:
BlockRange(
BaseKVCacheManager const& cacheManager, std::vector<SizeType32> blockIds, LlmRequest::RequestIdType requestId)
: mManager(&cacheManager)
, mPool(cacheManager.getBlockManager().getPrimaryPool(kFIRST_POOL_INDEX))
, mWindowSize(firstWindowSize(cacheManager))
, mRequestId(requestId)
, mBlockIds(std::move(blockIds))
{
}

static SizeType32 firstWindowSize(BaseKVCacheManager const& cacheManager)
{
constexpr SizeType32 FIRST_POOL_IDX = 0;
return cacheManager.getBlockManager().getPoolWindowSize(FIRST_POOL_IDX);
}

private:
BaseKVCacheManager const* mManager;
runtime::ITensor::SharedPtr mPool;
SizeType32 mWindowSize;
const LlmRequest::RequestIdType mRequestId;
std::vector<SizeType32> mBlockIds;

static constexpr SizeType32 kFIRST_AND_ONLY_BEAM = 0;
static constexpr SizeType32 kFIRST_POOL_INDEX = 0;
};

class BlockIterator
Expand Down
7 changes: 3 additions & 4 deletions cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ void CacheFormatter::formatOutput(LlmRequest const& llmRequest,
constexpr SizeType32 beam{0};
auto& blockManager = mCacheManager->getBlockManager();
size_t requestBlockNum = llmRequest.getRequestedBlockHashes().size();
auto blockRange = BlockRange(*mCacheManager, llmRequest.mRequestId, beam);
auto blockRange = BlockRange::fromOldAllocatedBlockIds(*mCacheManager, llmRequest.mRequestId, beam);
if (requestBlockNum < blockRange.size() && requestBlockNum > 0)
{
// handle block reuse, the prefix blocks are reused
Expand Down Expand Up @@ -109,7 +109,7 @@ void CacheFormatter::formatOutput(LlmRequest const& llmRequest,
}
TLLM_CHECK(!inputKvCacheBlocks.empty());
TLLM_CHECK(blockNum > 0);
int deviceId = mCacheManager->getBlockManager().getBufferManager().getStream().getDevice();
int deviceId = mCacheManager->getBlockManager().getStreamDevice();

if (common::getEnvTryZCopyForKVCacheTransfer()
&& (destConfig.getParallelConfig().mPipelineParallelism
Expand Down Expand Up @@ -318,8 +318,7 @@ void CacheFormatter::formatInput(LlmRequest const& llmRequest,
"Start receiving KV cache for request ID: %ld, context request ID: %ld.", llmRequest.mRequestId,
llmRequest.getContextPhaseParams().value().getReqId());
TLLM_CHECK(!connections.empty());
auto blockRange = BlockRange(*mCacheManager, mCacheManager->getNewlyAllocatedBlockIds(llmRequest.mRequestId));

auto blockRange = BlockRange::fromNewlyAllocatedBlockIds(*mCacheManager, llmRequest.mRequestId);
std::vector<runtime::ITensor::SharedPtr> recvBufferTmps;
std::vector<runtime::ITensor::SharedPtr> outputBuffers;
auto const numPools = mCacheManager->getBlockManager().getNumPools();
Expand Down
111 changes: 60 additions & 51 deletions cpp/tensorrt_llm/batch_manager/capacityScheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,25 +199,32 @@ std::tuple<RequestVector, RequestVector> GuaranteedNoEvictScheduler::impl(
RequestVector scheduledRequests;

// Now check if we can add pending requests
auto const numFreeBlocks = kvCacheManager.getNumFreeBlocks();
auto const numFreeCrossBlocks = crossKvCacheManager ? crossKvCacheManager->getNumFreeBlocks() : 0;
auto const maxPeftCachePages
= peftCacheManager ? peftCacheManager->getMaxDevicePages() : std::numeric_limits<SizeType32>::max();

// The optimization of delaying requests won't work for variable window attention
bool skippingIsRelevant = (!kvCacheManager.getBlockManager().isVariableWindow())
&& (!crossKvCacheManager || !crossKvCacheManager->getBlockManager().isVariableWindow());

// Keep track of blocks contributed by requests in context phase
std::unordered_set<BlockKey, BlockKeyHasher> newlyContributedContextBlocks;
std::unordered_set<BlockKey, BlockKeyHasher> newlyContributedCrossContextBlocks;
if constexpr (!StaticBatchScheduling)
{
std::tie(newlyContributedContextBlocks, newlyContributedCrossContextBlocks)
= prefillWithChunkedContextsAlreadyExecuting(activeRequests, kvCacheManager, crossKvCacheManager);
if (skippingIsRelevant)
{
std::tie(newlyContributedContextBlocks, newlyContributedCrossContextBlocks)
= prefillWithChunkedContextsAlreadyExecuting(activeRequests, kvCacheManager, crossKvCacheManager);
}
}

// If a request is already in progress, include it
// If it's been allocated, it had resource to run to completion
// Also keep track of blocks needed to drive all in-progress requests to completion
SizeType32 reservedBlocks{0};
SizeType32 reservedCrossBlocks{0};
auto reservedBlocks = kv_cache_manager::NoEvictScheduledBlocksManager(kvCacheManager);
auto reservedCrossBlocks = crossKvCacheManager
? std::optional(kv_cache_manager::NoEvictScheduledBlocksManager(*crossKvCacheManager))
: std::nullopt;
SizeType32 claimedPeftPages{0};
std::unordered_set<uint64_t> uniqTaskIds{};
RequestVector pendingRequests;
Expand All @@ -242,16 +249,16 @@ std::tuple<RequestVector, RequestVector> GuaranteedNoEvictScheduler::impl(
else if (req->isGenerationInProgressState())
{
scheduledRequests.emplace_back(req);
reservedBlocks += kvCacheManager.getRemainingBlocksToCompletion(*req);

reservedBlocks.decrementReservedBlocks(*req);
if (reservedCrossBlocks)
reservedCrossBlocks->decrementReservedBlocks(*req);
bool const reqHasLora = req->getLoraTaskId().has_value();
bool const isNewTask = reqHasLora && !uniqTaskIds.count(req->getLoraTaskId().value());
if (isNewTask)
{
claimedPeftPages += peftCacheManager ? peftCacheManager->determineNumPages(req) : 0;
uniqTaskIds.insert(req->getLoraTaskId().value());
}
reservedCrossBlocks += crossKvCacheManager ? crossKvCacheManager->getRemainingBlocksToCompletion(*req) : 0;
}
else if (req->isDisaggGenerationInitState())
{
Expand All @@ -268,8 +275,6 @@ std::tuple<RequestVector, RequestVector> GuaranteedNoEvictScheduler::impl(
if (!StaticBatchScheduling || scheduledRequests.size() == 0)
{
// Now check if we can add pending requests
auto availableBlocks = numFreeBlocks - reservedBlocks;
auto availableCrossBlocks = numFreeCrossBlocks - reservedCrossBlocks;
auto availablePeftPages = maxPeftCachePages - claimedPeftPages;

// Loop over pending requests and add them if they can be scheduled
Expand All @@ -279,7 +284,7 @@ std::tuple<RequestVector, RequestVector> GuaranteedNoEvictScheduler::impl(
for (auto const& req : requests)
{
// if context request can reuse blocks contributed by another context request, skip
if (!StaticBatchScheduling && !req->isDisaggGenerationInitState()
if (!StaticBatchScheduling && skippingIsRelevant && !req->isDisaggGenerationInitState()
&& beneficialToSkip(req, kvCacheManager, crossKvCacheManager, newlyContributedContextBlocks,
newlyContributedCrossContextBlocks))
{
Expand All @@ -292,27 +297,26 @@ std::tuple<RequestVector, RequestVector> GuaranteedNoEvictScheduler::impl(
}
else if (req->isContextInitState() || req->isDisaggGenerationInitState())
{
auto const neededBlocks = kvCacheManager.getRemainingBlocksToCompletion(*req);
auto const neededCrossBlocks
= crossKvCacheManager ? crossKvCacheManager->getRemainingBlocksToCompletion(*req) : 0;
bool const reqHasLora = req->getLoraTaskId().has_value();
bool const isNewTask = reqHasLora && !uniqTaskIds.count(req->getLoraTaskId().value());
auto const neededPeftPages
= (isNewTask && peftCacheManager) ? peftCacheManager->determineNumPages(req) : 0;

if (neededBlocks <= availableBlocks && neededCrossBlocks <= availableCrossBlocks
&& neededPeftPages <= availablePeftPages)
bool enoughBlocks = reservedBlocks.enoughAvailableBlocks(*req);
bool enoughCrossBlocks
= reservedCrossBlocks ? reservedCrossBlocks->enoughAvailableBlocks(*req) : true;
bool reqHasLora = req->getLoraTaskId().has_value();
bool isNewTask = reqHasLora && !uniqTaskIds.count(req->getLoraTaskId().value());
auto neededPeftPages = isNewTask && peftCacheManager ? peftCacheManager->determineNumPages(req) : 0;

if (enoughBlocks && enoughCrossBlocks && neededPeftPages <= availablePeftPages)
{
scheduledRequests.emplace_back(req);
availableBlocks -= neededBlocks;
availableCrossBlocks -= neededCrossBlocks;
reservedBlocks.decrementReservedBlocks(*req);
if (reservedCrossBlocks)
reservedCrossBlocks->decrementReservedBlocks(*req);
availablePeftPages -= neededPeftPages;
if (isNewTask)
{
uniqTaskIds.insert(req->getLoraTaskId().value());
}
}
else if (neededBlocks > availableBlocks || neededCrossBlocks > availableCrossBlocks)
else if (!enoughBlocks || !enoughCrossBlocks)
{
// If one requests fails to be scheduled, break
break;
Expand All @@ -324,14 +328,25 @@ std::tuple<RequestVector, RequestVector> GuaranteedNoEvictScheduler::impl(
return {std::move(scheduledRequests), RequestVector{}};
}

// TODO(nhaber): remove forward declare and just keep the function here, right before the merge. I put it below just so
// the remote diff is easier to look at/rebase conflicts
bool trySchedulingRequestMaxUtilization(std::shared_ptr<LlmRequest> const& req, SizeType32 maxNumRequests,
RequestVector& scheduledRequests, kv_cache_manager::MaxUtilizationScheduledBlocksManager& blocksManager,
OptionalRef<BasePeftCacheManager const> peftCacheManager, SizeType32& numScheduledPeftPages,
std::unordered_set<uint64_t>& seenTaskIds);

std::tuple<RequestVector, RequestVector> MaxUtilizationScheduler::operator()(
kv_cache_manager::BaseKVCacheManager& kvCacheManager, OptionalRef<BasePeftCacheManager const> peftCacheManager,
RequestList const& activeRequests) const
{
kvCacheManager.startScheduling();

// The optimization of delaying requests won't work for variable window attention
bool skippingIsRelevant = !kvCacheManager.getBlockManager().isVariableWindow();

// Keep track of number of requests and block needed for the scheduled requests
SizeType32 numScheduledBlocks{0};
auto scheduledBlocksManager
= kv_cache_manager::MaxUtilizationScheduledBlocksManager(kvCacheManager, mManyMicroBatches);
SizeType32 numScheduledPeftPages{0};
std::unordered_set<uint64_t> seenTaskIds;

Expand Down Expand Up @@ -366,16 +381,17 @@ std::tuple<RequestVector, RequestVector> MaxUtilizationScheduler::operator()(
}

// if context request can reuse blocks contributed by another context request, skip
if (beneficialToSkip(
if (skippingIsRelevant
&& beneficialToSkip(
req, kvCacheManager, std::nullopt, newlyContributedContextBlocks, newlyContributedCrossContextBlocks))
{
reqIt++;
continue;
}

auto const [fitsKvCache, fitsPeftCache] = trySchedulingRequestMaxUtilization(kvCacheManager, peftCacheManager,
req, scheduledRequests, numScheduledBlocks, numScheduledPeftPages, seenTaskIds);
if (fitsKvCache && fitsPeftCache)
bool const wasScheduled = trySchedulingRequestMaxUtilization(req, mMaxNumRequests, scheduledRequests,
scheduledBlocksManager, peftCacheManager, numScheduledPeftPages, seenTaskIds);
if (wasScheduled)
{
TLLM_LOG_DEBUG("MaxUtilizationScheduler: request ID %lu -> start", req->mRequestId);
reqIt++;
Expand Down Expand Up @@ -405,45 +421,38 @@ std::tuple<RequestVector, RequestVector> MaxUtilizationScheduler::operator()(
return {std::move(scheduledRequests), std::move(pausedRequests)};
}

std::pair<bool, bool> MaxUtilizationScheduler::trySchedulingRequestMaxUtilization(
kv_cache_manager::BaseKVCacheManager const& kvCacheManager,
OptionalRef<BasePeftCacheManager const> peftCacheManager, std::shared_ptr<LlmRequest> const& req,
RequestVector& scheduledRequests, SizeType32& numScheduledBlocks, SizeType32& numScheduledPeftPages,
std::unordered_set<uint64_t>& seenTaskIds) const
bool trySchedulingRequestMaxUtilization(std::shared_ptr<LlmRequest> const& req, SizeType32 maxNumRequests,
RequestVector& scheduledRequests, kv_cache_manager::MaxUtilizationScheduledBlocksManager& blocksManager,
OptionalRef<BasePeftCacheManager const> peftCacheManager, SizeType32& numScheduledPeftPages,
std::unordered_set<uint64_t>& seenTaskIds)
{
if (scheduledRequests.size() < static_cast<std::size_t>(mMaxNumRequests))
if (scheduledRequests.size() < static_cast<std::size_t>(maxNumRequests))
{
SizeType32 numRequiredBlocks = kvCacheManager.getNeededBlocksOneStep(*req, mManyMicroBatches);
TLLM_LOG_DEBUG(
"MaxUtilizationScheduler: request ID %lu required blocks: %i", req->mRequestId, numRequiredBlocks);

bool const reqHasLora = req->getLoraTaskId().has_value();
bool const isNewTask = reqHasLora && !seenTaskIds.count(req->getLoraTaskId().value());
auto const numRequiredPeftPages
bool reqHasLora = req->getLoraTaskId().has_value();
bool isNewTask = reqHasLora && !seenTaskIds.count(req->getLoraTaskId().value());
SizeType32 numRequiredPeftPages
= (isNewTask && peftCacheManager) ? peftCacheManager->determineNumPages(req) : 0;
TLLM_LOG_DEBUG(
"MaxUtilizationScheduler: request ID %lu required peft pages: %i", req->mRequestId, numRequiredPeftPages);
bool const fitsKvCache
= kvCacheManager.getBlockManager().schedulingHasFreeBlocks(numScheduledBlocks + numRequiredBlocks);
bool const fitsPeft
auto const scheduledBlocksIfFitsKvCache = blocksManager.prepareNewNumberOfBlocksIfWeEndUpScheduling(*req);
bool fitsPeft
= (peftCacheManager ? numRequiredPeftPages + numScheduledPeftPages <= peftCacheManager->getMaxDevicePages()
: true);

if (fitsKvCache && fitsPeft)
if (scheduledBlocksIfFitsKvCache && fitsPeft)
{
numScheduledBlocks += numRequiredBlocks;
TLLM_LOG_DEBUG("MaxUtilizationScheduler: scheduled blocks: %i", numScheduledBlocks);
blocksManager.updateScheduledBlocks(scheduledBlocksIfFitsKvCache.value());
numScheduledPeftPages += numRequiredPeftPages;
TLLM_LOG_DEBUG("MaxUtilizationScheduler: scheduled peft pages: %i", numRequiredPeftPages);
scheduledRequests.emplace_back(req);
if (isNewTask)
{
seenTaskIds.insert(req->getLoraTaskId().value());
}
return true;
}
return std::make_pair(fitsKvCache, fitsPeft);
}
return std::make_pair(false, false);
return false;
}

CapacityScheduler::CapacityScheduler(SizeType32 maxNumRequests,
Expand Down
4 changes: 2 additions & 2 deletions cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ void DataReceiverImpl::sendRequestInfo(LlmRequest const& llmRequest)
if (cacheFormatter != nullptr)
{
auto* cacheManager = cacheFormatter->getCacheManager();
auto blockRange = kv_cache_manager::BlockRange(
*cacheManager, cacheManager->getNewlyAllocatedBlockIds(llmRequest.mRequestId));
auto blockRange
= kv_cache_manager::BlockRange::fromNewlyAllocatedBlockIds(*cacheManager, llmRequest.mRequestId);
requestInfo = RequestInfo(requestId, blockRange.getBlockHashes(), mSelfState);
}

Expand Down
Loading