In [1]:
from tqdm import tqdm
from pprint import pprint
import json
import pandas as pd
import sys
sys.path.append('/home/ajassani/iLoveTrace')
from Trace2Tree.trace_to_tree import TraceToTree
from tree_perf import TreePerfAnalyzer

In [9]:


# path = "/content/pytorch_profile_rank7_step120.json"
# path = "/home/ajassani/trace_data/pytorch_profile_gpt-3-large-segmentation_ddp_bfloat16_bs10_level1_rank0.json"
path = '/home/ajassani/trace_data/wide_resnet101_2.json'
with open(path, 'r') as f:
    data = json.load(f)

events = data['traceEvents']
tree = TraceToTree(events)
tree.build_tree(add_python_func=False)
perf_analyzer = TreePerfAnalyzer(tree)


Building tree with add_python_func=False
Building CPU op tree with add_python_func=False


In [None]:
events

In [11]:
root_nodes = [node for node in tree.cpu_root_nodes if tree.events_by_uid[node]['name'] == 'aten::linear']

for uid in root_nodes:
    event = tree.events_by_uid[uid]
    result_dict = perf_analyzer.compute_fwd_perf_metrics(event)
    print(result_dict)

{'GFLOPS': 0.004097, 'Kernel Time (µs)': 27, 'TFLOPS/s': 0.15174074074074073, 'Non-Data-Mov Kernel Time (µs)': 27, 'Non-Data-Mov TFLOPS/s': 0.15174074074074073, 'FLOPS/Byte': 0.9982709956102391, 'param: M': 1, 'param: N': 1000, 'param: K': 2048, 'param: bias': True}


In [3]:
# Full model
# get event with name "nn.Module: DistributedDataParallel_0"
event = next(event for event in tree.events if event['name'] == 'nn.Module: DistributedDataParallel_0')
fwd_kernel_time, _ = perf_analyzer.agg_kernels_in_subtree(event['UID'])
tree.link_bwd_events(event['UID'])
bwd_kernel_time, _ = perf_analyzer.loop_and_aggregate_kernels(event['bwd_events'])
total_kernel_time = fwd_kernel_time + bwd_kernel_time

# (1) conv backbone part
event = next(event for event in tree.events if event['name'] == 'model/model.py(173): backbone_forward')
fwd_conv_backbone_kernel_time,_ = perf_analyzer.agg_kernels_in_subtree(event['UID'])

tree.link_bwd_events(event['UID'])
bwd_conv_backbone_kernel_time,  _ = perf_analyzer.loop_and_aggregate_kernels(event['bwd_events'])
total_conv_backbone_kernel_time = fwd_conv_backbone_kernel_time + bwd_conv_backbone_kernel_time

# (2) Transformer part
transformer_node_names = ['nn.Module: CrossAttention_0']
transformer_node_names.extend([f'nn.Module: TransformerEncoder_{i}' for i in range(24)])
# get event uid list
transformer_node_uids = [event['UID'] for event in tree.events if event['name'] in transformer_node_names]
fwd_transformer_kernel_time, _ = perf_analyzer.loop_and_aggregate_kernels(transformer_node_uids)
# for each event link bwd and then get
for event_uid in transformer_node_uids:
    tree.link_bwd_events(event_uid)
bwd_events_uids = [event['bwd_events'] for event in tree.events if event['name'] in transformer_node_names]
bwd_events_uids = [item for sublist in bwd_events_uids for item in sublist]
bwd_transformer_kernel_time, _ = perf_analyzer.loop_and_aggregate_kernels(bwd_events_uids)
total_transformer_kernel_time = fwd_transformer_kernel_time + bwd_transformer_kernel_time

# (3) Decoder part
event = next(event for event in tree.events if event['name'] == 'nn.Module: Decoder_0')
fwd_decoder_kernel_time, _ = perf_analyzer.agg_kernels_in_subtree(event['UID'])
tree.link_bwd_events(event['UID'])
bwd_decoder_kernel_time, _ = perf_analyzer.loop_and_aggregate_kernels(event['bwd_events'])
total_decoder_kernel_time = fwd_decoder_kernel_time + bwd_decoder_kernel_time

