In [None]:
import json
import torch
import matplotlib.pyplot as plt
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, EarlyStoppingCallback, TrainerCallback
import os

# Configuration
MODEL_NAME = "google/flan-t5-small"
DATA_PATH = "sft_data/text_to_sql_sampled_balanced.jsonl"
OUTPUT_DIR = "./flan-t5-sql-model"
MAX_INPUT_LENGTH = 512
BATCH_SIZE = 5
EPOCHS = 5
LEARNING_RATE = 5e-5

os.makedirs(OUTPUT_DIR, exist_ok=True)  # Create it immediately
progress_file = os.path.join(OUTPUT_DIR, "training_progress.json")

# Batch processing settings
BATCH_TRAINING_SIZE = 1000  # Train on 1000 examples at a time
SAVE_EVERY_N_STEPS = 100  # Save checkpoint every 100 steps (very frequent)

# Custom callback to save frequently
class FrequentCheckpointCallback(TrainerCallback):
    def __init__(self, save_steps):
        self.save_steps = save_steps
        
    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step % self.save_steps == 0:
            control.should_save = True
        return control

# Check GPU/CPU
print(f"Using device: {'GPU' if torch.cuda.is_available() else 'CPU'}")
if torch.cuda.is_available():
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
print()

In [None]:
# Load or resume model and tokenizer
checkpoint_dir = None
if os.path.exists(OUTPUT_DIR):
    checkpoints = [f for f in os.listdir(OUTPUT_DIR) if f.startswith("checkpoint-")]
    if checkpoints:
        latest_checkpoint = max(checkpoints, key=lambda x: int(x.split("-")[1]))
        checkpoint_dir = os.path.join(OUTPUT_DIR, latest_checkpoint)
        print(f"‚úÖ Found checkpoint: {checkpoint_dir}")
        print("Loading from checkpoint...")
        tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir)
        model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint_dir)
    else:
        print("No checkpoint found. Loading base model...")
        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
