In [None]:
import os
import openai
import duckdb
import pandas as pd
import re
from fpdf import FPDF
import matplotlib.pyplot as plt
import sys
import requests  # 确保已安装 requests 库

# -------------------------------
# 1. 安全加载 OpenAI API 密钥
# -------------------------------

# 推荐通过环境变量加载 API 密钥以确保安全
openai_api_key = "YOUR API HERE"

if not openai_api_key:
    raise ValueError("未找到 OpenAI API 密钥。请设置 'OPENAI_API_KEY' 环境变量。")

openai.api_key = openai_api_key

# -------------------------------
# 2. 下载字体（如果需要）
# -------------------------------

def download_font(font_url, save_path):
    try:
        response = requests.get(font_url)
        response.raise_for_status()
        with open(save_path, 'wb') as f:
            f.write(response.content)
        print(f"字体已成功下载并保存到 {save_path}。")
    except Exception as e:
        print(f"下载字体失败: {e}")
        sys.exit(1)

# 如果您选择使用自定义字体，请确保下载并放置在正确的位置
# 这里我们不使用自定义字体，因此可以跳过此步骤

def sanitize_text(text):
    """
    清理和处理文本，以确保其在 PDF 中正确显示。
    
    Args:
        text (str): 需要清理的文本。
        
    Returns:
        str: 清理后的文本。
    """
    if not isinstance(text, str):
        text = str(text)
    # 替换或移除可能导致问题的字符
    text = text.replace('\n', ' ').replace('\r', ' ').replace('\t', ' ')
    # 如果需要，可以添加更多的清理步骤，例如处理特殊字符
    # 例如，移除 FPDF 不支持的字符
    # text = re.sub(r'[^\x00-\x7F]+',' ', text)
    return text.strip()

# -------------------------------
# 3. 数据加载使用 DuckDB
# -------------------------------

def load_data_duckdb(tables_files):
    """
    在 DuckDB 中注册 Parquet 和 CSV 文件作为表。
    
    Args:
        tables_files (dict): 表名到文件路径和类型的映射字典。
    
    Returns:
        duckdb.DuckDBPyConnection: 注册了表的活跃 DuckDB 连接。
    """
    try:
        # 连接到内存中的 DuckDB 数据库
        conn = duckdb.connect(database=':memory:')

        # 注册每个文件为一个表
        for table_name, file_info in tables_files.items():
            file_path = file_info['path']
            file_type = file_info['type']
            if file_type == 'parquet':
                conn.execute(f"CREATE TABLE {table_name} AS SELECT * FROM read_parquet('{file_path}')")
            elif file_type == 'csv':
                conn.execute(f"CREATE TABLE {table_name} AS SELECT * FROM read_csv_auto('{file_path}')")
            else:
                print(f"未知的文件类型 '{file_type}'，跳过表 '{table_name}'。")
                continue
            print(f"已注册表: {table_name} -> {file_path} ({file_type})")

        print("所有表已成功注册到 DuckDB。")
        return conn
    except Exception as e:
        print(f"将数据加载到 DuckDB 时出错: {e}")
        return None

# -------------------------------
# 4. 生成 SQL 查询
# -------------------------------

