In [None]:
from collections import defaultdict
import functools
from jax_nsys import (
    align_profiler_data_timestamps,
    apply_warmup_heuristics,
    display_flamegraph,
    ensure_compiled_protos_are_importable,
    generate_compilation_statistics,
    load_profiler_data,
    remove_autotuning_detail,
    xla_module_metadata,
)
import matplotlib.pyplot as plt
import numpy as np

In [None]:
# Make sure that the .proto files under protos/ have been compiled to .py, and
# that those generated .py files are importable.]
compiled_dir = ensure_compiled_protos_are_importable()

In [None]:
# Load the runtime profile data
all_data = load_profiler_data()
# Remove some detail from the autotuner
all_data = remove_autotuning_detail(all_data)
# Align GPU timestamps across profiles collected by different Nsight Systems processes
all_data, alignment_metadata = align_profiler_data_timestamps(all_data)
# Use heuristics to partition the profile data into initialisation and steady state
# execution
init, steady_state = apply_warmup_heuristics(all_data)

## Data format

First, look at the high-level format of the profile data frames.
The `module` frame has a single row for each XLA module execution, which typically corresponds to a single JITed JAX function:

In [None]:
assert steady_state.module is not None
steady_state.module

This data frame has a three-level index:
- `ProgramId` is an integer ID that uniquely identifies the XLA module
- This is the `ProgramExecution`-th execution of the module within the profiles. You may see this starting from 1, not 0, because of the `warmup_removal_heuristics` option passed to `load_profiler_data`.
- `Device` is the global (across multiple nodes and processes) index of the GPU on which the module execution took place

The columns are as follows:
- `Name`: the name of the XLA module; this should always be the same for a given `ProgramId`
- `NumThunks`: the number of thunks executed inside this module execution
- `ProjStartMs`: the timestamp of the start of the module execution on the GPU, in milliseconds
- `ProjDurMs`: the duration of the module execution on the GPU, in milliseconds
- `OrigStartMs`: the timestamp of the start of the module launch **on the host**, in milliseconds. *i.e.* `ProjStartMs-OrigStartMs` is something like the launch latency of the first kernel
- `OrigDurMs`: the duration of the module launch **on the host**, in milliseconds
- `LocalDevice`: the index within the node/slice of the GPU on which the module execution took place
- `Process`: the global (across multiple nodes) index of the process
- `Slice`: the global index of the node/slice; devices within the same node/slice should have faster interconnects than to devices in different slices

Another profile data frame for GPU execution is `thunk`, which has a single row for each XLA thunk.
Loosely, each XLA module contains a series of thunks, and each thunk launches a GPU kernel.
In reality, thunks can be nested and may launch multiple kernels, but this data frame still provides the most granular distribution available of GPU execution time across the XLA module:

In [None]:
assert steady_state.thunk is not None
steady_state.thunk

Here the index has four levels. `ProgramId`, `ProgramExecution` and `Device` have the same meanings as in `module_df`.
The fourth level (in the 3rd position) shows that this row is the `ThunkIndex`-th thunk within the `ProgramExecution`-th execution of XLA module `ProgramId`.
Note that a given thunk can be executed multiple times within the same module, so indexing on the thunk name would not be unique.

The columns are as follows:
- `Name`: the name of the thunk; this should be unique within a given `ProgramId` and can be used as a key to look up XLA metadata
- `ProjStartMs`, `OrigStartMs`, `OrigDurMs`: see above, same meaning as in `module_df`.
- `Communication`: does this thunk represent communication between GPUs (*i.e.* a NCCL collective)? XLA overlaps communication and computation kernels, and `load_profiler_data` triggers an overlap calculation. `ProjDurMs` for a communication kernel shows only the duration that was **not** overlapped with computation kernels, while `ProjDurHiddenMs` shows the duration that **was** overlapped.
- This is the `ThunkExecution`-th execution of this thunk for this `(ProgramId, ProgramExecution, Device)`

The third data frame does not show any GPU execution, but is rather a host-side trace:

In [None]:
assert init.compile is not None
init.compile

Here the index has two levels; `ProfileName` is important when multiple reports are being analysed together (*i.e.* using `nsys-jax-combine` having run multiple `nsys-jax` processes), as the `RangeId` values referred to in `ParentId` and `RangeStack` are not unique across different `ProfileName` values.

