In [1]:
from flask import Flask, request, jsonify
from sqlalchemy import create_engine, text
import openai
from openai import OpenAI
import nest_asyncio
import threading

# Apply the nest_asyncio patch
nest_asyncio.apply()

app = Flask(__name__)

#Parameters
DB_USER = "XXX"
DB_PWD = "XXX"
DATABASE_URI = 'postgresql+psycopg2://' + DB_USER + ':' + DB_PWD + '@localhost/<db_name>'
OPEN_API_KEY="XXX"

# Set up DB Connection Engine
engine = create_engine(DATABASE_URI)

# Set up OpenAI API key
openai.api_key = OPEN_API_KEY

client = OpenAI(api_key=OPEN_API_KEY)


def generate_sql(natural_language_query):
   
    
    # Create prompt for OpenAI
    prompt = f"""
    You are a SQL query generator. Given the following table structures, all tables are in the financials schema so must be
    referenced with financials.<table_name>:

    Table: stock_price_history
    column_name             data_type                  Description
    "symbol"	            "character varying"        The stock symbold
    "close_date"	        "date"                     The date of the price, the latest date will be the latest price 
    "close_price"	        "numeric"                  The stock price

    Table: stock_quotes
    column_name             data_type                  Description
    "symbol"	            "character varying"        The stock symbold for the company
    "long_name"   	        "character varying"        The name for the company e.g. "Marks and Spencer Group plc", Apple Inc
    "financial_currency"	"character varying"        The financial currency of the companye.g. USD, GBP, normally based on the country of the company
    "market_cap"	        "numeric"                  The Market Capitalisation value for the company
    "exchange"              "character varying"        The code of the stock exchange where the stock is trades e.g. LSE, NYSE


     For the tables above feel free to join them if required. For company names if for Example the query is
     "Give me the latest prices and market cap for Barclays the query would be:-

     SELECT sq.symbol, sq.long_name, sq.market_cap,  pr.latest_price_date, pr.latest_price
     FROM financials.stock_quotes sq
     JOIN (SELECT sph.symbol, sph.close_date as latest_price_date, sph.close_price AS latest_price
	  FROM financials.stock_price_history sph
	  WHERE close_date = (SELECT MAX(close_date) FROM financials.stock_price_history ld WHERE sph.symbol = ld.symbol)
	 ) pr
	 ON pr.symbol = sq.symbol
     WHERE LOWER(sq.long_name) like '%barclays%'

    Convert this natural language query to SQL:
    Query: {natural_language_query}
    """
    
    response = client.chat.completions.create(
        model="gpt-3.5-turbo",
        messages=[{"role": "user", "content": prompt}]
    )
    
    sql_query = response.choices[0].message.content
    
    # Clean up the SQL query by removing code block markers
    sql_query = sql_query.replace("```sql", "").replace("```", "").strip()
    
    print("Cleaned SQL Query:", sql_query)
    return sql_query

@app.route("/generate_query")
def generate_and_execute_query():
    # Read the query parameter from the URL
    natural_language_query = request.args.get("query")
    
    if not natural_language_query:
        return "<pre>Error: No query provided</pre>", 400
    
    try:
        # Generate SQL query from natural language
        sql_query = generate_sql(natural_language_query)

        # Debugging line: print the generated SQL query
        print("Generated SQL Query:", sql_query)

        # Execute SQL query and fetch data
        with engine.connect() as connection:
            result = connection.execute(text(sql_query))  # Use text() to ensure raw SQL
            rows = [dict(row) for row in result.mappings()]
        
        if rows:
            # Extract column headers and calculate maximum column width
            columns = rows[0].keys()
            col_widths = {col: max(len(str(col)), *(len(str(row[col])) for row in rows)) for col in columns}

            # Create the header row
            header = " | ".join(f"{col.ljust(col_widths[col])}" for col in columns)
            separator = " | ".join("".ljust(col_widths[col], '_') for col in columns)

            # Create the rows
            table_rows = [header, separator]
            for row in rows:
                row_str = " | ".join(f"{str(row[col]).ljust(col_widths[col])}" for col in columns)
                table_rows.append(row_str)

            # Convert the rows to a text block
            text_table = "\n".join(table_rows)
        else:
            text_table = "No results found."

        # Return the text table as preformatted text
        return f"<pre>{text_table}</pre>"
    
    except Exception as e:
        # Handle errors (like SQL syntax errors or connection errors)
        print("Error occurred:", e)  # Print the error for debugging
        return f"<pre>Error: {str(e)}</pre>", 500



