In [None]:
import os
import openai
import duckdb
import pandas as pd
import re
from fpdf import FPDF
import matplotlib.pyplot as plt
import sys
import requests  # Ensure the requests library is installed

# -------------------------------
# 1. Securely Load OpenAI API Key
# -------------------------------

# It is recommended to load the API key through environment variables to ensure security
openai_api_key = "YOUR api HERE"

if not openai_api_key:
    raise ValueError("OpenAI API key not found. Please set the 'OPENAI_API_KEY' environment variable.")

openai.api_key = openai_api_key

# -------------------------------
# 2. Download Font (If Needed)
# -------------------------------

def download_font(font_url, save_path):
    try:
        response = requests.get(font_url)
        response.raise_for_status()
        with open(save_path, 'wb') as f:
            f.write(response.content)
        print(f"Font successfully downloaded and saved to {save_path}.")
    except Exception as e:
        print(f"Failed to download font: {e}")
        sys.exit(1)

# If you choose to use a custom font, ensure it is downloaded and placed in the correct location
# Here we are not using a custom font, so this step can be skipped

def sanitize_text(text):
    """
    Clean and process text to ensure it displays correctly in the PDF.
    
    Args:
        text (str): The text to be cleaned.
        
    Returns:
        str: The cleaned text.
    """
    if not isinstance(text, str):
        text = str(text)
    # Replace or remove characters that may cause issues
    text = text.replace('\n', ' ').replace('\r', ' ').replace('\t', ' ')
    # If needed, additional cleaning steps can be added, such as handling special characters
    # For example, remove characters not supported by FPDF
    # text = re.sub(r'[^\x00-\x7F]+',' ', text)
    return text.strip()

# -------------------------------
# 3. Load Data Using DuckDB
# -------------------------------

def load_data_duckdb(tables_files):
    """
    Register Parquet and CSV files as tables in DuckDB.
    
    Args:
        tables_files (dict): A mapping dictionary of table names to file paths and types.
    
    Returns:
        duckdb.DuckDBPyConnection: An active DuckDB connection with registered tables.
    """
    try:
        # Connect to an in-memory DuckDB database
        conn = duckdb.connect(database=':memory:')

        # Register each file as a table
        for table_name, file_info in tables_files.items():
            file_path = file_info['path']
            file_type = file_info['type']
            if file_type == 'parquet':
                conn.execute(f"CREATE TABLE {table_name} AS SELECT * FROM read_parquet('{file_path}')")
            elif file_type == 'csv':
                conn.execute(f"CREATE TABLE {table_name} AS SELECT * FROM read_csv_auto('{file_path}')")
            else:
                print(f"Unknown file type '{file_type}', skipping table '{table_name}'.")
                continue
            print(f"Registered table: {table_name} -> {file_path} ({file_type})")

        print("All tables have been successfully registered to DuckDB.")
        return conn
    except Exception as e:
        print(f"Error loading data into DuckDB: {e}")
        return None

# -------------------------------
# 4. Generate SQL Query
# -------------------------------

