In [1]:
import os
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from typing import Optional, Union, List, Dict, Any
from tqdm.notebook import tqdm
import re
from sqlalchemy import create_engine, text, MetaData, Table, Column

from llama_index.core import SQLDatabase
from llama_index.core.tools import QueryEngineTool
from llama_index.core.query_engine import NLSQLTableQueryEngine

import setup
setup.init_django()

from rag import (
    db as rag_db,
    engines as rag_engines,
    settings as rag_settings,
    prompts as rag_prompts,
    patches as rag_patches,
)

In [2]:
class SpiderDatabaseManager:
    def __init__(self, spider_dir="/home/harry/chatbotDjango/spider/spider_data"):
        self.spider_dir = spider_dir
        self.db_dir = os.path.join(spider_dir, "database")
        self.tables_file = os.path.join(spider_dir, "tables.json")
        self.db_schemas = self._load_db_schemas()
        
    def _load_db_schemas(self):
        """Load schema information for all Spider databases"""
        with open(self.tables_file, 'r') as f:
            return json.load(f)
        
    def get_db_path(self, db_id):
        """Get SQLite file path for a database ID"""
        return os.path.join(self.db_dir, db_id, f"{db_id}.sqlite")
    
    def get_schema_for_db(self, db_id):
        """Get schema information for specific database"""
        for db_schema in self.db_schemas:
            if db_schema['db_id'] == db_id:
                return db_schema
        return None
    
    def get_all_db_ids(self):
        """Get all database IDs in Spider benchmark"""
        return [db_schema['db_id'] for db_schema in self.db_schemas]
    
    def get_sqlalchemy_engine(self, db_id):
        """Create SQLAlchemy engine for a database"""
        db_path = self.get_db_path(db_id)
        return create_engine(f"sqlite:///{db_path}")
    
    def get_table_names(self, db_id):
        """Get all table names for a database"""
        schema = self.get_schema_for_db(db_id)
        if schema:
            if isinstance(schema['table_names_original'], list):
                # Check if the list contains strings or something else
                if all(isinstance(item, str) for item in schema['table_names_original']):
                    return schema['table_names_original']
                else:
                    # If table_names_original contains objects, try to extract table_name
                    return [table['table_name'] if isinstance(table, dict) and 'table_name' in table 
                           else str(table) for table in schema['table_names_original']]
        return []
    
    def get_llama_index_database(self, db_id):
        """Create LlamaIndex SQLDatabase for Spider database"""
        engine = self.get_sqlalchemy_engine(db_id)
        table_names = self.get_table_names(db_id)
        # NEW: Verify tables
        print(f"Final tables for {db_id}: {table_names}")  # NEW
        return SQLDatabase(engine, include_tables=table_names)

spider_manager = SpiderDatabaseManager()

In [3]:
import os
import json

# 1. Define the correct path using raw string format
SPIDER_DIR = "/home/harry/chatbotDjango/spider/spider_data"

# 2. Verify the path exists
if not os.path.exists(SPIDER_DIR):
    raise FileNotFoundError(f"Spider directory not found at: {SPIDER_DIR}")

# 3. Verify dev.json exists
DEV_JSON_PATH = os.path.join(SPIDER_DIR, "dev.json")
if not os.path.exists(DEV_JSON_PATH):
    raise FileNotFoundError(f"dev.json not found at: {DEV_JSON_PATH}")

def load_spider_data(split="dev"):
    """Load Spider evaluation data for specified split"""
    file_path = os.path.join(SPIDER_DIR, f"{split}.json")
    
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            return json.load(f)
    except UnicodeDecodeError as e:
        raise ValueError(f"Error decoding JSON file: {e}") from e

# 4. Load the data with verification
try:
    spider_dev = load_spider_data("dev")
    print(f"Successfully loaded Spider dev data with {len(spider_dev)} entries")
except Exception as e:
    print(f"Error loading data: {str(e)}")
    raise

Successfully loaded Spider dev data with 1034 entries


