In [None]:
# Text-to-SQL Error Analysis Script for Google Colab
# For bachelor thesis on schema-enhanced Text-to-SQL generation

# First, mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Install required packages (if needed)
!pip install -q matplotlib seaborn pandas tqdm

import json
import os
import re
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from typing import Dict, List, Any, Counter
from collections import defaultdict
from tqdm import tqdm

# --- Configuration ---
DRIVE_BASE_DIR = "/content/drive/MyDrive/text2sql"
TARGET_EXPERIMENT_NAME = "t5_large_sql_types_schema_v3_nofp16"
EVAL_OUTPUT_DIR = f"{DRIVE_BASE_DIR}/eval_results/{TARGET_EXPERIMENT_NAME}"
PREDICTIONS_FILE = f"{EVAL_OUTPUT_DIR}/predictions.json"
ANALYSIS_OUTPUT_DIR = f"{DRIVE_BASE_DIR}/error_analysis/{TARGET_EXPERIMENT_NAME}"
MODEL_NAME_LABEL = "T5-Large with Schema Types"

# Create output directory
os.makedirs(ANALYSIS_OUTPUT_DIR, exist_ok=True)

print(f"Analyzing predictions from: {PREDICTIONS_FILE}")
print(f"Saving analysis results to: {ANALYSIS_OUTPUT_DIR}")
print(f"Using model label: {MODEL_NAME_LABEL}")

if not os.path.exists(PREDICTIONS_FILE):
    print(f"\n*** WARNING: Predictions file not found at {PREDICTIONS_FILE} ***")
    print(f"*** Please ensure the path and TARGET_EXPERIMENT_NAME ('{TARGET_EXPERIMENT_NAME}') are correct. ***")
else:
    print("Predictions file found.")

# --- Utility Functions ---
def extract_tables(sql: str) -> List[str]:
    from_pattern = r'\bFROM\s+([a-zA-Z0-9_]+)'
    join_pattern = r'\bJOIN\s+([a-zA-Z0-9_]+)'
    tables = []
    for match in re.finditer(from_pattern, sql, re.IGNORECASE):
        tables.append(match.group(1).strip())
    for match in re.finditer(join_pattern, sql, re.IGNORECASE):
        tables.append(match.group(1).strip())
    return tables

def extract_columns(sql: str) -> List[str]:
    select_pattern = r'\bSELECT\s+(.*?)\s+FROM'
    where_pattern = r'\bWHERE\s+(.*?)(?:\bGROUP BY|\bORDER BY|\bLIMIT|\bJOIN|\bUNION|\s*$)'
    groupby_pattern = r'\bGROUP BY\s+(.*?)(?:\bHAVING|\bORDER BY|\bLIMIT|\s*$)'
    orderby_pattern = r'\bORDER BY\s+(.*?)(?:\bLIMIT|\s*$)'
    columns = []
    select_match = re.search(select_pattern, sql, re.IGNORECASE | re.DOTALL)
    if select_match:
        select_cols = select_match.group(1).strip()
        select_cols = re.sub(r'[a-zA-Z0-9_]+\s*\(([^)]*)\)', r'\1', select_cols)
        select_cols = re.sub(r'AS\s+[a-zA-Z0-9_]+', '', select_cols, flags=re.IGNORECASE)
        for col in select_cols.split(','):
            col = col.strip().split('.')[-1]
            if col != '*': columns.append(col)
    where_match = re.search(where_pattern, sql, re.IGNORECASE | re.DOTALL)
    if where_match:
        where_clause = where_match.group(1).strip()
        col_matches = re.finditer(r'([a-zA-Z0-9_\.]+)\s*(?:=|>|<|>=|<=|!=|LIKE|IN|NOT IN|IS|IS NOT)', where_clause, re.IGNORECASE)
        for match in col_matches:
            col = match.group(1).strip().split('.')[-1]
            columns.append(col)
    groupby_match = re.search(groupby_pattern, sql, re.IGNORECASE | re.DOTALL)
    if groupby_match:
        for col in groupby_match.group(1).strip().split(','):
            columns.append(col.strip().split('.')[-1])
    orderby_match = re.search(orderby_pattern, sql, re.IGNORECASE | re.DOTALL)
    if orderby_match:
        orderby_cols = orderby_match.group(1).strip()
        orderby_cols = re.sub(r'(?:ASC|DESC)(?:\s*,|$)', ',', orderby_cols, flags=re.IGNORECASE)
        for col in orderby_cols.split(','):
            if col.strip(): columns.append(col.strip().split('.')[-1])
    return [col for col in columns if col]

