Load Spider Dataset

In [1]:
import os
import sqlite3
import json

In [2]:
database_dir_path = "/Users/virounikamina/Desktop/spider_data"
SCHEMA_FILE = "/Users/virounikamina/Desktop/spider_data/tables.json"

Spider Dataset Analysis

In [3]:
def list_databases(dir_path):
    """
    List all SQLite database files in the Spider dataset.
    """
    db_files = []
    db_paths = []
    for root, _, files in os.walk(dir_path):
        for file in files:
            if file.endswith(".sqlite"):
                db_files.append(file)
                db_paths.append(os.path.join(root, file))
    return db_files, db_paths

def connect_to_db(db_path):
    """
    Connect to a specific SQLite database.
    """
    if not os.path.exists(db_path):
        raise FileNotFoundError(f"Database file not found: {db_path}")
    
    # Connect to the database
    conn = sqlite3.connect(db_path)
    return conn

def list_tables(conn):
    """
    List all tables in the connected SQLite database.
    """
    query = "SELECT name FROM sqlite_master WHERE type='table';"
    cursor = conn.cursor()
    cursor.execute(query)
    return [row[0] for row in cursor.fetchall()]

def preview_table(conn, table_name, limit=5):
    """
    Preview data from a specific table.
    """
    query = f"SELECT * FROM {table_name} LIMIT {limit};"
    cursor = conn.cursor()
    cursor.execute(query)
    columns = [desc[0] for desc in cursor.description]  # Column names
    rows = cursor.fetchall()  # Data rows
    return columns, rows

In [4]:
db_names, db_paths = list_databases(database_dir_path)
# print("Available Databases:")
# for idx, db_name in enumerate(db_names):
#     print(f"{idx + 1}: {db_name}")

# Select a database to open
db_index = int(input("\nEnter the number of the database to open: ")) - 1
db_path = db_paths[db_index]

# Connect to the database
conn = connect_to_db(db_path)
print(f"\nConnected to: {db_path}")

# List tables in the database
tables = list_tables(conn)
print("# Tables: ", len(tables))
print("\nTables in the database:")
for idx, table in enumerate(tables):
    print(f"{idx + 1}: {table}")

# Select a table to preview
table_index = int(input("\nEnter the number of the table to preview: ")) - 1
table_name = tables[table_index]

# Preview the selected table
print(f"\nPreviewing table: {table_name}")
columns, rows = preview_table(conn, table_name)
print("\nColumns:", columns)
print("\nRows:")
for row in rows:
    print(row)

# Close the connection
conn.close()
print("\nConnection closed.")


Connected to: /Users/virounikamina/Desktop/spider_data/database/browser_web/browser_web.sqlite
# Tables:  3

Tables in the database:
1: Web_client_accelerator
2: browser
3: accelerator_compatible_browser

Previewing table: Web_client_accelerator

Columns: ['id', 'name', 'Operating_system', 'Client', 'Connection']

Rows:
(1, 'CACHEbox', 'Appliance (Linux)', 'End user, ISP', 'Broadband, Satellite, Wireless, Fiber, DSL')
(2, 'CProxy', 'Windows', 'user', 'up to 756kbit/s')
(3, 'Fasterfox', 'Windows, Mac, Linux and Mobile devices', 'user', 'Dialup, Wireless, Broadband, DSL')
(4, 'fasTun', 'Any', 'All', 'Any')
(5, 'Freewire', 'Windows, except NT and 95', 'ISP', 'Dial-up')

Connection closed.


Schema Formation

In [5]:
import os
import json
import numpy as np
import nltk
import re
from typing import List, Dict, Tuple
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from sklearn.feature_extraction.text import TfidfVectorizer
from fuzzywuzzy import fuzz
from openai import OpenAI
from dotenv import load_dotenv

load_dotenv()
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))

class PSLsh:
    """Locality-Sensitive Hashing implementation for fast approximate nearest neighbor search."""
    def __init__(self, vectors, n_planes=10, n_tables=5, seed: int = 42):
        self.n_planes = n_planes
        self.n_tables = n_tables
        self.hash_tables = [{} for _ in range(n_tables)]
        self.random_planes = []
        
        np.random.seed(seed)
        
        # Generate random planes for each hash table
        for _ in range(n_tables):
            planes = np.random.randn(vectors.shape[1], n_planes)
            self.random_planes.append(planes)
            
        self.num_vectors = vectors.shape[0]
        self.vectors = vectors
        self.build_hash_tables()

    def build_hash_tables(self):
        """Build hash tables from input vectors."""
        for idx in range(self.num_vectors):
            vector = self.vectors[idx].toarray()[0]
            hashes = self.hash_vector(vector)
            for i, h in enumerate(hashes):
                if h not in self.hash_tables[i]:
                    self.hash_tables[i][h] = []
                self.hash_tables[i][h].append(idx)

    def hash_vector(self, vector):
        """Generate hash codes for a vector."""
        hashes = []
        for planes in self.random_planes:
            projections = np.dot(vector, planes)
            hash_code = ''.join(['1' if x > 0 else '0' for x in projections])
            hashes.append(hash_code)
        return hashes

    def query(self, vector):
        """Find candidate nearest neighbors for a query vector."""
        vector = vector.toarray()[0]  # Convert sparse matrix to 1D array
        hashes = self.hash_vector(vector)
        candidates = set()
        for i, h in enumerate(hashes):
            candidates.update(self.hash_tables[i].get(h, []))
        return candidates


