Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions skllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
_AZURE_API_VERSION_VAR = "SKLLM_CONFIG_AZURE_API_VERSION"
_GOOGLE_PROJECT = "GOOGLE_CLOUD_PROJECT"
_GPT_URL_VAR = "SKLLM_CONFIG_GPT_URL"
_ANTHROPIC_KEY_VAR = "SKLLM_CONFIG_ANTHROPIC_KEY"
_GGUF_DOWNLOAD_PATH = "SKLLM_CONFIG_GGUF_DOWNLOAD_PATH"
_GGUF_MAX_GPU_LAYERS = "SKLLM_CONFIG_GGUF_MAX_GPU_LAYERS"
_GGUF_VERBOSE = "SKLLM_CONFIG_GGUF_VERBOSE"
Expand Down Expand Up @@ -168,6 +169,28 @@ def get_gpt_url() -> Optional[str]:
GPT URL.
"""
return os.environ.get(_GPT_URL_VAR, None)

@staticmethod
def set_anthropic_key(key: str) -> None:
"""Sets the Anthropic key.

Parameters
----------
key : str
Anthropic key.
"""
os.environ[_ANTHROPIC_KEY_VAR] = key

@staticmethod
def get_anthropic_key() -> Optional[str]:
"""Gets the Anthropic key.

Returns
-------
Optional[str]
Anthropic key.
"""
return os.environ.get(_ANTHROPIC_KEY_VAR, None)

@staticmethod
def reset_gpt_url():
Expand Down
72 changes: 72 additions & 0 deletions skllm/llm/anthropic/completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from typing import Dict, List, Optional
from skllm.llm.anthropic.credentials import set_credentials
from skllm.utils import retry

@retry(max_retries=3)
def get_chat_completion(
messages: List[Dict],
key: str,
model: str = "claude-3-haiku-20240307",
max_tokens: int = 1000,
temperature: float = 0.0,
system: Optional[str] = None,
json_response: bool = False,
) -> dict:
"""
Gets a chat completion from the Anthropic Claude API using the Messages API.

Parameters
----------
messages : dict
Input messages to use.
key : str
The Anthropic API key to use.
model : str, optional
The Claude model to use.
max_tokens : int, optional
Maximum tokens to generate.
temperature : float, optional
Sampling temperature.
system : str, optional
System message to set the assistant's behavior.
json_response : bool, optional
Whether to request a JSON-formatted response. Defaults to False.

Returns
-------
response : dict
The completion response from the API.
"""
if not messages:
raise ValueError("Messages list cannot be empty")
if not isinstance(messages, list):
raise TypeError("Messages must be a list")

client = set_credentials(key)

if json_response and system:
system = f"{system.rstrip('.')}. Respond in JSON format."
elif json_response:
system = "Respond in JSON format."

formatted_messages = [
{
"role": "user", # Explicitly set role to "user"
"content": [
{
"type": "text",
"text": message.get("content", "")
}
]
}
for message in messages
]

response = client.messages.create(
model=model,
max_tokens=max_tokens,
temperature=temperature,
system=system,
messages=formatted_messages,
)
return response
13 changes: 13 additions & 0 deletions skllm/llm/anthropic/credentials.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from anthropic import Anthropic


def set_credentials(key: str) -> None:
"""Set the Anthropic key.

Parameters
----------
key : str
The Anthropic key to use.
"""
client = Anthropic(api_key=key)
return client
103 changes: 103 additions & 0 deletions skllm/llm/anthropic/mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from typing import Optional, Union, Any, List, Dict, Mapping
from skllm.config import SKLLMConfig as _Config
from skllm.llm.anthropic.completion import get_chat_completion
from skllm.utils import extract_json_key
from skllm.llm.base import BaseTextCompletionMixin, BaseClassifierMixin
import json


class ClaudeMixin:
"""A mixin class that provides Claude API key to other classes."""

_prefer_json_output = False

def _set_keys(self, key: Optional[str] = None) -> None:
"""Set the Claude API key."""
self.key = key

def _get_claude_key(self) -> str:
"""Get the Claude key from the class or config file."""
key = self.key
if key is None:
key = _Config.get_anthropic_key()
if key is None:
raise RuntimeError("Claude API key was not found")
return key

class ClaudeTextCompletionMixin(ClaudeMixin, BaseTextCompletionMixin):
"""A mixin class that provides text completion capabilities using the Claude API."""

def _get_chat_completion(
self,
model: str,
messages: Union[str, List[Dict[str, str]]],
system_message: Optional[str] = None,
**kwargs: Any,
):
"""Gets a chat completion from the Anthropic API.

Parameters
----------
model : str
The model to use.
messages : Union[str, List[Dict[str, str]]]
input messages to use.
system_message : Optional[str]
A system message to use.
**kwargs : Any
placeholder.

Returns
-------
completion : dict
"""
if isinstance(messages, str):
messages = [{"content": messages}]
elif isinstance(messages, list):
messages = [{"content": msg["content"]} for msg in messages]

completion = get_chat_completion(
messages=messages,
key=self._get_claude_key(),
model=model,
system=system_message,
json_response=self._prefer_json_output,
**kwargs,
)
return completion

def _convert_completion_to_str(self, completion: Mapping[str, Any]):
"""Converts Claude API completion to string."""
try:
if hasattr(completion, 'content'):
return completion.content[0].text
return completion.get('content', [{}])[0].get('text', '')
except Exception as e:
print(f"Error converting completion to string: {str(e)}")
return ""

class ClaudeClassifierMixin(ClaudeTextCompletionMixin, BaseClassifierMixin):
"""A mixin class that provides classification capabilities using Claude API."""

_prefer_json_output = True

def _extract_out_label(self, completion: Mapping[str, Any], **kwargs) -> str:
"""Extracts the label from a Claude API completion."""
try:
content = self._convert_completion_to_str(completion)
if not self._prefer_json_output:
return content.strip()

# Attempt to parse content as JSON and extract label
try:
data = json.loads(content)
if "label" in data:
return data["label"]
except json.JSONDecodeError:
pass
return ""

except Exception as e:
print(f"Error extracting label: {str(e)}")
return ""

142 changes: 142 additions & 0 deletions skllm/models/anthropic/classification/few_shot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from skllm.models._base.classifier import (
BaseFewShotClassifier,
BaseDynamicFewShotClassifier,
SingleLabelMixin,
MultiLabelMixin,
)
from skllm.llm.anthropic.mixin import ClaudeClassifierMixin
from skllm.models.gpt.vectorization import GPTVectorizer
from skllm.models._base.vectorizer import BaseVectorizer
from skllm.memory.base import IndexConstructor
from typing import Optional


class FewShotClaudeClassifier(BaseFewShotClassifier, ClaudeClassifierMixin, SingleLabelMixin):
"""Few-shot text classifier using Anthropic's Claude API for single-label classification tasks."""

def __init__(
self,
model: str = "claude-3-haiku-20240307",
default_label: str = "Random",
prompt_template: Optional[str] = None,
key: Optional[str] = None,
**kwargs,
):
"""
Few-shot text classifier using Anthropic's Claude API.

Parameters
----------
model : str, optional
model to use, by default "claude-3-haiku-20240307"
default_label : str, optional
default label for failed prediction; if "Random" -> selects randomly based on class frequencies
prompt_template : Optional[str], optional
custom prompt template to use, by default None
key : Optional[str], optional
estimator-specific API key; if None, retrieved from the global config
"""
super().__init__(
model=model,
default_label=default_label,
prompt_template=prompt_template,
**kwargs,
)
self._set_keys(key)


class MultiLabelFewShotClaudeClassifier(
BaseFewShotClassifier, ClaudeClassifierMixin, MultiLabelMixin
):
"""Few-shot text classifier using Anthropic's Claude API for multi-label classification tasks."""

def __init__(
self,
model: str = "claude-3-haiku-20240307",
default_label: str = "Random",
max_labels: Optional[int] = 5,
prompt_template: Optional[str] = None,
key: Optional[str] = None,
**kwargs,
):
"""
Multi-label few-shot text classifier using Anthropic's Claude API.

Parameters
----------
model : str, optional
model to use, by default "claude-3-haiku-20240307"
default_label : str, optional
default label for failed prediction; if "Random" -> selects randomly based on class frequencies
max_labels : Optional[int], optional
maximum labels per sample, by default 5
prompt_template : Optional[str], optional
custom prompt template to use, by default None
key : Optional[str], optional
estimator-specific API key; if None, retrieved from the global config
"""
super().__init__(
model=model,
default_label=default_label,
max_labels=max_labels,
prompt_template=prompt_template,
**kwargs,
)
self._set_keys(key)


class DynamicFewShotClaudeClassifier(
BaseDynamicFewShotClassifier, ClaudeClassifierMixin, SingleLabelMixin
):
"""
Dynamic few-shot text classifier using Anthropic's Claude API for
single-label classification tasks with dynamic example selection using GPT embeddings.
"""

def __init__(
self,
model: str = "claude-3-haiku-20240307",
default_label: str = "Random",
prompt_template: Optional[str] = None,
key: Optional[str] = None,
n_examples: int = 3,
memory_index: Optional[IndexConstructor] = None,
vectorizer: Optional[BaseVectorizer] = None,
metric: Optional[str] = "euclidean",
**kwargs,
):
"""
Dynamic few-shot text classifier using Anthropic's Claude API.
For each sample, N closest examples are retrieved from the memory.

Parameters
----------
model : str, optional
model to use, by default "claude-3-haiku-20240307"
default_label : str, optional
default label for failed prediction; if "Random" -> selects randomly based on class frequencies
prompt_template : Optional[str], optional
custom prompt template to use, by default None
key : Optional[str], optional
estimator-specific API key; if None, retrieved from the global config
n_examples : int, optional
number of closest examples per class to be retrieved, by default 3
memory_index : Optional[IndexConstructor], optional
custom memory index, for details check `skllm.memory` submodule
vectorizer : Optional[BaseVectorizer], optional
scikit-llm vectorizer; if None, `GPTVectorizer` is used
metric : Optional[str], optional
metric used for similarity search, by default "euclidean"
"""
if vectorizer is None:
vectorizer = GPTVectorizer(model="text-embedding-ada-002", key=key)
super().__init__(
model=model,
default_label=default_label,
prompt_template=prompt_template,
n_examples=n_examples,
memory_index=memory_index,
vectorizer=vectorizer,
metric=metric,
)
self._set_keys(key)
Loading
Loading