diff --git a/.github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py b/.github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py index dfd6b29d1..e5287c823 100644 --- a/.github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py +++ b/.github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py @@ -18,6 +18,7 @@ import tempfile import time import traceback +from typing import Optional import zipfile from .utils import execute_analysis_script, shuffle_analysis_arg @@ -259,21 +260,22 @@ def override_nsys_default(arg, value): if "JAX_ENABLE_COMPILATION_CACHE" not in env: env["JAX_ENABLE_COMPILATION_CACHE"] = "false" + def format_flag(tup): + n, v = tup + return f"--{n}" if v is None else f"--{n}={v}" + # Get the existing XLA_FLAGS and parse them into a dictionary. - xla_flag_list = shlex.split(env.get("XLA_FLAGS", "")) - xla_flags = {} - for flag in xla_flag_list: + xla_flags: dict[str, Optional[str]] = {} + for flag in shlex.split(env.get("XLA_FLAGS", "")): assert flag.startswith("--") bits = flag[2:].split("=", maxsplit=1) name, value = bits[0], bits[1] if len(bits) > 1 else None - assert name not in xla_flags + if name in xla_flags: + print( + f"WARNING: {format_flag((name, xla_flags[name]))} being overriden by {flag}" + ) xla_flags[name] = value - def as_list(flags): - return [f"--{n}" if v is None else f"--{n}={v}" for n, v in flags.items()] - - assert xla_flag_list == as_list(xla_flags) - def as_bool(s): """String -> bool conversion following XLA's semantics.""" if s.lower() == "true" or s == "1": @@ -298,7 +300,7 @@ def as_bool(s): # Serialise the modified XLA flags. shlex.join is tempting, but doesn't seem to # get the right result for --xla_dump_hlo_pass_re=.*, as it adds extra quotes. - env["XLA_FLAGS"] = " ".join(as_list(xla_flags)) + env["XLA_FLAGS"] = " ".join(map(format_flag, xla_flags.items())) # Run the application in nsys # TODO: consider being more fault-tolerant? diff --git a/.github/container/nsys_jax/nsys_jax/scripts/patch_nsys.py b/.github/container/nsys_jax/nsys_jax/scripts/patch_nsys.py index 44496e99a..bbf364560 100644 --- a/.github/container/nsys_jax/nsys_jax/scripts/patch_nsys.py +++ b/.github/container/nsys_jax/nsys_jax/scripts/patch_nsys.py @@ -1,4 +1,3 @@ -import os import re import shutil import subprocess @@ -1372,11 +1371,20 @@ def main(): patch_content = None if patch_content is not None: print(f"Patching Nsight Systems version {m.group(1)}") - # e.g. /opt/nvidia/nsight-systems-cli/2024.7.1/target-linux-x64 - tdir = os.path.dirname(os.path.realpath(nsys)) + nsys_recipe_help = subprocess.check_output( + [nsys, "recipe", "--help"], text=True + ) + m = re.search( + r"List of required Python packages: '(.*?)/nsys_recipe/requirements/common.txt'", + nsys_recipe_help, + ) + assert m is not None, ( + f"Could not determine target directory from: {nsys_recipe_help}" + ) + # e.g. /opt/nvidia/nsight-systems-cli/2024.7.1/target-linux-x64/python/packages subprocess.run( [shutil.which("git"), "apply"], - cwd=os.path.join(tdir, "python", "packages"), + cwd=m.group(1), input=patch_content, check=True, text=True, diff --git a/.github/workflows/nsys-jax.yaml b/.github/workflows/nsys-jax.yaml index e15cd557f..6e8fbfaf5 100644 --- a/.github/workflows/nsys-jax.yaml +++ b/.github/workflows/nsys-jax.yaml @@ -1,4 +1,4 @@ -name: nsys-jax pure-Python CI +name: nsys-jax non-GPU CI concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} @@ -274,3 +274,22 @@ jobs: if [[ $format_status != 0 || $check_status != 0 ]]; then exit 1 fi + installation: + strategy: + matrix: + include: + - container: "nvidia/cuda:12.6.3-base-ubuntu24.04" + nsys_package: "cuda-nsight-systems-12-6" + - container: "nvidia/cuda:12.8.0-base-ubuntu24.04" + nsys_package: "cuda-nsight-systems-12-8" + runs-on: ubuntu-latest + container: "${{ matrix.container }}" + steps: + - name: Install ${{ matrix.nsys_package }} + run: | + apt-get update + apt-get install -y git python3-pip ${{ matrix.nsys_package }} + - name: Install nsys-jax + run: pip install --break-system-packages git+https://github.com/NVIDIA/JAX-Toolbox.git@${{ github.head_ref || github.sha }}#subdirectory=.github/container/nsys_jax + - name: Run nsys-jax-patch-nsys + run: nsys-jax-patch-nsys