Skip to content

Commit

Permalink
Protect StreamServerConnection::connection
Browse files Browse the repository at this point in the history
https://bugs.webkit.org/show_bug.cgi?id=262525
rdar://problem/116384986

Reviewed by Dan Glastonbury, Kimmo Kinnunen and Chris Dumez.

Use StreamServerConnection::protectedConnection() (that returns a Ref) in most
cases, including in generated IPC code.

* Source/WebKit/Platform/IPC/HandleMessage.h:
(IPC::handleMessageSynchronous):
* Source/WebKit/Platform/IPC/StreamServerConnection.h:
* Source/WebKit/Scripts/webkit/messages.py:
(async_message_statement):
(generate_message_handler):
* Source/WebKit/Scripts/webkit/tests/TestWithStreamBatchedMessageReceiver.cpp:
(WebKit::TestWithStreamBatched::didReceiveStreamMessage):
* Source/WebKit/Scripts/webkit/tests/TestWithStreamMessageReceiver.cpp:
(WebKit::TestWithStream::didReceiveStreamMessage):

Canonical link: https://commits.webkit.org/268893@main
  • Loading branch information
squelart committed Oct 5, 2023
1 parent 781be02 commit 295c392
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 12 deletions.
4 changes: 2 additions & 2 deletions Source/WebKit/Platform/IPC/HandleMessage.h
Original file line number Diff line number Diff line change
Expand Up @@ -294,10 +294,10 @@ void handleMessageSynchronous(StreamServerConnection& connection, Decoder& decod
static_assert(std::is_same_v<typename ValidationType::CompletionHandlerArguments, typename MessageType::ReplyArguments>);
using CompletionHandlerType = typename ValidationType::CompletionHandlerType;

logMessage(connection.connection(), MessageType::name(), object, *arguments);
logMessage(connection.protectedConnection(), MessageType::name(), object, *arguments);
callMemberFunction(object, function, WTFMove(*arguments),
CompletionHandlerType([syncRequestID, connection = Ref { connection }] (auto&&... args) mutable {
logReply(connection->connection(), MessageType::name(), args...);
logReply(connection->protectedConnection(), MessageType::name(), args...);
connection->sendSyncReply<MessageType>(syncRequestID, std::forward<decltype(args)>(args)...);
}));
}
Expand Down
3 changes: 2 additions & 1 deletion Source/WebKit/Platform/IPC/StreamServerConnection.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class StreamServerConnection final : public ThreadSafeRefCounted<StreamServerCon
void stopReceivingMessages(ReceiverName, uint64_t destinationID);

Connection& connection() { return m_connection; }
Ref<Connection> protectedConnection() { return m_connection; }

