Skip to content

Commit

Permalink
Add torch to setup_requires & dynamic import to prevent import erro…
Browse files Browse the repository at this point in the history
…rs when installing via pip (#514)

* dynamic import of torch to prevent import error when installing

* Update setup.py

Co-authored-by: Xintao <wxt1994@126.com>
Co-authored-by: Hans Brouwer <hans@brouwer.work>
  • Loading branch information
3 people committed Aug 30, 2022
1 parent b4f48db commit 3974c3f
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions setup.py
Expand Up @@ -5,8 +5,6 @@
import os
import subprocess
import time
import torch
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension

version_file = 'basicsr/version.py'

Expand Down Expand Up @@ -117,6 +115,12 @@ def get_requirements(filename='requirements.txt'):
if __name__ == '__main__':
cuda_ext = os.getenv('BASICSR_EXT') # whether compile cuda ext
if cuda_ext == 'True':
try:
import torch
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
except ImportError:
raise ImportError('Unable to import torch - torch is needed to build cuda extensions')

ext_modules = [
make_cuda_ext(
name='deform_conv_ext',
Expand All @@ -134,8 +138,10 @@ def get_requirements(filename='requirements.txt'):
sources=['src/upfirdn2d.cpp'],
sources_cuda=['src/upfirdn2d_kernel.cu']),
]
setup_kwargs = dict(cmdclass={'build_ext': BuildExtension})
else:
ext_modules = []
setup_kwargs = dict()

write_version_py()
setup(
Expand All @@ -159,8 +165,8 @@ def get_requirements(filename='requirements.txt'):
'Programming Language :: Python :: 3.8',
],
license='Apache License 2.0',
setup_requires=['cython', 'numpy'],
setup_requires=['cython', 'numpy', 'torch'],
install_requires=get_requirements(),
ext_modules=ext_modules,
cmdclass={'build_ext': BuildExtension},
zip_safe=False)
zip_safe=False,
**setup_kwargs)

0 comments on commit 3974c3f

Please sign in to comment.