Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@ isort
ruff
docformatter
interrogate
numpy
pandas
100 changes: 91 additions & 9 deletions skllm/models/gpt_zero_shot_clf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import random
from abc import ABC, abstractmethod
from collections import Counter
from typing import Any, List, Optional, Union
from typing import Any, List, Optional, Union, Literal

import numpy as np
import pandas as pd
Expand All @@ -19,18 +19,42 @@


class _BaseZeroShotGPTClassifier(ABC, BaseEstimator, ClassifierMixin, _OAIMixin):
"""Base class for zero-shot classifiers.

Parameters
----------
openai_key : Optional[str] , default : None
Your OpenAI API key. If None, the key will be read from the SKLLM_CONFIG_OPENAI_KEY environment variable.
openai_org : Optional[str] , default : None
Your OpenAI organization. If None, the organization will be read from the SKLLM_CONFIG_OPENAI_ORG
environment variable.
openai_model : str , default : "gpt-3.5-turbo"
The OpenAI model to use. See https://beta.openai.com/docs/api-reference/available-models for a list of
available models.
default_label : Optional[Union[List[str], str]] , default : 'Random'
The default label to use if the LLM could not generate a response for a sample. If set to 'Random' a random
label will be chosen based on probabilities from the training set.
"""

def __init__(
self,
openai_key: Optional[str] = None,
openai_org: Optional[str] = None,
openai_model: str = "gpt-3.5-turbo",
default_label: Optional[Union[List[str], str]] = 'Random',
):
self._set_keys(openai_key, openai_org)
self.openai_model = openai_model
self.default_label = default_label

def _to_np(self, X):
return _to_numpy(X)

@abstractmethod
def _get_default_label(self):
""" Returns the default label based on the default_label argument. """
raise NotImplementedError()

