<a href="https://colab.research.google.com/github/SunOneLeo/nlp2sql/blob/main/nl2sql_web.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 准备环境

安装基础依赖

In [26]:
!pip install gradio
!pip install mdtex2html
!pip install langchain
!pip install openai

挂载GoogleDrive，有自己的ak可以省略这一步

In [None]:
from google.colab import drive
from langchain import OpenAI, SQLDatabase, SQLDatabaseChain
import os

drive.mount('/content/gdrive')


导入openKey

In [None]:
with open('/content/gdrive/My Drive/openai/openai-key', 'r') as f:
  api_key = f.read().strip()

os.environ["OPENAI_API_KEY"] = api_key

加载数据库db文件

In [None]:
# 本地数据库
# db = SQLDatabase.from_uri("sqlite:///Chinook.db")

# 从GoogleDrive导入
db_path = '/content/gdrive/My Drive/openai/Chinook.db'
db = SQLDatabase.from_uri(f"sqlite:///{db_path}")

# 打印，可以不做
# print(db.table_info)


In [None]:
#llm = OpenAI(temperature=0)
llm = OpenAI(model_name='text-davinci-003')
# db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, use_query_checker=True)
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)

# 运行Web网页

In [None]:
import gradio as gr
import mdtex2htmlb

"""Override Chatbot.postprocess"""
def postprocess(self, y):
    if y is None:
        return []
    for i, (message, response) in enumerate(y):
        y[i] = (
            None if message is None else mdtex2html.convert((message)),
            None if response is None else mdtex2html.convert(response),
        )
    return y


gr.Chatbot.postprocess = postprocess


def parse_text(text):
    """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
    lines = text.split("\n")
    lines = [line for line in lines if line != ""]
    count = 0
    for i, line in enumerate(lines):
        if "```" in line:
            count += 1
            items = line.split('`')
            if count % 2 == 1:
                lines[i] = f'<pre><code class="language-{items[-1]}">'
            else:
                lines[i] = f'<br></code></pre>'
        else:
            if i > 0:
                if count % 2 == 1:
                    line = line.replace("`", "\`")
                    line = line.replace("<", "&lt;")
                    line = line.replace(">", "&gt;")
                    line = line.replace(" ", "&nbsp;")
                    line = line.replace("*", "&ast;")
                    line = line.replace("_", "&lowbar;")
                    line = line.replace("-", "&#45;")
                    line = line.replace(".", "&#46;")
                    line = line.replace("!", "&#33;")
                    line = line.replace("(", "&#40;")
                    line = line.replace(")", "&#41;")
                    line = line.replace("$", "&#36;")
                lines[i] = "<br>"+line
    text = "".join(lines)
    return text


def predict(input, chatbot, max_length, top_p, temperature, history, past_key_values):
    chatbot.append((parse_text(input), ""))
    chatbot[-1] = (parse_text(input), parse_text(db_chain.run(input)))
    yield chatbot, history, past_key_values


def reset_user_input():
    return gr.update(value='')


def reset_state():
    return [], [], None


with gr.Blocks() as demo:
    gr.HTML("""<h1 align="center">NLP2SQL</h1>""")

    chatbot = gr.Chatbot()
    with gr.Row():
        with gr.Column(scale=4):
            with gr.Column(scale=12):
                user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
                    container=False)
            with gr.Column(min_width=32, scale=1):
                submitBtn = gr.Button("Submit", variant="primary")
        with gr.Column(scale=1):
            emptyBtn = gr.Button("Clear History")
            max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True)
            top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
            temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)

    history = gr.State([])
    past_key_values = gr.State(None)

    submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history, past_key_values],
                    [chatbot, history, past_key_values], show_progress=True)

    submitBtn.click(reset_user_input, [], [user_input])

    emptyBtn.click(reset_state, outputs=[chatbot, history, past_key_values], show_progress=True)

demo.queue().launch(share=True, inbrowser=True)
