Skip to content

Commit c451b05

Browse files
committed
Add support for IF/WHEN/CALL in pseudocompiler
This implements spilling and invalidation to ensure the validity of register during execution. It generates some pretty lousy code.
1 parent 0130e96 commit c451b05

File tree

3 files changed

+256
-39
lines changed

3 files changed

+256
-39
lines changed

src/jit/compile.c

Lines changed: 241 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -150,17 +150,26 @@ static const X64_REGISTER FREE_REGISTERS[] = {
150150
rax, rcx, rdx, rsi, rdi, r8, r9, r10, r11
151151
};
152152

153+
/* NB - valid only for POSIX call convention */
154+
static const MVMint32 CALL_REGISTERS[] = {
155+
/* rdi, rsi, rdx, rcx, r8, r9 */
156+
4, 3, 2, 1, 5, 6
157+
};
158+
153159
#define NUM_REGISTERS (sizeof(FREE_REGISTERS)/sizeof(X64_REGISTER))
160+
#define NUM_CALL_REGS (sizeof(FREE_REGISTERS)/sizeof(X64_REGISTER))
154161

155162
typedef struct {
156163
MVMint32 reg_used[NUM_REGISTERS]; /* last use of a register */
157-
MVMint32 *nodes_reg; /* register number of last computed operand */
164+
MVMint32 *nodes_reg; /* register number of node */
158165
MVMint32 *nodes_spill; /* spill location of a given node, if spilled */
159166
MVMint32 last_register; /* last register number allocated; used
160167
to implement the 'register ring' */
161168
MVMint32 last_spill; /* last spill threshold, used to ensure
162169
* that every value receives it's own
163170
* node */
171+
MVMint32 cond_depth; /* last used local label depth by
172+
conditional statement */
164173
} CompilerRegisterState;
165174

166175
static void take_register(MVMThreadContext *tc, CompilerRegisterState *state, MVMint32 regnum) {
@@ -177,13 +186,52 @@ static void take_register(MVMThreadContext *tc, CompilerRegisterState *state, MV
177186
MVM_jit_log(tc, "mov [rsp+0x%x], %s\n", spill*sizeof(MVMRegister),
178187
X64_REGISTER_NAMES[FREE_REGISTERS[regnum]]);
179188
/* mark node as spilled */
180-
state->nodes_reg[node] = -1;
189+
state->nodes_reg[node] = -1;
181190
state->nodes_spill[node] = spill;
182191
}
183192
/* mark register as free */
184193
state->reg_used[regnum] = -1;
185194
}
186195

