In [1]:
import json
import torch
from torch.utils.data import Dataset

In [2]:
with open("spider_data/train_spider.json", 'r') as f:
    train_data = json.load(f)

with open("spider_data/dev.json", 'r') as f:
    dev_data = json.load(f)

with open("spider_data/tables.json", 'r') as f:
    tables = json.load(f)

In [3]:
len(train_data), len(dev_data), len(tables)

(7000, 1034, 166)

In [4]:
train_data[0].keys()

dict_keys(['db_id', 'query', 'query_toks', 'query_toks_no_value', 'question', 'question_toks', 'sql'])

In [5]:
print(train_data[0]['db_id'])
print(train_data[0]['query'])
print(train_data[0]['query_toks'])
print(train_data[0]['query_toks_no_value'])
print(train_data[0]['question'])
print(train_data[0]['question_toks'])
print(train_data[0]['sql'])

department_management
SELECT count(*) FROM head WHERE age  >  56
['SELECT', 'count', '(', '*', ')', 'FROM', 'head', 'WHERE', 'age', '>', '56']
['select', 'count', '(', '*', ')', 'from', 'head', 'where', 'age', '>', 'value']
How many heads of the departments are older than 56 ?
['How', 'many', 'heads', 'of', 'the', 'departments', 'are', 'older', 'than', '56', '?']
{'from': {'table_units': [['table_unit', 1]], 'conds': []}, 'select': [False, [[3, [0, [0, 0, False], None]]]], 'where': [[False, 3, [0, [0, 10, False], None], 56.0, None]], 'groupBy': [], 'having': [], 'orderBy': [], 'limit': None, 'intersect': None, 'union': None, 'except': None}


dict_keys(['column_names', 'column_names_original', 'column_types', 'db_id', 'foreign_keys', 'primary_keys', 'table_names', 'table_names_original'])

In [23]:
print(tables[4]['table_names_original']), print(tables[4]['column_names_original'])

['body_builder', 'people']
[[-1, '*'], [0, 'Body_Builder_ID'], [0, 'People_ID'], [0, 'Snatch'], [0, 'Clean_Jerk'], [0, 'Total'], [1, 'People_ID'], [1, 'Name'], [1, 'Height'], [1, 'Weight'], [1, 'Birth_Date'], [1, 'Birth_Place']]


(None, None)

In [25]:
db_id = train_data[0]['db_id']
for i in tables:
    if i['db_id'] == db_id:
        print(i['table_names_original'])
        print(i['column_names_original'])

    

['department', 'head', 'management']
[[-1, '*'], [0, 'Department_ID'], [0, 'Name'], [0, 'Creation'], [0, 'Ranking'], [0, 'Budget_in_Billions'], [0, 'Num_Employees'], [1, 'head_ID'], [1, 'name'], [1, 'born_state'], [1, 'age'], [2, 'department_ID'], [2, 'head_ID'], [2, 'temporary_acting']]


In [18]:
data_dbid = train_data[0]['db_id']
for i in range(len(tables)):
    if tables[i]['db_id'] == data_dbid:
        table_dbid = tables[i]['db_id']
        print(data_dbid)
        print(table_dbid)
        print(tables[i].keys())
        print(tables[i]['column_names'])
        print(tables[i]['column_names_original'])
    else:
        continue

department_management
department_management
dict_keys(['column_names', 'column_names_original', 'column_types', 'db_id', 'foreign_keys', 'primary_keys', 'table_names', 'table_names_original'])
[[-1, '*'], [0, 'department id'], [0, 'name'], [0, 'creation'], [0, 'ranking'], [0, 'budget in billions'], [0, 'num employees'], [1, 'head id'], [1, 'name'], [1, 'born state'], [1, 'age'], [2, 'department id'], [2, 'head id'], [2, 'temporary acting']]
[[-1, '*'], [0, 'Department_ID'], [0, 'Name'], [0, 'Creation'], [0, 'Ranking'], [0, 'Budget_in_Billions'], [0, 'Num_Employees'], [1, 'head_ID'], [1, 'name'], [1, 'born_state'], [1, 'age'], [2, 'department_ID'], [2, 'head_ID'], [2, 'temporary_acting']]