In [4]:
def create_spider_sql_query_engine(db_id):
    """Create SQL query engine for Spider database"""
    sql_database = spider_manager.get_llama_index_database(db_id)
    table_names = spider_manager.get_table_names(db_id)
    
    text_to_sql_prompt = rag_prompts.custom_text_to_sql_prompt
    text_to_sql_prompt.template = text_to_sql_prompt.template.replace(
        "{dialect} PostgreSQL", "SQLite"
    )
    
    sql_query_engine = NLSQLTableQueryEngine(
        sql_database=sql_database,
        tables=table_names,
        response_synthesis_prompt=rag_prompts.custom_sql_response_synthesis_prompt,
        text_to_sql_prompt=text_to_sql_prompt
    )
    
    # Use the corrected import for QueryEngineTool
    from llama_index.core.tools import QueryEngineTool
    
    sql_tool = QueryEngineTool.from_defaults(
        query_engine=sql_query_engine,
        description=f"Useful for translating natural language queries into SQL over {db_id} database"
    )
    
    return sql_query_engine

In [5]:
import time

# Global rate limiting state
LAST_API_CALL_TIME = 0
MAX_RETRIES = 3

def api_call_with_retry(func):
    """Decorator to handle rate limiting and retries for any API call"""
    def wrapper(*args, **kwargs):
        global LAST_API_CALL_TIME, MAX_RETRIES
        retries = 0
        
        while retries <= MAX_RETRIES:
            # Enforce rate limiting
            elapsed = time.time() - LAST_API_CALL_TIME
            if elapsed < 5:
                sleep_time = 5 - elapsed
                print(f"Global rate limiting: Sleeping {sleep_time:.1f}s")
                time.sleep(sleep_time)
            
            try:
                result = func(*args, **kwargs)
                LAST_API_CALL_TIME = time.time()
                return result
                
            except Exception as e:
                if "rate limit" in str(e).lower():
                    print(f"Rate limit hit ({func.__name__}). Waiting 5s (retry {retries+1}/{MAX_RETRIES})")
                    time.sleep(5)
                    LAST_API_CALL_TIME = time.time()
                    retries += 1
                else:
                    print(f"Non-retryable error in {func.__name__}: {str(e)}")
                    raise
        
        print(f"Max retries exceeded for {func.__name__}. Raising error.")
        raise RuntimeError(f"API call failed after {MAX_RETRIES} retries")
    
    return wrapper

In [6]:
@api_call_with_retry
def translate_to_farsi(text):
    """Translate English to Farsi using LLM"""
    prompt = f"Translate English to Farsi. Only return translation:\n{text}"
    response = rag_settings.Settings.llm.complete(prompt)
    return response.text.strip()

In [7]:
@api_call_with_retry
def execute_query(query_engine, query_text):
    """Execute a query with rate limiting"""
    return query_engine.query(query_text)

In [8]:
def debug_spider_schema(db_id):
    schema = spider_manager.get_schema_for_db(db_id)
    if not schema:
        print(f"No schema found for database: {db_id}")
        return
    print(f"Database ID: {db_id}")
    print(f"Schema keys: {list(schema.keys())}")
    print("\nTable Names Original structure:")
    if 'table_names_original' in schema:
        table_names = schema['table_names_original']
        print(f"Type: {type(table_names)}")
        print(f"Length: {len(table_names)}")
        if len(table_names) > 0:
            print(f"First item type: {type(table_names[0])}")
            print(f"First item: {table_names[0]}")
    else:
        print("No table_names_original in schema")
    all_db_ids = spider_manager.get_all_db_ids()
    if all_db_ids:
        print(f"\nAll database IDs (first 5): {all_db_ids[:5]}")

In [9]:
import re
import numpy as np
from tqdm.notebook import tqdm

