<!--
Copyright (c) 2024 - 2025 Advanced Micro Devices, Inc. All rights reserved.

See LICENSE for license information.
-->

# JAX NCCL Analyser Example

This notebook demonstrates how to use the `JaxNcclAnalyser` class to analyze bandwidth performance of collective communication operations in JAX traces.

## Overview

The `JaxNcclAnalyser` provides functionality to:
- Load and parse JAX trace files (protobuf format)
- Extract collective communication events (all-reduce, all-gather, etc.)
- Calculate bandwidth metrics for each collective operation
- Analyze performance across different collective types
- Generate summary statistics 

## Tracelens installation

For detailed installation instructions, dependencies, and setup requirements, please refer to the [TraceLens README.md](https://github.com/ROCm/TraceLens/blob/main/README.md) in the main repository.

## Step 0: Import Required Libraries

Import the necessary modules for JAX NCCL analysis and data manipulation.

In [1]:
from TraceLens import JaxNcclAnalyser
import pandas as pd
import numpy as np

## Step 1: Configure Input Parameters

## Trace Analysis Configuration

This section configures the parameters for analyzing distributed training traces from multiple nodes.

### Key Configuration Steps:

1. **Set trace directory**: Point to the root directory containing all trace files
2. **Configure world size**: Number of processes participating in distributed training setup (typically one process per GPU) 
3. **Create node-to-protobuf mapping**: Maps each node rank to its `.xplane.pb` trace file path

### Expected Directory Structure

The `traces_dir` should contain trace files organized by nodes, typically with this structure:

```
traces_dir/
├── node_0/
│   ├── plugins/
│   │   └── profile/
│   │       └── xyz/
│   │           └── xyz.xplane.pb
│   └── xla_dumps/
│       └── module_xyz.jit_train_step.xxxxx_gpu_after_optimizations.txt
├── node_1/
│   ├── plugins/
│   │   └── profile/
│   │       └── xyz/
│   │           └── xyz.xplane.pb
│   └── xla_dumps/
│       └── module_xyz.jit_train_step.xxxxx_gpu_after_optimizations.txt
└── ...
```

**Note**: The exact paths may vary depending on how you collected the traces. 

### Node Rank Mapping

The node rank determines how GPU ranks are calculated using the formula: 
`global_gpu_rank = node_rank * gpus_per_node + local_gpu_id`

### Mapping Options:

**Option 1: Auto-generation** (Use with caution)
- Utility script available but may be incorrect for your specific setup
- Always verify the generated mapping matches your actual configuration

**Option 2: Manual mapping** (Recommended)
- Ensures accuracy by explicitly defining each node's trace file path
- Guarantees correct node rank assignment

### Common Configurations:
- **32 GPUs**: 4 nodes × 8 GPUs/node (MI300X nodes)
- **64 GPUs**: 8 nodes × 8 GPUs/node  (MI300X nodes)
- **128 GPUs**: 16 nodes × 8 GPUs/node (MI300X nodes)

In [None]:
# Configure trace analysis parameters
traces_dir = "/path/to/your/trace/files/directory"
# Modify world_size as per your case
world_size = 32  #Number of processes participating in distributed training setup (typically one process per GPU)

# Option 1: Auto-generate mapping (verify output)
from TraceLens.NcclAnalyser.util import node_rank_to_protobuf_file_mapping
# node_to_pb_file_mapping, total_nodes = node_rank_to_protobuf_file_mapping.get_node_rank_protobuf_mapping(traces_dir)

# Option 2: Manual mapping (recommended)
node_to_pb_file_mapping = {
    # 0: "/path/to/traces_dir/node_0/plugins/profile/xyz/xyz.xplane.pb",
    # 1: "/path/to/traces_dir/node_1/plugins/profile/xyz/xyz.xplane.pb",
    # 2: "/path/to/traces_dir/node_2/plugins/profile/xyz/xyz.xplane.pb",
    # ... continue for all nodes
}

# Validation
gpus_per_node = 8
expected_nodes = world_size // gpus_per_node

print(f"Traces directory: {traces_dir}")
print(f"World size: {world_size} GPUs")
print(f"Expected nodes: {expected_nodes} (assuming {gpus_per_node} GPUs per node)")
print(f"Configured trace files: {len(node_to_pb_file_mapping)}")

## Step 2: Initialize the Analyser

Create an instance of the `JaxNcclAnalyser` class. This will automatically load and parse the trace files.

### What Happens During Initialization

The `JaxNcclAnalyser` constructor automatically performs several key initialization steps:

1. **Load and parse trace data** from protobuf files using `JaxTraceToTree`
2. **Extract NCCL/RCCL communication events** from the trace data
3. **Build collective information datframe** by parsing XLA dump files

### XLA dumps parsing

The `JaxNcclAnalyser` uses an integrated XLA parser to extract collective operation metadata from XLA dump files, providing crucial information about replica groups, data sizes, and operation types for accurate bandwidth calculations.

#### Expected Directory Structure

The analyser expects XLA dump files in this structure:
```
traces_dir/
├── node_0/
│   ├── plugins/profile/xyz/xyz.xplane.pb
│   └── xla_dumps/
│       └── *jit_train_step.xxxxx_gpu_after_optimizations.txt
├── node_1/
│   ├── plugins/profile/xyz/xyz.xplane.pb  
│   └── xla_dumps/
│       └── *jit_train_step.xxxxx_gpu_after_optimizations.txt
└── ...
```

#### Automatic XLA parsing

**Important**: The analyser automatically builds XLA file mapping by:
- Scanning `traces_dir` for XLA dump files matching `*jit_train_step.gfx942_gpu_after_optimizations.txt`
- Creating node-to-XLA-file mapping based on directory structure similarity
- No manual `node_to_xla_file_map` parameter required

**Manual XLA File Mapping (Optional)**: If the automatic detection doesn't work for your directory structure, you can manually specify the `node_to_xla_file_map` parameter when creating the `JaxNcclAnalyser` instance:
```python
node_to_xla_file_map = {
    0: "/path/to/node_0/xla_dumps/module_xyz.jit_train_step.xxxxx_gpu_after_optimizations.txt",
    1: "/path/to/node_1/xla_dumps/module_xyz.jit_train_step.xxxxx_gpu_after_optimizations.txt",
    # ... continue for all nodes
}
nccl_analyser = JaxNcclAnalyser(traces_dir, node_to_pb_file_mapping, world_size, node_to_xla_file_map=node_to_xla_file_map)
```

#### XLA Parser Capabilities

The integrated `XLACollectiveParser` provides:
- **Collective Detection**: Identifies all-reduce, all-gather, reduce-scatter, all-to-all, collective-permute operations
- **Replica Group Parsing**: Handles complex formats including explicit groups `{{0,1,2,3},{4,5,6,7}}` and Replica group specifications using IotaTileAssignment format like`[4,8]<=[32]`
- **Data Size Calculation**: Parses tensor shards dimensions and data types for precise bandwidth calculations

### Processing Details

The analyser will:
- Process each node's protobuf trace file (`.xplane.pb`)
- Filter for collective communication events
- Parse XLA dump files to extract replica groups and data sizes
- Build the foundational data structures needed for bandwidth analysis

### Parameters

- **`traces_dir`**: Directory containing trace files
- **`node_to_pb_file_mapping`**: Mapping from nodes to protobuf file paths  
- **`world_size`**: Total number of processes in the distributed setup
- **`node_to_xla_file_map`**: Optional mapping to XLA files (**automatically built if None**)

In [None]:
# Initialize the analyser - this loads all trace data
print("Initializing JaxNcclAnalyser...")
nccl_analyser = JaxNcclAnalyser(traces_dir, node_to_pb_file_mapping, world_size)
print("✓ Analyser initialized successfully")

## Step 3: Build the Analysis DataFrame

Convert the loaded trace data into a structured long-format DataFrame for analysis. This core function processes all the trace events from each node and constructs a comprehensive table where **each row represents a single collective communication event on a specific GPU rank**.

### What `build_df_long()` Does:

**Data Transformation**: Processes the internal `node_to_trace_data` storage and flattens it into a structured table format

**Event Enrichment**: For each collective communication event, extracts and combines:
- **Timing information**: Timestamps (`ts`) and durations (`dur`) from trace events
- **Process mapping**: Node ID, GPU rank, and pid for distributed setup tracking  
- **Collective metadata**: Operation names (`collective_name`), HLO modules, and correlation IDs
- **Querying XLA dump**: Extracts Replica groups and exchanged data during a collective event by lookup on parsed XLA dump dataframe

#### Core Processing Steps

The `build_df_long()` method performs these key operations:

1. **Event Extraction**: Processes each collective communication event from all nodes
2. **Row Construction**: For each event, creates a row with:
   - Basic info: `node`, `gpu_rank`, `pid`, `ts` (timestamp), `dur` (duration)
   - Collective metadata: `collective_name`, `hlo_module`, `correlation_id`
   - XLA data: `replica_groups`, `data(bytes)` (looked up from parsed XLA dumps)

3. **Indexing in group**: Groups events by `collective_name`, `pid`, and `node`, then assigns sequential indices based on timestamp order. 
4. **Collective ID Generation**: Creates unique identifiers using format `{collective_name}_{index_in_group}` 

#### Key Outputs

- **Per-GPU Events**: Each row represents one collective event on a specific GPU rank
- **Unique IDs**: `collective_id` enables grouping all GPUs participating in the same collective operation

**Output**: A comprehensive DataFrame with columns including `node`, `gpu_rank`, `pid`, `ts`, `dur`, `collective_name`, `replica_groups`, `data(bytes)`, `collective_id`.

In [None]:
# Build the long-format dataframe
print("Building analysis dataframe...")
df = nccl_analyser.build_df_long()

## Step 4: Run Bandwidth Analysis

Now we'll analyze the bandwidth performance of all collective operations. This calculates:
- Bandwidth for each collective operation instance
- Statistics across different slices (multiple communication events associated with each hlo-op)
- Performance metrics per replica group

### Core Processing in `analyze_all_collectives_from_df()`

This method performs **comprehensive bandwidth analysis** of all collective communication operations found in the DataFrame. Here's the detailed breakdown of the core processing:

#### 1. **Collective Discovery & Iteration**
- **Extracts unique collective operations** from the DataFrame using `df["collective_name"].unique()`
- **Iterates through each collective type** (e.g., `all-reduce`, `all-gather`, `reduce-scatter`)

#### 2. **Per-Collective Bandwidth Calculation**
For each collective operation, the method calls `_calculate_collective_bandwidth_from_df()` which performs:

##### **Slice Analysis** 
- **Groups by collective_id**: Each collective operation typically has multiple "slices" (communication events)
- **Per-slice processing**: Analyzes bandwidth for each slice separately

##### **Replica Group Bandwidth Calculation**
- **Identifies participating GPUs**: Uses XLA dumps parsed replica groups 
- **Calculates timing**: Finds the fastest GPU in each replica group
- **Algorithmic bandwidth formula**: `data_bytes / fastest_gpu_duration` 
- **Bus bandwidth calculation**: Applies collective-specific scaling factors to algorithimic bandwidth to calculate bus bandwidth


##### **Example Processing Flow for a collective**:

**All data here is placeholder to demonstrate structure**
```
Analyzing: xyz.28 (collective)
Number of slices: xy

  Slice: xyz.28_0
    Available GPU ranks in data: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
      Analyzing group 0: [0, 1, 2, 3, 4, 5, 6, 7]
        Matching GPUs in data: [0, 1, 2, 3, 4, 5, 6, 7]
        Algorithmic Bandwidth: xyz.xy GB/s
        Bus Bandwidth: abc.ab GB/s (scaler: 0.ab)
        Fastest GPU: rank 3, duration: xyz.cd μs
        Group data size: xyz bytes (0.ab GB)
        Algorithmic bytes: wxyz (0.abc GB)
        Replica group size: 8
        ...

    Slice summary:
      Algorithmic BW: avg=xyz.xy GB/s
      Bus BW: avg=xyz.xy GB/s      

    ...

Overall Results:
  Average algorithmic bandwidth: xyz.xy GB/s
  Average bus bandwidth: abc.ab GB/s
```


#### 3. **Statistical Aggregation**
For each collective operation, computes:
- **Per-slice bandwidths**: Array of bandwidth measurements across each slice
- **Overall averages**: Mean bandwidth across all slices
- **Data size information**: Bytes exchanged per operation 


#### 4. **Results Structure**
Returns a comprehensive dictionary which can be used for further analysis as per user requirements:

**All data here is placeholder to demonstrate structure**
```python
{
  'xyz.28': {
    'bandwidths': [100, 105, 110, 115, 120...],  # Per-slice algorithmic BW (GB/s)
    'bus_bandwidths': [90, 95, 100, 105, 110 ...],  # Per-slice bus BW (GB/s)  
    'avg_bandwidth': 100,                       # Overall algorithmic average (GB/s)
    'avg_bus_bandwidth': 98,                   # Overall bus average (GB/s)
    'num_slices': 35,                              # Number of slices
    'data_size_bytes': 100000                   # 0.095 MB per group
    'slice_info': [                                # Detailed per-slice data
      {
        'collective_id': 'xyz.28_0',
        'group_bandwidths': [110, 115, 120, 125],  # Per replica group BW
        'group_bus_bandwidths': [90, 95, 100, 105],  # Per replica group bus BW
        'group_details': [
          {
            'group_idx': 0,
            'gpu_group': [0, 1, 2, 3, 4, 5, 6, 7],
            'bandwidth_gbps': 110,
            'algorithmic_bandwidth_gbps': 110,
            'bus_bandwidth_gbps': 90,
            'bus_bandwidth_scaler': 0.875,
            'duration_us': 1000,
            'fastest_gpu_rank': 3,...        
          }...
        ],
        'slice_avg_bandwidth': 117,
        'slice_min_bandwidth': 100,
        'slice_max_bandwidth': 120,
        'num_groups': 4
      }...
    ]
  }...
}
```

In [None]:
# Analyze bandwidth for all collective operations
print("Running bandwidth analysis for all collective operations...")
bandwidth_results = nccl_analyser.analyze_all_collectives_from_df(df)
print(f"\n✓ Analysis complete for {len(bandwidth_results)} collective operations")

## Step 5: Analyze by Collective Types

Group the results by collective operation types (all-reduce, all-gather, etc.) to get aggregate performance statistics.

### Core Processing Functions

**`analyze_collective_types_from_df(bandwidth_results)`**:
- **Groups collectives by type**: Categorizes individual collective operations (e.g., "all-reduce-start.1", "all-reduce-start.2") into collective types ("all-reduce")
- **Aggregates bandwidth data**: Combines all bandwidth measurements for each collective type across all slices and operations
- **Computes statistics**: Calculates mean for both algorithmic and bus bandwidth per collective type
- **Returns**: Dictionary of collective type data and structured summary statistics

**`display_summary_table(summary_stats)`**:
- **Formats results**: Creates a formatted table displaying performance statistics for each collective type
- **Shows key metrics**: Displays average algorithimic and bus bandwidths and operation counts
- **Outputs**: A summary table with collective performance comparison

In [None]:
# Analyze performance by collective types
collective_types_data, summary_stats = nccl_analyser.analyze_collective_types_from_df(bandwidth_results)
# Display the summary table
nccl_analyser.display_summary_table(summary_stats)

## Step 6: Detailed Analysis of Specific Collective (OPTIONAL)

Let's examine one collective operation in detail to understand the analysis methodology.

In [None]:
# Analyze the specific collective: xyz
target_collective = "xyz" # Modify as per your case

if bandwidth_results:
    # Search for the target collective
    collective_name = None
    for name in bandwidth_results.keys():
        if target_collective in name:
            collective_name = name
            break
    
    if collective_name:
        result = bandwidth_results[collective_name]
        
        print(f"=== DETAILED ANALYSIS: {collective_name} ===")
        print(f"Average algorithmic bandwidth: {result['avg_bandwidth']:.2f} GB/s")
        print(f"Average bus bandwidth: {result['avg_bus_bandwidth']:.2f} GB/s")
        print(f"Number of slices: {result['num_slices']}")

        
        if result['bandwidths']:
            bandwidths = np.array(result['bandwidths'])
            bus_bandwidths = np.array(result['bus_bandwidths'])
            print(f"Algorithmic bandwidth range: {bandwidths.min():.2f} - {bandwidths.max():.2f} GB/s")
            print(f"Bus bandwidth range: {bus_bandwidths.min():.2f} - {bus_bandwidths.max():.2f} GB/s")
            print(f"Algorithmic bandwidth std dev: {bandwidths.std():.2f} GB/s")
            print(f"Bus bandwidth std dev: {bus_bandwidths.std():.2f} GB/s")
        
        # Show detailed slice information
        print(f"\n=== SLICE-BY-SLICE BREAKDOWN ===")
        for i, slice_info in enumerate(result['slice_info'][:3]):  # Show first 3 slices
            print(f"\nSlice {i+1}: {slice_info['collective_id']}")
            print(f"  Number of replica groups: {slice_info['num_groups']}")
            print(f"  Slice average algorithmic bandwidth: {slice_info['slice_avg_bandwidth']:.2f} GB/s")
            print(f"  Slice algorithmic bandwidth range: {slice_info['slice_min_bandwidth']:.2f} - {slice_info['slice_max_bandwidth']:.2f} GB/s")
            
            # Show group details
            for group_detail in slice_info['group_details'][:2]:  # Show first 2 groups per slice
                print(f"    Group {group_detail['group_idx']}: {group_detail['algorithmic_bandwidth_gbps']:.2f} GB/s (alg), {group_detail['bus_bandwidth_gbps']:.2f} GB/s (bus)")
                print(f"      GPU group: {group_detail['gpu_group']}")
                print(f"      Fastest GPU: rank {group_detail['fastest_gpu_rank']} ({group_detail['duration_us']:.0f} μs)")
    else:
        print(f"❌ Collective '{target_collective}' not found in bandwidth results")
        print("Available collectives:")
        for name in sorted(bandwidth_results.keys()):
            print(f"  - {name}")
else:
    print("❌ No bandwidth results available")