def has_aggregation(sql: str) -> bool:
    """Check if SQL query contains aggregation."""
    agg_pattern = r'\b(COUNT|SUM|AVG|MIN|MAX|GROUP BY|HAVING)\b'
    return bool(re.search(agg_pattern, sql, re.IGNORECASE))

def has_join(sql: str) -> bool:
    """Check if SQL query contains joins."""
    join_pattern = r'\bJOIN\b'
    return bool(re.search(join_pattern, sql, re.IGNORECASE))

def has_nesting(sql: str) -> bool:
    """Check if SQL query contains nested queries."""
    subquery_pattern = r'\(\s*SELECT'
    return bool(re.search(subquery_pattern, sql, re.IGNORECASE))

def has_order(sql: str) -> bool:
    """Check if SQL query contains ORDER BY."""
    order_pattern = r'\bORDER BY\b'
    return bool(re.search(order_pattern, sql, re.IGNORECASE))

def has_limit(sql: str) -> bool:
    """Check if SQL query contains LIMIT."""
    limit_pattern = r'\bLIMIT\b'
    return bool(re.search(limit_pattern, sql, re.IGNORECASE))

def has_union(sql: str) -> bool:
    """Check if SQL query contains UNION."""
    union_pattern = r'\bUNION\b'
    return bool(re.search(union_pattern, sql, re.IGNORECASE))

def count_conditions(sql: str) -> int:
    where_match = re.search(r'\bWHERE\s+(.*?)(?:\bGROUP BY|\bORDER BY|\bLIMIT|\s*$)', sql, re.IGNORECASE | re.DOTALL)
    if not where_match: return 0
    where_clause = where_match.group(1)
    and_count = len(re.findall(r'\bAND\b', where_clause, re.IGNORECASE))
    or_count = len(re.findall(r'\bOR\b', where_clause, re.IGNORECASE))
    return 1 + and_count + or_count

def categorize_error(gold_sql: str, pred_sql: str) -> Dict[str, bool]:
    gold_tables = extract_tables(gold_sql)
    pred_tables = extract_tables(pred_sql)
    gold_columns = extract_columns(gold_sql)
    pred_columns = extract_columns(pred_sql)
    error_types = {
        'table_selection': not all(table in pred_tables for table in gold_tables),
        'column_selection': not all(col in pred_columns for col in gold_columns),
        'join_error': has_join(gold_sql) != has_join(pred_sql),
        'aggregation_error': has_aggregation(gold_sql) != has_aggregation(pred_sql),
        'nesting_error': has_nesting(gold_sql) != has_nesting(pred_sql),
        'order_error': has_order(gold_sql) != has_order(pred_sql),
        'limit_error': has_limit(gold_sql) != has_limit(pred_sql),
        'union_error': has_union(gold_sql) != has_union(pred_sql),
        'condition_count_error': count_conditions(gold_sql) != count_conditions(pred_sql),
        'syntax_error': 'syntax' in pred_sql.lower() or pred_sql.count('(') != pred_sql.count(')'),
        'empty_result': len(pred_sql.strip()) == 0,
        'wrong_db_schema': any(table not in gold_tables and table.lower() not in ['t1', 't2', 't3', 't4'] for table in pred_tables if table)
    }
    return error_types

