diff --git a/docs/source/reference/thunder.rst b/docs/source/reference/thunder.rst index 706f9d6f0..fea68c5d7 100644 --- a/docs/source/reference/thunder.rst +++ b/docs/source/reference/thunder.rst @@ -32,7 +32,7 @@ Querying information on compiled functions and modules cache_misses list_transforms last_interpreted_instructions - last_interpreted_history + last_interpreter_log last_compile_options .. compile diff --git a/thunder/__init__.py b/thunder/__init__.py index 73c295d64..c54bdda0a 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -12,6 +12,7 @@ from looseversion import LooseVersion +from thunder.core.interpreter import InterpreterLogItem from thunder.core.options import ( INTERPRETATION_OPTIONS, resolve_interpretation_option, @@ -20,6 +21,7 @@ SHARP_EDGES_OPTIONS, ) from thunder.core.trace import ( + TraceResults, TraceCtx, from_trace, set_tracectx, @@ -58,6 +60,7 @@ DictProxy, AnyProxy, ) +from thunder.core.interpreter import print_interpreter_log, print_to_log from thunder.core.jit_ext import thunder_general_jit from thunder.executors.torch_autograd import split_forward_backward, ThunderFunction from thunder.cudagraphs import CUDAGraphExecutor @@ -171,7 +174,7 @@ def __version__(): # Translates the Python function to a thunder program using the thunder interpreter -def _general_frontend(fn: Callable, args, kwargs, /, *, sharp_edges: SHARP_EDGES_OPTIONS) -> tuple[TraceCtx, TraceCtx]: +def _general_frontend(fn: Callable, args, kwargs, /, *, sharp_edges: SHARP_EDGES_OPTIONS) -> TraceResults: return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges) @@ -442,7 +445,7 @@ def get_computation_and_inputs(*args, **kwargs): cs.cache_hits += 1 cs.last_traces = comp_traces cs.last_interpreted_instructions = None - cs.last_interpreted_history = None + cs.last_interpreter_log = None cs.last_prologue_traces = pro_traces cs.last_prologue = pro cs.last_prologue_transformation_start = 0 @@ -481,7 +484,7 @@ def get_computation_and_inputs(*args, **kwargs): cs.cache_hits += 1 cs.last_traces = comp_traces cs.last_interpreted_instructions = None - cs.last_interpreted_history = None + cs.last_interpreter_log = None cs.last_prologue_traces = pro_traces cs.last_prologue = pro @@ -501,17 +504,15 @@ def get_computation_and_inputs(*args, **kwargs): with langctxs.langctx(cd.langctx): prologue_trc: TraceCtx computation_trc: TraceCtx - prologue_trc, computation_trc, *maybe_epilogue = interpreter( - fn, args, kwargs, sharp_edges=cd.sharp_edges - ) - - if maybe_epilogue: - epilogue_traces = maybe_epilogue - if epilogue_traces[-1] is not None: - epilogue = epilogue_traces[-1].python_callable() - else: - epilogue_traces = None - epilogue = None + jit_results: TraceResults = interpreter(fn, args, kwargs, sharp_edges=cd.sharp_edges) + prologue_trc = jit_results.prologue_trace + computation_trc = jit_results.computation_trace + epilogue_trc = jit_results.epilogue_trace + last_interpreter_log = jit_results.interpreter_log + + if epilogue_trc is not None: + epilogue_traces = [epilogue_trc] + epilogue = epilogue_trc.python_callable() else: epilogue_traces = None epilogue = None @@ -542,6 +543,8 @@ def get_computation_and_inputs(*args, **kwargs): cs.last_traces = computation_traces backward_traces = [] cs.last_backward_traces = backward_traces + cs.last_interpreter_log = last_interpreter_log + cs.last_interpreted_instructions = (i for i in last_interpreter_log if isinstance(i, dis.Instruction)) computation_trc = dce(computation_trc) computation_traces.append(computation_trc) @@ -787,28 +790,61 @@ def list_transforms(fn) -> list: return fn._lc_transforms -def last_interpreted_instructions(fn: Callable) -> list[dis.Instruction]: - """Returns the list of instructions the interpreter encountered while tracing through the +def last_interpreter_log(fn: Callable) -> list[InterpreterLogItem]: + """Returns the list of instructions and other information the interpreter encountered while tracing through the user program (on the last cache miss). """ cs = compile_stats(fn) if cs is None: raise TypeError(f"{fn} doesn't seem to be a thunder compiled function.") - if cs.last_interpreted_instructions is None: + if cs.last_interpreter_log is None: raise TypeError(f"{fn} doesn't seem to have been called yet.") - return cs.last_interpreted_instructions + return cs.last_interpreter_log -def last_interpreted_history(fn: Callable) -> list[dis.Instruction | str]: - """Returns the list of instructions and other information the interpreter encountered while tracing through the +def last_interpreted_instructions(fn: Callable) -> list[dis.Instruction]: + """Returns the list of instructions the interpreter encountered while tracing through the user program (on the last cache miss). """ cs = compile_stats(fn) if cs is None: raise TypeError(f"{fn} doesn't seem to be a thunder compiled function.") - if cs.last_interpreted_history is None: + if cs.last_interpreted_instructions is None: raise TypeError(f"{fn} doesn't seem to have been called yet.") - return cs.last_interpreted_history + return list(cs.last_interpreted_instructions) + + +def print_last_interpreter_log( + fn: Callable, + /, + print_fn: Callable = print, + use_colors: bool = True, + indent: bool = True, + max_depth: int | None = None, + color_internals: bool = False, + print_source_code: bool = True, +) -> None: + """Prints a log of the last run of the interpreter for the given function. + + Args: + fn: The function returned by `thunder.jit()` to print the last interpreter run log for. The function must have been called at least once first. + print_fn: The function to use for printing. Defaults to builtin `print`. + use_colors: Whether to use colors in the output. Defaults to `None`, which attempts to autodetect if the terminal supports ANSI color. + indent: Whether to indent the output with function scope. Defaults to `True`. + max_depth: The maximum indentation depth of the output. Doesn't print log items nested deeper than the max depth. Defaults to `None`, which means no limit. + color_internals: Whether to color instructions implicitly interpreted by other instructions. Defaults to `False`, so that only the instructions in the user's code are highlighted in color. + print_source_code: Whether to print the source line below each LineLogItem in the log. Defaults to `True`. + """ + log = last_interpreter_log(fn) + print_interpreter_log( + log, + print_fn=print_fn, + use_colors=use_colors, + indent=indent, + max_depth=max_depth, + color_internals=color_internals, + print_source_code=print_source_code, + ) def last_compile_options(fn: Callable, /) -> None: diff --git a/thunder/common.py b/thunder/common.py index ec0ed1d26..42db8c627 100644 --- a/thunder/common.py +++ b/thunder/common.py @@ -1,4 +1,6 @@ +import dis from typing import Any, Optional +from collections.abc import Generator from collections.abc import Callable from enum import Enum, auto from collections import deque, defaultdict @@ -58,8 +60,8 @@ def __init__(self): self.last_traces = None self.last_prologue = None self.last_prologue_traces = None - self.last_interpreted_instructions = None - self.last_interpreted_history = None + self.last_interpreted_instructions: Generator[dis.Instruction, None, None] | None = None + self.last_interpreter_log: list[InterpreterLogItem] | None = None # torch.autograd.Function specific data self.last_backward_traces = None @@ -466,7 +468,7 @@ def cache_get( # TODO Consider modeling additional calls to trace() # TODO RC1 Change the way this is called to be trace(langctx, inline_trace, rename_proxies...)(fn, *args, **kwargs) # to separate the traced function's args and kwargs from this function's kwargs -from thunder.core.interpreter import make_opaque +from thunder.core.interpreter import InterpreterLogItem, make_opaque def trace( diff --git a/thunder/core/interpreter.py b/thunder/core/interpreter.py index 5a88abff7..6cf18a241 100644 --- a/thunder/core/interpreter.py +++ b/thunder/core/interpreter.py @@ -451,44 +451,38 @@ def interpretercompilectx(_interpretercompilectx: InterpreterCompileCtx): reset_interpretercompilectx(tok) -class LineHistoryItem(TypedDict): +class LineLogItem(TypedDict): kind: Literal["Line"] fn: Callable | CodeType filename: str position: Positions | None -class OpaqueHistoryItem(TypedDict): +class OpaqueLogItem(TypedDict): kind: Literal["Opaque"] fn: Callable -class LookasideHistoryItem(TypedDict): +class LookasideLogItem(TypedDict): kind: Literal["Lookaside"] fn: Callable -class CallHistoryItem(TypedDict): +class CallLogItem(TypedDict): kind: Literal["InterpreterCall"] fn: Callable prev_frame: str -class ReturnHistoryItem(TypedDict): +class ReturnLogItem(TypedDict): kind: Literal["InterpreterReturn"] fn: Callable is_signal: bool rval: type | INTERPRETER_SIGNALS -InterpreterHistoryItem = ( - dis.Instruction - | str - | LineHistoryItem - | OpaqueHistoryItem - | LookasideHistoryItem - | CallHistoryItem - | ReturnHistoryItem +InterpreterLogItem = ( + dis.Instruction | str | LineLogItem | OpaqueLogItem | LookasideLogItem | CallLogItem | ReturnLogItem ) @@ -521,7 +515,7 @@ class InterpreterRuntimeCtx: def __init__(self, *, debug_log: None | StringIO = None): self.frame_stack: list[InterpreterFrame] = [] self._globals_dict: dict[str, Any] | None = None - self._history: list[InterpreterHistoryItem] = [] + self._interpreter_log: list[InterpreterLogItem] = [] self._interpreted_instructions: list[dis.Instruction] = [] self._curexc: BaseException | None = None # The exception_stack mirrors the exc_info/exc_state from PyThreadState @@ -574,14 +568,14 @@ def interpreted_instructions(self) -> list[dis.Instruction]: # The operations and opaque calls encountered while interpreting @property - def history(self) -> list[InterpreterHistoryItem]: - return self._history + def interp_log(self) -> list[InterpreterLogItem]: + return self._interpreter_log - def record(self, val: InterpreterHistoryItem, /) -> None: - self._history.append(val) + def record(self, val: InterpreterLogItem, /) -> None: + self._interpreter_log.append(val) if self.debug_log is not None: - self.debug_log.write(f"Appended to history: {val}" + os.linesep) + self.debug_log.write(f"Appended to log: {val}" + os.linesep) def peek_interpreter_stack(self) -> InterpreterStack: return self.frame_stack[-1].interpreter_stack @@ -607,8 +601,8 @@ def push_frame_stack(self, frame: InterpreterFrame): pf = self._pop_frame_stack() assert pf is frame, "Frame stack inconsistency" - # TODO Instead of appending to both history and and interpreted_instructions we could - # consider just appending to history and then filtering to only instructions when + # TODO Instead of appending to both the log and and interpreted_instructions we could + # consider just appending to the log and then filtering to only instructions when # interpreted_instructions is accessed def record_interpreted_instruction(self, inst: dis.Instruction, /) -> InterpreterRuntimeCtx: self._interpreted_instructions.append(inst) @@ -640,16 +634,16 @@ def record_interpreter_call(self, fn: Callable) -> InterpreterRuntimeCtx: def record_interpreter_return(self, fn: Callable, rval: Any | INTERPRETER_SIGNALS, /) -> InterpreterRuntimeCtx: is_signal: bool = isinstance(rval, INTERPRETER_SIGNALS) - rv: type | INTERPRETER_SIGNALS = rval if is_signal else type(rval) - self.record(ReturnHistoryItem(kind="InterpreterReturn", fn=fn, is_signal=is_signal, rval=rv)) + rv: type | INTERPRETER_SIGNALS = rval if is_signal else type(unwrap(rval)) + self.record(ReturnLogItem(kind="InterpreterReturn", fn=fn, is_signal=is_signal, rval=rv)) return self def record_opaque_call(self, fn: Callable) -> InterpreterRuntimeCtx: - self.record(OpaqueHistoryItem(kind="Opaque", fn=fn)) + self.record(OpaqueLogItem(kind="Opaque", fn=fn)) return self def record_lookaside(self, fn: Callable) -> InterpreterRuntimeCtx: - self.record(LookasideHistoryItem(kind="Lookaside", fn=fn)) + self.record(LookasideLogItem(kind="Lookaside", fn=fn)) return self def record_position( @@ -664,7 +658,7 @@ def record_position( self._prev_position = position self._prev_filename = filename - line = LineHistoryItem(kind="Line", fn=fn, filename=filename, position=position) + line = LineLogItem(kind="Line", fn=fn, filename=filename, position=position) self.record(line) return self @@ -672,14 +666,14 @@ def format_traceback(self): return os.linesep.join(f.format_with_source() for f in self.frame_stack) -def print_to_history(*objects, sep=" ", end=os.linesep): +def print_to_log(*objects, sep=" ", end=os.linesep): if sep is None: sep = " " if end is None: end = os.linesep ctx: InterpreterRuntimeCtx = get_interpreterruntimectx() - ctx._history.append(str(sep).join(str(o) for o in objects) + str(end)) + ctx._interpreter_log.append(str(sep).join(str(o) for o in objects) + str(end)) _interpreterruntimectx = contextvars.ContextVar("interpreterruntimectx") @@ -6660,7 +6654,7 @@ def fn_2(args, kwargs): interpretation_result: Any = _interpret_call(wrapped_fn_2, args, kwargs) interpretation_result = unwrap(interpretation_result) - except Exception as e: + except BaseException as e: # TODO Highlight the portion of the line that originated the opcode on Python versions that include # the line offset information in the instruction traceback_str = os.linesep.join(f.format_with_source() for f in runtimectx.frame_stack) @@ -6668,14 +6662,9 @@ def fn_2(args, kwargs): f"Encountered exception {type(e).__name__}: {e} while tracing {fn}:{os.linesep}" f"{traceback_str}" ) raise InterpreterError(msg) from e - finally: - # NOTE: Wrapped functions are valid to assign new attributes to. - fn_._last_interpreted_instructions = runtimectx.interpreted_instructions # type: ignore - fn_._last_interpreted_history = runtimectx.history # type: ignore - # # NOTE: Wrapped functions are valid to assign new attributes to. - # fn_._last_interpreted_instructions = runtimectx.interpreted_instructions # type: ignore - # fn_._last_interpreted_history = runtimectx.history # type: ignore + # NOTE: Wrapped functions are valid to assign new attributes to. + fn_._last_interpreter_log = runtimectx.interp_log # type: ignore if interpretation_result is INTERPRETER_SIGNALS.EXCEPTION_RAISED: e = runtimectx.curexc @@ -6689,19 +6678,19 @@ def fn_2(args, kwargs): return fn_ -def last_interpreted_instructions(fn: Callable) -> None | list[dis.Instruction]: - return getattr(fn, "_last_interpreted_instructions", None) +def last_interpreted_instructions(fn: Callable) -> list[dis.Instruction]: + return [i for i in getattr(fn, "_last_interpreter_log", ()) if isinstance(i, dis.Instruction)] -def last_interpreted_history(fn: Callable) -> None | list[InterpreterHistoryItem]: - return getattr(fn, "_last_interpreted_history", None) +def last_interpreter_log(fn: Callable) -> list[InterpreterLogItem]: + return getattr(fn, "_last_interpreter_log", []) -def print_history( - history: list[InterpreterHistoryItem], +def print_interpreter_log( + interpreter_log: list[InterpreterLogItem], /, print_fn: Callable = print, - use_colors: bool = True, + use_colors: bool | None = None, indent: bool = True, max_depth: int | None = None, color_internals: bool = False, @@ -6712,36 +6701,37 @@ def print_history( c_indent = -1 inside_inner_interpreter = False - for item in history: + for item in interpreter_log: linecolor = "" nl = "" deindent = False source_line = None - # Match each kind of history item. The history items are instructions, strings, + # Match each kind of log item. The log items are instructions, strings, # or typed dicts, with "kind" describing what kind of entry it is. match item: case dis.Instruction(): if color_internals or not inside_inner_interpreter: linecolor = colors["MAGENTA"] - history_line = f"Instruction('{item.opname}', arg={item.arg}, argrepr={repr(item.argrepr)})" + log_line = f"Instruction('{item.opname}', arg={item.arg}, argrepr={repr(item.argrepr)})" case str(): # Print the string as-is, indented, without colors. linecolor = colors["RESET"] - history_line = item + log_line = item case {"kind": "Line", "fn": _fn, "filename": filename, "position": position}: - # LineHistoryItem + # LineLogItem + _fn = unwrap(_fn) inside_inner_interpreter = interpreter_path in filename if color_internals or not inside_inner_interpreter: linecolor = colors["YELLOW"] nl = os.linesep fnname = extract_callable_name(_fn) if position: - history_line = f"# Line {filename}:{position.lineno} in {fnname}()" + log_line = f"# Line {filename}:{position.lineno} in {fnname}()" else: - history_line = f"# {filename} in {fnname}()" + log_line = f"# {filename} in {fnname}()" if not print_source_code or not position: continue @@ -6754,65 +6744,46 @@ def print_history( source_line = linestr case {"kind": "InterpreterCall", "fn": fn, "prev_frame": prev_frame}: - # CallHistoryItem + # CallLogItem + fn = unwrap(fn) if color_internals or not inside_inner_interpreter: linecolor = colors["GREEN"] c_indent += 1 - history_line = f"Interpreting call to {extract_callable_name(fn)}() from {prev_frame}{'()' if not prev_frame.endswith('>') else ''}" + log_line = f"Interpreting call to {extract_callable_name(fn)}() from {prev_frame}{'()' if not prev_frame.endswith('>') else ''}" case {"kind": "InterpreterReturn", "fn": fn, "is_signal": is_signal, "rval": rval}: - # ReturnHistoryItem + # ReturnLogItem + fn = unwrap(fn) + rval = unwrap(rval) if color_internals or not inside_inner_interpreter: linecolor = colors["RED"] deindent = True meaning = "signal" if is_signal else "value of type" val = rval if is_signal else rval.__qualname__ - history_line = f"Returning from call to {extract_callable_name(fn)}() with {meaning} {val}" + log_line = f"Returning from call to {extract_callable_name(fn)}() with {meaning} {val}" case {"kind": "Lookaside", "fn": fn}: - # LookasideHistoryItem + # LookasideLogItem + fn = unwrap(fn) if color_internals or not inside_inner_interpreter: linecolor = colors["BLUE"] - history_line = f"Lookaside to {extract_callable_name(fn)}()" + log_line = f"Lookaside to {extract_callable_name(fn)}()" case {"kind": "Opaque", "fn": fn}: - # OpaqueHistoryItem + # OpaqueLogItem + fn = unwrap(fn) if color_internals or not inside_inner_interpreter: linecolor = colors["CYAN"] - history_line = f"Opaque call to {fn} with name {extract_callable_name(fn)}" + log_line = f"Opaque call to {fn} with name {extract_callable_name(fn)}" case _: - raise NotImplementedError(f"Unexpected history item {item}") + raise NotImplementedError(f"Unexpected log item {item}") if max_depth is None or c_indent <= max_depth: - print_fn(f"{nl}{' ' * c_indent if indent else ''}{linecolor}{history_line}{colors['RESET']}") + print_fn(f"{nl}{' ' * c_indent if indent else ''}{linecolor}{log_line}{colors['RESET']}") if source_line: print_fn(f"{' ' * c_indent if indent else ''}{linecolor}{source_line}{colors['RESET']}") if deindent: c_indent -= 1 - - -def print_last_interpreted_history( - fn: Callable, - /, - print_fn: Callable = print, - use_colors: bool = True, - indent: bool = True, - max_depth: int | None = None, - color_internals: bool = False, - print_source_code: bool = True, -) -> None: - if (history := last_interpreted_history(fn)) is None: - print("No history could be found.") - return - print_history( - history, - print_fn=print_fn, - use_colors=use_colors, - indent=indent, - max_depth=max_depth, - color_internals=color_internals, - print_source_code=print_source_code, - ) diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index 0f9206037..5982fda18 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -67,6 +67,7 @@ ) from thunder.core.trace import set_tracectx, reset_tracectx, tracectx, from_trace from thunder.core.interpreter import ( + InterpreterLogItem, interpret, _interpret_call, CapsuleType, @@ -98,7 +99,7 @@ from thunder.extend import Executor from thunder.common import CompileData, CompileStats -from thunder.core.trace import TraceCtx +from thunder.core.trace import TraceCtx, TraceResults from thunder.torch import _torch_to_thunder_function_map from thunder.clang import _clang_fn_set from thunder.core.pytree import tree_map @@ -1394,9 +1395,7 @@ def _get_process_group_from(*fn_and_args) -> Optional["ProcessGroup"]: return found_pg -def thunder_general_jit( - fn: Callable, args, kwargs, /, *, sharp_edges: SHARP_EDGES_OPTIONS -) -> tuple[TraceCtx, TraceCtx]: +def thunder_general_jit(fn: Callable, args, kwargs, /, *, sharp_edges: SHARP_EDGES_OPTIONS) -> TraceResults: # TODO: move into wrap_callback or so if isinstance(fn, torch.nn.parallel.DistributedDataParallel): raise NotImplementedError( @@ -1409,7 +1408,7 @@ def thunder_general_jit( prologue_trace: TraceCtx = TraceCtx(fn) computation_trace: TraceCtx = TraceCtx() - epilogue_trace: TraceCtx = TraceCtx() + epilogue_trace: TraceCtx | None = TraceCtx() si = SigInfo("prologue") si.varargs = ("args", None) @@ -1441,6 +1440,8 @@ def thunder_general_jit( prims.python_return(result) process_recorded_modifications(ctx, epilogue_trace) + last_interpreter_log = jfn._last_interpreter_log + pro_to_comp, computation_intermediates = get_computation_inputs_and_intermediates(computation_trace) epilogue_inputs, _ = get_computation_inputs_and_intermediates(epilogue_trace) @@ -1499,4 +1500,4 @@ def restrict_proxy_swapmap(proxies: tuple[Proxy]) -> dict[Variable, Proxy]: epilogue_trace, restrict_proxy_swapmap(pro_to_epi_proxies + comp_to_epi_proxies), "epilogue" ) - return prologue_trace, computation_trace, epilogue_trace + return TraceResults(prologue_trace, computation_trace, epilogue_trace, last_interpreter_log) diff --git a/thunder/core/trace.py b/thunder/core/trace.py index 88f195058..dc2029e79 100644 --- a/thunder/core/trace.py +++ b/thunder/core/trace.py @@ -572,3 +572,16 @@ def _set_execution_file(path: str) -> None: def _get_execution_file() -> None | str: return _execution_file.get() + + +# +# Container for the two/three types of traces, plus extra tracked data +# + + +class TraceResults: + def __init__(self, prologue: TraceCtx, computation: TraceCtx, epilogue: TraceCtx | None, interpreter_log: list): + self.prologue_trace = prologue + self.computation_trace: TraceCtx = computation + self.epilogue_trace = epilogue + self.interpreter_log = interpreter_log diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 33d1ed521..744112ae5 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -419,7 +419,10 @@ def add_transform(cfn: Callable, transform: Callable) -> Callable: **cd.compile_options, ) - cs = CompileStats() + cs = getattr(cfn, "_lc_cs", None) + if cs is None: + cs = CompileStats() + transforms = cfn._lc_transforms + [transform] potransforms = cfn._lc_post_optimization_transforms using_grad_transform = cfn._using_grad_transform diff --git a/thunder/functional.py b/thunder/functional.py index a5cc86409..b28abc140 100644 --- a/thunder/functional.py +++ b/thunder/functional.py @@ -14,6 +14,7 @@ from thunder.core.trace import ( TraceCtx, tracectx, + TraceResults, ) import thunder.core.prims as prims @@ -300,7 +301,7 @@ def _eager_unpack(x: Any, /, name: None | str, *, co: CACHE_OPTIONS) -> tuple[Pr # returns what the original function did def _eager_unpacking_interpreter( interpreter: Callable, fn: Callable, args, kwargs, /, *, interpreter_name: str -) -> tuple[TraceCtx, TraceCtx]: +) -> TraceResults: # Unpacks the inputs si: SigInfo = get_siginfo(fn, args, kwargs) @@ -385,7 +386,9 @@ def _eager_unpacking_interpreter( csi.args.append((p.name, None)) computation_trc.add_name(p.name) - result = interpreter(si.unwrapped_fn)(*interpretation_args, **interpretation_kwargs) + jfn = interpreter(si.unwrapped_fn) + result = jfn(*interpretation_args, **interpretation_kwargs) + interpreter_log = getattr(jfn, "_last_interpreter_log", []) # Validates that the returned items are proxies or printable values def leaf_test(x: Any) -> bool: @@ -412,13 +415,11 @@ def leaf_test(x: Any) -> bool: computation_trc._siginfo = csi computation_trc.args = computation_args - return prologue_trc, computation_trc + return TraceResults(prologue_trc, computation_trc, None, interpreter_log) # Translates the Python function a thunder program using the Python interpreter -def _python_interpreter( - fn: Callable, args, kwargs, /, *, sharp_edges: SHARP_EDGES_OPTIONS -) -> tuple[TraceCtx, TraceCtx]: +def _python_interpreter(fn: Callable, args, kwargs, /, *, sharp_edges: SHARP_EDGES_OPTIONS) -> TraceResults: if sharp_edges is not SHARP_EDGES_OPTIONS.ALLOW: raise ValueError( f"Detecting sharp edges is not supported when using the Python interpreter. To detect sharp edges use another interpretation option." @@ -433,7 +434,7 @@ def _interpreter(fn_): # Translates the Python function to a thunder program using the thunder interpreter def _translate_functions_interpreter( fn: Callable, args, kwargs, /, *, sharp_edges: SHARP_EDGES_OPTIONS -) -> tuple[TraceCtx, TraceCtx]: +) -> TraceResults: from thunder.core.jit_ext import minimal_thunder_jit pjit = partial(minimal_thunder_jit, sharp_edges=sharp_edges) diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index 377a284a6..c8af5a58e 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -2186,9 +2186,9 @@ def func(qkv): compiled = thunder.jit(func, executors=executor.executors_list()) out = compiled(qkv) - history = thunder.last_traces(compiled) + traces = thunder.last_traces(compiled) torch.testing.assert_close(out, func(qkv)) - assert "scaled_dot_product_attention" in tuple(bsym.sym.id for bsym in history[-1].bound_symbols) + assert "scaled_dot_product_attention" in tuple(bsym.sym.id for bsym in traces[-1].bound_symbols) @instantiate(dtypes=NOTHING) diff --git a/thunder/tests/test_interpreter.py b/thunder/tests/test_interpreter.py index d326cb2b5..b3641d901 100644 --- a/thunder/tests/test_interpreter.py +++ b/thunder/tests/test_interpreter.py @@ -19,7 +19,9 @@ make_opaque, interpret, InterpreterError, - print_last_interpreted_history, + print_interpreter_log, + last_interpreter_log, + last_interpreted_instructions, ) # @@ -1455,10 +1457,10 @@ def foo(): jfoo = jit(foo) assert foo() == 3 assert jfoo() == 3 - assert any(i.opname == "MATCH_KEYS" for i in jfoo._last_interpreted_instructions) - assert any(i.opname == "MATCH_MAPPING" for i in jfoo._last_interpreted_instructions) + assert any(i.opname == "MATCH_KEYS" for i in last_interpreted_instructions(jfoo)) + assert any(i.opname == "MATCH_MAPPING" for i in last_interpreted_instructions(jfoo)) if "COPY_DICT_WITHOUT_KEYS" in dis.opmap.keys(): - assert any(i.opname == "COPY_DICT_WITHOUT_KEYS" for i in jfoo._last_interpreted_instructions) + assert any(i.opname == "COPY_DICT_WITHOUT_KEYS" for i in last_interpreted_instructions(jfoo)) # Test MATCH_SEQUENCE def bar(): @@ -1473,9 +1475,9 @@ def bar(): jbar = jit(bar) assert bar() == 3 assert jbar() == 3 - assert any(i.opname == "MATCH_SEQUENCE" for i in jbar._last_interpreted_instructions) - assert any(i.opname == "GET_LEN" for i in jbar._last_interpreted_instructions) - assert any(i.opname == "UNPACK_EX" for i in jbar._last_interpreted_instructions) + assert any(i.opname == "MATCH_SEQUENCE" for i in last_interpreted_instructions(jbar)) + assert any(i.opname == "GET_LEN" for i in last_interpreted_instructions(jbar)) + assert any(i.opname == "UNPACK_EX" for i in last_interpreted_instructions(jbar)) def test_class_match_statement(jit): @@ -1499,7 +1501,7 @@ def foo(): jfoo = jit(foo) assert foo() == 3 assert jfoo() == 3 - assert any(i.opname == "MATCH_CLASS" for i in jfoo._last_interpreted_instructions) + assert any(i.opname == "MATCH_CLASS" for i in last_interpreted_instructions(jfoo)) def test_match_fallthrough(jit): @@ -2719,7 +2721,7 @@ def test_displayhook(jit): def smt(s): interpreter.runsource(s) - smt("from thunder.core.interpreter import interpret") + smt("from thunder.core.interpreter import interpret, last_interpreted_instructions") smt( """ def foo(): @@ -2734,7 +2736,7 @@ def foo(): ) smt("jfoo = interpret(foo)") smt("jfoo()") - smt("assert any(i.opname == 'PRINT_EXPR' for i in jfoo._last_interpreted_instructions)") + smt("assert any(i.opname == 'PRINT_EXPR' for i in last_interpreted_instructions(jfoo))") py_out: str = py_redirect.getvalue() assert py_out == "redirected 5\nredirected 6\nredirected 7\nReset.\n", py_out @@ -2756,7 +2758,7 @@ def __init__(self): assert cp().bar == jp().bar assert any(i.opname == "LOAD_BUILD_CLASS" for i in dis.get_instructions(foo)) - assert any(i.opname == "LOAD_BUILD_CLASS" for i in jfoo._last_interpreted_instructions) + assert any(i.opname == "LOAD_BUILD_CLASS" for i in last_interpreted_instructions(jfoo)) def test_with(jit): @@ -2943,6 +2945,24 @@ def baz(): assert res == jres +def test_print_log_types(jit): + def foo(): + return 5 + + jfoo = jit(foo) + jfoo() + + log = last_interpreter_log(jfoo) + + # print into string + buf = io.StringIO() + with redirect_stdout(buf): + print_interpreter_log(log, use_colors=False, indent=False) + bufstr = buf.getvalue() + + assert "Returning from call to test_print_log_types..foo() with value of type int" in bufstr + + def test_is_jitting_with_raise(jit): def foo(): return is_jitting_with_raise() @@ -3023,9 +3043,9 @@ def fn(*args): y = jm(x) y.sum().backward() - # Find the hook registration in the history the normal way + # Find the hook registration in the log the normal way found = False - for item in jm._last_interpreted_history: + for item in last_interpreter_log(jm): if (not isinstance(item, dict)) or (item["kind"] != "Opaque"): continue _fn = item["fn"] @@ -3035,10 +3055,10 @@ def fn(*args): assert found - # Redirect print_last_interpreted_history from stdout to a string, and assert that it's in there. + # Redirect print_last_interpreter_log from stdout to a string, and assert that it's in there. buf = io.StringIO() with redirect_stdout(buf): - print_last_interpreted_history(jm, use_colors=False, indent=False) + print_interpreter_log(last_interpreter_log(jm), use_colors=False, indent=False) match_against = "Opaque call to with name _FunctionBase.register_hook" assert match_against in buf.getvalue()