diff --git a/torchgen/gen.py b/torchgen/gen.py index 5c9b156b50442..78699bcac1f8a 100644 --- a/torchgen/gen.py +++ b/torchgen/gen.py @@ -2710,6 +2710,12 @@ def main() -> None: help="output directory", default="build/aten/src/ATen", ) + parser.add_argument( + "--aoti-install-dir", + "--aoti_install_dir", + help="output directory for AOTInductor shim", + default="torch/csrc/inductor/aoti_torch/generated", + ) parser.add_argument( "--rocm", action="store_true", @@ -2830,15 +2836,15 @@ def main() -> None: pathlib.Path(core_install_dir).mkdir(parents=True, exist_ok=True) ops_install_dir = f"{options.install_dir}/ops" pathlib.Path(ops_install_dir).mkdir(parents=True, exist_ok=True) + aoti_install_dir = f"{options.aoti_install_dir}" + pathlib.Path(aoti_install_dir).mkdir(parents=True, exist_ok=True) core_fm = make_file_manager(options=options, install_dir=core_install_dir) cpu_fm = make_file_manager(options=options) cpu_vec_fm = make_file_manager(options=options) cuda_fm = make_file_manager(options=options) ops_fm = make_file_manager(options=options, install_dir=ops_install_dir) - aoti_fm = make_file_manager( - options=options, install_dir="torch/csrc/inductor/aoti_torch/generated" - ) + aoti_fm = make_file_manager(options=options, install_dir=aoti_install_dir) # Only a limited set of dispatch keys get CPUFunctions.h headers generated # for them; this is the set