Skip to content

Commit

Permalink
Accelerate WTF::equal using NEON intrinsics
Browse files Browse the repository at this point in the history
https://bugs.webkit.org/show_bug.cgi?id=263734
rdar://117541948

Reviewed by Yusuke Suzuki.

1. Adds a hash check to WTF::equal(StringImpl*, StringImpl*) to quickly rule
out matches (thanks @hyjorc1!).

2. Reworks the implementation of WTF::equal for LChar and UChar strings.
Instead of checking for leftover bytes at the end, we use overlapping loads to
reduce the amount of branching per string. And we now switch up-front on the
string length instead of falling through to other branches, reducing branching
for most small strings.

3. When CPU(ARM_NEON) is defined, use 16-byte SIMD chunks to do the comparison,
whenever strings are large enough.

4. Also adds a SIMD loop to WTF::equal(const LChar*, const UChar*, unsigned),
which is dramatically (~5x) faster than the current byte-by-byte pessimistic
implementation.

* Source/WTF/wtf/PlatformCPU.h:
* Source/WTF/wtf/text/StringCommon.h:
(WTF::equal):
* Source/WTF/wtf/text/StringImpl.cpp:
(WTF::equal):
* Source/WTF/wtf/text/StringImpl.h:

Canonical link: https://commits.webkit.org/269866@main
  • Loading branch information
ddegazio committed Oct 27, 2023
1 parent f3216a0 commit 661be0f
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 53 deletions.
2 changes: 1 addition & 1 deletion Source/WTF/wtf/PlatformCPU.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@
# error "Cannot use both of WTF_CPU_ARM_TRADITIONAL and WTF_CPU_ARM_THUMB2 platforms"
#endif /* !defined(WTF_CPU_ARM_TRADITIONAL) && !defined(WTF_CPU_ARM_THUMB2) */

#if defined(__ARM_NEON__) && !defined(WTF_CPU_ARM_NEON)
#if (defined(__ARM_NEON__) || defined(__ARM_NEON)) && !defined(WTF_CPU_ARM_NEON)
#define WTF_CPU_ARM_NEON 1
#endif