def extract_sql_from_response(response):
    if response is None:
        return "N/A (No Response)"
    if hasattr(response, 'metadata') and getattr(response, 'metadata', None) and 'sql_query' in response.metadata:
        return response.metadata['sql_query']
    if hasattr(response, 'source_nodes'):
        for node in getattr(response, 'source_nodes', []):
            node_metadata = getattr(node, 'metadata', {})
            if 'sql_query' in node_metadata:
                return node_metadata['sql_query']
    response_text = str(getattr(response, 'response', str(response)))
    patterns = [
        r"```sql\s*([\s\S]*?)\s*```",
        r"SQL: (.*?);",
        r"SELECT.*?FROM.*?(?:WHERE.*?|ORDER BY.*?|LIMIT.*?)?$"
    ]
    for pattern in patterns:
        match = re.search(pattern, response_text, re.IGNORECASE | re.DOTALL)
        if match:
            return match.group(1).strip()
    return response_text

def evaluate_on_spider_sample(sample_size=3, use_farsi=False, random_seed=42):
    import time
    np.random.seed(random_seed)
    results = []
    db_counts = {}
    for ex in spider_dev:
        db_counts[ex['db_id']] = db_counts.get(ex['db_id'], 0) + 1
    valid_dbs = [db_id for db_id, count in db_counts.items() if count > 0]
    if not valid_dbs:
        print("No valid databases found with examples")
        return results
    selected_dbs = np.random.choice(valid_dbs, min(sample_size, len(valid_dbs)), replace=False)
    for db_id in selected_dbs:
        print(f"\n{'='*40}\nProcessing database: {db_id}")
        translation_cache = {} if use_farsi else None
        db_examples = [ex for ex in spider_dev if ex['db_id'] == db_id]
        if len(db_examples) == 0:
            print(f"Skipping {db_id} - no examples available")
            continue
        actual_sample_size = min(2, len(db_examples))
        try:
            if actual_sample_size == 1:
                sample_examples = db_examples
            else:
                sample_examples = np.random.choice(db_examples, actual_sample_size, replace=False)
        except ValueError as e:
            print(f"Error sampling examples for {db_id}: {str(e)}")
            continue
        try:
            query_engine = create_spider_sql_query_engine(db_id)
            test_response = execute_query(query_engine, "SELECT 1")
            if not test_response:
                raise ValueError("Query engine failed basic connectivity test")
        except Exception as e:
            print(f"Failed to create engine for {db_id}: {str(e)}")
            continue
        for example in tqdm(sample_examples, desc=f"Evaluating {db_id}"):
            question = example['question']
            query_text = None 
            try:
                if use_farsi:
                    if question not in translation_cache:
                        translation_cache[question] = translate_to_farsi(question)
                    query_text = translation_cache[question]
                else:
                    query_text = question
                response = execute_query(query_engine, query_text)
                predicted_sql = extract_sql_from_response(response)
                if not re.search(r"(SELECT|INSERT|UPDATE|DELETE)", predicted_sql, re.IGNORECASE):
                    raise ValueError("Generated SQL appears invalid")
                results.append({
                    'db_id': db_id,
                    'question': question,
                    'farsi_question': query_text if use_farsi else None,
                    'predicted_sql': predicted_sql,
                    'gold_sql': example['query'],
                    'success': True
                })
            except Exception as e:
                import traceback
                print(f"\nError processing query: {traceback.format_exc()}")
                results.append({
                    'db_id': db_id,
                    'question': question,
                    'farsi_question': query_text if use_farsi else None,  # 
                    'gold_sql': example.get('query', 'Not available'),
                    'error': str(e),
                    'success': False
                })
    return results

In [None]:
# English evaluation
english_results = evaluate_on_spider_sample(sample_size=2, use_farsi=False)

# Farsi evaluation 
farsi_results = evaluate_on_spider_sample(sample_size=2, use_farsi=True)


Processing database: concert_singer
Final tables for concert_singer: ['stadium', 'singer', 'concert', 'singer_in_concert']


Evaluating concert_singer:   0%|          | 0/2 [00:00<?, ?it/s]

