diff --git a/src/brpc/controller.cpp b/src/brpc/controller.cpp index 2522745403..dfa4ce1f55 100644 --- a/src/brpc/controller.cpp +++ b/src/brpc/controller.cpp @@ -274,6 +274,7 @@ void Controller::ResetPods() { _request_stream = INVALID_STREAM_ID; _response_stream = INVALID_STREAM_ID; _remote_stream_settings = NULL; + _auth_flags = 0; } Controller::Call::Call(Controller::Call* rhs) @@ -1162,7 +1163,7 @@ void Controller::IssueRPC(int64_t start_realtime_us) { wopt.id_wait = cid; wopt.abstime = pabstime; wopt.pipelined_count = _pipelined_count; - wopt.with_auth = has_flag(FLAGS_REQUEST_WITH_AUTH); + wopt.auth_flags = _auth_flags; wopt.ignore_eovercrowded = has_flag(FLAGS_IGNORE_EOVERCROWDED); int rc; size_t packet_size = 0; diff --git a/src/brpc/controller.h b/src/brpc/controller.h index 7fb3596520..41f1267de7 100755 --- a/src/brpc/controller.h +++ b/src/brpc/controller.h @@ -138,7 +138,6 @@ friend void policy::ProcessThriftRequest(InputMessageBase*); static const uint32_t FLAGS_PB_BYTES_TO_BASE64 = (1 << 11); static const uint32_t FLAGS_ALLOW_DONE_TO_RUN_IN_PLACE = (1 << 12); static const uint32_t FLAGS_USED_BY_RPC = (1 << 13); - static const uint32_t FLAGS_REQUEST_WITH_AUTH = (1 << 15); static const uint32_t FLAGS_PB_JSONIFY_EMPTY_ARRAY = (1 << 16); static const uint32_t FLAGS_ENABLED_CIRCUIT_BREAKER = (1 << 17); static const uint32_t FLAGS_ALWAYS_PRINT_PRIMITIVE_FIELDS = (1 << 18); @@ -807,6 +806,8 @@ friend void policy::ProcessThriftRequest(InputMessageBase*); // Thrift method name, only used when thrift protocol enabled std::string _thrift_method_name; + + uint32_t _auth_flags; }; // Advises the RPC system that the caller desires that the RPC call be diff --git a/src/brpc/details/controller_private_accessor.h b/src/brpc/details/controller_private_accessor.h index 990efc915d..1be7df8bb8 100644 --- a/src/brpc/details/controller_private_accessor.h +++ b/src/brpc/details/controller_private_accessor.h @@ -128,13 +128,11 @@ class ControllerPrivateAccessor { void set_readable_progressive_attachment(ReadableProgressiveAttachment* s) { _cntl->_rpa.reset(s); } - void add_with_auth() { - _cntl->add_flag(Controller::FLAGS_REQUEST_WITH_AUTH); + void set_auth_flags(uint32_t auth_flags) { + _cntl->_auth_flags = auth_flags; } - void clear_with_auth() { - _cntl->clear_flag(Controller::FLAGS_REQUEST_WITH_AUTH); - } + void clear_auth_flags() { _cntl->_auth_flags = 0; } std::string& protocol_param() { return _cntl->protocol_param(); } const std::string& protocol_param() const { return _cntl->protocol_param(); } diff --git a/src/brpc/policy/redis_authenticator.cpp b/src/brpc/policy/redis_authenticator.cpp index b6921f78ae..152d9832e6 100644 --- a/src/brpc/policy/redis_authenticator.cpp +++ b/src/brpc/policy/redis_authenticator.cpp @@ -28,7 +28,12 @@ namespace policy { int RedisAuthenticator::GenerateCredential(std::string* auth_str) const { butil::IOBuf buf; - brpc::RedisCommandFormat(&buf, "AUTH %s", passwd_.c_str()); + if (!passwd_.empty()) { + brpc::RedisCommandFormat(&buf, "AUTH %s", passwd_.c_str()); + } + if (db_ >= 0) { + brpc::RedisCommandFormat(&buf, "SELECT %d", db_); + } *auth_str = buf.to_string(); return 0; } diff --git a/src/brpc/policy/redis_authenticator.h b/src/brpc/policy/redis_authenticator.h index 739f93460c..8359811de1 100644 --- a/src/brpc/policy/redis_authenticator.h +++ b/src/brpc/policy/redis_authenticator.h @@ -26,8 +26,8 @@ namespace policy { // Request to redis for authentication. class RedisAuthenticator : public Authenticator { public: - RedisAuthenticator(const std::string& passwd) - : passwd_(passwd) {} + RedisAuthenticator(const std::string& passwd, int db = -1) + : passwd_(passwd), db_(db) {} int GenerateCredential(std::string* auth_str) const; @@ -36,8 +36,21 @@ class RedisAuthenticator : public Authenticator { return 0; } + uint32_t GetAuthFlags() const { + uint32_t n = 0; + if (!passwd_.empty()) { + ++n; + } + if (db_ >= 0) { + ++n; + } + return n; + } + private: const std::string passwd_; + + int db_; }; } // namespace policy diff --git a/src/brpc/policy/redis_protocol.cpp b/src/brpc/policy/redis_protocol.cpp index 7cd705a90b..67e5213307 100644 --- a/src/brpc/policy/redis_protocol.cpp +++ b/src/brpc/policy/redis_protocol.cpp @@ -20,6 +20,7 @@ #include // MethodDescriptor #include // Message #include +#include "brpc/policy/redis_authenticator.h" #include "butil/logging.h" // LOG() #include "butil/time.h" #include "butil/iobuf.h" // butil::IOBuf @@ -214,7 +215,7 @@ ParseResult ParseRedisMessage(butil::IOBuf* source, Socket* socket, socket->reset_parsing_context(msg); } - const int consume_count = (pi.with_auth ? 1 : pi.count); + const int consume_count = (pi.auth_flags ? pi.auth_flags : pi.count); ParseError err = msg->response.ConsumePartialIOBuf(*source, consume_count); if (err != PARSE_OK) { @@ -222,18 +223,21 @@ ParseResult ParseRedisMessage(butil::IOBuf* source, Socket* socket, return MakeParseError(err); } - if (pi.with_auth) { - if (msg->response.reply_size() != 1 || - !(msg->response.reply(0).type() == brpc::REDIS_REPLY_STATUS && - msg->response.reply(0).data().compare("OK") == 0)) { - LOG(ERROR) << "Redis Auth failed: " << msg->response; - return MakeParseError(PARSE_ERROR_NO_RESOURCE, - "Fail to authenticate with Redis"); + if (pi.auth_flags) { + for (int i = 0; i < (int)pi.auth_flags; ++i) { + if (i >= msg->response.reply_size() || + !(msg->response.reply(i).type() == + brpc::REDIS_REPLY_STATUS && + msg->response.reply(i).data().compare("OK") == 0)) { + LOG(ERROR) << "Redis Auth failed: " << msg->response; + return MakeParseError(PARSE_ERROR_NO_RESOURCE, + "Fail to authenticate with Redis"); + } } DestroyingPtr auth_msg( static_cast(socket->release_parsing_context())); - pi.with_auth = false; + pi.auth_flags = 0; continue; } @@ -333,9 +337,15 @@ void PackRedisRequest(butil::IOBuf* buf, return cntl->SetFailed(EREQUEST, "Fail to generate credential"); } buf->append(auth_str); - ControllerPrivateAccessor(cntl).add_with_auth(); + const RedisAuthenticator* redis_auth = + dynamic_cast(auth); + if (redis_auth == NULL) { + return cntl->SetFailed(EREQUEST, "Fail to generate credential"); + } + ControllerPrivateAccessor(cntl).set_auth_flags( + redis_auth->GetAuthFlags()); } else { - ControllerPrivateAccessor(cntl).clear_with_auth(); + ControllerPrivateAccessor(cntl).clear_auth_flags(); } buf->append(request); diff --git a/src/brpc/socket.cpp b/src/brpc/socket.cpp index 59a570ac71..de61f4ad38 100644 --- a/src/brpc/socket.cpp +++ b/src/brpc/socket.cpp @@ -295,7 +295,7 @@ bool Socket::CreatedByConnect() const { } SocketMessage* const DUMMY_USER_MESSAGE = (SocketMessage*)0x1; -const uint32_t MAX_PIPELINED_COUNT = 32768; +const uint32_t MAX_PIPELINED_COUNT = 16384; struct BAIDU_CACHELINE_ALIGNMENT Socket::WriteRequest { static WriteRequest* const UNCONNECTED; @@ -306,12 +306,12 @@ struct BAIDU_CACHELINE_ALIGNMENT Socket::WriteRequest { Socket* socket; uint32_t pipelined_count() const { - return (_pc_and_udmsg >> 48) & 0x7FFF; + return (_pc_and_udmsg >> 48) & 0x3FFF; } - bool is_with_auth() const { - return _pc_and_udmsg & 0x8000000000000000ULL; + uint32_t get_auth_flags() const { + return (_pc_and_udmsg >> 62) & 0x03; } - void clear_pipelined_count_and_with_auth() { + void clear_pipelined_count_and_auth_flags() { _pc_and_udmsg &= 0xFFFFFFFFFFFFULL; } SocketMessage* user_message() const { @@ -321,9 +321,9 @@ struct BAIDU_CACHELINE_ALIGNMENT Socket::WriteRequest { _pc_and_udmsg &= 0xFFFF000000000000ULL; } void set_pipelined_count_and_user_message( - uint32_t pc, SocketMessage* msg, bool with_auth) { - if (with_auth) { - pc |= (1 << 15); + uint32_t pc, SocketMessage* msg, uint32_t auth_flags) { + if (auth_flags) { + pc |= (auth_flags & 0x03) << 14; } _pc_and_udmsg = ((uint64_t)pc << 48) | (uint64_t)(uintptr_t)msg; } @@ -337,7 +337,7 @@ struct BAIDU_CACHELINE_ALIGNMENT Socket::WriteRequest { // is already failed. (void)msg->AppendAndDestroySelf(&dummy_buf, NULL); } - set_pipelined_count_and_user_message(0, NULL, false); + set_pipelined_count_and_user_message(0, NULL, 0); return true; } return false; @@ -376,9 +376,9 @@ void Socket::WriteRequest::Setup(Socket* s) { // The struct will be popped when reading a message from the socket. PipelinedInfo pi; pi.count = pc; - pi.with_auth = is_with_auth(); + pi.auth_flags = get_auth_flags(); pi.id_wait = id_wait; - clear_pipelined_count_and_with_auth(); // avoid being pushed again + clear_pipelined_count_and_auth_flags(); // avoid being pushed again s->PushPipelinedInfo(pi); } } @@ -1462,7 +1462,7 @@ int Socket::Write(butil::IOBuf* data, const WriteOptions* options_in) { req->next = WriteRequest::UNCONNECTED; req->id_wait = opt.id_wait; req->set_pipelined_count_and_user_message( - opt.pipelined_count, DUMMY_USER_MESSAGE, opt.with_auth); + opt.pipelined_count, DUMMY_USER_MESSAGE, opt.auth_flags); return StartWrite(req, opt); } @@ -1497,7 +1497,7 @@ int Socket::Write(SocketMessagePtr<>& msg, const WriteOptions* options_in) { // wait until it points to a valid WriteRequest or NULL. req->next = WriteRequest::UNCONNECTED; req->id_wait = opt.id_wait; - req->set_pipelined_count_and_user_message(opt.pipelined_count, msg.release(), opt.with_auth); + req->set_pipelined_count_and_user_message(opt.pipelined_count, msg.release(), opt.auth_flags); return StartWrite(req, opt); } diff --git a/src/brpc/socket.h b/src/brpc/socket.h index 4be6a73165..5ca0970039 100644 --- a/src/brpc/socket.h +++ b/src/brpc/socket.h @@ -153,11 +153,11 @@ struct PipelinedInfo { PipelinedInfo() { reset(); } void reset() { count = 0; - with_auth = false; + auth_flags = 0; id_wait = INVALID_BTHREAD_ID; } uint32_t count; - bool with_auth; + uint32_t auth_flags; bthread_id_t id_wait; }; @@ -256,7 +256,7 @@ friend class policy::H2GlobalStreamCreator; // The request contains authenticating information which will be // responded by the server and processed specially when dealing // with the response. - bool with_auth; + uint32_t auth_flags; // Do not return EOVERCROWDED // Default: false @@ -264,7 +264,7 @@ friend class policy::H2GlobalStreamCreator; WriteOptions() : id_wait(INVALID_BTHREAD_ID), abstime(NULL) - , pipelined_count(0), with_auth(false) + , pipelined_count(0), auth_flags(0) , ignore_eovercrowded(false) {} }; int Write(butil::IOBuf *msg, const WriteOptions* options = NULL);