From 16f8f2660df5f2578da67d5463fa1113a6f3e4f7 Mon Sep 17 00:00:00 2001 From: Oleg Kostromin Date: Sun, 23 Jul 2023 19:34:54 +0200 Subject: [PATCH] added support of gpt4all>=1.0 --- README.md | 8 ++++---- pyproject.toml | 2 +- skllm/gpt4all_client.py | 26 ++++++++++++++++++++------ 3 files changed, 25 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 001046e..7b08ffe 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ SKLLMConfig.set_openai_org("") ```python from skllm.config import SKLLMConfig -SKLLMConfig.set_openai_key("") #use azure key instead +SKLLMConfig.set_openai_key("") # use azure key instead SKLLMConfig.set_azure_api_base("") # start with "azure::" prefix when setting the model name @@ -76,7 +76,7 @@ In order to switch from OpenAI to GPT4ALL model, simply provide a string of the SKLLMConfig.set_openai_key("any string") SKLLMConfig.set_openai_org("any string") -ZeroShotGPTClassifier(openai_model="gpt4all::ggml-gpt4all-j-v1.3-groovy") +ZeroShotGPTClassifier(openai_model="gpt4all::ggml-model-gpt4all-falcon-q4_0.bin") ``` When running for the first time, the model file will be downloaded automatially. @@ -225,11 +225,12 @@ labels = clf.predict(X) ### Text Classification with Google PaLM 2 At the moment 3 PaLM based models are available in test mode: + - `ZeroShotPaLMClassifier` - zero-shot text classification with PaLM 2; - `PaLMClassifier` - fine-tunable text classifier with PaLM 2; - `PaLM` - fine-tunable estimator that can be trained on arbitrary text input-output pairs. -Example: +Example: ```python from skllm.models.palm import PaLMClassifier @@ -311,4 +312,3 @@ X = get_translation_dataset() t = GPTTranslator(openai_model="gpt-3.5-turbo", output_language="English") translated_text = t.fit_transform(X) ``` - diff --git a/pyproject.toml b/pyproject.toml index edb1bf2..806a150 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ classifiers = [ ] [project.optional-dependencies] -gpt4all = ["gpt4all>=0.2.0"] +gpt4all = ["gpt4all>=1.0.0"] [tool.ruff] select = [ diff --git a/skllm/gpt4all_client.py b/skllm/gpt4all_client.py index 2331463..03f360a 100644 --- a/skllm/gpt4all_client.py +++ b/skllm/gpt4all_client.py @@ -8,9 +8,14 @@ _loaded_models = {} -def get_chat_completion(messages: Dict, model: str="ggml-gpt4all-j-v1.3-groovy") -> Dict: - """ - Gets a chat completion from GPT4All +def _make_openai_compatabile(message: str) -> Dict: + return {"choices": [{"message": {"content": message, "role": "assistant"}}]} + + +def get_chat_completion( + messages: Dict, model: str = "ggml-model-gpt4all-falcon-q4_0.bin" +) -> Dict: + """Gets a chat completion from GPT4All. Parameters ---------- @@ -28,11 +33,20 @@ def get_chat_completion(messages: Dict, model: str="ggml-gpt4all-j-v1.3-groovy") "gpt4all is not installed, try `pip install scikit-llm[gpt4all]`" ) if model not in _loaded_models.keys(): - _loaded_models[model] = GPT4All(model) + loaded_model = GPT4All(model) + _loaded_models[model] = loaded_model + loaded_model._current_prompt_template = loaded_model.config["promptTemplate"] - return _loaded_models[model].chat_completion( - messages, verbose=False, streaming=False, temp=1e-10 + prompt = _loaded_models[model]._format_chat_prompt_template( + messages, _loaded_models[model].config["systemPrompt"] ) + generated = _loaded_models[model].generate( + prompt, + streaming=False, + temp=1e-10, + ) + + return _make_openai_compatabile(generated) def unload_models() -> None: