Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 16 additions & 24 deletions backend/utils/retrieval/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ def determine_conversation(state: GraphState):
question = extract_question_from_conversation(state.get("messages", []))
print("determine_conversation parsed question:", question)

# stream
if state.get('streaming', False):
state['callback'].put_thought_nowait(question)
# # stream
# if state.get('streaming', False):
# state['callback'].put_thought_nowait(question)

return {"parsed_question": question}

Expand Down Expand Up @@ -145,21 +145,13 @@ def no_context_conversation(state: GraphState):
def omi_question(state: GraphState):
print("no_context_omi_question node")

# streaming
streaming = state.get("streaming")
if streaming:
state['callback'].put_thought_nowait("Searching through Omi's documents")

context: dict = get_github_docs_content()
context_str = 'Documentation:\n\n'.join([f'{k}:\n {v}' for k, v in context.items()])

# streaming
streaming = state.get("streaming")
if streaming:
state['callback'].put_thought_nowait(f"Found {len(context.items())} relevant documents")

# streaming
if streaming:
state['callback'].put_thought_nowait("Reasoning")
# state['callback'].put_thought_nowait("Reasoning")
answer: str = answer_omi_question_stream(
state.get("messages", []), context_str,
callbacks=[state.get('callback')]
Expand Down Expand Up @@ -214,9 +206,9 @@ def retrieve_date_filters(state: GraphState):
def query_vectors(state: GraphState):
print("query_vectors")

# stream
if state.get('streaming', False):
state['callback'].put_thought_nowait("Searching through your memories")
# # stream
# if state.get('streaming', False):
# state['callback'].put_thought_nowait("Searching through your memories")

date_filters = state.get("date_filters")
uid = state.get("uid")
Expand All @@ -241,13 +233,13 @@ def query_vectors(state: GraphState):
)
memories = memories_db.get_memories_by_id(uid, memories_id)

# stream
if state.get('streaming', False):
if len(memories) == 0:
msg = "No relevant memories found"
else:
msg = f"Found {len(memories)} relevant memories"
state['callback'].put_thought_nowait(msg)
## stream
#if state.get('streaming', False):
# if len(memories) == 0:
# msg = "No relevant memories found"
# else:
# msg = f"Found {len(memories)} relevant memories"
# state['callback'].put_thought_nowait(msg)

# print(memories_id)
return {"memories_found": memories}
Expand All @@ -259,7 +251,7 @@ def qa_handler(state: GraphState):
# streaming
streaming = state.get("streaming")
if streaming:
state['callback'].put_thought_nowait("Reasoning")
# state['callback'].put_thought_nowait("Reasoning")
memories = state.get("memories_found", [])
response: str = qa_rag_stream(
uid,
Expand Down