From 10961760fd88274e518d5a74cab2a66aca3e6705 Mon Sep 17 00:00:00 2001 From: Oleh Kostromin Date: Wed, 24 Jul 2024 21:06:36 +0200 Subject: [PATCH] updated list of tunable models --- pyproject.toml | 2 +- skllm/__init__.py | 2 +- skllm/llm/gpt/mixin.py | 12 ++++++++---- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e3c8e20..c1b1b73 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ dependencies = [ "google-cloud-aiplatform[pipelines]>=1.27.0,<2.0.0" ] name = "scikit-llm" -version = "1.3.0" +version = "1.3.1" authors = [ { name="Oleh Kostromin", email="kostromin97@gmail.com" }, { name="Iryna Kondrashchenko", email="iryna230520@gmail.com" }, diff --git a/skllm/__init__.py b/skllm/__init__.py index 4715137..008a11e 100644 --- a/skllm/__init__.py +++ b/skllm/__init__.py @@ -1,2 +1,2 @@ -__version__ = '1.3.0' +__version__ = '1.3.1' __author__ = 'Iryna Kondrashchenko, Oleh Kostromin' diff --git a/skllm/llm/gpt/mixin.py b/skllm/llm/gpt/mixin.py index 704ee28..a0e0c41 100644 --- a/skllm/llm/gpt/mixin.py +++ b/skllm/llm/gpt/mixin.py @@ -66,10 +66,9 @@ def _set_keys(self, key: Optional[str] = None, org: Optional[str] = None) -> Non """ Set the OpenAI key and organization. """ - + self.key = key self.org = org - def _get_openai_key(self) -> str: """ @@ -214,7 +213,12 @@ def _get_embeddings(self, text: np.ndarray) -> List[List[float]]: # for now this works only with OpenAI class GPTTunableMixin(BaseTunableMixin): - _supported_tunable_models = ["gpt-3.5-turbo-0125", "gpt-3.5-turbo"] + _supported_tunable_models = [ + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo", + "gpt-4o-mini-2024-07-18", + "gpt-4o-mini", + ] def _build_label(self, label: str): return json.dumps({"label": label}) @@ -264,4 +268,4 @@ def _tune(self, X, y): self.openai_model = job.fine_tuned_model self.model = self.openai_model # openai_model is probably not required anymore delete_file(client, job.training_file) - print(f"Finished training.") \ No newline at end of file + print(f"Finished training.")