In [None]:
import openai
import duckdb
import pandas as pd
import os
import re

# 配置你的OpenAI API密钥
openai.api_key = "Your API here"

def generate_sql(user_question, model="ft:gpt-4o-2024-08-06:personal::AXYv83vn"):
    """
    使用 Fine-Tuned 的 GPT 模型将自然语言问题转换为 SQL 语句。

    :param user_question: 用户的自然语言问题
    :param model: Fine-Tuned 的 GPT 模型名称
    :return: 生成的 SQL 语句
    """
    prompt = f"""
Pretend you are an expert at converting natural language questions into accurate SQL queries. Please generate an accurate SQL query based on the following natural language question and database schema provided below. Think sequentially and refer to the sample natural language questions with correct and incorrect outputs as well.

Database Schema:
Table 1: t_zacks_fc (This table contains fundamental indicators for companies)
Columns: 'ticker' = Unique zacks Identifier for each company/stock, ticker or trading symbol, 'comp_name' = Company name, 'exchange' = Exchange traded, 'per_end_date' = Period end date which represents quarterly data, 'per_type' = Period type (eg. Q for quarterly data), 'filing_date' = Filing date, 'filing_type' = Filing type: 10-K, 10-Q, PRELIM, 'zacks_sector_code' = Zacks sector code (Numeric Value eg. 11 = Aerospace), 'eps_diluted_net_basic’ = Earnings per share (EPS) net (Company's net earnings or losses attributable to common shareholders per basic share basis), 'lterm_debt_net_tot' = Net long-term debt (The net amount of long term debt issued and repaid. This field is either calculated as the sum of the long term debt fields or used if a company does not report debt issued and repaid separately).
Keys: ticker, per_end_date, per_type

Table 2: t_zacks_fr (This table contains fundamental ratios for companies)
Columns: 'ticker' = Unique zacks Identifier for each company/stock, ticker or trading symbol, 'per_end_date' = Period end date which represents quarterly data, 'per_type' = Period type (eg. Q for quarterly data), ‘ret_invst’ = Return on investments (An indicator of how profitable a company is relative to its assets invested by shareholders and long-term bond holders. Calculated by dividing a company's operating earnings by its long-term debt and shareholders equity), ‘tot_debt_tot_equity’ = Total debt / total equity (A measure of a company's financial leverage calculated by dividing its long-term debt by stockholders' equity).
Keys: ticker, per_end_date, per_type.

Table 3: t_zacks_mktv (This table contains market value data for companies)
Columns: 'ticker' = Unique zacks Identifier for each company/stock, ticker or trading symbol, 'per_end_date' = Period end date which represents quarterly data, 'per_type' = Period type (eg. Q for quarterly data), ‘mkt_val’ = Market Cap of Company (shares out x last monthly price per share - unit is in Millions).
Keys: ticker, per_end_date, per_type.

Table 4: t_zacks_shrs (This table contains shares outstanding data for companies)
Columns: 'ticker' = Unique zacks Identifier for each company/stock, ticker or trading symbol, 'per_end_date' = Period end date which represents quarterly data, 'per_type' = Period type (eg. Q for quarterly data), ‘shares_out’ = Number of Common Shares Outstanding from the front page of 10K/Q.
Keys: ticker, per_end_date, per_type.

Table 5: t_zacks_sectors (This table contains the zacks sector codes and their corresponding sectors)
Columns: 'zacks_sector_code' = Unique identifier for each zacks sector, 'sector': the sector descriptions that correspond to the sector code 
Keys: zacks_sector_code 

Sample natural language questions with correct and incorrect outputs: 
Sample prompt 1: Output ticker with the largest market value recorded on any given period end date. 
Correct output for prompt 1: SELECT ticker, per_end_date, MAX(mkt_val) AS max_market_value FROM t_zacks_mktv GROUP BY per_end_date ORDER BY max_market_value DESC LIMIT 1;
Incorrect output for prompt 1: SELECT MAX(mkt_val) , ticker FROM t_zacks_mktv GROUP BY ticker

Sample prompt 2: What is the company name with the lowest market cap?
Correct output for prompt 2: SELECT fc.comp_name, mktv.ticker, mktv.mkt_val FROM t_zacks_mktv AS mktv JOIN t_zacks_fc AS fc ON mktv.ticker = fc.ticker WHERE mktv.mkt_val = (SELECT MIN(mkt_val) FROM t_zacks_mktv);
Incorrect output for prompt 2:  SELECT T1.comp_name FROM t_zacks_fc AS T1 INNER JOIN t_zacks_mktv AS T2 ON T1.ticker = T2.ticker AND T1.per_end_date = T2.per_end_date AND T1.per_type = T2.per_type ORDER BY T2.mkt_val LIMIT 1

Sample prompt 3: Filter t_zacks_fc to only show companies with a total debt-to-equity ratio greater than 1.
Correct output for prompt 3: SELECT * FROM t_zacks_fr WHERE tot_debt_tot_equity > 1;
Incorrect output for prompt 3: SELECT * FROM t_zacks_fr WHERE t_zacks_mktv > 1;

Sample prompt 4: Filter t_zacks_shrs to include companies with more than 500 million shares outstanding as of the most recent quarter.
Correct output for prompt 4: SELECT *
FROM t_zacks_shrs
WHERE shares_out > 5000
ORDER BY per_end_date DESC;
Incorrect output for prompt 4: SELECT * FROM t_zacks_shrs WHERE shares_out > 500000000

Sample prompt 5: Combine t_zacks_mktv and t_zacks_shrs to show tickers with market cap and shares outstanding in the latest period end date.
Correct output for prompt 5: SELECT mktv.ticker, mktv.per_end_date, mktv.mkt_val, shrs.shares_out
FROM t_zacks_mktv mktv
JOIN t_zacks_shrs shrs ON mktv.ticker = shrs.ticker AND mktv.per_end_date = shrs.per_end_date
ORDER BY mktv.per_end_date DESC;
Incorrect output for prompt 5: SELECT ticker, mkt_val, shares_out FROM t_zacks_mktv INNER JOIN t_zacks_shrs ON t_zacks_mktv.ticker = t_zacks_shrs.ticker AND t_zacks_mktv.per_end_date = t_zacks_shrs.per_end_date ORDER BY per_end_date DESC LIMIT 1

Sample prompt 6: Join t_zacks_fc and t_zacks_fr to show tickers with total debt-to-equity ratios and EPS from NASDAQ as of Q2 2024.
Correct output for prompt 6: SELECT fc.ticker, fc.eps_diluted_net_basic, fr.tot_debt_tot_equity
FROM t_zacks_fc fc
JOIN t_zacks_fr fr ON fc.ticker = fr.ticker AND fc.per_end_date = fr.per_end_date
WHERE fc.exchange = 'NASDAQ' AND fc.per_type = 'Q' AND fc.per_end_date BETWEEN '2024-04-01' AND '2024-06-30';
Incorrect output for prompt 6: SELECT T1.ticker, T1.eps_diluted_net_basic, T2.ret_invst, T2.tot_debt_tot_equity FROM t_zacks_fc AS T1 INNER JOIN t_zacks_fr AS T2 ON T1.ticker = T2.ticker AND T1.per_end_date = T2.per_end_date WHERE T1.exchange = 'NASDAQ' AND T1.per_type = 'Q2';

Please make sure that when you are joining 2 or more tables, you are using all 3 keys (ticker, per_end_date & per_type) Also, ensure that the SQL query is syntactically correct and provides the expected output based on the natural language question provided.

User's Question:
{user_question}

Please provide only the SQL query without any markdown, code block syntax, or explanations.
    """
    try:
        response = openai.ChatCompletion.create(
            model=model,
            messages=[
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": prompt}
            ],
            max_tokens=300,  # 根据复杂性调整
            temperature=0.0,  # 设置为0以获得更确定性的输出
            n=1,
            stop=None
        )
        # 获取生成的SQL查询
        raw_sql = response.choices[0].message['content'].strip()
        print("\n原始生成的SQL查询:")
        print(raw_sql)

        # 清理可能的代码块语法
        # 移除 ```sql 和 ``` 标记
        sql_query = re.sub(r'^```sql\s*', '', raw_sql, flags=re.IGNORECASE)
        sql_query = re.sub(r'```$', '', sql_query, flags=re.IGNORECASE)
        sql_query = sql_query.strip()

        print("\n清理后的SQL查询:")
        print(sql_query)
        return sql_query
    except openai.OpenAIError as e:
        print(f"生成SQL查询时发生错误: {e}")
        return None

