In [28]:
import os

import torch
from torch import nn
from triton.testing import do_bench
from torch.profiler import profile, record_function, ProfilerActivity

In [24]:
os.environ['TORCH_COMPILE_DEBUG'] = '1'

## Creating a basic model

In [25]:
class MLP_modified(nn.Module):
    """
    Added in a squared to the relu function so as to see the optimization in action
    """
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(in_features=1, out_features=5)
        self.l2 = nn.Linear(in_features=5, out_features=1)

    def forward(self, x):
        x = self.l1(x).relu() ** 2
        return self.l2(x).relu() ** 2


In [58]:
device = 'cpu'
in_data = torch.arange(0.0,100.0,2).to(device).unsqueeze(dim=1)

## Pytorch Profiler
PyTorch profiler is enabled through the context manager and accepts a number of parameters, some of the most useful are:

- activities - a list of activities to profile:
     - ProfilerActivity.CPU - PyTorch operators, TorchScript functions and user-defined code labels (see record_function below);
     - ProfilerActivity.CUDA - on-device CUDA kernels;
     - ProfilerActivity.XPU - on-device XPU kernels;

- record_shapes - whether to record shapes of the operator inputs;
- profile_memory - whether to report amount of memory consumed by model’s Tensors;

#### Comparing running on cpu v/s mps

In [67]:
device = 'mps'
in_data = torch.arange(0.0,100.0,2).to(device).unsqueeze(dim=1)
model = MLP_modified().to(device)
with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
    with record_function("model_inference"):
        model(in_data)


Note that we can use record_function context manager to label arbitrary code ranges with user provided names (model_inference is used as a label in the example above). Profiler allows one to check which operators were called during the execution of a code range wrapped with a profiler context manager. If multiple profiler ranges are active at the same time (e.g. in parallel PyTorch threads), each profiling context manager tracks only the operators of its corresponding range. Profiler also automatically profiles the asynchronous tasks launched with torch.jit._fork and (in case of a backward pass) the backward pass operators launched with backward() call.

In [68]:
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
              model_inference        15.63%       5.346ms       100.00%      34.210ms      34.210ms             1  
                 aten::linear        68.72%      23.510ms        68.74%      23.516ms      11.758ms             2  
                    aten::pow         7.04%       2.410ms         9.80%       3.354ms       1.677ms             2  
                   aten::relu         5.81%       1.986ms         5.83%       1.994ms     997.061us             2  
                   aten::item         1.88%     644.291us         2.76%     942.957us     471.479us             2  
    aten::_local_scalar_dense         0.87%     298.666us         0.87% 

Note the difference between self cpu time and cpu time - operators can call other operators, self cpu time excludes time spent in children operator calls, while total cpu time includes it. You can choose to sort by the self cpu time by passing sort_by="self_cpu_time_total" into the table call. To get a finer granularity of results and include operator input shapes, pass group_by_input_shape=True (note: this requires running the profiler with record_shapes=True)

In [69]:
device = 'cpu'
in_data = torch.arange(0.0,100.0,2).to(device).unsqueeze(dim=1)
model = MLP_modified().to(device)
with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
    with record_function("model_inference"):
        model(in_data)
        
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

----------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                  Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
----------------------  ------------  ------------  ------------  ------------  ------------  ------------  
       model_inference         4.93%     535.875us       100.00%      10.878ms      10.878ms             1  
          aten::linear         7.77%     845.667us        75.39%       8.200ms       4.100ms             2  
           aten::addmm        33.38%       3.631ms        40.68%       4.425ms       2.213ms             2  
               aten::t        25.28%       2.749ms        26.93%       2.930ms       1.465ms             2  
            aten::relu         4.70%     511.043us        16.41%       1.785ms     892.438us             2  
       aten::clamp_min        11.71%       1.274ms        11.71%       1.274ms     636.916us             2  
           aten::co

#### Memory Consumption Analysis

In [78]:
model = MLP_modified().to(device)
with profile(activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True) as prof:
    model(in_data)

print(prof.key_averages().table(sort_by="cpu_memory_usage", row_limit=10))

