Skip to content
Permalink
Browse files
Add a new pattern to instruction selector to utilize SMADDL supported…
… by ARM64

https://bugs.webkit.org/show_bug.cgi?id=227188

Patch by Yijia Huang <yijia_huang@apple.com> on 2021-06-22
Reviewed by Saam Barati.

Signed Multiply-Add Long(SMADDL), supported by ARM64, multiplies two 32-bit
register values, adds a 64-bit register value, and writes the result to the
64-bit destination register. The instruction selector can utilize this to
lowering certain patterns in B3 IR before further Air optimization.

Given the operation:

smaddl d, n, m, a

The equivalent patterns would be:

d = a + SExt32(n) * SExt32(m)
d = SExt32(n) * SExt32(m) + a

Given B3 IR:
Int @0 = ArgumentReg(%x0)
Int @1 = SExt32(Trunc(ArgumentReg(%x1)))
Int @2 = SExt32(Trunc(ArgumentReg(%x2)))
Int @3 = Mul(@1, @2)
Int @4 = Add(@0, @3)
Void@5 = Return(@4, Terminal)

Before Adding SMADDL:
// Old optimized AIR
SignExtend32ToPtr  %x1, %x1,           @1
SignExtend32ToPtr  %x2, %x2,           @2
MultiplyAdd64      %x1, %x2, %x0, %x0, @4
Ret64              %x0,                @5

After Adding SMADDL:
// New optimized AIR
MultiplyAddSignExtend32 %x1, %x2, %x0, %x0, @8
Ret64                   %x0,                @9

* assembler/MacroAssemblerARM64.h:
(JSC::MacroAssemblerARM64::multiplyAddSignExtend32):
* assembler/testmasm.cpp:
(JSC::testMultiplyAddSignExtend32Left):
(JSC::testMultiplyAddSignExtend32Right):
* b3/B3LowerToAir.cpp:
* b3/air/AirOpcode.opcodes:
* b3/testb3.h:
* b3/testb3_2.cpp:
(testMulAddArg):
(testMulAddArgsLeft):
(testMulAddArgsRight):
(testMulAddSignExtend32ArgsLeft):
(testMulAddSignExtend32ArgsRight):
(testMulAddArgsLeft32):
(testMulAddArgsRight32):
* b3/testb3_3.cpp:
(addArgTests):

Canonical link: https://commits.webkit.org/239048@main
git-svn-id: https://svn.webkit.org/repository/webkit/trunk@279134 268f45cc-cd09-0410-ab3c-d52691b4dbfc
  • Loading branch information
Yijia Huang authored and webkit-commit-queue committed Jun 22, 2021
1 parent e9d0c9d commit d6fa13e3f4900b43650d85655fff5f37b6ccd514
Showing 8 changed files with 241 additions and 19 deletions.
@@ -1,3 +1,63 @@
2021-06-22 Yijia Huang <yijia_huang@apple.com>

Add a new pattern to instruction selector to utilize SMADDL supported by ARM64
https://bugs.webkit.org/show_bug.cgi?id=227188

Reviewed by Saam Barati.

Signed Multiply-Add Long(SMADDL), supported by ARM64, multiplies two 32-bit
register values, adds a 64-bit register value, and writes the result to the
64-bit destination register. The instruction selector can utilize this to
lowering certain patterns in B3 IR before further Air optimization.

Given the operation:

smaddl d, n, m, a

The equivalent patterns would be:

d = a + SExt32(n) * SExt32(m)
d = SExt32(n) * SExt32(m) + a

Given B3 IR:
Int @0 = ArgumentReg(%x0)
Int @1 = SExt32(Trunc(ArgumentReg(%x1)))
Int @2 = SExt32(Trunc(ArgumentReg(%x2)))
Int @3 = Mul(@1, @2)
Int @4 = Add(@0, @3)
Void@5 = Return(@4, Terminal)

Before Adding SMADDL:
// Old optimized AIR
SignExtend32ToPtr %x1, %x1, @1
SignExtend32ToPtr %x2, %x2, @2
MultiplyAdd64 %x1, %x2, %x0, %x0, @4
Ret64 %x0, @5

After Adding SMADDL:
// New optimized AIR
MultiplyAddSignExtend32 %x1, %x2, %x0, %x0, @8
Ret64 %x0, @9

