In [None]:
from collections import defaultdict
from jax_nsys import (
    calculate_collective_metrics,
    compile_protos,
    display_flamegraph,
    generate_compilation_statistics,
    load_profiler_data,
    remove_child_ranges,
    xla_module_metadata,
)
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd  # type: ignore
import sys
from typing import NamedTuple

In [None]:
# Make sure that the .proto files under protos/ have been compiled to .py, and
# that those generated .py files are importable.
proto_dir, compiled_proto_dir = "protos", "compiled_protos"
if not os.path.isdir(compiled_proto_dir):
    os.mkdir(compiled_proto_dir)
    compile_protos(proto_dir=proto_dir, output_dir=compiled_proto_dir)
if compiled_proto_dir not in sys.path:
    sys.path.insert(0, compiled_proto_dir)

In [None]:
# Load the runtime profile data
all_data = load_profiler_data(
    frames={"thunk", "module", "compile"}, warmup_removal_heuristics=True
)
thunk_df = all_data["thunk"]
module_df = all_data["module"]
compile_df = all_data["compile"]
# module_df may contain some entries with ProgramId == -1, which are typically
# autotuner executions. Throw these away for now.
module_df = module_df[module_df["ProgramId"] >= 0]
thunk_df = thunk_df[thunk_df["ProgramId"] >= 0]

In [None]:
# Get a short list of the XLA modules that contribute most to the execution
# time in this application. The threshold is the fraction of exec time that can
# be ignored.
threshold = 0.01
top_module_sum = (
    module_df.groupby("ProgramId")
    .agg({"ProjDurNs": "sum"})["ProjDurNs"]
    .sort_values()
    .cumsum()
)
top_module_mask = top_module_sum / top_module_sum.max() > threshold
top_module_ids = top_module_mask[top_module_mask].index

In [None]:
# Mildly sanitise the autotuner results by removing child ranges of
# XlaAutotunerMeasurement ranges. The GEMM fusion autotuner creates small
# modules/thunks when measuring, which emit XlaModule and XlaThunk ranges
compile_df = remove_child_ranges(
    compile_df, compile_df["Name"].str.startswith("XlaAutotunerMeasurement")
)
# You might want to report autotuner compilation as one big block, instead of having
# the lower level components of it (EmitLlvmIr etc.) lumped in with their non-autotuner
# counterparts
compile_df = remove_child_ranges(
    compile_df, compile_df["Name"] == "XlaAutotunerCompilation"
)

In [None]:
# Summarise all the observed compilation time
# The first compilation triggers a bunch of library loading, things like cuBLAS
# and cuDNN. Label that explicitly to pull it out of the generic non-leaf time.
first_xlacompile_index = compile_df["Name"].eq("XlaCompile").idxmax()
assert compile_df.loc[first_xlacompile_index, "Name"] == "XlaCompile"
if compile_df.loc[first_xlacompile_index, "DurNonChildNs"] > 0.0:
    new_index = compile_df.index.max() + 1
    new_row = compile_df.loc[first_xlacompile_index, :].copy()
    new_row["DurChildNs"] = 0.0
    new_row["Name"] = "[non-leaf time in 0th XlaCompile range]"
    new_row["NumChild"] = 0
    new_row["RangeStack"] += f":{new_index}"
    compile_df.loc[first_xlacompile_index, "DurNonChildNs"] = 0.0
    compile_df.loc[first_xlacompile_index, "NumChild"] += 1
    compile_df = pd.concat([compile_df, pd.DataFrame([new_row], index=[new_index])])


# This averages over all profiled compilations and handles parallel compilation
compile_time_ns = generate_compilation_statistics(compile_df)


def clean_compilation_range_name(name):
    """
    This defines how we summarise compilation phases, e.g. if XLA's passes are
    kept separate or lumped in together.
    """
    # Remove the name of the HLO op being autotuned
    if name.startswith("XlaAutotunerMeasurement"):
        name = "XlaAutotunerMeasurement"
    # Parallel backend compilation leads to these split_module names for XlaEmitGpuAsm and XlaOptimizeLlvmIr
    name = name.removesuffix(":#module=split_module#")
    # Lump all XlaPass[Pipeline] stuff in together
    if name.startswith("XlaPass:#") or name.startswith("XlaPassPipeline:#"):
        name = "XlaPass"
    return name


