Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion skllm/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@


def get_chat_completion(
messages, openai_key=None, openai_org=None, model="gpt-3.5-turbo", max_retries=3
messages: dict, openai_key: str=None, openai_org: str=None, model: str="gpt-3.5-turbo", max_retries: int=3
):
"""
Gets a chat completion from the OpenAI API.
"""
if model.startswith("gpt4all::"):
return _g4a_get_chat_completion(messages, model[9:])
else:
Expand Down
21 changes: 19 additions & 2 deletions skllm/gpt4all_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Dict

try:
from gpt4all import GPT4All
except (ImportError, ModuleNotFoundError):
Expand All @@ -6,18 +8,33 @@
_loaded_models = {}


def get_chat_completion(messages, model="ggml-gpt4all-j-v1.3-groovy"):
def get_chat_completion(messages: Dict, model: str="ggml-gpt4all-j-v1.3-groovy") -> Dict:
"""
Gets a chat completion from GPT4All

Parameters
----------
messages : Dict
The messages to use as a prompt for the chat completion.
model : str
The model to use for the chat completion. Defaults to "ggml-gpt4all-j-v1.3-groovy".

Returns
-------
completion : Dict
"""
if GPT4All is None:
raise ImportError(
"gpt4all is not installed, try `pip install scikit-llm[gpt4all]`"
)
if model not in _loaded_models.keys():
_loaded_models[model] = GPT4All(model)

return _loaded_models[model].chat_completion(
messages, verbose=False, streaming=False, temp=1e-10
)


def unload_models():
def unload_models() -> None:
global _loaded_models
_loaded_models = {}
3 changes: 1 addition & 2 deletions skllm/models/gpt_few_shot_clf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ def fit(
X: Union[np.ndarray, pd.Series, List[str]],
y: Union[np.ndarray, pd.Series, List[str]],
):
"""Fits the model by storing the training data and extracting the
unique targets.
"""Fits the model to the given data.

Parameters
----------
Expand Down
80 changes: 78 additions & 2 deletions skllm/models/gpt_zero_shot_clf.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class _BaseZeroShotGPTClassifier(ABC, BaseEstimator, ClassifierMixin, _OAIMixin)
The default label to use if the LLM could not generate a response for a sample. If set to 'Random' a random
label will be chosen based on probabilities from the training set.
"""

def __init__(
self,
openai_key: Optional[str] = None,
Expand All @@ -48,6 +48,19 @@ def __init__(
self.default_label = default_label

def _to_np(self, X):
"""
Converts X to a numpy array.

Parameters
----------
X : Any
The input data to convert to a numpy array.

Returns
-------
np.ndarray
The input data as a numpy array.
"""
return _to_numpy(X)

@abstractmethod
Expand All @@ -60,11 +73,35 @@ def fit(
X: Optional[Union[np.ndarray, pd.Series, List[str]]],
y: Union[np.ndarray, pd.Series, List[str], List[List[str]]],
):
"""
Extracts the target for each datapoint in X.

Parameters
----------
X : Optional[Union[np.ndarray, pd.Series, List[str]]]
The input array data to fit the model to.

y : Union[np.ndarray, pd.Series, List[str], List[List[str]]]
The target array data to fit the model to.

"""
X = self._to_np(X)
self.classes_, self.probabilities_ = self._get_unique_targets(y)
return self

def predict(self, X: Union[np.ndarray, pd.Series, List[str]]):
"""
Predicts the class of each input.

Parameters
----------
X : Union[np.ndarray, pd.Series, List[str]]
The input data to predict the class of.

Returns
-------
List[str]
"""
X = self._to_np(X)
predictions = []
for i in tqdm(range(len(X))):
Expand All @@ -75,7 +112,7 @@ def predict(self, X: Union[np.ndarray, pd.Series, List[str]]):
def _extract_labels(self, y: Any) -> List[str]:
pass

def _get_unique_targets(self, y):
def _get_unique_targets(self, y:Any):
labels = self._extract_labels(y)

counts = Counter(labels)
Expand Down Expand Up @@ -128,6 +165,17 @@ def __init__(
super().__init__(openai_key, openai_org, openai_model, default_label)

def _extract_labels(self, y: Any) -> List[str]:
"""
Return the class labels as a list.

Parameters
----------
y : Any

Returns
-------
List[str]
"""
if isinstance(y, (pd.Series, np.ndarray)):
labels = y.tolist()
else:
Expand All @@ -145,6 +193,9 @@ def _get_default_label(self):
return self.default_label

def _predict_single(self, x):
"""
Predicts the labels for a single sample.
"""
completion = self._get_chat_completion(x)
try:
label = str(
Expand Down Expand Up @@ -207,6 +258,17 @@ def __init__(
self.max_labels = max_labels

def _extract_labels(self, y) -> List[str]:
"""
Extracts the labels into a list.

Parameters
----------
y : Any

