diff --git a/pyproject.toml b/pyproject.toml index bf1dd45..2818764 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,8 @@ ignore = [ "D205", "D401", "E501", + "N803", + "N806", ] extend-exclude = ["tests/*.py", "setup.py"] target-version = "py38" diff --git a/skllm/models/gpt_zero_shot_clf.py b/skllm/models/gpt_zero_shot_clf.py index 00354ec..b57f700 100644 --- a/skllm/models/gpt_zero_shot_clf.py +++ b/skllm/models/gpt_zero_shot_clf.py @@ -1,20 +1,25 @@ -from typing import Optional, Union, List, Any -import numpy as np -import pandas as pd -from collections import Counter import random -from tqdm import tqdm from abc import ABC, abstractmethod +from collections import Counter +from typing import Any, List, Optional, Union + +import numpy as np +import pandas as pd from sklearn.base import BaseEstimator, ClassifierMixin -from skllm.openai.prompts import get_zero_shot_prompt_slc, get_zero_shot_prompt_mlc +from tqdm import tqdm + from skllm.openai.chatgpt import ( construct_message, - get_chat_completion, extract_json_key, + get_chat_completion, ) -from skllm.config import SKLLMConfig as _Config -from skllm.utils import to_numpy as _to_numpy from skllm.openai.mixin import OpenAIMixin as _OAIMixin +from skllm.prompts.builders import ( + build_zero_shot_prompt_mlc, + build_zero_shot_prompt_slc, +) +from skllm.utils import to_numpy as _to_numpy + class _BaseZeroShotGPTClassifier(ABC, BaseEstimator, ClassifierMixin, _OAIMixin): def __init__( @@ -34,7 +39,7 @@ def fit( X: Optional[Union[np.ndarray, pd.Series, List[str]]], y: Union[np.ndarray, pd.Series, List[str], List[List[str]]], ): - X = self._to_np(X) + X = self._to_np(X) self.classes_, self.probabilities_ = self._get_unique_targets(y) return self @@ -91,7 +96,7 @@ def _extract_labels(self, y: Any) -> List[str]: return labels def _get_prompt(self, x) -> str: - return get_zero_shot_prompt_slc(x, self.classes_) + return build_zero_shot_prompt_slc(x, repr(self.classes_)) def _predict_single(self, x): completion = self._get_chat_completion(x) @@ -99,7 +104,7 @@ def _predict_single(self, x): label = str( extract_json_key(completion.choices[0].message["content"], "label") ) - except Exception as e: + except Exception: label = "" if label not in self.classes_: @@ -137,7 +142,7 @@ def _extract_labels(self, y) -> List[str]: return labels def _get_prompt(self, x) -> str: - return get_zero_shot_prompt_mlc(x, self.classes_, self.max_labels) + return build_zero_shot_prompt_mlc(x, repr(self.classes_), self.max_labels) def _predict_single(self, x): completion = self._get_chat_completion(x) @@ -145,7 +150,7 @@ def _predict_single(self, x): labels = extract_json_key(completion.choices[0].message["content"], "label") if not isinstance(labels, list): raise RuntimeError("Invalid labels type, expected list") - except Exception as e: + except Exception: labels = [] labels = list(filter(lambda l: l in self.classes_, labels)) diff --git a/skllm/openai/prompts.py b/skllm/openai/prompts.py deleted file mode 100644 index 63a7198..0000000 --- a/skllm/openai/prompts.py +++ /dev/null @@ -1,50 +0,0 @@ -def get_zero_shot_prompt_slc(x, labels): - lines = [ - "You will be provided with the following information:", - "1. An arbitrary text sample. The sample is delimited with triple backticks.", - "2. List of categories the text sample can be assigned to. The list is delimited with square brackets. The categories in the list are enclosed in the single quotes and comma separated.", - "", - "Perform the following tasks:", - "1. Identify to which category the provided text belongs to with the highest probability.", - "2. Assign the provided text to that category.", - "3. Provide your response in a JSON format containing a single key `label` and a value corresponding to the assigned category. Do not provide any additional information except the JSON.", - "\n", - f"List of categories: {repr(labels)}", - "\n", - f"Text sample: ```{x}```", - "\n", - "Your JSON response: " - ] - prompt = "\n".join(lines) - return prompt - -def get_zero_shot_prompt_mlc(x, labels, max_cats): - lines = [ - "You will be provided with the following information:", - "1. An arbitrary text sample. The sample is delimited with triple backticks.", - f"2. List of categories the text sample can be assigned to. The list is delimited with square brackets. The categories in the list are enclosed in the single quotes and comma separated. The text sample belongs to at least one category but cannot exceed {max_cats}.", - "", - "Perform the following tasks:", - "1. Identify to which categories the provided text belongs to with the highest probability.", - f"2. Assign the text sample to at least 1 but up to {max_cats} categories based on the probabilities.", - "3. Provide your response in a JSON format containing a single key `label` and a value corresponding to the list of assigned categories. Do not provide any additional information except the JSON.", - "\n", - f"List of categories: {repr(labels)}", - "\n", - f"Text sample: ```{x}```", - "\n", - "Your JSON response: " - ] - prompt = "\n".join(lines) - return prompt - -def get_summary_prompt(x, max_words): - lines = [ - "Your task is to generate a summary of the text sample.", - f"Summarize the text sample provided below, delimited by triple backticks, in at most {max_words} words.", - "\n", - f"Text sample: ```{x}```", - f"Summarized text:" - ] - prompt = "\n".join(lines) - return prompt \ No newline at end of file diff --git a/skllm/preprocessing/gpt_summarizer.py b/skllm/preprocessing/gpt_summarizer.py index 88d7c26..8804275 100644 --- a/skllm/preprocessing/gpt_summarizer.py +++ b/skllm/preprocessing/gpt_summarizer.py @@ -1,9 +1,10 @@ -from skllm.openai.prompts import get_summary_prompt from typing import Optional + from skllm.openai.base_gpt import BaseZeroShotGPTTransformer as _BaseGPT +from skllm.prompts.builders import build_summary_prompt -class GPTSummarizer(_BaseGPT): +class GPTSummarizer(_BaseGPT): system_msg = "You are a text summarizer." default_output = "Summary is unavailable." @@ -17,6 +18,6 @@ def __init__( self._set_keys(openai_key, openai_org) self.openai_model = openai_model self.max_words = max_words - + def _get_prompt(self, X) -> str: - return get_summary_prompt(X, self.max_words) \ No newline at end of file + return build_summary_prompt(X, self.max_words) diff --git a/skllm/prompts/builders.py b/skllm/prompts/builders.py new file mode 100644 index 0000000..c9d7418 --- /dev/null +++ b/skllm/prompts/builders.py @@ -0,0 +1,80 @@ +from typing import Union + +from skllm.prompts.templates import ( + SUMMARY_PROMPT_TEMPLATE, + ZERO_SHOT_CLF_PROMPT_TEMPLATE, + ZERO_SHOT_MLCLF_PROMPT_TEMPLATE, +) + +# TODO add validators + + +def build_zero_shot_prompt_slc( + x: str, labels: str, template: str = ZERO_SHOT_CLF_PROMPT_TEMPLATE +) -> str: + """Builds a prompt for zero-shot single-label classification. + + Parameters + ---------- + x : str + sample to classify + labels : str + candidate labels in a list-like representation + template : str + prompt template to use, must contain placeholders for all variables, by default ZERO_SHOT_CLF_PROMPT_TEMPLATE + + Returns + ------- + str + prepared prompt + """ + return template.format(x=x, labels=labels) + + +def build_zero_shot_prompt_mlc( + x: str, + labels: str, + max_cats: Union[int, str], + template: str = ZERO_SHOT_MLCLF_PROMPT_TEMPLATE, +) -> str: + """Builds a prompt for zero-shot multi-label classification. + + Parameters + ---------- + x : str + sample to classify + labels : str + candidate labels in a list-like representation + max_cats : Union[int,str] + maximum number of categories to assign + template : str + prompt template to use, must contain placeholders for all variables, by default ZERO_SHOT_MLCLF_PROMPT_TEMPLATE + + Returns + ------- + str + prepared prompt + """ + return template.format(x=x, labels=labels, max_cats=max_cats) + + +def build_summary_prompt( + x: str, max_words: Union[int, str], template: str = SUMMARY_PROMPT_TEMPLATE +) -> str: + """Builds a prompt for text summarization. + + Parameters + ---------- + x : str + sample to summarize + max_words : Union[int,str] + maximum number of words to use in the summary + template : str + prompt template to use, must contain placeholders for all variables, by default SUMMARY_PROMPT_TEMPLATE + + Returns + ------- + str + prepared prompt + """ + return template.format(x=x, max_words=max_words) diff --git a/skllm/prompts/templates.py b/skllm/prompts/templates.py new file mode 100644 index 0000000..18785c0 --- /dev/null +++ b/skllm/prompts/templates.py @@ -0,0 +1,41 @@ +ZERO_SHOT_CLF_PROMPT_TEMPLATE = """ +You will be provided with the following information: +1. An arbitrary text sample. The sample is delimited with triple backticks. +2. List of categories the text sample can be assigned to. The list is delimited with square brackets. The categories in the list are enclosed in the single quotes and comma separated. + +Perform the following tasks: +1. Identify to which category the provided text belongs to with the highest probability. +2. Assign the provided text to that category. +3. Provide your response in a JSON format containing a single key `label` and a value corresponding to the assigned category. Do not provide any additional information except the JSON. + +List of categories: {labels} + +Text sample: ```{x}``` + +Your JSON response: +""" + +ZERO_SHOT_MLCLF_PROMPT_TEMPLATE = """ +You will be provided with the following information: +1. An arbitrary text sample. The sample is delimited with triple backticks. +2. List of categories the text sample can be assigned to. The list is delimited with square brackets. The categories in the list are enclosed in the single quotes and comma separated. The text sample belongs to at least one category but cannot exceed {max_cats}. + +Perform the following tasks: +1. Identify to which categories the provided text belongs to with the highest probability. +2. Assign the text sample to at least 1 but up to {max_cats} categories based on the probabilities. +3. Provide your response in a JSON format containing a single key `label` and a value corresponding to the list of assigned categories. Do not provide any additional information except the JSON. + +List of categories: {labels} + +Text sample: ```{x}``` + +Your JSON response: +""" + +SUMMARY_PROMPT_TEMPLATE = """ +Your task is to generate a summary of the text sample. +Summarize the text sample provided below, delimited by triple backticks, in at most {max_words} words. + +Text sample: ```{x}``` +Summarized text: +"""