In [1]:
from pyprojroot import here
from langchain_community.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from operator import itemgetter
from langchain_openai import ChatOpenAI
import os
from dotenv import load_dotenv
load_dotenv()

True

**Set the environment variables and load the LLM**

In [2]:
os.environ['OPENAI_API_KEY'] = os.getenv("OPENAI_API_KEY")

llm = ChatOpenAI(model="gpt-4")
# llm = ChatOpenAI(model="gpt-4o-mini")
# llm = ChatOpenAI(model="gpt-4o")

**Load and test the sqlite db**

In [3]:
sqldb_directory = here("data/ICDD-energy.db")
db = SQLDatabase.from_uri(
    f"sqlite:///{sqldb_directory}")

print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Comments LIMIT 10;")

# from sqlalchemy import create_engine, inspect
# from sqlalchemy.orm import sessionmaker
# engine = create_engine(db_path)

# # Create a session
# Session = sessionmaker(bind=engine)
# session = Session()

# # Use SQLAlchemy's Inspector to get database information
# inspector = inspect(engine)

# # Get table names
# tables = inspector.get_table_names()
# print("Tables in the database:", tables)
# print(len(tables))

sqlite
['Applications', 'Comments', 'Compositions', 'Melting Points']


"[(353, 10359, None, None, 'Decomposes at 1750 C'), (494, 10503, None, None, None), (569, 10579, None, None, None), (618, 10628, None, None, 'Transition to calcite at 520. Antacid'), (636, 10646, None, 'A new method of X-Ray crystal analysis', None), (667, 10677, None, None, None), (696, 10706, None, None, None), (738, 10750, None, None, None), (780, 10792, None, None, None), (784, 10796, None, None, None)]"

**Create the SQL agent chain and run a test query**

In [4]:
system_role = """Given the following user question, corresponding SQL query, and SQL result, answer the user question.\n
    Question: {question}\n
    SQL Query: {query}\n
    SQL Result: {result}\n
    Answer:
    """

execute_query = QuerySQLDataBaseTool(db=db)
write_query = create_sql_query_chain(
    llm, db)
answer_prompt = PromptTemplate.from_template(
    system_role)
answer = answer_prompt | llm | StrOutputParser()
chain = (
    RunnablePassthrough.assign(query=write_query).assign(
        result=itemgetter("query") | execute_query
    )
    | answer
)

  execute_query = QuerySQLDataBaseTool(db=db)


In [5]:
message = "How many tables do I have in the database? and what are their names?"
response = chain.invoke({"question": message})
response

'You have 4 tables in the database. Their names are Comments, Applications, Compositions, and Melting Points.'

**SQL-agent Tool Design**

In [9]:
from langchain_core.tools import tool
from langchain_community.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from operator import itemgetter
from langchain_openai import ChatOpenAI


class ICDDSQLAgentTool:
    """
    A tool for interacting with a ICDD database using an LLM (Language Model) to generate and execute SQL queries.

    This tool enables users to ask questions about compositions, comments, melting points, etc., which are transformed into SQL queries by a language model.
    The SQL queries are executed on the provided ICDD-energy SQLite database, and the results are processed by the language model to
    generate a final answer for the user.

    Attributes:
        sql_agent_llm (ChatOpenAI): An instance of a ChatOpenAI language model used to generate and process SQL queries.
        system_role (str): A system prompt template that guides the language model in answering user questions based on SQL query results.
        db (SQLDatabase): An instance of the SQL database used to execute queries.
        chain (RunnablePassthrough): A chain of operations that creates SQL queries, executes them, and generates a response.

    Methods:
        __init__: Initializes the ICDDSQLAgentTool by setting up the language model, SQL database, and query-answering pipeline.
    """

    def __init__(self, llm: str, sqldb_directory: str, llm_temerature: float) -> None:
        """
        Initializes the ICDDSQLAgentTool with the necessary configurations.

        Args:
            llm (str): The name of the language model to be used for generating and interpreting SQL queries.
            sqldb_directory (str): The directory path where the SQLite database is stored.
            llm_temerature (float): The temperature setting for the language model, controlling response randomness.
        """
        self.sql_agent_llm = ChatOpenAI(
            model=llm, temperature=llm_temerature)
        self.system_role = """Given the following user question, corresponding SQL query, and SQL result, answer the user question.\n
            Question: {question}\n
            SQL Query: {query}\n
            SQL Result: {result}\n
            Answer:
            """
        self.db = SQLDatabase.from_uri(
            f"sqlite:///{sqldb_directory}")
        print(self.db.get_usable_table_names())

        execute_query = QuerySQLDataBaseTool(db=self.db)
        write_query = create_sql_query_chain(
            self.sql_agent_llm, self.db)
        answer_prompt = PromptTemplate.from_template(
            self.system_role)

        answer = answer_prompt | self.sql_agent_llm | StrOutputParser()
        self.chain = (
            RunnablePassthrough.assign(query=write_query).assign(
                result=itemgetter("query") | execute_query
            )
            | answer
        )

In [12]:
import sys
from pyprojroot import here
import os
sys.path.append(str(here()))
from src.agent_graph.load_tools_config import LoadToolsConfig
from langchain_core.tools import tool

TOOLS_CFG = LoadToolsConfig()

@tool
def query_icdd_sqldb(query: str) -> str:
    """Query the ICDD-energy SQL Database and access all the energy reltaed materials information. Input should be a search query."""
    agent = ICDDSQLAgentTool(
        llm=TOOLS_CFG.icdd_sqlagent_llm,
        sqldb_directory=TOOLS_CFG.icdd_sqldb_directory,
        llm_temerature=TOOLS_CFG.icdd_sqlagent_llm_temperature
    )
    response = agent.chain.invoke({"question": query})
    return response

In [13]:
query_icdd_sqldb("How many tables do I have in the database? and what are their names?")

['Applications', 'Comments', 'Compositions', 'Melting Points']


"You have 4 tables in the database. Their names are 'Comments', 'Applications', 'Compositions', and 'Melting Points'."