Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions .github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import tempfile
import time
import traceback
from typing import Optional
import zipfile

from .utils import execute_analysis_script, shuffle_analysis_arg
Expand Down Expand Up @@ -259,21 +260,22 @@ def override_nsys_default(arg, value):
if "JAX_ENABLE_COMPILATION_CACHE" not in env:
env["JAX_ENABLE_COMPILATION_CACHE"] = "false"

def format_flag(tup):
n, v = tup
return f"--{n}" if v is None else f"--{n}={v}"

# Get the existing XLA_FLAGS and parse them into a dictionary.
xla_flag_list = shlex.split(env.get("XLA_FLAGS", ""))
xla_flags = {}
for flag in xla_flag_list:
xla_flags: dict[str, Optional[str]] = {}
for flag in shlex.split(env.get("XLA_FLAGS", "")):
assert flag.startswith("--")
bits = flag[2:].split("=", maxsplit=1)
name, value = bits[0], bits[1] if len(bits) > 1 else None
assert name not in xla_flags
if name in xla_flags:
print(
f"WARNING: {format_flag((name, xla_flags[name]))} being overriden by {flag}"
)
xla_flags[name] = value

def as_list(flags):
return [f"--{n}" if v is None else f"--{n}={v}" for n, v in flags.items()]

assert xla_flag_list == as_list(xla_flags)

def as_bool(s):
"""String -> bool conversion following XLA's semantics."""
if s.lower() == "true" or s == "1":
Expand All @@ -298,7 +300,7 @@ def as_bool(s):

# Serialise the modified XLA flags. shlex.join is tempting, but doesn't seem to
# get the right result for --xla_dump_hlo_pass_re=.*, as it adds extra quotes.
env["XLA_FLAGS"] = " ".join(as_list(xla_flags))
env["XLA_FLAGS"] = " ".join(map(format_flag, xla_flags.items()))

# Run the application in nsys
# TODO: consider being more fault-tolerant?
Expand Down
16 changes: 12 additions & 4 deletions .github/container/nsys_jax/nsys_jax/scripts/patch_nsys.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import re
import shutil
import subprocess
Expand Down Expand Up @@ -1372,11 +1371,20 @@ def main():
patch_content = None
if patch_content is not None:
print(f"Patching Nsight Systems version {m.group(1)}")
# e.g. /opt/nvidia/nsight-systems-cli/2024.7.1/target-linux-x64
tdir = os.path.dirname(os.path.realpath(nsys))
nsys_recipe_help = subprocess.check_output(
[nsys, "recipe", "--help"], text=True
)
m = re.search(
r"List of required Python packages: '(.*?)/nsys_recipe/requirements/common.txt'",
nsys_recipe_help,
)
assert m is not None, (
f"Could not determine target directory from: {nsys_recipe_help}"
)
# e.g. /opt/nvidia/nsight-systems-cli/2024.7.1/target-linux-x64/python/packages
subprocess.run(
[shutil.which("git"), "apply"],
cwd=os.path.join(tdir, "python", "packages"),
cwd=m.group(1),
input=patch_content,
check=True,
text=True,
Expand Down
21 changes: 20 additions & 1 deletion .github/workflows/nsys-jax.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: nsys-jax pure-Python CI
name: nsys-jax non-GPU CI

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
Expand Down Expand Up @@ -274,3 +274,22 @@ jobs:
if [[ $format_status != 0 || $check_status != 0 ]]; then
exit 1
fi
installation:
strategy:
matrix:
include:
- container: "nvidia/cuda:12.6.3-base-ubuntu24.04"
nsys_package: "cuda-nsight-systems-12-6"
- container: "nvidia/cuda:12.8.0-base-ubuntu24.04"
nsys_package: "cuda-nsight-systems-12-8"
runs-on: ubuntu-latest
container: "${{ matrix.container }}"
steps:
- name: Install ${{ matrix.nsys_package }}
run: |
apt-get update
apt-get install -y git python3-pip ${{ matrix.nsys_package }}
- name: Install nsys-jax
run: pip install --break-system-packages git+https://github.com/NVIDIA/JAX-Toolbox.git@${{ github.head_ref || github.sha }}#subdirectory=.github/container/nsys_jax
- name: Run nsys-jax-patch-nsys
run: nsys-jax-patch-nsys
Loading