Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions skllm/models/_base/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,13 +222,14 @@ def predict(self, X: Union[np.ndarray, pd.Series, List[str]]):

Returns
-------
List[str]
np.ndarray
The predicted classes as a numpy array.
"""
X = _to_numpy(X)
predictions = []
for i in tqdm(range(len(X))):
predictions.append(self._predict_single(X[i]))
return predictions
return np.array(predictions)

def _get_unique_targets(self, y: Any):
labels = self._extract_labels(y)
Expand Down Expand Up @@ -351,6 +352,7 @@ def __init__(
memory_index: Optional[IndexConstructor] = None,
vectorizer: _BaseVectorizer = None,
prompt_template: Optional[str] = None,
metric="euclidean",
):
super().__init__(
model=model,
Expand All @@ -360,6 +362,7 @@ def __init__(
self.vectorizer = vectorizer
self.memory_index = memory_index
self.n_examples = n_examples
self.metric = metric
if isinstance(self, MultiLabelMixin):
raise TypeError("Multi-label classification is not supported")

Expand Down Expand Up @@ -402,7 +405,7 @@ def fit(
index = self.memory_index()
index.dim = embeddings.shape[1]
else:
index = SklearnMemoryIndex(embeddings.shape[1])
index = SklearnMemoryIndex(embeddings.shape[1], metric=self.metric)
for embedding in embeddings:
index.add(embedding)
index.build()
Expand Down
4 changes: 4 additions & 0 deletions skllm/models/gpt/classification/few_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(
n_examples: int = 3,
memory_index: Optional[IndexConstructor] = None,
vectorizer: Optional[BaseVectorizer] = None,
metric: Optional[str] = "euclidean",
**kwargs,
):
"""
Expand All @@ -124,6 +125,8 @@ def __init__(
custom memory index, for details check `skllm.memory` submodule, by default None
vectorizer : Optional[BaseVectorizer], optional
scikit-llm vectorizer; if None, `GPTVectorizer` is used, by default None
metric : Optional[str], optional
metric used for similarity search, by default "euclidean"
"""
if vectorizer is None:
vectorizer = GPTVectorizer(model="text-embedding-ada-002")
Expand All @@ -134,5 +137,6 @@ def __init__(
n_examples=n_examples,
memory_index=memory_index,
vectorizer=vectorizer,
metric=metric,
)
self._set_keys(key, org)