-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
prompt_generation.py
68 lines (55 loc) · 2.21 KB
/
prompt_generation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from typing import Any, Union
from pandasai.pipelines.logic_unit_output import LogicUnitOutput
from ...helpers.logger import Logger
from ...prompts.base import BasePrompt
from ...prompts.generate_python_code import GeneratePythonCodePrompt
from ...prompts.generate_python_code_with_sql import GeneratePythonCodeWithSQLPrompt
from ..base_logic_unit import BaseLogicUnit
from ..pipeline_context import PipelineContext
class PromptGeneration(BaseLogicUnit):
"""
Code Prompt Generation Stage
"""
pass
def execute(self, input: Any, **kwargs) -> Any:
"""
This method will return output according to
Implementation.
:param input: Your input data.
:param kwargs: A dictionary of keyword arguments.
- 'logger' (any): The logger for logging.
- 'config' (Config): Global configurations for the test
- 'context' (any): The execution context.
:return: The result of the execution.
"""
self.context: PipelineContext = kwargs.get("context")
self.logger: Logger = kwargs.get("logger")
prompt = self.get_chat_prompt(self.context)
self.logger.log(f"Using prompt: {prompt}")
return LogicUnitOutput(
prompt,
True,
"Prompt Generated Successfully",
{"content_type": "prompt", "value": prompt.to_string()},
)
def get_chat_prompt(self, context: PipelineContext) -> Union[str, BasePrompt]:
# set matplotlib as the default library
viz_lib = "matplotlib"
if context.config.data_viz_library:
viz_lib = context.config.data_viz_library
output_type = context.get("output_type")
return (
GeneratePythonCodeWithSQLPrompt(
context=context,
last_code_generated=context.get("last_code_generated"),
viz_lib=viz_lib,
output_type=output_type,
)
if context.config.direct_sql
else GeneratePythonCodePrompt(
context=context,
last_code_generated=context.get("last_code_generated"),
viz_lib=viz_lib,
output_type=output_type,
)
)