In [1]:
from tqdm import tqdm
from pprint import pprint
import json
import pandas as pd
import sys
sys.path.append('/home/ajassani/iLoveTrace')
from Trace2Tree.trace_to_tree import TraceToTree
from tree_perf import TreePerfAnalyzer

In [2]:


# path = "/content/pytorch_profile_rank7_step120.json"
path = "/home/ajassani/trace_data/pytorch_profile_gpt-3-large-segmentation_ddp_bfloat16_bs10_level1_rank0.json"
# path = '/home/ajassani/trace_data/wide_resnet101_2.json'
with open(path, 'r') as f:
    data = json.load(f)

events = data['traceEvents']
tree = TraceToTree(events)
tree.build_tree(add_python_func=False)
perf_analyzer = TreePerfAnalyzer(tree)


Building tree with add_python_func=False
Building CPU op tree with add_python_func=False


In [3]:
linear_root_nodes = [node for node in tree.cpu_root_nodes if tree.events_by_uid[node]['name'] == 'aten::linear']
print(f"Found {len(linear_root_nodes)} root nodes")

df_linear_ops = perf_analyzer.build_df_perf_metrics(linear_root_nodes, bwd=False)
perf_analyzer.summarize_df_perf_metrics(df_linear_ops, ['mean'])

Found 149 root nodes


Unnamed: 0,name,param: M,param: N,param: K,param: bias,GFLOPS_first,FLOPS/Byte_first,TFLOPS/s_mean,Kernel Time (µs)_sum,name_count
0,aten::linear,40960,1536,1536,True,193.336443,754.101394,261.342643,73298,98
2,aten::linear,40960,6144,1536,True,773.345772,1193.375429,364.628983,50908,24
1,aten::linear,40960,1536,6144,True,773.157028,1193.10114,438.162555,42408,24
3,aten::linear,41400,1536,1536,True,195.413299,754.249025,151.97503,2572,2
4,aten::linear,48000,64,1536,True,9.440256,61.381381,27.765459,340,1


In [4]:
linear_root_nodes = [node for node in tree.cpu_root_nodes if tree.events_by_uid[node]['name'] == 'aten::linear']
uid = linear_root_nodes[0]
event = tree.events_by_uid[uid]
perf_analyzer.compute_fwd_perf_metrics(event)

{'GFLOPS': 193.33644288,
 'Kernel Time (µs)': 778,
 'TFLOPS/s': 248.5044252956298,
 'FLOPS/Byte': 754.1013935319985,
 'param: M': 40960,
 'param: N': 1536,
 'param: K': 1536,
 'param: bias': True}

In [5]:
df_linear_ops

Unnamed: 0,cat,name,pid,tid,external_id,GFLOPS,Kernel Time (µs),TFLOPS/s,FLOPS/Byte,param: M,param: N,param: K,param: bias
0,cpu_op,aten::linear,17178,17178,9103,193.336443,778,248.504425,754.101394,40960,1536,1536,True
1,cpu_op,aten::linear,17178,17178,9126,195.413299,1301,150.202382,754.249025,41400,1536,1536,True
2,cpu_op,aten::linear,17178,17178,9149,195.413299,1271,153.747678,754.249025,41400,1536,1536,True
3,cpu_op,aten::linear,17178,17178,9190,193.336443,629,307.371133,754.101394,40960,1536,1536,True
4,cpu_op,aten::linear,17178,17178,9209,193.336443,593,326.031101,754.101394,40960,1536,1536,True
...,...,...,...,...,...,...,...,...,...,...,...,...,...
144,cpu_op,aten::linear,17178,17178,13525,193.336443,735,263.042779,754.101394,40960,1536,1536,True
145,cpu_op,aten::linear,17178,17178,13566,193.336443,650,297.440681,754.101394,40960,1536,1536,True
146,cpu_op,aten::linear,17178,17178,13600,773.345772,2120,364.785741,1193.375429,40960,6144,1536,True
147,cpu_op,aten::linear,17178,17178,13631,773.157028,1824,423.879949,1193.101140,40960,1536,6144,True


