In [1]:
!pip install camel-tools --no-build-isolation --no-deps
!pip install transformers

Collecting camel-tools
  Downloading camel_tools-1.5.6-py3-none-any.whl.metadata (10 kB)
Downloading camel_tools-1.5.6-py3-none-any.whl (124 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m124.7/124.7 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: camel-tools
Successfully installed camel-tools-1.5.6


In [2]:
import numpy as np
import pandas as pd
import sqlite3
import os
import torch
import re
import sqlparse
from datasets import Dataset
from transformers import (
    T5ForConditionalGeneration,
    AutoTokenizer,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq,
    TrainerCallback
)
from camel_tools.utils.normalize import (
    normalize_unicode,
    normalize_alef_maksura_ar,
    normalize_teh_marbuta_ar
)
from camel_tools.utils.dediac import dediac_ar
from sklearn.metrics import accuracy_score
import warnings
warnings.filterwarnings("ignore")

2025-05-06 05:13:24.277441: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746508404.526642      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746508404.594465      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
# 1. Data Loading and Schema Extraction
def get_all_databases(folder_path):
    """Retrieve all SQLite database files from subdirectories."""
    db_files = []
    for root, _, files in os.walk(folder_path):
        for file in files:
            if file.endswith(".sqlite"):
                db_files.append(os.path.join(root, file))
    return db_files

def extract_schema(db_path):
    """Extract tables and columns from a SQLite database."""
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = cursor.fetchall()
    
    schema = {}
    for table in tables:
        table_name = table[0]
        cursor.execute(f"PRAGMA table_info({table_name});")
        columns = cursor.fetchall()
        schema[table_name] = [col[1] for col in columns]  # col[1] is column name
    
    conn.close()
    return schema

In [4]:
# Load database schemas
folder_path = "/kaggle/input/txt-to-sql/Text To SQL Task/Dataset/database"
sqlite_files = get_all_databases(folder_path)
all_schemas = {db: extract_schema(db) for db in sqlite_files}

In [5]:
{
  "database1.sqlite": {
    "students": ["id", "name", "department"],
    "courses": ["id", "title", "credits"]
  },
  "database2.sqlite": {
    "employees": ["id", "name", "salary"],
    "departments": ["id", "name"]
  }
}

{'database1.sqlite': {'students': ['id', 'name', 'department'],
  'courses': ['id', 'title', 'credits']},
 'database2.sqlite': {'employees': ['id', 'name', 'salary'],
  'departments': ['id', 'name']}}

# DATA

In [6]:
train = pd.read_json('/kaggle/input/finalll/AR_spider.jsonl', lines=True)

In [7]:
train.head()

Unnamed: 0,question,query,arabic,db_id
0,How many heads of the departments are older th...,SELECT count(*) FROM head WHERE age > 56,كم عدد رؤساء الأقسام الذين تزيد أعمارهم عن 56 ...,department_management
1,"List the name, born state and age of the heads...","SELECT name , born_state , age FROM head ORD...",اعرض قائمة بأسماء رؤساء الأقسام، مكان ميلادهم،...,department_management
2,"List the creation year, name and budget of eac...","SELECT creation , name , budget_in_billions ...",اعرض قائمة بسنوات الإنشاء، وأسماء وميزانيات كل...,department_management
3,What are the maximum and minimum budget of the...,"SELECT max(budget_in_billions) , min(budget_i...",ما هي أقصى وأدنى ميزانية للأقسام؟,department_management
4,What is the average number of employees of the...,SELECT avg(num_employees) FROM department WHER...,ما هو المتوسط ​​لعدد الموظفين في الأقسام الذين...,department_management


In [8]:
train.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 6396 entries, 0 to 6395
Data columns (total 4 columns):
 #   Column    Non-Null Count  Dtype 
---  ------    --------------  ----- 
 0   question  6396 non-null   object
 1   query     6396 non-null   object
 2   arabic    6396 non-null   object
 3   db_id     6396 non-null   object
dtypes: object(4)
memory usage: 200.0+ KB


In [9]:
from camel_tools.utils.dediac import dediac_ar
import re
# Enhanced Arabic text normalization
def preprocess_arabic(text):
    """Enhanced Arabic text normalization"""
    if not isinstance(text, str):
        return ""
    
    # Normalize Unicode
    text = normalize_unicode(text)
    
    # Normalize Arabic-specific characters
    text = normalize_teh_marbuta_ar(text)
    text = normalize_alef_maksura_ar(text)
    
    # Remove diacritics more thoroughly
    text = re.sub(r'[\u064b-\u065f]', '', text)  # Remove all diacritics
    
    # Standardize numbers (Arabic to Western)
    text = re.sub(r'[٠١٢٣٤٥٦٧٨٩]', lambda x: str(ord(x.group(0)) - 1632), text)
    
    # Clean text
    text = re.sub(r'[^\w\s،؛؟]', '', text)  # Keep Arabic punctuation
    text = re.sub(r'\s+', ' ', text)  # Remove extra spaces
    
    return text.strip().lower()

def normalize_sql(query):
    """Enhanced SQL normalization"""
    try:
        parsed = sqlparse.parse(query)[0]
        # Standardize formatting
        formatted = parsed.format(reindent=True, keyword_case='upper')
        # Remove extra whitespace
        return ' '.join(formatted.split())
    except:
        return query.lower()

train["arabic"] = train["arabic"].apply(preprocess_arabic)
train["normalized_query"] = train["query"].apply(normalize_sql)

In [10]:
train.columns

Index(['question', 'query', 'arabic', 'db_id', 'normalized_query'], dtype='object')

In [11]:
from transformers import T5ForConditionalGeneration, AutoTokenizer, T5Config

model_name = "cssupport/t5-small-awesome-text-to-sql" 
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
config = T5Config(
    vocab_size=len(tokenizer),
    d_model=768,  # Increased from default 512
    d_ff=3072,    # Increased from default 2048
    num_layers=8,  # Increased from default 6
    num_heads=12,  # Increased from default 8
    dropout_rate=0.1,
    decoder_start_token_id=tokenizer.pad_token_id
)

model = T5ForConditionalGeneration(config)

# Special tokens for SQL generation
sql_tokens = ["<", ">", "=", "!=", "<=", ">=", "(", ")", ",", ";", "[", "]"]
tokenizer.add_tokens(sql_tokens)
model.resize_token_embeddings(len(tokenizer))

tokenizer_config.json:   0%|          | 0.00/2.41k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


config.json:   0%|          | 0.00/2.37k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/242M [00:00<?, ?B/s]

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Embedding(32104, 768)

In [12]:
from transformers import T5ForConditionalGeneration, AutoTokenizer
from collections import Counter
import re
import pandas as pd


# Load CAMeLBERT tokenizer
camelbert_tokenizer = AutoTokenizer.from_pretrained("CAMeL-Lab/bert-base-arabic-camelbert-da")

# Function to analyze Arabic text and get most frequent tokens
def get_most_frequent_arabic_tokens(texts, camelbert_tokenizer, top_n=2000):
    # Tokenize all Arabic text with CAMeLBERT
    all_tokens = []
    for text in texts:
        tokens = camelbert_tokenizer.tokenize(text)
        all_tokens.extend(tokens)
    
    # Count token frequencies
    token_counts = Counter(all_tokens)
    
    # Filter for Arabic tokens (simple regex for Arabic script)
    arabic_pattern = re.compile(r'[\u0600-\u06FF\u0750-\u077F\u08A0-\u08FF]+')
    arabic_tokens = {
        token: count for token, count in token_counts.items() 
        if arabic_pattern.search(token) and count > 3  # Only tokens appearing >3 times
    }
    
    # Sort by frequency and get top N
    sorted_tokens = sorted(arabic_tokens.items(), key=lambda x: x[1], reverse=True)
    return [token for token, count in sorted_tokens[:top_n]]

# Get Arabic texts from your dataset
def extract_arabic_texts(data):
    if isinstance(data, pd.DataFrame):
        return data['arabic'].tolist()
    elif hasattr(data, '__iter__') and not isinstance(data, str):
        # Handle Hugging Face Dataset or list of dicts
        if isinstance(data[0], dict):
            return [item['arabic'] for item in data]
        elif isinstance(data[0], str):
            return data
    return []

arabic_texts = extract_arabic_texts(train)  # Replace 'train' with your dataset variable
if not arabic_texts:
    raise ValueError("Could not extract Arabic texts from the dataset")

top_arabic_tokens = get_most_frequent_arabic_tokens(arabic_texts, camelbert_tokenizer, top_n=5000)

# Add only tokens not already in the T5 tokenizer
new_arabic_tokens = [
    token for token in top_arabic_tokens 
    if token not in tokenizer.get_vocab()
]

print(f"Adding {len(new_arabic_tokens)} new Arabic tokens to the tokenizer")
tokenizer.add_tokens(new_arabic_tokens)

# Add SQL special tokens (check if they exist first)
sql_tokens = ["<", ">", "=", "!=", "<=", ">=", "(", ")", ",", ";", "[", "]"]
new_sql_tokens = [
    token for token in sql_tokens 
    if token not in tokenizer.get_vocab()
]

print(f"Adding {len(new_sql_tokens)} new SQL tokens to the tokenizer")
tokenizer.add_tokens(new_sql_tokens)

# Resize model embeddings
if len(new_arabic_tokens) > 0 or len(new_sql_tokens) > 0:
    print("Resizing model embeddings...")
    model.resize_token_embeddings(len(tokenizer))
else:
    print("No new tokens to add")

tokenizer_config.json:   0%|          | 0.00/86.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/468 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/305k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Adding 2390 new Arabic tokens to the tokenizer
Adding 0 new SQL tokens to the tokenizer
Resizing model embeddings...


In [13]:
# Add SQL-specific output embeddings
sql_keywords = ["SELECT", "FROM", "WHERE", "JOIN", "GROUP BY", "ORDER BY", "LIMIT"]
with torch.no_grad():
    for keyword in sql_keywords:
        ids = tokenizer(keyword, add_special_tokens=False).input_ids
        if len(ids) == 1:  # Single token
            # Boost probability of SQL keywords
            model.lm_head.weight[ids[0]] *= 1.2

In [14]:
hf_dataset = Dataset.from_pandas(train)
val_dataset1= hf_dataset.train_test_split(test_size=0.2, seed=42)['test']

In [15]:
from datasets import Dataset
from transformers import DataCollatorForSeq2Seq
import numpy as np
import re
import os

# Improved schema matching function
def find_matching_schema(db_id, all_schemas):
    # Try exact match first
    if db_id in all_schemas:
        return all_schemas[db_id]
    
    # Try different variations
    basename = os.path.basename(db_id).replace('.sqlite', '')
    possible_keys = [
        k for k in all_schemas.keys() 
        if basename in k.lower() or db_id.lower() in k.lower()
    ]
    
    if possible_keys:
        return all_schemas[possible_keys[0]]
    
    raise ValueError(f"No matching schema found for {db_id}")
# Format schema as before
def format_schema(schema_dict):
    schema_text = "Database schema:\n"
    for table, columns in schema_dict.items():
        schema_text += f"[TABLE]{table}[/TABLE]("
        schema_text += ", ".join(f"[COLUMN]{col}[/COLUMN]" for col in columns)
        schema_text += ")\n"
    return schema_text.strip()

# Preprocessing function
def preprocess_function(examples):
    try:
        # Handle both batched and non-batched cases
        if isinstance(examples['db_id'], list):
            schema_texts = []
            for db_id in examples['db_id']:
                try:
                    schema = find_matching_schema(db_id, all_schemas)
                    schema_texts.append(format_schema(schema))
                except KeyError as e:
                    print(f"Skipping example with db_id: {db_id} - {str(e)}")
                    schema_texts.append("Database schema not available")
            
            arabic_texts = examples['arabic']
            queries = examples["query"]
        else:
            try:
                schema = find_matching_schema(examples['db_id'], all_schemas)
                schema_texts = [format_schema(schema)]
            except KeyError as e:
                print(f"Skipping example with db_id: {examples['db_id']} - {str(e)}")
                schema_texts = ["Database schema not available"]
            
            arabic_texts = [examples['arabic']]
            queries = [examples["query"]]
        
        # Construct inputs
        inputs = [
            f"{schema_text}\n\nTranslate Arabic to SQL: {arabic}"
            for schema_text, arabic in zip(schema_texts, arabic_texts)
        ]
        
        # Tokenize
        model_inputs = tokenizer(
            inputs,
            max_length=384,
            truncation=True,
            padding="max_length",
            return_tensors="np"
        )
        
        with tokenizer.as_target_tokenizer():
            labels = tokenizer(
                queries,
                max_length=256,
                truncation=True,
                padding="max_length",
                return_tensors="np"
            )
        
        model_inputs["labels"] = labels["input_ids"]
        return model_inputs
    
    except Exception as e:
        print(f"Error processing batch: {str(e)}")
        raise

# Apply preprocessing with error handling
try:
    tokenized_dataset = hf_dataset.map(
        preprocess_function,
        batched=True,
        batch_size=32,
        remove_columns=hf_dataset.column_names
    )
    
    # Split dataset
    split = tokenized_dataset.train_test_split(test_size=0.2, seed=42)
    train_dataset = split["train"]
    val_dataset = split["test"]
    
except Exception as e:
    print(f"Failed to preprocess dataset: {str(e)}")
    print("\nDebugging info:")
    print(f"Sample db_ids: {hf_dataset['db_id'][:5]}")
    print(f"Sample schema keys: {list(all_schemas.keys())[:5]}")
    raise

Map:   0%|          | 0/6396 [00:00<?, ? examples/s]

In [16]:
def generate_sql_with_constraints(arabic_text, db_id, model, tokenizer, all_schemas):
    # Format input text
    schema = find_matching_schema(db_id, all_schemas)
    input_text = f"Translate Arabic to SQL:\nSchema: {format_schema(schema)}\nQuestion: {arabic_text}\nSQL:"
    
    # Prepare constraints
    table_names = list(schema.keys())
    column_names = [col for cols in schema.values() for col in cols]
    
    constraints = [
        SQLGrammarConstraint(tokenizer),
        TableConstraint(tokenizer, table_names),
        ColumnConstraint(tokenizer, column_names)
    ]
    
    # Tokenize and generate
    inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
    outputs = model.generate(
        **inputs,
        constraints=constraints,
        max_length=256,
        num_beams=5,
        early_stopping=True
    )
    
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

In [17]:
os.environ["WANDB_DISABLED"] = "true"

In [18]:
def format_input(arabic_text, db_id):
    schema = find_matching_schema(db_id, all_schemas)
    schema_text = "Database schema:\n"
    for table, columns in schema.items():
        schema_text += f"- {table} ({', '.join(columns)})\n"
    
    return (
        f"Translate Arabic to SQL:\n"
        f"Schema: {schema_text.strip()}\n"
        f"Question: {arabic_text}\n"
        f"SQL:"
    )

In [19]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [20]:
class CustomCallback(TrainerCallback):
    def on_step_begin(self, args, state, control, **kwargs):
        """Dynamic learning rate adjustment"""
        if state.global_step < args.warmup_steps:
            lr_scale = min(1.0, float(state.global_step) / float(args.warmup_steps))
        else:
            progress = float(state.global_step - args.warmup_steps) / float(args.max_steps - args.warmup_steps)
            lr_scale = max(0.0, 0.5 * (1.0 + np.cos(np.pi * progress)))
        
        for param_group in kwargs['optimizer'].param_groups:
            param_group['lr'] = args.learning_rate * lr_scale

def compute_metrics(eval_pred):
    """Robust metrics calculation with overflow protection"""
    preds, labels = eval_pred
    
    # Handle tuple output from model
    if isinstance(preds, tuple):
        preds = preds[0]
    
    # Convert to numpy if not already
    preds = preds.numpy() if hasattr(preds, "numpy") else np.array(preds)
    labels = labels.numpy() if hasattr(labels, "numpy") else np.array(labels)
    
    # Replace -100 with pad_token_id and clip to valid token IDs
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    
    # Clip predictions and labels to valid token ID range
    vocab_size = len(tokenizer)
    preds = np.clip(preds, 0, vocab_size - 1)
    labels = np.clip(labels, 0, vocab_size - 1)
    
    # Safely decode predictions and labels
    decoded_preds = []
    for pred in preds:
        try:
            decoded = tokenizer.decode(pred, skip_special_tokens=True)
            decoded_preds.append(decoded)
        except:
            decoded_preds.append("")  # Fallback for invalid sequences
    
    decoded_labels = []
    for label in labels:
        try:
            decoded = tokenizer.decode(label, skip_special_tokens=True)
            decoded_labels.append(decoded)
        except:
            decoded_labels.append("")  # Fallback for invalid sequences
    
    # Filter out empty strings that resulted from decoding errors
    valid_pairs = [(p, l) for p, l in zip(decoded_preds, decoded_labels) if p and l]
    if not valid_pairs:
        return {
            "exact_match": 0.0,
            "token_accuracy": 0.0,
            "combined_score": 0.0
        }
    
    valid_preds, valid_labels = zip(*valid_pairs)
    
    # Calculate exact match accuracy
    exact_matches = [1 if pred.strip() == label.strip() else 0 
                    for pred, label in valid_pairs]
    exact_match = sum(exact_matches) / len(exact_matches)
    
    # Calculate token-level accuracy
    token_accuracies = []
    for pred, label in valid_pairs:
        pred_tokens = pred.split()
        label_tokens = label.split()
        correct = sum(1 for p, l in zip(pred_tokens, label_tokens) if p == l)
        token_accuracies.append(correct / max(len(pred_tokens), len(label_tokens)))
    
    token_accuracy = sum(token_accuracies) / len(token_accuracies) if token_accuracies else 0.0
    
    return {
        "exact_match": exact_match,
        "token_accuracy": token_accuracy,
        "combined_score": 0.7 * exact_match + 0.3 * token_accuracy,
        "valid_samples": len(valid_pairs),
        "total_samples": len(decoded_preds)
    }

In [21]:
from transformers import Seq2SeqTrainingArguments

# Updated training arguments for Transformers 4.51.1
training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",  # Changed from evaluation_strategy
    eval_steps=500,
    save_strategy="epoch",  # Changed from save_strategy
    save_steps=500,
    learning_rate=1e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=10,
    predict_with_generate=True,
    fp16=True,
    gradient_accumulation_steps=2,
    warmup_steps=500,
    logging_dir="./logs",
    logging_steps=100,
    load_best_model_at_end=True,
    metric_for_best_model="eval_combined_score",
    greater_is_better=True,
    report_to="none",
    push_to_hub=False,
    generation_max_length=256,
    generation_num_beams=5,
    warmup_ratio=0.1,  # Instead of fixed steps
    lr_scheduler_type="cosine", 
)

