Skip to content

Commit

Permalink
net: don't try to set TCP_NODELAY on local Unix sockets.
Browse files Browse the repository at this point in the history
Change-Id: Ic0f720554080a01b7a46abbffda023834016c0ea
Signed-off-by: Michael Meeks <michael.meeks@collabora.com>
  • Loading branch information
mmeeks committed Nov 1, 2023
1 parent 112fefa commit 08d9081
Show file tree
Hide file tree
Showing 10 changed files with 60 additions and 44 deletions.
2 changes: 1 addition & 1 deletion net/DelaySocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class DelaySocket : public Socket {
std::vector<std::shared_ptr<WriteChunk>> _chunks;
public:
DelaySocket(int delayMs, int fd) :
Socket (fd), _delayMs(delayMs),
Socket (fd, Socket::Type::Unix), _delayMs(delayMs),
_state(ReadWrite)
{
// setSocketBufferSize(Socket::DefaultSendBufferSize);
Expand Down
7 changes: 4 additions & 3 deletions net/NetUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,13 @@ connect(const std::string& host, const std::string& port, const bool isSSL,
}
else
{
Socket::Type type = ai->ai_family == AF_INET ? Socket::Type::IPv4 : Socket::Type::IPv6;
#if ENABLE_SSL
if (isSSL)
socket = StreamSocket::create<SslStreamSocket>(host, fd, true, protocolHandler);
socket = StreamSocket::create<SslStreamSocket>(host, fd, type, true, protocolHandler);
#endif
if (!socket && !isSSL)
socket = StreamSocket::create<StreamSocket>(host, fd, true, protocolHandler);
socket = StreamSocket::create<StreamSocket>(host, fd, type, true, protocolHandler);

if (socket)
{
Expand Down Expand Up @@ -218,4 +219,4 @@ bool parseUri(std::string uri, std::string& scheme, std::string& host, std::stri
return !host.empty();
}

} // namespace net
} // namespace net
6 changes: 3 additions & 3 deletions net/ServerSocket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
class SocketFactory
{
public:
virtual std::shared_ptr<Socket> create(const int fd) = 0;
virtual std::shared_ptr<Socket> create(const int fd, Socket::Type type) = 0;
};

/// A non-blocking, streaming socket.
Expand Down Expand Up @@ -107,9 +107,9 @@ class ServerSocket : public Socket

protected:
/// Create a Socket instance from the accepted socket FD.
std::shared_ptr<Socket> createSocketFromAccept(int fd) const
std::shared_ptr<Socket> createSocketFromAccept(int fd, Socket::Type type) const
{
return _sockFactory->create(fd);
return _sockFactory->create(fd, type);
}

private:
Expand Down
16 changes: 11 additions & 5 deletions net/Socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,8 @@ bool SocketPoll::insertNewUnixSocket(
}

std::shared_ptr<StreamSocket> socket
= StreamSocket::create<StreamSocket>(std::string(), fd, true, websocketHandler);
= StreamSocket::create<StreamSocket>(std::string(), fd, Socket::Type::Unix,
true, websocketHandler);
if (!socket)
{
LOG_ERR("Failed to create socket unix socket at " << location);
Expand Down Expand Up @@ -664,7 +665,7 @@ void SocketPoll::insertNewFakeSocket(
else
{
std::shared_ptr<StreamSocket> socket;
socket = StreamSocket::create<StreamSocket>(std::string(), fd, true, websocketHandler);
socket = StreamSocket::create<StreamSocket>(std::string(), fd, Socket::Type::Unix, true, websocketHandler);
if (socket)
{
LOG_TRC("Sending 'hello' instead of HTTP GET for now");
Expand Down Expand Up @@ -893,29 +894,34 @@ std::shared_ptr<Socket> ServerSocket::accept()
// Create a socket object using the factory.
if (rc != -1)
{
std::shared_ptr<Socket> _socket = createSocketFromAccept(rc);

#if !MOBILEAPP
char addrstr[INET6_ADDRSTRLEN];

Socket::Type type;
const void *inAddr;
if (clientInfo.sin6_family == AF_INET)
{
auto ipv4 = (struct sockaddr_in *)&clientInfo;
inAddr = &(ipv4->sin_addr);
type = Socket::Type::IPv4;
}
else
{
auto ipv6 = (struct sockaddr_in6 *)&clientInfo;
inAddr = &(ipv6->sin6_addr);
type = Socket::Type::IPv6;
}

std::shared_ptr<Socket> _socket = createSocketFromAccept(rc, type);

inet_ntop(clientInfo.sin6_family, inAddr, addrstr, sizeof(addrstr));
_socket->setClientAddress(addrstr);

LOG_TRC("Accepted socket #" << _socket->getFD() << " has family "
<< clientInfo.sin6_family << " address "
<< _socket->clientAddress());
#else
std::shared_ptr<Socket> _socket = createSocketFromAccept(rc, Socket::Type::Unix);
#endif
return _socket;
}
Expand Down Expand Up @@ -978,7 +984,7 @@ std::shared_ptr<Socket> LocalServerSocket::accept()
if (rc < 0)
return std::shared_ptr<Socket>(nullptr);

std::shared_ptr<Socket> _socket = createSocketFromAccept(rc);
std::shared_ptr<Socket> _socket = createSocketFromAccept(rc, Socket::Type::Unix);
// Sanity check this incoming socket
#ifndef __FreeBSD__
#define CREDS_UID(c) c.uid
Expand Down
24 changes: 13 additions & 11 deletions net/Socket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ class Socket
Socket(Type type)
: _fd(createSocket(type))
{
init();
init(type);
}

virtual ~Socket()
Expand Down Expand Up @@ -363,15 +363,19 @@ class Socket
protected:
/// Construct based on an existing socket fd.
/// Used by accept() only.
Socket(const int fd)
Socket(const int fd, Type type)
: _fd(fd)
{
init();
init(type);
}

void init()
inline void logPrefix(std::ostream& os) const { os << '#' << _fd << ": "; }

private:
void init(Type type)
{
setNoDelay();
if (type != Type::Unix)
setNoDelay();
_ignoreInput = false;
_sendBufferSize = DefaultSendBufferSize;
_owner = std::this_thread::get_id();
Expand All @@ -389,8 +393,6 @@ class Socket
#endif
}

inline void logPrefix(std::ostream& os) const { os << '#' << _fd << ": "; }

private:
std::string _clientAddress;
const int _fd;
Expand Down Expand Up @@ -921,10 +923,10 @@ class StreamSocket : public Socket,
};

/// Create a StreamSocket from native FD.
StreamSocket(std::string host, const int fd, bool /* isClient */,
StreamSocket(std::string host, const int fd, Type type, bool /* isClient */,
std::shared_ptr<ProtocolHandlerInterface> socketHandler,
ReadType readType = NormalRead) :
Socket(fd),
Socket(fd, type),
_hostname(std::move(host)),
_socketHandler(std::move(socketHandler)),
_bytesSent(0),
Expand Down Expand Up @@ -1194,12 +1196,12 @@ class StreamSocket : public Socket,
/// We need this helper since the handler needs a shared_ptr to the socket
/// but we can't have a shared_ptr in the ctor.
template <typename TSocket>
static std::shared_ptr<TSocket> create(std::string hostname, const int fd, bool isClient,
static std::shared_ptr<TSocket> create(std::string hostname, const int fd, Type type, bool isClient,
std::shared_ptr<ProtocolHandlerInterface> handler,
ReadType readType = NormalRead)
{
ProtocolHandlerInterface* pHandler = handler.get();
auto socket = std::make_shared<TSocket>(std::move(hostname), fd, isClient,
auto socket = std::make_shared<TSocket>(std::move(hostname), fd, type, isClient,
std::move(handler), readType);
pHandler->onConnect(socket);
return socket;
Expand Down
4 changes: 2 additions & 2 deletions net/SslSocket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
class SslStreamSocket final : public StreamSocket
{
public:
SslStreamSocket(const std::string& host, const int fd, bool isClient,
SslStreamSocket(const std::string& host, const int fd, Type type, bool isClient,
std::shared_ptr<ProtocolHandlerInterface> responseClient,
ReadType readType = NormalRead)
: StreamSocket(host, fd, isClient, std::move(responseClient), readType)
: StreamSocket(host, fd, type, isClient, std::move(responseClient), readType)
, _bio(nullptr)
, _ssl(nullptr)
, _sslWantsTo(SslWantsTo::Neither)
Expand Down
10 changes: 5 additions & 5 deletions test/HttpRequestTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class HttpRequestTests final : public CPPUNIT_NS::TestFixture

class ServerSocketFactory final : public SocketFactory
{
std::shared_ptr<Socket> create(const int physicalFd) override
std::shared_ptr<Socket> create(const int physicalFd, Socket::Type type) override
{
int fd = physicalFd;

Expand All @@ -125,12 +125,12 @@ class HttpRequestTests final : public CPPUNIT_NS::TestFixture
#if ENABLE_SSL
if (helpers::haveSsl())
return StreamSocket::create<SslStreamSocket>(
std::string(), fd, false, std::make_shared<ServerRequestHandler>());
std::string(), fd, type, false, std::make_shared<ServerRequestHandler>());
else
return StreamSocket::create<StreamSocket>(std::string(), fd, false,
return StreamSocket::create<StreamSocket>(std::string(), fd, type, false,
std::make_shared<ServerRequestHandler>());
#else
return StreamSocket::create<StreamSocket>(std::string(), fd, false,
return StreamSocket::create<StreamSocket>(std::string(), fd, type, false,
std::make_shared<ServerRequestHandler>());
#endif
}
Expand Down Expand Up @@ -179,7 +179,7 @@ void HttpRequestTests::testSslHostname()
{
const std::string host = "localhost";
std::shared_ptr<SslStreamSocket> socket = StreamSocket::create<SslStreamSocket>(
host, _port, false, std::make_shared<ServerRequestHandler>());
host, _port, Socket::Type::All, false, std::make_shared<ServerRequestHandler>());
LOK_ASSERT_EQUAL(host, socket->getSslServername());
}
#endif
Expand Down
10 changes: 6 additions & 4 deletions tools/WebSocketDump.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,17 +185,19 @@ class DumpSocketFactory final : public SocketFactory
public:
DumpSocketFactory(bool isSSL) : _isSSL(isSSL) {}

std::shared_ptr<Socket> create(const int physicalFd) override
std::shared_ptr<Socket> create(const int physicalFd, Socket::Type type) override
{
#if ENABLE_SSL
if (_isSSL)
return StreamSocket::create<SslStreamSocket>(
std::string(), physicalFd, false, std::make_shared<ClientRequestDispatcher>());
std::string(), physicalFd, type, false,
std::make_shared<ClientRequestDispatcher>());
#else
(void)_isSSL;
#endif
return StreamSocket::create<StreamSocket>(std::string(), physicalFd, false,
std::make_shared<ClientRequestDispatcher>());
return StreamSocket::create<StreamSocket>(
std::string(), physicalFd, type, false,
std::make_shared<ClientRequestDispatcher>());
}
};

Expand Down
17 changes: 9 additions & 8 deletions wsd/COOLWSD.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1019,7 +1019,7 @@ class InotifySocket : public Socket
{
public:
InotifySocket():
Socket(inotify_init1(IN_NONBLOCK))
Socket(inotify_init1(IN_NONBLOCK), Socket::Type::Unix)
, m_stopOnConfigChange(true)
{
if (getFD() == -1)
Expand Down Expand Up @@ -5464,7 +5464,7 @@ std::map<std::string, std::string> ClientRequestDispatcher::StaticFileContentCac

class PlainSocketFactory final : public SocketFactory
{
std::shared_ptr<Socket> create(const int physicalFd) override
std::shared_ptr<Socket> create(const int physicalFd, Socket::Type type) override
{
int fd = physicalFd;
#if !MOBILEAPP
Expand All @@ -5477,15 +5477,16 @@ class PlainSocketFactory final : public SocketFactory
fd = delayfd;
}
#endif
return StreamSocket::create<StreamSocket>(std::string(), fd, false,
std::make_shared<ClientRequestDispatcher>());
return StreamSocket::create<StreamSocket>(
std::string(), fd, type, false,
std::make_shared<ClientRequestDispatcher>());
}
};

#if ENABLE_SSL
class SslSocketFactory final : public SocketFactory
{
std::shared_ptr<Socket> create(const int physicalFd) override
std::shared_ptr<Socket> create(const int physicalFd, Socket::Type type) override
{
int fd = physicalFd;

Expand All @@ -5494,18 +5495,18 @@ class SslSocketFactory final : public SocketFactory
fd = Delay::create(SimulatedLatencyMs, physicalFd);
#endif

return StreamSocket::create<SslStreamSocket>(std::string(), fd, false,
return StreamSocket::create<SslStreamSocket>(std::string(), fd, type, false,
std::make_shared<ClientRequestDispatcher>());
}
};
#endif

class PrisonerSocketFactory final : public SocketFactory
{
std::shared_ptr<Socket> create(const int fd) override
std::shared_ptr<Socket> create(const int fd, Socket::Type type) override
{
// No local delay.
return StreamSocket::create<StreamSocket>(std::string(), fd, false,
return StreamSocket::create<StreamSocket>(std::string(), fd, type, false,
std::make_shared<PrisonerRequestDispatcher>(),
StreamSocket::ReadType::UseRecvmsgExpectFD);
}
Expand Down
8 changes: 6 additions & 2 deletions wsd/DocumentBroker.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,12 @@ class ChildProcess final : public WSProcess
int urpToKitFD = socket->getIncomingFD(URPToKit);
if (urpFromKitFD != -1 && urpToKitFD != -1)
{
_urpFromKit = StreamSocket::create<StreamSocket>(std::string(), urpFromKitFD, false, std::make_shared<UrpHandler>(this));
_urpToKit = StreamSocket::create<StreamSocket>(std::string(), urpToKitFD, false, std::make_shared<UrpHandler>(this));
_urpFromKit = StreamSocket::create<StreamSocket>(
std::string(), urpFromKitFD, Socket::Type::Unix,
false, std::make_shared<UrpHandler>(this));
_urpToKit = StreamSocket::create<StreamSocket>(
std::string(), urpToKitFD, Socket::Type::Unix,
false, std::make_shared<UrpHandler>(this));
}
}

Expand Down

0 comments on commit 08d9081

Please sign in to comment.