From ec2d49109819fa31c95d023761e5d0bf9a20c9f1 Mon Sep 17 00:00:00 2001 From: Pringled Date: Tue, 21 Jan 2025 20:45:47 +0100 Subject: [PATCH 1/7] Updated save_pretrained --- model2vec/hf_utils.py | 11 ++++++++++- tests/test_model.py | 1 + uv.lock | 30 +++++++++++++++--------------- 3 files changed, 26 insertions(+), 16 deletions(-) diff --git a/model2vec/hf_utils.py b/model2vec/hf_utils.py index b9e853fa..3a988f2e 100644 --- a/model2vec/hf_utils.py +++ b/model2vec/hf_utils.py @@ -41,6 +41,15 @@ def save_pretrained( tokenizer.save(str(folder_path / "tokenizer.json")) json.dump(config, open(folder_path / "config.json", "w")) + # Create modules.json + modules = [{"idx": 0, "name": "0", "path": ".", "type": "sentence_transformers.models.StaticEmbedding"}] + if config.get("normalize") is True: + # If normalize=True, add the second entry for sentence_transformers.models.Normalize + modules.append({"idx": 1, "name": "1", "path": "1_Normalize", "type": "sentence_transformers.models.Normalize"}) + + with open(folder_path / "modules.json", "w", encoding="utf-8") as f: + json.dump(modules, f, indent=4) + logger.info(f"Saved model to {folder_path}") # Optionally create the model card @@ -75,7 +84,7 @@ def _create_model_card( base_model=base_model_name, license=license, language=language, - tags=["embeddings", "static-embeddings"], + tags=["embeddings", "static-embeddings", "sentence-transformers"], library_name="model2vec", **kwargs, ) diff --git a/tests/test_model.py b/tests/test_model.py index 0b2bc1d8..810fed84 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -161,6 +161,7 @@ def test_save_pretrained( assert (save_path / "model.safetensors").exists() assert (save_path / "tokenizer.json").exists() assert (save_path / "config.json").exists() + assert (save_path / "modules.json").exists() def test_load_pretrained( diff --git a/uv.lock b/uv.lock index ee17e1ec..516d7685 100644 --- a/uv.lock +++ b/uv.lock @@ -160,7 +160,7 @@ name = "click" version = "8.1.8" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "colorama", marker = "platform_system == 'Windows'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/b9/2e/0090cbf739cee7d23781ad4b89a9894a41538e4fcf4c31dcdd705b78eb8b/click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a", size = 226593 } wheels = [ @@ -536,7 +536,7 @@ wheels = [ [[package]] name = "model2vec" -version = "0.3.5" +version = "0.3.6" source = { editable = "." } dependencies = [ { name = "jinja2" }, @@ -1627,19 +1627,19 @@ dependencies = [ { name = "jinja2" }, { name = "networkx", version = "3.2.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, - { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, { name = "sympy" }, - { name = "triton", marker = "python_full_version < '3.13' and platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "triton", marker = "python_full_version < '3.13' and platform_machine == 'x86_64' and platform_system == 'Linux'" }, { name = "typing-extensions" }, ] wheels = [ @@ -1666,7 +1666,7 @@ name = "tqdm" version = "4.67.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "colorama", marker = "platform_system == 'Windows'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737 } wheels = [ From a69a88349822d9110a9dc0b840131cc1328cc9fd Mon Sep 17 00:00:00 2001 From: Pringled Date: Tue, 21 Jan 2025 20:46:34 +0100 Subject: [PATCH 2/7] Updated save_pretrained --- model2vec/hf_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/model2vec/hf_utils.py b/model2vec/hf_utils.py index 3a988f2e..362df637 100644 --- a/model2vec/hf_utils.py +++ b/model2vec/hf_utils.py @@ -47,8 +47,7 @@ def save_pretrained( # If normalize=True, add the second entry for sentence_transformers.models.Normalize modules.append({"idx": 1, "name": "1", "path": "1_Normalize", "type": "sentence_transformers.models.Normalize"}) - with open(folder_path / "modules.json", "w", encoding="utf-8") as f: - json.dump(modules, f, indent=4) + json.dump(modules, open(folder_path / "modules.json", "w")) logger.info(f"Saved model to {folder_path}") From b37624eaa876c200ad6eefac90a8d302dd4ef6fa Mon Sep 17 00:00:00 2001 From: Pringled Date: Tue, 21 Jan 2025 20:46:55 +0100 Subject: [PATCH 3/7] Updated save_pretrained --- model2vec/hf_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/model2vec/hf_utils.py b/model2vec/hf_utils.py index 362df637..7c20d9e1 100644 --- a/model2vec/hf_utils.py +++ b/model2vec/hf_utils.py @@ -46,7 +46,6 @@ def save_pretrained( if config.get("normalize") is True: # If normalize=True, add the second entry for sentence_transformers.models.Normalize modules.append({"idx": 1, "name": "1", "path": "1_Normalize", "type": "sentence_transformers.models.Normalize"}) - json.dump(modules, open(folder_path / "modules.json", "w")) logger.info(f"Saved model to {folder_path}") From 8ee1bb9127de7518d5d749c383eca36ed466b575 Mon Sep 17 00:00:00 2001 From: Pringled Date: Tue, 21 Jan 2025 20:47:46 +0100 Subject: [PATCH 4/7] Updated save_pretrained --- model2vec/hf_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model2vec/hf_utils.py b/model2vec/hf_utils.py index 7c20d9e1..a1fb7768 100644 --- a/model2vec/hf_utils.py +++ b/model2vec/hf_utils.py @@ -46,7 +46,7 @@ def save_pretrained( if config.get("normalize") is True: # If normalize=True, add the second entry for sentence_transformers.models.Normalize modules.append({"idx": 1, "name": "1", "path": "1_Normalize", "type": "sentence_transformers.models.Normalize"}) - json.dump(modules, open(folder_path / "modules.json", "w")) + json.dump(modules, open(folder_path / "modules.json", "w"), indent=4) logger.info(f"Saved model to {folder_path}") From e197ecb463583d8b5ca289ed034b86f51f8ca39d Mon Sep 17 00:00:00 2001 From: Pringled Date: Tue, 21 Jan 2025 20:47:59 +0100 Subject: [PATCH 5/7] Updated save_pretrained --- model2vec/hf_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model2vec/hf_utils.py b/model2vec/hf_utils.py index a1fb7768..d0bfdd35 100644 --- a/model2vec/hf_utils.py +++ b/model2vec/hf_utils.py @@ -44,7 +44,7 @@ def save_pretrained( # Create modules.json modules = [{"idx": 0, "name": "0", "path": ".", "type": "sentence_transformers.models.StaticEmbedding"}] if config.get("normalize") is True: - # If normalize=True, add the second entry for sentence_transformers.models.Normalize + # If normalize=True, add sentence_transformers.models.Normalize modules.append({"idx": 1, "name": "1", "path": "1_Normalize", "type": "sentence_transformers.models.Normalize"}) json.dump(modules, open(folder_path / "modules.json", "w"), indent=4) From fbf78d9f33566b6fe29727cd184bdc7d63c2635f Mon Sep 17 00:00:00 2001 From: Pringled Date: Tue, 21 Jan 2025 20:48:46 +0100 Subject: [PATCH 6/7] Updated save_pretrained --- model2vec/hf_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model2vec/hf_utils.py b/model2vec/hf_utils.py index d0bfdd35..a38bdc31 100644 --- a/model2vec/hf_utils.py +++ b/model2vec/hf_utils.py @@ -39,7 +39,7 @@ def save_pretrained( folder_path.mkdir(exist_ok=True, parents=True) save_file({"embeddings": embeddings}, folder_path / "model.safetensors") tokenizer.save(str(folder_path / "tokenizer.json")) - json.dump(config, open(folder_path / "config.json", "w")) + json.dump(config, open(folder_path / "config.json", "w"), indent=4) # Create modules.json modules = [{"idx": 0, "name": "0", "path": ".", "type": "sentence_transformers.models.StaticEmbedding"}] From 44541c6e245eac3e249380997da245dce1d889b3 Mon Sep 17 00:00:00 2001 From: Pringled Date: Tue, 21 Jan 2025 20:51:57 +0100 Subject: [PATCH 7/7] Resolved comment --- model2vec/hf_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model2vec/hf_utils.py b/model2vec/hf_utils.py index a38bdc31..53adbe5c 100644 --- a/model2vec/hf_utils.py +++ b/model2vec/hf_utils.py @@ -43,7 +43,7 @@ def save_pretrained( # Create modules.json modules = [{"idx": 0, "name": "0", "path": ".", "type": "sentence_transformers.models.StaticEmbedding"}] - if config.get("normalize") is True: + if config.get("normalize"): # If normalize=True, add sentence_transformers.models.Normalize modules.append({"idx": 1, "name": "1", "path": "1_Normalize", "type": "sentence_transformers.models.Normalize"}) json.dump(modules, open(folder_path / "modules.json", "w"), indent=4)