In [None]:
text_to_sql_prompt = ChatPromptTemplate.from_messages(
    [
        ("system",
         "You are a PostgreSQL expert. Given an input question, create a syntactically correct PostgreSQL query to run and return ONLY the generated Query and nothing else. Unless otherwise specified, do not return more than \
        {top_k} rows.\n\nHere is the relevant table info: {table_info}\
        Pay close attention on which column is in which table. if context contains more than one tables then create a query by performing JOIN operation only using the column unitid for the tables.\
        Follow these Instructions for creating syntactically correct SQL query:\
        - Be sure not to query for columns that do not exist in the tables and use alias only where required.\
        - Always use the column 'instnm' associated with the 'unitid' in the generated query.\
        - Whenever asked for Institute Names, return the institute names using column 'instnm' associated with the 'unitid' in the generated query.\
        - Likewise, when asked about the average (AVG function) or ratio, ensure the appropriate aggregation function is used.\
        - Pay close attention to the filtering criteria mentioned in the question and incorporate them using the WHERE clause in your SQL query.\
        - If the question involves multiple conditions, use logical operators such as AND, OR to combine them effectively.\
        - When dealing with date or timestamp columns, use appropriate date functions (e.g., DATE_PART, EXTRACT) for extracting specific parts of the date or performing date arithmetic.\
        - If the question involves grouping of data (e.g., finding totals or averages for different categories), use the GROUP BY clause along with appropriate aggregate functions.\
        - Consider using aliases for tables and columns to improve readability of the query, especially in case of complex joins or subqueries.\
        - If necessary, use subqueries or common table expressions (CTEs) to break down the problem into smaller, more manageable parts."),
        MessagesPlaceholder(variable_name="messages"),
        ("human", "{input}"),
    ]
)

# Generates a SQL Query
generate_sql_query = create_sql_query_chain(llm, db, text_to_sql_prompt)

In [None]:
def get_table_details():
    # Read the CSV file containing Table Names and Descriptions using Pandas DataFrame
    table_description = pd.read_csv("Data/table_descriptions.csv")

    # Retrieving Table Names and Descriptions from the DataFrame
    table_details = ""
    for index, row in table_description.iterrows():
        table_details = table_details + "Table Name:" + \
            row['Table'] + "\n" + "Table Description:" + \
            row['Description'] + "\n\n"

    return table_details

#Creating a Pydantic Base Model
class Table(BaseModel):
    """Table in SQL database."""

    name: str = Field(description="Name of table in SQL database.")


def get_tables(tables: List[Table]) -> List[str]:
    tables = [table.name for table in tables]
    return tables

table_details = get_table_details()
table_details_prompt = f"""Refer the Above Context and Return the names of SQL Tables mentioned in the above context\n\n
The tables are:

{table_details}
 """

table_chain = {"input": itemgetter("question")} | create_extraction_chain_pydantic(
    Table, llm, system_message=table_details_prompt) | get_tables

In [None]:
# Convert "question" key to the "input" key expected by current table_chain.
table_chain = {"input": itemgetter("question")} | table_chain
# Set table_names_to_use using table_chain.
full_chain = (
        RunnablePassthrough.assign(table_names_to_use=table_chain) | generate_sql_query
)

In [None]:
def metadata_func(record: dict, metadata: dict) -> dict:
    columns = record.get("Columns", [])

    # Extracting column details using list comprehension
    cname = [col.get("Column_Name") for col in columns]
    dtype = [col.get("Data_Type") for col in columns]
    cdesc = [col.get("Column_Description") for col in columns]

    metadata.update({
        "Table_Name": record.get("Table_Name"),
        "Table_Description": record.get("Table_Description"),
        "Column_Names": str(cname),
        "Data_Type": str(dtype),
        "Column_Description": str(cdesc),
    })
    return metadata


embedding_function = OpenAIEmbeddings(
    openai_api_key=openai_api_key, model="text-embedding-ada-002")


loader = JSONLoader(
    file_path="/Users/omkarsadekar/Documents/NEU Study Material/NEU Study Material/Capstone/ipedsllm/Data/tableinfo.json",
    jq_schema=".[].Table_Info[]",
    content_key="Table_Name",
    metadata_func=metadata_func,
)
data = loader.load()
vectorstore = Chroma.from_documents(
    data, embedding_function)
llm = ChatOpenAI(model="gpt-3.5-turbo-1106", temperature=0)
retriever = vectorstore.as_retriever()


