diff --git a/.gitignore b/.gitignore index 707cff7..8b6a6c8 100644 --- a/.gitignore +++ b/.gitignore @@ -167,3 +167,4 @@ tmp.py # vscode .vscode/ tmp2.py +tmp.* \ No newline at end of file diff --git a/skllm/models/_base/tagger.py b/skllm/models/_base/tagger.py new file mode 100644 index 0000000..2bb90aa --- /dev/null +++ b/skllm/models/_base/tagger.py @@ -0,0 +1,152 @@ +from typing import Any, Union, List, Optional, Dict +from abc import abstractmethod, ABC +from numpy import ndarray +from tqdm import tqdm +import numpy as np +import pandas as pd +from skllm.utils import to_numpy as _to_numpy +from sklearn.base import ( + BaseEstimator as _SklBaseEstimator, + TransformerMixin as _SklTransformerMixin, +) + +from skllm.utils.rendering import display_ner +from skllm.utils.xml import filter_xml_tags, filter_unwanted_entities, json_to_xml +from skllm.prompts.builders import build_ner_prompt +from skllm.prompts.templates import ( + NER_SYSTEM_MESSAGE_TEMPLATE, + NER_SYSTEM_MESSAGE_SPARSE, + EXPLAINABLE_NER_DENSE_PROMPT_TEMPLATE, + EXPLAINABLE_NER_SPARSE_PROMPT_TEMPLATE, +) +from skllm.utils import re_naive_json_extractor +import json +from concurrent.futures import ThreadPoolExecutor + +class BaseTagger(ABC, _SklBaseEstimator, _SklTransformerMixin): + + num_workers = 1 + + def fit(self, X: Any, y: Any = None): + """ + Fits the model to the data. Usually a no-op. + + Parameters + ---------- + X : Any + training data + y : Any + training outputs + + Returns + ------- + self + BaseTagger + """ + return self + + def predict(self, X: Union[np.ndarray, pd.Series, List[str]]): + return self.transform(X) + + def fit_transform( + self, + X: Union[np.ndarray, pd.Series, List[str]], + y: Optional[Union[np.ndarray, pd.Series, List[str]]] = None, + ) -> ndarray: + return self.fit(X, y).transform(X) + + def transform(self, X: Union[np.ndarray, pd.Series, List[str]]): + """ + Transforms the input data. + + Parameters + ---------- + X : Union[np.ndarray, pd.Series, List[str]] + The input data to predict the class of. + + Returns + ------- + List[str] + """ + X = _to_numpy(X) + predictions = [] + with ThreadPoolExecutor(max_workers=self.num_workers) as executor: + predictions = list(tqdm(executor.map(self._predict_single, X), total=len(X))) + return np.asarray(predictions) + + def _predict_single(self, x: Any) -> Any: + prompt_dict = self._get_prompt(x) + # this will be inherited from the LLM + prediction = self._get_chat_completion(model=self.model, **prompt_dict) + prediction = self._convert_completion_to_str(prediction) + return prediction + + @abstractmethod + def _get_prompt(self, x: str) -> dict: + """Returns the prompt to use for a single input.""" + pass + + +class ExplainableNER(BaseTagger): + entities: Optional[Dict[str, str]] = None + sparse_output = True + _allowed_tags = ["entity", "not_entity"] + + display_predictions = False + + def fit(self, X: Any, y: Any = None): + entities = [] + for k, v in self.entities.items(): + entities.append({"entity": k, "definition": v}) + self.expanded_entities_ = entities + return self + + def transform(self, X: ndarray | pd.Series | List[str]): + predictions = super().transform(X) + if self.sparse_output: + json_predictions = [ + re_naive_json_extractor(p, expected_output="array") for p in predictions + ] + predictions = [] + attributes = ["reasoning", "tag", "value"] + for x, p in zip(X, json_predictions): + p_json = json.loads(p) + predictions.append( + json_to_xml( + x, + p_json, + "entity", + "not_entity", + value_key="value", + attributes=attributes, + ) + ) + + predictions = [ + filter_unwanted_entities( + filter_xml_tags(p, self._allowed_tags), self.entities + ) + for p in predictions + ] + if self.display_predictions: + print("Displaying predictions...") + display_ner(predictions, self.entities) + return predictions + + def _get_prompt(self, x: str) -> dict: + if not hasattr(self, "expanded_entities_"): + raise ValueError("Model not fitted.") + system_message = ( + NER_SYSTEM_MESSAGE_TEMPLATE.format(entities=self.entities.keys()) + if self.sparse_output + else NER_SYSTEM_MESSAGE_SPARSE + ) + template = ( + EXPLAINABLE_NER_SPARSE_PROMPT_TEMPLATE + if self.sparse_output + else EXPLAINABLE_NER_DENSE_PROMPT_TEMPLATE + ) + return { + "messages": build_ner_prompt(self.expanded_entities_, x, template=template), + "system_message": system_message.format(entities=self.entities.keys()), + } diff --git a/skllm/models/gpt/tagging/ner.py b/skllm/models/gpt/tagging/ner.py new file mode 100644 index 0000000..98857a8 --- /dev/null +++ b/skllm/models/gpt/tagging/ner.py @@ -0,0 +1,42 @@ +from skllm.models._base.tagger import ExplainableNER as _ExplainableNER +from skllm.llm.gpt.mixin import GPTTextCompletionMixin as _GPTTextCompletionMixin +from typing import Optional, Dict + + +class GPTExplainableNER(_ExplainableNER, _GPTTextCompletionMixin): + def __init__( + self, + entities: Dict[str, str], + display_predictions: bool = False, + sparse_output: bool = True, + model: str = "gpt-4o", + key: Optional[str] = None, + org: Optional[str] = None, + num_workers: int = 1, + ) -> None: + """ + Named entity recognition using OpenAI/GPT API-compatible models. + + Parameters + ---------- + entities : dict + dictionary of entities to recognize, with keys as entity names and values as descriptions + display_predictions : bool, optional + whether to display predictions, by default False + sparse_output : bool, optional + whether to generate a sparse representation of the predictions, by default True + model : str, optional + model to use, by default "gpt-4o" + key : Optional[str], optional + estimator-specific API key; if None, retrieved from the global config, by default None + org : Optional[str], optional + estimator-specific ORG key; if None, retrieved from the global config, by default None + num_workers : int, optional + number of workers (threads) to use, by default 1 + """ + self._set_keys(key, org) + self.model = model + self.entities = entities + self.display_predictions = display_predictions + self.sparse_output = sparse_output + self.num_workers = num_workers \ No newline at end of file diff --git a/skllm/prompts/builders.py b/skllm/prompts/builders.py index 823c20a..e24ae72 100644 --- a/skllm/prompts/builders.py +++ b/skllm/prompts/builders.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Union, Optional from skllm.prompts.templates import ( FEW_SHOT_CLF_PROMPT_TEMPLATE, @@ -8,6 +8,7 @@ TRANSLATION_PROMPT_TEMPLATE, ZERO_SHOT_CLF_PROMPT_TEMPLATE, ZERO_SHOT_MLCLF_PROMPT_TEMPLATE, + EXPLAINABLE_NER_DENSE_PROMPT_TEMPLATE, ) # TODO add validators @@ -190,3 +191,27 @@ def build_translation_prompt( prepared prompt """ return template.format(x=x, output_language=output_language) + + +def build_ner_prompt( + entities: list, + x: str, + template: str = EXPLAINABLE_NER_DENSE_PROMPT_TEMPLATE, +) -> str: + """Builds a prompt for named entity recognition. + + Parameters + ---------- + entities : list + list of entities to recognize + x : str + sample to recognize entities in + template : str, optional + prompt template to use, must contain placeholders for all variables, by default EXPLAINABLE_NER_DENSE_PROMPT_TEMPLATE + + Returns + ------- + str + prepared prompt + """ + return template.format(entities=entities, x=x) diff --git a/skllm/prompts/templates.py b/skllm/prompts/templates.py index 552230f..96334b7 100644 --- a/skllm/prompts/templates.py +++ b/skllm/prompts/templates.py @@ -118,3 +118,90 @@ Original text: ```{x}``` Output: """ + +NER_SYSTEM_MESSAGE_TEMPLATE = """You are an expert in Natural Language Processing. Your task is to identify common Named Entities (NER) in a text provided by the user. +Mark the entities with tags according to the following guidelines: + - Use XML format to tag entities; + - All entities must be enclosed in ... tags; All other text must be enclosed in ... tags; No content should be outside of these tags; + - The tagging operation must be invertible, i.e. the original text must be recoverable from the tagged textl; This is crucial and easy to overlook, double-check this requirement; + - Adjacent entities should be separated into different tags; + - The list of entities is strictly restricted to the following: {entities}. +""" + +NER_SYSTEM_MESSAGE_SPARSE = """You are an expert in Natural Language Processing.""" + +EXPLAINABLE_NER_DENSE_PROMPT_TEMPLATE = """You are provided with a text. Your task is to identify and tag all named entities within the text using the following entity types only: +{entities} + +For each entity, provide a brief explanation for your choice within an XML comment. Use the following XML tag format for each entity: + +Your reasoning hereENTITY_NAME_UPPERCASEEntity text + +The remaining text must be enclosed in a TEXT tag. + +Focus on the context and meaning of each entity rather than just the exact words. The tags should encompass the entire entity based on its definition and usage in the sentence. It is crucial to base your decision on the description of the entity, not just its name. + +Format example: + +Input: +```This text contains some entity and another entity.``` + +Output: +```xml +This text contains some justificationENTITY1some entity and another another justificationENTITY2entity. +``` + +Input: +``` +{x} +``` + +Output (origina text with tags): +""" + + +EXPLAINABLE_NER_SPARSE_PROMPT_TEMPLATE = """You are provided with a text. Your task is to identify and tag all named entities within the text using the following entity types only: +{entities} + +You must provide the following information for each entity: +- The reasoning of why you tagged the entity as such; Based on the reasoning, a non-expert should be able to evaluate your decision; +- The tag of the entity (uppercase); +- The value of the entity (as it appears in the text). + +Your response should be json formatted using the following schema: + +{{ + "$schema": "http://json-schema.org/draft-04/schema#", + "type": "array", + "items": [ + {{ + "type": "object", + "properties": {{ + "reasoning": {{ + "type": "string" + }}, + "tag": {{ + "type": "string" + }}, + "value": {{ + "type": "string" + }} + }}, + "required": [ + "reasoning", + "tag", + "value" + ] + }} + ] +}} + + +Input: +``` +{x} +``` + +Output json: +""" + diff --git a/skllm/utils.py b/skllm/utils/__init__.py similarity index 75% rename from skllm/utils.py rename to skllm/utils/__init__.py index ca6e968..29436ac 100644 --- a/skllm/utils.py +++ b/skllm/utils/__init__.py @@ -4,7 +4,7 @@ import pandas as pd from functools import wraps from time import sleep - +import re def to_numpy(X: Any) -> np.ndarray: """Converts a pandas Series or list to a numpy array. @@ -23,10 +23,11 @@ def to_numpy(X: Any) -> np.ndarray: elif isinstance(X, list): X = np.asarray(X, dtype=object) if isinstance(X, np.ndarray) and len(X.shape) > 1: - X = np.squeeze(X) + # do not squeeze the first dim + X = np.squeeze(X, axis=tuple([i for i in range(1, len(X.shape))])) return X - +# TODO: replace with re version below def find_json_in_string(string: str) -> str: """Finds the JSON object in a string. @@ -48,6 +49,30 @@ def find_json_in_string(string: str) -> str: return json_string + +def re_naive_json_extractor(json_string: str, expected_output: str = "object") -> str: + """Finds the first JSON-like object or array in a string using regex. + + Parameters + ---------- + string : str + The string to search for a JSON object or array. + + Returns + ------- + json_string : str + A JSON string if found, otherwise an empty JSON object. + """ + json_pattern = json_pattern = r'(\{.*\}|\[.*\])' + match = re.search(json_pattern, json_string, re.DOTALL) + if match: + return match.group(0) + else: + return r"{}" if expected_output == "object" else "[]" + + + + def extract_json_key(json_: str, key: str): """Extracts JSON key from a string. diff --git a/skllm/utils/rendering.py b/skllm/utils/rendering.py new file mode 100644 index 0000000..d9f4289 --- /dev/null +++ b/skllm/utils/rendering.py @@ -0,0 +1,139 @@ +import random +import re +import html +from typing import Dict, List + +color_palettes = { + "light": [ + "lightblue", + "lightgreen", + "lightcoral", + "lightsalmon", + "lightyellow", + "lightpink", + "lightgray", + "lightcyan", + ], + "dark": [ + "darkblue", + "darkgreen", + "darkred", + "darkorange", + "darkgoldenrod", + "darkmagenta", + "darkgray", + "darkcyan", + ], +} + + +def get_random_color(): + return f"#{random.randint(0, 0xFFFFFF):06x}" + + +# def validate_text(input_text, output_text): +# # Verify the original text was not changed (other than addition of tags) +# stripped_output_text = re.sub(r'<<.*?>>', '', output_text) +# stripped_output_text = re.sub(r'<>', '', stripped_output_text) +# if not all(word in stripped_output_text.split() for word in input_text.split()): +# raise ValueError("Original text was altered.") +# return True + + +# TODO: In the future this should probably be replaced with a proper HTML template +def render_ner(output_texts, allowed_entities): + entity_colors = {} + all_entities = [k.upper() for k in allowed_entities.keys()] + + for i, entity in enumerate(all_entities): + if i < len(color_palettes["light"]): + entity_colors[entity] = { + "light": color_palettes["light"][i], + "dark": color_palettes["dark"][i], + } + else: + random_color = get_random_color() + entity_colors[entity] = {"light": random_color, "dark": random_color} + + def replace_match(match): + reasoning, entity, text = match.groups() + entity = entity.upper() + return ( + f'{text}' + ) + + legend_html = "
" + legend_html += "" + legend_html += "Entities: " + for entity in entity_colors.keys(): + description = allowed_entities.get(entity, "No description") + legend_html += ( + f'{entity} ' + ) + legend_html += "

