From 70d097a4f36f2480cfda5cd736675a0826d0fc6b Mon Sep 17 00:00:00 2001 From: Weiqi Gao Date: Mon, 4 Dec 2023 08:40:16 +0800 Subject: [PATCH] Requires PyInstaller 6.x --- news/666.update.rst | 2 +- .../hooks/stdhooks/hook-torch.py | 35 +++++++++++++++---- 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/news/666.update.rst b/news/666.update.rst index 465defad..60308526 100644 --- a/news/666.update.rst +++ b/news/666.update.rst @@ -1 +1 @@ -Discard header files (and static libraries on Windows) of ``torch`` +Modernize the torch for ``hook`` and reduce the amount of unnecessarily collected data files (header files and static libraries). Requires PyInstaller >= 6.0. diff --git a/src/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-torch.py b/src/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-torch.py index 40014c45..6b0265de 100644 --- a/src/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-torch.py +++ b/src/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-torch.py @@ -10,19 +10,40 @@ # SPDX-License-Identifier: GPL-2.0-or-later # ------------------------------------------------------------------ -from PyInstaller.utils.hooks import logger, collect_data_files, is_module_satisfies, collect_dynamic_libs, collect_submodules +from PyInstaller.utils.hooks import ( + logger, + collect_data_files, + is_module_satisfies, + collect_dynamic_libs, + collect_submodules, + get_package_paths, +) -module_collection_mode = 'pyz+py' +if is_module_satisfies("PyInstaller >= 6.0"): + module_collection_mode = "pyz+py" -datas = collect_data_files("torch", excludes=["**/*.h", "**/*.hpp", "**/*.cuh", - "**/*.lib", "**/*.cpp", "**/*.pyi", "**/*.cmake"]) -binaries = collect_dynamic_libs("torch") -hiddenimports = collect_submodules("torch") + datas = collect_data_files( + "torch", + excludes=[ + "**/*.h", + "**/*.hpp", + "**/*.cuh", + "**/*.lib", + "**/*.cpp", + "**/*.pyi", + "**/*.cmake", + ], + ) + binaries = collect_dynamic_libs("torch") + hiddenimports = collect_submodules("torch") +else: + datas = [(get_package_paths("torch")[1], "torch")] # With torch 2.0.0, PyInstaller's modulegraph analysis hits the recursion limit. # So, unless the user has already done so, increase it automatically. -if is_module_satisfies('torch >= 2.0.0'): +if is_module_satisfies("torch >= 2.0.0"): import sys + new_limit = 5000 if sys.getrecursionlimit() < new_limit: logger.info("hook-torch: raising recursion limit to %d", new_limit)