Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/streaming #251

Merged
merged 8 commits into from
May 31, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 79 additions & 34 deletions application/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
import os
import traceback

import openai
import dotenv
import requests
from celery import Celery
from celery.result import AsyncResult
from flask import Flask, request, render_template, send_from_directory, jsonify
from flask import Flask, request, render_template, send_from_directory, jsonify, Response
from langchain import FAISS
from langchain import VectorDBQA, HuggingFaceHub, Cohere, OpenAI
from langchain.chains import LLMChain, ConversationalRetrievalChain
Expand Down Expand Up @@ -107,6 +108,31 @@ def run_async_chain(chain, question, chat_history):
result["answer"] = answer
return result

def get_vectorstore(data):
if "active_docs" in data:
if data["active_docs"].split("/")[0] == "local":
if data["active_docs"].split("/")[1] == "default":
vectorstore = ""
else:
vectorstore = "indexes/" + data["active_docs"]
else:
vectorstore = "vectors/" + data["active_docs"]
if data['active_docs'] == "default":
vectorstore = ""
else:
vectorstore = ""
return vectorstore

def get_docsearch(vectorstore, embeddings_key):
if settings.EMBEDDINGS_NAME == "openai_text-embedding-ada-002":
docsearch = FAISS.load_local(vectorstore, OpenAIEmbeddings(openai_api_key=embeddings_key))
elif settings.EMBEDDINGS_NAME == "huggingface_sentence-transformers/all-mpnet-base-v2":
docsearch = FAISS.load_local(vectorstore, HuggingFaceHubEmbeddings())
elif settings.EMBEDDINGS_NAME == "huggingface_hkunlp/instructor-large":
docsearch = FAISS.load_local(vectorstore, HuggingFaceInstructEmbeddings())
elif settings.EMBEDDINGS_NAME == "cohere_medium":
docsearch = FAISS.load_local(vectorstore, CohereEmbeddings(cohere_api_key=embeddings_key))
return docsearch

@celery.task(bind=True)
def ingest(self, directory, formats, name_job, filename, user):
Expand All @@ -119,6 +145,53 @@ def home():
return render_template("index.html", api_key_set=api_key_set, llm_choice=settings.LLM_NAME,
embeddings_choice=settings.EMBEDDINGS_NAME)

def complete_stream(question, docsearch, chat_history, api_key):
import sys
openai.api_key = api_key
docs = docsearch.similarity_search(question, k=2)
# join all page_content together with a newline
docs_together = "\n".join([doc.page_content for doc in docs])

# swap {summaries} in chat_combine_template with the summaries from the docs
p_chat_combine = chat_combine_template.replace("{summaries}", docs_together)
completion = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=[
{"role": "system", "content": p_chat_combine},
{"role": "user", "content": question},
], stream=True, max_tokens=1000, temperature=0)

for line in completion:
if 'content' in line['choices'][0]['delta']:
# check if the delta contains content
data = json.dumps({"answer": str(line['choices'][0]['delta']['content'])})
yield f"data: {data}\n\n"
# send data.type = "end" to indicate that the stream has ended as json
data = json.dumps({"type": "end"})
yield f"data: {data}\n\n"
@app.route("/stream", methods=['POST', 'GET'])
def stream():
# get parameter from url question
question = request.args.get('question')
history = request.args.get('history')
# check if active_docs is set

if not api_key_set:
api_key = request.args.get("api_key")
else:
api_key = settings.API_KEY
if not embeddings_key_set:
embeddings_key = request.args.get("embeddings_key")
else:
embeddings_key = settings.EMBEDDINGS_KEY
if "active_docs" in request.args:
vectorstore = get_vectorstore({"active_docs": request.args.get("active_docs")})
else:
vectorstore = ""
docsearch = get_docsearch(vectorstore, embeddings_key)


#question = "Hi"
return Response(complete_stream(question, docsearch, chat_history= history, api_key=api_key), mimetype='text/event-stream')


