In [1]:
import logging
import torch
import torch._dynamo
import torch._inductor
import torch.nn as nn
from torch.profiler import profile, record_function, ProfilerActivity

In [2]:
import os
#os.environ['TORCH_COMPILE_DEBUG'] = "1"
#torch._dynamo.config.log_level = logging.DEBUG
#torch._dynamo.config.verbose = True
#torch._dynamo.config.log_level = logging.INFO
#torch._dynamo.config.output_code = True
torch._dynamo.config.cache_size_limit = 1
torch.set_default_device('cuda')

In [3]:
from typing import List
import tabulate
def custom_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print("custom backend called with FX graph:")
    gm.graph.print_tabular()
    #print(gm.code)
    return gm.forward

In [13]:
class SingleConvLayer(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.batch = nn.BatchNorm2d(num_features=32)
        self.relu = nn.ReLU()
        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, input):
        out = self.conv(input)
        out = self.batch(out)
        out = self.relu(out)
        out = self.max_pool(out)
        return out

In [14]:
data = torch.rand(2048, 3, 32, 32)

In [19]:
eager_model = SingleConvLayer()

torch._dynamo.reset()
#graph_model = torch.compile(CNN(), backend="inductor", fullgraph=True)
graph_model = torch.compile(SingleConvLayer(), backend="inductor", fullgraph=True)

In [20]:
prof = profile(activities=[ProfilerActivity.CUDA])

prof.start()
graph_model(data)
#eager_model(data)
prof.stop()

print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=10))

STAGE:2023-05-30 13:19:09 3433778:3433778 ActivityProfilerController.cpp:311] Completed Stage: Warm Up


-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
            cudnn_ampere_scudnn_128x32_relu_small_nn_v1         0.00%       0.000us         0.00%       0.000us       0.000us     401.000us        19.54%     401.000us     401.000us             1  
                                      triton__0d1d2d3d4         0.00%       0.000us         0.00%       0.000us       0.000us     349.000us        17.01%     349.000us     349.000us             1  
         

STAGE:2023-05-30 13:19:09 3433778:3433778 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2023-05-30 13:19:09 3433778:3433778 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


In [8]:
explanation, out_guards, graphs, ops_per_graph, break_reasons, explanation_verbose = torch._dynamo.explain(
    graph_model, data)
print(explanation_verbose)

Dynamo produced 1 graphs with 0 graph break and 0 ops
 Break reasons: 

1. return_value
  File "/tmp/ipykernel_3433778/2101468636.py", line 14, in forward
    return out
 
TorchDynamo compilation metrics:
Function                          Runtimes (s)
------------------------------  --------------
_compile                                0.0221
OutputGraph.call_user_compiler          0


In [9]:
print(graphs)

[GraphModule(
  (self_conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (self_batch): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (self_relu): ReLU()
  (self_max_pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)]