# Summarise the results more by combining together different passes
compile_summary = (
    compile_time_ns.groupby(clean_compilation_range_name)
    .agg("sum")
    .sort_values(by=["DurNonChildNs"], ascending=False)
)
total_compile_time = compile_summary["DurNonChildNs"].sum()
# Print out the largest entries adding up to at least this fraction of the total
threshold = 0.99
compile_summary["FracNonChild"] = compile_summary["DurNonChildNs"] / total_compile_time
print(f"Top {threshold:.0%}+ of {total_compile_time*1e-9:.2f}s compilation time")
for row in compile_summary[
    compile_summary["FracNonChild"].cumsum() <= threshold
].itertuples():
    print(f"{row.FracNonChild:6.2%} {row.DurNonChildNs*1e-9:.2f}s {row.Index}")

In [None]:
# Summarise all the XLA modules that have been seen in this profile. Note that
# this does *not* respect the `top_module_ids` list derived above.
module_stats = defaultdict(list)
for module_row in module_df.itertuples():
    thunk_mask = thunk_df["ModuleId"] == module_row.Index
    num_thunks = thunk_mask.sum()
    module_stats[module_row.Name].append(
        {"GPU time [ms]": 1e-6 * module_row.ProjDurNs, "#Thunks": num_thunks}
    )


class Summary(NamedTuple):
    mean: float
    std: float
    total: float


def reduce_module_stats(module_stats) -> dict[str, Summary]:
    # [{"a": 0.3}, {"a": 0.4}] -> {"a": (0.35, stddev), "#Instances": 2}
    num_instances = len(module_stats)
    r = {"#Instances": Summary(mean=num_instances, std=0.0, total=num_instances)}
    keys = module_stats[0].keys()
    for stats in module_stats[1:]:
        assert stats.keys() == keys
    for k in keys:
        values = [stats[k] for stats in module_stats]
        r[k] = Summary(mean=np.mean(values), std=np.std(values), total=np.sum(values))
    return r


# Aggregate HLO module statistics over repeated executions of them
agg_module_stats = [(k, reduce_module_stats(v)) for k, v in module_stats.items()]


def sort_key(x):
    return x[1]["GPU time [ms]"].total


agg_module_stats.sort(key=sort_key, reverse=True)
total = sum(sort_key(x) for x in agg_module_stats)
print("      Active GPU time #Exec. #Thunks  Module name")
accounted_time, top_n = 0.0, None
for n, tup in enumerate(agg_module_stats):
    module_name, stats = tup
    module_time = sort_key(tup)
    print(
        " {:7.2f}% {:9.2f}ms {:5} {:5.0f}±{:<3.0f} {}".format(
            100.0 * module_time / total,
            module_time,
            stats["#Instances"].mean,
            stats["#Thunks"].mean,
            stats["#Thunks"].std,
            module_name,
        )
    )
    accounted_time += module_time

In [None]:
# Summarise the thunks/kernels that have been seen. Here we do respect the
# `top_module_ids` list derived above, as in particular the definition (3) of
# the total runtime is sensitive to outliers. This is probably a reasonable
# default, but it is still a heuristic.
top_module_thunk_df = thunk_df[thunk_df["ProgramId"].isin(top_module_ids)]
top_module_df = module_df[module_df["ProgramId"].isin(top_module_ids)].copy()
top_module_df["ProjEndNs"] = top_module_df["ProjStartNs"] + top_module_df["ProjDurNs"]
thunk_summary = (
    top_module_thunk_df.groupby(["ProgramId", "Name"])
    .agg({"ProjDurNs": "sum"})
    .sort_values("ProjDurNs", ascending=False)
)