In [6]:
class SpiderDataset(Dataset):
    """
    A PyTorch Dataset for the Spider Text-to-SQL task, combining the query and the
    database schema into a structured input string.
    """
    def __init__(self, train_data_path=None, table_data_path=None):
        """
        Initializes the dataset by loading and processing the train and schema files.
        """
        print("Initializing SpiderDataset...")

        # --- 1. Load Data ---
        try:
            if train_data_path:
                with open(train_data_path, 'r') as f:
                    self.data = json.load(f)
            else:
                self.data = json.load(io.StringIO(MOCK_TRAIN_JSON_CONTENT))

            if table_data_path:
                with open(table_data_path, 'r') as f:
                    table_list = json.load(f)
            else:
                table_list = json.load(io.StringIO(MOCK_TABLES_JSON_CONTENT))

        except Exception as e:
            print(f"Error loading JSON files. Check file paths/content: {e}")
            self.data = []
            table_list = []

        # --- 2. Process Schemas for Quick Lookup ---
        self.schema_map = {}
        for db_schema in table_list:
            db_id = db_schema['db_id']
            formatted_schema = self._format_schema(db_schema)
            self.schema_map[db_id] = formatted_schema

        print(f"Loaded {len(self.data)} training examples.")


    def _format_schema(self, db_schema):
        """
        Converts the database schema dictionary into the required string format:
        "| table_1 ; *, col_1, col_2, ... | table_2 ; *, col_1, col_2, ... |"
        """
        table_names = db_schema['table_names']
        column_names_info = db_schema['column_names']

        table_to_columns = {name: [] for name in table_names}
        
        for col_info in column_names_info:
            table_idx = col_info[0]
            col_name = col_info[1]
            
            # Skip the global '*' column (table_idx = -1)
            if table_idx >= 0:
                 table_name = table_names[table_idx]
                 table_to_columns[table_name].append(col_name)

        schema_parts = []
        for table_name in table_names:
            columns = table_to_columns[table_name]
            
            # Prepend the '*' column for universal selection
            full_columns = ['*'] + columns
            columns_str = ", ".join(full_columns)
            
            schema_parts.append(f"| {table_name} ; {columns_str} ")

        return "".join(schema_parts) + "|"


    def __len__(self):
        """Returns the total number of samples in the training set."""
        return len(self.data)


    def __getitem__(self, idx):
        """
        Retrieves one sample from the dataset.

        Returns:
            tuple: (query, formatted_schema, gold_sql)
        """
        sample = self.data[idx]
        
        query = sample['question']
        gold_sql = sample['query'] 
        
        db_id = sample['db_id']
        formatted_schema = self.schema_map.get(db_id, "Error: Schema not found")
    
        return query, formatted_schema, gold_sql

spider_data = SpiderDataset(train_data_path="spider_data/train_spider.json", table_data_path="spider_data/tables.json")
    
print("\n--- Dataset Sample Examples ---")

# 2. Access samples and confirm the output format
for i in range(len(spider_data)):
    query, schema, sql = spider_data[i]
    print(f"Sample {i+1}:")
    print(f"  Query: {query}")
    print(f"  Schema: {schema}")
    print(f"  SQL Query: {sql}") 
    print("-" * 40)
    break

Initializing SpiderDataset...
Loaded 7000 training examples.

--- Dataset Sample Examples ---
Sample 1:
  Query: How many heads of the departments are older than 56 ?
  Schema: | department ; *, department id, name, creation, ranking, budget in billions, num employees | head ; *, head id, name, born state, age | management ; *, department id, head id, temporary acting |
  SQL Query: SELECT count(*) FROM head WHERE age  >  56
----------------------------------------


In [7]:
q, sch, ans = spider_data[10]

In [5]:
sch

'| department ; *, department id, name, creation, ranking, budget in billions, num employees | head ; *, head id, name, born state, age | management ; *, department id, head id, temporary acting |'

In [6]:
q

'How many acting statuses are there?'

In [None]:
model_path = "/projects/p32722/Models/Qwen2.5-1.5B-Instruct"

In [27]:
class SpiderDataset:
    """
    Dataset class for inference on Spider dataset.
    Returns samples in a format suitable for model prediction.
    """
    def __init__(self, data_path, table_data_path):
        """Load the evaluation dataset and schema information."""
        print(f"Loading dataset from {data_path}...")
        
        with open(data_path, 'r') as f:
            self.data = json.load(f)
        
        with open(table_data_path, 'r') as f:
            table_list = json.load(f)
        
        # Build schema map
        self.schema_map = {}
        for db_schema in table_list:
            db_id = db_schema['db_id']
            formatted_schema = self._format_schema(db_schema)
            self.schema_map[db_id] = formatted_schema
        
        print(f"Loaded {len(self.data)} examples for inference.")
    
    def _format_schema(self, db_schema):
        """Format database schema into string representation."""
        table_names = db_schema['table_names']
        column_names_info = db_schema['column_names']
        
        table_to_columns = {name: [] for name in table_names}
        for col_info in column_names_info:
            table_idx = col_info[0]
            col_name = col_info[1]
            
            if table_idx >= 0:
                table_name = table_names[table_idx]
                table_to_columns[table_name].append(col_name)
        
        schema_parts = []
        for table_name in table_names:
            columns = table_to_columns[table_name]
            full_columns = ['*'] + columns
            columns_str = ", ".join(full_columns)
            schema_parts.append(f"| {table_name} ; {columns_str} ")
        
        return "".join(schema_parts) + "|"
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        """Return query, schema, db_id for inference."""
        sample = self.data[idx]
        query = sample['question']
        db_id = sample['db_id']
        formatted_schema = self.schema_map.get(db_id, "")

        reference, template = get_sql_template_and_reference(formatted_schema, gold_sql)
        
        return {
            'query': query,
            'schema': formatted_schema,
            'db_id': db_id,
            'gold_sql': sample.get('query', ''),  # For gold file generation
            "reference": reference,
            "template": template
        }