* assembler/MacroAssemblerARM64.h:
(JSC::MacroAssemblerARM64::multiplyAddSignExtend32):
* assembler/testmasm.cpp:
(JSC::testMultiplyAddSignExtend32Left):
(JSC::testMultiplyAddSignExtend32Right):
* b3/B3LowerToAir.cpp:
* b3/air/AirOpcode.opcodes:
* b3/testb3.h:
* b3/testb3_2.cpp:
(testMulAddArg):
(testMulAddArgsLeft):
(testMulAddArgsRight):
(testMulAddSignExtend32ArgsLeft):
(testMulAddSignExtend32ArgsRight):
(testMulAddArgsLeft32):
(testMulAddArgsRight32):
* b3/testb3_3.cpp:
(addArgTests):

2021-06-22 Saam Barati <sbarati@apple.com>

jitCompileAndSetHeuristics shouldn't return true when we fail to compile
@@ -627,6 +627,11 @@ class MacroAssemblerARM64 : public AbstractMacroAssembler<Assembler> {
m_assembler.madd<64>(dest, mulLeft, mulRight, summand);
}

void multiplyAddSignExtend32(RegisterID mulLeft, RegisterID mulRight, RegisterID summand, RegisterID dest)
{
m_assembler.smaddl(dest, mulLeft, mulRight, summand);
}

void multiplySub64(RegisterID mulLeft, RegisterID mulRight, RegisterID minuend, RegisterID dest)
{
m_assembler.msub<64>(dest, mulLeft, mulRight, minuend);
@@ -921,6 +921,52 @@ void testMul32SignExtend()
}
}

void testMultiplyAddSignExtend32Left()
{
// d = SExt32(n) * SExt32(m) + a
auto add = compile([=] (CCallHelpers& jit) {
emitFunctionPrologue(jit);

jit.multiplyAddSignExtend32(GPRInfo::argumentGPR0,
GPRInfo::argumentGPR1,
GPRInfo::argumentGPR2,
GPRInfo::returnValueGPR);

emitFunctionEpilogue(jit);
jit.ret();
});

for (auto n : int32Operands()) {
for (auto m : int32Operands()) {
for (auto a : int64Operands())
CHECK_EQ(invoke<int64_t>(add, n, m, a), static_cast<int64_t>(n) * static_cast<int64_t>(m) + a);
}
}
}

void testMultiplyAddSignExtend32Right()
{
// d = a + SExt32(n) * SExt32(m)
auto add = compile([=] (CCallHelpers& jit) {
emitFunctionPrologue(jit);

jit.multiplyAddSignExtend32(GPRInfo::argumentGPR1,
GPRInfo::argumentGPR2,
GPRInfo::argumentGPR0,
GPRInfo::returnValueGPR);

emitFunctionEpilogue(jit);
jit.ret();
});

for (auto a : int64Operands()) {
for (auto n : int32Operands()) {
for (auto m : int32Operands())
CHECK_EQ(invoke<int64_t>(add, a, n, m), a + static_cast<int64_t>(n) * static_cast<int64_t>(m));
}
}
}

