From 388b5bca17159438366fd1b8c7f2cb3a57607b05 Mon Sep 17 00:00:00 2001 From: Jim King Date: Thu, 2 Apr 2015 09:41:56 -0400 Subject: [PATCH] [THRIFT-1025] allow TSSLServerSocket and TServerSocket to optionally bind to a specific interface --- .../src/thrift/transport/TSSLServerSocket.cpp | 8 ++- .../src/thrift/transport/TSSLServerSocket.h | 17 ++++- .../src/thrift/transport/TServerSocket.cpp | 23 ++++++- lib/cpp/src/thrift/transport/TServerSocket.h | 32 ++++++++- lib/cpp/test/Makefile.am | 5 +- lib/cpp/test/TServerSocketTest.cpp | 67 +++++++++++++++++++ lib/cpp/test/TestPortFixture.h | 36 ++++++++++ 7 files changed, 177 insertions(+), 11 deletions(-) create mode 100644 lib/cpp/test/TServerSocketTest.cpp create mode 100644 lib/cpp/test/TestPortFixture.h diff --git a/lib/cpp/src/thrift/transport/TSSLServerSocket.cpp b/lib/cpp/src/thrift/transport/TSSLServerSocket.cpp index cf686e03529..421af6ad5ff 100644 --- a/lib/cpp/src/thrift/transport/TSSLServerSocket.cpp +++ b/lib/cpp/src/thrift/transport/TSSLServerSocket.cpp @@ -27,11 +27,17 @@ namespace transport { /** * SSL server socket implementation. */ -TSSLServerSocket::TSSLServerSocket(THRIFT_SOCKET port, boost::shared_ptr factory) +TSSLServerSocket::TSSLServerSocket(int port, boost::shared_ptr factory) : TServerSocket(port), factory_(factory) { factory_->server(true); } +TSSLServerSocket::TSSLServerSocket(const std::string& address, int port, + boost::shared_ptr factory) + : TServerSocket(address, port), factory_(factory) { + factory_->server(true); +} + TSSLServerSocket::TSSLServerSocket(int port, int sendTimeout, int recvTimeout, diff --git a/lib/cpp/src/thrift/transport/TSSLServerSocket.h b/lib/cpp/src/thrift/transport/TSSLServerSocket.h index bb52b04b128..7d2dfcc0002 100644 --- a/lib/cpp/src/thrift/transport/TSSLServerSocket.h +++ b/lib/cpp/src/thrift/transport/TSSLServerSocket.h @@ -35,14 +35,25 @@ class TSSLSocketFactory; class TSSLServerSocket : public TServerSocket { public: /** - * Constructor. + * Constructor. Binds to all interfaces. * * @param port Listening port * @param factory SSL socket factory implementation */ - TSSLServerSocket(THRIFT_SOCKET port, boost::shared_ptr factory); + TSSLServerSocket(int port, boost::shared_ptr factory); + + /** + * Constructor. Binds to the specified address. + * + * @param address Address to bind to + * @param port Listening port + * @param factory SSL socket factory implementation + */ + TSSLServerSocket(const std::string& address, int port, + boost::shared_ptr factory); + /** - * Constructor. + * Constructor. Binds to all interfaces. * * @param port Listening port * @param sendTimeout Socket send timeout diff --git a/lib/cpp/src/thrift/transport/TServerSocket.cpp b/lib/cpp/src/thrift/transport/TServerSocket.cpp index e228dabf6bc..fccbcfa88c1 100644 --- a/lib/cpp/src/thrift/transport/TServerSocket.cpp +++ b/lib/cpp/src/thrift/transport/TServerSocket.cpp @@ -108,7 +108,24 @@ TServerSocket::TServerSocket(int port, int sendTimeout, int recvTimeout) intSock2_(THRIFT_INVALID_SOCKET) { } -TServerSocket::TServerSocket(string path) +TServerSocket::TServerSocket(const string& address, int port) + : port_(port), + address_(address), + serverSocket_(THRIFT_INVALID_SOCKET), + acceptBacklog_(DEFAULT_BACKLOG), + sendTimeout_(0), + recvTimeout_(0), + accTimeout_(-1), + retryLimit_(0), + retryDelay_(0), + tcpSendBuffer_(0), + tcpRecvBuffer_(0), + keepAlive_(false), + intSock1_(THRIFT_INVALID_SOCKET), + intSock2_(THRIFT_INVALID_SOCKET) { +} + +TServerSocket::TServerSocket(const string& path) : port_(0), path_(path), serverSocket_(THRIFT_INVALID_SOCKET), @@ -184,8 +201,8 @@ void TServerSocket::listen() { hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG; sprintf(port, "%d", port_); - // Wildcard address - error = getaddrinfo(NULL, port, &hints, &res0); + // If address is not specified use wildcard address (NULL) + error = getaddrinfo(address_.empty() ? NULL : &address_[0], port, &hints, &res0); if (error) { GlobalOutput.printf("getaddrinfo %d: %s", error, THRIFT_GAI_STRERROR(error)); close(); diff --git a/lib/cpp/src/thrift/transport/TServerSocket.h b/lib/cpp/src/thrift/transport/TServerSocket.h index 15339373a62..49711e85548 100644 --- a/lib/cpp/src/thrift/transport/TServerSocket.h +++ b/lib/cpp/src/thrift/transport/TServerSocket.h @@ -42,11 +42,38 @@ class TServerSocket : public TServerTransport { const static int DEFAULT_BACKLOG = 1024; + /** + * Constructor. + * + * @param port Port number to bind to + */ TServerSocket(int port); + + /** + * Constructor. + * + * @param port Port number to bind to + * @param sendTimeout Socket send timeout + * @param recvTimeout Socket receive timeout + */ TServerSocket(int port, int sendTimeout, int recvTimeout); - TServerSocket(std::string path); - ~TServerSocket(); + /** + * Constructor. + * + * @param address Address to bind to + * @param port Port number to bind to + */ + TServerSocket(const std::string& address, int port); + + /** + * Constructor used for unix sockets. + * + * @param path Pathname for unix socket. + */ + TServerSocket(const std::string& path); + + virtual ~TServerSocket(); void setSendTimeout(int sendTimeout); void setRecvTimeout(int recvTimeout); @@ -85,6 +112,7 @@ class TServerSocket : public TServerTransport { private: int port_; + std::string address_; std::string path_; THRIFT_SOCKET serverSocket_; int acceptBacklog_; diff --git a/lib/cpp/test/Makefile.am b/lib/cpp/test/Makefile.am index 43c5975b708..46ff9114dfa 100755 --- a/lib/cpp/test/Makefile.am +++ b/lib/cpp/test/Makefile.am @@ -69,6 +69,7 @@ Benchmark_SOURCES = \ Benchmark_LDADD = libtestgencpp.la check_PROGRAMS = \ + UnitTests \ TFDTransportTest \ TPipedTransportTest \ DebugProtoTest \ @@ -80,7 +81,6 @@ check_PROGRAMS = \ TransportTest \ ZlibTest \ TFileTransportTest \ - UnitTests \ link_test \ OpenSSLManualInitTest \ EnumTest @@ -106,7 +106,8 @@ UnitTests_SOURCES = \ TBufferBaseTest.cpp \ Base64Test.cpp \ ToStringTest.cpp \ - TypedefTest.cpp + TypedefTest.cpp \ + TServerSocketTest.cpp if !WITH_BOOSTTHREADS UnitTests_SOURCES += \ diff --git a/lib/cpp/test/TServerSocketTest.cpp b/lib/cpp/test/TServerSocketTest.cpp new file mode 100644 index 00000000000..ebfd03f6ea8 --- /dev/null +++ b/lib/cpp/test/TServerSocketTest.cpp @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include "TestPortFixture.h" + +using apache::thrift::transport::TServerSocket; +using apache::thrift::transport::TSocket; +using apache::thrift::transport::TTransport; +using apache::thrift::transport::TTransportException; + +BOOST_FIXTURE_TEST_SUITE ( TServerSocketTest, TestPortFixture ) + +class TestTServerSocket : public TServerSocket +{ + public: + TestTServerSocket(const std::string& address, int port) : TServerSocket(address, port) { } + using TServerSocket::acceptImpl; +}; + +BOOST_AUTO_TEST_CASE( test_bind_to_address ) +{ + TestTServerSocket sock1("localhost", m_serverPort); + sock1.listen(); + TSocket clientSock("localhost", m_serverPort); + clientSock.open(); + boost::shared_ptr accepted = sock1.acceptImpl(); + accepted->close(); + sock1.close(); + + TServerSocket sock2("this.is.truly.an.unrecognizable.address.", m_serverPort); + BOOST_CHECK_THROW(sock2.listen(), TTransportException); + sock2.close(); +} + +BOOST_AUTO_TEST_CASE( test_close_before_listen ) +{ + TServerSocket sock1("localhost", m_serverPort); + sock1.close(); +} + +BOOST_AUTO_TEST_CASE( test_get_port ) +{ + TServerSocket sock1("localHost", 888); + BOOST_CHECK_EQUAL(888, sock1.getPort()); +} + +BOOST_AUTO_TEST_SUITE_END() + diff --git a/lib/cpp/test/TestPortFixture.h b/lib/cpp/test/TestPortFixture.h new file mode 100644 index 00000000000..5b27e5e46c7 --- /dev/null +++ b/lib/cpp/test/TestPortFixture.h @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +#include + +class TestPortFixture +{ + public: + TestPortFixture() + { + const char *spEnv = std::getenv("THRIFT_TEST_PORT"); + m_serverPort = (spEnv) ? atoi(spEnv) : 9090; + } + + protected: + int m_serverPort; +}; +