list_model_level_metrics = []
list_model_level_metrics.append({'name': 'Conv Backbone fwd', 'kernel_time (ms)': fwd_conv_backbone_kernel_time / 1000, 'percent_total': round(fwd_conv_backbone_kernel_time/total_kernel_time*100, 2)})
list_model_level_metrics.append({'name': 'Conv Backbone bwd', 'kernel_time (ms)': bwd_conv_backbone_kernel_time / 1000, 'percent_total': round(bwd_conv_backbone_kernel_time/total_kernel_time*100, 2)})
list_model_level_metrics.append({'name': 'Transformer fwd', 'kernel_time (ms)': fwd_transformer_kernel_time / 1000, 'percent_total': round(fwd_transformer_kernel_time/total_kernel_time*100, 2)})
list_model_level_metrics.append({'name': 'Transformer bwd', 'kernel_time (ms)': bwd_transformer_kernel_time / 1000, 'percent_total': round(bwd_transformer_kernel_time/total_kernel_time*100, 2)})
list_model_level_metrics.append({'name': 'Decoder fwd', 'kernel_time (ms)': fwd_decoder_kernel_time / 1000, 'percent_total': round(fwd_decoder_kernel_time/total_kernel_time*100, 2)})
list_model_level_metrics.append({'name': 'Decoder bwd', 'kernel_time (ms)': bwd_decoder_kernel_time / 1000, 'percent_total': round(bwd_decoder_kernel_time/total_kernel_time*100, 2)})

# put into df
df_model_level_metrics = pd.DataFrame(list_model_level_metrics)
display(df_model_level_metrics) # Use display for better formatting in Colab
# save df
df_model_level_metrics.to_csv('model_level_metrics.csv', index=False)

Unnamed: 0,name,kernel_time (ms),percent_total
0,Conv Backbone fwd,203.257,9.41
1,Conv Backbone bwd,414.32,19.18
2,Transformer fwd,297.84,13.79
3,Transformer bwd,950.451,44.01
4,Decoder fwd,10.326,0.48
5,Decoder bwd,283.447,13.12


In [4]:
linear_root_nodes = [node for node in tree.cpu_root_nodes if tree.events_by_uid[node]['name'] == 'aten::linear']
print(f"Found {len(linear_root_nodes)} root nodes")

df_linear_ops = perf_analyzer.build_df_perf_metrics(linear_root_nodes, bwd=False)
perf_analyzer.summarize_df_perf_metrics(df_linear_ops, ['mean'])

Found 149 root nodes


Unnamed: 0,name,param: M,param: N,param: K,param: bias,GFLOPS_first,FLOPS/Byte_first,TFLOPS/s_mean,Non-Data-Mov Kernel Time (µs)_sum,Kernel Time (µs)_sum
0,aten::linear,40960,1536,1536,True,193.336443,754.101394,261.342643,48339,73298
2,aten::linear,40960,6144,1536,True,773.157028,1193.10114,364.539992,35732,50908
1,aten::linear,40960,1536,6144,True,773.345772,1193.375429,438.26952,38175,42408
3,aten::linear,41400,1536,1536,True,195.413299,754.249025,151.97503,1626,2572
4,aten::linear,48000,64,1536,True,9.510912,61.839609,27.973271,117,340


In [16]:
linear_root_nodes = [node for node in tree.cpu_root_nodes if tree.events_by_uid[node]['name'] == 'aten::linear']
uid = linear_root_nodes[0]
event = tree.events_by_uid[uid]
perf_analyzer.compute_fwd_perf_metrics(event)

{'GFLOPS': 193.33644288,
 'Kernel Time (µs)': 778,
 'TFLOPS/s': 248.5044252956298,
 'Non-Data-Mov Kernel Time (µs)': 480,
 'Non-Data-Mov TFLOPS/s': 402.78425599999997,
 'FLOPS/Byte': 754.1013935319985,
 'param: M': 40960,
 'param: N': 1536,
 'param: K': 1536,
 'param: bias': True}

In [5]:
df_linear_ops

