1
+ import pytest
2
+ from fastapi import HTTPException , status
3
+ from unittest .mock import patch , MagicMock
4
+
5
+ from openai_api_client .dependencies .openai import OpenAIService , openai_service
6
+ from openai_api_client .schemas .openai import OpenAIRequest , OpenAIResponse , OpenAIModel
7
+
8
+ # Mocking the OpenAI library for unit testing
9
+ @pytest .fixture
10
+ def mock_openai ():
11
+ with patch ("openai_api_client.dependencies.openai.openai" ) as mock_openai :
12
+ yield mock_openai
13
+
14
+ # Test cases for text completion
15
+ class TestOpenAI_CompleteText :
16
+ def test_complete_text_success (self , mock_openai ):
17
+ """Test successful text completion with valid parameters."""
18
+ mock_openai .Completion .create .return_value = MagicMock (
19
+ choices = [MagicMock (text = "This is the completed text." )],
20
+ )
21
+ request = OpenAIRequest (text = "This is the prompt." )
22
+ service = OpenAIService ()
23
+ response = service .complete_text (text = request .text )
24
+ assert response .response == "This is the completed text."
25
+ mock_openai .Completion .create .assert_called_once_with (
26
+ engine = request .model , prompt = request .text , temperature = request .temperature , max_tokens = request .max_tokens
27
+ )
28
+
29
+ def test_complete_text_invalid_model (self , mock_openai ):
30
+ """Test handling of invalid OpenAI model."""
31
+ mock_openai .Completion .create .side_effect = openai .error .InvalidRequestError ("Invalid model." )
32
+ request = OpenAIRequest (text = "This is the prompt." , model = "invalid_model" )
33
+ service = OpenAIService ()
34
+ with pytest .raises (HTTPException ) as exc :
35
+ service .complete_text (text = request .text , model = request .model )
36
+ assert exc .value .status_code == status .HTTP_400_BAD_REQUEST
37
+ assert "Invalid request to OpenAI API" in str (exc .value .detail )
38
+
39
+ def test_complete_text_rate_limit (self , mock_openai ):
40
+ """Test handling of OpenAI API rate limit."""
41
+ mock_openai .Completion .create .side_effect = openai .error .RateLimitError ("Rate limit exceeded." )
42
+ request = OpenAIRequest (text = "This is the prompt." )
43
+ service = OpenAIService ()
44
+ with pytest .raises (HTTPException ) as exc :
45
+ service .complete_text (text = request .text )
46
+ assert exc .value .status_code == status .HTTP_429_TOO_MANY_REQUESTS
47
+ assert "OpenAI API rate limit exceeded" in str (exc .value .detail )
48
+
49
+ def test_complete_text_authentication_error (self , mock_openai ):
50
+ """Test handling of authentication error with OpenAI API."""
51
+ mock_openai .Completion .create .side_effect = openai .error .AuthenticationError ("Invalid API key." )
52
+ request = OpenAIRequest (text = "This is the prompt." )
53
+ service = OpenAIService ()
54
+ with pytest .raises (HTTPException ) as exc :
55
+ service .complete_text (text = request .text )
56
+ assert exc .value .status_code == status .HTTP_401_UNAUTHORIZED
57
+ assert "Invalid OpenAI API key" in str (exc .value .detail )
58
+
59
+ def test_complete_text_timeout_error (self , mock_openai ):
60
+ """Test handling of timeout error during API call."""
61
+ mock_openai .Completion .create .side_effect = openai .error .TimeoutError ("Request timed out." )
62
+ request = OpenAIRequest (text = "This is the prompt." )
63
+ service = OpenAIService ()
64
+ with pytest .raises (HTTPException ) as exc :
65
+ service .complete_text (text = request .text )
66
+ assert exc .value .status_code == status .HTTP_504_GATEWAY_TIMEOUT
67
+ assert "Request to OpenAI API timed out" in str (exc .value .detail )
68
+
69
+ def test_complete_text_connection_error (self , mock_openai ):
70
+ """Test handling of connection error with OpenAI API."""
71
+ mock_openai .Completion .create .side_effect = openai .error .APIConnectionError ("Connection error." )
72
+ request = OpenAIRequest (text = "This is the prompt." )
73
+ service = OpenAIService ()
74
+ with pytest .raises (HTTPException ) as exc :
75
+ service .complete_text (text = request .text )
76
+ assert exc .value .status_code == status .HTTP_500_INTERNAL_SERVER_ERROR
77
+ assert "Error connecting to OpenAI API" in str (exc .value .detail )
78
+
79
+ def test_complete_text_general_api_error (self , mock_openai ):
80
+ """Test handling of general API error during call."""
81
+ mock_openai .Completion .create .side_effect = openai .error .APIError ("General API error." )
82
+ request = OpenAIRequest (text = "This is the prompt." )
83
+ service = OpenAIService ()
84
+ with pytest .raises (HTTPException ) as exc :
85
+ service .complete_text (text = request .text )
86
+ assert exc .value .status_code == status .HTTP_500_INTERNAL_SERVER_ERROR
87
+ assert "Error calling OpenAI API" in str (exc .value .detail )
88
+
89
+ # Test cases for text translation
90
+ class TestOpenAI_TranslateText :
91
+ def test_translate_text_success (self , mock_openai ):
92
+ """Test successful text translation with valid parameters."""
93
+ mock_openai .Translation .create .return_value = MagicMock (
94
+ choices = [MagicMock (text = "This is the translated text." )],
95
+ )
96
+ request = OpenAIRequest (text = "This is the text to translate." , source_language = "en" , target_language = "fr" )
97
+ service = OpenAIService ()
98
+ response = service .translate_text (text = request .text , source_language = request .source_language , target_language = request .target_language )
99
+ assert response .response == "This is the translated text."
100
+ mock_openai .Translation .create .assert_called_once_with (
101
+ model = "gpt-3.5-turbo" , from_language = "en" , to_language = "fr" , text = "This is the text to translate."
102
+ )
103
+
104
+ def test_translate_text_invalid_request (self , mock_openai ):
105
+ """Test handling of invalid request for translation."""
106
+ mock_openai .Translation .create .side_effect = openai .error .InvalidRequestError ("Invalid request." )
107
+ request = OpenAIRequest (text = "This is the text to translate." , source_language = "invalid" , target_language = "fr" )
108
+ service = OpenAIService ()
109
+ with pytest .raises (HTTPException ) as exc :
110
+ service .translate_text (text = request .text , source_language = request .source_language , target_language = request .target_language )
111
+ assert exc .value .status_code == status .HTTP_400_BAD_REQUEST
112
+ assert "Invalid request to OpenAI API" in str (exc .value .detail )
113
+
114
+ def test_translate_text_rate_limit (self , mock_openai ):
115
+ """Test handling of OpenAI API rate limit during translation."""
116
+ mock_openai .Translation .create .side_effect = openai .error .RateLimitError ("Rate limit exceeded." )
117
+ request = OpenAIRequest (text = "This is the text to translate." , source_language = "en" , target_language = "fr" )
118
+ service = OpenAIService ()
119
+ with pytest .raises (HTTPException ) as exc :
120
+ service .translate_text (text = request .text , source_language = request .source_language , target_language = request .target_language )
121
+ assert exc .value .status_code == status .HTTP_429_TOO_MANY_REQUESTS
122
+ assert "OpenAI API rate limit exceeded" in str (exc .value .detail )
123
+
124
+ def test_translate_text_authentication_error (self , mock_openai ):
125
+ """Test handling of authentication error during translation."""
126
+ mock_openai .Translation .create .side_effect = openai .error .AuthenticationError ("Invalid API key." )
127
+ request = OpenAIRequest (text = "This is the text to translate." , source_language = "en" , target_language = "fr" )
128
+ service = OpenAIService ()
129
+ with pytest .raises (HTTPException ) as exc :
130
+ service .translate_text (text = request .text , source_language = request .source_language , target_language = request .target_language )
131
+ assert exc .value .status_code == status .HTTP_401_UNAUTHORIZED
132
+ assert "Invalid OpenAI API key" in str (exc .value .detail )
133
+
134
+ def test_translate_text_timeout_error (self , mock_openai ):
135
+ """Test handling of timeout error during translation."""
136
+ mock_openai .Translation .create .side_effect = openai .error .TimeoutError ("Request timed out." )
137
+ request = OpenAIRequest (text = "This is the text to translate." , source_language = "en" , target_language = "fr" )
138
+ service = OpenAIService ()
139
+ with pytest .raises (HTTPException ) as exc :
140
+ service .translate_text (text = request .text , source_language = request .source_language , target_language = request .target_language )
141
+ assert exc .value .status_code == status .HTTP_504_GATEWAY_TIMEOUT
142
+ assert "Request to OpenAI API timed out" in str (exc .value .detail )
143
+
144
+ def test_translate_text_connection_error (self , mock_openai ):
145
+ """Test handling of connection error during translation."""
146
+ mock_openai .Translation .create .side_effect = openai .error .APIConnectionError ("Connection error." )
147
+ request = OpenAIRequest (text = "This is the text to translate." , source_language = "en" , target_language = "fr" )
148
+ service = OpenAIService ()
149
+ with pytest .raises (HTTPException ) as exc :
150
+ service .translate_text (text = request .text , source_language = request .source_language , target_language = request .target_language )
151
+ assert exc .value .status_code == status .HTTP_500_INTERNAL_SERVER_ERROR
152
+ assert "Error connecting to OpenAI API" in str (exc .value .detail )
153
+
154
+ def test_translate_text_general_api_error (self , mock_openai ):
155
+ """Test handling of general API error during translation."""
156
+ mock_openai .Translation .create .side_effect = openai .error .APIError ("General API error." )
157
+ request = OpenAIRequest (text = "This is the text to translate." , source_language = "en" , target_language = "fr" )
158
+ service = OpenAIService ()
159
+ with pytest .raises (HTTPException ) as exc :
160
+ service .translate_text (text = request .text , source_language = request .source_language , target_language = request .target_language )
161
+ assert exc .value .status_code == status .HTTP_500_INTERNAL_SERVER_ERROR
162
+ assert "Error calling OpenAI API" in str (exc .value .detail )
163
+
164
+ # Test cases for text summarization
165
+ class TestOpenAI_SummarizeText :
166
+ def test_summarize_text_success (self , mock_openai ):
167
+ """Test successful text summarization with valid parameters."""
168
+ mock_openai .Completion .create .return_value = MagicMock (
169
+ choices = [MagicMock (text = "This is the summarized text." )],
170
+ )
171
+ request = OpenAIRequest (text = "This is the text to summarize." )
172
+ service = OpenAIService ()
173
+ response = service .summarize_text (text = request .text )
174
+ assert response .response == "This is the summarized text."
175
+ mock_openai .Completion .create .assert_called_once_with (
176
+ engine = request .model , prompt = f"Summarize the following text:\n \n { request .text } " , temperature = 0.7 , max_tokens = 256
177
+ )
178
+
179
+ def test_summarize_text_invalid_request (self , mock_openai ):
180
+ """Test handling of invalid request for summarization."""
181
+ mock_openai .Completion .create .side_effect = openai .error .InvalidRequestError ("Invalid request." )
182
+ request = OpenAIRequest (text = "This is the text to summarize." )
183
+ service = OpenAIService ()
184
+ with pytest .raises (HTTPException ) as exc :
185
+ service .summarize_text (text = request .text )
186
+ assert exc .value .status_code == status .HTTP_400_BAD_REQUEST
187
+ assert "Invalid request to OpenAI API" in str (exc .value .detail )
188
+
189
+ def test_summarize_text_rate_limit (self , mock_openai ):
190
+ """Test handling of OpenAI API rate limit during summarization."""
191
+ mock_openai .Completion .create .side_effect = openai .error .RateLimitError ("Rate limit exceeded." )
192
+ request = OpenAIRequest (text = "This is the text to summarize." )
193
+ service = OpenAIService ()
194
+ with pytest .raises (HTTPException ) as exc :
195
+ service .summarize_text (text = request .text )
196
+ assert exc .value .status_code == status .HTTP_429_TOO_MANY_REQUESTS
197
+ assert "OpenAI API rate limit exceeded" in str (exc .value .detail )
198
+
199
+ def test_summarize_text_authentication_error (self , mock_openai ):
200
+ """Test handling of authentication error during summarization."""
201
+ mock_openai .Completion .create .side_effect = openai .error .AuthenticationError ("Invalid API key." )
202
+ request = OpenAIRequest (text = "This is the text to summarize." )
203
+ service = OpenAIService ()
204
+ with pytest .raises (HTTPException ) as exc :
205
+ service .summarize_text (text = request .text )
206
+ assert exc .value .status_code == status .HTTP_401_UNAUTHORIZED
207
+ assert "Invalid OpenAI API key" in str (exc .value .detail )
208
+
209
+ def test_summarize_text_timeout_error (self , mock_openai ):
210
+ """Test handling of timeout error during summarization."""
211
+ mock_openai .Completion .create .side_effect = openai .error .TimeoutError ("Request timed out." )
212
+ request = OpenAIRequest (text = "This is the text to summarize." )
213
+ service = OpenAIService ()
214
+ with pytest .raises (HTTPException ) as exc :
215
+ service .summarize_text (text = request .text )
216
+ assert exc .value .status_code == status .HTTP_504_GATEWAY_TIMEOUT
217
+ assert "Request to OpenAI API timed out" in str (exc .value .detail )
218
+
219
+ def test_summarize_text_connection_error (self , mock_openai ):
220
+ """Test handling of connection error during summarization."""
221
+ mock_openai .Completion .create .side_effect = openai .error .APIConnectionError ("Connection error." )
222
+ request = OpenAIRequest (text = "This is the text to summarize." )
223
+ service = OpenAIService ()
224
+ with pytest .raises (HTTPException ) as exc :
225
+ service .summarize_text (text = request .text )
226
+ assert exc .value .status_code == status .HTTP_500_INTERNAL_SERVER_ERROR
227
+ assert "Error connecting to OpenAI API" in str (exc .value .detail )
228
+
229
+ def test_summarize_text_general_api_error (self , mock_openai ):
230
+ """Test handling of general API error during summarization."""
231
+ mock_openai .Completion .create .side_effect = openai .error .APIError ("General API error." )
232
+ request = OpenAIRequest (text = "This is the text to summarize." )
233
+ service = OpenAIService ()
234
+ with pytest .raises (HTTPException ) as exc :
235
+ service .summarize_text (text = request .text )
236
+ assert exc .value .status_code == status .HTTP_500_INTERNAL_SERVER_ERROR
237
+ assert "Error calling OpenAI API" in str (exc .value .detail )
238
+
239
+ # Test cases for model retrieval
240
+ class TestOpenAI_GetModel :
241
+ def test_get_model_success (self , mock_openai ):
242
+ """Test successful model retrieval with valid model ID."""
243
+ mock_openai .Model .retrieve .return_value = MagicMock (
244
+ id = "model-id" ,
245
+ object = "model" ,
246
+ created = 1678886400 ,
247
+ owned_by = "user-id" ,
248
+ permissions = [{"id" : "permission-id" , "allow" : ["fine-tune" , "use" ], "created" : 1678886400 }],
249
+ root = "model-id" ,
250
+ parent = "model-id" ,
251
+ is_moderation_model = False ,
252
+ is_text_search_model = False ,
253
+ is_translation_model = False ,
254
+ is_code_model = False ,
255
+ is_embedding_model = False ,
256
+ is_question_answering_model = False ,
257
+ is_completion_model = True ,
258
+ is_chat_completion_model = False ,
259
+ is_fine_tuned = False ,
260
+ is_available = True ,
261
+ )
262
+ service = OpenAIService ()
263
+ response = service .get_model (model_id = "model-id" )
264
+ assert response .id == "model-id"
265
+ assert response .is_completion_model
266
+ assert response .is_available
267
+ mock_openai .Model .retrieve .assert_called_once_with ("model-id" )
268
+
269
+ def test_get_model_invalid_model_id (self , mock_openai ):
270
+ """Test handling of invalid model ID."""
271
+ mock_openai .Model .retrieve .side_effect = openai .error .InvalidRequestError ("Invalid model ID." )
272
+ service = OpenAIService ()
273
+ with pytest .raises (HTTPException ) as exc :
274
+ service .get_model (model_id = "invalid-model-id" )
275
+ assert exc .value .status_code == status .HTTP_400_BAD_REQUEST
276
+ assert "Invalid request to OpenAI API" in str (exc .value .detail )
277
+
278
+ def test_get_model_rate_limit (self , mock_openai ):
279
+ """Test handling of OpenAI API rate limit during model retrieval."""
280
+ mock_openai .Model .retrieve .side_effect = openai .error .RateLimitError ("Rate limit exceeded." )
281
+ service = OpenAIService ()
282
+ with pytest .raises (HTTPException ) as exc :
283
+ service .get_model (model_id = "model-id" )
284
+ assert exc .value .status_code == status .HTTP_429_TOO_MANY_REQUESTS
285
+ assert "OpenAI API rate limit exceeded" in str (exc .value .detail )
286
+
287
+ def test_get_model_authentication_error (self , mock_openai ):
288
+ """Test handling of authentication error during model retrieval."""
289
+ mock_openai .Model .retrieve .side_effect = openai .error .AuthenticationError ("Invalid API key." )
290
+ service = OpenAIService ()
291
+ with pytest .raises (HTTPException ) as exc :
292
+ service .get_model (model_id = "model-id" )
293
+ assert exc .value .status_code == status .HTTP_401_UNAUTHORIZED
294
+ assert "Invalid OpenAI API key" in str (exc .value .detail )
295
+
296
+ def test_get_model_timeout_error (self , mock_openai ):
297
+ """Test handling of timeout error during model retrieval."""
298
+ mock_openai .Model .retrieve .side_effect = openai .error .TimeoutError ("Request timed out." )
299
+ service = OpenAIService ()
300
+ with pytest .raises (HTTPException ) as exc :
301
+ service .get_model (model_id = "model-id" )
302
+ assert exc .value .status_code == status .HTTP_504_GATEWAY_TIMEOUT
303
+ assert "Request to OpenAI API timed out" in str (exc .value .detail )
304
+
305
+ def test_get_model_connection_error (self , mock_openai ):
306
+ """Test handling of connection error during model retrieval."""
307
+ mock_openai .Model .retrieve .side_effect = openai .error .APIConnectionError ("Connection error." )
308
+ service = OpenAIService ()
309
+ with pytest .raises (HTTPException ) as exc :
310
+ service .get_model (model_id = "model-id" )
311
+ assert exc .value .status_code == status .HTTP_500_INTERNAL_SERVER_ERROR
312
+ assert "Error connecting to OpenAI API" in str (exc .value .detail )
313
+
314
+ def test_get_model_general_api_error (self , mock_openai ):
315
+ """Test handling of general API error during model retrieval."""
316
+ mock_openai .Model .retrieve .side_effect = openai .error .APIError ("General API error." )
317
+ service = OpenAIService ()
318
+ with pytest .raises (HTTPException ) as exc :
319
+ service .get_model (model_id = "model-id" )
320
+ assert exc .value .status_code == status .HTTP_500_INTERNAL_SERVER_ERROR
321
+ assert "Error calling OpenAI API" in str (exc .value .detail )
0 commit comments