In [22]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[CustomCallback()]
)

trainer.train()

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Epoch,Training Loss,Validation Loss,Exact Match,Token Accuracy,Combined Score,Valid Samples,Total Samples
1,2.5021,1.001313,0.0,0.016123,0.004837,1280,1280
2,0.7726,0.634521,0.0,0.084206,0.025262,1280,1280
3,0.6344,0.526576,0.0,0.050244,0.015073,1280,1280
4,0.5164,0.471679,0.0,0.063143,0.018943,1280,1280
5,0.4633,0.442789,0.0,0.052671,0.015801,1280,1280
6,0.459,0.435164,0.0,0.05217,0.015651,1280,1280
7,0.4494,0.434383,0.0,0.056288,0.016886,1280,1280
8,0.4477,0.423285,0.0,0.059164,0.017749,1280,1280
9,0.4393,0.402828,0.0,0.057948,0.017384,1280,1280
10,0.4089,0.377137,0.0,0.06973,0.020919,1280,1280


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight', 'lm_head.weight'].


TrainOutput(global_step=1600, training_loss=0.6634806847572327, metrics={'train_runtime': 15655.3529, 'train_samples_per_second': 3.268, 'train_steps_per_second': 0.102, 'total_flos': 1.557720319131648e+16, 'train_loss': 0.6634806847572327, 'epoch': 10.0})

