Skip to content

Commit

Permalink
Use std::span more in WebSocketHandshake
Browse files Browse the repository at this point in the history
https://bugs.webkit.org/show_bug.cgi?id=271650

Reviewed by Darin Adler.

* Source/WebCore/Modules/websockets/WebSocketHandshake.cpp:
(WebCore::trimInputSample):
(WebCore::WebSocketHandshake::readServerHandshake):
(WebCore::WebSocketHandshake::readStatusLine):
(WebCore::WebSocketHandshake::readHTTPHeaders):
* Source/WebCore/Modules/websockets/WebSocketHandshake.h:
* Source/WebKit/NetworkProcess/curl/WebSocketTaskCurl.cpp:
(WebKit::WebSocketTask::validateOpeningHandshake):
* Source/WebKitLegacy/WebCoreSupport/WebSocketChannel.cpp:
(WebCore::WebSocketChannel::processBuffer):

Canonical link: https://commits.webkit.org/276696@main
  • Loading branch information
cdumez committed Mar 26, 2024
1 parent bbaf174 commit 0827d78
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 57 deletions.
102 changes: 50 additions & 52 deletions Source/WebCore/Modules/websockets/WebSocketHandshake.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,11 @@ static String hostName(const URL& url, bool secure)
}

static constexpr size_t maxInputSampleSize = 128;
static String trimInputSample(const uint8_t* p, size_t length)
static String trimInputSample(std::span<const uint8_t> input)
{
if (length <= maxInputSampleSize)
return String(p, length);
return makeString(std::span { p, length }.first(maxInputSampleSize), horizontalEllipsis);
if (input.size() <= maxInputSampleSize)
return input;
return makeString(input.first(maxInputSampleSize), horizontalEllipsis);
}

static String generateSecWebSocketKey()
Expand Down Expand Up @@ -216,17 +216,18 @@ void WebSocketHandshake::reset()
m_extensionDispatcher.reset();
}