" + + css = "" + + rendered_html = "" + for output_text in output_texts: + none_pattern = re.compile(r"(.*?)") + output_text = none_pattern.sub(r'\1', output_text) + pattern = re.compile(r"(.*?)(.*?)(.*?)") + highlighted_html = pattern.sub(replace_match, output_text) + rendered_html += highlighted_html + "
" + + return css + legend_html + rendered_html + + +def display_ner(output_texts: List[str], allowed_entities: Dict[str, str]): + rendered_html = render_ner(output_texts, allowed_entities) + if is_running_in_jupyter(): + from IPython.display import display, HTML + + display(HTML(rendered_html)) + else: + with open("skllm_ner_output.html", "w") as f: + f.write(rendered_html) + try: + import webbrowser + + webbrowser.open("skllm_ner_output.html") + except Exception: + print( + "Output saved to 'skllm_ner_output.html', please open it in a browser." + ) + + +def is_running_in_jupyter(): + try: + from IPython import get_ipython + + if "IPKernelApp" in get_ipython().config: + return True + except Exception: + return False + return False diff --git a/skllm/utils/xml.py b/skllm/utils/xml.py new file mode 100644 index 0000000..84ef2e1 --- /dev/null +++ b/skllm/utils/xml.py @@ -0,0 +1,60 @@ +import re + + +def filter_xml_tags(xml_string, tags): + pattern = "|".join(f"<{tag}>.*?" for tag in tags) + regex = re.compile(pattern, re.DOTALL) + matches = regex.findall(xml_string) + return "".join(matches) + + +def filter_unwanted_entities(xml_string, allowed_entities): + allowed_values_pattern = "|".join(allowed_entities.keys()) + replacement = r"\3" + pattern = rf"(.*?)(?!{allowed_values_pattern})(.*?)(.*?)" + return re.sub(pattern, replacement, xml_string) + + +def replace_all_at_once(text, replacements): + sorted_keys = sorted(replacements, key=len, reverse=True) + regex = re.compile(r"(" + "|".join(map(re.escape, sorted_keys)) + r")") + return regex.sub(lambda match: replacements[match.group(0)], text) + + +def json_to_xml( + original_text: str, + tags: list, + tag_root: str, + non_tag_root: str, + value_key: str = "value", + attributes: list = None, +): + + if len(tags) == 0: + return f"<{non_tag_root}>{original_text}" + + if attributes is None: + attributes = tags[0].keys() + + replacements = {} + for item in tags: + value = item.get(value_key, "") + if not value: + continue + + attribute_parts = [] + for attr in attributes: + if attr in item: + attribute_parts.append(f"<{attr}>{item[attr]}") + attribute_str = "".join(attribute_parts) + replacements[value] = f"<{tag_root}>{attribute_str}" + original_text = replace_all_at_once(original_text, replacements) + + parts = re.split(f"(<{tag_root}>.*?)", original_text) + final_text = "" + for part in parts: + if not part.startswith(f"<{tag_root}>"): + final_text += f"<{non_tag_root}>{part}" + else: + final_text += part + return final_text