In [23]:
def validate_and_correct_sql(predicted_sql, schema):
    """Simple SQL validation"""
    try:
        parsed = sqlparse.parse(predicted_sql)[0]
        # Basic validation - check if tables exist in schema
        tables_in_query = set()
        for token in parsed.tokens:
            if isinstance(token, sqlparse.sql.Identifier):
                tables_in_query.add(token.get_real_name())
        
        for table in tables_in_query:
            if table not in schema:
                return None  # Invalid table name
        
        return predicted_sql
    except:
        return None

In [24]:
def arabic_to_sql(arabic_text, db_id):
    """Enhanced prediction function"""
    try:
        schema = find_matching_schema(db_id, all_schemas)
        input_text = (
            f"Translate this Arabic question to SQL:\n"
            f"Database schema:\n{format_schema(schema)}\n"
            f"Question: {arabic_text}\n"
            f"SQL:"
        )
        
        inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512).to(model.device)
        
        # Beam search with length penalty
        outputs = model.generate(
            **inputs,
            max_length=256,
            num_beams=8,
            early_stopping=True,
            length_penalty=0.6,
            no_repeat_ngram_size=3,
            num_return_sequences=3  # Get top 3 candidates
        )
        
        # Decode and select best valid SQL
        candidates = [tokenizer.decode(o, skip_special_tokens=True) for o in outputs]
        for sql in candidates:
            validated_sql = validate_and_correct_sql(sql, schema)
            if validated_sql:
                return validated_sql
        
        return candidates[0]  # Return first candidate if none validate
    
    except Exception as e:
        print(f"Error generating SQL: {str(e)}")
        return "SELECT * FROM unknown_table"  # Fallback query

In [25]:
# Example usage
sample = val_dataset1[0]
print("Arabic Question:", sample["arabic"])
print("Generated SQL:", arabic_to_sql(sample["arabic"], sample["db_id"]))
print("Actual SQL:", sample["query"])

Arabic Question: ما هي تسجيلات المدارس التي لا تنتمي إلي الكاثوليكيه؟
Generated SQL: SELECT T1, T2 ON T1.id WHERE T1.customer_id = T1.name, count( *) DESC LIMIT 1
Actual SQL: SELECT LOCATION FROM school ORDER BY Founded DESC
