diff --git a/skllm/completions.py b/skllm/completions.py index 02e07df..dc64911 100644 --- a/skllm/completions.py +++ b/skllm/completions.py @@ -3,8 +3,11 @@ def get_chat_completion( - messages, openai_key=None, openai_org=None, model="gpt-3.5-turbo", max_retries=3 + messages: dict, openai_key: str=None, openai_org: str=None, model: str="gpt-3.5-turbo", max_retries: int=3 ): + """ + Gets a chat completion from the OpenAI API. + """ if model.startswith("gpt4all::"): return _g4a_get_chat_completion(messages, model[9:]) else: diff --git a/skllm/gpt4all_client.py b/skllm/gpt4all_client.py index e8ef0ae..2331463 100644 --- a/skllm/gpt4all_client.py +++ b/skllm/gpt4all_client.py @@ -1,3 +1,5 @@ +from typing import Dict + try: from gpt4all import GPT4All except (ImportError, ModuleNotFoundError): @@ -6,18 +8,33 @@ _loaded_models = {} -def get_chat_completion(messages, model="ggml-gpt4all-j-v1.3-groovy"): +def get_chat_completion(messages: Dict, model: str="ggml-gpt4all-j-v1.3-groovy") -> Dict: + """ + Gets a chat completion from GPT4All + + Parameters + ---------- + messages : Dict + The messages to use as a prompt for the chat completion. + model : str + The model to use for the chat completion. Defaults to "ggml-gpt4all-j-v1.3-groovy". + + Returns + ------- + completion : Dict + """ if GPT4All is None: raise ImportError( "gpt4all is not installed, try `pip install scikit-llm[gpt4all]`" ) if model not in _loaded_models.keys(): _loaded_models[model] = GPT4All(model) + return _loaded_models[model].chat_completion( messages, verbose=False, streaming=False, temp=1e-10 ) -def unload_models(): +def unload_models() -> None: global _loaded_models _loaded_models = {} diff --git a/skllm/models/gpt_few_shot_clf.py b/skllm/models/gpt_few_shot_clf.py index b552f38..31a6747 100644 --- a/skllm/models/gpt_few_shot_clf.py +++ b/skllm/models/gpt_few_shot_clf.py @@ -25,8 +25,7 @@ def fit( X: Union[np.ndarray, pd.Series, List[str]], y: Union[np.ndarray, pd.Series, List[str]], ): - """Fits the model by storing the training data and extracting the - unique targets. + """Fits the model to the given data. Parameters ---------- diff --git a/skllm/models/gpt_zero_shot_clf.py b/skllm/models/gpt_zero_shot_clf.py index f970485..fd8a75c 100644 --- a/skllm/models/gpt_zero_shot_clf.py +++ b/skllm/models/gpt_zero_shot_clf.py @@ -35,7 +35,7 @@ class _BaseZeroShotGPTClassifier(ABC, BaseEstimator, ClassifierMixin, _OAIMixin) The default label to use if the LLM could not generate a response for a sample. If set to 'Random' a random label will be chosen based on probabilities from the training set. """ - + def __init__( self, openai_key: Optional[str] = None, @@ -48,6 +48,19 @@ def __init__( self.default_label = default_label def _to_np(self, X): + """ + Converts X to a numpy array. + + Parameters + ---------- + X : Any + The input data to convert to a numpy array. + + Returns + ------- + np.ndarray + The input data as a numpy array. + """ return _to_numpy(X) @abstractmethod @@ -60,11 +73,35 @@ def fit( X: Optional[Union[np.ndarray, pd.Series, List[str]]], y: Union[np.ndarray, pd.Series, List[str], List[List[str]]], ): + """ + Extracts the target for each datapoint in X. + + Parameters + ---------- + X : Optional[Union[np.ndarray, pd.Series, List[str]]] + The input array data to fit the model to. + + y : Union[np.ndarray, pd.Series, List[str], List[List[str]]] + The target array data to fit the model to. + + """ X = self._to_np(X) self.classes_, self.probabilities_ = self._get_unique_targets(y) return self def predict(self, X: Union[np.ndarray, pd.Series, List[str]]): + """ + Predicts the class of each input. + + Parameters + ---------- + X : Union[np.ndarray, pd.Series, List[str]] + The input data to predict the class of. + + Returns + ------- + List[str] + """ X = self._to_np(X) predictions = [] for i in tqdm(range(len(X))): @@ -75,7 +112,7 @@ def predict(self, X: Union[np.ndarray, pd.Series, List[str]]): def _extract_labels(self, y: Any) -> List[str]: pass - def _get_unique_targets(self, y): + def _get_unique_targets(self, y:Any): labels = self._extract_labels(y) counts = Counter(labels) @@ -128,6 +165,17 @@ def __init__( super().__init__(openai_key, openai_org, openai_model, default_label) def _extract_labels(self, y: Any) -> List[str]: + """ + Return the class labels as a list. + + Parameters + ---------- + y : Any + + Returns + ------- + List[str] + """ if isinstance(y, (pd.Series, np.ndarray)): labels = y.tolist() else: @@ -145,6 +193,9 @@ def _get_default_label(self): return self.default_label def _predict_single(self, x): + """ + Predicts the labels for a single sample. + """ completion = self._get_chat_completion(x) try: label = str( @@ -207,6 +258,17 @@ def __init__( self.max_labels = max_labels def _extract_labels(self, y) -> List[str]: + """ + Extracts the labels into a list. + + Parameters + ---------- + y : Any + + Returns + ------- + List[str] + """ labels = [] for l in y: for j in l: @@ -231,6 +293,9 @@ def _get_default_label(self): return result def _predict_single(self, x): + """ + Predicts the labels for a single sample. + """ completion = self._get_chat_completion(x) try: labels = extract_json_key(completion["choices"][0]["message"]["content"], "label") @@ -254,4 +319,15 @@ def fit( X: Optional[Union[np.ndarray, pd.Series, List[str]]], y: List[List[str]], ): + """ + Calls the parent fit method on input data. + + Parameters + ---------- + X : Optional[Union[np.ndarray, pd.Series, List[str]]] + Input array data + y : List[List[str]] + The labels. + + """ return super().fit(X, y) diff --git a/skllm/openai/base_gpt.py b/skllm/openai/base_gpt.py index 8741a7d..3fcc6c8 100644 --- a/skllm/openai/base_gpt.py +++ b/skllm/openai/base_gpt.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Any, List, Optional, Union import numpy as np @@ -18,6 +20,19 @@ class BaseZeroShotGPTTransformer(_BaseEstimator, _TransformerMixin, _OAIMixin): default_output = "Output is unavailable" def _get_chat_completion(self, X): + """ + Gets the chat completion for the given input using open ai API. + + Parameters + ---------- + X : str + Input string + + Returns + ------- + str + + """ prompt = self._get_prompt(X) msgs = [] msgs.append(construct_message("system", self.system_msg)) @@ -31,10 +46,35 @@ def _get_chat_completion(self, X): print(f"Skipping a sample due to the following error: {str(e)}") return self.default_output - def fit(self, X: Any = None, y: Any = None, **kwargs: Any): + def fit(self, X: Any = None, y: Any = None, **kwargs: Any) -> BaseZeroShotGPTTransformer: + """ + Fits the model to the data. + + Parameters + ---------- + X : Any, optional + y : Any, optional + kwargs : dict, optional + + Returns + ------- + self : BaseZeroShotGPTTransformer + """ + return self def transform(self, X: Optional[Union[np.ndarray, pd.Series, List[str]]], **kwargs: Any) -> ndarray: + """ + Converts a list of strings using the open ai API and a predefined prompt. + + Parameters + ---------- + X : Union[np.ndarray, pd.Series, List[str]] + + Returns + ------- + ndarray + """ X = _to_numpy(X) transformed = [] for i in tqdm(range(len(X))): @@ -45,4 +85,16 @@ def transform(self, X: Optional[Union[np.ndarray, pd.Series, List[str]]], **kwar return transformed def fit_transform(self, X: Optional[Union[np.ndarray, pd.Series, List[str]]], y=None, **fit_params) -> ndarray: + """ + Fits and transforms a list of strings using the transform method. + This is modelled to function as the sklearn fit_transform method + + Parameters + ---------- + X : np.ndarray, pd.Series, or list + + Returns + ------- + ndarray + """ return self.fit(X, y).transform(X) \ No newline at end of file diff --git a/skllm/openai/chatgpt.py b/skllm/openai/chatgpt.py index 3a73004..7ea93a7 100644 --- a/skllm/openai/chatgpt.py +++ b/skllm/openai/chatgpt.py @@ -1,5 +1,6 @@ import json from time import sleep +from typing import Any import openai @@ -7,13 +8,47 @@ from skllm.utils import find_json_in_string -def construct_message(role, content): +def construct_message(role: str, content: str) -> dict: + """ + Constructs a message for the OpenAI API. + + Parameters + ---------- + role : str + The role of the message. Must be one of "system", "user", or "assistant". + content : str + The content of the message. + + Returns + ------- + message : dict + """ if role not in ("system", "user", "assistant"): raise ValueError("Invalid role") return {"role": role, "content": content} -def get_chat_completion(messages, key, org, model="gpt-3.5-turbo", max_retries=3): +def get_chat_completion(messages: dict, key: str, org: str, model: str="gpt-3.5-turbo", max_retries: int=3): + """ + Gets a chat completion from the OpenAI API. + + Parameters + ---------- + messages : dict + input messages to use. + key : str + The OPEN AI key to use. + org : str + The OPEN AI organization ID to use. + model : str, optional + The OPEN AI model to use. Defaults to "gpt-3.5-turbo". + max_retries : int, optional + The maximum number of retries to use. Defaults to 3. + + Returns + ------- + completion : dict + """ set_credentials(key, org) error_msg = None error_type = None @@ -33,7 +68,16 @@ def get_chat_completion(messages, key, org, model="gpt-3.5-turbo", max_retries=3 ) -def extract_json_key(json_, key): + +def extract_json_key(json_: str, key: str): + """ + Extracts JSON key from a string. + + json_ : str + The JSON string to extract the key from. + key : str + The key to extract. + """ original_json = json_ for i in range(2): try: @@ -48,4 +92,4 @@ def extract_json_key(json_, key): except Exception: if i == 0: continue - return None \ No newline at end of file + return None diff --git a/skllm/openai/credentials.py b/skllm/openai/credentials.py index 733274a..ac11815 100644 --- a/skllm/openai/credentials.py +++ b/skllm/openai/credentials.py @@ -1,6 +1,15 @@ import openai +def set_credentials(key: str, org: str) -> None: + """ + Set the OpenAI key and organization. -def set_credentials(key: str, org: str): + Parameters + ---------- + key : str + The OpenAI key to use. + org : str + The OPEN AI organization ID to use. + """ openai.api_key = key openai.organization = org diff --git a/skllm/openai/embeddings.py b/skllm/openai/embeddings.py index b06f3e8..6712bda 100644 --- a/skllm/openai/embeddings.py +++ b/skllm/openai/embeddings.py @@ -6,8 +6,29 @@ def get_embedding( - text, key: str, org: str, model="text-embedding-ada-002", max_retries=3 -): + text: str, key: str, org: str, model: str="text-embedding-ada-002", max_retries: int=3 +): + """ + Encodes a string and return the embedding for a string. + + Parameters + ---------- + text : str + The string to encode. + key : str + The OPEN AI key to use. + org : str + The OPEN AI organization ID to use. + model : str, optional + The model to use. Defaults to "text-embedding-ada-002". + max_retries : int, optional + The maximum number of retries to use. Defaults to 3. + + Returns + ------- + emb : list + The GPT embedding for the string. + """ set_credentials(key, org) text = text.replace("\n", " ") error_msg = None diff --git a/skllm/openai/mixin.py b/skllm/openai/mixin.py index 1ab07cd..6fc0856 100644 --- a/skllm/openai/mixin.py +++ b/skllm/openai/mixin.py @@ -4,12 +4,24 @@ class OpenAIMixin: - + """ + A mixin class that provides OpenAI key and organization to other classes. + """ def _set_keys(self, key: Optional[str] = None, org: Optional[str] = None) -> None: + """ + Set the OpenAI key and organization. + """ self.openai_key = key self.openai_org = org def _get_openai_key(self) -> str: + """ + Get the OpenAI key from the class or the config file. + + Returns + ------- + openai_key: str + """ key = self.openai_key if key is None: key = _Config.get_openai_key() @@ -18,6 +30,13 @@ def _get_openai_key(self) -> str: return key def _get_openai_org(self) -> str: + """ + Get the OpenAI organization ID from the class or the config file. + + Returns + ------- + openai_org: str + """ key = self.openai_org if key is None: key = _Config.get_openai_org() diff --git a/skllm/preprocessing/gpt_summarizer.py b/skllm/preprocessing/gpt_summarizer.py index 8804275..e5b22c3 100644 --- a/skllm/preprocessing/gpt_summarizer.py +++ b/skllm/preprocessing/gpt_summarizer.py @@ -5,6 +5,21 @@ class GPTSummarizer(_BaseGPT): + """ + A text summarizer. + + Parameters + ---------- + openai_key : str, optional + The OPEN AI key to use. Defaults to None. + openai_org : str, optional + The OPEN AI organization ID to use. Defaults to None. + openai_model : str, optional + The OPEN AI model to use. Defaults to "gpt-3.5-turbo". + max_words : int, optional + The maximum number of words to use in the summary. Defaults to 15. + + """ system_msg = "You are a text summarizer." default_output = "Summary is unavailable." @@ -19,5 +34,18 @@ def __init__( self.openai_model = openai_model self.max_words = max_words - def _get_prompt(self, X) -> str: + + def _get_prompt(self, X: str) -> str: + """ + Generates the prompt for the given input. + + Parameters + ---------- + X : str + sample to summarize + + Returns + ------- + str + """ return build_summary_prompt(X, self.max_words) diff --git a/skllm/preprocessing/gpt_translator.py b/skllm/preprocessing/gpt_translator.py index 35e4d68..368ef58 100644 --- a/skllm/preprocessing/gpt_translator.py +++ b/skllm/preprocessing/gpt_translator.py @@ -16,7 +16,7 @@ def __init__( openai_org: Optional[str] = None, openai_model: str = "gpt-3.5-turbo", output_language: str = "English", - ): + ) -> None: self._set_keys(openai_key, openai_org) self.openai_model = openai_model self.output_language = output_language diff --git a/skllm/preprocessing/gpt_vectorizer.py b/skllm/preprocessing/gpt_vectorizer.py index 18c17b6..fe4537b 100644 --- a/skllm/preprocessing/gpt_vectorizer.py +++ b/skllm/preprocessing/gpt_vectorizer.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Any, List, Optional, Union import numpy as np @@ -13,6 +15,18 @@ class GPTVectorizer(_BaseEstimator, _TransformerMixin, _OAIMixin): + """ + A class that uses OPEN AI embedding model that converts text to GPT embeddings. + + Parameters + ---------- + openai_embedding_model : str + The OPEN AI embedding model to use. Defaults to "text-embedding-ada-002". + openai_key : str, optional + The OPEN AI key to use. Defaults to None. + openai_org : str, optional + The OPEN AI organization ID to use. Defaults to None. + """ def __init__( self, openai_embedding_model: str = "text-embedding-ada-002", @@ -22,10 +36,37 @@ def __init__( self.openai_embedding_model = openai_embedding_model self._set_keys(openai_key, openai_org) - def fit(self, X: Any = None, y: Any = None, **kwargs): + def fit(self, X: Any = None, y: Any = None, **kwargs) -> GPTVectorizer: + """ + Fits the GPTVectorizer to the data. + This is modelled to function as the sklearn fit method. + + Parameters + ---------- + X : Any, optional + y : Any, optional + kwargs : dict, optional + + Returns + ------- + self : GPTVectorizer + """ return self def transform(self, X: Optional[Union[np.ndarray, pd.Series, List[str]]]) -> ndarray: + """ + Transforms a list of strings into a list of GPT embeddings. + This is modelled to function as the sklearn transform method + + Parameters + ---------- + X : Optional[Union[np.ndarray, pd.Series, List[str]]] + The input array of strings to transform into GPT embeddings. + + Returns + ------- + embeddings : np.ndarray + """ X = _to_numpy(X) embeddings = [] for i in tqdm(range(len(X))): @@ -36,4 +77,18 @@ def transform(self, X: Optional[Union[np.ndarray, pd.Series, List[str]]]) -> nda return embeddings def fit_transform(self, X: Optional[Union[np.ndarray, pd.Series, List[str]]], y=None, **fit_params) -> ndarray: + """ + Fits and transforms a list of strings into a list of GPT embeddings. + This is modelled to function as the sklearn fit_transform method + + Parameters + ---------- + X : Optional[Union[np.ndarray, pd.Series, List[str]]] + The input array of strings to transform into GPT embeddings. + y : Any, optional + + Returns + ------- + embeddings : np.ndarray + """ return self.fit(X, y).transform(X) diff --git a/skllm/utils.py b/skllm/utils.py index 49b2c03..eca2289 100644 --- a/skllm/utils.py +++ b/skllm/utils.py @@ -1,8 +1,21 @@ import numpy as np import pandas as pd +from typing import Any -def to_numpy(X): +def to_numpy(X: Any) -> np.ndarray: + """ + Converts a pandas Series or list to a numpy array. + + Parameters + ---------- + X : Any + The data to convert to a numpy array. + + Returns + ------- + X : np.ndarray + """ if isinstance(X, pd.Series): X = X.to_numpy().astype(object) elif isinstance(X, list): @@ -12,11 +25,24 @@ def to_numpy(X): return X -def find_json_in_string(string): +def find_json_in_string(string: str) -> str: + """ + Finds the JSON object in a string. + + Parameters + ---------- + string : str + The string to search for a JSON object. + + Returns + ------- + json_string : str + """ + start = string.find("{") end = string.rfind("}") if start != -1 and end != -1: json_string = string[start : end + 1] else: - json_string = {} + json_string = "{}" return json_string