else:
    print("No checkpoint found. Loading base model...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)

# Load training data
print("\nLoading training data...")
data_list = []
with open(DATA_PATH, 'r') as f:
    for line in f:
        data_list.append(json.loads(line))

print(f"Loaded {len(data_list)} total examples")

# ‚úÖ Show first record before shuffle
print("\nFirst record BEFORE shuffle:")
print(json.dumps(data_list[0], indent=2))

# ‚úÖ Shuffle the data
import random
random.seed(42)  # optional, for reproducibility
random.shuffle(data_list)

# ‚úÖ Show first record after shuffle
print("\nFirst record AFTER shuffle:")
print(json.dumps(data_list[0], indent=2))

print("\n‚úÖ Data has been shuffled randomly.\n")
print(f"Loaded {len(data_list)} total examples")

# Analyze SQL query lengths to determine MAX_TARGET_LENGTH
print("\n--- Analyzing SQL Query Lengths ---")
sql_lengths = []
for item in data_list:
    sql = item["output"]
    # Use tokenizer to get accurate token count
    tokens = len(tokenizer.encode(sql))
    sql_lengths.append(tokens)

avg_length = sum(sql_lengths) / len(sql_lengths)
max_length = max(sql_lengths)
min_length = min(sql_lengths)
percentile_95 = sorted(sql_lengths)[int(len(sql_lengths) * 0.95)]
percentile_99 = sorted(sql_lengths)[int(len(sql_lengths) * 0.99)]

print(f"Average SQL length: {avg_length:.1f} tokens")
print(f"Min SQL length: {min_length} tokens")
print(f"Max SQL length: {max_length} tokens")
print(f"95th percentile: {percentile_95} tokens")
print(f"99th percentile: {percentile_99} tokens")

# Set MAX_TARGET_LENGTH based on 99th percentile
MAX_TARGET_LENGTH = min(percentile_99 + 20, 512)
print(f"Setting MAX_TARGET_LENGTH to: {MAX_TARGET_LENGTH} tokens\n")


In [None]:
# Split data into train/val/test FIRST (before batching)
print("Creating train/val/test split...")
from sklearn.model_selection import train_test_split

# First split: 90% train, 10% test
train_val_data, test_data = train_test_split(data_list, test_size=0.1, random_state=42)

# Second split: ~89% train, ~11% val (of the 90%)
train_data, val_data = train_test_split(train_val_data, test_size=0.1111, random_state=42)

print(f"Total training examples: {len(train_data)}")
print(f"Validation examples: {len(val_data)}")
print(f"Test examples: {len(test_data)}")

# Calculate number of batches needed
num_batches = (len(train_data) + BATCH_TRAINING_SIZE - 1) // BATCH_TRAINING_SIZE
print(f"\nWill train in {num_batches} batches of ~{BATCH_TRAINING_SIZE} examples each\n")


In [None]:
# Create dataset function
def create_dataset(data_list):
    return Dataset.from_dict({
        "instruction": [d["instruction"] for d in data_list],
        "input": [d.get("input", "") for d in data_list],
        "output": [d["output"] for d in data_list]
    })

# Add these helper functions at the top of your code
def encode_sql_for_training(sql):
    """Encode SQL operators that conflict with tokenizer special tokens"""
    sql = sql.replace('<=', ' LESS_EQUAL ')
    sql = sql.replace('>=', ' GREATER_EQUAL ')
    sql = sql.replace('<', ' LESS_THAN ')
    sql = sql.replace('>', ' GREATER_THAN ')
    return ' '.join(sql.split())  # Normalize spaces

def decode_sql_from_model(sql):
    """Decode SQL operators back from training format"""
    sql = sql.replace('LESS_EQUAL', '<=')
    sql = sql.replace('GREATER_EQUAL', '>=')
    sql = sql.replace('LESS_THAN', '<')
    sql = sql.replace('GREATER_THAN', '>')
    return sql

# Update preprocessing function
def preprocess_function(examples):
    inputs = [f"Translate this to SQL: {instr} {inp}".strip() 
              for instr, inp in zip(examples["instruction"], examples["input"])]
    
    # Encode outputs to avoid tokenizer issues with < and >
    outputs = [encode_sql_for_training(sql) for sql in examples["output"]]
    
    model_inputs = tokenizer(inputs, max_length=MAX_INPUT_LENGTH, truncation=True, padding="max_length")
    labels = tokenizer(outputs, max_length=MAX_TARGET_LENGTH, truncation=True, padding="max_length")
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
##if training stopped by user
import json
import os

progress_file = "./training_progress.json"

# Set which batch you want to start from
restart_batch = 5  # Will start training batch 5 (skipping batches 1-4)

with open(progress_file, 'w') as f:
    json.dump({
        "completed_batches": restart_batch - 1,  # -1 because we count completed batches
        "current_batch_started": False,
        "all_losses": [],
        "total_batches": 8,  # Or however many batches you have
        "last_batch_max_checkpoint": 0
    }, f)

print(f"‚úÖ Training will restart from batch {restart_batch}")
print("Now run your training code")

In [None]:
# Create validation and test datasets (these stay constant)
print("Preparing validation and test datasets...")
val_dataset = create_dataset(val_data)
val_dataset = val_dataset.map(preprocess_function, batched=True, remove_columns=val_dataset.column_names)
test_dataset = create_dataset(test_data)
test_dataset = test_dataset.map(preprocess_function, batched=True, remove_columns=test_dataset.column_names)
# Training loop - process data in batches
all_losses = []
completed_batches = 0
current_batch_started = False
last_batch_max_checkpoint = 0  # Track checkpoint numbers per batch


if os.path.exists(progress_file):
    with open(progress_file, 'r') as f:
        progress = json.load(f)
        completed_batches = progress.get("completed_batches", 0)
        current_batch_started = progress.get("current_batch_started", False)
        all_losses = progress.get("all_losses", [])
        last_batch_max_checkpoint = progress.get("last_batch_max_checkpoint", 0)
    print(f"‚úÖ Resuming from batch {completed_batches + 1}/{num_batches}\n")

for batch_idx in range(completed_batches, num_batches):
    print(f"\n{'='*80}")
    print(f"TRAINING BATCH {batch_idx + 1}/{num_batches}")
    print(f"{'='*80}\n")
    
    # Get batch of training data
    start_idx = batch_idx * BATCH_TRAINING_SIZE
    end_idx = min(start_idx + BATCH_TRAINING_SIZE, len(train_data))
    batch_train_data = train_data[start_idx:end_idx]
    
    print(f"Training on examples {start_idx} to {end_idx} ({len(batch_train_data)} examples)")
    
    # Create training dataset for this batch
    train_dataset_batch = create_dataset(batch_train_data)
    train_dataset_batch = train_dataset_batch.map(preprocess_function, batched=True, remove_columns=train_dataset_batch.column_names)
    
    # Data collator
    data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
    
    # Determine if we're resuming or starting fresh
    resume_from_checkpoint = None
    
    # CASE 1: We're resuming the SAME batch after a crash
    if current_batch_started and batch_idx == completed_batches:
        # Look for checkpoint to resume
        if os.path.exists(OUTPUT_DIR):
            checkpoints = [f for f in os.listdir(OUTPUT_DIR) if f.startswith("checkpoint-")]
            if checkpoints:
                latest_checkpoint = max(checkpoints, key=lambda x: int(x.split("-")[1]))
                resume_from_checkpoint = os.path.join(OUTPUT_DIR, latest_checkpoint)
                print(f"üìå Resuming batch {batch_idx + 1} from checkpoint: {resume_from_checkpoint}")
    
    # CASE 2: Starting a NEW batch
    else:
        # Delete old checkpoints from PREVIOUS batch
        if os.path.exists(OUTPUT_DIR) and batch_idx > 0:
            import shutil
            checkpoints = [f for f in os.listdir(OUTPUT_DIR) if f.startswith("checkpoint-")]
            for ckpt in checkpoints:
                ckpt_num = int(ckpt.split("-")[1])
                # Only delete if it's from the previous batch
                if ckpt_num <= last_batch_max_checkpoint:
                    ckpt_path = os.path.join(OUTPUT_DIR, ckpt)
                    try:
                        shutil.rmtree(ckpt_path)
                        print(f"üßπ Cleaned up old checkpoint: {ckpt}")
                    except:
                        pass
        print(f"üÜï Starting fresh training for batch {batch_idx + 1}")
    
    # Mark that we've started this batch (for crash recovery)
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    with open(progress_file, 'w') as f:
        json.dump({
            "completed_batches": completed_batches,
            "current_batch_started": True,
            "all_losses": all_losses,
            "total_batches": num_batches,
            "last_batch_max_checkpoint": last_batch_max_checkpoint
        }, f)
    
    # Training arguments
    training_args = Seq2SeqTrainingArguments(
        output_dir=OUTPUT_DIR,
        num_train_epochs=EPOCHS,
        per_device_train_batch_size=BATCH_SIZE,
        per_device_eval_batch_size=BATCH_SIZE,
        learning_rate=LEARNING_RATE,
        weight_decay=0.01,
        save_total_limit=2,
        eval_strategy="epoch",
        save_strategy="steps",
        save_steps=SAVE_EVERY_N_STEPS,
        load_best_model_at_end=False,
        logging_steps=50,
        push_to_hub=False,
    )
    
    # Trainer
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset_batch,
        eval_dataset=val_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
        callbacks=[FrequentCheckpointCallback(save_steps=SAVE_EVERY_N_STEPS)]
    )
    
    # Train this batch
    print(f"Starting training for batch {batch_idx + 1}...")
    
    try:
        train_result = trainer.train(resume_from_checkpoint=resume_from_checkpoint)
        
        # Find max checkpoint number from this batch
        if os.path.exists(OUTPUT_DIR):
            checkpoints = [f for f in os.listdir(OUTPUT_DIR) if f.startswith("checkpoint-")]
            if checkpoints:
                last_batch_max_checkpoint = max([int(c.split("-")[1]) for c in checkpoints])
        
        # Batch completed successfully
        completed_batches = batch_idx + 1
        
        # Collect losses
        logs = trainer.state.log_history
        for log in logs:
            if "loss" in log or "eval_loss" in log:
                all_losses.append(log)
        
        # Save progress - mark batch as complete, reset started flag
        os.makedirs(OUTPUT_DIR, exist_ok=True)
        with open(progress_file, 'w') as f:
            json.dump({
                "completed_batches": completed_batches,
                "current_batch_started": False,
                "all_losses": all_losses,
                "total_batches": num_batches,
                "last_batch_max_checkpoint": last_batch_max_checkpoint
            }, f)
        
        print(f"\n‚úÖ Batch {batch_idx + 1}/{num_batches} complete!")
        print(f"Progress saved. Safe to stop/restart anytime.\n")
        
    except Exception as e:

        
        print(f"\n‚ùå Error during training: {e}")
        print(f"Progress saved. Restart to resume from checkpoint.")
        raise

