diff --git a/skllm/openai/chatgpt.py b/skllm/openai/chatgpt.py index 9becebd..416d23a 100644 --- a/skllm/openai/chatgpt.py +++ b/skllm/openai/chatgpt.py @@ -19,10 +19,10 @@ def get_chat_completion(messages, key, org, model="gpt-3.5-turbo", max_retries = continue def extract_json_key(json_, key): - try: - as_json = json.loads(json_.replace('\n', '')) + try: + as_json = json.loads(json_.replace('\n', '')) if key not in as_json.keys(): raise KeyError("The required key was not found") return as_json[key] except Exception as e: - return None \ No newline at end of file + return None diff --git a/skllm/tests/__init__.py b/skllm/tests/__init__.py new file mode 100644 index 0000000..19f714a --- /dev/null +++ b/skllm/tests/__init__.py @@ -0,0 +1,2 @@ +from . import test_chatgpt +from . import test_gpt_zero_shot_clf diff --git a/skllm/tests/test_chatgpt.py b/skllm/tests/test_chatgpt.py new file mode 100644 index 0000000..12706cf --- /dev/null +++ b/skllm/tests/test_chatgpt.py @@ -0,0 +1,42 @@ +import unittest +from unittest.mock import patch +from skllm.openai.chatgpt import construct_message, get_chat_completion, extract_json_key + +class TestChatGPT(unittest.TestCase): + + @patch("skllm.openai.credentials.set_credentials") + @patch("openai.ChatCompletion.create") + def test_get_chat_completion(self, mock_create, mock_set_credentials): + messages = [{"role": "system", "content": "Hello"}] + key = "some_key" + org = "some_org" + model = "gpt-3.5-turbo" + mock_create.side_effect = [Exception("API error"), "success"] + + result = get_chat_completion(messages, key, org, model) + + self.assertTrue(mock_set_credentials.call_count <= 1, "set_credentials should be called at most once") + self.assertEqual(mock_create.call_count, 2, "ChatCompletion.create should be called twice due to an exception on the first call") + self.assertEqual(result, "success") + + def test_construct_message(self): + role = "user" + content = "Hello, World!" + message = construct_message(role, content) + self.assertEqual(message, {"role": role, "content": content}) + with self.assertRaises(ValueError): + construct_message("invalid_role", content) + + def test_extract_json_key(self): + json_ = '{"key": "value"}' + key = "key" + result = extract_json_key(json_, key) + self.assertEqual(result, "value") + + # Given that the function returns None when a KeyError occurs, adjust the assertion + result_with_invalid_key = extract_json_key(json_, "invalid_key") + self.assertEqual(result_with_invalid_key, None) + + +if __name__ == '__main__': + unittest.main() diff --git a/skllm/tests/test_gpt_zero_shot_clf.py b/skllm/tests/test_gpt_zero_shot_clf.py new file mode 100644 index 0000000..b48da9c --- /dev/null +++ b/skllm/tests/test_gpt_zero_shot_clf.py @@ -0,0 +1,42 @@ +import unittest +import json +from unittest.mock import patch, MagicMock +import numpy as np +from skllm.models.gpt_zero_shot_clf import ZeroShotGPTClassifier, MultiLabelZeroShotGPTClassifier + +class TestZeroShotGPTClassifier(unittest.TestCase): + + @patch("skllm.models.gpt_zero_shot_clf.get_chat_completion", return_value=MagicMock()) + def test_fit_predict(self, mock_get_chat_completion): + clf = ZeroShotGPTClassifier(openai_key="mock_key", openai_org="mock_org") # Mock keys + X = np.array(["text1", "text2", "text3"]) + y = np.array(["class1", "class2", "class1"]) + clf.fit(X, y) + + self.assertEqual(set(clf.classes_), set(["class1", "class2"])) + self.assertEqual(clf.probabilities_, [2/3, 1/3]) + + mock_get_chat_completion.return_value.choices[0].message = {"content": json.dumps({"label": "class1"})} + predictions = clf.predict(X) + + self.assertEqual(predictions, ["class1", "class1", "class1"]) + +class TestMultiLabelZeroShotGPTClassifier(unittest.TestCase): + + @patch("skllm.models.gpt_zero_shot_clf.get_chat_completion", return_value=MagicMock()) + def test_fit_predict(self, mock_get_chat_completion): + clf = MultiLabelZeroShotGPTClassifier(openai_key="mock_key", openai_org="mock_org") # Mock keys + X = np.array(["text1", "text2", "text3"]) + y = [["class1", "class2"], ["class1", "class2"], ["class1", "class2"]] # Adjusted y to ensure [0.5, 0.5] probability + clf.fit(X, y) + + self.assertEqual(set(clf.classes_), set(["class1", "class2"])) + self.assertEqual(clf.probabilities_, [0.5, 0.5]) + + mock_get_chat_completion.return_value.choices[0].message = {"content": json.dumps({"label": ["class1", "class2"]})} + predictions = clf.predict(X) + + self.assertEqual(predictions, [["class1", "class2"], ["class1", "class2"], ["class1", "class2"]]) + +if __name__ == '__main__': + unittest.main()