# JAX NCCL Analyser Example

This notebook demonstrates how to use the `JaxNcclAnalyser` class to analyze bandwidth performance of collective communication operations in JAX traces.

## Overview

The `JaxNcclAnalyser` provides functionality to:
- Load and parse JAX trace files (protobuf format)
- Extract collective communication events (all-reduce, all-gather, etc.)
- Calculate bandwidth metrics for each collective operation
- Analyze performance across different collective types
- Generate summary statistics

#### Note

- Bus bandwidth calculations in context of Jax is work in progress. 

#### Tracelens installation

See installation section in README.md file.

In [None]:
# Import required modules
import sys
sys.path.append('/path/TraceLens') # Path to Tracelens
from TraceLens.NcclAnalyser import JaxNcclAnalyser
import pandas as pd
import numpy as np

## Step 1: Configure Input Parameters

First, set up the input parameters needed for the analyser:
- **traces_dir**: Root directory containing trace files
- **node_to_pb_file_mapping**: Mapping from node to their protobuf trace files
- **world_size**: Total number of GPUs in the distributed setup

In [None]:
# Configure trace analysis parameters
traces_dir = "/path/to/your/trace/files/directory" # Path to the root traces dir
world_size = 32 # Distributed training world size, modify as per your case

# Create mapping from node to protobuf files
# Each entry maps a node rank to its corresponding .xplane.pb file
# You can use a util script (NcclAnalyser/util/node_rank_to_protobuf_file_mapping.py) to get
# this mapping. Example usage :
# 
# node_to_pb_file_mapping, total_nodes = node_rank_to_protobuf_file_mapping.get_node_rank_protobuf_mapping(traces_dir)
# 
# Manually check the mapping generated by the script. It might not be accurate in some cases.
# Or you can specify such mapping here manually. 
# Node rank should be correct as it is used to map pid in traceevents to global gpu ranks
node_to_pb_file_mapping = {}

print(f"Traces directory: {traces_dir}")
print(f"World size: {world_size}")
print(f"Number of trace files: {len(node_to_pb_file_mapping)}")

## Step 2: Initialize the Analyser

Create an instance of the `JaxNcclAnalyser` class. This will automatically load and parse the trace files.

In [None]:
# Initialize the analyser - this loads all trace data
print("Initializing JaxNcclAnalyser...")
nccl_analyser = JaxNcclAnalyser(traces_dir, node_to_pb_file_mapping, world_size)
print("✓ Analyser initialized successfully")

## Step 3: Build the Analysis DataFrame

Convert the trace data into a structured DataFrame for analysis. Each row represents a collective communication event on a specific GPU rank.

In [None]:
# Build the long-format dataframe
print("Building analysis dataframe...")
df = nccl_analyser.build_df_long()

print(f"✓ DataFrame created with {len(df)} events")
print(f"Columns: {list(df.columns)}")
print(f"Unique collective operations: {df['collective_name'].nunique()}")
print(f"GPU ranks represented: {sorted(df['gpu_rank'].unique())}")

## Step 4: Run Bandwidth Analysis

Now we'll analyze the bandwidth performance of all collective operations. This calculates:
- Bandwidth for each collective operation instance
- Statistics across different slices/iterations
- Performance metrics per replica group

In [None]:
# Analyze bandwidth for all collective operations
print("Running bandwidth analysis for all collective operations...")
bandwidth_results = nccl_analyser.analyze_all_collectives_from_df(df)

print(f"\n✓ Analysis complete for {len(bandwidth_results)} collective operations")

## Step 5: Analyze by Collective Types

Group the results by collective operation types (all-reduce, all-gather, etc.) to get aggregate performance statistics.

In [None]:
# Analyze performance by collective types
collective_types_data, summary_stats = nccl_analyser.analyze_collective_types_from_df(bandwidth_results)

# Display the summary table
nccl_analyser.display_summary_table(summary_stats)

## Step 6: Detailed Analysis of Specific Collective (OPTIONAL)

Let's examine one collective operation in detail to understand the analysis methodology.

In [None]:
# Pick the first collective for detailed analysis
if bandwidth_results:
    collective_name = list(bandwidth_results.keys())[0]
    result = bandwidth_results[collective_name]
    
    print(f"=== DETAILED ANALYSIS: {collective_name} ===")
    print(f"Average bandwidth: {result['avg_bandwidth']:.2f} GB/s")
    print(f"Number of slices: {result['num_slices']}")
    print(f"Data size: {result['data_size_bytes']:,} bytes ({result['data_size_bytes']/(1024**3):.3f} GB)")
    
    if result['bandwidths']:
        bandwidths = np.array(result['bandwidths'])
        print(f"Bandwidth range: {bandwidths.min():.2f} - {bandwidths.max():.2f} GB/s")
        print(f"Standard deviation: {bandwidths.std():.2f} GB/s")
    
    # Show detailed slice information
    print(f"\n=== SLICE-BY-SLICE BREAKDOWN ===")
    for i, slice_info in enumerate(result['slice_info'][:3]):  # Show first 3 slices
        print(f"\nSlice {i+1}: {slice_info['collective_id']}")
        print(f"  Number of replica groups: {slice_info['num_groups']}")
        print(f"  Slice average bandwidth: {slice_info['slice_avg_bandwidth']:.2f} GB/s")
        print(f"  Slice bandwidth range: {slice_info['slice_min_bandwidth']:.2f} - {slice_info['slice_max_bandwidth']:.2f} GB/s")
        
        # Show group details
        for group_detail in slice_info['group_details'][:2]:  # Show first 2 groups per slice
            print(f"    Group {group_detail['group_idx']}: {group_detail['bandwidth_gbps']:.2f} GB/s")
            print(f"      GPU group: {group_detail['gpu_group']}")
            print(f"      Slowest GPU: rank {group_detail['slowest_gpu_rank']} ({group_detail['duration_us']:.0f} μs)")