In [None]:
# Configuration to train everything all at once
MODEL_NAME = "google/flan-t5-small"
DATA_PATH = "sft_data/text_to_sql_sampled_balanced.jsonl"
OUTPUT_DIR = "./flan-t5-sql-model"
MAX_INPUT_LENGTH = 512
MAX_TARGET_LENGTH = 256
BATCH_SIZE = 8
EPOCHS = 10
LEARNING_RATE = 3e-4

# Load all data at once
with open(DATA_PATH, 'r', encoding='utf-8') as f:
    all_data = [json.loads(line) for line in f]

# Split
random.shuffle(all_data)
train_size = int(0.8 * len(all_data))
val_size = int(0.1 * len(all_data))

train_data = all_data[:train_size]
val_data = all_data[train_size:train_size+val_size]
test_data = all_data[train_size+val_size:]

print(f"Train: {len(train_data)}, Val: {len(val_data)}, Test: {len(test_data)}")

# Create datasets
train_dataset = create_dataset(train_data)
train_dataset = train_dataset.map(preprocess_function, batched=True, remove_columns=train_dataset.column_names)

val_dataset = create_dataset(val_data)
val_dataset = val_dataset.map(preprocess_function, batched=True, remove_columns=val_dataset.column_names)

test_dataset = create_dataset(test_data)
test_dataset = test_dataset.map(preprocess_function, batched=True, remove_columns=test_dataset.column_names)

# Training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    weight_decay=0.01,
    save_total_limit=2,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    logging_steps=50,
    push_to_hub=False,
)

# Data collator
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

# Trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

# Train
print("Starting training...")
train_result = trainer.train()

# Save final model
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print(f"‚úÖ Model saved to {OUTPUT_DIR}")

In [None]:
print("\n" + "="*80)
print("ALL BATCHES COMPLETE!")
print("="*80 + "\n")

# Save final model
print("Saving final model...")
model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print(f"‚úÖ Model saved to {OUTPUT_DIR}\n")

# Evaluate on test set
print("--- Evaluating on Test Set ---")

# Create new training args for evaluation only
eval_args = Seq2SeqTrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_eval_batch_size=BATCH_SIZE,
    eval_strategy="no",  # No eval strategy needed for just testing
)

trainer = Seq2SeqTrainer(
    model=model,
    args=eval_args,
    tokenizer=tokenizer,
    data_collator=data_collator,
)
test_results = trainer.evaluate(eval_dataset=test_dataset)
print(f"Test Loss: {test_results['eval_loss']:.4f}\n")


In [None]:
# Plot all losses
print("--- Generating Loss Plots ---")
train_losses = [log["loss"] for log in all_losses if "loss" in log]
eval_losses = [log["eval_loss"] for log in all_losses if "eval_loss" in log]

if train_losses:
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, marker='o', linewidth=1, markersize=3)
    plt.xlabel('Training Step')
    plt.ylabel('Training Loss')
    plt.title('Training Loss Over Time')
    plt.grid(True, alpha=0.3)
    
    plt.subplot(1, 2, 2)
    if eval_losses:
        plt.plot(eval_losses, marker='s', linewidth=2, markersize=6, color='orange')
        plt.xlabel('Evaluation Step')
        plt.ylabel('Validation Loss')
        plt.title('Validation Loss Over Time')
        plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

In [None]:
# check train test adn val set for uniqueness of query.  adress with data generator.

In [None]:
# Show sample predictions
print("\n--- Sample Predictions from Test Set ---")
import random
import sqlite3
import pandas as pd

# Connect to database
conn = sqlite3.connect("customer_data.db")

random.seed(42)
sample_indices = random.sample(range(len(test_dataset)), min(10, len(test_dataset)))

for idx in sample_indices:
    sample = test_dataset[idx]
    input_ids = torch.tensor([sample["input_ids"]]).to(model.device)
    outputs = model.generate(input_ids, max_length=MAX_TARGET_LENGTH)
    predicted_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
    actual_sql = tokenizer.decode(sample["labels"], skip_special_tokens=True)
    
    # DECODE the operators back to SQL format
    predicted_sql = decode_sql_from_model(predicted_sql)
    actual_sql = decode_sql_from_model(actual_sql)
    
    nl_question = test_data[idx]["instruction"]
    
    print(f"\n{'='*80}")
    print(f"Example {idx}:")
    print(f"\nNatural Language: {nl_question}")
    print(f"\nPredicted SQL: {predicted_sql}")
    print(f"\nActual SQL:    {actual_sql}")
    # Try to execute predicted SQL
    try:
        result_df = pd.read_sql_query(predicted_sql, conn)
        print(f"\n‚úÖ Predicted Query Results ({len(result_df)} rows):")
        if len(result_df) > 0:
            # Show first 5 rows
            print(result_df.head().to_string(index=False))
            if len(result_df) > 5:
                print(f"... (showing 5 of {len(result_df)} rows)")
        else:
            print("  No results returned")
    except Exception as e:
        print(f"\n‚ùå Predicted query failed: {str(e)}")
    
    # Try to execute actual SQL for comparison
    try:
        actual_result_df = pd.read_sql_query(actual_sql, conn)
        print(f"\nüìã Actual Query Results ({len(actual_result_df)} rows):")
        if len(actual_result_df) > 0:
            print(actual_result_df.head().to_string(index=False))
            if len(actual_result_df) > 5:
                print(f"... (showing 5 of {len(actual_result_df)} rows)")
        else:
            print("  No results returned")
        
        # Compare results
        if len(result_df) == len(actual_result_df) and result_df.equals(actual_result_df):
            print("\n‚úÖ MATCH: Predicted and actual queries return identical results!")
        else:
            print(f"\n‚ö†Ô∏è  MISMATCH: Different results (Predicted: {len(result_df)} rows, Actual: {len(actual_result_df)} rows)")
    except Exception as e:
        print(f"\n‚ùå Actual query failed: {str(e)}")
    
    print(f"{'='*80}")

conn.close()
print("\n‚úÖ All training complete!")




In [None]:
# Show sample predictions (only for queries containing < or >)
print("\n--- Sample Predictions from Test Set (Filtered for < or >) ---")
import random
import sqlite3
import pandas as pd

# Connect to database
conn = sqlite3.connect("customer_data.db")

# Filter indices to only include SQL queries with < or >
filtered_indices = [
    i for i in range(len(test_dataset))
    if "<" in decode_sql_from_model(tokenizer.decode(test_dataset[i]["labels"], skip_special_tokens=True))
    or ">" in decode_sql_from_model(tokenizer.decode(test_dataset[i]["labels"], skip_special_tokens=True))
]

if not filtered_indices:
    print("‚ö†Ô∏è No queries with < or > found in test set.")
