Skip to content

Commit

Permalink
feat: Allow users to config LLM Inference Parameters as per bot from …
Browse files Browse the repository at this point in the history
…UI. Close #166. (#303)

* feat: Allow users to config LLM Inference Parameters as per bot from UI

* docs: Update README file

* chore: Fix linting errors

* chore: Fix linting errors

* chore: Update as per feedback comments

* chore: Fix lintting errors

* chore: Update based on PR feedbacks
  • Loading branch information
jessieweiyi committed May 16, 2024
1 parent 116edcf commit de4f042
Show file tree
Hide file tree
Showing 26 changed files with 982 additions and 46 deletions.
7 changes: 3 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,13 +211,12 @@ Update `enableMistral` to `true` in [cdk.json](./cdk/cdk.json), and run `cdk dep
> [!Important]
> This project focus on Anthropic Claude models, the Mistral models are limited supported. For example, prompt examples are based on Claude models. This is a Mistral-only option, once you toggled to enable Mistral models, you can only use Mistral models for all the chat features, NOT both Claude and Mistral models.
### Configure text generation
### Configure default text generation

Edit [config.py](./backend/app/config.py) and run `cdk deploy`.
Users can adjust the [text generation parameters](https://docs.anthropic.com/claude/reference/complete_post) from the custom bot creation screen. If the bot is not used, the default parameters set in [config.py](./backend/app/config.py) will be used.

```py
# See: https://docs.anthropic.com/claude/reference/complete_post
GENERATION_CONFIG = {
DEFAULT_GENERATION_CONFIG = {
"max_tokens": 2000,
"top_k": 250,
"top_p": 0.999,
Expand Down
42 changes: 36 additions & 6 deletions backend/app/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
from app.config import (
BEDROCK_PRICING,
DEFAULT_EMBEDDING_CONFIG,
GENERATION_CONFIG,
MISTRAL_GENERATION_CONFIG,
DEFAULT_GENERATION_CONFIG,
DEFAULT_MISTRAL_GENERATION_CONFIG,
)
from app.repositories.models.conversation import MessageModel
from app.repositories.models.custom_bot import GenerationParamsModel
from app.utils import get_bedrock_client, is_anthropic_model
from pydantic import BaseModel

Expand All @@ -32,21 +33,27 @@ def compose_args(
model: str,
instruction: str | None = None,
stream: bool = False,
generation_params: GenerationParamsModel | None = None,
) -> dict:
# if model is from Anthropic, use AnthropicBedrock
# otherwise, use bedrock client
model_id = get_model_id(model)
if is_anthropic_model(model_id):
return compose_args_for_anthropic_client(messages, model, instruction, stream)
return compose_args_for_anthropic_client(
messages, model, instruction, stream, generation_params
)
else:
return compose_args_for_other_client(messages, model, instruction, stream)
return compose_args_for_other_client(
messages, model, instruction, stream, generation_params
)


def compose_args_for_other_client(
messages: list[MessageModel],
model: str,
instruction: str | None = None,
stream: bool = False,
generation_params: GenerationParamsModel | None = None,
) -> dict:
arg_messages = []
for message in messages:
Expand All @@ -64,7 +71,18 @@ def compose_args_for_other_client(
arg_messages.append(m)

args = {
**MISTRAL_GENERATION_CONFIG,
**DEFAULT_MISTRAL_GENERATION_CONFIG,
**(
{
"max_tokens": generation_params.max_tokens,
"top_k": generation_params.top_k,
"top_p": generation_params.top_p,
"temperature": generation_params.temperature,
"stop_sequences": generation_params.stop_sequences,
}
if generation_params
else {}
),
"model": get_model_id(model),
"messages": arg_messages,
"stream": stream,
Expand All @@ -79,6 +97,7 @@ def compose_args_for_anthropic_client(
model: str,
instruction: str | None = None,
stream: bool = False,
generation_params: GenerationParamsModel | None = None,
) -> dict:
"""Compose arguments for Anthropic client.
Ref: https://docs.anthropic.com/claude/reference/messages_post
Expand Down Expand Up @@ -110,7 +129,18 @@ def compose_args_for_anthropic_client(
arg_messages.append(m)

args = {
**GENERATION_CONFIG,
**DEFAULT_GENERATION_CONFIG,
**(
{
"max_tokens": generation_params.max_tokens,
"top_k": generation_params.top_k,
"top_p": generation_params.top_p,
"temperature": generation_params.temperature,
"stop_sequences": generation_params.stop_sequences,
}
if generation_params
else {}
),
"model": get_model_id(model),
"messages": arg_messages,
"stream": stream,
Expand Down
8 changes: 4 additions & 4 deletions backend/app/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import TypedDict


class GenerationConfig(TypedDict):
class GenerationParams(TypedDict):
max_tokens: int
top_k: int
top_p: float
Expand All @@ -18,7 +18,7 @@ class EmbeddingConfig(TypedDict):
# Configure generation parameter for Claude chat response.
# Adjust the values according to your application.
# See: https://docs.anthropic.com/claude/reference/complete_post
GENERATION_CONFIG: GenerationConfig = {
DEFAULT_GENERATION_CONFIG: GenerationParams = {
"max_tokens": 2000,
"top_k": 250,
"top_p": 0.999,
Expand All @@ -27,7 +27,7 @@ class EmbeddingConfig(TypedDict):
}

# Ref: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral.html#model-parameters-mistral-request-response
MISTRAL_GENERATION_CONFIG: GenerationConfig = {
DEFAULT_MISTRAL_GENERATION_CONFIG: GenerationParams = {
"max_tokens": 4096,
"top_k": 50,
"top_p": 0.9,
Expand All @@ -45,7 +45,7 @@ class EmbeddingConfig(TypedDict):
}

# Configure search parameter to fetch relevant documents from vector store.
SEARCH_CONFIG = {
DEFAULT_SEARCH_CONFIG = {
"max_results": 20,
}

Expand Down
50 changes: 49 additions & 1 deletion backend/app/repositories/custom_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,34 @@
decompose_bot_alias_id,
decompose_bot_id,
)
from app.config import (
DEFAULT_GENERATION_CONFIG as DEFAULT_CLAUDE_GENERATION_CONFIG,
DEFAULT_MISTRAL_GENERATION_CONFIG,
DEFAULT_SEARCH_CONFIG,
)
from app.repositories.models.custom_bot import (
BotAliasModel,
BotMeta,
BotMetaWithStackInfo,
BotModel,
EmbeddingParamsModel,
KnowledgeModel,
GenerationParamsModel,
SearchParamsModel,
)
from app.routes.schemas.bot import type_sync_status
from app.utils import get_current_time
from boto3.dynamodb.conditions import Attr, Key
from botocore.exceptions import ClientError

TABLE_NAME = os.environ.get("TABLE_NAME", "")
ENABLE_MISTRAL = os.environ.get("ENABLE_MISTRAL", "") == "true"

DEFAULT_GENERATION_CONFIG = (
DEFAULT_MISTRAL_GENERATION_CONFIG
if ENABLE_MISTRAL
else DEFAULT_CLAUDE_GENERATION_CONFIG
)

logger = logging.getLogger(__name__)
sts_client = boto3.client("sts")
Expand All @@ -50,6 +64,8 @@ def store_bot(user_id: str, custom_bot: BotModel):
"LastBotUsed": decimal(custom_bot.last_used_time),
"IsPinned": custom_bot.is_pinned,
"EmbeddingParams": custom_bot.embedding_params.model_dump(),
"GenerationParams": custom_bot.generation_params.model_dump(),
"SearchParams": custom_bot.search_params.model_dump(),
"Knowledge": custom_bot.knowledge.model_dump(),
"SyncStatus": custom_bot.sync_status,
"SyncStatusReason": custom_bot.sync_status_reason,
Expand All @@ -70,6 +86,8 @@ def update_bot(
description: str,
instruction: str,
embedding_params: EmbeddingParamsModel,
generation_params: GenerationParamsModel,
search_params: SearchParamsModel,
knowledge: KnowledgeModel,
sync_status: type_sync_status,
sync_status_reason: str,
Expand All @@ -83,7 +101,7 @@ def update_bot(
try:
response = table.update_item(
Key={"PK": user_id, "SK": compose_bot_id(user_id, bot_id)},
UpdateExpression="SET Title = :title, Description = :description, Instruction = :instruction,EmbeddingParams = :embedding_params, Knowledge = :knowledge, SyncStatus = :sync_status, SyncStatusReason = :sync_status_reason",
UpdateExpression="SET Title = :title, Description = :description, Instruction = :instruction,EmbeddingParams = :embedding_params, Knowledge = :knowledge, SyncStatus = :sync_status, SyncStatusReason = :sync_status_reason, GenerationParams = :generation_params, SearchParams = :search_params",
ExpressionAttributeValues={
":title": title,
":description": description,
Expand All @@ -92,6 +110,8 @@ def update_bot(
":embedding_params": embedding_params.model_dump(),
":sync_status": sync_status,
":sync_status_reason": sync_status_reason,
":generation_params": generation_params.model_dump(),
":search_params": search_params.model_dump(),
},
ReturnValues="ALL_NEW",
ConditionExpression="attribute_exists(PK) AND attribute_exists(SK)",
Expand Down Expand Up @@ -315,6 +335,20 @@ def find_private_bot_by_id(user_id: str, bot_id: str) -> BotModel:
else 200
),
),
generation_params=GenerationParamsModel(
**(
item["GenerationParams"]
if "GenerationParams" in item
else DEFAULT_GENERATION_CONFIG
)
),
search_params=SearchParamsModel(
max_results=(
item["SearchParams"]["max_results"]
if "SearchParams" in item
else DEFAULT_SEARCH_CONFIG["max_results"]
)
),
knowledge=KnowledgeModel(**item["Knowledge"]),
sync_status=item["SyncStatus"],
sync_status_reason=item["SyncStatusReason"],
Expand Down Expand Up @@ -374,6 +408,20 @@ def find_public_bot_by_id(bot_id: str) -> BotModel:
else 200
),
),
generation_params=GenerationParamsModel(
**(
item["GenerationParams"]
if "GenerationParams" in item
else DEFAULT_GENERATION_CONFIG
)
),
search_params=SearchParamsModel(
max_results=(
item["SearchParams"]["max_results"]
if "SearchParams" in item
else DEFAULT_SEARCH_CONFIG["max_results"]
)
),
knowledge=KnowledgeModel(**item["Knowledge"]),
sync_status=item["SyncStatus"],
sync_status_reason=item["SyncStatusReason"],
Expand Down
10 changes: 10 additions & 0 deletions backend/app/repositories/models/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from decimal import Decimal
from pydantic.functional_serializers import PlainSerializer
from typing_extensions import Annotated

# Declare customized float type
Float = Annotated[
# Note: Before decimalization, apply str() to keep the precision
float,
PlainSerializer(lambda v: Decimal(str(v)), return_type=Decimal),
]
15 changes: 15 additions & 0 deletions backend/app/repositories/models/custom_bot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from app.routes.schemas.bot import type_sync_status
from pydantic import BaseModel
from app.repositories.models.common import Float


class EmbeddingParamsModel(BaseModel):
Expand All @@ -13,6 +14,18 @@ class KnowledgeModel(BaseModel):
filenames: list[str]


class GenerationParamsModel(BaseModel):
max_tokens: int
top_k: int
top_p: Float
temperature: Float
stop_sequences: list[str]


class SearchParamsModel(BaseModel):
max_results: int


class BotModel(BaseModel):
id: str
title: str
Expand All @@ -25,6 +38,8 @@ class BotModel(BaseModel):
owner_user_id: str
is_pinned: bool
embedding_params: EmbeddingParamsModel
generation_params: GenerationParamsModel
search_params: SearchParamsModel
knowledge: KnowledgeModel
sync_status: type_sync_status
sync_status_reason: str
Expand Down
12 changes: 12 additions & 0 deletions backend/app/routes/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
BotSwitchVisibilityInput,
EmbeddingParams,
Knowledge,
GenerationParams,
SearchParams,
)
from app.usecases.bot import (
create_new_bot,
Expand Down Expand Up @@ -133,6 +135,16 @@ def get_private_bot(request: Request, bot_id: str):
sitemap_urls=bot.knowledge.sitemap_urls,
filenames=bot.knowledge.filenames,
),
generation_params=GenerationParams(
max_tokens=bot.generation_params.max_tokens,
top_k=bot.generation_params.top_k,
top_p=bot.generation_params.top_p,
temperature=bot.generation_params.temperature,
stop_sequences=bot.generation_params.stop_sequences,
),
search_params=SearchParams(
max_results=bot.search_params.max_results,
),
sync_status=bot.sync_status,
sync_status_reason=bot.sync_status_reason,
sync_last_exec_id=bot.sync_last_exec_id,
Expand Down
20 changes: 20 additions & 0 deletions backend/app/routes/schemas/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,18 @@ class EmbeddingParams(BaseSchema):
chunk_overlap: int


class GenerationParams(BaseSchema):
max_tokens: int
top_k: int
top_p: float
temperature: float
stop_sequences: list[str]


class SearchParams(BaseSchema):
max_results: int


class Knowledge(BaseSchema):
source_urls: list[str]
sitemap_urls: list[str]
Expand All @@ -39,6 +51,8 @@ class BotInput(BaseSchema):
instruction: str
description: str | None
embedding_params: EmbeddingParams | None
generation_params: GenerationParams | None
search_params: SearchParams | None
knowledge: Knowledge | None


Expand All @@ -47,6 +61,8 @@ class BotModifyInput(BaseSchema):
instruction: str
description: str | None
embedding_params: EmbeddingParams | None
generation_params: GenerationParams | None
search_params: SearchParams | None
knowledge: KnowledgeDiffInput | None

def has_update_files(self) -> bool:
Expand Down Expand Up @@ -92,6 +108,8 @@ class BotModifyOutput(BaseSchema):
instruction: str
description: str
embedding_params: EmbeddingParams
generation_params: GenerationParams
search_params: SearchParams
knowledge: Knowledge


Expand All @@ -107,6 +125,8 @@ class BotOutput(BaseSchema):
# Whether the bot is owned by the user
owned: bool
embedding_params: EmbeddingParams
generation_params: GenerationParams
search_params: SearchParams
knowledge: Knowledge
sync_status: type_sync_status
sync_status_reason: str
Expand Down
Loading

0 comments on commit de4f042

Please sign in to comment.