-
Notifications
You must be signed in to change notification settings - Fork 281
added default_label argument and functionality + isort formatting #27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,3 +4,5 @@ isort | |
| ruff | ||
| docformatter | ||
| interrogate | ||
| numpy | ||
| pandas | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,7 @@ | ||
| import random | ||
| from abc import ABC, abstractmethod | ||
| from collections import Counter | ||
| from typing import Any, List, Optional, Union | ||
| from typing import Any, List, Optional, Union, Literal | ||
|
|
||
| import numpy as np | ||
| import pandas as pd | ||
|
|
@@ -19,18 +19,42 @@ | |
|
|
||
|
|
||
| class _BaseZeroShotGPTClassifier(ABC, BaseEstimator, ClassifierMixin, _OAIMixin): | ||
| """Base class for zero-shot classifiers. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| openai_key : Optional[str] , default : None | ||
| Your OpenAI API key. If None, the key will be read from the SKLLM_CONFIG_OPENAI_KEY environment variable. | ||
| openai_org : Optional[str] , default : None | ||
| Your OpenAI organization. If None, the organization will be read from the SKLLM_CONFIG_OPENAI_ORG | ||
| environment variable. | ||
| openai_model : str , default : "gpt-3.5-turbo" | ||
| The OpenAI model to use. See https://beta.openai.com/docs/api-reference/available-models for a list of | ||
| available models. | ||
| default_label : Optional[Union[List[str], str]] , default : 'Random' | ||
| The default label to use if the LLM could not generate a response for a sample. If set to 'Random' a random | ||
| label will be chosen based on probabilities from the training set. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| openai_key: Optional[str] = None, | ||
| openai_org: Optional[str] = None, | ||
| openai_model: str = "gpt-3.5-turbo", | ||
| default_label: Optional[Union[List[str], str]] = 'Random', | ||
| ): | ||
| self._set_keys(openai_key, openai_org) | ||
| self.openai_model = openai_model | ||
| self.default_label = default_label | ||
|
|
||
| def _to_np(self, X): | ||
| return _to_numpy(X) | ||
|
|
||
| @abstractmethod | ||
| def _get_default_label(self): | ||
| """ Returns the default label based on the default_label argument. """ | ||
| raise NotImplementedError() | ||
|
|
||
| def fit( | ||
| self, | ||
| X: Optional[Union[np.ndarray, pd.Series, List[str]]], | ||
|
|
@@ -77,13 +101,31 @@ def _get_chat_completion(self, x): | |
|
|
||
|
|
||
| class ZeroShotGPTClassifier(_BaseZeroShotGPTClassifier): | ||
| """Zero-shot classifier for multiclass classification. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| openai_key : Optional[str] , default : None | ||
| Your OpenAI API key. If None, the key will be read from the SKLLM_CONFIG_OPENAI_KEY environment variable. | ||
| openai_org : Optional[str] , default : None | ||
| Your OpenAI organization. If None, the organization will be read from the SKLLM_CONFIG_OPENAI_ORG | ||
| environment variable. | ||
| openai_model : str , default : "gpt-3.5-turbo" | ||
| The OpenAI model to use. See https://beta.openai.com/docs/api-reference/available-models for a list of | ||
| available models. | ||
| default_label : Optional[str] , default : 'Random' | ||
| The default label to use if the LLM could not generate a response for a sample. If set to 'Random' a random | ||
| label will be chosen based on probabilities from the training set. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| openai_key: Optional[str] = None, | ||
| openai_org: Optional[str] = None, | ||
| openai_model: str = "gpt-3.5-turbo", | ||
| default_label: Optional[str] = 'Random', | ||
| ): | ||
| super().__init__(openai_key, openai_org, openai_model) | ||
| super().__init__(openai_key, openai_org, openai_model, default_label) | ||
|
|
||
| def _extract_labels(self, y: Any) -> List[str]: | ||
| if isinstance(y, (pd.Series, np.ndarray)): | ||
|
|
@@ -95,6 +137,13 @@ def _extract_labels(self, y: Any) -> List[str]: | |
| def _get_prompt(self, x) -> str: | ||
| return build_zero_shot_prompt_slc(x, repr(self.classes_)) | ||
|
|
||
| def _get_default_label(self): | ||
| """ Returns the default label based on the default_label argument. """ | ||
| if self.default_label == "Random": | ||
| return random.choices(self.classes_, self.probabilities_)[0] | ||
| else: | ||
| return self.default_label | ||
|
|
||
| def _predict_single(self, x): | ||
| completion = self._get_chat_completion(x) | ||
| try: | ||
|
|
@@ -116,7 +165,7 @@ def _predict_single(self, x): | |
| if label not in self.classes_: | ||
| label = label.replace("'", "").replace('"', "") | ||
| if label not in self.classes_: # try again | ||
| label = random.choices(self.classes_, self.probabilities_)[0] | ||
| label = self._get_default_label() | ||
| return label | ||
|
|
||
| def fit( | ||
|
|
@@ -129,16 +178,37 @@ def fit( | |
|
|
||
|
|
||
| class MultiLabelZeroShotGPTClassifier(_BaseZeroShotGPTClassifier): | ||
| """Zero-shot classifier for multilabel classification. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| openai_key : Optional[str] , default : None | ||
| Your OpenAI API key. If None, the key will be read from the SKLLM_CONFIG_OPENAI_KEY environment variable. | ||
| openai_org : Optional[str] , default : None | ||
| Your OpenAI organization. If None, the organization will be read from the SKLLM_CONFIG_OPENAI_ORG | ||
| environment variable. | ||
| openai_model : str , default : "gpt-3.5-turbo" | ||
| The OpenAI model to use. See https://beta.openai.com/docs/api-reference/available-models for a list of | ||
| available models. | ||
| default_label : Optional[Union[List[str], Literal['Random']] , default : 'Random' | ||
| The default label to use if the LLM could not generate a response for a sample. If set to 'Random' a random | ||
| label will be chosen based on probabilities from the training set. | ||
| max_labels : int , default : 3 | ||
| The maximum number of labels to predict for each sample. | ||
| """ | ||
| def __init__( | ||
| self, | ||
| openai_key: Optional[str] = None, | ||
| openai_org: Optional[str] = None, | ||
| openai_model: str = "gpt-3.5-turbo", | ||
| default_label: Optional[Union[List[str], Literal['Random']]] = 'Random', | ||
| max_labels: int = 3, | ||
| ): | ||
| super().__init__(openai_key, openai_org, openai_model) | ||
| super().__init__(openai_key, openai_org, openai_model, default_label) | ||
| if max_labels < 2: | ||
| raise ValueError("max_labels should be at least 2") | ||
| if isinstance(default_label, str) and default_label != "Random": | ||
| raise ValueError("default_label should be a list of strings or 'Random'") | ||
| self.max_labels = max_labels | ||
|
|
||
| def _extract_labels(self, y) -> List[str]: | ||
|
|
@@ -152,6 +222,19 @@ def _extract_labels(self, y) -> List[str]: | |
| def _get_prompt(self, x) -> str: | ||
| return build_zero_shot_prompt_mlc(x, repr(self.classes_), self.max_labels) | ||
|
|
||
| def _get_default_label(self): | ||
| """ Returns the default label based on the default_label argument. """ | ||
| result = [] | ||
| if isinstance(self.default_label, str) and self.default_label == "Random": | ||
| for cls, probability in zip(self.classes_, self.probabilities_): | ||
| coin_flip = random.choices([0,1], [1-probability, probability])[0] | ||
| if coin_flip == 1: | ||
| result.append(cls) | ||
| else: | ||
| result = self.default_label | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that this behaviour might still be a bit confusing for the users. If a string != "Random" is provided instead of a list, the label will again be a string. So, I would still add an additional type check and convert to list whenever applicable. If you intentionally want to have a flexibility of having non-list outputs, maybe this could be done for
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See lines 221-222, The output can be either a list or a None.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Nadav-Barak you mean lines 210-211? I must have missed that change. |
||
|
|
||
| return result | ||
|
|
||
| def _predict_single(self, x): | ||
| completion = self._get_chat_completion(x) | ||
| try: | ||
|
|
@@ -162,11 +245,10 @@ def _predict_single(self, x): | |
| labels = [] | ||
|
|
||
| labels = list(filter(lambda l: l in self.classes_, labels)) | ||
|
|
||
| if len(labels) > self.max_labels: | ||
| labels = labels[: self.max_labels - 1] | ||
| elif len(labels) < 1: | ||
| labels = [random.choices(self.classes_, self.probabilities_)[0]] | ||
| if len(labels) == 0: | ||
| labels = self._get_default_label() | ||
| if labels is not None and len(labels) > self.max_labels: | ||
| labels = labels[:self.max_labels - 1] | ||
| return labels | ||
|
|
||
| def fit( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,6 @@ | ||
| import openai | ||
|
|
||
|
|
||
| def set_credentials(key: str, org: str): | ||
| openai.api_key = key | ||
| openai.organization = org |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1 @@ | ||
| from . import test_chatgpt | ||
| from . import test_gpt_zero_shot_clf | ||
| from . import test_chatgpt, test_gpt_zero_shot_clf |
Uh oh!
There was an error while loading. Please reload this page.