def execute_sql(sql, tables_files):
    """
    在指定的Parquet和CSV文件上执行SQL查询。

    :param sql: 要执行的SQL语句
    :param tables_files: 字典，键为表名，值为字典包含文件路径和文件类型
    :return: 查询结果的DataFrame
    """
    try:
        # 使用DuckDB连接到内存数据库
        conn = duckdb.connect(database=':memory:')

        # 注册每个文件为一个表
        for table_name, file_info in tables_files.items():
            file_path = file_info['path']
            file_type = file_info['type']
            if file_type == 'parquet':
                conn.execute(f"CREATE TABLE {table_name} AS SELECT * FROM read_parquet('{file_path}')")
            elif file_type == 'csv':
                conn.execute(f"CREATE TABLE {table_name} AS SELECT * FROM read_csv_auto('{file_path}')")
            else:
                print(f"未知的文件类型 '{file_type}'，跳过表 '{table_name}'。")
                continue
            print(f"已注册表: {table_name} -> {file_path} ({file_type})")

        # 执行SQL查询
        result = conn.execute(sql).fetchdf()

        # 关闭连接
        conn.close()

        return result
    except Exception as e:
        print(f"执行SQL查询时发生错误: {e}")
        return None

def main():
    # 定义表名与文件路径及类型的映射
    tables_files = {
        't_zacks_fc': {'path': 't_zacks_fc.parquet', 'type': 'parquet'},       # 请替换为你的Parquet文件路径
        't_zacks_fr': {'path': 't_zacks_fr.parquet', 'type': 'parquet'},
        't_zacks_mktv': {'path': 't_zacks_mktv.parquet', 'type': 'parquet'},
        't_zacks_shrs': {'path': 't_zacks_shrs.parquet', 'type': 'parquet'},
        't_zacks_sectors': {'path': 't_zacks_sectors.csv', 'type': 'csv'}       # CSV文件
    }

    # 检查所有文件是否存在
    missing_files = [info['path'] for info in tables_files.values() if not os.path.exists(info['path'])]
    if missing_files:
        print("以下文件未找到，请检查路径:")
        for path in missing_files:
            print(f" - {path}")
        return

    # 用户输入的自然语言问题
    user_question = input("请输入你的问题: ")

    # 生成SQL语句
    sql_query = generate_sql(user_question)

    if sql_query:
        # 执行SQL查询
        results = execute_sql(sql_query, tables_files)

        if results is not None:
            print("\n查询结果:")
            print(results)
        else:
            print("未能获取查询结果。")
    else:
        print("未能生成SQL语句。")