----------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                  Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
----------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
          aten::linear         6.96%      36.333us        65.79%     343.199us     171.599us       1.17 Kb           0 b             2  
           aten::addmm        28.92%     150.871us        36.01%     187.828us      93.914us       1.17 Kb       1.17 Kb             2  
            aten::relu         3.63%      18.957us        10.60%      55.290us      27.645us       1.17 Kb           0 b             2  
       aten::clamp_min         6.96%      36.333us         6.96%      36.333us      18.167us       1.17 Kb       1.17 Kb             2  
             aten::pow        22.92%     

#### Tracing Functionality

In [79]:
model = MLP_modified().to(device)
with profile(activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True) as prof:
    model(in_data)

prof.export_chrome_trace("trace.json")

You can examine the sequence of profiled operators and CUDA/XPU kernels in Chrome trace viewer (chrome://tracing)

#### Examining stack traces

In [85]:
model = MLP_modified().to(device)
with profile(activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True, with_stack=True) as prof:
    model(in_data)

# Print aggregated stats
print(prof.key_averages(group_by_stack_n=5).table(sort_by='cpu_time_total', row_limit=2))

----------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                  Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
----------------------  ------------  ------------  ------------  ------------  ------------  ------------  
          aten::linear         9.61%      41.212us        79.98%     343.017us     171.508us             2  
           aten::addmm        36.48%     156.466us        41.34%     177.299us      88.649us             2  
----------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 428.857us



#### Using profiler to analyse long-running jobs

Profiler assumes that the long-running job is composed of steps, numbered starting from zero. The example above defines the following sequence of actions for the profiler:

Parameter skip_first tells profiler that it should ignore the first 10 steps (default value of skip_first is zero);

After the first skip_first steps, profiler starts executing profiler cycles;

Each cycle consists of three phases:

- idling (wait=5 steps), during this phase profiler is not active;
- warming up (warmup=1 steps), during this phase profiler starts tracing, but the results are discarded; this phase is used to discard the samples obtained by the profiler at the beginning of the trace since they are usually skewed by an extra overhead;
- active tracing (active=3 steps), during this phase profiler traces and records data;

An optional repeat parameter specifies an upper bound on the number of cycles. By default (zero value), profiler will execute cycles as long as the job runs.

Thus, in the example above, profiler will skip the first 15 steps, spend the next step on the warm up, actively record the next 3 steps, skip another 5 steps, spend the next step on the warm up, actively record another 3 steps. Since the repeat=2 parameter value is specified, the profiler will stop the recording after the first two cycles.

At the end of each cycle profiler calls the specified on_trace_ready function and passes itself as an argument. This function is used to process the new trace - either by obtaining the table output or by saving the output on disk as a trace file.

To send the signal to the profiler that the next step has started, call prof.step() function. The current profiler step is stored in prof.step_num.

In [92]:
device = 'cpu'
sort_by_keyword = "self_" + device + "_time_total"

def trace_handler(p):
    output = p.key_averages().table(sort_by=sort_by_keyword, row_limit=10)
    print(output)
    p.export_chrome_trace("/tmp/trace_" + str(p.step_num) + ".json")

with profile(
    activities=[ProfilerActivity.CPU],
    schedule=torch.profiler.schedule(
        wait=1,
        warmup=1,
        active=3),
    on_trace_ready=trace_handler
) as p:
    for idx in range(16):
        model(in_data)
        p.step()

----------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                  Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
----------------------  ------------  ------------  ------------  ------------  ------------  ------------  
         ProfilerStep*        53.63%     165.377us       100.00%     308.340us     102.780us             3  
           aten::addmm        14.89%      45.899us        20.53%      63.312us      10.552us             6  
               aten::t         7.17%      22.122us         9.19%      28.329us       4.721us             6  
       aten::clamp_min         6.46%      19.913us         6.46%      19.913us       3.319us             6  
             aten::pow         4.86%      14.997us         5.35%      16.497us       2.750us             6  
           aten::copy_         3.82%      11.789us         3.82%      11.789us       1.965us             6  
            aten::r