Global rate limiting: Sleeping 5.0s
Global rate limiting: Sleeping 5.0s

Processing database: dog_kennels
Final tables for dog_kennels: ['Breeds', 'Charges', 'Sizes', 'Treatment_Types', 'Owners', 'Dogs', 'Professionals', 'Treatments']
Global rate limiting: Sleeping 5.0s


Evaluating dog_kennels:   0%|          | 0/2 [00:00<?, ?it/s]

Global rate limiting: Sleeping 5.0s


In [None]:
def analyze_results(results):
    if not results:
        print("No results to analyze")
        return pd.DataFrame()
    df = pd.DataFrame(results)
    if df.empty:
        print("No results available")
        return df
    if 'success' not in df.columns:
        print("No successful queries found")
        df['success'] = False
    success_rate = df['success'].mean() if 'success' in df.columns else 0
    print(f"Total Queries: {len(df)}")
    print(f"Success Rate: {success_rate:.2%}")
    if 'error' in df.columns and (~df['success']).any():
        print("\nError Distribution:")
        print(df[~df['success']]['error'].value_counts())
    if 'db_id' in df.columns and not df.empty:
        plt.figure(figsize=(10, 5))
        df['db_id'].value_counts().plot(kind='bar')
        plt.title("Query Distribution by Database")
        plt.show()
    return df

In [None]:
# Fix the existing data frame creation issue
english_df = analyze_results(english_results)
farsi_df = analyze_results(farsi_results)

# Now create the comparison bar chart correctly
plt.figure(figsize=(8, 5))
plt.bar(['English', 'Farsi'], 
        [english_df['success'].mean() if not english_df.empty and 'success' in english_df.columns else 0, 
         farsi_df['success'].mean() if not farsi_df.empty and 'success' in farsi_df.columns else 0])
plt.ylim(0, 1)
plt.title("Success Rate Comparison")
plt.ylabel("Success Rate")
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.show()

