# 1.加载配置文件

In [None]:
from sentence_transformers import SentenceTransformer
import os
from openai import AzureOpenAI
import json
from aip import AipSpeech

# OPENAI-API配置
AZURE_OPENAI_API_KEY = '***'
AZURE_OPENAI_ENDPOINT = '***'
os.environ['AZURE_OPENAI_API_KEY'] = AZURE_OPENAI_API_KEY
os.environ['AZURE_OPENAI_ENDPOINT'] = AZURE_OPENAI_ENDPOINT
os.environ['OPENAI_API_VERSION'] = '2024-05-01-preview'

client = AzureOpenAI(
    api_key=os.getenv("AZURE_OPENAI_API_KEY"),
    api_version="2024-05-01-preview",
    azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT")
)

def get_completion(prompt, model="gpt-4o"):
    messages = [{"role": "user", "content": prompt}]
    response = client.chat.completions.create(
        model=model,
        messages=messages,
        temperature=0,
    )
    return response.choices[0].message.content

def load_model(model_path):
    return SentenceTransformer(model_path)

def load_json(file_path):
    with open(file_path, 'r') as file:
        return json.load(file)

def encode_questions(embedding_model, questions):
    return embedding_model.encode(questions, normalize_embeddings=True)

# 加载本地embedding模型
embedding_model_path = "model/Dmeta-embedding-zh"
embedding_model = load_model(embedding_model_path)

all_data_path = "dataset/all_data_process_unique.json"
all_data = load_json(all_data_path)

questions = [item['question'] for item in all_data]
queries = [item['query'] for item in all_data]
mask_queries = [item['mask_query'] for item in all_data]

question_embeddings = encode_questions(embedding_model, questions)

# 加载百度asr模型
BAIDU_APP_ID = '***'
BAIDU_API_KEY = '***'
BAIDU_SECRET_KEY = '***'
aip_speech = AipSpeech(BAIDU_APP_ID, BAIDU_API_KEY, BAIDU_SECRET_KEY)

# 2.输出sql

In [None]:
import numpy as np
from sklearn.metrics.pairwise import euclidean_distances
import re
import time
import speech_recognition as sr
import logging

# 嵌入模型
def encode_questions(embedding_model, questions):
    return embedding_model.encode(questions, normalize_embeddings=True)

def compute_distances(question_embedding, question_embeddings):
    return euclidean_distances(question_embedding, question_embeddings).squeeze()

def extract_sql(query):
    query = query.replace("### SQL:", "").replace("###", "").replace("#", "")
    sql_patterns = [r"```SQL(.*?)```", r"```SQL(.*?)", r"```sql(.*?)```", r"```sql(.*?)"]
    for pattern in sql_patterns:
        matches = re.findall(pattern, query.replace('\n', ' '))
        if matches:
            return matches[0].strip()
    return query.replace('\n', '').replace("`", "").strip()

def generate_prompt(examples, create_table_sql_prompt, question, database_records_prompt):
    return f"""
### 以下是基于类似问题提供的一些问题和相应的SQL查询的示例对：
{examples}
### 仅通过SQLite SQL查询回答问题，不需要解释。您必须在确保正确性的同时最小化SQL执行时间。
### 给定以下数据库架构:
#
{create_table_sql_prompt}
#
### 以下是相关数据库中引用的一些数据信息:
#
{database_records_prompt}
#
### 问题:{question}
### SQL:
"""

def jaccard_similarity(str1, str2):
    set1 = set(str1.split())
    set2 = set(str2.split())
    intersection = len(set1.intersection(set2))
    union = len(set1.union(set2))
    return intersection / union

def compute_combined_scores(distances, mask_similarities, threshold):
    return [(i, distances[i], mask_similarities[i]) for i in range(len(distances)) if mask_similarities[i] >= threshold]

def filter_prompt_lines(prompt, found_tables):
    filtered_prompt = []
    for line in prompt.split('\n'):
        if line.startswith("# ") and len(line) > 3:
            if any(table in line for table in found_tables):
                filtered_prompt.append(line)
        else:
            filtered_prompt.append(line)
    return "\n".join(filtered_prompt)

def add_comment_to_sql(sql):
    return "\n".join(["# " + line for line in sql.strip().split("\n")])

def extract_table_info(sql):
    # Regex for capturing table creation and its content
    table_regex = re.compile(r'CREATE TABLE\s+(\w+(?:\.\w+)?)\s*\((.*?)\);', re.DOTALL | re.IGNORECASE)
    # Regex for capturing column names, allowing for complex definitions
    column_regex = re.compile(r'(\w+)\s+[^\s,]+.*?(?:,|$)', re.IGNORECASE)
    tables = table_regex.findall(sql)
    table_info = []

    for table in tables:
        table_name = table[0]
        columns = column_regex.findall(table[1])
        filtered_columns = [col for col in columns if 'PRIMARY' not in col.upper()]
        table_info.append(f"{table_name} ({', '.join(filtered_columns)})")

    return table_info

def format_sql_statements(sql):
    # 去除前后的空格和换行符
    sql = sql.strip()

    # 使用正则表达式匹配每个完整的SQL语句
    statements = re.split(r';\s*', sql)

    formatted_statements = []
    for statement in statements:
        if statement.strip():
            # 将所有换行符替换为空格
            single_line_statement = " ".join(statement.split())
            formatted_statements.append(single_line_statement + ";")

    return "\n".join(formatted_statements)