The fourth data frame is derived from the "thunk" frame, but focuses on device-device collective communication:

In [None]:
assert steady_state.communication is not None
steady_state.communication

The index structure, and many of the columns, are equivalent to `thunk_df`. Additional columns are:

- `MessageSize`: the message size of the collective in bytes; this aims to follow the same conventions as the NCCL tests
- `Collective`: the type of collective communication
- `CollectiveSize`: the number of devices participating in each instance of the collective. For example, if a JAX program is executing across 8 devices, but a particular collective involves two sub-groupings of 4 devices communicating with each other, `CollectiveSize` would be 4.
- `AlgorithmBandwidthGBPerSec` and `BusBandwidthGBPerSec`: see https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md#bandwidth; note the units are GB (base 1000) not GiB (base 1024)

Before going further, show the corrections that were applied by `align_profiler_data_timestamps` above:

In [None]:
# If no collectives were profiled, this metadata is not available
if len(alignment_metadata):
    seen_devices = [False] * alignment_metadata["collective_size"]
    data: list[list[float]] = [[]] * alignment_metadata["collective_size"]
    for device, delta_ms in alignment_metadata["collective_end_time_skews_ms"].groupby(
        "Device"
    ):
        assert not seen_devices[device]
        seen_devices[device] = True
        data[device] = delta_ms
    fig, ax = plt.subplots()
    ax.violinplot(data, positions=range(len(data)))
    ax.set_title(
        f"Estimated clock skew from N={alignment_metadata['collective_size']} collectives"
    )
    ax.set_xlabel("Device")
    ax.set_ylabel("Clock skew [ms]")

In [None]:
module_stats = (
    steady_state.module.groupby("ProgramId")
    .agg(
        {
            "Name": ("count", "first"),
            "ProjDurMs": ("sum", "std"),
            "NumThunks": ("mean", "std"),
        }
    )
    .sort_values(("ProjDurMs", "sum"), ascending=False)
)

# Get a short list of the XLA modules that contribute most to the execution time in
# this application. Threshold is the fraction of exec time that should be accounted for
top_module_threshold = 0.99
frac_seen, top_module_ids = 0.0, []
module_total_time = module_stats[("ProjDurMs", "sum")].sum()
print("     Active GPU time   Wait #Exec. #Thunks   Module name")
program_id: int
for program_id, row in module_stats.iterrows():
    # Study how much time is wasted due to different devices launching modules at
    # different times; this only matters if the modules include collectives
    wait_frac = "   -- "
    if steady_state.thunk.loc[program_id, "Communication"].any():
        # Min/max over ranks
        start_times = (
            steady_state.module.loc[program_id, "ProjStartMs"]
            .groupby("ProgramExecution")
            .agg(("max", "min"))
        )
        # Sum over program executions
        wait_frac = "{:6.2%}".format(
            (start_times["max"] - start_times["min"]).sum() / row[("ProjDurMs", "sum")]
        )
    module_frac = row[("ProjDurMs", "sum")] / module_total_time
    if frac_seen < top_module_threshold:
        top_module_ids.append(program_id)
    frac_seen += module_frac
    print(
        " {:7.2%} {:9.2f}ms {} {:6} {:5.0f}±{:<3.0f} {} ({})".format(
            module_frac,
            row[("ProjDurMs", "sum")],
            wait_frac,
            row[("Name", "count")],
            row[("NumThunks", "mean")],
            row[("NumThunks", "std")],
            row[("Name", "first")],
            program_id,
        )
    )
print(
    f"{top_module_threshold:.1%}+ of execution time accounted for by module ID(s): {' '.join(map(str, top_module_ids))}"
)

In [None]:
# Summarise all the observed compilation time; this averages over all profiled compilations and handles parallel compilation
compile_time_ms = generate_compilation_statistics(init.compile)


def clean_compilation_range_name(name):
    """
    This defines how we summarise compilation phases, e.g. if XLA's passes are
    kept separate or lumped in together.
    """
    # Remove the name of the HLO op being autotuned
    if name.startswith("XlaAutotunerMeasurement"):
        name = "XlaAutotunerMeasurement"
    # Parallel backend compilation leads to these split_module names for XlaEmitGpuAsm and XlaOptimizeLlvmIr
    name = name.removesuffix(":#module=split_module#")
    return name