Returns
-------
List[str]
"""
labels = []
for l in y:
for j in l:
Expand All @@ -231,6 +293,9 @@ def _get_default_label(self):
return result

def _predict_single(self, x):
"""
Predicts the labels for a single sample.
"""
completion = self._get_chat_completion(x)
try:
labels = extract_json_key(completion["choices"][0]["message"]["content"], "label")
Expand All @@ -254,4 +319,15 @@ def fit(
X: Optional[Union[np.ndarray, pd.Series, List[str]]],
y: List[List[str]],
):
"""
Calls the parent fit method on input data.

Parameters
----------
X : Optional[Union[np.ndarray, pd.Series, List[str]]]
Input array data
y : List[List[str]]
The labels.

"""
return super().fit(X, y)
54 changes: 53 additions & 1 deletion skllm/openai/base_gpt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing import Any, List, Optional, Union

import numpy as np
Expand All @@ -18,6 +20,19 @@ class BaseZeroShotGPTTransformer(_BaseEstimator, _TransformerMixin, _OAIMixin):
default_output = "Output is unavailable"

def _get_chat_completion(self, X):
"""
Gets the chat completion for the given input using open ai API.

Parameters
----------
X : str
Input string

Returns
-------
str

"""
prompt = self._get_prompt(X)
msgs = []
msgs.append(construct_message("system", self.system_msg))
Expand All @@ -31,10 +46,35 @@ def _get_chat_completion(self, X):
print(f"Skipping a sample due to the following error: {str(e)}")
return self.default_output

def fit(self, X: Any = None, y: Any = None, **kwargs: Any):
def fit(self, X: Any = None, y: Any = None, **kwargs: Any) -> BaseZeroShotGPTTransformer:
"""
Fits the model to the data.

Parameters
----------
X : Any, optional
y : Any, optional
kwargs : dict, optional

Returns
-------
self : BaseZeroShotGPTTransformer
"""

return self

def transform(self, X: Optional[Union[np.ndarray, pd.Series, List[str]]], **kwargs: Any) -> ndarray:
"""
Converts a list of strings using the open ai API and a predefined prompt.

Parameters
----------
X : Union[np.ndarray, pd.Series, List[str]]

Returns
-------
ndarray
"""
X = _to_numpy(X)
transformed = []
for i in tqdm(range(len(X))):
Expand All @@ -45,4 +85,16 @@ def transform(self, X: Optional[Union[np.ndarray, pd.Series, List[str]]], **kwar
return transformed

def fit_transform(self, X: Optional[Union[np.ndarray, pd.Series, List[str]]], y=None, **fit_params) -> ndarray:
"""
Fits and transforms a list of strings using the transform method.
This is modelled to function as the sklearn fit_transform method

Parameters
----------
X : np.ndarray, pd.Series, or list

Returns
-------
ndarray
"""
return self.fit(X, y).transform(X)
52 changes: 48 additions & 4 deletions skllm/openai/chatgpt.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,54 @@
import json
from time import sleep
from typing import Any

import openai

from skllm.openai.credentials import set_credentials
from skllm.utils import find_json_in_string


def construct_message(role, content):
def construct_message(role: str, content: str) -> dict:
"""
Constructs a message for the OpenAI API.

Parameters
----------
role : str
The role of the message. Must be one of "system", "user", or "assistant".
content : str
The content of the message.

Returns
-------
message : dict
"""
if role not in ("system", "user", "assistant"):
raise ValueError("Invalid role")
return {"role": role, "content": content}


def get_chat_completion(messages, key, org, model="gpt-3.5-turbo", max_retries=3):
def get_chat_completion(messages: dict, key: str, org: str, model: str="gpt-3.5-turbo", max_retries: int=3):
"""
Gets a chat completion from the OpenAI API.

Parameters
----------
messages : dict
input messages to use.
key : str
The OPEN AI key to use.
org : str
The OPEN AI organization ID to use.
model : str, optional
The OPEN AI model to use. Defaults to "gpt-3.5-turbo".
max_retries : int, optional
The maximum number of retries to use. Defaults to 3.

Returns
-------
completion : dict
"""
set_credentials(key, org)
error_msg = None
error_type = None
Expand All @@ -33,7 +68,16 @@ def get_chat_completion(messages, key, org, model="gpt-3.5-turbo", max_retries=3
)


def extract_json_key(json_, key):

def extract_json_key(json_: str, key: str):
"""
Extracts JSON key from a string.

json_ : str
The JSON string to extract the key from.
key : str
The key to extract.
"""
original_json = json_
for i in range(2):
try:
Expand All @@ -48,4 +92,4 @@ def extract_json_key(json_, key):
except Exception:
if i == 0:
continue
return None
return None
11 changes: 10 additions & 1 deletion skllm/openai/credentials.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
import openai

def set_credentials(key: str, org: str) -> None:
"""
Set the OpenAI key and organization.

def set_credentials(key: str, org: str):
Parameters
----------
key : str
The OpenAI key to use.
org : str
The OPEN AI organization ID to use.
"""
openai.api_key = key
openai.organization = org
Loading