class SpiderValueRetrieval:
    def __init__(self, spider_tables_path: str = 'spider/tables.json', lsh_seed: int = 42):
        load_dotenv()
        self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))

        # Load all Spider schemas
        print("DEBUG: Loading Spider schemas from:", spider_tables_path)
        with open(spider_tables_path, 'r') as f:
            self.schemas = json.load(f)
        
        # Create a mapping of db_id to schema
        self.db_schemas = {schema['db_id']: schema for schema in self.schemas}
        
        self.lemmatizer = WordNetLemmatizer()
        self.stop_words = set(stopwords.words('english'))
        
        # Initialize database-specific components
        self.column_indices = {}
        self.vectorizers = {}
        self.lsh_indices = {}
        
        # Build indices for each database
        self._build_indices()

    def process_schema(self, question: str, db_id: str) -> str:
        """Process schema with database context."""
        if db_id not in self.db_schemas:
            raise ValueError(f"Unknown database ID: {db_id}")
            
        # Get schema for this database
        schema = self.db_schemas[db_id]
        
        # Get schema relationships
        primary_keys, foreign_keys = self._get_schema_relationships(schema)
        
        # Process question with database context
        results = self.process_question(question, db_id)

    def _get_schema_relationships(self, schema: Dict) -> Tuple[List[str], List[str]]:
        """Extract primary and foreign keys for a specific database."""
        table_names = schema.get('table_names_original', [])
        column_names = schema.get('column_names_original', [])
        primary_keys = schema.get('primary_keys', [])
        foreign_keys = schema.get('foreign_keys', [])
        
        # Format primary keys
        formatted_pks = []
        for pk in primary_keys:
            table_idx, col_name = column_names[pk]
            if table_idx != -1:
                table_name = table_names[table_idx]
                formatted_pks.append(f"{table_name.lower()}.{col_name.lower()}")
        
        # Format foreign keys
        formatted_fks = []
        for fk in foreign_keys:
            fk_col = column_names[fk[0]]
            pk_col = column_names[fk[1]]
            fk_table = table_names[fk_col[0]]
            pk_table = table_names[pk_col[0]]
            formatted_fks.append(
                f"{fk_table.lower()}.{fk_col[1].lower()} = {pk_table.lower()}.{pk_col[1].lower()}"
            )
            
        return formatted_pks, formatted_fks

    def _parse_numeric_value(self, word: str) -> str:
        """Parse numeric values from words."""
        if 'billion' in word.lower():
            return '1000000000'
        elif 'million' in word.lower():
            return '1000000'
        return word
    
    def _build_indices(self):
        """Build all necessary indices for all databases."""
        for db_id, schema in self.db_schemas.items():
            # Build column index
            self.column_indices[db_id] = self._build_column_index(schema)
            
            # Build vectorizer and LSH
            terms = self._get_schema_terms(schema)
            vectorizer = TfidfVectorizer(analyzer='char', ngram_range=(1, 3), min_df=1, max_df=0.95)
            term_vectors = vectorizer.fit_transform(terms)
            
            self.vectorizers[db_id] = {
                'vectorizer': vectorizer,
                'terms': terms
            }
            
            self.lsh_indices[db_id] = {
                'lsh': PSLsh(term_vectors, n_planes=10, n_tables=5),
                'vectors': term_vectors
            }

    def _build_column_index(self, schema: Dict) -> Dict:
        """Build column index for a specific database schema."""
        column_index = {}
        table_names = schema.get('table_names_original', [])
        column_names = schema.get('column_names_original', [])
        column_types = schema.get('column_types', [])
        
        for (table_idx, col_name), col_type in zip(column_names[1:], column_types[1:]):  # Skip first row (*) 
            if table_idx != -1:  # Skip table_idx == -1 which represents '*'
                table_name = table_names[table_idx].lower()
                qualified_name = f"{table_name}.{col_name.lower()}"
                
                column_index[qualified_name] = {
                    'table': table_name,
                    'column': col_name.lower(),
                    'type': col_type,
                    'words': self._split_column_name(col_name),
                    'synonyms': self._get_column_synonyms(col_name)
                }
        
        return column_index

    def _split_column_name(self, column_name: str) -> List[str]:
        """Split column name into individual words."""
        words = re.sub('([A-Z][a-z]+)', r' \1', re.sub('([A-Z]+)', r' \1', column_name)).split()
        words.extend(column_name.split('_'))
        return [word.lower() for word in words if word]

    def _get_column_synonyms(self, column_name: str) -> List[str]:
        """Get synonyms for words in column name."""
        words = self._split_column_name(column_name)
        return list(set(words))  # For Spider, we'll just use the words themselves as synonyms

    def _get_schema_terms(self, schema: Dict) -> List[str]:
        """Get all terms from a specific database schema."""
        terms = []
        table_names = schema.get('table_names_original', [])
        column_names = schema.get('column_names_original', [])
        
        for idx, table in enumerate(table_names):
            table = table.lower()
            terms.append(table)
            
            # Add column terms
            table_columns = [(t_idx, col) for t_idx, col in column_names if t_idx == idx]
            for _, column in table_columns:
                terms.append(f"{table}.{column.lower()}")
                
        return terms

    def preprocess_text(self, text: str) -> List[str]:
        """Tokenize and lemmatize input text, removing stop words."""
        if not text:
            return []
            
        try:
            tokens = nltk.word_tokenize(str(text).lower())
            filtered_tokens = [word for word in tokens if word not in self.stop_words and word.isalnum()]
            lemmatized_tokens = [self.lemmatizer.lemmatize(token) for token in filtered_tokens]
            return lemmatized_tokens
        except Exception as e:
            print(f"Error in preprocessing text '{text}': {str(e)}")
            return []

    def _find_similar_words(self, word: str, db_id: str) -> List[Tuple[str, float]]:
        """Find similar words in the database-specific schema."""
        if not word:
            return []

        word = word.lower()
        matches = []
        
        # Direct matching with column names
        column_index = self.column_indices[db_id]
        for qualified_name, metadata in column_index.items():
            score = 0.0
            
            # Check exact matches in column words
            if word in metadata['words']:
                matches.append((qualified_name, 1.0))
                continue
            
            # Fuzzy match with column words
            for col_word in metadata['words']:
                ratio = fuzz.ratio(word, col_word) / 100.0
                if ratio > score:
                    score = ratio
            
            if score > 0.6:
                matches.append((qualified_name, score))

        # LSH-based matching as backup
        if len(matches) < 5:
            try:
                vectorizer = self.vectorizers[db_id]['vectorizer']
                terms = self.vectorizers[db_id]['terms']
                word_vector = vectorizer.transform([word])
                
                lsh = self.lsh_indices[db_id]['lsh']
                vectors = self.lsh_indices[db_id]['vectors']
                
                candidate_indices = lsh.query(word_vector)
                
                for idx in candidate_indices:
                    term = terms[idx]
                    if not any(term == match[0] for match in matches):
                        candidate_vector = vectors[idx].toarray()[0]
                        word_vector_array = word_vector.toarray()[0]
                        dist = np.linalg.norm(word_vector_array - candidate_vector)
                        sim = 1 / (1 + dist)
                        if sim > 0.5:
                            matches.append((term, sim * 0.8))
            except Exception as e:
                print(f"LSH matching failed: {e}")

        matches.sort(key=lambda x: x[1], reverse=True)
        return matches[:5]

    def _extract_keywords(self, question: str, db_id: str) -> Dict:
        """Extract keywords with database-specific context."""
        schema = self.db_schemas[db_id]
        primary_keys, foreign_keys = self._get_schema_relationships(schema)
        
        system_prompt = f"""Given this database schema:
        Tables: {schema['table_names_original']}
        Columns: {schema['column_names_original']}
        
        Primary Keys: {primary_keys}
        Foreign Keys: {foreign_keys}

        Extract relevant keywords, keyphrases, and numerical values from the question in JSON format."""

        response = self.client.chat.completions.create(
            model="gpt-4o",
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": f"Question: {question}"}
            ],
            functions=[
                {
                    "name": "extract_components",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "keywords": {"type": "array", "items": {"type": "string"}},
                            "keyphrases": {"type": "array", "items": {"type": "string"}},
                            "numerical_values": {"type": "array", "items": {"type": "string"}}
                        },
                        "required": ["keywords"]
                    }
                }
            ],
            function_call={"name": "extract_components"}
        )

        # Access the function_call attribute directly
        function_call = response.choices[0].message.function_call
        arguments = function_call.arguments
        extracted_info = json.loads(arguments)

        # Debugging statement
        print("Extracted Info:", extracted_info)

        return extracted_info


    def process_question(self, question: str, db_id: str) -> Dict:
        """Process question with database context."""
        # Extract keywords using database-specific schema
        extracted_info = self._extract_keywords(question, db_id)
        
        # Process words
        words = []
        for key in ['keywords', 'keyphrases', 'numerical_values']:
            words.extend(extracted_info.get(key, []))
        
        # Debugging statement
        print("Words Extracted:", words)

        processed_words = []
        for word in words:
            processed_words.extend(self.preprocess_text(word))
        
        processed_words = list(set(processed_words))
        
        # Debugging statement
        print("Processed Words:", processed_words)

        # Find similar columns using database-specific indices
        similar_matches = {}
        for word in processed_words:
            similar_matches[word] = self._find_similar_words(word, db_id)
            # Debugging statement
            print(f"Similar matches for '{word}': {similar_matches[word]}")
        
        return {
            "question": question,
            "extracted_info": extracted_info,
            "processed_words": processed_words,
            "similar_matches": similar_matches,
            "schema_relationships": {
                "primary_keys": self._get_schema_relationships(self.db_schemas[db_id])[0],
                "foreign_keys": self._get_schema_relationships(self.db_schemas[db_id])[1]
            }
        }


    def process_schema(self, question: str, db_id: str) -> str:
        """Process schema with database context."""
        if db_id not in self.db_schemas:
            raise ValueError(f"Unknown database ID: {db_id}")
            
        # Get schema for this database
        schema = self.db_schemas[db_id]
        
        # Get schema relationships
        primary_keys, foreign_keys = self._get_schema_relationships(schema)
        
        # Process question with database context
        results = self.process_question(question, db_id)
        
        table_columns = []
        relevant_primary_keys = []
        relevant_foreign_keys = []
        
        # Use database-specific indices and relationships
        for word, matches in results['similar_matches'].items():
            if matches:
                top_match = matches[0]
                if top_match[1] > 0.7:
                    if word in results['extracted_info'].get('numerical_values', []):
                        # Handle numerical values
                        value = self._parse_numeric_value(word)
                        table_columns.append(f"{top_match[0]} > {value}")
                    else:
                        table_columns.append(top_match[0])
        
        # Debugging statement
        print("Table Columns:", table_columns)

        # Get relevant tables for this database
        tables_needed = set()
        for link in table_columns:
            if '.' in link:
                tables_needed.add(link.split('.')[0].lower())
        
        # Add relevant primary keys
        for pk in primary_keys:
            table = pk.split('.')[0].lower()
            if table in tables_needed:
                relevant_primary_keys.append(pk)
                
        # Add relevant foreign keys
        for fk in foreign_keys:
            tables_in_fk = set(part.split('.')[0].lower() for part in fk.split(' = '))
            if tables_in_fk.intersection(tables_needed):
                relevant_foreign_keys.append(fk)
        
        schema_dict = {
            "table_columns": table_columns,
            "primary_keys": relevant_primary_keys,
            "foreign_keys": relevant_foreign_keys,
            "schema_links": table_columns  # Added for DIN SQL compatibility
        }
        
        # Debugging statement
        print("Schema Dict:", schema_dict)
        
        return str(schema_dict)