# Summarise the results more by combining together different passes
compile_summary = (
    compile_time_ms.groupby(clean_compilation_range_name)
    .agg("sum")
    .sort_values(by=["DurNonChildMs"], ascending=False)
)
total_compile_time = compile_summary["DurNonChildMs"].sum()
# Print out the largest entries adding up to at least this fraction of the total
threshold = 0.97
compile_summary["FracNonChild"] = compile_summary["DurNonChildMs"] / total_compile_time
print(f"Top {threshold:.0%}+ of {total_compile_time*1e-9:.2f}s compilation time")
for row in compile_summary[
    compile_summary["FracNonChild"].cumsum() <= threshold
].itertuples():
    print(f"{row.FracNonChild:6.2%} {row.DurNonChildMs*1e-3:.2f}s {row.Index}")

In [None]:
# Summarise the thunks/kernels that have been seen. Here we do respect the
# `top_module_ids` list derived above, as in particular the definition (3) of
# the total runtime is sensitive to outliers. This is probably a reasonable
# default, but it is still a heuristic.
top_module_thunk_df = steady_state.thunk.loc[top_module_ids]
top_module_df = steady_state.module.loc[top_module_ids]
top_module_df["ProjEndMs"] = top_module_df["ProjStartMs"] + top_module_df["ProjDurMs"]
thunk_summary = (
    top_module_thunk_df.groupby(["ProgramId", "Name"])
    .agg({"ProjDurMs": "sum"})
    .sort_values("ProjDurMs", ascending=False)
)

# Calculate a few different definitions of the total runtime:
# 1. the sum of all thunk/kernel runtimes, after overlap subtraction
# 2. the sum of all module runtimes, which is (1) plus any time the GPU is idle
#    *during* execution of a module
# 3. the time from the first thunk in the first module starting to execute on
#    the GPU and the last thunk in the last module finishing its execution on
#    the GPU, which is (2) plus any time the GPU is idle between execution of
#    modules.
# (3) can easily include compilation and initialisation time if the profile is
# not collected in a targeted manner
#
# In case multiple GPUs are being driven by the same process, (3) is calculated
# on a per-GPU basis and then summed over GPUs
all_thunks_active_ms = thunk_summary["ProjDurMs"].sum()  # (1)
all_modules_active_ms = top_module_df["ProjDurMs"].sum()  # (2)
top_module_duration_df = top_module_df.groupby("Device").agg(
    {"ProjStartMs": "min", "ProjEndMs": "max"}
)
all_modules_wall_ms = (
    top_module_duration_df["ProjEndMs"] - top_module_duration_df["ProjStartMs"]
).sum()  # (3)

# Project the thunk runtime data onto some other data structures, to be
# presented in different ways.
op_runtime: dict[str, float] = defaultdict(float)
op_name_runtime: dict[tuple[str, ...], float] = defaultdict(float)
src_runtime: dict[tuple[str, ...], float] = defaultdict(float)

# Dummy entries to massage the source code view
gpu_active = ["[GPU active]"]
gpu_active_unknown = gpu_active + ["[Unknown]"]
gpu_idle_inside_modules = ["[GPU idle during module execution]"]
gpu_idle_between_modules = ["[GPU idle between module executions]"]


@functools.cache
def instructions_and_frames(hlo_module, instruction_name):
    _, hlo_inst = hlo_module.find_instruction(instruction_name)
    instructions = [hlo_inst.proto()] + [
        called_inst
        for called_comp_id in hlo_inst.proto().called_computation_ids
        for called_inst in hlo_module.find_computation(called_comp_id).instructions
    ]
    metadata = [inst.metadata for inst in instructions]
    frames = [hlo_module.get_stack_frames(meta.stack_frame_id) for meta in metadata]
    return hlo_inst.proto().opcode, metadata, frames


