<a href="https://colab.research.google.com/github/ak2742/mlplay/blob/Fine-Tuning/06)_text_to_SQL_agent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#Code to mount Google Drive at Colab Notebook instance
from google.colab import drive
drive.mount('/content/drive')

In [None]:
#@title Install Langchain and Google Gemini related libraries

!pip install -q langchain==0.1.2
!pip install -q google-generativeai langchain-google-genai

## **Get Gemini Key from Secrets**
Set GEMINI_KEY secret key at Google Colab and get that here to runn Google Gemini LLM. You can get Google Gemini Key from following link https://makersuite.google.com/app/apikey

In [None]:
from google.colab import userdata
GOOGLE_API_KEY = userdata.get('GEMINI_KEY')

In [None]:
#@title DB interface

from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:////content/drive/MyDrive/Colab Notebooks/chinook.db")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artists LIMIT 10;")

In [None]:
#@title Convert question to SQL query

import re
from langchain.chains import create_sql_query_chain
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.prompts import ChatPromptTemplate

llm = ChatGoogleGenerativeAI(model="gemini-pro", google_api_key=GOOGLE_API_KEY, convert_system_message_to_human=True, temperature=0.1)
chain = create_sql_query_chain(llm, db)

response = chain.invoke({"question": "How many employees are there"})

# Generate regex pattern to remove prefix and suffix
regex_pattern = r'^```sql\n|\n```$'

# Remove prefix and suffix using regex to get proper SQL query
modified_response = re.sub(regex_pattern, '', response)

# db.run(modified_response)
print(modified_response)

In [None]:
#@title Execute SQL query

from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain_core.prompts import PromptTemplate

execute_query = QuerySQLDataBaseTool(db=db)

template = '''Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the {top_k} answer.
Use the following format:

Question: "Question here"
"SQL Query to run"
SQLResult: "Result of the SQLQuery"
Answer: "Final answer here"

Only use the following tables:

{table_info}.

Provide SQL query as simple string without any markdown.

Question: {input}'''
prompt = PromptTemplate.from_template(template)

write_query = create_sql_query_chain(llm, db, prompt)
chain = write_query | execute_query
chain.invoke({"question": "How many employees are there"})

In [None]:
#@title Generate answer based on the query data

from operator import itemgetter

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

answer_prompt = PromptTemplate.from_template(
"""Given the following user question, corresponding SQL query, and SQL result, answer the user question.

Question: {question}
{query}
SQL Result: {result}
Answer: """
)

answer = answer_prompt | llm | StrOutputParser()
chain = (
    RunnablePassthrough.assign(query=write_query).assign(
        result=itemgetter("query") | execute_query
    )
    | answer
)

chain.invoke({"question": "How many employees are there"})

In [None]:
#@title create agent

from langchain_community.agent_toolkits import create_sql_agent
from langchain.agents.agent_types import AgentType
from langchain.agents.agent_toolkits import SQLDatabaseToolkit

agent_executor = create_sql_agent(llm,
                                  toolkit=SQLDatabaseToolkit(db=db, llm=llm),
                                  agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
                                  verbose=True)

In [None]:
#@title run agent

agent_executor.invoke(
    {
        "input": "List the total sales per country. Which country's customers spent the most?"
    }
)