In [None]:
from tqdm import tqdm
from pprint import pprint
import json
import pandas as pd
import sys
sys.path.append('..')
from Trace2Tree.trace_to_tree import TraceToTree
from tree_perf import TreePerfAnalyzer

In [40]:
# replace by your profile path, it can be a single rank profile from a multi gpu run as well
path = '/home/ajassani/jan22_2025/AMD_Instinct_MI325X_llama3_70b_bf16_bsz_18_trace/quanta-cyxtera-r35a-3_172800.1736558806854893773.pt.trace.json'
with open(path, 'r') as f:
    data = json.load(f)

events = data['traceEvents']
tree = TraceToTree(events)
tree.build_tree(add_python_func=False)
perf_analyzer = TreePerfAnalyzer(tree)

Building tree with add_python_func=False
Building CPU op tree with add_python_func=False


In [41]:
# get breakdown of gpu timeline - busy time, idle time, communication time, etc
perf_analyzer.get_df_gpu_timeline()

Unnamed: 0,type,time ms,percent
0,busy_time,6521.458211,99.927717
1,computation_time,6318.257587,96.814092
2,exposed_communication_time,203.05789,3.111438
3,exposed_memcpy_time,0.142734,0.002187
4,idle_time,4.717306,0.072283
5,total_time,6526.175517,100.0


In [42]:
# table of all lowest-level CPU operations (from the call stack perspective)
# and the time they "induce" on the GPU
df_kernel_launchers = perf_analyzer.get_df_kernel_launchers()
df_kernel_launchers.round(2).head()

Unnamed: 0,name,total_direct_kernel_time,direct_kernel_count,Input Dims
0,aten::index_select,2665.0,1,"((128256, 8192), (), (73728,))"
1,triton_red_fused__to_copy_add_mean_mul_pow_rsq...,2286.47,1,"((18, 4096, 1), (18, 4096, 8192), (8192,), (18..."
2,aten::mm,24762.75,1,"((73728, 8192), (8192, 10240), (73728, 10240))"
3,triton_poi_fused_clone_1,1572.92,1,"((73728, 10240), (4096, 64, 2), (18, 4096, 64,..."
4,triton_poi_fused_clone_2,183.83,1,"((73728, 10240), (4096, 64, 2), (18, 4096, 8, ..."


In [43]:
# group by op name and summarize
# this gives an op wise breakdown of gpu time
df_kernel_launchers_summary = perf_analyzer.get_df_kernel_launchers_summary(df_kernel_launchers)
df_kernel_launchers_summary.round(2).head()

Unnamed: 0,name,total_direct_kernel_time_sum,Count,total_direct_kernel_time_ms,Percentage (%),Cumulative Percentage (%)
0,aten::mm,5576758.03,126,5576.76,88.73,88.73
1,flash_attn::_flash_attn_backward,213619.49,8,213.62,3.4,92.13
2,flash_attn::_flash_attn_forward,119653.38,8,119.65,1.9,94.04
3,aten::copy_,69959.87,4,69.96,1.11,95.15
4,triton_poi_fused_add_fill_mul_sigmoid_silu_sub_0,43190.06,8,43.19,0.69,95.84


In [44]:
# We can further get breakdown by shapes for a particular op
# We do this by filtering the name and then grouping by the input dims
df_kernel_launchers_summary_name_shapes = perf_analyzer.get_df_kernel_launchers_summary_by_shape(df_kernel_launchers, "aten::mm")
df_kernel_launchers_summary_name_shapes.round(2)

