In [1]:
import sys
sys.path.append('..')

### Set OpenAI key 

In [2]:
import os
import configparser

config = configparser.ConfigParser()
config.read('../../.secrets.ini')
openai_api_key = config['OPENAI']['OPENAI_API_KEY']

os.environ.update({'OPENAI_API_KEY': openai_api_key})

### Get tools

In [3]:
from models.llm.chain import GraphChain, DraftChunkChain

In [4]:
graph_template_prompt_path = '../../openai_skt/models/templates/graph_prompt_template.txt'
with open(graph_template_prompt_path, 'r') as f:
    graph_template = f.read()
graph_chain = GraphChain(graph_template=graph_template, input_variables=["graph_to_draw"])

In [5]:
draft_chunk_template_prompt_path = '../../openai_skt/models/templates/draft_chunk_prompt_template.txt'
with open(draft_chunk_template_prompt_path, 'r') as f:
    draft_chunk_template = f.read()
draft_chunk_chain = DraftChunkChain(draft_chunk_template=draft_chunk_template, input_variables=["draft", "query"])

In [6]:
from tools import DatabaseTool, DraftChunkTool, GraphTool

In [44]:
from typing import Optional, Type, Any
from pydantic import BaseModel

from langchain.tools import BaseTool

from models.llm.chain import DraftChunkChain

class DraftChunkTool(BaseTool):
    name = "draft_chunk_tool"
    description = "A tool to extract a part of draft that corresponds to the user's query."
    args_schema: Optional[Type[BaseModel]] = None
    """Pydantic model class to validate and parse the tool's input arguments."""

    draft_chunk_chain: Any
    draft: str = None

    def __init__(self, draft_chunk_chain) -> None:
        super().__init__()
        self.draft_chunk_chain = draft_chunk_chain

    def set_draft(self, draft):
        self.draft = draft

    def _run(self, query:str) -> dict:
        result = self.draft_chunk_chain.run(draft=self.draft, query=query)
        return result
    
    async def _arun(self, query:str) -> dict:
        result = await self.draft_chunk_chain.arun(draft=self.draft, query=query)
        return result

In [45]:
database_tool = DatabaseTool()
draft_chunk_tool = DraftChunkTool(draft_chunk_chain=draft_chunk_chain)
graph_tool = GraphTool(graph_chain=graph_chain)

In [63]:
tools = [database_tool, graph_tool]

### Set Agent

In [59]:
import re
from typing import List, Union

from langchain import LLMChain, OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.agents import Tool, AgentExecutor, LLMSingleActionAgent, AgentOutputParser
from langchain.prompts import StringPromptTemplate
from langchain.schema import AgentAction, AgentFinish, OutputParserException
from langchain.tools import BaseTool

In [72]:
# Set up a prompt template
class CustomPromptTemplate(StringPromptTemplate):
    # The template to use
    template: str
    # The list of tools available
    tools: List[BaseTool]

    def format(self, **kwargs) -> str:
        # Get the intermediate steps (AgentAction, Observation tuples)
        # Format them in a particular way
        intermediate_steps = kwargs.pop("intermediate_steps")
        thoughts = ""
        for action, observation in intermediate_steps:
            thoughts += action.log
            thoughts += f"\nObservation: {observation}\nThought: "
        # Set the agent_scratchpad variable to that value
        kwargs["agent_scratchpad"] = thoughts
        # Create a tools variable from the list of tools provided
        kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools])
        # Create a list of tool names for the tools provided
        kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools])
        print(kwargs)
        return self.template.format(**kwargs)

class CustomOutputParser(AgentOutputParser):

    def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:
        # Check if agent should finish
        if "Final Answer:" in llm_output:
            return AgentFinish(
                # Return values is generally always a dictionary with a single `output` key
                # It is not recommended to try anything else at the moment :)
                return_values={"output": llm_output.split("Final Answer:")[-1].strip()},
                log=llm_output,
            )
        # Parse out the action and action input
        regex = r"Action\s*\d*\s*:(.*?)\nAction\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
        match = re.search(regex, llm_output, re.DOTALL)
        if not match:
            raise OutputParserException(f"Could not parse LLM output: `{llm_output}`")
        action = match.group(1).strip()
        action_input = match.group(2)
        # Return the action and action input
        return AgentAction(tool=action, tool_input=action_input.strip(" ").strip('"'), log=llm_output)
    
