In [1]:
import torch
from torch._dynamo import optimize
from typing import *
from torch import _dynamo



All the valid backends

In [2]:
_dynamo.list_backends()

['aot_ts_nvfuser',
 'cudagraphs',
 'inductor',
 'ipex',
 'nvprims_nvfuser',
 'onnxrt',
 'tensorrt',
 'tvm']

# optimizer() usage with inductor as backend

A naive example

In [3]:
from torch import nn
import torch.nn.functional as F


class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.fc = nn.Linear(128, 10)

  def forward(self, x):
    x = self.fc(x)
    x = F.relu(x)
    return x

With configs we could alter the behavior of both TorchDynamo and TorchInductor.

In [4]:
from torch._inductor import config as inductor_config
from torch._dynamo import config as dynamo_config
import logging

inductor_config.debug = True
dynamo_config.verbose = True
dynamo_config.log_level = logging.DEBUG
dynamo_config.output_code = True

In [5]:
foo = Net()
foo = torch.compile(foo)

In [6]:
foo
type(foo)

torch._dynamo.eval_frame.OptimizedModule

When enable `inductor.debug`, it could dump the python code it codegened.

In [7]:
a = torch.randn((2, 128))

foo(a)

[2023-02-28 16:11:48,236] torch._dynamo.eval_frame: [DEBUG] skipping __init__ /usr/lib/python3.8/contextlib.py
[2023-02-28 16:11:48,237] torch._dynamo.eval_frame: [DEBUG] skipping __enter__ /usr/lib/python3.8/contextlib.py
[2023-02-28 16:11:48,237] torch._dynamo.eval_frame: [DEBUG] skipping __init__ /usr/lib/python3.8/contextlib.py
[2023-02-28 16:11:48,238] torch._dynamo.eval_frame: [DEBUG] skipping __enter__ /usr/lib/python3.8/contextlib.py
[2023-02-28 16:11:48,238] torch._dynamo.eval_frame: [DEBUG] skipping enable_dynamic /home/chunwei/trienv/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py
[2023-02-28 16:11:48,255] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing forward
[2023-02-28 16:11:48,256] torch._dynamo.symbolic_convert: [DEBUG] TRACE starts_line /tmp/ipykernel_3597807/3534211230.py:11
[2023-02-28 16:11:48,256] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST self []
[2023-02-28 16:11:48,257] torch._dynamo.symbolic_convert: [DEBUG] TRACE 

compile_fx_inner:
class GraphModule(torch.nn.Module):
    def forward(self, primals_1: f32[10, 128], primals_2: f32[10], primals_3: f32[2, 128]):
        # File: /tmp/ipykernel_3597807/3534211230.py:11, code: x = self.fc(x)
        permute: f32[128, 10] = torch.ops.aten.permute.default(primals_1, [1, 0]);  primals_1 = None
        addmm: f32[2, 10] = torch.ops.aten.addmm.default(primals_2, primals_3, permute);  primals_2 = permute = None
        
        # File: /tmp/ipykernel_3597807/3534211230.py:12, code: x = F.relu(x)
        relu: f32[2, 10] = torch.ops.aten.relu.default(addmm);  addmm = None
        le: b8[2, 10] = torch.ops.aten.le.Scalar(relu, 0)
        return [relu, primals_3, le]
        


[2023-02-28 16:11:52,160] torch._inductor.graph: [INFO] Output code: /tmp/torchinductor_chunwei/na/cnanocxvxyalevkulr775i4z4mzgb7pag7gju2ollw75ulnzeqpl.py
[2023-02-28 16:11:52,161] torch._inductor.compile_fx: [INFO] Step 3: torchinductor done compiling FORWARDS graph 0
[2023-02-28 16:11:52,163] torch._dynamo.output_graph: [INFO] Step 2: done compiler function debug_wrapper
[2023-02-28 16:11:52,166] torch._dynamo.output_graph: [INFO] TRACED GRAPH
 __compiled_fn_0 <eval_with_key>.5 opcode         name     target                             args        kwargs
-------------  -------  ---------------------------------  ----------  --------
placeholder    x        x                                  ()          {}
call_module    self_fc  self_fc                            (x,)        {}
call_function  relu     <function relu at 0x7fa2c81fc9d0>  (self_fc,)  {}
output         output   output                             ((relu,),)  {}