Unnamed: 0,Input Dims,Total Kernel Time (µs),Count,Mean Kernel Time (µs),Std Kernel Time (µs),Max Direct Kernel Count,Min Direct Kernel Count,Total Kernel Time (ms),Percentage (%),Cumulative Percentage (%)
0,"((73728, 8192), (8192, 28672), (73728, 28672))",1193258.36,24,49719.1,2422.12,1,1,1193.26,21.4,21.4
1,"((73728, 28672), (28672, 8192), (73728, 8192))",1162965.71,24,48456.9,4183.92,1,1,1162.97,20.85,42.25
2,"((28672, 73728), (73728, 8192), (28672, 8192))",748381.82,16,46773.86,369.05,1,1,748.38,13.42,55.67
3,"((73728, 8192), (8192, 128256))",493321.2,2,246660.6,310.83,1,1,493.32,8.85,64.52
4,"((8192, 73728), (73728, 28672), (8192, 28672))",467928.45,8,58491.06,6695.2,1,1,467.93,8.39,72.91
5,"((128256, 73728), (73728, 8192))",382835.1,2,191417.55,99.46,1,1,382.84,6.86,79.77
6,"((73728, 128256), (128256, 8192))",354791.3,2,177395.65,92.43,1,1,354.79,6.36,86.13
7,"((73728, 8192), (8192, 8192), (73728, 8192))",207449.51,16,12965.59,700.94,1,1,207.45,3.72,89.85
8,"((73728, 8192), (8192, 10240), (73728, 10240))",170397.23,8,21299.65,3409.7,1,1,170.4,3.06,92.91
9,"((10240, 73728), (73728, 8192), (10240, 8192))",145209.72,8,18151.22,96.42,1,1,145.21,2.6,95.51


In [45]:
# Roofline for ops
# currently we have GEMM, CONV fwd+bwd, FA
# many more coming soon

# Example 1 GEMM
gemm_events = [event for event in tree.events if event['name'] in ['aten::addmm', 'aten::mm']]
gemm_events_uids = [event['UID'] for event in gemm_events]
print(f"Found {len(gemm_events)} gemm events")

# take an example event and compute perf metrics
gemm_event = gemm_events[0]
print("Event dict:")
pprint(gemm_event)
print("Perf metrics dict:")
pprint(perf_analyzer.compute_perf_metrics(gemm_event))


Found 126 gemm events
Event dict:
{'UID': 109,
 'args': {'Concrete Inputs': ['', '', ''],
          'Ev Idx': 95,
          'External id': 96,
          'Input Dims': [[73728, 8192], [8192, 10240], [73728, 10240]],
          'Input Strides': [[8192, 1], [1, 8192], [10240, 1]],
          'Input type': ['c10::BFloat16', 'c10::BFloat16', 'c10::BFloat16'],
          'Record function id': 0},
 'cat': 'cpu_op',
 'children': [110, 52127, 52129, 52131, 52133],
 'direct_kernel_count': 1,
 'dur': 56.364,
 'name': 'aten::mm',
 'parent': 106,
 'ph': 'X',
 'pid': 172800,
 't_end': 926440231441.65,
 'tid': 172800,
 'total_direct_kernel_time': 24762.7509765625,
 'tree': True,
 'ts': 926440231385.286}
Perf metrics dict:
{'FLOPS/Byte': 4286.511627906977,
 'GFLOPS': 12369.50581248,
 'Kernel Time (µs)': 24762.7509765625,
 'Kernel sum Time (µs)': 24762.751,
 'TB/s': 0.11653314103635923,
 'TFLOPS/s': 499.5206640888775,
 'param: K': 8192,
 'param: M': 73728,
 'param: N': 10240,
 'param: bias': False}


In [46]:
# build table for compute perf metrics for all gemm events
df_gemm_ops = perf_analyzer.build_df_perf_metrics(gemm_events_uids, bwd=False, non_data_mov=True)
df_gemm_ops.head()

Unnamed: 0,cat,name,pid,tid,external_id,GFLOPS,Kernel Time (µs),Kernel sum Time (µs),TFLOPS/s,Non-Data-Mov Kernel Time (µs),Non-Data-Mov TFLOPS/s,FLOPS/Byte,TB/s,param: M,param: N,param: K,param: bias
0,cpu_op,aten::mm,172800,172800,96,12369.505812,24762.750977,24762.751,499.520664,24762.750977,499.520664,4286.511628,0.116533,73728,10240,8192,False
1,cpu_op,aten::mm,172800,172800,113,9895.60465,13167.574951,13167.575,751.513068,13167.574951,751.513068,3880.421053,0.193668,73728,8192,8192,False
2,cpu_op,aten::mm,172800,172800,116,34634.616275,53424.876953,53424.877,648.286309,53424.876953,648.286309,5864.727273,0.11054,73728,28672,8192,False
3,cpu_op,aten::mm,172800,172800,118,34634.616275,53317.717041,53317.717,649.589258,53317.717041,649.589258,5864.727273,0.110762,73728,28672,8192,False
4,cpu_op,aten::mm,172800,172800,121,34634.616275,55024.0,55024.0,629.445629,55024.0,629.445629,5864.727273,0.107327,73728,8192,28672,False


