Skip to content

Commit

Permalink
Support and surface (print_)last_interpreter_log() to the public AP…
Browse files Browse the repository at this point in the history
…I, introduce `TraceResults` (#115)
  • Loading branch information
apaz-cli committed Apr 9, 2024
1 parent 3a91b01 commit dba8ce7
Show file tree
Hide file tree
Showing 10 changed files with 188 additions and 141 deletions.
2 changes: 1 addition & 1 deletion docs/source/reference/thunder.rst
Expand Up @@ -32,7 +32,7 @@ Querying information on compiled functions and modules
cache_misses
list_transforms
last_interpreted_instructions
last_interpreted_history
last_interpreter_log
last_compile_options
..
compile
Expand Down
80 changes: 58 additions & 22 deletions thunder/__init__.py
Expand Up @@ -12,6 +12,7 @@

from looseversion import LooseVersion

from thunder.core.interpreter import InterpreterLogItem
from thunder.core.options import (
INTERPRETATION_OPTIONS,
resolve_interpretation_option,
Expand All @@ -20,6 +21,7 @@
SHARP_EDGES_OPTIONS,
)
from thunder.core.trace import (
TraceResults,
TraceCtx,
from_trace,
set_tracectx,
Expand Down Expand Up @@ -58,6 +60,7 @@
DictProxy,
AnyProxy,
)
from thunder.core.interpreter import print_interpreter_log, print_to_log
from thunder.core.jit_ext import thunder_general_jit
from thunder.executors.torch_autograd import split_forward_backward, ThunderFunction
from thunder.cudagraphs import CUDAGraphExecutor
Expand Down Expand Up @@ -171,7 +174,7 @@ def __version__():


# Translates the Python function to a thunder program using the thunder interpreter
def _general_frontend(fn: Callable, args, kwargs, /, *, sharp_edges: SHARP_EDGES_OPTIONS) -> tuple[TraceCtx, TraceCtx]:
def _general_frontend(fn: Callable, args, kwargs, /, *, sharp_edges: SHARP_EDGES_OPTIONS) -> TraceResults:
return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges)


Expand Down Expand Up @@ -442,7 +445,7 @@ def get_computation_and_inputs(*args, **kwargs):
cs.cache_hits += 1
cs.last_traces = comp_traces
cs.last_interpreted_instructions = None
cs.last_interpreted_history = None
cs.last_interpreter_log = None
cs.last_prologue_traces = pro_traces
cs.last_prologue = pro
cs.last_prologue_transformation_start = 0
Expand Down Expand Up @@ -481,7 +484,7 @@ def get_computation_and_inputs(*args, **kwargs):
cs.cache_hits += 1
cs.last_traces = comp_traces
cs.last_interpreted_instructions = None
cs.last_interpreted_history = None
cs.last_interpreter_log = None
cs.last_prologue_traces = pro_traces
cs.last_prologue = pro

Expand All @@ -501,17 +504,15 @@ def get_computation_and_inputs(*args, **kwargs):
with langctxs.langctx(cd.langctx):
prologue_trc: TraceCtx
computation_trc: TraceCtx
prologue_trc, computation_trc, *maybe_epilogue = interpreter(
fn, args, kwargs, sharp_edges=cd.sharp_edges
)

if maybe_epilogue:
epilogue_traces = maybe_epilogue
if epilogue_traces[-1] is not None:
epilogue = epilogue_traces[-1].python_callable()
else:
epilogue_traces = None
epilogue = None
jit_results: TraceResults = interpreter(fn, args, kwargs, sharp_edges=cd.sharp_edges)
prologue_trc = jit_results.prologue_trace
computation_trc = jit_results.computation_trace
epilogue_trc = jit_results.epilogue_trace
last_interpreter_log = jit_results.interpreter_log

if epilogue_trc is not None:
epilogue_traces = [epilogue_trc]
epilogue = epilogue_trc.python_callable()
else:
epilogue_traces = None
epilogue = None
Expand Down Expand Up @@ -542,6 +543,8 @@ def get_computation_and_inputs(*args, **kwargs):
cs.last_traces = computation_traces
backward_traces = []
cs.last_backward_traces = backward_traces
cs.last_interpreter_log = last_interpreter_log
cs.last_interpreted_instructions = (i for i in last_interpreter_log if isinstance(i, dis.Instruction))

