In [72]:
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain.chains import create_sql_query_chain
from langchain_openai import ChatOpenAI
from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain.chains.openai_tools import create_extraction_chain_pydantic
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_core.pydantic_v1 import BaseModel, Field
from typing import List
import pandas as pd
import os
import re

class LangchainActions:

    
    os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')

    def __init__(self, user, host, password, name, dbType, file_name, model='gpt-4o'):
        self.user = user
        self.host = host
        self.password = password
        self.name = name
        self.dbType = dbType
        self.model = model
        self.filename = file_name
       

    def createURI(self, dbType, tabelnames:list, rows=3):
        if dbType == 'MySQL':
            return SQLDatabase.from_uri(f"mysql+pymysql://{self.user}:{self.password}@{self.host}/{self.name}", sample_rows_in_table_info = rows, include_tables = tabelnames)
        elif dbType == 'Postgres':
            return SQLDatabase.from_uri(f"postgresql+psycopg2://{self.user}:{self.password}@{self.host}/{self.name}", sample_rows_in_table_info = rows, include_tables = tabelnames)

    def get_table_details(self, filename):
        # Read the CSV file into a DataFrame
        table_description = pd.read_csv(filename)
        table_docs = []

        # Iterate over the DataFrame rows to create Document objects
        table_details = ""
        for index, row in table_description.iterrows():
            table_details = table_details + "Table Name:" + row['Table Name'].lower() + "\n" + "Table Description:" + row['Description'] + "\n\n"

        return table_details


    def answer_query(self, question:str):
        llm = ChatOpenAI(model = self.model, temperature=0)

        table_details = self.get_table_details(self.filename)
        table_details_prompt = f"""Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \
        The tables are:

        {table_details}

        Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed."""

        class Table(BaseModel):
            """Table in SQL database."""

            name: str = Field(description="Name of table in SQL database.")

        def get_tables(tables: List[Table]) -> List[str]:
            tables  = [table.name for table in tables]
            return tables
        
        select_table = {"input": itemgetter("question")} | create_extraction_chain_pydantic(Table, llm, system_message=table_details_prompt) | get_tables
        tables = select_table.invoke({"question": question})
        
        db = self.createURI(self.dbType, tables)

        generate_query = create_sql_query_chain(llm, db)
        execute_query = QuerySQLDataBaseTool(db = db)
        
        answer_prompt = PromptTemplate.from_template(
            """Given the following user question, corresponding SQL query, and SQL result, answer the user question.

        Question: {question}
        SQL Query: {query}
        SQL Result: {result}
        Answer: """
        )
        rephrase_answer = answer_prompt | llm | StrOutputParser()

        def stripper(query:str):
            # Adjust regex to capture content within ```sql and ```
            match = re.search(r"```sql\n(.*?)\n```", query, re.DOTALL)
            if match:
                return (match.group(1)).replace('SQLQuery:', '').strip()
            return None

        chain = (
            RunnablePassthrough.assign(query=generate_query)
            .assign(clean_query=lambda x: stripper(x['query']))
            .assign(result=itemgetter("clean_query") | execute_query
            )
            | rephrase_answer
        )

        return chain.invoke({"question": question})

        
    


In [73]:
langchain = LangchainActions(user=os.getenv('DB_USER'), host=os.getenv('DB_HOST'), password=os.getenv('DB_PASSWORD'), name=os.getenv('DB_NAME'), file_name = 'Hospital_Relational_Model.csv', dbType='Postgres')

In [74]:
langchain.answer_query('Top 5 test supplies')

"The top 5 test supplies are:\n\n1. Micropore with a total quantity of 48\n2. Ryle's Tube with a total quantity of 42\n3. Thermometers with a total quantity of 41\n4. Torch with a total quantity of 40\n5. Scissors with a total quantity of 40"