In [47]:
# summarize by grouping across params M K N and bias and computing aggregate metrics
perf_analyzer.summarize_df_perf_metrics(df_gemm_ops, ['mean'])

Unnamed: 0,name,param: M,param: N,param: K,param: bias,GFLOPS_first,FLOPS/Byte_first,TB/s_mean,TFLOPS/s_mean,Non-Data-Mov TFLOPS/s_mean,Non-Data-Mov Kernel Time (µs)_sum,Kernel Time (µs)_sum,name_count
0,aten::mm,73728,28672,8192,False,34634.616275,5864.727273,0.11905,698.194193,698.194193,1193258.0,1193258.0,24
1,aten::mm,73728,8192,28672,False,34634.616275,5864.727273,0.122698,719.592525,719.592525,1162966.0,1162966.0,24
2,aten::mm,28672,8192,73728,False,34634.616275,5864.727273,0.126265,740.512598,740.512598,748381.8,748381.8,16
3,aten::mm,73728,128256,8192,False,154928.060301,6972.01359,0.090089,628.102683,628.102683,493321.2,493321.2,2
4,aten::mm,8192,28672,73728,False,34634.616275,5864.727273,0.102298,599.949658,599.949658,467928.5,467928.5,8
5,aten::mm,128256,8192,73728,False,154928.060301,6972.01359,0.116089,809.372403,809.372403,382835.1,382835.1,2
6,aten::mm,73728,8192,128256,False,154928.060301,6972.01359,0.125265,873.347696,873.347696,354791.3,354791.3,2
7,aten::mm,73728,8192,8192,False,9895.60465,3880.421053,0.197245,765.392934,765.392934,207449.5,207449.5,16
8,aten::mm,73728,10240,8192,False,12369.505812,4286.511628,0.138983,595.753241,595.753241,170397.2,170397.2,8
9,aten::mm,10240,8192,73728,False,12369.505812,4286.511628,0.158984,681.486622,681.486622,145209.7,145209.7,8


In [48]:
# Example 2a FA fwd
fa_root_nodes = [node for node in tree.cpu_root_nodes if tree.events_by_uid[node]['name'] == 'FlashAttnFunc']
df_fa_fwd_ops = perf_analyzer.build_df_perf_metrics(fa_root_nodes, bwd=False, non_data_mov=True)
perf_analyzer.summarize_df_perf_metrics(df_fa_fwd_ops, ['mean'])

KeyError: 'name'

In [None]:
# Example 2b FA bwd
# Note: bwd events for a fwd pass event are found 
# by traversing the autograd links
df_fa_bwd_ops = perf_analyzer.build_df_perf_metrics(fa_root_nodes, bwd=True, non_data_mov=True)
perf_analyzer.summarize_df_perf_metrics(df_fa_bwd_ops, ['mean'])

In [None]:
# Example 3a conv2d fwd
conv2d_root_nodes = [node for node in tree.cpu_root_nodes if tree.events_by_uid[node]['name'] == 'aten::conv2d']

df_conv2d_ops = perf_analyzer.build_df_perf_metrics(conv2d_root_nodes, bwd=False, non_data_mov=True)
perf_analyzer.summarize_df_perf_metrics(df_conv2d_ops, ['mean'])

In [None]:
# Example 3b conv2d bwd 
df_conv2d_bwd_ops = perf_analyzer.build_df_perf_metrics(conv2d_root_nodes, bwd=True, non_data_mov=True)
perf_analyzer.summarize_df_perf_metrics(df_conv2d_bwd_ops, ['mean'])