for thunk_row in thunk_summary.itertuples():
    program_id, thunk_name = thunk_row.Index
    # policy="all" means we may get a set of HloProto instead of a single one, if
    # nsys-jax-combine was used and the dumped metadata were not bitwise identical
    hlo_modules = xla_module_metadata(program_id, policy="all")
    thunk_opcode, inst_metadata, inst_frames = hlo_modules.unique_result(
        lambda proto: instructions_and_frames(proto, thunk_name)
    )

    # Summarise by opcode, i.e. fusion/custom-call/...
    op_runtime[thunk_opcode] += thunk_row.ProjDurMs

    # Summarise by source location. This is inherently approximate because
    # there are multiple instructions and stack traces attributed to each unit
    # of GPU runtime, and we do not know how to weight them. For now, give
    # equal weight to the instruction `hlo_inst` and all instructions in called
    # computations that have non-empty metadata.
    Location = tuple[str, ...]
    src_runtime_preferences: tuple[set[Location], ...] = (
        # 1st choice: gpu_active, source location, op_name
        set(),
        # 2nd choice: gpu_active_unknown, op_name
        set(),
        # 3rd choice: gpu_active_unknown
        {tuple(gpu_active_unknown)},
    )
    op_name_runtime_preferences: tuple[set[Location], ...] = (
        # 1st choice: gpu_active, op_name
        set(),
        # 2nd choice: gpu_active_unknown
        {tuple(gpu_active_unknown)},
    )
    for meta, frames in zip(inst_metadata, inst_frames):
        op_name = [meta.op_name] if len(meta.op_name) else []
        if len(frames):
            src_runtime_preferences[0].add(tuple(gpu_active + frames + op_name))
        if len(op_name):
            src_runtime_preferences[1].add(tuple(gpu_active_unknown + op_name))
            op_name_runtime_preferences[0].add(
                tuple(gpu_active + op_name[0].split("/"))
            )
    for locations in src_runtime_preferences:
        if len(locations) > 0:
            weight = thunk_row.ProjDurMs / len(locations)
            for loc in locations:
                src_runtime[loc] += weight
            break
    for locations in op_name_runtime_preferences:
        if len(locations) > 0:
            weight = thunk_row.ProjDurMs / len(locations)
            for loc in locations:
                op_name_runtime[loc] += weight
            break


# Use total time (2) when summarising over opcodes, as it's not trivial to
# collapse away the difference between (2) and (3).
op_runtime["_total"] = all_modules_active_ms
op_runtime["GPU idle during modules"] = all_modules_active_ms - all_thunks_active_ms

# When summarising over source locations use total time (3) as the top level of
# the hierarchy, assuming that the visualisation will be able to handle this.
src_runtime[tuple(gpu_idle_inside_modules)] = max(
    0.0, all_modules_active_ms - all_thunks_active_ms
)
src_runtime[tuple(gpu_idle_between_modules)] = max(
    0.0, all_modules_wall_ms - all_modules_active_ms
)
op_name_runtime[tuple(gpu_idle_inside_modules)] = src_runtime[
    tuple(gpu_idle_inside_modules)
]
op_name_runtime[tuple(gpu_idle_between_modules)] = src_runtime[
    tuple(gpu_idle_between_modules)
]

In [None]:
print("GPU runtime by operation type")
for k, v in sorted(op_runtime.items(), key=lambda x: -x[1]):
    if k.startswith("_"):
        continue
    print(" {:5.2f}% {:10.2f}ms {}".format(100.0 * v / op_runtime["_total"], v, k))

In [None]:
display_flamegraph(
    data=src_runtime,
    title="Source code flamegraph",
    filename="source_code.svg",
    width=1250,
)

In [None]:
display_flamegraph(
    data=op_name_runtime, title="op_name flamegraph", filename="op_name.svg", width=1250
)