In [6]:
# linear bwd metrics
df_linear_bwd_ops = perf_analyzer.build_df_perf_metrics(linear_root_nodes, bwd=True)
perf_analyzer.summarize_df_perf_metrics(df_linear_bwd_ops, ['mean'])

Unnamed: 0,name,param: M,param: N,param: K,param: bias,GFLOPS_first,FLOPS/Byte_first,TFLOPS/s_mean,Kernel Time (µs)_sum,name_count
0,aten::linear,40960,1536,1536,True,386.609971,671.584699,212.946378,179200,98
2,aten::linear,40960,6144,1536,True,1546.439885,999.186992,325.697734,114028,24
1,aten::linear,40960,1536,6144,True,1546.251141,1137.824074,340.815388,109205,24
3,aten::linear,41400,1536,1536,True,390.763008,671.70173,237.908783,3292,2
4,aten::linear,48000,64,1536,True,18.87744,60.764576,27.279538,692,1


In [7]:
# flash attention metrics
fa_root_nodes = [node for node in tree.cpu_root_nodes if tree.events_by_uid[node]['name'] == 'FlashAttnFunc']
print(f"Found {len(fa_root_nodes)} root nodes")
df_fa_fwd_ops = perf_analyzer.build_df_perf_metrics(fa_root_nodes, bwd=False)
perf_analyzer.summarize_df_perf_metrics(df_fa_fwd_ops, ['mean'])

Found 25 root nodes


Unnamed: 0,name,param: B,param: N_Q,param: N_K,param: H,param: d_k,param: dropout,param: causal,param: flash_impl,GFLOPS_first,FLOPS/Byte_first,TFLOPS/s_mean,Kernel Time (µs)_sum,name_count
0,FlashAttnFunc,10,4096,4096,16,96,0.0,False,True,1030.792151,2048.0,303.081004,81653,24
1,FlashAttnFunc,10,4096,4140,16,96,0.0,False,True,1041.865114,2058.941234,286.305335,3639,1


In [8]:
# flash attention bwd metrics
df_fa_bwd_ops = perf_analyzer.build_df_perf_metrics(fa_root_nodes, bwd=True)
perf_analyzer.summarize_df_perf_metrics(df_fa_bwd_ops, ['mean'])

Unnamed: 0,name,param: B,param: N_Q,param: N_K,param: H,param: d_k,param: dropout,param: causal,param: flash_impl,GFLOPS_first,TFLOPS/s_mean,Kernel Time (µs)_sum,name_count
0,FlashAttnFunc,10,4096,4096,16,96,0.0,False,True,2576.980378,134.131422,461134,24
1,FlashAttnFunc,10,4096,4140,16,96,0.0,False,True,2604.662784,131.184225,19855,1


In [9]:
conv2d_root_nodes = [node for node in tree.cpu_root_nodes if tree.events_by_uid[node]['name'] == 'aten::conv2d']
print(f"Found {len(conv2d_root_nodes)} root nodes")

df_conv2d_ops = perf_analyzer.build_df_perf_metrics(conv2d_root_nodes, bwd=False)
perf_analyzer.summarize_df_perf_metrics(df_conv2d_ops, ['mean'])

Found 276 root nodes


