Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions skllm/openai/chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ def get_chat_completion(messages, key, org, model="gpt-3.5-turbo", max_retries =
continue

def extract_json_key(json_, key):
try:
as_json = json.loads(json_.replace('\n', ''))
try:
as_json = json.loads(json_.replace('\n', ''))
if key not in as_json.keys():
raise KeyError("The required key was not found")
return as_json[key]
except Exception as e:
return None
return None
2 changes: 2 additions & 0 deletions skllm/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from . import test_chatgpt
from . import test_gpt_zero_shot_clf
42 changes: 42 additions & 0 deletions skllm/tests/test_chatgpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import unittest
from unittest.mock import patch
from skllm.openai.chatgpt import construct_message, get_chat_completion, extract_json_key

class TestChatGPT(unittest.TestCase):

@patch("skllm.openai.credentials.set_credentials")
@patch("openai.ChatCompletion.create")
def test_get_chat_completion(self, mock_create, mock_set_credentials):
messages = [{"role": "system", "content": "Hello"}]
key = "some_key"
org = "some_org"
model = "gpt-3.5-turbo"
mock_create.side_effect = [Exception("API error"), "success"]

result = get_chat_completion(messages, key, org, model)

self.assertTrue(mock_set_credentials.call_count <= 1, "set_credentials should be called at most once")
self.assertEqual(mock_create.call_count, 2, "ChatCompletion.create should be called twice due to an exception on the first call")
self.assertEqual(result, "success")

def test_construct_message(self):
role = "user"
content = "Hello, World!"
message = construct_message(role, content)
self.assertEqual(message, {"role": role, "content": content})
with self.assertRaises(ValueError):
construct_message("invalid_role", content)

def test_extract_json_key(self):
json_ = '{"key": "value"}'
key = "key"
result = extract_json_key(json_, key)
self.assertEqual(result, "value")

# Given that the function returns None when a KeyError occurs, adjust the assertion
result_with_invalid_key = extract_json_key(json_, "invalid_key")
self.assertEqual(result_with_invalid_key, None)


if __name__ == '__main__':
unittest.main()
42 changes: 42 additions & 0 deletions skllm/tests/test_gpt_zero_shot_clf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import unittest
import json
from unittest.mock import patch, MagicMock
import numpy as np
from skllm.models.gpt_zero_shot_clf import ZeroShotGPTClassifier, MultiLabelZeroShotGPTClassifier

class TestZeroShotGPTClassifier(unittest.TestCase):

@patch("skllm.models.gpt_zero_shot_clf.get_chat_completion", return_value=MagicMock())
def test_fit_predict(self, mock_get_chat_completion):
clf = ZeroShotGPTClassifier(openai_key="mock_key", openai_org="mock_org") # Mock keys
X = np.array(["text1", "text2", "text3"])
y = np.array(["class1", "class2", "class1"])
clf.fit(X, y)

self.assertEqual(set(clf.classes_), set(["class1", "class2"]))
self.assertEqual(clf.probabilities_, [2/3, 1/3])

mock_get_chat_completion.return_value.choices[0].message = {"content": json.dumps({"label": "class1"})}
predictions = clf.predict(X)

self.assertEqual(predictions, ["class1", "class1", "class1"])

class TestMultiLabelZeroShotGPTClassifier(unittest.TestCase):

@patch("skllm.models.gpt_zero_shot_clf.get_chat_completion", return_value=MagicMock())
def test_fit_predict(self, mock_get_chat_completion):
clf = MultiLabelZeroShotGPTClassifier(openai_key="mock_key", openai_org="mock_org") # Mock keys
X = np.array(["text1", "text2", "text3"])
y = [["class1", "class2"], ["class1", "class2"], ["class1", "class2"]] # Adjusted y to ensure [0.5, 0.5] probability
clf.fit(X, y)

self.assertEqual(set(clf.classes_), set(["class1", "class2"]))
self.assertEqual(clf.probabilities_, [0.5, 0.5])

mock_get_chat_completion.return_value.choices[0].message = {"content": json.dumps({"label": ["class1", "class2"]})}
predictions = clf.predict(X)

self.assertEqual(predictions, [["class1", "class2"], ["class1", "class2"], ["class1", "class2"]])

if __name__ == '__main__':
unittest.main()