def fit(
self,
X: Optional[Union[np.ndarray, pd.Series, List[str]]],
Expand Down Expand Up @@ -77,13 +101,31 @@ def _get_chat_completion(self, x):


class ZeroShotGPTClassifier(_BaseZeroShotGPTClassifier):
"""Zero-shot classifier for multiclass classification.

Parameters
----------
openai_key : Optional[str] , default : None
Your OpenAI API key. If None, the key will be read from the SKLLM_CONFIG_OPENAI_KEY environment variable.
openai_org : Optional[str] , default : None
Your OpenAI organization. If None, the organization will be read from the SKLLM_CONFIG_OPENAI_ORG
environment variable.
openai_model : str , default : "gpt-3.5-turbo"
The OpenAI model to use. See https://beta.openai.com/docs/api-reference/available-models for a list of
available models.
default_label : Optional[str] , default : 'Random'
The default label to use if the LLM could not generate a response for a sample. If set to 'Random' a random
label will be chosen based on probabilities from the training set.
"""

def __init__(
self,
openai_key: Optional[str] = None,
openai_org: Optional[str] = None,
openai_model: str = "gpt-3.5-turbo",
default_label: Optional[str] = 'Random',
):
super().__init__(openai_key, openai_org, openai_model)
super().__init__(openai_key, openai_org, openai_model, default_label)

def _extract_labels(self, y: Any) -> List[str]:
if isinstance(y, (pd.Series, np.ndarray)):
Expand All @@ -95,6 +137,13 @@ def _extract_labels(self, y: Any) -> List[str]:
def _get_prompt(self, x) -> str:
return build_zero_shot_prompt_slc(x, repr(self.classes_))

def _get_default_label(self):
""" Returns the default label based on the default_label argument. """
if self.default_label == "Random":
return random.choices(self.classes_, self.probabilities_)[0]
else:
return self.default_label

def _predict_single(self, x):
completion = self._get_chat_completion(x)
try:
Expand All @@ -116,7 +165,7 @@ def _predict_single(self, x):
if label not in self.classes_:
label = label.replace("'", "").replace('"', "")
if label not in self.classes_: # try again
label = random.choices(self.classes_, self.probabilities_)[0]
label = self._get_default_label()
return label

def fit(
Expand All @@ -129,16 +178,37 @@ def fit(


class MultiLabelZeroShotGPTClassifier(_BaseZeroShotGPTClassifier):
"""Zero-shot classifier for multilabel classification.

Parameters
----------
openai_key : Optional[str] , default : None
Your OpenAI API key. If None, the key will be read from the SKLLM_CONFIG_OPENAI_KEY environment variable.
openai_org : Optional[str] , default : None
Your OpenAI organization. If None, the organization will be read from the SKLLM_CONFIG_OPENAI_ORG
environment variable.
openai_model : str , default : "gpt-3.5-turbo"
The OpenAI model to use. See https://beta.openai.com/docs/api-reference/available-models for a list of
available models.
default_label : Optional[Union[List[str], Literal['Random']] , default : 'Random'
The default label to use if the LLM could not generate a response for a sample. If set to 'Random' a random
label will be chosen based on probabilities from the training set.
max_labels : int , default : 3
The maximum number of labels to predict for each sample.
"""
def __init__(
self,
openai_key: Optional[str] = None,
openai_org: Optional[str] = None,
openai_model: str = "gpt-3.5-turbo",
default_label: Optional[Union[List[str], Literal['Random']]] = 'Random',
max_labels: int = 3,
):
super().__init__(openai_key, openai_org, openai_model)
super().__init__(openai_key, openai_org, openai_model, default_label)
if max_labels < 2:
raise ValueError("max_labels should be at least 2")
if isinstance(default_label, str) and default_label != "Random":
raise ValueError("default_label should be a list of strings or 'Random'")
self.max_labels = max_labels

def _extract_labels(self, y) -> List[str]:
Expand All @@ -152,6 +222,19 @@ def _extract_labels(self, y) -> List[str]:
def _get_prompt(self, x) -> str:
return build_zero_shot_prompt_mlc(x, repr(self.classes_), self.max_labels)

def _get_default_label(self):
""" Returns the default label based on the default_label argument. """
result = []
if isinstance(self.default_label, str) and self.default_label == "Random":
for cls, probability in zip(self.classes_, self.probabilities_):
coin_flip = random.choices([0,1], [1-probability, probability])[0]
if coin_flip == 1:
result.append(cls)
else:
result = self.default_label
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that this behaviour might still be a bit confusing for the users. If a string != "Random" is provided instead of a list, the label will again be a string. So, I would still add an additional type check and convert to list whenever applicable.

If you intentionally want to have a flexibility of having non-list outputs, maybe this could be done for default_label = None as a special case (this should be properly documented then). But in my opinion, it would be always better to have a list as an output.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See lines 221-222, The output can be either a list or a None.
They main point of this PR is to facilitate a way to ignore the model predictions when he fails,
that is most commonly achieved by setting label = None

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Nadav-Barak you mean lines 210-211? I must have missed that change.


return result

def _predict_single(self, x):
completion = self._get_chat_completion(x)
try:
Expand All @@ -162,11 +245,10 @@ def _predict_single(self, x):
labels = []

labels = list(filter(lambda l: l in self.classes_, labels))

if len(labels) > self.max_labels:
labels = labels[: self.max_labels - 1]
elif len(labels) < 1:
labels = [random.choices(self.classes_, self.probabilities_)[0]]
if len(labels) == 0:
labels = self._get_default_label()
if labels is not None and len(labels) > self.max_labels:
labels = labels[:self.max_labels - 1]
return labels

def fit(
Expand Down
22 changes: 10 additions & 12 deletions skllm/openai/base_gpt.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
from typing import Any, List, Optional, Union

import numpy as np
from numpy import ndarray
import pandas as pd
from skllm.utils import to_numpy as _to_numpy
from typing import Any, Optional, Union, List
from skllm.openai.mixin import OpenAIMixin as _OAIMixin
from numpy import ndarray
from sklearn.base import BaseEstimator as _BaseEstimator
from sklearn.base import TransformerMixin as _TransformerMixin
from tqdm import tqdm
from sklearn.base import (
BaseEstimator as _BaseEstimator,
TransformerMixin as _TransformerMixin,
)
from skllm.openai.chatgpt import (
construct_message,
get_chat_completion,
)

from skllm.openai.chatgpt import construct_message, get_chat_completion
from skllm.openai.mixin import OpenAIMixin as _OAIMixin
from skllm.utils import to_numpy as _to_numpy


class BaseZeroShotGPTTransformer(_BaseEstimator, _TransformerMixin, _OAIMixin):

Expand Down
1 change: 1 addition & 0 deletions skllm/openai/credentials.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import openai


def set_credentials(key: str, org: str):
openai.api_key = key
openai.organization = org
5 changes: 4 additions & 1 deletion skllm/openai/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import openai
from time import sleep

import openai

from skllm.openai.credentials import set_credentials


def get_embedding(
text, key: str, org: str, model="text-embedding-ada-002", max_retries=3
):
Expand Down
2 changes: 2 additions & 0 deletions skllm/openai/mixin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Optional

from skllm.config import SKLLMConfig as _Config


class OpenAIMixin:

def _set_keys(self, key: Optional[str] = None, org: Optional[str] = None) -> None:
Expand Down
16 changes: 8 additions & 8 deletions skllm/preprocessing/gpt_vectorizer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from sklearn.base import (
BaseEstimator as _BaseEstimator,
TransformerMixin as _TransformerMixin,
)
from typing import Any, Optional, Union, List
from tqdm import tqdm
from typing import Any, List, Optional, Union

import numpy as np
from numpy import ndarray
import pandas as pd
from skllm.openai.mixin import OpenAIMixin as _OAIMixin
from numpy import ndarray
from sklearn.base import BaseEstimator as _BaseEstimator
from sklearn.base import TransformerMixin as _TransformerMixin
from tqdm import tqdm

from skllm.openai.embeddings import get_embedding as _get_embedding
from skllm.openai.mixin import OpenAIMixin as _OAIMixin
from skllm.utils import to_numpy as _to_numpy


Expand Down
3 changes: 1 addition & 2 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
from . import test_chatgpt
from . import test_gpt_zero_shot_clf
from . import test_chatgpt, test_gpt_zero_shot_clf
11 changes: 9 additions & 2 deletions tests/test_chatgpt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import unittest
from unittest.mock import patch
from skllm.openai.chatgpt import construct_message, get_chat_completion, extract_json_key

from skllm.openai.chatgpt import (
construct_message,
extract_json_key,
get_chat_completion,
)


class TestChatGPT(unittest.TestCase):

Expand All @@ -16,7 +22,8 @@ def test_get_chat_completion(self, mock_create, mock_set_credentials):
result = get_chat_completion(messages, key, org, model)

self.assertTrue(mock_set_credentials.call_count <= 1, "set_credentials should be called at most once")
self.assertEqual(mock_create.call_count, 2, "ChatCompletion.create should be called twice due to an exception on the first call")
self.assertEqual(mock_create.call_count, 2, "ChatCompletion.create should be called twice due to an exception "
"on the first call")
self.assertEqual(result, "success")

def test_construct_message(self):
Expand Down
Loading