Unnamed: 0,name,param: input_shape,param: filter_shape,param: stride,param: padding,param: dilation,param: groups,param: bias,param: transposed_conv,GFLOPS_first,FLOPS/Byte_first,TFLOPS/s_mean,Kernel Time (µs)_sum,name_count
28,aten::conv2d,"(10, 896, 59, 91)","(896, 896, 1, 1)","(1, 1)","(0, 0)","(1, 1)",1,False,False,86.206382,444.292733,206.430882,23920,57
27,aten::conv2d,"(10, 896, 59, 91)","(896, 56, 3, 3)","(1, 1)","(1, 1)","(1, 1)",16,False,False,48.49109,250.822736,97.645955,13426,27
17,aten::conv2d,"(10, 224, 469, 724)","(224, 56, 3, 3)","(2, 2)","(1, 1)","(1, 1)",4,False,False,192.081254,100.959884,45.274699,12729,3
20,aten::conv2d,"(10, 448, 118, 181)","(448, 56, 3, 3)","(1, 1)","(1, 1)","(1, 1)",8,False,False,96.449311,251.703019,82.544592,10522,9
14,aten::conv2d,"(10, 224, 235, 362)","(224, 224, 1, 1)","(1, 1)","(0, 0)","(1, 1)",1,False,False,85.369446,111.985256,95.731039,8075,9
13,aten::conv2d,"(10, 224, 235, 362)","(224, 56, 3, 3)","(1, 1)","(1, 1)","(1, 1)",4,False,False,192.081254,251.925373,72.977839,7897,3
21,aten::conv2d,"(10, 448, 118, 181)","(448, 448, 1, 1)","(1, 1)","(0, 0)","(1, 1)",1,False,False,85.732721,223.765318,245.430713,7350,21
24,aten::conv2d,"(10, 448, 235, 362)","(448, 56, 3, 3)","(2, 2)","(1, 1)","(1, 1)",8,False,False,96.449311,101.094983,46.897823,6170,3
15,aten::conv2d,"(10, 224, 235, 362)","(448, 224, 1, 1)","(1, 1)","(0, 0)","(1, 1)",1,False,False,170.738893,149.307124,133.117671,3848,3
3,aten::conv2d,"(10, 32, 469, 724)","(224, 32, 1, 1)","(1, 1)","(0, 0)","(1, 1)",1,False,False,48.678748,27.999769,39.907164,3662,3


In [11]:
# conv2d bwd 
df_conv2d_bwd_ops = perf_analyzer.build_df_perf_metrics(conv2d_root_nodes, bwd=True)
perf_analyzer.summarize_df_perf_metrics(df_conv2d_bwd_ops, ['mean'])

Unnamed: 0,name,param: input_shape,param: filter_shape,param: stride,param: padding,param: dilation,param: groups,param: bias,param: transposed_conv,GFLOPS_first,FLOPS/Byte_first,TFLOPS/s_mean,Kernel Time (µs)_sum,name_count
28,aten::conv2d,"(10, 896, 59, 91)","(896, 896, 1, 1)","(1, 1)","(0, 0)","(1, 1)",1,False,False,172.412764,444.292733,204.136473,48518,57
17,aten::conv2d,"(10, 224, 469, 724)","(224, 56, 3, 3)","(2, 2)","(1, 1)","(1, 1)",4,False,False,384.162509,100.959884,31.997011,36019,3
27,aten::conv2d,"(10, 896, 59, 91)","(896, 56, 3, 3)","(1, 1)","(1, 1)","(1, 1)",16,False,False,96.98218,250.822736,85.186251,30777,27
20,aten::conv2d,"(10, 448, 118, 181)","(448, 56, 3, 3)","(1, 1)","(1, 1)","(1, 1)",8,False,False,192.898621,251.703019,69.197994,25102,9
21,aten::conv2d,"(10, 448, 118, 181)","(448, 448, 1, 1)","(1, 1)","(0, 0)","(1, 1)",1,False,False,171.465441,223.765318,146.426581,24808,21
14,aten::conv2d,"(10, 224, 235, 362)","(224, 224, 1, 1)","(1, 1)","(0, 0)","(1, 1)",1,False,False,170.738893,111.985256,85.029022,18183,9
13,aten::conv2d,"(10, 224, 235, 362)","(224, 56, 3, 3)","(1, 1)","(1, 1)","(1, 1)",4,False,False,384.162509,251.925373,67.152075,17163,3
24,aten::conv2d,"(10, 448, 235, 362)","(448, 56, 3, 3)","(2, 2)","(1, 1)","(1, 1)",8,False,False,192.898621,101.094983,33.88152,17080,3
3,aten::conv2d,"(10, 32, 469, 724)","(224, 32, 1, 1)","(1, 1)","(0, 0)","(1, 1)",1,False,False,97.357496,27.999769,18.308031,15955,3
15,aten::conv2d,"(10, 224, 235, 362)","(448, 224, 1, 1)","(1, 1)","(0, 0)","(1, 1)",1,False,False,341.477786,149.307124,126.570726,8097,3


