In [7]:
import torch
import traceback as tb
from functorch.experimental.control_flow import cond

torch._logging.set_logs(graph_code=True)

## Basics

In [2]:
def foo(x, y):
    a = torch.sin(x)
    b = torch.cos(y)
    return a + b


opt_foo1 = torch.compile(foo)
print(opt_foo1(torch.randn(3, 3), torch.randn(3, 3)))


@torch.compile
def opt_foo2(x, y):
    a = torch.sin(x)
    b = torch.cos(y)
    return a + b


print(opt_foo2(torch.randn(3, 3), torch.randn(3, 3)))

V0209 03:54:40.007000 926 torch/_dynamo/output_graph.py:1983] [0/0] [__graph_code] TRACED GRAPH
V0209 03:54:40.007000 926 torch/_dynamo/output_graph.py:1983] [0/0] [__graph_code]  ===== __compiled_fn_1_5bd9b235_7c5c_439c_b2dc_f2819823176a =====
V0209 03:54:40.007000 926 torch/_dynamo/output_graph.py:1983] [0/0] [__graph_code]  /usr/local/lib/python3.12/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
V0209 03:54:40.007000 926 torch/_dynamo/output_graph.py:1983] [0/0] [__graph_code]     def forward(self, L_x_: "f32[3, 3][3, 1]cpu", L_y_: "f32[3, 3][3, 1]cpu"):
V0209 03:54:40.007000 926 torch/_dynamo/output_graph.py:1983] [0/0] [__graph_code]         l_x_ = L_x_
V0209 03:54:40.007000 926 torch/_dynamo/output_graph.py:1983] [0/0] [__graph_code]         l_y_ = L_y_
V0209 03:54:40.007000 926 torch/_dynamo/output_graph.py:1983] [0/0] [__graph_code]         
V0209 03:54:40.007000 926 torch/_dynamo/output_graph.py:1983] [0/0] [__graph_code]          # File: /tmp

tensor([[ 1.8430,  0.0587,  0.1466],
        [ 1.0785, -0.8453,  1.3462],
        [-0.1504,  0.6011,  0.4347]])
tensor([[ 1.4312,  1.9973, -0.6420],
        [ 1.0545,  0.8127,  1.6647],
        [-0.3056,  1.3990,  1.2609]])


In [3]:
def inner(x):
    return torch.sin(x)


@torch.compile
def outer(x, y):
    a = inner(x)
    b = torch.cos(y)
    return a + b


print(outer(torch.randn(3, 3), torch.randn(3, 3)))

V0209 03:58:17.994000 926 torch/_dynamo/output_graph.py:1983] [2/0] [__graph_code] TRACED GRAPH
V0209 03:58:17.994000 926 torch/_dynamo/output_graph.py:1983] [2/0] [__graph_code]  ===== __compiled_fn_5_5af9f1b7_44f4_448d_b25c_41b2bd9ba27e =====
V0209 03:58:17.994000 926 torch/_dynamo/output_graph.py:1983] [2/0] [__graph_code]  /usr/local/lib/python3.12/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
V0209 03:58:17.994000 926 torch/_dynamo/output_graph.py:1983] [2/0] [__graph_code]     def forward(self, L_x_: "f32[3, 3][3, 1]cpu", L_y_: "f32[3, 3][3, 1]cpu"):
V0209 03:58:17.994000 926 torch/_dynamo/output_graph.py:1983] [2/0] [__graph_code]         l_x_ = L_x_
V0209 03:58:17.994000 926 torch/_dynamo/output_graph.py:1983] [2/0] [__graph_code]         l_y_ = L_y_
V0209 03:58:17.994000 926 torch/_dynamo/output_graph.py:1983] [2/0] [__graph_code]         
V0209 03:58:17.994000 926 torch/_dynamo/output_graph.py:1983] [2/0] [__graph_code]          # File: /tmp

tensor([[ 0.6077,  1.1759,  1.4414],
        [-0.8267,  1.8781,  1.2680],
        [ 0.0673,  0.3714,  0.1308]])


## Torch Modules

In [4]:
t = torch.randn(10, 100)


class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = torch.nn.Linear(3, 3)

    def forward(self, x):
        return torch.nn.functional.relu(self.lin(x))


mod1 = MyModule()
mod1.compile()
print(mod1(torch.randn(3, 3)))

mod2 = MyModule()
mod2 = torch.compile(mod2)
print(mod2(torch.randn(3, 3)))

