In [1]:
import os
import sys

thunder_path = os.path.abspath(os.path.join(os.path.abspath(""), "..", "..", ".."))
if thunder_path not in sys.path:
    sys.path.append(thunder_path)

In [2]:
import thunder

In [3]:
def masked_apply(x, mask, layer_0, layer_1):
    x = layer_0(x, mask)
    x = layer_1(x, mask if mask is not None else 1)
    return x, mask is None

In [4]:
from thunder.core.script.frontend import _construct_protograph, acquire_method
proto_graph = _construct_protograph(masked_apply)

In [5]:
from thunder.core.script import parse, values

provenance = []
last = proto_graph.provenance
while isinstance(last, tuple):
    provenance.append(last)
    transform, prior_proto_graph = last
    last = prior_proto_graph.provenance
assert isinstance(last, values.ParsedSymbolic)
provenance.extend((last, last.provenance, last.provenance.provenance))
provenance.reverse()

disassembled, parsed, parsed_symbolic, *protograph_transforms = provenance
assert isinstance(disassembled, parse.Disassembled)
assert isinstance(parsed, parse.ParsedFunctional)
assert isinstance(parsed_symbolic, values.ParsedSymbolic)

In [6]:
import dis
import io
import itertools

print(disassembled.code, "\n\n")
dis.dis(disassembled.code, file=(buffer := io.StringIO()))
buffer.seek(0)
dis_lines = buffer.read().splitlines(False)

block_lines = []
for idx, block in enumerate(disassembled.raw.values()):
    block_lines.extend(f"{'' if idy else f'{idx})':<4}{instruction.opname}" for idy, instruction in enumerate(block))
    block_lines.append("")

pad = max(len(l) for l in dis_lines)
for dis_line, block_line in itertools.zip_longest(dis_lines, block_lines, fillvalue=""):
    print(f"{dis_line:<{pad + 10}} ||{' ' * 10}{block_line}")

<code object masked_apply at 0x7f064f0d8500, file "/tmp/ipykernel_630313/933819650.py", line 1> 


  2           0 LOAD_FAST                2 (layer_0)           ||          0)  LOAD_FAST
              2 LOAD_FAST                0 (x)                 ||              LOAD_FAST
              4 LOAD_FAST                1 (mask)              ||              LOAD_FAST
              6 CALL_FUNCTION            2                     ||              CALL_FUNCTION
              8 STORE_FAST               0 (x)                 ||              STORE_FAST
                                                               ||              LOAD_FAST
  3          10 LOAD_FAST                3 (layer_1)           ||              LOAD_FAST
             12 LOAD_FAST                0 (x)                 ||              LOAD_FAST
             14 LOAD_FAST                1 (mask)              ||              LOAD_CONST
             16 LOAD_CONST               0 (None)              ||              IS_OP
         

In [7]:
print(parsed.summary)

Block 0:  [] => [layer_1, v0]
  LOAD[layer_0, x, mask]
  CALL_FUNCTION . . . . . . . . . . . . . . (layer_0, x, mask) -> v0
  STORE[x]
  LOAD[layer_1, x, mask, None: CONST]
  IS_OP . . . . . . . . . . . . . . . . . . (mask, None) -> v1
  POP_JUMP_IF_FALSE . . . . . . . . . . . . (v1) -> 
      -> 1, 2(Jump)

Block 1:  [⓵ , ⓶ ] => [⓵ , ⓶ , mask]
  LOAD[mask]
  JUMP_FORWARD
      -> 3(Jump)

Block 2:  [⓵ , ⓶ ] => [⓵ , ⓶ , 1]
  LOAD[1: CONST]
  JUMP_ABSOLUTE*
      -> 3(Jump)

Block 3:  [⓵ , ⓶ , ⓷ ] => []
  CALL_FUNCTION . . . . . . . . . . . . . . (⓵ , ⓶ , ⓷ ) -> v0
  STORE[x]
  LOAD[x, mask, None: CONST]
  IS_OP . . . . . . . . . . . . . . . . . . (mask, None) -> v1
  BUILD_TUPLE . . . . . . . . . . . . . . . (v0, v1) -> v2
  RETURN_VALUE . . . . . . . . . . . . . .  (v2) -> 



