diff --git a/skllm/models/_base/classifier.py b/skllm/models/_base/classifier.py index 757457c..46138c7 100644 --- a/skllm/models/_base/classifier.py +++ b/skllm/models/_base/classifier.py @@ -222,13 +222,14 @@ def predict(self, X: Union[np.ndarray, pd.Series, List[str]]): Returns ------- - List[str] + np.ndarray + The predicted classes as a numpy array. """ X = _to_numpy(X) predictions = [] for i in tqdm(range(len(X))): predictions.append(self._predict_single(X[i])) - return predictions + return np.array(predictions) def _get_unique_targets(self, y: Any): labels = self._extract_labels(y) @@ -351,6 +352,7 @@ def __init__( memory_index: Optional[IndexConstructor] = None, vectorizer: _BaseVectorizer = None, prompt_template: Optional[str] = None, + metric="euclidean", ): super().__init__( model=model, @@ -360,6 +362,7 @@ def __init__( self.vectorizer = vectorizer self.memory_index = memory_index self.n_examples = n_examples + self.metric = metric if isinstance(self, MultiLabelMixin): raise TypeError("Multi-label classification is not supported") @@ -402,7 +405,7 @@ def fit( index = self.memory_index() index.dim = embeddings.shape[1] else: - index = SklearnMemoryIndex(embeddings.shape[1]) + index = SklearnMemoryIndex(embeddings.shape[1], metric=self.metric) for embedding in embeddings: index.add(embedding) index.build() diff --git a/skllm/models/gpt/classification/few_shot.py b/skllm/models/gpt/classification/few_shot.py index b3a9cb0..a3a71c6 100644 --- a/skllm/models/gpt/classification/few_shot.py +++ b/skllm/models/gpt/classification/few_shot.py @@ -100,6 +100,7 @@ def __init__( n_examples: int = 3, memory_index: Optional[IndexConstructor] = None, vectorizer: Optional[BaseVectorizer] = None, + metric: Optional[str] = "euclidean", **kwargs, ): """ @@ -124,6 +125,8 @@ def __init__( custom memory index, for details check `skllm.memory` submodule, by default None vectorizer : Optional[BaseVectorizer], optional scikit-llm vectorizer; if None, `GPTVectorizer` is used, by default None + metric : Optional[str], optional + metric used for similarity search, by default "euclidean" """ if vectorizer is None: vectorizer = GPTVectorizer(model="text-embedding-ada-002") @@ -134,5 +137,6 @@ def __init__( n_examples=n_examples, memory_index=memory_index, vectorizer=vectorizer, + metric=metric, ) self._set_keys(key, org)