Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,35 @@ def _env_enabled_any(names, default="1") -> bool:
cpp_ext = None

if cpp_ext is not None:
# Limit compile parallelism to avoid overwhelming nvcc/cicc invocations.
# Respect pre-set MAX_JOBS, otherwise fall back to CPU count minus two (min 1).
cpu_count = os.cpu_count() or 1
default_max_jobs = max(1, cpu_count - 2)
max_jobs_raw = os.environ.get("MAX_JOBS")
if max_jobs_raw is None or max_jobs_raw.strip() == "":
effective_max_jobs = default_max_jobs
print(f"MAX_JOBS not set; defaulting to {effective_max_jobs} concurrent CUDA compilations.")
else:
try:
parsed_jobs = int(max_jobs_raw)
except ValueError:
effective_max_jobs = default_max_jobs
print(f"Ignoring invalid MAX_JOBS={max_jobs_raw!r}; using {effective_max_jobs}.")
else:
if parsed_jobs <= 0:
effective_max_jobs = default_max_jobs
print(f"MAX_JOBS={parsed_jobs} is non-positive; using {effective_max_jobs}.")
else:
effective_max_jobs = parsed_jobs

os.environ["MAX_JOBS"] = str(effective_max_jobs)
os.environ["NINJA_NUM_JOBS"] = str(effective_max_jobs)
print(f"Using MAX_JOBS={effective_max_jobs} to cap concurrent CUDA compilations.")

nvcc_threads = 1
os.environ["NVCC_THREADS"] = str(nvcc_threads)
print(f"Using NVCC_THREADS={nvcc_threads} for per-invocation NVCC concurrency.")

# Optional conda CUDA runtime headers
#conda_cuda_include_dir = os.path.join(get_python_lib(), "nvidia/cuda_runtime/include")
# if os.path.isdir(conda_cuda_include_dir):
Expand Down Expand Up @@ -543,7 +572,7 @@ def _env_enabled_any(names, default="1") -> bool:
# if _version_geq(NVCC_VERSION, 13, 0):
# extra_compile_args["nvcc"].append("--device-entity-has-hidden-visibility=false")
nvcc_extra_flags = [
"--threads", "8", # NVCC parallelism
"--threads", str(nvcc_threads), # NVCC parallelism
"--optimize=3", # alias for -O3
# "-rdc=true", # enable relocatable device code, required for future cuda > 13.x <-- TODO FIX ME broken loading
# "-dlto", # compile and link <-- TODO FIX ME
Expand Down
Loading