Skip to content

Commit

Permalink
Implement register allocator in JIT
Browse files Browse the repository at this point in the history
Signed-off-by: Zoltan Herczeg zherczeg.u-szeged@partner.samsung.com
  • Loading branch information
Zoltan Herczeg committed May 9, 2024
1 parent d728d28 commit 24b3fb6
Show file tree
Hide file tree
Showing 19 changed files with 1,802 additions and 260 deletions.
4 changes: 3 additions & 1 deletion src/jit/Analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ void JITCompiler::buildVariables(uint32_t requiredStackSize)
DependencyGenContext dependencyCtx(dependencySize, requiredStackSize);
bool updateDeps = true;

m_variableList = new VariableList(variableCount);
m_variableList = new VariableList(variableCount, requiredStackSize);
nextTryBlock = m_tryBlockStart;

for (uint32_t i = 0; i < requiredStackSize; i++) {
Expand Down Expand Up @@ -359,6 +359,8 @@ void JITCompiler::buildVariables(uint32_t requiredStackSize)
const ValueTypeVector& param = module()->functionType(tagType->sigIndex())->param();
Label* catchLabel = it.u.handler;

m_variableList->pushCatchUpdate(catchLabel, param.size());

dependencyCtx.update(catchLabel->m_dependencyStart, catchLabel->id(),
STACK_OFFSET(it.stackSizeToBe), param, m_variableList);
}
Expand Down
346 changes: 312 additions & 34 deletions src/jit/Backend.cpp

Large diffs are not rendered by default.

77 changes: 56 additions & 21 deletions src/jit/ByteCodeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,31 +151,29 @@ static bool isFloatGlobal(uint32_t globalIndex, Module* module)
OL3(OTOp2F32, /* SSD */ F32, F32, F32 | S0 | S1) \
OL3(OTOp2F64, /* SSD */ F64, F64, F64 | S0 | S1) \
OL1(OTGetI32, /* S */ I32) \
OL1(OTPutI32, /* D */ I32 | TMP) \
OL1(OTPutI32, /* D */ I32) \
OL1(OTPutI64, /* D */ I64) \
OL1(OTPutV128, /* D */ V128) \
OL1(OTPutPTR, /* D */ PTR) \
OL2(OTMoveI32, /* SD */ I32, I32 | S0) \
OL2(OTMoveF32, /* SD */ F32 | NOTMP, F32 | S0) \
OL2(OTMoveI64, /* SD */ I64, I64 | S0) \
OL2(OTMoveF64, /* SD */ F64 | NOTMP, F64 | S0) \
OL2(OTMoveV128, /* SD */ V128, V128 | S0) \
OL3(OTCompareI64, /* SSD */ I64, I64, I32 | S0 | S1) \
OL2(OTEqzI64, /* SD */ I64, I32) \
OL3(OTCompareI64, /* SSD */ I64, I64, I32) \
OL3(OTCompareF32, /* SSD */ F32, F32, I32) \
OL3(OTCompareF64, /* SSD */ F64, F64, I32) \
OL3(OTCopySignF32, /* SSD */ F32, F32, F32 | TMP | S0 | S1) \
OL3(OTCopySignF64, /* SSD */ F64, F64, F64 | TMP | S0 | S1) \
OL2(OTDemoteF64, /* SD */ F64, F32 | S0) \
OL2(OTPromoteF32, /* SD */ F32, F64 | S0) \
OL4(OTLoadI32, /* SDTT */ I32, I32 | S0, PTR, I32 | S0) \
OL4(OTLoadI64, /* SDTT */ I32, I64 | S0, PTR, I32 | S0) \
OL4(OTLoadF32, /* SDTT */ I32, F32, PTR, I32 | S0) \
OL4(OTLoadF64, /* SDTT */ I32, F64, PTR, I32 | S0) \
OL4(OTLoadV128, /* SDTT */ I32, V128 | TMP, PTR, I32 | S0) \
OL5(OTLoadLaneV128, /* SSDTTT */ I32, V128 | NOTMP, V128 | TMP | S1, PTR, I32 | S0) \
OL5(OTStoreI32, /* SSTTT */ I32, I32, PTR, I32 | S0, I32 | S1) \
OL5(OTStoreI64, /* SSTTT */ I32, I64, PTR, I32 | S0, PTR | S1) \
OL4(OTStoreF32, /* SSTT */ I32, F32 | NOTMP, PTR, I32 | S0) \
OL5(OTStoreI64, /* SSTTT */ I32, I64, PTR, I32 | S0, PTR | S1) \
OL4(OTStoreF64, /* SSTT */ I32, F64 | NOTMP, PTR, I32 | S0) \
OL4(OTStoreV128, /* SSTT */ I32, V128 | TMP, PTR, I32 | S0) \
OL3(OTCallback3Arg, /* SSS */ I32, I32, I32) \
Expand All @@ -188,8 +186,8 @@ static bool isFloatGlobal(uint32_t globalIndex, Module* module)
OL2(OTGlobalSetI64, /* ST */ I64, PTR) \
OL1(OTGlobalSetF32, /* S */ F32 | NOTMP) \
OL1(OTGlobalSetF64, /* S */ F64 | NOTMP) \
OL2(OTConvertInt32FromInt64, /* SD */ I64, I32 | S0) \
OL2(OTConvertInt64FromInt32, /* SD */ I32, I64 | S0) \
OL2(OTConvertInt32FromInt64, /* SD */ I64, I32) \
OL2(OTConvertInt64FromInt32, /* SD */ I32, I64) \
OL2(OTConvertInt32FromFloat32, /* SD */ F32 | TMP, I32 | TMP) \
OL2(OTConvertInt32FromFloat64, /* SD */ F64 | TMP, I32 | TMP) \
OL2(OTConvertInt64FromFloat32Callback, /* SD */ F32, I64) \
Expand All @@ -207,9 +205,10 @@ static bool isFloatGlobal(uint32_t globalIndex, Module* module)
#define OPERAND_TYPE_LIST_MATH \
OL3(OTOp2I64, /* SSD */ I64, I64, I64 | TMP | S0 | S1) \
OL3(OTShiftI64, /* SSD */ I64, I64 | LOW, I64 | TMP | S0) \
OL4(OTMulI64, /* SSDT */ I64, I64, I64 | TMP | S0, I32 | S1) \
OL3(OTMulI64, /* SSDT */ I64, I64, I64 | S0 | S1) \
OL3(OTDivRemI64, /* SSD */ I64, I64, I64 | S0 | S1) \
OL2(OTCountZeroesI64, /* SD */ I64, I64 | TMP | S0) \
OL4(OTLoadI64, /* SDTT */ I32, I64, PTR, I32 | S0) \
OL5(OTStoreI64Low, /* SSTTT */ I32, I64 | LOW, PTR, I32 | S0, PTR | S1) \
OL1(OTGlobalGetI64, /* D */ I64_LOW) \
OL2(OTConvertInt32FromFloat32Callback, /* SD */ F32, I32) \
Expand All @@ -220,6 +219,7 @@ static bool isFloatGlobal(uint32_t globalIndex, Module* module)

