From 484f86022b9c97583208de1f328a2cad256c21aa Mon Sep 17 00:00:00 2001 From: Olli Lupton Date: Mon, 10 Mar 2025 12:00:46 +0000 Subject: [PATCH 1/5] nsys-jax-patch-nsys: cope with nsys being a shim --- .../nsys_jax/nsys_jax/scripts/patch_nsys.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/.github/container/nsys_jax/nsys_jax/scripts/patch_nsys.py b/.github/container/nsys_jax/nsys_jax/scripts/patch_nsys.py index 44496e99a..e2dd5abc7 100644 --- a/.github/container/nsys_jax/nsys_jax/scripts/patch_nsys.py +++ b/.github/container/nsys_jax/nsys_jax/scripts/patch_nsys.py @@ -1372,11 +1372,20 @@ def main(): patch_content = None if patch_content is not None: print(f"Patching Nsight Systems version {m.group(1)}") - # e.g. /opt/nvidia/nsight-systems-cli/2024.7.1/target-linux-x64 - tdir = os.path.dirname(os.path.realpath(nsys)) + nsys_recipe_help = subprocess.check_output( + [nsys, "recipe", "--help"], text=True + ) + m = re.search( + r"List of required Python packages: '(.*?)/nsys_recipe/requirements/common.txt'", + nsys_recipe_help, + ) + assert m is not None, ( + f"Could not determine target directory from: {nsys_recipe_help}" + ) + # e.g. /opt/nvidia/nsight-systems-cli/2024.7.1/target-linux-x64/python/packages subprocess.run( [shutil.which("git"), "apply"], - cwd=os.path.join(tdir, "python", "packages"), + cwd=m.group(1), input=patch_content, check=True, text=True, From e7d4288b9391f73aa297eb5ba9acbe144f6c4c0d Mon Sep 17 00:00:00 2001 From: Olli Lupton Date: Mon, 10 Mar 2025 12:23:45 +0000 Subject: [PATCH 2/5] nsys-jax-patch-jax: test nvidia/cuda container The cuda-nsight-systems-MAJOR-MINOR packages produce an installation where nsys is a shim script, rather than the actual executable. --- .github/workflows/nsys-jax.yaml | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/.github/workflows/nsys-jax.yaml b/.github/workflows/nsys-jax.yaml index e15cd557f..6e8fbfaf5 100644 --- a/.github/workflows/nsys-jax.yaml +++ b/.github/workflows/nsys-jax.yaml @@ -1,4 +1,4 @@ -name: nsys-jax pure-Python CI +name: nsys-jax non-GPU CI concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} @@ -274,3 +274,22 @@ jobs: if [[ $format_status != 0 || $check_status != 0 ]]; then exit 1 fi + installation: + strategy: + matrix: + include: + - container: "nvidia/cuda:12.6.3-base-ubuntu24.04" + nsys_package: "cuda-nsight-systems-12-6" + - container: "nvidia/cuda:12.8.0-base-ubuntu24.04" + nsys_package: "cuda-nsight-systems-12-8" + runs-on: ubuntu-latest + container: "${{ matrix.container }}" + steps: + - name: Install ${{ matrix.nsys_package }} + run: | + apt-get update + apt-get install -y git python3-pip ${{ matrix.nsys_package }} + - name: Install nsys-jax + run: pip install --break-system-packages git+https://github.com/NVIDIA/JAX-Toolbox.git@${{ github.head_ref || github.sha }}#subdirectory=.github/container/nsys_jax + - name: Run nsys-jax-patch-nsys + run: nsys-jax-patch-nsys From b0410ce24133cd3e9b5e1cd902789e38c2773382 Mon Sep 17 00:00:00 2001 From: Olli Lupton Date: Mon, 10 Mar 2025 12:43:13 +0000 Subject: [PATCH 3/5] nsys-jax: support repeated XLA_FLAGS The last value wins. Earlier values are pruned from the XLA_FLAGS passed to the profiled application. --- .../nsys_jax/nsys_jax/scripts/nsys_jax.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/.github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py b/.github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py index dfd6b29d1..fa8c939fb 100644 --- a/.github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py +++ b/.github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py @@ -259,21 +259,22 @@ def override_nsys_default(arg, value): if "JAX_ENABLE_COMPILATION_CACHE" not in env: env["JAX_ENABLE_COMPILATION_CACHE"] = "false" + def format_flag(tup): + n, v = tup + return f"--{n}" if v is None else f"--{n}={v}" + # Get the existing XLA_FLAGS and parse them into a dictionary. - xla_flag_list = shlex.split(env.get("XLA_FLAGS", "")) xla_flags = {} - for flag in xla_flag_list: + for flag in shlex.split(env.get("XLA_FLAGS", "")): assert flag.startswith("--") bits = flag[2:].split("=", maxsplit=1) name, value = bits[0], bits[1] if len(bits) > 1 else None - assert name not in xla_flags + if name in xla_flags: + print( + f"WARNING: {format_flag((name, xla_flags[name]))} being overriden by {flag}" + ) xla_flags[name] = value - def as_list(flags): - return [f"--{n}" if v is None else f"--{n}={v}" for n, v in flags.items()] - - assert xla_flag_list == as_list(xla_flags) - def as_bool(s): """String -> bool conversion following XLA's semantics.""" if s.lower() == "true" or s == "1": @@ -298,7 +299,7 @@ def as_bool(s): # Serialise the modified XLA flags. shlex.join is tempting, but doesn't seem to # get the right result for --xla_dump_hlo_pass_re=.*, as it adds extra quotes. - env["XLA_FLAGS"] = " ".join(as_list(xla_flags)) + env["XLA_FLAGS"] = " ".join(map(format_flag, xla_flags.items())) # Run the application in nsys # TODO: consider being more fault-tolerant? From 36a4648401632771fc228a7b26b553038d38f31d Mon Sep 17 00:00:00 2001 From: Olli Lupton Date: Mon, 10 Mar 2025 12:47:35 +0000 Subject: [PATCH 4/5] linter fix --- .github/container/nsys_jax/nsys_jax/scripts/patch_nsys.py | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/container/nsys_jax/nsys_jax/scripts/patch_nsys.py b/.github/container/nsys_jax/nsys_jax/scripts/patch_nsys.py index e2dd5abc7..bbf364560 100644 --- a/.github/container/nsys_jax/nsys_jax/scripts/patch_nsys.py +++ b/.github/container/nsys_jax/nsys_jax/scripts/patch_nsys.py @@ -1,4 +1,3 @@ -import os import re import shutil import subprocess From 2f57d9d901da81f6bdbc7970f6478861d86936b8 Mon Sep 17 00:00:00 2001 From: Olli Lupton Date: Mon, 10 Mar 2025 12:51:48 +0000 Subject: [PATCH 5/5] mypy fix --- .github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py b/.github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py index fa8c939fb..e5287c823 100644 --- a/.github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py +++ b/.github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py @@ -18,6 +18,7 @@ import tempfile import time import traceback +from typing import Optional import zipfile from .utils import execute_analysis_script, shuffle_analysis_arg @@ -264,7 +265,7 @@ def format_flag(tup): return f"--{n}" if v is None else f"--{n}={v}" # Get the existing XLA_FLAGS and parse them into a dictionary. - xla_flags = {} + xla_flags: dict[str, Optional[str]] = {} for flag in shlex.split(env.get("XLA_FLAGS", "")): assert flag.startswith("--") bits = flag[2:].split("=", maxsplit=1)