<a href="https://colab.research.google.com/github/Raytan/nl2sql/blob/main/nl2sql_playground.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 环境准备

In [1]:
!pip install langchain
!pip install openai
from langchain import OpenAI, SQLDatabase, SQLDatabaseChain
import os


Collecting langchain
  Downloading langchain-0.0.240-py3-none-any.whl (1.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
Collecting dataclasses-json<0.6.0,>=0.5.7 (from langchain)
  Downloading dataclasses_json-0.5.13-py3-none-any.whl (26 kB)
Collecting langsmith<0.1.0,>=0.0.11 (from langchain)
  Downloading langsmith-0.0.14-py3-none-any.whl (29 kB)
Collecting openapi-schema-pydantic<2.0,>=1.2 (from langchain)
  Downloading openapi_schema_pydantic-1.2.4-py3-none-any.whl (90 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m90.0/90.0 kB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
Collecting marshmallow<4.0.0,>=3.18.0 (from dataclasses-json<0.6.0,>=0.5.7->langchain)
  Downloading marshmallow-3.20.1-py3-none-any.whl (49 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m49.4/49.4 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting typing-inspect<1,>=0.4.0 (from dataclas

挂载GoogleDrive，有自己的ak可以省略这一步

In [2]:
from google.colab import drive
drive.mount('/content/gdrive')


Mounted at /content/gdrive


导入openAI key

In [3]:
with open('/content/gdrive/My Drive/openai/openai-key', 'r') as f:
  api_key = f.read().strip()

os.environ["OPENAI_API_KEY"] = api_key

准备数据库

In [4]:
# 本地数据库
# db = SQLDatabase.from_uri("sqlite:///Chinook.db")

# 从GoogleDrive导入
db_path = '/content/gdrive/My Drive/openai/Chinook.db'
db = SQLDatabase.from_uri(f"sqlite:///{db_path}")

# 打印，可以不做
# print(db.table_info)


选择模型

chatgpt的模型可以参考：https://platform.openai.com/docs/models/overview

In [5]:
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI

#llm = OpenAI(temperature=0)
llm = OpenAI(model_name='gpt-3.5-turbo')




# Router和Chain


In [6]:
from langchain import PromptTemplate, OpenAI, LLMChain
from langchain.chains.router import MultiPromptChain
from langchain.llms import OpenAI
from langchain.chains import ConversationChain
from langchain.chains.llm import LLMChain
from langchain.prompts import PromptTemplate

# db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, use_query_checker=True)
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)

prompt_template = "判断这个语句 {query} 的目的，如果是sql查询，回答 sql; 其他回答 chat "

llm_chain = LLMChain(
    llm = llm,
    prompt=PromptTemplate.from_template(prompt_template)
)


#llm_chain("SELECT * FROM (Customer INNER JOIN \"Employee\" ON \"Customer\".\"SupportRepId\" = \"Employee\".\"EmployeeId\")")
llm_chain("统计一下有多少员工")


{'query': '统计一下有多少员工', 'text': 'sql'}

In [29]:
from langchain.chains.router import MultiPromptChain
from langchain.llms import OpenAI
from langchain.chains import ConversationChain
from langchain.chains.llm import LLMChain
from langchain.prompts import PromptTemplate
from langchain.chains.router.llm_router import LLMRouterChain, RouterOutputParser
from langchain.chains.router.multi_prompt_prompt import MULTI_PROMPT_ROUTER_TEMPLATE

physics_template = """You are a very smart physics professor. \
You are great at answering questions about physics in a concise and easy to understand manner. \
When you don't know the answer to a question you admit that you don't know.

Here is a question:
{input}"""


math_template = """You are a very good mathematician. You are great at answering math questions. \
You are so good because you are able to break down hard problems into their component parts, \
answer the component parts, and then put them together to answer the broader question.

Here is a question:
{input}"""

prompt_infos = [
    {
        "name": "physics",
        "description": "Good for answering questions about physics",
        "prompt_template": physics_template,
    },
    {
        "name": "math",
        "description": "Good for answering math questions",
        "prompt_template": math_template,
    },
]

destination_chains = {}
for p_info in prompt_infos:
    name = p_info["name"]
    prompt_template = p_info["prompt_template"]
    prompt = PromptTemplate(template=prompt_template, input_variables=["input"])
    chain = LLMChain(llm=llm, prompt=prompt)
    destination_chains[name] = chain
default_chain = ConversationChain(llm=llm, output_key="text")

destinations = [f"{p['name']}: {p['description']}" for p in prompt_infos]
destinations_str = "\n".join(destinations)
router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format(destinations=destinations_str)
print(router_template)
router_prompt = PromptTemplate(
    template=router_template,
    input_variables=["input"],
    output_parser=RouterOutputParser(),
)
router_chain = LLMRouterChain.from_llm(llm, router_prompt)

chain = MultiPromptChain(
    router_chain=router_chain,
    destination_chains=destination_chains,
    default_chain=default_chain,
    verbose=True,
)

# print(chain.run("What is black body radiation?"))

print(
    chain.run(
        "What is the first prime number greater than 40 such that one plus the prime number is divisible by 3"
    )
)

Given a raw text input to a language model select the model prompt best suited for the input. You will be given the names of the available prompts and a description of what the prompt is best suited for. You may also revise the original input if you think that revising it will ultimately lead to a better response from the language model.

<< FORMATTING >>
Return a markdown code snippet with a JSON object formatted to look like:
```json
{{
    "destination": string \ name of the prompt to use or "DEFAULT"
    "next_inputs": string \ a potentially modified version of the original input
}}
```

REMEMBER: "destination" MUST be one of the candidate prompt names specified below OR it can be "DEFAULT" if the input is not well suited for any of the candidate prompts.
REMEMBER: "next_inputs" can just be the original input if you don't think any modifications are needed.

<< CANDIDATE PROMPTS >>
physics: Good for answering questions about physics
math: Good for answering math questions

<< INPUT

# SQL模型
基础查询

In [37]:
db_chain.run("列出跟员工有关的表")
#db_chain.run("列出10张专辑")



[1m> Entering new SQLDatabaseChain chain...[0m
列出跟员工有关的表
SQLQuery:[32;1m[1;3mSELECT * FROM 
("Customer"
INNER JOIN "Employee" ON "Customer"."SupportRepId" = "Employee"."EmployeeId")[0m
SQLResult: [33;1m[1;3m[(1, 'Luís', 'Gonçalves', 'Embraer - Empresa Brasileira de Aeronáutica S.A.', 'Av. Brigadeiro Faria Lima, 2170', 'São José dos Campos', 'SP', 'Brazil', '12227-000', '+55 (12) 3923-5555', '+55 (12) 3923-5566', 'luisg@embraer.com.br', 3, 3, 'Peacock', 'Jane', 'Sales Support Agent', 2, '1973-08-29 00:00:00', '2002-04-01 00:00:00', '1111 6 Ave SW', 'Calgary', 'AB', 'Canada', 'T2P 5M5', '+1 (403) 262-3443', '+1 (403) 262-6712', 'jane@chinookcorp.com'), (2, 'Leonie', 'Köhler', None, 'Theodor-Heuss-Straße 34', 'Stuttgart', None, 'Germany', '70174', '+49 0711 2842222', None, 'leonekohler@surfeu.de', 5, 5, 'Johnson', 'Steve', 'Sales Support Agent', 2, '1965-03-03 00:00:00', '2003-10-17 00:00:00', '7727B 41 Ave', 'Calgary', 'AB', 'Canada', 'T3B 1Y7', '1 (780) 836-9987', '1 (780) 836-

InvalidRequestError: ignored

聚合查询

In [7]:
#db_chain.run("帮我统计一下有多少员工")
db_chain.run("帮我统计一下员工的平均年龄")



[1m> Entering new SQLDatabaseChain chain...[0m
帮我统计一下员工的平均年龄
SQLQuery:[32;1m[1;3mSELECT AVG(strftime('%Y', 'now') - strftime('%Y', BirthDate)) AS AverageAge
FROM Employee[0m
SQLResult: [33;1m[1;3m[(58.5,)][0m
Answer:[32;1m[1;3mThe average age of the employees is 58.5.[0m
[1m> Finished chain.[0m


'The average age of the employees is 58.5.'

# 运行web网页

In [None]:
!pip install gradio
!pip install mdtex2html

Collecting mdtex2html
  Downloading mdtex2html-1.2.0-py3-none-any.whl (13 kB)
Collecting latex2mathml (from mdtex2html)
  Downloading latex2mathml-3.76.0-py3-none-any.whl (73 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m73.4/73.4 kB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: latex2mathml, mdtex2html
Successfully installed latex2mathml-3.76.0 mdtex2html-1.2.0


In [None]:
import gradio as gr
import mdtex2html

"""Override Chatbot.postprocess"""
def postprocess(self, y):
    if y is None:
        return []
    for i, (message, response) in enumerate(y):
        y[i] = (
            None if message is None else mdtex2html.convert((message)),
            None if response is None else mdtex2html.convert(response),
        )
    return y


gr.Chatbot.postprocess = postprocess


def parse_text(text):
    """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
    lines = text.split("\n")
    lines = [line for line in lines if line != ""]
    count = 0
    for i, line in enumerate(lines):
        if "```" in line:
            count += 1
            items = line.split('`')
            if count % 2 == 1:
                lines[i] = f''
            else:
                lines[i] = f''
        else:
            if i > 0:
                if count % 2 == 1:
                    line = line.replace("`", "\`")
                    line = line.replace("<", "<")
                    line = line.replace(">", ">")
                    line = line.replace(" ", " ")
                    line = line.replace("*", "&ast;")
                    line = line.replace("_", "&lowbar;")
                    line = line.replace("-", "-")
                    line = line.replace(".", ".")
                    line = line.replace("!", "!")
                    line = line.replace("(", "(")
                    line = line.replace(")", ")")
                    line = line.replace("$", "$")
                lines[i] = ""+line
    text = "".join(lines)
    return text


def predict(input, chatbot, max_length, top_p, temperature, history, past_key_values):
    chatbot.append((parse_text(input), ""))
    chatbot[-1] = (parse_text(input), parse_text(db_chain.run(input)))
    yield chatbot, history, past_key_values


def reset_user_input():
    return gr.update(value='')


def reset_state():
    return [], [], None


with gr.Blocks() as demo:
    gr.HTML("""NLP2SQL""")

    chatbot = gr.Chatbot()
    with gr.Row():
        with gr.Column(scale=4):
            with gr.Column(scale=12):
                user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
                    container=False)
            with gr.Column(min_width=32, scale=1):
                submitBtn = gr.Button("Submit", variant="primary")
        with gr.Column(scale=1):
            emptyBtn = gr.Button("Clear History")
            max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True)
            top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
            temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)

    history = gr.State([])
    past_key_values = gr.State(None)

    submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history, past_key_values],
                    [chatbot, history, past_key_values], show_progress=True)

    submitBtn.click(reset_user_input, [], [user_input])

    emptyBtn.click(reset_state, outputs=[chatbot, history, past_key_values], show_progress=True)

demo.queue().launch(share=True, inbrowser=True)

  user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(


Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Running on public URL: https://5e0a4ddfea6fb6601b.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