In [28]:
spider = SpiderDataset(data_path="spider_data/dev.json", table_data_path="spider_data/tables.json")


Loading dataset from spider_data/dev.json...
Loaded 1034 examples for inference.


In [29]:
data = spider[2]
q = data['query']
sch = data['schema']
gold_sql = data['gold_sql']
data

{'query': 'Show name, country, age for all singers ordered by age from the oldest to the youngest.',
 'schema': '| stadium ; *, stadium id, location, name, capacity, highest, lowest, average | singer ; *, singer id, name, country, song name, song release year, age, is male | concert ; *, concert id, concert name, theme, stadium id, year | singer in concert ; *, concert id, singer id |',
 'db_id': 'concert_singer',
 'gold_sql': 'SELECT name ,  country ,  age FROM singer ORDER BY age DESC',
 'reference': '| singer; name, country, age |',
 'template': 'SELECT _ ,  _ ,  _ FROM _ ORDER BY _ DESC'}

In [None]:
import re
from typing import Tuple, Dict, List


def parse_schema(schema_str: str) -> Dict[str, List[str]]:
    """Parses a schema string into a dictionary mapping table names to column lists."""
    schema = {}
    parts = [p.strip() for p in schema_str.split('|') if p.strip()]
    for part in parts:
        if ';' not in part:
            continue
        table_name_str, columns_str = [s.strip() for s in part.split(';', 1)]
        columns = [c.strip() for c in columns_str.split(',')]
        schema[table_name_str.lower()] = columns
    return schema


def normalize(name: str) -> str:
    """Normalize names: lowercase and remove underscores/spaces."""
    return re.sub(r'[\s_]+', '', name.strip().lower())


def extract_aliases(sql_query: str) -> Dict[str, str]:
    """Extract aliases (alias -> real table)."""
    alias_pattern = re.compile(
        r'\bFROM\s+([a-zA-Z_][\w]*)\s+(?:AS\s+)?([a-zA-Z_][\w]*)'
        r'|\bJOIN\s+([a-zA-Z_][\w]*)\s+(?:AS\s+)?([a-zA-Z_][\w]*)',
        re.IGNORECASE
    )
    aliases = {}
    for match in alias_pattern.finditer(sql_query):
        t1, a1, t2, a2 = match.groups()
        if t1 and a1:
            aliases[a1.lower()] = t1.lower()
        if t2 and a2:
            aliases[a2.lower()] = t2.lower()
    return aliases


