Skip to content

Commit

Permalink
Requires PyInstaller 6.x
Browse files Browse the repository at this point in the history
  • Loading branch information
CarlGao4 committed Dec 4, 2023
1 parent f3b19f7 commit 70d097a
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 8 deletions.
2 changes: 1 addition & 1 deletion news/666.update.rst
Original file line number Diff line number Diff line change
@@ -1 +1 @@
Discard header files (and static libraries on Windows) of ``torch``
Modernize the torch for ``hook`` and reduce the amount of unnecessarily collected data files (header files and static libraries). Requires PyInstaller >= 6.0.
35 changes: 28 additions & 7 deletions src/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,40 @@
# SPDX-License-Identifier: GPL-2.0-or-later
# ------------------------------------------------------------------

from PyInstaller.utils.hooks import logger, collect_data_files, is_module_satisfies, collect_dynamic_libs, collect_submodules
from PyInstaller.utils.hooks import (
logger,
collect_data_files,
is_module_satisfies,
collect_dynamic_libs,
collect_submodules,
get_package_paths,
)

module_collection_mode = 'pyz+py'
if is_module_satisfies("PyInstaller >= 6.0"):
module_collection_mode = "pyz+py"

datas = collect_data_files("torch", excludes=["**/*.h", "**/*.hpp", "**/*.cuh",
"**/*.lib", "**/*.cpp", "**/*.pyi", "**/*.cmake"])
binaries = collect_dynamic_libs("torch")
hiddenimports = collect_submodules("torch")
datas = collect_data_files(
"torch",
excludes=[
"**/*.h",
"**/*.hpp",
"**/*.cuh",
"**/*.lib",
"**/*.cpp",
"**/*.pyi",
"**/*.cmake",
],
)
binaries = collect_dynamic_libs("torch")
hiddenimports = collect_submodules("torch")
else:
datas = [(get_package_paths("torch")[1], "torch")]

# With torch 2.0.0, PyInstaller's modulegraph analysis hits the recursion limit.
# So, unless the user has already done so, increase it automatically.
if is_module_satisfies('torch >= 2.0.0'):
if is_module_satisfies("torch >= 2.0.0"):
import sys

new_limit = 5000
if sys.getrecursionlimit() < new_limit:
logger.info("hook-torch: raising recursion limit to %d", new_limit)
Expand Down

0 comments on commit 70d097a

Please sign in to comment.