V0209 04:03:03.683000 926 torch/_dynamo/output_graph.py:1983] [3/0] [__graph_code] TRACED GRAPH
V0209 04:03:03.683000 926 torch/_dynamo/output_graph.py:1983] [3/0] [__graph_code]  ===== __compiled_fn_7_c3138e92_e800_426b_87bd_24279e6678f1 =====
V0209 04:03:03.683000 926 torch/_dynamo/output_graph.py:1983] [3/0] [__graph_code]  /usr/local/lib/python3.12/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
V0209 04:03:03.683000 926 torch/_dynamo/output_graph.py:1983] [3/0] [__graph_code]     def forward(self, L_self_modules_lin_parameters_weight_: "f32[3, 3][3, 1]cpu", L_self_modules_lin_parameters_bias_: "f32[3][1]cpu", L_x_: "f32[3, 3][3, 1]cpu"):
V0209 04:03:03.683000 926 torch/_dynamo/output_graph.py:1983] [3/0] [__graph_code]         l_self_modules_lin_parameters_weight_ = L_self_modules_lin_parameters_weight_
V0209 04:03:03.683000 926 torch/_dynamo/output_graph.py:1983] [3/0] [__graph_code]         l_self_modules_lin_parameters_bias_ = L_self_modules_lin

tensor([[0.0000, 0.4400, 0.3770],
        [0.0000, 0.9943, 0.2302],
        [0.0000, 1.1746, 0.6812]], grad_fn=<CompiledFunctionBackward>)
tensor([[0.0000, 0.0000, 0.3902],
        [0.3950, 0.0000, 0.3283],
        [0.4503, 0.0000, 0.0000]], grad_fn=<CompiledFunctionBackward>)


## Speedup Demo

In [2]:
def foo3(x):
    y = x + 1
    z = torch.nn.functional.relu(y)
    u = z * 2
    return u


opt_foo3 = torch.compile(foo3)


