<a href="https://colab.research.google.com/github/ak2742/mlplay/blob/Fine-Tuning/07)_Chat_on_SQL_DB.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#Code to mount Google Drive at Colab Notebook instance
from google.colab import drive
drive.mount('/content/drive')

In [None]:
#@title Install libraries

!pip install -q langchain==0.1.2 langchain_experimental
!pip install -q google-generativeai langchain-google-genai

!pip install -q gradio

## **Get Gemini Key from Secrets**
Set GEMINI_KEY secret key at Google Colab and get that here to runn Google Gemini LLM. You can get Google Gemini Key from following link https://makersuite.google.com/app/apikey

In [None]:
from google.colab import userdata
GOOGLE_API_KEY = userdata.get('GEMINI_KEY')

In [None]:
#@title DB interface

from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:////content/drive/MyDrive/Colab Notebooks/chinook.db")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artists LIMIT 10;")

In [None]:
#@title Create LLM and Prompt

from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.prompts import PromptTemplate

llm = ChatGoogleGenerativeAI(model="gemini-pro", google_api_key=GOOGLE_API_KEY, convert_system_message_to_human=True, temperature=0.1)

sql_prompt_template = """
Only use the following tables:
{table_info}
Question: {input}

Given an input question, first create a syntactically correct
{dialect} query to run.

Relevant pieces of previous conversation:
{history}

(You do not need to use these pieces of information if not relevant)
Dont include ```, ```sql and \n in the output.
"""
prompt = PromptTemplate(
        input_variables=["input", "table_info", "dialect", "history"],
        template=sql_prompt_template,
    )

In [None]:
#@title Create Chains

from langchain_experimental.sql.base import SQLDatabaseChain
from langchain.memory import ConversationBufferMemory
from langchain_core.output_parsers import StrOutputParser

memory = ConversationBufferMemory(memory_key="history")

db_chain = SQLDatabaseChain.from_llm(
        llm, db, memory=memory, prompt=prompt, return_direct=True,  verbose=True
    )

In [None]:
output_prompt = PromptTemplate(
        input_variables=["query", "result"],
        template="Based on the following information generate human redable response: {query},  {result}",
    )

output_parser = StrOutputParser()
chain = db_chain | output_prompt | llm | output_parser

In [None]:
#@title Add Gradio UI

import gradio as gr

def gradio_fn(msg, chat_history):
    response = chain.invoke(msg)

    chat_history.append((msg, response))
    return response

gr.ChatInterface(
    gradio_fn,
    examples=[["How many employees are there"]],
    chatbot=gr.Chatbot(height=300),
    title="CSV QA",
    textbox=gr.Textbox(placeholder="Ask your question here", container=False, scale=7),
    theme="soft",
    description="Ask me any question on the given database",
    cache_examples=True,
    retry_btn=None,
    undo_btn="Delete Previous",
    clear_btn="Clear",
    ).launch(share=True, debug=True)