diff --git a/docs/api-guide.md b/docs/api-guide.md
index 2dc2c0b..2a71d1c 100644
--- a/docs/api-guide.md
+++ b/docs/api-guide.md
@@ -312,6 +312,15 @@ if (daqiri::is_tx_burst_available(burst)) {
}
```
+For connection-oriented transports such as TCP socket mode, attach the connection ID before
+sending when you need to target a specific peer. RX bursts from those transports can be
+inspected with the matching getter:
+
+```cpp
+daqiri::set_connection_id(burst, conn_id);
+auto rx_conn_id = daqiri::get_connection_id(rx_burst);
+```
+
### Step 2: Fill packets
Use the header helper functions for standard UDP packets:
diff --git a/docs/daqiri-api.html b/docs/daqiri-api.html
index eff48e9..572a690 100644
--- a/docs/daqiri-api.html
+++ b/docs/daqiri-api.html
@@ -198,6 +198,7 @@
fn
diff --git a/examples/socket_bench.cpp b/examples/socket_bench.cpp
index 690e15f..6c937d0 100644
--- a/examples/socket_bench.cpp
+++ b/examples/socket_bench.cpp
@@ -105,8 +105,7 @@ void socket_worker(const SocketBenchConfig& cfg, std::atomic& stop, Socket
std::memset(payload, static_cast(stats.sent_packets & 0xff), cfg.message_size);
daqiri::set_packet_lengths(msg, 0, {cfg.message_size});
- // Socket transport optionally consumes conn_id from the generic burst header.
- msg->rdma_hdr.conn_id = conn_id;
+ daqiri::set_connection_id(msg, conn_id);
if (daqiri::send_tx_burst(msg) == daqiri::Status::SUCCESS) {
stats.sent_packets++;
diff --git a/include/daqiri/common.h b/include/daqiri/common.h
index c132691..dee1b5d 100644
--- a/include/daqiri/common.h
+++ b/include/daqiri/common.h
@@ -537,6 +537,21 @@ int64_t get_num_packets(BurstParams *burst);
*/
int64_t get_q_id(BurstParams *burst);
+/**
+ * @brief Get the transport connection ID associated with a burst
+ *
+ * @param burst Burst structure with transport metadata
+ */
+uintptr_t get_connection_id(const BurstParams *burst);
+
+/**
+ * @brief Set the transport connection ID associated with a burst
+ *
+ * @param burst Burst structure with transport metadata
+ * @param conn_id Connection ID representing a unique client/server connection
+ */
+void set_connection_id(BurstParams *burst, uintptr_t conn_id);
+
/**
* @brief Get mac address of an interface
*
diff --git a/include/daqiri/types.h b/include/daqiri/types.h
index 4cfe51d..a8a4041 100644
--- a/include/daqiri/types.h
+++ b/include/daqiri/types.h
@@ -85,7 +85,7 @@ enum class RDMAOpCode {
enum class RDMACompletionType { RX, TX, INVALID };
-struct AdvNetRdmaBurstHdr {
+struct BurstTransportHeader {
uint8_t version;
RDMAOpCode opcode;
Status status;
@@ -142,7 +142,7 @@ struct BurstHeader {
struct BurstParams {
union {
BurstHeader hdr;
- AdvNetRdmaBurstHdr rdma_hdr;
+ BurstTransportHeader transport_hdr;
};
std::array pkts;
diff --git a/python/daqiri_common_pybind.cpp b/python/daqiri_common_pybind.cpp
index 67de72e..fac4ab7 100644
--- a/python/daqiri_common_pybind.cpp
+++ b/python/daqiri_common_pybind.cpp
@@ -448,16 +448,16 @@ void bind_config_types(py::module_ &m) {
.def(py::init<>())
.def_readwrite("hdr", &BurstParams::hdr)
.def_property(
- "rdma_conn_id",
- [](const BurstParams &burst) { return burst.rdma_hdr.conn_id; },
+ "connection_id",
+ [](const BurstParams &burst) { return get_connection_id(&burst); },
[](BurstParams &burst, uintptr_t conn_id) {
- burst.rdma_hdr.conn_id = conn_id;
+ set_connection_id(&burst, conn_id);
})
.def_property(
"rdma_wr_id",
- [](const BurstParams &burst) { return burst.rdma_hdr.wr_id; },
+ [](const BurstParams &burst) { return burst.transport_hdr.wr_id; },
[](BurstParams &burst, uint64_t wr_id) {
- burst.rdma_hdr.wr_id = wr_id;
+ burst.transport_hdr.wr_id = wr_id;
});
py::class_(m, "RDMAConfig")
@@ -728,6 +728,8 @@ PYBIND11_MODULE(_daqiri, m) {
m.def("set_num_packets", &set_num_packets, "burst"_a, "num"_a);
m.def("get_num_packets", &get_num_packets, "burst"_a);
m.def("get_q_id", &get_q_id, "burst"_a);
+ m.def("set_connection_id", &set_connection_id, "burst"_a, "conn_id"_a);
+ m.def("get_connection_id", &get_connection_id, "burst"_a);
m.def(
"get_segment_packet_ptr",
diff --git a/src/common.cpp b/src/common.cpp
index 4b9c747..4127c00 100644
--- a/src/common.cpp
+++ b/src/common.cpp
@@ -243,6 +243,16 @@ int64_t get_q_id(BurstParams* burst) {
return burst->hdr.hdr.q_id;
}
+uintptr_t get_connection_id(const BurstParams* burst) {
+ assert(burst != nullptr && "burst is null");
+ return burst->transport_hdr.conn_id;
+}
+
+void set_connection_id(BurstParams* burst, uintptr_t conn_id) {
+ assert(burst != nullptr && "burst is null");
+ burst->transport_hdr.conn_id = conn_id;
+}
+
void set_num_packets(BurstParams* burst, int64_t num) {
assert(burst != nullptr && "burst is null");
burst->hdr.hdr.num_pkts = num;
diff --git a/src/managers/rdma/daqiri_rdma_mgr.cpp b/src/managers/rdma/daqiri_rdma_mgr.cpp
index df88152..a7c06cf 100644
--- a/src/managers/rdma/daqiri_rdma_mgr.cpp
+++ b/src/managers/rdma/daqiri_rdma_mgr.cpp
@@ -342,10 +342,10 @@ inline int RdmaMgr::set_affinity(int cpu_core) {
Status RdmaMgr::send_tx_burst(BurstParams* burst) {
struct rte_ring* ring;
- auto ri = tx_rings_map_.find(reinterpret_cast(burst->rdma_hdr.conn_id));
+ const auto conn_id = get_connection_id(burst);
+ auto ri = tx_rings_map_.find(reinterpret_cast(conn_id));
if (ri == tx_rings_map_.end()) {
- DAQIRI_LOG_ERROR("Invalid server connection ID in send_tx_burst: {}",
- burst->rdma_hdr.conn_id);
+ DAQIRI_LOG_ERROR("Invalid server connection ID in send_tx_burst: {}", conn_id);
return Status::INVALID_PARAMETER;
}
@@ -423,11 +423,13 @@ void RdmaMgr::rdma_thread(bool is_server, rdma_thread_params* tparams) {
msg = it->second;
- if (msg->rdma_hdr.conn_id != reinterpret_cast(tparams->client_id)) {
+ const auto conn_id = get_connection_id(msg);
+ const auto expected_conn_id = reinterpret_cast(tparams->client_id);
+ if (conn_id != expected_conn_id) {
DAQIRI_LOG_CRITICAL("Wrong connection ID in receive completion {}: {} != {}",
wc.wr_id,
- msg->rdma_hdr.conn_id,
- reinterpret_cast(tparams->client_id));
+ conn_id,
+ expected_conn_id);
}
outstanding_receive_wr_ids.erase(it);
@@ -436,13 +438,13 @@ void RdmaMgr::rdma_thread(bool is_server, rdma_thread_params* tparams) {
}
// Only populate a header to indicate which burst needs to be freed
- // msg->rdma_hdr.opcode = ibv_opcode_to_daqiri_opcode(wc.opcode);
- msg->rdma_hdr.status =
+ // msg->transport_hdr.opcode = ibv_opcode_to_daqiri_opcode(wc.opcode);
+ msg->transport_hdr.status =
wc.status == IBV_WC_SUCCESS ? Status::SUCCESS : Status::GENERIC_FAILURE;
- // msg->rdma_hdr.conn_id = reinterpret_cast(tparams->client_id);
- msg->rdma_hdr.server = is_server;
- msg->rdma_hdr.tx = false;
- // msg->rdma_hdr.wr_id = wc.wr_id;
+ // set_connection_id(msg, reinterpret_cast(tparams->client_id));
+ msg->transport_hdr.server = is_server;
+ msg->transport_hdr.tx = false;
+ // msg->transport_hdr.wr_id = wc.wr_id;
if (rte_ring_enqueue(rx_ring, reinterpret_cast(msg)) != 0) {
DAQIRI_LOG_CRITICAL("Failed to enqueue RX completion message");
@@ -477,11 +479,13 @@ void RdmaMgr::rdma_thread(bool is_server, rdma_thread_params* tparams) {
msg = it->second;
- if (msg->rdma_hdr.conn_id != reinterpret_cast(tparams->client_id)) {
+ const auto conn_id = get_connection_id(msg);
+ const auto expected_conn_id = reinterpret_cast(tparams->client_id);
+ if (conn_id != expected_conn_id) {
DAQIRI_LOG_CRITICAL("Wrong connection ID in send completion {}: {} != {}",
wc.wr_id,
- msg->rdma_hdr.conn_id,
- reinterpret_cast(tparams->client_id));
+ conn_id,
+ expected_conn_id);
}
outstanding_send_wr_ids.erase(it);
@@ -490,13 +494,13 @@ void RdmaMgr::rdma_thread(bool is_server, rdma_thread_params* tparams) {
}
// Only populate a header to indicate which burst needs to be freed
- // msg->rdma_hdr.opcode = ibv_opcode_to_daqiri_opcode(wc.opcode);
- msg->rdma_hdr.tx = true;
- msg->rdma_hdr.status =
+ // msg->transport_hdr.opcode = ibv_opcode_to_daqiri_opcode(wc.opcode);
+ msg->transport_hdr.tx = true;
+ msg->transport_hdr.status =
wc.status == IBV_WC_SUCCESS ? Status::SUCCESS : Status::GENERIC_FAILURE;
- // msg->rdma_hdr.conn_id = reinterpret_cast(tparams->client_id);
- msg->rdma_hdr.server = is_server;
- // msg->rdma_hdr.wr_id = wc.wr_id;
+ // set_connection_id(msg, reinterpret_cast(tparams->client_id));
+ msg->transport_hdr.server = is_server;
+ // msg->transport_hdr.wr_id = wc.wr_id;
if (rte_ring_enqueue(rx_ring, reinterpret_cast(msg)) != 0) {
DAQIRI_LOG_CRITICAL("Failed to enqueue RX completion message");
@@ -515,15 +519,15 @@ void RdmaMgr::rdma_thread(bool is_server, rdma_thread_params* tparams) {
continue;
}
- const auto local_mr = mrs_.find(std::string(burst->rdma_hdr.local_mr_name));
+ const auto local_mr = mrs_.find(std::string(burst->transport_hdr.local_mr_name));
if (local_mr == mrs_.end()) {
DAQIRI_LOG_CRITICAL("Couldn't find MR with name {} in registry",
- burst->rdma_hdr.local_mr_name);
+ burst->transport_hdr.local_mr_name);
free_tx_burst(burst);
continue;
}
- switch (burst->rdma_hdr.opcode) {
+ switch (burst->transport_hdr.opcode) {
case RDMAOpCode::SEND: {
// Get lkey for this PD
auto pd = pd_map_.find(tparams->client_id->verbs);
@@ -537,12 +541,12 @@ void RdmaMgr::rdma_thread(bool is_server, rdma_thread_params* tparams) {
auto lkey = local_mr->second.ctx_mr_map_.find(pd->second);
if (lkey == local_mr->second.ctx_mr_map_.end()) {
DAQIRI_LOG_CRITICAL("Couldn't find MR with name {} in registry",
- burst->rdma_hdr.local_mr_name);
+ burst->transport_hdr.local_mr_name);
free_tx_burst(burst);
continue;
}
- for (int p = 0; p < burst->rdma_hdr.num_pkts; p++) {
+ for (int p = 0; p < burst->transport_hdr.num_pkts; p++) {
ibv_send_wr wr;
ibv_send_wr* bad_wr;
ibv_sge sge;
@@ -551,7 +555,7 @@ void RdmaMgr::rdma_thread(bool is_server, rdma_thread_params* tparams) {
sge.addr = (uint64_t)burst->pkts[0][p];
sge.length = (uint32_t)burst->pkt_lens[0][p];
sge.lkey = lkey->second->lkey;
- wr.wr_id = burst->rdma_hdr.wr_id + p; // Auto-increment wr_id to be unique
+ wr.wr_id = burst->transport_hdr.wr_id + p; // Auto-increment wr_id to be unique
wr.sg_list = &sge;
wr.num_sge = 1;
wr.opcode = IBV_WR_SEND;
@@ -564,7 +568,7 @@ void RdmaMgr::rdma_thread(bool is_server, rdma_thread_params* tparams) {
continue;
}
- outstanding_send_wr_ids[burst->rdma_hdr.wr_id + p] = burst;
+ outstanding_send_wr_ids[burst->transport_hdr.wr_id + p] = burst;
}
break;
@@ -582,12 +586,12 @@ void RdmaMgr::rdma_thread(bool is_server, rdma_thread_params* tparams) {
auto lkey = local_mr->second.ctx_mr_map_.find(pd->second);
if (lkey == local_mr->second.ctx_mr_map_.end()) {
DAQIRI_LOG_CRITICAL("Couldn't find MR with name {} in registry",
- burst->rdma_hdr.local_mr_name);
+ burst->transport_hdr.local_mr_name);
free_tx_burst(burst);
continue;
}
- for (int p = 0; p < burst->rdma_hdr.num_pkts; p++) {
+ for (int p = 0; p < burst->transport_hdr.num_pkts; p++) {
struct ibv_recv_wr recv_wr;
struct ibv_sge sge;
struct ibv_recv_wr* bad_wr = NULL;
@@ -601,7 +605,7 @@ void RdmaMgr::rdma_thread(bool is_server, rdma_thread_params* tparams) {
// Prepare Receive Work Request
memset(&recv_wr, 0, sizeof(recv_wr));
- recv_wr.wr_id = burst->rdma_hdr.wr_id + p;
+ recv_wr.wr_id = burst->transport_hdr.wr_id + p;
recv_wr.next = NULL;
recv_wr.sg_list = &sge;
recv_wr.num_sge = 1;
@@ -614,7 +618,7 @@ void RdmaMgr::rdma_thread(bool is_server, rdma_thread_params* tparams) {
continue;
}
- outstanding_receive_wr_ids[burst->rdma_hdr.wr_id + p] = burst;
+ outstanding_receive_wr_ids[burst->transport_hdr.wr_id + p] = burst;
}
break;
}
@@ -638,20 +642,20 @@ Status RdmaMgr::rdma_connect_to_server(const std::string& dst_addr, uint16_t dst
}
RDMAOpCode RdmaMgr::rdma_get_opcode(BurstParams* burst) {
- return burst->rdma_hdr.opcode;
+ return burst->transport_hdr.opcode;
}
Status RdmaMgr::rdma_set_header(BurstParams* burst, RDMAOpCode op_code, uintptr_t conn_id,
bool is_server, int num_pkts, uint64_t wr_id,
const std::string& local_mr_name) {
- burst->rdma_hdr.opcode = op_code;
- burst->rdma_hdr.conn_id = conn_id;
- burst->rdma_hdr.server = is_server;
- burst->rdma_hdr.num_pkts = num_pkts;
- burst->rdma_hdr.num_segs = 1;
- burst->rdma_hdr.wr_id = wr_id;
- snprintf(burst->rdma_hdr.local_mr_name,
- sizeof(burst->rdma_hdr.local_mr_name),
+ burst->transport_hdr.opcode = op_code;
+ set_connection_id(burst, conn_id);
+ burst->transport_hdr.server = is_server;
+ burst->transport_hdr.num_pkts = num_pkts;
+ burst->transport_hdr.num_segs = 1;
+ burst->transport_hdr.wr_id = wr_id;
+ snprintf(burst->transport_hdr.local_mr_name,
+ sizeof(burst->transport_hdr.local_mr_name),
"%s",
local_mr_name.c_str());
return Status::SUCCESS;
@@ -917,12 +921,12 @@ Status RdmaMgr::rdma_get_port_queue(uintptr_t conn_id, uint16_t* port, uint16_t*
Status RdmaMgr::get_tx_packet_burst(BurstParams* burst) {
// RDMA isn't allowing split segments yet
- assert(burst->rdma_hdr.num_segs == 1);
- assert(burst->rdma_hdr.num_pkts <= MAX_RDMA_BATCH);
- auto burst_pool = mem_pools_.find(burst->rdma_hdr.local_mr_name);
+ assert(burst->transport_hdr.num_segs == 1);
+ assert(burst->transport_hdr.num_pkts <= MAX_RDMA_BATCH);
+ auto burst_pool = mem_pools_.find(burst->transport_hdr.local_mr_name);
if (burst_pool == mem_pools_.end()) {
DAQIRI_LOG_ERROR("Failed to look up burst pool name for MR {}",
- burst->rdma_hdr.local_mr_name);
+ burst->transport_hdr.local_mr_name);
return Status::INVALID_PARAMETER;
}
@@ -933,13 +937,13 @@ Status RdmaMgr::get_tx_packet_burst(BurstParams* burst) {
int rx = rte_ring_dequeue_bulk(burst_pool->second,
reinterpret_cast(burst->pkts[0]),
- burst->rdma_hdr.num_pkts,
+ burst->transport_hdr.num_pkts,
nullptr);
- if (rx != burst->rdma_hdr.num_pkts) {
- DAQIRI_LOG_ERROR("Asked for {} packets, got {}", burst->rdma_hdr.num_pkts, rx);
+ if (rx != burst->transport_hdr.num_pkts) {
+ DAQIRI_LOG_ERROR("Asked for {} packets, got {}", burst->transport_hdr.num_pkts, rx);
rte_ring_enqueue_bulk(burst_pool->second,
reinterpret_cast(burst->pkts[0]),
- burst->rdma_hdr.num_pkts,
+ burst->transport_hdr.num_pkts,
nullptr);
return Status::NO_FREE_BURST_BUFFERS;
}
@@ -955,14 +959,14 @@ Status RdmaMgr::get_tx_packet_burst(BurstParams* burst) {
}
bool RdmaMgr::is_tx_burst_available(BurstParams* burst) {
- auto burst_pool = mem_pools_.find(burst->rdma_hdr.local_mr_name);
+ auto burst_pool = mem_pools_.find(burst->transport_hdr.local_mr_name);
if (burst_pool == mem_pools_.end()) {
DAQIRI_LOG_ERROR("Failed to look up burst pool name for MR {}",
- burst->rdma_hdr.local_mr_name);
+ burst->transport_hdr.local_mr_name);
return false;
}
- if (rte_ring_count(burst_pool->second) < burst->rdma_hdr.num_pkts) { return false; }
+ if (rte_ring_count(burst_pool->second) < burst->transport_hdr.num_pkts) { return false; }
return true;
}
@@ -1341,21 +1345,21 @@ void RdmaMgr::free_rx_burst(BurstParams* burst) {
}
void RdmaMgr::free_tx_burst(BurstParams* burst) {
- auto burst_pool = mem_pools_.find(burst->rdma_hdr.local_mr_name);
+ auto burst_pool = mem_pools_.find(burst->transport_hdr.local_mr_name);
if (burst_pool != mem_pools_.end()) {
int ret = rte_ring_enqueue_bulk(burst_pool->second,
reinterpret_cast(burst->pkts[0]),
- burst->rdma_hdr.num_pkts,
+ burst->transport_hdr.num_pkts,
nullptr);
- if (ret != burst->rdma_hdr.num_pkts) {
+ if (ret != burst->transport_hdr.num_pkts) {
DAQIRI_LOG_CRITICAL(
- "Asked to free {} packets, only enqueued {}", burst->rdma_hdr.num_pkts, ret);
+ "Asked to free {} packets, only enqueued {}", burst->transport_hdr.num_pkts, ret);
}
}
rte_mempool_put(tx_burst_pool_, (void*)burst->pkts[0]);
rte_mempool_put(pkt_len_pool_, (void*)burst->pkt_lens[0]);
- burst->rdma_hdr.num_pkts = 0;
+ burst->transport_hdr.num_pkts = 0;
rte_mempool_put(tx_meta, burst);
}
diff --git a/src/managers/socket/daqiri_socket_mgr.cpp b/src/managers/socket/daqiri_socket_mgr.cpp
index 3ccec89..841533e 100644
--- a/src/managers/socket/daqiri_socket_mgr.cpp
+++ b/src/managers/socket/daqiri_socket_mgr.cpp
@@ -815,7 +815,7 @@ Status SocketMgr::send_tx_burst(BurstParams* burst) {
{
std::lock_guard lock(state_mutex_);
- auto requested_id = burst->rdma_hdr.conn_id;
+ auto requested_id = get_connection_id(burst);
if (requested_id != 0) {
auto it = connections_.find(requested_id);
if (it != connections_.end()) { conn = it->second; }
@@ -1230,7 +1230,7 @@ void SocketMgr::tcp_rx_loop(std::shared_ptr conn) {
std::memcpy(payload, tmp.data(), static_cast(rx));
burst->pkts[0][0] = payload;
burst->pkt_lens[0][0] = static_cast(rx);
- burst->rdma_hdr.conn_id = conn->conn_id;
+ set_connection_id(burst, conn->conn_id);
push_rx_burst(conn->rx_queue, burst);
rx_pkts_.fetch_add(1);
@@ -1301,7 +1301,7 @@ void SocketMgr::udp_rx_loop(int if_index) {
std::memcpy(payload, iovs[static_cast(i)].iov_base, rx);
burst->pkts[0][0] = payload;
burst->pkt_lens[0][0] = static_cast(rx);
- burst->rdma_hdr.conn_id = ep->primary_conn_id;
+ set_connection_id(burst, ep->primary_conn_id);
push_rx_burst(ep->rx_queue_state, burst);
rx_pkts_.fetch_add(1);