Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 8 additions & 14 deletions demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,9 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([1, 1, 1, 1, 1, 1, 0, 1, 0])"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"queries = [\n",
" \"aapl vs nvda\",\n",
Expand All @@ -59,9 +48,14 @@
" \"algae\",\n",
" \"san francisco game this week\",\n",
" \"what college did michael phelps go to\"\n",
" \"election 2024\",\n",
" \"election results\",\n",
" \"presidential election polls 2024\",\n",
"]\n",
"\n",
"query_filter.predict(queries)"
"preds = query_filter.predict(queries)\n",
"for query, pred in zip(queries, preds):\n",
" print(f\"{query} -> {pred}\")\n"
]
},
{
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ numpy = "1.26.4"
python = ">=3.10.14,<4.0"
scikit-learn = "1.4.1.post1"
sentence-transformers = {version = "3.0.1", optional = true}
spacy = {version = "3.7.5"}

[tool.poetry.extras]
embeddings = ["sentence-transformers"]
Expand Down
82 changes: 49 additions & 33 deletions src/tako_query_filter/filter.py
Original file line number Diff line number Diff line change
@@ -1,80 +1,102 @@
import json
import logging
from pathlib import Path
import re
import joblib
import numpy as np
from typing import Iterable, List, Optional, Set
from sklearn.linear_model import LogisticRegressionCV
from huggingface_hub import hf_hub_download
import spacy
from spacy.language import Language


class TakoQueryFilter:
def __init__(
self,
chart_model: LogisticRegressionCV,
topic_model: LogisticRegressionCV,
spacy_model: Language,
keywords: Set[str],
):
self.chart_model = chart_model
self.topic_model = topic_model
self.spacy_model = spacy_model
self.keywords = keywords
self.keyword_match_score = 0.9
self.model = None
self.embeddings_model = None

@classmethod
def load_from_hf(
cls,
scikit_path: str = "TakoData/ScikitModels",
revision: Optional[str] = None,
topic_revision: Optional[str] = "a8a257f706ec28a63eeb40b088b8e05b30670971",
spacy_revision: Optional[str] = "156303cfba1f9ac5ef7cfd35fe5dc8c9238a459d",
force_download: bool = False,
):
chart_model = joblib.load(
hf_hub_download(
repo_id=scikit_path,
filename="models/chart_model.pkl",
revision=revision,
force_download=force_download,
)
)
topic_model = joblib.load(
hf_hub_download(
repo_id=scikit_path,
filename="models/topic_model.pkl",
revision=revision,
revision=topic_revision,
force_download=force_download,
)
)
spacy_model_path = hf_hub_download(
repo_id="TakoData/ner-model-best",
filename="config.cfg",
revision=spacy_revision,
force_download=force_download,
)
spacy_model_dir = str(Path(spacy_model_path).parent)
spacy_model = spacy.load(spacy_model_dir)
keywords_file = hf_hub_download(
repo_id=scikit_path,
filename="models/keywords.json",
revision=revision,
revision=topic_revision,
force_download=force_download,
)
with open(keywords_file, "r") as f:
keywords = set(json.load(f))

return cls(chart_model, topic_model, keywords)
return cls(topic_model, spacy_model, keywords)

def create_embeddings(self, queries: Iterable[str]) -> np.ndarray:
if not self.model:
if not self.embeddings_model:
from sentence_transformers import SentenceTransformer

self.model = SentenceTransformer(
self.embeddings_model = SentenceTransformer(
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
)

embeddings = self.model.encode(list(queries), normalize_embeddings=True)
embeddings = self.embeddings_model.encode(
list(queries), normalize_embeddings=True
)
return embeddings

def extract_spacy_features(self, query: str) -> np.ndarray:
vector = np.zeros((256,))

doc = self.spacy_model(query)
spans = doc.spans["sc"]
scores = doc.spans["sc"].attrs["scores"]
for span, score in zip(spans, scores):
if score:
vector += np.array(span.vector) * score

if len(spans) > 0:
# Normalize vector
norm = np.linalg.norm(vector)
if norm > 0:
vector = (vector - np.mean(vector)) / norm

return vector

def predict(
self,
queries: List[str],
embeddings: np.ndarray = np.array([]),
chart_weight=0.5,
topic_weight=0.5,
):
# Use predict_proba to get class predictions
probs = self.predict_proba(queries, embeddings, chart_weight, topic_weight)
probs = self.predict_proba(queries, embeddings)
# Convert probabilities to binary predictions
predictions = (probs > 0.5).astype(int)
return predictions
Expand All @@ -83,8 +105,6 @@ def predict_proba(
self,
queries: List[str],
embeddings: np.ndarray = np.array([]),
chart_weight=0.5,
topic_weight=0.5,
) -> np.ndarray:
if len(embeddings) != len(queries):
if len(embeddings) > 0:
Expand All @@ -93,17 +113,13 @@ def predict_proba(
)
embeddings = self.create_embeddings(queries)

# Get probabilities from both models
chart_probs = self.chart_model.predict_proba(embeddings)
topic_probs = self.topic_model.predict_proba(embeddings)

# Get probabilities of the positive class (index 1) from both models
chart_probs_positive = chart_probs[:, 1]
topic_probs_positive = topic_probs[:, 1]
spacy_vectors = [self.extract_spacy_features(query) for query in queries]
# Combine embeddings with spacy vectors
X = np.hstack([embeddings, spacy_vectors])

positive_probs = (
chart_weight * chart_probs_positive + topic_weight * topic_probs_positive
)
# Get probabilities from both models
probs = self.topic_model.predict_proba(X)
positive_probs = probs[:, 1]

for i, query in enumerate(queries):
split_query = self._split_query(query)
Expand Down