enum DispatchResult : bool {
HasNoMessages,
Expand Down Expand Up @@ -113,7 +114,7 @@ class StreamServerConnection final : public ThreadSafeRefCounted<StreamServerCon
bool dispatchOutOfStreamMessage(Decoder&&);

using WakeUpClient = StreamServerConnectionBuffer::WakeUpClient;
Ref<IPC::Connection> m_connection;
const Ref<IPC::Connection> m_connection;
RefPtr<StreamConnectionWorkQueue> m_workQueue;
StreamServerConnectionBuffer m_buffer;

Expand Down
4 changes: 2 additions & 2 deletions Source/WebKit/Scripts/webkit/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ def async_message_statement(receiver, message):

connection = 'connection, '
if receiver.has_attribute(STREAM_ATTRIBUTE):
connection = 'connection.connection(), '
connection = 'connection.protectedConnection(), '
if receiver.has_attribute(NOT_USING_IPC_CONNECTION_ATTRIBUTE):
connection = ''

Expand Down Expand Up @@ -1170,7 +1170,7 @@ def collect_message_statements(messages, message_statement_function):
result.append(' UNUSED_PARAM(decoder);\n')
result.append(' UNUSED_PARAM(connection);\n')
result.append('#if ENABLE(IPC_TESTING_API)\n')
result.append(' if (connection.connection().ignoreInvalidMessageForTesting())\n')
result.append(' if (connection.protectedConnection()->ignoreInvalidMessageForTesting())\n')
result.append(' return;\n')
result.append('#endif // ENABLE(IPC_TESTING_API)\n')
result.append(' ASSERT_NOT_REACHED_WITH_MESSAGE("Unhandled stream message %s to %" PRIu64, IPC::description(decoder.messageName()), decoder.destinationID());\n')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ namespace WebKit {
void TestWithStreamBatched::didReceiveStreamMessage(IPC::StreamServerConnection& connection, IPC::Decoder& decoder)
{
if (decoder.messageName() == Messages::TestWithStreamBatched::SendString::name())
return IPC::handleMessage<Messages::TestWithStreamBatched::SendString>(connection.connection(), decoder, this, &TestWithStreamBatched::sendString);
return IPC::handleMessage<Messages::TestWithStreamBatched::SendString>(connection.protectedConnection(), decoder, this, &TestWithStreamBatched::sendString);
UNUSED_PARAM(decoder);
UNUSED_PARAM(connection);
#if ENABLE(IPC_TESTING_API)
if (connection.connection().ignoreInvalidMessageForTesting())
if (connection.protectedConnection()->ignoreInvalidMessageForTesting())
return;
#endif // ENABLE(IPC_TESTING_API)
ASSERT_NOT_REACHED_WITH_MESSAGE("Unhandled stream message %s to %" PRIu64, IPC::description(decoder.messageName()), decoder.destinationID());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ namespace WebKit {
void TestWithStream::didReceiveStreamMessage(IPC::StreamServerConnection& connection, IPC::Decoder& decoder)
{
if (decoder.messageName() == Messages::TestWithStream::SendString::name())
return IPC::handleMessage<Messages::TestWithStream::SendString>(connection.connection(), decoder, this, &TestWithStream::sendString);
return IPC::handleMessage<Messages::TestWithStream::SendString>(connection.protectedConnection(), decoder, this, &TestWithStream::sendString);
if (decoder.messageName() == Messages::TestWithStream::SendStringAsync::name())
return IPC::handleMessageAsync<Messages::TestWithStream::SendStringAsync>(connection.connection(), decoder, this, &TestWithStream::sendStringAsync);
return IPC::handleMessageAsync<Messages::TestWithStream::SendStringAsync>(connection.protectedConnection(), decoder, this, &TestWithStream::sendStringAsync);
if (decoder.messageName() == Messages::TestWithStream::CallWithIdentifier::name())
return IPC::handleMessageAsyncWithReplyID<Messages::TestWithStream::CallWithIdentifier>(connection.connection(), decoder, this, &TestWithStream::callWithIdentifier);
return IPC::handleMessageAsyncWithReplyID<Messages::TestWithStream::CallWithIdentifier>(connection.protectedConnection(), decoder, this, &TestWithStream::callWithIdentifier);
#if PLATFORM(COCOA)
if (decoder.messageName() == Messages::TestWithStream::SendMachSendRight::name())
return IPC::handleMessage<Messages::TestWithStream::SendMachSendRight>(connection.connection(), decoder, this, &TestWithStream::sendMachSendRight);
return IPC::handleMessage<Messages::TestWithStream::SendMachSendRight>(connection.protectedConnection(), decoder, this, &TestWithStream::sendMachSendRight);
#endif
if (decoder.messageName() == Messages::TestWithStream::SendStringSync::name())
return IPC::handleMessageSynchronous<Messages::TestWithStream::SendStringSync>(connection, decoder, this, &TestWithStream::sendStringSync);
Expand All @@ -66,7 +66,7 @@ void TestWithStream::didReceiveStreamMessage(IPC::StreamServerConnection& connec
UNUSED_PARAM(decoder);
UNUSED_PARAM(connection);
#if ENABLE(IPC_TESTING_API)
if (connection.connection().ignoreInvalidMessageForTesting())
if (connection.protectedConnection()->ignoreInvalidMessageForTesting())
return;
#endif // ENABLE(IPC_TESTING_API)
ASSERT_NOT_REACHED_WITH_MESSAGE("Unhandled stream message %s to %" PRIu64, IPC::description(decoder.messageName()), decoder.destinationID());
Expand Down

0 comments on commit 295c392

Please sign in to comment.