In [4]:
import torch
from torch._dynamo import optimize
from typing import *

In [32]:
def my_compiler( 
        gm: torch.fx.GraphModule,
        example_inputs: List[torch.Tensor]):
    print("my_compiler() called with FX graph:")
    gm.graph.print_tabular(); print()
    #print(f"code: {gm.graph.python_code()}")
    return gm.forward  # python callable

In [33]:
def toy_example(a, b):
    x = a / (torch.abs(a) + 1)
    if b.sum() < 0:
        b = b * -1
    return x * b

In [34]:
torch._dynamo.reset()
toy_example = optimize(my_compiler)(toy_example)

In [35]:
a = torch.tensor(data=[1. for i in range(6)]).reshape((2,3))
a_neg = torch.tensor(data=[-1. for i in range(6)]).reshape((2,3))
b = torch.randn((2, 3))
for i in range(4):
    print(f">>> Iteration {i}")
    toy_example(a, b * (-1) ** i)

>>> Iteration 0
my_compiler() called with FX graph:
opcode         name     target                                                  args              kwargs
-------------  -------  ------------------------------------------------------  ----------------  --------
placeholder    a        a                                                       ()                {}
placeholder    b        b                                                       ()                {}
call_function  abs_1    <built-in method abs of type object at 0x7f21751e6b80>  (a,)              {}
call_function  add      <built-in function add>                                 (abs_1, 1)        {}
call_function  truediv  <built-in function truediv>                             (a, add)          {}
call_method    sum_1    sum                                                     (b,)              {}
call_function  lt       <built-in function lt>                                  (sum_1, 0)        {}
output         output   outpu

In [36]:
func = lambda x, y : -y if x.sum() < 0 else y

In [45]:
torch._dynamo.reset()
func = optimize(my_compiler)(func)

## The compilation cache is reused

In [46]:
for i in range(6):
    print(f"Iteration {i}:")
    func(a, b)
    func(a_neg, b)

Iteration 0:
my_compiler() called with FX graph:
opcode         name    target                  args        kwargs
-------------  ------  ----------------------  ----------  --------
placeholder    x       x                       ()          {}
call_method    sum_1   sum                     (x,)        {}
call_function  lt      <built-in function lt>  (sum_1, 0)  {}
output         output  output                  ((lt,),)    {}

my_compiler() called with FX graph:
opcode         name    target                   args       kwargs
-------------  ------  -----------------------  ---------  --------
placeholder    y       y                        ()         {}
call_function  neg     <built-in function neg>  (y,)       {}
output         output  output                   ((neg,),)  {}

Iteration 1:
Iteration 2:
Iteration 3:
Iteration 4:
Iteration 5:


## Non-torch function call

In [47]:
import scipy

In [78]:
def draw_example(a, b):
    import numpy as np
    aa = np.randn((2,3))
    sum = a + b
    return sum.numpy() + aa
    return aa

In [79]:
torch._dynamo.reset()
torch._dynamo.config.verbose=True
func = optimize(my_compiler)(draw_example)

In [80]:
func(a, b)

InternalTorchDynamoError: 