Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 23 additions & 25 deletions skllm/openai/base_gpt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, List, Optional, Union
from typing import Any

import numpy as np
import pandas as pd
Expand All @@ -14,24 +14,21 @@
from skllm.utils import to_numpy as _to_numpy


class BaseZeroShotGPTTransformer(_BaseEstimator, _TransformerMixin, _OAIMixin):

class BaseZeroShotGPTTransformer(_BaseEstimator, _TransformerMixin, _OAIMixin):
system_msg = "You are an scikit-learn transformer."
default_output = "Output is unavailable"

def _get_chat_completion(self, X):
"""
Gets the chat completion for the given input using open ai API.
"""Gets the chat completion for the given input using open ai API.

Parameters
----------
X : str
Input string

Returns
-------
str

"""
prompt = self._get_prompt(X)
msgs = []
Expand All @@ -45,10 +42,11 @@ def _get_chat_completion(self, X):
except Exception as e:
print(f"Skipping a sample due to the following error: {str(e)}")
return self.default_output

def fit(self, X: Any = None, y: Any = None, **kwargs: Any) -> BaseZeroShotGPTTransformer:
"""
Fits the model to the data.

def fit(
self, X: Any = None, y: Any = None, **kwargs: Any
) -> BaseZeroShotGPTTransformer:
"""Fits the model to the data.

Parameters
----------
Expand All @@ -60,12 +58,13 @@ def fit(self, X: Any = None, y: Any = None, **kwargs: Any) -> BaseZeroShotGPTTra
-------
self : BaseZeroShotGPTTransformer
"""

return self

def transform(self, X: Optional[Union[np.ndarray, pd.Series, List[str]]], **kwargs: Any) -> ndarray:
"""
Converts a list of strings using the open ai API and a predefined prompt.
def transform(
self, X: np.ndarray | pd.Series | list[str], **kwargs: Any
) -> ndarray:
"""Converts a list of strings using the open ai API and a predefined
prompt.

Parameters
----------
Expand All @@ -78,23 +77,22 @@ def transform(self, X: Optional[Union[np.ndarray, pd.Series, List[str]]], **kwar
X = _to_numpy(X)
transformed = []
for i in tqdm(range(len(X))):
transformed.append(
self._get_chat_completion(X[i])
)
transformed = np.asarray(transformed, dtype = object)
transformed.append(self._get_chat_completion(X[i]))
transformed = np.asarray(transformed, dtype=object)
return transformed

def fit_transform(self, X: Optional[Union[np.ndarray, pd.Series, List[str]]], y=None, **fit_params) -> ndarray:
"""
Fits and transforms a list of strings using the transform method.
This is modelled to function as the sklearn fit_transform method
def fit_transform(
self, X: np.ndarray | pd.Series | list[str], y=None, **fit_params
) -> ndarray:
"""Fits and transforms a list of strings using the transform method.
This is modelled to function as the sklearn fit_transform method.

Parameters
----------
X : np.ndarray, pd.Series, or list

Returns
-------
ndarray
ndarray
"""
return self.fit(X, y).transform(X)
return self.fit(X, y).transform(X)
14 changes: 13 additions & 1 deletion skllm/preprocessing/gpt_translator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from typing import Optional
from typing import Any, List, Optional, Union

import numpy as np
from numpy import ndarray
from pandas import Series

from skllm.openai.base_gpt import BaseZeroShotGPTTransformer as _BaseGPT
from skllm.prompts.builders import build_translation_prompt
Expand Down Expand Up @@ -35,3 +39,11 @@ def _get_prompt(self, X: str) -> str:
translated sample
"""
return build_translation_prompt(X, self.output_language)

def transform(self, X: Union[ndarray, Series, List[str]], **kwargs: Any) -> ndarray:
y = super().transform(X, **kwargs)
y = np.asarray(
[i.replace("[Translated text:]", "").replace("```", "").strip() for i in y],
dtype=object,
)
return y
15 changes: 4 additions & 11 deletions skllm/prompts/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,9 @@
"""

TRANSLATION_PROMPT_TEMPLATE = """
You will be provided with an arbitrary text sample, delimited by triple backticks.
Your task is to translate this text to {output_language} language and output the translated text.
If the original text, delimited by triple backticks, is already in {output_language} language, output the original text.
Otherwise, translate the original text, delimited by triple backticks, to {output_language} language, and output the translated text only. Do not output any additional information except the translated text.

Perform the following actions:
1. Determine the language of the text sample.
2. If the language is not {output_language}, translate the text sample to {output_language} language.
3. Output the translated text.
If the text sample provided is not in a recognizable language, output "No translation available".
Do not output any additional information except the translated text.

Text sample: ```{x}```
Translated text:
Original text: ```{x}```
Output:
"""