diff --git a/setup.py b/setup.py index 9e31e945e..bc228e473 100644 --- a/setup.py +++ b/setup.py @@ -5,8 +5,6 @@ import os import subprocess import time -import torch -from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension version_file = 'basicsr/version.py' @@ -117,6 +115,12 @@ def get_requirements(filename='requirements.txt'): if __name__ == '__main__': cuda_ext = os.getenv('BASICSR_EXT') # whether compile cuda ext if cuda_ext == 'True': + try: + import torch + from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension + except ImportError: + raise ImportError('Unable to import torch - torch is needed to build cuda extensions') + ext_modules = [ make_cuda_ext( name='deform_conv_ext', @@ -134,8 +138,10 @@ def get_requirements(filename='requirements.txt'): sources=['src/upfirdn2d.cpp'], sources_cuda=['src/upfirdn2d_kernel.cu']), ] + setup_kwargs = dict(cmdclass={'build_ext': BuildExtension}) else: ext_modules = [] + setup_kwargs = dict() write_version_py() setup( @@ -159,8 +165,8 @@ def get_requirements(filename='requirements.txt'): 'Programming Language :: Python :: 3.8', ], license='Apache License 2.0', - setup_requires=['cython', 'numpy'], + setup_requires=['cython', 'numpy', 'torch'], install_requires=get_requirements(), ext_modules=ext_modules, - cmdclass={'build_ext': BuildExtension}, - zip_safe=False) + zip_safe=False, + **setup_kwargs)