# Short tests to figure out how to use torch.compile

## basic usage

based on this [tutorial](https://docs.pytorch.org/tutorials/intermediate/torch_compile_tutorial.html)

```python

In [3]:
import torch

torch.manual_seed(0)
torch._logging.set_logs(graph_code=True)
random_input = torch.randn(3, 3)

In [4]:
def foo(x, y):
    a = torch.sin(x)
    b = torch.cos(y)
    return a + b


opt_foo1 = torch.compile(foo)
print(opt_foo1(random_input, random_input))


@torch.compile
def opt_foo2(x, y):
    a = torch.sin(x)
    b = torch.cos(y)
    return a + b


print(opt_foo2(random_input, random_input))

V1224 15:01:16.498000 12061 torch/fx/passes/runtime_assert.py:118] [2/0] [__graph_code] TRACED GRAPH
V1224 15:01:16.498000 12061 torch/fx/passes/runtime_assert.py:118] [2/0] [__graph_code]  ===== pre insert_deferred_runtime_asserts __compiled_fn_5 =====
V1224 15:01:16.498000 12061 torch/fx/passes/runtime_assert.py:118] [2/0] [__graph_code]  <eval_with_key>.8 class GraphModule(torch.nn.Module):
V1224 15:01:16.498000 12061 torch/fx/passes/runtime_assert.py:118] [2/0] [__graph_code]     def forward(self, L_x_: "f32[3, 3]"):
V1224 15:01:16.498000 12061 torch/fx/passes/runtime_assert.py:118] [2/0] [__graph_code]         l_x_ = L_x_
V1224 15:01:16.498000 12061 torch/fx/passes/runtime_assert.py:118] [2/0] [__graph_code]         
V1224 15:01:16.498000 12061 torch/fx/passes/runtime_assert.py:118] [2/0] [__graph_code]          # File: /tmp/ipykernel_12061/1116440188.py:2 in foo, code: a = torch.sin(x)
V1224 15:01:16.498000 12061 torch/fx/passes/runtime_assert.py:118] [2/0] [__graph_code]        

tensor([[ 1.0294,  0.6680, -1.3920],
        [ 1.3811, -0.4167, -0.8139],
        [ 1.3123,  1.4123,  0.0935]])
tensor([[ 1.0294,  0.6680, -1.3920],
        [ 1.3811, -0.4167, -0.8139],
        [ 1.3123,  1.4123,  0.0935]])


In [5]:
def foo3(x):
    y = x + 1
    z = torch.nn.functional.relu(y)
    u = z * 2
    return u


opt_foo3 = torch.compile(foo3)


# Returns the result of running `fn()` and the time it took for `fn()` to run,
# in seconds. We use CUDA events and synchronization for the most accurate
# measurements.
def timed(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000


inp = torch.randn(4096, 4096).cuda()
print("compile:", timed(lambda: opt_foo3(inp))[1])
print("eager:", timed(lambda: foo3(inp))[1])

V1224 15:02:56.304000 12061 torch/fx/passes/runtime_assert.py:118] [4/0] [__graph_code] TRACED GRAPH
V1224 15:02:56.304000 12061 torch/fx/passes/runtime_assert.py:118] [4/0] [__graph_code]  ===== pre insert_deferred_runtime_asserts __compiled_fn_9 =====
V1224 15:02:56.304000 12061 torch/fx/passes/runtime_assert.py:118] [4/0] [__graph_code]  <eval_with_key>.16 class GraphModule(torch.nn.Module):
V1224 15:02:56.304000 12061 torch/fx/passes/runtime_assert.py:118] [4/0] [__graph_code]     def forward(self, L_x_: "f32[4096, 4096]"):
V1224 15:02:56.304000 12061 torch/fx/passes/runtime_assert.py:118] [4/0] [__graph_code]         l_x_ = L_x_
V1224 15:02:56.304000 12061 torch/fx/passes/runtime_assert.py:118] [4/0] [__graph_code]         
V1224 15:02:56.304000 12061 torch/fx/passes/runtime_assert.py:118] [4/0] [__graph_code]          # File: /tmp/ipykernel_12061/2208011240.py:2 in foo3, code: y = x + 1
V1224 15:02:56.304000 12061 torch/fx/passes/runtime_assert.py:118] [4/0] [__graph_code]       

compile: 0.8350673217773438
eager: 0.0952995834350586


In [6]:
# turn off logging for now to prevent spam
torch._logging.set_logs(graph_code=False)

eager_times = []
for i in range(10):
    _, eager_time = timed(lambda: foo3(inp))
    eager_times.append(eager_time)
    print(f"eager time {i}: {eager_time}")
print("~" * 10)

compile_times = []
for i in range(10):
    _, compile_time = timed(lambda: opt_foo3(inp))
    compile_times.append(compile_time)
    print(f"compile time {i}: {compile_time}")
print("~" * 10)

import numpy as np

eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
assert speedup > 1
print(
    f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x"
)
print("~" * 10)

eager time 0: 0.020702207565307617
eager time 1: 0.023580671310424805
eager time 2: 0.0192174072265625
eager time 3: 0.020316160202026368
eager time 4: 0.021369855880737306
eager time 5: 0.02023219108581543
eager time 6: 0.020067327499389647
eager time 7: 0.02007347106933594
eager time 8: 0.021485567092895508
eager time 9: 0.020282367706298828
~~~~~~~~~~
compile time 0: 0.009602047920227052
compile time 1: 0.006400000095367431
compile time 2: 0.006756351947784424
compile time 3: 0.007633920192718506
compile time 4: 0.007480319976806641
compile time 5: 0.007010303974151612
compile time 6: 0.00790118408203125
compile time 7: 0.006790143966674805
compile time 8: 0.006490111827850342
compile time 9: 0.00652185583114624
~~~~~~~~~~
(eval) eager median: 0.020299263954162598, compile median: 0.006900223970413209, speedup: 2.9418268220280694x
~~~~~~~~~~


In [7]:
def f1(x, y):
    if x.sum() < 0:
        return -y
    return y


# Test that `fn1` and `fn2` return the same result, given the same arguments `args`.
def test_fns(fn1, fn2, args):
    out1 = fn1(*args)
    out2 = fn2(*args)
    return torch.allclose(out1, out2)


inp1 = torch.randn(5, 5)
inp2 = torch.randn(5, 5)

traced_f1 = torch.jit.trace(f1, (inp1, inp2))
print("traced 1, 1:", test_fns(f1, traced_f1, (inp1, inp2)))
print("traced 1, 2:", test_fns(f1, traced_f1, (-inp1, inp2)))

compile_f1 = torch.compile(f1)
print("compile 1, 1:", test_fns(f1, compile_f1, (inp1, inp2)))
print("compile 1, 2:", test_fns(f1, compile_f1, (-inp1, inp2)))
print("~" * 10)

  if x.sum() < 0:


traced 1, 1: True
traced 1, 2: False
compile 1, 1: True
compile 1, 2: True
~~~~~~~~~~


In [8]:
import traceback as tb

torch._logging.set_logs(graph_code=True)


def f2(x, y):
    return x + y


inp1 = torch.randn(5, 5)
inp2 = 3

script_f2 = torch.jit.script(f2)
try:
    script_f2(inp1, inp2)
except:
    tb.print_exc()

compile_f2 = torch.compile(f2)
print("compile 2:", test_fns(f2, compile_f2, (inp1, inp2)))
print("~" * 10)

Traceback (most recent call last):
  File "/tmp/ipykernel_12061/3652677659.py", line 15, in <module>
    script_f2(inp1, inp2)
RuntimeError: f2() Expected a value of type 'Tensor (inferred)' for argument 'y' but instead found type 'int'.
Inferred 'y' to be of type 'Tensor' because it was not annotated with an explicit type.
Position: 1
Value: 3
Declaration: f2(Tensor x, Tensor y) -> Tensor
Cast error details: Unable to cast 3 to Tensor
V1224 15:10:02.253000 12061 torch/fx/passes/runtime_assert.py:118] [8/0] [__graph_code] TRACED GRAPH
V1224 15:10:02.253000 12061 torch/fx/passes/runtime_assert.py:118] [8/0] [__graph_code]  ===== pre insert_deferred_runtime_asserts __compiled_fn_18 =====
V1224 15:10:02.253000 12061 torch/fx/passes/runtime_assert.py:118] [8/0] [__graph_code]  <eval_with_key>.31 class GraphModule(torch.nn.Module):
V1224 15:10:02.253000 12061 torch/fx/passes/runtime_assert.py:118] [8/0] [__graph_code]     def forward(self, L_x_: "f32[5, 5]"):
V1224 15:10:02.253000 12061 tor

compile 2: True
~~~~~~~~~~
