diff --git a/skllm/config.py b/skllm/config.py index 37530a4..ea6709b 100644 --- a/skllm/config.py +++ b/skllm/config.py @@ -7,6 +7,7 @@ _AZURE_API_VERSION_VAR = "SKLLM_CONFIG_AZURE_API_VERSION" _GOOGLE_PROJECT = "GOOGLE_CLOUD_PROJECT" _GPT_URL_VAR = "SKLLM_CONFIG_GPT_URL" +_ANTHROPIC_KEY_VAR = "SKLLM_CONFIG_ANTHROPIC_KEY" _GGUF_DOWNLOAD_PATH = "SKLLM_CONFIG_GGUF_DOWNLOAD_PATH" _GGUF_MAX_GPU_LAYERS = "SKLLM_CONFIG_GGUF_MAX_GPU_LAYERS" _GGUF_VERBOSE = "SKLLM_CONFIG_GGUF_VERBOSE" @@ -168,6 +169,28 @@ def get_gpt_url() -> Optional[str]: GPT URL. """ return os.environ.get(_GPT_URL_VAR, None) + + @staticmethod + def set_anthropic_key(key: str) -> None: + """Sets the Anthropic key. + + Parameters + ---------- + key : str + Anthropic key. + """ + os.environ[_ANTHROPIC_KEY_VAR] = key + + @staticmethod + def get_anthropic_key() -> Optional[str]: + """Gets the Anthropic key. + + Returns + ------- + Optional[str] + Anthropic key. + """ + return os.environ.get(_ANTHROPIC_KEY_VAR, None) @staticmethod def reset_gpt_url(): diff --git a/skllm/llm/anthropic/completion.py b/skllm/llm/anthropic/completion.py new file mode 100644 index 0000000..52132ed --- /dev/null +++ b/skllm/llm/anthropic/completion.py @@ -0,0 +1,72 @@ +from typing import Dict, List, Optional +from skllm.llm.anthropic.credentials import set_credentials +from skllm.utils import retry + +@retry(max_retries=3) +def get_chat_completion( + messages: List[Dict], + key: str, + model: str = "claude-3-haiku-20240307", + max_tokens: int = 1000, + temperature: float = 0.0, + system: Optional[str] = None, + json_response: bool = False, +) -> dict: + """ + Gets a chat completion from the Anthropic Claude API using the Messages API. + + Parameters + ---------- + messages : dict + Input messages to use. + key : str + The Anthropic API key to use. + model : str, optional + The Claude model to use. + max_tokens : int, optional + Maximum tokens to generate. + temperature : float, optional + Sampling temperature. + system : str, optional + System message to set the assistant's behavior. + json_response : bool, optional + Whether to request a JSON-formatted response. Defaults to False. + + Returns + ------- + response : dict + The completion response from the API. + """ + if not messages: + raise ValueError("Messages list cannot be empty") + if not isinstance(messages, list): + raise TypeError("Messages must be a list") + + client = set_credentials(key) + + if json_response and system: + system = f"{system.rstrip('.')}. Respond in JSON format." + elif json_response: + system = "Respond in JSON format." + + formatted_messages = [ + { + "role": "user", # Explicitly set role to "user" + "content": [ + { + "type": "text", + "text": message.get("content", "") + } + ] + } + for message in messages + ] + + response = client.messages.create( + model=model, + max_tokens=max_tokens, + temperature=temperature, + system=system, + messages=formatted_messages, + ) + return response \ No newline at end of file diff --git a/skllm/llm/anthropic/credentials.py b/skllm/llm/anthropic/credentials.py new file mode 100644 index 0000000..bbd5f99 --- /dev/null +++ b/skllm/llm/anthropic/credentials.py @@ -0,0 +1,13 @@ +from anthropic import Anthropic + + +def set_credentials(key: str) -> None: + """Set the Anthropic key. + + Parameters + ---------- + key : str + The Anthropic key to use. + """ + client = Anthropic(api_key=key) + return client diff --git a/skllm/llm/anthropic/mixin.py b/skllm/llm/anthropic/mixin.py new file mode 100644 index 0000000..40fa4eb --- /dev/null +++ b/skllm/llm/anthropic/mixin.py @@ -0,0 +1,103 @@ +from typing import Optional, Union, Any, List, Dict, Mapping +from skllm.config import SKLLMConfig as _Config +from skllm.llm.anthropic.completion import get_chat_completion +from skllm.utils import extract_json_key +from skllm.llm.base import BaseTextCompletionMixin, BaseClassifierMixin +import json + + +class ClaudeMixin: + """A mixin class that provides Claude API key to other classes.""" + + _prefer_json_output = False + + def _set_keys(self, key: Optional[str] = None) -> None: + """Set the Claude API key.""" + self.key = key + + def _get_claude_key(self) -> str: + """Get the Claude key from the class or config file.""" + key = self.key + if key is None: + key = _Config.get_anthropic_key() + if key is None: + raise RuntimeError("Claude API key was not found") + return key + +class ClaudeTextCompletionMixin(ClaudeMixin, BaseTextCompletionMixin): + """A mixin class that provides text completion capabilities using the Claude API.""" + + def _get_chat_completion( + self, + model: str, + messages: Union[str, List[Dict[str, str]]], + system_message: Optional[str] = None, + **kwargs: Any, + ): + """Gets a chat completion from the Anthropic API. + + Parameters + ---------- + model : str + The model to use. + messages : Union[str, List[Dict[str, str]]] + input messages to use. + system_message : Optional[str] + A system message to use. + **kwargs : Any + placeholder. + + Returns + ------- + completion : dict + """ + if isinstance(messages, str): + messages = [{"content": messages}] + elif isinstance(messages, list): + messages = [{"content": msg["content"]} for msg in messages] + + completion = get_chat_completion( + messages=messages, + key=self._get_claude_key(), + model=model, + system=system_message, + json_response=self._prefer_json_output, + **kwargs, + ) + return completion + + def _convert_completion_to_str(self, completion: Mapping[str, Any]): + """Converts Claude API completion to string.""" + try: + if hasattr(completion, 'content'): + return completion.content[0].text + return completion.get('content', [{}])[0].get('text', '') + except Exception as e: + print(f"Error converting completion to string: {str(e)}") + return "" + +class ClaudeClassifierMixin(ClaudeTextCompletionMixin, BaseClassifierMixin): + """A mixin class that provides classification capabilities using Claude API.""" + + _prefer_json_output = True + + def _extract_out_label(self, completion: Mapping[str, Any], **kwargs) -> str: + """Extracts the label from a Claude API completion.""" + try: + content = self._convert_completion_to_str(completion) + if not self._prefer_json_output: + return content.strip() + + # Attempt to parse content as JSON and extract label + try: + data = json.loads(content) + if "label" in data: + return data["label"] + except json.JSONDecodeError: + pass + return "" + + except Exception as e: + print(f"Error extracting label: {str(e)}") + return "" + \ No newline at end of file diff --git a/skllm/models/anthropic/classification/few_shot.py b/skllm/models/anthropic/classification/few_shot.py new file mode 100644 index 0000000..19b4691 --- /dev/null +++ b/skllm/models/anthropic/classification/few_shot.py @@ -0,0 +1,142 @@ +from skllm.models._base.classifier import ( + BaseFewShotClassifier, + BaseDynamicFewShotClassifier, + SingleLabelMixin, + MultiLabelMixin, +) +from skllm.llm.anthropic.mixin import ClaudeClassifierMixin +from skllm.models.gpt.vectorization import GPTVectorizer +from skllm.models._base.vectorizer import BaseVectorizer +from skllm.memory.base import IndexConstructor +from typing import Optional + + +class FewShotClaudeClassifier(BaseFewShotClassifier, ClaudeClassifierMixin, SingleLabelMixin): + """Few-shot text classifier using Anthropic's Claude API for single-label classification tasks.""" + + def __init__( + self, + model: str = "claude-3-haiku-20240307", + default_label: str = "Random", + prompt_template: Optional[str] = None, + key: Optional[str] = None, + **kwargs, + ): + """ + Few-shot text classifier using Anthropic's Claude API. + + Parameters + ---------- + model : str, optional + model to use, by default "claude-3-haiku-20240307" + default_label : str, optional + default label for failed prediction; if "Random" -> selects randomly based on class frequencies + prompt_template : Optional[str], optional + custom prompt template to use, by default None + key : Optional[str], optional + estimator-specific API key; if None, retrieved from the global config + """ + super().__init__( + model=model, + default_label=default_label, + prompt_template=prompt_template, + **kwargs, + ) + self._set_keys(key) + + +class MultiLabelFewShotClaudeClassifier( + BaseFewShotClassifier, ClaudeClassifierMixin, MultiLabelMixin +): + """Few-shot text classifier using Anthropic's Claude API for multi-label classification tasks.""" + + def __init__( + self, + model: str = "claude-3-haiku-20240307", + default_label: str = "Random", + max_labels: Optional[int] = 5, + prompt_template: Optional[str] = None, + key: Optional[str] = None, + **kwargs, + ): + """ + Multi-label few-shot text classifier using Anthropic's Claude API. + + Parameters + ---------- + model : str, optional + model to use, by default "claude-3-haiku-20240307" + default_label : str, optional + default label for failed prediction; if "Random" -> selects randomly based on class frequencies + max_labels : Optional[int], optional + maximum labels per sample, by default 5 + prompt_template : Optional[str], optional + custom prompt template to use, by default None + key : Optional[str], optional + estimator-specific API key; if None, retrieved from the global config + """ + super().__init__( + model=model, + default_label=default_label, + max_labels=max_labels, + prompt_template=prompt_template, + **kwargs, + ) + self._set_keys(key) + + +class DynamicFewShotClaudeClassifier( + BaseDynamicFewShotClassifier, ClaudeClassifierMixin, SingleLabelMixin +): + """ + Dynamic few-shot text classifier using Anthropic's Claude API for + single-label classification tasks with dynamic example selection using GPT embeddings. + """ + + def __init__( + self, + model: str = "claude-3-haiku-20240307", + default_label: str = "Random", + prompt_template: Optional[str] = None, + key: Optional[str] = None, + n_examples: int = 3, + memory_index: Optional[IndexConstructor] = None, + vectorizer: Optional[BaseVectorizer] = None, + metric: Optional[str] = "euclidean", + **kwargs, + ): + """ + Dynamic few-shot text classifier using Anthropic's Claude API. + For each sample, N closest examples are retrieved from the memory. + + Parameters + ---------- + model : str, optional + model to use, by default "claude-3-haiku-20240307" + default_label : str, optional + default label for failed prediction; if "Random" -> selects randomly based on class frequencies + prompt_template : Optional[str], optional + custom prompt template to use, by default None + key : Optional[str], optional + estimator-specific API key; if None, retrieved from the global config + n_examples : int, optional + number of closest examples per class to be retrieved, by default 3 + memory_index : Optional[IndexConstructor], optional + custom memory index, for details check `skllm.memory` submodule + vectorizer : Optional[BaseVectorizer], optional + scikit-llm vectorizer; if None, `GPTVectorizer` is used + metric : Optional[str], optional + metric used for similarity search, by default "euclidean" + """ + if vectorizer is None: + vectorizer = GPTVectorizer(model="text-embedding-ada-002", key=key) + super().__init__( + model=model, + default_label=default_label, + prompt_template=prompt_template, + n_examples=n_examples, + memory_index=memory_index, + vectorizer=vectorizer, + metric=metric, + ) + self._set_keys(key) diff --git a/skllm/models/anthropic/classification/zero_shot.py b/skllm/models/anthropic/classification/zero_shot.py new file mode 100644 index 0000000..94272bd --- /dev/null +++ b/skllm/models/anthropic/classification/zero_shot.py @@ -0,0 +1,120 @@ +from skllm.models._base.classifier import ( + SingleLabelMixin as _SingleLabelMixin, + MultiLabelMixin as _MultiLabelMixin, + BaseZeroShotClassifier as _BaseZeroShotClassifier, + BaseCoTClassifier as _BaseCoTClassifier, +) +from skllm.llm.anthropic.mixin import ClaudeClassifierMixin as _ClaudeClassifierMixin +from typing import Optional + + +class ZeroShotClaudeClassifier( + _BaseZeroShotClassifier, _ClaudeClassifierMixin, _SingleLabelMixin +): + """Zero-shot text classifier using Anthropic Claude models for single-label classification.""" + + def __init__( + self, + model: str = "claude-3-haiku-20240307", + default_label: str = "Random", + prompt_template: Optional[str] = None, + key: Optional[str] = None, + **kwargs, + ): + """ + Zero-shot text classifier using Anthropic Claude models. + + Parameters + ---------- + model : str, optional + Model to use, by default "claude-3-haiku-20240307". + default_label : str, optional + Default label for failed predictions; if "Random", selects randomly based on class frequencies, defaults to "Random". + prompt_template : Optional[str], optional + Custom prompt template to use, by default None. + key : Optional[str], optional + Estimator-specific API key; if None, retrieved from the global config, by default None. + """ + super().__init__( + model=model, + default_label=default_label, + prompt_template=prompt_template, + **kwargs, + ) + self._set_keys(key) + + +class CoTClaudeClassifier( + _BaseCoTClassifier, _ClaudeClassifierMixin, _SingleLabelMixin +): + """Chain-of-thought text classifier using Anthropic Claude models for single-label classification.""" + + def __init__( + self, + model: str = "claude-3-haiku-20240307", + default_label: str = "Random", + prompt_template: Optional[str] = None, + key: Optional[str] = None, + **kwargs, + ): + """ + Chain-of-thought text classifier using Anthropic Claude models. + + Parameters + ---------- + model : str, optional + Model to use, by default "claude-3-haiku-20240307". + default_label : str, optional + Default label for failed predictions; if "Random", selects randomly based on class frequencies, defaults to "Random". + prompt_template : Optional[str], optional + Custom prompt template to use, by default None. + key : Optional[str], optional + Estimator-specific API key; if None, retrieved from the global config, by default None. + """ + super().__init__( + model=model, + default_label=default_label, + prompt_template=prompt_template, + **kwargs, + ) + self._set_keys(key) + + +class MultiLabelZeroShotClaudeClassifier( + _BaseZeroShotClassifier, _ClaudeClassifierMixin, _MultiLabelMixin +): + """Zero-shot text classifier using Anthropic Claude models for multi-label classification.""" + + def __init__( + self, + model: str = "claude-3-haiku-20240307", + default_label: str = "Random", + max_labels: Optional[int] = 5, + prompt_template: Optional[str] = None, + key: Optional[str] = None, + **kwargs, + ): + """ + Multi-label zero-shot text classifier using Anthropic Claude models. + + Parameters + ---------- + model : str, optional + Model to use, by default "claude-3-haiku-20240307". + default_label : str, optional + Default label for failed predictions; if "Random", selects randomly based on class frequencies, defaults to "Random". + max_labels : Optional[int], optional + Maximum number of labels per sample, by default 5. + prompt_template : Optional[str], optional + Custom prompt template to use, by default None. + key : Optional[str], optional + Estimator-specific API key; if None, retrieved from the global config, by default None. + """ + super().__init__( + model=model, + default_label=default_label, + max_labels=max_labels, + prompt_template=prompt_template, + **kwargs, + ) + self._set_keys(key) \ No newline at end of file diff --git a/skllm/models/anthropic/tagging/ner.py b/skllm/models/anthropic/tagging/ner.py new file mode 100644 index 0000000..86afc7b --- /dev/null +++ b/skllm/models/anthropic/tagging/ner.py @@ -0,0 +1,41 @@ +from skllm.models._base.tagger import ExplainableNER as _ExplainableNER +from skllm.llm.anthropic.mixin import ClaudeTextCompletionMixin as _ClaudeTextCompletionMixin +from typing import Optional, Dict + + +class AnthropicExplainableNER(_ExplainableNER, _ClaudeTextCompletionMixin): + """Named Entity Recognition model using Anthropic's Claude API for explainable entity extraction.""" + + def __init__( + self, + entities: Dict[str, str], + display_predictions: bool = False, + sparse_output: bool = True, + model: str = "claude-3-haiku-20240307", + key: Optional[str] = None, + num_workers: int = 1, + ) -> None: + """ + Named entity recognition using Anthropic Claude API. + + Parameters + ---------- + entities : dict + dictionary of entities to recognize, with keys as entity names and values as descriptions + display_predictions : bool, optional + whether to display predictions, by default False + sparse_output : bool, optional + whether to generate a sparse representation of the predictions, by default True + model : str, optional + model to use, by default "claude-3-haiku-20240307" + key : Optional[str], optional + estimator-specific API key; if None, retrieved from the global config + num_workers : int, optional + number of workers (threads) to use, by default 1 + """ + self._set_keys(key) + self.model = model + self.entities = entities + self.display_predictions = display_predictions + self.sparse_output = sparse_output + self.num_workers = num_workers \ No newline at end of file diff --git a/skllm/models/anthropic/text2text/__init__.py b/skllm/models/anthropic/text2text/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/skllm/models/anthropic/text2text/summarization.py b/skllm/models/anthropic/text2text/summarization.py new file mode 100644 index 0000000..161d1f4 --- /dev/null +++ b/skllm/models/anthropic/text2text/summarization.py @@ -0,0 +1,34 @@ +from skllm.models._base.text2text import BaseSummarizer as _BaseSummarizer +from skllm.llm.anthropic.mixin import ClaudeTextCompletionMixin as _ClaudeTextCompletionMixin +from typing import Optional + + +class ClaudeSummarizer(_BaseSummarizer, _ClaudeTextCompletionMixin): + """Text summarizer using Anthropic Claude API.""" + + def __init__( + self, + model: str = "claude-3-haiku-20240307", + key: Optional[str] = None, + max_words: int = 15, + focus: Optional[str] = None, + ) -> None: + """ + Initialize the Claude summarizer. + + Parameters + ---------- + model : str, optional + Model to use, by default "claude-3-haiku-20240307" + key : Optional[str], optional + Estimator-specific API key; if None, retrieved from global config + max_words : int, optional + Soft limit of the summary length, by default 15 + focus : Optional[str], optional + Concept in the text to focus on, by default None + """ + self._set_keys(key) + self.model = model + self.max_words = max_words + self.focus = focus + self.system_message = "You are a text summarizer. Provide concise and accurate summaries." \ No newline at end of file diff --git a/skllm/models/anthropic/text2text/translation.py b/skllm/models/anthropic/text2text/translation.py new file mode 100644 index 0000000..57174e9 --- /dev/null +++ b/skllm/models/anthropic/text2text/translation.py @@ -0,0 +1,35 @@ +from skllm.models._base.text2text import BaseTranslator as _BaseTranslator +from skllm.llm.anthropic.mixin import ClaudeTextCompletionMixin as _ClaudeTextCompletionMixin +from typing import Optional + + +class ClaudeTranslator(_BaseTranslator, _ClaudeTextCompletionMixin): + """Text translator using Anthropic Claude API.""" + + default_output = "Translation is unavailable." + + def __init__( + self, + model: str = "claude-3-haiku-20240307", + key: Optional[str] = None, + output_language: str = "English", + ) -> None: + """ + Initialize the Claude translator. + + Parameters + ---------- + model : str, optional + Model to use, by default "claude-3-haiku-20240307" + key : Optional[str], optional + Estimator-specific API key; if None, retrieved from global config + output_language : str, optional + Target language, by default "English" + """ + self._set_keys(key) + self.model = model + self.output_language = output_language + self.system_message = ( + "You are a professional translator. Provide accurate translations " + "while maintaining the original meaning and tone of the text." + ) \ No newline at end of file diff --git a/tests/llm/anthropic/__init__.py b/tests/llm/anthropic/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/llm/anthropic/test_anthropic_mixins.py b/tests/llm/anthropic/test_anthropic_mixins.py new file mode 100644 index 0000000..f8eb027 --- /dev/null +++ b/tests/llm/anthropic/test_anthropic_mixins.py @@ -0,0 +1,61 @@ +import unittest +from unittest.mock import patch +import json +from skllm.llm.anthropic.mixin import ( + ClaudeMixin, + ClaudeTextCompletionMixin, + ClaudeClassifierMixin, +) + + +class TestClaudeMixin(unittest.TestCase): + def test_ClaudeMixin(self): + mixin = ClaudeMixin() + mixin._set_keys("test_key") + self.assertEqual(mixin._get_claude_key(), "test_key") + + +class TestClaudeTextCompletionMixin(unittest.TestCase): + @patch("skllm.llm.anthropic.mixin.get_chat_completion") + def test_chat_completion_with_valid_params(self, mock_get_chat_completion): + mixin = ClaudeTextCompletionMixin() + mixin._set_keys("test_key") + + mock_get_chat_completion.return_value = { + "content": [ + {"type": "text", "text": "test response"} + ] + } + + completion = mixin._get_chat_completion( + model="claude-3-haiku-20240307", + messages="Hello", + system_message="Test system" + ) + + self.assertEqual( + mixin._convert_completion_to_str(completion), + "test response" + ) + mock_get_chat_completion.assert_called_once() + + +class TestClaudeClassifierMixin(unittest.TestCase): + @patch("skllm.llm.anthropic.mixin.get_chat_completion") + def test_extract_out_label_with_valid_completion(self, mock_get_chat_completion): + mixin = ClaudeClassifierMixin() + mixin._set_keys("test_key") + + mock_get_chat_completion.return_value = { + "content": [ + {"type": "text", "text": '{"label":"hello world"}'} + ] + } + + completion = mixin._get_chat_completion( + model="claude-3-haiku-20240307", + messages="Hello", + system_message="World" + ) + self.assertEqual(mixin._extract_out_label(completion), "hello world") + mock_get_chat_completion.assert_called_once() \ No newline at end of file