In [5]:
conv3d_root_nodes = [node for node in tree.cpu_root_nodes if tree.events_by_uid[node]['name'] == 'aten::conv3d']
print(f"Found {len(conv3d_root_nodes)} root nodes")

df_conv3d_ops = perf_analyzer.build_df_perf_metrics(conv3d_root_nodes, bwd=False)
perf_analyzer.summarize_df_perf_metrics(df_conv3d_ops, ['mean'])

Found 16 root nodes


Unnamed: 0,name,param: input_shape,param: filter_shape,param: stride,param: padding,param: dilation,param: groups,param: bias,param: transposed_conv,GFLOPS_first,FLOPS/Byte_first,TFLOPS/s_mean,Kernel Time (µs)_sum,name_count
5,aten::conv3d,"(10, 32, 1, 480, 640)","(16, 32, 1, 3, 3)","(1, 1, 1)","(0, 0, 0)","(1, 1, 1)",1,False,False,28.105482,95.530041,26.741658,1051,1
6,aten::conv3d,"(10, 64, 1, 60, 80)","(64, 64, 1, 3, 3)","(1, 1, 1)","(0, 0, 0)","(1, 1, 1)",1,False,False,3.335455,277.759018,31.546537,758,7
3,aten::conv3d,"(10, 32, 1, 240, 320)","(32, 32, 1, 3, 3)","(1, 1, 1)","(0, 0, 0)","(1, 1, 1)",1,False,False,13.950075,142.919099,42.518127,659,2
1,aten::conv3d,"(10, 16, 1, 480, 640)","(16, 16, 1, 3, 3)","(1, 1, 1)","(0, 0, 0)","(1, 1, 1)",1,False,False,14.052741,71.735324,21.356749,658,1
4,aten::conv3d,"(10, 32, 1, 480, 640)","(16, 32, 1, 1, 1)","(1, 1, 1)","(0, 0, 0)","(1, 1, 1)",1,False,False,3.145728,10.66663,6.241524,504,1
0,aten::conv3d,"(10, 16, 1, 480, 640)","(3, 16, 1, 3, 3)","(1, 1, 1)","(0, 0, 0)","(1, 1, 1)",1,False,False,2.634889,22.597151,5.333783,494,1
8,aten::conv3d,"(10, 64, 1, 120, 160)","(32, 64, 1, 3, 3)","(1, 1, 1)","(0, 0, 0)","(1, 1, 1)",1,False,False,6.872924,188.067305,43.49952,158,1
2,aten::conv3d,"(10, 32, 1, 120, 160)","(32, 32, 1, 3, 3)","(1, 1, 1)","(0, 0, 0)","(1, 1, 1)",1,False,False,3.436462,141.776473,29.882279,115,1
7,aten::conv3d,"(10, 64, 1, 120, 160)","(32, 64, 1, 1, 1)","(1, 1, 1)","(0, 0, 0)","(1, 1, 1)",1,False,False,0.786432,21.330963,9.252141,85,1


In [None]:
# conv3d bwd
df_conv3d_bwd_ops = perf_analyzer.build_df_perf_metrics(conv3d_root_nodes, bwd=True)
perf_analyzer.summarize_df_perf_metrics(df_conv3d_bwd_ops, ['mean'])