diff --git a/hubconf.py b/hubconf.py index 7d599deb4..66788b84a 100644 --- a/hubconf.py +++ b/hubconf.py @@ -66,7 +66,7 @@ def nvidia_ncf(pretrained=True, **kwargs): checkpoint = 'https://developer.nvidia.com/joc-ncf-fp16-pyt-20190225' else: checkpoint = 'https://developer.nvidia.com/joc-ncf-fp32-pyt-20190225' - ckpt_file = os.path.basename(checkpoint) + ckpt_file = os.path.join(torch.hub._get_torch_home(), "checkpoints", os.path.basename(checkpoint)) if not os.path.exists(ckpt_file) or force_reload: sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint)) urllib.request.urlretrieve(checkpoint, ckpt_file) @@ -130,7 +130,7 @@ def nvidia_tacotron2(pretrained=True, **kwargs): checkpoint = 'https://developer.nvidia.com/joc-tacotron2-fp16-pyt-20190306' else: checkpoint = 'https://developer.nvidia.com/joc-tacotron2-fp32-pyt-20190306' - ckpt_file = os.path.basename(checkpoint) + ckpt_file = os.path.join(torch.hub._get_torch_home(), "checkpoints", os.path.basename(checkpoint)) if not os.path.exists(ckpt_file) or force_reload: sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint)) urllib.request.urlretrieve(checkpoint, ckpt_file) @@ -190,7 +190,7 @@ def nvidia_waveglow(pretrained=True, **kwargs): checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp16-pyt-20190306' else: checkpoint = 'https://developer.nvidia.com/joc-waveglow-fp32-pyt-20190306' - ckpt_file = os.path.basename(checkpoint) + ckpt_file = os.path.join(torch.hub._get_torch_home(), "checkpoints", os.path.basename(checkpoint)) if not os.path.exists(ckpt_file) or force_reload: sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint)) urllib.request.urlretrieve(checkpoint, ckpt_file) @@ -360,7 +360,7 @@ def batchnorm_to_float(module): checkpoint = 'https://developer.nvidia.com/joc-ssd-fp16-pyt-20190225' else: checkpoint = 'https://developer.nvidia.com/joc-ssd-fp32-pyt-20190225' - ckpt_file = os.path.basename(checkpoint) + ckpt_file = os.path.join(torch.hub._get_torch_home(), "checkpoints", os.path.basename(checkpoint)) if not os.path.exists(ckpt_file) or force_reload: sys.stderr.write('Downloading checkpoint from {}\n'.format(checkpoint)) urllib.request.urlretrieve(checkpoint, ckpt_file)