int WebSocketHandshake::readServerHandshake(const uint8_t* header, size_t len)
int WebSocketHandshake::readServerHandshake(std::span<const uint8_t> header)
{
ASSERT(header.size() <= static_cast<size_t>(std::numeric_limits<int>::max()));
m_mode = Incomplete;
int statusCode;
String statusText;
int lineLength = readStatusLine(header, len, statusCode, statusText);
int lineLength = readStatusLine(header, statusCode, statusText);
if (lineLength == -1)
return -1;
if (statusCode == -1) {
m_mode = Failed; // m_failureReason is set inside readStatusLine().
return len;
return header.size();
}
LOG(Network, "WebSocketHandshake %p readServerHandshake() Status code is %d", this, statusCode);

Expand All @@ -237,28 +238,28 @@ int WebSocketHandshake::readServerHandshake(const uint8_t* header, size_t len)
if (statusCode != 101) {
m_mode = Failed;
m_failureReason = makeString("Unexpected response code: ", statusCode);
return len;
return header.size();
}
m_mode = Normal;
if (!memmem(header, len, "\r\n\r\n", 4)) {
if (!memmem(header.data(), header.size(), "\r\n\r\n", 4)) {
// Just hasn't been received fully yet.
m_mode = Incomplete;
return -1;
}
auto p = readHTTPHeaders(header + lineLength, header + len);
if (!p) {
auto p = readHTTPHeaders(header.subspan(lineLength));
if (!p.data()) {
LOG(Network, "WebSocketHandshake %p readServerHandshake() readHTTPHeaders() failed", this);
m_mode = Failed; // m_failureReason is set inside readHTTPHeaders().
return len;
return header.size();
}
if (!checkResponseHeaders()) {
LOG(Network, "WebSocketHandshake %p readServerHandshake() checkResponseHeaders() failed", this);
m_mode = Failed;
return p - header;
return p.data() - header.data();
}

m_mode = Connected;
return p - header;
return p.data() - header.data();
}

WebSocketHandshake::Mode WebSocketHandshake::mode() const
Expand Down Expand Up @@ -356,7 +357,7 @@ static inline bool headerHasValidHTTPVersion(StringView httpStatusLine)
// Returns the header length (including "\r\n"), or -1 if we have not received enough data yet.
// If the line is malformed or the status code is not a 3-digit number,
// statusCode and statusText will be set to -1 and a null string, respectively.
int WebSocketHandshake::readStatusLine(const uint8_t* header, size_t headerLength, int& statusCode, String& statusText)
int WebSocketHandshake::readStatusLine(std::span<const uint8_t> header, int& statusCode, String& statusText)
{
// Arbitrary size limit to prevent the server from sending an unbounded
// amount of data with no newlines and forcing us to buffer it all.
Expand All @@ -365,57 +366,55 @@ int WebSocketHandshake::readStatusLine(const uint8_t* header, size_t headerLengt
statusCode = -1;
statusText = nullAtom();

const uint8_t* space1 = nullptr;
const uint8_t* space2 = nullptr;
const uint8_t* p;
size_t consumedLength;

for (p = header, consumedLength = 0; consumedLength < headerLength; p++, consumedLength++) {
if (*p == ' ') {
if (!space1)
space1 = p;
else if (!space2)
space2 = p;
} else if (*p == '\0') {
std::optional<size_t> firstSpaceIndex;
std::optional<size_t> secondSpaceIndex;

size_t index = 0;
for (; index < header.size(); ++index) {
if (header[index] == ' ') {
if (!firstSpaceIndex)
firstSpaceIndex = index;
else if (!secondSpaceIndex)
secondSpaceIndex = index;
} else if (header[index] == '\0') {
// The caller isn't prepared to deal with null bytes in status
// line. WebSockets specification doesn't prohibit this, but HTTP
// does, so we'll just treat this as an error.
m_failureReason = "Status line contains embedded null"_s;
return p + 1 - header;
} else if (!isASCII(*p)) {
return index + 1;
} else if (!isASCII(header[index])) {
m_failureReason = "Status line contains non-ASCII character"_s;
return p + 1 - header;
} else if (*p == '\n')
return index + 1;
} else if (header[index] == '\n')
break;
}
if (consumedLength == headerLength)
if (index == header.size())
return -1; // We have not received '\n' yet.

auto end = p + 1;
int lineLength = end - header;
int lineLength = index + 1;
if (lineLength > maximumLength) {
m_failureReason = "Status line is too long"_s;
return maximumLength;
}

// The line must end with "\r\n".
if (lineLength < 2 || *(end - 2) != '\r') {
if (lineLength < 2 || header[index - 1] != '\r') {
m_failureReason = "Status line does not end with CRLF"_s;
return lineLength;
}

if (!space1 || !space2) {
m_failureReason = makeString("No response code found: ", trimInputSample(header, lineLength - 2));
if (!firstSpaceIndex || !secondSpaceIndex) {
m_failureReason = makeString("No response code found: ", trimInputSample(header.first(lineLength - 2)));
return lineLength;
}

StringView httpStatusLine(std::span(header, space1 - header));
StringView httpStatusLine(header.first(*firstSpaceIndex));
if (!headerHasValidHTTPVersion(httpStatusLine)) {
m_failureReason = makeString("Invalid HTTP version string: ", httpStatusLine);
return lineLength;
}

StringView statusCodeString(std::span(space1 + 1, space2 - space1 - 1));
StringView statusCodeString(header.subspan(*firstSpaceIndex + 1, *secondSpaceIndex - *firstSpaceIndex - 1));
if (statusCodeString.length() != 3) // Status code must consist of three digits.
return lineLength;
for (int i = 0; i < 3; ++i) {
Expand All @@ -426,23 +425,22 @@ int WebSocketHandshake::readStatusLine(const uint8_t* header, size_t headerLengt
}

statusCode = parseInteger<int>(statusCodeString).value();
statusText = String(space2 + 1, end - space2 - 3); // Exclude "\r\n".
statusText = String(header.subspan(*secondSpaceIndex + 1, index - *secondSpaceIndex - 3)); // Exclude "\r\n".
return lineLength;
}

const uint8_t* WebSocketHandshake::readHTTPHeaders(const uint8_t* start, const uint8_t* end)
std::span<const uint8_t> WebSocketHandshake::readHTTPHeaders(std::span<const uint8_t> data)
{
StringView name;
String value;
bool sawSecWebSocketExtensionsHeaderField = false;
bool sawSecWebSocketAcceptHeaderField = false;
bool sawSecWebSocketProtocolHeaderField = false;
auto p = start;
for (; p < end; p++) {
size_t consumedLength = parseHTTPHeader(std::span(p, end - p), m_failureReason, name, value);
for (; !data.empty(); data = data.subspan(1)) {
size_t consumedLength = parseHTTPHeader(data, m_failureReason, name, value);
if (!consumedLength)
return nullptr;
p += consumedLength;
return { };
data = data.subspan(consumedLength);

// Stop once we consumed an empty line.
if (name.isEmpty())
Expand All @@ -462,38 +460,38 @@ const uint8_t* WebSocketHandshake::readHTTPHeaders(const uint8_t* start, const u
|| headerName == HTTPHeaderName::SecWebSocketProtocol)
&& !value.containsOnlyASCII()) {
m_failureReason = makeString(name, " header value should only contain ASCII characters");
return nullptr;
return { };
}

if (headerName == HTTPHeaderName::SecWebSocketExtensions) {
if (sawSecWebSocketExtensionsHeaderField) {
m_failureReason = "The Sec-WebSocket-Extensions header must not appear more than once in an HTTP response"_s;
return nullptr;
return { };
}
if (!m_extensionDispatcher.processHeaderValue(value)) {
m_failureReason = m_extensionDispatcher.failureReason();
return nullptr;
return { };
}
sawSecWebSocketExtensionsHeaderField = true;
} else {
if (headerName == HTTPHeaderName::SecWebSocketAccept) {
if (sawSecWebSocketAcceptHeaderField) {
m_failureReason = "The Sec-WebSocket-Accept header must not appear more than once in an HTTP response"_s;
return nullptr;
return { };
}
sawSecWebSocketAcceptHeaderField = true;
} else if (headerName == HTTPHeaderName::SecWebSocketProtocol) {
if (sawSecWebSocketProtocolHeaderField) {
m_failureReason = "The Sec-WebSocket-Protocol header must not appear more than once in an HTTP response"_s;
return nullptr;
return { };
}
sawSecWebSocketProtocolHeaderField = true;
}

m_serverHandshakeResponse.addHTTPHeaderField(headerName, value);
}
}
return p;
return data;
}

bool WebSocketHandshake::checkResponseHeaders()
Expand Down
6 changes: 3 additions & 3 deletions Source/WebCore/Modules/websockets/WebSocketHandshake.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class WebSocketHandshake {

WEBCORE_EXPORT void reset();

WEBCORE_EXPORT int readServerHandshake(const uint8_t* header, size_t len);
WEBCORE_EXPORT int readServerHandshake(std::span<const uint8_t> header);
WEBCORE_EXPORT Mode mode() const;
WEBCORE_EXPORT String failureReason() const; // Returns a string indicating the reason of failure if mode() == Failed.

Expand All @@ -88,10 +88,10 @@ class WebSocketHandshake {

private:

int readStatusLine(const uint8_t* header, size_t headerLength, int& statusCode, String& statusText);
int readStatusLine(std::span<const uint8_t> header, int& statusCode, String& statusText);

// Reads all headers except for the two predefined ones.
const uint8_t* readHTTPHeaders(const uint8_t* start, const uint8_t* end);
std::span<const uint8_t> readHTTPHeaders(std::span<const uint8_t>);
void processHeaders();
bool checkResponseHeaders();

Expand Down
2 changes: 1 addition & 1 deletion Source/WebKit/NetworkProcess/curl/WebSocketTaskCurl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ Expected<bool, String> WebSocketTask::validateOpeningHandshake()
return makeUnexpected("Unexpected handshakeing condition"_s);
}

auto headerLength = m_handshake->readServerHandshake(m_receiveBuffer.data(), m_receiveBuffer.size());
auto headerLength = m_handshake->readServerHandshake(m_receiveBuffer.span());
if (headerLength <= 0)
return false;

Expand Down
2 changes: 1 addition & 1 deletion Source/WebKitLegacy/WebCoreSupport/WebSocketChannel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ bool WebSocketChannel::processBuffer()
Ref<WebSocketChannel> protectedThis(*this); // The client can close the channel, potentially removing the last reference.

if (m_handshake->mode() == WebSocketHandshake::Incomplete) {
int headerLength = m_handshake->readServerHandshake(m_buffer.data(), m_buffer.size());
int headerLength = m_handshake->readServerHandshake(m_buffer.span());
if (headerLength <= 0)
return false;
if (m_handshake->mode() == WebSocketHandshake::Connected) {
Expand Down

0 comments on commit 0827d78

Please sign in to comment.