In [None]:
if len(steady_state.communication):
    fig, axs2d = plt.subplots(
        ncols=3, figsize=[15, 5], squeeze=False, tight_layout=True
    )
    axs = axs2d[0]
    wait_data, wait_data_labels = [], []
    comm_df = steady_state.communication
    comm_df["ProjDurFullMs"] = comm_df["ProjDurMs"] + comm_df["ProjDurHiddenMs"]
    comm_df["ProjEndMs"] = comm_df["ProjStartMs"] + comm_df["ProjDurFullMs"]
    for comm, df in comm_df.groupby("Collective"):
        # The grouped data frame will have a row for each device that is participating in
        # this instance of this collective, in the loose SPMD sense. Depending on the JAX
        # program, there may be different sub-groupings that are participating in smaller
        # collectives in the strict/NCCL sense. TODO: it would be better to identify those
        # sub-groupings and group them, but we currently lack the relevant information.
        collective_df = df.groupby(["ProgramId", "ProgramExecution", "ThunkIndex"])
        # Take the fastest device kernel as a proxy for the actual bandwidth of the
        # collective.
        bandwidth_df = collective_df.agg(
            {
                "BusBandwidthGBPerSec": "max",
                "MessageSize": "min",
                "ProjStartMs": "min",
                "ProjDurFullMs": "min",
                "ProjEndMs": "max",
                "Name": "count",
            }
        )
        axs[0].plot(
            bandwidth_df["MessageSize"],
            bandwidth_df["BusBandwidthGBPerSec"],
            "o",
            label=comm,
        )
        # Take last_end - first_start - fastest_duration as a proxy for time lost due
        # to stragglers / failing to operate in neat lockstep.
        wait_time_ms = (
            bandwidth_df["ProjEndMs"]
            - bandwidth_df["ProjStartMs"]
            - bandwidth_df["ProjDurFullMs"]
        )
        wait_data.append(wait_time_ms)
        wait_data_labels.append(comm)
        axs[2].plot(bandwidth_df["MessageSize"], wait_time_ms, "o", label=comm)
    axs[0].legend()
    axs[0].set_xlabel("Message size (B)")
    axs[0].set_xscale("log")
    axs[0].set_ylabel("Bus bandwidth (GB/s)")
    axs[1].boxplot(wait_data, vert=True)
    axs[1].set_xticks([y + 1 for y in range(len(wait_data))], labels=wait_data_labels)
    axs[1].set_xlabel("Collective")
    axs[1].set_ylabel("Wait time [ms]")
    axs[1].set_yscale("log")
    axs[2].set_xlabel("Message size (B)")
    axs[2].set_ylabel("Wait time [ms]")
    axs[2].set_xscale("log")
    axs[2].set_yscale("log")

In [None]:
# Arbitrary thresholds for detailed view:
var_threshold = 0.10  # std/mean must be larger than this
# only this many of the slowest kernels clearing `var_threshold` will be shown
detailed_limit = 10

# Calculate statistics over different devices and different executions of each thunk, including multiple executions of the same thunk within the same module
compute_durations = steady_state.thunk.loc[
    ~steady_state.thunk["Communication"], ("Name", "ProjDurMs")
].groupby(["ProgramId", "Name"])
compute_duration_stats = compute_durations["ProjDurMs"].agg(("mean", "std"))
compute_duration_means = compute_duration_stats["mean"]
compute_duration_rel_stds = compute_duration_stats["std"] / compute_duration_means

# Calculate a threshold such that `detailed_limit` points satisfy (std/mean > var_threshold && mean > mean_threshold)
high_variance_means = compute_duration_means[
    compute_duration_rel_stds > var_threshold
].sort_values(ascending=False)
mean_threshold = sum(high_variance_means[detailed_limit - 1 : detailed_limit + 1]) / 2
detailed_mask = (compute_duration_rel_stds > var_threshold) & (
    compute_duration_means > mean_threshold
)
assert (
    detailed_mask.sum() <= detailed_limit
), f"Aimed for {detailed_limit} and got {detailed_mask.sum()}"

fig, axs = plt.subplots(
    ncols=2, width_ratios=[1, 2], figsize=[15, 5], tight_layout=True
)
fig.suptitle(
    rf"{detailed_limit} slowest thunks with $\sigma/\mu$ > {var_threshold:.0%}"
)
axs[0].set_xlabel(r"Mean execution time ($\mu$) [ms]")
axs[0].set_ylabel(r"Execution time variability $\sigma/\mu$ [%]")
axs[0].scatter(
    compute_duration_means[~detailed_mask],
    100 * compute_duration_rel_stds[~detailed_mask],
)
axs[0].scatter(
    compute_duration_means[detailed_mask],
    100 * compute_duration_rel_stds[detailed_mask],
    label="Included in detailed view",
)
# Set explicitly so they don't get adjusted
xlims, ylims = axs[0].get_xlim(), axs[0].get_ylim()
axs[0].set_xlim(xlims)
axs[0].set_ylim(ylims)
axs[0].fill_between(
    [mean_threshold, xlims[1]],
    [100 * var_threshold, 100 * var_threshold],
    [ylims[1], ylims[1]],
    alpha=0.2,
    color="green",
    zorder=0,
)
axs[0].legend()


