## Background

There are separate tools and scripts to collect performance data on jax traces. e.g.
1. Tracelens/examples/generate_perf_report.py: PerfModel.torch_op_mapping, TreePerfAnalyzer
- df_gpu_timeline -> get_kernal_events -> GPUEventAnalyser.compute_metrics_dict(dict_events) -> breakdown
- df_kernel_launchers -> get_df_kernel_launchers -> get_kernel_launchers -> 'traverses the event tree', parent-child-grandchild?
- *df_kernel_launchers_summary* -> get_df_kernel_launchers_summary
- *df_kernel_launchers_summary_by_category* -> get_df_kernel_launchers_summary_by_category
- *df_kernel_launchers_unique_args* -> get_df_kernel_launchers_unique_args
- df_op_x -> dict_cat2names -> op_to_perf_model_class_map
- df_op_x -> loop_and_aggregate_kernels -> agg_kernels_in_subtree -> GPUEventAnalyser.compute_metrics_dict(dict_events) -> breakdown
2. Tracelens/Reporting/gererate_perf_report_jax_analysis.py: jax_analyses
- df_gpu_events_averages -> summarize_gpu_events -> JaxGPUEventAnalyser.compute_metrics_dict(dict_events) -> breakdown
- df_gpu_events_categorized
- df_xla_grouped
- df_gemms_detailed
3. JaxTrace_Analysis
- df_gemm_perf_info
- df_perf_breakdown
- gemm_roofline
- attention_roofline
- rccl
- xla_op
- hlo_op



In [3]:
# Imports
import os, sys
from pprint import pprint
import json
import pandas as pd
from itertools import chain
from collections import Counter

from TraceLens import DataLoader
from TraceLens import TraceToTree, JaxAnalyses, TraceEventUtils, TreePerfAnalyzer

# Configs
home_dir='/home/guangphu'
data_dir='perf-profiling/midj_traces'
# jax xplane.pb
trace_id='mi355x/hunyuan_t129/plugins/profile/2025_06_25_12_43_37'
trace_fname='chi-mi300x-007.ord.vultr.cpe.ice.amd.com'
trace_ext='.xplane.pb'
trace_dir=os.path.join(home_dir, data_dir, trace_id) 
trace_path=os.path.join(home_dir, data_dir, trace_id, trace_fname) + trace_ext 
os.path.exists(trace_path)
trace_path_jax = trace_path
# pytorch trace.json
trace_id='tts-traces/tts-traces-h100/bs16/rank_0/'
trace_fname='rocm-framework-h100-sxm-1_60628.1751372362949836640'
trace_ext='.pt.trace.json'
trace_dir=os.path.join(home_dir, data_dir, trace_id) 
trace_path=os.path.join(home_dir, data_dir, trace_id, trace_fname) + trace_ext
os.path.exists(trace_path)
trace_path_pytorch = trace_path


## Trace2Tree

In [None]:
data = DataLoader.load_data(trace_path_pytorch)
events_pt=data["traceEvents"]
data = DataLoader.load_data(trace_path_jax)
events_jax=data["traceEvents"]

length of pt events: 1895761


2025-08-08 07:12:00.372981: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-08-08 07:12:00.396264: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-08-08 07:12:00.578546: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-08-08 07:12:00.723127: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1754637120.844882 1274858 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1754637120.88

Length of jax events: 1982839


##### Event keys

In [11]:
# Using set comprehension with itertools.chain.from_iterable
names = [e['name'] for e in events_pt]
set_keys_pt = set(chain.from_iterable(d.keys() for d in events_pt))
print('torch keys:', len(set_keys_pt), set_keys_pt)
all_keys = []
for d in events_pt:
    all_keys.extend(d.keys())
key_counts = Counter(all_keys)
print(f'number of torch events: {len(events_pt)}, unique event names: {len(set(names))}', )
print('torch event key counts:', key_counts)

names = [e['name'] for e in events_jax]
set_keys_jax = set(chain.from_iterable(d.keys() for d in events_jax))
print('jax keys', len(set_keys_jax), set_keys_jax)
all_keys = []
for d in events_jax:
    all_keys.extend(d.keys())
key_counts = Counter(all_keys)
print(f'number of jax events: {len(events_jax)}, unique event names: {len(set(names))}')
print('jax event key counts:',key_counts)

keys_common = [_key for _key in set_keys_pt if _key in set_keys_jax]
keys_uniq_jax = [_key for _key in set_keys_jax if _key not in set_keys_pt]
keys_uniq_pt = [_key for _key in set_keys_pt if _key not in set_keys_jax]
print('common kyes:', len(keys_common), keys_common)
print('unique to torch:', len(keys_uniq_pt), keys_uniq_pt)
print('unique to jax:', len(keys_uniq_jax), keys_uniq_jax)

