Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dlclibrary/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@
# Licensed under GNU Lesser General Public License v3.0
#

from dlclibrary.dlcmodelzoo.modelzoo_download import download_hugginface_model
from dlclibrary.dlcmodelzoo.modelzoo_download import download_huggingface_model
from dlclibrary.version import __version__, VERSION
88 changes: 32 additions & 56 deletions dlclibrary/dlcmodelzoo/modelzoo_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"mouse_pupil_vclose",
"horse_sideview",
"full_macaque",
"superanimal_mouse",
"superanimal_mouse_topview",
]


Expand All @@ -29,83 +29,59 @@ def _get_dlclibrary_path():
return os.path.split(importlib.util.find_spec("dlclibrary").origin)[0]


def _loadmodelnames():
"""Loads URLs and commit hashes for available models."""
def _load_model_names():
"""Load URLs and commit hashes for available models."""
from ruamel.yaml import YAML

fn = os.path.join(_get_dlclibrary_path(), "modelzoo_urls.yaml")
with open(fn) as file:
return YAML().load(file)


def download_huggingface_model(modelname, target_dir=".", removeHFfolder=True):
def download_huggingface_model(modelname, target_dir=".", remove_hf_folder=True):
"""
Downloads a DeepLabCut Model Zoo Project from Hugging Face
Download a DeepLabCut Model Zoo Project from Hugging Face

Parameters
----------
modelname : string
Name of the ModelZoo model. For visualizations see: http://www.mackenziemathislab.org/dlc-modelzoo
target_dir : directory (as string)
Directory where to store the model weigths and pose_cfg.yaml file
removeHFfolder : bool, default True
Directory where to store the model weights and pose_cfg.yaml file
remove_hf_folder : bool, default True
Whether to remove the directory structure provided by HuggingFace after downloading and decompressing data into DeepLabCut format.
"""
from huggingface_hub import hf_hub_download
import tarfile, os
import tarfile
from pathlib import Path

neturls = _loadmodelnames()
neturls = _load_model_names()
if modelname not in neturls:
raise ValueError(f"`modelname` should be one of: {', '.join(modelname)}.")

if modelname in neturls.keys():
print("Loading....", modelname)
url = neturls[modelname].split("/")
repo_id, targzfn = url[0] + "/" + url[1], str(url[-1])
print("Loading....", modelname)
url = neturls[modelname].split("/")
repo_id, targzfn = url[0] + "/" + url[1], str(url[-1])

hf_hub_download(repo_id, targzfn, cache_dir=str(target_dir))
# creates a new subfolder as indicated below, unzipping from there and deleting this folder
hf_hub_download(repo_id, targzfn, cache_dir=str(target_dir))

# Building the HuggingFaceHub download path:
hf_path = (
"models--"
+ url[0]
+ "--"
+ url[1]
+ "/snapshots/"
+ str(neturls[modelname + "_commit"])
+ "/"
+ targzfn
)
# Create a new subfolder as indicated below, unzipping from there and deleting this folder
hf_folder = f"models--{url[0]}--{url[1]}"
hf_path = os.path.join(
hf_folder,
"snapshots",
str(neturls[modelname + "_commit"]),
targzfn,
)

filename = os.path.join(target_dir, hf_path)
with tarfile.open(filename, mode="r:gz") as tar:
for member in tar:
if not member.isdir():
fname = Path(member.name).name # getting the filename
tar.makefile(member, target_dir + "/" + fname)
# tar.extractall(target_dir, members=tarfilenamecutting(tar))
filename = os.path.join(target_dir, hf_path)
with tarfile.open(filename, mode="r:gz") as tar:
for member in tar:
if not member.isdir():
fname = Path(member.name).name
tar.makefile(member, os.path.join(target_dir, fname))

if removeHFfolder:
# Removing folder
import shutil
if remove_hf_folder:
import shutil

shutil.rmtree(
Path(os.path.join(target_dir, "models--" + url[0] + "--" + url[1]))
)

else:
models = [fn for fn in neturls.keys()]
print("Model does not exist: ", modelname)
print("Pick one of the following: ", MODELOPTIONS)


if __name__ == "__main__":
print("Randomly downloading a model for testing...")

import random

# modelname = 'full_cat'
modelname = random.choice(MODELOPTIONS)

target_dir = "/Users/alex/Downloads" # folder has to exist!
download_hugginface_model(modelname, target_dir)
shutil.rmtree(os.path.join(target_dir, hf_folder))
4 changes: 2 additions & 2 deletions dlclibrary/modelzoo_urls.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@ horse_sideview_commit: fd0329b2ffc8fe7a5e6eb3d4850ebca75987e92c
full_macaque: mwmathis/DeepLabCutModelZoo-macaque_full/DLC_macaque_full_resnet50.tar.gz
full_macaque_commit: 4c7ebf2628d5b7eb0483356595256fb01b7e1a9e

superanimal_mouse: mwmathis/DeepLabCutModelZoo-SuperAnimal-TopViewMouse/DLC_ma_supertopview5k_resnet_50_iteration-0_shuffle-1.tar.gz
superanimal_mouse_commit: a7d7df40c3307a3c7a0ceeb2593d46a783235b28
superanimal_mouse_topview: mwmathis/DeepLabCutModelZoo-SuperAnimal-TopViewMouse/DLC_ma_supertopview5k_resnet_50_iteration-0_shuffle-1.tar.gz
superanimal_mouse_topview_commit: a7d7df40c3307a3c7a0ceeb2593d46a783235b28
14 changes: 10 additions & 4 deletions tests/test_modeldownload.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,21 @@
#
# Licensed under GNU Lesser General Public License v3.0
#
import dlclibrary
import os
import pytest


def test_catdownload(tmp_path_factory):
# TODO: just download the lightweight stuff..
import dlclibrary, os

def test_download_huggingface_model(tmp_path_factory):
folder = tmp_path_factory.mktemp("cat")
dlclibrary.download_huggingface_model("full_cat", str(folder))

assert os.path.exists(folder / "pose_cfg.yaml")
assert os.path.exists(folder / "snapshot-75000.meta")
# Verify that the Hugging Face folder was removed
assert not any(f.startswith("models--") for f in os.listdir(folder))


def test_download_huggingface_wrong_model():
with pytest.raises(ValueError):
dlclibrary.download_huggingface_model("wrong_model_name")