if __name__ == "__main__":

    main()



原始生成的SQL查询:
SELECT fc.*, fr.ret_invst, fr.tot_debt_tot_equity, mktv.mkt_val, shrs.shares_out
FROM t_zacks_fc fc
JOIN t_zacks_fr fr ON fc.ticker = fr.ticker AND fc.per_end_date = fr.per_end_date AND fc.per_type = fr.per_type
JOIN t_zacks_mktv mktv ON fc.ticker = mktv.ticker AND fc.per_end_date = mktv.per_end_date AND fc.per_type = mktv.per_type
JOIN t_zacks_shrs shrs ON fc.ticker = shrs.ticker AND fc.per_end_date = shrs.per_end_date AND fc.per_type = shrs.per_type
WHERE fc.ticker = 'MSFT' AND fc.per_end_date >= '2010-01-01';

清理后的SQL查询:
SELECT fc.*, fr.ret_invst, fr.tot_debt_tot_equity, mktv.mkt_val, shrs.shares_out
FROM t_zacks_fc fc
JOIN t_zacks_fr fr ON fc.ticker = fr.ticker AND fc.per_end_date = fr.per_end_date AND fc.per_type = fr.per_type
JOIN t_zacks_mktv mktv ON fc.ticker = mktv.ticker AND fc.per_end_date = mktv.per_end_date AND fc.per_type = mktv.per_type
JOIN t_zacks_shrs shrs ON fc.ticker = shrs.ticker AND fc.per_end_date = shrs.per_end_date AND fc.per_type = shrs.per_type
W