In [1]:
import sys
import os
import re
import openai as OpenAI
from typing import List, Tuple, Dict
import json
from dotenv import load_dotenv
from openai import OpenAI
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
import nltk
from nltk.stem import WordNetLemmatizer
from nltk.corpus import stopwords
from rapidfuzz import fuzz
os.chdir("..")
from chatgpt_api import chat_prompt
from chatgpt_api import api
print(os.getcwd())

nltk.download('punkt')
nltk.download('wordnet')
nltk.download('stopwords')


current_dir = os.path.dirname(os.path.abspath("/d/GithubRepos/PIMCO-Text2SQL"))
din_modules_path = os.path.join(current_dir, 'chatgpt_api')
sys.path.append(din_modules_path)

client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
if not client.api_key:
    raise ValueError("OpenAI API key not configured")


```
Task Description:
The task is to transform the natural language query into a SQL query for SQLite database.
This involves parsing the intent of the query and understanding the structure of the data to generate an appropriate SQL command.
```

```
Database Overview:
- The Database combines information from 30 tables of the NPORT dataset from quarter 4 of 2019 to quarter 3 of 2024.
- The data includes a comprehensive view of fund-level information, holdings, debt securities, repurchase agreements, and derivative instruments.
- Each relation represents detailed information about financial transactions, security holdings, and fund performance, including key identifiers like ACCESSION_NUMBER, HOLDING_ID, and CUSIP for borrowers, holdings, and securities.
- The table provides essential metrics like total assets, liabilities, interest rate risks, monthly returns, and details for securities lending and collateral.
- The table aggregates all the data to provide a holistic view of financial

DEBUG:httpx:load_ssl_context verify=True cert=None trust_env=True http2=False
DEBUG:httpx:load_verify_locations cafile='C:\\Users\\User\\anaconda3\\Library\\ssl\\cacert.pem'
INFO:chatgpt_api.api:OpenAI API key loaded successfully
DEBUG:chatgpt_api.api:Attempting to load schema from chatgpt_api/schema.json
INFO:chatgpt_api.api:Schema loaded successfully


Expected schema path: chatgpt_api/schema.json


[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\User\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\User\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\User\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\User\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\User\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\User\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
DEBUG:httpx:load_ssl_context verify=True cert=None trust_env=True http

d:\GithubRepos\PIMCO-Text2SQL


In [2]:
text = """
1. "List the top 5 registrants by total net assets, including their CIK and country."
   SQL: 
   WITH FundAssets AS (
       SELECT R.CIK, R.REGISTRANT_NAME, R.COUNTRY, F.NET_ASSETS
       FROM REGISTRANT R
       JOIN FUND_REPORTED_INFO F ON R.ACCESSION_NUMBER = F.ACCESSION_NUMBER
   )
   SELECT CIK, REGISTRANT_NAME, COUNTRY, NET_ASSETS
   FROM FundAssets
   ORDER BY NET_ASSETS DESC
   LIMIT 5;

2. "Find all holdings with a fair value level of Level 1 and their corresponding fund names."
   SQL: 
   WITH HoldingsCTE AS (
       SELECT H.HOLDING_ID, H.ISSUER_NAME, H.FAIR_VALUE_LEVEL, F.SERIES_NAME
       FROM FUND_REPORTED_HOLDING H
       JOIN FUND_REPORTED_INFO F ON H.ACCESSION_NUMBER = F.ACCESSION_NUMBER
       WHERE H.FAIR_VALUE_LEVEL = 'Level 1'
   )
   SELECT HOLDING_ID, ISSUER_NAME, SERIES_NAME
   FROM HoldingsCTE;

3. "Calculate the total collateral amount for repurchase agreements grouped by counterparty."
   SQL: 
   WITH CollateralCTE AS (
    SELECT RCP.NAME AS Counterparty_Name, SUM(RC.COLLATERAL_AMOUNT) AS Total_Collateral
    FROM REPURCHASE_COLLATERAL RC
    JOIN REPURCHASE_COUNTERPARTY RCP ON RC.HOLDING_ID = RCP.HOLDING_ID
    GROUP BY RCP.NAME
   )
   SELECT Counterparty_Name, Total_Collateral
   FROM CollateralCTE
   ORDER BY Total_Collateral DESC;

4. "Locate funds that have both securities lending activities and repurchase agreements."
   SQL: 
   WITH SecuritiesLending AS (
       SELECT ACCESSION_NUMBER
       FROM SECURITIES_LENDING
       WHERE IS_LOAN_BY_FUND = 'Y'
   ),
   RepurchaseAgreements AS (
       SELECT ACCESSION_NUMBER
       FROM REPURCHASE_AGREEMENT
   )
   SELECT F.SERIES_NAME
   FROM FUND_REPORTED_INFO F
   WHERE F.ACCESSION_NUMBER IN (SELECT ACCESSION_NUMBER FROM SecuritiesLending)
     AND F.ACCESSION_NUMBER IN (SELECT ACCESSION_NUMBER FROM RepurchaseAgreements);

5. "Find borrowers who have borrowed more than $5,000,000, including their names and LEIs."
   SQL: 
   WITH BorrowedAmounts AS (
       SELECT BORROWER_ID, SUM(AGGREGATE_VALUE) AS Total_Borrowed
       FROM BORROWER
       GROUP BY BORROWER_ID
       HAVING SUM(AGGREGATE_VALUE) > 5000000
   )
   SELECT B.NAME, B.LEI, BA.Total_Borrowed
   FROM BORROWER B
   JOIN BorrowedAmounts BA ON B.BORROWER_ID = BA.BORROWER_ID;

6. "List all derivative counterparties along with the number of derivative instruments they are involved in."
   SQL: 
   WITH CounterpartyCounts AS (
       SELECT DC.DERIVATIVE_COUNTERPARTY_NAME, COUNT(*) AS Instrument_Count
       FROM DERIVATIVE_COUNTERPARTY DC
       JOIN FUND_REPORTED_HOLDING H ON DC.HOLDING_ID = H.HOLDING_ID
       JOIN DEBT_SECURITY D ON H.HOLDING_ID = D.HOLDING_ID
       GROUP BY DC.DERIVATIVE_COUNTERPARTY_NAME
   )
   SELECT DERIVATIVE_COUNTERPARTY_NAME, Instrument_Count
   FROM CounterpartyCounts
   ORDER BY Instrument_Count DESC;

7. "Compute the average annualized rate for debt securities grouped by coupon type."
   SQL: 
   WITH RateAverages AS (
       SELECT DS.COUPON_TYPE, AVG(DS.ANNUALIZED_RATE) AS Avg_Annualized_Rate
       FROM DEBT_SECURITY DS
       WHERE DS.ANNUALIZED_RATE IS NOT NULL
       GROUP BY DS.COUPON_TYPE
   )
   SELECT COUPON_TYPE, Avg_Annualized_Rate
   FROM RateAverages
   ORDER BY Avg_Annualized_Rate DESC;

8. "Get funds that have experienced a net decrease in assets over the last three reporting periods."
   SQL: 
   WITH AssetChanges AS (
       SELECT F.ACCESSION_NUMBER, F.SERIES_NAME, S.REPORT_DATE, F.NET_ASSETS,
              LAG(F.NET_ASSETS, 1) OVER (PARTITION BY F.SERIES_NAME ORDER BY S.REPORT_DATE) AS Previous_Period_Assets
       FROM FUND_REPORTED_INFO F
       JOIN SUBMISSION S ON F.ACCESSION_NUMBER = S.ACCESSION_NUMBER
   )
   SELECT DISTINCT AC.SERIES_NAME
   FROM AssetChanges AC
   WHERE AC.NET_ASSETS < AC.Previous_Period_Assets
     AND AC.Previous_Period_Assets IS NOT NULL;

9. "Identify issuers with more than three different securities holdings, including their names and CUSIPs."
   SQL: 
   WITH IssuerHoldings AS (
       SELECT H.ISSUER_NAME, H.ISSUER_CUSIP, COUNT(DISTINCT H.HOLDING_ID) AS Holding_Count
       FROM FUND_REPORTED_HOLDING H
       GROUP BY H.ISSUER_NAME, H.ISSUER_CUSIP
       HAVING COUNT(DISTINCT H.HOLDING_ID) > 3
   )
   SELECT ISSUER_NAME, ISSUER_CUSIP, Holding_Count
   FROM IssuerHoldings
   ORDER BY Holding_Count DESC;

10. "Calculate the total notional amount of derivatives per currency and identify the top 3 currencies by notional amount."
    SQL: 
    WITH NotionalSums AS (
        SELECT ODNA.CURRENCY_CODE, SUM(ODNA.NOTIONAL_AMOUNT) AS Total_Notional
        FROM OTHER_DERIV_NOTIONAL_AMOUNT ODNA
        GROUP BY ODNA.CURRENCY_CODE
    )
    SELECT CURRENCY_CODE, Total_Notional
    FROM NotionalSums
    ORDER BY Total_Notional DESC
    LIMIT 3;

11. "List funds with liquidation preferences exceeding their net assets."
    SQL: 
    WITH FundPreferences AS (
        SELECT F.SERIES_NAME, F.LIQUIDATION_PREFERENCE, F.NET_ASSETS
        FROM FUND_REPORTED_INFO F
    )
    SELECT SERIES_NAME, LIQUIDATION_PREFERENCE, NET_ASSETS
    FROM FundPreferences
    WHERE LIQUIDATION_PREFERENCE > NET_ASSETS;

12. "Find all convertible securities that are contingent and have a conversion ratio above 1.5."
    SQL: 
    WITH ConvertibleCTE AS (
        SELECT DS.HOLDING_ID, CSC.CONVERSION_RATIO
        FROM DEBT_SECURITY DS
        JOIN CONVERTIBLE_SECURITY_CURRENCY CSC ON DS.HOLDING_ID = CSC.HOLDING_ID
        WHERE DS.IS_CONVTIBLE_CONTINGENT = 'Y' AND CSC.CONVERSION_RATIO > 1.5
    )
    SELECT HOLDING_ID, CONVERSION_RATIO
    FROM ConvertibleCTE;

13. "Analyze the distribution of asset categories within the top 10 largest funds by total assets."
    SQL: 
    WITH TopFunds AS (
        SELECT SERIES_NAME, ACCESSION_NUMBER
        FROM FUND_REPORTED_INFO
        ORDER BY TOTAL_ASSETS DESC
        LIMIT 10
    ),
    AssetDistribution AS (
        SELECT H.ASSET_CAT, COUNT(*) AS Category_Count
        FROM FUND_REPORTED_HOLDING H
        JOIN TopFunds T ON H.ACCESSION_NUMBER = T.ACCESSION_NUMBER
        GROUP BY H.ASSET_CAT
    )
    SELECT ASSET_CAT, Category_Count
    FROM AssetDistribution
    ORDER BY Category_Count DESC;
    
14. "Find the top 10 funds with the highest average monthly returns in the past quarter."
   SQL: 
   WITH AvgMonthlyReturn AS (
       SELECT ACCESSION_NUMBER, 
              (MONTHLY_TOTAL_RETURN1 + MONTHLY_TOTAL_RETURN2 + MONTHLY_TOTAL_RETURN3) / 3.0 AS Avg_Return
       FROM MONTHLY_TOTAL_RETURN
   )
   SELECT F.SERIES_NAME, A.ACCESSION_NUMBER, A.Avg_Return
   FROM AvgMonthlyReturn A
   JOIN FUND_REPORTED_INFO F ON A.ACCESSION_NUMBER = F.ACCESSION_NUMBER
   ORDER BY A.Avg_Return DESC
   LIMIT 10;

15. "Compare the latest net asset values of the top 5 performing funds."
    SQL: 
    WITH TopPerformingFunds AS (
        SELECT 
            ACCESSION_NUMBER, 
            (MONTHLY_TOTAL_RETURN1 + MONTHLY_TOTAL_RETURN2 + MONTHLY_TOTAL_RETURN3) / 3.0 AS Avg_Return
        FROM 
            MONTHLY_TOTAL_RETURN
        ORDER BY 
            Avg_Return DESC
        LIMIT 5
    )
    SELECT 
        FR.SERIES_NAME, 
        FR.NET_ASSETS, 
        TP.Avg_Return
    FROM 
        TopPerformingFunds TP
    JOIN 
        FUND_REPORTED_INFO FR ON TP.ACCESSION_NUMBER = FR.ACCESSION_NUMBER;

16. "Calculate the overall average return across all funds for the most recent month."
    SQL: 
    WITH LatestReturns AS (
        SELECT 
            M.ACCESSION_NUMBER, 
            M.MONTHLY_TOTAL_RETURN1
        FROM 
            MONTHLY_TOTAL_RETURN M
        JOIN 
            SUBMISSION S ON M.ACCESSION_NUMBER = S.ACCESSION_NUMBER
        WHERE 
            S.REPORT_DATE = (SELECT MAX(REPORT_DATE) FROM SUBMISSION)
    )
    SELECT 
        AVG(MONTHLY_TOTAL_RETURN1) AS Average_Return
    FROM 
        LatestReturns;

17. "Find the interest rate risk for each fund and identify those with the highest risk scores."
    SQL: 
    WITH InterestRiskScores AS (
        SELECT 
            IR.ACCESSION_NUMBER, 
            -- Calculating composite risk score by summing absolute values of DV01 and DV100 columns
            (ABS(CAST(IR.INTRST_RATE_CHANGE_3MON_DV01 AS FLOAT)) +
            ABS(CAST(IR.INTRST_RATE_CHANGE_1YR_DV01 AS FLOAT)) +
            ABS(CAST(IR.INTRST_RATE_CHANGE_5YR_DV01 AS FLOAT)) +
            ABS(CAST(IR.INTRST_RATE_CHANGE_10YR_DV01 AS FLOAT)) +
            ABS(CAST(IR.INTRST_RATE_CHANGE_30YR_DV01 AS FLOAT)) +
            ABS(CAST(IR.INTRST_RATE_CHANGE_3MON_DV100 AS FLOAT)) +
            ABS(CAST(IR.INTRST_RATE_CHANGE_1YR_DV100 AS FLOAT)) +
            ABS(CAST(IR.INTRST_RATE_CHANGE_5YR_DV100 AS FLOAT)) +
            ABS(CAST(IR.INTRST_RATE_CHANGE_10YR_DV100 AS FLOAT)) +
            ABS(CAST(IR.INTRST_RATE_CHANGE_30YR_DV100 AS FLOAT))
            ) AS Composite_Risk_Score
        FROM 
            INTEREST_RATE_RISK IR
    )
    SELECT 
        FR.SERIES_NAME, 
        FR.ACCESSION_NUMBER, 
        IRS.Composite_Risk_Score
    FROM 
        InterestRiskScores IRS
    JOIN 
        FUND_REPORTED_INFO FR ON IRS.ACCESSION_NUMBER = FR.ACCESSION_NUMBER
    ORDER BY 
        IRS.Composite_Risk_Score DESC
    LIMIT 5;

18. "Analyze the composition of fund portfolios by categorizing assets and their total values."
    SQL: 
    WITH PortfolioComposition AS (
    SELECT 
        ACCESSION_NUMBER, 
        ASSET_CAT, 
        SUM(CAST(CURRENCY_VALUE AS FLOAT)) AS Total_Value
    FROM 
        FUND_REPORTED_HOLDING
    GROUP BY 
        ACCESSION_NUMBER, 
        ASSET_CAT
    )
    SELECT 
        F.SERIES_NAME, 
        PC.ASSET_CAT, 
        PC.Total_Value
    FROM 
        PortfolioComposition PC
    JOIN 
        FUND_REPORTED_INFO F ON PC.ACCESSION_NUMBER = F.ACCESSION_NUMBER
    ORDER BY 
        F.SERIES_NAME, 
        PC.Total_Value DESC;

19. "Identify the most common asset categories across all fund portfolios."
    SQL: 
    WITH AssetCounts AS (
        SELECT ASSET_CAT, COUNT(*) AS Count
        FROM FUND_REPORTED_HOLDING
        GROUP BY ASSET_CAT
    )
    SELECT ASSET_CAT, Count
    FROM AssetCounts
    ORDER BY Count DESC
    LIMIT 5;

20. "Retrieve funds that have experienced a net decrease in assets over the last three reporting periods."
   SQL: 
   WITH AssetChanges AS (
       SELECT F.ACCESSION_NUMBER, F.SERIES_NAME, S.REPORT_DATE, F.NET_ASSETS,
              LAG(F.NET_ASSETS, 1) OVER (PARTITION BY F.SERIES_NAME ORDER BY S.REPORT_DATE) AS Previous_Period_Assets
       FROM FUND_REPORTED_INFO F
       JOIN SUBMISSION S ON F.ACCESSION_NUMBER = S.ACCESSION_NUMBER
   )
   SELECT DISTINCT AC.SERIES_NAME
   FROM AssetChanges AC
   WHERE AC.NET_ASSETS < AC.Previous_Period_Assets
     AND AC.Previous_Period_Assets IS NOT NULL;"""


pattern = r'"\s*(.*?)\s*"\s*SQL:\s*(WITH.*?;)(?=\n\s*\d+|$)'
matches = re.findall(pattern, text, re.DOTALL)

ground_truth_query = [match[1] for match in matches]
llm_query = [match[0] for match in matches]
print("Queries:", ground_truth_query)
print("SQL Statements:", llm_query)

Queries: ['WITH FundAssets AS (\n       SELECT R.CIK, R.REGISTRANT_NAME, R.COUNTRY, F.NET_ASSETS\n       FROM REGISTRANT R\n       JOIN FUND_REPORTED_INFO F ON R.ACCESSION_NUMBER = F.ACCESSION_NUMBER\n   )\n   SELECT CIK, REGISTRANT_NAME, COUNTRY, NET_ASSETS\n   FROM FundAssets\n   ORDER BY NET_ASSETS DESC\n   LIMIT 5;', "WITH HoldingsCTE AS (\n       SELECT H.HOLDING_ID, H.ISSUER_NAME, H.FAIR_VALUE_LEVEL, F.SERIES_NAME\n       FROM FUND_REPORTED_HOLDING H\n       JOIN FUND_REPORTED_INFO F ON H.ACCESSION_NUMBER = F.ACCESSION_NUMBER\n       WHERE H.FAIR_VALUE_LEVEL = 'Level 1'\n   )\n   SELECT HOLDING_ID, ISSUER_NAME, SERIES_NAME\n   FROM HoldingsCTE;", 'WITH CollateralCTE AS (\n    SELECT RCP.NAME AS Counterparty_Name, SUM(RC.COLLATERAL_AMOUNT) AS Total_Collateral\n    FROM REPURCHASE_COLLATERAL RC\n    JOIN REPURCHASE_COUNTERPARTY RCP ON RC.HOLDING_ID = RCP.HOLDING_ID\n    GROUP BY RCP.NAME\n   )\n   SELECT Counterparty_Name, Total_Collateral\n   FROM CollateralCTE\n   ORDER BY Total_

In [3]:
SCHEMA_FILE = 'chatgpt_api/schema.json'
print(f"Expected schema path: {SCHEMA_FILE}")  


def format_schema_for_gpt(schema):
    if not schema:
        return "No schema available"
        
    formatted_schema = []
    tables = schema.get('schema', {}).get('tables', [])
    
    for table in tables:
        table_name = table.get('name')
        formatted_schema.append(f"\nTable: {table_name}")
        formatted_schema.append("Columns:")
        for column in table.get('columns', []):
            col_name = column.get('name')
            col_type = column.get('type')
            formatted_schema.append(f"- {col_name} ({col_type})")
    
    return "\n".join(formatted_schema)

try:
    db_schema = api.load_schema_from_json(SCHEMA_FILE)
except Exception as e:
    db_schema = None

schema_info = format_schema_for_gpt(db_schema)
print(schema_info)


DEBUG:chatgpt_api.api:Attempting to load schema from chatgpt_api/schema.json
INFO:chatgpt_api.api:Schema loaded successfully


Expected schema path: chatgpt_api/schema.json

Table: REGISTRANT
Columns:
- ACCESSION_NUMBER (TEXT)
- CIK (TEXT)
- REGISTRANT_NAME (TEXT)
- FILE_NUM (TEXT)
- LEI (TEXT)
- ADDRESS1 (TEXT)
- ADDRESS2 (TEXT)
- CITY (TEXT)
- STATE (TEXT)
- COUNTRY (TEXT)
- ZIP (TEXT)
- PHONE (TEXT)
- QUARTER (TEXT)

Table: FUND_REPORTED_INFO
Columns:
- ACCESSION_NUMBER (TEXT)
- SERIES_NAME (TEXT)
- SERIES_ID (TEXT)
- SERIES_LEI (TEXT)
- TOTAL_ASSETS (TEXT)
- TOTAL_LIABILITIES (TEXT)
- NET_ASSETS (TEXT)
- ASSETS_ATTRBT_TO_MISC_SECURITY (TEXT)
- ASSETS_INVESTED (TEXT)
- BORROWING_PAY_WITHIN_1YR (TEXT)
- CTRLD_COMPANIES_PAY_WITHIN_1YR (TEXT)
- OTHER_AFFILIA_PAY_WITHIN_1YR (TEXT)
- OTHER_PAY_WITHIN_1YR (TEXT)
- BORROWING_PAY_AFTER_1YR (TEXT)
- CTRLD_COMPANIES_PAY_AFTER_1YR (TEXT)
- OTHER_AFFILIA_PAY_AFTER_1YR (TEXT)
- OTHER_PAY_AFTER_1YR (TEXT)
- DELAYED_DELIVERY (TEXT)
- STANDBY_COMMITMENT (TEXT)
- LIQUIDATION_PREFERENCE (TEXT)
- CASH_NOT_RPTD_IN_C_OR_D (TEXT)
- CREDIT_SPREAD_3MON_INVEST (TEXT)
- CREDIT_SPREA

In [4]:
primary_keys = {
            'SUBMISSION': ['ACCESSION_NUMBER'],
            'REGISTRANT': ['ACCESSION_NUMBER'],
            'FUND_REPORTED_INFO': ['ACCESSION_NUMBER'],
            'INTEREST_RATE_RISK': ['ACCESSION_NUMBER', 'INTEREST_RATE_RISK_ID'],
            'BORROWER': ['ACCESSION_NUMBER', 'BORROWER_ID'],
            'BORROW_AGGREGATE': ['ACCESSION_NUMBER', 'BORROW_AGGREGATE_ID'],
            'MONTHLY_TOTAL_RETURN': ['ACCESSION_NUMBER', 'MONTHLY_TOTAL_RETURN_ID'],
            'MONTHLY_RETURN_CAT_INSTRUMENT': ['ACCESSION_NUMBER', 'ASSET_CAT', 'INSTRUMENT_KIND'],
            'FUND_VAR_INFO': ['ACCESSION_NUMBER'],
            'FUND_REPORTED_HOLDING': ['ACCESSION_NUMBER', 'HOLDING_ID'],
            'IDENTIFIERS': ['HOLDING_ID', 'IDENTIFIERS_ID'],
            'DEBT_SECURITY': [],  
            'DEBT_SECURITY_REF_INSTRUMENT': ['HOLDING_ID', 'DEBT_SECURITY_REF_ID'],
            'CONVERTIBLE_SECURITY_CURRENCY': ['HOLDING_ID', 'CONVERTIBLE_SECURITY_ID'],
            'REPURCHASE_AGREEMENT': ['HOLDING_ID'],
            'REPURCHASE_COUNTERPARTY': ['HOLDING_ID', 'REPURCHASE_COUNTERPARTY_ID'],
            'REPURCHASE_COLLATERAL': ['HOLDING_ID', 'REPURCHASE_COLLATERAL_ID'],
            'DERIVATIVE_COUNTERPARTY': ['HOLDING_ID', 'DERIVATIVE_COUNTERPARTY_ID'],
            'SWAPTION_OPTION_WARNT_DERIV': ['HOLDING_ID'],
            'DESC_REF_INDEX_BASKET': ['HOLDING_ID'],
            'DESC_REF_INDEX_COMPONENT': ['HOLDING_ID', 'DESC_REF_INDEX_COMPONENT_ID'],
            'DESC_REF_OTHER': ['HOLDING_ID', 'DESC_REF_OTHER_ID'],
            'FUT_FWD_NONFOREIGNCUR_CONTRACT': ['HOLDING_ID'],
            'FWD_FOREIGNCUR_CONTRACT_SWAP': ['HOLDING_ID'],
            'NONFOREIGN_EXCHANGE_SWAP': ['HOLDING_ID'],
            'FLOATING_RATE_RESET_TENOR': ['HOLDING_ID', 'RATE_RESET_TENOR_ID'],
            'OTHER_DERIV': ['HOLDING_ID'],
            'OTHER_DERIV_NOTIONAL_AMOUNT': ['HOLDING_ID', 'OTHER_DERIV_NOTIONAL_AMOUNT_ID'],
            'SECURITIES_LENDING': ['HOLDING_ID'],
            'EXPLANATORY_NOTE': ['ACCESSION_NUMBER', 'EXPLANATORY_NOTE_ID']
        }

foreign_keys = [
            # ACCESSION_NUMBER relationships
            'REGISTRANT.ACCESSION_NUMBER = FUND_REPORTED_INFO.ACCESSION_NUMBER',
            'INTEREST_RATE_RISK.ACCESSION_NUMBER = FUND_REPORTED_INFO.ACCESSION_NUMBER',
            'BORROWER.ACCESSION_NUMBER = FUND_REPORTED_INFO.ACCESSION_NUMBER',
            'BORROW_AGGREGATE.ACCESSION_NUMBER = FUND_REPORTED_INFO.ACCESSION_NUMBER',
            'MONTHLY_TOTAL_RETURN.ACCESSION_NUMBER = FUND_REPORTED_INFO.ACCESSION_NUMBER',
            'MONTHLY_RETURN_CAT_INSTRUMENT.ACCESSION_NUMBER = FUND_REPORTED_INFO.ACCESSION_NUMBER',
            'FUND_VAR_INFO.ACCESSION_NUMBER = FUND_REPORTED_INFO.ACCESSION_NUMBER',
            'FUND_REPORTED_HOLDING.ACCESSION_NUMBER = FUND_REPORTED_INFO.ACCESSION_NUMBER',
            'EXPLANATORY_NOTE.ACCESSION_NUMBER = FUND_REPORTED_INFO.ACCESSION_NUMBER',
            'SUBMISSION.ACCESSION_NUMBER = FUND_REPORTED_INFO.ACCESSION_NUMBER',

            # HOLDING_ID relationships
            'IDENTIFIERS.HOLDING_ID = FUND_REPORTED_HOLDING.HOLDING_ID',
            'DEBT_SECURITY.HOLDING_ID = FUND_REPORTED_HOLDING.HOLDING_ID',
            'DEBT_SECURITY_REF_INSTRUMENT.HOLDING_ID = FUND_REPORTED_HOLDING.HOLDING_ID',
            'CONVERTIBLE_SECURITY_CURRENCY.HOLDING_ID = FUND_REPORTED_HOLDING.HOLDING_ID',
            'REPURCHASE_AGREEMENT.HOLDING_ID = FUND_REPORTED_HOLDING.HOLDING_ID',
            'REPURCHASE_COUNTERPARTY.HOLDING_ID = FUND_REPORTED_HOLDING.HOLDING_ID',
            'REPURCHASE_COLLATERAL.HOLDING_ID = FUND_REPORTED_HOLDING.HOLDING_ID',
            'DERIVATIVE_COUNTERPARTY.HOLDING_ID = FUND_REPORTED_HOLDING.HOLDING_ID',
            'SWAPTION_OPTION_WARNT_DERIV.HOLDING_ID = FUND_REPORTED_HOLDING.HOLDING_ID',
            'DESC_REF_INDEX_BASKET.HOLDING_ID = FUND_REPORTED_HOLDING.HOLDING_ID',
            'DESC_REF_INDEX_COMPONENT.HOLDING_ID = FUND_REPORTED_HOLDING.HOLDING_ID',
            'DESC_REF_OTHER.HOLDING_ID = FUND_REPORTED_HOLDING.HOLDING_ID',
            'FUT_FWD_NONFOREIGNCUR_CONTRACT.HOLDING_ID = FUND_REPORTED_HOLDING.HOLDING_ID',
            'FWD_FOREIGNCUR_CONTRACT_SWAP.HOLDING_ID = FUND_REPORTED_HOLDING.HOLDING_ID',
            'NONFOREIGN_EXCHANGE_SWAP.HOLDING_ID = FUND_REPORTED_HOLDING.HOLDING_ID',
            'FLOATING_RATE_RESET_TENOR.HOLDING_ID = FUND_REPORTED_HOLDING.HOLDING_ID',
            'OTHER_DERIV.HOLDING_ID = FUND_REPORTED_HOLDING.HOLDING_ID',
            'OTHER_DERIV_NOTIONAL_AMOUNT.HOLDING_ID = FUND_REPORTED_HOLDING.HOLDING_ID',
            'SECURITIES_LENDING.HOLDING_ID = FUND_REPORTED_HOLDING.HOLDING_ID'
        ]

In [5]:
def explore_keys():
    """Explore potential primary and foreign keys in the database"""
    import sqlite3
    
    # Connect to database
    conn = sqlite3.connect('sqlite/nport.db')
    cursor = conn.cursor()

    # Get all tables
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = cursor.fetchall()

    print("Database Key Analysis:")
    print("-" * 80)

    # Analyze each table
    for table in tables:
        table_name = table[0]
        print(f"\nTable: {table_name}")

        # Get column info
        cursor.execute(f"PRAGMA table_info({table_name})")
        columns = cursor.fetchall()
        
        # Get sample count for potential key columns
        for col in columns:
            col_name = col[1]
            # Check if column name contains potential key indicators
            if any(key_term in col_name.lower() for key_term in ['_id', 'accession', 'number']):
                cursor.execute(f"""
                    SELECT COUNT(*) total_rows, 
                           COUNT(DISTINCT {col_name}) unique_values 
                    FROM {table_name}
                    WHERE {col_name} IS NOT NULL
                """)
                stats = cursor.fetchone()
                print(f"Column: {col_name}")
                print(f"Total rows: {stats[0]}")
                print(f"Unique values: {stats[1]}")
                
                # If unique values equals total rows, likely a key
                if stats[0] == stats[1] and stats[0] > 0:
                    print(">>> Potential PRIMARY KEY <<<")

        # Look for foreign key relationships
        for col in columns:
            col_name = col[1]
            if col_name == 'ACCESSION_NUMBER':
                cursor.execute(f"""
                    SELECT COUNT(*) FROM {table_name} t1
                    WHERE EXISTS (
                        SELECT 1 FROM FUND_REPORTED_INFO t2 
                        WHERE t1.ACCESSION_NUMBER = t2.ACCESSION_NUMBER
                    )
                """)
                match_count = cursor.fetchone()[0]
                if match_count > 0:
                    print(f"Foreign Key: {table_name}.ACCESSION_NUMBER -> FUND_REPORTED_INFO.ACCESSION_NUMBER")
            
            elif col_name == 'HOLDING_ID':
                cursor.execute(f"""
                    SELECT COUNT(*) FROM {table_name} t1
                    WHERE EXISTS (
                        SELECT 1 FROM FUND_REPORTED_HOLDING t2 
                        WHERE t1.HOLDING_ID = t2.HOLDING_ID
                    )
                """)
                match_count = cursor.fetchone()[0]
                if match_count > 0:
                    print(f"Foreign Key: {table_name}.HOLDING_ID -> FUND_REPORTED_HOLDING.HOLDING_ID")

    conn.close()

# Run the analysis
explore_keys()

Database Key Analysis:
--------------------------------------------------------------------------------

Table: REGISTRANT
Column: ACCESSION_NUMBER
Total rows: 2822
Unique values: 2822
>>> Potential PRIMARY KEY <<<
Foreign Key: REGISTRANT.ACCESSION_NUMBER -> FUND_REPORTED_INFO.ACCESSION_NUMBER

Table: FUND_REPORTED_INFO
Column: ACCESSION_NUMBER
Total rows: 2822
Unique values: 2822
>>> Potential PRIMARY KEY <<<
Column: SERIES_ID
Total rows: 2822
Unique values: 2643
Foreign Key: FUND_REPORTED_INFO.ACCESSION_NUMBER -> FUND_REPORTED_INFO.ACCESSION_NUMBER

Table: INTEREST_RATE_RISK
Column: ACCESSION_NUMBER
Total rows: 4629
Unique values: 1460
Column: INTEREST_RATE_RISK_ID
Total rows: 4629
Unique values: 4629
>>> Potential PRIMARY KEY <<<
Foreign Key: INTEREST_RATE_RISK.ACCESSION_NUMBER -> FUND_REPORTED_INFO.ACCESSION_NUMBER

Table: BORROWER
Column: ACCESSION_NUMBER
Total rows: 12685
Unique values: 1189
Column: BORROWER_ID
Total rows: 12685
Unique values: 12685
>>> Potential PRIMARY KEY <<<


In [None]:
############################################ VALUE RETRIEVAL AND SCHEMA LINKING
class PSLsh:
    def __init__(self, vectors, n_planes=10, n_tables=5, seed: int = 42):
        self.n_planes = n_planes
        self.n_tables = n_tables
        self.hash_tables = [{} for _ in range(n_tables)]
        self.random_planes = []
        
        np.random.seed(seed)
        
        for _ in range(n_tables):
            planes = np.random.randn(vectors.shape[1], n_planes)
            self.random_planes.append(planes)
            
        self.num_vectors = vectors.shape[0]
        self.vectors = vectors
        self.build_hash_tables()

    def build_hash_tables(self):
        for idx in range(self.num_vectors):
            vector = self.vectors[idx].toarray()[0]
            hashes = self.hash_vector(vector)
            for i, h in enumerate(hashes):
                if h not in self.hash_tables[i]:
                    self.hash_tables[i][h] = []
                self.hash_tables[i][h].append(idx)

    def hash_vector(self, vector):
        hashes = []
        for planes in self.random_planes:
            projections = np.dot(vector, planes)
            hash_code = ''.join(['1' if x > 0 else '0' for x in projections])
            hashes.append(hash_code)
        return hashes

    def query(self, vector):
        hashes = self.hash_vector(vector)
        candidates = set()
        for i, h in enumerate(hashes):
            candidates.update(self.hash_tables[i].get(h, []))
        return candidates


class ValueRetrieval:
    financial_terms = {
            'total': ['total', 'sum', 'aggregate', 'combined'],
            'assets': ['asset', 'holdings', 'investments', 'securities'],
            'liabilities': ['liability', 'debt', 'obligations'],
            'net': ['net', 'pure', 'adjusted'],
            'fund': ['fund', 'portfolio', 'investment vehicle'],
            'return': ['return', 'yield', 'profit', 'gain'],
            'monthly': ['monthly', 'month', 'monthly basis'],
            'rate': ['rate', 'percentage', 'ratio'],
            'risk': ['risk', 'exposure', 'vulnerability']
        }
    
    def __init__(self, schema_path: str = 'chatgpt_api/schema.json', lsh_seed: int = 42):
        load_dotenv()
        self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))

        # Load schema
        print("DEBUG: Loading schema file:", schema_path)
        with open(schema_path, 'r') as f:
            self.schema = json.load(f)

        # Initialize lemmatizer and stop words
        self.lemmatizer = WordNetLemmatizer()
        self.stop_words = set(stopwords.words('english'))
        
        # Build column name index
        self.column_index = self._build_column_index()        

        # Build vectorizer and LSH for backup matching
        self.build_vectorizer_and_lsh(seed=lsh_seed)
        
        # Get schema relationships
        self.primary_keys, self.foreign_keys = self.discover_schema_relationships()

    def _build_column_index(self) -> Dict:
        column_index = {}
        tables = self.schema.get('schema', {}).get('tables', [])
        
        for table in tables:
            table_name = table.get('name', '').lower()
            for column in table.get('columns', []):
                column_name = column.get('name', '').lower()
                
                # Store the full qualified name and column properties
                qualified_name = f"{table_name}.{column_name}"
                column_index[qualified_name] = {
                    'table': table_name,
                    'column': column_name,
                    'type': column.get('type', ''),
                    'words': self._split_column_name(column_name),
                    'synonyms': self._get_column_synonyms(column_name)
                }
                
        return column_index

    def _split_column_name(self, column_name: str) -> List[str]:
        """Split column name into individual words."""
        # Handle  underscore + camel case.
        words = re.sub('([A-Z][a-z]+)', r' \1', re.sub('([A-Z]+)', r' \1', column_name)).split()
        words.extend(column_name.split('_'))
        return [word.lower() for word in words if word]

    def _get_column_synonyms(self, column_name: str) -> List[str]:
        """Get synonyms for words in column name."""
        words = self._split_column_name(column_name)
        synonyms = []
        
        for word in words:
            if word in self.financial_terms:
                synonyms.extend(self.financial_terms[word])
                
        return list(set(synonyms))

    def build_vectorizer_and_lsh(self, seed: int):
        self.vectorizer = TfidfVectorizer(analyzer='char', ngram_range=(1, 3), min_df=1, max_df=0.95)
        self.term_list = self.get_schema_terms()
        self.term_vectors = self.vectorizer.fit_transform(self.term_list)
        self.lsh = PSLsh(self.term_vectors, n_planes=10, n_tables=5)

    def get_schema_terms(self) -> List[str]:
        terms = []
        tables = self.schema.get('schema', {}).get('tables', [])
        for table in tables:
            table_name = table.get('name', '').lower()
            terms.append(table_name)
            for column in table.get('columns', []):
                column_name = column.get('name', '').lower()
                terms.append(f"{table_name}.{column_name}")
        return terms

    def discover_schema_relationships(self):
        # Define our primary keys and foreign keys here
        primary_keys = {
            'SUBMISSION': ['ACCESSION_NUMBER'],
            'REGISTRANT': ['ACCESSION_NUMBER'],
            'FUND_REPORTED_INFO': ['ACCESSION_NUMBER'],
            'INTEREST_RATE_RISK': ['ACCESSION_NUMBER', 'INTEREST_RATE_RISK_ID'],
            'BORROWER': ['ACCESSION_NUMBER', 'BORROWER_ID'],
            'BORROW_AGGREGATE': ['ACCESSION_NUMBER', 'BORROW_AGGREGATE_ID'],
            'MONTHLY_TOTAL_RETURN': ['ACCESSION_NUMBER', 'MONTHLY_TOTAL_RETURN_ID'],
            'MONTHLY_RETURN_CAT_INSTRUMENT': ['ACCESSION_NUMBER', 'ASSET_CAT', 'INSTRUMENT_KIND'],
            'FUND_VAR_INFO': ['ACCESSION_NUMBER'],
            'FUND_REPORTED_HOLDING': ['ACCESSION_NUMBER', 'HOLDING_ID'],
            'IDENTIFIERS': ['HOLDING_ID', 'IDENTIFIERS_ID'],
            'DEBT_SECURITY': [],  
            'DEBT_SECURITY_REF_INSTRUMENT': ['HOLDING_ID', 'DEBT_SECURITY_REF_ID'],
            'CONVERTIBLE_SECURITY_CURRENCY': ['HOLDING_ID', 'CONVERTIBLE_SECURITY_ID'],
            'REPURCHASE_AGREEMENT': ['HOLDING_ID'],
            'REPURCHASE_COUNTERPARTY': ['HOLDING_ID', 'REPURCHASE_COUNTERPARTY_ID'],
            'REPURCHASE_COLLATERAL': ['HOLDING_ID', 'REPURCHASE_COLLATERAL_ID'],
            'DERIVATIVE_COUNTERPARTY': ['HOLDING_ID', 'DERIVATIVE_COUNTERPARTY_ID'],
            'SWAPTION_OPTION_WARNT_DERIV': ['HOLDING_ID'],
            'DESC_REF_INDEX_BASKET': ['HOLDING_ID'],
            'DESC_REF_INDEX_COMPONENT': ['HOLDING_ID', 'DESC_REF_INDEX_COMPONENT_ID'],
            'DESC_REF_OTHER': ['HOLDING_ID', 'DESC_REF_OTHER_ID'],
            'FUT_FWD_NONFOREIGNCUR_CONTRACT': ['HOLDING_ID'],
            'FWD_FOREIGNCUR_CONTRACT_SWAP': ['HOLDING_ID'],
            'NONFOREIGN_EXCHANGE_SWAP': ['HOLDING_ID'],
            'FLOATING_RATE_RESET_TENOR': ['HOLDING_ID', 'RATE_RESET_TENOR_ID'],
            'OTHER_DERIV': ['HOLDING_ID'],
            'OTHER_DERIV_NOTIONAL_AMOUNT': ['HOLDING_ID', 'OTHER_DERIV_NOTIONAL_AMOUNT_ID'],
            'SECURITIES_LENDING': ['HOLDING_ID'],
            'EXPLANATORY_NOTE': ['ACCESSION_NUMBER', 'EXPLANATORY_NOTE_ID']
        }

        foreign_keys = [
            # ACCESSION_NUMBER relationships
            'REGISTRANT.ACCESSION_NUMBER = FUND_REPORTED_INFO.ACCESSION_NUMBER',
            'INTEREST_RATE_RISK.ACCESSION_NUMBER = FUND_REPORTED_INFO.ACCESSION_NUMBER',
            'BORROWER.ACCESSION_NUMBER = FUND_REPORTED_INFO.ACCESSION_NUMBER',
            'BORROW_AGGREGATE.ACCESSION_NUMBER = FUND_REPORTED_INFO.ACCESSION_NUMBER',
            'MONTHLY_TOTAL_RETURN.ACCESSION_NUMBER = FUND_REPORTED_INFO.ACCESSION_NUMBER',
            'MONTHLY_RETURN_CAT_INSTRUMENT.ACCESSION_NUMBER = FUND_REPORTED_INFO.ACCESSION_NUMBER',
            'FUND_VAR_INFO.ACCESSION_NUMBER = FUND_REPORTED_INFO.ACCESSION_NUMBER',
            'FUND_REPORTED_HOLDING.ACCESSION_NUMBER = FUND_REPORTED_INFO.ACCESSION_NUMBER',
            'EXPLANATORY_NOTE.ACCESSION_NUMBER = FUND_REPORTED_INFO.ACCESSION_NUMBER',
            'SUBMISSION.ACCESSION_NUMBER = FUND_REPORTED_INFO.ACCESSION_NUMBER',

            # HOLDING_ID relationships
            'IDENTIFIERS.HOLDING_ID = FUND_REPORTED_HOLDING.HOLDING_ID',
            'DEBT_SECURITY.HOLDING_ID = FUND_REPORTED_HOLDING.HOLDING_ID',
            'DEBT_SECURITY_REF_INSTRUMENT.HOLDING_ID = FUND_REPORTED_HOLDING.HOLDING_ID',
            'CONVERTIBLE_SECURITY_CURRENCY.HOLDING_ID = FUND_REPORTED_HOLDING.HOLDING_ID',
            'REPURCHASE_AGREEMENT.HOLDING_ID = FUND_REPORTED_HOLDING.HOLDING_ID',
            'REPURCHASE_COUNTERPARTY.HOLDING_ID = FUND_REPORTED_HOLDING.HOLDING_ID',
            'REPURCHASE_COLLATERAL.HOLDING_ID = FUND_REPORTED_HOLDING.HOLDING_ID',
            'DERIVATIVE_COUNTERPARTY.HOLDING_ID = FUND_REPORTED_HOLDING.HOLDING_ID',
            'SWAPTION_OPTION_WARNT_DERIV.HOLDING_ID = FUND_REPORTED_HOLDING.HOLDING_ID',
            'DESC_REF_INDEX_BASKET.HOLDING_ID = FUND_REPORTED_HOLDING.HOLDING_ID',
            'DESC_REF_INDEX_COMPONENT.HOLDING_ID = FUND_REPORTED_HOLDING.HOLDING_ID',
            'DESC_REF_OTHER.HOLDING_ID = FUND_REPORTED_HOLDING.HOLDING_ID',
            'FUT_FWD_NONFOREIGNCUR_CONTRACT.HOLDING_ID = FUND_REPORTED_HOLDING.HOLDING_ID',
            'FWD_FOREIGNCUR_CONTRACT_SWAP.HOLDING_ID = FUND_REPORTED_HOLDING.HOLDING_ID',
            'NONFOREIGN_EXCHANGE_SWAP.HOLDING_ID = FUND_REPORTED_HOLDING.HOLDING_ID',
            'FLOATING_RATE_RESET_TENOR.HOLDING_ID = FUND_REPORTED_HOLDING.HOLDING_ID',
            'OTHER_DERIV.HOLDING_ID = FUND_REPORTED_HOLDING.HOLDING_ID',
            'OTHER_DERIV_NOTIONAL_AMOUNT.HOLDING_ID = FUND_REPORTED_HOLDING.HOLDING_ID',
            'SECURITIES_LENDING.HOLDING_ID = FUND_REPORTED_HOLDING.HOLDING_ID'
        ]

        formatted_pks = []
        for table, keys in primary_keys.items():
            for key in keys:
                formatted_pks.append(f"{table}.{key}")

        return formatted_pks, foreign_keys

    def find_similar_words(self, word: str) -> List[Tuple[str, float]]:
        """Better matching using multiple techniques - backup method with financial terms dictionary."""
        if not word:
            return []

        word = word.lower()
        
        matches = []
        
        # 1. Direct matching with column names and their components
        for qualified_name, metadata in self.column_index.items():
            score = 0.0
            
            # Check exact matches in column words
            if word in metadata['words']:
                matches.append((qualified_name, 1.0))
                continue
                
            # Check synonyms
            if word in self.financial_terms.get(word, []):
                matches.append((qualified_name, 0.9))
                continue
            
            # Fuzzy match with column words
            for col_word in metadata['words']:
                ratio = fuzz.ratio(word, col_word) / 100.0
                if ratio > score:
                    score = ratio
            
            # Fuzzy match with synonyms
            for term, synonyms in self.financial_terms.items():
                if term in metadata['words']:
                    for synonym in synonyms:
                        ratio = fuzz.ratio(word, synonym) / 100.0
                        if ratio > score:
                            score = ratio * 0.9  # Slightly lower weight for synonym matches
            
            if score > 0.6:  # Only include if similarity is above 60%
                matches.append((qualified_name, score))

        # 2. LSH-based matching as backup
        if len(matches) < 5:  # If we have fewer than 5 matches, try LSH
            try:
                word_vector = self.vectorizer.transform([word]).toarray()[0]
                candidate_indices = self.lsh.query(word_vector)
                
                for idx in candidate_indices:
                    term = self.term_list[idx]
                    if not any(term == m[0] for m in matches):  # Avoid duplicates
                        candidate_vector = self.term_vectors[idx].toarray()[0]
                        dist = np.linalg.norm(word_vector - candidate_vector)
                        sim = 1 / (1 + dist)
                        if sim > 0.5:  # Only include if similarity is above 50%
                            matches.append((term, sim * 0.8))
            except Exception as e:
                print(f"LSH matching failed: {e}")

        # Remove duplicates keeping highest score and sort by score
        unique_matches = {}
        for term, score in matches:
            if term not in unique_matches or score > unique_matches[term]:
                unique_matches[term] = score
        
        matches = [(term, score) for term, score in unique_matches.items()]
        matches.sort(key=lambda x: x[1], reverse=True)
        
        # Print debug info
        print(f"Found {len(matches)} matches for '{word}':")
        for match, score in matches[:5]:
            print(f"  {match}: {score:.4f}")
        
        return matches[:5] if matches else [('fund_reported_info.total_assets', 0.6)] if word in ['total', 'asset', 'assets'] else []
    
    def extract_keywords(self, question: str) -> Dict:
        system_prompt = """Given a financial database schema:
        {schema_info}

        Primary Keys: {primary_keys}
        Foreign Keys: {foreign_keys}

        Extract from the question schema-aware components using the examples below."""

        few_shot_examples = """
        Example Question: "Show me all equity-focused funds"
        {
        "keywords": ["equity", "funds", "series"],
        "keyphrases": ["equity-focused funds"], 
        "table_matches": ["FUND_REPORTED_INFO"],
        "column_matches": ["SERIES_NAME", "TOTAL_ASSETS"],
        "primary_keys": ["FUND_REPORTED_INFO.ACCESSION_NUMBER"]
        }

        Example Question: "Show fund holdings over 1 billion in assets"
        {
        "keywords": ["holdings", "assets", "funds"],
        "numerical_values": ["1 billion"],
        "table_matches": ["FUND_REPORTED_INFO", "FUND_REPORTED_HOLDING"],
        "column_matches": ["TOTAL_ASSETS", "SERIES_NAME", "HOLDING_VALUE"],
        "required_joins": [
            "FUND_REPORTED_INFO to FUND_REPORTED_HOLDING via ACCESSION_NUMBER"
        ],
        "primary_keys": [
            "FUND_REPORTED_INFO.ACCESSION_NUMBER",
            "FUND_REPORTED_HOLDING.HOLDING_ID"
        ],
        "foreign_keys": [
            "FUND_REPORTED_HOLDING.ACCESSION_NUMBER = FUND_REPORTED_INFO.ACCESSION_NUMBER"
        ]
        }"""

        response = self.client.chat.completions.create(
            model="gpt-4o",
            messages=[
                {"role": "system", "content": system_prompt.format(
                    schema_info=self.schema,
                    primary_keys=self.primary_keys,
                    foreign_keys=self.foreign_keys
                )},
                {"role": "user", "content": few_shot_examples + f"\n\nQuestion: {question}"}
            ],
            tools=[{
                "type": "function",
                "function": {
                    "name": "extract_components",
                    "description": "Extract components mapping to schema",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "keywords": {"type": "array", "items": {"type": "string"}},
                            "keyphrases": {"type": "array", "items": {"type": "string"}},
                            "table_matches": {"type": "array", "items": {"type": "string"}},
                            "column_matches": {"type": "array", "items": {"type": "string"}},
                            "required_joins": {"type": "array", "items": {"type": "string"}},
                            "primary_keys": {"type": "array", "items": {"type": "string"}},
                            "foreign_keys": {"type": "array", "items": {"type": "string"}},
                            "numerical_values": {"type": "array", "items": {"type": "string"}}
                        },
                        "required": ["keywords", "table_matches", "column_matches"]
                    }
                }
            }],
            tool_choice={"type": "function", "function": {"name": "extract_components"}}
        )

        function_call = response.choices[0].message.tool_calls[0].function
        return json.loads(function_call.arguments)

    def preprocess_text(self, text: str) -> List[str]:
        """Tokenize and lemmatize input text, removing stop words."""
        if not text:  # Add check for empty text
            return []
            
        try:
            tokens = nltk.word_tokenize(str(text).lower())
            filtered_tokens = [word for word in tokens if word not in self.stop_words and word.isalnum()]
            lemmatized_tokens = [self.lemmatizer.lemmatize(token) for token in filtered_tokens]
            return lemmatized_tokens
        except Exception as e:
            print(f"Error in preprocessing text '{text}': {str(e)}")
            return []  # Return empty list instead of None on error
       
       
    def process_schema(self, question: str) -> str:
        # Get all the processing results
        results = self.process_question(question)
        
        # Organize schema links by type
        table_columns = []
        relevant_primary_keys = []
        relevant_foreign_keys = []
        
        # 1. Get main table/column matches
        for word, matches in results['similar_matches'].items():
            if matches:
                # Only take the top match if score > 0.7
                top_match = matches[0]  # (match, score)
                if top_match[1] > 0.7:
                    # Handle numerical values
                    if word in results['extracted_info'].get('numerical_values', []):
                        if 'billion' in word.lower():
                            table_columns.append(f"{top_match[0]} > 1000000000")
                        elif 'million' in word.lower():
                            table_columns.append(f"{top_match[0]} > 1000000")
                        else:
                            table_columns.append(f"{top_match[0]} > {word}")
                    else:
                        table_columns.append(top_match[0])
        
        # 2. Get relevant tables
        tables_needed = set()
        for link in table_columns:
            if '.' in link:
                tables_needed.add(link.split('.')[0].upper())
        
        # 3. Add relevant primary keys
        for pk in results['schema_relationships']['primary_keys']:
            table = pk.split('.')[0]
            if table in tables_needed:
                relevant_primary_keys.append(pk)
        
        # 4. Add relevant foreign keys
        for fk in results['schema_relationships']['foreign_keys']:
            tables_in_fk = set(part.split('.')[0] for part in fk.split(' = '))
            if tables_in_fk.intersection(tables_needed):
                relevant_foreign_keys.append(fk)
        
        # Format output with sections
        schema_dict = {
            "table_columns": table_columns,
            "primary_keys": relevant_primary_keys,
            "foreign_keys": relevant_foreign_keys
            #### ADD ONE MORE KEY AS SCHEMA_LINKS FROM DIN_SQL
            #### ADD SCHEMA LINKS
        }
        
        print("\nProcessed Schema Links:")
        print("Table Columns:", table_columns)
        print("Primary Keys:", relevant_primary_keys)
        print("Foreign Keys:", relevant_foreign_keys)
        
        return str(schema_dict)


    def process_question(self, question: str) -> Dict:
        # Extract keywords using gpt
        extracted_info = self.extract_keywords(question)

        words = []
        for key in ['keywords', 'keyphrases', 'named_entities', 'numerical_values']:
            words.extend(extracted_info.get(key, []))

        # Preprocess the words (lemmatize, remove stop words)
        processed_words = []
        for word in words:
            processed_words.extend(self.preprocess_text(word))

        # Remove duplicates
        processed_words = list(set(processed_words))

        # Find similar columns for each word
        similar_matches = {}
        for word in processed_words:
            similar_matches[word] = self.find_similar_words(word)

        # Combine the results
        result = {
            "question": question,
            "extracted_info": extracted_info,
            "processed_words": processed_words,
            "similar_matches": similar_matches,
            "schema_relationships": {
                "primary_keys": self.primary_keys,
                "foreign_keys": self.foreign_keys
            }
        }
        return result
    
if __name__ == "__main__":
    vr = ValueRetrieval(schema_path='chatgpt_api/schema.json')
    schema_links = vr.process_schema("Show me all funds with total assets over 1 billion")
    print("Schema Links:", schema_links)

DEBUG:httpx:load_ssl_context verify=True cert=None trust_env=True http2=False
DEBUG:httpx:load_verify_locations cafile='C:\\Users\\User\\anaconda3\\Library\\ssl\\cacert.pem'


DEBUG:openai._base_client:Request options: {'method': 'post', 'url': '/chat/completions', 'files': None, 'json_data': {'messages': [{'role': 'system', 'content': "Given a financial database schema:\n        {'type': 'database', 'schema': {'tables': [{'name': 'REGISTRANT', 'columns': [{'name': 'ACCESSION_NUMBER', 'type': 'TEXT'}, {'name': 'CIK', 'type': 'TEXT'}, {'name': 'REGISTRANT_NAME', 'type': 'TEXT'}, {'name': 'FILE_NUM', 'type': 'TEXT'}, {'name': 'LEI', 'type': 'TEXT'}, {'name': 'ADDRESS1', 'type': 'TEXT'}, {'name': 'ADDRESS2', 'type': 'TEXT'}, {'name': 'CITY', 'type': 'TEXT'}, {'name': 'STATE', 'type': 'TEXT'}, {'name': 'COUNTRY', 'type': 'TEXT'}, {'name': 'ZIP', 'type': 'TEXT'}, {'name': 'PHONE', 'type': 'TEXT'}, {'name': 'QUARTER', 'type': 'TEXT'}]}, {'name': 'FUND_REPORTED_INFO', 'columns': [{'name': 'ACCESSION_NUMBER', 'type': 'TEXT'}, {'name': 'SERIES_NAME', 'type': 'TEXT'}, {'name': 'SERIES_ID', 'type': 'TEXT'}, {'name': 'SERIES_LEI', 'type': 'TEXT'}, {'name': 'TOTAL_ASSETS

DEBUG: Loading schema file: chatgpt_api/schema.json


DEBUG:httpcore.connection:connect_tcp.complete return_value=<httpcore._backends.sync.SyncStream object at 0x0000022CC61937D0>
DEBUG:httpcore.connection:start_tls.started ssl_context=<ssl.SSLContext object at 0x0000022CF9AB8DD0> server_hostname='api.openai.com' timeout=5.0
DEBUG:httpcore.connection:start_tls.complete return_value=<httpcore._backends.sync.SyncStream object at 0x0000022CFB3FB6D0>
DEBUG:httpcore.http11:send_request_headers.started request=<Request [b'POST']>
DEBUG:httpcore.http11:send_request_headers.complete
DEBUG:httpcore.http11:send_request_body.started request=<Request [b'POST']>
DEBUG:httpcore.http11:send_request_body.complete
DEBUG:httpcore.http11:receive_response_headers.started request=<Request [b'POST']>
DEBUG:httpcore.http11:receive_response_headers.complete return_value=(b'HTTP/1.1', 200, b'OK', [(b'Date', b'Tue, 26 Nov 2024 00:41:55 GMT'), (b'Content-Type', b'application/json'), (b'Transfer-Encoding', b'chunked'), (b'Connection', b'keep-alive'), (b'access-contr

Found 2 matches for 'billion':
  submission.filing_date: 0.6154
  submission.is_last_filing: 0.6154
Found 279 matches for 'fund':
  securities_lending.is_loan_by_fund: 1.0000
  registrant.accession_number: 0.9000
  registrant.cik: 0.9000
  registrant.registrant_name: 0.9000
  registrant.file_num: 0.9000
Found 0 matches for '1':
Found 279 matches for 'total':
  fund_reported_info.total_assets: 1.0000
  fund_reported_info.total_liabilities: 1.0000
  monthly_total_return.monthly_total_return_id: 1.0000
  monthly_total_return.monthly_total_return1: 1.0000
  monthly_total_return.monthly_total_return2: 1.0000
Found 7 matches for 'asset':
  fund_reported_holding.asset_cat: 1.0000
  monthly_return_cat_instrument.asset_cat: 1.0000
  fund_reported_info.total_assets: 0.9000
  fund_reported_info.net_assets: 0.9000
  fund_reported_info.assets_attrbt_to_misc_security: 0.9000

Processed Schema Links:
Table Columns: ['securities_lending.is_loan_by_fund', 'fund_reported_info.total_assets', 'fund_report

In [None]:
############################################ CLASSIFICATION
classification_prompt = '''Q: "Find the filing date and submission number of all reports filed for an NPORT-P submission."
schema_links: [submission.filing_date, submission.sub_type = "NPORT-P", submission.accession_number]
A: Let’s think step by step. The SQL query for the question "Find the filing date and submission number of all reports filed for an NPORT-P submission." needs these tables = [submission], so we don't need JOIN.
Plus, it doesn't require nested queries with (INTERSECT, UNION, EXCEPT, IN, NOT IN), and we need the answer to the questions = [""]. 
So, we don't need JOIN and don't need nested queries, then the SQL query can be classified as "EASY".
Label: "EASY"

Q: "Get the names and CIK of registrants who are located in California."
schema_links: [registrant.registrant_name, registrant.cik, registrant.state = "US-CA"]
A: Let’s think step by step. The SQL query for the question "Get the names and CIK of registrants who are located in California." needs these tables = [registrant], so we don't need JOIN.
Plus, it doesn't require nested queries with (INTERSECT, UNION, EXCEPT, IN, NOT IN), and we need the answer to the questions = [""]. 
So, we don't need JOIN and don't need nested queries, then the SQL query can be classified as "EASY".
Label: "EASY"

Q: "Find the names and CIK of registrants in California, but only for those whose total assets are above 100 million."
schema_links: [registrant.registrant_name, registrant.cik, registrant.state = "US-CA", fund_reported_info.total_assets > 100000000]
A: Let's analyze this. The query involves data from two tables: "registrant" for registrant details and "fund_reported_info" for total assets. Since we need to check if total assets exceed 100 million, a nested query is necessary to filter based on this condition. This is a nested query. So, the SQL query can be classified as "NESTED."
Label: "NESTED"

'''

def classification_prompt_maker(question, relevant_schema_links):
   instruction = """# Given the database schema:
{schema_info}

Primary Keys:
{primary_keys}

Foreign Keys:
{foreign_keys}

- For the given question, classify it as EASY, NON-NESTED, or NESTED based on nested queries and JOIN
- if need nested queries: predict NESTED
- elif need JOIN and don't need nested queries: predict NON-NESTED
- elif don't need JOIN and don't need nested queries: predict EASY

Consider table relationships and what joins would be needed."""

   prompt = instruction.format(
       schema_info=schema_info,
       primary_keys=primary_keys,
       foreign_keys=foreign_keys
   ) + classification_prompt + f'Q: "{question}"\nrelevant_schema_links: {relevant_schema_links}\nA: Let\'s think step by step.'
       
   return prompt

def process_question_classification(question, relevant_schema_links):
    def extract_classification(text):
        print(f"Trying to extract classification from: {text}")
        # Common patterns in GPT's response
        patterns = [
            "Label:", 
            "Classification:", 
            "The SQL query can be classified as",
            "can be classified as"
        ]
        
        text = text.upper()  # Normalize text
        # Direct class detection
        for class_type in ["EASY", "NON-NESTED", "NESTED"]:
            if class_type in text:
                return class_type

        # Try splitting with different patterns
        for pattern in patterns:
            if pattern.upper() in text:
                parts = text.split(pattern.upper())
                if len(parts) > 1:
                    # Get the last part and clean it
                    result = parts[1].strip().strip('"').strip("'")
                    # Extract first word as classification
                    classification = result.split()[0].strip()
                    if classification in ["EASY", "NON-NESTED", "NESTED"]:
                        return classification
                        
        return "NESTED"  # Default fallback

    classification = None
    attempts = 0
    while classification is None and attempts < 3:
        try:
            print("Attempting classification...")
            response = client.chat.completions.create(
                model="gpt-4o",
                messages=[{
                    "role": "user", 
                    "content": classification_prompt_maker(question, relevant_schema_links) #### ADD SCHEMA LINKS
                }],
                n=1,
                stream=False,
                temperature=0.0,
                max_tokens=300,
                top_p=1.0,
                frequency_penalty=0.0,
                presence_penalty=0.0
            )
            raw_response = response.choices[0].message.content
            print("Raw response:", raw_response)
            classification = extract_classification(raw_response)
        except Exception as e:
            print(f"Error occurred: {str(e)}")
            time.sleep(3)
            attempts += 1
    
    final_class = classification if classification else "NESTED"
    return f'"{final_class}"'

############################################ SQL GENERATION
easy_prompt = '''Q: "Find the issuers with a balance greater than 1 million."
Schema_links: [fund_reported_holding.balance]
SQL: SELECT DISTINCT issuer_name 
      FROM fund_reported_holding 
      WHERE balance > 1000000
'''

medium_prompt = '''Q: "Find the total upfront payments and receipts for swaps with fixed rate receipts."
Schema_links: [nonforeign_exchange_swap.upfront_payment, nonforeign_exchange_swap.upfront_receipt, nonforeign_exchange_swap.fixed_rate_receipt]
A: Let’s think step by step. For creating the SQL for the given question, we need to filter the swaps that have fixed rate receipts. Then, sum up the upfront payments and receipts. First, create an intermediate representation, then use it to construct the SQL query.
Intermediate_representation: 
SELECT SUM(nonforeign_exchange_swap.upfront_payment) + SUM(nonforeign_exchange_swap.upfront_receipt) 
FROM nonforeign_exchange_swap 
WHERE nonforeign_exchange_swap.fixed_rate_receipt IS NOT NULL
SQL: 
SELECT SUM(upfront_payment) + SUM(upfront_receipt) 
FROM nonforeign_exchange_swap 
WHERE fixed_rate_receipt IS NOT NULL
'''

hard_prompt = '''Q: "Find the borrowers with aggregate value greater than $1 million and whose interest rate change at 10-year maturity for a 100 basis point change is positive."
Schema_links: [borrower.aggregate_value, borrower.name, interest_rate_risk.intrst_rate_change_10yr_dv100]
A: Let's think step by step. First, we need to filter borrowers with aggregate values greater than $1 million. Then, we need to check for interest rate changes at 10-year maturity where the change is positive. 
The SQL query for the sub-question "What are the borrowers with aggregate value greater than $1 million and positive interest rate change at 10-year maturity for 100 basis points?" is:

Intermediate_representation: 
SELECT borrower.name 
FROM borrower 
JOIN interest_rate_risk 
ON borrower.accession_number = interest_rate_risk.accession_number 
WHERE borrower.aggregate_value > 1000000 
AND interest_rate_risk.intrst_rate_change_10yr_dv100 > 0

SQL: 
SELECT borrower.name 
FROM borrower 
JOIN interest_rate_risk 
ON borrower.accession_number = interest_rate_risk.accession_number 
WHERE borrower.aggregate_value > 1000000 
AND interest_rate_risk.intrst_rate_change_10yr_dv100 > 0
'''

def hard_prompt_maker(question, schema_links, sub_questions=""):
   instruction = f"""# Given the database schema:
{schema_info}

Primary Keys:
{primary_keys}

Foreign Keys:
{foreign_keys}

Use the intermediate representation and schema links to generate SQL queries."""

   if sub_questions=="":
       stepping = f'''\nA: Let's think step by step.'''
   else:
       stepping = f'''\nA: Let's think step by step.'''
   prompt = instruction + hard_prompt + chat_prompt.gpt_queries_hard + f'Q: "{question}"\nschema_links: {schema_links}\nA:'
   return prompt

def medium_prompt_maker(question, schema_links):
   instruction = f"""# Given the database schema:
{schema_info}

Primary Keys:
{primary_keys}

Foreign Keys:
{foreign_keys}

Use the schema links and Intermediate_representation to generate SQL queries."""

   prompt = instruction + medium_prompt + chat_prompt.gpt_queries_medium + f'Q: "{question}"\nSchema_links: {schema_links}\nA: Let\'s think step by step.'
   return prompt

def easy_prompt_maker(question, schema_links):
   instruction = f"""# Given the database schema:
{schema_info}

Primary Keys:
{primary_keys}

Foreign Keys:
{foreign_keys}

Use the schema links to generate SQL queries."""

   prompt = instruction + easy_prompt + chat_prompt.gpt_queries_easy + f'Q: "{question}"\nSchema_links: {schema_links}\nSQL:' #### ADD SCHEMA LINKS
   return prompt

In [None]:
import time
def process_question_sql(question, predicted_class, schema_links, max_retries=3):
    def extract_sql(text):
        print(f"\nTrying to extract SQL from: {text}")  # Debug print
        if not text:
            return "SELECT"
            
        markers = ["SQL:", "Query:", "QUERY:", "SQL Query:", "Final SQL:"]
        for marker in markers:
            if marker in text:
                parts = text.split(marker)
                if len(parts) > 1:
                    sql = parts[-1].strip()
                    print(f"Found SQL after {marker}: {sql}")  # Debug print
                    return sql
        print("No SQL marker found, returning full text")  # Debug print
        return text.strip()

    if '"EASY"' in predicted_class:
        print("EASY")
        for attempt in range(max_retries):
            try:
                SQL = GPT4_generation(easy_prompt_maker(
                    question=question,
                    schema_links=schema_links
                ))
                if SQL:
                    SQL = extract_sql(SQL)
                    break
            except Exception as e:
                print(f"Attempt {attempt + 1} failed: {str(e)}")
                if attempt < max_retries - 1:
                    time.sleep(3)
                else:
                    SQL = "SELECT"
                    
    elif '"NON-NESTED"' in predicted_class:
        print("NON-NESTED")
        for attempt in range(max_retries):
            try:
                SQL = GPT4_generation(medium_prompt_maker(
                    question=question,
                    schema_links=schema_links
                ))
                if SQL:
                    SQL = extract_sql(SQL)
                    break
            except Exception as e:
                print(f"Attempt {attempt + 1} failed: {str(e)}")
                if attempt < max_retries - 1:
                    time.sleep(3)
                else:
                    SQL = "SELECT"
                    
    else:
        print("NESTED")
        for attempt in range(max_retries):
            try:
                SQL = GPT4_generation(hard_prompt_maker(
                    question=question,
                    schema_links=schema_links
                ))
                if SQL:
                    SQL = extract_sql(SQL)
                    break
            except Exception as e:
                print(f"Attempt {attempt + 1} failed: {str(e)}")
                if attempt < max_retries - 1:
                    time.sleep(3)
                else:
                    SQL = "SELECT"

    return SQL if SQL else "SELECT"

def GPT4_generation(prompt, max_retries=3):
    for attempt in range(max_retries):
        try:
            response = client.chat.completions.create(
                model="gpt-4o", 
                messages=[{"role": "user", "content": prompt}],
                n = 1,
                stream = False,
                temperature=0.0,
                max_tokens=600,
                top_p = 1.0,
                frequency_penalty=0.0,
                presence_penalty=0.0
                # Removed stop=["Q:"] as it cause issues
            )
            return response.choices[0].message.content
        except Exception as e:
            print(f"Attempt {attempt + 1} failed: {str(e)}")
            if attempt < max_retries - 1:
                time.sleep(3)
            else:
                print("Max retries reached")
                return None
    return None

In [9]:
############################################ SELF CORRECTION
def debuger(test_sample_text,sql):
	instruction = """#### For the given question, the following SQL query was generated 
	
	Use the provided tables, columns, foreign keys, and primary keys to fix the given SQLite SQL QUERY for any issues. If there are any problems, fix them and return the fixed SQLite QUERY in the output. If there are no issues, return the SQLite SQL QUERY as is in the output.
"for these set of instructions, this sql was generated; check whether or not this is correctly generated, and assess on these 7 parameters"
	
	#### Use the following instructions for fixing the SQL QUERY:
1) Use the database values that are explicitly mentioned in the question.
2) Pay attention to the columns that are used for the JOIN by using the Foreign_keys.
3) Use DESC and DISTINCT when needed.
4) Pay attention to the columns that are used for the GROUP BY statement.
5) Pay attention to the columns that are used for the SELECT statement.
6) Only change the GROUP BY clause when necessary (Avoid redundant columns in GROUP BY).
7) Use GROUP BY on one column only.
"""
	prompt = instruction + '#### Question: ' + test_sample_text + '\n#### SQLite SQL QUERY\n' + sql +'\n#### SQLite FIXED SQL QUERY' + schema_links
	return prompt



def GPT4_debug(prompt):
    response = client.chat.completions.create(
        model="gpt-4o",
        messages=[{"role": "user", "content": prompt}],
        n = 1,
        stream = False,
        temperature=0.0,
        max_tokens=350,
        top_p = 1.0,
        frequency_penalty=0.0,
        presence_penalty=0.0,
        stop = ["#", ";","\n\n"]
    )
    return response.choices[0].message.content


def refine_query(question, sql):
	debugged_SQL = None
	while debugged_SQL is None:
		try:
			debugged_SQL = GPT4_debug(debuger(question,sql)).replace("\n", " ")
		except:
			time.sleep(3)
			pass
	SQL = debugged_SQL.split('sql', 1)
	print(SQL)


In [10]:
def generate_din_sql(question: str):
       schema_info = vr.process_schema(question)
       classification = process_question_classification(question, schema_info)
       process_thesql = process_question_sql(question, classification, schema_info)
       debug = debuger(question, process_thesql)
       final_output = GPT4_debug(debug)
       return (final_output)

In [11]:
import sqlite3
import io
import csv
def execute_sql(query: str) -> str:
    conn = None
    try:
        conn = sqlite3.connect('sqlite/nport.db')
        cursor = conn.cursor()

        # Execute the query with a timeout
        cursor.execute(query)

        # Fetch column names and rows
        columns = [description[0] for description in cursor.description]
        rows = cursor.fetchall()

        # Convert the results to CSV
        output = io.StringIO()
        writer = csv.writer(output)
        writer.writerow(columns)
        writer.writerows(rows)
        csv_data = output.getvalue()
        output.close()

        return csv_data
    except sqlite3.Error as e:
        print(f"Database error: {str(e)}")
        raise e
    except Exception as e:
        print(f"Error executing SQL: {str(e)}")
        raise e
    finally:
        if conn:
            conn.close()

def compare_csv_strings(csv_data1: str, csv_data2: str) -> bool:
    # Use io.StringIO to read the CSV strings as file-like objects
    csv_file1 = io.StringIO(csv_data1)
    csv_file2 = io.StringIO(csv_data2)
    
    # Create CSV readers for each CSV string
    reader1 = csv.reader(csv_file1)
    reader2 = csv.reader(csv_file2)
    
    # Compare rows one by one
    for row1, row2 in zip(reader1, reader2):
        if row1 != row2:
            return False  # Rows are different
    
    # Check if there are extra rows in either file
    try:
        next(reader1)
        return False  # Extra rows in csv_data1
    except StopIteration:
        pass

    try:
        next(reader2)
        return False  # Extra rows in csv_data2
    except StopIteration:
        pass

    return True  # CSVs are identical


def compare_csv_din(ground_truth_query: str, llm_query: str):
    ## let LLM stack query the database
    try: 
        din_generated_sql = generate_din_sql(llm_query)
        llm_csv = execute_sql(din_generated_sql.replace("```sql", "").replace("```", "").strip())
    except Exception as e:
        print(f"Unexpected error: {str(e)}")
        raise e
    ## direct execute of ground truth SQL on the database
    try:
        ground_truth_csv = execute_sql(ground_truth_query)
        ## compare results
    except Exception as e:
        print(f"Error Executing Ground Truth SQL: {str(e)}")
        raise e
    ## compare 2 SQL outputs using compare_csv_strings
    try:
        diff=compare_csv_strings(ground_truth_csv,llm_csv)
        if diff:
            print("CSV outputs match perfectly.")
            return True
        else:
            print("Mismatch found.")
            return False
    except Exception as e:
        print(f"Error comparing CSVs: {str(e)}")
        raise e




In [12]:
results = []
while len(results) < 20:
    results.append(None)
i=0
for i in range(len(llm_query)):
    results[i]=(str(i)+'. '+ str(compare_csv_din(ground_truth_query[i],llm_query[i])))

DEBUG:openai._base_client:Request options: {'method': 'post', 'url': '/chat/completions', 'files': None, 'json_data': {'messages': [{'role': 'system', 'content': "Given a financial database schema:\n        {'type': 'database', 'schema': {'tables': [{'name': 'REGISTRANT', 'columns': [{'name': 'ACCESSION_NUMBER', 'type': 'TEXT'}, {'name': 'CIK', 'type': 'TEXT'}, {'name': 'REGISTRANT_NAME', 'type': 'TEXT'}, {'name': 'FILE_NUM', 'type': 'TEXT'}, {'name': 'LEI', 'type': 'TEXT'}, {'name': 'ADDRESS1', 'type': 'TEXT'}, {'name': 'ADDRESS2', 'type': 'TEXT'}, {'name': 'CITY', 'type': 'TEXT'}, {'name': 'STATE', 'type': 'TEXT'}, {'name': 'COUNTRY', 'type': 'TEXT'}, {'name': 'ZIP', 'type': 'TEXT'}, {'name': 'PHONE', 'type': 'TEXT'}, {'name': 'QUARTER', 'type': 'TEXT'}]}, {'name': 'FUND_REPORTED_INFO', 'columns': [{'name': 'ACCESSION_NUMBER', 'type': 'TEXT'}, {'name': 'SERIES_NAME', 'type': 'TEXT'}, {'name': 'SERIES_ID', 'type': 'TEXT'}, {'name': 'SERIES_LEI', 'type': 'TEXT'}, {'name': 'TOTAL_ASSETS

Found 0 matches for '5':
Found 20 matches for 'country':
  registrant.country: 1.0000
  fund_reported_holding.investment_country: 1.0000
  repurchase_agreement.central_counter_party: 0.8571
  repurchase_counterparty.repurchase_counterparty_id: 0.7368
  derivative_counterparty.derivative_counterparty_id: 0.7368
Found 1 matches for 'cik':
  registrant.cik: 1.0000
Found 5 matches for 'registrant':
  registrant.registrant_name: 1.0000
  fund_reported_info.reinvestment_flow_mon1: 0.6364
  fund_reported_info.reinvestment_flow_mon2: 0.6364
  fund_reported_info.reinvestment_flow_mon3: 0.6364
  monthly_total_return.monthly_total_return_id: 0.6250
Found 279 matches for 'total':
  fund_reported_info.total_assets: 1.0000
  fund_reported_info.total_liabilities: 1.0000
  monthly_total_return.monthly_total_return_id: 1.0000
  monthly_total_return.monthly_total_return1: 1.0000
  monthly_total_return.monthly_total_return2: 1.0000
Found 279 matches for 'net':
  fund_reported_info.net_assets: 1.0000
  fu

DEBUG:httpcore.http11:receive_response_headers.complete return_value=(b'HTTP/1.1', 200, b'OK', [(b'Date', b'Tue, 26 Nov 2024 00:42:05 GMT'), (b'Content-Type', b'application/json'), (b'Transfer-Encoding', b'chunked'), (b'Connection', b'keep-alive'), (b'access-control-expose-headers', b'X-Request-ID'), (b'openai-organization', b'user-annfuni26pdtuawdwdj6zorw'), (b'openai-processing-ms', b'6104'), (b'openai-version', b'2020-10-01'), (b'x-ratelimit-limit-requests', b'10000'), (b'x-ratelimit-limit-tokens', b'2000000'), (b'x-ratelimit-remaining-requests', b'9999'), (b'x-ratelimit-remaining-tokens', b'1995450'), (b'x-ratelimit-reset-requests', b'6ms'), (b'x-ratelimit-reset-tokens', b'136ms'), (b'x-request-id', b'req_c23e64d7b6372d09e5d1fb2cbbb205f8'), (b'strict-transport-security', b'max-age=31536000; includeSubDomains; preload'), (b'CF-Cache-Status', b'DYNAMIC'), (b'Set-Cookie', b'__cf_bm=zwYiPRQ9ZiYXsf6a1LTVP2MC.JXAHFHyX6okCkH.P2s-1732581725-1.0.1.1-v7vn2_3u4ujiyYKXn0LEVjgnY1etO1s6T4OfYBx52

Raw response: To answer the question "List the top 5 registrants by total net assets, including their CIK and country," we need to consider the following:

1. **Tables Involved**: 
   - `REGISTRANT`: This table contains information about the registrant, including `CIK` and `COUNTRY`.
   - `FUND_REPORTED_INFO`: This table contains financial information, including `NET_ASSETS`.

2. **Required Columns**:
   - From `REGISTRANT`: `CIK`, `COUNTRY`.
   - From `FUND_REPORTED_INFO`: `NET_ASSETS`.

3. **Join Requirement**:
   - We need to join `REGISTRANT` and `FUND_REPORTED_INFO` on the `ACCESSION_NUMBER` to combine registrant details with their financial information.

4. **Sorting and Limiting**:
   - We need to sort the results by `NET_ASSETS` in descending order to find the top 5 registrants.
   - We will limit the results to the top 5.

Given these requirements, the query involves a join between two tables but does not require nested queries. Therefore, the SQL query can be classified as "N

DEBUG:httpcore.http11:receive_response_headers.complete return_value=(b'HTTP/1.1', 200, b'OK', [(b'Date', b'Tue, 26 Nov 2024 00:42:15 GMT'), (b'Content-Type', b'application/json'), (b'Transfer-Encoding', b'chunked'), (b'Connection', b'keep-alive'), (b'access-control-expose-headers', b'X-Request-ID'), (b'openai-organization', b'user-annfuni26pdtuawdwdj6zorw'), (b'openai-processing-ms', b'10066'), (b'openai-version', b'2020-10-01'), (b'x-ratelimit-limit-requests', b'10000'), (b'x-ratelimit-limit-tokens', b'2000000'), (b'x-ratelimit-remaining-requests', b'9999'), (b'x-ratelimit-remaining-tokens', b'1988554'), (b'x-ratelimit-reset-requests', b'6ms'), (b'x-ratelimit-reset-tokens', b'343ms'), (b'x-request-id', b'req_4551b8b904631e74345c5ef802ad0767'), (b'strict-transport-security', b'max-age=31536000; includeSubDomains; preload'), (b'CF-Cache-Status', b'DYNAMIC'), (b'X-Content-Type-Options', b'nosniff'), (b'Server', b'cloudflare'), (b'CF-RAY', b'8e85d625cfbe08ca-LAX'), (b'Content-Encoding', 


Trying to extract SQL from: To generate the SQL query for the given question, we need to follow these steps:

1. **Identify the Tables and Columns Needed**: 
   - We need information from the `REGISTRANT` table for the registrant's CIK and country.
   - We need the `NET_ASSETS` from the `FUND_REPORTED_INFO` table to calculate the total net assets for each registrant.

2. **Establish the Relationship Between Tables**:
   - The `REGISTRANT` and `FUND_REPORTED_INFO` tables are linked by the `ACCESSION_NUMBER` column, as indicated by the foreign key relationship: `REGISTRANT.ACCESSION_NUMBER = FUND_REPORTED_INFO.ACCESSION_NUMBER`.

3. **Aggregate the Data**:
   - We need to sum the `NET_ASSETS` for each registrant to get the total net assets.

4. **Order and Limit the Results**:
   - We need to order the results by total net assets in descending order to get the top registrants.
   - Limit the results to the top 5 registrants.

5. **Construct the SQL Query**:
   - Use a `JOIN` to combine 

DEBUG:httpcore.http11:receive_response_headers.complete return_value=(b'HTTP/1.1', 200, b'OK', [(b'Date', b'Tue, 26 Nov 2024 00:42:16 GMT'), (b'Content-Type', b'application/json'), (b'Transfer-Encoding', b'chunked'), (b'Connection', b'keep-alive'), (b'access-control-expose-headers', b'X-Request-ID'), (b'openai-organization', b'user-annfuni26pdtuawdwdj6zorw'), (b'openai-processing-ms', b'1185'), (b'openai-version', b'2020-10-01'), (b'x-ratelimit-limit-requests', b'10000'), (b'x-ratelimit-limit-tokens', b'2000000'), (b'x-ratelimit-remaining-requests', b'9999'), (b'x-ratelimit-remaining-tokens', b'1998346'), (b'x-ratelimit-reset-requests', b'6ms'), (b'x-ratelimit-reset-tokens', b'49ms'), (b'x-request-id', b'req_ab124bea6ca9d92edbcb7f3deeb51393'), (b'strict-transport-security', b'max-age=31536000; includeSubDomains; preload'), (b'CF-Cache-Status', b'DYNAMIC'), (b'X-Content-Type-Options', b'nosniff'), (b'Server', b'cloudflare'), (b'CF-RAY', b'8e85d665cc5108ca-LAX'), (b'Content-Encoding', b'

Database error: near "The": syntax error
Unexpected error: near "The": syntax error


OperationalError: near "The": syntax error