From 00490c4dd08fe80408acf45b9d62d83bd69a7fc0 Mon Sep 17 00:00:00 2001 From: Oleh Date: Tue, 30 May 2023 14:14:58 +0200 Subject: [PATCH 1/2] gpt4all support --- CONTRIBUTING.md | 3 +- README.md | 52 ++++++++++++++++++++++--------- pyproject.toml | 7 ++--- skllm/completions.py | 13 ++++++++ skllm/gpt4all_client.py | 23 ++++++++++++++ skllm/models/gpt_zero_shot_clf.py | 28 +++++++++++------ skllm/openai/chatgpt.py | 27 +++++++++++----- skllm/utils.py | 17 ++++++++-- 8 files changed, 129 insertions(+), 41 deletions(-) create mode 100644 skllm/completions.py create mode 100644 skllm/gpt4all_client.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c2d3cd8..991a0ee 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -16,6 +16,7 @@ There are several ways you can contribute to this project: **Important:** before contributing, we recommend that you open an issue to discuss your planned changes. This allows us to align our goals, provide guidance, and potentially find other contributors interested in collaborating on the same feature or bug fix. > ### Legal Notice +> > When contributing to this project, you must agree that you have authored 100% of the content, that you have the necessary rights to the content and that the content you contribute may be provided under the project license. ## Development dependencies @@ -23,7 +24,7 @@ There are several ways you can contribute to this project: In order to install all development dependencies, run the following command: ```shell -pip install -e ".[dev]" +pip install -r requirements-dev.txt ``` To ensure that you follow the development workflow, please setup the pre-commit hooks: diff --git a/README.md b/README.md index 7da6b27..0467e22 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,14 @@ You can support the project in the following ways: ### Configuring OpenAI API Key -At the moment Scikit-LLM is only compatible with some of the OpenAI models. Hence, a user-provided OpenAI API key is required. +At the moment the majority of the Scikit-LLM estimators are only compatible with some of the OpenAI models. Hence, a user-provided OpenAI API key is required. + +```python +from skllm.config import SKLLMConfig + +SKLLMConfig.set_openai_key("") +SKLLMConfig.set_openai_org("") +``` ```python from skllm.config import SKLLMConfig @@ -39,6 +46,35 @@ SKLLMConfig.set_openai_org("") - If you have a free trial OpenAI account, the [rate limits](https://platform.openai.com/docs/guides/rate-limits/overview) are not sufficient (specifically 3 requests per minute). Please switch to the "pay as you go" plan first. - When calling `SKLLMConfig.set_openai_org`, you have to provide your organization ID and **NOT** the name. You can find your ID [here](https://platform.openai.com/account/org-settings). +### Using GPT4ALL + +In addition to OpenAI, some of the models can use [gpt4all](https://gpt4all.io/index.html) as a backend. + +**This feature is considered higly experimental!** + +In order to use gpt4all, you need to install the corresponding submodule: + +```bash +pip install scikit-llm[gpt4all] +``` + +In order to switch from OpenAI to GPT4ALL model, simply provide a string of the format `gpt4all::` as an argument. While the model runs completely locally, the estimator still treats it as an OpenAI endpoint and will try to check that the API key is present. You can provide any string as a key. + +```python +SKLLMConfig.set_openai_key("any string") +SKLLMConfig.set_openai_org("any string") + +ZeroShotGPTClassifier(openai_model="gpt4all::ggml-gpt4all-j-v1.3-groovy") +``` + +When running for the first time, the model file will be downloaded automatially. + +When using gpt4all please keep the following in mind: + +1. Not all gpt4all models are commercially licensable, please consult gpt4all website for more details. +2. The accuracy of the models may be much lower compared to ones provided by OpenAI (especially gpt-4). +3. Not all of the available models were tested, some may not work with scikit-llm at all. + ### Zero-Shot Text Classification One of the powerful ChatGPT features is the ability to perform text classification without being re-trained. For that, the only requirement is that the labels must be descriptive. @@ -207,17 +243,3 @@ Please be aware that the `max_words` hyperparameter sets a soft limit, which is - [ ] Open source models *The order of the elements in the roadmap is arbitrary and does not reflect the planned order of implementation.* - -## Contributing - -In order to install all development dependencies, run the following command: - -```shell -pip install -e ".[dev]" -``` - -To ensure that you follow the development workflow, please setup the pre-commit hooks: - -```shell -pre-commit install -``` diff --git a/pyproject.toml b/pyproject.toml index be5c561..2d904ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ dependencies = [ "tqdm>=4.60.0", ] name = "scikit-llm" -version = "0.1.0b3" +version = "0.1.0" authors = [ { name="Oleg Kostromin", email="kostromin97@gmail.com" }, { name="Iryna Kondrashchenko", email="iryna230520@gmail.com" }, @@ -24,10 +24,9 @@ classifiers = [ "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ] -dynamic = ["optional-dependencies"] -[tool.setuptools.dynamic.optional-dependencies] -dev = { file = ["requirements-dev.txt"] } +[project.optional-dependencies] +gpt4all = ["gpt4all>=0.2.0"] [tool.ruff] select = [ diff --git a/skllm/completions.py b/skllm/completions.py new file mode 100644 index 0000000..02e07df --- /dev/null +++ b/skllm/completions.py @@ -0,0 +1,13 @@ +from skllm.gpt4all_client import get_chat_completion as _g4a_get_chat_completion +from skllm.openai.chatgpt import get_chat_completion as _oai_get_chat_completion + + +def get_chat_completion( + messages, openai_key=None, openai_org=None, model="gpt-3.5-turbo", max_retries=3 +): + if model.startswith("gpt4all::"): + return _g4a_get_chat_completion(messages, model[9:]) + else: + return _oai_get_chat_completion( + messages, openai_key, openai_org, model, max_retries + ) diff --git a/skllm/gpt4all_client.py b/skllm/gpt4all_client.py new file mode 100644 index 0000000..e8ef0ae --- /dev/null +++ b/skllm/gpt4all_client.py @@ -0,0 +1,23 @@ +try: + from gpt4all import GPT4All +except (ImportError, ModuleNotFoundError): + GPT4All = None + +_loaded_models = {} + + +def get_chat_completion(messages, model="ggml-gpt4all-j-v1.3-groovy"): + if GPT4All is None: + raise ImportError( + "gpt4all is not installed, try `pip install scikit-llm[gpt4all]`" + ) + if model not in _loaded_models.keys(): + _loaded_models[model] = GPT4All(model) + return _loaded_models[model].chat_completion( + messages, verbose=False, streaming=False, temp=1e-10 + ) + + +def unload_models(): + global _loaded_models + _loaded_models = {} diff --git a/skllm/models/gpt_zero_shot_clf.py b/skllm/models/gpt_zero_shot_clf.py index b57f700..fe64cfb 100644 --- a/skllm/models/gpt_zero_shot_clf.py +++ b/skllm/models/gpt_zero_shot_clf.py @@ -8,11 +8,8 @@ from sklearn.base import BaseEstimator, ClassifierMixin from tqdm import tqdm -from skllm.openai.chatgpt import ( - construct_message, - extract_json_key, - get_chat_completion, -) +from skllm.completions import get_chat_completion +from skllm.openai.chatgpt import construct_message, extract_json_key from skllm.openai.mixin import OpenAIMixin as _OAIMixin from skllm.prompts.builders import ( build_zero_shot_prompt_mlc, @@ -101,14 +98,25 @@ def _get_prompt(self, x) -> str: def _predict_single(self, x): completion = self._get_chat_completion(x) try: - label = str( - extract_json_key(completion.choices[0].message["content"], "label") - ) - except Exception: + if self.openai_model.startswith("gpt4all::"): + label = str( + extract_json_key( + completion["choices"][0]["message"]["content"], "label" + ) + ) + else: + label = str( + extract_json_key(completion.choices[0].message["content"], "label") + ) + except Exception as e: + print(completion) + print(f"Could not extract the label from the completion: {str(e)}") label = "" if label not in self.classes_: - label = random.choices(self.classes_, self.probabilities_)[0] + label = label.replace("'", "").replace('"', "") + if label not in self.classes_: # try again + label = random.choices(self.classes_, self.probabilities_)[0] return label def fit( diff --git a/skllm/openai/chatgpt.py b/skllm/openai/chatgpt.py index 15edc50..dc1d862 100644 --- a/skllm/openai/chatgpt.py +++ b/skllm/openai/chatgpt.py @@ -1,34 +1,45 @@ -import openai -from time import sleep import json +from time import sleep + +import openai + from skllm.openai.credentials import set_credentials +from skllm.utils import find_json_in_string + def construct_message(role, content): if role not in ("system", "user", "assistant"): raise ValueError("Invalid role") return {"role": role, "content": content} -def get_chat_completion(messages, key, org, model="gpt-3.5-turbo", max_retries = 3): + +def get_chat_completion(messages, key, org, model="gpt-3.5-turbo", max_retries=3): set_credentials(key, org) error_msg = None error_type = None for _ in range(max_retries): try: completion = openai.ChatCompletion.create( - model=model, temperature=0., messages=messages + model=model, temperature=0.0, messages=messages ) return completion except Exception as e: error_msg = str(e) - error_type = type(e).__name__ + error_type = type(e).__name__ sleep(3) - print(f"Could not obtain the completion after {max_retries} retries: `{error_type} :: {error_msg}`") + print( + f"Could not obtain the completion after {max_retries} retries: `{error_type} ::" + f" {error_msg}`" + ) + def extract_json_key(json_, key): try: - as_json = json.loads(json_.replace('\n', '')) + json_ = json_.replace("\n", "") + json_ = find_json_in_string(json_) + as_json = json.loads(json_) if key not in as_json.keys(): raise KeyError("The required key was not found") return as_json[key] - except Exception as e: + except Exception: return None diff --git a/skllm/utils.py b/skllm/utils.py index 8580124..49b2c03 100644 --- a/skllm/utils.py +++ b/skllm/utils.py @@ -1,11 +1,22 @@ -import numpy as np +import numpy as np import pandas as pd + def to_numpy(X): if isinstance(X, pd.Series): X = X.to_numpy().astype(object) elif isinstance(X, list): - X = np.asarray(X, dtype = object) + X = np.asarray(X, dtype=object) if isinstance(X, np.ndarray) and len(X.shape) > 1: X = np.squeeze(X) - return X \ No newline at end of file + return X + + +def find_json_in_string(string): + start = string.find("{") + end = string.rfind("}") + if start != -1 and end != -1: + json_string = string[start : end + 1] + else: + json_string = {} + return json_string From d43b264fddea9f26135172ac61d0ad5b8593108a Mon Sep 17 00:00:00 2001 From: Oleh Date: Tue, 30 May 2023 14:26:13 +0200 Subject: [PATCH 2/2] updated readme --- README.md | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 0467e22..21b9cf3 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ You can support the project in the following ways: - ⭐ Star Scikit-LLM on GitHub (click the star button in the top right corner) - 🐦 Check out our related project - [Falcon AutoML](https://github.com/OKUA1/falcon) -- 💡 Provide your feedback or propose ideas in the [issues](https://github.com/iryna-kondr/scikit-llm/issues) section +- 💡 Provide your feedback or propose ideas in the [issues](https://github.com/iryna-kondr/scikit-llm/issues) section or [Discord](https://discord.gg/NTaRnRpf) - 🔗 Post about Scikit-LLM on LinkedIn or other platforms ## Documentation 📚 @@ -69,6 +69,11 @@ ZeroShotGPTClassifier(openai_model="gpt4all::ggml-gpt4all-j-v1.3-groovy") When running for the first time, the model file will be downloaded automatially. +At the moment only the following estimators support gpt4all as a backend: +- `ZeroShotGPTClassifier` +- `MultiLabelZeroShotGPTClassifier` +- `FewShotGPTClassifier` + When using gpt4all please keep the following in mind: 1. Not all gpt4all models are commercially licensable, please consult gpt4all website for more details.