In [8]:
def pretty_repr(x) -> str:
    if isinstance(x, parse.VariableKey):
        return f"{x.identifier}({x.scope.name})"
    if isinstance(x, values.OutputRef):
        return f"OutputRef({x.instruction.opname}, idx={x.idx})"
    return repr(x)

for block, begin, end in parsed_symbolic.blocks:
    # At this point `begin` isn't very interesting as it's all just placeholders.
    for instruction, symbolic in block.items():
        
        inputs = ", ".join(pretty_repr(i) for i in symbolic.inputs.ordered)
        print(f"{instruction.opname:<25}  ({inputs}) -> {symbolic.outputs}")
    print()

CALL_FUNCTION              (layer_0(LOCAL), x(LOCAL), mask(LOCAL)) -> (IntermediateValue(at 0x7f064f0c26b0),)
IS_OP                      (mask(LOCAL), None(CONST)) -> (IntermediateValue(at 0x7f064f0c3070),)
POP_JUMP_IF_FALSE          (OutputRef(IS_OP, idx=0)) -> ()

JUMP_FORWARD               () -> ()

JUMP_ABSOLUTE              () -> ()

CALL_FUNCTION              (0(STACK), 1(STACK), 2(STACK)) -> (IntermediateValue(at 0x7f064f0f0040),)
IS_OP                      (mask(LOCAL), None(CONST)) -> (IntermediateValue(at 0x7f064f0f0250),)
BUILD_TUPLE                (OutputRef(CALL_FUNCTION, idx=0), OutputRef(IS_OP, idx=0)) -> (IntermediateValue(at 0x7f064f0f0460),)
RETURN_VALUE               (OutputRef(BUILD_TUPLE, idx=0)) -> ()



In [9]:
for transform, _ in protograph_transforms:
    print(transform.__name__)

print(f"\n{'=' * 80}\n", proto_graph)

Unlink
MarkTuples
AddTransitive
Connect

 ProtoBlock: 0x7f064f0f1cf0
  CALL_FUNCTION
  IS_OP
  POP_JUMP_IF_FALSE

ProtoBlock: 0x7f064ef1e560
  JUMP_FORWARD

ProtoBlock: 0x7f064ef81b10
  JUMP_ABSOLUTE

ProtoBlock: 0x7f064ef9cfd0
  CALL_FUNCTION
  IS_OP
  BUILD_TUPLE
  RETURN_VALUE


In [10]:
g = acquire_method(masked_apply)
print(g)

Graph of
  Block (reached from [None])
    CALL_FUNCTION([<thunder.core.script.graph.PhiValue object at 0x7f064f0c1b70 PhiValue 0x7f064f0c1b70 ()>, <thunder.core.script.graph.PhiValue object at 0x7f064efc0310 PhiValue 0x7f064efc0310 ()>, <thunder.core.script.graph.PhiValue object at 0x7f064efc0280 PhiValue 0x7f064efc0280 ()>])
    IS_OP 1 (1)
    POP_JUMP_IF_FALSE 13 (26)
  Block (reached from [<thunder.core.script.graph.Node object at 0x7f069f46da80 POP_JUMP_IF_FALSE 13 (26)>])
    JUMP_FORWARD 1 (28)
  Block (reached from [<thunder.core.script.graph.Node object at 0x7f069f46da80 POP_JUMP_IF_FALSE 13 (26)>])
    JUMP_ABSOLUTE 14 (None)
  Block (reached from [<thunder.core.script.graph.Node object at 0x7f064ef9d120 JUMP_FORWARD 1 (28)>, <thunder.core.script.graph.Node object at 0x7f064ef9d480 JUMP_ABSOLUTE 14 (None)>])
    CALL_FUNCTION([<thunder.core.script.graph.PhiValue object at 0x7f064efc01c0 PhiValue 0x7f064efc01c0 ()>, <thunder.core.script.graph.PhiValue object at 0x7f064efc0220