# Calculate a few different definitions of the total runtime:
# 1. the sum of all thunk/kernel runtimes, after overlap subtraction
# 2. the sum of all module runtimes, which is (1) plus any time the GPU is idle
#    *during* execution of a module
# 3. the time from the first thunk in the first module starting to execute on
#    the GPU and the last thunk in the last module finishing its execution on
#    the GPU, which is (2) plus any time the GPU is idle between execution of
#    modules.
# (3) can easily include compilation and initialisation time if the profile is
# not collected in a targeted manner, as it can easily include compilation and
# initialisation time.
#
# In case multiple GPUs are being driven by the same process, (3) is calculated
# on a per-GPU basis and then summed over GPUs
all_thunks_active_ns = thunk_summary["ProjDurNs"].sum()  # (1)
all_modules_active_ns = top_module_df["ProjDurNs"].sum()  # (2)
top_module_duration_df = top_module_df.groupby("TID").agg(
    {"ProjStartNs": "min", "ProjEndNs": "max"}
)
all_modules_wall_ns = (
    top_module_duration_df["ProjEndNs"] - top_module_duration_df["ProjStartNs"]
).sum()  # (3)

# Project the thunk runtime data onto some other data structures, to be
# presented in different ways.
op_runtime: dict[str, float] = defaultdict(float)
op_name_runtime: dict[tuple[str, ...], float] = defaultdict(float)
src_runtime: dict[tuple[str, ...], float] = defaultdict(float)

# Dummy entries to massage the source code view
gpu_active = ["[GPU active]"]
gpu_active_unknown = gpu_active + ["[Unknown]"]
gpu_idle_inside_modules = ["[GPU idle during module execution]"]
gpu_idle_between_modules = ["[GPU idle between module executions]"]

print("Top 10 thunks by GPU runtime")
for n, thunk_row in enumerate(thunk_summary.itertuples()):
    program_id, thunk_name = thunk_row.Index
    if program_id == -1:
        # No module information -> probably an autotuning run.
        continue
    hlo_module = xla_module_metadata(program_id)
    hlo_comp, hlo_inst = hlo_module.find_instruction(thunk_name)
    if n < 10:
        print(
            " {:5.2f}% {:5.2f}ms {} {}".format(
                100.0 * thunk_row.ProjDurNs / all_thunks_active_ns,
                1e-6 * thunk_row.ProjDurNs,
                thunk_name,
                hlo_inst.metadata.op_name,
            )
        )

    # Summarise by opcode, i.e. fusion/custom-call/...
    op_runtime[hlo_inst.opcode] += thunk_row.ProjDurNs

    # Summarise by source location. This is inherently approximate because
    # there are multiple instructions and stack traces attributed to each unit
    # of GPU runtime, and we do not know how to weight them. For now, give
    # equal weight to the instruction `hlo_inst` and all instructions in called
    # computations that have non-empty metadata.
    called_instructions = [
        called_inst
        for called_comp_id in hlo_inst.called_computation_ids
        for called_inst in hlo_module.find_computation(called_comp_id).instructions
    ]
    src_runtime_preferences: tuple[set[tuple[str, ...]], ...] = (
        set(),
        set(),
        {tuple(gpu_active_unknown)},
    )
    op_name_runtime_preferences: tuple[set[tuple[str, ...]], ...] = (
        set(),
        {tuple(gpu_active_unknown)},
    )
    for inst in [hlo_inst] + called_instructions:
        frames = hlo_module.get_stack_frames(inst.metadata.stack_frame_id)
        op_name = [inst.metadata.op_name] if len(inst.metadata.op_name) else []
        if len(frames):
            src_runtime_preferences[0].add(tuple(gpu_active + frames + op_name))
        if len(op_name):
            src_runtime_preferences[1].add(tuple(gpu_active_unknown + op_name))
            op_name_runtime_preferences[0].add(
                tuple(gpu_active + op_name[0].split("/"))
            )
    for locations in src_runtime_preferences:
        if len(locations) > 0:
            weight = thunk_row.ProjDurNs / len(locations)
            for loc in locations:
                src_runtime[loc] += weight
            break
    for locations in op_name_runtime_preferences:
        if len(locations) > 0:
            weight = thunk_row.ProjDurNs / len(locations)
            for loc in locations:
                op_name_runtime[loc] += weight
            break


# Use total time (2) when summarising over opcodes, as it's not trivial to
# collapse away the difference between (2) and (3).
op_runtime["_total"] = all_modules_active_ns
op_runtime["GPU idle during modules"] = all_modules_active_ns - all_thunks_active_ns

