Skip to content

Pytorch hub models don't download to correct directory #119

@onetonfoot

Description

@onetonfoot

Currently the code that download a models weights doesn't use torch.hub.hub_dir, so the weights are downloaded to the current working directory instead of that specified by default or torch.hub.set_dir. It also means that the weights are re-downloaded each time the working directory is changed e.g

ssd_model = torch.hub.load(
            'NVIDIA/DeepLearningExamples', 'nvidia_ssd')

os.chdir("../")

ssd_model = torch.hub.load('NVIDIA/DeepLearningExamples', 'nvidia_ssd')

#Downloads twice rather than using the cache

One possible fix could be to use torch.hub.load_state_dict_from_url to download instead however it requires the filenames to follow the naming convention of model-<sha256>.pth, so the weight files would have to be renamed.

Or alternatively something along the lines of

ckpt_file = os.path.join(
            torch.hub._get_torch_home(), "checkpoints", os.path.basename(checkpoint))

Thanks

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions