In [None]:
# Define Enum for LLMProvider Options
from enum import Enum
class LLMProviderOptions(Enum):
    OPENAI = "openai"
    GEMINI = "gemini"
    ANTHROPIC = "anthropic"

In [None]:
# Define Base Config which will be implemented by model specific configs
from pydantic import BaseModel, ConfigDict
from typing import Optional

class BaseAgentConfig(BaseModel):
    model_config = ConfigDict(extra="forbid")
    model: str
    temperature: Optional[float] = None
    max_tokens: Optional[int] = None

In [None]:
# Provider Specific Config Extensions
from typing import Dict, Any

class OpenAIConfig(BaseAgentConfig):
    top_p: Optional[float] = None
    presence_penalty: Optional[float] = None
    frequency_penalty: Optional[float] = None


class GeminiConfig(BaseAgentConfig):
    safety_settings: Optional[Dict[str, Any]] = None
    top_k: Optional[int] = None

In [None]:
# Model for messages
from typing import List, Literal
from pydantic import BaseModel

class Message(BaseModel):
    role: Literal["system", "user", "assistant", "tool"]
    content: str

In [None]:
# Base Provider class (llm-provider)
from abc import ABC, abstractmethod
from typing import Generator, Dict, Any, List, Type

class LLMProvider(ABC):
    config_schema: Type[BaseAgentConfig]

    def __init__(self, config: BaseAgentConfig, system_message: str):
        self.config = config
        self.system_message = system_message
        self.history: List[Message] = []

    # -------------------------
    # History Management
    # -------------------------
    def add_user_message(self, content: str):
        self.history.append(Message(role="user", content=content))

    def add_assistant_message(self, content: str):
        self.history.append(Message(role="assistant", content=content))

    def rollback(self, steps: int = 1):
        """Go back N messages (excluding system)."""
        if steps > 0:
            self.history = self.history[:-steps]

    def save_history(self) -> List[Message]:
        return self.history.copy()

    def compress_history(self, summarizer: "LLMProvider"):
        """Replace history with a summary."""
        summary = summarizer.generate(
            "Summarize the following conversation:\n"
            + "\n\n".join((m.role + ": " + m.content) for m in self.history)
        )
        self.history = [
            Message(role="assistant", content=summary)
        ]
    
    # -------------------------
    # Messages Management
    # -------------------------
    
    # allow changing system message at runtime (still not persisted)
    def set_system_message(self, msg: Optional[str]):
        self._system_message = msg

    # helpers to build provider payload (default behavior)
    def _build_messages_payload(self) -> List[Dict[str, str]]:
        """
        Default neutral representation: system injected as first message if present.
        Child providers may override to fit their API.
        """
        msgs = []
        if self._system_message:
            msgs.append({"role": "system", "content": self._system_message})
        for m in self.history:
            msgs.append({"role": m.role, "content": m.content})
        return msgs

    # -------------------------
    # Generation APIs
    # -------------------------
    @abstractmethod
    def generate(self, prompt: str, **kwargs) -> str:
        pass

    @abstractmethod
    def generate_stream(self, prompt: str, **kwargs) -> Generator[str, None, None]:
        pass

    @abstractmethod
    def generate_structured(
        self,
        prompt: str,
        response_schema: Dict[str, Any],
        **kwargs
    ) -> Dict[str, Any]:
        pass

    @abstractmethod
    def generate_with_tools(
        self,
        prompt: str,
        tools: List[BaseModel],
        enforce_tool_use: bool = False,
        stream: bool = False,
        **kwargs
    ):
        pass

    # -------------------------
    # Introspection
    # -------------------------
    @classmethod
    def supported_config(cls) -> Dict[str, Any]:
        """Expose valid config fields to user/dev."""
        return cls.config_schema.model_json_schema()
    
    @classmethod
    @abstractmethod
    def capability_metadata(cls) -> Dict[str, Any]:
        pass

In [None]:
class OpenAIProvider(LLMProvider):
    config_schema = OpenAIConfig

    def generate(self, prompt: str, **kwargs) -> str:
        self.add_user_message(prompt)
        # OpenAI SDK call here
        response = "openai-response"
        self.add_assistant_message(response)
        return response

    def generate_stream(self, prompt: str, **kwargs):
        yield "openai-stream"

    def generate_structured(self, prompt, response_schema, **kwargs):
        return {}

    def generate_tool_calls(self, prompt, tools, stream=False, **kwargs):
        return {}


In [None]:
class GeminiProvider(LLMProvider):
    config_schema = GeminiConfig

    def generate(self, prompt: str, **kwargs) -> str:
        self.add_user_message(prompt)
        response = "gemini-response"
        self.add_assistant_message(response)
        return response

    def generate_stream(self, prompt: str, **kwargs):
        yield "gemini-stream"

    def generate_structured(self, prompt, response_schema, **kwargs):
        return {}

    def generate_tool_calls(self, prompt, tools, stream=False, **kwargs):
        return {}


In [None]:
from typing import Type, Union

PROVIDER_REGISTRY: Dict[LLMProviderOptions, Type[LLMProvider]] = {
    LLMProviderOptions.OPENAI: OpenAIProvider,
    LLMProviderOptions.GEMINI: GeminiProvider,
}

In [None]:
def create_agent(
    provider: LLMProviderOptions,
    config: Union[OpenAIConfig, GeminiConfig],
    system_message: str,
) -> LLMProvider:
    provider_cls = PROVIDER_REGISTRY[provider]

    # Runtime safety: ensure correct config type
    if not isinstance(config, provider_cls.config_schema):
        raise TypeError(
            f"{provider.value} expects {provider_cls.config_schema.__name__}"
        )

    return provider_cls(config=config, system_message=system_message)