else:
    random.seed(42)
    sample_indices = random.sample(filtered_indices, min(10, len(filtered_indices)))

    for idx in sample_indices:
        sample = test_dataset[idx]
        input_ids = torch.tensor([sample["input_ids"]]).to(model.device)
        outputs = model.generate(input_ids, max_length=MAX_TARGET_LENGTH)
        predicted_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
        actual_sql = tokenizer.decode(sample["labels"], skip_special_tokens=True)

        # Decode operators back to SQL format
        predicted_sql = decode_sql_from_model(predicted_sql)
        actual_sql = decode_sql_from_model(actual_sql)

        nl_question = test_data[idx]["instruction"]

        print(f"\n{'='*80}")
        print(f"Example {idx}:")
        print(f"\nNatural Language: {nl_question}")
        print(f"\nPredicted SQL: {predicted_sql}")
        print(f"\nActual SQL:    {actual_sql}")

        # Try to execute predicted SQL
        try:
            result_df = pd.read_sql_query(predicted_sql, conn)
            print(f"\n‚úÖ Predicted Query Results ({len(result_df)} rows):")
            if len(result_df) > 0:
                print(result_df.head().to_string(index=False))
                if len(result_df) > 5:
                    print(f"... (showing 5 of {len(result_df)} rows)")
            else:
                print("  No results returned")
        except Exception as e:
            print(f"\n‚ùå Predicted query failed: {str(e)}")

        # Try to execute actual SQL for comparison
        try:
            actual_result_df = pd.read_sql_query(actual_sql, conn)
            print(f"\nüìã Actual Query Results ({len(actual_result_df)} rows):")
            if len(actual_result_df) > 0:
                print(actual_result_df.head().to_string(index=False))
                if len(actual_result_df) > 5:
                    print(f"... (showing 5 of {len(actual_result_df)} rows)")
            else:
                print("  No results returned")

            # Compare results
            if len(result_df) == len(actual_result_df) and result_df.equals(actual_result_df):
                print("\n‚úÖ MATCH: Predicted and actual queries return identical results!")
            else:
                print(f"\n‚ö†Ô∏è  MISMATCH: Different results (Predicted: {len(result_df)} rows, Actual: {len(actual_result_df)} rows)")
        except Exception as e:
            print(f"\n‚ùå Actual query failed: {str(e)}")

        print(f"{'='*80}")

conn.close()
print("\n‚úÖ All training complete!")


In [None]:
# Comprehensive Test Set Analysis with Query Type Labeling
print("\n" + "="*80)
print("COMPREHENSIVE TEST SET ANALYSIS")
print("="*80)

import random
import sqlite3
import pandas as pd
import re

# Connect to database
conn = sqlite3.connect("customer_data.db")

# Helper function to categorize queries
def categorize_query(sql):
    """Categorize SQL query by type and complexity"""
    sql_upper = sql.upper()
    categories = []
    
    # Main query types
    if 'WITH' in sql_upper:
        categories.append('CTE')
    if 'CASE WHEN' in sql_upper:
        categories.append('CASE_WHEN')
    if 'JOIN' in sql_upper:
        categories.append('JOIN')
    if 'GROUP BY' in sql_upper:
        categories.append('GROUP_BY')
    if 'ORDER BY' in sql_upper:
        categories.append('ORDER_BY')
    if 'ROW_NUMBER()' in sql_upper or 'PARTITION BY' in sql_upper:
        categories.append('WINDOW_FUNCTION')
    
    # Aggregation types
    if 'AVG(' in sql_upper:
        categories.append('AVG')
    if 'MAX(' in sql_upper:
        categories.append('MAX')
    if 'MIN(' in sql_upper:
        categories.append('MIN')
    if 'SUM(' in sql_upper:
        categories.append('SUM')
    if 'COUNT(' in sql_upper:
        categories.append('COUNT')
    
    # Operators
    if 'LESS_THAN' in sql or '<' in sql:
        categories.append('LESS_THAN_OP')
    if 'GREATER_THAN' in sql or '>' in sql:
        categories.append('GREATER_THAN_OP')
    if 'LESS_EQUAL' in sql or '<=' in sql:
        categories.append('LESS_EQUAL_OP')
    if 'GREATER_EQUAL' in sql or '>=' in sql:
        categories.append('GREATER_EQUAL_OP')
    
    # Simple queries
    if len(categories) == 0:
        if 'SELECT * FROM' in sql_upper and 'WHERE' in sql_upper:
            categories.append('SIMPLE_FILTER')
        elif 'SELECT DISTINCT' in sql_upper:
            categories.append('DISTINCT')
        elif 'SELECT COUNT(*)' in sql_upper:
            categories.append('SIMPLE_COUNT')
        else:
            categories.append('SIMPLE_SELECT')
    
    # Complexity level
    complexity = 'SIMPLE'
    if len(categories) >= 3:
        complexity = 'COMPLEX'
    elif len(categories) >= 2:
        complexity = 'MEDIUM'
    
    return categories, complexity

# Helper function to check if queries match semantically
def queries_match_semantically(predicted, actual):
    """Check if two queries are semantically equivalent"""
    # Exact match
    if predicted.strip() == actual.strip():
        return True, "EXACT_MATCH"
    
    # Normalize whitespace and compare
    pred_normalized = ' '.join(predicted.split())
    actual_normalized = ' '.join(actual.split())
    if pred_normalized == actual_normalized:
        return True, "WHITESPACE_DIFF"
    
    return False, "DIFFERENT"

# Helper function to try executing and compare results
def execute_and_compare(predicted_sql, actual_sql, conn):
    """Execute both queries and compare results"""
    try:
        pred_result = pd.read_sql_query(predicted_sql, conn)
        pred_success = True
    except Exception as e:
        pred_result = None
        pred_success = False
        pred_error = str(e)
    
    try:
        actual_result = pd.read_sql_query(actual_sql, conn)
        actual_success = True
    except Exception as e:
        actual_result = None
        actual_success = False
        actual_error = str(e)
    
    if pred_success and actual_success:
        # Compare results
        if len(pred_result) == len(actual_result):
            if pred_result.equals(actual_result):
                return "RESULTS_MATCH", True
            else:
                return "RESULTS_DIFFER", False
        else:
            return f"ROW_COUNT_DIFF ({len(pred_result)} vs {len(actual_result)})", False
    elif not pred_success and not actual_success:
        return "BOTH_FAILED", False
    elif not pred_success:
        return f"PRED_FAILED: {pred_error[:50]}", False
    else:
        return f"ACTUAL_FAILED: {actual_error[:50]}", False

