Skip to content

Commit

Permalink
feat(qa): improve code (#886)
Browse files Browse the repository at this point in the history
* feat(qa): improve code

* feat: 馃幐 customprompt

now in system
  • Loading branch information
StanGirard committed Aug 7, 2023
1 parent fe9280b commit 7028505
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 269 deletions.
73 changes: 1 addition & 72 deletions backend/core/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from typing import AsyncIterable, List

from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.llms.base import LLM
from logger import get_logger
from models.settings import BrainSettings # Importing settings related to the 'brain'
from pydantic import BaseModel # For data validation and settings management
Expand Down Expand Up @@ -73,75 +71,6 @@ class Config:
arbitrary_types_allowed = True

# the below methods define the names, arguments and return types for the most useful functions for the child classes. These should be overwritten if they are used.
@abstractmethod
def _create_llm(self, model, temperature=0, streaming=False, callbacks=None) -> LLM:
"""
Determine and construct the language model.
:param model: Language model name to be used.
:return: Language model instance
This method should take into account the following:
- Whether the model is streaming compatible
- Whether the model is private
- Whether the model should use an openai api key and use the _determine_api_key method
"""

@abstractmethod
def _create_question_chain(self, model) -> LLMChain:
"""
Determine and construct the question chain.
:param model: Language model name to be used.
:return: Question chain instance
This method should take into account the following:
- Which prompt to use (normally CONDENSE_QUESTION_PROMPT)
"""

@abstractmethod
def _create_doc_chain(self, model) -> LLMChain:
"""
Determine and construct the document chain.
:param model Language model name to be used.
:return: Document chain instance
This method should take into account the following:
- chain_type (normally "stuff")
- Whether the model is streaming compatible and/or streaming is set (determine_streaming).
"""

@abstractmethod
def _create_qa(
self, question_chain, document_chain
) -> ConversationalRetrievalChain:
"""
Constructs a conversational retrieval chain .
:param question_chain
:param document_chain
:return: ConversationalRetrievalChain instance
"""

@abstractmethod
def _call_chain(self, chain, question, history) -> str:
"""
Call a chain with a given question and history.
:param chain: The chain eg QA (ConversationalRetrievalChain)
:param question: The user prompt
:param history: The chat history from DB
:return: The answer.
"""

async def _acall_chain(self, chain, question, history) -> str:
"""
Call a chain with a given question and history.
:param chain: The chain eg qa (ConversationalRetrievalChain)
:param question: The user prompt
:param history: The chat history from DB
:return: The answer.
"""
raise NotImplementedError(
"Async generation not implemented for this BrainPicking Class."
)

@abstractmethod
def generate_answer(self, question: str) -> str:
"""
Expand All @@ -153,7 +82,7 @@ def generate_answer(self, question: str) -> str:
It should also update the chat_history in the DB.
"""


@abstractmethod
async def generate_stream(self, question: str) -> AsyncIterable:
"""
Generate a streaming answer to a given question using QA Chain.
Expand Down
19 changes: 1 addition & 18 deletions backend/core/llm/openai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from langchain.chat_models import ChatOpenAI
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms.base import BaseLLM
from llm.qa_base import QABaseBrainPicking
from logger import get_logger

Expand Down Expand Up @@ -46,19 +44,4 @@ def embeddings(self) -> OpenAIEmbeddings:
openai_api_key=self.openai_api_key
) # pyright: ignore reportPrivateUsage=none

def _create_llm(self, model, temperature=0, streaming=False, callbacks=None) -> BaseLLM:
"""
Determine the language model to be used.
:param model: Language model name to be used.
:param streaming: Whether to enable streaming of the model
:param callbacks: Callbacks to be used for streaming
:return: Language model instance
"""
return ChatOpenAI(
temperature=temperature,
model=model,
streaming=streaming,
verbose=True,
callbacks=callbacks,
openai_api_key=self.openai_api_key,
) # pyright: ignore reportPrivateUsage=none

0 comments on commit 7028505

Please sign in to comment.