diff --git a/skllm/llm/gpt/mixin.py b/skllm/llm/gpt/mixin.py index 5949935..704ee28 100644 --- a/skllm/llm/gpt/mixin.py +++ b/skllm/llm/gpt/mixin.py @@ -66,8 +66,10 @@ def _set_keys(self, key: Optional[str] = None, org: Optional[str] = None) -> Non """ Set the OpenAI key and organization. """ - self.openai_key = key - self.openai_org = org + + self.key = key + self.org = org + def _get_openai_key(self) -> str: """ @@ -75,9 +77,9 @@ def _get_openai_key(self) -> str: Returns ------- - openai_key: str + key: str """ - key = self.openai_key + key = self.key if key is None: key = _Config.get_openai_key() if key is None: @@ -90,14 +92,14 @@ def _get_openai_org(self) -> str: Returns ------- - openai_org: str + org: str """ - key = self.openai_org - if key is None: - key = _Config.get_openai_org() - if key is None: + org = self.org + if org is None: + org = _Config.get_openai_org() + if org is None: raise RuntimeError("OpenAI organization was not found") - return key + return org class GPTTextCompletionMixin(GPTMixin, BaseTextCompletionMixin): @@ -262,4 +264,4 @@ def _tune(self, X, y): self.openai_model = job.fine_tuned_model self.model = self.openai_model # openai_model is probably not required anymore delete_file(client, job.training_file) - print(f"Finished training.") + print(f"Finished training.") \ No newline at end of file diff --git a/skllm/models/_base/classifier.py b/skllm/models/_base/classifier.py index 46138c7..03e0688 100644 --- a/skllm/models/_base/classifier.py +++ b/skllm/models/_base/classifier.py @@ -4,9 +4,11 @@ BaseEstimator as _SklBaseEstimator, ClassifierMixin as _SklClassifierMixin, ) +import warnings import numpy as np import pandas as pd from tqdm import tqdm +from concurrent.futures import ThreadPoolExecutor import random from collections import Counter from skllm.llm.base import ( @@ -211,7 +213,7 @@ def fit( self.classes_, self.probabilities_ = self._get_unique_targets(y) return self - def predict(self, X: Union[np.ndarray, pd.Series, List[str]]): + def predict(self, X: Union[np.ndarray, pd.Series, List[str]], num_workers: int = 1): """ Predicts the class of each input. @@ -219,6 +221,9 @@ def predict(self, X: Union[np.ndarray, pd.Series, List[str]]): ---------- X : Union[np.ndarray, pd.Series, List[str]] The input data to predict the class of. + + num_workers : int + number of workers to use for multithreaded prediction, default 1 Returns ------- @@ -226,9 +231,12 @@ def predict(self, X: Union[np.ndarray, pd.Series, List[str]]): The predicted classes as a numpy array. """ X = _to_numpy(X) - predictions = [] - for i in tqdm(range(len(X))): - predictions.append(self._predict_single(X[i])) + + if num_workers > 1: + warnings.warn("Passing num_workers to predict is temporary and will be removed in the future.") + with ThreadPoolExecutor(max_workers=num_workers) as executor: + predictions = list(tqdm(executor.map(self._predict_single, X), total=len(X))) + return np.array(predictions) def _get_unique_targets(self, y: Any): diff --git a/skllm/models/gpt/classification/few_shot.py b/skllm/models/gpt/classification/few_shot.py index a3a71c6..b7dcd8b 100644 --- a/skllm/models/gpt/classification/few_shot.py +++ b/skllm/models/gpt/classification/few_shot.py @@ -129,7 +129,7 @@ def __init__( metric used for similarity search, by default "euclidean" """ if vectorizer is None: - vectorizer = GPTVectorizer(model="text-embedding-ada-002") + vectorizer = GPTVectorizer(model="text-embedding-ada-002", key=key, org=org) super().__init__( model=model, default_label=default_label,