Skip to content

Commit

Permalink
add test for a streaming assistant (#349)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Mar 11, 2024
1 parent 1043241 commit 4a667f9
Showing 1 changed file with 30 additions and 10 deletions.
40 changes: 30 additions & 10 deletions tests/deploy/api/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,45 @@
import pytest
from fastapi.testclient import TestClient

from ragna.assistants import RagnaDemoAssistant
from ragna.deploy import Config
from ragna.deploy._api import app

from .utils import authenticate


@pytest.mark.parametrize("stream_answer", [True, False])
def test_e2e(tmp_local_root, stream_answer):
config = Config(local_root=tmp_local_root)
check_api(config, stream_answer=stream_answer)
class TestAssistant(RagnaDemoAssistant):
@property
def max_input_size(self) -> int:
return 0

def answer(self, prompt, sources, *, multiple_answer_chunks: bool):
content = next(super().answer(prompt, sources))

if multiple_answer_chunks:
for chunk in content.split(" "):
yield f"{chunk} "
else:
yield content


def check_api(config, *, stream_answer):
@pytest.mark.parametrize("multiple_answer_chunks", [True, False])
@pytest.mark.parametrize("stream_answer", [True, False])
def test_e2e(tmp_local_root, multiple_answer_chunks, stream_answer):
config = Config(local_root=tmp_local_root, assistants=[TestAssistant])

document_root = config.local_root / "documents"
document_root.mkdir()
document_path = document_root / "test.txt"
with open(document_path, "w") as file:
file.write("!\n")

# Reset starlette_sse AppStatus for each run
# See https://github.com/sysid/sse-starlette/issues/59
from sse_starlette.sse import AppStatus

AppStatus.should_exit_event = None

with TestClient(app(config=config, ignore_unavailable_components=False)) as client:
authenticate(client)

Expand Down Expand Up @@ -66,7 +86,7 @@ def check_api(config, *, stream_answer):
"name": "test-chat",
"source_storage": source_storage,
"assistant": assistant,
"params": {},
"params": {"multiple_answer_chunks": multiple_answer_chunks},
"documents": [document],
}
chat = client.post("/chats", json=chat_metadata).raise_for_status().json()
Expand Down Expand Up @@ -96,10 +116,10 @@ def check_api(config, *, stream_answer):
json={"prompt": prompt, "stream": True},
) as event_source:
for sse in event_source.iter_sse():
chunk = json.loads(sse.data)
chunks.append(chunk["content"])
message = chunk
message["content"] = "".join(chunks)
chunks.append(json.loads(sse.data))
message = chunks[0]
assert all(chunk["sources"] is None for chunk in chunks[1:])
message["content"] = "".join(chunk["content"] for chunk in chunks)
else:
message = (
client.post(f"/chats/{chat['id']}/answer", json={"prompt": prompt})
Expand Down

0 comments on commit 4a667f9

Please sign in to comment.