## Background

There are separate tools and scripts to collect performance data on jax traces. e.g.
1. Tracelens/examples/generate_perf_report.py: PerfModel.torch_op_mapping, TreePerfAnalyzer
- df_gpu_timeline -> get_kernal_events -> GPUEventAnalyser.compute_metrics_dict(dict_events) -> breakdown
- df_kernel_launchers -> get_df_kernel_launchers -> get_kernel_launchers -> 'traverses the event tree', parent-child-grandchild?
- *df_kernel_launchers_summary* -> get_df_kernel_launchers_summary
- *df_kernel_launchers_summary_by_category* -> get_df_kernel_launchers_summary_by_category
- *df_kernel_launchers_unique_args* -> get_df_kernel_launchers_unique_args
- df_op_x -> dict_cat2names -> op_to_perf_model_class_map
- df_op_x -> loop_and_aggregate_kernels -> agg_kernels_in_subtree -> GPUEventAnalyser.compute_metrics_dict(dict_events) -> breakdown
2. Tracelens/Reporting/gererate_perf_report_jax_analysis.py: jax_analyses
- df_gpu_events_averages -> summarize_gpu_events -> JaxGPUEventAnalyser.compute_metrics_dict(dict_events) -> breakdown
- df_gpu_events_categorized
- df_xla_grouped
- df_gemms_detailed
3. JaxTrace_Analysis
- df_gemm_perf_info
- df_perf_breakdown
- gemm_roofline
- attention_roofline
- rccl
- xla_op
- hlo_op



In [None]:
# Imports
import os, sys
from pprint import pprint
import json
import pandas as pd
from itertools import chain
from collections import Counter

from TraceLens import DataLoader
from TraceLens import TraceToTree, JaxAnalyses, TraceEventUtils, TreePerfAnalyzer

# Configs
home_dir='/home/guangphu'
data_dir='perf-profiling/midj_traces'
# jax xplane.pb
trace_id='mi355x/hunyuan_t129/plugins/profile/2025_06_25_12_43_37'
trace_fname='chi-mi300x-007.ord.vultr.cpe.ice.amd.com'
trace_ext='.xplane.pb'
trace_dir=os.path.join(home_dir, data_dir, trace_id) 
trace_path=os.path.join(home_dir, data_dir, trace_id, trace_fname) + trace_ext 
os.path.exists(trace_path)
trace_path_jax = trace_path
# pytorch trace.json
trace_id='tts-traces/tts-traces-h100/bs16/rank_0/'
trace_fname='rocm-framework-h100-sxm-1_60628.1751372362949836640'
trace_ext='.pt.trace.json'
trace_dir=os.path.join(home_dir, data_dir, trace_id) 
trace_path=os.path.join(home_dir, data_dir, trace_id, trace_fname) + trace_ext
os.path.exists(trace_path)
trace_path_pytorch = trace_path


## Trace2Tree

In [None]:
data = DataLoader.load_data(trace_path_pytorch)
events_pt=data["traceEvents"]
data = DataLoader.load_data(trace_path_jax)
events_jax=data["traceEvents"]

##### Event keys

In [None]:
# Using set comprehension with itertools.chain.from_iterable
names = [e['name'] for e in events_pt]
set_keys_pt = set(chain.from_iterable(d.keys() for d in events_pt))
print('torch keys:', len(set_keys_pt), set_keys_pt)
all_keys = []
for d in events_pt:
    all_keys.extend(d.keys())
key_counts = Counter(all_keys)
print(f'number of torch events: {len(events_pt)}, unique event names: {len(set(names))}', )
print('torch event key counts:', key_counts)

names = [e['name'] for e in events_jax]
set_keys_jax = set(chain.from_iterable(d.keys() for d in events_jax))
print('jax keys', len(set_keys_jax), set_keys_jax)
all_keys = []
for d in events_jax:
    all_keys.extend(d.keys())
