## Background

There are separate tools and scripts to collect performance data on jax traces. e.g.
1. Tracelens/examples/generate_perf_report.py: PerfModel.torch_op_mapping, TreePerfAnalyzer
- df_gpu_timeline -> get_kernal_events -> GPUEventAnalyser.compute_metrics_dict(dict_events) -> breakdown
- df_kernel_launchers -> get_df_kernel_launchers -> get_kernel_launchers -> 'traverses the event tree', parent-child-grandchild?
- *df_kernel_launchers_summary* -> get_df_kernel_launchers_summary
- *df_kernel_launchers_summary_by_category* -> get_df_kernel_launchers_summary_by_category
- *df_kernel_launchers_unique_args* -> get_df_kernel_launchers_unique_args
- df_op_x -> dict_cat2names -> op_to_perf_model_class_map
- df_op_x -> loop_and_aggregate_kernels -> agg_kernels_in_subtree -> GPUEventAnalyser.compute_metrics_dict(dict_events) -> breakdown
2. Tracelens/Reporting/gererate_perf_report_jax_analysis.py: jax_analyses
- df_gpu_events_averages -> summarize_gpu_events -> JaxGPUEventAnalyser.compute_metrics_dict(dict_events) -> breakdown
- df_gpu_events_categorized
- df_xla_grouped
- df_gemms_detailed
3. JaxTrace_Analysis
- df_gemm_perf_info
- df_perf_breakdown
- gemm_roofline
- attention_roofline
- rccl
- xla_op
- hlo_op



In [1]:
# Imports
import os, sys
from pprint import pprint
import json
import pandas as pd
from itertools import chain
from collections import Counter

from TraceLens import DataLoader
from TraceLens import TraceToTree, JaxAnalyses, TraceEventUtils, TreePerfAnalyzer

# Configs
home_dir='/home/guangphu'
data_dir='perf-profiling/midj_traces'
# jax xplane.pb
trace_id='mi355x/hunyuan_t129/plugins/profile/2025_06_25_12_43_37'
trace_fname='chi-mi300x-007.ord.vultr.cpe.ice.amd.com'
trace_ext='.xplane.pb'
trace_dir=os.path.join(home_dir, data_dir, trace_id) 
trace_path=os.path.join(home_dir, data_dir, trace_id, trace_fname) + trace_ext 
os.path.exists(trace_path)
trace_path_jax = trace_path
# pytorch trace.json
trace_id='tts-traces/tts-traces-h100/bs16/rank_0/'
trace_fname='rocm-framework-h100-sxm-1_60628.1751372362949836640'
trace_ext='.pt.trace.json'
trace_dir=os.path.join(home_dir, data_dir, trace_id) 
trace_path=os.path.join(home_dir, data_dir, trace_id, trace_fname) + trace_ext
os.path.exists(trace_path)
trace_path_pytorch = trace_path


## Trace2Tree

##### Event keys

In [None]:
data = DataLoader.load_data(trace_path_pytorch)
events_pt=data["traceEvents"]
data = DataLoader.load_data(trace_path_jax)
events_jax=data["traceEvents"]

# Using set comprehension with itertools.chain.from_iterable
names = [e['name'] for e in events_pt]
set_keys_pt = set(chain.from_iterable(d.keys() for d in events_pt))
print('torch keys:', len(set_keys_pt), set_keys_pt)
all_keys = []
for d in events_pt:
    all_keys.extend(d.keys())
key_counts = Counter(all_keys)
print(f'number of torch events: {len(events_pt)}, unique event names: {len(set(names))}', )
print('torch event key counts:', key_counts)

names = [e['name'] for e in events_jax]
set_keys_jax = set(chain.from_iterable(d.keys() for d in events_jax))
print('jax keys', len(set_keys_jax), set_keys_jax)
all_keys = []
for d in events_jax:
    all_keys.extend(d.keys())
key_counts = Counter(all_keys)
print(f'number of jax events: {len(events_jax)}, unique event names: {len(set(names))}')
print('jax event key counts:',key_counts)

keys_common = [_key for _key in set_keys_pt if _key in set_keys_jax]
keys_uniq_jax = [_key for _key in set_keys_jax if _key not in set_keys_pt]
keys_uniq_pt = [_key for _key in set_keys_pt if _key not in set_keys_jax]
print('common kyes:', len(keys_common), keys_common)
print('unique to torch:', len(keys_uniq_pt), keys_uniq_pt)
print('unique to jax:', len(keys_uniq_jax), keys_uniq_jax)

