Skip to content

Commit

Permalink
feat: improved Chuanhu Assistant behaviour when handling file uploads
Browse files Browse the repository at this point in the history
  • Loading branch information
GaiZhenbiao committed Apr 10, 2024
1 parent c759290 commit 0c4dc56
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 71 deletions.
124 changes: 62 additions & 62 deletions modules/models/ChuanhuAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from langchain.docstore.document import Document
from langchain.text_splitter import TokenTextSplitter
from langchain.tools import StructuredTool, Tool
from langchain_community.callbacks import get_openai_callback
from langchain.vectorstores.base import VectorStoreRetriever
from langchain_community.embeddings import OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_core.messages.ai import AIMessage
Expand All @@ -25,9 +25,9 @@
from langchain_openai import ChatOpenAI
from pydantic.v1 import BaseModel, Field

from ..config import default_chuanhu_assistant_model
from ..index_func import construct_index
from ..presets import SUMMARIZE_PROMPT, i18n
from ..utils import add_source_numbers
from .base_model import (BaseLLMModel, CallbackToIterator,
ChuanhuCallbackHandler)

Expand All @@ -40,24 +40,19 @@ class WebBrowsingInput(BaseModel):
url: str = Field(description="URL of a webpage")


class KnowledgeBaseQueryInput(BaseModel):
question: str = Field(
description="The question you want to ask the knowledge base."
)


class WebAskingInput(BaseModel):
url: str = Field(description="URL of a webpage")
question: str = Field(
description="Question that you want to know the answer to, based on the webpage's content."
)


agent_prompt = ChatPromptTemplate.from_messages(
[
("system", "You are a helpful assistant"),
("placeholder", "{chat_history}"),
("human", "{input}"),
("placeholder", "{agent_scratchpad}"),
]
)
agent_prompt.input_variables = ['agent_scratchpad', 'input']


class ChuanhuAgent_Client(BaseLLMModel):
def __init__(self, model_name, openai_api_key, user_name="") -> None:
super().__init__(model_name=model_name, user=user_name)
Expand Down Expand Up @@ -85,12 +80,14 @@ def __init__(self, model_name, openai_api_key, user_name="") -> None:
openai_api_key=openai_api_key,
model_name="gpt-4-turbo-preview",
openai_api_base=os.environ.get("OPENAI_API_BASE", None),
streaming=True,
)
else:
self.llm = ChatOpenAI(
openai_api_key=openai_api_key,
model_name="gpt-3.5-turbo",
openai_api_base=os.environ.get("OPENAI_API_BASE", None),
streaming=True,
)
tools_to_enable = ["llm-math", "arxiv", "wikipedia"]
# if exists GOOGLE_CSE_ID and GOOGLE_API_KEY, enable google-search-results-json
Expand All @@ -115,9 +112,7 @@ def __init__(self, model_name, openai_api_key, user_name="") -> None:
if os.environ.get("WOLFRAM_ALPHA_APPID", None) is not None:
tools_to_enable.append("wolfram-alpha")
else:
logging.warning(
"WOLFRAM_ALPHA_APPID not found, wolfram-alpha is disabled."
)
logging.warning("WOLFRAM_ALPHA_APPID not found, wolfram-alpha is disabled.")
# if exists SERPAPI_API_KEY, enable serpapi
if os.environ.get("SERPAPI_API_KEY", None) is not None:
tools_to_enable.append("serpapi")
Expand Down Expand Up @@ -161,53 +156,31 @@ def handle_file_upload(self, files, chatbot, language):
assert index is not None, "获取索引失败"
self.index = index
status = i18n("索引构建完成")
# Summarize the document
logging.info(i18n("生成内容总结中……"))
with get_openai_callback() as cb:
os.environ["OPENAI_API_KEY"] = self.api_key
from langchain.chains.summarize import load_summarize_chain
from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate

prompt_template = (
"Write a concise summary of the following:\n\n{text}\n\nCONCISE SUMMARY IN "
+ language
+ ":"
)
PROMPT = PromptTemplate(
template=prompt_template, input_variables=["text"]
)
llm = ChatOpenAI()
chain = load_summarize_chain(
llm,
chain_type="map_reduce",
return_intermediate_steps=True,
map_prompt=PROMPT,
combine_prompt=PROMPT,
)
summary = chain(
{
"input_documents": list(
index.docstore.__dict__["_dict"].values()
)
},
return_only_outputs=True,
)["output_text"]
logging.info(f"Summary: {summary}")
self.index_summary = summary
chatbot.append((f"Uploaded {len(files)} files", summary))
logging.info(cb)
self.index_summary = ", ".join(
[os.path.basename(file.name) for file in files]
)
return gr.update(), chatbot, status

def prepare_inputs(
self, real_inputs, use_websearch, files, reply_language, chatbot
):
fake_inputs = real_inputs
display_append = ""
limited_context = False
return limited_context, fake_inputs, display_append, real_inputs, chatbot

def query_index(self, query):
if self.index is not None:
retriever = self.index.as_retriever()
qa = RetrievalQA.from_chain_type(
llm=self.llm, chain_type="stuff", retriever=retriever
)
return qa.run(query)
else:
"Error during query."
retriever = VectorStoreRetriever(
vectorstore=self.index, search_type="similarity", search_kwargs={"k": 6}
)
relevant_documents = retriever.get_relevant_documents(query)
reference_results = [
[d.page_content.strip("�"), os.path.basename(d.metadata["source"])]
for d in relevant_documents
]
reference_results = add_source_numbers(reference_results)
reference_results = "\n".join(reference_results)
return reference_results

def summary(self, text):
texts = Document(page_content=text)
Expand Down Expand Up @@ -271,15 +244,42 @@ def get_answer_stream_iter(self):
it = CallbackToIterator()
manager = BaseCallbackManager(handlers=[ChuanhuCallbackHandler(it.callback)])

if "Pro" in self.model_name:
self.llm = ChatOpenAI(
openai_api_key=self.api_key,
model_name="gpt-4-turbo-preview",
openai_api_base=os.environ.get("OPENAI_API_BASE", None),
temperature=self.temperature,
streaming=True,
)
else:
self.llm = ChatOpenAI(
openai_api_key=self.api_key,
model_name="gpt-3.5-turbo",
openai_api_base=os.environ.get("OPENAI_API_BASE", None),
temperature=self.temperature,
streaming=True,
)

agent_prompt = ChatPromptTemplate.from_messages(
[
("system", self.system_prompt),
("placeholder", "{chat_history}"),
("human", "{input}"),
("placeholder", "{agent_scratchpad}"),
]
)
agent_prompt.input_variables = ["agent_scratchpad", "input"]

def thread_func():
tools = self.tools
if self.index is not None:
tools.append(
Tool.from_function(
func=self.query_index,
name="Query Knowledge Base",
name="query_knowledge_base",
description=f"useful when you need to know about: {self.index_summary}",
args_schema=WebBrowsingInput,
args_schema=KnowledgeBaseQueryInput,
)
)
agent = create_openai_tools_agent(self.llm, tools, agent_prompt)
Expand Down
26 changes: 17 additions & 9 deletions modules/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from itertools import islice
from threading import Condition, Thread
from typing import Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, TypeVar, Union
from uuid import UUID
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk

import colorama
import PIL
Expand Down Expand Up @@ -109,18 +112,23 @@ def on_agent_finish(
# self.callback(f"{finish.log}\n\n")
logging.info(finish.log)

def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
self.callback(token)

def on_chat_model_start(
def on_llm_new_token(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when a chat model starts running."""
pass
"""Run on new LLM token. Only available when streaming is enabled.
Args:
token (str): The new token.
chunk (GenerationChunk | ChatGenerationChunk): The new generated chunk,
containing content and other information.
"""
logging.info(f"### CHUNK ###: {chunk}")


class ModelType(Enum):
Expand Down

0 comments on commit 0c4dc56

Please sign in to comment.