From 145a22803339e85094a607385544bb793b93b0a1 Mon Sep 17 00:00:00 2001 From: AndreasKarasenko Date: Tue, 28 May 2024 16:48:09 +0200 Subject: [PATCH 1/7] propagate key to vectorizer --- skllm/models/gpt/classification/few_shot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skllm/models/gpt/classification/few_shot.py b/skllm/models/gpt/classification/few_shot.py index a3a71c6..b7dcd8b 100644 --- a/skllm/models/gpt/classification/few_shot.py +++ b/skllm/models/gpt/classification/few_shot.py @@ -129,7 +129,7 @@ def __init__( metric used for similarity search, by default "euclidean" """ if vectorizer is None: - vectorizer = GPTVectorizer(model="text-embedding-ada-002") + vectorizer = GPTVectorizer(model="text-embedding-ada-002", key=key, org=org) super().__init__( model=model, default_label=default_label, From eb6a2b6439740ac46218c103093c5556fbebb043 Mon Sep 17 00:00:00 2001 From: AndreasKarasenko Date: Tue, 28 May 2024 16:48:48 +0200 Subject: [PATCH 2/7] renamed key, org to conform to scikit-learn __innit__ reqs --- skllm/llm/gpt/mixin.py | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/skllm/llm/gpt/mixin.py b/skllm/llm/gpt/mixin.py index 5949935..049c427 100644 --- a/skllm/llm/gpt/mixin.py +++ b/skllm/llm/gpt/mixin.py @@ -66,8 +66,14 @@ def _set_keys(self, key: Optional[str] = None, org: Optional[str] = None) -> Non """ Set the OpenAI key and organization. """ - self.openai_key = key - self.openai_org = org + + self.key = key + self.org = org + + if self.key is None: + self.key = self._get_openai_key() + if org is None: + self.org = self._get_openai_org() def _get_openai_key(self) -> str: """ @@ -77,12 +83,12 @@ def _get_openai_key(self) -> str: ------- openai_key: str """ - key = self.openai_key - if key is None: - key = _Config.get_openai_key() - if key is None: + openai_key = self.key + if openai_key is None: + openai_key = _Config.get_openai_key() + if openai_key is None: raise RuntimeError("OpenAI key was not found") - return key + return openai_key def _get_openai_org(self) -> str: """ @@ -92,12 +98,12 @@ def _get_openai_org(self) -> str: ------- openai_org: str """ - key = self.openai_org - if key is None: - key = _Config.get_openai_org() - if key is None: + openai_org = self.org + if openai_org is None: + openai_org = _Config.get_openai_org() + if openai_org is None: raise RuntimeError("OpenAI organization was not found") - return key + return openai_org class GPTTextCompletionMixin(GPTMixin, BaseTextCompletionMixin): @@ -212,7 +218,7 @@ def _get_embeddings(self, text: np.ndarray) -> List[List[float]]: # for now this works only with OpenAI class GPTTunableMixin(BaseTunableMixin): - _supported_tunable_models = ["gpt-3.5-turbo-0125", "gpt-3.5-turbo"] + _supported_tunable_models = ["gpt-3.5-turbo-0613", "gpt-3.5-turbo"] def _build_label(self, label: str): return json.dumps({"label": label}) @@ -262,4 +268,4 @@ def _tune(self, X, y): self.openai_model = job.fine_tuned_model self.model = self.openai_model # openai_model is probably not required anymore delete_file(client, job.training_file) - print(f"Finished training.") + print(f"Finished training.") \ No newline at end of file From 4194a3789f9588b989d14204109d923115dfebb5 Mon Sep 17 00:00:00 2001 From: AndreasKarasenko Date: Fri, 7 Jun 2024 14:35:05 +0200 Subject: [PATCH 3/7] removed default assign in _set_keys, removed mention of openai --- skllm/llm/gpt/mixin.py | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/skllm/llm/gpt/mixin.py b/skllm/llm/gpt/mixin.py index 049c427..365f4cd 100644 --- a/skllm/llm/gpt/mixin.py +++ b/skllm/llm/gpt/mixin.py @@ -70,10 +70,6 @@ def _set_keys(self, key: Optional[str] = None, org: Optional[str] = None) -> Non self.key = key self.org = org - if self.key is None: - self.key = self._get_openai_key() - if org is None: - self.org = self._get_openai_org() def _get_openai_key(self) -> str: """ @@ -83,12 +79,12 @@ def _get_openai_key(self) -> str: ------- openai_key: str """ - openai_key = self.key - if openai_key is None: - openai_key = _Config.get_openai_key() - if openai_key is None: + key = self.key + if key is None: + key = _Config.get_openai_key() + if key is None: raise RuntimeError("OpenAI key was not found") - return openai_key + return key def _get_openai_org(self) -> str: """ @@ -98,12 +94,12 @@ def _get_openai_org(self) -> str: ------- openai_org: str """ - openai_org = self.org - if openai_org is None: - openai_org = _Config.get_openai_org() - if openai_org is None: + org = self.org + if org is None: + org = _Config.get_openai_org() + if org is None: raise RuntimeError("OpenAI organization was not found") - return openai_org + return org class GPTTextCompletionMixin(GPTMixin, BaseTextCompletionMixin): @@ -218,7 +214,7 @@ def _get_embeddings(self, text: np.ndarray) -> List[List[float]]: # for now this works only with OpenAI class GPTTunableMixin(BaseTunableMixin): - _supported_tunable_models = ["gpt-3.5-turbo-0613", "gpt-3.5-turbo"] + _supported_tunable_models = ["gpt-3.5-turbo-0125", "gpt-3.5-turbo"] def _build_label(self, label: str): return json.dumps({"label": label}) From 5ff4ac5fdd73d3e7edb1f897dd375795d9e0594a Mon Sep 17 00:00:00 2001 From: AndreasKarasenko Date: Fri, 7 Jun 2024 14:48:04 +0200 Subject: [PATCH 4/7] fixed docstring --- skllm/llm/gpt/mixin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skllm/llm/gpt/mixin.py b/skllm/llm/gpt/mixin.py index 365f4cd..704ee28 100644 --- a/skllm/llm/gpt/mixin.py +++ b/skllm/llm/gpt/mixin.py @@ -77,7 +77,7 @@ def _get_openai_key(self) -> str: Returns ------- - openai_key: str + key: str """ key = self.key if key is None: @@ -92,7 +92,7 @@ def _get_openai_org(self) -> str: Returns ------- - openai_org: str + org: str """ org = self.org if org is None: From b8bd1bd12fe4fcdbdaff1e4726f8adbe595461b5 Mon Sep 17 00:00:00 2001 From: AndreasKarasenko Date: Fri, 7 Jun 2024 16:27:25 +0200 Subject: [PATCH 5/7] added multi-threading --- skllm/models/_base/classifier.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/skllm/models/_base/classifier.py b/skllm/models/_base/classifier.py index 46138c7..2168f5c 100644 --- a/skllm/models/_base/classifier.py +++ b/skllm/models/_base/classifier.py @@ -7,6 +7,7 @@ import numpy as np import pandas as pd from tqdm import tqdm +from concurrent.futures import ThreadPoolExecutor import random from collections import Counter from skllm.llm.base import ( @@ -211,7 +212,7 @@ def fit( self.classes_, self.probabilities_ = self._get_unique_targets(y) return self - def predict(self, X: Union[np.ndarray, pd.Series, List[str]]): + def predict(self, X: Union[np.ndarray, pd.Series, List[str]], num_workers: int = 1): """ Predicts the class of each input. @@ -219,6 +220,9 @@ def predict(self, X: Union[np.ndarray, pd.Series, List[str]]): ---------- X : Union[np.ndarray, pd.Series, List[str]] The input data to predict the class of. + + num_workers : int + number of workers to use for multithreaded prediction, default 1 Returns ------- @@ -226,9 +230,11 @@ def predict(self, X: Union[np.ndarray, pd.Series, List[str]]): The predicted classes as a numpy array. """ X = _to_numpy(X) - predictions = [] - for i in tqdm(range(len(X))): - predictions.append(self._predict_single(X[i])) + + + with ThreadPoolExecutor(max_workers=num_workers) as executor: + predictions = list(tqdm(executor.map(self._predict_single, X), total=len(X))) + return np.array(predictions) def _get_unique_targets(self, y: Any): From 2c47d9100ba4dd671315d0961316a02b3815c068 Mon Sep 17 00:00:00 2001 From: AndreasKarasenko Date: Sat, 8 Jun 2024 16:46:54 +0200 Subject: [PATCH 6/7] added depreceiation warning --- skllm/models/_base/classifier.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/skllm/models/_base/classifier.py b/skllm/models/_base/classifier.py index 2168f5c..43b094e 100644 --- a/skllm/models/_base/classifier.py +++ b/skllm/models/_base/classifier.py @@ -231,7 +231,8 @@ def predict(self, X: Union[np.ndarray, pd.Series, List[str]], num_workers: int = """ X = _to_numpy(X) - + if num_workers > 1: + raise DeprecationWarning("Passing num_workers to predict is temporary and will be removed in the future.") with ThreadPoolExecutor(max_workers=num_workers) as executor: predictions = list(tqdm(executor.map(self._predict_single, X), total=len(X))) From 1917943613d8ad2727dc8fdc377e8a884f4111f3 Mon Sep 17 00:00:00 2001 From: AndreasKarasenko Date: Sat, 8 Jun 2024 18:25:09 +0200 Subject: [PATCH 7/7] use warnings.warn instead of raise --- skllm/models/_base/classifier.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/skllm/models/_base/classifier.py b/skllm/models/_base/classifier.py index 43b094e..03e0688 100644 --- a/skllm/models/_base/classifier.py +++ b/skllm/models/_base/classifier.py @@ -4,6 +4,7 @@ BaseEstimator as _SklBaseEstimator, ClassifierMixin as _SklClassifierMixin, ) +import warnings import numpy as np import pandas as pd from tqdm import tqdm @@ -232,7 +233,7 @@ def predict(self, X: Union[np.ndarray, pd.Series, List[str]], num_workers: int = X = _to_numpy(X) if num_workers > 1: - raise DeprecationWarning("Passing num_workers to predict is temporary and will be removed in the future.") + warnings.warn("Passing num_workers to predict is temporary and will be removed in the future.") with ThreadPoolExecutor(max_workers=num_workers) as executor: predictions = list(tqdm(executor.map(self._predict_single, X), total=len(X)))