Permalink
Browse files

Add support for IF/WHEN/CALL in pseudocompiler

This implements spilling and invalidation to ensure the validity
of register during execution. It generates some pretty lousy code.
  • Loading branch information...
bdw committed Jul 16, 2015
1 parent 0130e96 commit c451b05efff9564d710d517d76e5912f97b4178d
Showing with 256 additions and 39 deletions.
  1. +241 −23 src/jit/compile.c
  2. +3 −3 src/jit/expr.h
  3. +12 −13 src/jit/exprlist
View
@@ -150,17 +150,26 @@ static const X64_REGISTER FREE_REGISTERS[] = {
rax, rcx, rdx, rsi, rdi, r8, r9, r10, r11
};
/* NB - valid only for POSIX call convention */
static const MVMint32 CALL_REGISTERS[] = {
/* rdi, rsi, rdx, rcx, r8, r9 */
4, 3, 2, 1, 5, 6
};
#define NUM_REGISTERS (sizeof(FREE_REGISTERS)/sizeof(X64_REGISTER))
#define NUM_CALL_REGS (sizeof(FREE_REGISTERS)/sizeof(X64_REGISTER))
typedef struct {
MVMint32 reg_used[NUM_REGISTERS]; /* last use of a register */
MVMint32 *nodes_reg; /* register number of last computed operand */
MVMint32 *nodes_reg; /* register number of node */
MVMint32 *nodes_spill; /* spill location of a given node, if spilled */
MVMint32 last_register; /* last register number allocated; used
to implement the 'register ring' */
MVMint32 last_spill; /* last spill threshold, used to ensure
* that every value receives it's own
* node */
MVMint32 cond_depth; /* last used local label depth by
conditional statement */
} CompilerRegisterState;
static void take_register(MVMThreadContext *tc, CompilerRegisterState *state, MVMint32 regnum) {
@@ -177,13 +186,52 @@ static void take_register(MVMThreadContext *tc, CompilerRegisterState *state, MV
MVM_jit_log(tc, "mov [rsp+0x%x], %s\n", spill*sizeof(MVMRegister),
X64_REGISTER_NAMES[FREE_REGISTERS[regnum]]);
/* mark node as spilled */
state->nodes_reg[node] = -1;
state->nodes_reg[node] = -1;
state->nodes_spill[node] = spill;
}
/* mark register as free */
state->reg_used[regnum] = -1;
}
static void emit_full_spill(MVMThreadContext *tc, CompilerRegisterState *state) {
/* Equivalent to simply taking all registers */
MVMint32 i;
for (i = 0; i < NUM_REGISTERS; i++) {
take_register(tc, state, i);
}
}
/* invalidate registers */
static void invalidate_registers(MVMThreadContext *tc, CompilerRegisterState *state) {
MVMint32 i;
for (i = 0; i < NUM_REGISTERS; i++) {
MVMint32 node = state->reg_used[i];
if (node >= 0) {
state->nodes_reg[node] = -1;
state->reg_used[i] = -1;
}
}
}
static void load_node_to(MVMThreadContext *tc, CompilerRegisterState *state, MVMint32 node, MVMint32 regnum) {
MVMint32 cur_reg = state->nodes_reg[node];
if (state->reg_used[regnum] == node)
return;
if (cur_reg >= 0) {
MVM_jit_log(tc, "mov %s, %s\n", X64_REGISTER_NAMES[FREE_REGISTERS[regnum]],
X64_REGISTER_NAMES[FREE_REGISTERS[cur_reg]]);
state->reg_used[cur_reg] = -1; /* is now free */
} else if (state->nodes_spill[node] >= 0) {
MVM_jit_log(tc, "mov %s, [rsp+0x%x]\n", X64_REGISTER_NAMES[FREE_REGISTERS[regnum]],
state->nodes_spill[node]*sizeof(MVMRegister));
} else {
MVM_oops(tc, "Requested load of node %d but not in register or memory\n");
}
state->nodes_reg[node] = regnum;
state->reg_used[regnum] = node;
}
static MVMint32 get_next_register(MVMThreadContext *tc, CompilerRegisterState *state,
X64_REGISTER *regs, MVMint32 nregs) {
MVMint32 i, j;
@@ -202,7 +250,7 @@ static MVMint32 get_next_register(MVMThreadContext *tc, CompilerRegisterState *s
MVM_oops(tc, "Could not allocate a register\n");
}
static void emit_expr_op(MVMThreadContext *tc, MVMJitExprNode op,
static void emit_expr_op(MVMThreadContext *tc, CompilerRegisterState *state, MVMJitExprNode op,
X64_REGISTER *regs, MVMJitExprNode *args) {
switch(op) {
case MVM_JIT_LOAD:
@@ -222,6 +270,26 @@ static void emit_expr_op(MVMThreadContext *tc, MVMJitExprNode op,
MVM_jit_log(tc, "lea %s, [%s+%s*%d]\n", X64_REGISTER_NAMES[regs[2]],
X64_REGISTER_NAMES[regs[0]], X64_REGISTER_NAMES[regs[1]], args[0]);
break;
case MVM_JIT_ADD:
MVM_jit_log(tc, "mov %s, %s\n", X64_REGISTER_NAMES[regs[2]], X64_REGISTER_NAMES[regs[0]]);
MVM_jit_log(tc, "add %s, %s\n", X64_REGISTER_NAMES[regs[2]], X64_REGISTER_NAMES[regs[1]]);
break;
case MVM_JIT_AND:
MVM_jit_log(tc, "mov %s, %s\n", X64_REGISTER_NAMES[regs[2]], X64_REGISTER_NAMES[regs[0]]);
MVM_jit_log(tc, "and %s, %s\n", X64_REGISTER_NAMES[regs[2]], X64_REGISTER_NAMES[regs[1]]);
break;
case MVM_JIT_NZ:
take_register(tc, state, 0);
MVM_jit_log(tc, "test %s, %s\n", X64_REGISTER_NAMES[regs[0]], X64_REGISTER_NAMES[regs[0]]);
MVM_jit_log(tc, "setnz al\n");
MVM_jit_log(tc, "movzx %s, al\n", X64_REGISTER_NAMES[regs[1]]);
break;
case MVM_JIT_ZR:
take_register(tc, state, 0);
MVM_jit_log(tc, "test %s, %s\n", X64_REGISTER_NAMES[regs[0]], X64_REGISTER_NAMES[regs[0]]);
MVM_jit_log(tc, "setnz al\n");
MVM_jit_log(tc, "movzx %s, al\n", X64_REGISTER_NAMES[regs[1]]);
break;
case MVM_JIT_LOCAL:
MVM_jit_log(tc, "mov %s, %s\n", X64_REGISTER_NAMES[regs[0]], X64_REGISTER_NAMES[rbx]);
break;
@@ -237,22 +305,21 @@ static void emit_expr_op(MVMThreadContext *tc, MVMJitExprNode op,
case MVM_JIT_CONST:
MVM_jit_log(tc, "mov %s, 0x%x\n", X64_REGISTER_NAMES[regs[0]], (MVMint32)args[0]);
break;
case MVM_JIT_CALL:
MVM_jit_log(tc, "call %s\n", X64_REGISTER_NAMES[regs[0]]);
break;
default: {
const MVMJitExprOpInfo *info = MVM_jit_expr_op_info(tc, op);
MVM_jit_log(tc, "not yet sure how to compile %s\n", info->name);
}
}
}
static void compile_expr_op(MVMThreadContext *tc, MVMJitTreeTraverser *traverser,
MVMJitExprTree *tree, MVMint32 node) {
CompilerRegisterState *state = traverser->data;
MVMJitExprNode op = tree->nodes[node];
static void load_op_regs(MVMThreadContext *tc, CompilerRegisterState *state,
MVMJitExprTree *tree, MVMint32 node, X64_REGISTER *regs) {
MVMJitExprNode op = tree->nodes[node];
const MVMJitExprOpInfo *info = MVM_jit_expr_op_info(tc, op);
X64_REGISTER regs[8];
MVMint32 i, j, first_child, nchild;
if (traverser->visits[node] > 1) /* no revisits */
return;
MVMint32 i, first_child, nchild;
first_child = node + 1;
nchild = (info->nchild < 0 ? tree->nodes[first_child++] : info->nchild);
/* ensure child nodes have been computed into memory */
@@ -262,28 +329,176 @@ static void compile_expr_op(MVMThreadContext *tc, MVMJitTreeTraverser *traverser
if (MVM_jit_expr_op_info(tc, tree->nodes[child])->vtype == MVM_JIT_VOID)
continue;
if (regnum < 0) {
/* child does not reside in a register */
MVMint32 spill = state->nodes_spill[child];
if (spill < 0) {
MVM_oops(tc, "Child %d of %s is needed, but not register or memory", i, info->name);
}
regnum = get_next_register(tc, state, regs, i);
/* emit load */
MVM_jit_log(tc, "mov %s, [rsp+0x%x]\n", X64_REGISTER_NAMES[FREE_REGISTERS[regnum]], spill);
/* store child as existing in the register */
state->reg_used[regnum] = child;
state->nodes_reg[child] = regnum;
load_node_to(tc, state, child, regnum);
}
regs[i] = FREE_REGISTERS[regnum];
}
if (info->vtype != MVM_JIT_VOID) {
/* assign an output register */
MVMint32 regnum = get_next_register(tc, state, regs, i);
regs[i] = FREE_REGISTERS[regnum];
state->nodes_reg[node] = regnum;
state->nodes_reg[node] = regnum;
state->reg_used[regnum] = node;
}
emit_expr_op(tc, op, regs, tree->nodes + first_child + nchild);
}
static void load_call_regs(MVMThreadContext *tc, CompilerRegisterState *state,
MVMJitExprTree *tree, MVMint32 node, X64_REGISTER *regs) {
MVMint32 arglist = tree->nodes[node+2];
MVMint32 nargs = tree->nodes[arglist+1];
MVMint32 i, func;
for (i = 0; i < nargs; i++) {
MVMint32 carg = tree->nodes[arglist+2+i];
MVMint32 argval = tree->nodes[carg+1];
MVMint32 argtyp = tree->nodes[carg+2];
MVMint32 regnum = CALL_REGISTERS[i];
X64_REGISTER reg = FREE_REGISTERS[regnum];
/* whatever, ignore argtyp */
load_node_to(tc, state, argval, regnum);
regs[i] = reg;
}
/* load function register */
func = get_next_register(tc, state, regs, nargs);
load_node_to(tc, state, tree->nodes[node+1], func);
regs[0] = FREE_REGISTERS[func];
}
static void prepare_expr_op(MVMThreadContext *tc, MVMJitTreeTraverser *traverser,
MVMJitExprTree *tree, MVMint32 node) {
/* Spill before call or if or when */
CompilerRegisterState *state = traverser->data;
MVMJitExprNode op = tree->nodes[node];
if (traverser->visits > 0)
return;
/* Conditional blocks should spill before they are run */
switch (op) {
case MVM_JIT_WHEN:
/* require a label for the statement end */
state->cond_depth += 1;
emit_full_spill(tc, state);
break;
case MVM_JIT_IF:
/* require two labels, one for the alternative block, and one for the statement end */
state->cond_depth += 2;
emit_full_spill(tc, state);
break;
default:
break;
}
}
static void compile_expr_labels(MVMThreadContext *tc, MVMJitTreeTraverser *traverser,
MVMJitExprTree *tree, MVMint32 node, MVMint32 i) {
CompilerRegisterState *state = traverser->data;
MVMJitExprNode op = tree->nodes[node];
if (traverser->visits > 0)
return;
switch (op) {
case MVM_JIT_IF:
/* 'ternary operator' or 'expression style' if */
if (i == 0) {
/* branch to second option (label 1) */
MVM_jit_log(tc, "jnz >%d\n", state->cond_depth - 1);
} else if (i == 1) {
/* move result value into place (i.e. rax) */
load_node_to(tc, state, node + 2, 0);
/* just after first option, branch to end */
MVM_jit_log(tc, "jmp >%d\n", state->cond_depth);
MVM_jit_log(tc, "%d:\n" , state->cond_depth - 1);
/* registers used in the 'then' block are not usable in the 'else' block */
invalidate_registers(tc, state);
} else {
/* move result value into place */
load_node_to(tc, state, node + 3, 0);
/* emit end label */
MVM_jit_log(tc, "%d:\n", state->cond_depth);
invalidate_registers(tc, state);
}
break;
case MVM_JIT_WHEN:
if (i == 0) {
MVM_jit_log(tc, "jnz >%d\n", state->cond_depth);
} else {
MVM_jit_log(tc, "%d:\n", state->cond_depth);
invalidate_registers(tc, state);
}
break;
}
}
static void compile_expr_op(MVMThreadContext *tc, MVMJitTreeTraverser *traverser,
MVMJitExprTree *tree, MVMint32 node) {
CompilerRegisterState *state = traverser->data;
MVMJitExprNode op = tree->nodes[node];
const MVMJitExprOpInfo *info = MVM_jit_expr_op_info(tc, op);
X64_REGISTER regs[8];
MVMJitExprNode *args = tree->nodes + node + (info->nchild < 0 ? tree->nodes[node+1] + 1 : info->nchild) + 1;
if (traverser->visits[node] > 1) /* no revisits */
return;
switch(op) {
case MVM_JIT_CALL:
/* spill before we go */
emit_full_spill(tc, state);
load_call_regs(tc, state, tree, node, regs);
emit_expr_op(tc, state, op, regs, args);
/* all loaded registers are now invalid */
invalidate_registers(tc, state);
if (args[0] != MVM_JIT_RV_VOID) {
state->nodes_reg[node] = 0;
state->reg_used[0] = node;
}
break;
case MVM_JIT_WHEN:
state->cond_depth -= 1;
break;
case MVM_JIT_IF:
/* result value lives in rax (as assured in compile_expr_labels) */
state->reg_used[0] = node;
state->nodes_reg[node] = 0;
state->cond_depth -= 2;
break;
case MVM_JIT_DO:
{
MVMint32 nchild = tree->nodes[node+1];
MVMint32 last_child = tree->nodes[node+1+nchild];
MVMint32 regnum = get_next_register(tc, state, regs, 0);
load_node_to(tc, state, last_child, regnum);
state->reg_used[regnum] = node;
state->nodes_reg[node] = regnum;
break;
}
case MVM_JIT_ALL:
{
MVMint32 nchild = tree->nodes[node+1];
MVMint32 first_child = tree->nodes[node+2];
MVMint32 result, value, i;
if (nchild == 0)
MVM_oops(tc, "No child for ALL is clearly an error");
result = get_next_register(tc, state, regs, 0);
regs[0] = FREE_REGISTERS[result];
value = get_next_register(tc, state, regs, 1);
/* load first child */
load_node_to(tc, state, first_child, result);
for (i = 1; i < nchild; i++) {
MVMint32 child = tree->nodes[node+i+2];
load_node_to(tc, state, child, value);
MVM_jit_log(tc, "and %s, %s\n", X64_REGISTER_NAMES[FREE_REGISTERS[result]],
X64_REGISTER_NAMES[FREE_REGISTERS[value]]);
}
state->nodes_reg[node] = result;
state->reg_used[result] = node;
break;
}
case MVM_JIT_CARG:
case MVM_JIT_ARGLIST:
break;
default:
load_op_regs(tc, state, tree, node, regs);
emit_expr_op(tc, state, op, regs, args);
break;
}
}
@@ -303,12 +518,15 @@ void MVM_jit_compile_expr_tree(MVMThreadContext *tc, MVMJitGraph *jg, MVMJitExpr
state.nodes_spill = MVM_malloc(sizeof(MVMint32)*tree->nodes_num);
state.last_register = 0;
state.last_spill = 0;
state.cond_depth = 0;
memset(state.nodes_reg, -1, sizeof(MVMint32)*tree->nodes_num);
memset(state.nodes_spill, -1, sizeof(MVMint32)*tree->nodes_num);
/* initialize compiler */
memset(&compiler, 0, sizeof(MVMJitTreeTraverser));
compiler.data = &state;
compiler.preorder = &prepare_expr_op;
compiler.inorder = &compile_expr_labels;
compiler.postorder = &compile_expr_op;
MVM_jit_expr_tree_traverse(tc, tree, &compiler);
View
@@ -53,12 +53,12 @@ enum MVMJitExprVtype { /* value type */
_(ANY, -1, 0, FLAG), \
/* control operators */ \
_(DO, -1, 0, REG), \
_(IF, 2, 0, VOID), \
_(IFELSE, 3, 0, REG), \
_(WHEN, 2, 0, VOID), \
_(IF, 3, 0, REG), \
/* call c functions */ \
_(CALL, 2, 1, REG), \
_(ARGLIST, -1, 0, VOID), \
_(CARG, 1, 1, REG), \
_(CARG, 1, 1, VOID), \
/* interpreter special variables */ \
_(TC, 0, 0, REG), \
_(CU, 0, 0, MEM), \
Oops, something went wrong.

0 comments on commit c451b05

Please sign in to comment.