diff --git a/setup.py b/setup.py index 7f95bf093..7fd343759 100644 --- a/setup.py +++ b/setup.py @@ -496,6 +496,35 @@ def _env_enabled_any(names, default="1") -> bool: cpp_ext = None if cpp_ext is not None: + # Limit compile parallelism to avoid overwhelming nvcc/cicc invocations. + # Respect pre-set MAX_JOBS, otherwise fall back to CPU count minus two (min 1). + cpu_count = os.cpu_count() or 1 + default_max_jobs = max(1, cpu_count - 2) + max_jobs_raw = os.environ.get("MAX_JOBS") + if max_jobs_raw is None or max_jobs_raw.strip() == "": + effective_max_jobs = default_max_jobs + print(f"MAX_JOBS not set; defaulting to {effective_max_jobs} concurrent CUDA compilations.") + else: + try: + parsed_jobs = int(max_jobs_raw) + except ValueError: + effective_max_jobs = default_max_jobs + print(f"Ignoring invalid MAX_JOBS={max_jobs_raw!r}; using {effective_max_jobs}.") + else: + if parsed_jobs <= 0: + effective_max_jobs = default_max_jobs + print(f"MAX_JOBS={parsed_jobs} is non-positive; using {effective_max_jobs}.") + else: + effective_max_jobs = parsed_jobs + + os.environ["MAX_JOBS"] = str(effective_max_jobs) + os.environ["NINJA_NUM_JOBS"] = str(effective_max_jobs) + print(f"Using MAX_JOBS={effective_max_jobs} to cap concurrent CUDA compilations.") + + nvcc_threads = 1 + os.environ["NVCC_THREADS"] = str(nvcc_threads) + print(f"Using NVCC_THREADS={nvcc_threads} for per-invocation NVCC concurrency.") + # Optional conda CUDA runtime headers #conda_cuda_include_dir = os.path.join(get_python_lib(), "nvidia/cuda_runtime/include") # if os.path.isdir(conda_cuda_include_dir): @@ -543,7 +572,7 @@ def _env_enabled_any(names, default="1") -> bool: # if _version_geq(NVCC_VERSION, 13, 0): # extra_compile_args["nvcc"].append("--device-entity-has-hidden-visibility=false") nvcc_extra_flags = [ - "--threads", "8", # NVCC parallelism + "--threads", str(nvcc_threads), # NVCC parallelism "--optimize=3", # alias for -O3 # "-rdc=true", # enable relocatable device code, required for future cuda > 13.x <-- TODO FIX ME broken loading # "-dlto", # compile and link <-- TODO FIX ME