def analyze_predictions(predictions_filepath: str, analysis_output_dir: str, model_label: str):
    """Analyze predictions and errors."""
    print(f"Loading predictions from {predictions_filepath}...")
    try:
        with open(predictions_filepath, 'r', encoding='utf-8') as f:
             predictions = json.load(f)
    except FileNotFoundError:
        print(f"ERROR: Predictions file not found at {predictions_filepath}")
        return None
    except json.JSONDecodeError:
        print(f"ERROR: Could not decode JSON from {predictions_filepath}")
        return None

    print(f"Analyzing {len(predictions)} predictions")

    error_counts = defaultdict(int)
    component_counts = defaultdict(lambda: {'gold': 0, 'pred': 0, 'correct': 0})
    db_performance = defaultdict(lambda: {'correct': 0, 'total': 0})
    error_examples = defaultdict(list)

    for i, item in enumerate(tqdm(predictions, desc="Analyzing errors")):
        gold_sql = item.get('gold_sql', '')
        pred_sql = item.get('pred_sql', '')
        db_id = item.get('db_id', 'unknown_db')
        question = item.get('question', '')

        is_match = ' '.join(gold_sql.lower().split()) == ' '.join(pred_sql.lower().split())

        db_performance[db_id]['total'] += 1
        if is_match:
            db_performance[db_id]['correct'] += 1

        if not is_match:
            error_categories = categorize_error(gold_sql, pred_sql)
            for error_type, has_error in error_categories.items():
                if has_error:
                    error_counts[error_type] += 1
                    if len(error_examples[error_type]) < 5:
                         error_examples[error_type].append({
                             'id': i, 'question': question, 'db_id': db_id,
                             'gold_sql': gold_sql, 'pred_sql': pred_sql
                         })

        for component, func in [
            ('join', has_join), ('aggregation', has_aggregation), ('nesting', has_nesting),
            ('order', has_order), ('limit', has_limit), ('union', has_union)
        ]:
            has_gold = func(gold_sql)
            has_pred = func(pred_sql)
            if has_gold: component_counts[component]['gold'] += 1
            if has_pred: component_counts[component]['pred'] += 1
            if has_gold and has_pred: component_counts[component]['correct'] += 1

    # --- Reporting and Plotting ---
    total_queries = len(predictions)
    if total_queries == 0:
        print("No predictions found to analyze.")
        return None

    error_percentages = {error: count / total_queries * 100 for error, count in error_counts.items()}
    sorted_errors = sorted(error_percentages.items(), key=lambda x: x[1], reverse=True)

    # Generate report text
    report = f"Error Analysis for {model_label}\n"
    report += f"Total queries analyzed: {total_queries}\n\n"
    report += "Error Type Frequencies:\n"
    for error, percentage in sorted_errors:
        report += f"  {error}: {percentage:.2f}% ({error_counts[error]} instances)\n"

    report += "\nSQL Component Presence Analysis (Recall approximation):\n"
    component_recall = {}
    for component, counts in component_counts.items():
         recall = counts['correct'] / counts['gold'] * 100 if counts['gold'] > 0 else 0
         component_recall[component] = recall
         report += f"  {component.capitalize()}: {recall:.2f}% ({counts['correct']}/{counts['gold']})\n"

    report += "\nTop 5 Databases by Error Rate (min 5 queries):\n"
    db_error_rates = {
        db_id: 100 * (1 - stats['correct'] / stats['total'])
        for db_id, stats in db_performance.items() if stats['total'] >= 5
    }
    sorted_dbs = sorted(db_error_rates.items(), key=lambda x: x[1], reverse=True)
    for db_id, error_rate in sorted_dbs[:5]:
        stats = db_performance[db_id]
        report += f"  {db_id}: {error_rate:.2f}% errors ({stats['correct']}/{stats['total']} correct)\n"

    # Save report text file
    report_path = os.path.join(analysis_output_dir, "error_analysis_report.txt")
    try:
        with open(report_path, 'w', encoding='utf-8') as f: f.write(report)
        print(f"Report saved to {report_path}")
    except Exception as e:
        print(f"Error saving report: {e}")

    # Create error examples file
    examples_report = "Error Examples:\n\n"
    for error_type, examples in error_examples.items():
        examples_report += f"== {error_type} Examples ==\n"
        for i, example in enumerate(examples):
            examples_report += f"Example {i+1}:\n"
            examples_report += f"Question: {example.get('question', 'N/A')}\n"
            examples_report += f"DB: {example.get('db_id', 'N/A')}\n"
            examples_report += f"Gold SQL: {example.get('gold_sql', 'N/A')}\n"
            examples_report += f"Pred SQL: {example.get('pred_sql', 'N/A')}\n\n"
        examples_report += "---\n\n"
    examples_path = os.path.join(analysis_output_dir, "error_examples.txt")
    try:
        with open(examples_path, 'w', encoding='utf-8') as f: f.write(examples_report)
        print(f"Error examples saved to {examples_path}")
    except Exception as e:
        print(f"Error saving examples: {e}")

    # --- Plotting ---
    sns.set_style("whitegrid")

    # 1. Error type distribution
    if sorted_errors:
        plt.figure(figsize=(12, 7))
        error_names = [e.replace('_', ' ').title() for e, _ in sorted_errors]
        error_values = [p for _, p in sorted_errors]
        ax = sns.barplot(x=error_names, y=error_values, palette="muted")
        ax.set_title(f'Error Type Distribution ({model_label})', fontsize=14)
        ax.set_xlabel('Error Type', fontsize=12)
        ax.set_ylabel('Percentage of Incorrect Queries (%)', fontsize=12)
        plt.xticks(rotation=45, ha='right')
        plt.tight_layout()
        plot_path = os.path.join(analysis_output_dir, "error_distribution")
        try:
            plt.savefig(f"{plot_path}.png", dpi=300)
            print(f"Error distribution chart saved to {plot_path}.png")
            plt.close()
        except Exception as e:
            print(f"Error saving error distribution plot: {e}")


    # 2. SQL Component Recall
    if component_recall:
        components = sorted(component_recall.keys())
        recalls = [component_recall[c] for c in components]
        component_labels = [c.capitalize() for c in components]

        plt.figure(figsize=(10, 6))
        ax = sns.barplot(x=component_labels, y=recalls, palette="viridis")
        ax.set_title(f'SQL Component Recall ({model_label})', fontsize=14)
        ax.set_xlabel('SQL Component', fontsize=12)
        ax.set_ylabel('Recall (%)', fontsize=12)
        plt.axhline(y=50, color='r', linestyle='--', alpha=0.5)
        plt.xticks(rotation=45, ha='right')
        plt.tight_layout()
        plot_path = os.path.join(analysis_output_dir, "component_recall")
        try:
            plt.savefig(f"{plot_path}.png", dpi=300)
            print(f"Component recall chart saved to {plot_path}.png")
            plt.close()
        except Exception as e:
            print(f"Error saving component recall plot: {e}")

    # 3. Database performance (Top 10 worst error rates)
    if sorted_dbs:
        plt.figure(figsize=(12, 7))
        db_names = [db_id for db_id, _ in sorted_dbs[:10]]
        error_rates = [rate for _, rate in sorted_dbs[:10]]
        ax = sns.barplot(x=db_names, y=error_rates, palette="rocket")
        ax.set_title(f'Top 10 Databases by Error Rate ({model_label})', fontsize=14)
        ax.set_xlabel('Database ID', fontsize=12)
        ax.set_ylabel('Error Rate (%)', fontsize=12)
        plt.xticks(rotation=45, ha='right')
        plt.ylim(0, 100)
        plt.tight_layout()
        plot_path = os.path.join(analysis_output_dir, "db_error_rates")
        try:
            plt.savefig(f"{plot_path}.png", dpi=300)
            print(f"DB error rates chart saved to {plot_path}.png")
            plt.close()
        except Exception as e:
            print(f"Error saving DB error rate plot: {e}")

    # Return summary statistics dictionary
    analysis_results = {
        'total_queries': total_queries,
        'error_counts': dict(error_counts),
        'error_percentages': error_percentages,
        'component_recall': component_recall,
        'db_error_rates': db_error_rates
    }
    return analysis_results


