Skip to content

Commit

Permalink
pythonGH-111848: Set the IP when de-optimizing (pythonGH-112065)
Browse files Browse the repository at this point in the history
* Replace jumps with deopts in tier 2

* Fewer special cases of uop names

* Add target field to uop IR

* Remove more redundant SET_IP and _CHECK_VALIDITY micro-ops

* Extend whitelist of non-escaping API functions.
  • Loading branch information
markshannon authored and aisk committed Feb 11, 2024
1 parent 3611cf6 commit 44f3d21
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 98 deletions.
104 changes: 52 additions & 52 deletions Include/internal/pycore_opcode_metadata.h

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions Include/internal/pycore_uops.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ extern "C" {
#define _Py_UOP_MAX_TRACE_LENGTH 128

typedef struct {
uint32_t opcode;
uint32_t oparg;
uint16_t opcode;
uint16_t oparg;
uint32_t target;
uint64_t operand; // A cache entry
} _PyUOpInstruction;

Expand Down
2 changes: 2 additions & 0 deletions Python/ceval.c
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,7 @@ _PyEval_EvalFrameDefault(PyThreadState *tstate, _PyInterpreterFrame *frame, int
UOP_STAT_INC(opcode, miss);
frame->return_offset = 0; // Dispatch to frame->instr_ptr
_PyFrame_SetStackPointer(frame, stack_pointer);
frame->instr_ptr = next_uop[-1].target + _PyCode_CODE((PyCodeObject *)frame->f_executable);
Py_DECREF(current_executor);
// Fall through
// Jump here from ENTER_EXECUTOR
Expand All @@ -1077,6 +1078,7 @@ _PyEval_EvalFrameDefault(PyThreadState *tstate, _PyInterpreterFrame *frame, int
// Jump here from _EXIT_TRACE
exit_trace:
_PyFrame_SetStackPointer(frame, stack_pointer);
frame->instr_ptr = next_uop[-1].target + _PyCode_CODE((PyCodeObject *)frame->f_executable);
Py_DECREF(current_executor);
OPT_HIST(trace_uop_execution_counter, trace_run_length_hist);
goto enter_tier_one;
Expand Down
48 changes: 22 additions & 26 deletions Python/optimizer.c
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,8 @@ translate_bytecode_to_trace(
#define DPRINTF(level, ...)
#endif

#define ADD_TO_TRACE(OPCODE, OPARG, OPERAND) \

#define ADD_TO_TRACE(OPCODE, OPARG, OPERAND, TARGET) \
DPRINTF(2, \
" ADD_TO_TRACE(%s, %d, %" PRIu64 ")\n", \
uop_name(OPCODE), \
Expand All @@ -458,23 +459,12 @@ translate_bytecode_to_trace(
trace[trace_length].opcode = (OPCODE); \
trace[trace_length].oparg = (OPARG); \
trace[trace_length].operand = (OPERAND); \
trace[trace_length].target = (TARGET); \
trace_length++;

#define INSTR_IP(INSTR, CODE) \
((uint32_t)((INSTR) - ((_Py_CODEUNIT *)(CODE)->co_code_adaptive)))

#define ADD_TO_STUB(INDEX, OPCODE, OPARG, OPERAND) \
DPRINTF(2, " ADD_TO_STUB(%d, %s, %d, %" PRIu64 ")\n", \
(INDEX), \
uop_name(OPCODE), \
(OPARG), \
(uint64_t)(OPERAND)); \
assert(reserved > 0); \
reserved--; \
trace[(INDEX)].opcode = (OPCODE); \
trace[(INDEX)].oparg = (OPARG); \
trace[(INDEX)].operand = (OPERAND);

// Reserve space for n uops
#define RESERVE_RAW(n, opname) \
if (trace_length + (n) > max_length) { \
Expand All @@ -483,7 +473,7 @@ translate_bytecode_to_trace(
OPT_STAT_INC(trace_too_long); \
goto done; \
} \
reserved = (n); // Keep ADD_TO_TRACE / ADD_TO_STUB honest
reserved = (n); // Keep ADD_TO_TRACE honest

// Reserve space for main+stub uops, plus 3 for _SET_IP, _CHECK_VALIDITY and _EXIT_TRACE
#define RESERVE(main, stub) RESERVE_RAW((main) + (stub) + 3, uop_name(opcode))
Expand All @@ -493,7 +483,7 @@ translate_bytecode_to_trace(
if (trace_stack_depth >= TRACE_STACK_SIZE) { \
DPRINTF(2, "Trace stack overflow\n"); \
OPT_STAT_INC(trace_stack_overflow); \
ADD_TO_TRACE(_SET_IP, 0, 0); \
ADD_TO_TRACE(_EXIT_TRACE, 0, 0, 0); \
goto done; \
} \
trace_stack[trace_stack_depth].code = code; \
Expand All @@ -513,22 +503,28 @@ translate_bytecode_to_trace(
PyUnicode_AsUTF8(code->co_filename),
code->co_firstlineno,
2 * INSTR_IP(initial_instr, code));

uint32_t target = 0;
top: // Jump here after _PUSH_FRAME or likely branches
for (;;) {
target = INSTR_IP(instr, code);
RESERVE_RAW(3, "epilogue"); // Always need space for _SET_IP, _CHECK_VALIDITY and _EXIT_TRACE
ADD_TO_TRACE(_SET_IP, INSTR_IP(instr, code), 0);
ADD_TO_TRACE(_CHECK_VALIDITY, 0, 0);
ADD_TO_TRACE(_SET_IP, target, 0, target);
ADD_TO_TRACE(_CHECK_VALIDITY, 0, 0, target);

uint32_t opcode = instr->op.code;
uint32_t oparg = instr->op.arg;
uint32_t extras = 0;

while (opcode == EXTENDED_ARG) {

if (opcode == EXTENDED_ARG) {
instr++;
extras += 1;
opcode = instr->op.code;
oparg = (oparg << 8) | instr->op.arg;
if (opcode == EXTENDED_ARG) {
instr--;
goto done;
}
}

if (opcode == ENTER_EXECUTOR) {
Expand All @@ -554,7 +550,7 @@ translate_bytecode_to_trace(
DPRINTF(4, "%s(%d): counter=%x, bitcount=%d, likely=%d, uopcode=%s\n",
uop_name(opcode), oparg,
counter, bitcount, jump_likely, uop_name(uopcode));
ADD_TO_TRACE(uopcode, max_length, 0);
ADD_TO_TRACE(uopcode, max_length, 0, target);
if (jump_likely) {
_Py_CODEUNIT *target_instr = next_instr + oparg;
DPRINTF(2, "Jump likely (%x = %d bits), continue at byte offset %d\n",
Expand All @@ -569,7 +565,7 @@ translate_bytecode_to_trace(
{
if (instr + 2 - oparg == initial_instr && code == initial_code) {
RESERVE(1, 0);
ADD_TO_TRACE(_JUMP_TO_TOP, 0, 0);
ADD_TO_TRACE(_JUMP_TO_TOP, 0, 0, 0);
}
else {
OPT_STAT_INC(inner_loop);
Expand Down Expand Up @@ -653,7 +649,7 @@ translate_bytecode_to_trace(
expansion->uops[i].offset);
Py_FatalError("garbled expansion");
}
ADD_TO_TRACE(uop, oparg, operand);
ADD_TO_TRACE(uop, oparg, operand, target);
if (uop == _POP_FRAME) {
TRACE_STACK_POP();
DPRINTF(2,
Expand Down Expand Up @@ -682,15 +678,15 @@ translate_bytecode_to_trace(
PyUnicode_AsUTF8(new_code->co_filename),
new_code->co_firstlineno);
OPT_STAT_INC(recursive_call);
ADD_TO_TRACE(_SET_IP, 0, 0);
ADD_TO_TRACE(_EXIT_TRACE, 0, 0, 0);
goto done;
}
if (new_code->co_version != func_version) {
// func.__code__ was updated.
// Perhaps it may happen again, so don't bother tracing.
// TODO: Reason about this -- is it better to bail or not?
DPRINTF(2, "Bailing because co_version != func_version\n");
ADD_TO_TRACE(_SET_IP, 0, 0);
ADD_TO_TRACE(_EXIT_TRACE, 0, 0, 0);
goto done;
}
// Increment IP to the return address
Expand All @@ -707,7 +703,7 @@ translate_bytecode_to_trace(
2 * INSTR_IP(instr, code));
goto top;
}
ADD_TO_TRACE(_SET_IP, 0, 0);
ADD_TO_TRACE(_EXIT_TRACE, 0, 0, 0);
goto done;
}
}
Expand All @@ -732,7 +728,7 @@ translate_bytecode_to_trace(
assert(code == initial_code);
// Skip short traces like _SET_IP, LOAD_FAST, _SET_IP, _EXIT_TRACE
if (trace_length > 4) {
ADD_TO_TRACE(_EXIT_TRACE, 0, 0);
ADD_TO_TRACE(_EXIT_TRACE, 0, 0, target);
DPRINTF(1,
"Created a trace for %s (%s:%d) at byte offset %d -- length %d+%d\n",
PyUnicode_AsUTF8(code->co_qualname),
Expand Down
22 changes: 10 additions & 12 deletions Python/optimizer_analysis.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,15 @@ remove_unneeded_uops(_PyUOpInstruction *buffer, int buffer_size)
{
// Note that we don't enter stubs, those SET_IPs are needed.
int last_set_ip = -1;
bool need_ip = true;
bool maybe_invalid = false;
for (int pc = 0; pc < buffer_size; pc++) {
int opcode = buffer[pc].opcode;
if (opcode == _SET_IP) {
if (!need_ip && last_set_ip >= 0) {
buffer[last_set_ip].opcode = NOP;
}
need_ip = false;
buffer[pc].opcode = NOP;
last_set_ip = pc;
}
else if (opcode == _CHECK_VALIDITY) {
if (maybe_invalid) {
/* Exiting the trace requires that IP is correct */
need_ip = true;
maybe_invalid = false;
}
else {
Expand All @@ -42,12 +36,16 @@ remove_unneeded_uops(_PyUOpInstruction *buffer, int buffer_size)
break;
}
else {
// If opcode has ERROR or DEOPT, set need_ip to true
if (_PyOpcode_opcode_metadata[opcode].flags & (HAS_ERROR_FLAG | HAS_DEOPT_FLAG) || opcode == _PUSH_FRAME) {
need_ip = true;
}
if (_PyOpcode_opcode_metadata[opcode].flags & HAS_ESCAPES_FLAG) {
if (OPCODE_HAS_ESCAPES(opcode)) {
maybe_invalid = true;
if (last_set_ip >= 0) {
buffer[last_set_ip].opcode = _SET_IP;
}
}
if (OPCODE_HAS_ERROR(opcode) || opcode == _PUSH_FRAME) {
if (last_set_ip >= 0) {
buffer[last_set_ip].opcode = _SET_IP;
}
}
}
}
Expand Down
35 changes: 29 additions & 6 deletions Tools/cases_generator/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import parsing
from typing import AbstractSet

WHITELIST = (
NON_ESCAPING_FUNCTIONS = (
"Py_INCREF",
"_PyDictOrValues_IsValues",
"_PyObject_DictOrValuesPointer",
Expand All @@ -31,9 +31,29 @@
"_PyLong_IsNonNegativeCompact",
"_PyLong_CompactValue",
"_Py_NewRef",
"_Py_IsImmortal",
"_Py_STR",
"_PyLong_Add",
"_PyLong_Multiply",
"_PyLong_Subtract",
"Py_NewRef",
"_PyList_ITEMS",
"_PyTuple_ITEMS",
"_PyList_AppendTakeRef",
"_Py_atomic_load_uintptr_relaxed",
"_PyFrame_GetCode",
"_PyThreadState_HasStackSpace",
)

def makes_escaping_api_call(instr: parsing.Node) -> bool:
ESCAPING_FUNCTIONS = (
"import_name",
"import_from",
)


def makes_escaping_api_call(instr: parsing.InstDef) -> bool:
if "CALL_INTRINSIC" in instr.name:
return True;
tkns = iter(instr.tokens)
for tkn in tkns:
if tkn.kind != lx.IDENTIFIER:
Expand All @@ -44,13 +64,17 @@ def makes_escaping_api_call(instr: parsing.Node) -> bool:
return False
if next_tkn.kind != lx.LPAREN:
continue
if tkn.text in ESCAPING_FUNCTIONS:
return True
if not tkn.text.startswith("Py") and not tkn.text.startswith("_Py"):
continue
if tkn.text.endswith("Check"):
continue
if tkn.text.startswith("Py_Is"):
continue
if tkn.text.endswith("CheckExact"):
continue
if tkn.text in WHITELIST:
if tkn.text in NON_ESCAPING_FUNCTIONS:
continue
return True
return False
Expand All @@ -74,7 +98,7 @@ def __post_init__(self) -> None:
self.bitmask = {name: (1 << i) for i, name in enumerate(self.names())}

@staticmethod
def fromInstruction(instr: parsing.Node) -> "InstructionFlags":
def fromInstruction(instr: parsing.InstDef) -> "InstructionFlags":
has_free = (
variable_used(instr, "PyCell_New")
or variable_used(instr, "PyCell_GET")
Expand All @@ -101,8 +125,7 @@ def fromInstruction(instr: parsing.Node) -> "InstructionFlags":
or variable_used(instr, "resume_with_error")
),
HAS_ESCAPES_FLAG=(
variable_used(instr, "tstate")
or makes_escaping_api_call(instr)
makes_escaping_api_call(instr)
),
)

Expand Down

0 comments on commit 44f3d21

Please sign in to comment.