torch.compile example that demonstrates how to use torch.compile for inference

In [4]:
import torch
def newfn(x):
  a = torch.cos(x)
  b = torch.sin(a)
  return b

new_fn = torch.compile(newfn, backend="inductor")
input_tensor = torch.randn(10000)
a = new_fn(input_tensor)
print(len(a))

10000


In [7]:
torch.compiler.list_backends()

['cudagraphs', 'inductor', 'openxla', 'tvm']

In [None]:
!pip install torch==2.2.0 torchvision==0.17.0

# Example: optimizing model

In [None]:
import torch
model = torch.hub.load('pytorch/vision', 'resnet50', pretrained=True)
opt_model = torch.compile(model, backend="inductor")
opt_model(torch.randn(1,3,64,64))

# Example: Optimizing pretrained model

In [None]:
import torch
from transformers import BertTokenizer, BertModel


tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
model = torch.compile(model,backend="inductor")

text = "Hello Shravani How Are you?"
encode_input = tokenizer(text, return_tensors='pt')
output = model(**encode_input)
print(output)

# 1. Dynamo Tracing
It generates computation graph

üîç What This Does Internally
Dynamo intercepts execution of f(x)

Wraps input into ProxyTensor

Symbolically executes bytecode

Builds an FX graph

Returns an explanation object

In [4]:
import torch
import torch._dynamo as dynamo

def f(x):
    y = torch.cos(x)
    z = torch.sin(y)
    return z

x = torch.randn(5)

explanation = dynamo.explain(f)(x)
gm = explanation.graphs[0]   # GraphModule

print("Graph Structure:")
for node in gm.graph.nodes:
    print(f"{node.op:15} | {node.name:10} | target={node.target}")

Graph Structure:
placeholder     | l_x_       | target=L_x_
call_function   | y          | target=<built-in method cos of type object at 0x7da47b2e4b40>
call_function   | z          | target=<built-in method sin of type object at 0x7da47b2e4b40>
output          | output     | target=output


# Graph Break
Graph breaks split your program into multiple smaller compiled pieces, reducing optimization and sometimes slowing down torch.compile.

In [6]:
import torch
import torch._dynamo as dynamo

def f(x):
    y = x * 2
    torch.save(y, "temp.pt")  # Unsupported operation
    z = torch.sin(y)
    return z

x = torch.randn(4)

explanation = dynamo.explain(f, x)

print("Number of FX graphs:", len(explanation.graphs))

print("\nFX Graph 1:")
print(explanation.graphs[0])

print("\nFX Graph 2:")
print(explanation.graphs[1])

print("\nGraph Break Reasons:")
for reason in explanation.break_reasons:
    print(reason)

Number of FX graphs: 2

FX Graph 1:
GraphModule()



def forward(self, L_x_ : torch.Tensor):
    l_x_ = L_x_
    y = l_x_ * 2;  l_x_ = None
    return (y,)
    
# To see more debug info, please use `graph_module.print_readable()`

FX Graph 2:
GraphModule()



def forward(self, L_y_ : torch.Tensor):
    l_y_ = L_y_
    z = torch.sin(l_y_);  l_y_ = None
    return (z,)
    
# To see more debug info, please use `graph_module.print_readable()`

Graph Break Reasons:
GraphCompileReason(reason='Attempted to call function marked as skipped\n  Explanation: Dynamo developers have intentionally marked that the function `save` in file `/usr/local/lib/python3.12/dist-packages/torch/serialization.py` should not be traced.\n  Hint: Avoid calling the function `save`.\n  Hint: Apply `@torch._dynamo.dont_skip_tracing` to the function `save` to force tracing into the function. More graph breaks may occur as a result of attempting to trace into the function.\n  Hint: Please file an issue to PyTorch.\n\n  

# 3. Guards
we generate conditions, which are runtime checks for these assumptions

In [10]:
import torch
import torch._dynamo as dynamo

def f(x):
    y = x * 2
    z = torch.sin(y)
    return z

x = torch.randn(4)

explanation = dynamo.explain(f)(x)

print("Number of FX graphs:", len(explanation.graphs))

print("\nFX Graph:")
print(explanation.graphs[0].graph)

print("\nOut Guards:")
for guard in explanation.out_guards:
    print(guard)

Number of FX graphs: 1

FX Graph:
graph():
    %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
    %y : [num_users=1] = call_function[target=operator.mul](args = (%l_x_, 2), kwargs = {})
    %z : [num_users=1] = call_function[target=torch.sin](args = (%y,), kwargs = {})
    return (z,)

Out Guards:
Name: ''
    Source: shape_env
    Create Function: SHAPE_ENV
    Guard Types: None
    Code List: None
    Object Weakref: None
    Guarded Class Weakref: None

Name: ''
    Source: global
    Create Function: DETERMINISTIC_ALGORITHMS
    Guard Types: None
    Code List: None
    Object Weakref: None
    Guarded Class Weakref: None

Name: ''
    Source: global
    Create Function: GRAD_MODE
    Guard Types: None
    Code List: None
    Object Weakref: None
    Guarded Class Weakref: None

Name: ''
    Source: global
    Create Function: DEFAULT_DEVICE
    Guard Types: ['DEFAULT_DEVICE']
    Code List: ['utils_device.CURRENT_DEVICE == None']
    Object Weakref: None
    Guarded

# Recompilations

In [11]:
import torch

@torch.compile
def fn(x):
    return x * 2
# First call (float32)
print(fn(torch.ones(3, 3, dtype=torch.float32)))

# Second call (float64)
print(fn(torch.ones(3, 3, dtype=torch.float64)))

tensor([[2., 2., 2.],
        [2., 2., 2.],
        [2., 2., 2.]])
tensor([[2., 2., 2.],
        [2., 2., 2.],
        [2., 2., 2.]], dtype=torch.float64)
