diff --git a/docs/api-guide.md b/docs/api-guide.md index 2dc2c0b..2a71d1c 100644 --- a/docs/api-guide.md +++ b/docs/api-guide.md @@ -312,6 +312,15 @@ if (daqiri::is_tx_burst_available(burst)) { } ``` +For connection-oriented transports such as TCP socket mode, attach the connection ID before +sending when you need to target a specific peer. RX bursts from those transports can be +inspected with the matching getter: + +```cpp +daqiri::set_connection_id(burst, conn_id); +auto rx_conn_id = daqiri::get_connection_id(rx_burst); +``` + ### Step 2: Fill packets Use the header helper functions for standard UDP packets: diff --git a/docs/daqiri-api.html b/docs/daqiri-api.html index eff48e9..572a690 100644 --- a/docs/daqiri-api.html +++ b/docs/daqiri-api.html @@ -198,6 +198,7 @@
Receive (RX)
get_rx_burst fn get_num_packets fn + connection ID fn get_packet_ptr fn get_packet_length fn get_packet_flow_id fn @@ -328,6 +329,20 @@

Segments

+
+
+ fn + set_connection_id / get_connection_id + (burst, conn_id?) + +
+
+

Sets or reads the transport connection ID carried by a burst. Connection-oriented transports use this to select a peer for TX and identify the peer that produced an RX burst.

+
daqiri::set_connection_id(burst, conn_id);
+uintptr_t rx_conn_id = daqiri::get_connection_id(rx_burst);
+
+
+
fn diff --git a/examples/socket_bench.cpp b/examples/socket_bench.cpp index 690e15f..6c937d0 100644 --- a/examples/socket_bench.cpp +++ b/examples/socket_bench.cpp @@ -105,8 +105,7 @@ void socket_worker(const SocketBenchConfig& cfg, std::atomic& stop, Socket std::memset(payload, static_cast(stats.sent_packets & 0xff), cfg.message_size); daqiri::set_packet_lengths(msg, 0, {cfg.message_size}); - // Socket transport optionally consumes conn_id from the generic burst header. - msg->rdma_hdr.conn_id = conn_id; + daqiri::set_connection_id(msg, conn_id); if (daqiri::send_tx_burst(msg) == daqiri::Status::SUCCESS) { stats.sent_packets++; diff --git a/include/daqiri/common.h b/include/daqiri/common.h index c132691..dee1b5d 100644 --- a/include/daqiri/common.h +++ b/include/daqiri/common.h @@ -537,6 +537,21 @@ int64_t get_num_packets(BurstParams *burst); */ int64_t get_q_id(BurstParams *burst); +/** + * @brief Get the transport connection ID associated with a burst + * + * @param burst Burst structure with transport metadata + */ +uintptr_t get_connection_id(const BurstParams *burst); + +/** + * @brief Set the transport connection ID associated with a burst + * + * @param burst Burst structure with transport metadata + * @param conn_id Connection ID representing a unique client/server connection + */ +void set_connection_id(BurstParams *burst, uintptr_t conn_id); + /** * @brief Get mac address of an interface * diff --git a/include/daqiri/types.h b/include/daqiri/types.h index 4cfe51d..a8a4041 100644 --- a/include/daqiri/types.h +++ b/include/daqiri/types.h @@ -85,7 +85,7 @@ enum class RDMAOpCode { enum class RDMACompletionType { RX, TX, INVALID }; -struct AdvNetRdmaBurstHdr { +struct BurstTransportHeader { uint8_t version; RDMAOpCode opcode; Status status; @@ -142,7 +142,7 @@ struct BurstHeader { struct BurstParams { union { BurstHeader hdr; - AdvNetRdmaBurstHdr rdma_hdr; + BurstTransportHeader transport_hdr; }; std::array pkts; diff --git a/python/daqiri_common_pybind.cpp b/python/daqiri_common_pybind.cpp index 67de72e..fac4ab7 100644 --- a/python/daqiri_common_pybind.cpp +++ b/python/daqiri_common_pybind.cpp @@ -448,16 +448,16 @@ void bind_config_types(py::module_ &m) { .def(py::init<>()) .def_readwrite("hdr", &BurstParams::hdr) .def_property( - "rdma_conn_id", - [](const BurstParams &burst) { return burst.rdma_hdr.conn_id; }, + "connection_id", + [](const BurstParams &burst) { return get_connection_id(&burst); }, [](BurstParams &burst, uintptr_t conn_id) { - burst.rdma_hdr.conn_id = conn_id; + set_connection_id(&burst, conn_id); }) .def_property( "rdma_wr_id", - [](const BurstParams &burst) { return burst.rdma_hdr.wr_id; }, + [](const BurstParams &burst) { return burst.transport_hdr.wr_id; }, [](BurstParams &burst, uint64_t wr_id) { - burst.rdma_hdr.wr_id = wr_id; + burst.transport_hdr.wr_id = wr_id; }); py::class_(m, "RDMAConfig") @@ -728,6 +728,8 @@ PYBIND11_MODULE(_daqiri, m) { m.def("set_num_packets", &set_num_packets, "burst"_a, "num"_a); m.def("get_num_packets", &get_num_packets, "burst"_a); m.def("get_q_id", &get_q_id, "burst"_a); + m.def("set_connection_id", &set_connection_id, "burst"_a, "conn_id"_a); + m.def("get_connection_id", &get_connection_id, "burst"_a); m.def( "get_segment_packet_ptr", diff --git a/src/common.cpp b/src/common.cpp index 4b9c747..4127c00 100644 --- a/src/common.cpp +++ b/src/common.cpp @@ -243,6 +243,16 @@ int64_t get_q_id(BurstParams* burst) { return burst->hdr.hdr.q_id; } +uintptr_t get_connection_id(const BurstParams* burst) { + assert(burst != nullptr && "burst is null"); + return burst->transport_hdr.conn_id; +} + +void set_connection_id(BurstParams* burst, uintptr_t conn_id) { + assert(burst != nullptr && "burst is null"); + burst->transport_hdr.conn_id = conn_id; +} + void set_num_packets(BurstParams* burst, int64_t num) { assert(burst != nullptr && "burst is null"); burst->hdr.hdr.num_pkts = num; diff --git a/src/managers/rdma/daqiri_rdma_mgr.cpp b/src/managers/rdma/daqiri_rdma_mgr.cpp index df88152..a7c06cf 100644 --- a/src/managers/rdma/daqiri_rdma_mgr.cpp +++ b/src/managers/rdma/daqiri_rdma_mgr.cpp @@ -342,10 +342,10 @@ inline int RdmaMgr::set_affinity(int cpu_core) { Status RdmaMgr::send_tx_burst(BurstParams* burst) { struct rte_ring* ring; - auto ri = tx_rings_map_.find(reinterpret_cast(burst->rdma_hdr.conn_id)); + const auto conn_id = get_connection_id(burst); + auto ri = tx_rings_map_.find(reinterpret_cast(conn_id)); if (ri == tx_rings_map_.end()) { - DAQIRI_LOG_ERROR("Invalid server connection ID in send_tx_burst: {}", - burst->rdma_hdr.conn_id); + DAQIRI_LOG_ERROR("Invalid server connection ID in send_tx_burst: {}", conn_id); return Status::INVALID_PARAMETER; } @@ -423,11 +423,13 @@ void RdmaMgr::rdma_thread(bool is_server, rdma_thread_params* tparams) { msg = it->second; - if (msg->rdma_hdr.conn_id != reinterpret_cast(tparams->client_id)) { + const auto conn_id = get_connection_id(msg); + const auto expected_conn_id = reinterpret_cast(tparams->client_id); + if (conn_id != expected_conn_id) { DAQIRI_LOG_CRITICAL("Wrong connection ID in receive completion {}: {} != {}", wc.wr_id, - msg->rdma_hdr.conn_id, - reinterpret_cast(tparams->client_id)); + conn_id, + expected_conn_id); } outstanding_receive_wr_ids.erase(it); @@ -436,13 +438,13 @@ void RdmaMgr::rdma_thread(bool is_server, rdma_thread_params* tparams) { } // Only populate a header to indicate which burst needs to be freed - // msg->rdma_hdr.opcode = ibv_opcode_to_daqiri_opcode(wc.opcode); - msg->rdma_hdr.status = + // msg->transport_hdr.opcode = ibv_opcode_to_daqiri_opcode(wc.opcode); + msg->transport_hdr.status = wc.status == IBV_WC_SUCCESS ? Status::SUCCESS : Status::GENERIC_FAILURE; - // msg->rdma_hdr.conn_id = reinterpret_cast(tparams->client_id); - msg->rdma_hdr.server = is_server; - msg->rdma_hdr.tx = false; - // msg->rdma_hdr.wr_id = wc.wr_id; + // set_connection_id(msg, reinterpret_cast(tparams->client_id)); + msg->transport_hdr.server = is_server; + msg->transport_hdr.tx = false; + // msg->transport_hdr.wr_id = wc.wr_id; if (rte_ring_enqueue(rx_ring, reinterpret_cast(msg)) != 0) { DAQIRI_LOG_CRITICAL("Failed to enqueue RX completion message"); @@ -477,11 +479,13 @@ void RdmaMgr::rdma_thread(bool is_server, rdma_thread_params* tparams) { msg = it->second; - if (msg->rdma_hdr.conn_id != reinterpret_cast(tparams->client_id)) { + const auto conn_id = get_connection_id(msg); + const auto expected_conn_id = reinterpret_cast(tparams->client_id); + if (conn_id != expected_conn_id) { DAQIRI_LOG_CRITICAL("Wrong connection ID in send completion {}: {} != {}", wc.wr_id, - msg->rdma_hdr.conn_id, - reinterpret_cast(tparams->client_id)); + conn_id, + expected_conn_id); } outstanding_send_wr_ids.erase(it); @@ -490,13 +494,13 @@ void RdmaMgr::rdma_thread(bool is_server, rdma_thread_params* tparams) { } // Only populate a header to indicate which burst needs to be freed - // msg->rdma_hdr.opcode = ibv_opcode_to_daqiri_opcode(wc.opcode); - msg->rdma_hdr.tx = true; - msg->rdma_hdr.status = + // msg->transport_hdr.opcode = ibv_opcode_to_daqiri_opcode(wc.opcode); + msg->transport_hdr.tx = true; + msg->transport_hdr.status = wc.status == IBV_WC_SUCCESS ? Status::SUCCESS : Status::GENERIC_FAILURE; - // msg->rdma_hdr.conn_id = reinterpret_cast(tparams->client_id); - msg->rdma_hdr.server = is_server; - // msg->rdma_hdr.wr_id = wc.wr_id; + // set_connection_id(msg, reinterpret_cast(tparams->client_id)); + msg->transport_hdr.server = is_server; + // msg->transport_hdr.wr_id = wc.wr_id; if (rte_ring_enqueue(rx_ring, reinterpret_cast(msg)) != 0) { DAQIRI_LOG_CRITICAL("Failed to enqueue RX completion message"); @@ -515,15 +519,15 @@ void RdmaMgr::rdma_thread(bool is_server, rdma_thread_params* tparams) { continue; } - const auto local_mr = mrs_.find(std::string(burst->rdma_hdr.local_mr_name)); + const auto local_mr = mrs_.find(std::string(burst->transport_hdr.local_mr_name)); if (local_mr == mrs_.end()) { DAQIRI_LOG_CRITICAL("Couldn't find MR with name {} in registry", - burst->rdma_hdr.local_mr_name); + burst->transport_hdr.local_mr_name); free_tx_burst(burst); continue; } - switch (burst->rdma_hdr.opcode) { + switch (burst->transport_hdr.opcode) { case RDMAOpCode::SEND: { // Get lkey for this PD auto pd = pd_map_.find(tparams->client_id->verbs); @@ -537,12 +541,12 @@ void RdmaMgr::rdma_thread(bool is_server, rdma_thread_params* tparams) { auto lkey = local_mr->second.ctx_mr_map_.find(pd->second); if (lkey == local_mr->second.ctx_mr_map_.end()) { DAQIRI_LOG_CRITICAL("Couldn't find MR with name {} in registry", - burst->rdma_hdr.local_mr_name); + burst->transport_hdr.local_mr_name); free_tx_burst(burst); continue; } - for (int p = 0; p < burst->rdma_hdr.num_pkts; p++) { + for (int p = 0; p < burst->transport_hdr.num_pkts; p++) { ibv_send_wr wr; ibv_send_wr* bad_wr; ibv_sge sge; @@ -551,7 +555,7 @@ void RdmaMgr::rdma_thread(bool is_server, rdma_thread_params* tparams) { sge.addr = (uint64_t)burst->pkts[0][p]; sge.length = (uint32_t)burst->pkt_lens[0][p]; sge.lkey = lkey->second->lkey; - wr.wr_id = burst->rdma_hdr.wr_id + p; // Auto-increment wr_id to be unique + wr.wr_id = burst->transport_hdr.wr_id + p; // Auto-increment wr_id to be unique wr.sg_list = &sge; wr.num_sge = 1; wr.opcode = IBV_WR_SEND; @@ -564,7 +568,7 @@ void RdmaMgr::rdma_thread(bool is_server, rdma_thread_params* tparams) { continue; } - outstanding_send_wr_ids[burst->rdma_hdr.wr_id + p] = burst; + outstanding_send_wr_ids[burst->transport_hdr.wr_id + p] = burst; } break; @@ -582,12 +586,12 @@ void RdmaMgr::rdma_thread(bool is_server, rdma_thread_params* tparams) { auto lkey = local_mr->second.ctx_mr_map_.find(pd->second); if (lkey == local_mr->second.ctx_mr_map_.end()) { DAQIRI_LOG_CRITICAL("Couldn't find MR with name {} in registry", - burst->rdma_hdr.local_mr_name); + burst->transport_hdr.local_mr_name); free_tx_burst(burst); continue; } - for (int p = 0; p < burst->rdma_hdr.num_pkts; p++) { + for (int p = 0; p < burst->transport_hdr.num_pkts; p++) { struct ibv_recv_wr recv_wr; struct ibv_sge sge; struct ibv_recv_wr* bad_wr = NULL; @@ -601,7 +605,7 @@ void RdmaMgr::rdma_thread(bool is_server, rdma_thread_params* tparams) { // Prepare Receive Work Request memset(&recv_wr, 0, sizeof(recv_wr)); - recv_wr.wr_id = burst->rdma_hdr.wr_id + p; + recv_wr.wr_id = burst->transport_hdr.wr_id + p; recv_wr.next = NULL; recv_wr.sg_list = &sge; recv_wr.num_sge = 1; @@ -614,7 +618,7 @@ void RdmaMgr::rdma_thread(bool is_server, rdma_thread_params* tparams) { continue; } - outstanding_receive_wr_ids[burst->rdma_hdr.wr_id + p] = burst; + outstanding_receive_wr_ids[burst->transport_hdr.wr_id + p] = burst; } break; } @@ -638,20 +642,20 @@ Status RdmaMgr::rdma_connect_to_server(const std::string& dst_addr, uint16_t dst } RDMAOpCode RdmaMgr::rdma_get_opcode(BurstParams* burst) { - return burst->rdma_hdr.opcode; + return burst->transport_hdr.opcode; } Status RdmaMgr::rdma_set_header(BurstParams* burst, RDMAOpCode op_code, uintptr_t conn_id, bool is_server, int num_pkts, uint64_t wr_id, const std::string& local_mr_name) { - burst->rdma_hdr.opcode = op_code; - burst->rdma_hdr.conn_id = conn_id; - burst->rdma_hdr.server = is_server; - burst->rdma_hdr.num_pkts = num_pkts; - burst->rdma_hdr.num_segs = 1; - burst->rdma_hdr.wr_id = wr_id; - snprintf(burst->rdma_hdr.local_mr_name, - sizeof(burst->rdma_hdr.local_mr_name), + burst->transport_hdr.opcode = op_code; + set_connection_id(burst, conn_id); + burst->transport_hdr.server = is_server; + burst->transport_hdr.num_pkts = num_pkts; + burst->transport_hdr.num_segs = 1; + burst->transport_hdr.wr_id = wr_id; + snprintf(burst->transport_hdr.local_mr_name, + sizeof(burst->transport_hdr.local_mr_name), "%s", local_mr_name.c_str()); return Status::SUCCESS; @@ -917,12 +921,12 @@ Status RdmaMgr::rdma_get_port_queue(uintptr_t conn_id, uint16_t* port, uint16_t* Status RdmaMgr::get_tx_packet_burst(BurstParams* burst) { // RDMA isn't allowing split segments yet - assert(burst->rdma_hdr.num_segs == 1); - assert(burst->rdma_hdr.num_pkts <= MAX_RDMA_BATCH); - auto burst_pool = mem_pools_.find(burst->rdma_hdr.local_mr_name); + assert(burst->transport_hdr.num_segs == 1); + assert(burst->transport_hdr.num_pkts <= MAX_RDMA_BATCH); + auto burst_pool = mem_pools_.find(burst->transport_hdr.local_mr_name); if (burst_pool == mem_pools_.end()) { DAQIRI_LOG_ERROR("Failed to look up burst pool name for MR {}", - burst->rdma_hdr.local_mr_name); + burst->transport_hdr.local_mr_name); return Status::INVALID_PARAMETER; } @@ -933,13 +937,13 @@ Status RdmaMgr::get_tx_packet_burst(BurstParams* burst) { int rx = rte_ring_dequeue_bulk(burst_pool->second, reinterpret_cast(burst->pkts[0]), - burst->rdma_hdr.num_pkts, + burst->transport_hdr.num_pkts, nullptr); - if (rx != burst->rdma_hdr.num_pkts) { - DAQIRI_LOG_ERROR("Asked for {} packets, got {}", burst->rdma_hdr.num_pkts, rx); + if (rx != burst->transport_hdr.num_pkts) { + DAQIRI_LOG_ERROR("Asked for {} packets, got {}", burst->transport_hdr.num_pkts, rx); rte_ring_enqueue_bulk(burst_pool->second, reinterpret_cast(burst->pkts[0]), - burst->rdma_hdr.num_pkts, + burst->transport_hdr.num_pkts, nullptr); return Status::NO_FREE_BURST_BUFFERS; } @@ -955,14 +959,14 @@ Status RdmaMgr::get_tx_packet_burst(BurstParams* burst) { } bool RdmaMgr::is_tx_burst_available(BurstParams* burst) { - auto burst_pool = mem_pools_.find(burst->rdma_hdr.local_mr_name); + auto burst_pool = mem_pools_.find(burst->transport_hdr.local_mr_name); if (burst_pool == mem_pools_.end()) { DAQIRI_LOG_ERROR("Failed to look up burst pool name for MR {}", - burst->rdma_hdr.local_mr_name); + burst->transport_hdr.local_mr_name); return false; } - if (rte_ring_count(burst_pool->second) < burst->rdma_hdr.num_pkts) { return false; } + if (rte_ring_count(burst_pool->second) < burst->transport_hdr.num_pkts) { return false; } return true; } @@ -1341,21 +1345,21 @@ void RdmaMgr::free_rx_burst(BurstParams* burst) { } void RdmaMgr::free_tx_burst(BurstParams* burst) { - auto burst_pool = mem_pools_.find(burst->rdma_hdr.local_mr_name); + auto burst_pool = mem_pools_.find(burst->transport_hdr.local_mr_name); if (burst_pool != mem_pools_.end()) { int ret = rte_ring_enqueue_bulk(burst_pool->second, reinterpret_cast(burst->pkts[0]), - burst->rdma_hdr.num_pkts, + burst->transport_hdr.num_pkts, nullptr); - if (ret != burst->rdma_hdr.num_pkts) { + if (ret != burst->transport_hdr.num_pkts) { DAQIRI_LOG_CRITICAL( - "Asked to free {} packets, only enqueued {}", burst->rdma_hdr.num_pkts, ret); + "Asked to free {} packets, only enqueued {}", burst->transport_hdr.num_pkts, ret); } } rte_mempool_put(tx_burst_pool_, (void*)burst->pkts[0]); rte_mempool_put(pkt_len_pool_, (void*)burst->pkt_lens[0]); - burst->rdma_hdr.num_pkts = 0; + burst->transport_hdr.num_pkts = 0; rte_mempool_put(tx_meta, burst); } diff --git a/src/managers/socket/daqiri_socket_mgr.cpp b/src/managers/socket/daqiri_socket_mgr.cpp index 3ccec89..841533e 100644 --- a/src/managers/socket/daqiri_socket_mgr.cpp +++ b/src/managers/socket/daqiri_socket_mgr.cpp @@ -815,7 +815,7 @@ Status SocketMgr::send_tx_burst(BurstParams* burst) { { std::lock_guard lock(state_mutex_); - auto requested_id = burst->rdma_hdr.conn_id; + auto requested_id = get_connection_id(burst); if (requested_id != 0) { auto it = connections_.find(requested_id); if (it != connections_.end()) { conn = it->second; } @@ -1230,7 +1230,7 @@ void SocketMgr::tcp_rx_loop(std::shared_ptr conn) { std::memcpy(payload, tmp.data(), static_cast(rx)); burst->pkts[0][0] = payload; burst->pkt_lens[0][0] = static_cast(rx); - burst->rdma_hdr.conn_id = conn->conn_id; + set_connection_id(burst, conn->conn_id); push_rx_burst(conn->rx_queue, burst); rx_pkts_.fetch_add(1); @@ -1301,7 +1301,7 @@ void SocketMgr::udp_rx_loop(int if_index) { std::memcpy(payload, iovs[static_cast(i)].iov_base, rx); burst->pkts[0][0] = payload; burst->pkt_lens[0][0] = static_cast(rx); - burst->rdma_hdr.conn_id = ep->primary_conn_id; + set_connection_id(burst, ep->primary_conn_id); push_rx_burst(ep->rx_queue_state, burst); rx_pkts_.fetch_add(1);