In [None]:
import torch
import torchvision

Initialize DNN model and compile using TorchInductor

In [None]:
model = torchvision.models.resnet50(pretrained=True).to("cuda")

In [None]:
torch._dynamo.reset()
resnet50_compiled = torch.compile(
        model,
        options={
            "trace.enabled": True,
        },
)

Set up the training loop

In [None]:
# for this example, we generate one random sample
inputs = torch.randn(64, 3, 224, 224).to("cuda")
labels = torch.randn(64, 1000).to("cuda")

# initialize the loss calculation and optimizer
learning_rate = 0.001
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(resnet50_compiled.parameters(), lr=learning_rate)

Wrap optimizer.step() in torch.compile()

In [None]:
def optimizer_step_fn(optimizer):
    '''Return torch.compile'd version of optimizer.step()'''
    def f():
        optimizer.step()
    return torch.compile(
        f,
        options={
            "trace.enabled": True,
        },
    )

optimizer_step = optimizer_step_fn(optimizer)

(unused) Wrap forward() and loss computation into single graph

In [None]:
# def forward_fn(model_fn, loss_fn):
#     '''Return torch.compile'd version of forward pass and loss calculation'''
#     def f(inputs, labels):
#         outputs = model_fn(inputs)
#         loss = loss_fn(outputs, labels)
#         return outputs, loss
#     return torch.compile(
#         f,
#         options={
#             "trace.enabled": True,
#         },
#     )
# forward = forward_fn(resnet50_compiled, criterion)

Run one training iteration

In [None]:
# Zero out the optimizer
optimizer.zero_grad()

# Forward pass
outputs = resnet50_compiled(inputs)
loss = criterion(outputs, labels) # torch.nn.CrossEntropyLoss()
# outputs, loss = forward(inputs, labels)

# Backward pass
loss.backward()

# parameter update
optimizer_step()

Example operator profiling loop

In [None]:
import time
num_iter = 1000
device = "cuda"

# allocate dummy inputs
primals_321 = torch.randn(64, 3, 224, 224).to(device)
primals_1 = torch.randn(64, 3, 7, 7).to(device)

t0 = time.time()

# profile the operator
for _ in range(num_iter):
    convolution = torch.ops.aten.convolution.default(primals_321, primals_1, None, [2, 2], [3, 3], [1, 1], False, [0, 0], 1)
torch.cuda.current_stream().synchronize()

t1 = time.time()

print(f"Time taken: {(t1 - t0) / num_iter * 1000} ms")