Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

net: don't try to set TCP_NODELAY on local Unix sockets. #7573

Merged
merged 1 commit into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion net/DelaySocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class DelaySocket : public Socket {
std::vector<std::shared_ptr<WriteChunk>> _chunks;
public:
DelaySocket(int delayMs, int fd) :
Socket (fd), _delayMs(delayMs),
Socket (fd, Socket::Type::Unix), _delayMs(delayMs),
_state(ReadWrite)
{
// setSocketBufferSize(Socket::DefaultSendBufferSize);
Expand Down
7 changes: 4 additions & 3 deletions net/NetUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,13 @@ connect(const std::string& host, const std::string& port, const bool isSSL,
}
else
{
Socket::Type type = ai->ai_family == AF_INET ? Socket::Type::IPv4 : Socket::Type::IPv6;
#if ENABLE_SSL
if (isSSL)
socket = StreamSocket::create<SslStreamSocket>(host, fd, true, protocolHandler);
socket = StreamSocket::create<SslStreamSocket>(host, fd, type, true, protocolHandler);
#endif
if (!socket && !isSSL)
socket = StreamSocket::create<StreamSocket>(host, fd, true, protocolHandler);
socket = StreamSocket::create<StreamSocket>(host, fd, type, true, protocolHandler);

if (socket)
{
Expand Down Expand Up @@ -218,4 +219,4 @@ bool parseUri(std::string uri, std::string& scheme, std::string& host, std::stri
return !host.empty();
}

} // namespace net
} // namespace net
6 changes: 3 additions & 3 deletions net/ServerSocket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
class SocketFactory
{
public:
virtual std::shared_ptr<Socket> create(const int fd) = 0;
virtual std::shared_ptr<Socket> create(const int fd, Socket::Type type) = 0;
};

/// A non-blocking, streaming socket.
Expand Down Expand Up @@ -107,9 +107,9 @@ class ServerSocket : public Socket

protected:
/// Create a Socket instance from the accepted socket FD.
std::shared_ptr<Socket> createSocketFromAccept(int fd) const
std::shared_ptr<Socket> createSocketFromAccept(int fd, Socket::Type type) const
{
return _sockFactory->create(fd);
return _sockFactory->create(fd, type);
}

private:
Expand Down
16 changes: 11 additions & 5 deletions net/Socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,8 @@ bool SocketPoll::insertNewUnixSocket(
}

std::shared_ptr<StreamSocket> socket
= StreamSocket::create<StreamSocket>(std::string(), fd, true, websocketHandler);
= StreamSocket::create<StreamSocket>(std::string(), fd, Socket::Type::Unix,
true, websocketHandler);
if (!socket)
{
LOG_ERR("Failed to create socket unix socket at " << location);
Expand Down Expand Up @@ -664,7 +665,7 @@ void SocketPoll::insertNewFakeSocket(
else
{
std::shared_ptr<StreamSocket> socket;
socket = StreamSocket::create<StreamSocket>(std::string(), fd, true, websocketHandler);
socket = StreamSocket::create<StreamSocket>(std::string(), fd, Socket::Type::Unix, true, websocketHandler);
if (socket)
{
LOG_TRC("Sending 'hello' instead of HTTP GET for now");
Expand Down Expand Up @@ -893,29 +894,34 @@ std::shared_ptr<Socket> ServerSocket::accept()
// Create a socket object using the factory.
if (rc != -1)
{
std::shared_ptr<Socket> _socket = createSocketFromAccept(rc);

#if !MOBILEAPP
char addrstr[INET6_ADDRSTRLEN];

Socket::Type type;
const void *inAddr;
if (clientInfo.sin6_family == AF_INET)
{
auto ipv4 = (struct sockaddr_in *)&clientInfo;
inAddr = &(ipv4->sin_addr);
type = Socket::Type::IPv4;
}
else
{
auto ipv6 = (struct sockaddr_in6 *)&clientInfo;
inAddr = &(ipv6->sin6_addr);
type = Socket::Type::IPv6;
}

std::shared_ptr<Socket> _socket = createSocketFromAccept(rc, type);

inet_ntop(clientInfo.sin6_family, inAddr, addrstr, sizeof(addrstr));
_socket->setClientAddress(addrstr);

LOG_TRC("Accepted socket #" << _socket->getFD() << " has family "
<< clientInfo.sin6_family << " address "
<< _socket->clientAddress());
#else
std::shared_ptr<Socket> _socket = createSocketFromAccept(rc, Socket::Type::Unix);
#endif
return _socket;
}
Expand Down Expand Up @@ -978,7 +984,7 @@ std::shared_ptr<Socket> LocalServerSocket::accept()
if (rc < 0)
return std::shared_ptr<Socket>(nullptr);

std::shared_ptr<Socket> _socket = createSocketFromAccept(rc);
std::shared_ptr<Socket> _socket = createSocketFromAccept(rc, Socket::Type::Unix);
// Sanity check this incoming socket
#ifndef __FreeBSD__
#define CREDS_UID(c) c.uid
Expand Down
24 changes: 13 additions & 11 deletions net/Socket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ class Socket
Socket(Type type)
: _fd(createSocket(type))
{
init();
init(type);
}

virtual ~Socket()
Expand Down Expand Up @@ -363,15 +363,19 @@ class Socket
protected:
/// Construct based on an existing socket fd.
/// Used by accept() only.
Socket(const int fd)
Socket(const int fd, Type type)
: _fd(fd)
{
init();
init(type);
}

void init()
inline void logPrefix(std::ostream& os) const { os << '#' << _fd << ": "; }

private:
void init(Type type)
{
setNoDelay();
if (type != Type::Unix)
setNoDelay();
_ignoreInput = false;
_sendBufferSize = DefaultSendBufferSize;
_owner = std::this_thread::get_id();
Expand All @@ -389,8 +393,6 @@ class Socket
#endif
}

inline void logPrefix(std::ostream& os) const { os << '#' << _fd << ": "; }

private:
std::string _clientAddress;
const int _fd;
Expand Down Expand Up @@ -921,10 +923,10 @@ class StreamSocket : public Socket,
};

/// Create a StreamSocket from native FD.
StreamSocket(std::string host, const int fd, bool /* isClient */,
StreamSocket(std::string host, const int fd, Type type, bool /* isClient */,
std::shared_ptr<ProtocolHandlerInterface> socketHandler,
ReadType readType = NormalRead) :
Socket(fd),
Socket(fd, type),
_hostname(std::move(host)),
_socketHandler(std::move(socketHandler)),
_bytesSent(0),
Expand Down Expand Up @@ -1194,12 +1196,12 @@ class StreamSocket : public Socket,
/// We need this helper since the handler needs a shared_ptr to the socket
/// but we can't have a shared_ptr in the ctor.
template <typename TSocket>
static std::shared_ptr<TSocket> create(std::string hostname, const int fd, bool isClient,
static std::shared_ptr<TSocket> create(std::string hostname, const int fd, Type type, bool isClient,
std::shared_ptr<ProtocolHandlerInterface> handler,
ReadType readType = NormalRead)
{
ProtocolHandlerInterface* pHandler = handler.get();
auto socket = std::make_shared<TSocket>(std::move(hostname), fd, isClient,
auto socket = std::make_shared<TSocket>(std::move(hostname), fd, type, isClient,
std::move(handler), readType);
pHandler->onConnect(socket);
return socket;
Expand Down
4 changes: 2 additions & 2 deletions net/SslSocket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
class SslStreamSocket final : public StreamSocket
{
public:
SslStreamSocket(const std::string& host, const int fd, bool isClient,
SslStreamSocket(const std::string& host, const int fd, Type type, bool isClient,
std::shared_ptr<ProtocolHandlerInterface> responseClient,
ReadType readType = NormalRead)
: StreamSocket(host, fd, isClient, std::move(responseClient), readType)
: StreamSocket(host, fd, type, isClient, std::move(responseClient), readType)
, _bio(nullptr)
, _ssl(nullptr)
, _sslWantsTo(SslWantsTo::Neither)
Expand Down
10 changes: 5 additions & 5 deletions test/HttpRequestTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class HttpRequestTests final : public CPPUNIT_NS::TestFixture

class ServerSocketFactory final : public SocketFactory
{
std::shared_ptr<Socket> create(const int physicalFd) override
std::shared_ptr<Socket> create(const int physicalFd, Socket::Type type) override
{
int fd = physicalFd;

Expand All @@ -125,12 +125,12 @@ class HttpRequestTests final : public CPPUNIT_NS::TestFixture
#if ENABLE_SSL
if (helpers::haveSsl())
return StreamSocket::create<SslStreamSocket>(
std::string(), fd, false, std::make_shared<ServerRequestHandler>());
std::string(), fd, type, false, std::make_shared<ServerRequestHandler>());
else
return StreamSocket::create<StreamSocket>(std::string(), fd, false,
return StreamSocket::create<StreamSocket>(std::string(), fd, type, false,
std::make_shared<ServerRequestHandler>());
#else
return StreamSocket::create<StreamSocket>(std::string(), fd, false,
return StreamSocket::create<StreamSocket>(std::string(), fd, type, false,
std::make_shared<ServerRequestHandler>());
#endif
}
Expand Down Expand Up @@ -179,7 +179,7 @@ void HttpRequestTests::testSslHostname()
{
const std::string host = "localhost";
std::shared_ptr<SslStreamSocket> socket = StreamSocket::create<SslStreamSocket>(
host, _port, false, std::make_shared<ServerRequestHandler>());
host, _port, Socket::Type::All, false, std::make_shared<ServerRequestHandler>());
LOK_ASSERT_EQUAL(host, socket->getSslServername());
}
#endif
Expand Down
10 changes: 6 additions & 4 deletions tools/WebSocketDump.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,17 +185,19 @@ class DumpSocketFactory final : public SocketFactory
public:
DumpSocketFactory(bool isSSL) : _isSSL(isSSL) {}

std::shared_ptr<Socket> create(const int physicalFd) override
std::shared_ptr<Socket> create(const int physicalFd, Socket::Type type) override
{
#if ENABLE_SSL
if (_isSSL)
return StreamSocket::create<SslStreamSocket>(
std::string(), physicalFd, false, std::make_shared<ClientRequestDispatcher>());
std::string(), physicalFd, type, false,
std::make_shared<ClientRequestDispatcher>());
#else
(void)_isSSL;
#endif
return StreamSocket::create<StreamSocket>(std::string(), physicalFd, false,
std::make_shared<ClientRequestDispatcher>());
return StreamSocket::create<StreamSocket>(
std::string(), physicalFd, type, false,
std::make_shared<ClientRequestDispatcher>());
}
};

Expand Down
17 changes: 9 additions & 8 deletions wsd/COOLWSD.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1019,7 +1019,7 @@ class InotifySocket : public Socket
{
public:
InotifySocket():
Socket(inotify_init1(IN_NONBLOCK))
Socket(inotify_init1(IN_NONBLOCK), Socket::Type::Unix)
, m_stopOnConfigChange(true)
{
if (getFD() == -1)
Expand Down Expand Up @@ -5464,7 +5464,7 @@ std::map<std::string, std::string> ClientRequestDispatcher::StaticFileContentCac

class PlainSocketFactory final : public SocketFactory
{
std::shared_ptr<Socket> create(const int physicalFd) override
std::shared_ptr<Socket> create(const int physicalFd, Socket::Type type) override
{
int fd = physicalFd;
#if !MOBILEAPP
Expand All @@ -5477,15 +5477,16 @@ class PlainSocketFactory final : public SocketFactory
fd = delayfd;
}
#endif
return StreamSocket::create<StreamSocket>(std::string(), fd, false,
std::make_shared<ClientRequestDispatcher>());
return StreamSocket::create<StreamSocket>(
std::string(), fd, type, false,
std::make_shared<ClientRequestDispatcher>());
}
};

#if ENABLE_SSL
class SslSocketFactory final : public SocketFactory
{
std::shared_ptr<Socket> create(const int physicalFd) override
std::shared_ptr<Socket> create(const int physicalFd, Socket::Type type) override
{
int fd = physicalFd;

Expand All @@ -5494,18 +5495,18 @@ class SslSocketFactory final : public SocketFactory
fd = Delay::create(SimulatedLatencyMs, physicalFd);
#endif

return StreamSocket::create<SslStreamSocket>(std::string(), fd, false,
return StreamSocket::create<SslStreamSocket>(std::string(), fd, type, false,
std::make_shared<ClientRequestDispatcher>());
}
};
#endif

class PrisonerSocketFactory final : public SocketFactory
{
std::shared_ptr<Socket> create(const int fd) override
std::shared_ptr<Socket> create(const int fd, Socket::Type type) override
{
// No local delay.
return StreamSocket::create<StreamSocket>(std::string(), fd, false,
return StreamSocket::create<StreamSocket>(std::string(), fd, type, false,
std::make_shared<PrisonerRequestDispatcher>(),
StreamSocket::ReadType::UseRecvmsgExpectFD);
}
Expand Down
8 changes: 6 additions & 2 deletions wsd/DocumentBroker.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,12 @@ class ChildProcess final : public WSProcess
int urpToKitFD = socket->getIncomingFD(URPToKit);
if (urpFromKitFD != -1 && urpToKitFD != -1)
{
_urpFromKit = StreamSocket::create<StreamSocket>(std::string(), urpFromKitFD, false, std::make_shared<UrpHandler>(this));
_urpToKit = StreamSocket::create<StreamSocket>(std::string(), urpToKitFD, false, std::make_shared<UrpHandler>(this));
_urpFromKit = StreamSocket::create<StreamSocket>(
std::string(), urpFromKitFD, Socket::Type::Unix,
false, std::make_shared<UrpHandler>(this));
_urpToKit = StreamSocket::create<StreamSocket>(
std::string(), urpToKitFD, Socket::Type::Unix,
false, std::make_shared<UrpHandler>(this));
}
}

Expand Down