def generate_sql(user_question, model="ft:gpt-4o-2024-08-06:personal::AXYv83vn"):
    """
    Generate an SQL query based on the user's natural language question using the OpenAI GPT model.
    
    Args:
        user_question (str): The user's natural language question.
        model (str): The name of the GPT model to use.
    
    Returns:
        str: The generated SQL query.
    """
    prompt = f"""
Pretend you are an expert at converting natural language questions into accurate SQL queries. Please generate an accurate SQL query based on the following natural language question and database schema provided below. Think sequentially and refer to the sample natural language questions with correct and incorrect outputs as well.

Database Schema:
Table 1: t_zacks_fc (This table contains fundamental indicators for companies)
Columns: 'ticker' = Unique Zacks Identifier for each company/stock, ticker or trading symbol, 'comp_name' = Company name, 'exchange' = Exchange traded, 'per_end_date' = Period end date which represents quarterly data, 'per_type' = Period type (e.g., Q for quarterly data), 'filing_date' = Filing date, 'filing_type' = Filing type: 10-K, 10-Q, PRELIM, 'zacks_sector_code' = Zacks sector code (Numeric Value e.g., 11 = Aerospace), 'eps_diluted_net_basic’ = Earnings per share (EPS) net (Company's net earnings or losses attributable to common shareholders per basic share basis), 'lterm_debt_net_tot' = Net long-term debt (The net amount of long-term debt issued and repaid. This field is either calculated as the sum of the long-term debt fields or used if a company does not report debt issued and repaid separately).
Keys: ticker, per_end_date, per_type

Table 2: t_zacks_fr (This table contains fundamental ratios for companies)
Columns: 'ticker' = Unique Zacks Identifier for each company/stock, ticker or trading symbol, 'per_end_date' = Period end date which represents quarterly data, 'per_type' = Period type (e.g., Q for quarterly data), ‘ret_invst’ = Return on investments (An indicator of how profitable a company is relative to its assets invested by shareholders and long-term bond holders. Calculated by dividing a company's operating earnings by its long-term debt and shareholders equity), ‘tot_debt_tot_equity’ = Total debt / total equity (A measure of a company's financial leverage calculated by dividing its long-term debt by stockholders' equity).
Keys: ticker, per_end_date, per_type.

Table 3: t_zacks_mktv (This table contains market value data for companies)
Columns: 'ticker' = Unique Zacks Identifier for each company/stock, ticker or trading symbol, 'per_end_date' = Period end date which represents quarterly data, 'per_type' = Period type (e.g., Q for quarterly data), ‘mkt_val’ = Market Cap of Company (shares out x last monthly price per share - unit is in Millions).
Keys: ticker, per_end_date, per_type.

Table 4: t_zacks_shrs (This table contains shares outstanding data for companies)
Columns: 'ticker' = Unique Zacks Identifier for each company/stock, ticker or trading symbol, 'per_end_date' = Period end date which represents quarterly data, 'per_type' = Period type (e.g., Q for quarterly data), ‘shares_out’ = Number of Common Shares Outstanding from the front page of 10K/Q.
Keys: ticker, per_end_date, per_type.

Table 5: t_zacks_sectors (This table contains the Zacks sector codes and their corresponding sectors)
Columns: 'zacks_sector_code' = Unique identifier for each Zacks sector, 'sector' = The sector descriptions that correspond to the sector code 
Keys: zacks_sector_code 

Sample natural language questions with correct and incorrect outputs: 
Sample prompt 1: Output ticker with the largest market value recorded on any given period end date. 
Correct output for prompt 1: SELECT ticker, per_end_date, MAX(mkt_val) AS max_market_value FROM t_zacks_mktv GROUP BY per_end_date ORDER BY max_market_value DESC LIMIT 1;
Incorrect output for prompt 1: SELECT MAX(mkt_val), ticker FROM t_zacks_mktv GROUP BY ticker

Sample prompt 2: What is the company name with the lowest market cap?
Correct output for prompt 2: SELECT fc.comp_name, mktv.ticker, mktv.mkt_val FROM t_zacks_mktv AS mktv JOIN t_zacks_fc AS fc ON mktv.ticker = fc.ticker WHERE mktv.mkt_val = (SELECT MIN(mkt_val) FROM t_zacks_mktv);
Incorrect output for prompt 2: SELECT T1.comp_name FROM t_zacks_fc AS T1 INNER JOIN t_zacks_mktv AS T2 ON T1.ticker = T2.ticker AND T1.per_end_date = T2.per_end_date AND T1.per_type = T2.per_type ORDER BY T2.mkt_val LIMIT 1

Sample prompt 3: Filter t_zacks_fc to only show companies with a total debt-to-equity ratio greater than 1.
Correct output for prompt 3: SELECT * FROM t_zacks_fr WHERE tot_debt_tot_equity > 1;
Incorrect output for prompt 3: SELECT * FROM t_zacks_fr WHERE t_zacks_mktv > 1;

Sample prompt 4: Filter t_zacks_shrs to include companies with more than 500 million shares outstanding as of the most recent quarter.
Correct output for prompt 4: SELECT *
FROM t_zacks_shrs
WHERE shares_out > 5000
ORDER BY per_end_date DESC;
Incorrect output for prompt 4: SELECT * FROM t_zacks_shrs WHERE shares_out > 500000000

Sample prompt 5: Combine t_zacks_mktv and t_zacks_shrs to show tickers with market cap and shares outstanding in the latest period end date.
Correct output for prompt 5: SELECT mktv.ticker, mktv.per_end_date, mktv.mkt_val, shrs.shares_out
FROM t_zacks_mktv mktv
JOIN t_zacks_shrs shrs ON mktv.ticker = shrs.ticker AND mktv.per_end_date = shrs.per_end_date
ORDER BY mktv.per_end_date DESC;
Incorrect output for prompt 5: SELECT ticker, mkt_val, shares_out FROM t_zacks_mktv INNER JOIN t_zacks_shrs ON t_zacks_mktv.ticker = t_zacks_shrs.ticker AND t_zacks_mktv.per_end_date = t_zacks_shrs.per_end_date ORDER BY per_end_date DESC LIMIT 1

Sample prompt 6: Join t_zacks_fc and t_zacks_fr to show tickers with total debt-to-equity ratios and EPS from NASDAQ as of Q2 2024.
Correct output for prompt 6: SELECT fc.ticker, fc.eps_diluted_net_basic, fr.tot_debt_tot_equity
FROM t_zacks_fc fc
JOIN t_zacks_fr fr ON fc.ticker = fr.ticker AND fc.per_end_date = fr.per_end_date
WHERE fc.exchange = 'NASDAQ' AND fc.per_type = 'Q' AND fc.per_end_date BETWEEN '2024-04-01' AND '2024-06-30';
Incorrect output for prompt 6: SELECT T1.ticker, T1.eps_diluted_net_basic, T2.ret_invst, T2.tot_debt_tot_equity FROM t_zacks_fc AS T1 INNER JOIN t_zacks_fr AS T2 ON T1.ticker = T2.ticker AND T1.per_end_date = T2.per_end_date WHERE T1.exchange = 'NASDAQ' AND T1.per_type = 'Q2';

Please make sure that when you are joining 2 or more tables, you are using all 3 keys (ticker, per_end_date & per_type). Also, ensure that the SQL query is syntactically correct and provides the expected output based on the natural language question provided.

User's Question:
{user_question}

Please provide only the SQL query without any markdown, code block syntax, or explanations.
    """
    try:
        response = openai.ChatCompletion.create(
            model=model,
            messages=[
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": prompt}
            ],
            max_tokens=300,  # Adjust based on complexity
            temperature=0.0,  # Set to 0 for more deterministic output
            n=1,
            stop=None
        )
        # Extract SQL query
        raw_sql = response.choices[0].message['content'].strip()
        print("\nGenerated SQL Query:")
        print(raw_sql)

        # Clean any code block formatting
        sql_query = re.sub(r'^```sql\s*', '', raw_sql, flags=re.IGNORECASE)
        sql_query = re.sub(r'```$', '', sql_query, flags=re.IGNORECASE)
        sql_query = sql_query.strip()

        print("\nCleaned SQL Query:")
        print(sql_query)
        return sql_query
    except openai.OpenAIError as e:
        print(f"Error generating SQL query: {e}")
        return None