torch keys: 11 {'pid', 'dur', 'cat', 'ts', 'args', 'id', 'bp', 'tid', 's', 'ph', 'name'}
number of torch events: 1895761, unique event names: 1101
torch event key counts: Counter({'ph': 1895761, 'name': 1895761, 'pid': 1895761, 'tid': 1895761, 'ts': 1895761, 'cat': 1895689, 'args': 1754535, 'dur': 1686174, 'id': 141205, 'bp': 81114, 's': 68312})
jax keys 10 {'z', 'dur', 'ts', 'sf', 'tid', 'args', 'pid', 'thread_count', 'ph', 'name'}
number of jax events: 1982839, unique event names: 11388
jax event key counts: Counter({'args': 1982839, 'name': 1982839, 'ph': 1982839, 'pid': 1982839, 'tid': 1982821, 'ts': 1982601, 'dur': 1982601, 'z': 264583, 'sf': 184277, 'thread_count': 9})
common kyes: 7 ['pid', 'dur', 'ts', 'args', 'tid', 'ph', 'name']
unique to torch: 4 ['cat', 'id', 'bp', 's']
unique to jax: 3 ['z', 'sf', 'thread_count']


##### Example Event

In [6]:
_event = events_jax[999] 
for _key in keys_uniq_jax:
    print('uniq to jax:', _key, _event.get(_key, None))
for _key in keys_common:
    print('common key:', _key, _event[_key])

_event = events_pt[999]
for _key in keys_uniq_pt:
    print('uniq to pt:', _key, _event.get(_key, None))
for _key in keys_common:
    print('common key:', _key, _event[_key])

uniq to jax: z 3
uniq to jax: sf None
uniq to jax: thread_count None
common key: pid 1
common key: dur 6.931
common key: ts 233560.213
common key: args {'correlation_id': 34507, 'theoretical_occupancy_pct': 0, 'occupancy_min_grid_size': 0, 'occupancy_suggested_block_size': 0, 'tf_op': 'XlaModule:', 'name': 'jit(train_step)/jit(main)/dot_general', 'hlo_op': 'gemm_fusion_dot.607.0', 'hlo_module': 'jit_train_step', 'is_eager': 0}
common key: tid 1
common key: ph X
common key: name gemm_fusion_dot_571_0
uniq to pt: cat cpu_op
uniq to pt: id None
uniq to pt: bp None
uniq to pt: s None
common key: pid 60628
common key: dur 31.199
common key: ts 7850735315604.46
common key: args {'External id': 860, 'Record function id': 0, 'finished': True, 'Ev Idx': 859}
common key: tid 60628
common key: ph X
common key: name TorchDynamo Cache Lookup


##### Tree categories

In [12]:
#jax tree
trace_events = events_jax
categorizer_jax =  JaxAnalyses.prepare_event_categorizer(trace_events)
trace_events = TraceEventUtils.non_metadata_events(trace_events)
events, categorizer = events_jax, categorizer_jax
cats = [categorizer(e) for e in events if categorizer(e)]
print('jax number of events:', len(cats), '\ncats:', len(set(cats)), set(cats))

# torch tree
trace_events = events_pt 
categorizer_pt = TraceToTree.default_categorizer 
events, categorizer = events_pt, categorizer_pt
cats = [categorizer(e) for e in events if categorizer(e)]
print('torch number of events::', len(cats), '\ncats:', len(set(cats)), set(cats))


jax number of events: 1982839 
cats: 5 {'memcpy', 'Unknown', 'kernel', 'cpu_op', 'python function'}
torch number of events:: 1895689 
cats: 14 {'ac2g', 'gpu_memcpy', 'Trace', 'cuda_driver', 'gpu_memset', 'user_annotation', 'gpu_user_annotation', 'cpu_instant_event', 'cuda_runtime', 'python_function', 'kernel', 'cpu_op', 'overhead', 'fwdbwd'}


## Performance

### 1. Temporal Breakdown
- Breakdown of time taken by the GPUs in terms of time spent in computation, communication, memory events, and idle time across all ranks.


In [16]:
kernel_events = '/home/guangphu/perf-profiling/logs/tracelens/jax/trace_analysis_results_kernel_events.csv'
df = pd.read_csv(kernel_events)
df.shape, df.head(3)