def durations_ms(idx):
    program_id, thunk_name = idx
    tmp = steady_state.thunk.loc[program_id, ("Name", "ProjDurMs")]
    return tmp.loc[tmp["Name"] == thunk_name, "ProjDurMs"]


detailed_index = high_variance_means[high_variance_means > mean_threshold].index
axs[1].violinplot(
    list(map(durations_ms, detailed_index)),
    positions=np.arange(len(detailed_index)),
    vert=False,
)
axs[1].set_xlabel("Execution time [ms]")
axs[1].set_yticks(
    np.arange(len(detailed_index)),
    labels=map(lambda idx: f"{idx[1]} ({idx[0]})", detailed_index),
);

In [None]:
if len(steady_state.communication):
    fig, grid = plt.subplots(
        nrows=len(top_module_ids),
        figsize=[15, 5 * len(top_module_ids)],
        squeeze=False,
        tight_layout=True,
    )
    time_df = steady_state.thunk.loc[
        ~steady_state.thunk["Communication"], ("ProjStartMs", "ProjDurMs")
    ]
    time_df["ProjEndMs"] = time_df["ProjStartMs"] + time_df.pop("ProjDurMs")

    def interleave(df):
        s, e = df["ProjStartMs"], df["ProjEndMs"]
        r = np.empty((s.size + e.size,), dtype=s.dtype)
        r[0::2] = s
        r[1::2] = e
        return r

    devices_to_show = 8
    for n_row, program_id in enumerate(top_module_ids):
        x_values = []
        y_values = defaultdict(list)
        ax = grid[n_row][0]
        for module_execution, exec_df in time_df.loc[program_id].groupby(
            "ProgramExecution"
        ):
            # Mean over devices to get a single [thunk0_start, thunk0_end, thunk1_start, ...]
            # array for this execution of this module
            mean_times = interleave(exec_df.groupby("ThunkIndex").agg("mean"))
            # x axis of the plot will be the average over executions of the module
            x_values.append(mean_times - mean_times[0])
            for device, device_values in exec_df.groupby("Device"):
                # [thunk0_start, thunk0_end, ...] array for one device within one module exec
                # with the average over devices subtracted
                y_values[device].append(interleave(device_values) - mean_times)
        mean_start_time_ms = np.mean(x_values, axis=0)
        all_values = np.array(list(y_values.values()))
        ax.plot(
            mean_start_time_ms,
            np.min(all_values, axis=(0, 1)),
            "k:",
            lw=1,
            label="min/max",
        )
        ax.plot(mean_start_time_ms, np.max(all_values, axis=(0, 1)), "k:", lw=1)
        std = np.std(all_values, axis=(0, 1))
        ax.fill_between(
            mean_start_time_ms, -std, +std, alpha=0.2, label=r"$\pm1\sigma$"
        )
        # max abs(bias) over ProgramExecution within a device, summed over ThunkIndex
        outlier_devices = np.sum(np.max(np.abs(all_values), axis=1), axis=1)
        for _, device in sorted(
            zip(outlier_devices, range(all_values.shape[0])), reverse=True
        )[:devices_to_show]:
            ax.plot(
                mean_start_time_ms,
                np.mean(all_values[device], axis=0),
                label=f"Device {device}",
            )

        comm_x_values = defaultdict(list)
        for module_execution, exec_df in comm_df.loc[program_id].groupby(
            "ProgramExecution"
        ):
            exec_df["EndInModuleMs"] = (
                exec_df["ProjEndMs"]
                - steady_state.module.loc[(program_id, module_execution), "ProjStartMs"]
            )
            tmp = exec_df.groupby("ThunkIndex").agg(
                {
                    "Name": "first",
                    "Collective": "first",
                    "CollectiveSize": "first",
                    "EndInModuleMs": "mean",
                }
            )
            for coll_size, values in tmp.groupby("CollectiveSize"):
                comm_x_values[coll_size].append(values["EndInModuleMs"])
        (_, xmax), (ymin, ymax) = ax.get_xlim(), ax.get_ylim()
        ax.set_xlim(0, xmax)
        ax.set_ylim(ymin, ymax)
        largest_collective = max(comm_x_values.keys())
        for n_color, (coll_size, values) in enumerate(comm_x_values.items()):
            collective_times = np.mean(values, axis=0)
            ax.vlines(
                collective_times,
                ymin,
                # Draw taller vertical lines for collectives involving more devices
                ymin * (1 - coll_size / largest_collective),
                color=f"C{n_color}",
                label=f"{coll_size}-device collective",
                linestyle="--",
            )

        ax.set_title(
            f"{steady_state.module.loc[program_id, 'Name'].iloc[0]} ({program_id}), {min(outlier_devices.size, devices_to_show)} most extreme devices"
        )
        ax.set_xlabel("Mean time within module [ms]")
        ax.set_ylabel("Mean(executions) bias from mean(executions&devices) [ms]")
        ax.legend(ncols=2)

