From 89c6f4be66d42a2966043ac5bb66960e4635e02b Mon Sep 17 00:00:00 2001 From: Iryna K Date: Sun, 18 Jun 2023 18:03:53 +0200 Subject: [PATCH] Translator updated --- skllm/openai/base_gpt.py | 48 +++++++++++++-------------- skllm/preprocessing/gpt_translator.py | 14 +++++++- skllm/prompts/templates.py | 15 +++------ 3 files changed, 40 insertions(+), 37 deletions(-) diff --git a/skllm/openai/base_gpt.py b/skllm/openai/base_gpt.py index 3fcc6c8..95ccee9 100644 --- a/skllm/openai/base_gpt.py +++ b/skllm/openai/base_gpt.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, List, Optional, Union +from typing import Any import numpy as np import pandas as pd @@ -14,24 +14,21 @@ from skllm.utils import to_numpy as _to_numpy -class BaseZeroShotGPTTransformer(_BaseEstimator, _TransformerMixin, _OAIMixin): - +class BaseZeroShotGPTTransformer(_BaseEstimator, _TransformerMixin, _OAIMixin): system_msg = "You are an scikit-learn transformer." default_output = "Output is unavailable" def _get_chat_completion(self, X): - """ - Gets the chat completion for the given input using open ai API. + """Gets the chat completion for the given input using open ai API. Parameters ---------- X : str Input string - + Returns ------- str - """ prompt = self._get_prompt(X) msgs = [] @@ -45,10 +42,11 @@ def _get_chat_completion(self, X): except Exception as e: print(f"Skipping a sample due to the following error: {str(e)}") return self.default_output - - def fit(self, X: Any = None, y: Any = None, **kwargs: Any) -> BaseZeroShotGPTTransformer: - """ - Fits the model to the data. + + def fit( + self, X: Any = None, y: Any = None, **kwargs: Any + ) -> BaseZeroShotGPTTransformer: + """Fits the model to the data. Parameters ---------- @@ -60,12 +58,13 @@ def fit(self, X: Any = None, y: Any = None, **kwargs: Any) -> BaseZeroShotGPTTra ------- self : BaseZeroShotGPTTransformer """ - return self - def transform(self, X: Optional[Union[np.ndarray, pd.Series, List[str]]], **kwargs: Any) -> ndarray: - """ - Converts a list of strings using the open ai API and a predefined prompt. + def transform( + self, X: np.ndarray | pd.Series | list[str], **kwargs: Any + ) -> ndarray: + """Converts a list of strings using the open ai API and a predefined + prompt. Parameters ---------- @@ -78,16 +77,15 @@ def transform(self, X: Optional[Union[np.ndarray, pd.Series, List[str]]], **kwar X = _to_numpy(X) transformed = [] for i in tqdm(range(len(X))): - transformed.append( - self._get_chat_completion(X[i]) - ) - transformed = np.asarray(transformed, dtype = object) + transformed.append(self._get_chat_completion(X[i])) + transformed = np.asarray(transformed, dtype=object) return transformed - def fit_transform(self, X: Optional[Union[np.ndarray, pd.Series, List[str]]], y=None, **fit_params) -> ndarray: - """ - Fits and transforms a list of strings using the transform method. - This is modelled to function as the sklearn fit_transform method + def fit_transform( + self, X: np.ndarray | pd.Series | list[str], y=None, **fit_params + ) -> ndarray: + """Fits and transforms a list of strings using the transform method. + This is modelled to function as the sklearn fit_transform method. Parameters ---------- @@ -95,6 +93,6 @@ def fit_transform(self, X: Optional[Union[np.ndarray, pd.Series, List[str]]], y= Returns ------- - ndarray + ndarray """ - return self.fit(X, y).transform(X) \ No newline at end of file + return self.fit(X, y).transform(X) diff --git a/skllm/preprocessing/gpt_translator.py b/skllm/preprocessing/gpt_translator.py index 368ef58..3ad86a7 100644 --- a/skllm/preprocessing/gpt_translator.py +++ b/skllm/preprocessing/gpt_translator.py @@ -1,4 +1,8 @@ -from typing import Optional +from typing import Any, List, Optional, Union + +import numpy as np +from numpy import ndarray +from pandas import Series from skllm.openai.base_gpt import BaseZeroShotGPTTransformer as _BaseGPT from skllm.prompts.builders import build_translation_prompt @@ -35,3 +39,11 @@ def _get_prompt(self, X: str) -> str: translated sample """ return build_translation_prompt(X, self.output_language) + + def transform(self, X: Union[ndarray, Series, List[str]], **kwargs: Any) -> ndarray: + y = super().transform(X, **kwargs) + y = np.asarray( + [i.replace("[Translated text:]", "").replace("```", "").strip() for i in y], + dtype=object, + ) + return y diff --git a/skllm/prompts/templates.py b/skllm/prompts/templates.py index b90fc3b..d3c2926 100644 --- a/skllm/prompts/templates.py +++ b/skllm/prompts/templates.py @@ -81,16 +81,9 @@ """ TRANSLATION_PROMPT_TEMPLATE = """ -You will be provided with an arbitrary text sample, delimited by triple backticks. -Your task is to translate this text to {output_language} language and output the translated text. +If the original text, delimited by triple backticks, is already in {output_language} language, output the original text. +Otherwise, translate the original text, delimited by triple backticks, to {output_language} language, and output the translated text only. Do not output any additional information except the translated text. -Perform the following actions: -1. Determine the language of the text sample. -2. If the language is not {output_language}, translate the text sample to {output_language} language. -3. Output the translated text. -If the text sample provided is not in a recognizable language, output "No translation available". -Do not output any additional information except the translated text. - -Text sample: ```{x}``` -Translated text: +Original text: ```{x}``` +Output: """