diff --git a/Dockerfile b/Dockerfile index 52198a69f..e033f7a81 100644 --- a/Dockerfile +++ b/Dockerfile @@ -50,8 +50,9 @@ RUN annif completion --bash >> /etc/bash.bashrc # Enable tab completion RUN groupadd -g 998 annif_user && \ useradd -r -u 998 -g annif_user annif_user && \ chmod -R a+rX /Annif && \ - mkdir -p /Annif/tests/data && \ + mkdir -p /Annif/tests/data /Annif/projects.d && \ chown -R annif_user:annif_user /annif-projects /Annif/tests/data USER annif_user +ENV HF_HOME="/tmp" CMD annif diff --git a/annif/cli.py b/annif/cli.py index d8ca1ea56..cc62a0f96 100644 --- a/annif/cli.py +++ b/annif/cli.py @@ -17,8 +17,12 @@ import annif.parallel import annif.project import annif.registry -from annif import cli_util -from annif.exception import NotInitializedException, NotSupportedException +from annif import cli_util, hfh_util +from annif.exception import ( + NotInitializedException, + NotSupportedException, + OperationFailedException, +) from annif.project import Access from annif.util import metric_code @@ -582,6 +586,124 @@ def run_hyperopt(project_id, paths, docs_limit, trials, jobs, metric, results_fi click.echo("---") +@cli.command("upload") +@click.argument("project_ids_pattern", shell_complete=cli_util.complete_param) +@click.argument("repo_id") +@click.option( + "--token", + help="""Authentication token, obtained from the Hugging Face Hub. + Will default to the stored token.""", +) +@click.option( + "--revision", + help="""An optional git revision to commit from. Defaults to the head of the "main" + branch.""", +) +@click.option( + "--commit-message", + help="""The summary / title / first line of the generated commit.""", +) +@cli_util.common_options +def run_upload(project_ids_pattern, repo_id, token, revision, commit_message): + """ + Upload selected projects and their vocabularies to a Hugging Face Hub repository. + \f + This command zips the project directories and vocabularies of the projects + that match the given `project_ids_pattern` to archive files, and uploads the + archives along with the project configurations to the specified Hugging Face + Hub repository. An authentication token and commit message can be given with + options. + """ + from huggingface_hub import HfApi + from huggingface_hub.utils import HfHubHTTPError, HFValidationError + + projects = hfh_util.get_matching_projects(project_ids_pattern) + click.echo(f"Uploading project(s): {', '.join([p.project_id for p in projects])}") + + commit_message = ( + commit_message + if commit_message is not None + else f"Upload project(s) {project_ids_pattern} with Annif" + ) + + fobjs, operations = [], [] + try: + fobjs, operations = hfh_util.prepare_commits(projects, repo_id) + api = HfApi() + api.create_commit( + repo_id=repo_id, + operations=operations, + commit_message=commit_message, + revision=revision, + token=token, + ) + except (HfHubHTTPError, HFValidationError) as err: + raise OperationFailedException(str(err)) + finally: + for fobj in fobjs: + fobj.close() + + +@cli.command("download") +@click.argument("project_ids_pattern") +@click.argument("repo_id") +@click.option( + "--token", + help="""Authentication token, obtained from the Hugging Face Hub. + Will default to the stored token.""", +) +@click.option( + "--revision", + help=""" + An optional Git revision id which can be a branch name, a tag, or a commit + hash. + """, +) +@click.option( + "--force", + "-f", + default=False, + is_flag=True, + help="Replace an existing project/vocabulary/config with the downloaded one", +) +@cli_util.common_options +def run_download(project_ids_pattern, repo_id, token, revision, force): + """ + Download selected projects and their vocabularies from a Hugging Face Hub + repository. + \f + This command downloads the project and vocabulary archives and the + configuration files of the projects that match the given + `project_ids_pattern` from the specified Hugging Face Hub repository and + unzips the archives to `data/` directory and places the configuration files + to `projects.d/` directory. An authentication token and revision can + be given with options. + """ + + project_ids = hfh_util.get_matching_project_ids_from_hf_hub( + project_ids_pattern, repo_id, token, revision + ) + click.echo(f"Downloading project(s): {', '.join(project_ids)}") + + vocab_ids = set() + for project_id in project_ids: + project_zip_cache_path = hfh_util.download_from_hf_hub( + f"projects/{project_id}.zip", repo_id, token, revision + ) + hfh_util.unzip_archive(project_zip_cache_path, force) + config_file_cache_path = hfh_util.download_from_hf_hub( + f"{project_id}.cfg", repo_id, token, revision + ) + vocab_ids.add(hfh_util.get_vocab_id_from_config(config_file_cache_path)) + hfh_util.copy_project_config(config_file_cache_path, force) + + for vocab_id in vocab_ids: + vocab_zip_cache_path = hfh_util.download_from_hf_hub( + f"vocabs/{vocab_id}.zip", repo_id, token, revision + ) + hfh_util.unzip_archive(vocab_zip_cache_path, force) + + @cli.command("completion") @click.option("--bash", "shell", flag_value="bash") @click.option("--zsh", "shell", flag_value="zsh") diff --git a/annif/cli_util.py b/annif/cli_util.py index 9f33f8153..2a64582f2 100644 --- a/annif/cli_util.py +++ b/annif/cli_util.py @@ -17,8 +17,8 @@ from annif.project import Access if TYPE_CHECKING: + import io from datetime import datetime - from io import TextIOWrapper from click.core import Argument, Context, Option @@ -185,7 +185,7 @@ def show_hits( hits: SuggestionResult, project: AnnifProject, lang: str, - file: TextIOWrapper | None = None, + file: io.TextIOWrapper | None = None, ) -> None: """ Print subject suggestions to the console or a file. The suggestions are displayed as @@ -234,7 +234,7 @@ def generate_filter_params(filter_batch_max_limit: int) -> list[tuple[int, float def _get_completion_choices( param: Argument, ) -> dict[str, AnnifVocabulary] | dict[str, AnnifProject] | list: - if param.name == "project_id": + if param.name in ("project_id", "project_ids_pattern"): return annif.registry.get_projects() elif param.name == "vocab_id": return annif.registry.get_vocabs() diff --git a/annif/hfh_util.py b/annif/hfh_util.py new file mode 100644 index 000000000..045e4710f --- /dev/null +++ b/annif/hfh_util.py @@ -0,0 +1,240 @@ +"""Utility functions for interactions with Hugging Face Hub.""" + +import binascii +import configparser +import importlib +import io +import os +import pathlib +import shutil +import tempfile +import time +import zipfile +from fnmatch import fnmatch +from typing import Any + +import click +from flask import current_app + +import annif +from annif.exception import OperationFailedException +from annif.project import Access, AnnifProject + +logger = annif.logger + + +def get_matching_projects(pattern: str) -> list[AnnifProject]: + """ + Get projects that match the given pattern. + """ + return [ + proj + for proj in annif.registry.get_projects(min_access=Access.private).values() + if fnmatch(proj.project_id, pattern) + ] + + +def prepare_commits(projects: list[AnnifProject], repo_id: str) -> tuple[list, list]: + """Prepare and pre-upload data and config commit operations for projects to a + Hugging Face Hub repository.""" + from huggingface_hub import preupload_lfs_files + + fobjs, operations = [], [] + data_dirs = {p.datadir for p in projects} + vocab_dirs = {p.vocab.datadir for p in projects} + all_dirs = data_dirs.union(vocab_dirs) + + for data_dir in all_dirs: + fobj, operation = _prepare_datadir_commit(data_dir) + preupload_lfs_files(repo_id, additions=[operation]) + fobjs.append(fobj) + operations.append(operation) + + for project in projects: + fobj, operation = _prepare_config_commit(project) + fobjs.append(fobj) + operations.append(operation) + + return fobjs, operations + + +def _prepare_datadir_commit(data_dir: str) -> tuple[io.BufferedRandom, Any]: + from huggingface_hub import CommitOperationAdd + + zip_repo_path = data_dir.split(os.path.sep, 1)[1] + ".zip" + fobj = _archive_dir(data_dir) + operation = CommitOperationAdd(path_in_repo=zip_repo_path, path_or_fileobj=fobj) + return fobj, operation + + +def _prepare_config_commit(project: AnnifProject) -> tuple[io.BytesIO, Any]: + from huggingface_hub import CommitOperationAdd + + config_repo_path = project.project_id + ".cfg" + fobj = _get_project_config(project) + operation = CommitOperationAdd(path_in_repo=config_repo_path, path_or_fileobj=fobj) + return fobj, operation + + +def _is_train_file(fname: str) -> bool: + train_file_patterns = ("-train", "tmp-") + for pat in train_file_patterns: + if pat in fname: + return True + return False + + +def _archive_dir(data_dir: str) -> io.BufferedRandom: + fp = tempfile.TemporaryFile() + path = pathlib.Path(data_dir) + fpaths = [fpath for fpath in path.glob("**/*") if not _is_train_file(fpath.name)] + with zipfile.ZipFile(fp, mode="w") as zfile: + zfile.comment = bytes( + f"Archived by Annif {importlib.metadata.version('annif')}", + encoding="utf-8", + ) + for fpath in fpaths: + logger.debug(f"Adding {fpath}") + arcname = os.path.join(*fpath.parts[1:]) + zfile.write(fpath, arcname=arcname) + fp.seek(0) + return fp + + +def _get_project_config(project: AnnifProject) -> io.BytesIO: + fp = tempfile.TemporaryFile(mode="w+t") + config = configparser.ConfigParser() + config[project.project_id] = project.config + config.write(fp) # This needs tempfile in text mode + fp.seek(0) + # But for upload fobj needs to be in binary mode + return io.BytesIO(fp.read().encode("utf8")) + + +def get_matching_project_ids_from_hf_hub( + project_ids_pattern: str, repo_id: str, token, revision: str +) -> list[str]: + """Get project IDs of the projects in a Hugging Face Model Hub repository that match + the given pattern.""" + all_repo_file_paths = _list_files_in_hf_hub(repo_id, token, revision) + return [ + path.rsplit(".cfg")[0] + for path in all_repo_file_paths + if fnmatch(path, f"{project_ids_pattern}.cfg") + ] + + +def _list_files_in_hf_hub(repo_id: str, token: str, revision: str) -> list[str]: + from huggingface_hub import list_repo_files + from huggingface_hub.utils import HfHubHTTPError, HFValidationError + + try: + return [ + repofile + for repofile in list_repo_files( + repo_id=repo_id, token=token, revision=revision + ) + ] + except (HfHubHTTPError, HFValidationError) as err: + raise OperationFailedException(str(err)) + + +def download_from_hf_hub( + filename: str, repo_id: str, token: str, revision: str +) -> list[str]: + from huggingface_hub import hf_hub_download + from huggingface_hub.utils import HfHubHTTPError, HFValidationError + + try: + return hf_hub_download( + repo_id=repo_id, + filename=filename, + token=token, + revision=revision, + ) + except (HfHubHTTPError, HFValidationError) as err: + raise OperationFailedException(str(err)) + + +def unzip_archive(src_path: str, force: bool) -> None: + """Unzip a zip archive of projects and vocabularies to a directory, by + default data/ under current directory.""" + datadir = current_app.config["DATADIR"] + with zipfile.ZipFile(src_path, "r") as zfile: + archive_comment = str(zfile.comment, encoding="utf-8") + logger.debug( + f'Extracting archive {src_path}; archive comment: "{archive_comment}"' + ) + for member in zfile.infolist(): + _unzip_member(zfile, member, datadir, force) + + +def _unzip_member( + zfile: zipfile.ZipFile, member: zipfile.ZipInfo, datadir: str, force: bool +) -> None: + dest_path = os.path.join(datadir, member.filename) + if os.path.exists(dest_path) and not force: + _handle_existing_file(member, dest_path) + return + logger.debug(f"Unzipping to {dest_path}") + zfile.extract(member, path=datadir) + _restore_timestamps(member, dest_path) + + +def _handle_existing_file(member: zipfile.ZipInfo, dest_path: str) -> None: + if _are_identical_member_and_file(member, dest_path): + logger.debug(f"Skipping unzip to {dest_path}; already in place") + else: + click.echo(f"Not overwriting {dest_path} (use --force to override)") + + +def _are_identical_member_and_file(member: zipfile.ZipInfo, dest_path: str) -> bool: + path_crc = _compute_crc32(dest_path) + return path_crc == member.CRC + + +def _restore_timestamps(member: zipfile.ZipInfo, dest_path: str) -> None: + date_time = time.mktime(member.date_time + (0, 0, -1)) + os.utime(dest_path, (date_time, date_time)) + + +def copy_project_config(src_path: str, force: bool) -> None: + """Copy a given project configuration file to projects.d/ directory.""" + project_configs_dest_dir = "projects.d" + os.makedirs(project_configs_dest_dir, exist_ok=True) + + dest_path = os.path.join(project_configs_dest_dir, os.path.basename(src_path)) + if os.path.exists(dest_path) and not force: + if _are_identical_files(src_path, dest_path): + logger.debug(f"Skipping copy to {dest_path}; already in place") + else: + click.echo(f"Not overwriting {dest_path} (use --force to override)") + else: + logger.debug(f"Copying to {dest_path}") + shutil.copy(src_path, dest_path) + + +def _are_identical_files(src_path: str, dest_path: str) -> bool: + src_crc32 = _compute_crc32(src_path) + dest_crc32 = _compute_crc32(dest_path) + return src_crc32 == dest_crc32 + + +def _compute_crc32(path: str) -> int: + if os.path.isdir(path): + return 0 + + size = 1024 * 1024 * 10 # 10 MiB chunks + with open(path, "rb") as fp: + crcval = 0 + while chunk := fp.read(size): + crcval = binascii.crc32(chunk, crcval) + return crcval + + +def get_vocab_id_from_config(config_path: str) -> str: + """Get the vocabulary ID from a configuration file.""" + config = configparser.ConfigParser() + config.read(config_path) + section = config.sections()[0] + return config[section]["vocab"] diff --git a/docs/source/commands.rst b/docs/source/commands.rst index 849f6aadf..a5ae2f46f 100644 --- a/docs/source/commands.rst +++ b/docs/source/commands.rst @@ -66,6 +66,20 @@ Project administration N/A +.. click:: annif.cli:run_upload + :prog: annif upload + +**REST equivalent** + + N/A + +.. click:: annif.cli:run_download + :prog: annif download + +**REST equivalent** + + N/A + **************************** Subject index administration **************************** diff --git a/pyproject.toml b/pyproject.toml index e83921c3b..be416961b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ python-dateutil = "2.8.*" tomli = { version = "2.0.*", python = "<3.11" } simplemma = "0.9.*" jsonschema = "4.21.*" +huggingface-hub = "0.22.*" fasttext-wheel = { version = "0.9.2", optional = true } voikko = { version = "0.5.*", optional = true } diff --git a/tests/huggingface-cache/dummy-en.cfg b/tests/huggingface-cache/dummy-en.cfg new file mode 100644 index 000000000..58398e8d0 --- /dev/null +++ b/tests/huggingface-cache/dummy-en.cfg @@ -0,0 +1,7 @@ +[dummy-en] +name=Dummy English +language=en +backend=dummy +analyzer=snowball(english) +vocab=dummy +access=hidden diff --git a/tests/huggingface-cache/dummy-fi.cfg b/tests/huggingface-cache/dummy-fi.cfg new file mode 100644 index 000000000..4d996f9b6 --- /dev/null +++ b/tests/huggingface-cache/dummy-fi.cfg @@ -0,0 +1,8 @@ +[dummy-fi] +name=Dummy Finnish +language=fi +backend=dummy +analyzer=snowball(finnish) +key=value +vocab=dummy +access=public diff --git a/tests/huggingface-cache/projects/dummy-en.zip b/tests/huggingface-cache/projects/dummy-en.zip new file mode 100644 index 000000000..5325bf527 Binary files /dev/null and b/tests/huggingface-cache/projects/dummy-en.zip differ diff --git a/tests/huggingface-cache/projects/dummy-fi.zip b/tests/huggingface-cache/projects/dummy-fi.zip new file mode 100644 index 000000000..3c6f29f4a Binary files /dev/null and b/tests/huggingface-cache/projects/dummy-fi.zip differ diff --git a/tests/huggingface-cache/vocabs/dummy.zip b/tests/huggingface-cache/vocabs/dummy.zip new file mode 100644 index 000000000..b43a5f3eb Binary files /dev/null and b/tests/huggingface-cache/vocabs/dummy.zip differ diff --git a/tests/test_cli.py b/tests/test_cli.py index 77adeab0f..46a9fa0ad 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -12,8 +12,11 @@ from click.shell_completion import ShellComplete from click.testing import CliRunner +from huggingface_hub.utils import HFValidationError import annif.cli +import annif.cli_util +import annif.hfh_util import annif.parallel runner = CliRunner(env={"ANNIF_CONFIG": "annif.default_config.TestingConfig"}) @@ -1072,6 +1075,257 @@ def test_routes_with_connexion_app(): assert re.search(r"app.home\s+GET\s+\/", result) +@mock.patch("huggingface_hub.HfApi.preupload_lfs_files") +@mock.patch("huggingface_hub.CommitOperationAdd") +@mock.patch("huggingface_hub.HfApi.create_commit") +def test_upload(create_commit, CommitOperationAdd, preupload_lfs_files): + result = runner.invoke(annif.cli.cli, ["upload", "dummy-fi", "dummy-repo"]) + assert not result.exception + assert create_commit.call_count == 1 + assert CommitOperationAdd.call_count == 3 # projects, vocab, config + assert ( + mock.call( + path_or_fileobj=mock.ANY, # io.BufferedRandom object + path_in_repo="data/vocabs/dummy.zip", + ) + in CommitOperationAdd.call_args_list + ) + assert ( + mock.call( + path_or_fileobj=mock.ANY, # io.BufferedRandom object + path_in_repo="data/projects/dummy-fi.zip", + ) + in CommitOperationAdd.call_args_list + ) + assert ( + mock.call( + path_or_fileobj=mock.ANY, # io.BytesIO object + path_in_repo="dummy-fi.cfg", + ) + in CommitOperationAdd.call_args_list + ) + assert ( + mock.call( + repo_id="dummy-repo", + operations=mock.ANY, + commit_message="Upload project(s) dummy-fi with Annif", + token=None, + revision=None, + ) + in create_commit.call_args_list + ) + + +@mock.patch("huggingface_hub.HfApi.preupload_lfs_files") +@mock.patch("huggingface_hub.CommitOperationAdd") +@mock.patch("huggingface_hub.HfApi.create_commit") +def test_upload_many(create_commit, CommitOperationAdd, preupload_lfs_files): + result = runner.invoke(annif.cli.cli, ["upload", "dummy-*", "dummy-repo"]) + assert not result.exception + assert create_commit.call_count == 1 + assert CommitOperationAdd.call_count == 11 + + +def test_upload_nonexistent_repo(): + failed_result = runner.invoke(annif.cli.cli, ["upload", "dummy-fi", "nonexistent"]) + assert failed_result.exception + assert failed_result.exit_code != 0 + assert "Repository Not Found for url:" in failed_result.output + + +def hf_hub_download_mock_side_effect(filename, repo_id, token, revision): + return "tests/huggingface-cache/" + filename # Mocks the downloaded file paths + + +@mock.patch( + "huggingface_hub.list_repo_files", + return_value=[ # Mocks the filenames in repo + "projects/dummy-fi.zip", + "vocabs/dummy.zip", + "dummy-fi.cfg", + "projects/dummy-en.zip", + "vocabs/dummy.zip", + "dummy-en.cfg", + ], +) +@mock.patch( + "huggingface_hub.hf_hub_download", + side_effect=hf_hub_download_mock_side_effect, +) +@mock.patch("annif.hfh_util.copy_project_config") +def test_download_dummy_fi( + copy_project_config, hf_hub_download, list_repo_files, testdatadir +): + result = runner.invoke( + annif.cli.cli, + [ + "download", + "dummy-fi", + "mock-repo", + ], + ) + assert not result.exception + assert list_repo_files.called + assert hf_hub_download.called + assert hf_hub_download.call_args_list == [ + mock.call( + repo_id="mock-repo", + filename="projects/dummy-fi.zip", + token=None, + revision=None, + ), + mock.call( + repo_id="mock-repo", + filename="dummy-fi.cfg", + token=None, + revision=None, + ), + mock.call( + repo_id="mock-repo", + filename="vocabs/dummy.zip", + token=None, + revision=None, + ), + ] + dirpath = os.path.join(str(testdatadir), "projects", "dummy-fi") + fpath = os.path.join(str(dirpath), "file.txt") + assert os.path.exists(fpath) + assert copy_project_config.call_args_list == [ + mock.call("tests/huggingface-cache/dummy-fi.cfg", False) + ] + + +@mock.patch( + "huggingface_hub.list_repo_files", + return_value=[ # Mock filenames in repo + "projects/dummy-fi.zip", + "vocabs/dummy.zip", + "dummy-fi.cfg", + "projects/dummy-en.zip", + "vocabs/dummy.zip", + "dummy-en.cfg", + ], +) +@mock.patch( + "huggingface_hub.hf_hub_download", + side_effect=hf_hub_download_mock_side_effect, +) +@mock.patch("annif.hfh_util.copy_project_config") +def test_download_dummy_fi_and_en( + copy_project_config, hf_hub_download, list_repo_files, testdatadir +): + result = runner.invoke( + annif.cli.cli, + [ + "download", + "dummy-??", + "mock-repo", + ], + ) + assert not result.exception + assert list_repo_files.called + assert hf_hub_download.called + assert hf_hub_download.call_args_list == [ + mock.call( + repo_id="mock-repo", + filename="projects/dummy-fi.zip", + token=None, + revision=None, + ), + mock.call( + repo_id="mock-repo", + filename="dummy-fi.cfg", + token=None, + revision=None, + ), + mock.call( + repo_id="mock-repo", + filename="projects/dummy-en.zip", + token=None, + revision=None, + ), + mock.call( + repo_id="mock-repo", + filename="dummy-en.cfg", + token=None, + revision=None, + ), + mock.call( + repo_id="mock-repo", + filename="vocabs/dummy.zip", + token=None, + revision=None, + ), + ] + dirpath_fi = os.path.join(str(testdatadir), "projects", "dummy-fi") + fpath_fi = os.path.join(str(dirpath_fi), "file.txt") + assert os.path.exists(fpath_fi) + dirpath_en = os.path.join(str(testdatadir), "projects", "dummy-en") + fpath_en = os.path.join(str(dirpath_en), "file.txt") + assert os.path.exists(fpath_en) + assert copy_project_config.call_args_list == [ + mock.call("tests/huggingface-cache/dummy-fi.cfg", False), + mock.call("tests/huggingface-cache/dummy-en.cfg", False), + ] + + +@mock.patch( + "huggingface_hub.list_repo_files", + side_effect=HFValidationError, +) +@mock.patch( + "huggingface_hub.hf_hub_download", +) +def test_download_list_repo_files_failed( + hf_hub_download, + list_repo_files, +): + failed_result = runner.invoke( + annif.cli.cli, + [ + "download", + "dummy-fi", + "mock-repo", + ], + ) + assert failed_result.exception + assert failed_result.exit_code != 0 + assert "Error: Operation failed:" in failed_result.output + assert list_repo_files.called + assert not hf_hub_download.called + + +@mock.patch( + "huggingface_hub.list_repo_files", + return_value=[ # Mock filenames in repo + "projects/dummy-fi.zip", + "vocabs/dummy.zip", + "dummy-fi.cfg", + ], +) +@mock.patch( + "huggingface_hub.hf_hub_download", + side_effect=HFValidationError, +) +def test_download_hf_hub_download_failed( + hf_hub_download, + list_repo_files, +): + failed_result = runner.invoke( + annif.cli.cli, + [ + "download", + "dummy-fi", + "mock-repo", + ], + ) + assert failed_result.exception + assert failed_result.exit_code != 0 + assert "Error: Operation failed:" in failed_result.output + assert list_repo_files.called + assert hf_hub_download.called + + def test_completion_script_generation(): result = runner.invoke(annif.cli.cli, ["completion", "--bash"]) assert not result.exception diff --git a/tests/test_hfh_util.py b/tests/test_hfh_util.py new file mode 100644 index 000000000..ce3d6aac9 --- /dev/null +++ b/tests/test_hfh_util.py @@ -0,0 +1,103 @@ +"""Unit test module for Hugging Face Hub utilities.""" + +import io +import os.path +import zipfile +from datetime import datetime, timezone +from unittest import mock + +import annif.hfh_util + + +def test_archive_dir(testdatadir): + dirpath = os.path.join(str(testdatadir), "projects", "dummy-fi") + os.makedirs(dirpath, exist_ok=True) + open(os.path.join(str(dirpath), "foo.txt"), "a").close() + open(os.path.join(str(dirpath), "-train.txt"), "a").close() + + fobj = annif.hfh_util._archive_dir(dirpath) + assert isinstance(fobj, io.BufferedRandom) + + with zipfile.ZipFile(fobj, mode="r") as zfile: + archived_files = zfile.namelist() + assert len(archived_files) == 1 + assert os.path.split(archived_files[0])[1] == "foo.txt" + + +def test_get_project_config(app_project): + result = annif.hfh_util._get_project_config(app_project) + assert isinstance(result, io.BytesIO) + string_result = result.read().decode("UTF-8") + assert "[dummy-en]" in string_result + + +def test_unzip_archive_initial(testdatadir): + dirpath = os.path.join(str(testdatadir), "projects", "dummy-fi") + fpath = os.path.join(str(dirpath), "file.txt") + annif.hfh_util.unzip_archive( + os.path.join("tests", "huggingface-cache", "projects", "dummy-fi.zip"), + force=False, + ) + assert os.path.exists(fpath) + assert os.path.getsize(fpath) == 0 # Zero content from zip + ts = os.path.getmtime(fpath) + assert datetime.fromtimestamp(ts).astimezone(tz=timezone.utc) == datetime( + 1980, 1, 1, 0, 0 + ).astimezone(tz=timezone.utc) + + +def test_unzip_archive_no_overwrite(testdatadir): + dirpath = os.path.join(str(testdatadir), "projects", "dummy-fi") + fpath = os.path.join(str(dirpath), "file.txt") + os.makedirs(dirpath, exist_ok=True) + with open(fpath, "wt") as pf: + print("Existing content", file=pf) + + annif.hfh_util.unzip_archive( + os.path.join("tests", "huggingface-cache", "projects", "dummy-fi.zip"), + force=False, + ) + assert os.path.exists(fpath) + assert os.path.getsize(fpath) == 17 # Existing content + assert datetime.now().timestamp() - os.path.getmtime(fpath) < 1 + + +def test_unzip_archive_overwrite(testdatadir): + dirpath = os.path.join(str(testdatadir), "projects", "dummy-fi") + fpath = os.path.join(str(dirpath), "file.txt") + os.makedirs(dirpath, exist_ok=True) + with open(fpath, "wt") as pf: + print("Existing content", file=pf) + + annif.hfh_util.unzip_archive( + os.path.join("tests", "huggingface-cache", "projects", "dummy-fi.zip"), + force=True, + ) + assert os.path.exists(fpath) + assert os.path.getsize(fpath) == 0 # Zero content from zip + ts = os.path.getmtime(fpath) + assert datetime.fromtimestamp(ts).astimezone(tz=timezone.utc) == datetime( + 1980, 1, 1, 0, 0 + ).astimezone(tz=timezone.utc) + + +@mock.patch("os.path.exists", return_value=True) +@mock.patch("annif.hfh_util._compute_crc32", return_value=0) +@mock.patch("shutil.copy") +def test_copy_project_config_no_overwrite(copy, _compute_crc32, exists): + annif.hfh_util.copy_project_config( + os.path.join("tests", "huggingface-cache", "dummy-fi.cfg"), force=False + ) + assert not copy.called + + +@mock.patch("os.path.exists", return_value=True) +@mock.patch("shutil.copy") +def test_copy_project_config_overwrite(copy, exists): + annif.hfh_util.copy_project_config( + os.path.join("tests", "huggingface-cache", "dummy-fi.cfg"), force=True + ) + assert copy.called + assert copy.call_args == mock.call( + "tests/huggingface-cache/dummy-fi.cfg", "projects.d/dummy-fi.cfg" + )