# NL2SQL 

In this notebook we demo the usage of agentic schema refinement for NL2SQL (unsupervised version). We use [BIRD-SQL](https://bird-bench.github.io/), a widely used benchmark for Text-to-SQL.

In [14]:
import os
import json
import shutil
import pandas as pd
import ast
import numpy as np

from src.utils import prompt_llm, extract_json_from_llm_response, flatten, text_embedding
from src.database import SQLiteDatabase
from src.postprocess import process_views
from src.refinement import refine_schema

In [2]:
database_name = 'california_schools'
result_version = 'refine'
nl_to_sql_file = './workspace/bird/dev.json'

# Create workspace
workspace = f'./workspace/bird/dev_databases/{database_name}/{result_version}'
os.makedirs(workspace, exist_ok=True)

# Copy database into workspace. This is the database that will be modified during the refinement process.
original_database_file = f'./workspace/bird/dev_databases/{database_name}/{database_name}.sqlite'
database_file = os.path.join(workspace, f'{database_name}_{result_version}.sqlite')
shutil.copy(original_database_file, database_file)

'./workspace/bird/dev_databases/california_schools/refine/california_schools_refine.sqlite'

In [3]:
db = SQLiteDatabase(database_name, database_file, nl_to_sql_file)

## Multi-agent Schema Refinement

### Schema Refinement

Run the multi-agent refinement framework. During the schema refinement process agents have individual chat sessions where they engage in view creation.
The chat log will be used to extract views in the next stage.

In [4]:
log_file = os.path.join(workspace, f'refine_{db.database_name}.log')

In [5]:
!python src/refinement.py \
    --workspace {workspace} \
    --db_name {database_name} \
    --db_file {database_file} \
    --instr_file './agent_instructions.yml' \
    --temperature 0.2 \
    --verify \
    --n_chats 5 \
    --n_rounds 10 \
    --n_verification_rounds 8 \
    --subsample \
    --n_samples 10 \
    --n_sampled_tables 3 \
    --sample_data \
    --cache_seed 10 \
    &> {log_file}

### Postprocessing

* The chat log is parsed into individual chats. (`refine_<database_name>_chats.jsonl`)
* Each chat is then processed by an LLM agent to extact one or more pairs of a high level analysis task together with the set of views that are useful to solve the task. This dataset will be used for instruction tuning in the next stage. (`refine_<database_name>_task_views.jsonl`)
* Each view definition is parsed by a SQL parser to detect the original tables and columns it uses. The set of original tables and columns are called sources. Views that share sources are later associated in a graph. (`refine_<database_name>_sql_parsed.jsonl`)

In [None]:
process_views(db, workspace, log_file, generate_instructions=True)

### Retrieve Views 

Start by retrieving the generated views. The view definitions are previously parsed by a SQL parser (`refine_<database_name>_sql_parsed.jsonl`). For each view, 
* Find the source tables and columns that appear in the view definitions. Note, we only care about columns that appear in the header, as these are the attributes the view contains. It is possible that interemediate columns are used to connect data from different tables, i.e., in JOIN statements.  
* Generate textual descriptions for the views. We ask an LLM model to generate a one-sentence description for the each view, given the view definition. The description should focus on the semantic interpretation rather that the structure.

In [7]:
views_parsed_file = os.path.join(workspace, f'refine_{db.database_name}_sql_parsed.jsonl')

In [8]:
df = pd.DataFrame(columns=['view_name', 'view_description', 'tables', 'columns', 'sql'])

# TODO: handle views defined with nested queries (multiple SELECT clauses)

with open(views_parsed_file, 'r') as f:
    lines = f.readlines()
    for line in lines:
        view_dict = json.loads(line)
        if isinstance(view_dict['sql_parsed'], dict):
            if view_dict['view_name']:
                view_name = view_dict['view_name'].replace('__', ' ').replace('_', ' ')
                column_set = set([c for c in flatten(view_dict['sql_parsed']['select']) if isinstance(c, str)])
                table_set = set([c.split('.')[0] for c in column_set]) # keep only the tables that appear in the header
                if '__all__' in column_set:
                    continue
                llm_description = prompt_llm(f"Provide a one-sentence description for the database view defined as follows: {view_dict['sql']} . Focus on the semantic interpretation rather that the structure. Reply with one short sentence.", "You are a database expert.")
                _columns_in_header = [c.replace('__', ' ').replace('_', ' ') for c in column_set]
                sources = ['column {} from table {}'.format(c.split('.')[1], c.split('.')[0]) for c in _columns_in_header]
                text = "Name: {}. Sources: {}. Description: {}".format(view_name, ', '.join(sources), llm_description)
                view_data ={'view_name': view_dict['view_name'], 'view_description': text, 'tables': table_set, 'columns': column_set, 'sql': view_dict['sql']}
                df = pd.concat([df, pd.DataFrame([view_data])], ignore_index=True)

# remove duplicates
df = df.drop_duplicates(subset=['sql'])

display(df.head())
print('Total number of views:', len(df))

Unnamed: 0,view_name,view_description,tables,columns,sql
0,charter_school_status,Name: charter school status. Sources: column c...,{__schools},"{__schools.charter__, __schools.district__, __...",CREATE VIEW charter_school_status AS SELECT s....
1,sat_scores,Name: sat scores. Sources: column avgscrread ...,{__satscores},"{__satscores.avgscrread__, __satscores.avgscrm...",CREATE VIEW sat_scores AS SELECT ss.cds AS CDS...
2,charter_school_performance,Name: charter school performance. Sources: col...,"{__sat_scores, __charter_school_status, __frpm...","{__sat_scores.reading_score__, __frpm_eligibil...",CREATE VIEW charter_school_performance AS SELE...
3,school_funding_info,Name: school funding info. Sources: column cou...,{__schools},"{__schools.county__, __schools.charter__, __sc...",CREATE VIEW school_funding_info AS SELECT ...
4,funding_sat_performance,Name: funding sat performance. Sources: column...,"{__school_sat_performance, __school_funding_info}","{__school_sat_performance.avgscrmath__, __scho...",CREATE VIEW funding_sat_performance AS SELECT ...


Total number of views: 13


### View Embeddings

Use the textual descriptions previously generated to produce embeddings. Save the dataframe that contains descriptions, table and column sources together with the embedding in the workspace. (`refine_<database_name>_views.csv`)

In [9]:
df['embedding'] = df.view_description.apply(lambda x: text_embedding(x, model='text-embedding-3-large'))
df.to_csv(os.path.join(workspace, f'refine_{db.database_name}_views.csv'), index=False)

In [10]:
df = pd.read_csv(os.path.join(workspace, f'refine_{db.database_name}_views.csv'))
df['embedding'] = df['embedding'].apply(lambda x: ast.literal_eval(x))
df['tables'] = df['tables'].apply(lambda x: ast.literal_eval(x))
df['columns'] = df['columns'].apply(lambda x: ast.literal_eval(x))
display(df.head())
print('Total number of views:', len(df))

Unnamed: 0,view_name,view_description,tables,columns,sql,embedding
0,charter_school_status,Name: charter school status. Sources: column c...,{__schools},"{__schools.cdscode__, __schools.district__, __...",CREATE VIEW charter_school_status AS SELECT s....,"[-0.005505330394953489, 0.008263121359050274, ..."
1,sat_scores,Name: sat scores. Sources: column avgscrread ...,{__satscores},"{__satscores.cds__, __satscores.avgscrmath__, ...",CREATE VIEW sat_scores AS SELECT ss.cds AS CDS...,"[-0.013216014951467514, 0.0050041223876178265,..."
2,charter_school_performance,Name: charter school performance. Sources: col...,"{__sat_scores, __charter_school_status, __frpm...","{__sat_scores.reading_score__, __frpm_eligibil...",CREATE VIEW charter_school_performance AS SELE...,"[-0.010539255104959011, -0.004365139175206423,..."
3,school_funding_info,Name: school funding info. Sources: column cou...,{__schools},"{__schools.county__, __schools.charter__, __sc...",CREATE VIEW school_funding_info AS SELECT ...,"[-0.024454733356833458, 0.019905706867575645, ..."
4,funding_sat_performance,Name: funding sat performance. Sources: column...,"{__school_sat_performance, __school_funding_info}","{__school_sat_performance.avgscrmath__, __scho...",CREATE VIEW funding_sat_performance AS SELECT ...,"[-0.01732783205807209, 0.019792819395661354, -..."


Total number of views: 13


## NL2SQL Evaluation

In [11]:
queries = db.get_nl_queries()

In [12]:
test_query_id = 10
print(queries[test_query_id])

Query    : For the school with the highest average score in Reading in the SAT test, what is its FRPM count for students aged 5-17?
SQL      : SELECT T2.`FRPM Count (Ages 5-17)` FROM satscores AS T1 INNER JOIN frpm AS T2 ON T1.cds = T2.CDSCode ORDER BY T1.AvgScrRead DESC LIMIT 1 


### Retrieve Relevant Views 

In [16]:
def retrieve_top_k_views(view_df, query: str, k: int = 5):
    """
    Retrieve the top-k most relevant views for a given query.
    """
    query_embedding = text_embedding(query, model='text-embedding-3-large')
    view_df['similarity'] = view_df.embedding.apply(lambda x: np.dot(query_embedding, x))
    top_k_views = view_df.sort_values('similarity', ascending=False).head(k)
    return top_k_views

top_k_relevant_views = retrieve_top_k_views(df, queries[test_query_id].get_nl_query(), k=5)
top_k_relevant_views_wording = '\n\n'.join(top_k_relevant_views['sql'].values)

### SQL generation using the schema only

In [17]:
user_message = f"""
I have the following database schema:

BEGIN SCHEMA

{db.schema_wording()}

END SCHEMA

Please, give me a SQL query that answers the following question:
{queries[test_query_id].get_nl_query()}
Please, respond with just the SQL query. Do not include any additional comments or explanations.
"""

response = prompt_llm(user_message, "You are a database expert.", model='gpt-4o')

print("Prompt :", user_message)
print("Response :", response)

Prompt : 
I have the following database schema:

BEGIN SCHEMA

CREATE TABLE frpm (
  CDSCode TEXT NOT NULL PRIMARY KEY,
  Academic Year TEXT,
  County Code TEXT,
  District Code INTEGER,
  School Code TEXT,
  County Name TEXT,
  District Name TEXT,
  School Name TEXT,
  District Type TEXT,
  School Type TEXT,
  Educational Option Type TEXT,
  NSLP Provision Status TEXT,
  Charter School (Y/N) INTEGER,
  Charter School Number TEXT,
  Charter Funding Type TEXT,
  IRC INTEGER,
  Low Grade TEXT,
  High Grade TEXT,
  Enrollment (K-12) REAL,
  Free Meal Count (K-12) REAL,
  Percent (%) Eligible Free (K-12) REAL,
  FRPM Count (K-12) REAL,
  Percent (%) Eligible FRPM (K-12) REAL,
  Enrollment (Ages 5-17) REAL,
  Free Meal Count (Ages 5-17) REAL,
  Percent (%) Eligible Free (Ages 5-17) REAL,
  FRPM Count (Ages 5-17) REAL,
  Percent (%) Eligible FRPM (Ages 5-17) REAL,
  2013-14 CALPADS Fall 1 Certification Status INTEGER
  FOREIGN KEY (CDSCode) REFERENCES schools(CDSCode)
);

-- Sample Data:
('0

### SQL generation using the schema and relevant views (RAG)

In [18]:
user_message = f"""
I have the following database schema:

BEGIN SCHEMA

{db.schema_wording()}

END SCHEMA

I have the following views that can be used to answer the question:

BEGIN VIEWS

{top_k_relevant_views_wording}

END VIEWS

Please, give me a SQL query that answers the following question:
{queries[test_query_id].get_nl_query()}
Please, respond with just the SQL query. Do not include any additional comments or explanations.
"""

response = prompt_llm(user_message, "You are a database expert.", model='gpt-4o')

print("Prompt :", user_message)
print("Response :", response)

Prompt : 
I have the following database schema:

BEGIN SCHEMA

CREATE TABLE frpm (
  CDSCode TEXT NOT NULL PRIMARY KEY,
  Academic Year TEXT,
  County Code TEXT,
  District Code INTEGER,
  School Code TEXT,
  County Name TEXT,
  District Name TEXT,
  School Name TEXT,
  District Type TEXT,
  School Type TEXT,
  Educational Option Type TEXT,
  NSLP Provision Status TEXT,
  Charter School (Y/N) INTEGER,
  Charter School Number TEXT,
  Charter Funding Type TEXT,
  IRC INTEGER,
  Low Grade TEXT,
  High Grade TEXT,
  Enrollment (K-12) REAL,
  Free Meal Count (K-12) REAL,
  Percent (%) Eligible Free (K-12) REAL,
  FRPM Count (K-12) REAL,
  Percent (%) Eligible FRPM (K-12) REAL,
  Enrollment (Ages 5-17) REAL,
  Free Meal Count (Ages 5-17) REAL,
  Percent (%) Eligible Free (Ages 5-17) REAL,
  FRPM Count (Ages 5-17) REAL,
  Percent (%) Eligible FRPM (Ages 5-17) REAL,
  2013-14 CALPADS Fall 1 Certification Status INTEGER
  FOREIGN KEY (CDSCode) REFERENCES schools(CDSCode)
);

-- Sample Data:
('0