In [1]:
import time
import torch
import torch._dynamo as dynamo
import torchvision.models as models

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def foo(x, y):
    a = torch.sin(x)
    b = torch.cos(x)
    return a + b

In [3]:
compiled_model = torch.compile(foo)
start = time.time()
out = compiled_model(torch.randn(10, 10), torch.randn(10, 10))
end = time.time()
print(end - start)

3.6635050773620605


In [4]:
opt_foo1 = dynamo.optimize("inductor")(foo)
start = time.time()
result = opt_foo1(torch.randn(10, 10), torch.randn(10, 10))
end = time.time()
print(end - start)

0.03625226020812988


In [5]:
start = time.time()
result = foo(torch.randn(10, 10), torch.randn(10, 10))
end = time.time()
print(end - start)

0.0009148120880126953


In [6]:
model = models.alexnet()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
compiled_model = torch.compile(model)
x = torch.randn(16, 3, 224, 224)
optimizer.zero_grad()

start = time.time()
out = compiled_model(x)
out.sum().backward()
optimizer.step()
end = time.time()
print(end - start)

1.5946545600891113


In [7]:
start = time.time()
out = model(x)
out.sum().backward()
optimizer.step()
end = time.time()
print(end - start)

0.21536564826965332


In [8]:
count = []
for epoch in range(10):
    start = time.time()
    out = model(x)
    out.sum().backward()
    optimizer.step()
    end = time.time()
    count.append(end - start)
    print('Finished Training:', end - start)
print(sum(count)/len(count))

Finished Training: 0.19880270957946777
Finished Training: 0.19921612739562988
Finished Training: 0.19353675842285156
Finished Training: 0.19151687622070312
Finished Training: 0.19006848335266113
Finished Training: 0.19156360626220703
Finished Training: 0.19502758979797363
Finished Training: 0.19385814666748047
Finished Training: 0.1936194896697998
Finished Training: 0.20826220512390137
0.19554719924926758


In [9]:
count = []
for epoch in range(10):
    start = time.time()
    out = compiled_model(x)
    out.sum().backward()
    optimizer.step()
    end = time.time()
    count.append(end - start)
    print('Finished Training:', end - start)
print(sum(count)/len(count))

Finished Training: 0.19549965858459473
Finished Training: 0.1931750774383545
Finished Training: 0.19072818756103516
Finished Training: 0.18522953987121582
Finished Training: 0.1881251335144043
Finished Training: 0.18946194648742676
Finished Training: 0.18743181228637695
Finished Training: 0.1903071403503418
Finished Training: 0.1918036937713623
Finished Training: 0.20378661155700684
0.19155488014221192


In [10]:
from torch.fx import symbolic_trace, GraphModule

In [13]:
traced = symbolic_trace(foo)
print(traced.graph)

graph():
    %x : [num_users=2] = placeholder[target=x]
    %y : [num_users=0] = placeholder[target=y]
    %sin : [num_users=1] = call_function[target=torch.sin](args = (%x,), kwargs = {})
    %cos : [num_users=1] = call_function[target=torch.cos](args = (%x,), kwargs = {})
    %add : [num_users=1] = call_function[target=operator.add](args = (%sin, %cos), kwargs = {})
    return add