# -------------------------------
# 5. Execute SQL Query Using DuckDB
# -------------------------------

def execute_sql_duckdb(sql, conn):
    """
    Execute an SQL query using the provided DuckDB connection.
    
    Args:
        sql (str): The SQL query to execute.
        conn (duckdb.DuckDBPyConnection): An active DuckDB connection.
    
    Returns:
        pd.DataFrame: The result of the SQL query.
    """
    try:
        result = conn.execute(sql).fetchdf()
        print("\nSQL query executed successfully.")
        print(f"Number of records retrieved: {result.shape[0]}")
        return result
    except Exception as e:
        print(f"Error executing SQL query: {e}")
        return pd.DataFrame()

# -------------------------------
# 6. Generate Analysis Using OpenAI API
# -------------------------------

def generate_analysis_from_openai(dataframe, user_question):
    if dataframe.empty:
        return "No data available for analysis."

    table_md = dataframe.to_markdown(index=False)
    prompt = f"""
I have executed an SQL query based on the following user question and obtained the data below.

User's Question:
{user_question}

Data Table:
{table_md}

Pretend you are an experienced equity analyst working in the banking industry. Please analyze this data in the style of an expert equity analyst, highlighting trends, comparing companies, analyzing significance of metrics, and noting any interesting insights regarding this data.
    """

    try:
        # Call OpenAI API
        response = openai.ChatCompletion.create(
            model="ft:gpt-4o-2024-08-06:personal::AYFZ3Shk",
            messages=[
                {"role": "system", "content": "You are an experienced equity analyst."},
                {"role": "user", "content": prompt}
            ],
            max_tokens=500,
            temperature=0.5,
            n=1,
            stop=None
        )
        # Get analysis content
        analysis = response.choices[0].message['content'].strip()
        print("\nGenerated Analysis:")
        print(analysis)
        return analysis
    except openai.OpenAIError as e:
        return f"Error generating analysis: {e}"

