Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,4 @@ tmp.py
# vscode
.vscode/
tmp2.py
tmp.*
152 changes: 152 additions & 0 deletions skllm/models/_base/tagger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
from typing import Any, Union, List, Optional, Dict
from abc import abstractmethod, ABC
from numpy import ndarray
from tqdm import tqdm
import numpy as np
import pandas as pd
from skllm.utils import to_numpy as _to_numpy
from sklearn.base import (
BaseEstimator as _SklBaseEstimator,
TransformerMixin as _SklTransformerMixin,
)

from skllm.utils.rendering import display_ner
from skllm.utils.xml import filter_xml_tags, filter_unwanted_entities, json_to_xml
from skllm.prompts.builders import build_ner_prompt
from skllm.prompts.templates import (
NER_SYSTEM_MESSAGE_TEMPLATE,
NER_SYSTEM_MESSAGE_SPARSE,
EXPLAINABLE_NER_DENSE_PROMPT_TEMPLATE,
EXPLAINABLE_NER_SPARSE_PROMPT_TEMPLATE,
)
from skllm.utils import re_naive_json_extractor
import json
from concurrent.futures import ThreadPoolExecutor

class BaseTagger(ABC, _SklBaseEstimator, _SklTransformerMixin):

num_workers = 1

def fit(self, X: Any, y: Any = None):
"""
Fits the model to the data. Usually a no-op.

Parameters
----------
X : Any
training data
y : Any
training outputs

Returns
-------
self
BaseTagger
"""
return self

def predict(self, X: Union[np.ndarray, pd.Series, List[str]]):
return self.transform(X)

def fit_transform(
self,
X: Union[np.ndarray, pd.Series, List[str]],
y: Optional[Union[np.ndarray, pd.Series, List[str]]] = None,
) -> ndarray:
return self.fit(X, y).transform(X)

def transform(self, X: Union[np.ndarray, pd.Series, List[str]]):
"""
Transforms the input data.

Parameters
----------
X : Union[np.ndarray, pd.Series, List[str]]
The input data to predict the class of.

Returns
-------
List[str]
"""
X = _to_numpy(X)
predictions = []
with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
predictions = list(tqdm(executor.map(self._predict_single, X), total=len(X)))
return np.asarray(predictions)

def _predict_single(self, x: Any) -> Any:
prompt_dict = self._get_prompt(x)
# this will be inherited from the LLM
prediction = self._get_chat_completion(model=self.model, **prompt_dict)
prediction = self._convert_completion_to_str(prediction)
return prediction

@abstractmethod
def _get_prompt(self, x: str) -> dict:
"""Returns the prompt to use for a single input."""
pass


class ExplainableNER(BaseTagger):
entities: Optional[Dict[str, str]] = None
sparse_output = True
_allowed_tags = ["entity", "not_entity"]

display_predictions = False

def fit(self, X: Any, y: Any = None):
entities = []
for k, v in self.entities.items():
entities.append({"entity": k, "definition": v})
self.expanded_entities_ = entities
return self

def transform(self, X: ndarray | pd.Series | List[str]):
predictions = super().transform(X)
if self.sparse_output:
json_predictions = [
re_naive_json_extractor(p, expected_output="array") for p in predictions
]
predictions = []
attributes = ["reasoning", "tag", "value"]
for x, p in zip(X, json_predictions):
p_json = json.loads(p)
predictions.append(
json_to_xml(
x,
p_json,
"entity",
"not_entity",
value_key="value",
attributes=attributes,
)
)

predictions = [
filter_unwanted_entities(
filter_xml_tags(p, self._allowed_tags), self.entities
)
for p in predictions
]
if self.display_predictions:
print("Displaying predictions...")
display_ner(predictions, self.entities)
return predictions

def _get_prompt(self, x: str) -> dict:
if not hasattr(self, "expanded_entities_"):
raise ValueError("Model not fitted.")
system_message = (
NER_SYSTEM_MESSAGE_TEMPLATE.format(entities=self.entities.keys())
if self.sparse_output
else NER_SYSTEM_MESSAGE_SPARSE
)
template = (
EXPLAINABLE_NER_SPARSE_PROMPT_TEMPLATE
if self.sparse_output
else EXPLAINABLE_NER_DENSE_PROMPT_TEMPLATE
)
return {
"messages": build_ner_prompt(self.expanded_entities_, x, template=template),
"system_message": system_message.format(entities=self.entities.keys()),
}
42 changes: 42 additions & 0 deletions skllm/models/gpt/tagging/ner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from skllm.models._base.tagger import ExplainableNER as _ExplainableNER
from skllm.llm.gpt.mixin import GPTTextCompletionMixin as _GPTTextCompletionMixin
from typing import Optional, Dict


class GPTExplainableNER(_ExplainableNER, _GPTTextCompletionMixin):
def __init__(
self,
entities: Dict[str, str],
display_predictions: bool = False,
sparse_output: bool = True,
model: str = "gpt-4o",
key: Optional[str] = None,
org: Optional[str] = None,
num_workers: int = 1,
) -> None:
"""
Named entity recognition using OpenAI/GPT API-compatible models.

Parameters
----------
entities : dict
dictionary of entities to recognize, with keys as entity names and values as descriptions
display_predictions : bool, optional
whether to display predictions, by default False
sparse_output : bool, optional
whether to generate a sparse representation of the predictions, by default True
model : str, optional
model to use, by default "gpt-4o"
key : Optional[str], optional
estimator-specific API key; if None, retrieved from the global config, by default None
org : Optional[str], optional
estimator-specific ORG key; if None, retrieved from the global config, by default None
num_workers : int, optional
number of workers (threads) to use, by default 1
"""
self._set_keys(key, org)
self.model = model
self.entities = entities
self.display_predictions = display_predictions
self.sparse_output = sparse_output
self.num_workers = num_workers
27 changes: 26 additions & 1 deletion skllm/prompts/builders.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union
from typing import Union, Optional

from skllm.prompts.templates import (
FEW_SHOT_CLF_PROMPT_TEMPLATE,
Expand All @@ -8,6 +8,7 @@
TRANSLATION_PROMPT_TEMPLATE,
ZERO_SHOT_CLF_PROMPT_TEMPLATE,
ZERO_SHOT_MLCLF_PROMPT_TEMPLATE,
EXPLAINABLE_NER_DENSE_PROMPT_TEMPLATE,
)

# TODO add validators
Expand Down Expand Up @@ -190,3 +191,27 @@ def build_translation_prompt(
prepared prompt
"""
return template.format(x=x, output_language=output_language)


def build_ner_prompt(
entities: list,
x: str,
template: str = EXPLAINABLE_NER_DENSE_PROMPT_TEMPLATE,
) -> str:
"""Builds a prompt for named entity recognition.

Parameters
----------
entities : list
list of entities to recognize
x : str
sample to recognize entities in
template : str, optional
prompt template to use, must contain placeholders for all variables, by default EXPLAINABLE_NER_DENSE_PROMPT_TEMPLATE

Returns
-------
str
prepared prompt
"""
return template.format(entities=entities, x=x)
87 changes: 87 additions & 0 deletions skllm/prompts/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,90 @@
Original text: ```{x}```
Output:
"""

NER_SYSTEM_MESSAGE_TEMPLATE = """You are an expert in Natural Language Processing. Your task is to identify common Named Entities (NER) in a text provided by the user.
Mark the entities with tags according to the following guidelines:
- Use XML format to tag entities;
- All entities must be enclosed in <entity>...</entity> tags; All other text must be enclosed in <not_entity>...</not_entity> tags; No content should be outside of these tags;
- The tagging operation must be invertible, i.e. the original text must be recoverable from the tagged textl; This is crucial and easy to overlook, double-check this requirement;
- Adjacent entities should be separated into different tags;
- The list of entities is strictly restricted to the following: {entities}.
"""

NER_SYSTEM_MESSAGE_SPARSE = """You are an expert in Natural Language Processing."""

EXPLAINABLE_NER_DENSE_PROMPT_TEMPLATE = """You are provided with a text. Your task is to identify and tag all named entities within the text using the following entity types only:
{entities}

For each entity, provide a brief explanation for your choice within an XML comment. Use the following XML tag format for each entity:

<entity><reasoning>Your reasoning here</reasoning><tag>ENTITY_NAME_UPPERCASE</tag><value>Entity text</value></entity>

The remaining text must be enclosed in a <not_entity>TEXT</not_entity> tag.

Focus on the context and meaning of each entity rather than just the exact words. The tags should encompass the entire entity based on its definition and usage in the sentence. It is crucial to base your decision on the description of the entity, not just its name.

Format example:

Input:
```This text contains some entity and another entity.```

Output:
```xml
<not_entity>This text contains </not_entity><entity><reasoning>some justification</reasoning><tag>ENTITY1</tag><value>some entity</value></entity><not_entity> and another </not_entity><entity><reasoning>another justification</reasoning><tag>ENTITY2</tag><value>entity</value></entity><not_entity>.</not_entity>
```

Input:
```
{x}
```

Output (origina text with tags):
"""


EXPLAINABLE_NER_SPARSE_PROMPT_TEMPLATE = """You are provided with a text. Your task is to identify and tag all named entities within the text using the following entity types only:
{entities}

You must provide the following information for each entity:
- The reasoning of why you tagged the entity as such; Based on the reasoning, a non-expert should be able to evaluate your decision;
- The tag of the entity (uppercase);
- The value of the entity (as it appears in the text).

Your response should be json formatted using the following schema:

{{
"$schema": "http://json-schema.org/draft-04/schema#",
"type": "array",
"items": [
{{
"type": "object",
"properties": {{
"reasoning": {{
"type": "string"
}},
"tag": {{
"type": "string"
}},
"value": {{
"type": "string"
}}
}},
"required": [
"reasoning",
"tag",
"value"
]
}}
]
}}


Input:
```
{x}
```

Output json:
"""

31 changes: 28 additions & 3 deletions skllm/utils.py → skllm/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pandas as pd
from functools import wraps
from time import sleep

import re

def to_numpy(X: Any) -> np.ndarray:
"""Converts a pandas Series or list to a numpy array.
Expand All @@ -23,10 +23,11 @@ def to_numpy(X: Any) -> np.ndarray:
elif isinstance(X, list):
X = np.asarray(X, dtype=object)
if isinstance(X, np.ndarray) and len(X.shape) > 1:
X = np.squeeze(X)
# do not squeeze the first dim
X = np.squeeze(X, axis=tuple([i for i in range(1, len(X.shape))]))
return X


# TODO: replace with re version below
def find_json_in_string(string: str) -> str:
"""Finds the JSON object in a string.

Expand All @@ -48,6 +49,30 @@ def find_json_in_string(string: str) -> str:
return json_string



def re_naive_json_extractor(json_string: str, expected_output: str = "object") -> str:
"""Finds the first JSON-like object or array in a string using regex.

Parameters
----------
string : str
The string to search for a JSON object or array.

Returns
-------
json_string : str
A JSON string if found, otherwise an empty JSON object.
"""
json_pattern = json_pattern = r'(\{.*\}|\[.*\])'
match = re.search(json_pattern, json_string, re.DOTALL)
if match:
return match.group(0)
else:
return r"{}" if expected_output == "object" else "[]"




def extract_json_key(json_: str, key: str):
"""Extracts JSON key from a string.

Expand Down
Loading