# --- Main Execution ---
if os.path.exists(PREDICTIONS_FILE):
    print("\n--- Running Error Analysis ---")
    analysis_summary = analyze_predictions(
        predictions_filepath=PREDICTIONS_FILE,
        analysis_output_dir=ANALYSIS_OUTPUT_DIR,
        model_label=MODEL_NAME_LABEL
    )

    if analysis_summary:
        print("\n--- Analysis Summary ---")
        print(f"Total queries analyzed: {analysis_summary['total_queries']}")
        print("\nTop 3 error types:")
        summary_sorted_errors = sorted(analysis_summary['error_percentages'].items(), key=lambda x: x[1], reverse=True)
        for error, percentage in summary_sorted_errors[:3]:
            print(f"  {error.replace('_', ' ').title()}: {percentage:.2f}% ({analysis_summary['error_counts'].get(error, 0)})")

        print("\nSQL Component Recall:")
        summary_sorted_components = sorted(analysis_summary['component_recall'].items(), key=lambda x: x[1], reverse=True)
        for component, recall in summary_sorted_components:
            print(f"  {component.capitalize()}: {recall:.2f}%")

        print(f"\nFull analysis report and charts saved to: {ANALYSIS_OUTPUT_DIR}")
    else:
        print("\nAnalysis could not be completed.")
else:
    print(f"\nSkipping analysis because predictions file was not found: {PREDICTIONS_FILE}")

print("\n--- Error Analysis Script Finished ---")
print(f"Find results in your Google Drive: {ANALYSIS_OUTPUT_DIR}")