diff --git a/setup.py b/setup.py index f9a5027e3..7f95bf093 100644 --- a/setup.py +++ b/setup.py @@ -525,6 +525,8 @@ def _env_enabled_any(names, default="1") -> bool: cutlass_root / "examples/common/include", cutlass_root / "tools/library/include", ] + if "GPTQMODEL_CUTLASS_DIR" not in os.environ: + os.environ["GPTQMODEL_CUTLASS_DIR"] = str(cutlass_root) cutlass_include_flags = [f"-I{path}" for path in cutlass_include_paths] extra_compile_args["cxx"] += cutlass_include_flags extra_compile_args["nvcc"] += cutlass_include_flags