diff --git a/skllm/models/gpt_zero_shot_clf.py b/skllm/models/gpt_zero_shot_clf.py index fe64cfb..fa18746 100644 --- a/skllm/models/gpt_zero_shot_clf.py +++ b/skllm/models/gpt_zero_shot_clf.py @@ -98,16 +98,11 @@ def _get_prompt(self, x) -> str: def _predict_single(self, x): completion = self._get_chat_completion(x) try: - if self.openai_model.startswith("gpt4all::"): - label = str( - extract_json_key( - completion["choices"][0]["message"]["content"], "label" - ) - ) - else: - label = str( - extract_json_key(completion.choices[0].message["content"], "label") + label = str( + extract_json_key( + completion["choices"][0]["message"]["content"], "label" ) + ) except Exception as e: print(completion) print(f"Could not extract the label from the completion: {str(e)}") @@ -155,10 +150,13 @@ def _get_prompt(self, x) -> str: def _predict_single(self, x): completion = self._get_chat_completion(x) try: - labels = extract_json_key(completion.choices[0].message["content"], "label") + labels = extract_json_key(completion["choices"][0]["message"]["content"], "label") if not isinstance(labels, list): - raise RuntimeError("Invalid labels type, expected list") - except Exception: + labels = labels.split(",") + labels = [l.strip() for l in labels] + except Exception as e: + print(completion) + print(f"Could not extract the label from the completion: {str(e)}") labels = [] labels = list(filter(lambda l: l in self.classes_, labels)) diff --git a/skllm/openai/chatgpt.py b/skllm/openai/chatgpt.py index dc1d862..3a73004 100644 --- a/skllm/openai/chatgpt.py +++ b/skllm/openai/chatgpt.py @@ -34,12 +34,18 @@ def get_chat_completion(messages, key, org, model="gpt-3.5-turbo", max_retries=3 def extract_json_key(json_, key): - try: - json_ = json_.replace("\n", "") - json_ = find_json_in_string(json_) - as_json = json.loads(json_) - if key not in as_json.keys(): - raise KeyError("The required key was not found") - return as_json[key] - except Exception: - return None + original_json = json_ + for i in range(2): + try: + json_ = original_json.replace("\n", "") + if i == 1: + json_ = json_.replace("'", "\"") + json_ = find_json_in_string(json_) + as_json = json.loads(json_) + if key not in as_json.keys(): + raise KeyError("The required key was not found") + return as_json[key] + except Exception: + if i == 0: + continue + return None \ No newline at end of file diff --git a/skllm/prompts/templates.py b/skllm/prompts/templates.py index cd05586..deecf80 100644 --- a/skllm/prompts/templates.py +++ b/skllm/prompts/templates.py @@ -44,7 +44,7 @@ Perform the following tasks: 1. Identify to which categories the provided text belongs to with the highest probability. 2. Assign the text sample to at least 1 but up to {max_cats} categories based on the probabilities. -3. Provide your response in a JSON format containing a single key `label` and a value corresponding to the list of assigned categories. Do not provide any additional information except the JSON. +3. Provide your response in a JSON format containing a single key `label` and a value corresponding to the array of assigned categories. Do not provide any additional information except the JSON. List of categories: {labels}