((173931, 17),
    pid  tid                      name         ts   dur ph  \
 0    1    1  loop_convert_fusion_1111  164176.73  2.28  X   
 1    1    1        loop_add_fusion_51  164208.66  2.28  X   
 2    1    1        loop_add_fusion_52  164237.70  2.32  X   
 
                                                 args    z  UID      t_end  \
 0  {'correlation_id': 245, 'theoretical_occupancy...  2.0    0  164179.02   
 1  {'correlation_id': 259, 'theoretical_occupancy...  3.0    1  164210.94   
 2  {'correlation_id': 268, 'theoretical_occupancy...  2.0    2  164240.03   
 
                                              process  \
 0  {'process_name': 'chi-mi300x-007.ord.vultr.cpe...   
 1  {'process_name': 'chi-mi300x-007.ord.vultr.cpe...   
 2  {'process_name': 'chi-mi300x-007.ord.vultr.cpe...   
 
                                               thread   parent  tree  \
 0  {'thread_name': 'Stream #1(Kernel,Memset)', 't...  1821028  True   
 1  {'thread_name': 'Stream #1(Kernel,Memset)',

### 2. Kernel Breakdown


In [None]:
kernel_launchers = '/home/guangphu/perf-profiling/logs/tracelens/pytorch/trace_analysis_results_kernel_launchers.csv'
df = pd.read_csv(kernel_launchers)
# print(df.shape, df.head(3))


(29760, 10)         name op category  UID  total_direct_kernel_time  direct_kernel_count  \
0  aten::cat       other    3                      4.09                    1   
1  aten::cat       other    6                      3.62                    1   
2  aten::cat       other   16                      3.33                    1   

                                   Input Dims                Input type  \
0         (((16, 256, 1), (16, 256, 32)), ())  ('TensorList', 'Scalar')   
1         (((16, 256, 1), (16, 256, 32)), ())  ('TensorList', 'Scalar')   
2  (((32,), (131072,), (64,), (262144,)), ())  ('TensorList', 'Scalar')   

                        Input Strides Concrete Inputs  \
0  (((256, 1, 1), (8192, 32, 1)), ())      ('', '-1')   
1  (((256, 1, 1), (8192, 32, 1)), ())      ('', '-1')   
2      (((1,), (1,), (1,), (1,)), ())       ('', '0')   

                                        kernel_names  
0  ['void at::native::(anonymous namespace)::CatA...  
1  ['void at::native::(anon

In [11]:
import pandas as pd

kernel_launchers = '/home/guangphu/perf-profiling/logs/tracelens/jax/trace_analysis_results_kernel_launchers.csv'
df = pd.read_csv(kernel_launchers)
#print(df.shape, df.head(3))
#df.loc[df['name']=='loop_convert_fusion_1111']
for pid in range(1,9):
    print(pid, df.loc[df['pid']==pid].shape)
    print('number of uniq names', len(set(df.name)))

1 (21764, 13)
number of uniq names 821
2 (21761, 13)
number of uniq names 821
3 (21761, 13)
number of uniq names 821
4 (21761, 13)
number of uniq names 821
5 (21761, 13)
number of uniq names 821
6 (21601, 13)
number of uniq names 821
7 (21761, 13)
number of uniq names 821
8 (21761, 13)
number of uniq names 821


In [15]:
set(df.tid)

{1, 19, 20, 22, 23}

In [14]:
df.head(3)

Unnamed: 0,name,op category,UID,total_direct_kernel_time,direct_kernel_count,Input Dims,Input type,Input Strides,Concrete Inputs,pid,tid,external_id,kernel_names
0,loop_convert_fusion_1111,other,0,2.28,1,,,,,1,1,,['loop_convert_fusion_1111']
1,loop_add_fusion_51,other,1,2.28,1,,,,,1,1,,['loop_add_fusion_51']
2,loop_add_fusion_52,other,2,2.32,1,,,,,1,1,,['loop_add_fusion_52']


In [12]:
df6 = df.loc[df['pid']==6]
df6.head(3)

Unnamed: 0,name,op category,UID,total_direct_kernel_time,direct_kernel_count,Input Dims,Input type,Input Strides,Concrete Inputs,pid,tid,external_id,kernel_names
108808,loop_convert_fusion_1111,other,273375,2.4,1,,,,,6,1,,['loop_convert_fusion_1111']
108809,loop_add_fusion_51,other,273376,2.04,1,,,,,6,1,,['loop_add_fusion_51']
108810,loop_add_fusion_52,other,273377,2.12,1,,,,,6,1,,['loop_add_fusion_52']



### 3. Trace Comparison



### 4. Memory Profiling

In [None]:
# Tracelens/examples/generate_perf_report.py
df_gpu_timeline
df_kernel_launchers
df_op_x

# Tracelens/Reporting/gererate_perf_report_jax_analysis.py
df_gpu_events_averages
df_gpu_events_categorized
df_xla_grouped
df_gemms_detailed

# JaxTrace_Analysis
df_gemm_perf_info
df_perf_breakdown
roofline_figs
