Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 5 additions & 11 deletions cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,27 +84,21 @@ void tb::CacheTransceiverBindings::initBindings(nb::module_& m)
.def("check_gen_transfer_status", &BaseCacheTransceiver::checkGenTransferStatus)
.def("check_gen_transfer_complete", &BaseCacheTransceiver::checkGenTransferComplete);

nb::enum_<tb::CacheTransceiver::CommType>(m, "CommType")
.value("UNKNOWN", tb::CacheTransceiver::CommType::UNKNOWN)
.value("MPI", tb::CacheTransceiver::CommType::MPI)
.value("UCX", tb::CacheTransceiver::CommType::UCX)
.value("NIXL", tb::CacheTransceiver::CommType::NIXL);

nb::enum_<executor::kv_cache::CacheState::AttentionType>(m, "AttentionType")
.value("DEFAULT", executor::kv_cache::CacheState::AttentionType::kDEFAULT)
.value("MLA", executor::kv_cache::CacheState::AttentionType::kMLA);

nb::class_<tb::CacheTransceiver, tb::BaseCacheTransceiver>(m, "CacheTransceiver")
.def(nb::init<tb::kv_cache_manager::BaseKVCacheManager*, tb::CacheTransceiver::CommType,
std::vector<SizeType32>, SizeType32, SizeType32, runtime::WorldConfig, nvinfer1::DataType,
executor::kv_cache::CacheState::AttentionType, std::optional<executor::CacheTransceiverConfig>>(),
nb::arg("cache_manager"), nb::arg("comm_type"), nb::arg("num_kv_heads_per_layer"), nb::arg("size_per_head"),
.def(nb::init<tb::kv_cache_manager::BaseKVCacheManager*, std::vector<SizeType32>, SizeType32, SizeType32,
runtime::WorldConfig, nvinfer1::DataType, executor::kv_cache::CacheState::AttentionType,
std::optional<executor::CacheTransceiverConfig>>(),
nb::arg("cache_manager"), nb::arg("num_kv_heads_per_layer"), nb::arg("size_per_head"),
nb::arg("tokens_per_block"), nb::arg("world_config"), nb::arg("dtype"), nb::arg("attention_type"),
nb::arg("cache_transceiver_config") = std::nullopt);

nb::class_<tb::kv_cache_manager::CacheTransBufferManager>(m, "CacheTransBufferManager")
.def(nb::init<tb::kv_cache_manager::BaseKVCacheManager*, std::optional<size_t>>(), nb::arg("cache_manager"),
nb::arg("max_num_tokens") = std::nullopt)
.def_static("pre_alloc_buffer_size", &tb::kv_cache_manager::CacheTransBufferManager::preAllocBufferSize,
nb::arg("max_num_tokens") = std::nullopt);
nb::arg("cache_size_bytes_per_token_per_window"), nb::arg("cache_transceiver_config") = nb::none());
}
37 changes: 30 additions & 7 deletions cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,21 +424,44 @@ void initConfigBindings(nb::module_& m)
.def("__getstate__", guidedDecodingConfigGetstate)
.def("__setstate__", guidedDecodingConfigSetstate);

auto cacheTransceiverConfigGetstate
= [](tle::CacheTransceiverConfig const& self) { return nb::make_tuple(self.getMaxNumTokens()); };
auto cacheTransceiverConfigGetstate = [](tle::CacheTransceiverConfig const& self)
{ return nb::make_tuple(self.getBackendType(), self.getMaxTokensInBuffer()); };
auto cacheTransceiverConfigSetstate = [](tle::CacheTransceiverConfig& self, nb::tuple const& state)
{
if (state.size() != 1)
if (state.size() != 2)
{
throw std::runtime_error("Invalid CacheTransceiverConfig state!");
}
new (&self) tle::CacheTransceiverConfig(nb::cast<std::optional<size_t>>(state[0]));
new (&self) tle::CacheTransceiverConfig(
nb::cast<tle::CacheTransceiverConfig::BackendType>(state[0]), nb::cast<std::optional<size_t>>(state[1]));
};

nb::enum_<tle::CacheTransceiverConfig::BackendType>(m, "CacheTransceiverBackendType")
.value("DEFAULT", tle::CacheTransceiverConfig::BackendType::DEFAULT)
.value("MPI", tle::CacheTransceiverConfig::BackendType::MPI)
.value("UCX", tle::CacheTransceiverConfig::BackendType::UCX)
.value("NIXL", tle::CacheTransceiverConfig::BackendType::NIXL)
.def("from_string",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that nanobind

nb::class_<MyType>(m, "MyType")
    .def(nb::init_implicit<MyOtherType>());

more suitable here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would it work? Unfortunately, I don't see a way to define an init function for an enum with nanobind.

[](std::string const& str)
{
if (str == "DEFAULT" || str == "default")
return tle::CacheTransceiverConfig::BackendType::DEFAULT;
if (str == "MPI" || str == "mpi")
return tle::CacheTransceiverConfig::BackendType::MPI;
if (str == "UCX" || str == "ucx")
return tle::CacheTransceiverConfig::BackendType::UCX;
if (str == "NIXL" || str == "nixl")
return tle::CacheTransceiverConfig::BackendType::NIXL;
throw std::runtime_error("Invalid backend type: " + str);
});

nb::class_<tle::CacheTransceiverConfig>(m, "CacheTransceiverConfig")
.def(nb::init<std::optional<size_t>>(), nb::arg("max_num_tokens") = nb::none())
.def_prop_rw("max_num_tokens", &tle::CacheTransceiverConfig::getMaxNumTokens,
&tle::CacheTransceiverConfig::setMaxNumTokens)
.def(nb::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>>(),
nb::arg("backend") = std::nullopt, nb::arg("max_tokens_in_buffer") = std::nullopt)
.def_prop_rw(
"backend", &tle::CacheTransceiverConfig::getBackendType, &tle::CacheTransceiverConfig::setBackendType)
.def_prop_rw("max_tokens_in_buffer", &tle::CacheTransceiverConfig::getMaxTokensInBuffer,
&tle::CacheTransceiverConfig::setMaxTokensInBuffer)
.def("__getstate__", cacheTransceiverConfigGetstate)
.def("__setstate__", cacheTransceiverConfigSetstate);

Expand Down
5 changes: 3 additions & 2 deletions tests/unittest/bindings/test_executor_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2478,8 +2478,9 @@ def test_guided_decoding_config_pickle():


def test_cache_transceiver_config_pickle():
config = trtllm.CacheTransceiverConfig(backend="UCX",
max_tokens_in_buffer=1024)
config = trtllm.CacheTransceiverConfig(
backend=trtllm.CacheTransceiverBackendType.UCX,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a lot of tests that use the string format, do we have something similar to "implicitly convertible" in Nanobind?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which tests do you mean? I only found this test and a test in kv_cache_transceiver. But the latter accesses the enum member directly and not through a string

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have cache_transceiver_config.backend: <str> in tests/integration/defs/disaggregated/test_configs, the type of field of backend is string.

max_tokens_in_buffer=1024)
config_copy = pickle.loads(pickle.dumps(config))
assert config_copy.backend == config.backend
assert config_copy.max_tokens_in_buffer == config.max_tokens_in_buffer
Expand Down