-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Closed
Description
Currently the code that download a models weights doesn't use torch.hub.hub_dir, so the weights are downloaded to the current working directory instead of that specified by default or torch.hub.set_dir. It also means that the weights are re-downloaded each time the working directory is changed e.g
ssd_model = torch.hub.load(
'NVIDIA/DeepLearningExamples', 'nvidia_ssd')
os.chdir("../")
ssd_model = torch.hub.load('NVIDIA/DeepLearningExamples', 'nvidia_ssd')
#Downloads twice rather than using the cacheOne possible fix could be to use torch.hub.load_state_dict_from_url to download instead however it requires the filenames to follow the naming convention of model-<sha256>.pth, so the weight files would have to be renamed.
Or alternatively something along the lines of
ckpt_file = os.path.join(
torch.hub._get_torch_home(), "checkpoints", os.path.basename(checkpoint))Thanks
Metadata
Metadata
Assignees
Labels
No labels