diff --git a/backend/app/tests/api/routes/test_threads.py b/backend/app/tests/api/routes/test_threads.py index 08ed1930..341b8d11 100644 --- a/backend/app/tests/api/routes/test_threads.py +++ b/backend/app/tests/api/routes/test_threads.py @@ -1,36 +1,40 @@ +import uuid from unittest.mock import MagicMock, patch -import pytest, uuid -from fastapi import FastAPI -from fastapi.testclient import TestClient +import pytest from sqlmodel import select from app.api.routes.threads import ( process_run, - router, validate_thread, setup_thread, process_message_content, handle_openai_error, poll_run_and_prepare_response, ) -from app.models import APIKey, OpenAI_Thread +from app.models import OpenAI_Thread from app.crud import get_thread_result from app.core.langfuse.langfuse import LangfuseTracer import openai from openai import OpenAIError -# Wrap the router in a FastAPI app instance. -app = FastAPI() -app.include_router(router) -client = TestClient(app) - -@patch("app.api.routes.threads.OpenAI") -def test_threads_endpoint(mock_openai, db): +@patch("app.api.routes.threads.configure_openai") +@patch("app.api.routes.threads.get_provider_credential") +@patch("app.api.routes.threads.send_callback") +@patch("app.api.routes.threads.process_run") +def test_threads_endpoint( + mock_process_run, + mock_send_callback, + mock_get_provider_credential, + mock_configure_openai, + client, + db, + user_api_key_header, +): """ Test the /threads endpoint when creating a new thread. - The patched OpenAI client simulates: + The patched configure_openai function simulates: - A successful assistant ID validation. - New thread creation with a dummy thread id. - No existing runs. @@ -49,21 +53,19 @@ def test_threads_endpoint(mock_openai, db): # Simulate that no active run exists. dummy_client.beta.threads.runs.list.return_value = MagicMock(data=[]) - mock_openai.return_value = dummy_client - - # Get an API key from the database - api_key_record = db.exec(select(APIKey).where(APIKey.is_deleted is False)).first() - if not api_key_record: - pytest.skip("No API key found in the database for testing") - - headers = {"X-API-KEY": api_key_record.key} + # Mock get_provider_credential to return dummy credentials + mock_get_provider_credential.return_value = {"api_key": "dummy_api_key"} + # Mock configure_openai to return our dummy client + mock_configure_openai.return_value = (dummy_client, True) request_data = { "question": "What is Glific?", "assistant_id": "assistant_123", "callback_url": "http://example.com/callback", } - response = client.post("/threads", json=request_data, headers=headers) + response = client.post( + "/api/v1/threads", json=request_data, headers=user_api_key_header + ) assert response.status_code == 200 response_json = response.json() assert response_json["success"] is True @@ -72,7 +74,8 @@ def test_threads_endpoint(mock_openai, db): assert response_json["data"]["thread_id"] == "dummy_thread_id" -@patch("app.api.routes.threads.OpenAI") +@patch("app.api.routes.threads.configure_openai") +@patch("app.api.routes.threads.get_provider_credential") @pytest.mark.parametrize( "remove_citation, expected_message", [ @@ -86,15 +89,21 @@ def test_threads_endpoint(mock_openai, db): ), ], ) -def test_process_run_variants(mock_openai, remove_citation, expected_message): +def test_process_run_variants( + mock_get_provider_credential, + mock_configure_openai, + remove_citation, + expected_message, +): """ Test process_run for both remove_citation variants: - - Mocks the OpenAI client to simulate a completed run. + - Mocks the configure_openai function to simulate a completed run. - Verifies that send_callback is called with the expected message based on the remove_citation flag. """ # Setup the mock client. mock_client = MagicMock() - mock_openai.return_value = mock_client + mock_get_provider_credential.return_value = {"api_key": "dummy_api_key"} + mock_configure_openai.return_value = (mock_client, True) # Create the request with the variable remove_citation flag. request = { @@ -132,14 +141,48 @@ def test_process_run_variants(mock_openai, remove_citation, expected_message): assert payload["success"] is True -@patch("app.api.routes.threads.OpenAI") -def test_threads_sync_endpoint_success(mock_openai, db): +@patch("app.api.routes.threads.configure_openai") +@patch("app.api.routes.threads.get_provider_credential") +def test_threads_sync_endpoint_active_run( + mock_get_provider_credential, mock_configure_openai, client, db, user_api_key_header +): + """Test the /threads/sync endpoint when there's an active run.""" + # Setup mock client + mock_client = MagicMock() + mock_get_provider_credential.return_value = {"api_key": "dummy_api_key"} + mock_configure_openai.return_value = (mock_client, True) + + # Simulate active run + mock_run = MagicMock() + mock_run.status = "in_progress" + mock_client.beta.threads.runs.list.return_value = MagicMock(data=[mock_run]) + + request_data = { + "question": "Test question", + "assistant_id": "assistant_123", + "thread_id": "existing_thread", + } + + # Expect the endpoint to raise when there's an active run + with pytest.raises(Exception) as excinfo: + client.post( + "/api/v1/threads/sync", json=request_data, headers=user_api_key_header + ) + assert "active run" in str(excinfo.value).lower() + + +@patch("app.api.routes.threads.configure_openai") +@patch("app.api.routes.threads.get_provider_credential") +def test_threads_sync_endpoint_success( + mock_get_provider_credential, mock_configure_openai, client, db, user_api_key_header +): """Test the /threads/sync endpoint for successful completion.""" # Setup mock client mock_client = MagicMock() - mock_openai.return_value = mock_client + mock_get_provider_credential.return_value = {"api_key": "dummy_api_key"} + mock_configure_openai.return_value = (mock_client, True) - # Simulate thread validation + # Simulate thread validation (no active runs) mock_client.beta.threads.runs.list.return_value = MagicMock(data=[]) # Simulate thread creation @@ -153,6 +196,10 @@ def test_threads_sync_endpoint_success(mock_openai, db): # Simulate successful run mock_run = MagicMock() mock_run.status = "completed" + mock_run.usage.prompt_tokens = 10 + mock_run.usage.completion_tokens = 20 + mock_run.usage.total_tokens = 30 + mock_run.model = "gpt-4" mock_client.beta.threads.runs.create_and_poll.return_value = mock_run # Simulate message retrieval @@ -160,55 +207,23 @@ def test_threads_sync_endpoint_success(mock_openai, db): dummy_message.content = [MagicMock(text=MagicMock(value="Test response"))] mock_client.beta.threads.messages.list.return_value.data = [dummy_message] - # Get API key - api_key_record = db.exec(select(APIKey).where(APIKey.is_deleted is False)).first() - if not api_key_record: - pytest.skip("No API key found in the database for testing") - - headers = {"X-API-KEY": api_key_record.key} request_data = { "question": "Test question", "assistant_id": "assistant_123", } - response = client.post("/threads/sync", json=request_data, headers=headers) + response = client.post( + "/api/v1/threads/sync", json=request_data, headers=user_api_key_header + ) + assert response.status_code == 200 response_json = response.json() assert response_json["success"] is True assert response_json["data"]["status"] == "success" assert response_json["data"]["message"] == "Test response" assert response_json["data"]["thread_id"] == "sync_thread_id" - - -@patch("app.api.routes.threads.OpenAI") -def test_threads_sync_endpoint_active_run(mock_openai, db): - """Test the /threads/sync endpoint when there's an active run.""" - # Setup mock client - mock_client = MagicMock() - mock_openai.return_value = mock_client - - # Simulate active run - mock_run = MagicMock() - mock_run.status = "in_progress" - mock_client.beta.threads.runs.list.return_value = MagicMock(data=[mock_run]) - - # Get API key - api_key_record = db.exec(select(APIKey).where(APIKey.is_deleted is False)).first() - if not api_key_record: - pytest.skip("No API key found in the database for testing") - - headers = {"X-API-KEY": api_key_record.key} - request_data = { - "question": "Test question", - "assistant_id": "assistant_123", - "thread_id": "existing_thread", - } - - response = client.post("/threads/sync", json=request_data, headers=headers) - assert response.status_code == 200 - response_json = response.json() - assert response_json["success"] is False - assert "active run" in response_json["error"].lower() + assert "diagnostics" in response_json["data"] + assert response_json["data"]["diagnostics"]["total_tokens"] == 30 def test_validate_thread_no_thread_id(): @@ -393,8 +408,11 @@ def test_handle_openai_error_with_none_body(): assert result == "None body error" -@patch("app.api.routes.threads.OpenAI") -def test_poll_run_and_prepare_response_completed(mock_openai, db): +@patch("app.api.routes.threads.configure_openai") +@patch("app.api.routes.threads.get_provider_credential") +def test_poll_run_and_prepare_response_completed( + mock_get_provider_credential, mock_configure_openai, db +): mock_client = MagicMock() mock_run = MagicMock() mock_run.status = "completed" @@ -403,7 +421,8 @@ def test_poll_run_and_prepare_response_completed(mock_openai, db): mock_message = MagicMock() mock_message.content = [MagicMock(text=MagicMock(value="Answer"))] mock_client.beta.threads.messages.list.return_value.data = [mock_message] - mock_openai.return_value = mock_client + mock_get_provider_credential.return_value = {"api_key": "dummy_api_key"} + mock_configure_openai.return_value = (mock_client, True) request = { "question": "What is Glific?", @@ -418,12 +437,16 @@ def test_poll_run_and_prepare_response_completed(mock_openai, db): assert result.response.strip() == "Answer" -@patch("app.api.routes.threads.OpenAI") -def test_poll_run_and_prepare_response_openai_error_handling(mock_openai, db): +@patch("app.api.routes.threads.configure_openai") +@patch("app.api.routes.threads.get_provider_credential") +def test_poll_run_and_prepare_response_openai_error_handling( + mock_get_provider_credential, mock_configure_openai, db +): mock_client = MagicMock() mock_error = OpenAIError("Simulated OpenAI error") mock_client.beta.threads.runs.create_and_poll.side_effect = mock_error - mock_openai.return_value = mock_client + mock_get_provider_credential.return_value = {"api_key": "dummy_api_key"} + mock_configure_openai.return_value = (mock_client, True) request = { "question": "Failing run", @@ -445,12 +468,16 @@ def test_poll_run_and_prepare_response_openai_error_handling(mock_openai, db): assert "Simulated OpenAI error" in (result.error or "") -@patch("app.api.routes.threads.OpenAI") -def test_poll_run_and_prepare_response_non_completed(mock_openai, db): +@patch("app.api.routes.threads.configure_openai") +@patch("app.api.routes.threads.get_provider_credential") +def test_poll_run_and_prepare_response_non_completed( + mock_get_provider_credential, mock_configure_openai, db +): mock_client = MagicMock() mock_run = MagicMock(status="failed") mock_client.beta.threads.runs.create_and_poll.return_value = mock_run - mock_openai.return_value = mock_client + mock_get_provider_credential.return_value = {"api_key": "dummy_api_key"} + mock_configure_openai.return_value = (mock_client, True) request = { "question": "Incomplete run", @@ -471,24 +498,31 @@ def test_poll_run_and_prepare_response_non_completed(mock_openai, db): assert result.status == "failed" -@patch("app.api.routes.threads.OpenAI") -def test_threads_start_endpoint_creates_thread(mock_openai, db): +@patch("app.api.routes.threads.configure_openai") +@patch("app.api.routes.threads.get_provider_credential") +@patch("app.api.routes.threads.poll_run_and_prepare_response") +def test_threads_start_endpoint_creates_thread( + mock_poll_run, + mock_get_provider_credential, + mock_configure_openai, + client, + db, + user_api_key_header, +): """Test /threads/start creates thread and schedules background task.""" mock_client = MagicMock() mock_thread = MagicMock() mock_thread.id = "mock_thread_001" mock_client.beta.threads.create.return_value = mock_thread mock_client.beta.threads.messages.create.return_value = None - mock_openai.return_value = mock_client - - api_key_record = db.exec(select(APIKey).where(APIKey.is_deleted is False)).first() - if not api_key_record: - pytest.skip("No API key found in the database for testing") + mock_get_provider_credential.return_value = {"api_key": "dummy_api_key"} + mock_configure_openai.return_value = (mock_client, True) - headers = {"X-API-KEY": api_key_record.key} data = {"question": "What's 2+2?", "assistant_id": "assist_123"} - response = client.post("/threads/start", json=data, headers=headers) + response = client.post( + "/api/v1/threads/start", json=data, headers=user_api_key_header + ) assert response.status_code == 200 res_json = response.json() assert res_json["success"] @@ -497,7 +531,7 @@ def test_threads_start_endpoint_creates_thread(mock_openai, db): assert res_json["data"]["prompt"] == "What's 2+2?" -def test_threads_result_endpoint_success(db): +def test_threads_result_endpoint_success(client, db, user_api_key_header): """Test /threads/result/{thread_id} returns completed thread.""" thread_id = f"test_processing_{uuid.uuid4()}" question = "Capital of France?" @@ -506,12 +540,9 @@ def test_threads_result_endpoint_success(db): db.add(OpenAI_Thread(thread_id=thread_id, prompt=question, response=message)) db.commit() - api_key_record = db.exec(select(APIKey).where(APIKey.is_deleted is False)).first() - if not api_key_record: - pytest.skip("No API key found in the database for testing") - - headers = {"X-API-KEY": api_key_record.key} - response = client.get(f"/threads/result/{thread_id}", headers=headers) + response = client.get( + f"/api/v1/threads/result/{thread_id}", headers=user_api_key_header + ) assert response.status_code == 200 data = response.json()["data"] @@ -521,7 +552,7 @@ def test_threads_result_endpoint_success(db): assert data["prompt"] == question -def test_threads_result_endpoint_processing(db): +def test_threads_result_endpoint_processing(client, db, user_api_key_header): """Test /threads/result/{thread_id} returns processing status if no message yet.""" thread_id = f"test_processing_{uuid.uuid4()}" question = "What is Glific?" @@ -529,50 +560,45 @@ def test_threads_result_endpoint_processing(db): db.add(OpenAI_Thread(thread_id=thread_id, prompt=question, response=None)) db.commit() - api_key_record = db.exec(select(APIKey).where(APIKey.is_deleted is False)).first() - if not api_key_record: - pytest.skip("No API key found in the database for testing") - - headers = {"X-API-KEY": api_key_record.key} - response = client.get(f"/threads/result/{thread_id}", headers=headers) + response = client.get( + f"/api/v1/threads/result/{thread_id}", headers=user_api_key_header + ) assert response.status_code == 200 data = response.json()["data"] assert data["status"] == "processing" - assert data["message"] is None + assert data["response"] is None assert data["thread_id"] == thread_id assert data["prompt"] == question -def test_threads_result_not_found(db): +def test_threads_result_not_found(client, user_api_key_header): """Test /threads/result/{thread_id} returns error for nonexistent thread.""" - api_key_record = db.exec(select(APIKey).where(APIKey.is_deleted is False)).first() - if not api_key_record: - pytest.skip("No API key found in the database for testing") - - headers = {"X-API-KEY": api_key_record.key} - response = client.get("/threads/result/nonexistent_thread", headers=headers) - - assert response.status_code == 200 - assert response.json()["success"] is False - assert "not found" in response.json()["error"].lower() + response = client.get( + "/api/v1/threads/result/nonexistent_thread", headers=user_api_key_header + ) + assert response.status_code == 404 + response_data = response.json() + assert response_data["success"] is False + assert "thread not found" in response_data["error"].lower() -@patch("app.api.routes.threads.OpenAI") -def test_threads_start_missing_question(mock_openai, db): +@patch("app.api.routes.threads.configure_openai") +@patch("app.api.routes.threads.get_provider_credential") +def test_threads_start_missing_question( + mock_get_provider_credential, mock_configure_openai, client, user_api_key_header +): """Test /threads/start with missing 'question' key in request.""" - mock_openai.return_value = MagicMock() - - api_key_record = db.exec(select(APIKey).where(APIKey.is_deleted is False)).first() - if not api_key_record: - pytest.skip("No API key found in the database for testing") - - headers = {"X-API-KEY": api_key_record.key} + mock_get_provider_credential.return_value = {"api_key": "dummy_api_key"} + mock_configure_openai.return_value = (MagicMock(), True) bad_data = {"assistant_id": "assist_123"} # no "question" key - response = client.post("/threads/start", json=bad_data, headers=headers) + response = client.post( + "/api/v1/threads/start", json=bad_data, headers=user_api_key_header + ) assert response.status_code == 422 # Unprocessable Entity (FastAPI will raise 422) error_response = response.json() - assert "detail" in error_response + assert error_response["success"] is False + assert "question" in error_response["error"]