# Run the Flask app in a separate thread
def run_app():
    app.run(port=5000, debug=True, use_reloader=False)

# Start the Flask server in a thread to allow it to run in the background
thread = threading.Thread(target=run_app)
thread.start()


 * Serving Flask app '__main__'
 * Debug mode: on


 * Running on http://127.0.0.1:5000
Press CTRL+C to quit
127.0.0.1 - - [12/Nov/2024 19:54:05] "GET /generate_query?query=Add%20currency%20to%20that%20last%20request HTTP/1.1" 200 -


Cleaned SQL Query: SELECT sq.symbol, sq.long_name, sq.market_cap, sq.financial_currency, pr.latest_price_date, pr.latest_price
FROM financials.stock_quotes sq
JOIN (SELECT sph.symbol, sph.close_date as latest_price_date, sph.close_price AS latest_price
	FROM financials.stock_price_history sph
	WHERE close_date = (SELECT MAX(close_date) FROM financials.stock_price_history ld WHERE sph.symbol = ld.symbol)
) pr
ON pr.symbol = sq.symbol
WHERE LOWER(sq.long_name) like '%barclays%'
Generated SQL Query: SELECT sq.symbol, sq.long_name, sq.market_cap, sq.financial_currency, pr.latest_price_date, pr.latest_price
FROM financials.stock_quotes sq
JOIN (SELECT sph.symbol, sph.close_date as latest_price_date, sph.close_price AS latest_price
	FROM financials.stock_price_history sph
	WHERE close_date = (SELECT MAX(close_date) FROM financials.stock_price_history ld WHERE sph.symbol = ld.symbol)
) pr
ON pr.symbol = sq.symbol
WHERE LOWER(sq.long_name) like '%barclays%'


127.0.0.1 - - [12/Nov/2024 19:54:07] "GET /generate_query?query=Show%20me%20the%20symbol,%20company%20name,%20market%20cap,%20exchange%20code%20and%20latest%20prices%20for%20all%20Japanese%20Companies HTTP/1.1" 200 -


Cleaned SQL Query: SELECT sq.symbol, sq.long_name, sq.market_cap, sq.exchange, pr.latest_price_date, pr.latest_price
FROM financials.stock_quotes sq
JOIN (SELECT sph.symbol, sph.close_date AS latest_price_date, sph.close_price AS latest_price
      FROM financials.stock_price_history sph
      WHERE close_date = (SELECT MAX(close_date) FROM financials.stock_price_history ld WHERE sph.symbol = ld.symbol)
     ) pr
ON pr.symbol = sq.symbol
WHERE sq.financial_currency = 'JPY';
Generated SQL Query: SELECT sq.symbol, sq.long_name, sq.market_cap, sq.exchange, pr.latest_price_date, pr.latest_price
FROM financials.stock_quotes sq
JOIN (SELECT sph.symbol, sph.close_date AS latest_price_date, sph.close_price AS latest_price
      FROM financials.stock_price_history sph
      WHERE close_date = (SELECT MAX(close_date) FROM financials.stock_price_history ld WHERE sph.symbol = ld.symbol)
     ) pr
ON pr.symbol = sq.symbol
WHERE sq.financial_currency = 'JPY';


127.0.0.1 - - [12/Nov/2024 19:54:14] "GET /generate_query?query=Add%20currency%20to%20that%20last%20request HTTP/1.1" 200 -


Cleaned SQL Query: SELECT sq.symbol, sq.long_name, sq.market_cap, sq.financial_currency, pr.latest_price_date, pr.latest_price
FROM financials.stock_quotes sq
JOIN (SELECT sph.symbol, sph.close_date as latest_price_date, sph.close_price AS latest_price
	  FROM financials.stock_price_history sph
	  WHERE close_date = (SELECT MAX(close_date) FROM financials.stock_price_history ld WHERE sph.symbol = ld.symbol)
	 ) pr
ON pr.symbol = sq.symbol
WHERE LOWER(sq.long_name) like '%barclays%'
Generated SQL Query: SELECT sq.symbol, sq.long_name, sq.market_cap, sq.financial_currency, pr.latest_price_date, pr.latest_price
FROM financials.stock_quotes sq
JOIN (SELECT sph.symbol, sph.close_date as latest_price_date, sph.close_price AS latest_price
	  FROM financials.stock_price_history sph
	  WHERE close_date = (SELECT MAX(close_date) FROM financials.stock_price_history ld WHERE sph.symbol = ld.symbol)
	 ) pr
ON pr.symbol = sq.symbol
WHERE LOWER(sq.long_name) like '%barclays%'