#define OPERAND_TYPE_LIST_MATH \
OL3(OTOp2I64, /* SSD */ I64, I64, I64 | S0 | S1) \
OL4(OTLoadI64, /* SDTT */ I32, I64 | S0, PTR, I32 | S0) \
OL1(OTGlobalGetI64, /* D */ I64) \
OL2(OTConvertInt64FromFloat32, /* SD */ F32 | TMP, I64 | TMP) \
OL2(OTConvertInt64FromFloat64, /* SD */ F64 | TMP, I64 | TMP) \
Expand All @@ -234,8 +234,8 @@ static bool isFloatGlobal(uint32_t globalIndex, Module* module)
OL1(OTGlobalSetV128, /* S */ V128 | NOTMP) \
OL2(OTSplatI32, /* SD */ I32, V128 | TMP) \
OL2(OTSplatI64, /* SD */ I64, V128 | TMP) \
OL2(OTSplatF32, /* SD */ F32 | NOTMP, V128 | TMP | S0) \
OL2(OTSplatF64, /* SD */ F64 | NOTMP, V128 | TMP | S0) \
OL2(OTSplatF32, /* SD */ F32 | NOTMP, V128 | TMP) \
OL2(OTSplatF64, /* SD */ F64 | NOTMP, V128 | TMP) \
OL2(OTV128ToI32, /* SD */ V128 | TMP, I32) \
OL4(OTBitSelectV128, /* SSSD */ V128 | TMP, V128 | TMP, V128 | NOTMP, V128 | TMP | S2) \
OL2(OTExtractLaneI64, /* SD */ V128 | TMP, I64) \
Expand Down Expand Up @@ -565,7 +565,7 @@ static void compileFunction(JITCompiler* compiler)
case ByteCode::I64GeUOpcode: {
group = Instruction::Compare;
paramCount = 2;
info = Instruction::kIsMergeCompare;
info = Instruction::kIsMergeCompare | Instruction::kFreeUnusedEarly;
requiredInit = OTCompareI64;
break;
}
Expand Down Expand Up @@ -723,20 +723,22 @@ static void compileFunction(JITCompiler* compiler)
case ByteCode::I64EqzOpcode: {
group = Instruction::Compare;
paramCount = 1;
info = Instruction::kIsMergeCompare;
requiredInit = OTOp1I64;
info = Instruction::kIsMergeCompare | Instruction::kFreeUnusedEarly;
requiredInit = OTEqzI64;
break;
}
case ByteCode::I32WrapI64Opcode: {
group = Instruction::Convert;
paramCount = 1;
info = Instruction::kFreeUnusedEarly;
requiredInit = OTConvertInt32FromInt64;
break;
}
case ByteCode::I64ExtendI32SOpcode:
case ByteCode::I64ExtendI32UOpcode: {
group = Instruction::Convert;
paramCount = 1;
info = Instruction::kFreeUnusedEarly;
requiredInit = OTConvertInt64FromInt32;
break;
}
Expand Down Expand Up @@ -934,6 +936,7 @@ static void compileFunction(JITCompiler* compiler)
Instruction* instr = compiler->appendExtended(byteCode, Instruction::Call, opcode,
functionType->param().size() + callerCount, functionType->result().size());
Operand* operand = instr->operands();
instr->addInfo(Instruction::kIsCallback | Instruction::kFreeUnusedEarly);