def main(your_question, create_table_sql, database_records):
    if not create_table_sql.strip():
        return "ERROR: 创建表的SQL语句不能为空。", "", ""
    
    your_question_embedding = encode_questions(embedding_model, [your_question])
    distances = compute_distances(your_question_embedding, question_embeddings)
    k = 5
    top_k_indices = distances.argsort()[:k]
    examples = "\n".join([f"### {questions[idx]}\n{queries[idx]}\n" for idx in top_k_indices]).strip()
    # print(f"生成pre-sql的所需示例:\n{examples}")

    formatted_sql = format_sql_statements(create_table_sql)
    create_table_sql_prompt = add_comment_to_sql(formatted_sql)
    
    formatted_database_records = format_sql_statements(database_records)
    database_records_prompt = add_comment_to_sql(formatted_database_records)
    
    table_info = extract_table_info(formatted_sql)

    prompt = generate_prompt(examples, create_table_sql_prompt, your_question, database_records_prompt)
    # print(f"pre-prompt为:\n{prompt}")
    response = get_completion(prompt)
    # print(f"pre-sql为:\n{response}")
    response_extract = extract_sql(response)
    
    table_names = []
    column_set = []

    for info in table_info:
        table_name, columns = info.split(' (')
        table_names.append(table_name.strip())
        columns = columns.strip(')').split(', ')
        column_set.extend(columns)

    table_names = list(set(table_names))
    column_set = list(set(column_set))
    found_tables = [table for table in table_names if table in response_extract]

    response_mask = response_extract
    for table in table_names:
        response_mask = re.sub(r'\b' + table + r'\b', '<mask>', response_mask)

    for column in column_set:
        response_mask = re.sub(r'\b' + column + r'\b', '<unk>', response_mask)

    mask_similarities = [jaccard_similarity(response_mask, mask_query) for mask_query in mask_queries]

    threshold = 0.4
    combined_scores = compute_combined_scores(distances, mask_similarities, threshold)
    combined_scores = sorted(combined_scores, key=lambda x: x[1])
    top_pairs = [x for x in combined_scores if x[2] >= threshold]

    top_k_indices = [x[0] for x in top_pairs[:k]]

    examples_new = "\n".join([f"### {questions[idx]}\n{queries[idx]}\n" for idx in top_k_indices]).strip()
    # print(f"生成final-sql的所需示例:\n{examples_new}")

    prompt_new = generate_prompt(examples_new, create_table_sql_prompt, your_question, database_records_prompt)
    prompt_final = filter_prompt_lines(prompt_new, found_tables)
    # print(f"final-prompt为:\n{prompt_final}")
    
    response_final = get_completion(prompt_final)
    response_final = extract_sql(response_final)
    return response_final

def recognize_speech():
    logging.info('录音中...')
    with mic as source:
        r.adjust_for_ambient_noise(source)
        audio = r.listen(source, phrase_time_limit=10)  # 录音最大持续时间10秒
    logging.info('录音结束，识别中...')
    
    start_time = time.time()
    audio_data = audio.get_wav_data()
    
    # 调用百度语音识别API进行语音识别
    ret = aip_speech.asr(audio_data, 'wav', 16000, {'dev_pid': 1536, })
    
    if ret and ret['err_no'] == 0:
        result = ''.join(ret['result'])
        print("识别结果:", result)
        end_time = time.time()
        print("识别时间:", end_time - start_time, "秒")
    else:
        print("识别失败:", ret['err_msg'])
    logging.info('end')

    return result if ret and ret['err_no'] == 0 else ""

if __name__ == "__main__":
    logging.basicConfig(level=logging.DEBUG)

    # 创建语音识别对象
    r = sr.Recognizer()
    mic = sr.Microphone(sample_rate=16000)
    create_table_sql = input("输入创建表的SQL语句,必填项,每个表请以';'进行分隔,在每个字段后面加入COMMENT信息会让结果更加精确\n示例:CREATE TABLE t_device_online_status (id INT PRIMARY KEY AUTO_INCREMENT COMMENT '序号',mac_id CHAR(16) NOT NULL COMMENT '设备id',login TINYINT(1) NOT NULL COMMENT '设备上下线状态，设备上线：1，设备下线：0',created TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',modified TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '修改时间');")
    database_records = input("输入数据表的具体数据信息,可选项,每个表的信息请以';'进行分隔,添加这部分可使结果更加精确\n示例:t_device_online_status (id [1, 2, 3], mac_id [0012*************4057, 0012*************3747, 0012*************5913],login [1, 1, 1],created [2024-07-01 00:00:00.0, 2024-07-01 00:00:00.0, 2024-07-01 00:00:00.0],modified [2024-07-01 03:24:30.0, 2024-07-01 23:55:40.0, 2024-07-01 23:20:40.0]);")

    input_choice = input("请选择输入方式（1：音频输入，2：文字输入）：")
    if input_choice == '1':
        # 进行语音识别并获取问题
        your_question = recognize_speech()
    elif input_choice == '2':
        your_question = input("请输入您的问题：")
    else:
        print("无效的选择,请选择1或2")
        exit()

    final_sql = main(your_question,create_table_sql, database_records)
    print(f"final-sql为:\n{final_sql}")