void testSub32Args()
{
for (auto value : int32Operands()) {
@@ -3270,6 +3316,8 @@ void run(const char* filter) WTF_IGNORES_THREAD_SAFETY_ANALYSIS
RUN(testLoadStorePair64Int64());
RUN(testLoadStorePair64Double());
RUN(testMul32SignExtend());
RUN(testMultiplyAddSignExtend32Left());
RUN(testMultiplyAddSignExtend32Right());
RUN(testSub32Args());
RUN(testSub32Imm());
RUN(testSub32ArgImm());
@@ -2544,36 +2544,57 @@ class LowerToAir {
case Add: {
if (tryAppendLea())
return;


ASSERT(isValidForm(MultiplyAdd64, Arg::Tmp, Arg::Tmp, Arg::Tmp, Arg::Tmp)
== isValidForm(MultiplyAddSignExtend32, Arg::Tmp, Arg::Tmp, Arg::Tmp, Arg::Tmp));
Air::Opcode multiplyAddOpcode = tryOpcodeForType(MultiplyAdd32, MultiplyAdd64, m_value->type());
if (isValidForm(multiplyAddOpcode, Arg::Tmp, Arg::Tmp, Arg::Tmp, Arg::Tmp)) {
Value* left = m_value->child(0);
Value* right = m_value->child(1);
if (!imm(right) || m_valueToTmp[right]) {
auto tryAppendMultiplyAdd = [&] (Value* left, Value* right) -> bool {
if (left->opcode() != Mul || !canBeInternal(left))
auto tryMultiply = [&] (Value* v) -> bool {
if (v->opcode() != Mul || !canBeInternal(v))
return false;
if (m_locked.contains(v->child(0)) || m_locked.contains(v->child(1)))
return false;
return true;
};

auto trySExt32 = [&] (Value* v) -> bool {
return v->opcode() == SExt32 && canBeInternal(v);
};

// MADD: d = n * m + a
auto tryAppendMultiplyAdd = [&] (Value* left, Value* right) -> bool {
if (!tryMultiply(left))
return false;
Value* multiplyLeft = left->child(0);
Value* multiplyRight = left->child(1);
if (canBeInternal(multiplyLeft) || canBeInternal(multiplyRight))
return false;

// SMADDL: d = SExt32(n) * SExt32(m) + a
if (multiplyAddOpcode == MultiplyAdd64 && trySExt32(multiplyLeft) && trySExt32(multiplyRight)) {
append(MultiplyAddSignExtend32,
tmp(multiplyLeft->child(0)),
tmp(multiplyRight->child(0)),
tmp(right),
tmp(m_value));
commitInternal(multiplyLeft);
commitInternal(multiplyRight);
commitInternal(left);
return true;
}

append(multiplyAddOpcode, tmp(multiplyLeft), tmp(multiplyRight), tmp(right), tmp(m_value));
commitInternal(left);

return true;
};

if (tryAppendMultiplyAdd(left, right))
return;
if (tryAppendMultiplyAdd(right, left))
if (tryAppendMultiplyAdd(left, right) || tryAppendMultiplyAdd(right, left))
return;
}
}

appendBinOp<Add32, Add64, AddDouble, AddFloat, Commutative>(
m_value->child(0), m_value->child(1));
appendBinOp<Add32, Add64, AddDouble, AddFloat, Commutative>(m_value->child(0), m_value->child(1));
return;
}

@@ -251,6 +251,9 @@ arm64: MultiplyAdd32 U:G:32, U:G:32, U:G:32, ZD:G:32
arm64: MultiplyAdd64 U:G:64, U:G:64, U:G:64, D:G:64
Tmp, Tmp, Tmp, Tmp

arm64: MultiplyAddSignExtend32 U:G:32, U:G:32, U:G:64, D:G:64
Tmp, Tmp, Tmp, Tmp

arm64: MultiplySub32 U:G:32, U:G:32, U:G:32, ZD:G:32
Tmp, Tmp, Tmp, Tmp

@@ -872,6 +872,8 @@ void testMulAddArgsLeft();
void testMulAddArgsRight();
void testMulAddArgsLeft32();
void testMulAddArgsRight32();
void testMulAddSignExtend32ArgsLeft();
void testMulAddSignExtend32ArgsRight();
void testMulSubArgsLeft();
void testMulSubArgsRight();
void testMulSubArgsLeft32();
@@ -841,7 +841,10 @@ void testMulAddArg(int a)
root->appendNew<Value>(proc, Mul, Origin(), value, value),
value));

CHECK(compileAndRun<int>(proc, a) == a * a + a);
auto code = compileProc(proc);
if (isARM64())
checkUsesInstruction(*code, "madd");
CHECK(invoke<int64_t>(*code, a, a, a) == a * a + a);
}

void testMulArgs(int a, int b)
@@ -1004,13 +1007,14 @@ void testMulAddArgsLeft()
root->appendNewControlValue(proc, Return, Origin(), added);

auto code = compileProc(proc);
if (isARM64())
checkUsesInstruction(*code, "madd");

auto testValues = int64Operands();
for (auto a : testValues) {
for (auto b : testValues) {
for (auto c : testValues) {
for (auto c : testValues)
CHECK(invoke<int64_t>(*code, a.value, b.value, c.value) == a.value * b.value + c.value);
}
}
}
}
@@ -1028,12 +1032,87 @@ void testMulAddArgsRight()
root->appendNewControlValue(proc, Return, Origin(), added);

auto code = compileProc(proc);
if (isARM64())
checkUsesInstruction(*code, "madd");

auto testValues = int64Operands();
for (auto a : testValues) {
for (auto b : testValues) {
for (auto c : testValues) {
for (auto c : testValues)
CHECK(invoke<int64_t>(*code, a.value, b.value, c.value) == a.value + b.value * c.value);
}
}
}