for (auto it : functionType->param()) {
operand->ref = STACK_OFFSET(*stackOffset);
Expand Down Expand Up @@ -1206,12 +1209,14 @@ static void compileFunction(JITCompiler* compiler)
case ByteCode::F32X4SplatOpcode: {
group = Instruction::SplatSIMD;
paramCount = 1;
info = Instruction::kFreeUnusedEarly;
requiredInit = OTSplatF32;
break;
}
case ByteCode::F64X2SplatOpcode: {
group = Instruction::SplatSIMD;
paramCount = 1;
info = Instruction::kFreeUnusedEarly;
requiredInit = OTSplatF64;
break;
}
Expand Down Expand Up @@ -1460,13 +1465,13 @@ static void compileFunction(JITCompiler* compiler)

switch (opcode) {
case ByteCode::MoveI32Opcode:
requiredInit = OTMoveI32;
requiredInit = OTOp1I32;
break;
case ByteCode::MoveF32Opcode:
requiredInit = OTMoveF32;
break;
case ByteCode::MoveI64Opcode:
requiredInit = OTMoveI64;
requiredInit = OTOp1I64;
break;
case ByteCode::MoveF64Opcode:
requiredInit = OTMoveF64;
Expand All @@ -1493,7 +1498,6 @@ static void compileFunction(JITCompiler* compiler)
Operand* operands = instr->operands();

if (isFloatGlobal(globalGet32->index(), compiler->module())) {
instr->addInfo(Instruction::kIsGlobalFloatBit);
instr->setRequiredRegsDescriptor(OTGlobalGetF32);
}

Expand All @@ -1508,7 +1512,6 @@ static void compileFunction(JITCompiler* compiler)
Operand* operands = instr->operands();

if (isFloatGlobal(globalGet64->index(), compiler->module())) {
instr->addInfo(Instruction::kIsGlobalFloatBit);
instr->setRequiredRegsDescriptor(OTGlobalGetF64);
}

Expand All @@ -1533,7 +1536,6 @@ static void compileFunction(JITCompiler* compiler)
Operand* operands = instr->operands();

if (isFloatGlobal(globalSet32->index(), compiler->module())) {
instr->addInfo(Instruction::kIsGlobalFloatBit);
instr->setRequiredRegsDescriptor(OTGlobalSetF32);
}

Expand All @@ -1548,7 +1550,6 @@ static void compileFunction(JITCompiler* compiler)
Operand* operands = instr->operands();

if (isFloatGlobal(globalSet64->index(), compiler->module())) {
instr->addInfo(Instruction::kIsGlobalFloatBit);
instr->setRequiredRegsDescriptor(OTGlobalSetF64);
}

Expand Down Expand Up @@ -1899,19 +1900,53 @@ static void compileFunction(JITCompiler* compiler)
}

compiler->buildVariables(STACK_OFFSET(function->requiredStackSize()));
compiler->allocateRegisters();

if (compiler->verboseLevel() >= 1) {
compiler->dump();
}

compiler->allocateRegisters();
compiler->freeVariables();

Walrus::JITFunction* jitFunc = new JITFunction();

function->setJITFunction(jitFunc);
compiler->compileFunction(jitFunc, true);
}