template = """Answer the question based only on the following context:
    {context}
    Search for the table descriptions in the context and accordingly search for column names and associated column description. Include only relevant tables and columns which can be used by the downstream Text-to-SQL Agent to create SQL Queries for generating answer.
    Search for any information performing the following tasks:
    1. Table Names
    2. Table Descriptions
    3. Column Names
    4. Column Descriptions
    5. Encoded Values
    Finally, only return table names, column names and Encoded Values only (if availabe).

    Question: {question}
    """
retriever_prompt = ChatPromptTemplate.from_template(template)

retriever_chain = (
    {"context": retriever, "question": RunnablePassthrough()}
    | prompt
    | model
    | StrOutputParser()
)

#Example Response by invoking our retriever_chain
retriever_chain.invoke("List of Institutes accepting secondary school GPA for getting admission in Undergrad program")

# Based on the provided context, the relevant information for the downstream Text-to-SQL Agent to create SQL Queries containing Join, Filtering, and Sub Query is as follows:
# 1. Table Names:
# - ADM2022
# - IC2022_CAMPUSES
# 2. Table Descriptions:
# - ADM2022: Contains information related to admissions for the year 2022.
# - IC2022_CAMPUSES: Contains information about the campuses associated with the institutes.
# 3. Column Names:
# - ADM.unitid
# - IC.index
# - IC.campusid
# - IC.pcaddr
# - IC.pccity
# - ADM.admcon1
# 4. Column Descriptions:
# - ADM.unitid: Primary Key in the "ADM2022" table, likely used for joining operations.
# - IC.index: Likely used for joining operations with the unitid.
# - IC.campusid: Identifier for the campus associated with the institute.
# - IC.pcaddr: Physical address of the campus.
# - IC.pccity: City where the campus is located.
# - ADM.admcon1: Possibly a column indicating admission conditions, with a value of 1 indicating institutes accepting secondary school GPA for the Undergrad program admissions.


In [None]:
text_to_sql_prompt = ChatPromptTemplate.from_messages(
    [
        ("system",
         "You are a PostgreSQL expert. Given an input question, create a syntactically correct PostgreSQL query to run and return ONLY the generated Query and nothing else. Remember NOT include backticks ```sql ``` before and after the created query. Unless otherwise specified, do not return more than \
        {top_k} rows.\n\nHere is the relevant table info: {table_info}\
        Finally, Use only tables names and Column names mentioned in:\n\n {context} to create correct SQL Query and pay close attention on which column is in which table. if context contains more than one tables then create a query by performing JOIN operation only using the column unitid for the tables.\
        Follow these Instructions for creating syntactically correct SQL query:\
        - Be sure not to query for columns that do not exist in the tables and use alias only where required.\
        - Always use the column 'instnm' associated with the 'unitid' in the generated query.\
        - Whenever asked for Institute Names, return the institute names using column 'instnm' associated with the 'unitid' in the generated query.\
        - Likewise, when asked about the average (AVG function) or ratio, ensure the appropriate aggregation function is used.\
        - Pay close attention to the filtering criteria mentioned in the question and incorporate them using the WHERE clause in your SQL query.\
        - If the question involves multiple conditions, use logical operators such as AND, OR to combine them effectively.\
        - When dealing with date or timestamp columns, use appropriate date functions (e.g., DATE_PART, EXTRACT) for extracting specific parts of the date or performing date arithmetic.\
        - If the question involves grouping of data (e.g., finding totals or averages for different categories), use the GROUP BY clause along with appropriate aggregate functions.\
        - Consider using aliases for tables and columns to improve readability of the query, especially in case of complex joins or subqueries.\
        - If necessary, use subqueries or common table expressions (CTEs) to break down the problem into smaller, more manageable parts."),
        MessagesPlaceholder(variable_name="messages"),
        ("human", "{input}"),
    ]
)

In [None]:
# Generate SQL Query Chain
generate_sql_query = create_sql_query_chain(llm, db, text_to_sql_prompt)

# Table Chain with Pydantic
table_chain = {"input": itemgetter("question")} | create_extraction_chain_pydantic(
    Table, llm, system_message=table_details_prompt) | get_tables

# Final Chain combined with SQL Query chain and Pydantic Chain
final_chain = (
            RunnablePassthrough.assign(context=retriever_chain, table_names_to_assign = table_chain) | generate_sql_query)