In [6]:
import time

def classification_prompt_maker(question, schema_links, db_id, spider_schemas):
    """Create classification prompt with Spider database context."""
    # Get specific database schema
    schema = next((s for s in spider_schemas if s['db_id'] == db_id), None)
    if not schema:
        raise ValueError(f"Unknown database ID: {db_id}")
    
    # Format schema info for the prompt
    schema_info = f"Tables: {schema['table_names_original']}\nColumns: {schema['column_names_original']}"
    
    instruction = """Given the database schema:
{schema_info}

Primary Keys: {primary_keys}
Foreign Keys: {foreign_keys}

Classify the question as:
- EASY: no JOIN and no nested queries needed
- NON-NESTED: needs JOIN but no nested queries
- NESTED: needs nested queries

Question: "{question}"
Schema Links: {schema_links}

Let's think step by step:"""

    # Get formatted keys
    def format_keys(schema):
        pks = []
        fks = []
        for pk in schema.get('primary_keys', []):
            table_idx, col_name = schema['column_names_original'][pk]
            if table_idx != -1:
                table_name = schema['table_names_original'][table_idx]
                pks.append(f"{table_name.lower()}.{col_name.lower()}")
        
        for fk in schema.get('foreign_keys', []):
            fk_col = schema['column_names_original'][fk[0]]
            pk_col = schema['column_names_original'][fk[1]]
            fk_table = schema['table_names_original'][fk_col[0]]
            pk_table = schema['table_names_original'][pk_col[0]]
            fks.append(f"{fk_table.lower()}.{fk_col[1].lower()} = {pk_table.lower()}.{pk_col[1].lower()}")
        
        return pks, fks

    primary_keys, foreign_keys = format_keys(schema)
    
    return instruction.format(
        schema_info=schema_info,
        primary_keys=primary_keys,
        foreign_keys=foreign_keys,
        question=question,
        schema_links=schema_links
    )

def process_question_classification(question, schema_links, db_id, spider_schemas):
    """Process question classification with Spider database context."""
    def extract_classification(text):
        print(f"Trying to extract classification from: {text}")
        text = text.upper()
        
        for class_type in ["EASY", "NON-NESTED", "NESTED"]:
            if class_type in text:
                return class_type
                
        patterns = ["LABEL:", "CLASSIFICATION:", "CAN BE CLASSIFIED AS"]
        for pattern in patterns:
            if pattern in text:
                parts = text.split(pattern)
                if len(parts) > 1:
                    result = parts[1].strip().strip('"').strip("'")
                    classification = result.split()[0].strip()
                    if classification in ["EASY", "NON-NESTED", "NESTED"]:
                        return classification
        
        return "NESTED"  # Default fallback

    classification = None
    attempts = 0
    while classification is None and attempts < 3:
        try:
            client = OpenAI()
            response = client.chat.completions.create(
                model="gpt-4",
                messages=[{
                    "role": "user", 
                    "content": classification_prompt_maker(
                        question=question,
                        schema_links=schema_links,
                        db_id=db_id,
                        spider_schemas=spider_schemas
                    )
                }],
                temperature=0.0,
                max_tokens=300
            )
            raw_response = response.choices[0].message.content
            print("Raw response:", raw_response)
            classification = extract_classification(raw_response)
        except Exception as e:
            print(f"Error occurred: {str(e)}")
            time.sleep(3)
            attempts += 1
    
    final_class = classification if classification else "NESTED"
    return f'"{final_class}"'

def process_question_sql(question, predicted_class, schema_links, db_id, spider_schemas, max_retries=3):
    def extract_sql(text):
        if "SQL:" in text:
            return text.split("SQL:")[-1].strip()
        return text.strip()

    def make_spider_prompt(template_type):
        schema = next((s for s in spider_schemas if s['db_id'] == db_id), None)
        if not schema:
            raise ValueError(f"Unknown database ID: {db_id}")

        # Spider-specific examples
        examples = {
            "easy": '''Q: "How many clubs are there?"
Schema_links: [club.id]
SQL: SELECT COUNT(*) FROM club''',
            
            "medium": '''Q: "Show the names of all teams and their leagues."
Schema_links: [team.name, league.name]
A: Let's think step by step. We need to join teams with leagues.
SQL: SELECT team.name, league.name 
FROM team 
JOIN league ON team.league_id = league.id''',
            
            "hard": '''Q: "Find players who scored more goals than average."
Schema_links: [player.name, player.goals]
A: Let's think step by step:
1. Calculate average goals
2. Find players above average
SQL: SELECT name FROM player 
WHERE goals > (SELECT AVG(goals) FROM player)'''
        }

        template = examples[template_type]
        prompt = f"""Database Schema for {db_id}:
Tables: {schema['table_names_original']}
Columns: {schema['column_names_original']}

Example:
{template}

Generate SQL for:
Question: {question}
Schema_links: {schema_links}
"""
        return prompt

    for attempt in range(max_retries):
        try:
            if '"EASY"' in predicted_class:
                prompt = make_spider_prompt("easy")
            elif '"NON-NESTED"' in predicted_class:
                prompt = make_spider_prompt("medium")
            else:
                prompt = make_spider_prompt("hard")

            client = OpenAI()
            response = client.chat.completions.create(
                model="gpt-4",
                messages=[{"role": "user", "content": prompt}],
                temperature=0.0,
                max_tokens=500
            )
            sql = extract_sql(response.choices[0].message.content)
            break
        except Exception as e:
            print(f"Attempt {attempt + 1} failed: {str(e)}")
            if attempt < max_retries - 1:
                time.sleep(3)
            else:
                sql = "SELECT"
    
    return sql

In [7]:
from pydantic import BaseModel, Field
from typing import List, Dict, Any, Optional, Literal
from datetime import datetime
import time
from openai import OpenAI
import json


final_output_schema_json = json.dumps({
    "$schema": "http://json-schema.org/draft-07/schema#",
    "type": "object",
    "properties": {
        "user_nlp_query": {
            "type": "string",
            "description": "The original natural language query to be translated into SQL"
        },
        "reasonings": {
            "type": "array",
            "items": {
                "type": "object",
                "properties": {
                    "thought": {
                        "type": "string",
                        "description": "A thought about the user's question"
                    },
                    "helpful": {
                        "type": "boolean",
                        "description": "Whether the thought is helpful to solving the user's question"
                    }
                }
            },
            "description": "Step-by-step reasoning process for query generation"
        },
        "generated_sql_query": {
            "type": "string",
            "description": "The final SQL query that answers the natural language question"
        }
    }
})

# Define comprehensive thought instructions
thought_instructions = """
```
Thought Instructions:
```

```
1. Initial Analysis
- Identify the core request in the question
- Map question terms to database schema elements
- Determine if aggregation is needed
```

```
2. Complexity Assessment
- Evaluate if joins are needed
- Check if subqueries or CTEs are required
- Determine grouping requirements
```

```
3. Schema Analysis
- Identify primary tables needed
- Map columns to required data
- Understand table relationships
```

```
4. Query Planning
- Determine optimal join order if needed
- Plan filtering conditions
- Consider performance implications
```

```
5. SQL Components
- Select clause composition
- From clause and join structure
- Where clause conditions
- Group by and having requirements
```

```
6. Final Validation
- Verify schema compatibility
- Check column name accuracy
- Ensure proper syntax
```
"""