Unnamed: 0,cat,name,pid,tid,external_id,GFLOPS,Kernel Time (µs),TFLOPS/s,Non-Data-Mov Kernel Time (µs),Non-Data-Mov TFLOPS/s,FLOPS/Byte,param: M,param: N,param: K,param: bias
0,cpu_op,aten::linear,17178,17178,9103,193.336443,778,248.504425,480,402.784256,754.101394,40960,1536,1536,True
1,cpu_op,aten::linear,17178,17178,9126,195.413299,1301,150.202382,824,237.152062,754.249025,41400,1536,1536,True
2,cpu_op,aten::linear,17178,17178,9149,195.413299,1271,153.747678,802,243.657480,754.249025,41400,1536,1536,True
3,cpu_op,aten::linear,17178,17178,9190,193.336443,629,307.371133,486,397.811611,754.101394,40960,1536,1536,True
4,cpu_op,aten::linear,17178,17178,9209,193.336443,593,326.031101,458,422.131971,754.101394,40960,1536,1536,True
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
144,cpu_op,aten::linear,17178,17178,13525,193.336443,735,263.042779,461,419.384909,754.101394,40960,1536,1536,True
145,cpu_op,aten::linear,17178,17178,13566,193.336443,650,297.440681,507,381.334207,754.101394,40960,1536,1536,True
146,cpu_op,aten::linear,17178,17178,13600,773.157028,2120,364.696711,1481,522.050660,1193.101140,40960,6144,1536,True
147,cpu_op,aten::linear,17178,17178,13631,773.345772,1824,423.983427,1622,476.785309,1193.375429,40960,1536,6144,True


In [6]:
# linear bwd metrics
df_linear_bwd_ops = perf_analyzer.build_df_perf_metrics(linear_root_nodes, bwd=True)
perf_analyzer.summarize_df_perf_metrics(df_linear_bwd_ops, ['mean'])

Unnamed: 0,name,param: M,param: N,param: K,param: bias,GFLOPS_first,FLOPS/Byte_first,TFLOPS/s_mean,Non-Data-Mov Kernel Time (µs)_sum,Kernel Time (µs)_sum
0,aten::linear,40960,1536,1536,True,386.609971,999.179059,212.946378,166553,179200
2,aten::linear,40960,6144,1536,True,1546.251141,1321.340617,325.657982,108279,114028
1,aten::linear,40960,1536,6144,True,1546.439885,1950.763412,340.85699,108361,109205
3,aten::linear,41400,1536,1536,True,390.763008,999.438215,237.908783,3268,3292
4,aten::linear,48000,64,1536,True,18.948096,118.321855,27.381642,496,692


In [7]:
# flash attention metrics
fa_root_nodes = [node for node in tree.cpu_root_nodes if tree.events_by_uid[node]['name'] == 'FlashAttnFunc']
print(f"Found {len(fa_root_nodes)} root nodes")
df_fa_fwd_ops = perf_analyzer.build_df_perf_metrics(fa_root_nodes, bwd=False)
perf_analyzer.summarize_df_perf_metrics(df_fa_fwd_ops, ['mean'])

Found 25 root nodes


Unnamed: 0,name,param: B,param: N_Q,param: N_K,param: H,param: d_k,param: dropout,param: causal,param: flash_impl,GFLOPS_first,FLOPS/Byte_first,TFLOPS/s_mean,Non-Data-Mov Kernel Time (µs)_sum,Kernel Time (µs)_sum
0,FlashAttnFunc,10,4096,4096,16,96,0.0,False,True,1030.792151,2048.0,303.081004,81653,81653
1,FlashAttnFunc,10,4096,4140,16,96,0.0,False,True,1041.865114,2053.456043,286.305335,3639,3639


In [8]:
# flash attention bwd metrics
df_fa_bwd_ops = perf_analyzer.build_df_perf_metrics(fa_root_nodes, bwd=True)
perf_analyzer.summarize_df_perf_metrics(df_fa_bwd_ops, ['mean'])

Unnamed: 0,name,param: B,param: N_Q,param: N_K,param: H,param: d_k,param: dropout,param: causal,param: flash_impl,GFLOPS_first,TFLOPS/s_mean,Non-Data-Mov Kernel Time (µs)_sum,Kernel Time (µs)_sum
0,FlashAttnFunc,10,4096,4096,16,96,0.0,False,True,2576.980378,134.131422,461134,461134
1,FlashAttnFunc,10,4096,4140,16,96,0.0,False,True,2604.662784,131.184225,19855,19855


In [9]:
conv2d_root_nodes = [node for node in tree.cpu_root_nodes if tree.events_by_uid[node]['name'] == 'aten::conv2d']
print(f"Found {len(conv2d_root_nodes)} root nodes")

df_conv2d_ops = perf_analyzer.build_df_perf_metrics(conv2d_root_nodes, bwd=False)
perf_analyzer.summarize_df_perf_metrics(df_conv2d_ops, ['mean'])

Found 276 root nodes


