From b47321bb57a15f2c9e97a0a9965d2d5f3a19d8d8 Mon Sep 17 00:00:00 2001 From: noahjax Date: Tue, 26 Nov 2024 09:56:13 -0800 Subject: [PATCH 1/2] add the ability to load from a path --- demo.ipynb | 43 +++++++++++++++++++++++++++++---- src/tako_query_filter/filter.py | 25 +++++++++++++++++++ 2 files changed, 63 insertions(+), 5 deletions(-) diff --git a/demo.ipynb b/demo.ipynb index 27b085c..dded4c8 100644 --- a/demo.ipynb +++ b/demo.ipynb @@ -23,20 +23,53 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/noahjax/Work/QueryFilter/.venv/lib/python3.10/site-packages/spacy/util.py:910: UserWarning: [W095] Model 'en_pipeline' (0.0.0) was trained with spaCy v3.8.2 and may not be 100% compatible with the current version (3.7.5). If you see errors or degraded performance, download a newer compatible model or retrain your custom model with the current spaCy version. For more details and available updates, run: python -m spacy validate\n", + " warnings.warn(warn_msg)\n" + ] + } + ], "source": [ "from tako_query_filter.filter import TakoQueryFilter\n", "\n", - "query_filter = TakoQueryFilter.load_from_hf(force_download=True)" + "query_filter = TakoQueryFilter.load_from_hf(force_download=True)\n", + "# Alternatively, you can load from local paths\n", + "# query_filter = TakoQueryFilter.load_from_local(\n", + "# topic_model_path=\"local-scikit-model-path/topic_model.pkl\",\n", + "# spacy_model_path=\"local-spacy-model-path\",\n", + "# keywords_path=\"local-scikit-model-path/topic_model.pkl\"\n", + "# )\n" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "aapl vs nvda -> 1\n", + "aapl vs nvidia -> 1\n", + "aapl -> 1\n", + "nvda -> 1\n", + "nvidia -> 1\n", + "trump vs. harris -> 1\n", + "algae -> 0\n", + "san francisco game this week -> 1\n", + "what college did michael phelps go toelection 2024 -> 1\n", + "election results -> 1\n", + "presidential election polls 2024 -> 1\n" + ] + } + ], "source": [ "queries = [\n", " \"aapl vs nvda\",\n", diff --git a/src/tako_query_filter/filter.py b/src/tako_query_filter/filter.py index e23a1b4..f0f45bc 100644 --- a/src/tako_query_filter/filter.py +++ b/src/tako_query_filter/filter.py @@ -57,6 +57,31 @@ def load_from_hf( return cls(topic_model, spacy_model, keywords) + @classmethod + def load_from_local( + cls, + topic_model_path: str, + spacy_model_path: str, + keywords_path: str, + ): + """Load TakoQueryFilter from local file paths. + + Args: + topic_model_path: Path to the scikit-learn topic model pickle file + spacy_model_path: Path to the spacy model directory + keywords_path: Path to the whitelist keywords JSON file + + Returns: + TakoQueryFilter: Initialized filter with models loaded from local paths + """ + topic_model = joblib.load(topic_model_path) + spacy_model = spacy.load(spacy_model_path) + + with open(keywords_path, "r") as f: + keywords = set(json.load(f)) + + return cls(topic_model, spacy_model, keywords) + def create_embeddings(self, queries: Iterable[str]) -> np.ndarray: if not self.embeddings_model: from sentence_transformers import SentenceTransformer From 81953693a6b790a6806fea45852c0da4bae643a8 Mon Sep 17 00:00:00 2001 From: noahjax Date: Tue, 26 Nov 2024 10:05:08 -0800 Subject: [PATCH 2/2] bump version --- demo.ipynb | 35 ++++------------------------------- pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 32 deletions(-) diff --git a/demo.ipynb b/demo.ipynb index dded4c8..2c230f5 100644 --- a/demo.ipynb +++ b/demo.ipynb @@ -23,18 +23,9 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/noahjax/Work/QueryFilter/.venv/lib/python3.10/site-packages/spacy/util.py:910: UserWarning: [W095] Model 'en_pipeline' (0.0.0) was trained with spaCy v3.8.2 and may not be 100% compatible with the current version (3.7.5). If you see errors or degraded performance, download a newer compatible model or retrain your custom model with the current spaCy version. For more details and available updates, run: python -m spacy validate\n", - " warnings.warn(warn_msg)\n" - ] - } - ], + "outputs": [], "source": [ "from tako_query_filter.filter import TakoQueryFilter\n", "\n", @@ -49,27 +40,9 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "aapl vs nvda -> 1\n", - "aapl vs nvidia -> 1\n", - "aapl -> 1\n", - "nvda -> 1\n", - "nvidia -> 1\n", - "trump vs. harris -> 1\n", - "algae -> 0\n", - "san francisco game this week -> 1\n", - "what college did michael phelps go toelection 2024 -> 1\n", - "election results -> 1\n", - "presidential election polls 2024 -> 1\n" - ] - } - ], + "outputs": [], "source": [ "queries = [\n", " \"aapl vs nvda\",\n", diff --git a/pyproject.toml b/pyproject.toml index fb28132..de52fa8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ description = "Combines models to predict which queries Tako's API should handle name = "tako-query-filter" packages = [{include = "tako_query_filter", from = "src"}] readme = "README.md" -version = "0.1.2" +version = "0.1.3" [tool.poetry.dependencies] huggingface-hub = "0.26.1"