reasoning_instructions = """
```
SQL Generation Guidelines:
1. Use COUNT(*) for simple counts
2. Avoid unnecessary DISTINCT
3. Use table aliases for clarity
4. Include proper JOIN conditions
5. Handle NULL values appropriately
```

```
Format Requirements:
1. Use clear indentation
2. Align related clauses
3. Use meaningful aliases
4. Format complex conditions clearly
```
"""

class Thought(BaseModel):
    """Thought structure with improved validation"""
    thought: str = Field(description="The reasoning step")
    helpful: bool = Field(default=True)

class FinalOutput(BaseModel):
    """Complete output structure containing the query, reasoning, and SQL"""
    user_nlp_query: str = Field(
        description="The original natural language query to be translated into SQL"
    )
    reasonings: List[Thought] = Field(
        description="Step-by-step reasoning process for query generation"
    )
    generated_sql_query: str = Field(
        description="The final SQL query that answers the natural language question"
    )

def create_prompt(question: str, schema_links: List[str], schema: Dict[str, Any], complexity: str) -> str:
    """Create enhanced prompt with better examples and guidance"""
    
    examples = {
        "EASY": """
Example:
Q: "How many students are there?"
A: This requires a simple count from the student table
SQL: SELECT COUNT(*) FROM student
""",
        "NON-NESTED": """
Example:
Q: "List student names and their department names"
A: This requires joining student and department tables
SQL: SELECT s.name, d.name 
FROM student s
JOIN department d ON s.dept_id = d.id
""",
        "NESTED": """
Example:
Q: "Find students with above average grades"
A: This requires a subquery to calculate the average
SQL: SELECT name 
FROM student 
WHERE grade > (SELECT AVG(grade) FROM student)
"""
    }

    return f"""You are an expert SQL developer tasked with generating precise SQL queries.

SCHEMA INFORMATION:
Tables: {schema['table_names_original']}
Columns: {schema['column_names_original']}

QUESTION: {question}
SCHEMA_LINKS: {schema_links}
COMPLEXITY: {complexity}

{thought_instructions}

{reasoning_instructions}

{examples.get(complexity, examples['EASY'])}

GUIDELINES:
1. Generate clear, efficient SQL
2. Use proper table aliases
3. Include complete reasoning
4. Match schema exactly

RESPONSE FORMAT:
{final_output_schema_json} """

def process_question_sql(
    question: str,
    predicted_class: str,
    schema_links: List[str],
    db_id: str,
    spider_schemas: List[Dict[str, Any]],
    max_retries: int = 3
) -> FinalOutput:
    """Generate SQL with thoughts and reasoning"""
    
    try:
        # Get schema
        schema = next((s for s in spider_schemas if s['db_id'] == db_id), None)
        if not schema:
            raise ValueError(f"Unknown database ID: {db_id}")

        for attempt in range(max_retries):
            try:
                prompt = create_prompt(
                    question=question,
                    schema_links=schema_links,
                    schema=schema,
                    complexity=predicted_class
                )
                
                client = OpenAI()
                response = client.chat.completions.create(
                    model="gpt-4o",
                    messages=[
                        {
                            "role": "system", 
                            "content": """You are an SQL expert. Return JSON with this exact format:
                            {
                                "user_nlp_query": "the original question",
                                "reasonings": [
                                    {"thought": "your reasoning step", "helpful": true}
                                ],
                                "generated_sql_query": "your SQL query"
                            }"""
                        },
                        {"role": "user", "content": prompt}
                    ],
                    temperature=0,
                    max_tokens=1500,
                    response_format={"type": "json_object"}
                )
                
                content = response.choices[0].message.content
                print(f"Raw GPT response: {content}")  # Debug print
                
                try:
                    result = json.loads(content)
                    return FinalOutput(
                        user_nlp_query=result.get("user_nlp_query", question),
                        reasonings=[
                            Thought(**thought) for thought in result.get("reasonings", [])
                        ] or [Thought(thought="Direct SQL generation", helpful=True)],
                        generated_sql_query=result.get("generated_sql_query", "SELECT 1")
                    )
                except Exception as e:
                    print(f"Error parsing response: {str(e)}")
                    if attempt == max_retries - 1:
                        return FinalOutput(
                            user_nlp_query=question,
                            reasonings=[
                                Thought(
                                    thought=f"Failed to parse response: {str(e)}",
                                    helpful=False
                                )
                            ],
                            generated_sql_query="SELECT 1"
                        )
                    continue
                    
            except Exception as e:
                print(f"Attempt {attempt + 1} failed: {str(e)}")
                if attempt < max_retries - 1:
                    time.sleep(3)
                else:
                    return FinalOutput(
                        user_nlp_query=question,
                        reasonings=[
                            Thought(
                                thought=f"Error in process: {str(e)}",
                                helpful=False
                            )
                        ],
                        generated_sql_query="SELECT 1"
                    )
                continue
                
    except Exception as e:
        print(f"Error: {str(e)}")
        return FinalOutput(
            user_nlp_query=question,
            reasonings=[
                Thought(
                    thought=f"Critical error: {str(e)}",
                    helpful=False
                )
            ],
            generated_sql_query="SELECT 1"
        )

    return FinalOutput(
        user_nlp_query=question,
        reasonings=[
            Thought(
                thought="Maximum retries exceeded",
                helpful=False
            )
        ],
        generated_sql_query="SELECT 1"
    )

In [8]:
import time
from openai import OpenAI

def easy_prompt_maker(question, schema_links, schema):
    """Create prompt for easy Spider questions."""
    prompt = f"""Database Schema:
Tables: {schema['table_names_original']}
Columns: {schema['column_names_original']}

Example:
Q: "How many clubs are there?"
Schema_links: [club.id]
SQL: SELECT COUNT(*) FROM club

Now generate SQL for:
Question: {question}
Schema_links: {schema_links}
SQL:"""
    return prompt

def medium_prompt_maker(question, schema_links, schema):
    """Create prompt for medium (non-nested) Spider questions."""
    prompt = f"""Database Schema:
Tables: {schema['table_names_original']}
Columns: {schema['column_names_original']}

Example:
Q: "Show the names of all teams and their leagues."
Schema_links: [team.name, league.name]
A: Let's think step by step. We need to join teams with leagues.
SQL: SELECT team.name, league.name 
FROM team 
JOIN league ON team.league_id = league.id

Now generate SQL for:
Question: {question}
Schema_links: {schema_links}
A: Let's think step by step."""
    return prompt

def hard_prompt_maker(question, schema_links, schema):
    """Create prompt for hard (nested) Spider questions."""
    prompt = f"""Database Schema:
Tables: {schema['table_names_original']}
Columns: {schema['column_names_original']}

Example:
Q: "Find players who scored more goals than average."
Schema_links: [player.name, player.goals]
A: Let's think step by step:
1. Calculate average goals
2. Find players above average
SQL: SELECT name FROM player 
WHERE goals > (SELECT AVG(goals) FROM player)

Now generate SQL for:
Question: {question}
Schema_links: {schema_links}
A: Let's think step by step."""
    return prompt

