Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions skllm/models/gpt_zero_shot_clf.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,16 +98,11 @@ def _get_prompt(self, x) -> str:
def _predict_single(self, x):
completion = self._get_chat_completion(x)
try:
if self.openai_model.startswith("gpt4all::"):
label = str(
extract_json_key(
completion["choices"][0]["message"]["content"], "label"
)
)
else:
label = str(
extract_json_key(completion.choices[0].message["content"], "label")
label = str(
extract_json_key(
completion["choices"][0]["message"]["content"], "label"
)
)
except Exception as e:
print(completion)
print(f"Could not extract the label from the completion: {str(e)}")
Expand Down Expand Up @@ -155,10 +150,13 @@ def _get_prompt(self, x) -> str:
def _predict_single(self, x):
completion = self._get_chat_completion(x)
try:
labels = extract_json_key(completion.choices[0].message["content"], "label")
labels = extract_json_key(completion["choices"][0]["message"]["content"], "label")
if not isinstance(labels, list):
raise RuntimeError("Invalid labels type, expected list")
except Exception:
labels = labels.split(",")
labels = [l.strip() for l in labels]
except Exception as e:
print(completion)
print(f"Could not extract the label from the completion: {str(e)}")
labels = []

labels = list(filter(lambda l: l in self.classes_, labels))
Expand Down
24 changes: 15 additions & 9 deletions skllm/openai/chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,18 @@ def get_chat_completion(messages, key, org, model="gpt-3.5-turbo", max_retries=3


def extract_json_key(json_, key):
try:
json_ = json_.replace("\n", "")
json_ = find_json_in_string(json_)
as_json = json.loads(json_)
if key not in as_json.keys():
raise KeyError("The required key was not found")
return as_json[key]
except Exception:
return None
original_json = json_
for i in range(2):
try:
json_ = original_json.replace("\n", "")
if i == 1:
json_ = json_.replace("'", "\"")
json_ = find_json_in_string(json_)
as_json = json.loads(json_)
if key not in as_json.keys():
raise KeyError("The required key was not found")
return as_json[key]
except Exception:
if i == 0:
continue
return None
2 changes: 1 addition & 1 deletion skllm/prompts/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
Perform the following tasks:
1. Identify to which categories the provided text belongs to with the highest probability.
2. Assign the text sample to at least 1 but up to {max_cats} categories based on the probabilities.
3. Provide your response in a JSON format containing a single key `label` and a value corresponding to the list of assigned categories. Do not provide any additional information except the JSON.
3. Provide your response in a JSON format containing a single key `label` and a value corresponding to the array of assigned categories. Do not provide any additional information except the JSON.

List of categories: {labels}

Expand Down