diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..92d110c9 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +*.html linguist-documentation diff --git a/README.md b/README.md index 40dbcb9a..02bdce63 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ Build, deploy, observe, and optimize your RAG engine. # About -R2R (Rag to Riches) is the ultimate open-source framework for building and deploying high-quality Retrieval-Augmented Generation (RAG) systems. Designed to bridge the gap between local LLM experimentation and scalable, production-ready applications, R2R provides a comprehensive, feature-rich environment for developers. +R2R (Rag to Riches) is the ultimate open-source answer engine for building and deploying high-quality Retrieval-Augmented Generation (RAG) systems. Designed to bridge the gap between local LLM experimentation and scalable, production-ready applications, R2R provides a comprehensive, feature-rich environment for developers. For a more complete view of R2R, check out our [documentation](https://r2r-docs.sciphi.ai/). diff --git a/docs/pages/cookbooks/knowledge-graph.mdx b/docs/pages/cookbooks/knowledge-graph.mdx index fa83b6a2..f58669da 100644 --- a/docs/pages/cookbooks/knowledge-graph.mdx +++ b/docs/pages/cookbooks/knowledge-graph.mdx @@ -217,7 +217,7 @@ from r2r import ( GenerationConfig, Pipeline, R2RAppBuilder, - KGAgentPipe, + KGAgentSearchPipe, Relation, run_pipeline, ) @@ -445,13 +445,13 @@ Results: Finally, we are in a position to automatically answer difficult to manage queries with a knowledge agent. The snippet below injects our custom schema into a generic few-shot prompt and uses gpt-4o to create a relevant query ```python filename="r2r/examples/scripts/advanced_kg_cookbook.py" - kg_agent_pipe = KGAgentPipe( + kg_agent_search_pipe = KGAgentSearchPipe( r2r_app.providers.kg, r2r_app.providers.llm, r2r_app.providers.prompt ) # Define the pipeline kg_pipe = Pipeline() - kg_pipe.add_pipe(kg_agent_pipe) + kg_pipe.add_pipe(kg_agent_search_pipe) kg.update_agent_prompt(prompt_provider, entity_types, relations) diff --git a/docs/pages/deep-dive/ingestion.mdx b/docs/pages/deep-dive/ingestion.mdx index 430b4223..311e229e 100644 --- a/docs/pages/deep-dive/ingestion.mdx +++ b/docs/pages/deep-dive/ingestion.mdx @@ -34,7 +34,7 @@ The **R2RVectorStoragePipe** stores the generated embeddings in a vector databas ### Knowledge Graph Pipes When the knowledge graph provider settings are non-null, the pipeline includes pipes for generating and storing knowledge graph data. -- **KGAgentPipe**: Generates Cypher queries to interact with a Neo4j knowledge graph. +- **KGAgentSearchPipe**: Generates Cypher queries to interact with a Neo4j knowledge graph. - **KGStoragePipe**: Stores the generated knowledge graph data in the specified knowledge graph database. @@ -72,7 +72,7 @@ custom_ingestion_pipeline = CustomIngestionPipeline() pipelines = R2RPipelineFactory(config, pipes).create_pipelines( ingestion_pipeline = custom_ingestion_pipeline ) -r2r = R2RApp(config, providers, pipelines) +r2r = R2RApp(config=config, providers=providers, pipes=pipes, pipelines=pipelines) ``` ### Conclusion diff --git a/docs/pages/index.mdx b/docs/pages/index.mdx index e0e7b6ce..6e05e13d 100644 --- a/docs/pages/index.mdx +++ b/docs/pages/index.mdx @@ -6,7 +6,7 @@ import GithubButtons from '../components/GithubButtons'; -R2R (Rag to Riches) is the ultimate open-source framework for building and deploying high-quality Retrieval-Augmented Generation (RAG) systems. Designed to bridge the gap between local LLM experimentation and scalable, production-ready applications, R2R provides a comprehensive, feature-rich environment for developers. +R2R (Rag to Riches) is the ultimate open-source engine for building and deploying high-quality Retrieval-Augmented Generation (RAG) systems. Designed to bridge the gap between local LLM experimentation and scalable, production-ready applications, R2R provides a comprehensive, feature-rich environment for developers. ## Key Features diff --git a/docs/public/swagger.json b/docs/public/swagger.json index ebd9c108..9bc01a7f 100644 --- a/docs/public/swagger.json +++ b/docs/public/swagger.json @@ -1 +1 @@ -{"openapi":"3.1.0","info":{"title":"R2R Application API","version":"1.0.0"},"paths":{"/update_prompt":{"post":{"summary":"Update Prompt","description":"Update a prompt's template and/or input types.","operationId":"update_prompt_app_update_prompt_post","requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/UpdatePromptRequest"}}},"required":true},"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/ingest_documents":{"post":{"summary":"Ingest Documents","operationId":"ingest_documents_app_ingest_documents_post","requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/IngestDocumentsRequest"}}},"required":true},"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/update_documents":{"post":{"summary":"Update Documents","operationId":"update_documents_app_update_documents_post","requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/UpdateDocumentsRequest"}}},"required":true},"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/ingest_files":{"post":{"summary":"Ingest Files","description":"Ingest files into the system.","operationId":"ingest_files_app_ingest_files_post","requestBody":{"content":{"multipart/form-data":{"schema":{"$ref":"#/components/schemas/Body_ingest_files_app_ingest_files_post"}}},"required":true},"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/update_files":{"post":{"summary":"Update Files","operationId":"update_files_app_update_files_post","requestBody":{"content":{"multipart/form-data":{"schema":{"$ref":"#/components/schemas/Body_update_files_app_update_files_post"}}},"required":true},"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/search":{"post":{"summary":"Search","operationId":"search_app_search_post","requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/SearchRequest"}}},"required":true},"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/rag":{"post":{"summary":"Rag","operationId":"rag_app_rag_post","requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/RAGRequest"}}},"required":true},"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/evaluate":{"post":{"summary":"Evaluate","operationId":"evaluate_app_evaluate_post","requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/EvalRequest"}}},"required":true},"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/logs":{"get":{"summary":"Logs","operationId":"logs_app_logs_get","parameters":[{"name":"log_type_filter","in":"query","required":false,"schema":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Log Type Filter"}},{"name":"max_runs_requested","in":"query","required":false,"schema":{"type":"integer","default":100,"title":"Max Runs Requested"}}],"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/analytics":{"post":{"summary":"Analytics","operationId":"analytics_app_analytics_post","requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/Body_analytics_app_analytics_post"}}},"required":true},"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/users_stats":{"get":{"summary":"Users Stats","operationId":"users_stats_app_users_stats_get","parameters":[{"name":"user_ids","in":"query","required":false,"schema":{"anyOf":[{"type":"array","items":{"type":"string","format":"uuid"}},{"type":"null"}],"title":"User Ids"}}],"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/documents_info":{"get":{"summary":"Documents Info","operationId":"documents_info_app_documents_info_get","parameters":[{"name":"document_ids","in":"query","required":false,"schema":{"anyOf":[{"type":"array","items":{"type":"string"}},{"type":"null"}],"title":"Document Ids"}},{"name":"user_ids","in":"query","required":false,"schema":{"anyOf":[{"type":"array","items":{"type":"string"}},{"type":"null"}],"title":"User Ids"}}],"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/delete":{"delete":{"summary":"Delete","operationId":"delete_app_delete_delete","requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/DeleteRequest"}}},"required":true},"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/app_settings":{"get":{"summary":"App Settings","description":"Return the config.json and all prompts.","operationId":"app_settings_app_app_settings_get","responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}}}}},"/openapi_spec":{"get":{"summary":"Openapi Spec","operationId":"openapi_spec_app_openapi_spec_get","responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}}}}}},"components":{"schemas":{"AnalysisTypes":{"properties":{"analysis_types":{"anyOf":[{"additionalProperties":{"items":{"type":"string"},"type":"array"},"type":"object"},{"type":"null"}],"title":"Analysis Types"}},"type":"object","title":"AnalysisTypes"},"Body_analytics_app_analytics_post":{"properties":{"filter_criteria":{"$ref":"#/components/schemas/FilterCriteria"},"analysis_types":{"$ref":"#/components/schemas/AnalysisTypes"}},"type":"object","required":["filter_criteria","analysis_types"],"title":"Body_analytics_app_analytics_post"},"Body_ingest_files_app_ingest_files_post":{"properties":{"files":{"items":{"type":"string","format":"binary"},"type":"array","title":"Files"},"metadatas":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Metadatas"},"ids":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Ids"},"user_ids":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"User Ids"}},"type":"object","required":["files"],"title":"Body_ingest_files_app_ingest_files_post"},"Body_update_files_app_update_files_post":{"properties":{"files":{"items":{"type":"string","format":"binary"},"type":"array","title":"Files"},"metadatas":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Metadatas"},"ids":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Ids"}},"type":"object","required":["files"],"title":"Body_update_files_app_update_files_post"},"DeleteRequest":{"properties":{"keys":{"items":{"type":"string"},"type":"array","title":"Keys"},"values":{"items":{"anyOf":[{"type":"boolean"},{"type":"integer"},{"type":"string"}]},"type":"array","title":"Values"}},"type":"object","required":["keys","values"],"title":"DeleteRequest"},"Document":{"properties":{"id":{"type":"string","format":"uuid","title":"Id"},"type":{"$ref":"#/components/schemas/DocumentType"},"data":{"anyOf":[{"type":"string"},{"type":"string","format":"binary"}],"title":"Data"},"metadata":{"type":"object","title":"Metadata"},"title":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Title"},"user_id":{"anyOf":[{"type":"string","format":"uuid"},{"type":"null"}],"title":"User Id"}},"type":"object","required":["id","type","data","metadata"],"title":"Document","description":"A document that has been stored in the system."},"DocumentType":{"type":"string","enum":["csv","docx","html","json","md","pdf","pptx","txt","xlsx","gif","png","jpg","jpeg","svg","mp3","mp4"],"title":"DocumentType","description":"Types of documents that can be stored."},"EvalRequest":{"properties":{"query":{"type":"string","title":"Query"},"context":{"type":"string","title":"Context"},"completion":{"type":"string","title":"Completion"}},"type":"object","required":["query","context","completion"],"title":"EvalRequest"},"FilterCriteria":{"properties":{"filters":{"anyOf":[{"additionalProperties":{"type":"string"},"type":"object"},{"type":"null"}],"title":"Filters"}},"type":"object","title":"FilterCriteria"},"HTTPValidationError":{"properties":{"detail":{"items":{"$ref":"#/components/schemas/ValidationError"},"type":"array","title":"Detail"}},"type":"object","title":"HTTPValidationError"},"IngestDocumentsRequest":{"properties":{"documents":{"items":{"$ref":"#/components/schemas/Document"},"type":"array","title":"Documents"}},"type":"object","required":["documents"],"title":"IngestDocumentsRequest"},"RAGRequest":{"properties":{"message":{"type":"string","title":"Message"},"search_filters":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Search Filters"},"search_limit":{"type":"integer","title":"Search Limit","default":10},"rag_generation_config":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Rag Generation Config"},"streaming":{"anyOf":[{"type":"boolean"},{"type":"null"}],"title":"Streaming"}},"type":"object","required":["message"],"title":"RAGRequest"},"SearchRequest":{"properties":{"query":{"type":"string","title":"Query"},"search_filters":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Search Filters"},"search_limit":{"type":"integer","title":"Search Limit","default":10}},"type":"object","required":["query"],"title":"SearchRequest"},"UpdateDocumentsRequest":{"properties":{"documents":{"items":{"$ref":"#/components/schemas/Document"},"type":"array","title":"Documents"}},"type":"object","required":["documents"],"title":"UpdateDocumentsRequest"},"UpdatePromptRequest":{"properties":{"name":{"type":"string","title":"Name"},"template":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Template"},"input_types":{"anyOf":[{"additionalProperties":{"type":"string"},"type":"object"},{"type":"null"}],"title":"Input Types"}},"type":"object","required":["name"],"title":"UpdatePromptRequest"},"ValidationError":{"properties":{"loc":{"items":{"anyOf":[{"type":"string"},{"type":"integer"}]},"type":"array","title":"Location"},"msg":{"type":"string","title":"Message"},"type":{"type":"string","title":"Error Type"}},"type":"object","required":["loc","msg","type"],"title":"ValidationError"}}}} +{"openapi":"3.1.0","info":{"title":"R2R Application API","version":"1.0.0"},"paths":{"/update_prompt":{"post":{"summary":"Update Prompt","description":"Update a prompt's template and/or input types.","operationId":"update_prompt_app_update_prompt_post","requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/R2RUpdatePromptRequest"}}},"required":true},"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/ingest_documents":{"post":{"summary":"Ingest Documents","operationId":"ingest_documents_app_ingest_documents_post","requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/R2RIngestDocumentsRequest"}}},"required":true},"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/update_documents":{"post":{"summary":"Update Documents","operationId":"update_documents_app_update_documents_post","requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/R2RUpdateDocumentsRequest"}}},"required":true},"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/ingest_files":{"post":{"summary":"Ingest Files","description":"Ingest files into the system.","operationId":"ingest_files_app_ingest_files_post","requestBody":{"content":{"multipart/form-data":{"schema":{"$ref":"#/components/schemas/Body_ingest_files_app_ingest_files_post"}}},"required":true},"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/update_files":{"post":{"summary":"Update Files","operationId":"update_files_app_update_files_post","requestBody":{"content":{"multipart/form-data":{"schema":{"$ref":"#/components/schemas/Body_update_files_app_update_files_post"}}},"required":true},"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/search":{"post":{"summary":"Search","operationId":"search_app_search_post","requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/VectorSearchRequest"}}},"required":true},"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/rag":{"post":{"summary":"Rag","operationId":"rag_app_rag_post","requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/R2RRAGRequest"}}},"required":true},"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/evaluate":{"post":{"summary":"Evaluate","operationId":"evaluate_app_evaluate_post","requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/R2REvalRequest"}}},"required":true},"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/logs":{"get":{"summary":"Logs","operationId":"logs_app_logs_get","parameters":[{"name":"log_type_filter","in":"query","required":false,"schema":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Log Type Filter"}},{"name":"max_runs_requested","in":"query","required":false,"schema":{"type":"integer","default":100,"title":"Max Runs Requested"}}],"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/analytics":{"post":{"summary":"Analytics","operationId":"analytics_app_analytics_post","requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/Body_analytics_app_analytics_post"}}},"required":true},"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/users_stats":{"get":{"summary":"Users Stats","operationId":"users_stats_app_users_stats_get","parameters":[{"name":"user_ids","in":"query","required":false,"schema":{"anyOf":[{"type":"array","items":{"type":"string","format":"uuid"}},{"type":"null"}],"title":"User Ids"}}],"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/documents_info":{"get":{"summary":"Documents Info","operationId":"documents_info_app_documents_info_get","parameters":[{"name":"document_ids","in":"query","required":false,"schema":{"anyOf":[{"type":"array","items":{"type":"string"}},{"type":"null"}],"title":"Document Ids"}},{"name":"user_ids","in":"query","required":false,"schema":{"anyOf":[{"type":"array","items":{"type":"string"}},{"type":"null"}],"title":"User Ids"}}],"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/delete":{"delete":{"summary":"Delete","operationId":"delete_app_delete_delete","requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/R2RDeleteRequest"}}},"required":true},"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/app_settings":{"get":{"summary":"App Settings","description":"Return the config.json and all prompts.","operationId":"app_settings_app_app_settings_get","responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}}}}},"/openapi_spec":{"get":{"summary":"Openapi Spec","operationId":"openapi_spec_app_openapi_spec_get","responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}}}}}},"components":{"schemas":{"AnalysisTypes":{"properties":{"analysis_types":{"anyOf":[{"additionalProperties":{"items":{"type":"string"},"type":"array"},"type":"object"},{"type":"null"}],"title":"Analysis Types"}},"type":"object","title":"AnalysisTypes"},"Body_analytics_app_analytics_post":{"properties":{"filter_criteria":{"$ref":"#/components/schemas/FilterCriteria"},"analysis_types":{"$ref":"#/components/schemas/AnalysisTypes"}},"type":"object","required":["filter_criteria","analysis_types"],"title":"Body_analytics_app_analytics_post"},"Body_ingest_files_app_ingest_files_post":{"properties":{"files":{"items":{"type":"string","format":"binary"},"type":"array","title":"Files"},"metadatas":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Metadatas"},"ids":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Ids"},"user_ids":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"User Ids"}},"type":"object","required":["files"],"title":"Body_ingest_files_app_ingest_files_post"},"Body_update_files_app_update_files_post":{"properties":{"files":{"items":{"type":"string","format":"binary"},"type":"array","title":"Files"},"metadatas":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Metadatas"},"ids":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Ids"}},"type":"object","required":["files"],"title":"Body_update_files_app_update_files_post"},"R2RDeleteRequest":{"properties":{"keys":{"items":{"type":"string"},"type":"array","title":"Keys"},"values":{"items":{"anyOf":[{"type":"boolean"},{"type":"integer"},{"type":"string"}]},"type":"array","title":"Values"}},"type":"object","required":["keys","values"],"title":"R2RDeleteRequest"},"Document":{"properties":{"id":{"type":"string","format":"uuid","title":"Id"},"type":{"$ref":"#/components/schemas/DocumentType"},"data":{"anyOf":[{"type":"string"},{"type":"string","format":"binary"}],"title":"Data"},"metadata":{"type":"object","title":"Metadata"},"title":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Title"},"user_id":{"anyOf":[{"type":"string","format":"uuid"},{"type":"null"}],"title":"User Id"}},"type":"object","required":["id","type","data","metadata"],"title":"Document","description":"A document that has been stored in the system."},"DocumentType":{"type":"string","enum":["csv","docx","html","json","md","pdf","pptx","txt","xlsx","gif","png","jpg","jpeg","svg","mp3","mp4"],"title":"DocumentType","description":"Types of documents that can be stored."},"R2REvalRequest":{"properties":{"query":{"type":"string","title":"Query"},"context":{"type":"string","title":"Context"},"completion":{"type":"string","title":"Completion"}},"type":"object","required":["query","context","completion"],"title":"R2REvalRequest"},"FilterCriteria":{"properties":{"filters":{"anyOf":[{"additionalProperties":{"type":"string"},"type":"object"},{"type":"null"}],"title":"Filters"}},"type":"object","title":"FilterCriteria"},"HTTPValidationError":{"properties":{"detail":{"items":{"$ref":"#/components/schemas/ValidationError"},"type":"array","title":"Detail"}},"type":"object","title":"HTTPValidationError"},"R2RIngestDocumentsRequest":{"properties":{"documents":{"items":{"$ref":"#/components/schemas/Document"},"type":"array","title":"Documents"}},"type":"object","required":["documents"],"title":"R2RIngestDocumentsRequest"},"R2RRAGRequest":{"properties":{"message":{"type":"string","title":"Message"},"search_filters":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Search Filters"},"search_limit":{"type":"integer","title":"Search Limit","default":10},"rag_generation_config":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Rag Generation Config"},"streaming":{"anyOf":[{"type":"boolean"},{"type":"null"}],"title":"Streaming"}},"type":"object","required":["message"],"title":"R2RRAGRequest"},"VectorSearchRequest":{"properties":{"query":{"type":"string","title":"Query"},"search_filters":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Search Filters"},"search_limit":{"type":"integer","title":"Search Limit","default":10}},"type":"object","required":["query"],"title":"VectorSearchRequest"},"R2RUpdateDocumentsRequest":{"properties":{"documents":{"items":{"$ref":"#/components/schemas/Document"},"type":"array","title":"Documents"}},"type":"object","required":["documents"],"title":"R2RUpdateDocumentsRequest"},"R2RUpdatePromptRequest":{"properties":{"name":{"type":"string","title":"Name"},"template":{"anyOf":[{"type":"string"},{"type":"null"}],"title":"Template"},"input_types":{"anyOf":[{"additionalProperties":{"type":"string"},"type":"object"},{"type":"null"}],"title":"Input Types"}},"type":"object","required":["name"],"title":"R2RUpdatePromptRequest"},"ValidationError":{"properties":{"loc":{"items":{"anyOf":[{"type":"string"},{"type":"integer"}]},"type":"array","title":"Location"},"msg":{"type":"string","title":"Message"},"type":{"type":"string","title":"Error Type"}},"type":"object","required":["loc","msg","type"],"title":"ValidationError"}}}} diff --git a/poetry.lock b/poetry.lock index 4d6f3de4..fee627b4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3946,4 +3946,4 @@ local-embedding = ["sentence-transformers"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "b84fa623b6fa140d03748bf787f988e615ee91bfd7b114a2a32a764ee88bc2d9" \ No newline at end of file +content-hash = "b84fa623b6fa140d03748bf787f988e615ee91bfd7b114a2a32a764ee88bc2d9" diff --git a/r2r/__init__.py b/r2r/__init__.py index 1c0ab599..9f77ab30 100644 --- a/r2r/__init__.py +++ b/r2r/__init__.py @@ -39,8 +39,8 @@ "VectorEntry", "VectorType", "Vector", - "SearchRequest", - "SearchResult", + "VectorSearchRequest", + "VectorSearchResult", "AsyncPipe", "PipeType", "AsyncState", @@ -98,7 +98,7 @@ "R2RPromptProvider", "WebSearchPipe", "R2RAppBuilder", - "KGAgentPipe", + "KGAgentSearchPipe", # Prebuilts "MultiSearchPipe", "R2RPipeFactoryWithMultiSearch", diff --git a/r2r/core/__init__.py b/r2r/core/__init__.py index cd9cc18e..012b4154 100644 --- a/r2r/core/__init__.py +++ b/r2r/core/__init__.py @@ -14,9 +14,21 @@ extract_triples, ) from .abstractions.llama_abstractions import VectorStoreQuery -from .abstractions.llm import LLMChatCompletion, LLMChatCompletionChunk +from .abstractions.llm import ( + GenerationConfig, + LLMChatCompletion, + LLMChatCompletionChunk, +) from .abstractions.prompt import Prompt -from .abstractions.search import SearchRequest, SearchResult +from .abstractions.search import ( + AggregateSearchResult, + KGSearchRequest, + KGSearchResult, + KGSearchSettings, + VectorSearchRequest, + VectorSearchResult, + VectorSearchSettings, +) from .abstractions.user import UserStats from .abstractions.vector import Vector, VectorEntry, VectorType from .logging.kv_logger import ( @@ -51,19 +63,16 @@ TextParser, XLSXParser, ) -from .pipeline.base_pipeline import ( - EvalPipeline, - IngestionPipeline, - Pipeline, - RAGPipeline, - SearchPipeline, -) +from .pipeline.base_pipeline import EvalPipeline, Pipeline +from .pipeline.ingestion_pipeline import IngestionPipeline +from .pipeline.rag_pipeline import RAGPipeline +from .pipeline.search_pipeline import SearchPipeline from .pipes.base_pipe import AsyncPipe, AsyncState, PipeType from .pipes.loggable_pipe import LoggableAsyncPipe from .providers.embedding_provider import EmbeddingConfig, EmbeddingProvider from .providers.eval_provider import EvalConfig, EvalProvider from .providers.kg_provider import KGConfig, KGProvider -from .providers.llm_provider import GenerationConfig, LLMConfig, LLMProvider +from .providers.llm_provider import LLMConfig, LLMProvider from .providers.prompt_provider import PromptConfig, PromptProvider from .providers.vector_db_provider import VectorDBConfig, VectorDBProvider from .utils import ( @@ -99,8 +108,13 @@ "VectorEntry", "VectorType", "Vector", - "SearchRequest", - "SearchResult", + "VectorSearchRequest", + "VectorSearchResult", + "VectorSearchSettings", + "KGSearchRequest", + "KGSearchResult", + "KGSearchSettings", + "AggregateSearchResult", "AsyncPipe", "PipeType", "AsyncState", diff --git a/r2r/core/abstractions/document.py b/r2r/core/abstractions/document.py index 82c809ec..a8cdb8e6 100644 --- a/r2r/core/abstractions/document.py +++ b/r2r/core/abstractions/document.py @@ -1,5 +1,6 @@ """Abstractions for documents and their extractions.""" +import base64 import json import logging import uuid @@ -7,6 +8,7 @@ from enum import Enum from typing import Optional, Union +from fastapi import HTTPException from pydantic import BaseModel, Field logger = logging.getLogger(__name__) @@ -43,8 +45,28 @@ class Document(BaseModel): data: DataType metadata: dict - title: Optional[str] = None - user_id: Optional[uuid.UUID] = None + def encode_data(self): + if isinstance(self.data, bytes): + self.data = base64.b64encode(self.data).decode("utf-8") + self.id = str(self.id) + for key, value in self.metadata.items(): + if isinstance(value, uuid.UUID): + self.metadata[key] = str(value) + + def decode_data(self): + if isinstance(self.data, str): + try: + self.data = base64.b64decode(self.data.encode("utf-8")) + except Exception as e: + raise HTTPException( + status_code=400, detail=f"Failed to decode data: {e}" + ) + self.id = uuid.UUID(self.id) + for key, value in self.metadata.items(): + try: + self.metadata[key] = uuid.UUID(value) + except ValueError: + pass class DocumentInfo(BaseModel): @@ -55,7 +77,6 @@ class DocumentInfo(BaseModel): size_in_bytes: int metadata: dict - title: Optional[str] = None user_id: Optional[uuid.UUID] = None created_at: Optional[datetime] = None updated_at: Optional[datetime] = None @@ -69,7 +90,7 @@ def convert_to_db_entry(self): ) return { "document_id": str(self.document_id), - "title": self.title or "N/A", + "title": metadata.get("title", "N/A"), "user_id": metadata["user_id"], "version": self.version, "size_in_bytes": self.size_in_bytes, diff --git a/r2r/core/abstractions/llm.py b/r2r/core/abstractions/llm.py index cef38e0d..5106379a 100644 --- a/r2r/core/abstractions/llm.py +++ b/r2r/core/abstractions/llm.py @@ -1,6 +1,27 @@ """Abstractions for the LLM model.""" +from typing import Optional + from openai.types.chat import ChatCompletion, ChatCompletionChunk +from pydantic import BaseModel LLMChatCompletion = ChatCompletion LLMChatCompletionChunk = ChatCompletionChunk + + +class GenerationConfig(BaseModel): + temperature: float = 0.1 + top_p: float = 1.0 + top_k: int = 100 + max_tokens_to_sample: int = 1_024 + model: str = "gpt-4o" + stream: bool = False + functions: Optional[list[dict]] = None + skip_special_tokens: bool = False + stop_token: Optional[str] = None + num_beams: int = 1 + do_sample: bool = True + # Additional args to pass to the generation config + generate_with_chat: bool = False + add_generation_kwargs: Optional[dict] = {} + api_base: Optional[str] = None diff --git a/r2r/core/abstractions/search.py b/r2r/core/abstractions/search.py index 30d37413..2b558fd0 100644 --- a/r2r/core/abstractions/search.py +++ b/r2r/core/abstractions/search.py @@ -1,12 +1,14 @@ """Abstractions for search functionality.""" import uuid -from typing import Any, Optional +from typing import Any, Dict, List, Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field +from .llm import GenerationConfig -class SearchRequest(BaseModel): + +class VectorSearchRequest(BaseModel): """Request for a search operation.""" query: str @@ -14,7 +16,7 @@ class SearchRequest(BaseModel): filters: Optional[dict[str, Any]] = None -class SearchResult(BaseModel): +class VectorSearchResult(BaseModel): """Result of a search operation.""" id: uuid.UUID @@ -22,10 +24,10 @@ class SearchResult(BaseModel): metadata: dict[str, Any] def __str__(self) -> str: - return f"SearchResult(id={self.id}, score={self.score}, metadata={self.metadata})" + return f"VectorSearchResult(id={self.id}, score={self.score}, metadata={self.metadata})" def __repr__(self) -> str: - return f"SearchResult(id={self.id}, score={self.score}, metadata={self.metadata})" + return f"VectorSearchResult(id={self.id}, score={self.score}, metadata={self.metadata})" def dict(self) -> dict: return { @@ -33,3 +35,47 @@ def dict(self) -> dict: "score": self.score, "metadata": self.metadata, } + + +class KGSearchRequest(BaseModel): + """Request for a knowledge graph search operation.""" + + query: str + + +KGSearchResult = List[List[Dict[str, Any]]] + + +class AggregateSearchResult(BaseModel): + """Result of an aggregate search operation.""" + + vector_search_results: Optional[List[VectorSearchResult]] + kg_search_results: Optional[KGSearchResult] = None + + def __str__(self) -> str: + return f"AggregateSearchResult(vector_search_results={self.vector_search_results}, kg_search_results={self.kg_search_results})" + + def __repr__(self) -> str: + return f"AggregateSearchResult(vector_search_results={self.vector_search_results}, kg_search_results={self.kg_search_results})" + + def dict(self) -> dict: + return { + "vector_search_results": [ + result.dict() for result in self.vector_search_results + ], + "kg_search_results": self.kg_search_results, + } + + +class VectorSearchSettings(BaseModel): + use_vector_search: bool = True + search_filters: Optional[dict[str, Any]] = Field(default_factory=dict) + search_limit: int = 10 + do_hybrid_search: bool = False + + +class KGSearchSettings(BaseModel): + use_kg: bool = False + agent_generation_config: Optional[GenerationConfig] = Field( + default_factory=GenerationConfig + ) diff --git a/r2r/core/logging/kv_logger.py b/r2r/core/logging/kv_logger.py index 75ee38db..36052198 100644 --- a/r2r/core/logging/kv_logger.py +++ b/r2r/core/logging/kv_logger.py @@ -503,10 +503,6 @@ def get_instance(cls): def configure( cls, logging_config: Optional[LoggingConfig] = LoggingConfig() ): - logger.info( - f"Initializing KVLoggingSingleton with config: {logging_config}" - ) - if not cls._is_configured: cls._config = logging_config cls._is_configured = True diff --git a/r2r/core/pipeline/base_pipeline.py b/r2r/core/pipeline/base_pipeline.py index 6f8f4242..4f08353b 100644 --- a/r2r/core/pipeline/base_pipeline.py +++ b/r2r/core/pipeline/base_pipeline.py @@ -2,7 +2,6 @@ import asyncio import logging -from asyncio import Queue from enum import Enum from typing import Any, AsyncGenerator, Optional @@ -30,13 +29,11 @@ def __init__( self, pipe_logger: Optional[KVLoggingSingleton] = None, run_manager: Optional[RunManager] = None, - log_run_info: bool = True, ): self.pipes: list[AsyncPipe] = [] self.upstream_outputs: list[list[dict[str, str]]] = [] self.pipe_logger = pipe_logger or KVLoggingSingleton() self.run_manager = run_manager or RunManager(self.pipe_logger) - self.log_run_info = log_run_info self.futures = {} self.level = 0 @@ -57,8 +54,9 @@ async def run( self, input: Any, state: Optional[AsyncState] = None, - streaming: bool = False, + stream: bool = False, run_manager: Optional[RunManager] = None, + log_run_info: bool = True, *args: Any, **kwargs: Any, ): @@ -75,7 +73,7 @@ async def run( self.state = state or AsyncState() current_input = input async with manage_run(run_manager, self.pipeline_type): - if self.log_run_info: + if log_run_info: await run_manager.log_run_info( key="pipeline_type", value=self.pipeline_type, @@ -94,7 +92,7 @@ async def run( **kwargs, ) self.futures[config_name].set_result(current_input) - if not streaming: + if not stream: final_result = await self._consume_all(current_input) return final_result else: @@ -206,13 +204,13 @@ async def run( self, input: Any, state: Optional[AsyncState] = None, - streaming: bool = False, + stream: bool = False, run_manager: Optional[RunManager] = None, *args: Any, **kwargs: Any, ): return await super().run( - input, state, streaming, run_manager, *args, **kwargs + input, state, stream, run_manager, *args, **kwargs ) def add_pipe( @@ -226,195 +224,10 @@ def add_pipe( return super().add_pipe(pipe, add_upstream_outputs, *args, **kwargs) -class IngestionPipeline(Pipeline): - """A pipeline for ingestion.""" - - pipeline_type: str = "ingestion" - - def __init__( - self, - pipe_logger: Optional[KVLoggingSingleton] = None, - run_manager: Optional[RunManager] = None, - ): - super().__init__(pipe_logger, run_manager) - self.parsing_pipe = None - self.embedding_pipeline = None - self.kg_pipeline = None - - async def run( - self, - input: Any, - state: Optional[AsyncState] = None, - streaming: bool = False, - run_manager: Optional[RunManager] = None, - *args: Any, - **kwargs: Any, - ): - self.state = state or AsyncState() - - async with manage_run(run_manager, self.pipeline_type): - await run_manager.log_run_info( - key="pipeline_type", - value=self.pipeline_type, - is_info_log=True, - ) - if self.parsing_pipe is None: - raise ValueError( - "parsing_pipeline must be set before running the ingestion pipeline" - ) - if self.embedding_pipeline is None and self.kg_pipeline is None: - raise ValueError( - "At least one of embedding_pipeline or kg_pipeline must be set before running the ingestion pipeline" - ) - # Use queues to duplicate the documents for each pipeline - embedding_queue = Queue() - kg_queue = Queue() - - async def enqueue_documents(): - async for document in await self.parsing_pipe.run( - self.parsing_pipe.Input(message=input), - state, - run_manager, - *args, - **kwargs, - ): - if self.embedding_pipeline: - await embedding_queue.put(document) - if self.kg_pipeline: - await kg_queue.put(document) - await embedding_queue.put(None) - await kg_queue.put(None) - - # Create an async generator to dequeue documents - async def dequeue_documents(queue: Queue) -> AsyncGenerator: - while True: - document = await queue.get() - if document is None: - break - yield document - - # Start the document enqueuing process - enqueue_task = asyncio.create_task(enqueue_documents()) - - # Start the embedding and KG pipelines in parallel - if self.embedding_pipeline: - embedding_task = asyncio.create_task( - self.embedding_pipeline.run( - dequeue_documents(embedding_queue), - state, - streaming, - run_manager, - *args, - **kwargs, - ) - ) - - if self.kg_pipeline: - kg_task = asyncio.create_task( - self.kg_pipeline.run( - dequeue_documents(kg_queue), - state, - streaming, - run_manager, - *args, - **kwargs, - ) - ) - - # Wait for the enqueueing task to complete - await enqueue_task - - # Wait for the embedding and KG tasks to complete - if self.embedding_pipeline: - await embedding_task - if self.kg_pipeline: - await kg_task - - def add_pipe( - self, - pipe: AsyncPipe, - add_upstream_outputs: Optional[list[dict[str, str]]] = None, - parsing_pipe: bool = False, - kg_pipe: bool = False, - embedding_pipe: bool = False, - *args, - **kwargs, - ) -> None: - logger.debug( - f"Adding pipe {pipe.config.name} to the IngestionPipeline" - ) - - if parsing_pipe: - self.parsing_pipe = pipe - elif kg_pipe: - if not self.kg_pipeline: - self.kg_pipeline = Pipeline(log_run_info=False) - self.kg_pipeline.add_pipe( - pipe, add_upstream_outputs, *args, **kwargs - ) - elif embedding_pipe: - if not self.embedding_pipeline: - self.embedding_pipeline = Pipeline(log_run_info=False) - self.embedding_pipeline.add_pipe( - pipe, add_upstream_outputs, *args, **kwargs - ) - else: - raise ValueError("Pipe must be a parsing, embedding, or KG pipe") - - -class RAGPipeline(Pipeline): - """A pipeline for RAG.""" - - pipeline_type: str = "rag" - - async def run( - self, - input: Any, - state: Optional[AsyncState] = None, - streaming: bool = False, - run_manager: Optional[RunManager] = None, - *args: Any, - **kwargs: Any, - ): - return await super().run( - input, state, streaming, run_manager, *args, **kwargs - ) - - def add_pipe( - self, - pipe: AsyncPipe, - add_upstream_outputs: Optional[list[dict[str, str]]] = None, - *args, - **kwargs, - ) -> None: - logger.debug(f"Adding pipe {pipe.config.name} to the RAGPipeline") - return super().add_pipe(pipe, add_upstream_outputs, *args, **kwargs) - - -class SearchPipeline(Pipeline): - """A pipeline for search.""" - - pipeline_type: str = "search" - - async def run( - self, - input: Any, - state: Optional[AsyncState] = None, - streaming: bool = False, - run_manager: Optional[RunManager] = None, - *args: Any, - **kwargs: Any, - ): - return await super().run( - input, state, streaming, run_manager, *args, **kwargs - ) - - def add_pipe( - self, - pipe: AsyncPipe, - add_upstream_outputs: Optional[list[dict[str, str]]] = None, - *args, - **kwargs, - ) -> None: - logger.debug(f"Adding pipe {pipe.config.name} to the SearchPipeline") - return super().add_pipe(pipe, add_upstream_outputs, *args, **kwargs) +async def dequeue_requests(queue: asyncio.Queue) -> AsyncGenerator: + """Create an async generator to dequeue requests.""" + while True: + request = await queue.get() + if request is None: + break + yield request diff --git a/r2r/core/pipeline/ingestion_pipeline.py b/r2r/core/pipeline/ingestion_pipeline.py new file mode 100644 index 00000000..6ef94616 --- /dev/null +++ b/r2r/core/pipeline/ingestion_pipeline.py @@ -0,0 +1,143 @@ +import asyncio +import logging +from asyncio import Queue +from typing import Any, Optional + +from ..logging.kv_logger import KVLoggingSingleton +from ..logging.run_manager import RunManager, manage_run +from ..pipes.base_pipe import AsyncPipe, AsyncState +from .base_pipeline import Pipeline, dequeue_requests + +logger = logging.getLogger(__name__) + + +class IngestionPipeline(Pipeline): + """A pipeline for ingestion.""" + + pipeline_type: str = "ingestion" + + def __init__( + self, + pipe_logger: Optional[KVLoggingSingleton] = None, + run_manager: Optional[RunManager] = None, + ): + super().__init__(pipe_logger, run_manager) + self.parsing_pipe = None + self.embedding_pipeline = None + self.kg_pipeline = None + + async def run( + self, + input: Any, + state: Optional[AsyncState] = None, + stream: bool = False, + run_manager: Optional[RunManager] = None, + log_run_info: bool = True, + *args: Any, + **kwargs: Any, + ) -> None: + self.state = state or AsyncState() + + async with manage_run(run_manager, self.pipeline_type): + if log_run_info: + await run_manager.log_run_info( + key="pipeline_type", + value=self.pipeline_type, + is_info_log=True, + ) + if self.parsing_pipe is None: + raise ValueError( + "parsing_pipeline must be set before running the ingestion pipeline" + ) + if self.embedding_pipeline is None and self.kg_pipeline is None: + raise ValueError( + "At least one of embedding_pipeline or kg_pipeline must be set before running the ingestion pipeline" + ) + # Use queues to duplicate the documents for each pipeline + embedding_queue = Queue() + kg_queue = Queue() + + async def enqueue_documents(): + async for document in await self.parsing_pipe.run( + self.parsing_pipe.Input(message=input), + state, + run_manager, + *args, + **kwargs, + ): + if self.embedding_pipeline: + await embedding_queue.put(document) + if self.kg_pipeline: + await kg_queue.put(document) + await embedding_queue.put(None) + await kg_queue.put(None) + + # Start the document enqueuing process + enqueue_task = asyncio.create_task(enqueue_documents()) + + # Start the embedding and KG pipelines in parallel + if self.embedding_pipeline: + embedding_task = asyncio.create_task( + self.embedding_pipeline.run( + dequeue_requests(embedding_queue), + state, + stream, + run_manager, + log_run_info=False, # Do not log run info since we have already done so + *args, + **kwargs, + ) + ) + + if self.kg_pipeline: + kg_task = asyncio.create_task( + self.kg_pipeline.run( + dequeue_requests(kg_queue), + state, + stream, + run_manager, + log_run_info=False, # Do not log run info since we have already done so + *args, + **kwargs, + ) + ) + + # Wait for the enqueueing task to complete + await enqueue_task + + # Wait for the embedding and KG tasks to complete + if self.embedding_pipeline: + await embedding_task + if self.kg_pipeline: + await kg_task + + def add_pipe( + self, + pipe: AsyncPipe, + add_upstream_outputs: Optional[list[dict[str, str]]] = None, + parsing_pipe: bool = False, + kg_pipe: bool = False, + embedding_pipe: bool = False, + *args, + **kwargs, + ) -> None: + logger.debug( + f"Adding pipe {pipe.config.name} to the IngestionPipeline" + ) + + if parsing_pipe: + self.parsing_pipe = pipe + elif kg_pipe: + if not self.kg_pipeline: + self.kg_pipeline = Pipeline() + self.kg_pipeline.add_pipe( + pipe, add_upstream_outputs, *args, **kwargs + ) + elif embedding_pipe: + if not self.embedding_pipeline: + self.embedding_pipeline = Pipeline() + self.embedding_pipeline.add_pipe( + pipe, add_upstream_outputs, *args, **kwargs + ) + else: + raise ValueError("Pipe must be a parsing, embedding, or KG pipe") diff --git a/r2r/core/pipeline/pipeline_router.py b/r2r/core/pipeline/pipeline_router.py deleted file mode 100644 index 3a9e841e..00000000 --- a/r2r/core/pipeline/pipeline_router.py +++ /dev/null @@ -1,42 +0,0 @@ -import random -from typing import Any, Optional - -from ..logging.kv_logger import KVLoggingSingleton -from ..logging.run_manager import RunManager -from ..pipes.base_pipe import AsyncState -from .base_pipeline import Pipeline - - -class PipelineRouter(Pipeline): - """PipelineRouter for routing to different pipelines based on weights.""" - - def __init__( - self, - pipelines: dict[Pipeline, float], - pipe_logger: Optional[KVLoggingSingleton] = None, - run_manager: Optional[RunManager] = None, - ): - super().__init__(pipe_logger, run_manager) - if not abs(sum(pipelines.values()) - 1.0) < 1e-6: - raise ValueError("The weights must sum to 1") - self.pipelines = pipelines - - async def run( - self, - input: Any, - state: Optional[AsyncState] = None, - streaming: bool = False, - run_manager: Optional[RunManager] = None, - *args: Any, - **kwargs: Any, - ): - run_manager = run_manager or self.run_manager - pipeline = self.select_pipeline() - return await pipeline.run( - input, state, streaming, run_manager, *args, **kwargs - ) - - def select_pipeline(self) -> Pipeline: - pipelines, weights = zip(*self.pipelines.items()) - selected_pipeline = random.choices(pipelines, weights)[0] - return selected_pipeline diff --git a/r2r/core/pipeline/rag_pipeline.py b/r2r/core/pipeline/rag_pipeline.py new file mode 100644 index 00000000..fed2c4ed --- /dev/null +++ b/r2r/core/pipeline/rag_pipeline.py @@ -0,0 +1,115 @@ +import asyncio +import logging +from typing import Any, Optional + +from ..abstractions.llm import GenerationConfig +from ..abstractions.search import KGSearchSettings, VectorSearchSettings +from ..logging.kv_logger import KVLoggingSingleton +from ..logging.run_manager import RunManager, manage_run +from ..pipes.base_pipe import AsyncPipe, AsyncState +from ..utils import to_async_generator +from .base_pipeline import Pipeline + +logger = logging.getLogger(__name__) + + +class RAGPipeline(Pipeline): + """A pipeline for RAG.""" + + pipeline_type: str = "rag" + + def __init__( + self, + pipe_logger: Optional[KVLoggingSingleton] = None, + run_manager: Optional[RunManager] = None, + ): + super().__init__(pipe_logger, run_manager) + self._search_pipeline = None + self._rag_pipeline = None + + async def run( + self, + input: Any, + state: Optional[AsyncState] = None, + run_manager: Optional[RunManager] = None, + log_run_info=True, + vector_search_settings: VectorSearchSettings = VectorSearchSettings(), + kg_search_settings: KGSearchSettings = KGSearchSettings(), + rag_generation_config: GenerationConfig = GenerationConfig(), + *args: Any, + **kwargs: Any, + ): + self.state = state or AsyncState() + async with manage_run(run_manager, self.pipeline_type): + if log_run_info: + await run_manager.log_run_info( + key="pipeline_type", + value=self.pipeline_type, + is_info_log=True, + ) + + if not self._search_pipeline: + raise ValueError( + "_search_pipeline must be set before running the RAG pipeline" + ) + + async def multi_query_generator(input): + tasks = [] + async for query in input: + task = asyncio.create_task( + self._search_pipeline.run( + to_async_generator([query]), + state=state, + stream=False, # do not stream the search results + run_manager=run_manager, + log_run_info=False, # do not log the run info as it is already logged above + vector_search_settings=vector_search_settings, + kg_search_settings=kg_search_settings, + *args, + **kwargs, + ) + ) + tasks.append((query, task)) + + for query, task in tasks: + yield (query, await task) + + rag_results = await self._rag_pipeline.run( + input=multi_query_generator(input), + state=state, + stream=rag_generation_config.stream, + run_manager=run_manager, + log_run_info=False, + rag_generation_config=rag_generation_config, + *args, + **kwargs, + ) + return rag_results + + def add_pipe( + self, + pipe: AsyncPipe, + add_upstream_outputs: Optional[list[dict[str, str]]] = None, + rag_pipe: bool = True, + *args, + **kwargs, + ) -> None: + logger.debug(f"Adding pipe {pipe.config.name} to the RAGPipeline") + if not rag_pipe: + raise ValueError( + "Only pipes that are part of the RAG pipeline can be added to the RAG pipeline" + ) + if not self._rag_pipeline: + self._rag_pipeline = Pipeline() + self._rag_pipeline.add_pipe( + pipe, add_upstream_outputs, *args, **kwargs + ) + + def set_search_pipeline( + self, + _search_pipeline: Pipeline, + *args, + **kwargs, + ) -> None: + logger.debug(f"Setting search pipeline for the RAGPipeline") + self._search_pipeline = _search_pipeline diff --git a/r2r/core/pipeline/search_pipeline.py b/r2r/core/pipeline/search_pipeline.py new file mode 100644 index 00000000..89088531 --- /dev/null +++ b/r2r/core/pipeline/search_pipeline.py @@ -0,0 +1,139 @@ +import asyncio +import logging +from asyncio import Queue +from typing import Any, Optional + +from ..abstractions.search import ( + AggregateSearchResult, + KGSearchSettings, + VectorSearchSettings, +) +from ..logging.kv_logger import KVLoggingSingleton +from ..logging.run_manager import RunManager, manage_run +from ..pipes.base_pipe import AsyncPipe, AsyncState +from .base_pipeline import Pipeline, dequeue_requests + +logger = logging.getLogger(__name__) + + +class SearchPipeline(Pipeline): + """A pipeline for search.""" + + pipeline_type: str = "search" + + def __init__( + self, + pipe_logger: Optional[KVLoggingSingleton] = None, + run_manager: Optional[RunManager] = None, + ): + super().__init__(pipe_logger, run_manager) + self._parsing_pipe = None + self._vector_search_pipeline = None + self._kg_search_pipeline = None + + async def run( + self, + input: Any, + state: Optional[AsyncState] = None, + stream: bool = False, + run_manager: Optional[RunManager] = None, + log_run_info: bool = True, + vector_search_settings: VectorSearchSettings = VectorSearchSettings(), + kg_search_settings: KGSearchSettings = KGSearchSettings(), + *args: Any, + **kwargs: Any, + ): + self.state = state or AsyncState() + do_vector_search = ( + self._vector_search_pipeline is not None + and vector_search_settings.use_vector_search + ) + do_kg = ( + self._kg_search_pipeline is not None and kg_search_settings.use_kg + ) + async with manage_run(run_manager, self.pipeline_type): + if log_run_info: + await run_manager.log_run_info( + key="pipeline_type", + value=self.pipeline_type, + is_info_log=True, + ) + + vector_search_queue = Queue() + kg_queue = Queue() + + async def enqueue_requests(): + async for message in input: + if do_vector_search: + await vector_search_queue.put(message) + if do_kg: + await kg_queue.put(message) + + await vector_search_queue.put(None) + await kg_queue.put(None) + + # Start the document enqueuing process + enqueue_task = asyncio.create_task(enqueue_requests()) + + # Start the embedding and KG pipelines in parallel + if do_vector_search: + vector_search_task = asyncio.create_task( + self._vector_search_pipeline.run( + dequeue_requests(vector_search_queue), + state, + stream, + run_manager, + log_run_info=False, + vector_search_settings=vector_search_settings, + ) + ) + + if do_kg: + kg_task = asyncio.create_task( + self._kg_search_pipeline.run( + dequeue_requests(kg_queue), + state, + stream, + run_manager, + log_run_info=False, + kg_search_settings=kg_search_settings, + ) + ) + + await enqueue_task + + vector_search_results = ( + await vector_search_task if do_vector_search else None + ) + kg_results = await kg_task if do_kg else None + + return AggregateSearchResult( + vector_search_results=vector_search_results, + kg_search_results=kg_results, + ) + + def add_pipe( + self, + pipe: AsyncPipe, + add_upstream_outputs: Optional[list[dict[str, str]]] = None, + kg_pipe: bool = False, + vector_search_pipe: bool = False, + *args, + **kwargs, + ) -> None: + logger.debug(f"Adding pipe {pipe.config.name} to the SearchPipeline") + + if kg_pipe: + if not self._kg_search_pipeline: + self._kg_search_pipeline = Pipeline() + self._kg_search_pipeline.add_pipe( + pipe, add_upstream_outputs, *args, **kwargs + ) + elif vector_search_pipe: + if not self._vector_search_pipeline: + self._vector_search_pipeline = Pipeline() + self._vector_search_pipeline.add_pipe( + pipe, add_upstream_outputs, *args, **kwargs + ) + else: + raise ValueError("Pipe must be a vector search or KG pipe") diff --git a/r2r/core/pipes/base_pipe.py b/r2r/core/pipes/base_pipe.py index ebb147af..57fadfc1 100644 --- a/r2r/core/pipes/base_pipe.py +++ b/r2r/core/pipes/base_pipe.py @@ -89,7 +89,9 @@ def __init__( self._run_info = None self._type = type - logger.info(f"Initialized pipe {self.config.name} of type {self.type}") + logger.debug( + f"Initialized pipe {self.config.name} of type {self.type}" + ) @property def config(self) -> PipeConfig: diff --git a/r2r/core/providers/embedding_provider.py b/r2r/core/providers/embedding_provider.py index 022fe1fe..3198584a 100644 --- a/r2r/core/providers/embedding_provider.py +++ b/r2r/core/providers/embedding_provider.py @@ -1,10 +1,13 @@ +import logging from abc import abstractmethod from enum import Enum from typing import Optional -from ..abstractions.search import SearchResult +from ..abstractions.search import VectorSearchResult from .base_provider import Provider, ProviderConfig +logger = logging.getLogger(__name__) + class EmbeddingConfig(ProviderConfig): """A base embedding configuration class""" @@ -38,6 +41,7 @@ def __init__(self, config: EmbeddingConfig): raise ValueError( "EmbeddingProvider must be initialized with a `EmbeddingConfig`." ) + logger.info(f"Initializing EmbeddingProvider with config {config}.") super().__init__(config) @@ -65,7 +69,7 @@ async def async_get_embeddings( def rerank( self, query: str, - results: list[SearchResult], + results: list[VectorSearchResult], stage: PipeStage = PipeStage.RERANK, limit: int = 10, ): diff --git a/r2r/core/providers/eval_provider.py b/r2r/core/providers/eval_provider.py index 9e4b7891..76053f87 100644 --- a/r2r/core/providers/eval_provider.py +++ b/r2r/core/providers/eval_provider.py @@ -1,7 +1,8 @@ from typing import Optional, Union +from ..abstractions.llm import GenerationConfig from .base_provider import Provider, ProviderConfig -from .llm_provider import GenerationConfig, LLMConfig +from .llm_provider import LLMConfig class EvalConfig(ProviderConfig): diff --git a/r2r/core/providers/llm_provider.py b/r2r/core/providers/llm_provider.py index b510e5f9..65bd363e 100644 --- a/r2r/core/providers/llm_provider.py +++ b/r2r/core/providers/llm_provider.py @@ -4,7 +4,7 @@ from abc import abstractmethod from typing import Optional -from pydantic import BaseModel +from r2r.core.abstractions.llm import GenerationConfig from ..abstractions.llm import LLMChatCompletion, LLMChatCompletionChunk from .base_provider import Provider, ProviderConfig @@ -12,24 +12,6 @@ logger = logging.getLogger(__name__) -class GenerationConfig(BaseModel): - temperature: float = 0.1 - top_p: float = 1.0 - top_k: int = 100 - max_tokens_to_sample: int = 1_024 - model: str - stream: bool = False - functions: Optional[list[dict]] = None - skip_special_tokens: bool = False - stop_token: Optional[str] = None - num_beams: int = 1 - do_sample: bool = True - # Additional args to pass to the generation config - generate_with_chat: bool = False - add_generation_kwargs: Optional[dict] = {} - api_base: Optional[str] = None - - class LLMConfig(ProviderConfig): """A base LLM config class""" diff --git a/r2r/core/providers/vector_db_provider.py b/r2r/core/providers/vector_db_provider.py index 09530faf..0e22bfcc 100644 --- a/r2r/core/providers/vector_db_provider.py +++ b/r2r/core/providers/vector_db_provider.py @@ -3,7 +3,7 @@ from typing import Optional, Union from ..abstractions.document import DocumentInfo -from ..abstractions.search import SearchResult +from ..abstractions.search import VectorSearchResult from ..abstractions.vector import VectorEntry from .base_provider import Provider, ProviderConfig @@ -58,7 +58,7 @@ def search( limit: int = 10, *args, **kwargs, - ) -> list[SearchResult]: + ) -> list[VectorSearchResult]: pass @abstractmethod @@ -74,7 +74,7 @@ def hybrid_search( rrf_k: int = 20, # typical value is ~2x the number of results you want *args, **kwargs, - ) -> list[SearchResult]: + ) -> list[VectorSearchResult]: pass @abstractmethod diff --git a/r2r/examples/quickstart.py b/r2r/examples/quickstart.py index 758116bc..250d81ff 100644 --- a/r2r/examples/quickstart.py +++ b/r2r/examples/quickstart.py @@ -12,12 +12,14 @@ AnalysisTypes, Document, FilterCriteria, - GenerationConfig, + KGSearchSettings, R2RAppBuilder, R2RClient, R2RConfig, + VectorSearchSettings, generate_id_from_label, ) +from r2r.core.abstractions.llm import GenerationConfig logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -60,24 +62,24 @@ def __init__( f"Running in client-server mode with base_url: {self.base_url}" ) else: - self.r2r = R2RAppBuilder(config).build() + self.r2r_app = R2RAppBuilder(config).build() logger.info("Running locally") root_path = os.path.dirname(os.path.abspath(__file__)) self.user_id = user_id self.default_files = file_list or [ - os.path.join(root_path, "data", "got.txt"), os.path.join(root_path, "data", "aristotle.txt"), - os.path.join(root_path, "data", "screen_shot.png"), - os.path.join(root_path, "data", "pg_essay_1.html"), - os.path.join(root_path, "data", "pg_essay_2.html"), - os.path.join(root_path, "data", "pg_essay_3.html"), - os.path.join(root_path, "data", "pg_essay_4.html"), - os.path.join(root_path, "data", "pg_essay_5.html"), - os.path.join(root_path, "data", "lyft_2021.pdf"), - os.path.join(root_path, "data", "uber_2021.pdf"), - os.path.join(root_path, "data", "sample.mp3"), - os.path.join(root_path, "data", "sample2.mp3"), + # os.path.join(root_path, "data", "got.txt"), + # os.path.join(root_path, "data", "screen_shot.png"), + # os.path.join(root_path, "data", "pg_essay_1.html"), + # os.path.join(root_path, "data", "pg_essay_2.html"), + # os.path.join(root_path, "data", "pg_essay_3.html"), + # os.path.join(root_path, "data", "pg_essay_4.html"), + # os.path.join(root_path, "data", "pg_essay_5.html"), + # os.path.join(root_path, "data", "lyft_2021.pdf"), + # os.path.join(root_path, "data", "uber_2021.pdf"), + # os.path.join(root_path, "data", "sample.mp3"), + # os.path.join(root_path, "data", "sample2.mp3"), ] self.file_tuples = file_tuples or [ @@ -100,7 +102,7 @@ def ingest_as_documents(self, file_paths: Optional[list[str]] = None): Document( id=generate_id_from_label(file_path), user_id=self.user_id, - title=file_path.split(os.path.sep)[-1], + title=file_path, data=data, type=file_path.split(".")[-1], metadata={}, @@ -111,7 +113,7 @@ def ingest_as_documents(self, file_paths: Optional[list[str]] = None): documents_dicts = [doc.dict() for doc in documents] response = self.client.ingest_documents(documents_dicts) else: - response = self.r2r.ingest_documents(documents) + response = self.r2r_app.ingest_documents(documents) t1 = time.time() print(f"Time taken to ingest files: {t1-t0:.2f} seconds") @@ -143,7 +145,7 @@ def update_as_documents(self, file_tuples: Optional[list[tuple]] = None): documents_dicts = [doc.dict() for doc in documents] response = self.client.update_documents(documents_dicts) else: - response = self.r2r.update_documents(documents) + response = self.r2r_app.update_documents(documents) t1 = time.time() print(f"Time taken to update documents: {t1-t0:.2f} seconds") @@ -169,7 +171,7 @@ def ingest_as_files( files = [ UploadFile( - filename=file_path.split(os.path.sep)[-1], + filename=file_path, file=open(file_path, "rb"), ) for file_path in file_paths @@ -180,18 +182,23 @@ def ingest_as_files( file.size = file.file.tell() file.file.seek(0) - metadatas = [{} for _ in file_paths] - user_ids = [self.user_id for _ in file_paths] t0 = time.time() if hasattr(self, "client"): response = self.client.ingest_files( - metadatas=None, files=file_paths, ids=ids, user_ids=user_ids + metadatas=None, + file_paths=file_paths, + document_ids=ids, + user_ids=user_ids, ) else: - response = self.r2r.ingest_files( - files=files, metadatas=metadatas, ids=ids, user_ids=user_ids + metadatas = [{} for _ in file_paths] + response = self.r2r_app.ingest_files( + files=files, + metadatas=metadatas, + document_ids=ids, + user_ids=user_ids, ) t1 = time.time() print(f"Time taken to ingest files: {t1-t0:.2f} seconds") @@ -202,7 +209,7 @@ def update_as_files(self, file_tuples: Optional[list[tuple]] = None): new_files = [ UploadFile( - filename=new_file.split(os.path.sep)[-1], + filename=new_file, file=open(new_file, "rb"), ) for old_file, new_file in file_tuples @@ -215,7 +222,7 @@ def update_as_files(self, file_tuples: Optional[list[tuple]] = None): metadatas = [ { - "title": old_file.split(os.path.sep)[-1], + "title": old_file, "user_id": self.user_id, } for old_file, new_file in file_tuples @@ -226,17 +233,17 @@ def update_as_files(self, file_tuples: Optional[list[tuple]] = None): response = self.client.update_files( metadatas=metadatas, files=[new for old, new in file_tuples], - ids=[ - generate_id_from_label(old_file.split(os.path.sep)[-1]) + document_ids=[ + generate_id_from_label(old_file) for old_file, new_file in file_tuples ], ) else: - response = self.r2r.update_files( + response = self.r2r_app.update_files( files=new_files, metadatas=metadatas, - ids=[ - generate_id_from_label(old_file.split(os.path.sep)[-1]) + document_ids=[ + generate_id_from_label(old_file) for old_file, new_file in file_tuples ], ) @@ -244,51 +251,106 @@ def update_as_files(self, file_tuples: Optional[list[tuple]] = None): print(f"Time taken to update files: {t1-t0:.2f} seconds") print(response) - def search(self, query: str, do_hybrid_search: bool = False): + def search( + self, + query: str, + use_vector_search: bool = True, + search_filters: Optional[dict] = None, + search_limit: int = 10, + do_hybrid_search: bool = False, + use_kg: bool = False, + kg_agent_generation_config: Optional[dict] = None, + ): + + kg_agent_generation_config = ( + GenerationConfig(**kg_agent_generation_config) + if kg_agent_generation_config + else GenerationConfig(model="gpt-4o") + ) + t0 = time.time() if hasattr(self, "client"): results = self.client.search( query, - search_filters={"user_id": self.user_id}, - do_hybrid_search=do_hybrid_search, + use_vector_search, + search_filters, + search_limit, + do_hybrid_search, + use_kg, + kg_agent_generation_config, ) else: - results = self.r2r.search( + results = self.r2r_app.search( query, - search_filters={"user_id": self.user_id}, - do_hybrid_search=do_hybrid_search, + VectorSearchSettings( + use_vector_search=use_vector_search, + search_filters=search_filters, + search_limit=search_limit, + do_hybrid_search=do_hybrid_search, + ), + KGSearchSettings( + use_kg=use_kg, + agent_generation_config=kg_agent_generation_config, + ), + ) + + if "vector_search_results" in results["results"]: + print("Vector search results:") + for result in results["results"]["vector_search_results"]: + print(result) + if ( + "kg_search_results" in results["results"] + and results["results"]["kg_search_results"] + ): + print( + "KG search results:", results["results"]["kg_search_results"] ) t1 = time.time() print(f"Time taken to search: {t1-t0:.2f} seconds") - for result in results["results"]: - print(result) def rag( self, query: str, - rag_generation_config: Optional[dict] = None, - streaming: bool = False, + use_vector_search: bool = True, + search_filters: Optional[dict] = None, + search_limit: int = 10, + do_hybrid_search: bool = False, + use_kg: bool = False, + kg_agent_generation_config: Optional[dict] = None, + stream: bool = False, + rag_generation_config: Optional[GenerationConfig] = None, ): t0 = time.time() + + kg_agent_generation_config = ( + GenerationConfig(**kg_agent_generation_config) + if kg_agent_generation_config + else GenerationConfig(model="gpt-4o") + ) + + rag_generation_config = ( + GenerationConfig(**rag_generation_config, stream=stream) + if rag_generation_config + else GenerationConfig(model="gpt-4o", stream=stream) + ) + if hasattr(self, "client"): - if not streaming: - response = self.client.rag( - query, - search_filters={"user_id": self.user_id}, - rag_generation_config=rag_generation_config, - streaming=streaming, - ) + response = self.client.rag( + query, + use_vector_search, + search_filters, + search_limit, + do_hybrid_search, + use_kg, + kg_agent_generation_config, + rag_generation_config, + ) + if not stream: t1 = time.time() print(f"Time taken to get RAG response: {t1-t0:.2f} seconds") print(response) else: - response = self.client.rag( - query, - search_filters={"user_id": self.user_id}, - rag_generation_config=rag_generation_config, - streaming=streaming, - ) for chunk in response: print(chunk, end="", flush=True) t1 = time.time() @@ -296,20 +358,23 @@ def rag( f"\nTime taken to stream RAG response: {t1-t0:.2f} seconds" ) else: - rag_generation_config = ( - GenerationConfig(**rag_generation_config, streaming=streaming) - if rag_generation_config - else GenerationConfig( - **{"stream": streaming, "model": "gpt-3.5-turbo"} - ) - ) - response = self.r2r.rag( + response = self.r2r_app.rag( query, search_filters={"user_id": self.user_id}, + vector_search_settings=VectorSearchSettings( + use_vector_search=use_vector_search, + search_filters=search_filters, + search_limit=search_limit, + do_hybrid_search=do_hybrid_search, + ), + kg_search_settings=KGSearchSettings( + use_kg=use_kg, + agent_generation_config=kg_agent_generation_config, + ), rag_generation_config=rag_generation_config, ) - if not streaming: + if not stream: t1 = time.time() print(f"Time taken to get RAG response: {t1-t0:.2f} seconds") print(response) @@ -358,7 +423,7 @@ def evaluate( completion=completion, ) else: - response = self.r2r.evaluate( + response = self.r2r_app.evaluate( query=query, context=context, completion=completion, @@ -386,7 +451,7 @@ def delete( if hasattr(self, "client"): response = self.client.delete(keys, values) else: - response = self.r2r.delete(keys, values) + response = self.r2r_app.delete(keys, values) t1 = time.time() print(f"Time taken to delete: {t1-t0:.2f} seconds") print(response) @@ -397,7 +462,7 @@ def logs(self, log_type_filter: Optional[str] = None): response = self.client.logs(log_type_filter) else: t0 = time.time() - response = self.r2r.logs(log_type_filter) + response = self.r2r_app.logs(log_type_filter) t1 = time.time() print(f"Time taken to get logs: {t1-t0:.2f} seconds") print(response) @@ -412,17 +477,18 @@ def documents_info( response = self.client.documents_info(document_ids, user_ids) else: t0 = time.time() - response = self.r2r.documents_info(document_ids, user_ids) + response = self.r2r_app.documents_info(document_ids, user_ids) t1 = time.time() print(f"Time taken to get document info: {t1-t0:.2f} seconds") - print(response) + for document in response: + print(document) def document_chunks(self, document_id: str): t0 = time.time() if hasattr(self, "client"): response = self.client.document_chunks(document_id) else: - response = self.r2r.document_chunks(document_id) + response = self.r2r_app.document_chunks(document_id) t1 = time.time() print(f"Time taken to get document chunks: {t1-t0:.2f} seconds") print(response) @@ -433,7 +499,7 @@ def app_settings(self): response = self.client.app_settings() else: t0 = time.time() - response = self.r2r.app_settings() + response = self.r2r_app.app_settings() t1 = time.time() print(f"Time taken to get app data: {t1-t0:.2f} seconds") print(response) @@ -445,10 +511,11 @@ def users_stats(self, user_ids: Optional[list[uuid.UUID]] = None): response = self.client.users_stats(user_ids) else: t0 = time.time() - response = self.r2r.users_stats(user_ids) + response = self.r2r_app.users_stats(user_ids) t1 = time.time() print(f"Time taken to get user stats: {t1-t0:.2f} seconds") - print(response) + for user in response: + print(user) def analytics( self, @@ -465,7 +532,7 @@ def analytics( analysis_types=analysis_types.model_dump(), ) else: - response = self.r2r.analytics( + response = self.r2r_app.analytics( filter_criteria=filter_criteria, analysis_types=analysis_types ) @@ -474,7 +541,7 @@ def analytics( print(response) def serve(self, host: str = "0.0.0.0", port: int = 8000): - self.r2r.serve(host, port) + self.r2r_app.serve(host, port) if __name__ == "__main__": diff --git a/r2r/examples/scripts/advanced_kg_cookbook.py b/r2r/examples/scripts/advanced_kg_cookbook.py index 0e8929e4..524d32a9 100644 --- a/r2r/examples/scripts/advanced_kg_cookbook.py +++ b/r2r/examples/scripts/advanced_kg_cookbook.py @@ -7,13 +7,13 @@ from r2r import ( Document, EntityType, - GenerationConfig, - KGAgentPipe, + KGAgentSearchPipe, Pipeline, R2RAppBuilder, Relation, run_pipeline, ) +from r2r.core.abstractions.llm import GenerationConfig def get_all_yc_co_directory_urls(): @@ -214,13 +214,13 @@ def main(max_entries=50, delete=False): print_all_relationships(kg) - kg_agent_pipe = KGAgentPipe( + kg_agent_search_pipe = KGAgentSearchPipe( r2r_app.providers.kg, r2r_app.providers.llm, r2r_app.providers.prompt ) # Define the pipeline kg_pipe = Pipeline() - kg_pipe.add_pipe(kg_agent_pipe) + kg_pipe.add_pipe(kg_agent_search_pipe) kg.update_agent_prompt(prompt_provider, entity_types, relations) diff --git a/r2r/examples/scripts/run_hyde.py b/r2r/examples/scripts/run_hyde.py index 53b0aef9..de9b7134 100644 --- a/r2r/examples/scripts/run_hyde.py +++ b/r2r/examples/scripts/run_hyde.py @@ -1,9 +1,5 @@ -from r2r import ( - GenerationConfig, - R2RAppBuilder, - R2RConfig, - R2RPipeFactoryWithMultiSearch, -) +from r2r import R2RAppBuilder, R2RConfig, R2RPipeFactoryWithMultiSearch +from r2r.core.abstractions.llm import GenerationConfig if __name__ == "__main__": # Load the configuration file diff --git a/r2r/examples/scripts/run_web_multi_search.py b/r2r/examples/scripts/run_web_multi_search.py index e7515a07..e38a4cc6 100644 --- a/r2r/examples/scripts/run_web_multi_search.py +++ b/r2r/examples/scripts/run_web_multi_search.py @@ -1,12 +1,12 @@ import fire from r2r import ( - GenerationConfig, R2RAppBuilder, R2RPipeFactoryWithMultiSearch, SerperClient, WebSearchPipe, ) +from r2r.core.abstractions.llm import GenerationConfig def run_rag_pipeline(query="Who was Aristotle?"): diff --git a/r2r/examples/scripts/run_web_search.py b/r2r/examples/scripts/run_web_search.py index 2a06dc1b..bbf5f870 100644 --- a/r2r/examples/scripts/run_web_search.py +++ b/r2r/examples/scripts/run_web_search.py @@ -1,6 +1,7 @@ import fire -from r2r import GenerationConfig, R2RAppBuilder, SerperClient, WebSearchPipe +from r2r import R2RAppBuilder, SerperClient, WebSearchPipe +from r2r.core.abstractions.llm import GenerationConfig def run_rag_pipeline(query="Who was Aristotle?"): diff --git a/r2r/main/__init__.py b/r2r/main/__init__.py index dce33590..6f149db0 100644 --- a/r2r/main/__init__.py +++ b/r2r/main/__init__.py @@ -1,4 +1,20 @@ -from .r2r_abstractions import R2RPipelines, R2RProviders +from .r2r_abstractions import ( + R2RAnalyticsRequest, + R2RDeleteRequest, + R2RDocumentChunksRequest, + R2RDocumentsInfoRequest, + R2REvalRequest, + R2RIngestDocumentsRequest, + R2RIngestFilesRequest, + R2RPipelines, + R2RProviders, + R2RRAGRequest, + R2RSearchRequest, + R2RUpdateDocumentsRequest, + R2RUpdateFilesRequest, + R2RUpdatePromptRequest, + R2RUsersStatsRequest, +) from .r2r_app import R2RApp from .r2r_builder import R2RAppBuilder from .r2r_client import R2RClient @@ -8,6 +24,19 @@ __all__ = [ "R2RPipelines", "R2RProviders", + "R2RUpdatePromptRequest", + "R2RIngestDocumentsRequest", + "R2RUpdateDocumentsRequest", + "R2RIngestFilesRequest", + "R2RUpdateFilesRequest", + "R2RSearchRequest", + "R2RRAGRequest", + "R2REvalRequest", + "R2RDeleteRequest", + "R2RAnalyticsRequest", + "R2RUsersStatsRequest", + "R2RDocumentsInfoRequest", + "R2RDocumentChunksRequest", "R2RApp", "R2RConfig", "R2RClient", diff --git a/r2r/main/r2r_abstractions.py b/r2r/main/r2r_abstractions.py index 211cd2b1..eca995de 100644 --- a/r2r/main/r2r_abstractions.py +++ b/r2r/main/r2r_abstractions.py @@ -1,19 +1,26 @@ -from typing import Optional +import uuid +from typing import Optional, Union from pydantic import BaseModel from r2r.core import ( + AnalysisTypes, + Document, EmbeddingProvider, EvalPipeline, EvalProvider, + FilterCriteria, + GenerationConfig, IngestionPipeline, KGProvider, + KGSearchSettings, LLMProvider, LoggableAsyncPipe, PromptProvider, RAGPipeline, SearchPipeline, VectorDBProvider, + VectorSearchSettings, ) @@ -39,6 +46,7 @@ class R2RPipes(BaseModel): eval_pipe: Optional[LoggableAsyncPipe] kg_pipe: Optional[LoggableAsyncPipe] kg_storage_pipe: Optional[LoggableAsyncPipe] + kg_agent_search_pipe: Optional[LoggableAsyncPipe] class Config: arbitrary_types_allowed = True @@ -53,3 +61,75 @@ class R2RPipelines(BaseModel): class Config: arbitrary_types_allowed = True + + +class R2RUpdatePromptRequest(BaseModel): + name: str + template: Optional[str] = None + input_types: Optional[dict[str, str]] = {} + + +class R2RIngestDocumentsRequest(BaseModel): + documents: list[Document] + versions: Optional[list[str]] = None + + +class R2RUpdateDocumentsRequest(BaseModel): + documents: list[Document] + versions: Optional[list[str]] = None + metadatas: Optional[list[dict]] = None + + +class R2RIngestFilesRequest(BaseModel): + metadatas: Optional[list[dict]] = None + document_ids: Optional[list[uuid.UUID]] = None + user_ids: Optional[list[Optional[uuid.UUID]]] = None + versions: Optional[list[str]] = None + skip_document_info: Optional[bool] = False + + +class R2RUpdateFilesRequest(BaseModel): + metadatas: Optional[list[dict]] = None + document_ids: Optional[uuid.UUID] = None + + +class R2RSearchRequest(BaseModel): + query: str + vector_settings: VectorSearchSettings + kg_settings: KGSearchSettings + + +class R2RRAGRequest(BaseModel): + query: str + vector_settings: VectorSearchSettings + kg_settings: KGSearchSettings + rag_generation_config: Optional[GenerationConfig] = None + + +class R2REvalRequest(BaseModel): + query: str + context: str + completion: str + + +class R2RDeleteRequest(BaseModel): + keys: list[str] + values: list[Union[bool, int, str]] + + +class R2RAnalyticsRequest(BaseModel): + filter_criteria: FilterCriteria + analysis_types: AnalysisTypes + + +class R2RUsersStatsRequest(BaseModel): + user_ids: Optional[list[uuid.UUID]] + + +class R2RDocumentsInfoRequest(BaseModel): + document_ids: Optional[list[uuid.UUID]] + user_ids: Optional[list[uuid.UUID]] + + +class R2RDocumentChunksRequest(BaseModel): + document_id: uuid.UUID diff --git a/r2r/main/r2r_app.py b/r2r/main/r2r_app.py index 3f60ec7f..db0b4514 100644 --- a/r2r/main/r2r_app.py +++ b/r2r/main/r2r_app.py @@ -7,10 +7,17 @@ from datetime import datetime from typing import Any, Optional, Union -from fastapi import Body, FastAPI, File, Form, HTTPException, Query, UploadFile +from fastapi import ( + Depends, + FastAPI, + File, + Form, + HTTPException, + Query, + UploadFile, +) from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse -from pydantic import BaseModel from r2r.core import ( AnalysisTypes, @@ -18,7 +25,6 @@ DocumentInfo, DocumentType, FilterCriteria, - GenerationConfig, KVLoggingSingleton, LogProcessor, RunManager, @@ -27,10 +33,29 @@ manage_run, to_async_generator, ) +from r2r.core.abstractions.llm import GenerationConfig from r2r.pipes import EvalPipe from r2r.telemetry.telemetry_decorator import telemetry_event -from .r2r_abstractions import R2RPipelines, R2RProviders +from .r2r_abstractions import ( + KGSearchSettings, + R2RAnalyticsRequest, + R2RDeleteRequest, + R2RDocumentChunksRequest, + R2RDocumentsInfoRequest, + R2REvalRequest, + R2RIngestDocumentsRequest, + R2RIngestFilesRequest, + R2RPipelines, + R2RProviders, + R2RRAGRequest, + R2RSearchRequest, + R2RUpdateDocumentsRequest, + R2RUpdateFilesRequest, + R2RUpdatePromptRequest, + R2RUsersStatsRequest, + VectorSearchSettings, +) from .r2r_config import R2RConfig MB_CONVERSION_FACTOR = 1024 * 1024 @@ -142,6 +167,7 @@ def __init__( self, config: R2RConfig, providers: R2RProviders, + pipes: R2RPipelines, pipelines: R2RPipelines, run_manager: Optional[RunManager] = None, do_apply_cors: bool = True, @@ -150,6 +176,7 @@ def __init__( ): self.config = config self.providers = providers + self.pipes = pipes self.logging_connection = KVLoggingSingleton() self.ingestion_pipeline = pipelines.ingestion_pipeline self.search_pipeline = pipelines.search_pipeline @@ -159,6 +186,30 @@ def __init__( self.run_manager = run_manager or RunManager(self.logging_connection) self.app = FastAPI() + from fastapi import Request + from fastapi.exceptions import RequestValidationError + from fastapi.responses import JSONResponse + + @self.app.exception_handler(RequestValidationError) + async def validation_exception_handler( + request: Request, exc: RequestValidationError + ): + logger.error(f"Validation error for request: {request.url}") + logger.error(f"Validation error details: {exc.errors()}") + return JSONResponse( + status_code=422, + content={"detail": exc.errors(), "body": exc.body}, + ) + + @self.app.exception_handler(HTTPException) + async def http_exception_handler(request: Request, exc: HTTPException): + logger.error(f"HTTP exception for request: {request.url}") + logger.error(f"HTTP exception details: {exc.detail}") + return JSONResponse( + status_code=exc.status_code, + content={"detail": exc.detail}, + ) + self._setup_routes() if do_apply_cors: self._apply_cors() @@ -247,23 +298,18 @@ def _setup_routes(self): ) @syncable - async def aupsert_prompt( + async def aupdate_prompt( self, name: str, template: str, input_types: dict ): """Upsert a prompt into the system.""" - self.providers.prompt.add_prompt(name, template, input_types) + self.providers.prompt.update_prompt(name, template, input_types) return {"results": f"Prompt '{name}' added successfully."} - class UpdatePromptRequest(BaseModel): - name: str - template: Optional[str] = None - input_types: Optional[dict[str, str]] = None - @telemetry_event("UpdatePrompt") - async def update_prompt_app(self, request: UpdatePromptRequest): + async def update_prompt_app(self, request: R2RUpdatePromptRequest): """Update a prompt's template and/or input types.""" try: - return await self.aupsert_prompt( + return await self.aupdate_prompt( request.name, request.template, request.input_types ) except Exception as e: @@ -276,15 +322,10 @@ async def update_prompt_app(self, request: UpdatePromptRequest): async def aingest_documents( self, documents: list[Document], - metadatas: Optional[list[dict]] = None, versions: Optional[list[str]] = None, *args: Any, **kwargs: Any, ): - if metadatas and len(metadatas) != len(documents): - raise ValueError( - "Number of metadata entries does not match number of documents." - ) if len(documents) == 0: raise HTTPException( status_code=400, detail="No documents provided for ingestion." @@ -311,20 +352,13 @@ async def aingest_documents( status_code=409, detail=f"Document with ID {document.id} already exists.", ) - skipped_documents.append(document.title or str(document.id)) + skipped_documents.append( + document.metadata.get("title", None) or str(document.id) + ) continue - document_metadata = ( - metadatas[iteration] if metadatas else document.metadata - ) - document_title = ( - document_metadata.get("title", None) or document.title - ) - document_metadata["title"] = document_title - - if document.user_id: - document_metadata["user_id"] = str(document.user_id) - document.metadata = document_metadata + document_title = document.metadata.get("title", None) + document_user_id = document.metadata.get("user_id", None) now = datetime.now() version = versions[iteration] if versions else "v0" @@ -334,16 +368,16 @@ async def aingest_documents( "document_id": document.id, "version": version, "size_in_bytes": len(document.data), - "metadata": document_metadata.copy(), + "metadata": document.metadata.copy(), "title": document_title, - "user_id": document.user_id, + "user_id": document_user_id, "created_at": now, "updated_at": now, } ) ) - processed_documents.append(document.title or str(document.id)) + processed_documents.append(document_title or str(document.id)) if skipped_documents and len(skipped_documents) == len(documents): logger.error("All provided documents already exist.") @@ -392,16 +426,15 @@ async def aingest_documents( ], } - class IngestDocumentsRequest(BaseModel): - documents: list[Document] - @telemetry_event("IngestDocuments") - async def ingest_documents_app(self, request: IngestDocumentsRequest): + async def ingest_documents_app(self, request: R2RIngestDocumentsRequest): async with manage_run( self.run_manager, "ingest_documents_app" ) as run_id: try: - results = await self.aingest_documents(request.documents) + results = await self.aingest_documents( + request.documents, request.versions + ) return {"results": results} except HTTPException as he: @@ -459,12 +492,9 @@ async def aupdate_documents( document_metadata = ( metadatas[iteration] if metadatas else doc.metadata ) - document_metadata["title"] = ( - document_metadata.get("title", None) or doc.title - ) - document_metadata["user_id"] = ( - str(doc.user_id) if doc.user_id else None - ) + document_metadata["title"] = document_metadata.get( + "title", None + ) or document_metadata.get("title", None) document_infos_modified.append( DocumentInfo( **{ @@ -473,7 +503,7 @@ async def aupdate_documents( "size_in_bytes": len(doc.data), "metadata": document_metadata.copy(), "title": document_metadata["title"], - "user_id": doc.user_id, + "user_id": document_metadata.get("user_id", None), "created_at": document_info.created_at, "updated_at": datetime.now(), } @@ -490,16 +520,15 @@ async def aupdate_documents( self.providers.vector_db.upsert_documents_info(document_infos_modified) return {"results": "Documents updated."} - class UpdateDocumentsRequest(BaseModel): - documents: list[Document] - @telemetry_event("UpdateDocuments") - async def update_documents_app(self, request: UpdateDocumentsRequest): + async def update_documents_app(self, request: R2RUpdateDocumentsRequest): async with manage_run( self.run_manager, "update_documents_app" ) as run_id: try: - return await self.aupdate_documents(request.documents) + return await self.aupdate_documents( + request.documents, request.versions, request.metadatas + ) except Exception as e: await self.logging_connection.log( log_id=run_id, @@ -523,8 +552,8 @@ async def aingest_files( self, files: list[UploadFile], metadatas: Optional[list[dict]] = None, - document_ids: Optional[list[uuid.UUID]] = None, - user_ids: Optional[list[Optional[uuid.UUID]]] = None, + document_ids: Optional[list[str]] = None, + user_ids: Optional[list[Optional[str]]] = None, versions: Optional[list[str]] = None, skip_document_info: bool = False, *args: Any, @@ -595,9 +624,16 @@ async def aingest_files( status_code=415, detail=f"{file_extension} is explicitly excluded in the configuration file.", ) + document_metadata = metadatas[iteration] if metadatas else {} + + document_title = ( + document_metadata.get("title", None) + or file.filename.split(os.path.sep)[-1] + ) + document_metadata["title"] = document_title document_id = ( - generate_id_from_label(file.filename) + generate_id_from_label(document_title) if document_ids is None else document_ids[iteration] ) @@ -620,12 +656,6 @@ async def aingest_files( file_content = await file.read() logger.info(f"File read successfully: {file.filename}") - document_metadata = metadatas[iteration] if metadatas else {} - document_title = ( - document_metadata.get("title", None) or file.filename - ) - document_metadata["title"] = document_title - user_id = user_ids[iteration] if user_ids else None if user_id: document_metadata["user_id"] = str(user_id) @@ -698,59 +728,66 @@ async def aingest_files( for file in files: file.file.close() + def parse_ingest_files_form_data( + metadatas: Optional[str] = Form(None), + document_ids: str = Form(...), + user_ids: str = Form(...), + versions: Optional[str] = Form(None), + skip_document_info: bool = Form(False), + ) -> R2RIngestFilesRequest: + try: + # Parse the form data + request_data = { + "metadatas": ( + json.loads(metadatas) + if metadatas and metadatas != "null" + else None + ), + "document_ids": ( + [uuid.UUID(doc_id) for doc_id in json.loads(document_ids)] + if document_ids and document_ids != "null" + else None + ), + "user_ids": ( + [ + uuid.UUID(user_id) if user_id else None + for user_id in json.loads(user_ids) + ] + if user_ids and user_ids != "null" + else None + ), + "versions": ( + json.loads(versions) + if versions and versions != "null" + else None + ), + "skip_document_info": skip_document_info, + } + return R2RIngestFilesRequest(**request_data) + except Exception as e: + raise HTTPException( + status_code=400, detail=f"Invalid form data: {e}" + ) + @telemetry_event("IngestFiles") + # Cannot use `request = R2RIngestFilesRequest` here because of the file upload async def ingest_files_app( self, files: list[UploadFile] = File(...), - metadatas: Optional[str] = Form(None), - ids: Optional[str] = Form(None), - user_ids: Optional[str] = Form(None), + request: R2RIngestFilesRequest = Depends(parse_ingest_files_form_data), ): """Ingest files into the system.""" async with manage_run(self.run_manager, "ingest_files_app") as run_id: try: - if ids and ids != "null": - ids_list = json.loads(ids) - if len(ids_list) != 0: - try: - ids_list = [uuid.UUID(id) for id in ids_list] - except ValueError as e: - raise HTTPException( - status_code=400, - detail="Invalid UUID provided.", - ) from e - else: - ids_list = None - - if user_ids and user_ids != "null": - user_ids_list = json.loads(user_ids) - if len(user_ids_list) != 0: - try: - user_ids_list = [ - uuid.UUID(id) if id else None - for id in user_ids_list - ] - except ValueError as e: - raise HTTPException( - status_code=400, - detail="Invalid UUID provided.", - ) from e - else: - user_ids_list = None - - # Parse metadatas if provided - metadatas = ( - json.loads(metadatas) - if metadatas and metadatas != "null" - else None - ) # Call aingest_files with the correct order of arguments results = await self.aingest_files( files=files, - metadatas=metadatas, - document_ids=ids_list, - user_ids=user_ids_list, + metadatas=request.metadatas, + document_ids=request.document_ids, + user_ids=request.user_ids, + versions=request.versions, + skip_document_info=request.skip_document_info, ) return {"results": results} @@ -778,7 +815,7 @@ async def ingest_files_app( async def aupdate_files( self, files: list[UploadFile], - ids: list[uuid.UUID], + document_ids: list[uuid.UUID], metadatas: Optional[list[dict]] = None, *args: Any, **kwargs: Any, @@ -790,7 +827,7 @@ async def aupdate_files( try: # Parse ids if provided - if len(ids) != len(files): + if len(document_ids) != len(files): raise HTTPException( status_code=400, detail="Number of ids does not match number of files.", @@ -806,7 +843,9 @@ async def aupdate_files( # Get the current document info old_versions = [] new_versions = [] - documents_info = await self.adocuments_info(document_ids=ids) + documents_info = await self.adocuments_info( + document_ids=document_ids + ) documents_info_modified = [] if len(documents_info) != len(files): raise HTTPException( @@ -817,7 +856,7 @@ async def aupdate_files( if not document_info: raise HTTPException( status_code=404, - detail=f"Document with id {ids[it]} not found.", + detail=f"Document with id {document_ids[it]} not found.", ) current_version = document_info.version @@ -832,20 +871,21 @@ async def aupdate_files( document_info.updated_at = datetime.now() title = files[it].filename.split(os.path.sep)[-1] - document_info.title = title - document_info.metadata["title"] = title + document_info.metadata["title"] = ( + document_info.metadata.get("title", None) or title + ) documents_info_modified.append(document_info) await self.aingest_files( files, [ele.metadata for ele in documents_info_modified], - ids, + document_ids, versions=new_versions, skip_document_info=True, ) # Delete the old version - for id, old_version in zip(ids, old_versions): + for id, old_version in zip(document_ids, old_versions): await self.adelete( ["document_id", "version"], [str(id), old_version] ) @@ -862,43 +902,42 @@ async def aupdate_files( for file in files: file.file.close() - class UpdateFilesRequest(BaseModel): - files: list[UploadFile] = File(...) - metadatas: Optional[str] = Form(None) - ids: str = Form("") + def parse_update_files_form_data( + metadatas: Optional[str] = Form(None), + document_ids: str = Form(...), + ) -> R2RUpdateFilesRequest: + try: + # Parse the form data + request_data = { + "metadatas": ( + json.loads(metadatas) + if metadatas and metadatas != "null" + else None + ), + "document_ids": ( + [uuid.UUID(doc_id) for doc_id in json.loads(document_ids)] + if document_ids and document_ids != "null" + else None + ), + } + return R2RIngestFilesRequest(**request_data) + except Exception as e: + raise HTTPException( + status_code=400, detail=f"Invalid form data: {e}" + ) @telemetry_event("UpdateFiles") async def update_files_app( self, files: list[UploadFile] = File(...), - metadatas: Optional[str] = Form(None), - ids: Optional[str] = Form(None), + request: R2RUpdateFilesRequest = Depends(parse_update_files_form_data), ): async with manage_run(self.run_manager, "update_files_app") as run_id: try: - # Parse metadatas if provided - metadatas = ( - json.loads(metadatas) - if metadatas and metadatas != "null" - else None - ) - - # Parse ids if provided - ids_list = json.loads(ids) - if ids_list: - ids_list = [uuid.UUID(id) for id in ids_list] - if len(ids_list) != len(files): - raise HTTPException( - status_code=400, - detail="Number of ids does not match number of files.", - ) - if metadatas and len(metadatas) != len(files): - raise HTTPException( - status_code=400, - detail="Number of metadata entries does not match number of files.", - ) return await self.aupdate_files( - files=files, metadatas=metadatas, ids=ids_list + files=files, + metadatas=request.metadatas, + document_ids=request.document_ids, ) except Exception as e: await self.logging_connection.log( @@ -920,9 +959,8 @@ async def update_files_app( async def asearch( self, query: str, - search_filters: Optional[dict] = None, - search_limit: int = 10, - do_hybrid_search: bool = False, + vector_search_settings: VectorSearchSettings = VectorSearchSettings(), + kg_search_settings: KGSearchSettings = KGSearchSettings(), *args: Any, **kwargs: Any, ): @@ -930,13 +968,11 @@ async def asearch( async with manage_run(self.run_manager, "search_app") as run_id: t0 = time.time() - search_filters = search_filters or {} results = await self.search_pipeline.run( input=to_async_generator([query]), - search_filters=search_filters, - search_limit=search_limit, + vector_search_settings=vector_search_settings, + kg_search_settings=kg_search_settings, run_manager=self.run_manager, - do_hybrid_search=do_hybrid_search, ) t1 = time.time() @@ -949,29 +985,14 @@ async def asearch( is_info_log=False, ) - return {"results": [result.dict() for result in results]} - - class SearchRequest(BaseModel): - query: str - search_filters: Optional[str] = None - search_limit: int = 10 - do_hybrid_search: Optional[bool] = False + return {"results": results.dict()} @telemetry_event("Search") - async def search_app(self, request: SearchRequest): + async def search_app(self, request: R2RSearchRequest): async with manage_run(self.run_manager, "search_app") as run_id: try: - search_filters = ( - {} - if request.search_filters is None - or request.search_filters == "null" - else json.loads(request.search_filters) - ) return await self.asearch( - request.query, - search_filters, - request.search_limit, - request.do_hybrid_search, + request.query, request.vector_settings, request.kg_settings ) except Exception as e: # TODO - Make this more modular @@ -993,17 +1014,16 @@ async def search_app(self, request: SearchRequest): @syncable async def arag( self, - message: str, - rag_generation_config: GenerationConfig, - search_filters: Optional[dict[str, str]] = None, - search_limit: int = 10, + query: str, + vector_search_settings: VectorSearchSettings = VectorSearchSettings(), + kg_search_settings: KGSearchSettings = KGSearchSettings(), + rag_generation_config: GenerationConfig = GenerationConfig(), *args, **kwargs, ): async with manage_run(self.run_manager, "rag_app") as run_id: try: t0 = time.time() - if rag_generation_config.stream: t1 = time.time() latency = f"{t1-t0:.2f}" @@ -1016,17 +1036,16 @@ async def arag( ) async def stream_response(): - # We must re-enter the manage_run context for the streaming pipeline + # We must re-enter the manage_run context for the stream pipeline async with manage_run(self.run_manager, "arag"): async for ( chunk ) in await self.streaming_rag_pipeline.run( - input=to_async_generator([message]), - streaming=True, - search_filters=search_filters, - search_limit=search_limit, - rag_generation_config=rag_generation_config, + input=to_async_generator([query]), run_manager=self.run_manager, + vector_settings=vector_search_settings, + kg_settings=kg_search_settings, + rag_generation_config=rag_generation_config, *args, **kwargs, ): @@ -1036,12 +1055,11 @@ async def stream_response(): if not rag_generation_config.stream: results = await self.rag_pipeline.run( - input=to_async_generator([message]), - streaming=False, - search_filters=search_filters, - search_limit=search_limit, - rag_generation_config=rag_generation_config, + input=to_async_generator([query]), run_manager=self.run_manager, + vector_search_settings=vector_search_settings, + kg_search_settings=kg_search_settings, + rag_generation_config=rag_generation_config, ) t1 = time.time() @@ -1066,65 +1084,21 @@ async def stream_response(): status_code=500, detail="Internal Server Error" ) - class RAGRequest(BaseModel): - message: str - search_filters: Optional[str] = None - search_limit: int = 10 - rag_generation_config: Optional[str] = None - streaming: Optional[bool] = None - @telemetry_event("RAG") - async def rag_app(self, request: RAGRequest): + async def rag_app(self, request: R2RRAGRequest): + print("in rag app with request = ", request) async with manage_run(self.run_manager, "rag_app") as run_id: try: - # Parse search filters - search_filters = None - if request.search_filters and request.search_filters != "null": - try: - search_filters = json.loads(request.search_filters) - except json.JSONDecodeError as jde: - logger.error( - f"Error parsing search filters: {str(jde)}" - ) - raise HTTPException( - status_code=400, - detail=f"Error parsing search filters: {str(jde)}", - ) - - # Parse RAG generation config - rag_generation_config = GenerationConfig( - model="gpt-3.5-turbo", stream=request.streaming - ) - if ( - request.rag_generation_config - and request.rag_generation_config != "null" - ): - try: - parsed_config = json.loads( - request.rag_generation_config - ) - rag_generation_config = GenerationConfig( - **parsed_config, - stream=request.streaming, - ) - except json.JSONDecodeError as jde: - logger.error( - f"Error parsing RAG generation config: {str(jde)}" - ) - raise HTTPException( - status_code=400, - detail=f"Error parsing RAG generation config: {str(jde)}", - ) from jde - # Call the async RAG method response = await self.arag( - request.message, - rag_generation_config, - search_filters, - request.search_limit, + request.query, + request.vector_settings, + request.kg_settings, + request.rag_generation_config + or GenerationConfig(model="gpt-4o"), ) - if request.streaming: + if request.rag_generation_config.stream: return StreamingResponse( response, media_type="application/json" ) @@ -1154,7 +1128,7 @@ async def rag_app(self, request: RAGRequest): value=str(e), is_info_log=False, ) - raise HTTPException(status_code=500, detail=str(e)) + raise HTTPException(status_code=500, detail=str(e)) from e @syncable async def aevaluate( @@ -1178,13 +1152,8 @@ async def aevaluate( ) return {"results": result} - class EvalRequest(BaseModel): - query: str - context: str - completion: str - @telemetry_event("Evaluate") - async def evaluate_app(self, request: EvalRequest): + async def evaluate_app(self, request: R2REvalRequest): async with manage_run(self.run_manager, "evaluate_app") as run_id: try: return await self.aevaluate( @@ -1223,12 +1192,8 @@ async def adelete( return {"results": "Entries deleted successfully."} - class DeleteRequest(BaseModel): - keys: list[str] - values: list[Union[bool, int, str]] - @telemetry_event("Delete") - async def delete_app(self, request: DeleteRequest = Body(...)): + async def delete_app(self, request: R2RDeleteRequest): try: return await self.adelete(request.keys, request.values) except Exception as e: @@ -1387,16 +1352,14 @@ async def aanalytics( @telemetry_event("Analytics") async def analytics_app( self, - filter_criteria: FilterCriteria = Body(...), - analysis_types: AnalysisTypes = Body(...), + request: R2RAnalyticsRequest, ): async with manage_run(self.run_manager, "analytics_app"): try: - return await self.aanalytics(filter_criteria, analysis_types) - except Exception as e: - await self.run_manager.log_run_info( - "error", str(e), is_info_log=False + return await self.aanalytics( + request.filter_criteria, request.analysis_types ) + except Exception as e: raise HTTPException(status_code=500, detail=str(e)) from e @syncable @@ -1428,15 +1391,13 @@ async def ausers_stats(self, user_ids: Optional[list[uuid.UUID]] = None): ) @telemetry_event("UsersStats") - async def users_stats_app( - self, user_ids: Optional[list[uuid.UUID]] = Query(None) - ): + async def users_stats_app(self, request: R2RUsersStatsRequest): try: - users_stats = await self.ausers_stats(user_ids) + users_stats = await self.ausers_stats(request.user_ids) return {"results": users_stats} except Exception as e: logger.error( - f"users_stats_app(user_ids={user_ids}) - \n\n{str(e)})" + f"users_stats_app(user_ids={request.user_ids}) - \n\n{str(e)})" ) raise HTTPException(status_code=500, detail=str(e)) from e @@ -1458,19 +1419,15 @@ async def adocuments_info( ) @telemetry_event("DocumentsInfo") - async def documents_info_app( - self, - document_ids: Optional[list[str]] = Query(None), - user_ids: Optional[list[str]] = Query(None), - ): + async def documents_info_app(self, request: R2RDocumentsInfoRequest): try: documents_info = await self.adocuments_info( - document_id=document_ids, user_id=user_ids + document_id=request.document_ids, user_id=request.user_ids ) return {"results": documents_info} except Exception as e: logger.error( - f"documents_info_app(document_ids={document_ids}, user_ids={user_ids}) - \n\n{str(e)})" + f"documents_info_app(document_ids={request.document_ids}, user_ids={request.user_ids}) - \n\n{str(e)})" ) raise HTTPException(status_code=500, detail=str(e)) from e @@ -1479,13 +1436,13 @@ async def adocument_chunks(self, document_id: str) -> list[str]: return self.providers.vector_db.get_document_chunks(document_id) @telemetry_event("DocumentChunks") - async def document_chunks_app(self, document_id: str): + async def document_chunks_app(self, request: R2RDocumentChunksRequest): try: - chunks = await self.adocument_chunks(document_id) + chunks = await self.adocument_chunks(request.document_id) return {"results": chunks} except Exception as e: logger.error( - f"get_document_chunks_app(document_id={document_id}) - \n\n{str(e)})" + f"get_document_chunks_app(document_id={request.document_id}) - \n\n{str(e)})" ) raise HTTPException(status_code=500, detail=str(e)) from e diff --git a/r2r/main/r2r_builder.py b/r2r/main/r2r_builder.py index aa31132d..44f90cbc 100644 --- a/r2r/main/r2r_builder.py +++ b/r2r/main/r2r_builder.py @@ -200,4 +200,4 @@ def build(self, *args, **kwargs) -> R2RApp: ) r2r_app = self.r2r_app_override or R2RApp - return r2r_app(self.config, providers, pipelines) + return r2r_app(self.config, providers, pipes, pipelines) diff --git a/r2r/main/r2r_client.py b/r2r/main/r2r_client.py index 9444fb0f..47922ddc 100644 --- a/r2r/main/r2r_client.py +++ b/r2r/main/r2r_client.py @@ -1,112 +1,115 @@ import asyncio import json import uuid -from typing import AsyncGenerator, Generator, Optional, Union +from typing import Any, AsyncGenerator, Generator, Optional, Union +import fire import httpx import nest_asyncio import requests -from r2r.core import DocumentType +from r2r.core import KGSearchSettings, VectorSearchSettings +from r2r.main.r2r_abstractions import ( + GenerationConfig, + R2RAnalyticsRequest, + R2RDeleteRequest, + R2RDocumentChunksRequest, + R2RDocumentsInfoRequest, + R2RIngestDocumentsRequest, + R2RIngestFilesRequest, + R2RRAGRequest, + R2RSearchRequest, + R2RUpdateDocumentsRequest, + R2RUpdateFilesRequest, + R2RUpdatePromptRequest, + R2RUsersStatsRequest, +) nest_asyncio.apply() -def default_serializer(obj): - if isinstance(obj, uuid.UUID): - return str(obj) - if isinstance(obj, DocumentType): - return obj.value - if isinstance(obj, bytes): - raise TypeError("Bytes serialization is not yet supported.") - raise TypeError(f"Type {type(obj)} not serializable.") - - class R2RClient: def __init__(self, base_url: str): self.base_url = base_url def update_prompt( self, - name: str, + name: str = "default_system", template: Optional[str] = None, input_types: Optional[dict] = None, ) -> dict: url = f"{self.base_url}/update_prompt" - data = { - "name": name, - "template": template, - "input_types": input_types, - } - response = requests.post( - url, - data=json.dumps(data, default=default_serializer), - headers={"Content-Type": "application/json"}, + request = R2RUpdatePromptRequest( + name=name, template=template, input_types=input_types ) + response = requests.post(url, json=json.loads(request.json())) response.raise_for_status() return response.json() - def ingest_documents(self, documents: list[dict]) -> dict: + def ingest_documents( + self, documents: list[dict], versions: Optional[list[str]] = None + ) -> dict: url = f"{self.base_url}/ingest_documents" - data = {"documents": documents} - serialized_data = json.dumps(data, default=default_serializer) - response = requests.post( - url, - data=serialized_data, - headers={"Content-Type": "application/json"}, + request = R2RIngestDocumentsRequest( + documents=documents, versions=versions ) + response = requests.post(url, json=json.loads(request.json())) response.raise_for_status() return response.json() def ingest_files( self, - files: list[str], + file_paths: list[str], metadatas: Optional[list[dict]] = None, - ids: Optional[list[str]] = None, - user_ids: Optional[list[str]] = None, + document_ids: Optional[list[Union[uuid.UUID, str]]] = None, + user_ids: Optional[list[Union[uuid.UUID, str]]] = None, + versions: Optional[list[str]] = None, + skip_document_info: Optional[bool] = False, ) -> dict: url = f"{self.base_url}/ingest_files" files_to_upload = [ ("files", (file, open(file, "rb"), "application/octet-stream")) - for file in files + for file in file_paths ] - data = { - "metadatas": ( - None - if metadatas is None - else json.dumps(metadatas, default=default_serializer) - ), - "document_ids": ( - None - if ids is None - else json.dumps(ids, default=default_serializer) + request = R2RIngestFilesRequest( + metadatas=metadatas, + document_ids=( + [str(ele) for ele in document_ids] if document_ids else None ), - "user_ids": ( - None - if user_ids is None - else json.dumps(user_ids, default=default_serializer) - ), - } - response = requests.post(url, files=files_to_upload, data=data) + user_ids=[str(ele) for ele in user_ids] if user_ids else None, + versions=versions, + skip_document_info=skip_document_info, + ) + response = requests.post( + url, + # must use data instead of json when sending files + data={ + k: json.dumps(v) for k, v in json.loads(request.json()).items() + }, + files=files_to_upload, + ) + response.raise_for_status() return response.json() - def update_documents(self, documents: list[dict]) -> dict: + def update_documents( + self, + documents: list[dict], + versions: Optional[list[str]] = None, + metadatas: Optional[list[dict]] = None, + ) -> dict: url = f"{self.base_url}/update_documents" - data = {"documents": documents} - serialized_data = json.dumps(data, default=default_serializer) - response = requests.post( - url, - data=serialized_data, - headers={"Content-Type": "application/json"}, + request = R2RUpdateDocumentsRequest( + documents=documents, versions=versions, metadatas=metadatas ) + response = requests.post(url, json=json.loads(request.json())) response.raise_for_status() return response.json() def update_files( self, files: list[str], - ids: list[str], + document_ids: list[str], metadatas: Optional[list[dict]] = None, ) -> dict: url = f"{self.base_url}/update_files" @@ -114,115 +117,98 @@ def update_files( ("files", (file, open(file, "rb"), "application/octet-stream")) for file in files ] - data = { - "metadatas": ( - None - if metadatas is None - else json.dumps(metadatas, default=default_serializer) - ), - "ids": json.dumps(ids, default=default_serializer), - } - response = requests.post(url, files=files_to_upload, data=data) + request = R2RUpdateFilesRequest( + metadatas=metadatas, + document_ids=document_ids, + ) + response = requests.post( + url, files=files_to_upload, data=request.json() + ) response.raise_for_status() return response.json() def search( self, query: str, - search_filters: Optional[dict] = None, + use_vector_search: bool = True, + search_filters: Optional[dict[str, Any]] = {}, search_limit: int = 10, do_hybrid_search: bool = False, + use_kg: bool = False, + kg_agent_generation_config: Optional[GenerationConfig] = None, ) -> dict: + request = R2RSearchRequest( + query=query, + vector_settings=VectorSearchSettings( + use_vector_search=use_vector_search, + search_filters=search_filters, + search_limit=search_limit, + do_hybrid_search=do_hybrid_search, + ), + kg_settings=KGSearchSettings( + use_kg=use_kg, + agent_generation_config=kg_agent_generation_config, + ), + ) url = f"{self.base_url}/search" - data = { - "query": query, - "search_filters": json.dumps(search_filters or {}), - "search_limit": search_limit, - "do_hybrid_search": do_hybrid_search, - } - response = requests.post(url, json=data) + response = requests.post(url, json=json.loads(request.json())) response.raise_for_status() return response.json() def rag( self, - message: str, - search_filters: Optional[dict] = None, + query: str, + use_vector_search: bool = True, + search_filters: Optional[dict[str, Any]] = {}, search_limit: int = 10, - rag_generation_config: Optional[dict] = None, - streaming: bool = False, - ) -> Union[dict, Generator[str, None, None]]: - if streaming: - return self._stream_rag_sync( - message=message, + do_hybrid_search: bool = False, + use_kg: bool = False, + kg_agent_generation_config: Optional[GenerationConfig] = None, + rag_generation_config: Optional[GenerationConfig] = None, + ) -> dict: + request = R2RRAGRequest( + query=query, + vector_settings=VectorSearchSettings( + use_vector_search=use_vector_search, search_filters=search_filters, search_limit=search_limit, - rag_generation_config=rag_generation_config, - ) + do_hybrid_search=do_hybrid_search, + ), + kg_settings=KGSearchSettings( + use_kg=use_kg, + agent_generation_config=kg_agent_generation_config, + ), + rag_generation_config=rag_generation_config, + ) + + if rag_generation_config.stream: + return self._stream_rag_sync(request) else: try: url = f"{self.base_url}/rag" - data = { - "message": message, - "search_filters": ( - json.dumps(search_filters) if search_filters else None - ), - "search_limit": search_limit, - "rag_generation_config": ( - json.dumps(rag_generation_config) - if rag_generation_config - else None - ), - "streaming": streaming, - } - - response = requests.post(url, json=data) + response = requests.post(url, json=json.loads(request.json())) response.raise_for_status() return response.json() except requests.exceptions.RequestException as e: raise e async def _stream_rag( - self, - message: str, - search_filters: Optional[dict] = None, - search_limit: int = 10, - rag_generation_config: Optional[dict] = None, + self, rag_request: R2RRAGRequest ) -> AsyncGenerator[str, None]: url = f"{self.base_url}/rag" - data = { - "message": message, - "search_filters": ( - json.dumps(search_filters) if search_filters else None - ), - "search_limit": search_limit, - "rag_generation_config": ( - json.dumps(rag_generation_config) - if rag_generation_config - else None - ), - "streaming": True, - } async with httpx.AsyncClient() as client: - async with client.stream("POST", url, json=data) as response: + async with client.stream( + "POST", url, json=json.loads(rag_request.json()) + ) as response: response.raise_for_status() async for chunk in response.aiter_text(): yield chunk def _stream_rag_sync( - self, - message: str, - search_filters: Optional[dict] = None, - search_limit: int = 10, - rag_generation_config: Optional[dict] = None, + self, rag_request: R2RRAGRequest ) -> Generator[str, None, None]: async def run_async_generator(): - async for chunk in self._stream_rag( - message=message, - search_filters=search_filters, - search_limit=search_limit, - rag_generation_config=rag_generation_config, - ): + async for chunk in self._stream_rag(rag_request): yield chunk loop = asyncio.new_event_loop() @@ -245,8 +231,8 @@ def delete( self, keys: list[str], values: list[Union[bool, int, str]] ) -> dict: url = f"{self.base_url}/delete" - data = {"keys": keys, "values": values} - response = requests.request("DELETE", url, json=data) + request = R2RDeleteRequest(keys=keys, values=values) + response = requests.delete(url, json=json.loads(request.json())) response.raise_for_status() return response.json() @@ -267,53 +253,50 @@ def app_settings(self) -> dict: def analytics(self, filter_criteria: dict, analysis_types: dict) -> dict: url = f"{self.base_url}/analytics" - data = { - "filter_criteria": filter_criteria, - "analysis_types": analysis_types, - } - - try: - response = requests.post(url, json=data) - response.raise_for_status() - return response.json() - except requests.exceptions.RequestException as e: - if e.response is None: - raise requests.exceptions.RequestException( - f"Error occurred while calling analytics API. {str(e)}" - ) from e - status_code = e.response.status_code - error_message = e.response.text - raise requests.exceptions.RequestException( - f"Error occurred while calling analytics API. Status Code: {status_code}, Error Message: {error_message}" - ) from e + request = R2RAnalyticsRequest( + filter_criteria=filter_criteria, analysis_types=analysis_types + ) + response = requests.post(url, json=json.loads(request.json())) + response.raise_for_status() + return response.json() def users_stats(self, user_ids: Optional[list[str]] = None) -> dict: url = f"{self.base_url}/users_stats" - params = {} - if user_ids is not None: - params["user_ids"] = ",".join(user_ids) - response = requests.get(url, params=params) + request = R2RUsersStatsRequest( + user_ids=[uuid.UUID(uid) for uid in user_ids] if user_ids else None + ) + response = requests.get(url, json=json.loads(request.json())) response.raise_for_status() return response.json() def documents_info( self, - document_ids: Optional[str] = None, - user_ids: Optional[str] = None, + document_ids: Optional[list[str]] = None, + user_ids: Optional[list[str]] = None, ) -> dict: url = f"{self.base_url}/documents_info" - params = {} - params["document_ids"] = ( - json.dumps(document_ids) if document_ids else None + request = R2RDocumentsInfoRequest( + document_ids=( + [uuid.UUID(did) for did in document_ids] + if document_ids + else None + ), + user_ids=( + [uuid.UUID(uid) for uid in user_ids] if user_ids else None + ), ) - params["user_ids"] = json.dumps(user_ids) if user_ids else None - response = requests.get(url, params=params) + response = requests.get(url, json=json.loads(request.json())) response.raise_for_status() return response.json() def document_chunks(self, document_id: str) -> dict: url = f"{self.base_url}/document_chunks" - params = {"document_id": document_id} - response = requests.get(url, params=params) + request = R2RDocumentChunksRequest(document_id=document_id) + response = requests.post(url, json=json.loads(request.json())) response.raise_for_status() return response.json() + + +if __name__ == "__main__": + client = R2RClient(base_url="http://localhost:8000") + fire.Fire(client) diff --git a/r2r/main/r2r_factory.py b/r2r/main/r2r_factory.py index 367108ea..f549846c 100644 --- a/r2r/main/r2r_factory.py +++ b/r2r/main/r2r_factory.py @@ -209,6 +209,7 @@ def create_pipes( embedding_pipe_override: Optional[LoggableAsyncPipe] = None, kg_pipe_override: Optional[LoggableAsyncPipe] = None, kg_storage_pipe_override: Optional[LoggableAsyncPipe] = None, + kg_agent_pipe_override: Optional[LoggableAsyncPipe] = None, vector_storage_pipe_override: Optional[LoggableAsyncPipe] = None, search_pipe_override: Optional[LoggableAsyncPipe] = None, rag_pipe_override: Optional[LoggableAsyncPipe] = None, @@ -227,6 +228,8 @@ def create_pipes( kg_pipe=kg_pipe_override or self.create_kg_pipe(*args, **kwargs), kg_storage_pipe=kg_storage_pipe_override or self.create_kg_storage_pipe(*args, **kwargs), + kg_agent_search_pipe=kg_agent_pipe_override + or self.create_kg_agent_pipe(*args, **kwargs), vector_storage_pipe=vector_storage_pipe_override or self.create_vector_storage_pipe(*args, **kwargs), vector_search_pipe=search_pipe_override @@ -234,7 +237,7 @@ def create_pipes( rag_pipe=rag_pipe_override or self.create_rag_pipe(*args, **kwargs), streaming_rag_pipe=streaming_rag_pipe_override - or self.create_rag_pipe(streaming=True, *args, **kwargs), + or self.create_rag_pipe(stream=True, *args, **kwargs), eval_pipe=eval_pipe_override or self.create_eval_pipe(*args, **kwargs), ) @@ -330,8 +333,20 @@ def create_kg_storage_pipe(self, *args, **kwargs) -> Any: embedding_provider=self.providers.embedding, ) - def create_rag_pipe(self, streaming: bool = False, *args, **kwargs) -> Any: - if streaming: + def create_kg_agent_pipe(self, *args, **kwargs) -> Any: + if self.config.kg.provider is None: + return None + + from r2r.pipes import KGAgentSearchPipe + + return KGAgentSearchPipe( + kg_provider=self.providers.kg, + llm_provider=self.providers.llm, + prompt_provider=self.providers.prompt, + ) + + def create_rag_pipe(self, stream: bool = False, *args, **kwargs) -> Any: + if stream: from r2r.pipes import StreamingSearchRAGPipe return StreamingSearchRAGPipe( @@ -372,7 +387,7 @@ def create_ingestion_pipeline(self, *args, **kwargs) -> IngestionPipeline: ingestion_pipeline.add_pipe( self.pipes.vector_storage_pipe, embedding_pipe=True ) - # Add KG pipes provider is set + # Add KG pipes if provider is set if self.config.kg.provider is not None: ingestion_pipeline.add_pipe(self.pipes.kg_pipe, kg_pipe=True) ingestion_pipeline.add_pipe( @@ -382,35 +397,40 @@ def create_ingestion_pipeline(self, *args, **kwargs) -> IngestionPipeline: return ingestion_pipeline def create_search_pipeline(self, *args, **kwargs) -> SearchPipeline: + """factory method to create an ingestion pipeline.""" search_pipeline = SearchPipeline() - search_pipeline.add_pipe(self.pipes.vector_search_pipe) + + # Add vector search pipes if embedding provider and vector provider is set + if ( + self.config.embedding.provider is not None + and self.config.vector_database.provider is not None + ): + search_pipeline.add_pipe( + self.pipes.vector_search_pipe, vector_search_pipe=True + ) + + # Add KG pipes if provider is set + if self.config.kg.provider is not None: + search_pipeline.add_pipe( + self.pipes.kg_agent_search_pipe, kg_pipe=True + ) + return search_pipeline def create_rag_pipeline( - self, streaming: bool = False, *args, **kwargs + self, + search_pipeline: SearchPipeline, + stream: bool = False, + *args, + **kwargs, ) -> RAGPipeline: - vector_search_pipe = self.pipes.vector_search_pipe rag_pipe = ( - self.pipes.streaming_rag_pipe if streaming else self.pipes.rag_pipe + self.pipes.streaming_rag_pipe if stream else self.pipes.rag_pipe ) rag_pipeline = RAGPipeline() - rag_pipeline.add_pipe(vector_search_pipe) - rag_pipeline.add_pipe( - rag_pipe, - add_upstream_outputs=[ - { - "prev_pipe_name": vector_search_pipe.config.name, - "prev_output_field": "search_results", - "input_field": "raw_search_results", - }, - { - "prev_pipe_name": vector_search_pipe.config.name, - "prev_output_field": "search_queries", - "input_field": "query", - }, - ], - ) + rag_pipeline.set_search_pipeline(search_pipeline) + rag_pipeline.add_pipe(rag_pipe) return rag_pipeline def create_eval_pipeline(self, *args, **kwargs) -> EvalPipeline: @@ -432,16 +452,27 @@ def create_pipelines( self.configure_logging() except Exception as e: logger.warn(f"Error configuring logging: {e}") - + search_pipeline = search_pipeline or self.create_search_pipeline( + *args, **kwargs + ) return R2RPipelines( ingestion_pipeline=ingestion_pipeline or self.create_ingestion_pipeline(*args, **kwargs), - search_pipeline=search_pipeline - or self.create_search_pipeline(*args, **kwargs), + search_pipeline=search_pipeline, rag_pipeline=rag_pipeline - or self.create_rag_pipeline(streaming=False, *args, **kwargs), + or self.create_rag_pipeline( + search_pipeline=search_pipeline, + stream=False, + *args, + **kwargs, + ), streaming_rag_pipeline=streaming_rag_pipeline - or self.create_rag_pipeline(streaming=True, *args, **kwargs), + or self.create_rag_pipeline( + search_pipeline=search_pipeline, + stream=True, + *args, + **kwargs, + ), eval_pipeline=eval_pipeline or self.create_eval_pipeline(*args, **kwargs), ) diff --git a/r2r/pipes/__init__.py b/r2r/pipes/__init__.py index 1d6c4006..2a78f84c 100644 --- a/r2r/pipes/__init__.py +++ b/r2r/pipes/__init__.py @@ -1,7 +1,7 @@ from .abstractions.search_pipe import SearchPipe from .embedding_pipe import EmbeddingPipe from .eval_pipe import EvalPipe -from .kg_agent_pipe import KGAgentPipe +from .kg_agent_search_pipe import KGAgentSearchPipe from .kg_extraction_pipe import KGExtractionPipe from .kg_storage_pipe import KGStoragePipe from .parsing_pipe import ParsingPipe @@ -24,6 +24,6 @@ "VectorSearchPipe", "VectorStoragePipe", "WebSearchPipe", - "KGAgentPipe", + "KGAgentSearchPipe", "KGStoragePipe", ] diff --git a/r2r/pipes/abstractions/generator_pipe.py b/r2r/pipes/abstractions/generator_pipe.py index 0066ddf1..03ff2781 100644 --- a/r2r/pipes/abstractions/generator_pipe.py +++ b/r2r/pipes/abstractions/generator_pipe.py @@ -4,13 +4,13 @@ from r2r.core import ( AsyncState, - GenerationConfig, KVLoggingSingleton, LLMProvider, LoggableAsyncPipe, PipeType, PromptProvider, ) +from r2r.core.abstractions.llm import GenerationConfig class GeneratorPipe(LoggableAsyncPipe): diff --git a/r2r/pipes/abstractions/search_pipe.py b/r2r/pipes/abstractions/search_pipe.py index 8826f80f..737905cb 100644 --- a/r2r/pipes/abstractions/search_pipe.py +++ b/r2r/pipes/abstractions/search_pipe.py @@ -9,7 +9,7 @@ KVLoggingSingleton, LoggableAsyncPipe, PipeType, - SearchResult, + VectorSearchResult, ) logger = logging.getLogger(__name__) @@ -48,7 +48,7 @@ async def search( limit: int = 10, *args: Any, **kwargs: Any, - ) -> AsyncGenerator[SearchResult, None]: + ) -> AsyncGenerator[VectorSearchResult, None]: pass @abstractmethod @@ -59,5 +59,5 @@ async def _run_logic( run_id: uuid.UUID, *args: Any, **kwargs, - ) -> AsyncGenerator[SearchResult, None]: + ) -> AsyncGenerator[VectorSearchResult, None]: pass diff --git a/r2r/pipes/embedding_pipe.py b/r2r/pipes/embedding_pipe.py index c234b5b4..9b75d809 100644 --- a/r2r/pipes/embedding_pipe.py +++ b/r2r/pipes/embedding_pipe.py @@ -179,7 +179,7 @@ async def _run_logic( active_tasks += 1 fragment_batch.clear() - logger.info( + logger.debug( f"Fragmented the input document ids into counts as shown: {fragment_info}" ) diff --git a/r2r/pipes/eval_pipe.py b/r2r/pipes/eval_pipe.py index 8925ceb4..4789c0b4 100644 --- a/r2r/pipes/eval_pipe.py +++ b/r2r/pipes/eval_pipe.py @@ -7,11 +7,11 @@ from r2r import ( AsyncState, EvalProvider, - GenerationConfig, LLMChatCompletion, LoggableAsyncPipe, PipeType, ) +from r2r.core.abstractions.llm import GenerationConfig logger = logging.getLogger(__name__) diff --git a/r2r/pipes/kg_agent_pipe.py b/r2r/pipes/kg_agent_pipe.py deleted file mode 100644 index 750df37f..00000000 --- a/r2r/pipes/kg_agent_pipe.py +++ /dev/null @@ -1,198 +0,0 @@ -import logging -import uuid -from typing import Any, Optional - -from r2r.core import ( - AsyncState, - GenerationConfig, - KGProvider, - KVLoggingSingleton, - LLMProvider, - PipeType, - PromptProvider, -) - -from .abstractions.generator_pipe import GeneratorPipe - -logger = logging.getLogger(__name__) - - -prompt = """**System Message:** - -You are an AI assistant capable of generating Cypher queries to interact with a Neo4j knowledge graph. The knowledge graph contains information about organizations, people, locations, and their relationships, such as founders of companies, locations of companies, and products associated with companies. - -**Instructions:** - -When a user asks a question, you will generate a Cypher query to retrieve the relevant information from the Neo4j knowledge graph. Later, you will be given a schema which specifies the available relationships to help you construct the query. First, review the examples provided to understand the expected format of the queries. - -### Example(s) - User Questions and Cypher Queries for an Academic Knowledge Graph - -**User Question:** -"List all courses available in the computer science department." - -**Generated Cypher Query:** -```cypher -MATCH (c:COURSE)-[:OFFERED_BY]->(d:DEPARTMENT) -WHERE d.name CONTAINS 'Computer Science' -RETURN c.id AS Course, d.name AS Department -ORDER BY c.id; -``` - -**User Question:** -"Retrieve all courses taught by professors who have published research on natural language processing." - -**Generated Cypher Query:** -```cypher -MATCH (pr:PERSON)-[:PUBLISHED]->(p:PAPER) -MATCH (p)-[:TOPIC]->(t:TOPIC) -WHERE t.name CONTAINS 'Natural Language Processing' -MATCH (c:COURSE)-[:TAUGHT_BY]->(pr) -RETURN DISTINCT c.id AS Course, pr.name AS Professor, t.name AS Topic -ORDER BY c.id; -``` - - -### Example(s) - User Questions and Cypher Queries for an Historical Events and Figures - -**User Question:** -"List all battles that occurred in the 19th century and the generals who participated in them." - -**Generated Cypher Query:** -```cypher -MATCH (b:EVENT)-[:HAPPENED_AT]->(d:DATE) -WHERE d.year >= 1800 AND d.year < 1900 AND b.type CONTAINS 'Battle' -MATCH (g:PERSON)-[:PARTICIPATED_IN]->(b) -RETURN b.name AS Battle, d.year AS Year, g.name AS General -ORDER BY d.year, b.name, g.name; -``` - -**User Question:** -"Find all treaties signed in Paris and the countries involved." - - -**Generated Cypher Query:** -```cypher -MATCH (t:EVENT)-[:HAPPENED_AT]->(l:LOCATION) -WHERE l.name CONTAINS 'Paris' AND t.type CONTAINS 'Treaty' -MATCH (c:ORGANIZATION)-[:SIGNED]->(t) -RETURN t.name AS Treaty, l.name AS Location, c.name AS Country -ORDER BY t.name, c.name; -``` - - -Now, you will be provided with a schema for the entities and relationships in the Neo4j knowledge graph. Use this schema to construct Cypher queries based on user questions. - -- **Entities:** - - `ORGANIZATION` (e.g.: `COMPANY`, `SCHOOL`, `NON-PROFIT`, `OTHER`) - - `LOCATION` (e.g.: `CITY`, `STATE`, `COUNTRY`, `OTHER`) - - `PERSON` - - `POSITION` - - `DATE` (e.g.: `YEAR`, `MONTH`, `DAY`, `BATCH`, `OTHER`) - - `QUANTITY` - - `EVENT` (e.g.: `INCORPORATION`, `FUNDING_ROUND`, `ACQUISITION`, `LAUNCH`, `OTHER`) - - `INDUSTRY` - - `MEDIA` (e.g.: `EMAIL`, `WEBSITE`, `TWITTER`, `LINKEDIN`, `OTHER`) - - `PRODUCT` - -- **Relationships:** - - `FOUNDED` - - `WORKED_AT` - - `EDUCATED_AT` - - `RAISED` - - `REVENUE` - - `TEAM_SIZE` - - `LOCATION` - - `ACQUIRED_BY` - - `ANNOUNCED` - - `INDUSTRY` - - `PRODUCT` - - `FEATURES` - - `USES` - - `USED_BY` - - `TECHNOLOGY` - - `HAS` - - `AS_OF` - - `PARTICIPATED` - - `ASSOCIATED` - - `GROUP_PARTNER` - - `ALIAS` - -Use the referenced examples and schema to help you construct an appropriate Cypher query based on the following question: - -**User Question:** -{question} - -**Generated Cypher Query:** -""" - - -class KGAgentPipe(GeneratorPipe): - """ - Embeds and stores documents using a specified embedding model and database. - """ - - def __init__( - self, - kg_provider: KGProvider, - llm_provider: LLMProvider, - prompt_provider: PromptProvider, - pipe_logger: Optional[KVLoggingSingleton] = None, - type: PipeType = PipeType.INGESTOR, - config: Optional[GeneratorPipe.PipeConfig] = None, - *args, - **kwargs, - ): - """ - Initializes the embedding pipe with necessary components and configurations. - """ - super().__init__( - llm_provider=llm_provider, - prompt_provider=prompt_provider, - type=type, - config=config - or GeneratorPipe.Config( - name="kg_rag_pipe", task_prompt="kg_agent" - ), - pipe_logger=pipe_logger, - *args, - **kwargs, - ) - self.kg_provider = kg_provider - self.llm_provider = llm_provider - self.prompt_provider = prompt_provider - self.pipe_run_info = None - - async def _run_logic( - self, - input: GeneratorPipe.Input, - state: AsyncState, - run_id: uuid.UUID, - rag_generation_config: GenerationConfig, - *args: Any, - **kwargs: Any, - ): - async for message in input.message: - # TODO - Remove hard code - formatted_prompt = self.prompt_provider.get_prompt( - "kg_agent", {"input": message} - ) - messages = self._get_message_payload(formatted_prompt) - - result = self.llm_provider.get_completion( - messages=messages, generation_config=rag_generation_config - ) - - extraction = result.choices[0].message.content - query = extraction.split("```cypher")[1].split("```")[0] - yield self.kg_provider.structured_query(query) - - def _get_message_payload(self, message: str) -> dict: - return [ - { - "role": "system", - "content": self.prompt_provider.get_prompt( - self.config.system_prompt, - ), - }, - {"role": "user", "content": message}, - ] diff --git a/r2r/pipes/kg_agent_search_pipe.py b/r2r/pipes/kg_agent_search_pipe.py new file mode 100644 index 00000000..3a483988 --- /dev/null +++ b/r2r/pipes/kg_agent_search_pipe.py @@ -0,0 +1,90 @@ +import logging +import uuid +from typing import Any, Optional + +from r2r.core import ( + AsyncState, + KGProvider, + KGSearchSettings, + KVLoggingSingleton, + LLMProvider, + PipeType, + PromptProvider, +) + +from .abstractions.generator_pipe import GeneratorPipe + +logger = logging.getLogger(__name__) + + +class KGAgentSearchPipe(GeneratorPipe): + """ + Embeds and stores documents using a specified embedding model and database. + """ + + def __init__( + self, + kg_provider: KGProvider, + llm_provider: LLMProvider, + prompt_provider: PromptProvider, + pipe_logger: Optional[KVLoggingSingleton] = None, + type: PipeType = PipeType.INGESTOR, + config: Optional[GeneratorPipe.PipeConfig] = None, + *args, + **kwargs, + ): + """ + Initializes the embedding pipe with necessary components and configurations. + """ + super().__init__( + llm_provider=llm_provider, + prompt_provider=prompt_provider, + type=type, + config=config + or GeneratorPipe.Config( + name="kg_rag_pipe", task_prompt="kg_agent" + ), + pipe_logger=pipe_logger, + *args, + **kwargs, + ) + self.kg_provider = kg_provider + self.llm_provider = llm_provider + self.prompt_provider = prompt_provider + self.pipe_run_info = None + + async def _run_logic( + self, + input: GeneratorPipe.Input, + state: AsyncState, + run_id: uuid.UUID, + kg_search_settings: KGSearchSettings, + *args: Any, + **kwargs: Any, + ): + async for message in input.message: + # TODO - Remove hard code + formatted_prompt = self.prompt_provider.get_prompt( + "kg_agent", {"input": message} + ) + messages = self._get_message_payload(formatted_prompt) + + result = self.llm_provider.get_completion( + messages=messages, + generation_config=kg_search_settings.agent_generation_config, + ) + + extraction = result.choices[0].message.content + query = extraction.split("```cypher")[1].split("```")[0] + yield self.kg_provider.structured_query(query) + + def _get_message_payload(self, message: str) -> dict: + return [ + { + "role": "system", + "content": self.prompt_provider.get_prompt( + self.config.system_prompt, + ), + }, + {"role": "user", "content": message}, + ] diff --git a/r2r/pipes/kg_extraction_pipe.py b/r2r/pipes/kg_extraction_pipe.py index ca4b619a..42b7d573 100644 --- a/r2r/pipes/kg_extraction_pipe.py +++ b/r2r/pipes/kg_extraction_pipe.py @@ -12,7 +12,6 @@ Extraction, Fragment, FragmentType, - GenerationConfig, KGExtraction, KGProvider, KVLoggingSingleton, @@ -25,6 +24,7 @@ extract_triples, generate_id_from_label, ) +from r2r.core.abstractions.llm import GenerationConfig logger = logging.getLogger(__name__) @@ -212,7 +212,7 @@ async def _run_logic( ) # pass a copy if necessary fragment_batch.clear() # Clear the batch for new fragments - logger.info( + logger.debug( f"Fragmented the input document ids into counts as shown: {fragment_info}" ) diff --git a/r2r/pipes/parsing_pipe.py b/r2r/pipes/parsing_pipe.py index 7127e0f4..338c9166 100644 --- a/r2r/pipes/parsing_pipe.py +++ b/r2r/pipes/parsing_pipe.py @@ -83,9 +83,6 @@ def __init__( *args, **kwargs, ): - logger.info( - "Initializing a `ParsingPipe` to parse incoming documents." - ) super().__init__( pipe_logger=pipe_logger, type=type, @@ -175,8 +172,8 @@ async def _parse( ), ) iteration += 1 - logger.info( - f"Parsed document with id={document.id}, title={document.title}, user_id={document.user_id}, metadata={document.metadata} into {iteration} extractions in t={time.time()-t0:.2f} seconds." + logger.debug( + f"Parsed document with id={document.id}, title={document.metadata.get('title', None)}, user_id={document.metadata.get('user_id', None)}, metadata={document.metadata} into {iteration} extractions in t={time.time()-t0:.2f} seconds." ) async def _run_logic( diff --git a/r2r/pipes/query_transform_pipe.py b/r2r/pipes/query_transform_pipe.py index bcdad34f..8e8c029f 100644 --- a/r2r/pipes/query_transform_pipe.py +++ b/r2r/pipes/query_transform_pipe.py @@ -5,11 +5,11 @@ from r2r.core import ( AsyncPipe, AsyncState, - GenerationConfig, LLMProvider, PipeType, PromptProvider, ) +from r2r.core.abstractions.llm import GenerationConfig from .abstractions.generator_pipe import GeneratorPipe diff --git a/r2r/pipes/search_rag_pipe.py b/r2r/pipes/search_rag_pipe.py index 4ee68379..5834d4a4 100644 --- a/r2r/pipes/search_rag_pipe.py +++ b/r2r/pipes/search_rag_pipe.py @@ -1,17 +1,17 @@ import logging import uuid -from typing import Any, AsyncGenerator, Optional +from typing import Any, AsyncGenerator, Optional, Tuple from r2r.core import ( + AggregateSearchResult, AsyncPipe, AsyncState, - GenerationConfig, LLMChatCompletion, LLMProvider, PipeType, PromptProvider, - SearchResult, ) +from r2r.core.abstractions.llm import GenerationConfig from .abstractions.generator_pipe import GeneratorPipe @@ -20,9 +20,7 @@ class SearchRAGPipe(GeneratorPipe): class Input(AsyncPipe.Input): - message: AsyncGenerator[SearchResult, None] - query: list[str] - raw_search_results: Optional[list[SearchResult]] = None + message: AsyncGenerator[Tuple[str, AggregateSearchResult], None] def __init__( self, @@ -54,8 +52,19 @@ async def _run_logic( *args: Any, **kwargs: Any, ) -> AsyncGenerator[LLMChatCompletion, None]: - context = await self._collect_context(input) - messages = self._get_message_payload("\n".join(input.query), context) + context = "" + search_iteration = 1 + total_results = 0 + sel_query = None + async for query, search_results in input.message: + if search_iteration == 1: + sel_query = query + context_piece, total_results = await self._collect_context( + query, search_results, search_iteration, total_results + ) + context += context_piece + search_iteration += 1 + messages = self._get_message_payload(sel_query, context) response = self.llm_provider.get_completion( messages=messages, generation_config=rag_generation_config @@ -81,18 +90,36 @@ def _get_message_payload(self, query: str, context: str) -> dict: "content": self.prompt_provider.get_prompt( self.config.task_prompt, inputs={ - "query": "\n".join(query), + "query": query, "context": context, }, ), }, ] - async def _collect_context(self, input: Input) -> str: - iteration = 0 - context = "" - async for result in input.message: - context += f"Result {iteration+1}:\n{result.metadata['text']}\n\n" - iteration += 1 - - return context + async def _collect_context( + self, + query: str, + results: AggregateSearchResult, + iteration: int, + total_results: int, + ) -> Tuple[str, int]: + context = f"Query:\n{query}\n\n" + if results.vector_search_results: + context += f"Vector Search Results({iteration}):\n" + it = total_results + 1 + for result in results.vector_search_results: + context += f"[{it}]: {result.metadata['text']}\n\n" + it += 1 + total_results = ( + it - 1 + ) # Update total_results based on the last index used + if results.kg_search_results: + context += f"Knowledge Graph Search Results({iteration}):\n" + for result in results.kg_search_results: + context += f"[{it}]: {result}\n\n" + it += 1 + total_results = ( + it - 1 + ) # Update total_results based on the last index used + return context, total_results diff --git a/r2r/pipes/streaming_rag_pipe.py b/r2r/pipes/streaming_rag_pipe.py index 100f9ffe..04b866b0 100644 --- a/r2r/pipes/streaming_rag_pipe.py +++ b/r2r/pipes/streaming_rag_pipe.py @@ -5,12 +5,12 @@ from r2r.core import ( AsyncState, - GenerationConfig, LLMChatCompletionChunk, LLMProvider, PipeType, PromptProvider, ) +from r2r.core.abstractions.llm import GenerationConfig from .abstractions.generator_pipe import GeneratorPipe from .search_rag_pipe import SearchRAGPipe @@ -55,32 +55,44 @@ async def _run_logic( iteration = 0 context = "" # dump the search results and construct the context - yield f"<{self.SEARCH_STREAM_MARKER}>" - for result in input.raw_search_results: - if iteration >= 1: - yield "," - yield json.dumps(result.json()) - context += f"Result {iteration+1}:\n{result.metadata['text']}\n\n" - iteration += 1 - yield f"" + async for query, search_results in input.message: + yield f"<{self.SEARCH_STREAM_MARKER}>" + if search_results.vector_search_results: + context += "Vector Search Results:\n" + for result in search_results.vector_search_results: + if iteration >= 1: + yield "," + yield json.dumps(result.json()) + context += f"{iteration+1}:\n{result.metadata['text']}\n\n" + iteration += 1 - messages = self._get_message_payload(str(input.query), context) - yield f"<{self.COMPLETION_STREAM_MARKER}>" - response = "" - for chunk in self.llm_provider.get_completion_stream( - messages=messages, generation_config=rag_generation_config - ): - chunk = StreamingSearchRAGPipe._process_chunk(chunk) - response += chunk - yield chunk + # if search_results.kg_search_results: + # for result in search_results.kg_search_results: + # if iteration >= 1: + # yield "," + # yield json.dumps(result.json()) + # context += f"Result {iteration+1}:\n{result.metadata['text']}\n\n" + # iteration += 1 - yield f"" + yield f"" - await self.enqueue_log( - run_id=run_id, - key="llm_response", - value=response, - ) + messages = self._get_message_payload(query, context) + yield f"<{self.COMPLETION_STREAM_MARKER}>" + response = "" + for chunk in self.llm_provider.get_completion_stream( + messages=messages, generation_config=rag_generation_config + ): + chunk = StreamingSearchRAGPipe._process_chunk(chunk) + response += chunk + yield chunk + + yield f"" + + await self.enqueue_log( + run_id=run_id, + key="llm_response", + value=response, + ) async def _yield_chunks( self, diff --git a/r2r/pipes/vector_search_pipe.py b/r2r/pipes/vector_search_pipe.py index 44653680..288fb9af 100644 --- a/r2r/pipes/vector_search_pipe.py +++ b/r2r/pipes/vector_search_pipe.py @@ -8,8 +8,9 @@ AsyncState, EmbeddingProvider, PipeType, - SearchResult, VectorDBProvider, + VectorSearchResult, + VectorSearchSettings, ) from .abstractions.search_pipe import SearchPipe @@ -40,16 +41,19 @@ async def search( self, message: str, run_id: uuid.UUID, - do_hybrid_search: bool, + vector_search_settings: VectorSearchSettings, *args: Any, **kwargs: Any, - ) -> AsyncGenerator[SearchResult, None]: - search_filters_override = kwargs.get("search_filters", None) - search_limit_override = kwargs.get("search_limit", None) - search_limit = search_limit_override or self.config.search_limit + ) -> AsyncGenerator[VectorSearchResult, None]: await self.enqueue_log( run_id=run_id, key="search_query", value=message ) + search_filters = ( + vector_search_settings.search_filters or self.config.search_filters + ) + search_limit = ( + vector_search_settings.search_limit or self.config.search_limit + ) results = [] query_vector = self.embedding_provider.get_embedding( message, @@ -58,13 +62,13 @@ async def search( self.vector_db_provider.hybrid_search( query_vector=query_vector, query_text=message, - filters=search_filters_override or self.config.search_filters, + filters=search_filters, limit=search_limit, ) - if do_hybrid_search + if vector_search_settings.do_hybrid_search else self.vector_db_provider.search( query_vector=query_vector, - filters=search_filters_override or self.config.search_filters, + filters=search_filters, limit=search_limit, ) ) @@ -86,10 +90,10 @@ async def _run_logic( input: AsyncPipe.Input, state: AsyncState, run_id: uuid.UUID, - do_hybrid_search: bool = False, + vector_search_settings: VectorSearchSettings = VectorSearchSettings(), *args: Any, **kwargs: Any, - ) -> AsyncGenerator[SearchResult, None]: + ) -> AsyncGenerator[VectorSearchResult, None]: search_queries = [] search_results = [] async for search_request in input.message: @@ -97,7 +101,7 @@ async def _run_logic( async for result in self.search( message=search_request, run_id=run_id, - do_hybrid_search=do_hybrid_search, + vector_search_settings=vector_search_settings, *args, **kwargs, ): diff --git a/r2r/pipes/vector_storage_pipe.py b/r2r/pipes/vector_storage_pipe.py index 1f85df01..bc11628d 100644 --- a/r2r/pipes/vector_storage_pipe.py +++ b/r2r/pipes/vector_storage_pipe.py @@ -33,10 +33,6 @@ def __init__( """ Initializes the async vector storage pipe with necessary components and configurations. """ - logger.info( - f"Initalizing an `AsyncVectorStoragePipe` to store embeddings in a vector database." - ) - super().__init__( pipe_logger=pipe_logger, type=type, diff --git a/r2r/pipes/web_search_pipe.py b/r2r/pipes/web_search_pipe.py index 67da802e..14fd7f22 100644 --- a/r2r/pipes/web_search_pipe.py +++ b/r2r/pipes/web_search_pipe.py @@ -7,7 +7,7 @@ AsyncPipe, AsyncState, PipeType, - SearchResult, + VectorSearchResult, generate_id_from_label, ) from r2r.integrations import SerperClient @@ -40,7 +40,7 @@ async def search( run_id: uuid.UUID, *args: Any, **kwargs: Any, - ) -> AsyncGenerator[SearchResult, None]: + ) -> AsyncGenerator[VectorSearchResult, None]: search_limit_override = kwargs.get("search_limit", None) await self.enqueue_log( run_id=run_id, key="search_query", value=message @@ -56,7 +56,7 @@ async def search( if result.get("snippet") is None: continue result["text"] = result.pop("snippet") - search_result = SearchResult( + search_result = VectorSearchResult( id=generate_id_from_label(str(result)), score=result.get( "score", 0 @@ -79,7 +79,7 @@ async def _run_logic( run_id: uuid.UUID, *args: Any, **kwargs, - ) -> AsyncGenerator[SearchResult, None]: + ) -> AsyncGenerator[VectorSearchResult, None]: search_queries = [] search_results = [] async for search_request in input.message: diff --git a/r2r/prebuilts/multi_search.py b/r2r/prebuilts/multi_search.py index 31a88706..d236fcfc 100644 --- a/r2r/prebuilts/multi_search.py +++ b/r2r/prebuilts/multi_search.py @@ -3,13 +3,13 @@ from typing import Any, AsyncGenerator, Optional from r2r import ( - GenerationConfig, LoggableAsyncPipe, QueryTransformPipe, R2RPipeFactory, SearchPipe, - SearchResult, + VectorSearchResult, ) +from r2r.core.abstractions.llm import GenerationConfig class MultiSearchPipe(LoggableAsyncPipe): @@ -55,7 +55,7 @@ async def _run_logic( query_transform_generation_config: Optional[GenerationConfig] = None, *args: Any, **kwargs: Any, - ) -> AsyncGenerator[SearchResult, None]: + ) -> AsyncGenerator[VectorSearchResult, None]: query_transform_generation_config = ( query_transform_generation_config or copy(kwargs.get("rag_generation_config", None)) diff --git a/r2r/providers/embeddings/openai/openai_base.py b/r2r/providers/embeddings/openai/openai_base.py index eef09b49..cf959bab 100644 --- a/r2r/providers/embeddings/openai/openai_base.py +++ b/r2r/providers/embeddings/openai/openai_base.py @@ -3,7 +3,7 @@ from openai import AsyncOpenAI, AuthenticationError, OpenAI -from r2r.core import EmbeddingConfig, EmbeddingProvider, SearchResult +from r2r.core import EmbeddingConfig, EmbeddingProvider, VectorSearchResult logger = logging.getLogger(__name__) @@ -21,9 +21,6 @@ class OpenAIEmbeddingProvider(EmbeddingProvider): } def __init__(self, config: EmbeddingConfig): - logger.info( - "Initializing `OpenAIEmbeddingProvider` to provide embeddings." - ) super().__init__(config) provider = config.provider if not provider: @@ -180,7 +177,7 @@ async def async_get_embeddings( def rerank( self, query: str, - results: list[SearchResult], + results: list[VectorSearchResult], stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.RERANK, limit: int = 10, ): diff --git a/r2r/providers/embeddings/sentence_transformer/sentence_transformer_base.py b/r2r/providers/embeddings/sentence_transformer/sentence_transformer_base.py index b8de4ea5..26fdc40c 100644 --- a/r2r/providers/embeddings/sentence_transformer/sentence_transformer_base.py +++ b/r2r/providers/embeddings/sentence_transformer/sentence_transformer_base.py @@ -1,6 +1,6 @@ import logging -from r2r.core import EmbeddingConfig, EmbeddingProvider, SearchResult +from r2r.core import EmbeddingConfig, EmbeddingProvider, VectorSearchResult logger = logging.getLogger(__name__) @@ -123,10 +123,10 @@ def get_embeddings( def rerank( self, query: str, - results: list[SearchResult], + results: list[VectorSearchResult], stage: EmbeddingProvider.PipeStage = EmbeddingProvider.PipeStage.RERANK, limit: int = 10, - ) -> list[SearchResult]: + ) -> list[VectorSearchResult]: if stage != EmbeddingProvider.PipeStage.RERANK: raise ValueError("`rerank` only supports `RERANK` stage.") if not self.do_rerank: diff --git a/r2r/providers/eval/llm/base_llm_eval.py b/r2r/providers/eval/llm/base_llm_eval.py index fcdb57e1..a7a33e47 100644 --- a/r2r/providers/eval/llm/base_llm_eval.py +++ b/r2r/providers/eval/llm/base_llm_eval.py @@ -1,13 +1,8 @@ from fractions import Fraction from typing import Union -from r2r import ( - EvalConfig, - EvalProvider, - GenerationConfig, - LLMProvider, - PromptProvider, -) +from r2r import EvalConfig, EvalProvider, LLMProvider, PromptProvider +from r2r.core.abstractions.llm import GenerationConfig class LLMEvalProvider(EvalProvider): diff --git a/r2r/providers/llms/litellm/base_litellm.py b/r2r/providers/llms/litellm/base_litellm.py index 1f1ebb65..0898fe03 100644 --- a/r2r/providers/llms/litellm/base_litellm.py +++ b/r2r/providers/llms/litellm/base_litellm.py @@ -2,12 +2,12 @@ from typing import Any, Generator, Union from r2r.core import ( - GenerationConfig, LLMChatCompletion, LLMChatCompletionChunk, LLMConfig, LLMProvider, ) +from r2r.core.abstractions.llm import GenerationConfig logger = logging.getLogger(__name__) diff --git a/r2r/providers/llms/openai/base_openai.py b/r2r/providers/llms/openai/base_openai.py index fafd760f..4a41957c 100644 --- a/r2r/providers/llms/openai/base_openai.py +++ b/r2r/providers/llms/openai/base_openai.py @@ -5,12 +5,12 @@ from typing import Union from r2r.core import ( - GenerationConfig, LLMChatCompletion, LLMChatCompletionChunk, LLMConfig, LLMProvider, ) +from r2r.core.abstractions.llm import GenerationConfig logger = logging.getLogger(__name__) diff --git a/r2r/providers/vector_dbs/local/r2r_local_vector_db.py b/r2r/providers/vector_dbs/local/r2r_local_vector_db.py index c63b7dcf..2c28e7ec 100644 --- a/r2r/providers/vector_dbs/local/r2r_local_vector_db.py +++ b/r2r/providers/vector_dbs/local/r2r_local_vector_db.py @@ -7,11 +7,11 @@ from r2r.core import ( DocumentInfo, - SearchResult, UserStats, VectorDBConfig, VectorDBProvider, VectorEntry, + VectorSearchResult, ) logger = logging.getLogger(__name__) @@ -19,6 +19,7 @@ class R2RLocalVectorDB(VectorDBProvider): def __init__(self, config: VectorDBConfig) -> None: + super().__init__(config) if config.provider != "local": raise ValueError( @@ -132,7 +133,7 @@ def search( limit: int = 10, *args, **kwargs, - ) -> list[SearchResult]: + ) -> list[VectorSearchResult]: if self.config.collection_name is None: raise ValueError( "Collection name is not set. Please call `initialize_collection` first." @@ -148,7 +149,9 @@ def search( # Local cosine similarity calculation score = self._cosine_similarity(query_vector, vector) results.append( - SearchResult(id=id, score=score, metadata=json_metadata) + VectorSearchResult( + id=id, score=score, metadata=json_metadata + ) ) results.sort(key=lambda x: x.score, reverse=True) conn.close() @@ -166,7 +169,7 @@ def hybrid_search( rrf_k: int = 20, # typical value is ~2x the number of results you want *args, **kwargs, - ) -> list[SearchResult]: + ) -> list[VectorSearchResult]: raise NotImplementedError( "Hybrid search is not supported in R2RLocalVectorDB." ) diff --git a/r2r/providers/vector_dbs/pgvector/pgvector_db.py b/r2r/providers/vector_dbs/pgvector/pgvector_db.py index c5ba0600..f18a2f04 100644 --- a/r2r/providers/vector_dbs/pgvector/pgvector_db.py +++ b/r2r/providers/vector_dbs/pgvector/pgvector_db.py @@ -2,18 +2,17 @@ import logging import os import time -import uuid from typing import Optional, Union from sqlalchemy import exc, text from r2r.core import ( DocumentInfo, - SearchResult, UserStats, VectorDBConfig, VectorDBProvider, VectorEntry, + VectorSearchResult, ) from r2r.vecs.client import Client from r2r.vecs.collection import Collection @@ -238,7 +237,7 @@ def search( limit: int = 10, *args, **kwargs, - ) -> list[SearchResult]: + ) -> list[VectorSearchResult]: if self.collection is None: raise ValueError( "Please call `initialize_collection` before attempting to run `search`." @@ -249,7 +248,7 @@ def search( } return [ - SearchResult(id=ele[0], score=float(1 - ele[1]), metadata=ele[2]) # type: ignore + VectorSearchResult(id=ele[0], score=float(1 - ele[1]), metadata=ele[2]) # type: ignore for ele in self.collection.query( data=query_vector, limit=limit, @@ -272,7 +271,7 @@ def hybrid_search( rrf_k: int = 20, # typical value is ~2x the number of results you want *args, **kwargs, - ) -> list[SearchResult]: + ) -> list[VectorSearchResult]: if self.collection is None: raise ValueError( "Please call `initialize_collection` before attempting to run `hybrid_search`." @@ -306,7 +305,7 @@ def hybrid_search( with self.vx.Session() as session: result = session.execute(query, params).fetchall() return [ - SearchResult(id=row[0], score=1.0, metadata=row[-1]) + VectorSearchResult(id=row[0], score=1.0, metadata=row[-1]) for row in result ] @@ -493,7 +492,7 @@ def get_users_stats(self, user_ids: Optional[list[str]] = None): user_id=row[0], num_files=row[1], total_size_in_bytes=row[2], - document_ids=[uuid.UUID(doc_id) for doc_id in row[3]], + document_ids=row[3], ) for row in results ] diff --git a/tests/test_abstractions.py b/tests/test_abstractions.py index b53f6edc..63263ef3 100644 --- a/tests/test_abstractions.py +++ b/tests/test_abstractions.py @@ -6,10 +6,10 @@ AsyncPipe, AsyncState, Prompt, - SearchRequest, - SearchResult, Vector, VectorEntry, + VectorSearchRequest, + VectorSearchResult, VectorType, generate_id_from_label, ) @@ -83,7 +83,7 @@ def test_prompt_invalid_input_type(): def test_search_request_with_optional_filters(): - request = SearchRequest( + request = VectorSearchRequest( query="test", limit=10, filters={"category": "books"} ) assert request.query == "test" @@ -92,7 +92,7 @@ def test_search_request_with_optional_filters(): def test_search_result_to_string(): - result = SearchResult( + result = VectorSearchResult( id=generate_id_from_label("1"), score=9.5, metadata={"author": "John Doe"}, @@ -100,19 +100,19 @@ def test_search_result_to_string(): result_str = str(result) assert ( result_str - == f"SearchResult(id={str(generate_id_from_label('1'))}, score=9.5, metadata={{'author': 'John Doe'}})" + == f"VectorSearchResult(id={str(generate_id_from_label('1'))}, score=9.5, metadata={{'author': 'John Doe'}})" ) def test_search_result_repr(): - result = SearchResult( + result = VectorSearchResult( id=generate_id_from_label("1"), score=9.5, metadata={"author": "John Doe"}, ) assert ( repr(result) - == f"SearchResult(id={str(generate_id_from_label('1'))}, score=9.5, metadata={{'author': 'John Doe'}})" + == f"VectorSearchResult(id={str(generate_id_from_label('1'))}, score=9.5, metadata={{'author': 'John Doe'}})" ) diff --git a/tests/test_embedding.py b/tests/test_embedding.py index beb607a3..dbc6c841 100644 --- a/tests/test_embedding.py +++ b/tests/test_embedding.py @@ -1,6 +1,6 @@ import pytest -from r2r import EmbeddingConfig, SearchResult, generate_id_from_label +from r2r import EmbeddingConfig, VectorSearchResult, generate_id_from_label from r2r.providers.embeddings import ( OpenAIEmbeddingProvider, SentenceTransformerEmbeddingProvider, @@ -106,12 +106,12 @@ def test_sentence_transformer_get_embeddings(sentence_transformer_provider): def test_sentence_transformer_rerank(sentence_transformer_provider): results = [ - SearchResult( + VectorSearchResult( id=generate_id_from_label("x"), score=0.9, metadata={"text": "doc1"}, ), - SearchResult( + VectorSearchResult( id=generate_id_from_label("y"), score=0.8, metadata={"text": "doc2"}, diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index b2294192..ea29aeac 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -7,15 +7,16 @@ from r2r import ( Document, - GenerationConfig, KVLoggingSingleton, R2RApp, R2RConfig, R2RPipeFactory, R2RPipelineFactory, R2RProviderFactory, + VectorSearchSettings, generate_id_from_label, ) +from r2r.core.abstractions.llm import GenerationConfig @pytest.fixture(scope="function") @@ -42,6 +43,7 @@ def r2r_app(request): r2r = R2RApp( config=config, providers=providers, + pipes=pipes, pipelines=pipelines, ) @@ -179,24 +181,29 @@ async def test_ingest_search_txt_file(r2r_app, logging_connection): ) search_results = await r2r_app.asearch("who was aristotle?") - assert len(search_results["results"]) == 10 + assert len(search_results["results"]["vector_search_results"]) == 10 assert ( "was an Ancient Greek philosopher and polymath" - in search_results["results"][0]["metadata"]["text"] + in search_results["results"]["vector_search_results"][0]["metadata"][ + "text" + ] ) search_results = await r2r_app.asearch( - "who was aristotle?", search_limit=20 + "who was aristotle?", + vector_search_settings=VectorSearchSettings(search_limit=20), ) - assert len(search_results["results"]) == 20 + assert len(search_results["results"]["vector_search_results"]) == 20 assert ( "was an Ancient Greek philosopher and polymath" - in search_results["results"][0]["metadata"]["text"] + in search_results["results"]["vector_search_results"][0]["metadata"][ + "text" + ] ) - ## test streaming + ## test stream response = await r2r_app.arag( - message="Who was aristotle?", + query="Who was aristotle?", rag_generation_config=GenerationConfig( **{"model": "gpt-3.5-turbo", "stream": True} ), @@ -231,10 +238,12 @@ async def test_ingest_search_then_delete(r2r_app, logging_connection): # Verify that the search results are not empty assert ( - len(search_results["results"]) > 0 + len(search_results["results"]["vector_search_results"]) > 0 ), "Expected search results, but got none" assert ( - search_results["results"][0]["metadata"]["text"] + search_results["results"]["vector_search_results"][0]["metadata"][ + "text" + ] == "The quick brown fox jumps over the lazy dog." ) @@ -251,11 +260,11 @@ async def test_ingest_search_then_delete(r2r_app, logging_connection): # Verify that the search results are empty assert ( - len(search_results_2["results"]) == 0 + len(search_results_2["results"]["vector_search_results"]) == 0 ), f"Expected no search results, but got {search_results_2['results']}" -@pytest.mark.parametrize("r2r_app", ["local", "postgres"], indirect=True) +@pytest.mark.parametrize("r2r_app", ["local", "pgvector"], indirect=True) @pytest.mark.asyncio async def test_ingest_user_documents(r2r_app, logging_connection): user_id_0 = generate_id_from_label("user_0") @@ -263,13 +272,13 @@ async def test_ingest_user_documents(r2r_app, logging_connection): await r2r_app.aingest_documents( [ Document( - id=generate_id_from_label("doc_0"), + id=generate_id_from_label("doc_01"), data="The quick brown fox jumps over the lazy dog.", type="txt", metadata={"author": "John Doe", "user_id": user_id_0}, ), Document( - id=generate_id_from_label("doc_1"), + id=generate_id_from_label("doc_11"), data="The lazy dog jumps over the quick brown fox.", type="txt", metadata={"author": "John Doe", "user_id": user_id_1}, @@ -291,10 +300,10 @@ async def test_ingest_user_documents(r2r_app, logging_connection): len(user_1_docs) == 1 ), f"Expected 1 document for user {user_id_1}, but got {len(user_1_docs)}" assert user_0_docs[0].document_id == generate_id_from_label( - "doc_0" + "doc_01" ), f"Expected document id {str(generate_id_from_label('doc_0'))} for user {user_id_0}, but got {user_0_docs[0].document_id}" assert user_1_docs[0].document_id == generate_id_from_label( - "doc_1" + "doc_11" ), f"Expected document id {str(generate_id_from_label('doc_1'))} for user {user_id_1}, but got {user_1_docs[0].document_id}" @@ -313,12 +322,12 @@ async def test_delete_by_id(r2r_app, logging_connection): ) search_results = await r2r_app.asearch("who was aristotle?") - assert len(search_results["results"]) > 0 + assert len(search_results["results"]["vector_search_results"]) > 0 await r2r_app.adelete( ["document_id"], [str(generate_id_from_label("doc_1"))] ) search_results = await r2r_app.asearch("who was aristotle?") - assert len(search_results["results"]) == 0 + assert len(search_results["results"]["vector_search_results"]) == 0 @pytest.mark.parametrize("r2r_app", ["pgvector", "local"], indirect=True) @@ -336,7 +345,7 @@ async def test_double_ingest(r2r_app, logging_connection): ) search_results = await r2r_app.asearch("who was aristotle?") - assert len(search_results["results"]) == 1 + assert len(search_results["results"]["vector_search_results"]) == 1 with pytest.raises(Exception): await r2r_app.aingest_documents( [ diff --git a/tests/test_llms.py b/tests/test_llms.py index 49b83908..8b0f49f3 100644 --- a/tests/test_llms.py +++ b/tests/test_llms.py @@ -1,6 +1,7 @@ import pytest -from r2r import GenerationConfig, LLMConfig +from r2r import LLMConfig +from r2r.core.abstractions.llm import GenerationConfig from r2r.providers.llms import LiteLLM diff --git a/tests/test_server_client.py b/tests/test_server_client.py index c5416ea1..f045f74f 100644 --- a/tests/test_server_client.py +++ b/tests/test_server_client.py @@ -6,12 +6,17 @@ from fastapi.testclient import TestClient from r2r import ( + KGSearchSettings, KVLoggingSingleton, R2RApp, R2RConfig, + R2RIngestFilesRequest, R2RPipeFactory, R2RPipelineFactory, R2RProviderFactory, + R2RRAGRequest, + R2RSearchRequest, + VectorSearchSettings, generate_id_from_label, ) @@ -40,6 +45,7 @@ def r2r_app(request): r2r = R2RApp( config=config, providers=providers, + pipes=pipes, pipelines=pipelines, ) @@ -107,6 +113,7 @@ async def test_ingest_txt_file(client): response = client.post( "/ingest_files/", + # must use data instead of json when sending files data={"metadatas": json.dumps([metadata])}, files=files, ) @@ -124,14 +131,13 @@ async def test_ingest_txt_file(client): @pytest.mark.asyncio async def test_search(client): query = "who was aristotle?" - response = client.post( - "/search/", - json={ - "query": query, - "search_filters": "{}", - "search_limit": "10", - }, + search_request = R2RSearchRequest( + query=query, + vector_settings=VectorSearchSettings(), + kg_settings=KGSearchSettings(), ) + + response = client.post("/search/", json=search_request.dict()) assert response.status_code == 200 assert "results" in response.json() @@ -140,15 +146,14 @@ async def test_search(client): @pytest.mark.asyncio async def test_rag(client): query = "who was aristotle?" - response = client.post( - "/rag/", - json={ - "message": query, - "search_filters": "{}", - "search_limit": "10", - "streaming": "false", - }, + rag_request = R2RRAGRequest( + query=query, + vector_settings=VectorSearchSettings(), + kg_settings=KGSearchSettings(), + rag_generation_config=None, ) + + response = client.post("/rag/", json=rag_request.dict()) assert response.status_code == 200 assert "results" in response.json() @@ -168,13 +173,17 @@ async def test_delete(client): ), ), ] + request = R2RIngestFilesRequest( + metadatas=[metadata], + ) response = client.post( "/ingest_files/", - data={"metadatas": json.dumps([metadata])}, + data={k: json.dumps(v) for k, v in json.loads(request.json()).items()}, files=files, ) + print("response = ", response) response = client.request( "DELETE", "/delete/",