# Analyze all test samples
random.seed(42)
all_indices = list(range(len(test_dataset)))

results = []

print("\nAnalyzing all test samples...")
for idx in all_indices:
    sample = test_dataset[idx]
    input_ids = torch.tensor([sample["input_ids"]]).to(model.device)
    outputs = model.generate(input_ids, max_length=MAX_TARGET_LENGTH)
    
    predicted_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
    actual_sql = tokenizer.decode(sample["labels"], skip_special_tokens=True)
    
    # Decode operators
    predicted_sql = decode_sql_from_model(predicted_sql)
    actual_sql = decode_sql_from_model(actual_sql)
    
    nl_question = test_data[idx]["instruction"]
    
    # Categorize
    actual_categories, actual_complexity = categorize_query(actual_sql)
    predicted_categories, predicted_complexity = categorize_query(predicted_sql)
    
    # Check if match
    is_match, match_type = queries_match_semantically(predicted_sql, actual_sql)
    
    # Execute and compare
    exec_result, results_match = execute_and_compare(predicted_sql, actual_sql, conn)
    
    results.append({
        'idx': idx,
        'question': nl_question,
        'predicted_sql': predicted_sql,
        'actual_sql': actual_sql,
        'actual_categories': actual_categories,
        'actual_complexity': actual_complexity,
        'predicted_categories': predicted_categories,
        'is_exact_match': is_match,
        'match_type': match_type,
        'exec_result': exec_result,
        'results_match': results_match
    })

# Convert to DataFrame for analysis
df_results = pd.DataFrame(results)

print("\n" + "="*80)
print("SUMMARY STATISTICS")
print("="*80)

total = len(df_results)
exact_matches = df_results['is_exact_match'].sum()
result_matches = df_results['results_match'].sum()

print(f"\nTotal test samples: {total}")
print(f"Exact SQL matches: {exact_matches} ({100*exact_matches/total:.1f}%)")
print(f"Semantic matches (same results): {result_matches} ({100*result_matches/total:.1f}%)")

print("\n" + "-"*80)
print("ACCURACY BY QUERY COMPLEXITY")
print("-"*80)

for complexity in ['SIMPLE', 'MEDIUM', 'COMPLEX']:
    subset = df_results[df_results['actual_complexity'] == complexity]
    if len(subset) > 0:
        acc = subset['results_match'].sum()
        print(f"{complexity:12} : {acc}/{len(subset)} ({100*acc/len(subset):.1f}%)")

print("\n" + "-"*80)
print("ACCURACY BY QUERY TYPE")
print("-"*80)

# Get all unique categories
all_categories = set()
for cats in df_results['actual_categories']:
    all_categories.update(cats)

category_stats = []
for cat in sorted(all_categories):
    subset = df_results[df_results['actual_categories'].apply(lambda x: cat in x)]
    if len(subset) > 0:
        correct = subset['results_match'].sum()
        total_cat = len(subset)
        accuracy = 100 * correct / total_cat
        category_stats.append({
            'category': cat,
            'total': total_cat,
            'correct': correct,
            'accuracy': accuracy
        })

df_cat_stats = pd.DataFrame(category_stats).sort_values('accuracy')
print(df_cat_stats.to_string(index=False))

print("\n" + "-"*80)
print("LOWEST PERFORMING QUERY TYPES (Need More Training Data)")
print("-"*80)

worst_performing = df_cat_stats[df_cat_stats['accuracy'] < 85].sort_values('accuracy')
if len(worst_performing) > 0:
    print(worst_performing.to_string(index=False))
    print("\n‚ö†Ô∏è  Focus training data generation on these query types!")
else:
    print("‚úÖ All query types performing above 85%!")

print("\n" + "-"*80)
print("SAMPLE FAILURES BY TYPE")
print("-"*80)

# Show examples of failures for each low-performing category
for _, row in worst_performing.head(5).iterrows():
    cat = row['category']
    print(f"\nüìå Category: {cat} (Accuracy: {row['accuracy']:.1f}%)")
    
    # Get a failure example
    failures = df_results[
        (df_results['actual_categories'].apply(lambda x: cat in x)) & 
        (~df_results['results_match'])
    ]
    
    if len(failures) > 0:
        example = failures.iloc[0]
        print(f"   Question: {example['question'][:80]}...")
        print(f"   Expected: {example['actual_sql'][:100]}...")
        print(f"   Got:      {example['predicted_sql'][:100]}...")
        print(f"   Issue:    {example['exec_result']}")

print("\n" + "="*80)
print("DETAILED FAILURE ANALYSIS")
print("="*80)

failures = df_results[~df_results['results_match']].copy()
print(f"\nTotal failures: {len(failures)}")

if len(failures) > 0:
    print("\nFailure reasons:")
    failure_reasons = failures['exec_result'].value_counts()
    for reason, count in failure_reasons.items():
        print(f"  {reason}: {count}")
    
    print("\n" + "-"*80)
    print("SHOWING 10 RANDOM FAILURE EXAMPLES")
    print("-"*80)
    
    sample_failures = failures.sample(min(10, len(failures)), random_state=42)
    
    for i, (_, row) in enumerate(sample_failures.iterrows(), 1):
        print(f"\n{'='*80}")
        print(f"FAILURE EXAMPLE {i} (Categories: {', '.join(row['actual_categories'])})")
        print(f"{'='*80}")
        print(f"Question: {row['question']}")
        print(f"\nExpected SQL:\n{row['actual_sql']}")
        print(f"\nPredicted SQL:\n{row['predicted_sql']}")
        print(f"\nIssue: {row['exec_result']}")
        print(f"Complexity: {row['actual_complexity']}")

conn.close()

# Save detailed results to CSV
output_csv = "test_set_analysis.csv"
df_results.to_csv(output_csv, index=False)
print(f"\n‚úÖ Detailed results saved to {output_csv}")

# Save category statistics
cat_stats_csv = "category_performance.csv"
df_cat_stats.to_csv(cat_stats_csv, index=False)
print(f"‚úÖ Category statistics saved to {cat_stats_csv}")

In [None]:
# ==========================================
# RETRAIN EXISTING MODEL WITH ENHANCED DATA
# ==========================================

import json
import torch
import matplotlib.pyplot as plt
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, EarlyStoppingCallback, TrainerCallback
import os
import random