def debugger(question: str, sql: str, predicted_class: str, schema_dict: Dict[str, Any]) -> str:
    """Create debug prompt based on query complexity."""
    
    if '"EASY"' in predicted_class:
        prompt_used = easy_prompt_maker(
            question=question,
            schema_links=schema_dict.get("schema_links", []),
            schema=schema_dict
        )
    elif '"NON-NESTED"' in predicted_class:
        prompt_used = medium_prompt_maker(
            question=question,
            schema_links=schema_dict.get("schema_links", []),
            schema=schema_dict
        )
    else:
        prompt_used = hard_prompt_maker(
            question=question,
            schema_links=schema_dict.get("schema_links", []),
            schema=schema_dict
        )

    instruction = f"""#### For the given question, use the provided tables, columns, foreign keys, and primary keys to check if the given SQLite SQL QUERY has any issues. If there are any issues, fix them and return the fixed SQLite QUERY in the output. If there are no issues, return the SQLite SQL QUERY as is in the output.
#### Background Information:
Relevant Schema Links: {schema_dict.get("schema_links", [])}
Prompt Used to Generate the Candidate SQLite SQL Query:
'''
{prompt_used}
'''
#### Use the following instructions for fixing the SQL QUERY:
1) Use the database values that are explicitly mentioned in the question.
2) Pay attention to the columns that are used for the JOIN by using the Foreign_keys.
3) Use DESC and DISTINCT only when needed.
4) Pay attention to the columns that are used for the GROUP BY statement.
5) Pay attention to the columns that are used for the SELECT statement.
6) Only change the GROUP BY clause when necessary (Avoid redundant columns in GROUP BY).
7) Use GROUP BY on one column only.

#### Question: {question}
#### SQLite SQL QUERY
{sql}
#### SQLite FIXED SQL QUERY
"""

    return instruction

def GPT4_debug(prompt: str) -> str:
    """Debug SQL using GPT-4."""
    client = OpenAI()
    try:
        response = client.chat.completions.create(
            model="gpt-4o",
            messages=[{
                "role": "system",
                "content": "You are a SQL expert. Return only the fixed SQL query with no explanation."
            },
            {
                "role": "user",
                "content": prompt
            }],
            temperature=0.0,
            max_tokens=350,
            top_p=1.0,
            frequency_penalty=0.0,
            presence_penalty=0.0,
            stop=["#", ";", "\n\n"]
        )
        return response.choices[0].message.content
    except Exception as e:
        print(f"Error in GPT4_debug: {str(e)}")
        return None
    
def refine_query(question: str, sql: str, predicted_class: str, schema_links: List[str], db_id: str, spider_schemas: List[Dict[str, Any]]) -> str:
    """Refine and debug the SQL query."""
    max_attempts = 3
    attempt = 0
    debugged_SQL = None
    
    # Get schema
    schema = next((s for s in spider_schemas if s['db_id'] == db_id), None)
    if not schema:
        return sql

    # Create schema dict with required format
    schema_dict = {
        "table_names_original": schema.get('table_names_original', []),
        "column_names_original": schema.get('column_names_original', []),
        "schema_links": schema_links,
        "primary_keys": schema.get('primary_keys', []),
        "foreign_keys": schema.get('foreign_keys', [])
    }
    
    while debugged_SQL is None and attempt < max_attempts:
        try:
            debug_prompt = debugger(
                question=question,
                sql=sql,
                predicted_class=predicted_class,
                schema_dict=schema_dict
            )
            debugged_SQL = GPT4_debug(debug_prompt)
            
            if debugged_SQL:
                # Clean up the response
                debugged_SQL = debugged_SQL.replace("\n", " ").strip()
                
                try:
                    # Try to extract SQL if wrapped in markdown
                    if "```" in debugged_SQL:
                        debugged_SQL = debugged_SQL.split("```sql")[-1].split("```")[0].strip()
                except:
                    # If extraction fails, use the whole response
                    pass
                
                print(f"Refined SQL (Attempt {attempt + 1}):", debugged_SQL)
                return debugged_SQL
            
        except Exception as e:
            print(f"Error in refinement attempt {attempt + 1}: {str(e)}")
            time.sleep(3)
        
        attempt += 1
    
    return sql

In [9]:
''' 
import os
import json
import time
from typing import Dict, Any, List
from openai import OpenAI

class SpiderTester:
    def __init__(self, spider_path: str = '/Users/virounikamina/Desktop/spider_data', mode='dev'):
        """Initialize Spider tester with dataset path and mode."""
        self.spider_path = spider_path
        
        # Load schemas
        with open(os.path.join(spider_path, 'tables.json'), 'r') as f:
            self.spider_schemas = json.load(f)
            
        # Select correct data file
        data_files = {
            'dev': 'dev.json',
            'train_spider': 'train_spider.json',
            'train_others': 'train_others.json'
        }
        
        if mode not in data_files:
            raise ValueError(f"Mode must be one of {list(data_files.keys())}")
            
        data_file = data_files[mode]
        print(f"Loading {data_file}...")
        
        # Load test data
        with open(os.path.join(spider_path, data_file), 'r') as f:
            self.test_data = json.load(f)
            
        print(f"Loaded {len(self.test_data)} questions")
            
        # Create results file
        self.results_file = os.path.join(
            spider_path, 
            f'results_{mode}_{time.strftime("%Y%m%d_%H%M%S")}.txt'
        )
        
        # Initialize value retrieval if needed
        self.value_retrieval = SpiderValueRetrieval(
            spider_tables_path=os.path.join(spider_path, 'tables.json')
        )

    def log(self, message: str, question_num: int = None):
        """Log message to file and console with optional question number."""
        if question_num is not None:
            message = f"Question {question_num}: {message}"
            
        print(message)
        with open(self.results_file, 'a') as f:
            f.write(message + '\n')

    def run_test(self, num_questions: int = None):
        if num_questions is None:
            num_questions = len(self.test_data)
        else:
            num_questions = min(num_questions, len(self.test_data))

        processed = 0
        successful = 0
        
        self.log(f"\nTesting {num_questions} questions")
        
        for idx, test_case in enumerate(self.test_data[:num_questions]):
            try:
                question = test_case['question']
                db_id = test_case['db_id']
                ground_truth = test_case['query']
                
                self.log(f"\nProcessing Question {idx + 1}/{num_questions}")
                self.log(f"Question: {question}")
                self.log(f"Database: {db_id}")
                
                # Get schema links
                schema_links = self.value_retrieval.process_schema(question, db_id)
                self.log(f"Schema Links: {schema_links}")
                
                try:
                    # Get classification
                    classification = process_question_classification(
                        question=question,
                        schema_links=schema_links,
                        db_id=db_id,
                        spider_schemas=self.spider_schemas
                    )
                    self.log(f"Classification: {classification}")
                    
                    # Get initial SQL with reasoning
                    process_thesql = process_question_sql(
                        question=question,
                        predicted_class=classification,
                        schema_links=schema_links,
                        db_id=db_id,
                        spider_schemas=self.spider_schemas
                    )
                    
                    self.log("Reasoning Steps:")
                    for thought in process_thesql.reasonings:
                        self.log(f"- {thought.thought}")
                        
                    self.log(f"Initial SQL: {process_thesql.generated_sql_query}")
                    
                    # Refine the SQL with all required parameters
                    final_sql = refine_query(
                        question=question,
                        sql=process_thesql.generated_sql_query,
                        predicted_class=classification,  # Added
                        schema_links=schema_links,      # Added
                        db_id=db_id,
                        spider_schemas=self.spider_schemas
                    )
                    
                    self.log(f"Final SQL: {final_sql}")
                    self.log(f"Ground Truth: {ground_truth}")
                    
                    # Compare results
                    #if self.compare_sql(final_sql, ground_truth):
                     #   successful += 1
                     #   self.log("✓ Match")
                    #else:
                    #    self.log("✗ No match")
                        
                    #processed += 1
                    
                except Exception as e:
                    self.log(f"Error processing question: {str(e)}")
                    continue
                    
            except Exception as e:
                self.log(f"Error in test case: {str(e)}")
                continue

        # Print summary
        self.log("\n=== Testing Summary ===")
        self.log(f"Total Questions Processed: {processed}")
        self.log(f"Successful Matches: {successful}")
        if processed > 0:
            self.log(f"Success Rate: {(successful/processed)*100:.2f}%")

def main():
    """Main entry point for Spider testing."""
    print("Available modes:")
    print("1. dev (1034 questions)")
    print("2. train_spider (7000 questions)")
    print("3. train_others (1659 questions)")
    
    mode = input("Choose mode (dev/train_spider/train_others) [default: dev]: ").strip() or 'dev'
    num_questions = input("How many questions to process? (press Enter for all): ").strip()
    
    try:
        tester = SpiderTester(mode=mode)
        tester.run_test(num_questions=int(num_questions) if num_questions else None)
    except Exception as e:
        print(f"Critical error: {str(e)}")

if __name__ == "__main__":
    main()
    '''