[2023-02-28 16:11:52,167] torch._dynamo.convert_frame: [INFO


from ctypes import c_void_p, c_long
import torch
import math
import random
from torch import empty_strided, as_strided, device
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels

aten = torch.ops.aten
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
async_compile = AsyncCompile()

import triton
import triton.language as tl
from torch._inductor.triton_ops.autotune import grid
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream


kernel_cpp_0 = async_compile.cpp('''
#include "/tmp/torchinductor_chunwei/zt/cztcl2vp5yqlnhofzpqfficjcxgyict6e3xhfdd7sdbkipp4p44x.h"
extern "C" void kernel(float* __restrict__ in_out_ptr0,
                       bool* __restrict__ out_ptr0)
{
    {
        #pragma GCC ivdep
        for(long i0=0; i0<20; i0+=1)
        {
            auto tmp0 = in_out_ptr0[i0];
            auto tmp1 = tmp0 * (tmp0>0);
            auto tmp2 = static_cast<float>(0);
            auto tmp3 = tmp

tensor([[0.1221, 0.0000, 0.0000, 0.6039, 0.0000, 0.0000, 0.2984, 0.7128, 0.0169,
         0.0000],
        [0.0000, 0.3757, 0.0000, 0.0000, 0.0000, 0.9155, 0.7447, 0.2878, 0.2604,
         -0.0000]], grad_fn=<CompiledFunctionBackward>)

# Dive into dynamo

According to the definition of `_dynamo.optimize`: 

```python
def optimize(
    backend="inductor",
    *,
    nopython=False,
    guard_export_fn=None,
    guard_fail_fn=None,
    disable=False,
    dynamic=False,
):
```

The `backend` argument could be either a `str` or a `callable`.
Let's hack it with a custom callable to dump something.

In [23]:
my_graph_id = 0
def my_compiler( 
        gm: torch.fx.GraphModule,
        inputs: List[torch.Tensor]):
    global my_graph_id
    print(f"my_compiler() called with FX graph-{my_graph_id}:")
    my_graph_id += 1
    gm.print_readable()
    print()
    #print("tabular:")
    #gm.graph.print_tabular(); print()
    #print(f"code: {gm.graph.python_code()}")
    return gm.forward  # python callable

## Example 1

In [29]:
def foo1(a:torch.tensor, b:torch.tensor):
  x = a + b
  if b.sum() < 0:
    x = x * -1
  return x 

foo1_ = optimize(my_compiler)(foo1)

Note that, this kernel contains a `if` `if b.sum() < 0`, since the `b.sum()` is determined by its value(dynamic), so it should break the graph into two cases:

The first, when the condition is true:

```python
x = a + b
x = x * -1
return x
```

The second, when the condition is false:

```python
x = a + b
return x
```

In [30]:
torch._dynamo.reset() # reset all che compilation cache
dynamo_config.log_level  = logging.INFO

a = torch.ones((2, 3))
b = torch.ones((2, 3))

# It should tigger both cases of the if-else
foo1_(a, b)

[2023-02-28 16:40:53,790] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing foo1
[2023-02-28 16:40:53,795] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function my_compiler
[2023-02-28 16:40:53,795] torch._dynamo.output_graph: [INFO] Step 2: done compiler function my_compiler
[2023-02-28 16:40:53,796] torch._dynamo.output_graph: [INFO] TRACED GRAPH
 __compiled_fn_55 <eval_with_key>.62 opcode         name    target                   args          kwargs
-------------  ------  -----------------------  ------------  --------
placeholder    a       a                        ()            {}
placeholder    b       b                        ()            {}
call_function  add     <built-in function add>  (a, b)        {}
call_method    sum_1   sum                      (b,)          {}
call_function  lt      <built-in function lt>   (sum_1, 0)    {}
output         output  output                   ((add, lt),)  {}

[2023-02-28 16:40:53,797] torch._dynamo.sym

my_compiler() called with FX graph-5:
class GraphModule(torch.nn.Module):
    def forward(self, a : torch.Tensor, b : torch.Tensor):
        # File: /tmp/ipykernel_3597807/1265408374.py:2, code: x = a + b
        add = a + b;  a = None
        
        # File: /tmp/ipykernel_3597807/1265408374.py:3, code: if b.sum() < 0:
        sum_1 = b.sum();  b = None
        lt = sum_1 < 0;  sum_1 = None
        return (add, lt)
        

__resume_at_20_56:
  3           0 JUMP_ABSOLUTE           22
              2 LOAD_FAST                1 (a)
              4 LOAD_FAST                2 (b)
              6 BINARY_ADD
              8 STORE_FAST               0 (x)
             10 LOAD_FAST                2 (b)
             12 LOAD_ATTR                0 (sum)
             14 CALL_FUNCTION            0
             16 LOAD_CONST               1 (0)
             18 COMPARE_OP               0 (<)
             20 POP_JUMP_IF_FALSE       30

  4     >>   22 LOAD_FAST                0 (x)
             24

tensor([[2., 2., 2.],
        [2., 2., 2.]])

In the above case, due to both a and b are positive, it should not goto the if-then block.




In the exaple above, it do break into two graphs, but not from expected:

- graph1: the expressions before the if, with the condition computation
- graph2: the expressions after the if

In [31]:
#torch._dynamo.reset() # reset all che compilation cache
foo1_(a, -b)

[2023-02-28 16:40:56,050] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing <graph break in foo1>
[2023-02-28 16:40:56,052] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo done tracing <graph break in foo1> (RETURN_VALUE)
[2023-02-28 16:40:56,054] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function my_compiler
[2023-02-28 16:40:56,055] torch._dynamo.output_graph: [INFO] Step 2: done compiler function my_compiler
[2023-02-28 16:40:56,055] torch._dynamo.output_graph: [INFO] TRACED GRAPH
 __compiled_fn_58 <eval_with_key>.64 opcode         name    target                   args       kwargs
-------------  ------  -----------------------  ---------  --------
placeholder    x       x                        ()         {}
call_function  mul     <built-in function mul>  (x, -1)    {}
output         output  output                   ((mul,),)  {}

[2023-02-28 16:40:56,056] torch._dynamo.convert_frame: [INFO] ORIGINAL BYTECODE <graph break in foo1>

my_compiler() called with FX graph-6:
class GraphModule(torch.nn.Module):
    def forward(self, x : torch.Tensor):
        # File: /tmp/ipykernel_3597807/1265408374.py:4, code: x = x * -1
        mul = x * -1;  x = None
        return (mul,)
        

transform.output <torch._dynamo.output_graph.OutputGraph object at 0x7fa338ba2dc0>


tensor([[-0., -0., -0.],
        [-0., -0., -0.]])

## Example 2

### Execute once case

In [45]:
def foo2(a:torch.tensor, b:torch.tensor):
  x = a + b
  if b.sum() < 0:
    x = x * -1
  if a.sum() < 0:
    x = x * -2
  x = 2 * x
  return x

foo2_ = optimize(my_compiler)(foo2)

In [46]:
torch._dynamo.reset() # reset all che compilation cache
my_graph_id = 0

a = torch.ones((2, 3))
b = torch.ones((2, 3))

# It should tigger only one case of the if-else
foo2_(a, b)

[2023-02-28 17:18:54,598] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing foo2
[2023-02-28 17:18:54,603] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function my_compiler
[2023-02-28 17:18:54,603] torch._dynamo.output_graph: [INFO] Step 2: done compiler function my_compiler
[2023-02-28 17:18:54,604] torch._dynamo.output_graph: [INFO] TRACED GRAPH
 __compiled_fn_106 <eval_with_key>.108 opcode         name    target                   args          kwargs
-------------  ------  -----------------------  ------------  --------
placeholder    a       a                        ()            {}
placeholder    b       b                        ()            {}
call_function  add     <built-in function add>  (a, b)        {}
call_method    sum_1   sum                      (b,)          {}
call_function  lt      <built-in function lt>   (sum_1, 0)    {}
output         output  output                   ((add, lt),)  {}

[2023-02-28 17:18:54,605] torch._dynamo.s

my_compiler() called with FX graph-0:
class GraphModule(torch.nn.Module):
    def forward(self, a : torch.Tensor, b : torch.Tensor):
        # File: /tmp/ipykernel_3597807/2689548108.py:2, code: x = a + b
        add = a + b;  a = None
        
        # File: /tmp/ipykernel_3597807/2689548108.py:3, code: if b.sum() < 0:
        sum_1 = b.sum();  b = None
        lt = sum_1 < 0;  sum_1 = None
        return (add, lt)
        

__resume_at_20_107:
  3           0 JUMP_ABSOLUTE           22
              2 LOAD_FAST                0 (a)
              4 LOAD_FAST                2 (b)
              6 BINARY_ADD
              8 STORE_FAST               1 (x)
             10 LOAD_FAST                2 (b)
             12 LOAD_ATTR                0 (sum)
             14 CALL_FUNCTION            0
             16 LOAD_CONST               1 (0)
             18 COMPARE_OP               0 (<)
             20 POP_JUMP_IF_FALSE       30

  4     >>   22 LOAD_FAST                1 (x)
             2

tensor([[4., 4., 4.],
        [4., 4., 4.]])

### Exectue all the cases

In [47]:
#torch._dynamo.reset() # reset all che compilation cache
#my_graph_id = 0

# It should tigger all the four combinations of the if-conditions
#foo2_(a, b)
foo2_(-a, b)
#foo2_(-a, b)
#foo2_(-a, -b)

[2023-02-28 17:18:59,100] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing <graph break in foo2>
[2023-02-28 17:18:59,102] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo done tracing <graph break in foo2> (RETURN_VALUE)
[2023-02-28 17:18:59,104] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function my_compiler
[2023-02-28 17:18:59,105] torch._dynamo.output_graph: [INFO] Step 2: done compiler function my_compiler
[2023-02-28 17:18:59,105] torch._dynamo.output_graph: [INFO] TRACED GRAPH
 __compiled_fn_113 <eval_with_key>.114 opcode         name    target                   args         kwargs
-------------  ------  -----------------------  -----------  --------
placeholder    x       x                        ()           {}
call_function  mul     <built-in function mul>  (x, -2)      {}
call_function  mul_1   <built-in function mul>  (2, mul)     {}
output         output  output                   ((mul_1,),)  {}

[2023-02-28 17:18:59,106]

my_compiler() called with FX graph-3:
class GraphModule(torch.nn.Module):
    def forward(self, x : torch.Tensor):
        # File: /tmp/ipykernel_3597807/2689548108.py:6, code: x = x * -2
        mul = x * -2;  x = None
        
        # File: /tmp/ipykernel_3597807/2689548108.py:7, code: x = 2 * x
        mul_1 = 2 * mul;  mul = None
        return (mul_1,)
        

transform.output <torch._dynamo.output_graph.OutputGraph object at 0x7fa338b19340>


tensor([[-0., -0., -0.],
        [-0., -0., -0.]])

In [48]:
foo2_(a, -b)

[2023-02-28 17:19:37,187] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing <graph break in foo2>
[2023-02-28 17:19:37,191] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function my_compiler
[2023-02-28 17:19:37,192] torch._dynamo.output_graph: [INFO] Step 2: done compiler function my_compiler
[2023-02-28 17:19:37,193] torch._dynamo.output_graph: [INFO] TRACED GRAPH
 __compiled_fn_114 <eval_with_key>.116 opcode         name    target                   args          kwargs
-------------  ------  -----------------------  ------------  --------
placeholder    a       a                        ()            {}
placeholder    x       x                        ()            {}
call_function  mul     <built-in function mul>  (x, -1)       {}
call_method    sum_1   sum                      (a,)          {}
call_function  lt      <built-in function lt>   (sum_1, 0)    {}
output         output  output                   ((mul, lt),)  {}

[2023-02-28 17:19:37,193

my_compiler() called with FX graph-4:
class GraphModule(torch.nn.Module):
    def forward(self, a : torch.Tensor, x : torch.Tensor):
        # File: /tmp/ipykernel_3597807/2689548108.py:4, code: x = x * -1
        mul = x * -1;  x = None
        
        # File: /tmp/ipykernel_3597807/2689548108.py:5, code: if a.sum() < 0:
        sum_1 = a.sum();  a = None
        lt = sum_1 < 0;  sum_1 = None
        return (mul, lt)
        

__resume_at_42_115:
  5           0 JUMP_ABSOLUTE           42
              2 LOAD_FAST                1 (a)
              4 LOAD_FAST                2 (b)
              6 BINARY_ADD
              8 STORE_FAST               0 (x)
             10 LOAD_FAST                2 (b)
             12 LOAD_ATTR                0 (sum)
             14 CALL_FUNCTION            0
             16 LOAD_CONST               1 (0)
             18 COMPARE_OP               0 (<)
             20 POP_JUMP_IF_FALSE       30
             22 LOAD_FAST                0 (x)
             

tensor([[-0., -0., -0.],
        [-0., -0., -0.]])

In [49]:
foo2_(-a, -b)

tensor([[-8., -8., -8.],
        [-8., -8., -8.]])

### Python Binary Code for this example

In [35]:
import dis
dis.dis(foo2)

  2           0 LOAD_FAST                0 (a)
              2 LOAD_FAST                1 (b)
              4 BINARY_ADD
              6 STORE_FAST               2 (x)

  3           8 LOAD_FAST                1 (b)
             10 LOAD_METHOD              0 (sum)
             12 CALL_METHOD              0
             14 LOAD_CONST               1 (0)
             16 COMPARE_OP               0 (<)
             18 POP_JUMP_IF_FALSE       28

  4          20 LOAD_FAST                2 (x)
             22 LOAD_CONST               2 (-1)
             24 BINARY_MULTIPLY
             26 STORE_FAST               2 (x)

  5     >>   28 LOAD_FAST                0 (a)
             30 LOAD_METHOD              0 (sum)
             32 CALL_METHOD              0
             34 LOAD_CONST               1 (0)
             36 COMPARE_OP               0 (<)
             38 POP_JUMP_IF_FALSE       48

  6          40 LOAD_FAST                2 (x)
             42 LOAD_CONST               2 (-1)
       

In [36]:
torch._dynamo.reset() # reset all che compilation cache
my_graph_id = 0

a = torch.randn((2, 3))
b = torch.randn((2, 3))

# It should tigger only one case of the if-else
foo2_(a, b)

[2023-02-28 16:44:15,936] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing foo2
[2023-02-28 16:44:15,942] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function my_compiler
[2023-02-28 16:44:15,942] torch._dynamo.output_graph: [INFO] Step 2: done compiler function my_compiler
[2023-02-28 16:44:15,943] torch._dynamo.output_graph: [INFO] TRACED GRAPH
 __compiled_fn_77 <eval_with_key>.82 opcode         name    target                   args          kwargs
-------------  ------  -----------------------  ------------  --------
placeholder    a       a                        ()            {}
placeholder    b       b                        ()            {}
call_function  add     <built-in function add>  (a, b)        {}
call_method    sum_1   sum                      (b,)          {}
call_function  lt      <built-in function lt>   (sum_1, 0)    {}
output         output  output                   ((add, lt),)  {}

[2023-02-28 16:44:15,944] torch._dynamo.sym

my_compiler() called with FX graph-0:
class GraphModule(torch.nn.Module):
    def forward(self, a : torch.Tensor, b : torch.Tensor):
        # File: /tmp/ipykernel_3597807/2473059898.py:2, code: x = a + b
        add = a + b;  a = None
        
        # File: /tmp/ipykernel_3597807/2473059898.py:3, code: if b.sum() < 0:
        sum_1 = b.sum();  b = None
        lt = sum_1 < 0;  sum_1 = None
        return (add, lt)
        

__resume_at_20_78:
  3           0 JUMP_ABSOLUTE           22
              2 LOAD_FAST                0 (a)
              4 LOAD_FAST                2 (b)
              6 BINARY_ADD
              8 STORE_FAST               1 (x)
             10 LOAD_FAST                2 (b)
             12 LOAD_ATTR                0 (sum)
             14 CALL_FUNCTION            0
             16 LOAD_CONST               1 (0)
             18 COMPARE_OP               0 (<)
             20 POP_JUMP_IF_FALSE       30

  4     >>   22 LOAD_FAST                1 (x)
             24

tensor([[ 1.4001, -0.2423, -4.1457],
        [ 1.0497, -1.7117,  1.4444]])

## bytecode Instructions to FX Graph

The execution pass:

1. optimize()
2. convert_frame(), Try to convert a frame into an FX graph, if error leave frame unmodified.
   1. `result = inner_convert(frame, cache_size, hooks)`
3. convert_frame_assert()
4. _compile(), it gets an CompileFn which takes an `fx.GraphModule` as input and outputs a list of `torch.Tensor`.
   1. transform()
      1. `tracer = InstructionTranslator(...)`
      2. `tracer.run()`

Get the instructions from foo2 Code Object.

In [17]:
from torch._dynamo.bytecode_transformation import cleaned_instructions
from torch._dynamo.bytecode_analysis import propagate_line_nums

instructions = cleaned_instructions(foo2.__code__)
propagate_line_nums(instructions)
instructions

[Instruction(opcode=124, opname='LOAD_FAST', arg=0, argval='a', offset=0, starts_line=2, is_jump_target=False, target=None),
 Instruction(opcode=124, opname='LOAD_FAST', arg=1, argval='b', offset=2, starts_line=2, is_jump_target=False, target=None),
 Instruction(opcode=23, opname='BINARY_ADD', arg=None, argval=None, offset=4, starts_line=2, is_jump_target=False, target=None),
 Instruction(opcode=125, opname='STORE_FAST', arg=2, argval='x', offset=6, starts_line=2, is_jump_target=False, target=None),
 Instruction(opcode=124, opname='LOAD_FAST', arg=1, argval='b', offset=8, starts_line=3, is_jump_target=False, target=None),
 Instruction(opcode=106, opname='LOAD_ATTR', arg=0, argval='sum', offset=10, starts_line=3, is_jump_target=False, target=None),
 Instruction(opcode=131, opname='CALL_FUNCTION', arg=0, argval=0, offset=12, starts_line=3, is_jump_target=False, target=None),
 Instruction(opcode=100, opname='LOAD_CONST', arg=1, argval=0, offset=14, starts_line=3, is_jump_target=False, tar