computation_trc = dce(computation_trc)
computation_traces.append(computation_trc)
Expand Down Expand Up @@ -787,28 +790,61 @@ def list_transforms(fn) -> list:
return fn._lc_transforms


def last_interpreted_instructions(fn: Callable) -> list[dis.Instruction]:
"""Returns the list of instructions the interpreter encountered while tracing through the
def last_interpreter_log(fn: Callable) -> list[InterpreterLogItem]:
"""Returns the list of instructions and other information the interpreter encountered while tracing through the
user program (on the last cache miss).
"""
cs = compile_stats(fn)
if cs is None:
raise TypeError(f"{fn} doesn't seem to be a thunder compiled function.")
if cs.last_interpreted_instructions is None:
if cs.last_interpreter_log is None:
raise TypeError(f"{fn} doesn't seem to have been called yet.")
return cs.last_interpreted_instructions
return cs.last_interpreter_log


def last_interpreted_history(fn: Callable) -> list[dis.Instruction | str]:
"""Returns the list of instructions and other information the interpreter encountered while tracing through the
def last_interpreted_instructions(fn: Callable) -> list[dis.Instruction]:
"""Returns the list of instructions the interpreter encountered while tracing through the
user program (on the last cache miss).
"""
cs = compile_stats(fn)
if cs is None:
raise TypeError(f"{fn} doesn't seem to be a thunder compiled function.")
if cs.last_interpreted_history is None:
if cs.last_interpreted_instructions is None:
raise TypeError(f"{fn} doesn't seem to have been called yet.")
return cs.last_interpreted_history
return list(cs.last_interpreted_instructions)


def print_last_interpreter_log(
fn: Callable,
/,
print_fn: Callable = print,
use_colors: bool = True,
indent: bool = True,
max_depth: int | None = None,
color_internals: bool = False,
print_source_code: bool = True,
) -> None:
"""Prints a log of the last run of the interpreter for the given function.
Args:
fn: The function returned by `thunder.jit()` to print the last interpreter run log for. The function must have been called at least once first.
print_fn: The function to use for printing. Defaults to builtin `print`.
use_colors: Whether to use colors in the output. Defaults to `None`, which attempts to autodetect if the terminal supports ANSI color.
indent: Whether to indent the output with function scope. Defaults to `True`.
max_depth: The maximum indentation depth of the output. Doesn't print log items nested deeper than the max depth. Defaults to `None`, which means no limit.
color_internals: Whether to color instructions implicitly interpreted by other instructions. Defaults to `False`, so that only the instructions in the user's code are highlighted in color.
print_source_code: Whether to print the source line below each LineLogItem in the log. Defaults to `True`.
"""
log = last_interpreter_log(fn)
print_interpreter_log(
log,
print_fn=print_fn,
use_colors=use_colors,
indent=indent,
max_depth=max_depth,
color_internals=color_internals,
print_source_code=print_source_code,
)


def last_compile_options(fn: Callable, /) -> None:
Expand Down
8 changes: 5 additions & 3 deletions thunder/common.py
@@ -1,4 +1,6 @@
import dis
from typing import Any, Optional
from collections.abc import Generator
from collections.abc import Callable
from enum import Enum, auto
from collections import deque, defaultdict
Expand Down Expand Up @@ -58,8 +60,8 @@ def __init__(self):
self.last_traces = None
self.last_prologue = None
self.last_prologue_traces = None
self.last_interpreted_instructions = None
self.last_interpreted_history = None
self.last_interpreted_instructions: Generator[dis.Instruction, None, None] | None = None
self.last_interpreter_log: list[InterpreterLogItem] | None = None

# torch.autograd.Function specific data
self.last_backward_traces = None
Expand Down Expand Up @@ -466,7 +468,7 @@ def cache_get(
# TODO Consider modeling additional calls to trace()
# TODO RC1 Change the way this is called to be trace(langctx, inline_trace, rename_proxies...)(fn, *args, **kwargs)
# to separate the traced function's args and kwargs from this function's kwargs
from thunder.core.interpreter import make_opaque
from thunder.core.interpreter import InterpreterLogItem, make_opaque


def trace(
Expand Down

0 comments on commit dba8ce7

Please sign in to comment.