# -------------------------------
# 7. Define PDF Generation Class
# -------------------------------

class PDF(FPDF):
    def __init__(self):
        super().__init__()
        # Use default Arial font
        self.set_font("Arial", "", 12)
        # Set auto page break
        self.set_auto_page_break(auto=True, margin=15)
        # Set margins
        self.set_margins(left=15, top=20, right=15)

    def header(self):
        self.set_font("Arial", "B", 16)
        self.cell(0, 10, "Equity Analyst Report", align="C", ln=True)
        self.ln(5)

    def chapter_title(self, title):
        self.set_font("Arial", "B", 14)
        title = sanitize_text(title)
        self.cell(0, 10, title, 0, 1, "L")
        self.ln(2)

    def chapter_body(self, body):
        self.set_font("Arial", "", 12)
        body = sanitize_text(body)
        self.multi_cell(0, 10, body)
        self.ln()

    def table(self, data):
        if data.empty:
            self.set_font("Arial", "I", 12)
            self.cell(0, 10, "No data available to display.", 0, 1, 'C')
            self.ln()
            return

        # Calculate column widths
        col_widths = self.calculate_col_widths(data)
        self.set_font("Arial", "B", 10)
        # Add table header
        for header in data.columns:
            header = sanitize_text(header)
            self.cell(col_widths[header], 8, header, 1, 0, 'C')
        self.ln()
        # Add table rows
        self.set_font("Arial", "", 10)
        max_rows_per_page = int((self.h - self.y - 15) / 8)
        row_count = 0
        for _, row in data.iterrows():
            if row_count == max_rows_per_page:
                self.add_page()
                # Repeat table header on new page
                self.set_font("Arial", "B", 10)
                for header in data.columns:
                    header = sanitize_text(header)
                    self.cell(col_widths[header], 8, header, 1, 0, 'C')
                self.ln()
                self.set_font("Arial", "", 10)
                row_count = 0
            for header in data.columns:
                cell_text = str(row[header]) if pd.notnull(row[header]) else ""
                cell_text = sanitize_text(cell_text)
                self.cell(col_widths[header], 8, cell_text, 1, 0, 'C')
            self.ln()
            row_count += 1
        self.ln()

    def calculate_col_widths(self, data):
        # Maximum width for each column
        max_width = (self.w - 30) / len(data.columns)
        col_widths = {}
        for col in data.columns:
            col_widths[col] = max_width
        return col_widths

    def add_image(self, image_path, title, width=180):
        if not os.path.exists(image_path):
            print(f"Image file {image_path} does not exist.")
            return
        self.chapter_title(title)
        self.image(image_path, w=width)
        self.ln(10)