# Create a more detailed comparison function
def create_detailed_evaluation_report(english_results, farsi_results):
    """Generate detailed comparison between English and Farsi SQL generation"""
    # Create a combined DataFrame with language indicator
    if not english_results and not farsi_results:
        print("No results available for analysis")
        return
    
    # Add language indicator to each result set
    for r in english_results:
        r['language'] = 'English'
    for r in farsi_results:
        r['language'] = 'Farsi'
    
    # Combine results
    all_results = english_results + farsi_results
    df = pd.DataFrame(all_results)
    
    # 1. Print overall statistics
    print(f"Total queries evaluated: {len(df)}")
    print(f"English queries: {len(english_results)}")
    print(f"Farsi queries: {len(farsi_results)}")
    
    # 2. Success rates by language
    success_by_lang = df.groupby('language')['success'].agg(['count', 'mean'])
    success_by_lang.columns = ['Total Queries', 'Success Rate']
    success_by_lang['Success Rate'] = success_by_lang['Success Rate'].apply(lambda x: f"{x:.2%}")
    print("\nSuccess Rate by Language:")
    print(success_by_lang)
    
    # 3. Paired comparison for same questions
    if english_results and farsi_results:
        # Create lookup dictionaries for both result sets
        english_dict = {r['question']: r for r in english_results}
        farsi_dict = {r['question']: r for r in farsi_results}
        
        # Find common questions
        common_questions = set(english_dict.keys()) & set(farsi_dict.keys())
        
        if common_questions:
            print(f"\nFound {len(common_questions)} queries that were tested in both languages")
            
            # Create paired comparison table
            paired_rows = []
            for q in common_questions:
                en_result = english_dict[q]
                fa_result = farsi_dict[q]
                
                paired_rows.append({
                    'Question': q,
                    'English_Success': en_result.get('success', False),
                    'Farsi_Success': fa_result.get('success', False),
                    'English_SQL': en_result.get('predicted_sql', 'N/A'),
                    'Farsi_SQL': fa_result.get('predicted_sql', 'N/A'),
                    'Gold_SQL': en_result.get('gold_sql', 'N/A'),
                    'Farsi_Translation': fa_result.get('farsi_question', 'N/A')
                })
            
            paired_df = pd.DataFrame(paired_rows)
            
            # Calculate agreement statistics
            both_success = sum((paired_df['English_Success'] & paired_df['Farsi_Success']))
            both_fail = sum((~paired_df['English_Success'] & ~paired_df['Farsi_Success']))
            en_only = sum((paired_df['English_Success'] & ~paired_df['Farsi_Success']))
            fa_only = sum((~paired_df['English_Success'] & paired_df['Farsi_Success']))
            
            print(f"Both languages succeeded: {both_success} ({both_success/len(paired_df):.2%})")
            print(f"Both languages failed: {both_fail} ({both_fail/len(paired_df):.2%})")
            print(f"Only English succeeded: {en_only} ({en_only/len(paired_df):.2%})")
            print(f"Only Farsi succeeded: {fa_only} ({fa_only/len(paired_df):.2%})")
            
            # Display example translations and SQL generations
            if len(paired_df) > 0:
                print("\nExample Translation and SQL Generation:")
                for idx, row in paired_df.head(min(3, len(paired_df))).iterrows():
                    print(f"\nQuestion: {row['Question']}")
                    print(f"Farsi Translation: {row['Farsi_Translation']}")
                    print(f"Gold SQL: {row['Gold_SQL']}")
                    print(f"English Generated SQL: {row['English_SQL']}")
                    print(f"Farsi Generated SQL: {row['Farsi_SQL']}")
                    print(f"Match: {'✓' if row['English_Success'] and row['Farsi_Success'] else '✗'}")
        else:
            print("\nNo common questions found between English and Farsi evaluations")
    
    # 4. Error analysis
    if 'error' in df.columns:
        print("\nError Analysis:")
        error_by_lang = df[~df['success']].groupby(['language', 'error']).size().reset_index()
        error_by_lang.columns = ['Language', 'Error Type', 'Count']
        error_by_lang = error_by_lang.sort_values(['Language', 'Count'], ascending=[True, False])
        
        # Group similar errors
        error_by_lang['Error Category'] = error_by_lang['Error Type'].apply(categorize_error)
        
        error_categories = error_by_lang.groupby(['Language', 'Error Category']).agg({'Count': 'sum'}).reset_index()
        error_categories = error_categories.sort_values(['Language', 'Count'], ascending=[True, False])
        
        # Plot error distribution
        plt.figure(figsize=(12, 6))
        for i, lang in enumerate(['English', 'Farsi']):
            lang_errors = error_categories[error_categories['Language'] == lang]
            if not lang_errors.empty:
                plt.subplot(1, 2, i+1)
                plt.pie(lang_errors['Count'], labels=lang_errors['Error Category'], 
                        autopct='%1.1f%%', startangle=90)
                plt.axis('equal')
                plt.title(f'{lang} Error Categories')
        
        plt.tight_layout()
        plt.show()
    
    # 5. Database-wise performance
    if 'db_id' in df.columns:
        print("\nPerformance by Database:")
        db_performance = df.groupby(['language', 'db_id'])['success'].agg(['count', 'mean'])
        db_performance.columns = ['Total Queries', 'Success Rate']
        print(db_performance)
        
        # Plot database performance comparison
        plt.figure(figsize=(12, 6))
        db_success = df.pivot_table(index='db_id', columns='language', 
                                   values='success', aggfunc='mean')
        db_success.plot(kind='bar', figsize=(12, 6))
        plt.title('Success Rate by Database and Language')
        plt.ylabel('Success Rate')
        plt.ylim(0, 1)
        plt.grid(axis='y', linestyle='--', alpha=0.5)
        plt.tight_layout()
        plt.show()
    
    return df