##### Example Event

In [None]:
_event = events_jax[999] 
for _key in keys_uniq_jax:
    print('uniq to jax:', _key, _event.get(_key, None))
for _key in keys_common:
    print('common key:', _key, _event[_key])

_event = events_pt[999]
for _key in keys_uniq_pt:
    print('uniq to pt:', _key, _event.get(_key, None))
for _key in keys_common:
    print('common key:', _key, _event[_key])

##### Tree categories

In [2]:
from TraceLens import JaxTraceToTree

data = DataLoader.load_data(trace_path_jax)
data_pb = data['traceEvents']
categorizer = JaxAnalyses.prepare_event_categorizer(data_pb)
events = TraceEventUtils.non_metadata_events(data_pb)
linking_key = 'correlation_id'
metadata = TraceEventUtils.get_metadata(data_pb)
tree_jax = JaxTraceToTree(events, linking_key=linking_key, event_to_category=categorizer)
tree_jax.build_tree(metadata=metadata, pb_file_name=trace_path_jax)

#jax tree
cats = [categorizer(e) for e in tree_jax.events if categorizer(e)]
print('jax number of events:', len(cats), '\ncats:', len(set(cats)), set(cats))

2025-08-19 10:39:58.549500: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-08-19 10:39:58.571337: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-08-19 10:39:58.710472: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-08-19 10:39:58.839442: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1755599998.950806  167785 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1755599998.98

Building tree with add_python_func=False
Building CPU op tree with add_python_func=False
jax number of events: 1982601 
cats: 5 {'cpu_op', 'kernel', 'python function', 'Unknown', 'memcpy'}


In [None]:
data = DataLoader.load_data(trace_path_pytorch)
data = data['traceEvents']
categorizer_pt = TraceToTree.default_categorizer 
tree_pt = TraceToTree(data, event_to_category=categorizer)

# torch tree
cats = [categorizer_pt(e) for e in tree_pt.events if categorizer_pt(e)]
print('torch number of events::', len(cats), '\ncats:', len(set(cats)), set(cats))

torch number of events:: 1895689 
cats: 14 {'cpu_op', 'kernel', 'cuda_runtime', 'ac2g', 'cuda_driver', 'user_annotation', 'python_function', 'overhead', 'cpu_instant_event', 'fwdbwd', 'gpu_memcpy', 'gpu_memset', 'gpu_user_annotation', 'Trace'}


##### GEMM

In [12]:
from TraceLens.PerfModel.torch_op_mapping import categorize_torch_op, dict_cat2names
from TraceLens.PerfModel.jax_op_mapping import categorize_jax_op

In [18]:
tree_pt.events[0]

{'ph': 'X',
 'cat': 'user_annotation',
 'name': 'ProfilerStep#2',
 'pid': 60628,
 'tid': 60628,
 'ts': 7850735204323.829,
 'dur': 245104.15,
 'args': {'External id': 1,
  'Record function id': 0,
  'finished': True,
  'Ev Idx': 0},
 <TraceKeys.UID: 'UID'>: 0,
 <TraceKeys.TimeEnd: 't_end'>: 7850735449427.9795}

In [None]:
op_categorizer = categorize_jax_op
gemms = [e for e in tree_jax.events if categorizer(e)=='kernel' and op_categorizer(e) =='GEMM']
print(len(gemms))
print(gemms[0]['metadata']['backend_config'])


16680
backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"64","block_n":"16","block_k":"32","split_k":"16","num_stages":"1","num_warps":"4","num_ctas":"1"}},"force_earliest_schedule":false}


In [None]:
import json
from collections import Counter

def get_backend_config(gemm_event):
    backend_config = gemm_event['metadata']['backend_config'].split('=')[1]
    dict_backend_config = json.loads(backend_config)
    return dict_backend_config.keys() #['fusion_backend_config']


In [None]:
gemm_event = gemms[100]
Counter([get_backend_config(gemm)['kind'] for gemm in gemms])

KeyError: 'fusion_backend_config'

In [None]:
pprint(dict_backend_config.get('triton_gemm_config', None))
pprint(gemms[0]['metadata']['operands'])

AttributeError: 'str' object has no attribute 'read'

In [23]:
gemms[0]['args']

