From 9a1554e2d5bcb9594e6cbc910916e4c0d18c2ebf Mon Sep 17 00:00:00 2001 From: chenBright Date: Tue, 2 Jun 2026 22:45:33 +0800 Subject: [PATCH] Support RDMA handshake V3 Background ========== The legacy v2 handshake ("RDMA" magic + 36B fixed binary HelloMessage) had correctness bugs that made the wire format effectively unevolvable. v2 bugs ======= A. Client never drained the "unknown tail" bytes when a peer sent msg_len > HELLO_MSG_LEN_MIN(40). Leftover bytes stayed in the socket recv buffer and silently corrupted the next ReadFromFd (the ACK). B. Server had the symmetric version of A. C. Server computed the body read length from its LOCAL g_rdma_hello_msg_len, implicitly assuming the peer's hello is the same length as its own. A longer peer left bytes behind; a shorter peer made the read block. Client used the correct compile-time constant; the two sides were not symmetric. Combined, A/B/C meant v2 could not safely append a single byte to the hello -- even an "optional hint" appended by a newer sender would mis-align an older receiver's next read. v3 design ========= Wire format, magic-namespace-isolated from v2: [ "RDM3" 4B ][ pb_size 4B big-endian ][ RdmaHello protobuf bytes ] with pb_size in (0, 4096]. RdmaHello carries the same 6 base fields as v2 plus room to append future capabilities. Why protobuf (and not "v2 plus length prefix") ---------------------------------------------- - Variable-length fields are coming. Future capability fields will include strings (rdma_device_name, netdev_name, ...) and other variable-length data. Supporting them on a fixed-binary protocol forces us to invent and maintain a TLV layer (per-field type + length + value framing, plus version-aware deserialization). That is reimplementing protobuf badly. Using protobuf from day one costs nothing and is the canonical answer. - Fixes v2 bug A/B/C generically: pb_size makes the wire self-describing, so the receiver never needs to guess the length or know the peer's schema version to read the body cleanly. - Append-only field evolution out of the box: new optional fields cost old receivers nothing -- they're skipped as unknown protobuf fields. v2 with a hand-rolled length prefix would still need per-field opt-in code on every side. - Built-in validation: ParseFromArray fails fast on malformed input; required-field presence is enforced at the parse layer, not by ad-hoc has_xxx() checks scattered through wire code. - bRPC already depends on protobuf -- no new build dependency. Why a NEW MAGIC rather than a version field inside protobuf ----------------------------------------------------------- - Forces "breaking change" to be a deployment decision visible at the wire level. You cannot accidentally ship a backwards- incompatible patch via a field-semantics tweak. - Server-side dispatch routes by magic to fully independent state machines that can't entangle (no `if (version == X)` branches anywhere -- this is the abstraction's red line). - Any future breaking change bumps the magic ("RDM4", "RDM5", ...). v3 fields, once shipped, never change semantics. Rollout ======= Server-side ALWAYS accepts both v2 and v3 (no gflag, no kill-switch); magic routes to fully independent code paths. A single rolling upgrade enables v3 fleet-wide. Client-side picks the wire protocol via gflag with a safe default: FLAGS_rdma_client_handshake_version (default 2) 2 = "RDMA" legacy (zero-regression default) 3 = "RDM3" protobuf (opt-in once target servers support v3) Sub-second rollback is one flag flip away. v3 client to v2-only legacy server is NOT guaranteed to transparently fall back on the same connection -- the supported migration is "upgrade servers first, then opt-in clients". --- .bazelrc | 2 + .../install-all-dependencies/action.yml | 2 +- .github/workflows/ci-linux.yml | 23 +- BUILD.bazel | 1 + CMakeLists.txt | 3 +- src/brpc/rdma/rdma_endpoint.cpp | 538 ++++----- src/brpc/rdma/rdma_endpoint.h | 46 + src/brpc/rdma/rdma_handshake.cpp | 408 +++++++ src/brpc/rdma/rdma_handshake.h | 192 +++ src/brpc/rdma/rdma_handshake.proto | 46 + src/brpc/rdma_transport.h | 9 +- src/brpc/server.cpp | 25 +- src/brpc/socket.cpp | 3 +- src/brpc/socket.h | 8 + src/butil/thread_key.h | 1 + test/brpc_rdma_unittest.cpp | 1060 +++++++++++++---- test/bvar_percentile_unittest.cpp | 2 + 17 files changed, 1806 insertions(+), 563 deletions(-) create mode 100644 src/brpc/rdma/rdma_handshake.cpp create mode 100644 src/brpc/rdma/rdma_handshake.h create mode 100644 src/brpc/rdma/rdma_handshake.proto diff --git a/.bazelrc b/.bazelrc index abf05fc6d7..c10fb589bc 100644 --- a/.bazelrc +++ b/.bazelrc @@ -50,6 +50,8 @@ build --features=per_object_debug_info # We already have absl in the build, define absl=1 to tell googletest to use absl for backtrace. build --define absl=1 +build:rdma --define BRPC_WITH_RDMA=true + # For UT. build:test --define BRPC_BUILD_FOR_UNITTEST=true # Hide libunwind's `_Unwind_*` symbols so they don't preempt libgcc_s at diff --git a/.github/actions/install-all-dependencies/action.yml b/.github/actions/install-all-dependencies/action.yml index 86d2884b97..5c1f673ff7 100644 --- a/.github/actions/install-all-dependencies/action.yml +++ b/.github/actions/install-all-dependencies/action.yml @@ -2,7 +2,7 @@ runs: using: "composite" steps: - uses: ./.github/actions/install-essential-dependencies - - run: sudo apt-get update && sudo apt-get install -y libunwind-dev libgoogle-glog-dev automake bison flex libboost-all-dev libevent-dev libtool pkg-config libibverbs1 libibverbs-dev + - run: sudo apt-get update && sudo apt-get install -y libunwind-dev libgoogle-glog-dev automake bison flex libboost-all-dev libevent-dev libtool pkg-config libibverbs-dev shell: bash - run: | wget https://archive.apache.org/dist/thrift/0.11.0/thrift-0.11.0.tar.gz && tar -xf thrift-0.11.0.tar.gz && cd thrift-0.11.0/ diff --git a/.github/workflows/ci-linux.yml b/.github/workflows/ci-linux.yml index 8a36af6024..a334b29126 100644 --- a/.github/workflows/ci-linux.yml +++ b/.github/workflows/ci-linux.yml @@ -29,7 +29,9 @@ jobs: - name: gcc with all options uses: ./.github/actions/compile-with-make with: - options: --headers=/usr/include --libs=/usr/lib /usr/lib64 --cc=gcc --cxx=g++ --werror --with-thrift --with-glog --with-rdma --with-debug-bthread-sche-safety --with-debug-lock --with-bthread-tracer --with-asan + options: --headers=/usr/include --libs=/usr/lib /usr/lib64 --cc=gcc --cxx=g++ --werror \ + --with-thrift --with-glog --with-rdma --with-debug-bthread-sche-safety \ + --with-debug-lock --with-bthread-tracer --with-asan - name: clang with default options uses: ./.github/actions/compile-with-make @@ -39,7 +41,9 @@ jobs: - name: clang with all options uses: ./.github/actions/compile-with-make with: - options: --headers=/usr/include --libs=/usr/lib /usr/lib64 --cc=clang --cxx=clang++ --werror --with-thrift --with-glog --with-rdma --with-debug-bthread-sche-safety --with-debug-lock --with-bthread-tracer --with-asan + options: --headers=/usr/include --libs=/usr/lib /usr/lib64 --cc=clang --cxx=clang++ --werror \ + --with-thrift --with-glog --with-rdma --with-debug-bthread-sche-safety \ + --with-debug-lock --with-bthread-tracer --with-asan compile-with-cmake: runs-on: ubuntu-22.04 @@ -57,7 +61,9 @@ jobs: run: | export CC=gcc && export CXX=g++ mkdir gcc_build_all && cd gcc_build_all - cmake -DWITH_MESALINK=OFF -DWITH_GLOG=ON -DWITH_THRIFT=ON -DWITH_RDMA=ON -DWITH_DEBUG_BTHREAD_SCHE_SAFETY=ON -DWITH_DEBUG_LOCK=ON -DWITH_BTHREAD_TRACER=ON -DWITH_ASAN=ON -DCMAKE_POLICY_VERSION_MINIMUM=3.5 .. + cmake -DWITH_MESALINK=OFF -DWITH_GLOG=ON -DWITH_THRIFT=ON -DWITH_RDMA=ON \ + -DWITH_DEBUG_BTHREAD_SCHE_SAFETY=ON -DWITH_DEBUG_LOCK=ON -DWITH_BTHREAD_TRACER=ON \ + -DWITH_ASAN=ON -DCMAKE_POLICY_VERSION_MINIMUM=3.5 .. make -j ${{env.proc_num}} && make clean - name: clang with default options @@ -70,7 +76,9 @@ jobs: run: | export CC=clang && export CXX=clang++ mkdir clang_build_all && cd clang_build_all - cmake -DWITH_MESALINK=OFF -DWITH_GLOG=ON -DWITH_THRIFT=ON -DWITH_RDMA=ON -DWITH_DEBUG_BTHREAD_SCHE_SAFETY=ON -DWITH_DEBUG_LOCK=ON -DWITH_BTHREAD_TRACER=ON -DWITH_ASAN=ON -DCMAKE_POLICY_VERSION_MINIMUM=3.5 .. + cmake -DWITH_MESALINK=OFF -DWITH_GLOG=ON -DWITH_THRIFT=ON -DWITH_RDMA=ON \ + -DWITH_DEBUG_BTHREAD_SCHE_SAFETY=ON -DWITH_DEBUG_LOCK=ON -DWITH_BTHREAD_TRACER=ON \ + -DWITH_ASAN=ON -DCMAKE_POLICY_VERSION_MINIMUM=3.5 .. make -j ${{env.proc_num}} && make clean gcc-compile-with-make-protobuf: @@ -160,6 +168,7 @@ jobs: runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v2 + - run: sudo apt-get update && sudo apt-get install -y libibverbs-dev - run: | bazel test --test_output=streamed \ --action_env=CC=clang \ @@ -229,6 +238,7 @@ jobs: USE_BAZEL_VERSION: "8.3.1" steps: - uses: actions/checkout@v2 + - run: sudo apt-get update && sudo apt-get install -y libibverbs-dev - name: Override protobuf version for testing run: | sed -i -E "s/(bazel_dep\(name = ['\"]protobuf['\"], version = ['\"])[^'\"]+/\1${TEST_PROTOBUF_VERSION}/" MODULE.bazel @@ -237,7 +247,6 @@ jobs: grep -qE "bazel_dep\(name = ['\"]protobuf['\"], version = ['\"]${TEST_PROTOBUF_VERSION}['\"]" MODULE.bazel \ || { echo "ERROR: failed to override protobuf version in MODULE.bazel to ${TEST_PROTOBUF_VERSION}"; exit 1; } - run: | - bazel test --action_env=CC=clang \ + bazel test --action_env=CC=clang --config=rdma \ --define with_babylon_counter=true \ - --define with_babylon_counter=true \ - //test:brpc_unittests + //test/... --test_arg=--gtest_filter=-RdmaRpcTest.* diff --git a/BUILD.bazel b/BUILD.bazel index 22cb508548..b51ee0f6b0 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -506,6 +506,7 @@ filegroup( srcs = glob([ "src/brpc/*.proto", "src/brpc/policy/*.proto", + "src/brpc/rdma/*.proto", ]), visibility = ["//visibility:public"], ) diff --git a/CMakeLists.txt b/CMakeLists.txt index 5e74007b66..a3ebb855cf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -555,7 +555,8 @@ set(PROTO_FILES idl_options.proto brpc/policy/mongo.proto brpc/trackme.proto brpc/streaming_rpc_meta.proto - brpc/proto_base.proto) + brpc/proto_base.proto + brpc/rdma/rdma_handshake.proto) file(MAKE_DIRECTORY ${PROJECT_BINARY_DIR}/output/include/brpc) set(PROTOC_FLAGS ${PROTOC_FLAGS} -I${PROTOBUF_INCLUDE_DIR}) compile_proto(PROTO_HDRS PROTO_SRCS ${PROJECT_BINARY_DIR} diff --git a/src/brpc/rdma/rdma_endpoint.cpp b/src/brpc/rdma/rdma_endpoint.cpp index f09d723ca1..658c7a2fcc 100644 --- a/src/brpc/rdma/rdma_endpoint.cpp +++ b/src/brpc/rdma/rdma_endpoint.cpp @@ -31,6 +31,7 @@ #include "brpc/rdma/rdma_helper.h" #include "brpc/rdma/rdma_endpoint.h" #include "brpc/rdma_transport.h" +#include "brpc/rdma/rdma_handshake.h" DECLARE_int32(task_group_ntags); @@ -70,84 +71,30 @@ static const size_t IOBUF_BLOCK_HEADER_LEN = 32; // implementation-dependent // DO NOT change this value unless you know the safe value!!! // This is the number of reserved WRs in SQ/RQ for pure ACK. -static const size_t RESERVED_WR_NUM = 3; - -// magic string RDMA (4B) -// message length (2B) -// hello version (2B) -// impl version (2B): 0 means should use tcp -// block size (4B) -// sq size (2B) -// rq size (2B) -// GID (16B) -// QP number (4B) -static const char* MAGIC_STR = "RDMA"; -static const size_t MAGIC_STR_LEN = 4; -static const size_t HELLO_MSG_LEN_MIN = 40; -// static const size_t HELLO_MSG_LEN_MAX = 4096; -static const size_t ACK_MSG_LEN = 4; -static uint16_t g_rdma_hello_msg_len = 40; // In Byte -static uint16_t g_rdma_hello_version = 2; -static uint16_t g_rdma_impl_version = 1; -static uint32_t g_rdma_recv_block_size = 0; +extern const size_t RESERVED_WR_NUM = 3; + +// The local recv block size, set during GlobalInitialize. +uint32_t g_rdma_recv_block_size = 0; // static const uint32_t MAX_INLINE_DATA = 64; static const uint8_t MAX_HOP_LIMIT = 16; static const uint8_t TIMEOUT = 14; static const uint8_t RETRY_CNT = 7; -static const uint16_t MIN_QP_SIZE = 16; +extern const uint16_t MIN_QP_SIZE = 16; static const uint16_t MAX_QP_SIZE = 4096; -static const uint16_t MIN_BLOCK_SIZE = 1024; -static const uint32_t ACK_MSG_RDMA_OK = 0x1; +extern const uint16_t MIN_BLOCK_SIZE = 1024; + +// ACK message wire format (shared by all protocol versions): a single +// 4B big-endian flags word; bit 0 (HELLO_ACK_RDMA_OK) indicates the +// sender wants to use RDMA. The state machines in +// ProcessHandshakeAt{Client,Server} inline the corresponding 4B +// send/recv directly using ReadFromFd / WriteToFd. +static const size_t HELLO_ACK_LEN = 4; +static const uint32_t HELLO_ACK_RDMA_OK = 0x1; static butil::Mutex* g_rdma_resource_mutex = NULL; static RdmaResource* g_rdma_resource_list = NULL; -struct HelloMessage { - void Serialize(void* data) const; - void Deserialize(void* data); - - uint16_t msg_len; - uint16_t hello_ver; - uint16_t impl_ver; - uint32_t block_size; - uint16_t sq_size; - uint16_t rq_size; - uint16_t lid; - ibv_gid gid; - uint32_t qp_num; -}; - -void HelloMessage::Serialize(void* data) const { - uint16_t* current_pos = (uint16_t*)data; - *(current_pos++) = butil::HostToNet16(msg_len); - *(current_pos++) = butil::HostToNet16(hello_ver); - *(current_pos++) = butil::HostToNet16(impl_ver); - uint32_t* block_size_pos = (uint32_t*)current_pos; - *block_size_pos = butil::HostToNet32(block_size); - current_pos += 2; // move forward 4 Bytes - *(current_pos++) = butil::HostToNet16(sq_size); - *(current_pos++) = butil::HostToNet16(rq_size); - *(current_pos++) = butil::HostToNet16(lid); - memcpy(current_pos, gid.raw, 16); - uint32_t* qp_num_pos = (uint32_t*)((char*)current_pos + 16); - *qp_num_pos = butil::HostToNet32(qp_num); -} - -void HelloMessage::Deserialize(void* data) { - uint16_t* current_pos = (uint16_t*)data; - msg_len = butil::NetToHost16(*current_pos++); - hello_ver = butil::NetToHost16(*current_pos++); - impl_ver = butil::NetToHost16(*current_pos++); - block_size = butil::NetToHost32(*(uint32_t*)current_pos); - current_pos += 2; // move forward 4 Bytes - sq_size = butil::NetToHost16(*current_pos++); - rq_size = butil::NetToHost16(*current_pos++); - lid = butil::NetToHost16(*current_pos++); - memcpy(gid.raw, current_pos, 16); - qp_num = butil::NetToHost32(*(uint32_t*)((char*)current_pos + 16)); -} - RdmaResource::~RdmaResource() { if (NULL != qp) { IbvDestroyQp(qp); @@ -169,6 +116,7 @@ RdmaResource::~RdmaResource() { RdmaEndpoint::RdmaEndpoint(Socket* s) : _socket(s) , _state(UNINIT) + , _handshake_version(0) , _resource(NULL) , _send_cq_events(0) , _recv_cq_events(0) @@ -348,31 +296,34 @@ void RdmaEndpoint::OnNewDataFromTcp(Socket* m) { } } -bool HelloNegotiationValid(HelloMessage& msg) { - if (msg.hello_ver == g_rdma_hello_version && - msg.impl_ver == g_rdma_impl_version && - msg.block_size >= MIN_BLOCK_SIZE && - msg.sq_size >= MIN_QP_SIZE && - msg.rq_size >= MIN_QP_SIZE) { - // This can be modified for future compatibility - return true; - } - return false; -} - static const int WAIT_TIMEOUT_MS = 50; -int RdmaEndpoint::ReadFromFd(void* data, size_t len) { - CHECK(data != NULL); - int nr = 0; +// Drive an EAGAIN-aware read loop to completion (exactly `len` bytes). +// `read_once(offset, remaining)` performs ONE underlying read attempt: +// - returns > 0 : number of bytes consumed (added to running total); +// - returns = 0 : end-of-stream (the loop fails with EEOF); +// - returns < 0 : errno set; EAGAIN is handled here via butex_wait, +// any other errno bubbles up. +// `offset` is bytes already received in THIS call (initially 0); the +// callable uses it to choose the next write target (e.g. `(char*)buf +// + offset`). Callables that don't need offset (e.g. IOPortal append) +// can ignore it. +// +// Centralizes the EAGAIN/butex/EOF loop so the two ReadFromFd +// overloads below stay one-liners; any future read source (memory- +// mapped, scatter-vector, etc.) can plug in by passing its own +// `read_once`. +template +static int ReadFromFdLoop(butil::atomic* read_butex, + size_t len, ReadOnce&& read_once) { size_t received = 0; - do { - const int expected_val = _read_butex->load(butil::memory_order_acquire); + while (received < len) { + const int expected_val = read_butex->load(butil::memory_order_acquire); const timespec duetime = butil::milliseconds_from_now(WAIT_TIMEOUT_MS); - nr = read(_socket->fd(), (uint8_t*)data + received, len - received); + ssize_t nr = read_once(received, len - received); if (nr < 0) { if (errno == EAGAIN) { - if (bthread::butex_wait(_read_butex, expected_val, &duetime) < 0) { + if (bthread::butex_wait(read_butex, expected_val, &duetime) < 0) { if (errno != EWOULDBLOCK && errno != ETIMEDOUT) { return -1; } @@ -386,34 +337,89 @@ int RdmaEndpoint::ReadFromFd(void* data, size_t len) { } else { received += nr; } - } while (received < len); + } return 0; } -int RdmaEndpoint::WriteToFd(void* data, size_t len) { +int RdmaEndpoint::ReadFromFd(void* data, size_t len) { + CHECK(data != NULL); + const int fd = _socket->fd(); + return ReadFromFdLoop(_read_butex, len, + [data, fd](size_t offset, size_t remaining) { + return read(fd, (uint8_t*)data + offset, remaining); + }); +} + +int RdmaEndpoint::ReadFromFd(butil::IOPortal* data, size_t len) { CHECK(data != NULL); - int nw = 0; + const int fd = _socket->fd(); + return ReadFromFdLoop(_read_butex, len, + [data, fd](size_t /*offset*/, size_t remaining) { + return data->append_from_file_descriptor(fd, remaining); + }); +} + +// Drive an EAGAIN-aware write loop to completion (exactly `len` bytes). +// +// `write_once(offset, remaining)` performs ONE underlying write attempt: +// - returns >= 0 : number of bytes consumed (added to running total); +// - returns < 0 : errno set; EAGAIN triggers `wait_writable(duetime)`, +// any other errno bubbles up. +// `offset` is bytes already written in THIS call (initially 0); the +// callable uses it to choose the next read source (e.g. `(char*)buf +// + offset`). Callables that drain a self-tracking sink (e.g. +// IOBuf::cut_into_file_descriptor) can ignore both args. +// +// `wait_writable(duetime)` is invoked on EAGAIN to park until the fd +// becomes writable again. It returns 0 on wake-up (or ETIMEDOUT), +// non-zero on hard failure. +template +static int WriteToFdLoop(size_t len, WriteOnce&& write_once, WaitWritable&& wait_writable) { size_t written = 0; - do { + while (written < len) { const timespec duetime = butil::milliseconds_from_now(WAIT_TIMEOUT_MS); - nw = write(_socket->fd(), (uint8_t*)data + written, len - written); - if (nw < 0) { - if (errno == EAGAIN) { - if (_socket->WaitEpollOut(_socket->fd(), true, &duetime) < 0) { - if (errno != ETIMEDOUT) { - return -1; - } - } - } else { - return -1; - } - } else { + ssize_t nw = write_once(written, len - written); + if (nw >= 0) { written += nw; + continue; + } + + if (errno != EAGAIN) { + return -1; } - } while (written < len); + if (!wait_writable(&duetime)) { + return -1; + } + } return 0; } +int RdmaEndpoint::WriteToFd(void* data, size_t len) { + CHECK(data != NULL); + Socket* s = _socket; + const int fd = s->fd(); + return WriteToFdLoop(len, + [data, fd](size_t offset, size_t remaining) { + return write(fd, (uint8_t*)data + offset, remaining); + }, + [s, fd](const timespec* duetime) { + return s->WaitEpollOut(fd, true, duetime) == 0 || errno == ETIMEDOUT; + }); +} + +int RdmaEndpoint::WriteToFd(butil::IOBuf* data) { + CHECK(data != NULL); + Socket* s = _socket; + const int fd = s->fd(); + return WriteToFdLoop(data->size(), + [data, fd](size_t /*offset*/, size_t /*remaining*/) { + return data->cut_into_file_descriptor(fd); + }, + [s, fd](const timespec* duetime) { + return s->WaitEpollOut(fd, true, duetime) == 0 || errno == ETIMEDOUT; + }); +} + inline void RdmaEndpoint::TryReadOnTcp() { if (_socket->_nevent.fetch_add(1, butil::memory_order_acq_rel) == 0) { if (_state == FALLBACK_TCP) { @@ -424,19 +430,52 @@ inline void RdmaEndpoint::TryReadOnTcp() { } } +void RdmaEndpoint::ApplyRemoteHello(const ParsedHello& remote) { + _remote_recv_block_size = remote.block_size; + _local_window_capacity = + std::min(_sq_size, remote.rq_size) - RESERVED_WR_NUM; + _remote_window_capacity = + std::min(_rq_size, remote.sq_size) - RESERVED_WR_NUM; + _sq_imm_window_size = RESERVED_WR_NUM; + _remote_rq_window_size.store( + _local_window_capacity, butil::memory_order_relaxed); + _sq_window_size.store( + _local_window_capacity, butil::memory_order_relaxed); +} + +// Client-side handshake entry: the state machine. +// +// C_ALLOC_QPCQ +// | +// v +// C_HELLO_SEND (hs->SendLocalHello) +// | +// v +// C_HELLO_WAIT (hs->ReceiveAndParseRemoteHello) +// | +// v +// [negotiation: ApplyRemoteHello + C_BRINGUP_QP] +// | +// v +// C_ACK_SEND +// | +// v +// ESTABLISHED / FALLBACK_TCP void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) { - RdmaEndpoint* ep = static_cast(arg); + auto ep = static_cast(arg); SocketUniquePtr s(ep->_socket); RdmaConnect::RunGuard rg((RdmaConnect*)s->_app_connect.get()); + auto rdma_transport = static_cast(s->_transport.get()); - LOG_IF(INFO, FLAGS_rdma_trace_verbose) - << "Start handshake on " << s->_local_side; + LOG_IF(INFO, FLAGS_rdma_trace_verbose) + << "Start handshake on " << s->description(); - uint8_t data[g_rdma_hello_msg_len]; + std::unique_ptr handshake = CreateClientHandshake(ep); + CHECK(handshake != NULL); + ep->_handshake_version = handshake->ProtocolVersion(); - // First initialize CQ and QP resources + // First initialize CQ and QP resources. ep->_state = C_ALLOC_QPCQ; - auto* rdma_transport = static_cast(s->_transport.get()); if (ep->AllocateResources() < 0) { LOG(WARNING) << "Fallback to tcp:" << s->description(); rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; @@ -446,94 +485,40 @@ void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) { // Send hello message to server ep->_state = C_HELLO_SEND; - HelloMessage local_msg; - local_msg.msg_len = g_rdma_hello_msg_len; - local_msg.hello_ver = g_rdma_hello_version; - local_msg.impl_ver = g_rdma_impl_version; - local_msg.block_size = g_rdma_recv_block_size; - local_msg.sq_size = ep->_sq_size; - local_msg.rq_size = ep->_rq_size; - local_msg.lid = GetRdmaLid(); - local_msg.gid = GetRdmaGid(); - if (BAIDU_LIKELY(ep->_resource)) { - local_msg.qp_num = ep->_resource->qp->qp_num; - } else { - // Only happens in UT - local_msg.qp_num = 0; - } - memcpy(data, MAGIC_STR, 4); - local_msg.Serialize((char*)data + 4); - if (ep->WriteToFd(data, g_rdma_hello_msg_len) < 0) { - const int saved_errno = errno; - PLOG(WARNING) << "Fail to send hello message to server:" << s->description(); + if (handshake->SendLocalHello() < 0) { + int saved_errno = errno; + PLOG(WARNING) << "Fail to send hello message to server:" + << s->description(); s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", - s->description().c_str(), berror(saved_errno)); + s->description().c_str(), berror(saved_errno)); ep->_state = FAILED; return NULL; } - // Check magic str + // Receive and parse remote hello. ep->_state = C_HELLO_WAIT; - if (ep->ReadFromFd(data, MAGIC_STR_LEN) < 0) { - const int saved_errno = errno; - PLOG(WARNING) << "Fail to get hello message from server:" << s->description(); + ParsedHello remote{}; + bool negotiated = false; + if (handshake->ReceiveAndParseRemoteHello(&remote, &negotiated) < 0) { + int saved_errno = errno; + PLOG(WARNING) << "Fail to receive hello from server:" + << s->description(); s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", - s->description().c_str(), berror(saved_errno)); - ep->_state = FAILED; - return NULL; - } - if (memcmp(data, MAGIC_STR, MAGIC_STR_LEN) != 0) { - LOG(WARNING) << "Read unexpected data during handshake:" << s->description(); - s->SetFailed(EPROTO, "Fail to complete rdma handshake from %s: %s", - s->description().c_str(), berror(EPROTO)); + s->description().c_str(), berror(saved_errno)); ep->_state = FAILED; return NULL; } - // Read hello message from server - if (ep->ReadFromFd(data, HELLO_MSG_LEN_MIN - MAGIC_STR_LEN) < 0) { - const int saved_errno = errno; - PLOG(WARNING) << "Fail to get Hello Message from server:" << s->description(); - s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", - s->description().c_str(), berror(saved_errno)); - ep->_state = FAILED; - return NULL; - } - HelloMessage remote_msg; - remote_msg.Deserialize(data); - if (remote_msg.msg_len < HELLO_MSG_LEN_MIN) { - LOG(WARNING) << "Fail to parse Hello Message length from server:" - << s->description(); - s->SetFailed(EPROTO, "Fail to complete rdma handshake from %s: %s", - s->description().c_str(), berror(EPROTO)); - ep->_state = FAILED; - return NULL; - } - - if (remote_msg.msg_len > HELLO_MSG_LEN_MIN) { - // TODO: Read Hello Message customized data - // Just for future use, should not happen now - } - - if (!HelloNegotiationValid(remote_msg)) { + if (!negotiated) { LOG(WARNING) << "Fail to negotiate with server, fallback to tcp:" << s->description(); rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; } else { - ep->_remote_recv_block_size = remote_msg.block_size; - ep->_local_window_capacity = - std::min(ep->_sq_size, remote_msg.rq_size) - RESERVED_WR_NUM; - ep->_remote_window_capacity = - std::min(ep->_rq_size, remote_msg.sq_size) - RESERVED_WR_NUM; - ep->_sq_imm_window_size = RESERVED_WR_NUM; - ep->_remote_rq_window_size.store( - ep->_local_window_capacity, butil::memory_order_relaxed); - ep->_sq_window_size.store( - ep->_local_window_capacity, butil::memory_order_relaxed); - + ep->ApplyRemoteHello(remote); ep->_state = C_BRINGUP_QP; - if (ep->BringUpQp(remote_msg.lid, remote_msg.gid, remote_msg.qp_num) < 0) { - LOG(WARNING) << "Fail to bringup QP, fallback to tcp:" << s->description(); + if (ep->BringUpQp(remote.lid, remote.gid, remote.qp_num) < 0) { + LOG(WARNING) << "Fail to bringup QP, fallback to tcp:" + << s->description(); rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; } else { rdma_transport->_rdma_state = RdmaTransport::RDMA_ON; @@ -542,28 +527,26 @@ void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) { // Send ACK message to server ep->_state = C_ACK_SEND; - uint32_t flags = 0; - if (rdma_transport->_rdma_state != RdmaTransport::RDMA_OFF) { - flags |= ACK_MSG_RDMA_OK; - } - uint32_t* tmp = (uint32_t*)data; // avoid GCC warning on strict-aliasing - *tmp = butil::HostToNet32(flags); - if (ep->WriteToFd(data, ACK_MSG_LEN) < 0) { - const int saved_errno = errno; - PLOG(WARNING) << "Fail to send Ack Message to server:" << s->description(); + uint32_t flags = rdma_transport->_rdma_state != RdmaTransport::RDMA_OFF ? HELLO_ACK_RDMA_OK : 0; + uint32_t flags_be = butil::HostToNet32(flags); + if (ep->WriteToFd(&flags_be, HELLO_ACK_LEN) < 0) { + int saved_errno = errno; + PLOG(WARNING) << "Fail to send Ack Message to server:" + << s->description(); s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", - s->description().c_str(), berror(saved_errno)); + s->description().c_str(), berror(saved_errno)); ep->_state = FAILED; return NULL; } if (rdma_transport->_rdma_state == RdmaTransport::RDMA_ON) { ep->_state = ESTABLISHED; - LOG_IF(INFO, FLAGS_rdma_trace_verbose) - << "Client handshake ends (use rdma) on " << s->description(); + LOG_IF(INFO, FLAGS_rdma_trace_verbose) + << "Client handshake ends (use rdma v" << ep->_handshake_version + << ") on " << s->description(); } else { ep->_state = FALLBACK_TCP; - LOG_IF(INFO, FLAGS_rdma_trace_verbose) + LOG_IF(INFO, FLAGS_rdma_trace_verbose) << "Client handshake ends (use tcp) on " << s->description(); } @@ -572,77 +555,75 @@ void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) { return NULL; } +// Server-side handshake entry: the state machine. +// +// S_HELLO_WAIT (read magic + dispatch + hs->ReceiveAndParseRemoteHello) +// | +// v +// [negotiation: ApplyRemoteHello + S_ALLOC_QPCQ + S_BRINGUP_QP] +// | +// v +// S_HELLO_SEND (hs->SendLocalHello) +// | +// v +// S_ACK_WAIT +// | +// v +// ESTABLISHED / FALLBACK_TCP void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) { - RdmaEndpoint* ep = static_cast(arg); + auto ep = static_cast(arg); SocketUniquePtr s(ep->_socket); + auto rdma_transport = static_cast(s->_transport.get()); - LOG_IF(INFO, FLAGS_rdma_trace_verbose) + LOG_IF(INFO, FLAGS_rdma_trace_verbose) << "Start handshake on " << s->description(); - uint8_t data[g_rdma_hello_msg_len]; - ep->_state = S_HELLO_WAIT; - if (ep->ReadFromFd(data, MAGIC_STR_LEN) < 0) { - const int saved_errno = errno; - PLOG(WARNING) << "Fail to read Hello Message from client:" << s->description() << " " << s->_remote_side; + uint8_t magic[MAGIC_STR_LEN]; + if (ep->ReadFromFd(magic, MAGIC_STR_LEN) < 0) { + int saved_errno = errno; + PLOG(WARNING) << "Fail to read Hello Message from client:" + << s->description() << " " << s->_remote_side; s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", - s->description().c_str(), berror(saved_errno)); + s->description().c_str(), berror(saved_errno)); ep->_state = FAILED; return NULL; } - auto* rdma_transport = static_cast(s->_transport.get()); - if (memcmp(data, MAGIC_STR, MAGIC_STR_LEN) != 0) { - LOG_IF(INFO, FLAGS_rdma_trace_verbose) << "It seems that the " - << "client does not use RDMA, fallback to TCP:" + + // Dispatch on magic, or fall back to TCP + std::unique_ptr handshake = CreateServerHandshakeByMagic(ep, magic); + if (!handshake) { + LOG_IF(INFO, FLAGS_rdma_trace_verbose) + << "It seems that the client does not use RDMA, fallback to TCP:" << s->description(); - // we need to copy data read back to _socket->_read_buf - s->_read_buf.append(data, MAGIC_STR_LEN); + // We need to copy data read back to _socket->_read_buf. + s->_read_buf.append(magic, MAGIC_STR_LEN); ep->_state = FALLBACK_TCP; rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; ep->TryReadOnTcp(); return NULL; } + ep->_handshake_version = handshake->ProtocolVersion(); - if (ep->ReadFromFd(data, g_rdma_hello_msg_len - MAGIC_STR_LEN) < 0) { - const int saved_errno = errno; - PLOG(WARNING) << "Fail to read Hello Message from client:" << s->description(); + // Magic was already consumed above; the subclass MUST NOT re-read it. + ParsedHello remote{}; + bool negotiated = false; + if (handshake->ReceiveAndParseRemoteHello(&remote, &negotiated) < 0) { + int saved_errno = errno; + PLOG(WARNING) << "Fail to receive hello from client:" + << s->description(); s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", - s->description().c_str(), berror(saved_errno)); + s->description().c_str(), berror(saved_errno)); ep->_state = FAILED; return NULL; } - HelloMessage remote_msg; - remote_msg.Deserialize(data); - if (remote_msg.msg_len < HELLO_MSG_LEN_MIN) { - LOG(WARNING) << "Fail to parse Hello Message length from client:" - << s->description(); - s->SetFailed(EPROTO, "Fail to complete rdma handshake from %s: %s", - s->description().c_str(), berror(EPROTO)); - ep->_state = FAILED; - return NULL; - } - if (remote_msg.msg_len > HELLO_MSG_LEN_MIN) { - // TODO: Read Hello Message customized header - // Just for future use, should not happen now - } - - if (!HelloNegotiationValid(remote_msg)) { + if (!negotiated) { LOG(WARNING) << "Fail to negotiate with client, fallback to tcp:" << s->description(); rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; } else { - ep->_remote_recv_block_size = remote_msg.block_size; - ep->_local_window_capacity = - std::min(ep->_sq_size, remote_msg.rq_size) - RESERVED_WR_NUM; - ep->_remote_window_capacity = - std::min(ep->_rq_size, remote_msg.sq_size) - RESERVED_WR_NUM; - ep->_sq_imm_window_size = RESERVED_WR_NUM; - ep->_remote_rq_window_size.store( - ep->_local_window_capacity, butil::memory_order_relaxed); - ep->_sq_window_size.store( - ep->_local_window_capacity, butil::memory_order_relaxed); - + ep->ApplyRemoteHello(remote); ep->_state = S_ALLOC_QPCQ; if (ep->AllocateResources() < 0) { LOG(WARNING) << "Fail to allocate rdma resources, fallback to tcp:" @@ -650,7 +631,7 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) { rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; } else { ep->_state = S_BRINGUP_QP; - if (ep->BringUpQp(remote_msg.lid, remote_msg.gid, remote_msg.qp_num) < 0) { + if (ep->BringUpQp(remote.lid, remote.gid, remote.qp_num) < 0) { LOG(WARNING) << "Fail to bringup QP, fallback to tcp:" << s->description(); rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; @@ -658,73 +639,55 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) { } } - // Send hello message to client ep->_state = S_HELLO_SEND; - HelloMessage local_msg; - local_msg.msg_len = g_rdma_hello_msg_len; - if (rdma_transport->_rdma_state == RdmaTransport::RDMA_OFF) { - local_msg.impl_ver = 0; - local_msg.hello_ver = 0; - } else { - local_msg.lid = GetRdmaLid(); - local_msg.gid = GetRdmaGid(); - local_msg.block_size = g_rdma_recv_block_size; - local_msg.sq_size = ep->_sq_size; - local_msg.rq_size = ep->_rq_size; - local_msg.hello_ver = g_rdma_hello_version; - local_msg.impl_ver = g_rdma_impl_version; - if (BAIDU_LIKELY(ep->_resource)) { - local_msg.qp_num = ep->_resource->qp->qp_num; - } else { - // Only happens in UT - local_msg.qp_num = 0; - } - } - memcpy(data, MAGIC_STR, 4); - local_msg.Serialize((char*)data + 4); - if (ep->WriteToFd(data, g_rdma_hello_msg_len) < 0) { - const int saved_errno = errno; - PLOG(WARNING) << "Fail to send Hello Message to client:" << s->description(); + if (handshake->SendLocalHello() < 0) { + int saved_errno = errno; + PLOG(WARNING) << "Fail to send Hello Message to client:" + << s->description(); s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", - s->description().c_str(), berror(saved_errno)); + s->description().c_str(), berror(saved_errno)); ep->_state = FAILED; return NULL; } - // Recv ACK Message ep->_state = S_ACK_WAIT; - if (ep->ReadFromFd(data, ACK_MSG_LEN) < 0) { - const int saved_errno = errno; - PLOG(WARNING) << "Fail to read ack message from client:" << s->description(); + uint32_t flags_be = 0; + if (ep->ReadFromFd(&flags_be, HELLO_ACK_LEN) < 0) { + int saved_errno = errno; + PLOG(WARNING) << "Fail to read ack message from client:" + << s->description(); s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", - s->description().c_str(), berror(saved_errno)); + s->description().c_str(), berror(saved_errno)); ep->_state = FAILED; return NULL; } + uint32_t flags = butil::NetToHost32(flags_be); + bool client_ack_ok = (flags & HELLO_ACK_RDMA_OK) != 0; - // Check RDMA enable flag - uint32_t* tmp = (uint32_t*)data; // avoid GCC warning on strict-aliasing - uint32_t flags = butil::NetToHost32(*tmp); - if (flags & ACK_MSG_RDMA_OK) { + if (client_ack_ok) { if (rdma_transport->_rdma_state == RdmaTransport::RDMA_OFF) { - LOG(WARNING) << "Fail to parse Hello Message length from client:" - << s->description(); + // Client asked for RDMA but we are falling back: protocol + // breakdown, abort the connection so the client sees a + // clean error rather than a half-up RDMA channel. + LOG(WARNING) << "Client wants RDMA in ACK but server is in " + << "RDMA_OFF state: " << s->description(); s->SetFailed(EPROTO, "Fail to complete rdma handshake from %s: %s", - s->description().c_str(), berror(EPROTO)); + s->description().c_str(), berror(EPROTO)); ep->_state = FAILED; return NULL; - } else { - rdma_transport->_rdma_state = RdmaTransport::RDMA_ON; - ep->_state = ESTABLISHED; - LOG_IF(INFO, FLAGS_rdma_trace_verbose) - << "Server handshake ends (use rdma) on " << s->description(); } + rdma_transport->_rdma_state = RdmaTransport::RDMA_ON; + ep->_state = ESTABLISHED; + LOG_IF(INFO, FLAGS_rdma_trace_verbose) + << "Server handshake ends (use rdma v" << ep->_handshake_version + << ") on " << s->description(); } else { rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; ep->_state = FALLBACK_TCP; - LOG_IF(INFO, FLAGS_rdma_trace_verbose) + LOG_IF(INFO, FLAGS_rdma_trace_verbose) << "Server handshake ends (use tcp) on " << s->description(); } + ep->TryReadOnTcp(); return NULL; @@ -1076,7 +1039,7 @@ int RdmaEndpoint::PostRecv(uint32_t num, bool zerocopy) { PLOG(WARNING) << "Fail to allocate rbuf"; return -1; } else { - CHECK(static_cast(size) == g_rdma_recv_block_size) << size; + CHECK_EQ(static_cast(size), g_rdma_recv_block_size); } } if (DoPostRecv(_rbuf_data[_rq_received], g_rdma_recv_block_size) < 0) { @@ -1645,6 +1608,7 @@ std::string RdmaEndpoint::GetStateStr() const { void RdmaEndpoint::DebugInfo(std::ostream& os, butil::StringPiece connector) const { os << "rdma_state=ON" << connector << "handshake_state=" << GetStateStr() + << connector << "handshake_version=" << static_cast(_handshake_version) << connector << "rdma_sq_imm_window_size=" << _sq_imm_window_size << connector << "rdma_remote_rq_window_size=" << _remote_rq_window_size.load(butil::memory_order_relaxed) << connector << "rdma_sq_window_size=" << _sq_window_size.load(butil::memory_order_relaxed) diff --git a/src/brpc/rdma/rdma_endpoint.h b/src/brpc/rdma/rdma_endpoint.h index 54a008f1f7..7b6652bc86 100644 --- a/src/brpc/rdma/rdma_endpoint.h +++ b/src/brpc/rdma/rdma_endpoint.h @@ -40,6 +40,24 @@ DECLARE_bool(rdma_use_polling); DECLARE_int32(rdma_poller_num); DECLARE_bool(rdma_disable_bthread); +class RdmaHandshakeClientV2; +class RdmaHandshakeServerV2; +class RdmaHandshakeClientV3; +class RdmaHandshakeServerV3; +struct ParsedHello; +class RdmaHello; +class RdmaEndpoint; +namespace v2_wire { + int ReadBodyAndNegotiate(RdmaEndpoint* ep, ParsedHello* remote, bool* negotiated); + int DrainBytes(RdmaEndpoint* ep, size_t n); +} // namespace v2_wire + +namespace v3_wire { + void FillLocalRdmaHello(const RdmaEndpoint* ep, RdmaHello* msg); + int ReadAndParseV3Hello(RdmaEndpoint* ep, RdmaHello* out); + int WriteV3Hello(RdmaEndpoint* ep, const RdmaHello& msg); +} // namespace v3_wire + class RdmaConnect : public AppConnect { public: void StartConnect(const Socket* socket, @@ -74,6 +92,15 @@ struct RdmaResource { class BAIDU_CACHELINE_ALIGNMENT RdmaEndpoint : public SocketUser { friend class RdmaConnect; friend class Socket; +friend class RdmaHandshakeClientV2; +friend class RdmaHandshakeServerV2; +friend class RdmaHandshakeClientV3; +friend class RdmaHandshakeServerV3; +friend int v2_wire::ReadBodyAndNegotiate(RdmaEndpoint*, ParsedHello*, bool*); +friend int v2_wire::DrainBytes(RdmaEndpoint*, size_t); +friend void v3_wire::FillLocalRdmaHello(const RdmaEndpoint*, RdmaHello*); +friend int v3_wire::ReadAndParseV3Hello(RdmaEndpoint*, RdmaHello*); +friend int v3_wire::WriteV3Hello(RdmaEndpoint*, const RdmaHello&); public: explicit RdmaEndpoint(Socket* s); ~RdmaEndpoint() override; @@ -181,6 +208,7 @@ friend class Socket; // wait for _read_butex if encounter EAGAIN // return -1 if encounter other errno (including EOF) int ReadFromFd(void* data, size_t len); + int ReadFromFd(butil::IOPortal* data, size_t len); // Write at most len bytes from data to fd in _socket @@ -188,6 +216,17 @@ friend class Socket; // return -1 if encounter other errno int WriteToFd(void* data, size_t len); + // Write data to fd in _socket. + // wait for _epollout_butex if encounter EAGAIN. + // return -1 if encounter other errno. + int WriteToFd(butil::IOBuf* data); + + // Copy negotiated remote parameters into the endpoint and compute + // the SQ/RQ window capacities. Called by both + // ProcessHandshakeAtClient and ProcessHandshakeAtServer after the + // peer's hello has been validated. + void ApplyRemoteHello(const ParsedHello& remote); + // Bringup the QP from RESET state to RTS state // Arguments: // lid: remote LID @@ -225,6 +264,13 @@ friend class Socket; // State of Handshake State _state; + // Wire-level handshake protocol version (set by dispatch in + // ProcessHandshakeAtClient/Server). Aligned with the protocol code: + // 0 = unnegotiated + // 2 = v2 "RDMA" + // 3 = v3 "RDM3" + int _handshake_version; + // rdma resource RdmaResource* _resource; diff --git a/src/brpc/rdma/rdma_handshake.cpp b/src/brpc/rdma/rdma_handshake.cpp new file mode 100644 index 0000000000..9bd2312ec4 --- /dev/null +++ b/src/brpc/rdma/rdma_handshake.cpp @@ -0,0 +1,408 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#if BRPC_WITH_RDMA + +#include "brpc/rdma/rdma_handshake.h" + +#include +#include // std::min +#include +#include +#include +#include "butil/iobuf.h" // IOBuf, IOPortal, IOBufAsZeroCopy*Stream +#include "butil/sys_byteorder.h" +#include "brpc/socket.h" +#include "brpc/rdma/rdma_endpoint.h" +#include "brpc/rdma/rdma_helper.h" +#include "brpc/rdma_transport.h" +#include "brpc/rdma/rdma_handshake.pb.h" + +namespace brpc { +namespace rdma { + +DEFINE_int32(rdma_client_handshake_version, 2, + "RDMA handshake protocol version used by client. " + "2 = legacy 'RDMA' magic (default, compatible with all servers); " + "3 = new 'RDM3' protobuf-based handshake " + "(MUST only be enabled after target servers support v3)."); + +extern const uint16_t MIN_QP_SIZE; +extern const uint16_t MIN_BLOCK_SIZE; +extern uint32_t g_rdma_recv_block_size; +extern bool g_skip_rdma_init; + +// Wire-level constants for the v2 handshake. +static const char* MAGIC_STR = "RDMA"; +static constexpr uint16_t RDMA_HELLO_V2_MSG_LEN = 40; // In Byte +extern const uint16_t RDMA_HELLO_V2_VERSION = 2; +extern const uint16_t RDMA_IMPL_V2_VERSION = 1; + +// Wire-level constants for the v3 handshake. +static const char* MAGIC_STR_V3 = "RDM3"; +static const size_t RDMA_HELLO_V3_PB_SIZE_LEN = 4; +static const size_t RDMA_HELLO_V3_MAX_PB_SIZE = 4096; + +namespace v2_wire { + +void HelloMessage::Serialize(void* data) const { + uint16_t* current_pos = (uint16_t*)data; + *(current_pos++) = butil::HostToNet16(msg_len); + *(current_pos++) = butil::HostToNet16(hello_ver); + *(current_pos++) = butil::HostToNet16(impl_ver); + uint32_t* block_size_pos = (uint32_t*)current_pos; + *block_size_pos = butil::HostToNet32(block_size); + current_pos += 2; // move forward 4 Bytes + *(current_pos++) = butil::HostToNet16(sq_size); + *(current_pos++) = butil::HostToNet16(rq_size); + *(current_pos++) = butil::HostToNet16(lid); + fast_memcpy(current_pos, gid.raw, 16); + uint32_t* qp_num_pos = (uint32_t*)((char*)current_pos + 16); + *qp_num_pos = butil::HostToNet32(qp_num); +} + +void HelloMessage::Deserialize(void* data) { + uint16_t* current_pos = (uint16_t*)data; + msg_len = butil::NetToHost16(*current_pos++); + hello_ver = butil::NetToHost16(*current_pos++); + impl_ver = butil::NetToHost16(*current_pos++); + block_size = butil::NetToHost32(*(uint32_t*)current_pos); + current_pos += 2; // move forward 4 Bytes + sq_size = butil::NetToHost16(*current_pos++); + rq_size = butil::NetToHost16(*current_pos++); + lid = butil::NetToHost16(*current_pos++); + fast_memcpy(gid.raw, current_pos, 16); + qp_num = butil::NetToHost32(*(uint32_t*)((char*)current_pos + 16)); +} + +static bool ValidHelloMessage(const HelloMessage& msg) { + return msg.hello_ver == RDMA_HELLO_V2_VERSION && + msg.impl_ver == RDMA_IMPL_V2_VERSION && + msg.block_size >= MIN_BLOCK_SIZE && + msg.sq_size >= MIN_QP_SIZE && + msg.rq_size >= MIN_QP_SIZE; +} + +static void TranslateV2Hello(const HelloMessage& msg, ParsedHello* out) { + out->block_size = msg.block_size; + out->sq_size = msg.sq_size; + out->rq_size = msg.rq_size; + out->lid = msg.lid; + out->gid = msg.gid; + out->qp_num = msg.qp_num; +} + +int ReadBodyAndNegotiate(RdmaEndpoint* ep, ParsedHello* remote, bool* negotiated) { + uint8_t data[HELLO_MSG_LEN_MIN]; + if (ep->ReadFromFd(data, HELLO_MSG_LEN_MIN - MAGIC_STR_LEN) < 0) { + return -1; + } + HelloMessage remote_msg{}; + remote_msg.Deserialize(data); + if (remote_msg.msg_len < HELLO_MSG_LEN_MIN || + remote_msg.msg_len > HELLO_MSG_LEN_MAX) { + errno = EPROTO; + return -1; + } + if (remote_msg.msg_len > HELLO_MSG_LEN_MIN) { + // Drain unknown trailing bytes so they don't pollute subsequent + // reads (e.g. the upcoming ACK message). v2 base fields already + // carry enough information for negotiation; unknown trailing + // bytes are treated as optional hints that v2 safely ignores. + size_t ext_len = remote_msg.msg_len - HELLO_MSG_LEN_MIN; + if (DrainBytes(ep, ext_len) < 0) { + return -1; + } + } + if (!ValidHelloMessage(remote_msg)) { + *negotiated = false; + return 0; + } + *negotiated = true; + TranslateV2Hello(remote_msg, remote); + return 0; +} + +int DrainBytes(RdmaEndpoint* ep, size_t n) { + uint8_t scratch[64]; + while (n > 0) { + size_t chunk = std::min(n, sizeof(scratch)); + if (ep->ReadFromFd(scratch, chunk) < 0) { + return -1; + } + n -= chunk; + } + return 0; +} + +} // namespace v2_wire + +int RdmaHandshakeClientV2::SendLocalHello() { + RdmaEndpoint* ep = _ep; + uint8_t data[RDMA_HELLO_V2_MSG_LEN]; + + v2_wire::HelloMessage local_msg{}; + local_msg.msg_len = RDMA_HELLO_V2_MSG_LEN; + local_msg.hello_ver = RDMA_HELLO_V2_VERSION; + local_msg.impl_ver = RDMA_IMPL_V2_VERSION; + local_msg.block_size = g_rdma_recv_block_size; + local_msg.sq_size = ep->_sq_size; + local_msg.rq_size = ep->_rq_size; + local_msg.lid = GetRdmaLid(); + local_msg.gid = GetRdmaGid(); + if (BAIDU_LIKELY(ep->_resource)) { + local_msg.qp_num = ep->_resource->qp->qp_num; + } else { + // Only happens in UT + local_msg.qp_num = 0; + } + fast_memcpy(data, MAGIC_STR, 4); + local_msg.Serialize((char*)data + 4); + return ep->WriteToFd(data, RDMA_HELLO_V2_MSG_LEN); +} + +int RdmaHandshakeClientV2::ReceiveAndParseRemoteHello(ParsedHello* remote, + bool* negotiated) { + RdmaEndpoint* ep = _ep; + + // Read and verify magic (the endpoint did NOT pre-read magic on the client side). + uint8_t magic[MAGIC_STR_LEN]; + if (ep->ReadFromFd(magic, MAGIC_STR_LEN) < 0) { + return -1; + } + if (memcmp(magic, MAGIC_STR, MAGIC_STR_LEN) != 0) { + errno = EPROTO; + return -1; + } + return v2_wire::ReadBodyAndNegotiate(ep, remote, negotiated); +} + +int RdmaHandshakeServerV2::ReceiveAndParseRemoteHello(ParsedHello* remote, bool* negotiated) { + // Magic already consumed by ProcessHandshakeAtServer. + return v2_wire::ReadBodyAndNegotiate(_ep, remote, negotiated); +} + +int RdmaHandshakeServerV2::SendLocalHello() { + uint8_t data[RDMA_HELLO_V2_MSG_LEN]; + v2_wire::HelloMessage local_msg{}; + local_msg.msg_len = RDMA_HELLO_V2_MSG_LEN; + auto rdma_transport = static_cast(_ep->_socket->_transport.get()); + if (rdma_transport->_rdma_state == RdmaTransport::RDMA_OFF) { + local_msg.hello_ver = 0; + local_msg.impl_ver = 0; + local_msg.block_size = 0; + local_msg.sq_size = 0; + local_msg.rq_size = 0; + local_msg.lid = 0; + memset(local_msg.gid.raw, 0, sizeof(local_msg.gid.raw)); + local_msg.qp_num = 0; + } else { + local_msg.hello_ver = RDMA_HELLO_V2_VERSION; + local_msg.impl_ver = RDMA_IMPL_V2_VERSION; + local_msg.block_size = g_rdma_recv_block_size; + local_msg.sq_size = _ep->_sq_size; + local_msg.rq_size = _ep->_rq_size; + local_msg.lid = GetRdmaLid(); + local_msg.gid = GetRdmaGid(); + if (BAIDU_LIKELY(_ep->_resource)) { + local_msg.qp_num = _ep->_resource->qp->qp_num; + } else { + // Only happens in UT + local_msg.qp_num = 0; + } + } + fast_memcpy(data, MAGIC_STR, 4); + local_msg.Serialize((char*)data + 4); + return _ep->WriteToFd(data, RDMA_HELLO_V2_MSG_LEN); +} + +namespace v3_wire { + +bool ValidRdmaHello(const RdmaHello& msg) { + if (msg.gid().size() != sizeof(ibv_gid)) { + return false; + } + // ParsedHello stores these as uint16_t; reject values that would truncate. + constexpr uint16_t MAX_UINT16 = std::numeric_limits::max(); + if (msg.sq_size() > MAX_UINT16 || msg.rq_size() > MAX_UINT16 || msg.lid() > MAX_UINT16) { + return false; + } + if (msg.block_size() < MIN_BLOCK_SIZE) { + return false; + } + if (msg.sq_size() < MIN_QP_SIZE) { + return false; + } + if (msg.rq_size() < MIN_QP_SIZE) { + return false; + } + // qp_num == 0 only happens in UT (no real QP allocated). + if (msg.qp_num() == 0 && !g_skip_rdma_init) { + return false; + } + return true; +} + +void FillLocalRdmaHello(const RdmaEndpoint* ep, RdmaHello* msg) { + msg->set_block_size(g_rdma_recv_block_size); + msg->set_sq_size(ep->_sq_size); + msg->set_rq_size(ep->_rq_size); + msg->set_lid(GetRdmaLid()); + ibv_gid gid = GetRdmaGid(); + msg->set_gid(std::string(reinterpret_cast(gid.raw), + sizeof(gid.raw))); + if (BAIDU_LIKELY(ep->_resource)) { + msg->set_qp_num(ep->_resource->qp->qp_num); + } else { + // Only happens in UT + msg->set_qp_num(0); + } +} + +int ReadAndParseV3Hello(RdmaEndpoint* ep, RdmaHello* out) { + uint8_t size_buf[RDMA_HELLO_V3_PB_SIZE_LEN]; + if (ep->ReadFromFd(size_buf, RDMA_HELLO_V3_PB_SIZE_LEN) < 0) { + return -1; + } + uint32_t pb_size = butil::NetToHost32( + *reinterpret_cast(size_buf)); + if (pb_size == 0 || pb_size > RDMA_HELLO_V3_MAX_PB_SIZE) { + errno = EPROTO; + return -1; + } + butil::IOPortal body; + if (ep->ReadFromFd(&body, pb_size) < 0) { + return -1; + } + + butil::IOBufAsZeroCopyInputStream input(body); + if (!out->ParseFromZeroCopyStream(&input)) { + LOG(ERROR) << "Failed to parse RdmaHello"; + errno = EPROTO; + return -1; + } + return 0; +} + +int WriteV3Hello(RdmaEndpoint* ep, const RdmaHello& msg) { + uint32_t pb_size = static_cast(msg.ByteSizeLong()); + if (pb_size > RDMA_HELLO_V3_MAX_PB_SIZE) { + errno = EPROTO; + return -1; + } + + // [ "RDM3" 4B ][ pb_size 4B (big-endian) ][ RdmaHello protobuf bytes ] + butil::IOBuf packet; + packet.append(MAGIC_STR_V3, MAGIC_STR_LEN); + uint32_t pb_size_be = butil::HostToNet32(pb_size); + packet.append(&pb_size_be, RDMA_HELLO_V3_PB_SIZE_LEN); + butil::IOBufAsZeroCopyOutputStream output(&packet); + if (!msg.SerializeToZeroCopyStream(&output)) { + LOG(ERROR) << "Failed to serialize RdmaHello"; + errno = EPROTO; + return -1; + } + return ep->WriteToFd(&packet); +} + +void TranslateHello(const RdmaHello& msg, ParsedHello* out) { + out->block_size = msg.block_size(); + out->sq_size = static_cast(msg.sq_size()); + out->rq_size = static_cast(msg.rq_size()); + out->lid = static_cast(msg.lid()); + fast_memcpy(out->gid.raw, msg.gid().data(), sizeof(out->gid.raw)); + out->qp_num = msg.qp_num(); +} + +} // namespace v3_wire + +int RdmaHandshakeClientV3::SendLocalHello() { + RdmaHello local_msg{}; + v3_wire::FillLocalRdmaHello(_ep, &local_msg); + return v3_wire::WriteV3Hello(_ep, local_msg); +} + +int RdmaHandshakeClientV3::ReceiveAndParseRemoteHello(ParsedHello* remote, + bool* negotiated) { + uint8_t magic[MAGIC_STR_LEN]; + if (_ep->ReadFromFd(magic, MAGIC_STR_LEN) < 0) { + return -1; + } + if (memcmp(magic, MAGIC_STR_V3, MAGIC_STR_LEN) != 0) { + errno = EPROTO; + return -1; + } + + RdmaHello remote_msg{}; + if (v3_wire::ReadAndParseV3Hello(_ep, &remote_msg) < 0) { + return -1; + } + if (!v3_wire::ValidRdmaHello(remote_msg)) { + *negotiated = false; + return 0; + } + *negotiated = true; + v3_wire::TranslateHello(remote_msg, remote); + return 0; +} + +int RdmaHandshakeServerV3::ReceiveAndParseRemoteHello(ParsedHello* remote, bool* negotiated) { + // Magic already consumed by ProcessHandshakeAtServer. + RdmaHello remote_msg{}; + if (v3_wire::ReadAndParseV3Hello(_ep, &remote_msg) < 0) { + return -1; + } + if (!v3_wire::ValidRdmaHello(remote_msg)) { + *negotiated = false; + return 0; + } + *negotiated = true; + v3_wire::TranslateHello(remote_msg, remote); + return 0; +} + +int RdmaHandshakeServerV3::SendLocalHello() { + RdmaHello local_msg{}; + v3_wire::FillLocalRdmaHello(_ep, &local_msg); + return v3_wire::WriteV3Hello(_ep, local_msg); +} + +std::unique_ptr CreateClientHandshake(RdmaEndpoint* ep) { + switch (FLAGS_rdma_client_handshake_version) { + case 3: + return std::unique_ptr(new RdmaHandshakeClientV3(ep)); + case 2: + default: + return std::unique_ptr(new RdmaHandshakeClientV2(ep)); + } +} + +std::unique_ptr CreateServerHandshakeByMagic( + RdmaEndpoint* ep, const uint8_t magic[MAGIC_STR_LEN]) { + if (memcmp(magic, MAGIC_STR, MAGIC_STR_LEN) == 0) { + return std::unique_ptr(new RdmaHandshakeServerV2(ep)); + } + if (memcmp(magic, MAGIC_STR_V3, MAGIC_STR_LEN) == 0) { + return std::unique_ptr(new RdmaHandshakeServerV3(ep)); + } + return nullptr; +} + +} // namespace rdma +} // namespace brpc + +#endif // BRPC_WITH_RDMA diff --git a/src/brpc/rdma/rdma_handshake.h b/src/brpc/rdma/rdma_handshake.h new file mode 100644 index 0000000000..5f36a9e6e2 --- /dev/null +++ b/src/brpc/rdma/rdma_handshake.h @@ -0,0 +1,192 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef BRPC_RDMA_HANDSHAKE_H +#define BRPC_RDMA_HANDSHAKE_H + +#if BRPC_WITH_RDMA + +#include +#include +#include +#include +#include "butil/macros.h" + +namespace brpc { +namespace rdma { + +class RdmaEndpoint; + +// Length of the RDMA handshake magic string (e.g. "RDMA", "RDM3"). +static const size_t MAGIC_STR_LEN = 4; + +// Wire-format-agnostic representation of a peer's hello message. +// Each protocol version (v2 binary, v3 protobuf) translates its own +// wire format into this struct so the state-machine driver in +// RdmaEndpoint::ProcessHandshakeAt{Client,Server} stays free of any +// wire-format details. +struct ParsedHello { + uint32_t block_size; + uint16_t sq_size; + uint16_t rq_size; + uint16_t lid; + ibv_gid gid; + uint32_t qp_num; +}; + +namespace v2_wire { + +// Wire constants for the v2 hello. +// +// HELLO_MSG_LEN_MIN: total length of the base v2 hello (4B magic + +// 36B HelloMessage). Anything shorter than this is malformed. +// HELLO_MSG_LEN_MAX: upper bound for the entire v2 hello message +// length declared by HelloMessage::msg_len. Anything beyond this is +// treated as a protocol error and the connection is closed without +// attempting to drain. +static constexpr size_t HELLO_MSG_LEN_MIN = 40; +static constexpr size_t HELLO_MSG_LEN_MAX = 4096; + +// v2 binary HelloMessage. +struct HelloMessage { + void Serialize(void* data) const; + void Deserialize(void* data); + + uint16_t msg_len; + uint16_t hello_ver; + uint16_t impl_ver; + uint32_t block_size; + uint16_t sq_size; + uint16_t rq_size; + uint16_t lid; + ibv_gid gid; + uint32_t qp_num; +}; + +} // namespace v2_wire + +// Abstract base class of an RDMA handshake. +// +// Acts as the protocol-version dispatch point for the state machine +// driven by RdmaEndpoint::ProcessHandshakeAt{Client,Server}. +class RdmaHandshake { +public: + explicit RdmaHandshake(RdmaEndpoint* ep) : _ep(ep) {} + virtual ~RdmaHandshake() = default; + + DISALLOW_COPY_AND_ASSIGN(RdmaHandshake); + + // Wire-level protocol version (2 for "RDMA", 3 for "RDM3"). + virtual int ProtocolVersion() const = 0; + + // Build and send the local hello (including the protocol magic). + // Returns 0 on success, -1 on IO error (errno set). + // + // For a server in fallback state, implementations MUST still + // produce a sendable message; each version uses its own wire + // convention to signal "I am falling back" to the peer: + // - v2: zero hello_ver/impl_ver so the peer's HelloNegotiationValid + // rejects it; + // - v3: qp_num==0 so the peer's ValidRdmaHello rejects it. + virtual int SendLocalHello() = 0; + + // Read the peer's hello, validate it, and translate into ParsedHello. + // + // Role-specific semantics: + // - Client subclasses: read & verify the 4B magic first, then the + // body. (The endpoint did NOT pre-read the magic on the client + // side.) + // - Server subclasses: read ONLY the body. The 4B magic was + // already consumed by ProcessHandshakeAtServer and was used to + // pick `this` from CreateServerHandshakeByMagic; re-reading + // would deadlock. + // + // Outputs: + // *negotiated -- true if the remote hello is structurally valid + // AND passes per-protocol negotiation checks; + // false means the peer asked for fallback or sent + // something we can't honor. + // Returns: + // 0 -- IO/parsing layer OK; check *negotiated and *remote. + // -1 -- IO error or unrecoverable protocol error (errno set). + virtual int ReceiveAndParseRemoteHello(ParsedHello* remote, bool* negotiated) = 0; + +protected: + RdmaEndpoint* _ep; +}; + +// v2 handshake (legacy "RDMA" magic, 36B binary HelloMessage). +class RdmaHandshakeClientV2 : public RdmaHandshake { +public: + using RdmaHandshake::RdmaHandshake; + int ProtocolVersion() const override { return 2; } + + int SendLocalHello() override; + int ReceiveAndParseRemoteHello(ParsedHello* remote, bool* negotiated) override; +}; + +class RdmaHandshakeServerV2 : public RdmaHandshake { +public: + using RdmaHandshake::RdmaHandshake; + int ProtocolVersion() const override { return 2; } + + int SendLocalHello() override; + int ReceiveAndParseRemoteHello(ParsedHello* remote, bool* negotiated) override; +}; + +// v3 handshake (new "RDM3" magic, protobuf RdmaHello). +// [ "RDM3" 4B ][ pb_size 4B (big-endian) ][ RdmaHello protobuf bytes ] +class RdmaHandshakeClientV3 : public RdmaHandshake { +public: + using RdmaHandshake::RdmaHandshake; + int ProtocolVersion() const override { return 3; } + + int SendLocalHello() override; + int ReceiveAndParseRemoteHello(ParsedHello* remote, bool* negotiated) override; +}; + +class RdmaHandshakeServerV3 : public RdmaHandshake { +public: + using RdmaHandshake::RdmaHandshake; + int ProtocolVersion() const override { return 3; } + + int SendLocalHello() override; + int ReceiveAndParseRemoteHello(ParsedHello* remote, bool* negotiated) override; +}; + +// Factory methods +// +// Pick the client-side handshake based on +// FLAGS_rdma_client_handshake_version: +// 2 (default) -> RdmaHandshakeClientV2 +// 3 -> RdmaHandshakeClientV3 +// Other values fall back to V2. +std::unique_ptr CreateClientHandshake(RdmaEndpoint* ep); + +// Pick the server-side handshake based on the 4B magic already read. +// Returns NULL if `magic` is not a recognized RDMA magic +// (the caller should then fallback to TCP). +// "RDMA" -> RdmaHandshakeServerV2 +// "RDM3" -> RdmaHandshakeServerV3 +std::unique_ptr CreateServerHandshakeByMagic( + RdmaEndpoint* ep, const uint8_t magic[MAGIC_STR_LEN]); + +} // namespace rdma +} // namespace brpc + +#endif // BRPC_WITH_RDMA +#endif // BRPC_RDMA_HANDSHAKE_H diff --git a/src/brpc/rdma/rdma_handshake.proto b/src/brpc/rdma/rdma_handshake.proto new file mode 100644 index 0000000000..c180b58b96 --- /dev/null +++ b/src/brpc/rdma/rdma_handshake.proto @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +syntax = "proto2"; + +package brpc.rdma; + +option cc_generic_services = false; + +// RDMA handshake v3 message. +// Carried in the body of every "RDM3" handshake packet: +// +// [ "RDM3" 4B ][ pb_size 4B ][ RdmaHello protobuf bytes ] +message RdmaHello { + // ---- v2-parity base fields (required) ---- + // Listed first and in the same logical order as the v2 binary + // HelloMessage (minus hello_ver / impl_ver, which are subsumed by + // the wrapper magic "RDM3"). Keeping the same ordering simplifies + // side-by-side reasoning when debugging mixed v2/v3 traffic. + // + // Marked `required` because the handshake cannot proceed without + // any of these; ParseFromArray() will reject a missing field at + // the protobuf layer, so we don't need an extra has_xxx() check + // in RdmaHelloValid() for presence. + required uint32 block_size = 1; + required uint32 sq_size = 2; + required uint32 rq_size = 3; + required uint32 lid = 4; + // Must be exactly 16 bytes (sizeof(ibv_gid)). + required bytes gid = 5; + required uint32 qp_num = 6; +} diff --git a/src/brpc/rdma_transport.h b/src/brpc/rdma_transport.h index 65ae88f7a6..d8520b1a6d 100644 --- a/src/brpc/rdma_transport.h +++ b/src/brpc/rdma_transport.h @@ -25,9 +25,10 @@ namespace brpc { class RdmaTransport : public Transport { - friend class TransportFactory; - friend class rdma::RdmaEndpoint; - friend class rdma::RdmaConnect; +friend class TransportFactory; +friend class rdma::RdmaEndpoint; +friend class rdma::RdmaConnect; +friend class rdma::RdmaHandshakeServerV2; public: void Init(Socket* socket, const SocketOptions& options) override; void Release() override; @@ -47,7 +48,7 @@ class RdmaTransport : public Transport { private: static bool OptionsAvailableForRdma(const ChannelOptions* opt); static bool OptionsAvailableOverRdma(const ServerOptions* opt); -private: + // The on/off state of RDMA enum RdmaState { RDMA_ON, diff --git a/src/brpc/server.cpp b/src/brpc/server.cpp index 935e5f1bb1..57f665da91 100644 --- a/src/brpc/server.cpp +++ b/src/brpc/server.cpp @@ -855,26 +855,37 @@ int Server::StartInternal(const butil::EndPoint& endpoint, return -1; } - copy_and_fill_server_options(_options, opt ? *opt : ServerOptions()); - - if (!_options.h2_settings.IsValid(true/*log_error*/)) { + // Validate the user-provided ServerOptions BEFORE + // copy_and_fill_server_options below. This is important: + // copy_and_fill_server_options unconditionally transfers ownership of + // user-provided pointers (nshead_service, thrift_service, ...) into + // _options. If we instead validated against _options after the copy, + // a failed Start() would leave fake/invalid pointers behind in + // _options, and the NEXT Start() would attempt to `delete` them via + // FREE_PTR_IF_NOT_REUSED, crashing (see RdmaTest.server_option_invalid). + const ServerOptions default_opt; + const ServerOptions& real_opt = opt ? *opt : default_opt; + + if (!real_opt.h2_settings.IsValid(true/*log_error*/)) { LOG(ERROR) << "Invalid h2_settings"; return -1; } - if (_options.bthread_tag < BTHREAD_TAG_DEFAULT || - _options.bthread_tag >= FLAGS_task_group_ntags) { - LOG(ERROR) << "Fail to set tag " << _options.bthread_tag + if (real_opt.bthread_tag < BTHREAD_TAG_DEFAULT || + real_opt.bthread_tag >= FLAGS_task_group_ntags) { + LOG(ERROR) << "Fail to set tag " << real_opt.bthread_tag << ", tag range is [" << BTHREAD_TAG_DEFAULT << ":" << FLAGS_task_group_ntags << ")"; return -1; } - int ret = TransportFactory::ContextInitOrDie(_options.socket_mode, true, &_options); + int ret = TransportFactory::ContextInitOrDie(real_opt.socket_mode, true, &real_opt); if (ret != 0) { LOG(ERROR) << "Fail to initialize transport context for server, ret=" << ret; return -1; } + copy_and_fill_server_options(_options, real_opt); + if (_options.http_master_service) { // Check requirements for http_master_service: // has "default_method" & request/response have no fields diff --git a/src/brpc/socket.cpp b/src/brpc/socket.cpp index 0ca6950428..a3d43fa3b8 100644 --- a/src/brpc/socket.cpp +++ b/src/brpc/socket.cpp @@ -1554,8 +1554,7 @@ void Socket::CheckConnectedAndKeepWrite(int fd, int err, void* data) { g_vars->channel_conn << 1; } if (s->_app_connect) { - s->_app_connect->StartConnect(req->get_socket(), - AfterAppConnected, req); + s->_app_connect->StartConnect(req->get_socket(), AfterAppConnected, req); } else { // Successfully created a connection AfterAppConnected(0, req); diff --git a/src/brpc/socket.h b/src/brpc/socket.h index 816fccdf27..7311d73895 100644 --- a/src/brpc/socket.h +++ b/src/brpc/socket.h @@ -56,6 +56,10 @@ class ChannelBalancer; namespace rdma { class RdmaEndpoint; class RdmaConnect; +class RdmaHandshakeClientV2; +class RdmaHandshakeServerV2; +class RdmaHandshakeClientV3; +class RdmaHandshakeServerV3; } class Socket; @@ -317,6 +321,10 @@ friend class policy::RtmpContext; friend class schan::ChannelBalancer; friend class rdma::RdmaEndpoint; friend class rdma::RdmaConnect; +friend class rdma::RdmaHandshakeClientV2; +friend class rdma::RdmaHandshakeServerV2; +friend class rdma::RdmaHandshakeClientV3; +friend class rdma::RdmaHandshakeServerV3; friend class HealthCheckTask; friend class OnAppHealthCheckDone; friend class HealthCheckManager; diff --git a/src/butil/thread_key.h b/src/butil/thread_key.h index c150528b63..77f346d608 100644 --- a/src/butil/thread_key.h +++ b/src/butil/thread_key.h @@ -18,6 +18,7 @@ #ifndef BUTIL_THREAD_KEY_H #define BUTIL_THREAD_KEY_H +#include #include #include #include diff --git a/test/brpc_rdma_unittest.cpp b/test/brpc_rdma_unittest.cpp index ccb280f1c8..43c6edfd12 100644 --- a/test/brpc_rdma_unittest.cpp +++ b/test/brpc_rdma_unittest.cpp @@ -24,7 +24,6 @@ #include #include "butil/endpoint.h" #include "butil/fd_guard.h" -#include "butil/fd_utility.h" #include "butil/iobuf.h" #include "butil/sys_byteorder.h" #include "butil/files/temp_file.h" @@ -36,15 +35,15 @@ #include "brpc/errno.pb.h" #include "brpc/parallel_channel.h" #include "brpc/selective_channel.h" +#include "brpc/rdma_transport.h" #include "brpc/rdma/block_pool.h" #include "brpc/rdma/rdma_endpoint.h" +#include "brpc/rdma/rdma_handshake.h" +#include "brpc/rdma/rdma_handshake.pb.h" #include "brpc/rdma/rdma_helper.h" #include "echo.pb.h" static const int PORT = 8713; -static const size_t RDMA_HELLO_MSG_LEN = 40; -static uint16_t RDMA_HELLO_VERSION = 2; -static uint16_t RDMA_IMPL_VERSION = 1; using namespace brpc; @@ -56,23 +55,13 @@ DEFINE_bool(rdma_test_enable, false, "Enable tests requring rdma runtime."); namespace rdma { -struct HelloMessage { - void Serialize(void* data) const; - void Deserialize(void* data); - - uint16_t msg_len; - uint16_t hello_ver; - uint16_t impl_ver; - uint32_t block_size; - uint16_t sq_size; - uint16_t rq_size; - uint16_t lid; - ibv_gid gid; - uint32_t qp_num; -}; +extern const uint16_t RDMA_HELLO_V2_VERSION; +extern const uint16_t RDMA_IMPL_V2_VERSION; DECLARE_bool(rdma_trace_verbose); DECLARE_int32(rdma_memory_pool_max_regions); +DECLARE_int32(rdma_client_handshake_version); + extern ibv_cq* (*IbvCreateCq)(ibv_context*, int, void*, ibv_comp_channel*, int); extern int (*IbvDestroyCq)(ibv_cq*); extern ibv_qp* (*IbvCreateQp)(ibv_pd*, ibv_qp_init_attr*); @@ -81,8 +70,8 @@ extern int (*IbvQueryQp)(ibv_qp*, ibv_qp_attr*, ibv_qp_attr_mask, ibv_qp_init_at extern int (*IbvDestroyQp)(ibv_qp*); extern butil::atomic g_rdma_available; extern bool g_skip_rdma_init; -} -} +} // namespace rdma +} // namespace brpc static std::string g_ip = "127.0.0.1"; static butil::EndPoint g_ep; @@ -109,7 +98,7 @@ class MyEchoService : public ::test::EchoService { LOG(INFO) << "sleep " << req->sleep_us() << "us..."; bthread_usleep(req->sleep_us()); } - res->set_message(req->message()); + res->set_message("MyEchoService"); if (req->code() != 0) { res->add_code_list(req->code()); } @@ -136,11 +125,12 @@ class RdmaTest : public ::testing::Test { rdma::DumpMemoryPoolInfo(std::cout); } -private: +protected: void StartServer(bool use_rdma = true) { ServerOptions options; - options.use_rdma = use_rdma; - options.idle_timeout_sec = 10; + options.enabled_protocols = "baidu_std"; + options.socket_mode = use_rdma ? SOCKET_MODE_RDMA : SOCKET_MODE_TCP; + options.idle_timeout_sec = 5; options.max_concurrency = 0; options.internal_port = -1; EXPECT_EQ(0, _server.Start(PORT, &options)); @@ -171,6 +161,29 @@ class RdmaTest : public ::testing::Test { MyEchoService _svc; }; +// Parameterized fixture used by upper-layer RPC tests that have no +// dependency on the handshake wire format. The parameter is the +// client-side handshake protocol version (FLAGS_rdma_client_handshake_version), +// so every TEST_P below is automatically executed once per supported +// version. Add a new version to INSTANTIATE_TEST_SUITE_P at the bottom +// of this file and these RPC tests will gain coverage for free. +class RdmaRpcTest : public RdmaTest, + public ::testing::WithParamInterface { +protected: + void SetUp() override { + RdmaTest::SetUp(); + _saved_handshake_version = rdma::FLAGS_rdma_client_handshake_version; + rdma::FLAGS_rdma_client_handshake_version = GetParam(); + } + void TearDown() override { + rdma::FLAGS_rdma_client_handshake_version = _saved_handshake_version; + RdmaTest::TearDown(); + } + +private: + int _saved_handshake_version = 2; +}; + TEST_F(RdmaTest, client_close_before_hello_send) { StartServer(); @@ -184,7 +197,7 @@ TEST_F(RdmaTest, client_close_before_hello_send) { ASSERT_EQ(0, connect(sockfd, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg Socket* s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); close(sockfd); usleep(100000); // wait for server to handle the msg ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -205,13 +218,13 @@ TEST_F(RdmaTest, client_hello_msg_invalid_magic_str) { ASSERT_EQ(0, connect(sockfd, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg Socket* s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); - uint8_t data[RDMA_HELLO_MSG_LEN]; + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; memcpy(data, "PRPC", 4); // send as normal baidu_std protocol ASSERT_EQ(4, write(sockfd, data, 4)); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); StopServer(); } @@ -231,11 +244,11 @@ TEST_F(RdmaTest, client_close_during_hello_send) { ASSERT_EQ(0, connect(sockfd1, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); memcpy(data, "RD", 2); ASSERT_EQ(2, write(sockfd1, data, 2)); // break in magic str usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); close(sockfd1); usleep(100000); // wait for server to handle the msg ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -245,11 +258,11 @@ TEST_F(RdmaTest, client_close_during_hello_send) { ASSERT_EQ(0, connect(sockfd2, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); memcpy(data, "RDMA", 4); ASSERT_EQ(4, write(sockfd2, data, 4)); // break after magic str usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); close(sockfd2); usleep(100000); // wait for server to handle the msg ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -259,12 +272,12 @@ TEST_F(RdmaTest, client_close_during_hello_send) { ASSERT_EQ(0, connect(sockfd3, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); memcpy(data, "RDMA", 4); memset(data + 4, 0, 4); ASSERT_EQ(8, write(sockfd3, data, 8)); // break after magic str usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); close(sockfd3); usleep(100000); // wait for server to handle the msg ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -280,18 +293,18 @@ TEST_F(RdmaTest, client_hello_msg_invalid_len) { addr.sin_family = AF_INET; addr.sin_port = htons(PORT); Socket* s = NULL; - uint8_t data[RDMA_HELLO_MSG_LEN]; + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; butil::fd_guard sockfd1(socket(AF_INET, SOCK_STREAM, 0)); ASSERT_TRUE(sockfd1 >= 0); ASSERT_EQ(0, connect(sockfd1, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); memcpy(data, "RDMA", 4); ASSERT_EQ(4, write(sockfd1, data, 4)); // Write magic string. usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); memset(data + 4, 0, 36); ASSERT_EQ(36, write(sockfd1, data + 4, 36)); // Write invalid length. usleep(100000); // wait for server to handle the msg @@ -302,11 +315,11 @@ TEST_F(RdmaTest, client_hello_msg_invalid_len) { ASSERT_EQ(0, connect(sockfd2, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); memcpy(data, "RDMA", 4); ASSERT_EQ(4, write(sockfd2, data, 4)); // Write magic string. usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); uint16_t len = butil::HostToNet16(35); memcpy(data + 4, &len, sizeof(len)); memset(data + 6, 0, 34); @@ -325,8 +338,8 @@ TEST_F(RdmaTest, client_hello_msg_invalid_version) { addr.sin_family = AF_INET; addr.sin_port = htons(PORT); Socket* s = NULL; - uint8_t data[RDMA_HELLO_MSG_LEN]; - uint16_t len = butil::HostToNet16(RDMA_HELLO_MSG_LEN); + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + uint16_t len = butil::HostToNet16(rdma::v2_wire::HELLO_MSG_LEN_MIN); uint16_t ver = butil::HostToNet16(1); butil::fd_guard sockfd1(socket(AF_INET, SOCK_STREAM, 0)); @@ -334,22 +347,29 @@ TEST_F(RdmaTest, client_hello_msg_invalid_version) { ASSERT_EQ(0, connect(sockfd1, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); memcpy(data, "RDMA", 4); ASSERT_EQ(4, write(sockfd1, data, 4)); // Write magic string. usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); memcpy(data + 4, &len, 2); memset(data + 6, 0, 34); memcpy(data + 6, &ver, 2); // hello_ver == 1, impl_ver == 0 - ASSERT_EQ(36, write(sockfd1, data, 36)); + // Write the 36B base starting at data + 4 (NOT data). Pre-Step-1 this + // UT mistakenly wrote `data, 36` which included the leftover "RDMA" + // magic at data[0..4); the server parsed it as msg_len = 0x5244 and + // happened to fall through to NegotiationValid (which then failed on + // hello_ver). Now that Step 1 enforces a HELLO_MSG_LEN_MAX upper bound, + // such an oversized msg_len would be rejected before reaching the + // version check, breaking the intent of this UT. + ASSERT_EQ(36, write(sockfd1, data + 4, 36)); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, s->_rdma_ep->_state); - ASSERT_EQ(Socket::RDMA_OFF, s->_rdma_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); + ASSERT_EQ(RdmaTransport::RDMA_OFF, static_cast(s->_transport.get())->_rdma_state); uint32_t flags = 0; ASSERT_EQ(sizeof(flags), write(sockfd1, &flags, sizeof(flags))); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); sockfd1.reset(-1); usleep(100000); // wait for server to handle the msg ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -359,21 +379,23 @@ TEST_F(RdmaTest, client_hello_msg_invalid_version) { ASSERT_EQ(0, connect(sockfd2, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); memcpy(data, "RDMA", 4); ASSERT_EQ(4, write(sockfd2, data, 4)); // Write magic string. usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); memcpy(data + 4, &len, 2); memset(data + 6, 0, 32); memcpy(data + 8, &ver, 2); // hello_ver == 0, impl_ver == 1 - ASSERT_EQ(36, write(sockfd2, data, 36)); + // See comment above on `write(sockfd1, data + 4, 36)` for why we + // write from data + 4 instead of data. + ASSERT_EQ(36, write(sockfd2, data + 4, 36)); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, s->_rdma_ep->_state); - ASSERT_EQ(Socket::RDMA_OFF, s->_rdma_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); + ASSERT_EQ(RdmaTransport::RDMA_OFF, static_cast(s->_transport.get())->_rdma_state); ASSERT_EQ(sizeof(flags), write(sockfd2, &flags, sizeof(flags))); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); sockfd2.reset(-1); usleep(100000); // wait for server to handle the msg ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -390,11 +412,11 @@ TEST_F(RdmaTest, client_hello_msg_invalid_sq_rq_block_size) { addr.sin_port = htons(PORT); Socket* s = NULL; uint32_t flags = butil::HostToNet32(0); - rdma::HelloMessage msg{}; - uint8_t data[RDMA_HELLO_MSG_LEN]; - msg.msg_len = RDMA_HELLO_MSG_LEN; - msg.hello_ver = RDMA_HELLO_VERSION; - msg.impl_ver = RDMA_IMPL_VERSION; + rdma::v2_wire::HelloMessage msg{}; + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + msg.msg_len = rdma::v2_wire::HELLO_MSG_LEN_MIN; + msg.hello_ver = rdma::RDMA_HELLO_V2_VERSION; + msg.impl_ver = rdma::RDMA_IMPL_V2_VERSION; msg.sq_size = 10; msg.rq_size = 16; @@ -406,17 +428,17 @@ TEST_F(RdmaTest, client_hello_msg_invalid_sq_rq_block_size) { ASSERT_EQ(0, connect(sockfd1, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(4, write(sockfd1, data, 4)); // Write magic string. usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(36, write(sockfd1, data + 4, 36)); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, s->_rdma_ep->_state); - ASSERT_EQ(Socket::RDMA_OFF, s->_rdma_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); + ASSERT_EQ(RdmaTransport::RDMA_OFF, static_cast(s->_transport.get())->_rdma_state); ASSERT_EQ(sizeof(flags), write(sockfd1, &flags, sizeof(flags))); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); sockfd1.reset(-1); usleep(100000); // wait for server to handle the msg ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -431,17 +453,17 @@ TEST_F(RdmaTest, client_hello_msg_invalid_sq_rq_block_size) { ASSERT_EQ(0, connect(sockfd2, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(4, write(sockfd2, data, 4)); // Write magic string. usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(36, write(sockfd2, data + 4, 36)); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, s->_rdma_ep->_state); - ASSERT_EQ(Socket::RDMA_OFF, s->_rdma_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); + ASSERT_EQ(RdmaTransport::RDMA_OFF, static_cast(s->_transport.get())->_rdma_state); ASSERT_EQ(sizeof(flags), write(sockfd2, &flags, sizeof(flags))); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); sockfd2.reset(-1); usleep(100000); // wait for server to handle the msg ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -456,17 +478,17 @@ TEST_F(RdmaTest, client_hello_msg_invalid_sq_rq_block_size) { ASSERT_EQ(0, connect(sockfd3, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(4, write(sockfd3, data, 4)); // Write magic string. usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(36, write(sockfd3, data + 4, 36)); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, s->_rdma_ep->_state); - ASSERT_EQ(Socket::RDMA_OFF, s->_rdma_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); + ASSERT_EQ(RdmaTransport::RDMA_OFF, static_cast(s->_transport.get())->_rdma_state); ASSERT_EQ(sizeof(flags), write(sockfd3, &flags, sizeof(flags))); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); sockfd3.reset(-1); usleep(100000); // wait for server to handle the msg ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -482,11 +504,11 @@ TEST_F(RdmaTest, client_close_after_qp_build) { addr.sin_family = AF_INET; addr.sin_port = htons(PORT); Socket* s = NULL; - rdma::HelloMessage msg; - uint8_t data[RDMA_HELLO_MSG_LEN]; - msg.msg_len = RDMA_HELLO_MSG_LEN; - msg.hello_ver = RDMA_HELLO_VERSION; - msg.impl_ver = RDMA_IMPL_VERSION; + rdma::v2_wire::HelloMessage msg{}; + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + msg.msg_len = rdma::v2_wire::HELLO_MSG_LEN_MIN; + msg.hello_ver = rdma::RDMA_HELLO_V2_VERSION; + msg.impl_ver = rdma::RDMA_IMPL_V2_VERSION; msg.sq_size = 16; msg.rq_size = 16; msg.block_size = 8192; @@ -500,10 +522,10 @@ TEST_F(RdmaTest, client_close_after_qp_build) { ASSERT_EQ(0, connect(sockfd1, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(40, write(sockfd1, data, 40)); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); close(sockfd1); usleep(100000); // wait for server to handle the msg ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -519,11 +541,11 @@ TEST_F(RdmaTest, client_close_during_ack_send) { addr.sin_family = AF_INET; addr.sin_port = htons(PORT); Socket* s = NULL; - rdma::HelloMessage msg; - uint8_t data[RDMA_HELLO_MSG_LEN]; - msg.msg_len = RDMA_HELLO_MSG_LEN; - msg.hello_ver = RDMA_HELLO_VERSION; - msg.impl_ver = RDMA_IMPL_VERSION; + rdma::v2_wire::HelloMessage msg{}; + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + msg.msg_len = rdma::v2_wire::HELLO_MSG_LEN_MIN; + msg.hello_ver = rdma::RDMA_HELLO_V2_VERSION; + msg.impl_ver = rdma::RDMA_IMPL_V2_VERSION; msg.sq_size = 16; msg.rq_size = 16; msg.block_size = 8192; @@ -537,17 +559,17 @@ TEST_F(RdmaTest, client_close_during_ack_send) { ASSERT_EQ(0, connect(sockfd1, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(4, write(sockfd1, data, 4)); // Write magic string. usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(36, write(sockfd1, data + 4, 36)); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); uint32_t flags = butil::HostToNet32(1); ASSERT_EQ(sizeof(flags), write(sockfd1, &flags, sizeof(flags))); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::ESTABLISHED, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::ESTABLISHED, static_cast(s->_transport.get())->_rdma_ep->_state); close(sockfd1); usleep(100000); // wait for server to handle the msg ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -563,11 +585,11 @@ TEST_F(RdmaTest, client_close_after_ack_send) { addr.sin_family = AF_INET; addr.sin_port = htons(PORT); Socket* s = NULL; - rdma::HelloMessage msg; - uint8_t data[RDMA_HELLO_MSG_LEN]; - msg.msg_len = RDMA_HELLO_MSG_LEN; - msg.hello_ver = RDMA_HELLO_VERSION; - msg.impl_ver = RDMA_IMPL_VERSION; + rdma::v2_wire::HelloMessage msg{}; + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + msg.msg_len = rdma::v2_wire::HELLO_MSG_LEN_MIN; + msg.hello_ver = rdma::RDMA_HELLO_V2_VERSION; + msg.impl_ver = rdma::RDMA_IMPL_V2_VERSION; msg.sq_size = 16; msg.rq_size = 16; msg.block_size = 8192; @@ -581,18 +603,18 @@ TEST_F(RdmaTest, client_close_after_ack_send) { ASSERT_EQ(0, connect(sockfd1, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(4, write(sockfd1, data, 4)); // Write magic string. usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(36, write(sockfd1, data + 4, 36)); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); uint32_t flags = butil::HostToNet32(0); ASSERT_EQ(sizeof(flags), write(sockfd1, &flags, sizeof(flags))); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, s->_rdma_ep->_state); - ASSERT_EQ(Socket::RDMA_OFF, s->_rdma_state); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); + ASSERT_EQ(RdmaTransport::RDMA_OFF, static_cast(s->_transport.get())->_rdma_state); close(sockfd1); usleep(100000); // wait for server to handle the msg ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -602,17 +624,17 @@ TEST_F(RdmaTest, client_close_after_ack_send) { ASSERT_EQ(0, connect(sockfd2, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(4, write(sockfd2, data, 4)); // Write magic string. usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(36, write(sockfd2, data + 4, 36)); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); flags = butil::HostToNet32(1); ASSERT_EQ(sizeof(flags), write(sockfd2, &flags, sizeof(flags))); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::ESTABLISHED, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::ESTABLISHED, static_cast(s->_transport.get())->_rdma_ep->_state); close(sockfd2); usleep(100000); // wait for server to handle the msg ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -628,11 +650,11 @@ TEST_F(RdmaTest, client_send_data_on_tcp_after_ack_send) { addr.sin_family = AF_INET; addr.sin_port = htons(PORT); Socket* s = NULL; - rdma::HelloMessage msg; - uint8_t data[RDMA_HELLO_MSG_LEN]; - msg.msg_len = RDMA_HELLO_MSG_LEN; - msg.hello_ver = RDMA_HELLO_VERSION; - msg.impl_ver = RDMA_IMPL_VERSION; + rdma::v2_wire::HelloMessage msg{}; + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + msg.msg_len = rdma::v2_wire::HELLO_MSG_LEN_MIN; + msg.hello_ver = rdma::RDMA_HELLO_V2_VERSION; + msg.impl_ver = rdma::RDMA_IMPL_V2_VERSION; msg.sq_size = 16; msg.rq_size = 16; msg.block_size = 8192; @@ -646,17 +668,17 @@ TEST_F(RdmaTest, client_send_data_on_tcp_after_ack_send) { ASSERT_EQ(0, connect(sockfd1, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(4, write(sockfd1, data, 4)); // Write magic string. usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(36, write(sockfd1, data + 4, 36)); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); uint32_t flags = butil::HostToNet32(0); ASSERT_EQ(sizeof(flags), write(sockfd1, &flags, sizeof(flags))); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(sizeof(flags), write(sockfd1, &flags, sizeof(flags))); usleep(100000); ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -666,17 +688,17 @@ TEST_F(RdmaTest, client_send_data_on_tcp_after_ack_send) { ASSERT_EQ(0, connect(sockfd2, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(4, write(sockfd2, data, 4)); // Write magic string. usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(36, write(sockfd2, data + 4, 36)); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); flags = butil::HostToNet32(1); ASSERT_EQ(sizeof(flags), write(sockfd2, &flags, sizeof(flags))); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::ESTABLISHED, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::ESTABLISHED, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(sizeof(flags), write(sockfd2, &flags, sizeof(flags))); usleep(100000); ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -690,7 +712,7 @@ TEST_F(RdmaTest, server_miss_before_hello_send) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -706,7 +728,7 @@ TEST_F(RdmaTest, server_miss_before_hello_send) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); ASSERT_TRUE(acc_fd >= 0); @@ -721,7 +743,7 @@ TEST_F(RdmaTest, server_close_before_hello_send) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -737,15 +759,15 @@ TEST_F(RdmaTest, server_close_before_hello_send) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); ASSERT_TRUE(acc_fd >= 0); - uint8_t data[RDMA_HELLO_MSG_LEN]; - ASSERT_EQ(RDMA_HELLO_MSG_LEN, read(acc_fd, data, RDMA_HELLO_MSG_LEN)); + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); close(acc_fd); usleep(100000); - ASSERT_EQ(rdma::RdmaEndpoint::FAILED, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FAILED, static_cast(s->_transport.get())->_rdma_ep->_state); bthread_id_join(cntl.call_id()); ASSERT_EQ(EEOF, cntl.ErrorCode()); @@ -757,7 +779,7 @@ TEST_F(RdmaTest, server_miss_during_magic_str) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -773,12 +795,12 @@ TEST_F(RdmaTest, server_miss_during_magic_str) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); ASSERT_TRUE(acc_fd >= 0); - uint8_t data[RDMA_HELLO_MSG_LEN]; - ASSERT_EQ(RDMA_HELLO_MSG_LEN, read(acc_fd, data, RDMA_HELLO_MSG_LEN)); + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); ASSERT_EQ(2, write(acc_fd, "RD", 2)); usleep(100000); bthread_id_join(cntl.call_id()); @@ -792,7 +814,7 @@ TEST_F(RdmaTest, server_close_during_magic_str) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -808,17 +830,17 @@ TEST_F(RdmaTest, server_close_during_magic_str) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); ASSERT_TRUE(acc_fd >= 0); - uint8_t data[RDMA_HELLO_MSG_LEN]; - ASSERT_EQ(RDMA_HELLO_MSG_LEN, read(acc_fd, data, RDMA_HELLO_MSG_LEN)); + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); ASSERT_EQ(2, write(acc_fd, "RD", 2)); usleep(100000); close(acc_fd); usleep(100000); - ASSERT_EQ(rdma::RdmaEndpoint::FAILED, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FAILED, static_cast(s->_transport.get())->_rdma_ep->_state); bthread_id_join(cntl.call_id()); ASSERT_EQ(EEOF, cntl.ErrorCode()); @@ -830,7 +852,7 @@ TEST_F(RdmaTest, server_hello_invalid_magic_str) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -846,15 +868,15 @@ TEST_F(RdmaTest, server_hello_invalid_magic_str) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); ASSERT_TRUE(acc_fd >= 0); - uint8_t data[RDMA_HELLO_MSG_LEN]; - ASSERT_EQ(RDMA_HELLO_MSG_LEN, read(acc_fd, data, RDMA_HELLO_MSG_LEN)); + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); ASSERT_EQ(4, write(acc_fd, "ABCD", 4)); usleep(100000); - ASSERT_EQ(rdma::RdmaEndpoint::FAILED, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FAILED, static_cast(s->_transport.get())->_rdma_ep->_state); bthread_id_join(cntl.call_id()); ASSERT_EQ(EPROTO, cntl.ErrorCode()); @@ -866,7 +888,7 @@ TEST_F(RdmaTest, server_miss_during_hello_msg) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -882,12 +904,12 @@ TEST_F(RdmaTest, server_miss_during_hello_msg) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); ASSERT_TRUE(acc_fd >= 0); - uint8_t data[RDMA_HELLO_MSG_LEN]; - ASSERT_EQ(RDMA_HELLO_MSG_LEN, read(acc_fd, data, RDMA_HELLO_MSG_LEN)); + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); ASSERT_EQ(4, write(acc_fd, "RDMA", 4)); ASSERT_EQ(2, write(acc_fd, "00", 2)); bthread_id_join(cntl.call_id()); @@ -901,7 +923,7 @@ TEST_F(RdmaTest, server_close_during_hello_msg) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -917,17 +939,17 @@ TEST_F(RdmaTest, server_close_during_hello_msg) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); ASSERT_TRUE(acc_fd >= 0); - uint8_t data[RDMA_HELLO_MSG_LEN]; - ASSERT_EQ(RDMA_HELLO_MSG_LEN, read(acc_fd, data, RDMA_HELLO_MSG_LEN)); + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); ASSERT_EQ(4, write(acc_fd, "RDMA", 4)); ASSERT_EQ(2, write(acc_fd, "00", 2)); close(acc_fd); usleep(100000); - ASSERT_EQ(rdma::RdmaEndpoint::FAILED, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FAILED, static_cast(s->_transport.get())->_rdma_ep->_state); bthread_id_join(cntl.call_id()); ASSERT_EQ(EEOF, cntl.ErrorCode()); @@ -939,7 +961,7 @@ TEST_F(RdmaTest, server_hello_invalid_msg_len) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -955,19 +977,19 @@ TEST_F(RdmaTest, server_hello_invalid_msg_len) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); ASSERT_TRUE(acc_fd >= 0); - uint8_t data[RDMA_HELLO_MSG_LEN]; - ASSERT_EQ(RDMA_HELLO_MSG_LEN, read(acc_fd, data, RDMA_HELLO_MSG_LEN)); + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); memcpy(data, "RDMA", 4); uint16_t len = butil::HostToNet16(35); memcpy(data + 4, &len, 2); memset(data + 6, 0, 32); - ASSERT_EQ(RDMA_HELLO_MSG_LEN, write(acc_fd, data, RDMA_HELLO_MSG_LEN)); + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, write(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); usleep(100000); - ASSERT_EQ(rdma::RdmaEndpoint::FAILED, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FAILED, static_cast(s->_transport.get())->_rdma_ep->_state); bthread_id_join(cntl.call_id()); ASSERT_EQ(EPROTO, cntl.ErrorCode()); @@ -979,7 +1001,7 @@ TEST_F(RdmaTest, server_hello_invalid_version) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -995,19 +1017,19 @@ TEST_F(RdmaTest, server_hello_invalid_version) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); ASSERT_TRUE(acc_fd >= 0); - uint8_t data[RDMA_HELLO_MSG_LEN]; - ASSERT_EQ(RDMA_HELLO_MSG_LEN, read(acc_fd, data, RDMA_HELLO_MSG_LEN)); + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); memcpy(data, "RDMA", 4); - uint16_t len = butil::HostToNet16(RDMA_HELLO_MSG_LEN); + uint16_t len = butil::HostToNet16(rdma::v2_wire::HELLO_MSG_LEN_MIN); memcpy(data + 4, &len, 2); memset(data + 6, 0, 32); - ASSERT_EQ(RDMA_HELLO_MSG_LEN, write(acc_fd, data, RDMA_HELLO_MSG_LEN)); + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, write(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); usleep(100000); - ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(4, read(acc_fd, data, 4)); uint32_t* tmp = (uint32_t*)data; ASSERT_EQ(0, butil::NetToHost32(*tmp)); @@ -1022,7 +1044,7 @@ TEST_F(RdmaTest, server_hello_invalid_sq_rq_size) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -1038,15 +1060,15 @@ TEST_F(RdmaTest, server_hello_invalid_sq_rq_size) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); ASSERT_TRUE(acc_fd >= 0); - uint8_t data[RDMA_HELLO_MSG_LEN]; - ASSERT_EQ(RDMA_HELLO_MSG_LEN, read(acc_fd, data, RDMA_HELLO_MSG_LEN)); + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); - rdma::HelloMessage msg; - msg.msg_len = RDMA_HELLO_MSG_LEN; + rdma::v2_wire::HelloMessage msg{}; + msg.msg_len = rdma::v2_wire::HELLO_MSG_LEN_MIN; msg.hello_ver = 1; msg.impl_ver = 1; msg.sq_size = 0; @@ -1056,10 +1078,10 @@ TEST_F(RdmaTest, server_hello_invalid_sq_rq_size) { msg.gid = rdma::GetRdmaGid(); memcpy(data, "RDMA", 4); msg.Serialize(data + 4); - ASSERT_EQ(RDMA_HELLO_MSG_LEN, write(acc_fd, data, RDMA_HELLO_MSG_LEN)); + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, write(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); usleep(100000); - ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(4, read(acc_fd, data, 4)); uint32_t* tmp = (uint32_t*)data; ASSERT_EQ(0, butil::NetToHost32(*tmp)); @@ -1074,7 +1096,7 @@ TEST_F(RdmaTest, server_miss_after_ack) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -1090,17 +1112,17 @@ TEST_F(RdmaTest, server_miss_after_ack) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); ASSERT_TRUE(acc_fd >= 0); - uint8_t data[RDMA_HELLO_MSG_LEN]; - ASSERT_EQ(RDMA_HELLO_MSG_LEN, read(acc_fd, data, RDMA_HELLO_MSG_LEN)); + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); - rdma::HelloMessage msg; - msg.msg_len = RDMA_HELLO_MSG_LEN; - msg.hello_ver = RDMA_HELLO_VERSION; - msg.impl_ver = RDMA_IMPL_VERSION; + rdma::v2_wire::HelloMessage msg{}; + msg.msg_len = rdma::v2_wire::HELLO_MSG_LEN_MIN; + msg.hello_ver = rdma::RDMA_HELLO_V2_VERSION; + msg.impl_ver = rdma::RDMA_IMPL_V2_VERSION; msg.sq_size = 16; msg.rq_size = 16; msg.block_size = 8192; @@ -1108,10 +1130,10 @@ TEST_F(RdmaTest, server_miss_after_ack) { msg.gid = rdma::GetRdmaGid(); memcpy(data, "RDMA", 4); msg.Serialize(data + 4); - ASSERT_EQ(RDMA_HELLO_MSG_LEN, write(acc_fd, data, RDMA_HELLO_MSG_LEN)); + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, write(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); usleep(100000); - ASSERT_EQ(rdma::RdmaEndpoint::ESTABLISHED, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::ESTABLISHED, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(4, read(acc_fd, data, 4)); uint32_t* tmp = (uint32_t*)data; ASSERT_EQ(1, butil::NetToHost32(*tmp)); @@ -1126,7 +1148,7 @@ TEST_F(RdmaTest, server_close_after_ack) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -1142,17 +1164,17 @@ TEST_F(RdmaTest, server_close_after_ack) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); ASSERT_TRUE(acc_fd >= 0); - uint8_t data[RDMA_HELLO_MSG_LEN]; - ASSERT_EQ(RDMA_HELLO_MSG_LEN, read(acc_fd, data, RDMA_HELLO_MSG_LEN)); + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); - rdma::HelloMessage msg; - msg.msg_len = RDMA_HELLO_MSG_LEN; - msg.hello_ver = RDMA_HELLO_VERSION; - msg.impl_ver = RDMA_IMPL_VERSION; + rdma::v2_wire::HelloMessage msg{}; + msg.msg_len = rdma::v2_wire::HELLO_MSG_LEN_MIN; + msg.hello_ver = rdma::RDMA_HELLO_V2_VERSION; + msg.impl_ver = rdma::RDMA_IMPL_V2_VERSION; msg.sq_size = 16; msg.rq_size = 16; msg.block_size = 8192; @@ -1160,10 +1182,10 @@ TEST_F(RdmaTest, server_close_after_ack) { msg.gid = rdma::GetRdmaGid(); memcpy(data, "RDMA", 4); msg.Serialize(data + 4); - ASSERT_EQ(RDMA_HELLO_MSG_LEN, write(acc_fd, data, RDMA_HELLO_MSG_LEN)); + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, write(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); usleep(100000); - ASSERT_EQ(rdma::RdmaEndpoint::ESTABLISHED, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::ESTABLISHED, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(4, read(acc_fd, data, 4)); uint32_t* tmp = (uint32_t*)data; ASSERT_EQ(1, butil::NetToHost32(*tmp)); @@ -1179,7 +1201,7 @@ TEST_F(RdmaTest, server_send_data_on_tcp_after_ack) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -1195,17 +1217,17 @@ TEST_F(RdmaTest, server_send_data_on_tcp_after_ack) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); ASSERT_TRUE(acc_fd >= 0); - uint8_t data[RDMA_HELLO_MSG_LEN]; - ASSERT_EQ(RDMA_HELLO_MSG_LEN, read(acc_fd, data, RDMA_HELLO_MSG_LEN)); + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); - rdma::HelloMessage msg; - msg.msg_len = RDMA_HELLO_MSG_LEN; - msg.hello_ver = RDMA_HELLO_VERSION; - msg.impl_ver = RDMA_IMPL_VERSION; + rdma::v2_wire::HelloMessage msg{}; + msg.msg_len = rdma::v2_wire::HELLO_MSG_LEN_MIN; + msg.hello_ver = rdma::RDMA_HELLO_V2_VERSION; + msg.impl_ver = rdma::RDMA_IMPL_V2_VERSION; msg.sq_size = 16; msg.rq_size = 16; msg.block_size = 8192; @@ -1213,23 +1235,528 @@ TEST_F(RdmaTest, server_send_data_on_tcp_after_ack) { msg.gid = rdma::GetRdmaGid(); memcpy(data, "RDMA", 4); msg.Serialize(data + 4); - ASSERT_EQ(RDMA_HELLO_MSG_LEN, write(acc_fd, data, RDMA_HELLO_MSG_LEN)); + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, write(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); usleep(100000); - ASSERT_EQ(rdma::RdmaEndpoint::ESTABLISHED, s->_rdma_ep->_state); - ASSERT_EQ(RDMA_HELLO_MSG_LEN, write(acc_fd, data, RDMA_HELLO_MSG_LEN)); + ASSERT_EQ(rdma::RdmaEndpoint::ESTABLISHED, static_cast(s->_transport.get())->_rdma_ep->_state); + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, write(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); bthread_id_join(cntl.call_id()); ASSERT_EQ(EPROTO, cntl.ErrorCode()); } + +TEST_F(RdmaTest, v2_client_hello_bytes_baseline) { + butil::fd_guard sockfd(butil::tcp_listen(g_ep)); + EXPECT_TRUE(sockfd >= 0); + + Channel channel; + ChannelOptions chan_options; + chan_options.socket_mode = SOCKET_MODE_RDMA; + chan_options.connect_timeout_ms = 500; + chan_options.timeout_ms = 500; + chan_options.max_retry = 0; + ASSERT_EQ(0, channel.Init(g_ep, &chan_options)); + + Controller cntl; + test::EchoRequest req; + test::EchoResponse res; + req.set_message(__FUNCTION__); + google::protobuf::Closure* done = DoNothing(); + ::test::EchoService::Stub(&channel).Echo(&cntl, &req, &res, done); + + usleep(100000); + SocketUniquePtr s; + ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); + + butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); + ASSERT_TRUE(acc_fd >= 0); + + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); + + // [0..4) magic + ASSERT_EQ(0, memcmp(data, "RDMA", 4)); + // [4..6) msg_len, big-endian uint16 == 40 + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, + (size_t)(((uint16_t)data[4] << 8) | (uint16_t)data[5])); + // [6..8) hello_ver, big-endian uint16 == rdma::RDMA_HELLO_V2_VERSION + ASSERT_EQ(rdma::RDMA_HELLO_V2_VERSION, + (uint16_t)(((uint16_t)data[6] << 8) | (uint16_t)data[7])); + // [8..10) impl_ver, big-endian uint16 == rdma::RDMA_IMPL_V2_VERSION + ASSERT_EQ(rdma::RDMA_IMPL_V2_VERSION, + (uint16_t)(((uint16_t)data[8] << 8) | (uint16_t)data[9])); + + rdma::v2_wire::HelloMessage msg{}; + msg.Deserialize(data + 4); + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, msg.msg_len); + ASSERT_EQ(rdma::RDMA_HELLO_V2_VERSION, msg.hello_ver); + ASSERT_EQ(rdma::RDMA_IMPL_V2_VERSION, msg.impl_ver); + + bthread_id_join(cntl.call_id()); +} + +TEST_F(RdmaTest, v2_server_hello_bytes_baseline) { + StartServer(); + + sockaddr_in addr; + bzero((char*)&addr, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(PORT); + + butil::fd_guard sockfd(socket(AF_INET, SOCK_STREAM, 0)); + ASSERT_TRUE(sockfd >= 0); + ASSERT_EQ(0, connect(sockfd, (sockaddr*)&addr, sizeof(sockaddr))); + usleep(100000); + Socket* s = GetSocketFromServer(0); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); + + // Send a well-formed v2 hello so the server enters S_ACK_WAIT. + rdma::v2_wire::HelloMessage msg{}; + msg.msg_len = rdma::v2_wire::HELLO_MSG_LEN_MIN; + msg.hello_ver = rdma::RDMA_HELLO_V2_VERSION; + msg.impl_ver = rdma::RDMA_IMPL_V2_VERSION; + msg.sq_size = 16; + msg.rq_size = 16; + msg.block_size = 8192; + msg.qp_num = 0; + msg.gid = rdma::GetRdmaGid(); + + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + memcpy(data, "RDMA", 4); + msg.Serialize(data + 4); + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, write(sockfd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); + usleep(100000); + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); + + // Read server's reply hello and assert its byte-level layout. + uint8_t reply[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(sockfd, reply, rdma::v2_wire::HELLO_MSG_LEN_MIN)); + + ASSERT_EQ(0, memcmp(reply, "RDMA", 4)); + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, + (size_t)(((uint16_t)reply[4] << 8) | (uint16_t)reply[5])); + ASSERT_EQ(rdma::RDMA_HELLO_V2_VERSION, + (uint16_t)(((uint16_t)reply[6] << 8) | (uint16_t)reply[7])); + ASSERT_EQ(rdma::RDMA_IMPL_V2_VERSION, + (uint16_t)(((uint16_t)reply[8] << 8) | (uint16_t)reply[9])); + + rdma::v2_wire::HelloMessage reply_msg{}; + reply_msg.Deserialize(reply + 4); + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, reply_msg.msg_len); + ASSERT_EQ(rdma::RDMA_HELLO_V2_VERSION, reply_msg.hello_ver); + ASSERT_EQ(rdma::RDMA_IMPL_V2_VERSION, reply_msg.impl_ver); + + // Drive the server into FALLBACK_TCP via ACK flags=0 so the test ends + // cleanly without requiring real RDMA hardware. + uint32_t flags = butil::HostToNet32(0); + ASSERT_EQ(sizeof(flags), write(sockfd, &flags, sizeof(flags))); + usleep(100000); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); + + sockfd.reset(-1); + usleep(100000); + ASSERT_EQ(NULL, GetSocketFromServer(0)); + + StopServer(); +} + +TEST_F(RdmaTest, v2_server_drains_tail_then_reads_ack) { + StartServer(); + + sockaddr_in addr; + bzero((char*)&addr, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(PORT); + butil::fd_guard sockfd(socket(AF_INET, SOCK_STREAM, 0)); + ASSERT_TRUE(sockfd >= 0); + ASSERT_EQ(0, connect(sockfd, (sockaddr*)&addr, sizeof(sockaddr))); + usleep(100000); + Socket* s = GetSocketFromServer(0); + ASSERT_TRUE(s != NULL); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); + + // Build a v2 hello with msg_len = 48 (40 base + 8B zero tail). + rdma::v2_wire::HelloMessage msg{}; + msg.msg_len = 48; + msg.hello_ver = rdma::RDMA_HELLO_V2_VERSION; + msg.impl_ver = rdma::RDMA_IMPL_V2_VERSION; + msg.sq_size = 16; + msg.rq_size = 16; + msg.block_size = 8192; + msg.qp_num = 0; + msg.gid = rdma::GetRdmaGid(); + + uint8_t buf[48]; + memcpy(buf, "RDMA", 4); + msg.Serialize(buf + 4); + memset(buf + 40, 0x00, 8); // 8B zero tail + ASSERT_EQ(48, write(sockfd, buf, 48)); + usleep(100000); + + // Send the real ACK (flags=1 = ACK_MSG_RDMA_OK). + uint32_t flags = butil::HostToNet32(1); + ASSERT_EQ(sizeof(flags), write(sockfd, &flags, sizeof(flags))); + usleep(100000); + + ASSERT_EQ(rdma::RdmaEndpoint::ESTABLISHED, static_cast(s->_transport.get())->_rdma_ep->_state); + + sockfd.reset(-1); + usleep(100000); + ASSERT_EQ(NULL, GetSocketFromServer(0)); + + StopServer(); +} + +TEST_F(RdmaTest, v2_server_rejects_oversized_msg_len) { + StartServer(); + + sockaddr_in addr; + bzero((char*)&addr, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(PORT); + butil::fd_guard sockfd(socket(AF_INET, SOCK_STREAM, 0)); + ASSERT_TRUE(sockfd >= 0); + ASSERT_EQ(0, connect(sockfd, (sockaddr*)&addr, sizeof(sockaddr))); + usleep(100000); + Socket* s = GetSocketFromServer(0); + ASSERT_TRUE(s != NULL); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); + + // Build a v2 hello with msg_len = 4097 (HELLO_MSG_LEN_MAX + 1). + // We only send the 40B base; the server must reject before reading + // (and definitely before attempting to drain) any "tail". + rdma::v2_wire::HelloMessage msg{}; + msg.msg_len = 4097; + msg.hello_ver = rdma::RDMA_HELLO_V2_VERSION; + msg.impl_ver = rdma::RDMA_IMPL_V2_VERSION; + msg.sq_size = 16; + msg.rq_size = 16; + msg.block_size = 8192; + msg.qp_num = 0; + msg.gid = rdma::GetRdmaGid(); + + uint8_t buf[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + memcpy(buf, "RDMA", 4); + msg.Serialize(buf + 4); + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, write(sockfd, buf, rdma::v2_wire::HELLO_MSG_LEN_MIN)); + usleep(100000); + + + ASSERT_EQ(NULL, GetSocketFromServer(0)); + + sockfd.reset(-1); + usleep(100000); + + StopServer(); +} + +// RAII for FLAGS_rdma_client_handshake_version: lets us flip the +// client-side handshake version for a single test and restore it on +// scope exit so subsequent tests stay on the v2 default. +class HandshakeVersionFlag { +public: + explicit HandshakeVersionFlag(int v) + : _saved(rdma::FLAGS_rdma_client_handshake_version) { + rdma::FLAGS_rdma_client_handshake_version = v; + } + ~HandshakeVersionFlag() { + rdma::FLAGS_rdma_client_handshake_version = _saved; + } +private: + int _saved; +}; + +// Build a v3 wire packet from an RdmaHello: "RDM3" + pb_size_be + body. +std::string MakeV3Packet(const rdma::RdmaHello& msg) { + std::string body; + EXPECT_TRUE(msg.SerializeToString(&body)); + std::string packet; + packet.reserve(4 + 4 + body.size()); + packet.append("RDM3", 4); + uint32_t pb_size_be = + butil::HostToNet32(static_cast(body.size())); + packet.append(reinterpret_cast(&pb_size_be), 4); + packet.append(body); + return packet; +} + +// Build a fully-valid RdmaHello: all 6 required fields are set, with +// values that pass RdmaHelloV3Wire::RdmaHelloValid(). +// - block_size = 8192 (>= MIN_BLOCK_SIZE) +// - sq_size / rq_size = 16 (>= MIN_QP_SIZE) +// - gid = exactly 16B (sizeof(ibv_gid)) +// - qp_num = 0 (allowed because g_skip_rdma_init in UT) +rdma::RdmaHello MakeValidV3Hello() { + rdma::RdmaHello msg; + msg.set_block_size(8192); + msg.set_sq_size(16); + msg.set_rq_size(16); + msg.set_lid(0); + ibv_gid gid = rdma::GetRdmaGid(); + msg.set_gid(std::string(reinterpret_cast(gid.raw), + sizeof(gid.raw))); + msg.set_qp_num(0); + return msg; +} + + +TEST_F(RdmaTest, v3_client_hello_bytes_baseline) { + HandshakeVersionFlag _hsv(3); + + butil::fd_guard sockfd(butil::tcp_listen(g_ep)); + EXPECT_TRUE(sockfd >= 0); + + Channel channel; + ChannelOptions chan_options; + chan_options.socket_mode = SOCKET_MODE_RDMA; + chan_options.connect_timeout_ms = 500; + chan_options.timeout_ms = 500; + chan_options.max_retry = 0; + ASSERT_EQ(0, channel.Init(g_ep, &chan_options)); + + Controller cntl; + test::EchoRequest req; + test::EchoResponse res; + req.set_message(__FUNCTION__); + google::protobuf::Closure* done = DoNothing(); + ::test::EchoService::Stub(&channel).Echo(&cntl, &req, &res, done); + + butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); + ASSERT_TRUE(acc_fd >= 0); + + // [0..4) magic "RDM3" + uint8_t magic[4]; + ASSERT_EQ(4, read(acc_fd, magic, 4)); + ASSERT_EQ(0, memcmp(magic, "RDM3", 4)); + + // [4..8) pb_size, big-endian uint32, must be in (0, 4096] + uint8_t size_buf[4]; + ASSERT_EQ(4, read(acc_fd, size_buf, 4)); + uint32_t pb_size = + butil::NetToHost32(*reinterpret_cast(size_buf)); + ASSERT_GT(pb_size, 0u); + ASSERT_LE(pb_size, 4096u); + + // [8..8+pb_size) RdmaHello protobuf body. + std::string body(pb_size, '\0'); + ASSERT_EQ((ssize_t)pb_size, read(acc_fd, &body[0], pb_size)); + rdma::RdmaHello msg; + ASSERT_TRUE(msg.ParseFromString(body)); + + // All 6 required fields must be present (ParseFromString would + // have already returned false otherwise). + ASSERT_TRUE(msg.has_block_size()); + ASSERT_TRUE(msg.has_sq_size()); + ASSERT_TRUE(msg.has_rq_size()); + ASSERT_TRUE(msg.has_lid()); + ASSERT_TRUE(msg.has_gid()); + ASSERT_TRUE(msg.has_qp_num()); + // gid wire encoding must be exactly 16 bytes (sizeof(ibv_gid)). + ASSERT_EQ(sizeof(ibv_gid), msg.gid().size()); + + // Let the RPC time out and release resources. + bthread_id_join(cntl.call_id()); +} + +TEST_F(RdmaTest, v3_server_hello_bytes_baseline) { + StartServer(); + + sockaddr_in addr; + bzero((char*)&addr, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(PORT); + + butil::fd_guard sockfd(socket(AF_INET, SOCK_STREAM, 0)); + ASSERT_TRUE(sockfd >= 0); + ASSERT_EQ(0, connect(sockfd, (sockaddr*)&addr, sizeof(sockaddr))); + usleep(100000); + Socket* s = GetSocketFromServer(0); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); + + // Send a valid v3 hello. + std::string packet = MakeV3Packet(MakeValidV3Hello()); + ASSERT_EQ((ssize_t)packet.size(), + write(sockfd, packet.data(), packet.size())); + usleep(100000); + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); + + // Read server's reply hello: 4B magic + 4B pb_size + body. + uint8_t reply_magic[4]; + ASSERT_EQ(4, read(sockfd, reply_magic, 4)); + ASSERT_EQ(0, memcmp(reply_magic, "RDM3", 4)); + + uint8_t size_buf[4]; + ASSERT_EQ(4, read(sockfd, size_buf, 4)); + uint32_t pb_size = + butil::NetToHost32(*reinterpret_cast(size_buf)); + ASSERT_GT(pb_size, 0u); + ASSERT_LE(pb_size, 4096u); + + std::string body(pb_size, '\0'); + ASSERT_EQ((ssize_t)pb_size, read(sockfd, &body[0], pb_size)); + rdma::RdmaHello reply; + ASSERT_TRUE(reply.ParseFromString(body)); + ASSERT_TRUE(reply.has_block_size()); + ASSERT_TRUE(reply.has_sq_size()); + ASSERT_TRUE(reply.has_rq_size()); + ASSERT_TRUE(reply.has_gid()); + ASSERT_EQ(sizeof(ibv_gid), reply.gid().size()); + + // Drive the server into FALLBACK_TCP via ACK flags=0 so the test ends + // cleanly without requiring real RDMA hardware. + uint32_t flags = butil::HostToNet32(0); + ASSERT_EQ((ssize_t)sizeof(flags), + write(sockfd, &flags, sizeof(flags))); + usleep(100000); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); + + sockfd.reset(-1); + usleep(100000); + ASSERT_EQ(NULL, GetSocketFromServer(0)); + + StopServer(); +} + +TEST_F(RdmaTest, v3_server_rejects_zero_pb_size) { + StartServer(); + + sockaddr_in addr; + bzero((char*)&addr, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(PORT); + butil::fd_guard sockfd(socket(AF_INET, SOCK_STREAM, 0)); + ASSERT_TRUE(sockfd >= 0); + ASSERT_EQ(0, connect(sockfd, (sockaddr*)&addr, sizeof(sockaddr))); + usleep(100000); + Socket* s = GetSocketFromServer(0); + ASSERT_TRUE(s != NULL); + + // "RDM3" + pb_size = 0 (4B big-endian zero). + uint8_t buf[8] = {'R', 'D', 'M', '3', 0, 0, 0, 0}; + ASSERT_EQ(8, write(sockfd, buf, 8)); + usleep(100000); + + ASSERT_EQ(NULL, GetSocketFromServer(0)); + + sockfd.reset(-1); + StopServer(); +} + +TEST_F(RdmaTest, v3_server_rejects_oversized_pb_size) { + StartServer(); + + sockaddr_in addr; + bzero((char*)&addr, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(PORT); + butil::fd_guard sockfd(socket(AF_INET, SOCK_STREAM, 0)); + ASSERT_TRUE(sockfd >= 0); + ASSERT_EQ(0, connect(sockfd, (sockaddr*)&addr, sizeof(sockaddr))); + usleep(100000); + Socket* s = GetSocketFromServer(0); + ASSERT_TRUE(s != NULL); + + uint8_t buf[8]; + memcpy(buf, "RDM3", 4); + uint32_t pb_size_be = butil::HostToNet32(4097); + memcpy(buf + 4, &pb_size_be, 4); + ASSERT_EQ(8, write(sockfd, buf, 8)); + usleep(100000); + + ASSERT_EQ(NULL, GetSocketFromServer(0)); + + sockfd.reset(-1); + StopServer(); +} + +TEST_F(RdmaTest, v3_server_rejects_invalid_pb_bytes) { + StartServer(); + + sockaddr_in addr; + bzero((char*)&addr, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(PORT); + butil::fd_guard sockfd(socket(AF_INET, SOCK_STREAM, 0)); + ASSERT_TRUE(sockfd >= 0); + ASSERT_EQ(0, connect(sockfd, (sockaddr*)&addr, sizeof(sockaddr))); + usleep(100000); + Socket* s = GetSocketFromServer(0); + ASSERT_TRUE(s != NULL); + + // "RDM3" + pb_size = 8 + 8 bytes of 0xff (invalid protobuf body). + uint8_t buf[16]; + memcpy(buf, "RDM3", 4); + uint32_t pb_size_be = butil::HostToNet32(8); + memcpy(buf + 4, &pb_size_be, 4); + memset(buf + 8, 0xff, 8); + ASSERT_EQ(16, write(sockfd, buf, 16)); + usleep(100000); + + ASSERT_EQ(NULL, GetSocketFromServer(0)); + + sockfd.reset(-1); + StopServer(); +} + +TEST_F(RdmaTest, v3_server_invalid_sq_size_falls_back) { + StartServer(); + + sockaddr_in addr; + bzero((char*)&addr, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(PORT); + butil::fd_guard sockfd(socket(AF_INET, SOCK_STREAM, 0)); + ASSERT_TRUE(sockfd >= 0); + ASSERT_EQ(0, connect(sockfd, (sockaddr*)&addr, sizeof(sockaddr))); + usleep(100000); + Socket* s = GetSocketFromServer(0); + ASSERT_TRUE(s != NULL); + + rdma::RdmaHello msg = MakeValidV3Hello(); + msg.set_sq_size(0); // invalid: < MIN_QP_SIZE (16) + std::string packet = MakeV3Packet(msg); + ASSERT_EQ((ssize_t)packet.size(), + write(sockfd, packet.data(), packet.size())); + usleep(100000); + + // Server validated the hello as invalid -> _rdma_state = RDMA_OFF, + // but still proceeds to S_ACK_WAIT (sends its own reply hello). + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); + ASSERT_EQ(RdmaTransport::RDMA_OFF, static_cast(s->_transport.get())->_rdma_state); + + // Drain server's reply hello (content not asserted here; covered + // by v3_server_hello_bytes_baseline). + uint8_t reply_hdr[8]; + ASSERT_EQ(8, read(sockfd, reply_hdr, 8)); + ASSERT_EQ(0, memcmp(reply_hdr, "RDM3", 4)); + uint32_t reply_pb_size = butil::NetToHost32( + *reinterpret_cast(reply_hdr + 4)); + std::string reply_body(reply_pb_size, '\0'); + ASSERT_EQ((ssize_t)reply_pb_size, + read(sockfd, &reply_body[0], reply_pb_size)); + + // Client ACK flags=0 -> server settles into FALLBACK_TCP. + uint32_t flags = butil::HostToNet32(0); + ASSERT_EQ((ssize_t)sizeof(flags), + write(sockfd, &flags, sizeof(flags))); + usleep(100000); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); + + sockfd.reset(-1); + usleep(100000); + ASSERT_EQ(NULL, GetSocketFromServer(0)); + + StopServer(); +} + TEST_F(RdmaTest, try_global_disable_rdma) { StartServer(); rdma::g_rdma_available.store(false, butil::memory_order_relaxed); Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -1245,7 +1772,7 @@ TEST_F(RdmaTest, try_global_disable_rdma) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); bthread_id_join(cntl.call_id()); ASSERT_EQ(0, cntl.ErrorCode()); @@ -1256,7 +1783,7 @@ TEST_F(RdmaTest, try_global_disable_rdma) { TEST_F(RdmaTest, server_option_invalid) { Server server; ServerOptions options; - options.use_rdma = true; + options.socket_mode = SOCKET_MODE_RDMA; // rtmp and rdma are incompatible options.rtmp_service = (RtmpService*)1; @@ -1281,7 +1808,7 @@ TEST_F(RdmaTest, server_option_invalid) { TEST_F(RdmaTest, channel_option_invalid) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; // rtmp and rdma are incompatible chan_options.protocol = "rtmp"; @@ -1342,7 +1869,7 @@ TEST_F(RdmaTest, channel_option_invalid) { ASSERT_EQ(-1, channel.Init(g_ep, &chan_options)); } -TEST_F(RdmaTest, rdma_client_to_rdma_server) { +TEST_P(RdmaRpcTest, rdma_client_to_rdma_server) { if (!FLAGS_rdma_test_enable) { return; } @@ -1351,7 +1878,7 @@ TEST_F(RdmaTest, rdma_client_to_rdma_server) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -1362,14 +1889,14 @@ TEST_F(RdmaTest, rdma_client_to_rdma_server) { req.set_message(__FUNCTION__); google::protobuf::Closure* done = DoNothing(); ::test::EchoService::Stub(&channel).Echo(&cntl, &req, &res, done); - usleep(100000); + // usleep(100000); bthread_id_join(cntl.call_id()); ASSERT_EQ(0, cntl.ErrorCode()); StopServer(); } -TEST_F(RdmaTest, tcp_client_to_tcp_server) { +TEST_P(RdmaRpcTest, tcp_client_to_tcp_server) { StartServer(false); Channel channel; @@ -1391,7 +1918,7 @@ TEST_F(RdmaTest, tcp_client_to_tcp_server) { StopServer(); } -TEST_F(RdmaTest, tcp_client_to_rdma_server) { +TEST_P(RdmaRpcTest, tcp_client_to_rdma_server) { StartServer(); Channel channel; @@ -1413,12 +1940,12 @@ TEST_F(RdmaTest, tcp_client_to_rdma_server) { StopServer(); } -TEST_F(RdmaTest, rdma_client_to_tcp_server) { +TEST_P(RdmaRpcTest, rdma_client_to_tcp_server) { StartServer(false); Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -1440,12 +1967,12 @@ static const int RPC_NUM = 1024; void DumpRdmaEndpointInfo(Socket* client, Socket* server) { std::cout << std::endl << "client:"; - client->_rdma_ep->DebugInfo(std::cout); + static_cast(client->_transport.get())->_rdma_ep->DebugInfo(std::cout); std::cout << std::endl << "server:"; - server->_rdma_ep->DebugInfo(std::cout); + static_cast(server->_transport.get())->_rdma_ep->DebugInfo(std::cout); } -TEST_F(RdmaTest, send_rpcs_in_one_qp) { +TEST_P(RdmaRpcTest, send_rpcs_in_one_qp) { if (!FLAGS_rdma_test_enable) { return; } @@ -1454,9 +1981,9 @@ TEST_F(RdmaTest, send_rpcs_in_one_qp) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; - chan_options.timeout_ms = 5000; + chan_options.timeout_ms = 50000; chan_options.max_retry = 0; ASSERT_EQ(0, channel.Init(g_ep, &chan_options)); Controller cntl[RPC_NUM]; @@ -1516,50 +2043,57 @@ TEST_F(RdmaTest, send_rpcs_in_one_qp) { Socket* m = GetSocketFromServer(0); DumpRdmaEndpointInfo(s.get(), m); } - ASSERT_TRUE(0 == cntl[i].ErrorCode() || EOVERCROWDED == cntl[i].ErrorCode()) - << "req[" << i << "] " << berror(cntl[i].ErrorCode()); + ASSERT_TRUE(0 == cntl[i].ErrorCode() || + EOVERCROWDED == cntl[i].ErrorCode()) << "req[" << i << "] " << berror(cntl[i].ErrorCode()); } + SocketUniquePtr s; + ASSERT_EQ(0, Socket::Address(cntl[0]._single_server_id, &s)); + Socket* m = GetSocketFromServer(0); + DumpRdmaEndpointInfo(s.get(), m); + StopServer(); } -TEST_F(RdmaTest, send_rpc_in_many_qp) { +TEST_P(RdmaRpcTest, send_rpc_in_many_qp) { if (!FLAGS_rdma_test_enable) { return; } + butil::ip_t ip; + ASSERT_EQ(0, butil::str2ip(g_ip.c_str(), &ip)); + Server server[100]; MyEchoService svc[100]; int num = 100; + butil::EndPoint server_eps[100]; for (int i = 0; i < num; ++i) { ServerOptions options; - options.use_rdma = true; + options.socket_mode = SOCKET_MODE_RDMA; options.idle_timeout_sec = 1; options.max_concurrency = 0; options.internal_port = -1; server[i].AddService(&svc[i], SERVER_DOESNT_OWN_SERVICE); - EXPECT_EQ(0, server[i].Start(i + 8000, &options)); + ASSERT_EQ(0, server[i].Start(0, &options)); + server_eps[i] = butil::EndPoint(ip, server[i].listen_address().port); } int port = 0; butil::IOBuf attach; attach.resize(4096); ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; - chan_options.timeout_ms = 500; + chan_options.timeout_ms = 100000; chan_options.max_retry = 0; Channel channel[RPC_NUM]; Server* svr[RPC_NUM]; Controller cntl[RPC_NUM]; test::EchoRequest req[RPC_NUM]; test::EchoResponse res[RPC_NUM]; - butil::ip_t ip; - butil::str2ip(g_ip.c_str(), &ip); for (int i = 0; i < RPC_NUM; ++i) { svr[i] = &server[i % num]; - butil::EndPoint ep(ip, 8000 + ((port++) % num)); - ASSERT_EQ(0, channel[i].Init(ep, &chan_options)); + ASSERT_EQ(0, channel[i].Init(server_eps[(port++) % num], &chan_options)); req[i].set_message(__FUNCTION__); cntl[i].request_attachment().append(attach); google::protobuf::Closure* done = DoNothing(); @@ -1569,16 +2103,19 @@ TEST_F(RdmaTest, send_rpc_in_many_qp) { bthread_id_join(cntl[i].call_id()); if (cntl[i].ErrorCode() == ERPCTIMEDOUT) { SocketUniquePtr s; - ASSERT_EQ(0, Socket::Address(cntl[i]._single_server_id, &s)); - std::vector sids; - svr[i]->_am->ListConnections(&sids); - for (size_t i = 0; i < sids.size(); ++i) { - SocketUniquePtr m; - ASSERT_EQ(0, Socket::AddressFailedAsWell(sids[i], &m)); - DumpRdmaEndpointInfo(s.get(), m.get()); + EXPECT_EQ(0, Socket::Address(cntl[i]._single_server_id, &s)); + if (s && svr[i] && svr[i]->_am) { + std::vector sids; + svr[i]->_am->ListConnections(&sids); + for (size_t j = 0; j < sids.size(); ++j) { + SocketUniquePtr m; + if (Socket::AddressFailedAsWell(sids[j], &m) == 0) { + DumpRdmaEndpointInfo(s.get(), m.get()); + } + } } } - ASSERT_EQ(0, cntl[i].ErrorCode()) << "req[" << i << "]"; + EXPECT_EQ(0, cntl[i].ErrorCode()) << "req[" << i << "]"; } for (int i = 0; i < num; ++i) { @@ -1587,7 +2124,7 @@ TEST_F(RdmaTest, send_rpc_in_many_qp) { } } -TEST_F(RdmaTest, send_rpcs_as_pooled_connection) { +TEST_P(RdmaRpcTest, send_rpcs_as_pooled_connection) { if (!FLAGS_rdma_test_enable) { return; } @@ -1596,7 +2133,7 @@ TEST_F(RdmaTest, send_rpcs_as_pooled_connection) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 30000; // it may very slow chan_options.timeout_ms = 30000; chan_options.max_retry = 0; @@ -1628,7 +2165,7 @@ TEST_F(RdmaTest, send_rpcs_as_pooled_connection) { StopServer(); } -TEST_F(RdmaTest, send_rpcs_as_short_connection) { +TEST_P(RdmaRpcTest, send_rpcs_as_short_connection) { if (!FLAGS_rdma_test_enable) { return; } @@ -1637,7 +2174,7 @@ TEST_F(RdmaTest, send_rpcs_as_short_connection) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 30000; // it may very slow chan_options.timeout_ms = 30000; chan_options.max_retry = 0; @@ -1669,7 +2206,7 @@ TEST_F(RdmaTest, send_rpcs_as_short_connection) { StopServer(); } -TEST_F(RdmaTest, server_stop_during_rpc) { +TEST_P(RdmaRpcTest, server_stop_during_rpc) { if (!FLAGS_rdma_test_enable) { return; } @@ -1678,7 +2215,7 @@ TEST_F(RdmaTest, server_stop_during_rpc) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 3000; chan_options.max_retry = 0; @@ -1707,7 +2244,7 @@ TEST_F(RdmaTest, server_stop_during_rpc) { } } -TEST_F(RdmaTest, server_close_during_rpc) { +TEST_P(RdmaRpcTest, server_close_during_rpc) { if (!FLAGS_rdma_test_enable) { return; } @@ -1716,7 +2253,7 @@ TEST_F(RdmaTest, server_close_during_rpc) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 3000; chan_options.max_retry = 0; @@ -1749,7 +2286,7 @@ TEST_F(RdmaTest, server_close_during_rpc) { StopServer(); } -TEST_F(RdmaTest, client_close_during_rpc) { +TEST_P(RdmaRpcTest, client_close_during_rpc) { if (!FLAGS_rdma_test_enable) { return; } @@ -1758,7 +2295,7 @@ TEST_F(RdmaTest, client_close_during_rpc) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 3000; chan_options.max_retry = 0; @@ -1789,7 +2326,7 @@ TEST_F(RdmaTest, client_close_during_rpc) { StopServer(); } -TEST_F(RdmaTest, verbs_error_handling) { +TEST_P(RdmaRpcTest, verbs_error_handling) { if (!FLAGS_rdma_test_enable) { return; } @@ -1798,7 +2335,7 @@ TEST_F(RdmaTest, verbs_error_handling) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -1826,7 +2363,8 @@ TEST_F(RdmaTest, verbs_error_handling) { wr.sg_list = &sge; wr.num_sge = 1; ibv_send_wr* bad = NULL; - ibv_post_send(s->_rdma_ep->_resource->qp, &wr, &bad); + auto rdma_transport = static_cast(s->_transport.get()); + ibv_post_send(rdma_transport->_rdma_ep->_resource->qp, &wr, &bad); bthread_id_join(cntl.call_id()); ASSERT_EQ(ERDMA, cntl.ErrorCode()); free(buf); @@ -1834,7 +2372,7 @@ TEST_F(RdmaTest, verbs_error_handling) { StopServer(); } -TEST_F(RdmaTest, rdma_use_parallel_channel) { +TEST_P(RdmaRpcTest, rdma_use_parallel_channel) { if (!FLAGS_rdma_test_enable) { return; } @@ -1845,13 +2383,14 @@ TEST_F(RdmaTest, rdma_use_parallel_channel) { Channel subchans[NCHANS]; ParallelChannel channel; ChannelOptions opts; - opts.use_rdma = true; + opts.socket_mode = SOCKET_MODE_RDMA; for (size_t i = 0; i < NCHANS; ++i) { ASSERT_EQ(0, subchans[i].Init(_naming_url.c_str(), "rR", &opts)); ASSERT_EQ(0, channel.AddChannel( &subchans[i], DOESNT_OWN_CHANNEL, NULL, NULL)); } + ASSERT_EQ(0, channel.Init(NULL)); Controller cntl; test::EchoRequest req; @@ -1865,7 +2404,7 @@ TEST_F(RdmaTest, rdma_use_parallel_channel) { StopServer(); } -TEST_F(RdmaTest, rdma_use_selective_channel) { +TEST_P(RdmaRpcTest, rdma_use_selective_channel) { if (!FLAGS_rdma_test_enable) { return; } @@ -1875,7 +2414,7 @@ TEST_F(RdmaTest, rdma_use_selective_channel) { const size_t NCHANS = 8; SelectiveChannel channel; ChannelOptions opts; - opts.use_rdma = true; + opts.socket_mode = SOCKET_MODE_RDMA; ASSERT_EQ(0, channel.Init("rr", &opts)); for (size_t i = 0; i < NCHANS; ++i) { Channel* subchan = new Channel; @@ -1897,7 +2436,7 @@ TEST_F(RdmaTest, rdma_use_selective_channel) { static void MockFree(void* buf) { } -TEST_F(RdmaTest, send_rpcs_with_user_defined_iobuf) { +TEST_P(RdmaRpcTest, send_rpcs_with_user_defined_iobuf) { if (!FLAGS_rdma_test_enable) { return; } @@ -1906,7 +2445,7 @@ TEST_F(RdmaTest, send_rpcs_with_user_defined_iobuf) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -1961,7 +2500,7 @@ TEST_F(RdmaTest, send_rpcs_with_user_defined_iobuf) { StopServer(); } -TEST_F(RdmaTest, try_memory_pool_empty) { +TEST_P(RdmaRpcTest, try_memory_pool_empty) { if (!FLAGS_rdma_test_enable) { return; } @@ -1970,7 +2509,7 @@ TEST_F(RdmaTest, try_memory_pool_empty) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 60000; chan_options.max_retry = 0; @@ -2000,6 +2539,19 @@ TEST_F(RdmaTest, try_memory_pool_empty) { StopServer(); } +// Run every TEST_P(RdmaRpcTest, ...) above twice: once with the +// client-side handshake forced to v2 ("RDMA" magic + fixed-layout +// HelloMessage), once with v3 ("RDM3" magic + protobuf RdmaHello). +// The server always accepts both via magic-byte dispatch, so this +// proves the upper-layer RPC paths behave identically under either +// wire format. +INSTANTIATE_TEST_SUITE_P( + HandshakeVersion, RdmaRpcTest, + ::testing::Values(2, 3), + [](const ::testing::TestParamInfo& info) { + return std::string("v") + std::to_string(info.param); + }); + #endif // if BRPC_WITH_RDMA int main(int argc, char* argv[]) { diff --git a/test/bvar_percentile_unittest.cpp b/test/bvar_percentile_unittest.cpp index f647e272ba..d9d01846a1 100644 --- a/test/bvar_percentile_unittest.cpp +++ b/test/bvar_percentile_unittest.cpp @@ -28,6 +28,7 @@ class PercentileTest : public testing::Test { void TearDown() {} }; +#if !WITH_BABYLON_COUNTER TEST_F(PercentileTest, add) { bvar::detail::Percentile p; for (int j = 0; j < 10; ++j) { @@ -51,6 +52,7 @@ TEST_F(PercentileTest, add) { b.describe(out); } } +#endif // !WITH_BABYLON_COUNTER TEST_F(PercentileTest, merge1) { // Merge 2 PercentileIntervals b1 and b2. b2 has double SAMPLE_SIZE