Unnamed: 0,name,param: B,param: C_in,param: H,param: W,param: C_out,param: K_h,param: K_w,param: stride,param: padding,param: dilation,param: groups,param: bias,GFLOPS_first,TFLOPS/s_mean,Non-Data-Mov Kernel Time (µs)_sum,Kernel Time (µs)_sum
27,aten::conv2d,10,896,59,91,896,1,1,"(1, 1)","(0, 0)","(1, 1)",1,False,86.206382,206.430882,13009,23920
28,aten::conv2d,10,896,59,91,896,3,3,"(1, 1)","(1, 1)","(1, 1)",16,False,48.49109,97.645955,6503,13426
17,aten::conv2d,10,224,469,724,224,3,3,"(2, 2)","(1, 1)","(1, 1)",4,False,192.081254,45.274699,4575,12729
21,aten::conv2d,10,448,118,181,448,3,3,"(1, 1)","(1, 1)","(1, 1)",8,False,96.449311,82.544592,4339,10522
13,aten::conv2d,10,224,235,362,224,1,1,"(1, 1)","(0, 0)","(1, 1)",1,False,85.369446,95.731039,8011,8075
14,aten::conv2d,10,224,235,362,224,3,3,"(1, 1)","(1, 1)","(1, 1)",4,False,192.081254,72.977839,3048,7897
20,aten::conv2d,10,448,118,181,448,1,1,"(1, 1)","(0, 0)","(1, 1)",1,False,85.732721,245.430713,7201,7350
24,aten::conv2d,10,448,235,362,448,3,3,"(2, 2)","(1, 1)","(1, 1)",8,False,96.449311,46.897823,2231,6170
15,aten::conv2d,10,224,235,362,448,1,1,"(1, 1)","(0, 0)","(1, 1)",1,False,170.738893,133.117671,3826,3848
3,aten::conv2d,10,32,469,724,224,1,1,"(1, 1)","(0, 0)","(1, 1)",1,False,48.678748,39.907164,3644,3662


In [10]:
# conv2d bwd 
df_conv2d_bwd_ops = perf_analyzer.build_df_perf_metrics(conv2d_root_nodes, bwd=True)
perf_analyzer.summarize_df_perf_metrics(df_conv2d_bwd_ops, ['mean'])

Unnamed: 0,name,param: B,param: C_in,param: H,param: W,param: C_out,param: K_h,param: K_w,param: stride,param: padding,param: dilation,param: groups,param: bias,GFLOPS_first,TFLOPS/s_mean,Non-Data-Mov Kernel Time (µs)_sum,Kernel Time (µs)_sum
27,aten::conv2d,10,896,59,91,896,1,1,"(1, 1)","(0, 0)","(1, 1)",1,False,172.412764,204.136473,40148,48518
17,aten::conv2d,10,224,469,724,224,3,3,"(2, 2)","(1, 1)","(1, 1)",4,False,958.771538,79.856371,12761,36019
28,aten::conv2d,10,896,59,91,896,3,3,"(1, 1)","(1, 1)","(1, 1)",16,False,96.98218,85.186251,20133,30777
21,aten::conv2d,10,448,118,181,448,3,3,"(1, 1)","(1, 1)","(1, 1)",8,False,192.898621,69.197994,14167,25102
20,aten::conv2d,10,448,118,181,448,1,1,"(1, 1)","(0, 0)","(1, 1)",1,False,171.465441,146.426581,15616,24808
13,aten::conv2d,10,224,235,362,224,1,1,"(1, 1)","(0, 0)","(1, 1)",1,False,170.738893,85.029022,9890,18183
14,aten::conv2d,10,224,235,362,224,3,3,"(1, 1)","(1, 1)","(1, 1)",4,False,384.162509,67.152075,9694,17163
24,aten::conv2d,10,448,235,362,448,3,3,"(2, 2)","(1, 1)","(1, 1)",8,False,480.61182,84.416668,6591,17080
3,aten::conv2d,10,32,469,724,224,1,1,"(1, 1)","(0, 0)","(1, 1)",1,False,97.357496,18.308031,9123,15955
15,aten::conv2d,10,224,235,362,448,1,1,"(1, 1)","(0, 0)","(1, 1)",1,False,341.477786,126.570726,3735,8097


In [11]:
conv3d_root_nodes = [node for node in tree.cpu_root_nodes if tree.events_by_uid[node]['name'] == 'aten::conv3d']
print(f"Found {len(conv3d_root_nodes)} root nodes")

df_conv3d_ops = perf_analyzer.build_df_perf_metrics(conv3d_root_nodes, bwd=False)
perf_analyzer.summarize_df_perf_metrics(df_conv3d_ops, ['mean'])

Found 16 root nodes


