diff --git a/model2vec/hf_utils.py b/model2vec/hf_utils.py index b9e853fa..53adbe5c 100644 --- a/model2vec/hf_utils.py +++ b/model2vec/hf_utils.py @@ -39,7 +39,14 @@ def save_pretrained( folder_path.mkdir(exist_ok=True, parents=True) save_file({"embeddings": embeddings}, folder_path / "model.safetensors") tokenizer.save(str(folder_path / "tokenizer.json")) - json.dump(config, open(folder_path / "config.json", "w")) + json.dump(config, open(folder_path / "config.json", "w"), indent=4) + + # Create modules.json + modules = [{"idx": 0, "name": "0", "path": ".", "type": "sentence_transformers.models.StaticEmbedding"}] + if config.get("normalize"): + # If normalize=True, add sentence_transformers.models.Normalize + modules.append({"idx": 1, "name": "1", "path": "1_Normalize", "type": "sentence_transformers.models.Normalize"}) + json.dump(modules, open(folder_path / "modules.json", "w"), indent=4) logger.info(f"Saved model to {folder_path}") @@ -75,7 +82,7 @@ def _create_model_card( base_model=base_model_name, license=license, language=language, - tags=["embeddings", "static-embeddings"], + tags=["embeddings", "static-embeddings", "sentence-transformers"], library_name="model2vec", **kwargs, ) diff --git a/tests/test_model.py b/tests/test_model.py index 0b2bc1d8..810fed84 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -161,6 +161,7 @@ def test_save_pretrained( assert (save_path / "model.safetensors").exists() assert (save_path / "tokenizer.json").exists() assert (save_path / "config.json").exists() + assert (save_path / "modules.json").exists() def test_load_pretrained( diff --git a/uv.lock b/uv.lock index ee17e1ec..516d7685 100644 --- a/uv.lock +++ b/uv.lock @@ -160,7 +160,7 @@ name = "click" version = "8.1.8" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "colorama", marker = "platform_system == 'Windows'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/b9/2e/0090cbf739cee7d23781ad4b89a9894a41538e4fcf4c31dcdd705b78eb8b/click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a", size = 226593 } wheels = [ @@ -536,7 +536,7 @@ wheels = [ [[package]] name = "model2vec" -version = "0.3.5" +version = "0.3.6" source = { editable = "." } dependencies = [ { name = "jinja2" }, @@ -1627,19 +1627,19 @@ dependencies = [ { name = "jinja2" }, { name = "networkx", version = "3.2.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, - { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, { name = "sympy" }, - { name = "triton", marker = "python_full_version < '3.13' and platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "triton", marker = "python_full_version < '3.13' and platform_machine == 'x86_64' and platform_system == 'Linux'" }, { name = "typing-extensions" }, ] wheels = [ @@ -1666,7 +1666,7 @@ name = "tqdm" version = "4.67.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "colorama", marker = "platform_system == 'Windows'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737 } wheels = [