@app.route("/api/answer", methods=["POST"])
def api_answer():
Expand All @@ -138,41 +211,13 @@ def api_answer():
# use try and except to check for exception
try:
# check if the vectorstore is set
if "active_docs" in data:
if data["active_docs"].split("/")[0] == "local":
if data["active_docs"].split("/")[1] == "default":
vectorstore = ""
else:
vectorstore = "indexes/" + data["active_docs"]
else:
vectorstore = "vectors/" + data["active_docs"]
if data['active_docs'] == "default":
vectorstore = ""
else:
vectorstore = ""
print(vectorstore)
# vectorstore = "outputs/inputs/"
vectorstore = get_vectorstore(data)
# loading the index and the store and the prompt template
# Note if you have used other embeddings than OpenAI, you need to change the embeddings
if settings.EMBEDDINGS_NAME == "openai_text-embedding-ada-002":
docsearch = FAISS.load_local(vectorstore, OpenAIEmbeddings(openai_api_key=embeddings_key))
elif settings.EMBEDDINGS_NAME == "huggingface_sentence-transformers/all-mpnet-base-v2":
docsearch = FAISS.load_local(vectorstore, HuggingFaceHubEmbeddings())
elif settings.EMBEDDINGS_NAME == "huggingface_hkunlp/instructor-large":
docsearch = FAISS.load_local(vectorstore, HuggingFaceInstructEmbeddings())
elif settings.EMBEDDINGS_NAME == "cohere_medium":
docsearch = FAISS.load_local(vectorstore, CohereEmbeddings(cohere_api_key=embeddings_key))

# create a prompt template
if history:
history = json.loads(history)
template_temp = template_hist.replace("{historyquestion}", history[0]).replace("{historyanswer}",
history[1])
c_prompt = PromptTemplate(input_variables=["summaries", "question"], template=template_temp,
template_format="jinja2")
else:
c_prompt = PromptTemplate(input_variables=["summaries", "question"], template=template,
template_format="jinja2")
docsearch = get_docsearch(vectorstore, embeddings_key)

c_prompt = PromptTemplate(input_variables=["summaries", "question"], template=template,
template_format="jinja2")

q_prompt = PromptTemplate(input_variables=["context", "question"], template=template_quest,
template_format="jinja2")
Expand Down
1 change: 1 addition & 0 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ services:
build: ./frontend
environment:
- VITE_API_HOST=http://localhost:5001
- VITE_API_STREAMING=$VITE_API_STREAMING
ports:
- "5173:5173"
depends_on:
Expand Down
50 changes: 48 additions & 2 deletions frontend/src/conversation/conversationApi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ export function fetchAnswerApi(
question: string,
apiKey: string,
selectedDocs: Doc,
history: Array<any> = [],
): Promise<Answer> {
let namePath = selectedDocs.name;
if (selectedDocs.language === namePath) {
Expand Down Expand Up @@ -37,15 +38,15 @@ export function fetchAnswerApi(
question: question,
api_key: apiKey,
embeddings_key: apiKey,
history: localStorage.getItem('chatHistory'),
history: history,
active_docs: docPath,
}),
})
.then((response) => {
if (response.ok) {
return response.json();
} else {
Promise.reject(response);
return Promise.reject(new Error(response.statusText));
}
})
.then((data) => {
Expand All @@ -54,6 +55,51 @@ export function fetchAnswerApi(
});
}

export function fetchAnswerSteaming(
question: string,
apiKey: string,
selectedDocs: Doc,
onEvent: (event: MessageEvent) => void,
): Promise<Answer> {
let namePath = selectedDocs.name;
if (selectedDocs.language === namePath) {
namePath = '.project';
}

let docPath = 'default';
if (selectedDocs.location === 'local') {
docPath = 'local' + '/' + selectedDocs.name + '/';
} else if (selectedDocs.location === 'remote') {
docPath =
selectedDocs.language +
'/' +
namePath +
'/' +
selectedDocs.version +
'/' +
selectedDocs.model +
'/';
}

return new Promise<Answer>((resolve, reject) => {
const url = new URL(apiHost + '/stream');
url.searchParams.append('question', question);
url.searchParams.append('api_key', apiKey);
url.searchParams.append('embeddings_key', apiKey);
url.searchParams.append('history', localStorage.getItem('chatHistory'));
url.searchParams.append('active_docs', docPath);

const eventSource = new EventSource(url.href);

eventSource.onmessage = onEvent;

eventSource.onerror = (error) => {
console.log('Connection failed.');
eventSource.close();
};
});
}