class DraftEditAgent:
    # TODO: Database 객체로 선언 안하는 방법이 있는지 생각
    def __init__(self, tools, draft_edit_prompt_path='../openai_skt/models/templates/draft_edit_prompt_template.txt', verbose=False) -> None:
        with open(draft_edit_prompt_path, 'r') as f:
            self.draft_edit_prompt_template = f.read()
        
        self.output_parser = CustomOutputParser()
        self.verbose = verbose
        self.tools = tools

        self.draft_edit_prompt = CustomPromptTemplate(
            template=self.draft_edit_prompt_template,
            tools=self.tools,
            # This omits the `agent_scratchpad`, `tools`, and `tool_names` variables because those are generated dynamically
            # This includes the `intermediate_steps` variable because that is needed
            input_variables=["user_query", "draft", "intermediate_steps"]
        )

        self.llm = ChatOpenAI(model='gpt-3.5-turbo-16k', temperature=0, verbose=self.verbose)
        self.draft_edit_chain = LLMChain(llm=self.llm, prompt=self.draft_edit_prompt, verbose=self.verbose)
        tool_names = [tool.name for tool in self.tools]
        self.agent = LLMSingleActionAgent(
            llm_chain=self.draft_edit_chain, 
            output_parser=self.output_parser,
            stop=["\nObservation:"], 
            allowed_tools=tool_names
        )
        self.agent_executor = AgentExecutor.from_agent_and_tools(agent=self.agent, tools=self.tools, verbose=self.verbose)

    def run(self, database, draft, query):
        part_draft = draft_chunk_chain.run(draft=draft, query=query)
        input_dict = self.parse_input(database, part_draft, query)
        result = self.agent_executor.run(input_dict)
        return result

    async def arun(self, database, draft, query):
        input_dict = self.parse_input(database, draft, query)
        result = await self.agent_executor.arun(input_dict)
        return result
    
    def parse_input(self, database, draft, query):
        database_tool = self.tools[0]
        database_tool.set_database(database)
        # draft_chunk_tool = self.tools[1]
        # draft_chunk_tool.set_draft(draft)
        input_dict = {'user_query': query, 'draft': draft}
        return input_dict

In [73]:
draft_edit_agent = DraftEditAgent(tools=tools, verbose=True, draft_edit_prompt_path='../../openai_skt/models/templates/draft_edit_prompt_template.txt')

In [12]:
from database import DataBase

In [13]:
# from embedchain.embedchain import EmbedChain
# from embedchain.config import AppConfig
# embed_chain = EmbedChain(config=AppConfig())

In [None]:
# database = DataBase.load(database_path='./user/test_2/database.json', embed_chain=embed_chain)

In [51]:
with open('./user/test_2/draft_0.md', 'r') as f:
    draft = f.read()

In [74]:
draft_edit_agent.run(database=database, draft=draft, query="가장 상단에 비트코인 가격 차트를 그려줘")



[1m> Entering new AgentExecutor chain...[0m


[1m> Entering new LLMChain chain...[0m
{'user_query': '가장 상단에 비트코인 가격 차트를 그려줘', 'draft': '하지만 비트코인은 가치의 변동성과 부정적인 영향을 가지고 있습니다. 비트코인의 가격은 수요와 공급에 따라 결정되기 때문에 시장 상황에 따라 큰 폭으로 변동할 수 있습니다. 이러한 가치의 변동성은 투자자에게 큰 위험을 안고 있으며, 예측하기 어렵다는 단점을 가지고 있습니다.', 'agent_scratchpad': '', 'tools': 'database: A tool to extract data from a database with a query\ngraph_tool: A tool to draw a graph. It return image path of the graph.', 'tool_names': 'database, graph_tool'}
Prompt after formatting:
[32;1m[1;3mYou should modify the draft according to the user's requirements. You have access to the following tools:

database: A tool to extract data from a database with a query
graph_tool: A tool to draw a graph. It return image path of the graph.

Use the following format:

Draft: the input draft
Requirements: the input user requirements
Thought: you should always think about what to do
Action: the action to take, should be one of [database, graph_tool]
Action

'비트코인 가격 예측에 대한 웹 페이지의 토큰 수를 나타내는 차트입니다.'