diff --git a/.gitignore b/.gitignore
index 707cff7..8b6a6c8 100644
--- a/.gitignore
+++ b/.gitignore
@@ -167,3 +167,4 @@ tmp.py
# vscode
.vscode/
tmp2.py
+tmp.*
\ No newline at end of file
diff --git a/skllm/models/_base/tagger.py b/skllm/models/_base/tagger.py
new file mode 100644
index 0000000..2bb90aa
--- /dev/null
+++ b/skllm/models/_base/tagger.py
@@ -0,0 +1,152 @@
+from typing import Any, Union, List, Optional, Dict
+from abc import abstractmethod, ABC
+from numpy import ndarray
+from tqdm import tqdm
+import numpy as np
+import pandas as pd
+from skllm.utils import to_numpy as _to_numpy
+from sklearn.base import (
+ BaseEstimator as _SklBaseEstimator,
+ TransformerMixin as _SklTransformerMixin,
+)
+
+from skllm.utils.rendering import display_ner
+from skllm.utils.xml import filter_xml_tags, filter_unwanted_entities, json_to_xml
+from skllm.prompts.builders import build_ner_prompt
+from skllm.prompts.templates import (
+ NER_SYSTEM_MESSAGE_TEMPLATE,
+ NER_SYSTEM_MESSAGE_SPARSE,
+ EXPLAINABLE_NER_DENSE_PROMPT_TEMPLATE,
+ EXPLAINABLE_NER_SPARSE_PROMPT_TEMPLATE,
+)
+from skllm.utils import re_naive_json_extractor
+import json
+from concurrent.futures import ThreadPoolExecutor
+
+class BaseTagger(ABC, _SklBaseEstimator, _SklTransformerMixin):
+
+ num_workers = 1
+
+ def fit(self, X: Any, y: Any = None):
+ """
+ Fits the model to the data. Usually a no-op.
+
+ Parameters
+ ----------
+ X : Any
+ training data
+ y : Any
+ training outputs
+
+ Returns
+ -------
+ self
+ BaseTagger
+ """
+ return self
+
+ def predict(self, X: Union[np.ndarray, pd.Series, List[str]]):
+ return self.transform(X)
+
+ def fit_transform(
+ self,
+ X: Union[np.ndarray, pd.Series, List[str]],
+ y: Optional[Union[np.ndarray, pd.Series, List[str]]] = None,
+ ) -> ndarray:
+ return self.fit(X, y).transform(X)
+
+ def transform(self, X: Union[np.ndarray, pd.Series, List[str]]):
+ """
+ Transforms the input data.
+
+ Parameters
+ ----------
+ X : Union[np.ndarray, pd.Series, List[str]]
+ The input data to predict the class of.
+
+ Returns
+ -------
+ List[str]
+ """
+ X = _to_numpy(X)
+ predictions = []
+ with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
+ predictions = list(tqdm(executor.map(self._predict_single, X), total=len(X)))
+ return np.asarray(predictions)
+
+ def _predict_single(self, x: Any) -> Any:
+ prompt_dict = self._get_prompt(x)
+ # this will be inherited from the LLM
+ prediction = self._get_chat_completion(model=self.model, **prompt_dict)
+ prediction = self._convert_completion_to_str(prediction)
+ return prediction
+
+ @abstractmethod
+ def _get_prompt(self, x: str) -> dict:
+ """Returns the prompt to use for a single input."""
+ pass
+
+
+class ExplainableNER(BaseTagger):
+ entities: Optional[Dict[str, str]] = None
+ sparse_output = True
+ _allowed_tags = ["entity", "not_entity"]
+
+ display_predictions = False
+
+ def fit(self, X: Any, y: Any = None):
+ entities = []
+ for k, v in self.entities.items():
+ entities.append({"entity": k, "definition": v})
+ self.expanded_entities_ = entities
+ return self
+
+ def transform(self, X: ndarray | pd.Series | List[str]):
+ predictions = super().transform(X)
+ if self.sparse_output:
+ json_predictions = [
+ re_naive_json_extractor(p, expected_output="array") for p in predictions
+ ]
+ predictions = []
+ attributes = ["reasoning", "tag", "value"]
+ for x, p in zip(X, json_predictions):
+ p_json = json.loads(p)
+ predictions.append(
+ json_to_xml(
+ x,
+ p_json,
+ "entity",
+ "not_entity",
+ value_key="value",
+ attributes=attributes,
+ )
+ )
+
+ predictions = [
+ filter_unwanted_entities(
+ filter_xml_tags(p, self._allowed_tags), self.entities
+ )
+ for p in predictions
+ ]
+ if self.display_predictions:
+ print("Displaying predictions...")
+ display_ner(predictions, self.entities)
+ return predictions
+
+ def _get_prompt(self, x: str) -> dict:
+ if not hasattr(self, "expanded_entities_"):
+ raise ValueError("Model not fitted.")
+ system_message = (
+ NER_SYSTEM_MESSAGE_TEMPLATE.format(entities=self.entities.keys())
+ if self.sparse_output
+ else NER_SYSTEM_MESSAGE_SPARSE
+ )
+ template = (
+ EXPLAINABLE_NER_SPARSE_PROMPT_TEMPLATE
+ if self.sparse_output
+ else EXPLAINABLE_NER_DENSE_PROMPT_TEMPLATE
+ )
+ return {
+ "messages": build_ner_prompt(self.expanded_entities_, x, template=template),
+ "system_message": system_message.format(entities=self.entities.keys()),
+ }
diff --git a/skllm/models/gpt/tagging/ner.py b/skllm/models/gpt/tagging/ner.py
new file mode 100644
index 0000000..98857a8
--- /dev/null
+++ b/skllm/models/gpt/tagging/ner.py
@@ -0,0 +1,42 @@
+from skllm.models._base.tagger import ExplainableNER as _ExplainableNER
+from skllm.llm.gpt.mixin import GPTTextCompletionMixin as _GPTTextCompletionMixin
+from typing import Optional, Dict
+
+
+class GPTExplainableNER(_ExplainableNER, _GPTTextCompletionMixin):
+ def __init__(
+ self,
+ entities: Dict[str, str],
+ display_predictions: bool = False,
+ sparse_output: bool = True,
+ model: str = "gpt-4o",
+ key: Optional[str] = None,
+ org: Optional[str] = None,
+ num_workers: int = 1,
+ ) -> None:
+ """
+ Named entity recognition using OpenAI/GPT API-compatible models.
+
+ Parameters
+ ----------
+ entities : dict
+ dictionary of entities to recognize, with keys as entity names and values as descriptions
+ display_predictions : bool, optional
+ whether to display predictions, by default False
+ sparse_output : bool, optional
+ whether to generate a sparse representation of the predictions, by default True
+ model : str, optional
+ model to use, by default "gpt-4o"
+ key : Optional[str], optional
+ estimator-specific API key; if None, retrieved from the global config, by default None
+ org : Optional[str], optional
+ estimator-specific ORG key; if None, retrieved from the global config, by default None
+ num_workers : int, optional
+ number of workers (threads) to use, by default 1
+ """
+ self._set_keys(key, org)
+ self.model = model
+ self.entities = entities
+ self.display_predictions = display_predictions
+ self.sparse_output = sparse_output
+ self.num_workers = num_workers
\ No newline at end of file
diff --git a/skllm/prompts/builders.py b/skllm/prompts/builders.py
index 823c20a..e24ae72 100644
--- a/skllm/prompts/builders.py
+++ b/skllm/prompts/builders.py
@@ -1,4 +1,4 @@
-from typing import Union
+from typing import Union, Optional
from skllm.prompts.templates import (
FEW_SHOT_CLF_PROMPT_TEMPLATE,
@@ -8,6 +8,7 @@
TRANSLATION_PROMPT_TEMPLATE,
ZERO_SHOT_CLF_PROMPT_TEMPLATE,
ZERO_SHOT_MLCLF_PROMPT_TEMPLATE,
+ EXPLAINABLE_NER_DENSE_PROMPT_TEMPLATE,
)
# TODO add validators
@@ -190,3 +191,27 @@ def build_translation_prompt(
prepared prompt
"""
return template.format(x=x, output_language=output_language)
+
+
+def build_ner_prompt(
+ entities: list,
+ x: str,
+ template: str = EXPLAINABLE_NER_DENSE_PROMPT_TEMPLATE,
+) -> str:
+ """Builds a prompt for named entity recognition.
+
+ Parameters
+ ----------
+ entities : list
+ list of entities to recognize
+ x : str
+ sample to recognize entities in
+ template : str, optional
+ prompt template to use, must contain placeholders for all variables, by default EXPLAINABLE_NER_DENSE_PROMPT_TEMPLATE
+
+ Returns
+ -------
+ str
+ prepared prompt
+ """
+ return template.format(entities=entities, x=x)
diff --git a/skllm/prompts/templates.py b/skllm/prompts/templates.py
index 552230f..96334b7 100644
--- a/skllm/prompts/templates.py
+++ b/skllm/prompts/templates.py
@@ -118,3 +118,90 @@
Original text: ```{x}```
Output:
"""
+
+NER_SYSTEM_MESSAGE_TEMPLATE = """You are an expert in Natural Language Processing. Your task is to identify common Named Entities (NER) in a text provided by the user.
+Mark the entities with tags according to the following guidelines:
+ - Use XML format to tag entities;
+ - All entities must be enclosed in