def categorize_error(error_text):
    """Categorize error messages into broader categories"""
    error_text = str(error_text).lower()
    
    if 'sql' in error_text and ('invalid' in error_text or 'syntax' in error_text):
        return 'SQL Syntax Error'
    elif 'generated sql appears invalid' in error_text:
        return 'Invalid SQL Generation'
    elif 'connection' in error_text or 'timeout' in error_text:
        return 'Connection/Timeout Error'
    elif 'schema' in error_text or 'table' in error_text or 'column' in error_text:
        return 'Schema Understanding Error'
    elif 'query' in error_text and 'execution' in error_text:
        return 'Query Execution Error'
    elif 'translation' in error_text:
        return 'Translation Error'
    else:
        return 'Other Error'

# Run the detailed analysis
combined_df = create_detailed_evaluation_report(english_results, farsi_results)

# Add visualization of translation quality
if farsi_results:
    # Extract sample translations
    translations = [(r['question'], r['farsi_question']) for r in farsi_results if 'farsi_question' in r and r['farsi_question']]
    
    if translations:
        # Display sample translations
        print("\nSample English to Farsi Translations:")
        for i, (en, fa) in enumerate(translations[:5]):
            print(f"{i+1}. English: {en}")
            print(f"   Farsi:   {fa}")
            print()
        
        # Create translation analysis
        translation_success = sum(1 for r in farsi_results if r.get('success', False))
        print(f"Total translation attempts: {len(translations)}")
        print(f"Successful queries after translation: {translation_success} ({translation_success/len(translations):.2%})")

# Compare SQL generation complexity
if english_results and farsi_results:
    print("\nSQL Complexity Analysis:")
    
    def analyze_sql_complexity(sql):
        """Calculate complexity metrics for SQL query"""
        if not sql or not isinstance(sql, str):
            return {'length': 0, 'joins': 0, 'conditions': 0, 'aggregations': 0}
        
        # Count SQL features
        joins = len(re.findall(r'\bjoin\b', sql, re.IGNORECASE))
        conditions = len(re.findall(r'\bwhere\b|\band\b|\bor\b', sql, re.IGNORECASE))
        aggregations = len(re.findall(r'\bcount\b|\bsum\b|\bavg\b|\bmax\b|\bmin\b', sql, re.IGNORECASE))
        
        return {
            'length': len(sql),
            'joins': joins,
            'conditions': conditions,
            'aggregations': aggregations
        }
    
    # Analyze only successful generations
    en_successful = [r for r in english_results if r.get('success', False) and 'predicted_sql' in r]
    fa_successful = [r for r in farsi_results if r.get('success', False) and 'predicted_sql' in r]
    
    en_complexity = [analyze_sql_complexity(r['predicted_sql']) for r in en_successful]
    fa_complexity = [analyze_sql_complexity(r['predicted_sql']) for r in fa_successful]
    
    if en_complexity and fa_complexity:
        # Calculate averages
        en_avg = {k: sum(d[k] for d in en_complexity) / len(en_complexity) for k in en_complexity[0]}
        fa_avg = {k: sum(d[k] for d in fa_complexity) / len(fa_complexity) for k in fa_complexity[0]}
        
        # Print complexity comparison
        print("Average SQL Complexity Metrics:")
        metrics = ['length', 'joins', 'conditions', 'aggregations']
        for metric in metrics:
            print(f"{metric.capitalize()}: English={en_avg[metric]:.2f}, Farsi={fa_avg[metric]:.2f}")
        
        # Plot complexity comparison
        plt.figure(figsize=(10, 6))
        x = np.arange(len(metrics))
        width = 0.35
        
        plt.bar(x - width/2, [en_avg[m] for m in metrics], width, label='English')
        plt.bar(x + width/2, [fa_avg[m] for m in metrics], width, label='Farsi')
        
        plt.xlabel('Metrics')
        plt.ylabel('Average Value')
        plt.title('SQL Complexity Comparison')
        plt.xticks(x, [m.capitalize() for m in metrics])
        plt.legend()
        plt.grid(axis='y', linestyle='--', alpha=0.5)
        plt.tight_layout()
        plt.show()