# -------------------------------
# 8. Generate PDF Report
# -------------------------------

def generate_pdf_report(pdf, analysis_text, data_table, chart_paths, filename="equity_analyst_report.pdf"):
    # Add overview
    pdf.chapter_title("Selected Companies Overview")
    overview_text = (
        "This report provides an analysis of selected companies based on the user's query, including data on revenue, net income, and market capitalization."
    )
    pdf.chapter_body(overview_text)

    # Add analysis
    pdf.chapter_title("Analysis")
    pdf.chapter_body(analysis_text)

    # Add charts
    if chart_paths:
        pdf.chapter_title("Visualizations")
        for chart_path in chart_paths:
            if 'pie_chart' in chart_path.lower():
                chart_title = "Market Value Distribution Pie Chart"
            elif 'bar_chart' in chart_path.lower():
                chart_title = "Market Value Comparison Bar Chart"
            else:
                chart_title = "Chart"
            pdf.add_image(chart_path, chart_title)

    # Add data table
    pdf.chapter_title("Company Financial Data")
    pdf.table(data_table)

    # Save PDF
    try:
        pdf.output(filename)
        print(f"Report generated and saved as {filename}")
    except Exception as e:
        print(f"Error saving PDF: {e}")

# -------------------------------
# 9. Interactive Chat Functionality
# -------------------------------

def interactive_chat_duckdb(conn, tables_files):
    print("\nStarting chat with the assistant. You can ask questions about the data.")
    print("Type 'exit' or 'quit' to end.\n")

    while True:
        user_input = input("You: ").strip()
        if user_input.lower() in ["exit", "quit"]:
            print("Ending chat. Goodbye!")
            break
        elif user_input.lower() in ["help", "h"]:
            print("\nYou can ask questions related to the data, such as:")
            print("- Which companies are ranked top 5 by market cap?")
            print("- Show financial metrics for company X.")
            print("- Compare market caps of companies in the technology sector.\n")
            continue

        # Generate SQL query based on user input
        sql_query = generate_sql(user_input)
        if not sql_query:
            print("Failed to generate SQL query. Please try another question.")
            continue

        # Execute SQL query
        query_result = execute_sql_duckdb(sql_query, conn)
        if query_result.empty:
            print("SQL query returned no data.")
            continue
        else:
            print("\nQuery Results:")
            print(query_result)

        # Generate analysis
        analysis_text = generate_analysis_from_openai(query_result, user_input)

        # Generate charts
        chart_paths = generate_charts(query_result)

        # Initialize PDF
        pdf_filename = "equity_analyst_report.pdf"
        pdf = PDF()
        pdf.add_page()

        # Generate PDF report with charts
        generate_pdf_report(pdf, analysis_text, query_result, chart_paths, filename=pdf_filename)

        print(f"\nReport generated and saved as {pdf_filename}")

# -------------------------------
# 10. Generate Charts
# -------------------------------