{'correlation_id': 2687,
 'theoretical_occupancy_pct': 0,
 'occupancy_min_grid_size': 0,
 'occupancy_suggested_block_size': 0,
 'tf_op': 'XlaModule:',
 'name': 'jit(train_step)/jit(main)/dot_general',
 'hlo_op': 'gemm_fusion_dot.8020.0',
 'hlo_module': 'jit_train_step',
 'is_eager': 0}

In [None]:
event = {'pid': 1, 'tid': 1, 'name': 'gemm_fusion_dot_8020_0', 
         'ts': 169717.392, 'dur': 16.585, 'ph': 'X', 
         'args': {'correlation_id': 2687, 'theoretical_occupancy_pct': 0, 
                  'occupancy_min_grid_size': 0, 'occupancy_suggested_block_size': 0, 
                  'tf_op': 'XlaModule:', 
                  'name': 'jit(train_step)/jit(main)/dot_general', 
                  'hlo_op': 'gemm_fusion_dot.8020.0', 
                  'hlo_module': 'jit_train_step', 'is_eager': 0}, 
        'z': 1, 'UID': 225, 't_end': 169733.97699999998,
        'process': {'process_name': 'chi-mi300x-007.ord.vultr.cpe.ice.amd.com /device:GPU:0', 'process_sort_index': 1}, 
        'thread': {'thread_name': 'Stream #1(Kernel,Memset)', 'thread_sort_index': 1}, 
        'parent': 1821255, 'tree': True, 
        'metadata': {'output': 'bf16[16,6144]{1,0}', 'operands': ['bf16[1,3072]{1,0}', 'bf16[3072,6144]{1,0}'], 'computation': 'rest', 'metadata': 'metadata={op_name="jit(train_step)/jit(main)/dot_general" source_file="/usr/local/lib/python3.10/dist-packages/flax/nnx/nn/linear.py" source_line=367 deduplicated_name="gemm_fusion_dot.7978.0"', 'backend_config': 'backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"64","block_n":"16","block_k":"32","split_k":"16","num_stages":"1","num_warps":"4","num_ctas":"1"}},"force_earliest_schedule":false}'}, 
        'cat': 'kernel', 'gpu_events': [225] 
        }


In [None]:
from TraceLens.PerfModel.jax_op_mapping import categorize_jax_op, dict_cat_to_perf_model
from TraceLens.TreePerf.jax_analyses import JaxAnalyses

profile_filepath = trace_path_jax
data = DataLoader.load_data(profile_filepath)
data_pb = data['traceEvents']
categorizer = JaxAnalyses.prepare_event_categorizer(data_pb)

metrics_event = {'cat': categorizer(event), 
                'name': event['name'],
                'UID': event['UID'],
        'pid': event['pid'], 'tid': event['tid'],
        'external_id': event['args'].get('External id')}

## Performance

### 1. Temporal Breakdown
- Breakdown of time taken by the GPUs in terms of time spent in computation, communication, memory events, and idle time across all ranks.


In [None]:
kernel_events = '/home/guangphu/perf-profiling/logs/tracelens/jax/trace_analysis_results_kernel_events.csv'
df = pd.read_csv(kernel_events)
df.shape, df.head(3)

### 2. Kernel Breakdown


In [None]:
kernel_launchers = '/home/guangphu/perf-profiling/logs/tracelens/pytorch/trace_analysis_results_kernel_launchers.csv'
df = pd.read_csv(kernel_launchers)
# print(df.shape, df.head(3))


In [None]:
import pandas as pd

kernel_launchers = '/home/guangphu/perf-profiling/logs/tracelens/jax/trace_analysis_results_kernel_launchers.csv'
df = pd.read_csv(kernel_launchers)
#print(df.shape, df.head(3))
#df.loc[df['name']=='loop_convert_fusion_1111']
for pid in range(1,9):
    print(pid, df.loc[df['pid']==pid].shape)
    print('number of uniq names', len(set(df.name)))

In [None]:
df.head(3)

In [None]:
df6 = df.loc[df['pid']==6]
df6.head(3)


### 3. Trace Comparison



### 4. Memory Profiling

In [None]:
# Tracelens/examples/generate_perf_report.py
df_gpu_timeline
df_kernel_launchers
df_op_x

# Tracelens/Reporting/gererate_perf_report_jax_analysis.py
df_gpu_events_averages
df_gpu_events_categorized
df_xla_grouped
df_gemms_detailed

# JaxTrace_Analysis
df_gemm_perf_info
df_perf_breakdown
roofline_figs
