diff --git a/demo.ipynb b/demo.ipynb index 27b085c..2c230f5 100644 --- a/demo.ipynb +++ b/demo.ipynb @@ -29,7 +29,13 @@ "source": [ "from tako_query_filter.filter import TakoQueryFilter\n", "\n", - "query_filter = TakoQueryFilter.load_from_hf(force_download=True)" + "query_filter = TakoQueryFilter.load_from_hf(force_download=True)\n", + "# Alternatively, you can load from local paths\n", + "# query_filter = TakoQueryFilter.load_from_local(\n", + "# topic_model_path=\"local-scikit-model-path/topic_model.pkl\",\n", + "# spacy_model_path=\"local-spacy-model-path\",\n", + "# keywords_path=\"local-scikit-model-path/topic_model.pkl\"\n", + "# )\n" ] }, { diff --git a/pyproject.toml b/pyproject.toml index fb28132..de52fa8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ description = "Combines models to predict which queries Tako's API should handle name = "tako-query-filter" packages = [{include = "tako_query_filter", from = "src"}] readme = "README.md" -version = "0.1.2" +version = "0.1.3" [tool.poetry.dependencies] huggingface-hub = "0.26.1" diff --git a/src/tako_query_filter/filter.py b/src/tako_query_filter/filter.py index e23a1b4..f0f45bc 100644 --- a/src/tako_query_filter/filter.py +++ b/src/tako_query_filter/filter.py @@ -57,6 +57,31 @@ def load_from_hf( return cls(topic_model, spacy_model, keywords) + @classmethod + def load_from_local( + cls, + topic_model_path: str, + spacy_model_path: str, + keywords_path: str, + ): + """Load TakoQueryFilter from local file paths. + + Args: + topic_model_path: Path to the scikit-learn topic model pickle file + spacy_model_path: Path to the spacy model directory + keywords_path: Path to the whitelist keywords JSON file + + Returns: + TakoQueryFilter: Initialized filter with models loaded from local paths + """ + topic_model = joblib.load(topic_model_path) + spacy_model = spacy.load(spacy_model_path) + + with open(keywords_path, "r") as f: + keywords = set(json.load(f)) + + return cls(topic_model, spacy_model, keywords) + def create_embeddings(self, queries: Iterable[str]) -> np.ndarray: if not self.embeddings_model: from sentence_transformers import SentenceTransformer