def generate_charts(dataframe, output_dir="charts"):
    if dataframe.empty:
        print("No data available to generate charts.")
        return []

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    chart_paths = []

    # Check if necessary columns are present
    if 'mkt_val' in dataframe.columns and 'ticker' in dataframe.columns:
        try:
            # Aggregate data by 'ticker'
            aggregated_data = dataframe.groupby('ticker', as_index=False)['mkt_val'].sum()

            # Select top 5 companies by market cap
            top_companies = aggregated_data.nlargest(5, 'mkt_val')

            # --- Pie Chart ---
            plt.figure(figsize=(6,6))
            plt.pie(top_companies['mkt_val'], labels=top_companies['ticker'], autopct='%1.1f%%', startangle=140)
            plt.title('Top 5 Companies Market Value Distribution')
            pie_chart_path = os.path.join(output_dir, 'market_value_pie_chart.png')
            plt.savefig(pie_chart_path, bbox_inches='tight')
            plt.close()
            chart_paths.append(pie_chart_path)

            # --- Bar Chart ---
            plt.figure(figsize=(8,6))
            plt.bar(top_companies['ticker'], top_companies['mkt_val'], color='skyblue')
            plt.xlabel('Ticker')
            plt.ylabel('Market Value (Millions)')
            plt.title('Market Value of Top 5 Companies')
            bar_chart_path = os.path.join(output_dir, 'market_value_bar_chart.png')
            plt.savefig(bar_chart_path, bbox_inches='tight')
            plt.close()
            chart_paths.append(bar_chart_path)

            print(f"\nCharts have been generated and saved in the '{output_dir}' directory.")
        except Exception as e:
            print(f"Error generating charts: {e}")
    else:
        print("Data does not contain the required columns ('mkt_val', 'ticker') for generating charts.")

    return chart_paths

# -------------------------------
# 11. Main Function Integrating All Features
# -------------------------------

def main():
    # Define mapping of table names to file paths and types
    tables_files = {
        't_zacks_fc': {'path': 't_zacks_fc.parquet', 'type': 'parquet'},       # Replace with actual paths
        't_zacks_fr': {'path': 't_zacks_fr.parquet', 'type': 'parquet'},
        't_zacks_mktv': {'path': 't_zacks_mktv.parquet', 'type': 'parquet'},
        't_zacks_shrs': {'path': 't_zacks_shrs.parquet', 'type': 'parquet'},
        't_zacks_sectors': {'path': 't_zacks_sectors.csv', 'type': 'csv'}      # CSV file
    }

    # Check if all files exist
    missing_files = [info['path'] for info in tables_files.values() if not os.path.exists(info['path'])]
    if missing_files:
        print("The following files were not found. Please check the paths:")
        for path in missing_files:
            print(f" - {path}")
        return

    try:
        # Load data into DuckDB
        conn = load_data_duckdb(tables_files)
        if not conn:
            print("Unable to load data into DuckDB. Exiting.")
            return

        # Start interactive chat
        interactive_chat_duckdb(conn, tables_files)

    except Exception as e:
        print(f"An error occurred: {e}")
    finally:
        # Ensure DuckDB connection is closed
        try:
            if conn:
                conn.close()
        except NameError:
            pass

if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        print("\nProgram interrupted by user. Exiting.")
        sys.exit()


Registered table: t_zacks_fc -> t_zacks_fc.parquet (parquet)
Registered table: t_zacks_fr -> t_zacks_fr.parquet (parquet)
Registered table: t_zacks_mktv -> t_zacks_mktv.parquet (parquet)
Registered table: t_zacks_shrs -> t_zacks_shrs.parquet (parquet)
Registered table: t_zacks_sectors -> t_zacks_sectors.csv (csv)
All tables have been successfully registered to DuckDB.

Starting chat with the assistant. You can ask questions about the data.
Type 'exit' or 'quit' to end.


Generated SQL Query:
SELECT * FROM t_zacks_fc WHERE ticker = 'AAPL';

Cleaned SQL Query:
SELECT * FROM t_zacks_fc WHERE ticker = 'AAPL';

SQL query executed successfully.
Number of records retrieved: 94

Query Results:
   ticker comp_name exchange per_end_date per_type filing_date filing_type  \
0    AAPL     Apple   NASDAQ   2006-03-31        Q  2006-05-05        10-Q   
1    AAPL     Apple   NASDAQ   2006-06-30        Q  2006-12-29        10-Q   
2    AAPL     Apple   NASDAQ   2006-09-30        Q  2006-12-29        1