Skip to content

Commit 1e25cf6

Browse files
committed
generated file: tests/unit/test_openai.py
1 parent 8e9ffcd commit 1e25cf6

File tree

1 file changed

+321
-0
lines changed

1 file changed

+321
-0
lines changed

Diff for: tests/unit/test_openai.py

+321
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,321 @@
1+
import pytest
2+
from fastapi import HTTPException, status
3+
from unittest.mock import patch, MagicMock
4+
5+
from openai_api_client.dependencies.openai import OpenAIService, openai_service
6+
from openai_api_client.schemas.openai import OpenAIRequest, OpenAIResponse, OpenAIModel
7+
8+
# Mocking the OpenAI library for unit testing
9+
@pytest.fixture
10+
def mock_openai():
11+
with patch("openai_api_client.dependencies.openai.openai") as mock_openai:
12+
yield mock_openai
13+
14+
# Test cases for text completion
15+
class TestOpenAI_CompleteText:
16+
def test_complete_text_success(self, mock_openai):
17+
"""Test successful text completion with valid parameters."""
18+
mock_openai.Completion.create.return_value = MagicMock(
19+
choices=[MagicMock(text="This is the completed text.")],
20+
)
21+
request = OpenAIRequest(text="This is the prompt.")
22+
service = OpenAIService()
23+
response = service.complete_text(text=request.text)
24+
assert response.response == "This is the completed text."
25+
mock_openai.Completion.create.assert_called_once_with(
26+
engine=request.model, prompt=request.text, temperature=request.temperature, max_tokens=request.max_tokens
27+
)
28+
29+
def test_complete_text_invalid_model(self, mock_openai):
30+
"""Test handling of invalid OpenAI model."""
31+
mock_openai.Completion.create.side_effect = openai.error.InvalidRequestError("Invalid model.")
32+
request = OpenAIRequest(text="This is the prompt.", model="invalid_model")
33+
service = OpenAIService()
34+
with pytest.raises(HTTPException) as exc:
35+
service.complete_text(text=request.text, model=request.model)
36+
assert exc.value.status_code == status.HTTP_400_BAD_REQUEST
37+
assert "Invalid request to OpenAI API" in str(exc.value.detail)
38+
39+
def test_complete_text_rate_limit(self, mock_openai):
40+
"""Test handling of OpenAI API rate limit."""
41+
mock_openai.Completion.create.side_effect = openai.error.RateLimitError("Rate limit exceeded.")
42+
request = OpenAIRequest(text="This is the prompt.")
43+
service = OpenAIService()
44+
with pytest.raises(HTTPException) as exc:
45+
service.complete_text(text=request.text)
46+
assert exc.value.status_code == status.HTTP_429_TOO_MANY_REQUESTS
47+
assert "OpenAI API rate limit exceeded" in str(exc.value.detail)
48+
49+
def test_complete_text_authentication_error(self, mock_openai):
50+
"""Test handling of authentication error with OpenAI API."""
51+
mock_openai.Completion.create.side_effect = openai.error.AuthenticationError("Invalid API key.")
52+
request = OpenAIRequest(text="This is the prompt.")
53+
service = OpenAIService()
54+
with pytest.raises(HTTPException) as exc:
55+
service.complete_text(text=request.text)
56+
assert exc.value.status_code == status.HTTP_401_UNAUTHORIZED
57+
assert "Invalid OpenAI API key" in str(exc.value.detail)
58+
59+
def test_complete_text_timeout_error(self, mock_openai):
60+
"""Test handling of timeout error during API call."""
61+
mock_openai.Completion.create.side_effect = openai.error.TimeoutError("Request timed out.")
62+
request = OpenAIRequest(text="This is the prompt.")
63+
service = OpenAIService()
64+
with pytest.raises(HTTPException) as exc:
65+
service.complete_text(text=request.text)
66+
assert exc.value.status_code == status.HTTP_504_GATEWAY_TIMEOUT
67+
assert "Request to OpenAI API timed out" in str(exc.value.detail)
68+
69+
def test_complete_text_connection_error(self, mock_openai):
70+
"""Test handling of connection error with OpenAI API."""
71+
mock_openai.Completion.create.side_effect = openai.error.APIConnectionError("Connection error.")
72+
request = OpenAIRequest(text="This is the prompt.")
73+
service = OpenAIService()
74+
with pytest.raises(HTTPException) as exc:
75+
service.complete_text(text=request.text)
76+
assert exc.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
77+
assert "Error connecting to OpenAI API" in str(exc.value.detail)
78+
79+
def test_complete_text_general_api_error(self, mock_openai):
80+
"""Test handling of general API error during call."""
81+
mock_openai.Completion.create.side_effect = openai.error.APIError("General API error.")
82+
request = OpenAIRequest(text="This is the prompt.")
83+
service = OpenAIService()
84+
with pytest.raises(HTTPException) as exc:
85+
service.complete_text(text=request.text)
86+
assert exc.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
87+
assert "Error calling OpenAI API" in str(exc.value.detail)
88+
89+
# Test cases for text translation
90+
class TestOpenAI_TranslateText:
91+
def test_translate_text_success(self, mock_openai):
92+
"""Test successful text translation with valid parameters."""
93+
mock_openai.Translation.create.return_value = MagicMock(
94+
choices=[MagicMock(text="This is the translated text.")],
95+
)
96+
request = OpenAIRequest(text="This is the text to translate.", source_language="en", target_language="fr")
97+
service = OpenAIService()
98+
response = service.translate_text(text=request.text, source_language=request.source_language, target_language=request.target_language)
99+
assert response.response == "This is the translated text."
100+
mock_openai.Translation.create.assert_called_once_with(
101+
model="gpt-3.5-turbo", from_language="en", to_language="fr", text="This is the text to translate."
102+
)
103+
104+
def test_translate_text_invalid_request(self, mock_openai):
105+
"""Test handling of invalid request for translation."""
106+
mock_openai.Translation.create.side_effect = openai.error.InvalidRequestError("Invalid request.")
107+
request = OpenAIRequest(text="This is the text to translate.", source_language="invalid", target_language="fr")
108+
service = OpenAIService()
109+
with pytest.raises(HTTPException) as exc:
110+
service.translate_text(text=request.text, source_language=request.source_language, target_language=request.target_language)
111+
assert exc.value.status_code == status.HTTP_400_BAD_REQUEST
112+
assert "Invalid request to OpenAI API" in str(exc.value.detail)
113+
114+
def test_translate_text_rate_limit(self, mock_openai):
115+
"""Test handling of OpenAI API rate limit during translation."""
116+
mock_openai.Translation.create.side_effect = openai.error.RateLimitError("Rate limit exceeded.")
117+
request = OpenAIRequest(text="This is the text to translate.", source_language="en", target_language="fr")
118+
service = OpenAIService()
119+
with pytest.raises(HTTPException) as exc:
120+
service.translate_text(text=request.text, source_language=request.source_language, target_language=request.target_language)
121+
assert exc.value.status_code == status.HTTP_429_TOO_MANY_REQUESTS
122+
assert "OpenAI API rate limit exceeded" in str(exc.value.detail)
123+
124+
def test_translate_text_authentication_error(self, mock_openai):
125+
"""Test handling of authentication error during translation."""
126+
mock_openai.Translation.create.side_effect = openai.error.AuthenticationError("Invalid API key.")
127+
request = OpenAIRequest(text="This is the text to translate.", source_language="en", target_language="fr")
128+
service = OpenAIService()
129+
with pytest.raises(HTTPException) as exc:
130+
service.translate_text(text=request.text, source_language=request.source_language, target_language=request.target_language)
131+
assert exc.value.status_code == status.HTTP_401_UNAUTHORIZED
132+
assert "Invalid OpenAI API key" in str(exc.value.detail)
133+
134+
def test_translate_text_timeout_error(self, mock_openai):
135+
"""Test handling of timeout error during translation."""
136+
mock_openai.Translation.create.side_effect = openai.error.TimeoutError("Request timed out.")
137+
request = OpenAIRequest(text="This is the text to translate.", source_language="en", target_language="fr")
138+
service = OpenAIService()
139+
with pytest.raises(HTTPException) as exc:
140+
service.translate_text(text=request.text, source_language=request.source_language, target_language=request.target_language)
141+
assert exc.value.status_code == status.HTTP_504_GATEWAY_TIMEOUT
142+
assert "Request to OpenAI API timed out" in str(exc.value.detail)
143+
144+
def test_translate_text_connection_error(self, mock_openai):
145+
"""Test handling of connection error during translation."""
146+
mock_openai.Translation.create.side_effect = openai.error.APIConnectionError("Connection error.")
147+
request = OpenAIRequest(text="This is the text to translate.", source_language="en", target_language="fr")
148+
service = OpenAIService()
149+
with pytest.raises(HTTPException) as exc:
150+
service.translate_text(text=request.text, source_language=request.source_language, target_language=request.target_language)
151+
assert exc.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
152+
assert "Error connecting to OpenAI API" in str(exc.value.detail)
153+
154+
def test_translate_text_general_api_error(self, mock_openai):
155+
"""Test handling of general API error during translation."""
156+
mock_openai.Translation.create.side_effect = openai.error.APIError("General API error.")
157+
request = OpenAIRequest(text="This is the text to translate.", source_language="en", target_language="fr")
158+
service = OpenAIService()
159+
with pytest.raises(HTTPException) as exc:
160+
service.translate_text(text=request.text, source_language=request.source_language, target_language=request.target_language)
161+
assert exc.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
162+
assert "Error calling OpenAI API" in str(exc.value.detail)
163+
164+
# Test cases for text summarization
165+
class TestOpenAI_SummarizeText:
166+
def test_summarize_text_success(self, mock_openai):
167+
"""Test successful text summarization with valid parameters."""
168+
mock_openai.Completion.create.return_value = MagicMock(
169+
choices=[MagicMock(text="This is the summarized text.")],
170+
)
171+
request = OpenAIRequest(text="This is the text to summarize.")
172+
service = OpenAIService()
173+
response = service.summarize_text(text=request.text)
174+
assert response.response == "This is the summarized text."
175+
mock_openai.Completion.create.assert_called_once_with(
176+
engine=request.model, prompt=f"Summarize the following text:\n\n{request.text}", temperature=0.7, max_tokens=256
177+
)
178+
179+
def test_summarize_text_invalid_request(self, mock_openai):
180+
"""Test handling of invalid request for summarization."""
181+
mock_openai.Completion.create.side_effect = openai.error.InvalidRequestError("Invalid request.")
182+
request = OpenAIRequest(text="This is the text to summarize.")
183+
service = OpenAIService()
184+
with pytest.raises(HTTPException) as exc:
185+
service.summarize_text(text=request.text)
186+
assert exc.value.status_code == status.HTTP_400_BAD_REQUEST
187+
assert "Invalid request to OpenAI API" in str(exc.value.detail)
188+
189+
def test_summarize_text_rate_limit(self, mock_openai):
190+
"""Test handling of OpenAI API rate limit during summarization."""
191+
mock_openai.Completion.create.side_effect = openai.error.RateLimitError("Rate limit exceeded.")
192+
request = OpenAIRequest(text="This is the text to summarize.")
193+
service = OpenAIService()
194+
with pytest.raises(HTTPException) as exc:
195+
service.summarize_text(text=request.text)
196+
assert exc.value.status_code == status.HTTP_429_TOO_MANY_REQUESTS
197+
assert "OpenAI API rate limit exceeded" in str(exc.value.detail)
198+
199+
def test_summarize_text_authentication_error(self, mock_openai):
200+
"""Test handling of authentication error during summarization."""
201+
mock_openai.Completion.create.side_effect = openai.error.AuthenticationError("Invalid API key.")
202+
request = OpenAIRequest(text="This is the text to summarize.")
203+
service = OpenAIService()
204+
with pytest.raises(HTTPException) as exc:
205+
service.summarize_text(text=request.text)
206+
assert exc.value.status_code == status.HTTP_401_UNAUTHORIZED
207+
assert "Invalid OpenAI API key" in str(exc.value.detail)
208+
209+
def test_summarize_text_timeout_error(self, mock_openai):
210+
"""Test handling of timeout error during summarization."""
211+
mock_openai.Completion.create.side_effect = openai.error.TimeoutError("Request timed out.")
212+
request = OpenAIRequest(text="This is the text to summarize.")
213+
service = OpenAIService()
214+
with pytest.raises(HTTPException) as exc:
215+
service.summarize_text(text=request.text)
216+
assert exc.value.status_code == status.HTTP_504_GATEWAY_TIMEOUT
217+
assert "Request to OpenAI API timed out" in str(exc.value.detail)
218+
219+
def test_summarize_text_connection_error(self, mock_openai):
220+
"""Test handling of connection error during summarization."""
221+
mock_openai.Completion.create.side_effect = openai.error.APIConnectionError("Connection error.")
222+
request = OpenAIRequest(text="This is the text to summarize.")
223+
service = OpenAIService()
224+
with pytest.raises(HTTPException) as exc:
225+
service.summarize_text(text=request.text)
226+
assert exc.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
227+
assert "Error connecting to OpenAI API" in str(exc.value.detail)
228+
229+
def test_summarize_text_general_api_error(self, mock_openai):
230+
"""Test handling of general API error during summarization."""
231+
mock_openai.Completion.create.side_effect = openai.error.APIError("General API error.")
232+
request = OpenAIRequest(text="This is the text to summarize.")
233+
service = OpenAIService()
234+
with pytest.raises(HTTPException) as exc:
235+
service.summarize_text(text=request.text)
236+
assert exc.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
237+
assert "Error calling OpenAI API" in str(exc.value.detail)
238+
239+
# Test cases for model retrieval
240+
class TestOpenAI_GetModel:
241+
def test_get_model_success(self, mock_openai):
242+
"""Test successful model retrieval with valid model ID."""
243+
mock_openai.Model.retrieve.return_value = MagicMock(
244+
id="model-id",
245+
object="model",
246+
created=1678886400,
247+
owned_by="user-id",
248+
permissions=[{"id": "permission-id", "allow": ["fine-tune", "use"], "created": 1678886400}],
249+
root="model-id",
250+
parent="model-id",
251+
is_moderation_model=False,
252+
is_text_search_model=False,
253+
is_translation_model=False,
254+
is_code_model=False,
255+
is_embedding_model=False,
256+
is_question_answering_model=False,
257+
is_completion_model=True,
258+
is_chat_completion_model=False,
259+
is_fine_tuned=False,
260+
is_available=True,
261+
)
262+
service = OpenAIService()
263+
response = service.get_model(model_id="model-id")
264+
assert response.id == "model-id"
265+
assert response.is_completion_model
266+
assert response.is_available
267+
mock_openai.Model.retrieve.assert_called_once_with("model-id")
268+
269+
def test_get_model_invalid_model_id(self, mock_openai):
270+
"""Test handling of invalid model ID."""
271+
mock_openai.Model.retrieve.side_effect = openai.error.InvalidRequestError("Invalid model ID.")
272+
service = OpenAIService()
273+
with pytest.raises(HTTPException) as exc:
274+
service.get_model(model_id="invalid-model-id")
275+
assert exc.value.status_code == status.HTTP_400_BAD_REQUEST
276+
assert "Invalid request to OpenAI API" in str(exc.value.detail)
277+
278+
def test_get_model_rate_limit(self, mock_openai):
279+
"""Test handling of OpenAI API rate limit during model retrieval."""
280+
mock_openai.Model.retrieve.side_effect = openai.error.RateLimitError("Rate limit exceeded.")
281+
service = OpenAIService()
282+
with pytest.raises(HTTPException) as exc:
283+
service.get_model(model_id="model-id")
284+
assert exc.value.status_code == status.HTTP_429_TOO_MANY_REQUESTS
285+
assert "OpenAI API rate limit exceeded" in str(exc.value.detail)
286+
287+
def test_get_model_authentication_error(self, mock_openai):
288+
"""Test handling of authentication error during model retrieval."""
289+
mock_openai.Model.retrieve.side_effect = openai.error.AuthenticationError("Invalid API key.")
290+
service = OpenAIService()
291+
with pytest.raises(HTTPException) as exc:
292+
service.get_model(model_id="model-id")
293+
assert exc.value.status_code == status.HTTP_401_UNAUTHORIZED
294+
assert "Invalid OpenAI API key" in str(exc.value.detail)
295+
296+
def test_get_model_timeout_error(self, mock_openai):
297+
"""Test handling of timeout error during model retrieval."""
298+
mock_openai.Model.retrieve.side_effect = openai.error.TimeoutError("Request timed out.")
299+
service = OpenAIService()
300+
with pytest.raises(HTTPException) as exc:
301+
service.get_model(model_id="model-id")
302+
assert exc.value.status_code == status.HTTP_504_GATEWAY_TIMEOUT
303+
assert "Request to OpenAI API timed out" in str(exc.value.detail)
304+
305+
def test_get_model_connection_error(self, mock_openai):
306+
"""Test handling of connection error during model retrieval."""
307+
mock_openai.Model.retrieve.side_effect = openai.error.APIConnectionError("Connection error.")
308+
service = OpenAIService()
309+
with pytest.raises(HTTPException) as exc:
310+
service.get_model(model_id="model-id")
311+
assert exc.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
312+
assert "Error connecting to OpenAI API" in str(exc.value.detail)
313+
314+
def test_get_model_general_api_error(self, mock_openai):
315+
"""Test handling of general API error during model retrieval."""
316+
mock_openai.Model.retrieve.side_effect = openai.error.APIError("General API error.")
317+
service = OpenAIService()
318+
with pytest.raises(HTTPException) as exc:
319+
service.get_model(model_id="model-id")
320+
assert exc.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
321+
assert "Error calling OpenAI API" in str(exc.value.detail)

0 commit comments

Comments
 (0)