Skip to content

[lldb] Add Socket::CreatePair #145015

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions lldb/include/lldb/Host/Socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ class Socket : public IOObject {
static std::unique_ptr<Socket> Create(const SocketProtocol protocol,
Status &error);

using Pair = std::pair<std::unique_ptr<Socket>, std::unique_ptr<Socket>>;
static llvm::Expected<Pair>
CreatePair(std::optional<SocketProtocol> protocol = std::nullopt);

virtual Status Connect(llvm::StringRef name) = 0;
virtual Status Listen(llvm::StringRef name, int backlog) = 0;

Expand Down
4 changes: 4 additions & 0 deletions lldb/include/lldb/Host/common/TCPSocket.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ class TCPSocket : public Socket {
TCPSocket(NativeSocket socket, bool should_close);
~TCPSocket() override;

using Pair =
std::pair<std::unique_ptr<TCPSocket>, std::unique_ptr<TCPSocket>>;
static llvm::Expected<Pair> CreatePair();

// returns port number or 0 if error
uint16_t GetLocalPortNumber() const;

Expand Down
4 changes: 4 additions & 0 deletions lldb/include/lldb/Host/posix/DomainSocket.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ class DomainSocket : public Socket {
DomainSocket(NativeSocket socket, bool should_close);
explicit DomainSocket(bool should_close);

using Pair =
std::pair<std::unique_ptr<DomainSocket>, std::unique_ptr<DomainSocket>>;
static llvm::Expected<Pair> CreatePair();

Status Connect(llvm::StringRef name) override;
Status Listen(llvm::StringRef name, int backlog) override;

Expand Down
17 changes: 17 additions & 0 deletions lldb/source/Host/common/Socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,23 @@ std::unique_ptr<Socket> Socket::Create(const SocketProtocol protocol,
return socket_up;
}

llvm::Expected<Socket::Pair>
Socket::CreatePair(std::optional<SocketProtocol> protocol) {
constexpr SocketProtocol kBestProtocol =
LLDB_ENABLE_POSIX ? ProtocolUnixDomain : ProtocolTcp;
switch (protocol.value_or(kBestProtocol)) {
case ProtocolTcp:
return TCPSocket::CreatePair();
#if LLDB_ENABLE_POSIX
case ProtocolUnixDomain:
case ProtocolUnixAbstract:
return DomainSocket::CreatePair();
#endif
default:
return llvm::createStringError("Unsupported protocol");
}
}

llvm::Expected<std::unique_ptr<Socket>>
Socket::TcpConnect(llvm::StringRef host_and_port) {
Log *log = GetLog(LLDBLog::Connection);
Expand Down
26 changes: 26 additions & 0 deletions lldb/source/Host/common/TCPSocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,32 @@ TCPSocket::TCPSocket(NativeSocket socket, bool should_close)

TCPSocket::~TCPSocket() { CloseListenSockets(); }

llvm::Expected<TCPSocket::Pair> TCPSocket::CreatePair() {
auto listen_socket_up = std::make_unique<TCPSocket>(true);
if (Status error = listen_socket_up->Listen("localhost:0", 5); error.Fail())
return error.takeError();

std::string connect_address =
llvm::StringRef(listen_socket_up->GetListeningConnectionURI()[0])
.split("://")
.second.str();

auto connect_socket_up = std::make_unique<TCPSocket>(true);
if (Status error = connect_socket_up->Connect(connect_address); error.Fail())
return error.takeError();

// Connection has already been made above, so a short timeout is sufficient.
Socket *accept_socket;
if (Status error =
listen_socket_up->Accept(std::chrono::seconds(1), accept_socket);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you think of a reason why we couldn't change the signature of Socket::Accept to return an llvm::Expected<std::unique_ptr>> instead of this Status + Socket* out parameter? Not trying to sign you up for it but it stands out.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No reason. I sort of actually have a patch stack for a major overhaul of the socket classes, and it fixes this (among other things). This PR is kind of a part of a very long and windy way of upstreaming it. :)

error.Fail())
return error.takeError();

return Pair(
std::move(connect_socket_up),
std::unique_ptr<TCPSocket>(static_cast<TCPSocket *>(accept_socket)));
}

bool TCPSocket::IsValid() const {
return m_socket != kInvalidSocketValue || m_listen_sockets.size() != 0;
}
Expand Down
27 changes: 27 additions & 0 deletions lldb/source/Host/posix/DomainSocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
#endif

#include "llvm/Support/Errno.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/FileSystem.h"

#include <cstddef>
#include <fcntl.h>
#include <memory>
#include <sys/socket.h>
#include <sys/un.h>
Expand Down Expand Up @@ -76,6 +78,31 @@ DomainSocket::DomainSocket(SocketProtocol protocol, NativeSocket socket,
m_socket = socket;
}

llvm::Expected<DomainSocket::Pair> DomainSocket::CreatePair() {
int sockets[2];
int type = SOCK_STREAM;
#ifdef SOCK_CLOEXEC
type |= SOCK_CLOEXEC;
#endif
if (socketpair(AF_UNIX, type, 0, sockets) == -1)
return llvm::errorCodeToError(llvm::errnoAsErrorCode());

#ifndef SOCK_CLOEXEC
for (int s : sockets) {
int r = fcntl(s, F_SETFD, FD_CLOEXEC | fcntl(s, F_GETFD));
assert(r == 0);
(void)r;
}
#endif

return Pair(std::unique_ptr<DomainSocket>(
new DomainSocket(ProtocolUnixDomain, sockets[0],
/*should_close=*/true)),
std::unique_ptr<DomainSocket>(
new DomainSocket(ProtocolUnixDomain, sockets[1],
/*should_close=*/true)));
}

Status DomainSocket::Connect(llvm::StringRef name) {
sockaddr_un saddr_un;
socklen_t saddr_un_len;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1141,34 +1141,14 @@ void GDBRemoteCommunication::DumpHistory(Stream &strm) { m_history.Dump(strm); }
llvm::Error
GDBRemoteCommunication::ConnectLocally(GDBRemoteCommunication &client,
GDBRemoteCommunication &server) {
const int backlog = 5;
TCPSocket listen_socket(true);
if (llvm::Error error =
listen_socket.Listen("localhost:0", backlog).ToError())
return error;
llvm::Expected<Socket::Pair> pair = Socket::CreatePair();
if (!pair)
return pair.takeError();

llvm::SmallString<32> remote_addr;
llvm::raw_svector_ostream(remote_addr)
<< "connect://localhost:" << listen_socket.GetLocalPortNumber();

std::unique_ptr<ConnectionFileDescriptor> conn_up(
new ConnectionFileDescriptor());
Status status;
if (conn_up->Connect(remote_addr, &status) != lldb::eConnectionStatusSuccess)
return llvm::createStringError(llvm::inconvertibleErrorCode(),
"Unable to connect: %s", status.AsCString());

// The connection was already established above, so a short timeout is
// sufficient.
Socket *accept_socket = nullptr;
if (Status accept_status =
listen_socket.Accept(std::chrono::seconds(1), accept_socket);
accept_status.Fail())
return accept_status.takeError();

client.SetConnection(std::move(conn_up));
client.SetConnection(
std::make_unique<ConnectionFileDescriptor>(pair->first.release()));
server.SetConnection(
std::make_unique<ConnectionFileDescriptor>(accept_socket));
std::make_unique<ConnectionFileDescriptor>(pair->second.release()));
return llvm::Error::success();
}

Expand Down
26 changes: 15 additions & 11 deletions lldb/unittests/Core/CommunicationTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
//===----------------------------------------------------------------------===//

#include "lldb/Core/Communication.h"
#include "TestingSupport/SubsystemRAII.h"
#include "lldb/Core/ThreadedCommunication.h"
#include "lldb/Host/Config.h"
#include "lldb/Host/ConnectionFileDescriptor.h"
#include "lldb/Host/Pipe.h"
#include "lldb/Host/Socket.h"
#include "llvm/Testing/Support/Error.h"
#include "gtest/gtest.h"
#include "TestingSupport/Host/SocketTestUtilities.h"
#include "TestingSupport/SubsystemRAII.h"

#include <chrono>
#include <thread>
Expand All @@ -31,15 +31,17 @@ class CommunicationTest : public testing::Test {
};

static void CommunicationReadTest(bool use_read_thread) {
std::unique_ptr<TCPSocket> a, b;
ASSERT_TRUE(CreateTCPConnectedSockets("localhost", &a, &b));
llvm::Expected<Socket::Pair> pair = Socket::CreatePair();
ASSERT_THAT_EXPECTED(pair, llvm::Succeeded());
Socket &a = *pair->first;

size_t num_bytes = 4;
ASSERT_THAT_ERROR(a->Write("test", num_bytes).ToError(), llvm::Succeeded());
ASSERT_THAT_ERROR(a.Write("test", num_bytes).ToError(), llvm::Succeeded());
ASSERT_EQ(num_bytes, 4U);

ThreadedCommunication comm("test");
comm.SetConnection(std::make_unique<ConnectionFileDescriptor>(b.release()));
comm.SetConnection(
std::make_unique<ConnectionFileDescriptor>(pair->second.release()));
comm.SetCloseOnEOF(true);

if (use_read_thread) {
Expand Down Expand Up @@ -73,7 +75,7 @@ static void CommunicationReadTest(bool use_read_thread) {
EXPECT_THAT_ERROR(error.ToError(), llvm::Failed());

// This read should return EOF.
ASSERT_THAT_ERROR(a->Close().ToError(), llvm::Succeeded());
ASSERT_THAT_ERROR(a.Close().ToError(), llvm::Succeeded());
error.Clear();
EXPECT_EQ(
comm.Read(buf, sizeof(buf), std::chrono::seconds(5), status, &error), 0U);
Expand Down Expand Up @@ -118,17 +120,19 @@ TEST_F(CommunicationTest, ReadThread) {
}

TEST_F(CommunicationTest, SynchronizeWhileClosing) {
std::unique_ptr<TCPSocket> a, b;
ASSERT_TRUE(CreateTCPConnectedSockets("localhost", &a, &b));
llvm::Expected<Socket::Pair> pair = Socket::CreatePair();
ASSERT_THAT_EXPECTED(pair, llvm::Succeeded());
Socket &a = *pair->first;

ThreadedCommunication comm("test");
comm.SetConnection(std::make_unique<ConnectionFileDescriptor>(b.release()));
comm.SetConnection(
std::make_unique<ConnectionFileDescriptor>(pair->second.release()));
comm.SetCloseOnEOF(true);
ASSERT_TRUE(comm.StartReadThread());

// Ensure that we can safely synchronize with the read thread while it is
// closing the read end (in response to us closing the write end).
ASSERT_THAT_ERROR(a->Close().ToError(), llvm::Succeeded());
ASSERT_THAT_ERROR(a.Close().ToError(), llvm::Succeeded());
comm.SynchronizeWithReadThread();

ASSERT_TRUE(comm.StopReadThread());
Expand Down
35 changes: 35 additions & 0 deletions lldb/unittests/Host/SocketTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,41 @@ TEST_F(SocketTest, DecodeHostAndPort) {
llvm::HasValue(Socket::HostAndPort{"abcd:12fg:AF58::1", 12345}));
}

TEST_F(SocketTest, CreatePair) {
std::vector<std::optional<Socket::SocketProtocol>> functional_protocols = {
std::nullopt,
Socket::ProtocolTcp,
#if LLDB_ENABLE_POSIX
Socket::ProtocolUnixDomain,
Socket::ProtocolUnixAbstract,
#endif
};
for (auto p : functional_protocols) {
auto expected_socket_pair = Socket::CreatePair(p);
ASSERT_THAT_EXPECTED(expected_socket_pair, llvm::Succeeded());
Socket &a = *expected_socket_pair->first;
Socket &b = *expected_socket_pair->second;
size_t num_bytes = 1;
ASSERT_THAT_ERROR(a.Write("a", num_bytes).takeError(), llvm::Succeeded());
ASSERT_EQ(num_bytes, 1);
char c;
ASSERT_THAT_ERROR(b.Read(&c, num_bytes).takeError(), llvm::Succeeded());
ASSERT_EQ(num_bytes, 1);
ASSERT_EQ(c, 'a');
}

std::vector<Socket::SocketProtocol> erroring_protocols = {
#if !LLDB_ENABLE_POSIX
Socket::ProtocolUnixDomain,
Socket::ProtocolUnixAbstract,
#endif
};
for (auto p : erroring_protocols) {
ASSERT_THAT_EXPECTED(Socket::CreatePair(p),
llvm::FailedWithMessage("Unsupported protocol"));
}
}

#if LLDB_ENABLE_POSIX
TEST_F(SocketTest, DomainListenConnectAccept) {
llvm::SmallString<64> Path;
Expand Down
Loading