Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,13 @@
"source": [
"from tako_query_filter.filter import TakoQueryFilter\n",
"\n",
"query_filter = TakoQueryFilter.load_from_hf(force_download=True)"
"query_filter = TakoQueryFilter.load_from_hf(force_download=True)\n",
"# Alternatively, you can load from local paths\n",
"# query_filter = TakoQueryFilter.load_from_local(\n",
"# topic_model_path=\"local-scikit-model-path/topic_model.pkl\",\n",
"# spacy_model_path=\"local-spacy-model-path\",\n",
"# keywords_path=\"local-scikit-model-path/topic_model.pkl\"\n",
"# )\n"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ description = "Combines models to predict which queries Tako's API should handle
name = "tako-query-filter"
packages = [{include = "tako_query_filter", from = "src"}]
readme = "README.md"
version = "0.1.2"
version = "0.1.3"

[tool.poetry.dependencies]
huggingface-hub = "0.26.1"
Expand Down
25 changes: 25 additions & 0 deletions src/tako_query_filter/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,31 @@ def load_from_hf(

return cls(topic_model, spacy_model, keywords)

@classmethod
def load_from_local(
cls,
topic_model_path: str,
spacy_model_path: str,
keywords_path: str,
):
"""Load TakoQueryFilter from local file paths.

Args:
topic_model_path: Path to the scikit-learn topic model pickle file
spacy_model_path: Path to the spacy model directory
keywords_path: Path to the whitelist keywords JSON file

Returns:
TakoQueryFilter: Initialized filter with models loaded from local paths
"""
topic_model = joblib.load(topic_model_path)
spacy_model = spacy.load(spacy_model_path)

with open(keywords_path, "r") as f:
keywords = set(json.load(f))

return cls(topic_model, spacy_model, keywords)

def create_embeddings(self, queries: Iterable[str]) -> np.ndarray:
if not self.embeddings_model:
from sentence_transformers import SentenceTransformer
Expand Down