-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Open
Description
Hi,
I have built a rag app and I am loading a LLM with Llamacpp. However I have problems with making Streaming work for FastAPI or Langserve requests. Streaming is working in my Terminal, but I don't know what I have to change to make it work in FastAPI/Langserve.
Here is my Langserve code:
from langchain_community.vectorstores.pgvector import PGVector
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.runnables import RunnableParallel
import os
from langchain_community.embeddings import HuggingFaceBgeEmbeddings, HuggingFaceEmbeddings
import box
import yaml
from langchain_community.llms import LlamaCpp
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from operator import itemgetter
from typing import TypedDict
from fastapi import FastAPI
from fastapi.responses import RedirectResponse
from langserve import add_routes
from fastapi.middleware.cors import CORSMiddleware
from starlette.staticfiles import StaticFiles
from langchain_core.output_parsers import StrOutputParser
with open('./config/config.yml', 'r', encoding='utf8') as ymlfile:
cfg = box.Box(yaml.safe_load(ymlfile))
def build_llm(model_path, temperature=cfg.RAG_TEMPERATURE, max_tokens=cfg.MAX_TOKENS, callback = StreamingStdOutCallbackHandler()):
callback_manager = CallbackManager([callback])
n_gpu_layers = 1 # Metal set to 1 is enough. # ausprobiert mit mehreren
n_batch = 512 #1024 Should be between 1 and n_ctx, consider the amount of RAM of your Apple Silicon Chip.
llm = LlamaCpp(
max_tokens = max_tokens,
n_threads = 8,#8, #für performance,
model_path=model_path,
temperature=temperature,
f16_kv=True,
n_ctx=15000, # 8k aber mann muss Platz lassen für Instruction, History etc.
n_gpu_layers=n_gpu_layers,
n_batch=n_batch,
callback_manager=callback_manager,
verbose=True, # Verbose is required to pass to the callback manager
top_p=0.75,
top_k=40,
repeat_penalty = 1.1,
streaming=True,
model_kwargs={
#'repetition_penalty': 1.1,
#'mirostat': 2,
},
)
return llm
embeddings = HuggingFaceEmbeddings(model_name=cfg.EMBEDDING_MODEL_NAME,
model_kwargs={'device': 'mps'})
PG_COLLECTION_NAME = "PGVECTOR_BKB"
model_path = "./modelle/sauerkrautlm-mixtral-8x7b-instruct.Q4_K_M.gguf"
CONNECTION_STRING = "MY_CONNTECTION_STRING"
vector_store = PGVector(
collection_name=PG_COLLECTION_NAME,
connection_string=CONNECTION_STRING,
embedding_function=embeddings
)
prompt= """
<s> [INST] Du bist RagBot, ein hilfsbereiter Assistent. Antworte nur auf Deutsch. Verwende die folgenden Kontextinformationen, um die Frage am Ende knapp zu beantworten. Wenn du die Antwort nicht kennst, sag einfach, dass du es nicht weisst. Erfinde keine Antwort! Falls der Nutzer allgemeine Fragen stellt, führe Smalltalk mit Ihm.
### Hier der Kontext: ###
{context}
### Hier die Frage: ###
{question}
Antwort: [/INST]
"""
def model_response_prompt():
return PromptTemplate(template=prompt, input_variables=['input', 'typescript_string'])
prompt_temp = model_response_prompt()
llm = build_llm(model_path, temperature= cfg.NO_RAG_TEMPERATURE, max_tokens = cfg.NO_RAG_MAX_TOKENS)
class RagInput(TypedDict):
question: str
final_chain = (
RunnableParallel(
context=(itemgetter("question") | vector_store.as_retriever()),
question=itemgetter("question")
) |
RunnableParallel(
answer=(prompt_temp| llm),
docs=itemgetter("context")
)
).with_types(input_type=RagInput)
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=[
"http://localhost:3000"
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
#app.mount("/rag/static", StaticFiles(directory="./source_docs"), name="static")
@app.get("/")
async def redirect_root_to_docs():
return RedirectResponse("/docs")
# Edit this to add the chain you want to add
add_routes(app, final_chain, path="/rag")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
When trying out in the Langserve Playground (http://0.0.0.0:8000/rag/playground/) the response gets streamed in my Terminal but not in the Playground.
So how can I make this work?
Metadata
Metadata
Assignees
Labels
No labels