From 73d6e132722a4639bfb00ef3c25354099c4caf4e Mon Sep 17 00:00:00 2001 From: Iryna K Date: Thu, 24 Aug 2023 10:32:43 +0200 Subject: [PATCH] fix gpt prediction function --- pyproject.toml | 2 +- skllm/models/gpt/gpt.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 47b3ca1..77c4740 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ dependencies = [ "google-cloud-aiplatform>=1.27.0" ] name = "scikit-llm" -version = "0.4.0" +version = "0.4.1" authors = [ { name="Oleg Kostromin", email="kostromin97@gmail.com" }, { name="Iryna Kondrashchenko", email="iryna230520@gmail.com" }, diff --git a/skllm/models/gpt/gpt.py b/skllm/models/gpt/gpt.py index cbaab53..6cbd4e5 100644 --- a/skllm/models/gpt/gpt.py +++ b/skllm/models/gpt/gpt.py @@ -149,6 +149,10 @@ def _get_prompt(self, x: str) -> str: def _build_label(self, label: str): return label + def _predict_single(self, x): + completion = self._get_chat_completion(x) + return completion["choices"][0]["message"]["content"] + def fit( self, X: Union[np.ndarray, pd.Series, List[str]],