void testMulAddSignExtend32ArgsLeft()
{
// d = SExt32(n) * SExt32(m) + a
Procedure proc;
BasicBlock* root = proc.addBlock();

Value* nValue = root->appendNew<Value>(
proc, SExt32, Origin(),
root->appendNew<Value>(
proc, Trunc, Origin(),
root->appendNew<ArgumentRegValue>(proc, Origin(), GPRInfo::argumentGPR0)));
Value* mValue = root->appendNew<Value>(
proc, SExt32, Origin(),
root->appendNew<Value>(
proc, Trunc, Origin(),
root->appendNew<ArgumentRegValue>(proc, Origin(), GPRInfo::argumentGPR1)));
Value* aValue = root->appendNew<ArgumentRegValue>(proc, Origin(), GPRInfo::argumentGPR2);

Value* mulValue = root->appendNew<Value>(proc, Mul, Origin(), nValue, mValue);
Value* addValue = root->appendNew<Value>(proc, Add, Origin(), mulValue, aValue);
root->appendNewControlValue(proc, Return, Origin(), addValue);

auto code = compileProc(proc);
if (isARM64())
checkUsesInstruction(*code, "smaddl");

for (auto n : int32Operands()) {
for (auto m : int32Operands()) {
for (auto a : int64Operands()) {
int64_t lhs = invoke<int64_t>(*code, n.value, m.value, a.value);
int64_t rhs = static_cast<int64_t>(n.value) * static_cast<int64_t>(m.value) + a.value;
CHECK(lhs == rhs);
}
}
}
}

void testMulAddSignExtend32ArgsRight()
{
// d = a + SExt32(n) * SExt32(m)
Procedure proc;
BasicBlock* root = proc.addBlock();

Value* aValue = root->appendNew<ArgumentRegValue>(proc, Origin(), GPRInfo::argumentGPR0);
Value* nValue = root->appendNew<Value>(
proc, SExt32, Origin(),
root->appendNew<Value>(
proc, Trunc, Origin(),
root->appendNew<ArgumentRegValue>(proc, Origin(), GPRInfo::argumentGPR1)));
Value* mValue = root->appendNew<Value>(
proc, SExt32, Origin(),
root->appendNew<Value>(
proc, Trunc, Origin(),
root->appendNew<ArgumentRegValue>(proc, Origin(), GPRInfo::argumentGPR2)));

Value* mulValue = root->appendNew<Value>(proc, Mul, Origin(), nValue, mValue);
Value* addValue = root->appendNew<Value>(proc, Add, Origin(), aValue, mulValue);
root->appendNewControlValue(proc, Return, Origin(), addValue);

auto code = compileProc(proc);
if (isARM64())
checkUsesInstruction(*code, "smaddl");

for (auto a : int64Operands()) {
for (auto n : int32Operands()) {
for (auto m : int32Operands()) {
int64_t lhs = invoke<int64_t>(*code, a.value, n.value, m.value);
int64_t rhs = a.value + static_cast<int64_t>(n.value) * static_cast<int64_t>(m.value);
CHECK(lhs == rhs);
}
}
}
@@ -1055,13 +1134,14 @@ void testMulAddArgsLeft32()
root->appendNewControlValue(proc, Return, Origin(), added);

auto code = compileProc(proc);
if (isARM64())
checkUsesInstruction(*code, "madd");

auto testValues = int32Operands();
for (auto a : testValues) {
for (auto b : testValues) {
for (auto c : testValues) {
for (auto c : testValues)
CHECK(invoke<int32_t>(*code, a.value, b.value, c.value) == a.value * b.value + c.value);
}
}
}
}
@@ -1082,13 +1162,14 @@ void testMulAddArgsRight32()
root->appendNewControlValue(proc, Return, Origin(), added);

auto code = compileProc(proc);
if (isARM64())
checkUsesInstruction(*code, "madd");

auto testValues = int32Operands();
for (auto a : testValues) {
for (auto b : testValues) {
for (auto c : testValues) {
for (auto c : testValues)
CHECK(invoke<int32_t>(*code, a.value, b.value, c.value) == a.value + b.value * c.value);
}
}
}
}
@@ -3196,6 +3196,8 @@ void addArgTests(const char* filter, Deque<RefPtr<SharedTask<void()>>>& tasks)
RUN(testMulAddArgsRight());
RUN(testMulAddArgsLeft32());
RUN(testMulAddArgsRight32());
RUN(testMulAddSignExtend32ArgsLeft());
RUN(testMulAddSignExtend32ArgsRight());
RUN(testMulSubArgsLeft());
RUN(testMulSubArgsRight());
RUN(testMulSubArgsLeft32());

0 comments on commit d6fa13e

Please sign in to comment.