# Configuration
MODEL_DIR = "./flan-t5-sql-model"  # Your existing trained model
ENHANCED_DATA_PATH = "sft_data/text_to_sql_enhanced.jsonl"
OUTPUT_DIR = "./flan-t5-sql-model-v2"  # Save improved model here
MAX_INPUT_LENGTH = 512
MAX_TARGET_LENGTH = 256
BATCH_SIZE = 8
EPOCHS = 5  # Fewer epochs since continuing training
LEARNING_RATE = 1e-4  # Lower LR for fine-tuning
SAVE_EVERY_N_STEPS = 200

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load your EXISTING trained model
print("\n" + "="*80)
print("LOADING EXISTING TRAINED MODEL")
print("="*80)
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_DIR)
model.to(device)
print(f"‚úÖ Loaded model from {MODEL_DIR}")

# Load enhanced dataset
print("\n" + "="*80)
print("LOADING ENHANCED DATASET")
print("="*80)
with open(ENHANCED_DATA_PATH, 'r', encoding='utf-8') as f:
    all_data = [json.loads(line) for line in f]

print(f"Total samples: {len(all_data)}")

# Shuffle and split
random.seed(42)
random.shuffle(all_data)
train_size = int(0.8 * len(all_data))
val_size = int(0.1 * len(all_data))

train_data = all_data[:train_size]
val_data = all_data[train_size:train_size+val_size]
test_data = all_data[train_size+val_size:]

print(f"  Train: {len(train_data)} samples")
print(f"  Val: {len(val_data)} samples")
print(f"  Test: {len(test_data)} samples")

# Helper functions
def create_dataset(data_list):
    return Dataset.from_dict({
        "instruction": [d["instruction"] for d in data_list],
        "input": [d.get("input", "") for d in data_list],
        "output": [d["output"] for d in data_list]
    })

def preprocess_function(examples):
    inputs = [f"Translate this to SQL: {instr} {inp}".strip() 
              for instr, inp in zip(examples["instruction"], examples["input"])]
    
    model_inputs = tokenizer(inputs, max_length=MAX_INPUT_LENGTH, truncation=True, padding="max_length")
    labels = tokenizer(examples["output"], max_length=MAX_TARGET_LENGTH, truncation=True, padding="max_length")
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# Prepare datasets
print("\n" + "="*80)
print("PREPARING DATASETS")
print("="*80)
train_dataset = create_dataset(train_data)
train_dataset = train_dataset.map(preprocess_function, batched=True, remove_columns=train_dataset.column_names)

val_dataset = create_dataset(val_data)
val_dataset = val_dataset.map(preprocess_function, batched=True, remove_columns=val_dataset.column_names)

test_dataset = create_dataset(test_data)
test_dataset = test_dataset.map(preprocess_function, batched=True, remove_columns=test_dataset.column_names)

print("‚úÖ Datasets prepared")

# Callback for frequent checkpoints
class FrequentCheckpointCallback(TrainerCallback):
    def __init__(self, save_steps):
        self.save_steps = save_steps
    
    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step % self.save_steps == 0:
            control.should_save = True
        return control

# Training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    weight_decay=0.01,
    save_total_limit=2,
    eval_strategy="steps",  # Changed from "epoch"
    eval_steps=SAVE_EVERY_N_STEPS,  # Added this
    save_strategy="steps",
    save_steps=SAVE_EVERY_N_STEPS,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    logging_steps=50,
    push_to_hub=False,
    save_safetensors=False,
)

# Data collator
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

# Trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    callbacks=[FrequentCheckpointCallback(save_steps=SAVE_EVERY_N_STEPS)]
)

# Train (continue training)
print("\n" + "="*80)
print("üöÄ STARTING CONTINUED TRAINING WITH ENHANCED DATASET")
print("="*80)
print(f"Training on {len(train_dataset)} samples for {EPOCHS} epochs")
print(f"Lower learning rate: {LEARNING_RATE} (for fine-tuning)")
print("="*80 + "\n")

train_result = trainer.train()

# Save final model
print("\n" + "="*80)
print("SAVING ENHANCED MODEL")
print("="*80)
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print(f"‚úÖ Enhanced model saved to {OUTPUT_DIR}")

# Evaluate on test set
print("\n" + "="*80)
print("EVALUATING ON TEST SET")
print("="*80)
test_results = trainer.evaluate(test_dataset)
print(f"Test Loss: {test_results['eval_loss']:.4f}")

# Show training history
print("\n" + "="*80)
print("TRAINING SUMMARY")
print("="*80)
logs = trainer.state.log_history
train_losses = [x.get('loss') for x in logs if 'loss' in x]
eval_losses = [x.get('eval_loss') for x in logs if 'eval_loss' in x]

if train_losses:
    print(f"Initial training loss: {train_losses[0]:.4f}")
    print(f"Final training loss: {train_losses[-1]:.4f}")
    print(f"Loss reduction: {train_losses[0] - train_losses[-1]:.4f}")

if eval_losses:
    print(f"\nInitial validation loss: {eval_losses[0]:.4f}")
    print(f"Final validation loss: {eval_losses[-1]:.4f}")
    print(f"Val loss reduction: {eval_losses[0] - eval_losses[-1]:.4f}")

print(f"\nTest loss: {test_results['eval_loss']:.4f}")

# Plot training curves
if len(train_losses) > 0 and len(eval_losses) > 0:
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Training Loss', marker='o')
    plt.xlabel('Checkpoint')
    plt.ylabel('Loss')
    plt.title('Training Loss Over Time')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.subplot(1, 2, 2)
    plt.plot(eval_losses, label='Validation Loss', marker='s', color='orange')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Validation Loss Over Time')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(f'{OUTPUT_DIR}/training_curves.png', dpi=150, bbox_inches='tight')
    print(f"\n‚úÖ Training curves saved to {OUTPUT_DIR}/training_curves.png")
    plt.show()

print("\n" + "="*80)
print("‚úÖ RETRAINING COMPLETE!")
print("="*80)
print(f"Original model: {MODEL_DIR}")
print(f"Enhanced model: {OUTPUT_DIR}")
print(f"\nExpected improvements:")
print("  ‚Ä¢ GREATER_EQUAL: 42% ‚Üí 75%+")
print("  ‚Ä¢ LESS_EQUAL: 67% ‚Üí 80%+")
print("  ‚Ä¢ CTE: 66% ‚Üí 78%+")
print("  ‚Ä¢ Overall: 90% ‚Üí 93%+")
print("\nRun your test evaluation script to verify improvements!")
print("="*80)

In [None]:
print("\n" + "="*80)
print("COMPREHENSIVE TEST SET ANALYSIS - UPDATED MODEL")
print("="*80)

import random
import sqlite3
import pandas as pd
import re
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import Dataset
import json