key_counts = Counter(all_keys)
print(f'number of jax events: {len(events_jax)}, unique event names: {len(set(names))}')
print('jax event key counts:',key_counts)

keys_common = [_key for _key in set_keys_pt if _key in set_keys_jax]
keys_uniq_jax = [_key for _key in set_keys_jax if _key not in set_keys_pt]
keys_uniq_pt = [_key for _key in set_keys_pt if _key not in set_keys_jax]
print('common kyes:', len(keys_common), keys_common)
print('unique to torch:', len(keys_uniq_pt), keys_uniq_pt)
print('unique to jax:', len(keys_uniq_jax), keys_uniq_jax)

##### Example Event

In [None]:
_event = events_jax[999] 
for _key in keys_uniq_jax:
    print('uniq to jax:', _key, _event.get(_key, None))
for _key in keys_common:
    print('common key:', _key, _event[_key])

_event = events_pt[999]
for _key in keys_uniq_pt:
    print('uniq to pt:', _key, _event.get(_key, None))
for _key in keys_common:
    print('common key:', _key, _event[_key])

##### Tree categories

In [None]:
from TraceLens import JaxTraceToTree

data = DataLoader.load_data(trace_path_jax)
data_pb = data['traceEvents']
categorizer = JaxAnalyses.prepare_event_categorizer(data_pb)
events = TraceEventUtils.non_metadata_events(data_pb)
linking_key = 'correlation_id'
metadata = TraceEventUtils.get_metadata(data_pb)
tree_jax = JaxTraceToTree(events, linking_key=linking_key, event_to_category=categorizer)
tree_jax.build_tree()

#jax tree
cats = [categorizer(e) for e in events if categorizer(e)]
print('jax number of events:', len(cats), '\ncats:', len(set(cats)), set(cats))

In [None]:
from TraceLens import TraceToTree

data = DataLoader.load_data(trace_path_pytorch)
events = data['traceEvents']
categorizer =  TraceToTree.default_categorizer
tree_pt = TraceToTree(events, event_to_category=categorizer)

# torch tree
cats = [categorizer(e) for e in events if categorizer(e)]
print('torch number of events::', len(cats), '\ncats:', len(set(cats)), set(cats))


## Performance

### 1. Temporal Breakdown
- Breakdown of time taken by the GPUs in terms of time spent in computation, communication, memory events, and idle time across all ranks.


In [None]:
kernel_events = '/home/guangphu/perf-profiling/logs/tracelens/jax/trace_analysis_results_kernel_events.csv'
df = pd.read_csv(kernel_events)
df.shape, df.head(3)

### 2. Kernel Breakdown


In [None]:
kernel_launchers = '/home/guangphu/perf-profiling/logs/tracelens/pytorch/trace_analysis_results_kernel_launchers.csv'
df = pd.read_csv(kernel_launchers)
# print(df.shape, df.head(3))


In [None]:
import pandas as pd

kernel_launchers = '/home/guangphu/perf-profiling/logs/tracelens/jax/trace_analysis_results_kernel_launchers.csv'
df = pd.read_csv(kernel_launchers)
#print(df.shape, df.head(3))
#df.loc[df['name']=='loop_convert_fusion_1111']
for pid in range(1,9):
    print(pid, df.loc[df['pid']==pid].shape)
    print('number of uniq names', len(set(df.name)))

In [None]:
df.head(3)

In [None]:
df6 = df.loc[df['pid']==6]
df6.head(3)


### 3. Trace Comparison



### 4. Memory Profiling

In [None]:
# Tracelens/examples/generate_perf_report.py
df_gpu_timeline
df_kernel_launchers
df_op_x

# Tracelens/Reporting/gererate_perf_report_jax_analysis.py
df_gpu_events_averages
df_gpu_events_categorized
df_xla_grouped
df_gemms_detailed

# JaxTrace_Analysis
df_gemm_perf_info
df_perf_breakdown
roofline_figs