const uint8_t* VariableList::getOperandDescriptor(Instruction* instr)
{
uint32_t requiredInit = OTNone;

switch (instr->opcode()) {
case ByteCode::Load32Opcode:
requiredInit = OTLoadF32;
break;
case ByteCode::Load64Opcode:
requiredInit = OTLoadF64;
break;
case ByteCode::Store32Opcode:
requiredInit = OTStoreF32;
break;
case ByteCode::Store64Opcode:
requiredInit = OTStoreF64;
break;
default:
break;
}

if (requiredInit != OTNone) {
ASSERT((instr->paramCount() + instr->resultCount()) == 2);
VariableList::Variable& variable = variables[instr->getParam(1)->ref];

if (variable.info & Instruction::FloatOperandMarker) {
return Instruction::getOperandDescriptorByOffset(requiredInit);
}
}

return instr->getOperandDescriptor();
}

void Module::jitCompile(ModuleFunction** functions, size_t functionsLength, int verboseLevel)
{
JITCompiler compiler(this, verboseLevel);
Expand Down
45 changes: 40 additions & 5 deletions src/jit/CallInl.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,19 +95,40 @@ static void emitCall(sljit_compiler* compiler, Instruction* instr)

Operand* operand = instr->operands();
for (auto it : functionType->param()) {
if (VARIABLE_TYPE(operand->ref) == Operand::Immediate && !(VARIABLE_GET_IMM(operand->ref)->info() & Instruction::kKeepInstruction)) {
emitStoreImmediate(compiler, *stackOffset, VARIABLE_GET_IMM(operand->ref));
Operand dst;
dst.ref = VARIABLE_SET(STACK_OFFSET(*stackOffset), Operand::Offset);

switch (VARIABLE_TYPE(operand->ref)) {
case Operand::Immediate:
ASSERT(!(VARIABLE_GET_IMM(operand->ref)->info() & Instruction::kKeepInstruction));
emitStoreImmediate(compiler, &dst, VARIABLE_GET_IMM(operand->ref), false);
break;
case Operand::Register:
emitMove(compiler, Instruction::valueTypeToOperandType(it), operand, &dst);
break;
}

operand++;
stackOffset += (valueSize(it) + (sizeof(size_t) - 1)) / sizeof(size_t);
}

if (instr->opcode() == ByteCode::CallIndirectOpcode) {
if (VARIABLE_TYPE(operand->ref) == Operand::Immediate && !(VARIABLE_GET_IMM(operand->ref)->info() & Instruction::kKeepInstruction)) {
CallIndirect* callIndirect = reinterpret_cast<CallIndirect*>(instr->byteCode());
emitStoreImmediate(compiler, callIndirect->calleeOffset(), VARIABLE_GET_IMM(operand->ref));
CallIndirect* callIndirect = reinterpret_cast<CallIndirect*>(instr->byteCode());

switch (VARIABLE_TYPE(operand->ref)) {
case Operand::Immediate: {
ASSERT(!(VARIABLE_GET_IMM(operand->ref)->info() & Instruction::kKeepInstruction));
Const32* value = reinterpret_cast<Const32*>(VARIABLE_GET_IMM(operand->ref)->byteCode());
sljit_emit_op1(compiler, SLJIT_MOV32, SLJIT_MEM1(kFrameReg), callIndirect->calleeOffset(), SLJIT_IMM, static_cast<sljit_s32>(value->value()));
break;
}
case Operand::Register: {
sljit_emit_op1(compiler, SLJIT_MOV32, SLJIT_MEM1(kFrameReg), callIndirect->calleeOffset(), static_cast<sljit_s32>(VARIABLE_GET_REF(operand->ref)), 0);
break;
}
}

operand++;
}

sljit_emit_op1(compiler, SLJIT_MOV_P, SLJIT_R0, 0, SLJIT_IMM, reinterpret_cast<sljit_sw>(instr->byteCode()));
Expand All @@ -118,6 +139,20 @@ static void emitCall(sljit_compiler* compiler, Instruction* instr)

sljit_jump* jump = sljit_emit_cmp(compiler, SLJIT_NOT_EQUAL, SLJIT_R0, 0, SLJIT_IMM, ExecutionContext::NoError);

for (auto it : functionType->result()) {
ASSERT(VARIABLE_TYPE(operand->ref) != Operand::Immediate);

if (VARIABLE_TYPE(operand->ref) == Operand::Register) {
Operand src;

src.ref = VARIABLE_SET(STACK_OFFSET(*stackOffset), Operand::Offset);
emitMove(compiler, Instruction::valueTypeToOperandType(it), &src, operand);
}

operand++;
stackOffset += (valueSize(it) + (sizeof(size_t) - 1)) / sizeof(size_t);
}

if (context->currentTryBlock == InstanceConstData::globalTryBlock) {
context->appendTrapJump(ExecutionContext::ReturnToLabel, jump);
return;
Expand Down

0 comments on commit 24b3fb6

Please sign in to comment.