Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(gpt4): image generation #2569

Merged
merged 1 commit into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 89 additions & 6 deletions backend/modules/brain/integrations/GPT4/Brain.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import json
import operator
from typing import Annotated, AsyncIterable, List, Sequence, TypedDict
from typing import Annotated, AsyncIterable, List, Optional, Sequence, Type, TypedDict
from uuid import UUID

from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain.pydantic_v1 import BaseModel as BaseModelV1
from langchain.pydantic_v1 import Field as FieldV1
from langchain.tools import BaseTool
from langchain_community.tools import DuckDuckGoSearchResults
from langchain_core.messages import BaseMessage, ToolMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
Expand All @@ -15,6 +22,8 @@
from modules.chat.dto.chats import ChatQuestion
from modules.chat.dto.outputs import GetChatHistoryOutput
from modules.chat.service.chat_service import ChatService
from openai import OpenAI
from pydantic import BaseModel


class AgentState(TypedDict):
Expand All @@ -28,6 +37,56 @@ class AgentState(TypedDict):
chat_service = ChatService()


class ImageGenerationInput(BaseModelV1):
query: str = FieldV1(
...,
title="description",
description="A detailled prompt to generate the image from. Takes into account the history of the chat.",
)


class ImageGeneratorTool(BaseTool):
name = "image-generator"
description = "useful for when you need to answer questions about current events"
args_schema: Type[BaseModel] = ImageGenerationInput
return_direct = True

def _run(
self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None
) -> str:
client = OpenAI()

response = client.images.generate(
model="dall-e-3",
prompt=query,
size="1024x1024",
quality="standard",
n=1,
)
logger.info(response.data[0])
image_url = response.data[0].url
revised_prompt = response.data[0].revised_prompt
# Make the url a markdown image
return f"{revised_prompt} \n ![Generated Image]({image_url}) "

async def _arun(
self, query: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None
) -> str:
"""Use the tool asynchronously."""
client = OpenAI()
response = await run_manager.run_async(
client.images.generate,
model="dall-e-3",
prompt=query,
size="1024x1024",
quality="standard",
n=1,
)
image_url = response.data[0].url
# Make the url a markdown image
return f"![Generated Image]({image_url})"


class GPT4Brain(KnowledgeBrainQA):
"""This is the Notion brain class. it is a KnowledgeBrainQA has the data is stored locally.
It is going to call the Data Store internally to get the data.
Expand All @@ -36,7 +95,7 @@ class GPT4Brain(KnowledgeBrainQA):
KnowledgeBrainQA (_type_): A brain that store the knowledge internaly
"""

tools: List[BaseTool] = [DuckDuckGoSearchResults()]
tools: List[BaseTool] = [DuckDuckGoSearchResults(), ImageGeneratorTool()]
tool_executor: ToolExecutor = ToolExecutor(tools)
model_function: ChatOpenAI = None

Expand All @@ -54,10 +113,16 @@ def calculate_pricing(self):
def should_continue(self, state):
messages = state["messages"]
last_message = messages[-1]
# Make sure there is a previous message

if last_message.tool_calls:
name = last_message.tool_calls[0]["name"]
if name == "image-generator":
return "final"
# If there is no function call, then we finish
if not last_message.tool_calls:
return "end"
# Otherwise if there is, we continue
# Otherwise if there is, we check if it's suppose to return direct
else:
return "continue"

Expand All @@ -76,6 +141,9 @@ def call_tool(self, state):
last_message = messages[-1]
# We construct an ToolInvocation from the function_call
tool_call = last_message.tool_calls[0]
tool_name = tool_call["name"]
arguments = tool_call["args"]

action = ToolInvocation(
tool=tool_call["name"],
tool_input=tool_call["args"],
Expand All @@ -96,6 +164,7 @@ def create_graph(self):
# Define the two nodes we will cycle between
workflow.add_node("agent", self.call_model)
workflow.add_node("action", self.call_tool)
workflow.add_node("final", self.call_tool)

# Set the entrypoint as `agent`
# This means that this node is the first one called
Expand All @@ -117,6 +186,8 @@ def create_graph(self):
{
# If `tools`, then we call the tool node.
"continue": "action",
# Final call
"final": "final",
# Otherwise we finish.
"end": END,
},
Expand All @@ -125,6 +196,7 @@ def create_graph(self):
# We now add a normal edge from `tools` to `agent`.
# This means that after `tools` is called, `agent` node is called next.
workflow.add_edge("action", "agent")
workflow.add_edge("final", END)

# Finally, we compile it!
# This compiles it into a LangChain Runnable,
Expand Down Expand Up @@ -196,18 +268,29 @@ async def generate_stream(
print(f"Done tool: {event['name']}")
print(f"Tool output was: {event['data'].get('output')}")
print("--")
elif kind == "on_chain_end":
output = event["data"]["output"]
final_output = [item for item in output if "final" in item]
if final_output:
if (
final_output[0]["final"]["messages"][0].name
== "image-generator"
):
final_message = final_output[0]["final"]["messages"][0].content
response_tokens.append(final_message)
streamed_chat_history.assistant = final_message
yield f"data: {json.dumps(streamed_chat_history.dict())}"

self.save_answer(question, response_tokens, streamed_chat_history, save_answer)

def generate_answer(
self, chat_id: UUID, question: ChatQuestion, save_answer: bool = True
) -> GetChatHistoryOutput:
conversational_qa_chain = self.get_chain()
transformed_history, streamed_chat_history = (
self.initialize_streamed_chat_history(chat_id, question)
transformed_history, _ = self.initialize_streamed_chat_history(
chat_id, question
)
filtered_history = self.filter_history(transformed_history, 20, 2000)
response_tokens = []
config = {"metadata": {"conversation_id": str(chat_id)}}

prompt = ChatPromptTemplate.from_messages(
Expand Down
4 changes: 2 additions & 2 deletions backend/modules/brain/knowledge_brain_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,8 @@ def generate_answer(
self, chat_id: UUID, question: ChatQuestion, save_answer: bool = True
) -> GetChatHistoryOutput:
conversational_qa_chain = self.knowledge_qa.get_chain()
transformed_history, streamed_chat_history = (
self.initialize_streamed_chat_history(chat_id, question)
transformed_history, _ = self.initialize_streamed_chat_history(
chat_id, question
)
metadata = self.metadata or {}
citations = None
Expand Down