# Configuration
MODEL_DIR = "./flan-t5-sql-model-v2"  # Your newly trained model
TEST_DATA_PATH = "sft_data/text_to_sql_enhanced.jsonl"  # Your test data
MAX_TARGET_LENGTH = 256
MAX_INPUT_LENGTH = 512

# Connect to database
conn = sqlite3.connect("customer_data.db")

# Load model
print(f"\nLoading updated model from {MODEL_DIR}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_DIR)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
print(f"‚úÖ Model loaded on {device}")

# Load test data
print(f"\nLoading test data...")
with open(TEST_DATA_PATH, 'r', encoding='utf-8') as f:
    test_data = [json.loads(line) for line in f]

# Use a subset for testing (or use your original test split)
random.seed(42)
random.shuffle(test_data)
test_data = test_data[:200]  # Adjust size as needed
print(f"‚úÖ Using {len(test_data)} test samples")

# Prepare test dataset
def create_dataset(data_list):
    return Dataset.from_dict({
        "instruction": [d["instruction"] for d in data_list],
        "input": [d.get("input", "") for d in data_list],
        "output": [d["output"] for d in data_list]
    })

def preprocess_function(examples):
    inputs = [f"Translate this to SQL: {instr} {inp}".strip() 
              for instr, inp in zip(examples["instruction"], examples["input"])]
    
    model_inputs = tokenizer(inputs, max_length=MAX_INPUT_LENGTH, truncation=True, padding="max_length")
    labels = tokenizer(examples["output"], max_length=MAX_TARGET_LENGTH, truncation=True, padding="max_length")
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

test_dataset = create_dataset(test_data)
test_dataset = test_dataset.map(preprocess_function, batched=True, remove_columns=test_dataset.column_names)

# Decode SQL operators
def decode_sql_from_model(sql):
    """Decode SQL operators back to standard form"""
    sql = sql.replace(' LESS_EQUAL ', ' <= ')
    sql = sql.replace(' GREATER_EQUAL ', ' >= ')
    sql = sql.replace(' LESS_THAN ', ' < ')
    sql = sql.replace(' GREATER_THAN ', ' > ')
    return sql

# Helper function to categorize queries
def categorize_query(sql):
    """Categorize SQL query by type and complexity"""
    sql_upper = sql.upper()
    categories = []
    
    # Main query types
    if 'WITH' in sql_upper:
        categories.append('CTE')
    if 'CASE WHEN' in sql_upper:
        categories.append('CASE_WHEN')
    if 'JOIN' in sql_upper:
        categories.append('JOIN')
    if 'GROUP BY' in sql_upper:
        categories.append('GROUP_BY')
    if 'ORDER BY' in sql_upper:
        categories.append('ORDER_BY')
    if 'ROW_NUMBER()' in sql_upper or 'PARTITION BY' in sql_upper:
        categories.append('WINDOW_FUNCTION')
    
    # Aggregation types
    if 'AVG(' in sql_upper:
        categories.append('AVG')
    if 'MAX(' in sql_upper:
        categories.append('MAX')
    if 'MIN(' in sql_upper:
        categories.append('MIN')
    if 'SUM(' in sql_upper:
        categories.append('SUM')
    if 'COUNT(' in sql_upper:
        categories.append('COUNT')
    
    # Operators
    if 'LESS_THAN' in sql or '<' in sql:
        categories.append('LESS_THAN_OP')
    if 'GREATER_THAN' in sql or '>' in sql:
        categories.append('GREATER_THAN_OP')
    if 'LESS_EQUAL' in sql or '<=' in sql:
        categories.append('LESS_EQUAL_OP')
    if 'GREATER_EQUAL' in sql or '>=' in sql:
        categories.append('GREATER_EQUAL_OP')
    
    # Simple queries
    if len(categories) == 0:
        if 'SELECT * FROM' in sql_upper and 'WHERE' in sql_upper:
            categories.append('SIMPLE_FILTER')
        elif 'SELECT DISTINCT' in sql_upper:
            categories.append('DISTINCT')
        elif 'SELECT COUNT(*)' in sql_upper:
            categories.append('SIMPLE_COUNT')
        else:
            categories.append('SIMPLE_SELECT')
    
    # Complexity level
    complexity = 'SIMPLE'
    if len(categories) >= 3:
        complexity = 'COMPLEX'
    elif len(categories) >= 2:
        complexity = 'MEDIUM'
    
    return categories, complexity

# Helper function to check if queries match semantically
def queries_match_semantically(predicted, actual):
    """Check if two queries are semantically equivalent"""
    # Exact match
    if predicted.strip() == actual.strip():
        return True, "EXACT_MATCH"
    
    # Normalize whitespace and compare
    pred_normalized = ' '.join(predicted.split())
    actual_normalized = ' '.join(actual.split())
    if pred_normalized == actual_normalized:
        return True, "WHITESPACE_DIFF"
    
    return False, "DIFFERENT"

# Helper function to try executing and compare results
def execute_and_compare(predicted_sql, actual_sql, conn):
    """Execute both queries and compare results"""
    try:
        pred_result = pd.read_sql_query(predicted_sql, conn)
        pred_success = True
    except Exception as e:
        pred_result = None
        pred_success = False
        pred_error = str(e)
    
    try:
        actual_result = pd.read_sql_query(actual_sql, conn)
        actual_success = True
    except Exception as e:
        actual_result = None
        actual_success = False
        actual_error = str(e)
    
    if pred_success and actual_success:
        # Compare results
        if len(pred_result) == len(actual_result):
            if pred_result.equals(actual_result):
                return "RESULTS_MATCH", True
            else:
                return "RESULTS_DIFFER", False
        else:
            return f"ROW_COUNT_DIFF ({len(pred_result)} vs {len(actual_result)})", False
    elif not pred_success and not actual_success:
        return "BOTH_FAILED", False
    elif not pred_success:
        return f"PRED_FAILED: {pred_error[:50]}", False
    else:
        return f"ACTUAL_FAILED: {actual_error[:50]}", False

# Analyze all test samples
random.seed(42)
all_indices = list(range(len(test_dataset)))

results = []

print("\nAnalyzing all test samples...")
for idx in all_indices:
    if idx % 20 == 0:
        print(f"  Processing {idx}/{len(test_dataset)}...")
    
    sample = test_dataset[idx]
    input_ids = torch.tensor([sample["input_ids"]]).to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(input_ids, max_length=MAX_TARGET_LENGTH)
    
    predicted_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
    actual_sql = tokenizer.decode(sample["labels"], skip_special_tokens=True)
    
    # Decode operators
    predicted_sql = decode_sql_from_model(predicted_sql)
    actual_sql = decode_sql_from_model(actual_sql)
    
    nl_question = test_data[idx]["instruction"]
    
    # Categorize
    actual_categories, actual_complexity = categorize_query(actual_sql)
    predicted_categories, predicted_complexity = categorize_query(predicted_sql)
    
    # Check if match
    is_match, match_type = queries_match_semantically(predicted_sql, actual_sql)
    
    # Execute and compare
    exec_result, results_match = execute_and_compare(predicted_sql, actual_sql, conn)
    
    results.append({
        'idx': idx,
        'question': nl_question,
        'predicted_sql': predicted_sql,
        'actual_sql': actual_sql,
        'actual_categories': actual_categories,
        'actual_complexity': actual_complexity,
        'predicted_categories': predicted_categories,
        'is_exact_match': is_match,
        'match_type': match_type,
        'exec_result': exec_result,
        'results_match': results_match
    })

# Convert to DataFrame for analysis
df_results = pd.DataFrame(results)

print("\n" + "="*80)
print("SUMMARY STATISTICS - UPDATED MODEL")
print("="*80)

total = len(df_results)
exact_matches = df_results['is_exact_match'].sum()
result_matches = df_results['results_match'].sum()

print(f"\nTotal test samples: {total}")
print(f"Exact SQL matches: {exact_matches} ({100*exact_matches/total:.1f}%)")
print(f"Semantic matches (same results): {result_matches} ({100*result_matches/total:.1f}%)")

print("\n" + "-"*80)
print("ACCURACY BY QUERY COMPLEXITY")
print("-"*80)

for complexity in ['SIMPLE', 'MEDIUM', 'COMPLEX']:
    subset = df_results[df_results['actual_complexity'] == complexity]
    if len(subset) > 0:
        acc = subset['results_match'].sum()
        print(f"{complexity:12} : {acc}/{len(subset)} ({100*acc/len(subset):.1f}%)")

print("\n" + "-"*80)
print("ACCURACY BY QUERY TYPE")
print("-"*80)

# Get all unique categories
all_categories = set()
for cats in df_results['actual_categories']:
    all_categories.update(cats)

category_stats = []
for cat in sorted(all_categories):
    subset = df_results[df_results['actual_categories'].apply(lambda x: cat in x)]
    if len(subset) > 0:
        correct = subset['results_match'].sum()
        total_cat = len(subset)
        accuracy = 100 * correct / total_cat
        category_stats.append({
            'category': cat,
            'total': total_cat,
            'correct': correct,
            'accuracy': accuracy
        })

df_cat_stats = pd.DataFrame(category_stats).sort_values('accuracy')
print(df_cat_stats.to_string(index=False))

print("\n" + "-"*80)
print("üéØ TARGET QUERY TYPES PERFORMANCE (After Retraining)")
print("-"*80)

target_types = {
    'GREATER_EQUAL_OP': 'GREATER_EQUAL (>=)',
    'LESS_EQUAL_OP': 'LESS_EQUAL (<=)',
    'CTE': 'CTE Queries'
}

for key, label in target_types.items():
    subset = df_cat_stats[df_cat_stats['category'] == key]
    if len(subset) > 0:
        row = subset.iloc[0]
        print(f"{label:25} : {row['correct']}/{row['total']} ({row['accuracy']:.1f}%)")
    else:
        print(f"{label:25} : No samples in test set")

print("\n" + "-"*80)
print("LOWEST PERFORMING QUERY TYPES")
print("-"*80)

worst_performing = df_cat_stats[df_cat_stats['accuracy'] < 85].sort_values('accuracy')
if len(worst_performing) > 0:
    print(worst_performing.to_string(index=False))
    print("\n‚ö†Ô∏è  These query types still need improvement!")
else:
    print("‚úÖ All query types performing above 85%!")

print("\n" + "-"*80)
print("SAMPLE FAILURES BY TYPE")
print("-"*80)

# Show examples of failures for each low-performing category
for _, row in worst_performing.head(5).iterrows():
    cat = row['category']
    print(f"\nüìå Category: {cat} (Accuracy: {row['accuracy']:.1f}%)")
    
    # Get a failure example
    failures = df_results[
        (df_results['actual_categories'].apply(lambda x: cat in x)) & 
        (~df_results['results_match'])
    ]
    
    if len(failures) > 0:
        example = failures.iloc[0]
        print(f"   Question: {example['question'][:80]}...")
        print(f"   Expected: {example['actual_sql'][:100]}...")
        print(f"   Got:      {example['predicted_sql'][:100]}...")
        print(f"   Issue:    {example['exec_result']}")

print("\n" + "="*80)
print("DETAILED FAILURE ANALYSIS")
print("="*80)

failures = df_results[~df_results['results_match']].copy()
print(f"\nTotal failures: {len(failures)}")

if len(failures) > 0:
    print("\nFailure reasons:")
    failure_reasons = failures['exec_result'].value_counts()
    for reason, count in failure_reasons.items():
        print(f"  {reason}: {count}")
    
    print("\n" + "-"*80)
    print("SHOWING 10 RANDOM FAILURE EXAMPLES")
    print("-"*80)
    
    sample_failures = failures.sample(min(10, len(failures)), random_state=42)
    
    for i, (_, row) in enumerate(sample_failures.iterrows(), 1):
        print(f"\n{'='*80}")
        print(f"FAILURE EXAMPLE {i} (Categories: {', '.join(row['actual_categories'])})")
        print(f"{'='*80}")
        print(f"Question: {row['question']}")
        print(f"\nExpected SQL:\n{row['actual_sql']}")
        print(f"\nPredicted SQL:\n{row['predicted_sql']}")
        print(f"\nIssue: {row['exec_result']}")
        print(f"Complexity: {row['actual_complexity']}")

conn.close()

# Save detailed results to CSV
output_csv = "updated_model_test_analysis.csv"
df_results.to_csv(output_csv, index=False)
print(f"\n‚úÖ Detailed results saved to {output_csv}")

# Save category statistics
cat_stats_csv = "updated_model_category_performance.csv"
df_cat_stats.to_csv(cat_stats_csv, index=False)
print(f"‚úÖ Category statistics saved to {cat_stats_csv}")

print("\n" + "="*80)
print("‚úÖ ANALYSIS COMPLETE FOR UPDATED MODEL!")
print("="*80)
print(f"Model tested: {MODEL_DIR}")
print(f"Compare these results with your original model to see improvements!")
print("="*80)

In [None]:
import sqlite3

conn = sqlite3.connect("customer_data.db")
cursor = conn.cursor()

cursor.execute("PRAGMA table_info(customer_demographics);")
print(cursor.fetchall())
