Skip to content

Commit

Permalink
fix(hf): wrong concatenation of prompt and suffix (#179)
Browse files Browse the repository at this point in the history
* fix(hf): wrong concatenation of prompt and suffix

* test: enforce wrong concatenation of prompt and suffix

---------

Co-authored-by: Gabriele Venturi <lele.venturi@gmail.com>
  • Loading branch information
mspronesti and gventuri committed May 28, 2023
1 parent 65cc9c8 commit 2c0a6f7
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
5 changes: 3 additions & 2 deletions pandasai/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ def query(self, payload):
def call(self, instruction: Prompt, value: str, suffix: str = "") -> str:
"""Call the LLM"""

payload = instruction + value + suffix
prompt = str(instruction)
payload = prompt + value + suffix

# sometimes the API doesn't return a valid response, so we retry passing the
# output generated from the previous call as the input
Expand All @@ -237,7 +238,7 @@ def call(self, instruction: Prompt, value: str, suffix: str = "") -> str:
break

# replace instruction + value from the inputs to avoid showing it in the output
output = response.replace(instruction + value + suffix, "")
output = response.replace(prompt + value + suffix, "")
return output


Expand Down
19 changes: 19 additions & 0 deletions tests/llms/test_base_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import requests

from pandasai.llm.base import HuggingFaceLLM
from pandasai.prompts.base import Prompt


class TestBaseHfLLM:
Expand Down Expand Up @@ -51,3 +52,21 @@ def test_call(self, mocker):

result = huggingface.call("instruction", "value", "suffix")
assert result == "Generated text"

def test_call_removes_original_prompt(self, mocker):
huggingface = HuggingFaceLLM()
huggingface.api_token = "test_token"

class MockPrompt(Prompt):
text: str = "instruction "

instruction = MockPrompt()
value = "value "
suffix = "suffix "

mocker.patch.object(
huggingface, "query", return_value="instruction value suffix generated text"
)

result = huggingface.call(instruction, value, suffix)
assert result == "generated text"

0 comments on commit 2c0a6f7

Please sign in to comment.