' \nimport os\nimport json\nimport time\nfrom typing import Dict, Any, List\nfrom openai import OpenAI\n\nclass SpiderTester:\n    def __init__(self, spider_path: str = \'/Users/virounikamina/Desktop/spider_data\', mode=\'dev\'):\n        """Initialize Spider tester with dataset path and mode."""\n        self.spider_path = spider_path\n        \n        # Load schemas\n        with open(os.path.join(spider_path, \'tables.json\'), \'r\') as f:\n            self.spider_schemas = json.load(f)\n            \n        # Select correct data file\n        data_files = {\n            \'dev\': \'dev.json\',\n            \'train_spider\': \'train_spider.json\',\n            \'train_others\': \'train_others.json\'\n        }\n        \n        if mode not in data_files:\n            raise ValueError(f"Mode must be one of {list(data_files.keys())}")\n            \n        data_file = data_files[mode]\n        print(f"Loading {data_file}...")\n        \n        # Load test data\n        with open(o

In [10]:
############################################ COLUMN MAPPING

from typing import Union, Tuple, List, Optional

class CMBackground(BaseModel):
    """A setup to the background for the user."""

    background: str = Field(description="Background for the user's question", min_length=10)


class CMThought(BaseModel):
    """A thought about the user's question."""

    thought: str  = Field(description="Text of the thought.")
#     helpful: bool = Field(description="Whether the thought is helpful to solving the user's question.")


class CMObservation(BaseModel):
    """An observation on the sequence of thoughts and observations generated so far."""

    observation: str = Field(description="An insightful observation on the sequence of thoughts and observations generated so far.")
    

class CMReasonings(BaseModel):
    """Returns a detailed reasoning to the user's question."""

    reasonings: list[Union[CMBackground, CMThought, CMObservation]] = Field(
        description="Reasonings to solve the users questions."
        #, min_length=5
    )

reasonings_schema_json = CMReasonings.model_json_schema()

class FinalQueryOutput(BaseModel):
    
    input_sql_query_1: str = Field(
        description=f"""Returns the exact same first query that the user gave as input.""")
        
    input_sql_query_2: str = Field(
        description=f"""Returns the exact same second query that the user gave as input.""")

    reasonings: list[Union[CMBackground, CMThought, CMObservation]] = Field(
        description="Reasonings to solve the users questions."
        #, min_length=5
    )
        
    column_mapping_list: List[Tuple[str, str]] = Field(
        description=f"""Returns the list of the corresponding column names in first sql query, sql 1, which
        corresponds to the column name in the other sql query, sql 2, as a list of tuple entries""")
    
column_mapping_schema_json = FinalQueryOutput.model_json_schema()

complete_user_prompts = """
```
Task Overview
```
Given two sql queries which are supposed to be equivalent, as inputs, 
the task is to give a column mapping between the output columns in one sql query
to the other sql query.

The mapping should include any table aliases present in the column names.
For example, if one query uses 'COLUMN_NAME' and another uses 'alias.column_name',
the mapping should be ['COLUMN_NAME', 'alias.column_name'].
```

```
The mapping is to be generated as a list of tuples.
```

```
For each element of the list which would be a tuple, 
the first entry in the tuple would be the column name used in sql query 1,
and the second entry in the tuple would be the corresponding column name in the sql query 2.
```
"""

reasoning_instructions = """
```
1. Reasoning you provide should first focus on whether the input sql queries contain 
a nested query or not.
2. It should give a plan on how to solve this question.
3. It should explain each of the clauses and why they are structured the way they are structured. 
For example, if there is a `group_by`, an explanation should be given as to why it exists.
```

```
Format the generated sql with proper indentation - the columns in the
(`select` statement should have more indentation than keyword `select`
and so on for each SQL clause.)
```
"""

thought_instructions = f"""
```
Thought Instructions:
```

```
Generate thoughts of increasing complexity.
Each thought should build on the previous ones and thoughts 
should progressively cover the nuances of the problem at hand.
```

```
Generate two separate thoughts, one each for the two input sql queries, 
to figure out the list of output columns in each of the sql queries.
```

```
Generate a thought to figure out the list of columns in sql query 1
which are present in both the sql queries.
```

```
Generate a thought to figure out the list of columns in sql query 1 
which are in sql query 1 but 
which are not present in sql query 2.
```

```
Generate a thought to figure out the list of columns in sql query 1
which are in sql query 2 but 
which are not present in sql query 1.
```

```
If the query uses common table expressions or nested queries, 
the above thoughts should be generated for each of the CTE separately.
```


```
Closing Thoughts and Observations
```
These should summarize:
1. The structure of the SQL query:
    - This states whether the query has any nested query.
    If so, the structure of the nested query is also mentioned.
    If not, a summary of the function of each of the select`, `where`, `group_by` etc. clauses
    should be mentioned.
2. An explanation of why the mapping is correct.
"""

reasoning_schema_instructions = f"""
```
Use the following JSON Schema as the grammar to create the structure 
for the step by step reasoning, and then to 
create the final SQL query.
```

```
Schema for Reasoning:
```
{reasonings_schema_json}
```

```
The instructions on how to structure the reasoning is provided below:
```
{thought_instructions}
```

```
Schema for Overall Output:
(This includes the reasonings schema above as an element)
```
{column_mapping_schema_json}
```

```
The final response should be a json with `names` as 
    `input_sql_query_1`,
    `input_sql_query_2`,
    `reasonings`,
    `column_mapping_list`.
```
"""


def get_user_prompt_for_question(input_sql_query_1, input_sql_query_2, input_table_schema, complete_user_prompts):
    
    user_prompt = f"""
```
Here are the two sql statements that are to be compared:
```

```
SQL Query 1:
```
{input_sql_query_1}
```

```
SQL Query 2:
```
{input_sql_query_2}
```

```
Generate a column mapping corresponding to the given input sql queries
and the description of the table provided below.
```
{input_table_schema}
```

```
Here's a more detailed set of instructions:
```
{complete_user_prompts}
```

```
Reasoning as to why the query is correct:
```
{reasoning_instructions}


{reasoning_schema_instructions}

```
Response for Column Mapping Generation:
```
"""
    
    return user_prompt


def call_openai_model(system_prompt, user_prompt, model_name):

    chat_history = [
        {
            'role': 'system', 
            'content': system_prompt
        },
        {
            'role': 'user', 
            'content': user_prompt
        }, 

    ]
    
    final_response = {}
    
    try:
        
        response = client.chat.completions.create(
            model           = model_name, 
            messages        = chat_history, 
            response_format = {"type":"json_object"}
        )
        
        final_response = response.choices[0].message.content
    
    except Exception:

        response = {
            "content": "An error occured. Please retry your chat. \
                If you keep getting this error, you may be out of OpenAI \
                completion tokens. Contact #help-ai on slack for assistance."
        }
        return response

    return final_response


system_prompt_snippet_001 = """
```
You are the most intelligent person in the world.
```
"""

system_prompt_snippet_002 = """

```
You will receive a $500 tip if you follow ALL the instructions specified.
```
"""

system_prompt_snippet_003 = """

```
Instructions
```
Give a column mapping between two equivalent sql statements
which may differ in the names of columns used in the output
and may also differ in the structure, but the overall meaning
and function of the query is meant to be the same.
```

```
Use step by step reasoning and at each step generate thoughts of increasing complexity.
```
"""

system_prompt_snippet_004 = """

```
Getting this answer right is important for my career. Please do your best.
```
"""

system_prompt = f"""
{system_prompt_snippet_001}
{system_prompt_snippet_002}
{system_prompt_snippet_003}
{system_prompt_snippet_004}
"""

