From 2c0a6f798793eb259893bdf084c3ea3a48313340 Mon Sep 17 00:00:00 2001 From: Massimiliano Pronesti Date: Sun, 28 May 2023 11:29:33 +0200 Subject: [PATCH] fix(hf): wrong concatenation of prompt and suffix (#179) * fix(hf): wrong concatenation of prompt and suffix * test: enforce wrong concatenation of prompt and suffix --------- Co-authored-by: Gabriele Venturi --- pandasai/llm/base.py | 5 +++-- tests/llms/test_base_hf.py | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/pandasai/llm/base.py b/pandasai/llm/base.py index 4c143a6ca..5a0bf553e 100644 --- a/pandasai/llm/base.py +++ b/pandasai/llm/base.py @@ -226,7 +226,8 @@ def query(self, payload): def call(self, instruction: Prompt, value: str, suffix: str = "") -> str: """Call the LLM""" - payload = instruction + value + suffix + prompt = str(instruction) + payload = prompt + value + suffix # sometimes the API doesn't return a valid response, so we retry passing the # output generated from the previous call as the input @@ -237,7 +238,7 @@ def call(self, instruction: Prompt, value: str, suffix: str = "") -> str: break # replace instruction + value from the inputs to avoid showing it in the output - output = response.replace(instruction + value + suffix, "") + output = response.replace(prompt + value + suffix, "") return output diff --git a/tests/llms/test_base_hf.py b/tests/llms/test_base_hf.py index cc9940091..2482b81d5 100644 --- a/tests/llms/test_base_hf.py +++ b/tests/llms/test_base_hf.py @@ -4,6 +4,7 @@ import requests from pandasai.llm.base import HuggingFaceLLM +from pandasai.prompts.base import Prompt class TestBaseHfLLM: @@ -51,3 +52,21 @@ def test_call(self, mocker): result = huggingface.call("instruction", "value", "suffix") assert result == "Generated text" + + def test_call_removes_original_prompt(self, mocker): + huggingface = HuggingFaceLLM() + huggingface.api_token = "test_token" + + class MockPrompt(Prompt): + text: str = "instruction " + + instruction = MockPrompt() + value = "value " + suffix = "suffix " + + mocker.patch.object( + huggingface, "query", return_value="instruction value suffix generated text" + ) + + result = huggingface.call(instruction, value, suffix) + assert result == "generated text"