def generate_sql(user_question, model="ft:gpt-4o-2024-08-06:personal::AXYv83vn"):
    """
    使用 OpenAI GPT 模型根据用户的自然语言问题生成 SQL 查询。
    
    Args:
        user_question (str): 用户的自然语言问题。
        model (str): 使用的 GPT 模型名称。
    
    Returns:
        str: 生成的 SQL 查询。
    """
    prompt = f"""
Pretend you are an expert at converting natural language questions into accurate SQL queries. Please generate an accurate SQL query based on the following natural language question and database schema provided below. Think sequentially and refer to the sample natural language questions with correct and incorrect outputs as well.

Database Schema:
Table 1: t_zacks_fc (This table contains fundamental indicators for companies)
Columns: 'ticker' = Unique zacks Identifier for each company/stock, ticker or trading symbol, 'comp_name' = Company name, 'exchange' = Exchange traded, 'per_end_date' = Period end date which represents quarterly data, 'per_type' = Period type (eg. Q for quarterly data), 'filing_date' = Filing date, 'filing_type' = Filing type: 10-K, 10-Q, PRELIM, 'zacks_sector_code' = Zacks sector code (Numeric Value eg. 11 = Aerospace), 'eps_diluted_net_basic’ = Earnings per share (EPS) net (Company's net earnings or losses attributable to common shareholders per basic share basis), 'lterm_debt_net_tot' = Net long-term debt (The net amount of long term debt issued and repaid. This field is either calculated as the sum of the long term debt fields or used if a company does not report debt issued and repaid separately).
Keys: ticker, per_end_date, per_type

Table 2: t_zacks_fr (This table contains fundamental ratios for companies)
Columns: 'ticker' = Unique zacks Identifier for each company/stock, ticker or trading symbol, 'per_end_date' = Period end date which represents quarterly data, 'per_type' = Period type (eg. Q for quarterly data), ‘ret_invst’ = Return on investments (An indicator of how profitable a company is relative to its assets invested by shareholders and long-term bond holders. Calculated by dividing a company's operating earnings by its long-term debt and shareholders equity), ‘tot_debt_tot_equity’ = Total debt / total equity (A measure of a company's financial leverage calculated by dividing its long-term debt by stockholders' equity).
Keys: ticker, per_end_date, per_type.

Table 3: t_zacks_mktv (This table contains market value data for companies)
Columns: 'ticker' = Unique zacks Identifier for each company/stock, ticker or trading symbol, 'per_end_date' = Period end date which represents quarterly data, 'per_type' = Period type (eg. Q for quarterly data), ‘mkt_val’ = Market Cap of Company (shares out x last monthly price per share - unit is in Millions).
Keys: ticker, per_end_date, per_type.

Table 4: t_zacks_shrs (This table contains shares outstanding data for companies)
Columns: 'ticker' = Unique zacks Identifier for each company/stock, ticker or trading symbol, 'per_end_date' = Period end date which represents quarterly data, 'per_type' = Period type (eg. Q for quarterly data), ‘shares_out’ = Number of Common Shares Outstanding from the front page of 10K/Q.
Keys: ticker, per_end_date, per_type.

Table 5: t_zacks_sectors (This table contains the zacks sector codes and their corresponding sectors)
Columns: 'zacks_sector_code' = Unique identifier for each zacks sector, 'sector': the sector descriptions that correspond to the sector code 
Keys: zacks_sector_code 

Sample natural language questions with correct and incorrect outputs: 
Sample prompt 1: Output ticker with the largest market value recorded on any given period end date. 
Correct output for prompt 1: SELECT ticker, per_end_date, MAX(mkt_val) AS max_market_value FROM t_zacks_mktv GROUP BY per_end_date ORDER BY max_market_value DESC LIMIT 1;
Incorrect output for prompt 1: SELECT MAX(mkt_val) , ticker FROM t_zacks_mktv GROUP BY ticker

Sample prompt 2: What is the company name with the lowest market cap?
Correct output for prompt 2: SELECT fc.comp_name, mktv.ticker, mktv.mkt_val FROM t_zacks_mktv AS mktv JOIN t_zacks_fc AS fc ON mktv.ticker = fc.ticker WHERE mktv.mkt_val = (SELECT MIN(mkt_val) FROM t_zacks_mktv);
Incorrect output for prompt 2:  SELECT T1.comp_name FROM t_zacks_fc AS T1 INNER JOIN t_zacks_mktv AS T2 ON T1.ticker = T2.ticker AND T1.per_end_date = T2.per_end_date AND T1.per_type = T2.per_type ORDER BY T2.mkt_val LIMIT 1

Sample prompt 3: Filter t_zacks_fc to only show companies with a total debt-to-equity ratio greater than 1.
Correct output for prompt 3: SELECT * FROM t_zacks_fr WHERE tot_debt_tot_equity > 1;
Incorrect output for prompt 3: SELECT * FROM t_zacks_fr WHERE t_zacks_mktv > 1;

Sample prompt 4: Filter t_zacks_shrs to include companies with more than 500 million shares outstanding as of the most recent quarter.
Correct output for prompt 4: SELECT *
FROM t_zacks_shrs
WHERE shares_out > 5000
ORDER BY per_end_date DESC;
Incorrect output for prompt 4: SELECT * FROM t_zacks_shrs WHERE shares_out > 500000000

Sample prompt 5: Combine t_zacks_mktv and t_zacks_shrs to show tickers with market cap and shares outstanding in the latest period end date.
Correct output for prompt 5: SELECT mktv.ticker, mktv.per_end_date, mktv.mkt_val, shrs.shares_out
FROM t_zacks_mktv mktv
JOIN t_zacks_shrs shrs ON mktv.ticker = shrs.ticker AND mktv.per_end_date = shrs.per_end_date
ORDER BY mktv.per_end_date DESC;
Incorrect output for prompt 5: SELECT ticker, mkt_val, shares_out FROM t_zacks_mktv INNER JOIN t_zacks_shrs ON t_zacks_mktv.ticker = t_zacks_shrs.ticker AND t_zacks_mktv.per_end_date = t_zacks_shrs.per_end_date ORDER BY per_end_date DESC LIMIT 1

Sample prompt 6: Join t_zacks_fc and t_zacks_fr to show tickers with total debt-to-equity ratios and EPS from NASDAQ as of Q2 2024.
Correct output for prompt 6: SELECT fc.ticker, fc.eps_diluted_net_basic, fr.tot_debt_tot_equity
FROM t_zacks_fc fc
JOIN t_zacks_fr fr ON fc.ticker = fr.ticker AND fc.per_end_date = fr.per_end_date
WHERE fc.exchange = 'NASDAQ' AND fc.per_type = 'Q' AND fc.per_end_date BETWEEN '2024-04-01' AND '2024-06-30';
Incorrect output for prompt 6: SELECT T1.ticker, T1.eps_diluted_net_basic, T2.ret_invst, T2.tot_debt_tot_equity FROM t_zacks_fc AS T1 INNER JOIN t_zacks_fr AS T2 ON T1.ticker = T2.ticker AND T1.per_end_date = T2.per_end_date WHERE T1.exchange = 'NASDAQ' AND T1.per_type = 'Q2';

Please make sure that when you are joining 2 or more tables, you are using all 3 keys (ticker, per_end_date & per_type). Also, ensure that the SQL query is syntactically correct and provides the expected output based on the natural language question provided.

User's Question:
{user_question}

Please provide only the SQL query without any markdown, code block syntax, or explanations.
    """
    try:
        response = openai.ChatCompletion.create(
            model=model,
            messages=[
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": prompt}
            ],
            max_tokens=300,  # 根据复杂性调整
            temperature=0.0,  # 设置为0以获得更确定性的输出
            n=1,
            stop=None
        )
        # 提取 SQL 查询
        raw_sql = response.choices[0].message['content'].strip()
        print("\n生成的 SQL 查询:")
        print(raw_sql)

        # 清理任何代码块格式
        sql_query = re.sub(r'^```sql\s*', '', raw_sql, flags=re.IGNORECASE)
        sql_query = re.sub(r'```$', '', sql_query, flags=re.IGNORECASE)
        sql_query = sql_query.strip()

        print("\n清理后的 SQL 查询:")
        print(sql_query)
        return sql_query
    except openai.OpenAIError as e:
        print(f"生成 SQL 查询时发生错误: {e}")
        return None