export function sendFeedback(
prompt: string,
response: string,
Expand Down
93 changes: 72 additions & 21 deletions frontend/src/conversation/conversationSlice.ts
Original file line number Diff line number Diff line change
@@ -1,27 +1,64 @@
import { createAsyncThunk, createSlice, PayloadAction } from '@reduxjs/toolkit';
import store from '../store';
import { fetchAnswerApi } from './conversationApi';
import { Answer, ConversationState, Query } from './conversationModels';
import { fetchAnswerApi, fetchAnswerSteaming } from './conversationApi';
import { Answer, ConversationState, Query, Status } from './conversationModels';

const initialState: ConversationState = {
queries: [],
status: 'idle',
};

export const fetchAnswer = createAsyncThunk<
Answer,
{ question: string },
{ state: RootState }
>('fetchAnswer', async ({ question }, { getState }) => {
const state = getState();
const API_STREAMING = import.meta.env.VITE_API_STREAMING === 'true';

const answer = await fetchAnswerApi(
question,
state.preference.apiKey,
state.preference.selectedDocs!,
);
return answer;
});
export const fetchAnswer = createAsyncThunk<Answer, { question: string }>(
'fetchAnswer',
async ({ question }, { dispatch, getState }) => {
const state = getState() as RootState;
if (state.preference) {
if (API_STREAMING) {
await fetchAnswerSteaming(
question,
state.preference.apiKey,
state.preference.selectedDocs!,
(event) => {
const data = JSON.parse(event.data);

// check if the 'end' event has been received
if (data.type === 'end') {
// set status to 'idle'
dispatch(conversationSlice.actions.setStatus('idle'));
} else {
const result = data.answer;
dispatch(
updateStreamingQuery({
index: state.conversation.queries.length - 1,
query: { response: result },
}),
);
}
},
);
} else {
const answer = await fetchAnswerApi(
question,
state.preference.apiKey,
state.preference.selectedDocs!,
state.conversation.queries,
);
if (answer) {
dispatch(
updateQuery({
index: state.conversation.queries.length - 1,
query: { response: answer.answer },
}),
);
dispatch(conversationSlice.actions.setStatus('idle'));
}
}
}
return { answer: '', query: question, result: '' };
},
);

export const conversationSlice = createSlice({
name: 'conversation',
Expand All @@ -30,6 +67,21 @@ export const conversationSlice = createSlice({
addQuery(state, action: PayloadAction<Query>) {
state.queries.push(action.payload);
},
updateStreamingQuery(
state,
action: PayloadAction<{ index: number; query: Partial<Query> }>,
) {
const index = action.payload.index;
if (action.payload.query.response) {
state.queries[index].response =
(state.queries[index].response || '') + action.payload.query.response;
} else {
state.queries[index] = {
...state.queries[index],
...action.payload.query,
};
}
},
updateQuery(
state,
action: PayloadAction<{ index: number; query: Partial<Query> }>,
Expand All @@ -40,17 +92,15 @@ export const conversationSlice = createSlice({
...action.payload.query,
};
},
setStatus(state, action: PayloadAction<Status>) {
state.status = action.payload;
},
},
extraReducers(builder) {
builder
.addCase(fetchAnswer.pending, (state) => {
state.status = 'loading';
})
.addCase(fetchAnswer.fulfilled, (state, action) => {
state.status = 'idle';
state.queries[state.queries.length - 1].response =
action.payload.answer;
})
.addCase(fetchAnswer.rejected, (state, action) => {
state.status = 'failed';
state.queries[state.queries.length - 1].error =
Expand All @@ -65,5 +115,6 @@ export const selectQueries = (state: RootState) => state.conversation.queries;

export const selectStatus = (state: RootState) => state.conversation.status;

export const { addQuery, updateQuery } = conversationSlice.actions;
export const { addQuery, updateQuery, updateStreamingQuery } =
conversationSlice.actions;
export default conversationSlice.reducer;
Loading