Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add StreamingResponse support for ConversationalRetrievalChain chain type #16

Merged
merged 10 commits into from
May 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ To run a demo example, select the command based on the langchain use case you wa

- Conversation Chain: `uvicorn app.conversation_chain:app --reload`
- Retrieval QA with Sources Chain: `uvicorn app.retrieval_qa_chain:app --reload`
- Conversational Retrieval: `uvicorn app.conversational_retrieval:app --reload`

You can also use the "Run & Debug" VSCode feature to run one of the applications.

Expand Down
84 changes: 84 additions & 0 deletions examples/app/conversational_retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from functools import lru_cache
from typing import Callable

from dotenv import load_dotenv
from fastapi import Depends, FastAPI
from fastapi.templating import Jinja2Templates
from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.chat_models import ChatOpenAI
from pydantic import BaseModel

from fastapi_async_langchain.responses import ConversationalRetrievalStreamingResponse
from fastapi_async_langchain.testing import mount_gradio_app

load_dotenv()

app = mount_gradio_app(FastAPI(title="ConversationalRetrievalChainDemo"))

templates = Jinja2Templates(directory="templates")


class QueryRequest(BaseModel):
query: str
history: list[list[str]] = []


def conversational_retrieval_chain_dependency() -> (
Callable[[], ConversationalRetrievalChain]
):
@lru_cache(maxsize=1)
def dependency() -> ConversationalRetrievalChain:
from langchain.chains.conversational_retrieval.prompts import (
CONDENSE_QUESTION_PROMPT,
)
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS

db = FAISS.load_local(
folder_path="vector_stores/",
index_name="langchain-python",
embeddings=OpenAIEmbeddings(),
)

question_generator = LLMChain(
llm=ChatOpenAI(
temperature=0,
streaming=True,
),
prompt=CONDENSE_QUESTION_PROMPT,
)
doc_chain = load_qa_chain(
llm=ChatOpenAI(
temperature=0,
streaming=True,
),
chain_type="stuff",
)

return ConversationalRetrievalChain(
combine_docs_chain=doc_chain,
question_generator=question_generator,
retriever=db.as_retriever(),
return_source_documents=True,
verbose=True,
)

return dependency


conversational_retrieval_chain = conversational_retrieval_chain_dependency()


@app.post("/chat")
async def chat(
request: QueryRequest,
chain: ConversationalRetrievalChain = Depends(conversational_retrieval_chain),
) -> ConversationalRetrievalStreamingResponse:
inputs = {
"question": request.query,
"chat_history": [(human, ai) for human, ai in request.history],
}
return ConversationalRetrievalStreamingResponse.from_chain(
chain, inputs, media_type="text/event-stream"
)
7 changes: 6 additions & 1 deletion fastapi_async_langchain/responses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from .conversational_retrieval import ConversationalRetrievalStreamingResponse
from .llm import LLMChainStreamingResponse
from .retrieval_qa import RetrievalQAStreamingResponse

__all__ = ["LLMChainStreamingResponse", "RetrievalQAStreamingResponse"]
__all__ = [
"LLMChainStreamingResponse",
"RetrievalQAStreamingResponse",
"ConversationalRetrievalStreamingResponse",
]
22 changes: 22 additions & 0 deletions fastapi_async_langchain/responses/conversational_retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Any, Awaitable, Callable, Dict, Union

from langchain.chains.retrieval_qa.base import BaseRetrievalQA
from starlette.types import Send

from ..callbacks import AsyncRetrievalQAStreamingCallback
from .base import BaseLangchainStreamingResponse


class ConversationalRetrievalStreamingResponse(BaseLangchainStreamingResponse):
"""BaseLangchainStreamingResponse class wrapper for ConversationalRetrievalStreamingResponse instances."""

@staticmethod
def _create_chain_executor(
chain: BaseRetrievalQA, inputs: Union[Dict[str, Any], Any]
) -> Callable[[Send], Awaitable[Any]]:
async def wrapper(send: Send):
return await chain.acall(
inputs=inputs, callbacks=[AsyncRetrievalQAStreamingCallback(send=send)]
)

return wrapper
2 changes: 2 additions & 0 deletions fastapi_async_langchain/websockets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from .base import BaseLangchainWebsocketConnection
from .conversational_retrieval import ConversationalRetrievalWebsocketConnection
from .llm import LLMChainWebsocketConnection
from .retrieval_qa import RetrievalQAWebsocketConnection

__all__ = [
"BaseLangchainWebsocketConnection",
"LLMChainWebsocketConnection",
"RetrievalQAWebsocketConnection",
"ConversationalRetrievalWebsocketConnection",
]
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ packages = [{include = "fastapi_async_langchain"}]
[tool.poetry.dependencies]
python = "^3.9"
fastapi = "^0.95.1"
langchain = "^0.0.157"
langchain = "^0.0.164"
urllib3 = "<=1.26.15" # added due to poetry errors
python-dotenv = "^1.0.0"

[build-system]
requires = ["poetry-core"]
Expand Down