196+
197+
static void emit_full_spill(MVMThreadContext *tc, CompilerRegisterState *state) {
198+
/* Equivalent to simply taking all registers */
199+
MVMint32 i;
200+
for (i = 0; i < NUM_REGISTERS; i++) {
201+
take_register(tc, state, i);
202+
}
203+
}
204+
205+
/* invalidate registers */
206+
static void invalidate_registers(MVMThreadContext *tc, CompilerRegisterState *state) {
207+
MVMint32 i;
208+
for (i = 0; i < NUM_REGISTERS; i++) {
209+
MVMint32 node = state->reg_used[i];
210+
if (node >= 0) {
211+
state->nodes_reg[node] = -1;
212+
state->reg_used[i] = -1;
213+
}
214+
}
215+
}
216+
217+
static void load_node_to(MVMThreadContext *tc, CompilerRegisterState *state, MVMint32 node, MVMint32 regnum) {
218+
MVMint32 cur_reg = state->nodes_reg[node];
219+
if (state->reg_used[regnum] == node)
220+
return;
221+
if (cur_reg >= 0) {
222+
MVM_jit_log(tc, "mov %s, %s\n", X64_REGISTER_NAMES[FREE_REGISTERS[regnum]],
223+
X64_REGISTER_NAMES[FREE_REGISTERS[cur_reg]]);
224+
state->reg_used[cur_reg] = -1; /* is now free */
225+
} else if (state->nodes_spill[node] >= 0) {
226+
MVM_jit_log(tc, "mov %s, [rsp+0x%x]\n", X64_REGISTER_NAMES[FREE_REGISTERS[regnum]],
227+
state->nodes_spill[node]*sizeof(MVMRegister));
228+
} else {
229+
MVM_oops(tc, "Requested load of node %d but not in register or memory\n");
230+
}
231+
state->nodes_reg[node] = regnum;
232+
state->reg_used[regnum] = node;
233+
}
234+
187235
static MVMint32 get_next_register(MVMThreadContext *tc, CompilerRegisterState *state,
188236
X64_REGISTER *regs, MVMint32 nregs) {
189237
MVMint32 i, j;
@@ -202,7 +250,7 @@ static MVMint32 get_next_register(MVMThreadContext *tc, CompilerRegisterState *s
202250
MVM_oops(tc, "Could not allocate a register\n");
203251
}
204252

205-
static void emit_expr_op(MVMThreadContext *tc, MVMJitExprNode op,
253+
static void emit_expr_op(MVMThreadContext *tc, CompilerRegisterState *state, MVMJitExprNode op,
206254
X64_REGISTER *regs, MVMJitExprNode *args) {
207255
switch(op) {
208256
case MVM_JIT_LOAD:
@@ -222,6 +270,26 @@ static void emit_expr_op(MVMThreadContext *tc, MVMJitExprNode op,
222270
MVM_jit_log(tc, "lea %s, [%s+%s*%d]\n", X64_REGISTER_NAMES[regs[2]],
223271
X64_REGISTER_NAMES[regs[0]], X64_REGISTER_NAMES[regs[1]], args[0]);
224272
break;
273+
case MVM_JIT_ADD:
274+
MVM_jit_log(tc, "mov %s, %s\n", X64_REGISTER_NAMES[regs[2]], X64_REGISTER_NAMES[regs[0]]);
275+
MVM_jit_log(tc, "add %s, %s\n", X64_REGISTER_NAMES[regs[2]], X64_REGISTER_NAMES[regs[1]]);
276+
break;
277+
case MVM_JIT_AND:
278+
MVM_jit_log(tc, "mov %s, %s\n", X64_REGISTER_NAMES[regs[2]], X64_REGISTER_NAMES[regs[0]]);
279+
MVM_jit_log(tc, "and %s, %s\n", X64_REGISTER_NAMES[regs[2]], X64_REGISTER_NAMES[regs[1]]);
280+
break;
281+
case MVM_JIT_NZ:
282+
take_register(tc, state, 0);
283+
MVM_jit_log(tc, "test %s, %s\n", X64_REGISTER_NAMES[regs[0]], X64_REGISTER_NAMES[regs[0]]);
284+
MVM_jit_log(tc, "setnz al\n");
285+
MVM_jit_log(tc, "movzx %s, al\n", X64_REGISTER_NAMES[regs[1]]);
286+
break;
287+
case MVM_JIT_ZR:
288+
take_register(tc, state, 0);
289+
MVM_jit_log(tc, "test %s, %s\n", X64_REGISTER_NAMES[regs[0]], X64_REGISTER_NAMES[regs[0]]);
290+
MVM_jit_log(tc, "setnz al\n");
291+
MVM_jit_log(tc, "movzx %s, al\n", X64_REGISTER_NAMES[regs[1]]);
292+
break;
225293
case MVM_JIT_LOCAL:
226294
MVM_jit_log(tc, "mov %s, %s\n", X64_REGISTER_NAMES[regs[0]], X64_REGISTER_NAMES[rbx]);
227295
break;
@@ -237,22 +305,21 @@ static void emit_expr_op(MVMThreadContext *tc, MVMJitExprNode op,
237305
case MVM_JIT_CONST:
238306
MVM_jit_log(tc, "mov %s, 0x%x\n", X64_REGISTER_NAMES[regs[0]], (MVMint32)args[0]);
239307
break;
308+
case MVM_JIT_CALL:
309+
MVM_jit_log(tc, "call %s\n", X64_REGISTER_NAMES[regs[0]]);
310+
break;
240311
default: {
241312
const MVMJitExprOpInfo *info = MVM_jit_expr_op_info(tc, op);
242313
MVM_jit_log(tc, "not yet sure how to compile %s\n", info->name);
243314
}
244315
}
245316
}
246317

247-
static void compile_expr_op(MVMThreadContext *tc, MVMJitTreeTraverser *traverser,
248-
MVMJitExprTree *tree, MVMint32 node) {
249-
CompilerRegisterState *state = traverser->data;
250-
MVMJitExprNode op = tree->nodes[node];
318+
static void load_op_regs(MVMThreadContext *tc, CompilerRegisterState *state,
319+
MVMJitExprTree *tree, MVMint32 node, X64_REGISTER *regs) {
320+
MVMJitExprNode op = tree->nodes[node];
251321
const MVMJitExprOpInfo *info = MVM_jit_expr_op_info(tc, op);
252-
X64_REGISTER regs[8];
253-
MVMint32 i, j, first_child, nchild;
254-
if (traverser->visits[node] > 1) /* no revisits */
255-
return;
322+
MVMint32 i, first_child, nchild;
256323
first_child = node + 1;
257324
nchild = (info->nchild < 0 ? tree->nodes[first_child++] : info->nchild);
258325
/* ensure child nodes have been computed into memory */
@@ -262,28 +329,176 @@ static void compile_expr_op(MVMThreadContext *tc, MVMJitTreeTraverser *traverser
262329
if (MVM_jit_expr_op_info(tc, tree->nodes[child])->vtype == MVM_JIT_VOID)
263330
continue;
264331
if (regnum < 0) {
265-
/* child does not reside in a register */
266-
MVMint32 spill = state->nodes_spill[child];
267-
if (spill < 0) {
268-
MVM_oops(tc, "Child %d of %s is needed, but not register or memory", i, info->name);
269-
}
270332
regnum = get_next_register(tc, state, regs, i);
271-
/* emit load */
272-
MVM_jit_log(tc, "mov %s, [rsp+0x%x]\n", X64_REGISTER_NAMES[FREE_REGISTERS[regnum]], spill);
273-
/* store child as existing in the register */
274-
state->reg_used[regnum] = child;
275-
state->nodes_reg[child] = regnum;
333+
load_node_to(tc, state, child, regnum);
276334
}
277335
regs[i] = FREE_REGISTERS[regnum];
278336
}
279337
if (info->vtype != MVM_JIT_VOID) {
280338
/* assign an output register */
281339
MVMint32 regnum = get_next_register(tc, state, regs, i);
282340
regs[i] = FREE_REGISTERS[regnum];
283-
state->nodes_reg[node] = regnum;
341+
state->nodes_reg[node] = regnum;
284342
state->reg_used[regnum] = node;
285343
}
286-
emit_expr_op(tc, op, regs, tree->nodes + first_child + nchild);
344+
}
345+
346+
static void load_call_regs(MVMThreadContext *tc, CompilerRegisterState *state,
347+
MVMJitExprTree *tree, MVMint32 node, X64_REGISTER *regs) {
348+
MVMint32 arglist = tree->nodes[node+2];
349+
MVMint32 nargs = tree->nodes[arglist+1];
350+
MVMint32 i, func;
351+
for (i = 0; i < nargs; i++) {
352+
MVMint32 carg = tree->nodes[arglist+2+i];
353+
MVMint32 argval = tree->nodes[carg+1];
354+
MVMint32 argtyp = tree->nodes[carg+2];
355+
MVMint32 regnum = CALL_REGISTERS[i];
356+
X64_REGISTER reg = FREE_REGISTERS[regnum];
357+
/* whatever, ignore argtyp */
358+
load_node_to(tc, state, argval, regnum);
359+
regs[i] = reg;
360+
}
361+
/* load function register */
362+
func = get_next_register(tc, state, regs, nargs);
363+
load_node_to(tc, state, tree->nodes[node+1], func);
364+
regs[0] = FREE_REGISTERS[func];
365+
}
366+
367+
static void prepare_expr_op(MVMThreadContext *tc, MVMJitTreeTraverser *traverser,
368+
MVMJitExprTree *tree, MVMint32 node) {
369+
/* Spill before call or if or when */
370+
CompilerRegisterState *state = traverser->data;
371+
MVMJitExprNode op = tree->nodes[node];
372+
if (traverser->visits > 0)
373+
return;
374+
/* Conditional blocks should spill before they are run */
375+
switch (op) {
376+
case MVM_JIT_WHEN:
377+
/* require a label for the statement end */
378+
state->cond_depth += 1;
379+
emit_full_spill(tc, state);
380+
break;
381+
case MVM_JIT_IF:
382+
/* require two labels, one for the alternative block, and one for the statement end */
383+
state->cond_depth += 2;
384+
emit_full_spill(tc, state);
385+
break;
386+
default:
387+
break;
388+
}
389+
}
390+
391+
static void compile_expr_labels(MVMThreadContext *tc, MVMJitTreeTraverser *traverser,
392+
MVMJitExprTree *tree, MVMint32 node, MVMint32 i) {
393+
CompilerRegisterState *state = traverser->data;
394+
MVMJitExprNode op = tree->nodes[node];
395+
if (traverser->visits > 0)
396+
return;
397+
switch (op) {
398+
case MVM_JIT_IF:
399+
/* 'ternary operator' or 'expression style' if */
400+
if (i == 0) {
401+
/* branch to second option (label 1) */
402+
MVM_jit_log(tc, "jnz >%d\n", state->cond_depth - 1);
403+
} else if (i == 1) {
404+
/* move result value into place (i.e. rax) */
405+
load_node_to(tc, state, node + 2, 0);
406+
/* just after first option, branch to end */
407+
MVM_jit_log(tc, "jmp >%d\n", state->cond_depth);
408+
MVM_jit_log(tc, "%d:\n" , state->cond_depth - 1);
409+
/* registers used in the 'then' block are not usable in the 'else' block */
410+
invalidate_registers(tc, state);
411+
} else {
412+
/* move result value into place */
413+
load_node_to(tc, state, node + 3, 0);
414+
/* emit end label */
415+
MVM_jit_log(tc, "%d:\n", state->cond_depth);
416+
invalidate_registers(tc, state);
417+
}
418+
break;
419+
case MVM_JIT_WHEN:
420+
if (i == 0) {
421+
MVM_jit_log(tc, "jnz >%d\n", state->cond_depth);
422+
} else {
423+
MVM_jit_log(tc, "%d:\n", state->cond_depth);
424+
invalidate_registers(tc, state);
425+
}
426+
break;
427+
}
428+
}
429+
430+
static void compile_expr_op(MVMThreadContext *tc, MVMJitTreeTraverser *traverser,
431+
MVMJitExprTree *tree, MVMint32 node) {
432+
CompilerRegisterState *state = traverser->data;
433+
MVMJitExprNode op = tree->nodes[node];
434+
const MVMJitExprOpInfo *info = MVM_jit_expr_op_info(tc, op);
435+
X64_REGISTER regs[8];
436+
MVMJitExprNode *args = tree->nodes + node + (info->nchild < 0 ? tree->nodes[node+1] + 1 : info->nchild) + 1;
437+
if (traverser->visits[node] > 1) /* no revisits */
438+
return;
439+
switch(op) {
440+
case MVM_JIT_CALL:
441+
/* spill before we go */
442+
emit_full_spill(tc, state);
443+
load_call_regs(tc, state, tree, node, regs);
444+
emit_expr_op(tc, state, op, regs, args);
445+
/* all loaded registers are now invalid */
446+
invalidate_registers(tc, state);
447+
if (args[0] != MVM_JIT_RV_VOID) {
448+
state->nodes_reg[node] = 0;
449+
state->reg_used[0] = node;
450+
}
451+
break;
452+
case MVM_JIT_WHEN:
453+
state->cond_depth -= 1;
454+
break;
455+
case MVM_JIT_IF:
456+
/* result value lives in rax (as assured in compile_expr_labels) */
457+
state->reg_used[0] = node;
458+
state->nodes_reg[node] = 0;
459+
state->cond_depth -= 2;
460+
break;
461+
case MVM_JIT_DO:
462+
{
463+
MVMint32 nchild = tree->nodes[node+1];
464+
MVMint32 last_child = tree->nodes[node+1+nchild];
465+
MVMint32 regnum = get_next_register(tc, state, regs, 0);
466+
load_node_to(tc, state, last_child, regnum);
467+
state->reg_used[regnum] = node;
468+
state->nodes_reg[node] = regnum;
469+
break;
470+
}
471+
case MVM_JIT_ALL:
472+
{
473+
MVMint32 nchild = tree->nodes[node+1];
474+
MVMint32 first_child = tree->nodes[node+2];
475+
MVMint32 result, value, i;
476+
if (nchild == 0)
477+
MVM_oops(tc, "No child for ALL is clearly an error");
478+
result = get_next_register(tc, state, regs, 0);
479+
regs[0] = FREE_REGISTERS[result];
480+
value = get_next_register(tc, state, regs, 1);
481+
/* load first child */
482+
load_node_to(tc, state, first_child, result);
483+
for (i = 1; i < nchild; i++) {
484+
MVMint32 child = tree->nodes[node+i+2];
485+
load_node_to(tc, state, child, value);
486+
MVM_jit_log(tc, "and %s, %s\n", X64_REGISTER_NAMES[FREE_REGISTERS[result]],
487+
X64_REGISTER_NAMES[FREE_REGISTERS[value]]);
488+
}
489+
state->nodes_reg[node] = result;
490+
state->reg_used[result] = node;
491+
break;
492+
}
493+
case MVM_JIT_CARG:
494+
case MVM_JIT_ARGLIST:
495+
break;
496+
default:
497+
load_op_regs(tc, state, tree, node, regs);
498+
emit_expr_op(tc, state, op, regs, args);
499+
break;
500+
}
501+
287502
}
288503

289504

@@ -303,12 +518,15 @@ void MVM_jit_compile_expr_tree(MVMThreadContext *tc, MVMJitGraph *jg, MVMJitExpr
303518
state.nodes_spill = MVM_malloc(sizeof(MVMint32)*tree->nodes_num);
304519
state.last_register = 0;
305520
state.last_spill = 0;
521+
state.cond_depth = 0;
306522
memset(state.nodes_reg, -1, sizeof(MVMint32)*tree->nodes_num);
307523
memset(state.nodes_spill, -1, sizeof(MVMint32)*tree->nodes_num);
308524

309525
/* initialize compiler */
310526
memset(&compiler, 0, sizeof(MVMJitTreeTraverser));
311527
compiler.data = &state;
528+
compiler.preorder = &prepare_expr_op;
529+
compiler.inorder = &compile_expr_labels;
312530
compiler.postorder = &compile_expr_op;
313531

314532
MVM_jit_expr_tree_traverse(tc, tree, &compiler);

src/jit/expr.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,12 @@ enum MVMJitExprVtype { /* value type */
5353
_(ANY, -1, 0, FLAG), \
5454
/* control operators */ \
5555
_(DO, -1, 0, REG), \
56-
_(IF, 2, 0, VOID), \
57-
_(IFELSE, 3, 0, REG), \
56+
_(WHEN, 2, 0, VOID), \
57+
_(IF, 3, 0, REG), \
5858
/* call c functions */ \
5959
_(CALL, 2, 1, REG), \
6060
_(ARGLIST, -1, 0, VOID), \
61-
_(CARG, 1, 1, REG), \
61+
_(CARG, 1, 1, VOID), \
6262
/* interpreter special variables */ \
6363
_(TC, 0, 0, REG), \
6464
_(CU, 0, 0, MEM), \

0 commit comments

Comments
 (0)