Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions cpp/tensorrt_llm/common/envUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,28 @@ std::string getEnvNixlInterface()
return nixlInterface;
}

std::string getEnvNixlBackend()
{
static std::once_flag flag;
static std::string nixlBackend;

std::call_once(flag,
[&]()
{
char const* nixl_backend = std::getenv("TRTLLM_NIXL_KVCACHE_BACKEND");
if (nixl_backend)
{
nixlBackend = nixl_backend;
}
else
{
// Default to UCX if not specified
nixlBackend = "UCX";
}
});
return nixlBackend;
}

bool getEnvDisaggLayerwise()
{
static bool const disaggLayerwise = getBoolEnv("TRTLLM_DISAGG_LAYERWISE");
Expand Down
2 changes: 2 additions & 0 deletions cpp/tensorrt_llm/common/envUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ std::string getEnvUCXInterface();

std::string getEnvNixlInterface();

std::string getEnvNixlBackend();

bool getEnvDisaggLayerwise();

bool getEnvParallelCacheSend();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <netdb.h>
#include <netinet/in.h>
#include <nixl_types.h>
#include <set>
#include <sys/file.h>
#include <sys/stat.h>
#include <unistd.h>
Expand Down Expand Up @@ -345,15 +346,27 @@ NixlTransferAgent::NixlTransferAgent(BaseAgentConfig const& config)
mRawAgent = std::make_unique<nixlAgent>(config.mName, std::move(nixlConfig));
}

std::string nixlBackend = common::getEnvNixlBackend();
// List of supported backends - extend this list as new backends are added
static const std::set<std::string> kSUPPORTED_BACKENDS = {"UCX"};

if (kSUPPORTED_BACKENDS.find(nixlBackend) == kSUPPORTED_BACKENDS.end())
{
TLLM_LOG_ERROR("Unsupported NIXL backend: %s, fallback to UCX", nixlBackend.c_str());
nixlBackend = "UCX";
}

TLLM_LOG_INFO("NixlTransferAgent::NixlTransferAgent using NIXL backend: %s", nixlBackend.c_str());

nixl_b_params_t init1;
nixl_mem_list_t mems1;
status = mRawAgent->getPluginParams("UCX", mems1, init1);
status = mRawAgent->getPluginParams(nixlBackend.c_str(), mems1, init1);
TLLM_CHECK(status == NIXL_SUCCESS);

status = mRawAgent->createBackend("UCX", init1, mRawBackend);
status = mRawAgent->createBackend(nixlBackend.c_str(), init1, mRawBackend);
if (status != NIXL_SUCCESS || !mRawBackend)
{
TLLM_THROW("Failed to create NIXL backend");
TLLM_THROW("Failed to create NIXL backend: %s", nixlBackend.c_str());
}
mExtraParams.backends.push_back(mRawBackend);
TLLM_LOG_INFO("NixlTransferAgent::NixlTransferAgent mAddress: %s", mAddress.c_str());
Expand Down