diff --git a/backend/app/config.py b/backend/app/config.py index 0bdc761a..4446fa47 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -13,6 +13,7 @@ class EmbeddingConfig(TypedDict): model_id: str chunk_size: int chunk_overlap: int + enable_partition_pdf: bool # Configure generation parameter for Claude chat response. @@ -42,6 +43,7 @@ class EmbeddingConfig(TypedDict): # NOTE: consider that cohere allows up to 2048 tokens per request "chunk_size": 1000, "chunk_overlap": 200, + "enable_partition_pdf": False, } # Configure search parameter to fetch relevant documents from vector store. diff --git a/backend/app/repositories/custom_bot.py b/backend/app/repositories/custom_bot.py index 9e3b3523..f0c8d88e 100644 --- a/backend/app/repositories/custom_bot.py +++ b/backend/app/repositories/custom_bot.py @@ -334,6 +334,12 @@ def find_private_bot_by_id(user_id: str, bot_id: str) -> BotModel: and "chunk_overlap" in item["EmbeddingParams"] else 200 ), + enable_partition_pdf=( + item["EmbeddingParams"]["enable_partition_pdf"] + if "EmbeddingParams" in item + and "enable_partition_pdf" in item["EmbeddingParams"] + else False + ), ), generation_params=GenerationParamsModel( **( @@ -407,6 +413,12 @@ def find_public_bot_by_id(bot_id: str) -> BotModel: and "chunk_overlap" in item["EmbeddingParams"] else 200 ), + enable_partition_pdf=( + item["EmbeddingParams"]["enable_partition_pdf"] + if "EmbeddingParams" in item + and "enable_partition_pdf" in item["EmbeddingParams"] + else False + ), ), generation_params=GenerationParamsModel( **( diff --git a/backend/app/repositories/models/custom_bot.py b/backend/app/repositories/models/custom_bot.py index 35d7ebed..84f06b00 100644 --- a/backend/app/repositories/models/custom_bot.py +++ b/backend/app/repositories/models/custom_bot.py @@ -6,6 +6,7 @@ class EmbeddingParamsModel(BaseModel): chunk_size: int chunk_overlap: int + enable_partition_pdf: bool class KnowledgeModel(BaseModel): diff --git a/backend/app/routes/bot.py b/backend/app/routes/bot.py index 5b45c27b..c4050fae 100644 --- a/backend/app/routes/bot.py +++ b/backend/app/routes/bot.py @@ -129,6 +129,7 @@ def get_private_bot(request: Request, bot_id: str): embedding_params=EmbeddingParams( chunk_size=bot.embedding_params.chunk_size, chunk_overlap=bot.embedding_params.chunk_overlap, + enable_partition_pdf=bot.embedding_params.enable_partition_pdf, ), knowledge=Knowledge( source_urls=bot.knowledge.source_urls, diff --git a/backend/app/routes/schemas/bot.py b/backend/app/routes/schemas/bot.py index 6ca9b4bf..b7212f8d 100644 --- a/backend/app/routes/schemas/bot.py +++ b/backend/app/routes/schemas/bot.py @@ -16,6 +16,7 @@ class EmbeddingParams(BaseSchema): chunk_size: int chunk_overlap: int + enable_partition_pdf: bool class GenerationParams(BaseSchema): @@ -94,6 +95,8 @@ def is_embedding_required(self, current_bot_model: BotModel) -> bool: == current_bot_model.embedding_params.chunk_size and self.embedding_params.chunk_overlap == current_bot_model.embedding_params.chunk_overlap + and self.embedding_params.enable_partition_pdf + == current_bot_model.embedding_params.enable_partition_pdf ): pass else: diff --git a/backend/app/usecases/bot.py b/backend/app/usecases/bot.py index de8f109c..dcf10d4c 100644 --- a/backend/app/usecases/bot.py +++ b/backend/app/usecases/bot.py @@ -131,6 +131,12 @@ def create_new_bot(user_id: str, bot_input: BotInput) -> BotOutput: else DEFAULT_EMBEDDING_CONFIG["chunk_overlap"] ) + enable_partition_pdf = ( + bot_input.embedding_params.enable_partition_pdf + if bot_input.embedding_params + else DEFAULT_EMBEDDING_CONFIG["enable_partition_pdf"] + ) + generation_params = ( bot_input.generation_params.model_dump() if bot_input.generation_params @@ -158,6 +164,7 @@ def create_new_bot(user_id: str, bot_input: BotInput) -> BotOutput: embedding_params=EmbeddingParamsModel( chunk_size=chunk_size, chunk_overlap=chunk_overlap, + enable_partition_pdf=enable_partition_pdf, ), generation_params=GenerationParamsModel(**generation_params), search_params=SearchParamsModel(**search_params), @@ -185,6 +192,7 @@ def create_new_bot(user_id: str, bot_input: BotInput) -> BotOutput: embedding_params=EmbeddingParams( chunk_size=chunk_size, chunk_overlap=chunk_overlap, + enable_partition_pdf=enable_partition_pdf, ), generation_params=GenerationParams(**generation_params), search_params=SearchParams(**search_params), @@ -239,6 +247,12 @@ def modify_owned_bot( else DEFAULT_EMBEDDING_CONFIG["chunk_overlap"] ) + enable_partition_pdf = ( + modify_input.embedding_params.enable_partition_pdf + if modify_input.embedding_params + else DEFAULT_EMBEDDING_CONFIG["enable_partition_pdf"] + ) + generation_params = ( modify_input.generation_params.model_dump() if modify_input.generation_params @@ -266,6 +280,7 @@ def modify_owned_bot( embedding_params=EmbeddingParamsModel( chunk_size=chunk_size, chunk_overlap=chunk_overlap, + enable_partition_pdf=enable_partition_pdf, ), generation_params=GenerationParamsModel(**generation_params), search_params=SearchParamsModel(**search_params), @@ -286,6 +301,7 @@ def modify_owned_bot( embedding_params=EmbeddingParams( chunk_size=chunk_size, chunk_overlap=chunk_overlap, + enable_partition_pdf=enable_partition_pdf, ), generation_params=GenerationParams(**generation_params), search_params=SearchParams(**search_params), diff --git a/backend/embedding.requirements.txt b/backend/embedding.requirements.txt index 26a3e400..4470fef7 100644 --- a/backend/embedding.requirements.txt +++ b/backend/embedding.requirements.txt @@ -15,3 +15,5 @@ unstructured[docx]==0.12.6 unstructured[xlsx]==0.12.6 unstructured[pptx]==0.12.6 unstructured[md]==0.12.6 +retry==0.9.2 +types-retry==0.9.9.4 \ No newline at end of file diff --git a/backend/embedding/loaders/s3.py b/backend/embedding/loaders/s3.py index b8d6da55..148aac9f 100644 --- a/backend/embedding/loaders/s3.py +++ b/backend/embedding/loaders/s3.py @@ -1,9 +1,14 @@ import os import tempfile - +import logging import boto3 +from distutils.util import strtobool from embedding.loaders.base import BaseLoader, Document from unstructured.partition.auto import partition +from unstructured.partition.pdf import partition_pdf + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) class S3FileLoader(BaseLoader): @@ -11,10 +16,17 @@ class S3FileLoader(BaseLoader): Reference: `langchain_community.document_loaders.S3FileLoader` class """ - def __init__(self, bucket: str, key: str, mode: str = "single"): + def __init__( + self, + bucket: str, + key: str, + mode: str = "single", + enable_partition_pdf: bool = False, + ): self.bucket = bucket self.key = key self.mode = mode + self.enable_partition_pdf = enable_partition_pdf def _get_elements(self) -> list: """Get elements.""" @@ -23,7 +35,19 @@ def _get_elements(self) -> list: file_path = f"{temp_dir}/{self.key}" os.makedirs(os.path.dirname(file_path), exist_ok=True) s3.download_file(self.bucket, self.key, file_path) - return partition(filename=file_path) + extension = os.path.splitext(file_path)[1] + + if extension == ".pdf" and self.enable_partition_pdf == True: + logger.info(f"Start partitioning using hi-resolution mode: {file_path}") + return partition_pdf( + filename=file_path, + strategy="hi_res", + infer_table_structure=True, + extract_images_in_pdf=False, + ) + else: + logger.info(f"Start partitioning using auto mode: {file_path}") + return partition(filename=file_path) def _get_metadata(self) -> dict: return {"source": f"s3://{self.bucket}/{self.key}"} diff --git a/backend/embedding/main.py b/backend/embedding/main.py index b146bc9e..bb722c17 100644 --- a/backend/embedding/main.py +++ b/backend/embedding/main.py @@ -1,10 +1,14 @@ import argparse import json import logging +import multiprocessing +from multiprocessing.managers import ListProxy import os +from typing import Any import pg8000 import requests +from retry import retry from app.config import DEFAULT_EMBEDDING_CONFIG from app.repositories.common import _get_table_client, RecordNotFoundError @@ -24,9 +28,13 @@ from ulid import ULID logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) BEDROCK_REGION = os.environ.get("BEDROCK_REGION", "us-east-1") - +RETRIES_TO_INSERT_TO_POSTGRES = 4 +RETRY_DELAY_TO_INSERT_TO_POSTGRES = 2 +RETRIES_TO_UPDATE_SYNC_STATUS = 4 +RETRY_DELAY_TO_UPDATE_SYNC_STATUS = 2 DB_NAME = os.environ.get("DB_NAME", "postgres") DB_HOST = os.environ.get("DB_HOST", "") @@ -48,8 +56,9 @@ def get_exec_id() -> str: return task_id +@retry(tries=RETRIES_TO_INSERT_TO_POSTGRES, delay=RETRY_DELAY_TO_INSERT_TO_POSTGRES) def insert_to_postgres( - bot_id: str, contents: list[str], sources: list[str], embeddings: list[list[float]] + bot_id: str, contents: ListProxy, sources: ListProxy, embeddings: ListProxy ): conn = pg8000.connect( database=DB_NAME, @@ -70,13 +79,13 @@ def insert_to_postgres( zip(sources, contents, embeddings) ): id_ = str(ULID()) - print(f"Preview of content {i}: {content[:200]}") + logger.info(f"Preview of content {i}: {content[:200]}") values_to_insert.append( (id_, bot_id, content, source, json.dumps(embedding)) ) cursor.executemany(insert_query, values_to_insert) conn.commit() - print(f"Successfully inserted {len(values_to_insert)} records.") + logger.info(f"Successfully inserted {len(values_to_insert)} records.") except Exception as e: conn.rollback() raise e @@ -84,6 +93,7 @@ def insert_to_postgres( conn.close() +@retry(tries=RETRIES_TO_UPDATE_SYNC_STATUS, delay=RETRY_DELAY_TO_UPDATE_SYNC_STATUS) def update_sync_status( user_id: str, bot_id: str, @@ -105,9 +115,9 @@ def update_sync_status( def embed( loader: BaseLoader, - contents: list[str], - sources: list[str], - embeddings: list[list[float]], + contents: ListProxy, + sources: ListProxy, + embeddings: ListProxy, chunk_size: int, chunk_overlap: int, ): @@ -139,12 +149,13 @@ def main( filenames: list[str], chunk_size: int, chunk_overlap: int, + enable_partition_pdf: bool, ): exec_id = "" try: exec_id = get_exec_id() except Exception as e: - print(f"[ERROR] Failed to get exec_id: {e}") + logger.error(f"[ERROR] Failed to get exec_id: {e}") exec_id = "FAILED_TO_GET_ECS_EXEC_ID" update_sync_status( @@ -169,44 +180,56 @@ def main( return # Calculate embeddings using LangChain - contents: list[str] = [] - sources: list[str] = [] - embeddings: list[list[float]] = [] - - if len(source_urls) > 0: - embed( - UrlLoader(source_urls), - contents, - sources, - embeddings, - chunk_size, - chunk_overlap, - ) - if len(sitemap_urls) > 0: - for sitemap_url in sitemap_urls: - raise NotImplementedError() - if len(filenames) > 0: - for filename in filenames: + with multiprocessing.Manager() as manager: + contents: ListProxy = manager.list() + sources: ListProxy = manager.list() + embeddings: ListProxy = manager.list() + + if len(source_urls) > 0: embed( - S3FileLoader( - bucket=DOCUMENT_BUCKET, - key=compose_upload_document_s3_path(user_id, bot_id, filename), - ), + UrlLoader(source_urls), contents, sources, embeddings, chunk_size, chunk_overlap, ) - - print(f"Number of chunks: {len(contents)}") - - # Insert records into postgres - insert_to_postgres(bot_id, contents, sources, embeddings) - status_reason = "Successfully inserted to vector store." + if len(sitemap_urls) > 0: + for sitemap_url in sitemap_urls: + raise NotImplementedError() + if len(filenames) > 0: + with multiprocessing.Pool(processes=None) as pool: + futures = [ + pool.apply_async( + embed, + args=( + S3FileLoader( + bucket=DOCUMENT_BUCKET, + key=compose_upload_document_s3_path( + user_id, bot_id, filename + ), + enable_partition_pdf=enable_partition_pdf, + ), + contents, + sources, + embeddings, + chunk_size, + chunk_overlap, + ), + ) + for filename in filenames + ] + for future in futures: + future.get() + + logger.info(f"Number of chunks: {len(contents)}") + + # Insert records into postgres + insert_to_postgres(bot_id, contents, sources, embeddings) + status_reason = "Successfully inserted to vector store." except Exception as e: - print("[ERROR] Failed to embed.") - print(e) + logger.error("[ERROR] Failed to embed.") + logger.error(e) update_sync_status( user_id, bot_id, @@ -243,17 +266,26 @@ def main( embedding_params = new_image.embedding_params chunk_size = embedding_params.chunk_size chunk_overlap = embedding_params.chunk_overlap + enable_partition_pdf = embedding_params.enable_partition_pdf knowledge = new_image.knowledge sitemap_urls = knowledge.sitemap_urls source_urls = knowledge.source_urls filenames = knowledge.filenames - print(f"source_urls to crawl: {source_urls}") - print(f"sitemap_urls to crawl: {sitemap_urls}") - print(f"filenames: {filenames}") - print(f"chunk_size: {chunk_size}") - print(f"chunk_overlap: {chunk_overlap}") + logger.info(f"source_urls to crawl: {source_urls}") + logger.info(f"sitemap_urls to crawl: {sitemap_urls}") + logger.info(f"filenames: {filenames}") + logger.info(f"chunk_size: {chunk_size}") + logger.info(f"chunk_overlap: {chunk_overlap}") + logger.info(f"enable_partition_pdf: {enable_partition_pdf}") main( - user_id, bot_id, sitemap_urls, source_urls, filenames, chunk_size, chunk_overlap + user_id, + bot_id, + sitemap_urls, + source_urls, + filenames, + chunk_size, + chunk_overlap, + enable_partition_pdf, ) diff --git a/backend/requirements.txt b/backend/requirements.txt index 2c1c54ef..428c7c24 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -11,3 +11,5 @@ pg8000==1.30.3 argparse==1.4.0 anthropic==0.18.1 anthropic[bedrock]==0.18.1 +retry==0.9.2 +types-retry==0.9.9.4 \ No newline at end of file diff --git a/backend/tests/test_repositories/test_conversation.py b/backend/tests/test_repositories/test_conversation.py index 98db72f3..fcd452dd 100644 --- a/backend/tests/test_repositories/test_conversation.py +++ b/backend/tests/test_repositories/test_conversation.py @@ -368,6 +368,7 @@ def setUp(self) -> None: embedding_params=EmbeddingParamsModel( chunk_size=DEFAULT_EMBEDDING_CONFIG["chunk_size"], chunk_overlap=DEFAULT_EMBEDDING_CONFIG["chunk_overlap"], + enable_partition_pdf=DEFAULT_EMBEDDING_CONFIG["enable_partition_pdf"], ), generation_params=GenerationParamsModel( max_tokens=2000, @@ -404,6 +405,7 @@ def setUp(self) -> None: embedding_params=EmbeddingParamsModel( chunk_size=DEFAULT_EMBEDDING_CONFIG["chunk_size"], chunk_overlap=DEFAULT_EMBEDDING_CONFIG["chunk_overlap"], + enable_partition_pdf=DEFAULT_EMBEDDING_CONFIG["enable_partition_pdf"], ), generation_params=GenerationParamsModel( max_tokens=2000, diff --git a/backend/tests/test_repositories/test_custom_bot.py b/backend/tests/test_repositories/test_custom_bot.py index e3c96d6e..dda88b29 100644 --- a/backend/tests/test_repositories/test_custom_bot.py +++ b/backend/tests/test_repositories/test_custom_bot.py @@ -48,6 +48,7 @@ def test_store_and_find_bot(self): embedding_params=EmbeddingParamsModel( chunk_size=DEFAULT_EMBEDDING_CONFIG["chunk_size"], chunk_overlap=DEFAULT_EMBEDDING_CONFIG["chunk_overlap"], + enable_partition_pdf=DEFAULT_EMBEDDING_CONFIG["enable_partition_pdf"], ), generation_params=GenerationParamsModel( max_tokens=2000, @@ -89,6 +90,11 @@ def test_store_and_find_bot(self): bot.embedding_params.chunk_overlap, DEFAULT_EMBEDDING_CONFIG["chunk_overlap"], ) + + self.assertEqual( + bot.embedding_params.enable_partition_pdf, + DEFAULT_EMBEDDING_CONFIG["enable_partition_pdf"], + ) self.assertEqual(bot.generation_params.max_tokens, 2000) self.assertEqual(bot.generation_params.top_k, 250) self.assertEqual(bot.generation_params.top_p, 0.999) @@ -133,6 +139,7 @@ def test_update_bot_last_used_time(self): embedding_params=EmbeddingParamsModel( chunk_size=DEFAULT_EMBEDDING_CONFIG["chunk_size"], chunk_overlap=DEFAULT_EMBEDDING_CONFIG["chunk_overlap"], + enable_partition_pdf=DEFAULT_EMBEDDING_CONFIG["enable_partition_pdf"], ), generation_params=GenerationParamsModel( max_tokens=2000, @@ -179,6 +186,7 @@ def test_update_delete_bot_publication(self): embedding_params=EmbeddingParamsModel( chunk_size=DEFAULT_EMBEDDING_CONFIG["chunk_size"], chunk_overlap=DEFAULT_EMBEDDING_CONFIG["chunk_overlap"], + enable_partition_pdf=DEFAULT_EMBEDDING_CONFIG["enable_partition_pdf"], ), generation_params=GenerationParamsModel( max_tokens=2000, @@ -234,6 +242,7 @@ def test_update_bot(self): embedding_params=EmbeddingParamsModel( chunk_size=DEFAULT_EMBEDDING_CONFIG["chunk_size"], chunk_overlap=DEFAULT_EMBEDDING_CONFIG["chunk_overlap"], + enable_partition_pdf=DEFAULT_EMBEDDING_CONFIG["enable_partition_pdf"], ), generation_params=GenerationParamsModel( max_tokens=2000, @@ -265,8 +274,7 @@ def test_update_bot(self): description="Updated Description", instruction="Updated Instruction", embedding_params=EmbeddingParamsModel( - chunk_size=500, - chunk_overlap=100, + chunk_size=500, chunk_overlap=100, enable_partition_pdf=False ), generation_params=GenerationParamsModel( max_tokens=2500, @@ -294,6 +302,8 @@ def test_update_bot(self): self.assertEqual(bot.embedding_params.chunk_size, 500) self.assertEqual(bot.embedding_params.chunk_overlap, 100) + self.assertEqual(bot.embedding_params.enable_partition_pdf, False) + self.assertEqual(bot.generation_params.max_tokens, 2500) self.assertEqual(bot.generation_params.top_k, 200) self.assertEqual(bot.generation_params.top_p, 0.99) @@ -324,6 +334,7 @@ def setUp(self) -> None: embedding_params=EmbeddingParamsModel( chunk_size=DEFAULT_EMBEDDING_CONFIG["chunk_size"], chunk_overlap=DEFAULT_EMBEDDING_CONFIG["chunk_overlap"], + enable_partition_pdf=DEFAULT_EMBEDDING_CONFIG["enable_partition_pdf"], ), generation_params=GenerationParamsModel( max_tokens=2000, @@ -361,6 +372,7 @@ def setUp(self) -> None: embedding_params=EmbeddingParamsModel( chunk_size=DEFAULT_EMBEDDING_CONFIG["chunk_size"], chunk_overlap=DEFAULT_EMBEDDING_CONFIG["chunk_overlap"], + enable_partition_pdf=DEFAULT_EMBEDDING_CONFIG["enable_partition_pdf"], ), generation_params=GenerationParamsModel( max_tokens=2000, @@ -398,6 +410,7 @@ def setUp(self) -> None: embedding_params=EmbeddingParamsModel( chunk_size=DEFAULT_EMBEDDING_CONFIG["chunk_size"], chunk_overlap=DEFAULT_EMBEDDING_CONFIG["chunk_overlap"], + enable_partition_pdf=DEFAULT_EMBEDDING_CONFIG["enable_partition_pdf"], ), generation_params=GenerationParamsModel( max_tokens=2000, @@ -435,6 +448,7 @@ def setUp(self) -> None: embedding_params=EmbeddingParamsModel( chunk_size=DEFAULT_EMBEDDING_CONFIG["chunk_size"], chunk_overlap=DEFAULT_EMBEDDING_CONFIG["chunk_overlap"], + enable_partition_pdf=DEFAULT_EMBEDDING_CONFIG["enable_partition_pdf"], ), generation_params=GenerationParamsModel( max_tokens=2000, @@ -471,6 +485,7 @@ def setUp(self) -> None: embedding_params=EmbeddingParamsModel( chunk_size=DEFAULT_EMBEDDING_CONFIG["chunk_size"], chunk_overlap=DEFAULT_EMBEDDING_CONFIG["chunk_overlap"], + enable_partition_pdf=DEFAULT_EMBEDDING_CONFIG["enable_partition_pdf"], ), generation_params=GenerationParamsModel( max_tokens=2000, @@ -507,6 +522,7 @@ def setUp(self) -> None: embedding_params=EmbeddingParamsModel( chunk_size=DEFAULT_EMBEDDING_CONFIG["chunk_size"], chunk_overlap=DEFAULT_EMBEDDING_CONFIG["chunk_overlap"], + enable_partition_pdf=DEFAULT_EMBEDDING_CONFIG["enable_partition_pdf"], ), generation_params=GenerationParamsModel( max_tokens=2000, @@ -615,6 +631,7 @@ def setUp(self) -> None: embedding_params=EmbeddingParamsModel( chunk_size=DEFAULT_EMBEDDING_CONFIG["chunk_size"], chunk_overlap=DEFAULT_EMBEDDING_CONFIG["chunk_overlap"], + enable_partition_pdf=DEFAULT_EMBEDDING_CONFIG["enable_partition_pdf"], ), generation_params=GenerationParamsModel( max_tokens=2000, @@ -651,6 +668,7 @@ def setUp(self) -> None: embedding_params=EmbeddingParamsModel( chunk_size=DEFAULT_EMBEDDING_CONFIG["chunk_size"], chunk_overlap=DEFAULT_EMBEDDING_CONFIG["chunk_overlap"], + enable_partition_pdf=DEFAULT_EMBEDDING_CONFIG["enable_partition_pdf"], ), generation_params=GenerationParamsModel( max_tokens=2000, @@ -687,6 +705,7 @@ def setUp(self) -> None: embedding_params=EmbeddingParamsModel( chunk_size=DEFAULT_EMBEDDING_CONFIG["chunk_size"], chunk_overlap=DEFAULT_EMBEDDING_CONFIG["chunk_overlap"], + enable_partition_pdf=DEFAULT_EMBEDDING_CONFIG["enable_partition_pdf"], ), generation_params=GenerationParamsModel( max_tokens=2000, @@ -744,6 +763,7 @@ def test_update_bot_visibility(self): embedding_params=EmbeddingParamsModel( chunk_size=DEFAULT_EMBEDDING_CONFIG["chunk_size"], chunk_overlap=DEFAULT_EMBEDDING_CONFIG["chunk_overlap"], + enable_partition_pdf=DEFAULT_EMBEDDING_CONFIG["enable_partition_pdf"], ), generation_params=GenerationParamsModel( max_tokens=2000, diff --git a/backend/tests/test_usecases/utils/bot_factory.py b/backend/tests/test_usecases/utils/bot_factory.py index b891cc45..04af5585 100644 --- a/backend/tests/test_usecases/utils/bot_factory.py +++ b/backend/tests/test_usecases/utils/bot_factory.py @@ -33,6 +33,7 @@ def create_test_private_bot( embedding_params=EmbeddingParamsModel( chunk_size=DEFAULT_EMBEDDING_CONFIG["chunk_size"], chunk_overlap=DEFAULT_EMBEDDING_CONFIG["chunk_overlap"], + enable_partition_pdf=DEFAULT_EMBEDDING_CONFIG["enable_partition_pdf"], ), generation_params=GenerationParamsModel( max_tokens=2000, @@ -78,6 +79,7 @@ def create_test_public_bot( embedding_params=EmbeddingParamsModel( chunk_size=DEFAULT_EMBEDDING_CONFIG["chunk_size"], chunk_overlap=DEFAULT_EMBEDDING_CONFIG["chunk_overlap"], + enable_partition_pdf=DEFAULT_EMBEDDING_CONFIG["enable_partition_pdf"], ), generation_params=GenerationParamsModel( max_tokens=2000, diff --git a/cdk/bin/bedrock-chat.ts b/cdk/bin/bedrock-chat.ts index f4c03dca..0844e73f 100644 --- a/cdk/bin/bedrock-chat.ts +++ b/cdk/bin/bedrock-chat.ts @@ -33,6 +33,10 @@ const USER_POOL_DOMAIN_PREFIX: string = app.node.tryGetContext( const RDS_SCHEDULES: CronScheduleProps = app.node.tryGetContext("rdbSchedules"); const ENABLE_MISTRAL: boolean = app.node.tryGetContext("enableMistral"); +// container size of embedding ecs tasks +const EMBEDDING_CONTAINER_VCPU:number = app.node.tryGetContext("embeddingContainerVcpu") +const EMBEDDING_CONTAINER_MEMORY:number = app.node.tryGetContext("embeddingContainerMemory") + // WAF for frontend // 2023/9: Currently, the WAF for CloudFront needs to be created in the North America region (us-east-1), so the stacks are separated // https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-wafv2-webacl.html @@ -63,5 +67,7 @@ const chat = new BedrockChatStack(app, `BedrockChatStack`, { ALLOWED_SIGN_UP_EMAIL_DOMAINS, rdsSchedules: RDS_SCHEDULES, enableMistral: ENABLE_MISTRAL, + embeddingContainerVcpu: EMBEDDING_CONTAINER_VCPU, + embeddingContainerMemory: EMBEDDING_CONTAINER_MEMORY, }); chat.addDependency(waf); diff --git a/cdk/cdk.json b/cdk/cdk.json index 48836863..c9e5a5a3 100644 --- a/cdk/cdk.json +++ b/cdk/cdk.json @@ -80,6 +80,8 @@ "rdbSchedules": { "stop": {}, "start": {} - } + }, + "embeddingContainerVcpu": 2048, + "embeddingContainerMemory": 4096 } } \ No newline at end of file diff --git a/cdk/lib/bedrock-chat-stack.ts b/cdk/lib/bedrock-chat-stack.ts index f6b539ae..81c5f6b4 100644 --- a/cdk/lib/bedrock-chat-stack.ts +++ b/cdk/lib/bedrock-chat-stack.ts @@ -33,6 +33,8 @@ export interface BedrockChatStackProps extends StackProps { readonly allowedSignUpEmailDomains: string[]; readonly rdsSchedules: CronScheduleProps; readonly enableMistral: boolean; + readonly embeddingContainerVcpu: number; + readonly embeddingContainerMemory: number; } export class BedrockChatStack extends cdk.Stack { @@ -170,6 +172,8 @@ export class BedrockChatStack extends cdk.Stack { dbConfig, tableAccessRole: database.tableAccessRole, documentBucket, + embeddingContainerVcpu: props.embeddingContainerVcpu, + embeddingContainerMemory: props.embeddingContainerMemory, }); documentBucket.grantRead(embedding.container.taskDefinition.taskRole); diff --git a/cdk/lib/constructs/embedding.ts b/cdk/lib/constructs/embedding.ts index 20fcf17c..317f00a8 100644 --- a/cdk/lib/constructs/embedding.ts +++ b/cdk/lib/constructs/embedding.ts @@ -33,6 +33,8 @@ export interface EmbeddingProps { readonly bedrockRegion: string; readonly tableAccessRole: iam.IRole; readonly documentBucket: IBucket; + readonly embeddingContainerVcpu: number; + readonly embeddingContainerMemory: number; } export class Embedding extends Construct { @@ -47,13 +49,14 @@ export class Embedding extends Construct { */ const cluster = new ecs.Cluster(this, "Cluster", { vpc: props.vpc, + containerInsights: true, }); const taskDefinition = new ecs.FargateTaskDefinition( this, "TaskDefinition", { - cpu: 2048, - memoryLimitMiB: 4096, + cpu: props.embeddingContainerVcpu, + memoryLimitMiB: props.embeddingContainerMemory, runtimePlatform: { cpuArchitecture: ecs.CpuArchitecture.X86_64, operatingSystemFamily: ecs.OperatingSystemFamily.LINUX, diff --git a/cdk/test/cdk.test.ts b/cdk/test/cdk.test.ts index 6a00983f..a5ec3c87 100644 --- a/cdk/test/cdk.test.ts +++ b/cdk/test/cdk.test.ts @@ -29,6 +29,8 @@ describe("Fine-grained Assertions Test", () => { start: {}, }, enableMistral: false, + embeddingContainerVcpu: 1024, + embeddingContainerMemory: 2048, } ); const hasGoogleProviderTemplate = Template.fromStack( @@ -82,6 +84,8 @@ describe("Fine-grained Assertions Test", () => { start: {}, }, enableMistral: false, + embeddingContainerVcpu: 1024, + embeddingContainerMemory: 2048, } ); const hasOidcProviderTemplate = Template.fromStack(hasOidcProviderStack); @@ -124,6 +128,8 @@ describe("Fine-grained Assertions Test", () => { start: {}, }, enableMistral: false, + embeddingContainerVcpu: 1024, + embeddingContainerMemory: 2048, }); const template = Template.fromStack(stack); @@ -166,6 +172,8 @@ describe("Scheduler Test", () => { }, }, enableMistral: false, + embeddingContainerVcpu: 1024, + embeddingContainerMemory: 2048, }); const template = Template.fromStack(hasScheduleStack); template.hasResourceProperties("AWS::Scheduler::Schedule", { @@ -192,6 +200,8 @@ describe("Scheduler Test", () => { start: {}, }, enableMistral: false, + embeddingContainerVcpu: 1024, + embeddingContainerMemory: 2048, }); const template = Template.fromStack(defaultStack); // The stack should have only 1 rule for exporting the data from ddb to s3 diff --git a/frontend/src/@types/bot.d.ts b/frontend/src/@types/bot.d.ts index 5319c365..6b8e0fea 100644 --- a/frontend/src/@types/bot.d.ts +++ b/frontend/src/@types/bot.d.ts @@ -22,6 +22,7 @@ export type BotKnowledge = { export type EmdeddingParams = { chunkSize: number; chunkOverlap: number; + enablePartitionPdf: boolean; }; export type BotKnowledgeDiff = { diff --git a/frontend/src/constants/index.ts b/frontend/src/constants/index.ts index 19eb15bc..e49ebf15 100644 --- a/frontend/src/constants/index.ts +++ b/frontend/src/constants/index.ts @@ -3,6 +3,7 @@ import { EmdeddingParams, GenerationParams, SearchParams } from '../@types/bot'; export const DEFAULT_EMBEDDING_CONFIG: EmdeddingParams = { chunkSize: 1000, chunkOverlap: 200, + enablePartitionPdf: false }; export const EDGE_EMBEDDING_PARAMS = { diff --git a/frontend/src/i18n/en/index.ts b/frontend/src/i18n/en/index.ts index 3733bed0..e7ceb918 100644 --- a/frontend/src/i18n/en/index.ts +++ b/frontend/src/i18n/en/index.ts @@ -376,6 +376,10 @@ How would you categorize this email?`, label: 'chunk overlap', hint: 'You can specify the number of overlapping characters between adjacent chunks.', }, + enablePartitionPdf: { + label: 'Enable detailed PDF analysis. If enabled, the PDF will be analyzed in detail over time.', + hint: 'It is effective when you want to improve search accuracy. Computation costs increase because computation takes more time.', + }, help: { chunkSize: "When the chunk size is too small, contextual information can be lost, and when it's too large, different contextual information may exist within the same chunk, potentially reducing search accuracy.", diff --git a/frontend/src/i18n/ja/index.ts b/frontend/src/i18n/ja/index.ts index c7f74d92..8b8bccfc 100644 --- a/frontend/src/i18n/ja/index.ts +++ b/frontend/src/i18n/ja/index.ts @@ -380,6 +380,10 @@ const translation = { label: 'チャンクオーバーラップ', hint: '隣接するチャンク同士で重複する文字数を指定します。', }, + enablePartitionPdf: { + label: 'PDFの詳細解析の有効化。有効にすると時間をかけてPDFを詳細に分析します。', + hint: '検索精度を高めたい場合に有効です。計算により多くの時間がかかるため計算コストが増加します。', + }, help: { chunkSize: 'チャンクサイズが小さすぎると文脈情報が失われ、大きすぎると同一チャンクの中に異なる文脈の情報が存在することになり、検索精度が低下する場合があります。', diff --git a/frontend/src/pages/BotEditPage.tsx b/frontend/src/pages/BotEditPage.tsx index 9f2afdaf..1b082315 100644 --- a/frontend/src/pages/BotEditPage.tsx +++ b/frontend/src/pages/BotEditPage.tsx @@ -25,6 +25,7 @@ import { EDGE_SEARCH_PARAMS, } from '../constants'; import { Slider } from '../components/Slider'; +import Toggle from '../components/Toggle'; import ExpandableDrawerGroup from '../components/ExpandableDrawerGroup'; import useErrorMessage from '../hooks/useErrorMessage'; import Help from '../components/Help'; @@ -56,6 +57,7 @@ const BotEditPage: React.FC = () => { const [embeddingParams, setEmbeddingParams] = useState({ chunkSize: DEFAULT_EMBEDDING_CONFIG.chunkSize, chunkOverlap: DEFAULT_EMBEDDING_CONFIG.chunkOverlap, + enablePartitionPdf: DEFAULT_EMBEDDING_CONFIG.enablePartitionPdf, }); const [addedFilenames, setAddedFilenames] = useState([]); const [unchangedFilenames, setUnchangedFilenames] = useState([]); @@ -364,6 +366,7 @@ const BotEditPage: React.FC = () => { embeddingParams: { chunkSize: embeddingParams.chunkSize, chunkOverlap: embeddingParams.chunkOverlap, + enablePartitionPdf: embeddingParams.enablePartitionPdf, }, generationParams: { maxTokens, @@ -417,6 +420,7 @@ const BotEditPage: React.FC = () => { embeddingParams: { chunkSize: embeddingParams?.chunkSize, chunkOverlap: embeddingParams?.chunkOverlap, + enablePartitionPdf: embeddingParams?.enablePartitionPdf, }, generationParams: { maxTokens, @@ -678,6 +682,19 @@ const BotEditPage: React.FC = () => { errorMessage={errorMessages['chunkOverlap']} /> +
+ + setEmbeddingParams((params) => ({ + ...params, + enablePartitionPdf: enablePartitionPdf, + })) + } + /> +