# When summarising over source locations use total time (3) as the top level of
# the hierarchy, assuming that the visualisation will be able to handle this.
src_runtime[tuple(gpu_idle_inside_modules)] = (
    all_modules_active_ns - all_thunks_active_ns
)
src_runtime[tuple(gpu_idle_between_modules)] = (
    all_modules_wall_ns - all_modules_active_ns
)
op_name_runtime[tuple(gpu_idle_inside_modules)] = src_runtime[
    tuple(gpu_idle_inside_modules)
]
op_name_runtime[tuple(gpu_idle_between_modules)] = src_runtime[
    tuple(gpu_idle_between_modules)
]

In [None]:
print("GPU runtime by operation type")
for k, v in sorted(op_runtime.items(), key=lambda x: -x[1]):
    if k.startswith("_"):
        continue
    print(
        " {:5.2f}% {:10.2f}ms {}".format(100.0 * v / op_runtime["_total"], 1e-6 * v, k)
    )

In [None]:
display_flamegraph(
    data=src_runtime,
    title="Source code flamegraph",
    filename="source_code.svg",
    width=1250,
)

In [None]:
display_flamegraph(
    data=op_name_runtime, title="op_name flamegraph", filename="op_name.svg", width=1250
)

In [None]:
comm_df = calculate_collective_metrics(thunk_df)
fig, axs = plt.subplots(ncols=3, figsize=[15, 5])
comm_df["ProjDurFullNs"] = comm_df["ProjDurNs"] + comm_df["ProjDurHiddenNs"]
comm_df["ProjEndNs"] = comm_df["ProjStartNs"] + comm_df["ProjDurFullNs"]
for comm, df in comm_df.groupby("Collective"):
    # The grouped data frame will have a row for each device that is participating in
    # this instance of this collective, in the loose SPMD sense. Depending on the JAX
    # program, there may be different sub-groupings that are participating in smaller
    # collectives in the strict/NCCL sense. TODO: it would be better to identify those
    # sub-groupings and group them, but we currently lack the relevant information.
    collective_df = df.groupby(["ProgramId", "Name", "ModuleExecution"])
    # Take the fastest device kernel as a proxy for the actual bandwidth of the
    # collective.
    bandwidth_df = collective_df.agg(
        {
            "BusBandwidthGBPerSec": "max",
            "MessageSize": "min",
            "ProjStartNs": "min",
            "ProjDurFullNs": "min",
            "ProjEndNs": "max",
        }
    )
    axs[0].plot(
        bandwidth_df["MessageSize"],
        bandwidth_df["BusBandwidthGBPerSec"],
        "o",
        label=comm,
    )
    # Take last_end - first_start - fastest_duration as a proxy for time lost due
    # to stragglers / failing to operate in neat lockstep.
    wait_time_ns = (
        bandwidth_df["ProjEndNs"]
        - bandwidth_df["ProjStartNs"]
        - bandwidth_df["ProjDurFullNs"]
    )
    wait_time_pc = wait_time_ns / bandwidth_df["ProjDurFullNs"]
    axs[1].hist(wait_time_ns * 1e-6, 100, label=comm)
    axs[2].plot(bandwidth_df["MessageSize"], wait_time_pc, "o", label=comm)
axs[0].legend()
axs[0].set_xlabel("Message size (B)")
axs[0].set_xscale("log")
axs[0].set_ylabel("Bus bandwidth (GB/s)")
axs[1].set_xlabel("Wait time (ms)")
axs[2].set_xlabel("Message size (B)")
axs[2].set_ylabel("Wait time (multiple of fastest)")
axs[2].set_xscale("log")
axs[2].set_yscale("log")

In [None]:
compute_times = (
    thunk_df[~thunk_df["Communication"]]
    .groupby(["ProgramId", "Name"])
    .agg({"ProjDurNs": ["mean", "std"]})
    .sort_values(("ProjDurNs", "mean"))
)
plt.plot(
    compute_times[("ProjDurNs", "mean")],
    compute_times[("ProjDurNs", "std")] / compute_times[("ProjDurNs", "mean")],
    "o",
)
# plt.errorbar(["{}:{}".format(*x) for x in compute_times.index], compute_times[("ProjDurNs", "mean")], compute_times[("ProjDurNs", "std")], marker="o")
# plt.xlabel(
plt.xscale("log")
# compute_times[("ProjDurNs", "std")]