# -------------------------------
# 5. 执行 SQL 查询 使用 DuckDB
# -------------------------------

def execute_sql_duckdb(sql, conn):
    """
    使用提供的 DuckDB 连接执行 SQL 查询。
    
    Args:
        sql (str): 要执行的 SQL 查询。
        conn (duckdb.DuckDBPyConnection): 活跃的 DuckDB 连接。
    
    Returns:
        pd.DataFrame: SQL 查询的结果。
    """
    try:
        result = conn.execute(sql).fetchdf()
        print("\nSQL 查询执行成功。")
        print(f"检索到的记录数: {result.shape[0]}")
        return result
    except Exception as e:
        print(f"执行 SQL 查询时出错: {e}")
        return pd.DataFrame()

# -------------------------------
# 6. 生成分析 使用 OpenAI API
# -------------------------------

def generate_analysis_from_openai(dataframe, user_question):
    if dataframe.empty:
        return "No data available for analysis."

    table_md = dataframe.to_markdown(index=False)
    prompt = f"""
I have executed a SQL query based on the following user question and obtained the data below.

User's Question:
{user_question}

Data Table:
{table_md}

Pretend you are an experienced equity analyst working in the banking industry. Please analyze this data in the style of an expert equity analyst, highlighting trends, comparing companies, analyzing significance of metrics, and noting any interesting insights regarding this data.
    """

    try:
        # 调用 OpenAI API
        response = openai.ChatCompletion.create(
            model="ft:gpt-4o-2024-08-06:personal::AYFZ3Shk",
            messages=[
                {"role": "system", "content": "You are an experienced equity analyst."},
                {"role": "user", "content": prompt}
            ],
            max_tokens=500,
            temperature=0.5,
            n=1,
            stop=None
        )
        # 获取分析内容
        analysis = response.choices[0].message['content'].strip()
        print("\n生成的分析:")
        print(analysis)
        return analysis
    except openai.OpenAIError as e:
        return f"生成分析时发生错误: {e}"

# -------------------------------
# 7. 定义 PDF 生成类
# -------------------------------

class PDF(FPDF):
    def __init__(self):
        super().__init__()
        # 使用默认的 Arial 字体
        self.set_font("Arial", "", 12)
        # 设置自动分页
        self.set_auto_page_break(auto=True, margin=15)
        # 设置页边距
        self.set_margins(left=15, top=20, right=15)

    def header(self):
        self.set_font("Arial", "B", 16)
        self.cell(0, 10, "Equity Analyst Report", align="C", ln=True)
        self.ln(5)

    def chapter_title(self, title):
        self.set_font("Arial", "B", 14)
        title = sanitize_text(title)
        self.cell(0, 10, title, 0, 1, "L")
        self.ln(2)

    def chapter_body(self, body):
        self.set_font("Arial", "", 12)
        body = sanitize_text(body)
        self.multi_cell(0, 10, body)
        self.ln()

    def table(self, data):
        if data.empty:
            self.set_font("Arial", "I", 12)
            self.cell(0, 10, "No data available to display.", 0, 1, 'C')
            self.ln()
            return

        # 计算列宽
        col_widths = self.calculate_col_widths(data)
        self.set_font("Arial", "B", 10)
        # 添加表头
        for header in data.columns:
            header = sanitize_text(header)
            self.cell(col_widths[header], 8, header, 1, 0, 'C')
        self.ln()
        # 添加表格行
        self.set_font("Arial", "", 10)
        max_rows_per_page = int((self.h - self.y - 15) / 8)
        row_count = 0
        for _, row in data.iterrows():
            if row_count == max_rows_per_page:
                self.add_page()
                # 在新页重复表头
                self.set_font("Arial", "B", 10)
                for header in data.columns:
                    header = sanitize_text(header)
                    self.cell(col_widths[header], 8, header, 1, 0, 'C')
                self.ln()
                self.set_font("Arial", "", 10)
                row_count = 0
            for header in data.columns:
                cell_text = str(row[header]) if pd.notnull(row[header]) else ""
                cell_text = sanitize_text(cell_text)
                self.cell(col_widths[header], 8, cell_text, 1, 0, 'C')
            self.ln()
            row_count += 1
        self.ln()

    def calculate_col_widths(self, data):
        # 每列的最大宽度
        max_width = (self.w - 30) / len(data.columns)
        col_widths = {}
        for col in data.columns:
            col_widths[col] = max_width
        return col_widths

    def add_image(self, image_path, title, width=180):
        if not os.path.exists(image_path):
            print(f"图像文件 {image_path} 不存在。")
            return
        self.chapter_title(title)
        self.image(image_path, w=width)
        self.ln(10)

# -------------------------------
# 8. 生成 PDF 报告
# -------------------------------

def generate_pdf_report(pdf, analysis_text, data_table, chart_paths, filename="equity_analyst_report.pdf"):
    # 添加概述
    pdf.chapter_title("Selected Companies Overview")
    overview_text = (
        "This report provides an analysis of selected companies based on the user's query, including data on revenue, net income, and market capitalization."
    )
    pdf.chapter_body(overview_text)

    # 添加分析
    pdf.chapter_title("Analysis")
    pdf.chapter_body(analysis_text)

    # 添加图表
    if chart_paths:
        pdf.chapter_title("Visualizations")
        for chart_path in chart_paths:
            if 'pie_chart' in chart_path.lower():
                chart_title = "Market Value Distribution Pie Chart"
            elif 'bar_chart' in chart_path.lower():
                chart_title = "Market Value Comparison Bar Chart"
            else:
                chart_title = "Chart"
            pdf.add_image(chart_path, chart_title)

    # 添加数据表
    pdf.chapter_title("Company Financial Data")
    pdf.table(data_table)

    # 保存 PDF
    try:
        pdf.output(filename)
        print(f"报告已生成并保存为 {filename}")
    except Exception as e:
        print(f"保存 PDF 时出错: {e}")

# -------------------------------
# 9. 交互式聊天功能
# -------------------------------

def interactive_chat_duckdb(conn, tables_files):
    print("\n开始与助手聊天。您可以询问有关数据的问题。")
    print("输入 'exit' 或 'quit' 以结束。\n")

    while True:
        user_input = input("您: ").strip()
        if user_input.lower() in ["exit", "quit"]:
            print("结束聊天。再见！")
            break
        elif user_input.lower() in ["help", "h"]:
            print("\n您可以询问与数据相关的问题，例如：")
            print("- 哪些公司是按市值排名前5的？")
            print("- 显示公司X的财务指标。")
            print("- 比较科技行业中公司的市值。\n")
            continue

        # 根据用户输入生成 SQL 查询
        sql_query = generate_sql(user_input)
        if not sql_query:
            print("生成 SQL 查询失败。请尝试其他问题。")
            continue

        # 执行 SQL 查询
        query_result = execute_sql_duckdb(sql_query, conn)
        if query_result.empty:
            print("SQL 查询未返回任何数据。")
            continue
        else:
            print("\n查询结果:")
            print(query_result)

        # 生成分析
        analysis_text = generate_analysis_from_openai(query_result, user_input)

        # 生成图表
        chart_paths = generate_charts(query_result)

        # 初始化 PDF
        pdf_filename = "equity_analyst_report.pdf"
        pdf = PDF()
        pdf.add_page()

        # 生成包含图表的 PDF 报告
        generate_pdf_report(pdf, analysis_text, query_result, chart_paths, filename=pdf_filename)

        print(f"\n报告已生成并保存为 {pdf_filename}")

# -------------------------------
# 10. 生成图表
# -------------------------------

def generate_charts(dataframe, output_dir="charts"):
    if dataframe.empty:
        print("没有数据可用于生成图表。")
        return []

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    chart_paths = []

    # 检查是否包含必要的列
    if 'mkt_val' in dataframe.columns and 'ticker' in dataframe.columns:
        try:
            # 按 'ticker' 聚合数据
            aggregated_data = dataframe.groupby('ticker', as_index=False)['mkt_val'].sum()

            # 选择市值前5的公司
            top_companies = aggregated_data.nlargest(5, 'mkt_val')

            # --- 饼图 ---
            plt.figure(figsize=(6,6))
            plt.pie(top_companies['mkt_val'], labels=top_companies['ticker'], autopct='%1.1f%%', startangle=140)
            plt.title('Top 5 Companies Market Value Distribution')
            pie_chart_path = os.path.join(output_dir, 'market_value_pie_chart.png')
            plt.savefig(pie_chart_path, bbox_inches='tight')
            plt.close()
            chart_paths.append(pie_chart_path)

            # --- 条形图 ---
            plt.figure(figsize=(8,6))
            plt.bar(top_companies['ticker'], top_companies['mkt_val'], color='skyblue')
            plt.xlabel('Ticker')
            plt.ylabel('Market Value (Millions)')
            plt.title('Market Value of Top 5 Companies')
            bar_chart_path = os.path.join(output_dir, 'market_value_bar_chart.png')
            plt.savefig(bar_chart_path, bbox_inches='tight')
            plt.close()
            chart_paths.append(bar_chart_path)

            print(f"\n图表已生成并保存在 '{output_dir}' 目录。")
        except Exception as e:
            print(f"生成图表时出错: {e}")
    else:
        print("数据中不包含生成图表所需的列 ('mkt_val', 'ticker')。")

    return chart_paths

