Skip to content

Commit

Permalink
Feature/llm data structs (#3486)
Browse files Browse the repository at this point in the history
* Organize all the llm stuff into a subpackage

* Add structs for interacting with llms
  • Loading branch information
collijk committed Apr 28, 2023
1 parent c7d7564 commit b8478a9
Show file tree
Hide file tree
Showing 9 changed files with 122 additions and 14 deletions.
3 changes: 1 addition & 2 deletions autogpt/agent/agent_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
from typing import List

from autogpt.config.config import Config
from autogpt.llm import create_chat_completion
from autogpt.llm import Message, create_chat_completion
from autogpt.singleton import Singleton
from autogpt.types.openai import Message


class AgentManager(metaclass=Singleton):
Expand Down
16 changes: 16 additions & 0 deletions autogpt/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
from autogpt.llm.api_manager import ApiManager
from autogpt.llm.base import (
ChatModelInfo,
ChatModelResponse,
EmbeddingModelInfo,
EmbeddingModelResponse,
LLMResponse,
Message,
ModelInfo,
)
from autogpt.llm.chat import chat_with_ai, create_chat_message, generate_context
from autogpt.llm.llm_utils import (
call_ai_function,
Expand All @@ -10,6 +19,13 @@

__all__ = [
"ApiManager",
"Message",
"ModelInfo",
"ChatModelInfo",
"EmbeddingModelInfo",
"LLMResponse",
"ChatModelResponse",
"EmbeddingModelResponse",
"create_chat_message",
"generate_context",
"chat_with_ai",
Expand Down
65 changes: 65 additions & 0 deletions autogpt/llm/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from dataclasses import dataclass, field
from typing import List, TypedDict


class Message(TypedDict):
"""OpenAI Message object containing a role and the message content"""

role: str
content: str


@dataclass
class ModelInfo:
"""Struct for model information.
Would be lovely to eventually get this directly from APIs, but needs to be scraped from
websites for now.
"""

name: str
prompt_token_cost: float
completion_token_cost: float
max_tokens: int


@dataclass
class ChatModelInfo(ModelInfo):
"""Struct for chat model information."""

pass


@dataclass
class EmbeddingModelInfo(ModelInfo):
"""Struct for embedding model information."""

embedding_dimensions: int


@dataclass
class LLMResponse:
"""Standard response struct for a response from an LLM model."""

model_info: ModelInfo
prompt_tokens_used: int = 0
completion_tokens_used: int = 0


@dataclass
class EmbeddingModelResponse(LLMResponse):
"""Standard response struct for a response from an embedding model."""

embedding: List[float] = field(default_factory=list)

def __post_init__(self):
if self.completion_tokens_used:
raise ValueError("Embeddings should not have completion tokens used.")


@dataclass
class ChatModelResponse(LLMResponse):
"""Standard response struct for a response from an LLM model."""

content: str = None
2 changes: 1 addition & 1 deletion autogpt/llm/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@

from autogpt.config import Config
from autogpt.llm.api_manager import ApiManager
from autogpt.llm.base import Message
from autogpt.llm.llm_utils import create_chat_completion
from autogpt.llm.token_counter import count_message_tokens
from autogpt.logs import logger
from autogpt.memory_management.store_memory import (
save_memory_trimmed_from_context_window,
)
from autogpt.types.openai import Message

cfg = Config()

Expand Down
2 changes: 1 addition & 1 deletion autogpt/llm/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

from autogpt.config import Config
from autogpt.llm.api_manager import ApiManager
from autogpt.llm.base import Message
from autogpt.logs import logger
from autogpt.types.openai import Message


def retry_openai_api(
Expand Down
Empty file.
37 changes: 37 additions & 0 deletions autogpt/llm/providers/openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from autogpt.llm.base import ChatModelInfo, EmbeddingModelInfo

OPEN_AI_CHAT_MODELS = {
"gpt-3.5-turbo": ChatModelInfo(
name="gpt-3.5-turbo",
prompt_token_cost=0.002,
completion_token_cost=0.002,
max_tokens=4096,
),
"gpt-4": ChatModelInfo(
name="gpt-4",
prompt_token_cost=0.03,
completion_token_cost=0.06,
max_tokens=8192,
),
"gpt-4-32k": ChatModelInfo(
name="gpt-4-32k",
prompt_token_cost=0.06,
completion_token_cost=0.12,
max_tokens=32768,
),
}

OPEN_AI_EMBEDDING_MODELS = {
"text-embedding-ada-002": EmbeddingModelInfo(
name="text-embedding-ada-002",
prompt_token_cost=0.0004,
completion_token_cost=0.0,
max_tokens=8191,
embedding_dimensions=1536,
),
}

OPEN_AI_MODELS = {
**OPEN_AI_CHAT_MODELS,
**OPEN_AI_EMBEDDING_MODELS,
}
2 changes: 1 addition & 1 deletion autogpt/llm/token_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

import tiktoken

from autogpt.llm.base import Message
from autogpt.logs import logger
from autogpt.types.openai import Message


def count_message_tokens(
Expand Down
9 changes: 0 additions & 9 deletions autogpt/types/openai.py

This file was deleted.

0 comments on commit b8478a9

Please sign in to comment.