Unnamed: 0,name,param: B,param: C_in,param: H,param: W,param: D,param: C_out,param: K_h,param: K_w,param: K_d,param: stride,param: padding,param: dilation,param: groups,param: bias,GFLOPS_first,TFLOPS/s_mean,Non-Data-Mov Kernel Time (µs)_sum,Kernel Time (µs)_sum
5,aten::conv3d,10,32,1,480,640,16,1,3,3,"(1, 1, 1)","(0, 0, 0)","(1, 1, 1)",1,True,9.417288,8.960312,379,1051
6,aten::conv3d,10,64,1,60,80,64,1,3,3,"(1, 1, 1)","(0, 0, 0)","(1, 1, 1)",1,True,1.114714,10.542897,290,758
3,aten::conv3d,10,32,1,240,320,32,1,3,3,"(1, 1, 1)","(0, 0, 0)","(1, 1, 1)",1,True,4.674244,14.246525,268,659
1,aten::conv3d,10,16,1,480,640,16,1,3,3,"(1, 1, 1)","(0, 0, 0)","(1, 1, 1)",1,True,4.733041,7.193072,295,658
4,aten::conv3d,10,32,1,480,640,16,1,1,1,"(1, 1, 1)","(0, 0, 0)","(1, 1, 1)",1,True,3.19488,6.339048,182,504
0,aten::conv3d,10,16,1,480,640,3,1,3,3,"(1, 1, 1)","(0, 0, 0)","(1, 1, 1)",1,True,0.887445,1.796448,193,494
8,aten::conv3d,10,64,1,120,160,32,1,3,3,"(1, 1, 1)","(0, 0, 0)","(1, 1, 1)",1,True,2.296941,14.5376,68,158
2,aten::conv3d,10,32,1,120,160,32,1,3,3,"(1, 1, 1)","(0, 0, 0)","(1, 1, 1)",1,True,1.151453,10.012639,45,115
7,aten::conv3d,10,64,1,120,160,32,1,1,1,"(1, 1, 1)","(0, 0, 0)","(1, 1, 1)",1,True,0.792576,9.324424,71,85


In [12]:
# conv3d bwd
df_conv3d_bwd_ops = perf_analyzer.build_df_perf_metrics(conv3d_root_nodes, bwd=True)
perf_analyzer.summarize_df_perf_metrics(df_conv3d_bwd_ops, ['mean'])

Unnamed: 0,name,param: B,param: C_in,param: H,param: W,param: D,param: C_out,param: K_h,param: K_w,param: K_d,param: stride,param: padding,param: dilation,param: groups,param: bias,GFLOPS_first,TFLOPS/s_mean,Non-Data-Mov Kernel Time (µs)_sum,Kernel Time (µs)_sum
0,aten::conv3d,10,16,1,480,640,3,1,3,3,"(1, 1, 1)","(0, 0, 0)","(1, 1, 1)",1,True,1.772181,0.023932,73770,74052
4,aten::conv3d,10,32,1,480,640,16,1,1,1,"(1, 1, 1)","(0, 0, 0)","(1, 1, 1)",1,True,6.340608,0.145996,42795,43430
5,aten::conv3d,10,32,1,480,640,16,1,3,3,"(1, 1, 1)","(0, 0, 0)","(1, 1, 1)",1,True,18.854472,0.442676,41726,42592
1,aten::conv3d,10,16,1,480,640,16,1,3,3,"(1, 1, 1)","(0, 0, 0)","(1, 1, 1)",1,True,9.451633,0.231142,40345,40891
3,aten::conv3d,10,32,1,240,320,32,1,3,3,"(1, 1, 1)","(0, 0, 0)","(1, 1, 1)",1,True,9.392836,1.05237,17332,17851
6,aten::conv3d,10,64,1,60,80,64,1,3,3,"(1, 1, 1)","(0, 0, 0)","(1, 1, 1)",1,True,2.294362,3.269196,4210,4915
8,aten::conv3d,10,64,1,120,160,32,1,3,3,"(1, 1, 1)","(0, 0, 0)","(1, 1, 1)",1,True,4.656237,4.291462,998,1085
2,aten::conv3d,10,32,1,120,160,32,1,3,3,"(1, 1, 1)","(0, 0, 0)","(1, 1, 1)",1,True,2.331101,2.598775,837,897
7,aten::conv3d,10,64,1,120,160,32,1,1,1,"(1, 1, 1)","(0, 0, 0)","(1, 1, 1)",1,True,1.579008,2.494483,622,633