# -------------------------------
# 11. 主函数集成所有功能
# -------------------------------

def main():
    # 定义表名与文件路径及类型的映射
    tables_files = {
        't_zacks_fc': {'path': 't_zacks_fc.parquet', 'type': 'parquet'},       # 请根据实际路径替换
        't_zacks_fr': {'path': 't_zacks_fr.parquet', 'type': 'parquet'},
        't_zacks_mktv': {'path': 't_zacks_mktv.parquet', 'type': 'parquet'},
        't_zacks_shrs': {'path': 't_zacks_shrs.parquet', 'type': 'parquet'},
        't_zacks_sectors': {'path': 't_zacks_sectors.csv', 'type': 'csv'}      # CSV 文件
    }

    # 检查所有文件是否存在
    missing_files = [info['path'] for info in tables_files.values() if not os.path.exists(info['path'])]
    if missing_files:
        print("以下文件未找到，请检查路径:")
        for path in missing_files:
            print(f" - {path}")
        return

    try:
        # 加载数据到 DuckDB
        conn = load_data_duckdb(tables_files)
        if not conn:
            print("无法将数据加载到 DuckDB。正在退出。")
            return

        # 启动交互式聊天
        interactive_chat_duckdb(conn, tables_files)

    except Exception as e:
        print(f"发生错误: {e}")
    finally:
        # 确保 DuckDB 连接关闭
        try:
            if conn:
                conn.close()
        except NameError:
            pass

if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        print("\n程序被用户中断。正在退出。")
        sys.exit()


已注册表: t_zacks_fc -> t_zacks_fc.parquet (parquet)
已注册表: t_zacks_fr -> t_zacks_fr.parquet (parquet)
已注册表: t_zacks_mktv -> t_zacks_mktv.parquet (parquet)
已注册表: t_zacks_shrs -> t_zacks_shrs.parquet (parquet)
已注册表: t_zacks_sectors -> t_zacks_sectors.csv (csv)
所有表已成功注册到 DuckDB。

开始与助手聊天。您可以询问有关数据的问题。
输入 'exit' 或 'quit' 以结束。


生成的 SQL 查询:
SELECT fc.*, fr.*, mktv.*, shrs.*
FROM t_zacks_fc fc
JOIN t_zacks_fr fr ON fc.ticker = fr.ticker AND fc.per_end_date = fr.per_end_date AND fc.per_type = fr.per_type
JOIN t_zacks_mktv mktv ON fc.ticker = mktv.ticker AND fc.per_end_date = mktv.per_end_date AND fc.per_type = mktv.per_type
JOIN t_zacks_shrs shrs ON fc.ticker = shrs.ticker AND fc.per_end_date = shrs.per_end_date AND fc.per_type = shrs.per_type
WHERE fc.ticker = 'AAPL'
ORDER BY fc.per_end_date DESC;

清理后的 SQL 查询:
SELECT fc.*, fr.*, mktv.*, shrs.*
FROM t_zacks_fc fc
JOIN t_zacks_fr fr ON fc.ticker = fr.ticker AND fc.per_end_date = fr.per_end_date AND fc.per_type = fr.per_type
JOIN t_zacks_mktv mktv