In [11]:
curr = os.getcwd()
print(curr)
output_file = os.path.join(curr, 'spider_all_outputs')
def append_to_file(output, qnum, filename=output_file):
    # Check if file exists
    output_filename= filename+str(qnum)+'.txt'
    if not os.path.exists(output_filename):
        with open(output_filename, 'w') as file:
            file.write("Test_Spider Output Log\n")
            file.write("=" * 80 + "\n")
    # Append the output
    with open(output_filename, 'a') as file:
        file.write(output + "\n" + "=" * 80 + "\n")

/Users/virounikamina/Desktop/PIMCO-Text2SQL/test


In [12]:
import sqlite3
import io
import csv
def execute_sql(query: str) -> str:
    conn = None
    try:
        conn = sqlite3.connect('sqlite/nport.db')
        cursor = conn.cursor()

        # Execute the query with a timeout
        cursor.execute(query)

        # Fetch column names and rows
        columns = [description[0] for description in cursor.description]
        rows = cursor.fetchall()

        # Convert the results to CSV
        output = io.StringIO()
        writer = csv.writer(output)
        writer.writerow(columns)
        writer.writerows(rows)
        csv_data = output.getvalue()
        output.close()

        return csv_data
    except sqlite3.Error as e:
        print(f"Database error: {str(e)}")
        raise e
    except Exception as e:
        print(f"Error executing SQL: {str(e)}")
        raise e
    finally:
        if conn:
            conn.close()

In [13]:
def compare_csv_strings(csv_data1: str, csv_data2: str) -> bool:
    # Use io.StringIO to read the CSV strings as file-like objects
    csv_file1 = io.StringIO(csv_data1)
    csv_file2 = io.StringIO(csv_data2)
    
    # Create CSV readers for each CSV string
    reader1 = csv.reader(csv_file1)
    reader2 = csv.reader(csv_file2)
    
    # Compare rows one by one
    for row1, row2 in zip(reader1, reader2):
        if row1 != row2:
            return False  # Rows are different
    
    # Check if there are extra rows in either file
    try:
        next(reader1)
        return False  # Extra rows in csv_data1
    except StopIteration:
        pass

    try:
        next(reader2)
        return False  # Extra rows in csv_data2
    except StopIteration:
        pass

    return True  # CSVs are identical



import pandas as pd
import re

def get_aggregate_columns(sql_query):
    """
    Extract resulting output column names of aggregate functions in the SQL query,
    handling duplicates and default naming conventions.
    """
    aggregate_functions = ["SUM", "AVG", "COUNT", "MAX", "MIN"]
    output_columns = []

    # Regex to match aggregate functions with optional aliasing
    pattern = rf"({'|'.join(aggregate_functions)})\((.*?)\)(?:\s+AS\s+([\w_]+))?"
    
    matches = re.findall(pattern, sql_query, re.IGNORECASE)
    function_counter = {}  # Track occurrences of each aggregate function
    
    for func, inner, alias in matches:
        func_lower = func.lower()
        if alias:  # Explicit alias defined
            output_columns.append(alias)
        else:  # No alias, use default naming conventions
            if func_lower not in function_counter:
                function_counter[func_lower] = 0
            else:
                function_counter[func_lower] += 1
            # Generate default name (e.g., sum, sum_1, sum_2, etc.)
            if function_counter[func_lower] == 0:
                output_columns.append(f"{func_lower}({inner.strip()})")  # Default naming for SQLite
            else:
                output_columns.append(f"{func_lower}({inner.strip()})_{function_counter[func_lower]}")  # Add suffix

    return output_columns

def evaluate_sql_accuracy(generated_sql, ground_truth_sql, generated_csv, ground_truth_csv, qnum):
    """
    Evaluate the accuracy of generated SQL by comparing the resulting CSV files.
    """
    # Load CSV files
    gen_df = pd.read_csv(io.StringIO(generated_csv))
    gt_df = pd.read_csv(io.StringIO(ground_truth_csv))
    
    # Ensure all ground truth columns are in the generated DataFrame
    for col in gt_df.columns:
        if col not in gen_df.columns:
            append_to_file("False, not all ground truth columns are in generated csv",qnum)
            return False

    # Identify resulting output columns of aggregate functions in both SQL queries
    gt_agg_columns = get_aggregate_columns(ground_truth_sql)

    # Remove aggregate function columns from both DataFrames
    gen_df = gen_df.drop(columns=gt_agg_columns, errors='ignore')
    gt_df = gt_df.drop(columns=gt_agg_columns, errors='ignore')

    # Align columns in the generated DataFrame to match ground truth
    gen_subset = gen_df[gt_df.columns]

    # Check if rows match exactly
    if not gen_subset.equals(gt_df):
        append_to_file("False, all ground truth columns exist, but rows mismatch",qnum)
        return False  # Row mismatch detected

    append_to_file("True, all ground truth columns exist, and rows match", qnum)
    return True  # All checks passed





def compare_csv_din(ground_truth_query: str, llm_query: str, qnum: int, db_id: str, spider_schemas: list, value_retrieval):
    ## let LLM stack query the database
    append_to_file(f"Ground Truth Query: {ground_truth_query}", qnum)

    try: 
        # Get schema links using Spider-specific value retrieval
        schema_links = value_retrieval.process_schema(llm_query, db_id)
        append_to_file(f"Schema Links for Question: {llm_query}\n{schema_links}", qnum)
    
    except Exception as e:
        err_string = (f"Error in process_schema of Value Retrieval: {e}")
        print(err_string)
        append_to_file(err_string,qnum)
        raise e
    try:
        # Get classification with Spider-specific parameters
        classification = process_question_classification(
            question=llm_query,
            schema_links=schema_links,
            db_id=db_id,
            spider_schemas=spider_schemas
        )
        append_to_file(f"classification: {classification}", qnum)
    except Exception as e:
        err_string = (f"Error in process_question_classification of Classification: {e}")
        print(err_string)
        append_to_file(err_string, qnum)
        raise e
    

    try:
        # Generate SQL with Spider-specific parameters
        process_thesql = process_question_sql(
            question=llm_query,
            predicted_class=classification,
            schema_links=schema_links,
            db_id=db_id,
            spider_schemas=spider_schemas
        )
        append_to_file(f"Thoughts: {process_thesql.reasonings}", qnum)
        append_to_file(f"SQL: {process_thesql.generated_sql_query}", qnum)
    except Exception as e:
        err_string = (f"Error in process_question_sql of SQL Generation: {e}")
        print(err_string)
        append_to_file(err_string, qnum)
        raise e

    try:
        # Refine query with Spider-specific parameters
        final_output = refine_query(
            question=llm_query,
            sql=process_thesql.generated_sql_query,
            predicted_class=classification,
            schema_links=schema_links,
            db_id=db_id,
            spider_schemas=spider_schemas
        ).replace("```sql", "").replace("```", "").strip()
        append_to_file(f"final_output: {final_output}", qnum)
    except Exception as e:
        err_string = (f"Error in refine_query of Self-Correction: {e}")
        print(err_string)
        append_to_file(err_string, qnum)
        raise e



    # Add column mapping here
    try:
        column_mappings_prompt = get_user_prompt_for_question(
            ground_truth_query,
            final_output,
            schema_links,  # Changed from schema_dict to schema_links
            complete_user_prompts
        )

        column_mappings_response = call_openai_model(
            system_prompt=system_prompt,
            user_prompt=column_mappings_prompt,
            model_name='gpt-4o'
        )
        
        response_parsed = json.loads(column_mappings_response)
        append_to_file(f"Column Mappings: {json.dumps(response_parsed['column_mapping_list'], indent=2)}", qnum)
    except Exception as e:
        err_string = f"Error Mapping Columns: {str(e)}"
        print(err_string)
        append_to_file(err_string, qnum)
        raise e
    '''
    try:
        llm_csv = execute_sql(final_output)
    except Exception as e:
        err_string = (f"Error Executing LLM-Generated SQL: {str(e)}")
        print(err_string)
        append_to_file(err_string,qnum)
        raise e
    
    try:
    ## direct query to the database
        ground_truth_csv = execute_sql(ground_truth_query)
    except Exception as e:
        err_string = (f"Error Executing Ground Truth SQL: {str(e)}")
        print(err_string)
        append_to_file(err_string,qnum)
        raise e
    ## compare results
    
    try:
        #diff=compare_csv_strings(ground_truth_csv,llm_csv)
        diff = evaluate_sql_accuracy(generated_sql=final_output,ground_truth_sql=ground_truth_query,generated_csv=llm_csv,ground_truth_csv=ground_truth_csv, qnum=qnum)
        if diff:
            print("CSV outputs match perfectly.")
            return True
        else:
            print("Mismatch found.")
            return False
    except Exception as e:
        err_string=(f"Error comparing CSVs: {str(e)}")
        print(err_string)
        append_to_file(err_string,qnum)
        raise e'''