Expand Down
154 changes: 102 additions & 52 deletions Source/WTF/wtf/text/StringCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <algorithm>
#include <unicode/uchar.h>
#include <wtf/ASCIICType.h>
#include <wtf/MathExtras.h>
#include <wtf/NotFound.h>
#include <wtf/UnalignedAccess.h>
#include <wtf/text/ASCIIFastPath.h>
Expand Down Expand Up @@ -70,72 +71,105 @@ bool equalLettersIgnoringASCIICase(const char*, ASCIILiteral);
#if (CPU(X86_64) || CPU(ARM64)) && !ASAN_ENABLED
ALWAYS_INLINE bool equal(const LChar* aLChar, const LChar* bLChar, unsigned length)
{
unsigned dwordLength = length >> 3;

const char* a = reinterpret_cast<const char*>(aLChar);
const char* b = reinterpret_cast<const char*>(bLChar);
// These branches could be combined into one, but it's measurably faster
// for length 0 or 1 strings to separate them out like this.
if (!length)
return true;
if (length == 1)
return *aLChar == *bLChar;

if (dwordLength) {
for (unsigned i = 0; i != dwordLength; ++i) {
if (unalignedLoad<uint64_t>(a) != unalignedLoad<uint64_t>(b))
#if COMPILER(GCC_COMPATIBLE)
switch (sizeof(unsigned) * CHAR_BIT - clz(length - 1)) { // Works as really fast log2, since length != 0.
#else
switch (fastLog2(length - 1)) {
#endif
case 0:
RELEASE_ASSERT_NOT_REACHED();
case 1: // Length is 2.
return unalignedLoad<uint16_t>(aLChar) == unalignedLoad<uint16_t>(bLChar);
case 2: // Length is 3 or 4.
return unalignedLoad<uint16_t>(aLChar) == unalignedLoad<uint16_t>(bLChar)
&& unalignedLoad<uint16_t>(aLChar + length - 2) == unalignedLoad<uint16_t>(bLChar + length - 2);
case 3: // Length is between 5 and 8 inclusive.
return unalignedLoad<uint32_t>(aLChar) == unalignedLoad<uint32_t>(bLChar)
&& unalignedLoad<uint32_t>(aLChar + length - 4) == unalignedLoad<uint32_t>(bLChar + length - 4);
case 4: // Length is between 9 and 16 inclusive.
return unalignedLoad<uint64_t>(aLChar) == unalignedLoad<uint64_t>(bLChar)
&& unalignedLoad<uint64_t>(aLChar + length - 8) == unalignedLoad<uint64_t>(bLChar + length - 8);
#if CPU(ARM64)
case 5: // Length is between 17 and 32 inclusive.
return vminvq_u8(vandq_u8(
vceqq_u8(unalignedLoad<uint8x16_t>(aLChar), unalignedLoad<uint8x16_t>(bLChar)),
vceqq_u8(unalignedLoad<uint8x16_t>(aLChar + length - 16), unalignedLoad<uint8x16_t>(bLChar + length - 16))
));
default: // Length is longer than 32 bytes.
if (!vminvq_u8(vceqq_u8(unalignedLoad<uint8x16_t>(aLChar), unalignedLoad<uint8x16_t>(bLChar))))
return false;
for (unsigned i = length % 16; i < length; i += 16) {
if (!vminvq_u8(vceqq_u8(unalignedLoad<uint8x16_t>(aLChar + i), unalignedLoad<uint8x16_t>(bLChar + i))))
return false;

a += sizeof(uint64_t);
b += sizeof(uint64_t);
}
}

if (length & 4) {
if (unalignedLoad<uint32_t>(a) != unalignedLoad<uint32_t>(b))
return false;

a += sizeof(uint32_t);
b += sizeof(uint32_t);
}

if (length & 2) {
if (unalignedLoad<uint16_t>(a) != unalignedLoad<uint16_t>(b))
return true;
#else
default: // Length is longer than 16 bytes.
if (unalignedLoad<uint64_t>(aLChar) != unalignedLoad<uint64_t>(bLChar))
return false;

a += sizeof(uint16_t);
b += sizeof(uint16_t);
for (unsigned i = length % 8; i < length; i += 8) {
if (unalignedLoad<uint64_t>(aLChar + i) != unalignedLoad<uint64_t>(bLChar + i))
return false;
}
return true;
#endif
}

if (length & 1 && (*reinterpret_cast<const LChar*>(a) != *reinterpret_cast<const LChar*>(b)))
return false;

return true;
}

ALWAYS_INLINE bool equal(const UChar* aUChar, const UChar* bUChar, unsigned length)
{
unsigned dwordLength = length >> 2;

const char* a = reinterpret_cast<const char*>(aUChar);
const char* b = reinterpret_cast<const char*>(bUChar);
if (!length)
return true;
if (length == 1)
return *aUChar == *bUChar;

if (dwordLength) {
for (unsigned i = 0; i != dwordLength; ++i) {
if (unalignedLoad<uint64_t>(a) != unalignedLoad<uint64_t>(b))
#if COMPILER(GCC_COMPATIBLE)
switch (sizeof(unsigned) * CHAR_BIT - clz(length - 1)) { // Works as really fast log2, since length != 0.
#else
switch (fastLog2(length - 1)) {
#endif
case 0:
RELEASE_ASSERT_NOT_REACHED();
case 1: // Length is 2 (4 bytes).
return unalignedLoad<uint32_t>(aUChar) == unalignedLoad<uint32_t>(bUChar);
case 2: // Length is 3 or 4 (6-8 bytes).
return unalignedLoad<uint32_t>(aUChar) == unalignedLoad<uint32_t>(bUChar)
&& unalignedLoad<uint32_t>(aUChar + length - 2) == unalignedLoad<uint32_t>(bUChar + length - 2);
case 3: // Length is between 5 and 8 inclusive (10-16 bytes).
return unalignedLoad<uint64_t>(aUChar) == unalignedLoad<uint64_t>(bUChar)
&& unalignedLoad<uint64_t>(aUChar + length - 4) == unalignedLoad<uint64_t>(bUChar + length - 4);
#if CPU(ARM64)
case 4: // Length is between 9 and 16 inclusive (18-32 bytes).
return vminvq_u16(vandq_u16(
vceqq_u16(unalignedLoad<uint16x8_t>(aUChar), unalignedLoad<uint16x8_t>(bUChar)),
vceqq_u16(unalignedLoad<uint16x8_t>(aUChar + length - 8), unalignedLoad<uint16x8_t>(bUChar + length - 8))
));
default: // Length is longer than 16 (32 bytes).
if (!vminvq_u16(vceqq_u16(unalignedLoad<uint16x8_t>(aUChar), unalignedLoad<uint16x8_t>(bUChar))))
return false;
for (unsigned i = length % 8; i < length; i += 8) {
if (!vminvq_u16(vceqq_u16(unalignedLoad<uint16x8_t>(aUChar + i), unalignedLoad<uint16x8_t>(bUChar + i))))
return false;

a += sizeof(uint64_t);
b += sizeof(uint64_t);
}
}

if (length & 2) {
if (unalignedLoad<uint32_t>(a) != unalignedLoad<uint32_t>(b))
return true;
#else
default: // Length is longer than 8 (16 bytes).
if (unalignedLoad<uint64_t>(aUChar) != unalignedLoad<uint64_t>(bUChar))
return false;

a += sizeof(uint32_t);
b += sizeof(uint32_t);
for (unsigned i = length % 4; i < length; i += 4) {
if (unalignedLoad<uint64_t>(aUChar + i) != unalignedLoad<uint64_t>(bUChar + i))
return false;
}
return true;
#endif
}

if (length & 1 && (*reinterpret_cast<const UChar*>(a) != *reinterpret_cast<const UChar*>(b)))
return false;

return true;
}
#elif CPU(X86) && !ASAN_ENABLED
ALWAYS_INLINE bool equal(const LChar* aLChar, const LChar* bLChar, unsigned length)
Expand Down Expand Up @@ -296,6 +330,22 @@ ALWAYS_INLINE bool equal(const UChar* a, const UChar* b, unsigned length)

ALWAYS_INLINE bool equal(const LChar* a, const UChar* b, unsigned length)
{
#if CPU(ARM64)
if (length >= 8) {
uint16x8_t aHalves = vmovl_u8(unalignedLoad<uint8x8_t>(a)); // Extends 8 LChars into 8 UChars.
uint16x8_t bHalves = unalignedLoad<uint16x8_t>(b);
if (!vminvq_u16(vceqq_u16(aHalves, bHalves)))
return false;
for (unsigned i = length % 8; i < length; i += 8) {
aHalves = vmovl_u8(unalignedLoad<uint8x8_t>(a + i));
bHalves = unalignedLoad<uint16x8_t>(b + i);
if (!vminvq_u16(vceqq_u16(aHalves, bHalves)))
return false;
}
return true;
}
// Otherwise, we just do a naive loop.
#endif
for (unsigned i = 0; i < length; ++i) {
if (a[i] != b[i])
return false;
Expand Down
4 changes: 4 additions & 0 deletions Source/WTF/wtf/text/StringImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1458,6 +1458,10 @@ bool equal(const StringImpl* a, const LChar* b)

bool equal(const StringImpl& a, const StringImpl& b)
{
unsigned aHash = a.rawHash();
unsigned bHash = b.rawHash();
if (aHash != bHash && aHash && bHash)
return false;
return equalCommon(a, b);
}

Expand Down
2 changes: 2 additions & 0 deletions Source/WTF/wtf/text/StringImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ class StringImpl : private StringImplShape {
template<typename> friend struct WTF::BufferFromStaticDataTranslator;
template<typename> friend struct WTF::HashAndCharactersTranslator;

friend WTF_EXPORT_PRIVATE bool equal(const StringImpl&, const StringImpl&);

public:
enum BufferOwnership { BufferInternal, BufferOwned, BufferSubstring, BufferExternal };

Expand Down

0 comments on commit 661be0f

Please sign in to comment.