## Using compile-time information

As well as analysing profile data, it can be helpful to explore compile-time metadata programatically.
This is an example of looking at the buffer assignment metadata, to show statically how the memory usage evolves throughout a module execution.

In [None]:
from xla.service.hlo_pb2 import HeapSimulatorTrace

Kind = HeapSimulatorTrace.Event.Kind


def heap_usage_trace(trace, buffer_sizes):
    """
    Convert a heap simulator trace into a trace of the memory usage at each point.
    buffer_sizes maps {buffer_id: buffer_size}
    """
    active = defaultdict(int)

    def inc(k, v):
        active[k] = (new := active[k] + v)
        return new

    aliases, sizes = {}, [0]
    for event in trace.events:
        if event.kind == Kind.FREE:
            alloc_id = aliases.pop(event.buffer_id, event.buffer_id)
            if inc(alloc_id, -1) == 0:
                # This was the last logical buffer to use this allocation
                sizes.append(-buffer_sizes[alloc_id])
        else:
            if event.kind == Kind.SHARE_WITH:
                assert event.buffer_id not in aliases
                alloc_id = event.share_with_canonical_id
                aliases[event.buffer_id] = alloc_id
            else:
                assert event.kind == Kind.ALLOC
                alloc_id = event.buffer_id
            if inc(alloc_id, +1) == 1:
                # This was the first allocation, not a later alias
                sizes.append(buffer_sizes[alloc_id])
    assert all(x == 0 for x in active.values())
    heap_usage = np.cumsum(sizes)
    assert heap_usage[0] == 0 and heap_usage[-1] == 0
    return heap_usage

Here we use the execution profile, indirectly via `top_module_ids`, to only actually draw traces from modules that contributed non-negligibly to the execution time.

In [None]:
num_traces = {
    module_id: xla_module_metadata(module_id, policy="all").unique_result(
        lambda hlo_module: len(
            hlo_module.proto().buffer_assignment.heap_simulator_traces
        )
    )
    for module_id in top_module_ids
}
module_ids_with_traces = {
    module_id: n_traces for module_id, n_traces in num_traces.items() if n_traces
}
max_n_traces = max(num_traces.values())
n_modules = len(module_ids_with_traces)
fig, axs = plt.subplots(
    ncols=max_n_traces,
    nrows=n_modules,
    figsize=[max_n_traces * 5, n_modules * 5],
    squeeze=False,
)
for n_module, module_id in enumerate(module_ids_with_traces):
    protos = xla_module_metadata(module_id, policy="all")
    sizes_by_logical_id = protos.unique_result(
        lambda proto: {
            buffer.id: buffer.size
            for buffer in proto.proto().buffer_assignment.logical_buffers
        }
    )
    traces = protos.unique_result(
        lambda proto: proto.proto().buffer_assignment.heap_simulator_traces
    )
    for n_trace, trace in enumerate(traces):
        heap_usage = heap_usage_trace(trace, buffer_sizes=sizes_by_logical_id)
        ax = axs[n_module][n_trace]
        ax.plot(heap_usage / 1e9)
        ax.set_title(f"Module {module_id}")
        ax.set_xlabel("Program order")
        ax.set_ylabel("Heap memory usage [GB]")
        print(
            f"Peak heap memory usage in module ID {module_id} is {max(heap_usage) / 1e9:.3f} GB"
        )