# Returns the result of running `fn()` and the time it took for `fn()` to run,
# in seconds. We use CUDA events and synchronization for the most accurate
# measurements.
def timed(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000


inp = torch.randn(4096, 4096).cuda()
# print("compile:", timed(lambda: opt_foo3(inp))[1])
# print("eager:", timed(lambda: foo3(inp))[1])

In [7]:
# turn off logging for now to prevent spam
torch._logging.set_logs(graph_code=False)

eager_times = []
for i in range(10):
    _, eager_time = timed(lambda: foo3(inp))
    eager_times.append(eager_time)
    print(f"eager time {i}: {eager_time}")
print("~" * 10)

compile_times = []
for i in range(10):
    _, compile_time = timed(lambda: opt_foo3(inp))
    compile_times.append(compile_time)
    print(f"compile time {i}: {compile_time}")
print("~" * 10)

import numpy as np

eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
assert speedup > 1
print(
    f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x"
)
print("~" * 10)

eager time 0: 0.001882688045501709
eager time 1: 0.0019129279851913452
eager time 2: 0.0017240320444107055
eager time 3: 0.0017167359590530395
eager time 4: 0.0017345600128173828
eager time 5: 0.0017388479709625243
eager time 6: 0.0017455040216445923
eager time 7: 0.001720960021018982
eager time 8: 0.0017236160039901733
eager time 9: 0.0017163840532302857
~~~~~~~~~~
compile time 0: 0.0008069120049476623
compile time 1: 0.0007274240255355835
compile time 2: 0.0006666560173034668
compile time 3: 0.0006573759913444519
compile time 4: 0.0006492159962654114
compile time 5: 0.0006516799926757813
compile time 6: 0.0006692799925804139
compile time 7: 0.0006451519727706909
compile time 8: 0.0006405760049819946
compile time 9: 0.0006440960168838501
~~~~~~~~~~
(eval) eager median: 0.0017292960286140443, compile median: 0.0006545279920101166, speedup: 2.6420505306476114x
~~~~~~~~~~


## TorchScript Comparison

In [8]:
def f1(x, y):
    if x.sum() < 0:
        return -y
    return y


# Test that `fn1` and `fn2` return the same result, given the same arguments `args`.
def test_fns(fn1, fn2, args):
    out1 = fn1(*args)
    out2 = fn2(*args)
    return torch.allclose(out1, out2)


inp1 = torch.randn(5, 5)
inp2 = torch.randn(5, 5)

traced_f1 = torch.jit.trace(f1, (inp1, inp2))
print("traced 1, 1:", test_fns(f1, traced_f1, (inp1, inp2)))
print("traced 1, 2:", test_fns(f1, traced_f1, (-inp1, inp2)))

compile_f1 = torch.compile(f1)
print("compile 1, 1:", test_fns(f1, compile_f1, (inp1, inp2)))
print("compile 1, 2:", test_fns(f1, compile_f1, (-inp1, inp2)))
print("~" * 10)

  if x.sum() < 0:


traced 1, 1: True
traced 1, 2: False
compile 1, 1: True
compile 1, 2: True
~~~~~~~~~~


In [9]:
torch._logging.set_logs(graph_code=True)


def f2(x, y):
    return x + y


inp1 = torch.randn(5, 5)
inp2 = 3

script_f2 = torch.jit.script(f2)
try:
    script_f2(inp1, inp2)
except:
    tb.print_exc()

compile_f2 = torch.compile(f2)
print("compile 2:", test_fns(f2, compile_f2, (inp1, inp2)))
print("~" * 10)

Traceback (most recent call last):
  File "/tmp/ipython-input-3652677659.py", line 15, in <cell line: 0>
    script_f2(inp1, inp2)
RuntimeError: f2() Expected a value of type 'Tensor (inferred)' for argument 'y' but instead found type 'int'.
Inferred 'y' to be of type 'Tensor' because it was not annotated with an explicit type.
Position: 1
Value: 3
Declaration: f2(Tensor x, Tensor y) -> Tensor
Cast error details: Unable to cast 3 to Tensor
V0209 04:24:12.132000 435 torch/_dynamo/output_graph.py:1983] [4/0] [__graph_code] TRACED GRAPH
V0209 04:24:12.132000 435 torch/_dynamo/output_graph.py:1983] [4/0] [__graph_code]  ===== __compiled_fn_10_1ebbd991_8d47_4d5d_a7d1_236cd1b30dc8 =====
V0209 04:24:12.132000 435 torch/_dynamo/output_graph.py:1983] [4/0] [__graph_code]  /usr/local/lib/python3.12/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
V0209 04:24:12.132000 435 torch/_dynamo/output_graph.py:1983] [4/0] [__graph_code]     def forward(self, L_x_: "f32[5, 

compile 2: True
~~~~~~~~~~


## Graph Breaks

In [8]:
def bar(a, b):
    x = a / (torch.abs(a) + 1)
    if b.sum() < 0:
        b = b * -1
    return x * b


opt_bar = torch.compile(bar)
inp1 = torch.ones(10)
inp2 = torch.ones(10)
opt_bar(inp1, inp2)

V0209 06:24:20.609000 950 torch/_dynamo/output_graph.py:1983] [1/0] [__graph_code] TRACED GRAPH
V0209 06:24:20.609000 950 torch/_dynamo/output_graph.py:1983] [1/0] [__graph_code]  ===== __compiled_fn_18_c3e4e3f1_d6f7_4392_930d_4d53a903df0f =====
V0209 06:24:20.609000 950 torch/_dynamo/output_graph.py:1983] [1/0] [__graph_code]  /usr/local/lib/python3.12/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
V0209 06:24:20.609000 950 torch/_dynamo/output_graph.py:1983] [1/0] [__graph_code]     def forward(self, L_a_: "f32[10][1]cpu", L_b_: "f32[10][1]cpu"):
V0209 06:24:20.609000 950 torch/_dynamo/output_graph.py:1983] [1/0] [__graph_code]         l_a_ = L_a_
V0209 06:24:20.609000 950 torch/_dynamo/output_graph.py:1983] [1/0] [__graph_code]         l_b_ = L_b_
V0209 06:24:20.609000 950 torch/_dynamo/output_graph.py:1983] [1/0] [__graph_code]         
V0209 06:24:20.609000 950 torch/_dynamo/output_graph.py:1983] [1/0] [__graph_code]          # File: /tmp/ipython-

tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000])

In [9]:
opt_bar(inp1, -inp2)

V0209 06:24:21.182000 950 torch/_dynamo/output_graph.py:1983] [3/0] [__graph_code] TRACED GRAPH
V0209 06:24:21.182000 950 torch/_dynamo/output_graph.py:1983] [3/0] [__graph_code]  ===== __compiled_fn_24_ef61c6ba_0cf7_4f88_94f4_297fb9600c28 =====
V0209 06:24:21.182000 950 torch/_dynamo/output_graph.py:1983] [3/0] [__graph_code]  /usr/local/lib/python3.12/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
V0209 06:24:21.182000 950 torch/_dynamo/output_graph.py:1983] [3/0] [__graph_code]     def forward(self, L_b_: "f32[10][1]cpu", L_x_: "f32[10][1]cpu"):
V0209 06:24:21.182000 950 torch/_dynamo/output_graph.py:1983] [3/0] [__graph_code]         l_b_ = L_b_
V0209 06:24:21.182000 950 torch/_dynamo/output_graph.py:1983] [3/0] [__graph_code]         l_x_ = L_x_
V0209 06:24:21.182000 950 torch/_dynamo/output_graph.py:1983] [3/0] [__graph_code]         
V0209 06:24:21.182000 950 torch/_dynamo/output_graph.py:1983] [3/0] [__graph_code]          # File: /tmp/ipython-

tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000])

In [10]:
# Reset to clear the torch.compile cache
torch._dynamo.reset()
opt_bar(inp1, inp2)
opt_bar(inp1, -inp2)

V0209 06:24:22.544000 950 torch/_dynamo/output_graph.py:1983] [0/0] [__graph_code] TRACED GRAPH
V0209 06:24:22.544000 950 torch/_dynamo/output_graph.py:1983] [0/0] [__graph_code]  ===== __compiled_fn_26_fb8ad3cf_e08d_4cab_8843_e799fbc2c4fa =====
V0209 06:24:22.544000 950 torch/_dynamo/output_graph.py:1983] [0/0] [__graph_code]  /usr/local/lib/python3.12/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
V0209 06:24:22.544000 950 torch/_dynamo/output_graph.py:1983] [0/0] [__graph_code]     def forward(self, L_a_: "f32[10][1]cpu", L_b_: "f32[10][1]cpu"):
V0209 06:24:22.544000 950 torch/_dynamo/output_graph.py:1983] [0/0] [__graph_code]         l_a_ = L_a_
V0209 06:24:22.544000 950 torch/_dynamo/output_graph.py:1983] [0/0] [__graph_code]         l_b_ = L_b_
V0209 06:24:22.544000 950 torch/_dynamo/output_graph.py:1983] [0/0] [__graph_code]         
V0209 06:24:22.544000 950 torch/_dynamo/output_graph.py:1983] [0/0] [__graph_code]          # File: /tmp/ipython-

tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000])

In [11]:
# Reset to clear the torch.compile cache
torch._dynamo.reset()

opt_bar_fullgraph = torch.compile(bar, fullgraph=True)
try:
    opt_bar_fullgraph(torch.randn(10), torch.randn(10))
except:
    tb.print_exc()

Traceback (most recent call last):
  File "/tmp/ipython-input-387069252.py", line 6, in <cell line: 0>
    opt_bar_fullgraph(torch.randn(10), torch.randn(10))
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 841, in compile_wrapper
    raise e.with_traceback(None) from e.__cause__  # User compiler error
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.Unsupported: Data-dependent branching
  Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow.
  Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround.
  Hint: Use `torch.cond` to express dynamic control flow.

  Developer debug context: attempted to jump with TensorVariable()

 For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0170.html

from user code:
 

In [12]:
@torch.compile(fullgraph=True)
def bar_fixed(a, b):
    x = a / (torch.abs(a) + 1)

    def true_branch(y):
        return y * -1

    def false_branch(y):
        # NOTE: torch.cond doesn't allow aliased outputs
        return y.clone()

    b = cond(b.sum() < 0, true_branch, false_branch, (b,))
    return x * b


bar_fixed(inp1, inp2)
bar_fixed(inp1, -inp2)

V0209 06:28:42.417000 950 torch/_dynamo/output_graph.py:1983] [1/0] [__graph_code] TRACED GRAPH
V0209 06:28:42.417000 950 torch/_dynamo/output_graph.py:1983] [1/0] [__graph_code]  ===== __compiled_fn_35_8727a923_e9db_42be_9ff3_b856c774bdac =====
V0209 06:28:42.417000 950 torch/_dynamo/output_graph.py:1983] [1/0] [__graph_code]  /usr/local/lib/python3.12/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
V0209 06:28:42.417000 950 torch/_dynamo/output_graph.py:1983] [1/0] [__graph_code]     def forward(self, L_a_: "f32[10][1]cpu", L_b_: "f32[10][1]cpu"):
V0209 06:28:42.417000 950 torch/_dynamo/output_graph.py:1983] [1/0] [__graph_code]         l_a_ = L_a_
V0209 06:28:42.417000 950 torch/_dynamo/output_graph.py:1983] [1/0] [__graph_code]         l_b_ = L_b_
V0209 06:28:42.417000 950 torch/_dynamo/output_graph.py:1983] [1/0] [__graph_code]         
V0209 06:28:42.417000 950 torch/_dynamo/output_graph.py:1983] [1/0] [__graph_code]          # File: /tmp/ipython-

tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000])