def get_sql_template_and_reference(schema_str: str, sql_query: str) -> Tuple[str, str]:
    """Generates reference and template for given SQL and schema."""
    schema = parse_schema(schema_str)
    aliases = extract_aliases(sql_query)

    # Flatten schema info
    all_tables = list(schema.keys())
    all_columns = {c for cols in schema.values() for c in cols}

    # Normalization maps
    norm_table_map = {normalize(t): t for t in all_tables}
    norm_col_map = {normalize(c): c for c in all_columns}

    # Extract potential tokens (tables, columns, aliases)
    tokens = re.findall(r'[A-Za-z_][A-Za-z0-9_]*|"[^"]+"|\*', sql_query)

    found_tables = set()
    found_columns = set()

    for token in tokens:
        tok = token.strip('"').lower()
        tok_norm = normalize(tok)
        # Match tables
        if tok_norm in norm_table_map:
            found_tables.add(norm_table_map[tok_norm])
        # Match columns
        if tok_norm in norm_col_map:
            found_columns.add(norm_col_map[tok_norm])

    # Include aliases’ base tables if they appear
    for alias, table in aliases.items():
        if table in schema:
            found_tables.add(table)

    # --- Build Reference ---
    reference_parts = []
    seen_pairs = set()

    for table in found_tables:
        table_cols = schema.get(table, [])
        cols_for_table = [
            col for col in table_cols
            if normalize(col) in {normalize(fc) for fc in found_columns}
        ]
        if (table, tuple(cols_for_table)) not in seen_pairs:
            reference_parts.append(
                f"{table}; {', '.join(cols_for_table)}" if cols_for_table else f"{table};"
            )
            seen_pairs.add((table, tuple(cols_for_table)))

    reference = f"| {' | '.join(reference_parts)} |" if reference_parts else "| |"

    # --- Build Template ---
    template = sql_query

    # Replace alias.column → alias._
    for alias, table in aliases.items():
        if table in schema:
            for col in schema[table]:
                for variant in [col, col.replace(' ', '_')]:
                    pattern = rf'\b{alias}\.{re.escape(variant)}\b'
                    template = re.sub(pattern, f'{alias}._', template, flags=re.IGNORECASE)

    # Replace table.column → table._
    for table in found_tables:
        for col in schema.get(table, []):
            for variant in [col, col.replace(' ', '_')]:
                pattern = rf'\b{table}\.{re.escape(variant)}\b'
                template = re.sub(pattern, f'{table}._', template, flags=re.IGNORECASE)

    # Replace remaining names with '_'
    all_names = sorted(found_tables.union(found_columns), key=len, reverse=True)
    for name in all_names:
        for variant in [name, name.replace(' ', '_')]:
            template = re.sub(rf'\b{re.escape(variant)}\b', '_', template, flags=re.IGNORECASE)

    return reference, template

for i in range(len(spider)-700):
    # if i < 10:
    sample =spider[i]
    schema = sample['schema']
    sql_query = sample['gold_sql']
    query = sample['query']
    reference, template = get_sql_template_and_reference(schema, sql_query)

    # Print the results
    print(f"\nSchema: \"{schema}\"")
    print(f"SQL Query: \"{sql_query}\"")
    print(f"NL Query: \"{query}\"")
    print("## Output")
    print(f"**Reference:** `{reference}`")
    print(f"**Template:** `{template}`")

In [43]:
import re

def clean_sql_query(sql_string: str) -> str:
    """
    Clean SQL query by extracting only the SQL code from markdown blocks
    or removing markdown markers, newlines, and extra whitespace.
    
    Args:
        sql_string: Raw SQL string that may contain ```sql, ```, \n, and other text
    
    Returns:
        Cleaned SQL query as a single line with normalized spacing
    """
    # First, try to extract SQL from within code blocks
    # Pattern to match content between ```sql and ``` or ``` and ```
    sql_block_pattern = r'```\s*(?:sql)?\s*(.*?)```'
    match = re.search(sql_block_pattern, sql_string, flags=re.IGNORECASE | re.DOTALL)
    
    if match:
        # Extract only the SQL content from the code block
        cleaned = match.group(1)
    else:
        # No code block found, clean the entire string
        cleaned = sql_string
        # Remove any stray ``` markers
        cleaned = re.sub(r'```\s*sql\s*', '', cleaned, flags=re.IGNORECASE)
        cleaned = re.sub(r'```', '', cleaned)
    
    # Remove all newline characters (\n, \r\n, \r)
    cleaned = cleaned.replace('\n', ' ')
    cleaned = cleaned.replace('\r', ' ')
    
    # Replace multiple spaces with a single space
    cleaned = re.sub(r'\s+', ' ', cleaned)
    
    # Strip leading and trailing whitespace
    cleaned = cleaned.strip()
    
    return cleaned


# Example Usage:
sql_input = """format(''' SELECT COUNT(*) FROM ( SELECT DISTINCT singer_id FROM singer_in_concert ) ''', {'table': 'singer'})
"""

clean_sql_query(sql_input)

"format(''' SELECT COUNT(*) FROM ( SELECT DISTINCT singer_id FROM singer_in_concert ) ''', {'table': 'singer'})"

In [44]:
a = """Given the database schema and question, generate the SQL query. Enclose the SQL query with in ```sql
(generated query)```. Write it in a single line

    Schema: | stadium ; *, stadium id, location, name, capacity, highest, lowest, average | singer ; *, singer id, name, country, song name, song release year, age, is male | concert ; *, concert id, concert name, theme, stadium id, year | singer in concert ; *, concert id, singer id |

    Question: How many singers do we have?

    SQL Query: format('''
    SELECT COUNT(*) FROM {table}
    ''', {'table': 'singer'}) in ```sql
    (generated query)
    ```
    """

a.split('SQL Query:')[-1]

" format('''\n    SELECT COUNT(*) FROM {table}\n    ''', {'table': 'singer'}) in ```sql\n    (generated query)\n    ```\n    "