Skip to content

Commit

Permalink
Fix segfault caused by deopt all in NativeCall callback
Browse files Browse the repository at this point in the history
A callback from a native function starts a new runloop. When a deopt all
happens in that runloop, the outer runloop did not jump back to the unoptimized
bytecode. An sp_getspeshslot op would then cause a segfault because
tc->cur_frame->effective_spesh_slots was NULLed by the deopt.

Fix by propagating the change to interp_cur_op, interp_bytecode_start and
others back to the outer runloop.

For JIT compiled code, we have to set up the trampoline (record current
address) before doing the native call and add instructions for exiting the JIT
compiled code if we deopted during the native call.
  • Loading branch information
niner committed Dec 14, 2019
1 parent c799de4 commit d7b6855
Show file tree
Hide file tree
Showing 13 changed files with 87 additions and 39 deletions.
6 changes: 3 additions & 3 deletions src/core/frame.c
Expand Up @@ -913,7 +913,7 @@ static MVMuint64 remove_one_frame(MVMThreadContext *tc, MVMuint8 unwind) {
}

/* Switch back to the caller frame if there is one. */
if (caller && returner != tc->thread_entry_frame) {
if (caller && (returner != tc->thread_entry_frame || tc->nested_interpreter)) {

if (tc->jit_return_address != NULL) {
/* on a JIT frame, exit to interpreter afterwards */
Expand All @@ -924,7 +924,7 @@ static MVMuint64 remove_one_frame(MVMThreadContext *tc, MVMuint8 unwind) {
tc->jit_return_address = NULL;
}

tc->cur_frame = caller;
tc->cur_frame = (caller && returner != tc->thread_entry_frame) ? caller : NULL;
tc->current_frame_nr = caller->sequence_nr;

*(tc->interp_cur_op) = caller->return_address;
Expand All @@ -950,7 +950,7 @@ static MVMuint64 remove_one_frame(MVMThreadContext *tc, MVMuint8 unwind) {
}
}

return 1;
return (returner != tc->thread_entry_frame) ? 1 : 0;
}
else {
tc->cur_frame = NULL;
Expand Down
33 changes: 29 additions & 4 deletions src/core/interp.c
Expand Up @@ -138,6 +138,19 @@ void MVM_interp_run(MVMThreadContext *tc, void (*initial_invoke)(MVMThreadContex
/* Stash addresses of current op, register base and SC deref base
* in the TC; this will be used by anything that needs to switch
* the current place we're interpreting. */

MVMuint8 **backup_interp_cur_op = NULL;
MVMuint8 **backup_interp_bytecode_start = NULL;
MVMRegister **backup_interp_reg_base = NULL;
MVMCompUnit **backup_interp_cu = NULL;

if (tc->nested_interpreter) {
backup_interp_cur_op = tc->interp_cur_op;
backup_interp_bytecode_start = tc->interp_bytecode_start;
backup_interp_reg_base = tc->interp_reg_base;
backup_interp_cu = tc->interp_cu;
}

tc->interp_cur_op = &cur_op;
tc->interp_bytecode_start = &bytecode_start;
tc->interp_reg_base = &reg_base;
Expand Down Expand Up @@ -6663,10 +6676,22 @@ void MVM_interp_run(MVMThreadContext *tc, void (*initial_invoke)(MVMThreadContex
return_label:
/* Need to clear these pointer pointers since they may be rooted
* by some GC procedure. */
tc->interp_cur_op = NULL;
tc->interp_bytecode_start = NULL;
tc->interp_reg_base = NULL;
tc->interp_cu = NULL;
if (tc->nested_interpreter) {
*backup_interp_cur_op = *tc->interp_cur_op;
*backup_interp_bytecode_start = *tc->interp_bytecode_start;
*backup_interp_reg_base = *tc->interp_reg_base;
*backup_interp_cu = *tc->interp_cu;
tc->interp_cur_op = backup_interp_cur_op;
tc->interp_bytecode_start = backup_interp_bytecode_start;
tc->interp_reg_base = backup_interp_reg_base;
tc->interp_cu = backup_interp_cu;
}
else {
tc->interp_cur_op = NULL;
tc->interp_bytecode_start = NULL;
tc->interp_reg_base = NULL;
tc->interp_cu = NULL;
}
MVM_barrier();
}

Expand Down
2 changes: 2 additions & 0 deletions src/core/nativecall.c
Expand Up @@ -1186,5 +1186,7 @@ void MVM_nativecall_invoke_jit(MVMThreadContext *tc, MVMObject *site) {
MVMNativeCallBody *body = MVM_nativecall_get_nc_body(tc, site);
MVMJitCode * const jitcode = body->jitcode;

tc->cur_frame->return_address = *tc->interp_cur_op;

jitcode->func_ptr(tc, *tc->interp_cu, jitcode->labels[0]);
}
23 changes: 9 additions & 14 deletions src/core/nativecall_dyncall.c
Expand Up @@ -212,7 +212,7 @@ static void callback_invoke(MVMThreadContext *tc, void *data) {
static char callback_handler(DCCallback *cb, DCArgs *cb_args, DCValue *cb_result, MVMNativeCallback *data) {
CallbackInvokeData cid;
MVMint32 num_roots, i;
MVMRegister res;
MVMRegister res = {0};
MVMRegister *args;
unsigned int interval_id;

Expand Down Expand Up @@ -329,16 +329,12 @@ static char callback_handler(DCCallback *cb, DCArgs *cb_args, DCValue *cb_result
cid.args = args;
cid.cs = data->cs;
{
MVMuint8 **backup_interp_cur_op = tc->interp_cur_op;
MVMuint8 **backup_interp_bytecode_start = tc->interp_bytecode_start;
MVMRegister **backup_interp_reg_base = tc->interp_reg_base;
MVMCompUnit **backup_interp_cu = tc->interp_cu;

MVMFrame *backup_cur_frame = MVM_frame_force_to_heap(tc, tc->cur_frame);
MVMFrame *backup_thread_entry_frame = tc->thread_entry_frame;
void **backup_jit_return_address = tc->jit_return_address;
tc->jit_return_address = NULL;
MVMROOT2(tc, backup_cur_frame, backup_thread_entry_frame, {
MVM_gc_root_temp_push(tc, (MVMCollectable **)&backup_cur_frame);
MVM_gc_root_temp_push(tc, (MVMCollectable **)&backup_thread_entry_frame);
MVMuint32 backup_mark = MVM_gc_root_temp_mark(tc);
jmp_buf backup_interp_jump;
memcpy(backup_interp_jump, tc->interp_jump, sizeof(jmp_buf));
Expand All @@ -347,20 +343,17 @@ static char callback_handler(DCCallback *cb, DCArgs *cb_args, DCValue *cb_result
tc->cur_frame->return_value = &res;
tc->cur_frame->return_type = MVM_RETURN_OBJ;

tc->nested_interpreter++;
MVM_interp_run(tc, callback_invoke, &cid);
tc->nested_interpreter--;

tc->interp_cur_op = backup_interp_cur_op;
tc->interp_bytecode_start = backup_interp_bytecode_start;
tc->interp_reg_base = backup_interp_reg_base;
tc->interp_cu = backup_interp_cu;
tc->cur_frame = backup_cur_frame;
tc->current_frame_nr = backup_cur_frame->sequence_nr;
tc->thread_entry_frame = backup_thread_entry_frame;
tc->jit_return_address = backup_jit_return_address;
tc->thread_entry_frame = backup_thread_entry_frame;

memcpy(tc->interp_jump, backup_interp_jump, sizeof(jmp_buf));
MVM_gc_root_temp_mark_reset(tc, backup_mark);
});
MVM_gc_root_temp_pop_n(tc, 2);
}

/* Handle return value. */
Expand Down Expand Up @@ -504,6 +497,8 @@ MVMObject * MVM_nativecall_invoke(MVMThreadContext *tc, MVMObject *res_type,
unsigned int interval_id;
DCCallVM *vm;

tc->cur_frame->return_address = *tc->interp_cur_op;

/* Create and set up call VM. */
vm = dcNewCallVM(8192);
dcMode(vm, body->convention);
Expand Down
12 changes: 6 additions & 6 deletions src/core/oplist
Expand Up @@ -623,7 +623,7 @@ getuniprop_bool w(int64) r(int64) r(int64) :pure
getuniprop_str w(str) r(int64) r(int64) :pure
matchuniprop w(int64) r(int64) r(int64) r(int64) :pure
nativecallbuild w(int64) r(obj) r(str) r(str) r(str) r(obj) r(obj)
nativecallinvoke w(obj) r(obj) r(obj) r(obj)
nativecallinvoke w(obj) r(obj) r(obj) r(obj) :deoptallpoint :maycausedeopt
nativecallrefresh r(obj)
threadrun r(obj)
threadid w(int64) r(obj) :pure
Expand Down Expand Up @@ -838,11 +838,11 @@ atomicstore_o r(obj) r(obj) :invokish :maycausedeopt :specializable
atomicstore_i r(obj) r(int64) :specializable
barrierfull
coveragecontrol r(int64)
nativeinvoke_v -a r(obj) r(obj) :specializable
nativeinvoke_i -a w(int64) r(obj) r(obj) :specializable
nativeinvoke_n -a w(num64) r(obj) r(obj) :specializable
nativeinvoke_s -a w(str) r(obj) r(obj) :specializable
nativeinvoke_o -a w(obj) r(obj) r(obj) :specializable
nativeinvoke_v -a r(obj) r(obj) :specializable :deoptallpoint :maycausedeopt
nativeinvoke_i -a w(int64) r(obj) r(obj) :specializable :deoptallpoint :maycausedeopt
nativeinvoke_n -a w(num64) r(obj) r(obj) :specializable :deoptallpoint :maycausedeopt
nativeinvoke_s -a w(str) r(obj) r(obj) :specializable :deoptallpoint :maycausedeopt
nativeinvoke_o -a w(obj) r(obj) r(obj) :specializable :deoptallpoint :maycausedeopt
getarg_i w(int64) r(int16)
getarg_n w(num64) r(int16)
getarg_s w(str) r(int16)
Expand Down
24 changes: 12 additions & 12 deletions src/core/ops.c
Expand Up @@ -7916,8 +7916,8 @@ static const MVMOpInfo MVM_op_infos[] = {
"nativecallinvoke",
4,
0,
0,
0,
2,
1,
0,
0,
0,
Expand Down Expand Up @@ -10898,8 +10898,8 @@ static const MVMOpInfo MVM_op_infos[] = {
"nativeinvoke_v",
2,
0,
0,
0,
2,
1,
0,
0,
0,
Expand All @@ -10912,8 +10912,8 @@ static const MVMOpInfo MVM_op_infos[] = {
"nativeinvoke_i",
3,
0,
0,
0,
2,
1,
0,
0,
0,
Expand All @@ -10926,8 +10926,8 @@ static const MVMOpInfo MVM_op_infos[] = {
"nativeinvoke_n",
3,
0,
0,
0,
2,
1,
0,
0,
0,
Expand All @@ -10940,8 +10940,8 @@ static const MVMOpInfo MVM_op_infos[] = {
"nativeinvoke_s",
3,
0,
0,
0,
2,
1,
0,
0,
0,
Expand All @@ -10954,8 +10954,8 @@ static const MVMOpInfo MVM_op_infos[] = {
"nativeinvoke_o",
3,
0,
0,
0,
2,
1,
0,
0,
0,
Expand Down
2 changes: 2 additions & 0 deletions src/core/threadcontext.h
Expand Up @@ -333,6 +333,8 @@ struct MVMThreadContext {

MVMuint32 cur_file_idx;
MVMuint32 cur_line_no;

int nested_interpreter;
};

MVMThreadContext * MVM_tc_create(MVMThreadContext *parent, MVMInstance *instance);
Expand Down
3 changes: 3 additions & 0 deletions src/jit/compile.c
Expand Up @@ -109,6 +109,9 @@ MVMJitCode * MVM_jit_compile_graph(MVMThreadContext *tc, MVMJitGraph *jg) {
case MVM_JIT_NODE_DATA:
MVM_jit_emit_data(tc, &cl, &node->u.data);
break;
case MVM_JIT_NODE_DEOPT_CHECK:
MVM_jit_emit_deopt_check(tc, &cl);
break;
}
node = node->next;
}
Expand Down
9 changes: 9 additions & 0 deletions src/jit/graph.c
Expand Up @@ -16,6 +16,12 @@ static void jg_append_node(MVMJitGraph *jg, MVMJitNode *node) {
node->next = NULL;
}

static void jg_append_deopt_check(MVMThreadContext *tc, MVMJitGraph *jg) {
MVMJitNode *node = MVM_spesh_alloc(tc, jg->sg, sizeof(MVMJitNode));
node->type = MVM_JIT_NODE_DEOPT_CHECK;
jg_append_node(jg, node);
}

static void jg_append_primitive(MVMThreadContext *tc, MVMJitGraph *jg,
MVMSpeshIns * ins) {
MVMJitNode * node = MVM_spesh_alloc(tc, jg->sg, sizeof(MVMJitNode));
Expand Down Expand Up @@ -3514,12 +3520,15 @@ static MVMint32 consume_ins(MVMThreadContext *tc, MVMJitGraph *jg,
MVMint16 restype = ins->operands[1].reg.orig;
MVMint16 site = ins->operands[2].reg.orig;
MVMint16 cargs = ins->operands[3].reg.orig;
MVMJitCallArg targs[] = { { MVM_JIT_INTERP_VAR, { MVM_JIT_INTERP_TC } } };
jg_append_call_c(tc, jg, MVM_jit_code_trampoline, 1, targs, MVM_JIT_RV_VOID, -1);
MVMJitCallArg args[] = { { MVM_JIT_INTERP_VAR, { MVM_JIT_INTERP_TC } },
{ MVM_JIT_REG_VAL, { restype } },
{ MVM_JIT_REG_VAL, { site } },
{ MVM_JIT_REG_VAL, { cargs } } };
jg_append_call_c(tc, jg, op_to_func(tc, op), 4, args,
MVM_JIT_RV_PTR, dst);
jg_append_deopt_check(tc, jg);
break;
}
case MVM_OP_typeparameters:
Expand Down
1 change: 1 addition & 0 deletions src/jit/graph.h
Expand Up @@ -221,6 +221,7 @@ typedef enum {
MVM_JIT_NODE_CONTROL,
MVM_JIT_NODE_DATA,
MVM_JIT_NODE_EXPR_TREE,
MVM_JIT_NODE_DEOPT_CHECK,
} MVMJitNodeType;

struct MVMJitNode {
Expand Down
1 change: 1 addition & 0 deletions src/jit/interface.c
Expand Up @@ -39,6 +39,7 @@ void * MVM_jit_code_get_current_position(MVMThreadContext *tc, MVMJitCode *code,
void MVM_jit_code_set_current_position(MVMThreadContext *tc, MVMJitCode *code, MVMFrame *frame, void *position) {
assert_within_region(tc, code, position);
if (tc->cur_frame == frame && tc->jit_return_address != NULL) {
/* this overwrites the address on the stack that MVM_frame_invoke_code will ret to! */
*tc->jit_return_address = position;
} else {
frame->jit_entry_label = position;
Expand Down
1 change: 1 addition & 0 deletions src/jit/internal.h
Expand Up @@ -57,6 +57,7 @@ void MVM_jit_emit_store(MVMThreadContext *tc, MVMJitCompiler *compiler,
void MVM_jit_emit_copy(MVMThreadContext *tc, MVMJitCompiler *compiler,
MVMint8 dst_reg, MVMint8 src_num);
void MVM_jit_emit_marker(MVMThreadContext *tc, MVMJitCompiler *compiler, MVMint32 num);
void MVM_jit_emit_deopt_check(MVMThreadContext *tc, MVMJitCompiler *compiler);

MVMuint32 MVM_jit_spill_memory_select(MVMThreadContext *tc, MVMJitCompiler *compiler, MVMint8 reg_type);
void MVM_jit_spill_memory_release(MVMThreadContext *tc, MVMJitCompiler *compiler, MVMuint32 pos, MVMint8 reg_type);
Expand Down
9 changes: 9 additions & 0 deletions src/jit/x64/emit.dasc
Expand Up @@ -3519,5 +3519,14 @@ void MVM_jit_emit_data(MVMThreadContext *tc, MVMJitCompiler *compiler, MVMJitDat
|.code
}

void MVM_jit_emit_deopt_check(MVMThreadContext *tc, MVMJitCompiler *compiler) {
| mov TMP6, TC->cur_frame;
| mov TMP6, FRAME:TMP6->spesh_cand
| test TMP6, TMP6
| jnz >1
| jmp ->exit
|1:
}

/* import tiles */
|.include src/jit/x64/tiles.dasc

0 comments on commit d7b6855

Please sign in to comment.