In [14]:
import csv
import json

# Load data from query_summary.csv
def load_queries(input_file):
    llm_query = []
    ground_truth_query = []
    with open(input_file, 'r', newline='', encoding='utf-8') as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            llm_query.append(row["Question"])
            ground_truth_query.append(row["SQL"])
    return llm_query, ground_truth_query

# Save arrays to file
def save_queries_to_file(file_path, llm_query, ground_truth_query):
    with open(file_path, 'w', encoding='utf-8') as file:
        json.dump({"llm_query": llm_query, "ground_truth_query": ground_truth_query}, file)

# Load arrays from file
def load_queries_from_file(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        data = json.load(file)
        return data["llm_query"], data["ground_truth_query"]
def process_queries(spider_path='/Users/virounikamina/Desktop/spider_data', mode='dev'):
    """Process queries from Spider dev.json file."""
    # Load Spider schemas
    tables_path = os.path.join(spider_path, 'tables.json')
    with open(tables_path, 'r') as f:
        spider_schemas = json.load(f)
    
    # Initialize value retrieval once with proper path
    value_retrieval = SpiderValueRetrieval(spider_tables_path=tables_path)
    
    # Load input file based on mode
    data_files = {
        'dev': 'dev.json',
        'train_spider': 'train_spider.json',
        'train_others': 'train_others.json'
    }
    
    if mode not in data_files:
        raise ValueError(f"Mode must be one of {list(data_files.keys())}")
        
    input_file = data_files[mode]
    print(f"Loading {input_file}...")
    
    # Load test data
    with open(os.path.join(spider_path, input_file), 'r') as f:
        test_data = json.load(f)
        
    print(f"Loaded {len(test_data)} questions")
    
    # Prepare output files in current directory
    output_file = f"spider_results_{mode}.csv"
    with open(output_file, 'w', newline='', encoding='utf-8') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow([
            "Question_Number",
            "Database_ID",
            "Question",
            "Ground_Truth_Query",
            "Match_Result"
        ])

    processed = 0
    successful = 0
    
    print(f"\nTesting {len(test_data)} questions")
    
    # Process each question in test data
    for i, test_case in enumerate(test_data):
        question = test_case['question']
        db_id = test_case['db_id']
        ground_truth = test_case['query']

        print("=" * 120)
        print(f"Processing Question {i + 1}/{len(test_data)}")
        print(f"Database ID: {db_id}")
        print("=" * 120)

        try:
            result = compare_csv_din(
                ground_truth_query=ground_truth,
                llm_query=question,
                qnum=i,
                db_id=db_id,
                spider_schemas=spider_schemas,
                value_retrieval=value_retrieval  # Pass the initialized value_retrieval instance
            )
            processed += 1
            if result:
                successful += 1
            
            write_to_output(output_file, i, db_id, question, ground_truth, result)
            
        except Exception as e:
            error_msg = str(e)
            print(f"Error processing question {i}: {error_msg}")
            write_to_output(output_file, i, db_id, question, ground_truth, "Error")
            append_to_file(f"Error: {error_msg}", i)
    
import os
import json
import csv
import time
from datetime import datetime

# Define Spider path
SPIDER_PATH = '/Users/virounikamina/Desktop/spider_data'

def write_to_output(file_path, qnum, db_id, question, ground_truth, result):
    """Write a single result to the output CSV file."""
    with open(file_path, 'a', newline='', encoding='utf-8') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow([
            qnum,
            db_id,
            question,
            ground_truth,
            result
        ])

def process_queries(spider_path='/Users/virounikamina/Desktop/spider_data', mode='dev'):
    """Process queries from Spider dev.json file."""
    # Load Spider schemas
    tables_path = os.path.join(spider_path, 'tables.json')
    with open(tables_path, 'r') as f:
        spider_schemas = json.load(f)
    
    # Initialize value retrieval once with proper path
    value_retrieval = SpiderValueRetrieval(spider_tables_path=tables_path)
    
    # Load input file based on mode
    data_files = {
        'dev': 'dev.json',
        'train_spider': 'train_spider.json',
        'train_others': 'train_others.json'
    }
    
    if mode not in data_files:
        raise ValueError(f"Mode must be one of {list(data_files.keys())}")
        
    input_file = data_files[mode]
    print(f"Loading {input_file}...")
    
    # Load test data
    with open(os.path.join(spider_path, input_file), 'r') as f:
        test_data = json.load(f)
        
    print(f"Loaded {len(test_data)} questions")
    
    # Create timestamp for this run
    timestamp = time.strftime("%Y%m%d_%H%M%S")
    
    # Prepare output files in current directory
    output_file = f"spider_results_{mode}_{timestamp}.csv"
    with open(output_file, 'w', newline='', encoding='utf-8') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow([
            "Question_Number",
            "Database_ID",
            "Question",
            "Ground_Truth_Query",
            "Match_Result"
        ])

    processed = 0
    successful = 0
    
    print(f"\nTesting {len(test_data)} questions")
    
    # Process each question in test data
    for i, test_case in enumerate(test_data):
        question = test_case['question']
        db_id = test_case['db_id']
        ground_truth = test_case['query']

        print("=" * 120)
        print(f"Processing Question {i + 1}/{len(test_data)}")
        print(f"Database ID: {db_id}")
        print("=" * 120)

        try:
            result = compare_csv_din(
                ground_truth_query=ground_truth,
                llm_query=question,
                qnum=i,
                db_id=db_id,
                spider_schemas=spider_schemas,
                value_retrieval=value_retrieval
            )
            processed += 1
            if result:
                successful += 1
            
            write_to_output(output_file, i, db_id, question, ground_truth, result)
            
        except Exception as e:
            error_msg = str(e)
            print(f"Error processing question {i}: {error_msg}")
            write_to_output(output_file, i, db_id, question, ground_truth, "Error")
            append_to_file(f"Error: {error_msg}", i)


if __name__ == "__main__":
    # Call process_queries with Spider path
    process_queries(
        spider_path=SPIDER_PATH,
        mode='dev'  # or 'train_spider' or 'train_others'
    )

DEBUG: Loading Spider schemas from: /Users/virounikamina/Desktop/spider_data/tables.json
Loading dev.json...
Loaded 1034 questions

Testing 1034 questions
Processing Question 1/1034
Database ID: concert_singer
Extracted Info: {'keywords': ['singers', 'count'], 'keyphrases': ['How many singers'], 'numerical_values': []}
Words Extracted: ['singers', 'count', 'How many singers']
Processed Words: ['many', 'singer', 'count']
Similar matches for 'many': []
Similar matches for 'singer': [('singer.singer_id', 1.0), ('singer_in_concert.singer_id', 1.0), ('singer', 0.8)]
Similar matches for 'count': [('singer.country', 0.83), ('concert.concert_id', 0.67), ('concert.concert_name', 0.67), ('singer_in_concert.concert_id', 0.67)]
Table Columns: ['singer.singer_id', 'singer.country']
Schema Dict: {'table_columns': ['singer.singer_id', 'singer.country'], 'primary_keys': ['singer.singer_id'], 'foreign_keys': ['singer_in_concert.singer_id = singer.singer_id'], 'schema_links': ['singer.singer_id', 'singe

KeyboardInterrupt: 