In [1]:
# 1. Install core dependencies
!pip install --upgrade pip
!pip install torch sqlglot accelerate==0.25.0

# 2. Install HuggingFace libraries with compatible versions
!pip install transformers==4.34.1 datasets trl peft==0.6.2

# 3. Handle data library conflicts by downgrading numpy/pyarrow just in case (though recent pip should handle it)
!pip install numpy==1.26.4 pyarrow==14.0.0

# 4. Check installation (optional, but good practice)
print("Installation Complete. Restarting runtime if in Colab/Jupyter is recommended.")

Collecting accelerate==0.25.0
  Using cached accelerate-0.25.0-py3-none-any.whl.metadata (18 kB)
Using cached accelerate-0.25.0-py3-none-any.whl (265 kB)
Installing collected packages: accelerate
  Attempting uninstall: accelerate
    Found existing installation: accelerate 1.12.0
    Uninstalling accelerate-1.12.0:
      Successfully uninstalled accelerate-1.12.0
Successfully installed accelerate-0.25.0


ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
trl 0.26.1 requires accelerate>=1.4.0, but you have accelerate 0.25.0 which is incompatible.


Collecting transformers==4.34.1
  Using cached transformers-4.34.1-py3-none-any.whl.metadata (121 kB)
Collecting tokenizers<0.15,>=0.14 (from transformers==4.34.1)
  Using cached tokenizers-0.14.1-cp310-none-win_amd64.whl.metadata (6.8 kB)
Collecting huggingface-hub<1.0,>=0.16.4 (from transformers==4.34.1)
  Using cached huggingface_hub-0.17.3-py3-none-any.whl.metadata (13 kB)
INFO: pip is looking at multiple versions of datasets to determine which version is compatible with other requirements. This could take a while.
Collecting datasets
  Downloading datasets-4.4.0-py3-none-any.whl.metadata (19 kB)
  Downloading datasets-4.3.0-py3-none-any.whl.metadata (18 kB)
  Downloading datasets-4.2.0-py3-none-any.whl.metadata (18 kB)
  Downloading datasets-4.1.1-py3-none-any.whl.metadata (18 kB)
  Downloading datasets-4.1.0-py3-none-any.whl.metadata (18 kB)
  Downloading datasets-4.0.0-py3-none-any.whl.metadata (19 kB)
Collecting fsspec (from huggingface-hub<1.0,>=0.16.4->transformers==4.34.1)
 

In [6]:
! curl -L -o data.tar.bz2 https://github.com/salesforce/WikiSQL/raw/master/data.tar.bz2
! tar -xjf data.tar.bz2

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed

  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0

  0 24.9M    0     0    0     0      0      0 --:--:--  0:00:01 --:--:--     0
100 24.9M  100 24.9M    0     0  11.6M      0  0:00:02  0:00:02 --:--:-- 55.8M


In [8]:
import torch
import sqlglot
import pandas as pd
from datasets import Dataset, Features, Value, Sequence
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import GRPOTrainer, GRPOConfig
from peft import LoraConfig
import json 
import sqlglot
from difflib import SequenceMatcher

MODEL_NAME = "ibm-granite/granite-4.0-350m-base"
OUTPUT_DIR = "E:/training_runs/granite-grpo-wikisql"
MAX_PROMPT_LENGTH = 512
MAX_COMPLETION_LENGTH = 128
LOCAL_DATA_DIR = "data"

# Feature Schema 
WIKISQL_FEATURES = Features({
    "phase": Value("int32"),
    "question": Value("string"),
    "sql": Value("string"),
    "table": Value("string"),
})

In [3]:
# --- Load main data files (.jsonl) and table schemas (.tables.jsonl) ---

"""
Loading the local jsonl files
dev -> Validation 
train -> train 

The dev and train json files have a table_id key so we know which table was referenced and the table json has all the headers page titles, rows and so on
"""
df_main_train = pd.read_json(f"{LOCAL_DATA_DIR}/train.jsonl", lines=True)
df_main_dev = pd.read_json(f"{LOCAL_DATA_DIR}/dev.jsonl", lines=True)
df_table_train = pd.read_json(f"{LOCAL_DATA_DIR}/train.tables.jsonl", lines=True)
df_table_dev = pd.read_json(f"{LOCAL_DATA_DIR}/dev.tables.jsonl", lines=True)


"""
Create dictionary with ID as index and all of the other table info as key value pairs for this dict 
e.g 

{
    # TABLE ID 
    "1-1000181-1": { 
        
        # COLUMN NAMES ARE INNER VALUES
        "header": ["State/territory", "Text/background colour", "Format", "Current slogan", "Current series", "Notes"],
        "types": ["text", "text", "text", "text", "text", "text"],
        "rows": [
            ["Australian Capital Territory", "blue/white", "Yaa\u00b7nna", ...],
            ... # All other rows 
        ],
        "name": "table_1000181_1"
    }
}

Why? We want to look up the table quickly when we 
"""
table_dict_train = df_table_train.set_index('id').T.to_dict('dict') 
table_dict_dev = df_table_dev.set_index('id').T.to_dict('dict')

def enforce_sql_types(sql_dict):
    """
     The third value was sometimes int, sometimes string, ... 
    """
    if 'conds' in sql_dict:
        new_conds = []
        for cond in sql_dict['conds']:
            if len(cond) == 3:
                # Cast the third element to string
                cond[2] = str(cond[2])
            new_conds.append(cond)
        sql_dict['conds'] = new_conds
    return sql_dict

def restructure_table(row):
    """
     Make sure the merged table is well formed 
    """
    table_info = row['table']
    # Handle NaN/missing values from map operation
    if pd.isna(table_info) or isinstance(table_info, float): 
         table_info = {'header': [], 'types': [], 'rows': [], 'id': row['table_id']}
    else:
         # Ensure the ID is present in the table info dictionary
         table_info['id'] = row['table_id'] 
    return table_info

def merge_and_serialize(df_main, table_dict, is_train_subset=False):
    if is_train_subset:
        df_subset = df_main.head(500).copy()  # df_main.head(2000).copy() 
    else:
        df_subset = df_main.copy()

    # Merge the table data based on the table id
    df_subset['table'] = df_subset['table_id'].map(table_dict)
    df_subset['table'] = df_subset.apply(restructure_table, axis=1)

    # make sure the sql part has strings in the conditions. 
    df_subset['sql'] = df_subset['sql'].apply(enforce_sql_types)

    # make sure these columns are strings again not json. 
    df_subset['sql'] = df_subset['sql'].apply(json.dumps)
    df_subset['table'] = df_subset['table'].apply(json.dumps)

    # creates HF Dataset object. 
    return Dataset.from_pandas(df_subset.drop(columns=['table_id']), features=WIKISQL_FEATURES)

train_dataset = merge_and_serialize(df_main_train, table_dict_train, is_train_subset=True)
validation_dataset = merge_and_serialize(df_main_dev, table_dict_dev)

# The final dataset is the subsetted train_dataset
dataset = train_dataset

print(f"âœ… Data loading and preparation complete. Training subset size: {len(dataset)}")
print("First example:")
print(dataset[0])

âœ… Data loading and preparation complete. Training subset size: 500
First example:
{'phase': 1, 'question': 'Tell me what the notes are for South Australia ', 'sql': '{"sel": 5, "conds": [[3, 0, "SOUTH AUSTRALIA"]], "agg": 0}', 'table': '{"header": ["State/territory", "Text/background colour", "Format", "Current slogan", "Current series", "Notes"], "types": ["text", "text", "text", "text", "text", "text"], "rows": [["Australian Capital Territory", "blue/white", "Yaa\\u00b7nna", "ACT \\u00b7 CELEBRATION OF A CENTURY 2013", "YIL\\u00b700A", "Slogan screenprinted on plate"], ["New South Wales", "black/yellow", "aa\\u00b7nn\\u00b7aa", "NEW SOUTH WALES", "BX\\u00b799\\u00b7HI", "No slogan on current series"], ["New South Wales", "black/white", "aaa\\u00b7nna", "NSW", "CPX\\u00b712A", "Optional white slimline series"], ["Northern Territory", "ochre/white", "Ca\\u00b7nn\\u00b7aa", "NT \\u00b7 OUTBACK AUSTRALIA", "CB\\u00b706\\u00b7ZZ", "New series began in June 2011"], ["Queensland", "maroon

In [6]:
# ==========================================
# 5. Data Formatting and Reward Functions
# ==========================================

"""
def format_wikisql(example):
"""
#Formats the input to include the table schema and deserializes the JSON strings.
"""
    # Deserialize the 'table' JSON string back into a dictionary
    table_dict = json.loads(example["table"])
    
    table_header = table_dict["header"] 
    question = example["question"]
    
    prompt = (
        f"Generate a SQL query to answer the question based on the table schema.\n"
        f"Schema: {table_header}\n"
        f"Question: {question}\n"
        f"Wrap your answer in <sql> tags. For example: <sql>SELECT * FROM table</sql>.\n"
        f"Answer:"
    )
    # The 'table' key is still a string in the output, but we extract the header here.
    return {
        "prompt": prompt,
        "schema_cols": table_header # Passing the header directly for simplicity
    }
"""
def format_wikisql_with_gold_target(example):
    """
    Formats the input prompt and secretly embeds the ground truth SQL
    to bypass the GRPOTrainer's restrictive keyword arguments.
    """
    table_dict = json.loads(example["table"])
    
    table_header = table_dict["header"] 
    question = example["question"]
    # ðŸŒŸ NEW: Get the raw SQL JSON string
    gold_sql_json_string = example["sql"]
    
    prompt = (
        f"Generate a SQL query to answer the question based on the table schema.\n"
        f"Schema: {table_header}\n"
        f"Question: {question}\n"
        f"Wrap your answer in <sql> tags. For example: <sql>SELECT * FROM table</sql>.\n"
        f"Answer:"
    )
    
    # ðŸŒŸ NEW: Embed the ground truth into the prompt string using hidden tags
    # The model will learn to ignore this, but the reward function can find it.
    prompt_with_target = f"[GOLD_SQL]{gold_sql_json_string}[/GOLD_SQL]\n{prompt}"
    
    return {
        "prompt": prompt_with_target, # Use the augmented prompt for training
        "schema_cols": table_header
    }

# Map the formatting function
# dataset = dataset.map(format_wikisql)
dataset = dataset.map(format_wikisql_with_gold_target)

def format_reward_func(completions, **kwargs):
    """Reward 1: Format Compliance (use of <sql> tags)."""
    rewards = []
    for completion in completions:
        if "<sql>" in completion and "</sql>" in completion:
            rewards.append(1.0)
        else:
            rewards.append(0.0) 
    return rewards

def sql_syntax_reward_func(completions, **kwargs):
    """Reward 2: SQL Syntax Validity (using sqlglot)."""
    rewards = []
    for completion in completions:
        try:
            if "<sql>" in completion and "</sql>" in completion:
                query = completion.split("<sql>")[1].split("</sql>")[0].strip()
            else:
                query = completion.strip()
            
            sqlglot.parse(query, read="sqlite")
            rewards.append(2.0)
        except Exception:
            rewards.append(-1.0)
    return rewards

def schema_consistency_reward_func(prompts, completions, **kwargs):
    """Reward 3: Schema Consistency (no hallucinated columns)."""
    rewards = []
    for prompt, completion in zip(prompts, completions):
        try:
            schema_part = prompt.split("Schema: ")[1].split("\nQuestion")[0]
            # Safely parse the column names from the prompt string
            valid_cols = set(c.strip().strip("'").strip('"').lower() for c in schema_part.strip("[]").split(",") if c)
            
            if "<sql>" in completion and "</sql>" in completion:
                query = completion.split("<sql>")[1].split("</sql>")[0].strip()
            else:
                query = completion.strip()

            parsed = sqlglot.parse_one(query, read="sqlite")
            used_cols = set(col.name.lower() for col in parsed.find_all(sqlglot.exp.Column))
            
            if "*" in used_cols: 
                used_cols.remove("*")
                
            if used_cols.issubset(valid_cols):
                rewards.append(1.5)
            else:
                rewards.append(-0.5)
                
        except Exception:
            rewards.append(0.0)
            
    return rewards

def logical_structure_reward_func(prompts, completions, **kwargs):
    """
    Reward 4: Scores based on logical structure.
    This version extracts the ground truth from the prompt string.
    """
    rewards = []

    for prompt, completion in zip(prompts, completions):
        score = 0.0
        
        # --- 1. Extract Ground Truth from Prompt ---
        try:
            # Find the hidden GOLD_SQL tag in the prompt
            start_idx = prompt.find("[GOLD_SQL]") + len("[GOLD_SQL]")
            end_idx = prompt.find("[/GOLD_SQL]")
            gt_json_str = prompt[start_idx:end_idx].strip()
            gt_dict = json.loads(gt_json_str)
        except Exception:
            # If we can't find or parse the gold target, we can't apply this reward.
            rewards.append(0.0) 
            continue
            
        # 2. Extract Generated Query
        try:
            query_string = completion.split("<sql>")[1].split("</sql>")[0].strip()
            query_string_upper = query_string.upper()
        except Exception:
            rewards.append(-3.0) # Penalty for broken format
            continue

        # 3. Parse Ground Truth Requirements
        gt_has_where = 'conds' in gt_dict and len(gt_dict['conds']) > 0
        gt_has_agg = gt_dict.get('agg', 0) > 0 

        # 4. Score Logic (The scoring logic remains the same)
        if "SELECT *" in query_string_upper: score -= 1.0
            
        if gt_has_where:
            if 'WHERE' in query_string_upper: score += 3.0 
            else: score -= 4.0 
            
        agg_keywords = ['COUNT', 'AVG', 'SUM', 'MAX', 'MIN']
        generated_has_agg = any(agg in query_string_upper for agg in agg_keywords)

        if gt_has_agg:
            if generated_has_agg: score += 3.0
            else: score -= 4.0
                
        rewards.append(score)
            
    return rewards

def intent_accuracy_reward_func(prompts, completions, **kwargs):
    rewards = []
    
    for prompt, completion in zip(prompts, completions):
        score = 0.0
        
        # --- 1. Extract Ground Truth & Schema from Prompt ---
        try:
            # Extract Gold JSON
            gt_start = prompt.find("[GOLD_SQL]") + len("[GOLD_SQL]")
            gt_end = prompt.find("[/GOLD_SQL]")
            gt_data = json.loads(prompt[gt_start:gt_end].strip())
            
            # Extract Schema to map indices to names
            # format: "Schema: ['State', 'Slogan', ...]"
            schema_start = prompt.find("Schema: ") + len("Schema: ")
            schema_end = prompt.find("\nQuestion")
            # a safe eval to turn string list "['a', 'b']" into list ['a', 'b']
            schema_cols = eval(prompt[schema_start:schema_end].strip()) 
        except Exception:
            rewards.append(0.0)
            continue

        # --- 2. Parse Generated Query ---
        try:
            # Extract SQL and normalize
            gen_sql = completion.split("<sql>")[1].split("</sql>")[0].strip()
            gen_parsed = sqlglot.parse_one(gen_sql)
            gen_sql_upper = gen_sql.upper()
        except Exception:
            rewards.append(-1.0) # Penalty for unparseable SQL
            continue

        # --- 3. Check Aggregation Accuracy (Fixes "COUNT" issues) ---
        # Map WikiSQL agg indices to keywords: 0=None, 1=MAX, 2=MIN, 3=COUNT, 4=SUM, 5=AVG
        agg_map = {0: None, 1: 'MAX', 2: 'MIN', 3: 'COUNT', 4: 'SUM', 5: 'AVG'}
        target_agg = agg_map.get(gt_data.get('agg', 0))

        if target_agg:
            # If Gold has aggregation, Generation MUST have it
            if target_agg in gen_sql_upper:
                score += 1.0
            else:
                score -= 1.0 # Penalty for missing it
        else:
            # If Gold has NO aggregation, Generation should NOT have it (prevent hallucinations)
            # We check if any agg keyword is present
            if any(x in gen_sql_upper for x in ['MAX(', 'MIN(', 'COUNT(', 'SUM(', 'AVG(']):
                score -= 0.5 

        # --- 4. Check Column Selection Accuracy (Fixes "Wrong selection") ---
        # The Gold JSON tells us the index of the SELECT column
        sel_col_idx = gt_data.get('sel')
        if sel_col_idx is not None and sel_col_idx < len(schema_cols):
            target_col_name = schema_cols[sel_col_idx].lower()
            
            # Check if the generated SQL selects this column
            # We look for the column name in the generated string (simplified check)
            # A strict check would use sqlglot to find the SELECT expression
            if target_col_name in gen_sql.lower():
                score += 1.0
            else:
                score -= 1.0 # Significant penalty for selecting the wrong column

        # --- 5. Check Condition Accuracy (Fixes "WHERE" clauses) ---
        gt_conds = gt_data.get('conds', [])
        if gt_conds:
            # Check if we have a WHERE clause
            if "WHERE" not in gen_sql_upper:
                score -= 1.0
            else:
                # Check if we are filtering by the RIGHT column
                for cond in gt_conds:
                    col_idx = cond[0]
                    if col_idx < len(schema_cols):
                        cond_col_name = schema_cols[col_idx].lower()
                        if cond_col_name in gen_sql.lower():
                            score += 1.0
                        else:
                            score -= 0.5 # Wrong filter column

        rewards.append(score)

    return rewards

Map: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 500/500 [00:00<00:00, 7812.83 examples/s]


In [5]:
dataset[0] 

{'phase': 1,
 'question': 'Tell me what the notes are for South Australia ',
 'sql': '{"sel": 5, "conds": [[3, 0, "SOUTH AUSTRALIA"]], "agg": 0}',
 'table': '{"header": ["State/territory", "Text/background colour", "Format", "Current slogan", "Current series", "Notes"], "types": ["text", "text", "text", "text", "text", "text"], "rows": [["Australian Capital Territory", "blue/white", "Yaa\\u00b7nna", "ACT \\u00b7 CELEBRATION OF A CENTURY 2013", "YIL\\u00b700A", "Slogan screenprinted on plate"], ["New South Wales", "black/yellow", "aa\\u00b7nn\\u00b7aa", "NEW SOUTH WALES", "BX\\u00b799\\u00b7HI", "No slogan on current series"], ["New South Wales", "black/white", "aaa\\u00b7nna", "NSW", "CPX\\u00b712A", "Optional white slimline series"], ["Northern Territory", "ochre/white", "Ca\\u00b7nn\\u00b7aa", "NT \\u00b7 OUTBACK AUSTRALIA", "CB\\u00b706\\u00b7ZZ", "New series began in June 2011"], ["Queensland", "maroon/white", "nnn\\u00b7aaa", "QUEENSLAND \\u00b7 SUNSHINE STATE", "999\\u00b7TLG", "

In [9]:
# ==========================================
# 6. Model & Trainer Initialization (CORRECTED)
# ==========================================

# Load Tokenizer (Keep this, as it's needed for the AutoModelForCausalLM and data prep)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Load Model
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

# LoRA Configuration
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], 
    task_type="CAUSAL_LM",
    lora_dropout=0.05,
)

# GRPO Config
training_args = GRPOConfig(
    output_dir=OUTPUT_DIR,
    learning_rate=1e-5,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    num_generations=4,
    max_prompt_length=MAX_PROMPT_LENGTH,
    max_completion_length=MAX_COMPLETION_LENGTH,
    num_train_epochs=1,
    logging_steps=10,
    save_strategy="steps",
    save_steps=100,
    report_to="none"
)

trainer = GRPOTrainer(
    model=model,
    reward_funcs=[
        format_reward_func, 
        sql_syntax_reward_func, 
        schema_consistency_reward_func,
        logical_structure_reward_func,
        intent_accuracy_reward_func
    ],
    args=training_args,
    train_dataset=dataset,
    peft_config=peft_config
    # REMOVED: tokenizer=tokenizer  <-- This was the unexpected keyword argument
)

# ==========================================
# 7. Execution
# ==========================================
print("Starting GRPO Training for Text-to-SQL with full logical rewards...")
try:
    trainer.train() 
except KeyboardInterrupt:
    print("\nTraining interrupted by user. Proceeding to final save...")
    
# Save the final adapter
trainer.save_model(OUTPUT_DIR)
print(f"Training complete. Model saved to {OUTPUT_DIR}")

The model is already on multiple devices. Skipping the move to device specified in `args`.


Starting GRPO Training for Text-to-SQL with full logical rewards...


Step,Training Loss
10,0.0965
20,-0.0099
30,-0.0154
40,0.0687
50,0.0742
60,-0.013
70,0.019
80,0.032
90,-0.0065
100,-0.038


'IF WHERE 2 THEN 1 ELSE 2 END AS name IN(SELECT DISTINCT columns FROM table WHERE column_of_query=2 O' contains unsupported syntax. Falling back to parsing as a 'Command'.
'IF WHERE 2 THEN 1 ELSE 2 END AS name IN(SELECT DISTINCT columns FROM table WHERE column_of_query=2 O' contains unsupported syntax. Falling back to parsing as a 'Command'.
'IF WHERE 2 THEN 1 ELSE 2 END AS name IN(SELECT DISTINCT columns FROM table WHERE column_of_query=2 O' contains unsupported syntax. Falling back to parsing as a 'Command'.
'CALL thread_pthreads_thread_get_intermediate(false,(Index_Ldap)0) ON sajmiÅ¡te SERVED_BY<HOSTNAME>

Y' contains unsupported syntax. Falling back to parsing as a 'Command'.
'CALL thread_pthreads_thread_get_intermediate(false,(Index_Ldap)0) ON sajmiÅ¡te SERVED_BY<HOSTNAME>

Y' contains unsupported syntax. Falling back to parsing as a 'Command'.


Training complete. Model saved to E:/training_runs/granite-grpo-wikisql


In [11]:
import torch
from peft import PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import json
import pandas as pd
from datasets import Dataset, Features, Value
import os # <-- IMPORTED OS FOR PATH JOINING

# --- FIX 1: UPDATED OUTPUT_DIR to E: DRIVE ---
# Assuming the adapter was saved here after the GRPO training run.
OUTPUT_DIR = "E:/training_runs/granite-grpo-wikisql" 
MAX_PROMPT_LENGTH = 512
MAX_COMPLETION_LENGTH = 128

# --- FIX 2: Using os.path.join for robust local data access ---
# This ensures the script finds the data files regardless of where the script is run from.
LOCAL_DATA_DIR = os.path.join(os.getcwd(), "data")
# --- Utility Functions (Must match the ones used in pre-processing) ---
def format_wikisql(example):
    """Formats the input prompt for the model."""
    table_dict = json.loads(example["table"])
    table_header = table_dict["header"] 
    question = example["question"]
    
    prompt = (
        f"Generate a SQL query to answer the question based on the table schema.\n"
        f"Schema: {table_header}\n"
        f"Question: {question}\n"
        f"Wrap your answer in <sql> tags. For example: <sql>SELECT * FROM table</sql>.\n"
        f"Answer:"
    )
    return {"prompt": prompt}

def get_validation_data():
    """Simplified data loading for the dev set."""
    df_main_dev = pd.read_json(f"{LOCAL_DATA_DIR}/dev.jsonl", lines=True)
    df_table_dev = pd.read_json(f"{LOCAL_DATA_DIR}/dev.tables.jsonl", lines=True)
    table_dict_dev = df_table_dev.set_index('id').T.to_dict('dict')
    
    # We need the full merge and serialize here, but for brevity, we assume 
    # a function similar to merge_and_serialize was used to generate this.
    
    # For a quick test, we will just merge and serialize the first 10 dev samples:
    df_subset = df_main_dev.head(10).copy()
    df_subset['table'] = df_subset['table_id'].map(table_dict_dev)
    # NOTE: In a full script, you would include restructure_table and enforce_sql_types here
    
    # Simple serialization for demonstration:
    df_subset['sql'] = df_subset['sql'].apply(json.dumps)
    df_subset['table'] = df_subset['table'].apply(json.dumps)
    
    dataset = Dataset.from_pandas(df_subset.drop(columns=['table_id']), features=WIKISQL_FEATURES)
    return dataset.map(format_wikisql)

# ==========================================
# 3. Model Loading and Inference
# ==========================================
print("1. Loading Tokenizer and Base Model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

print("2. Loading and Merging LoRA Adapter...")
model = PeftModel.from_pretrained(base_model, OUTPUT_DIR)
model = model.merge_and_unload()
model.eval()
print("âœ… Model adapter merged and ready for inference.")

# ==========================================
# 4. Testing Inference
# ==========================================
validation_dataset = get_validation_data()
test_sample = validation_dataset[1] # Use the second example

prompt = test_sample["prompt"]
ground_truth_sql = json.loads(test_sample["sql"]) # Deserialize for comparison

print("\n--- TEST CASE ---")
print(f"Question: {test_sample['question']}")
print(f"Ground Truth SQL: {ground_truth_sql}")
print("--- Generating Query ---")

inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

# Generate the completion
with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=MAX_COMPLETION_LENGTH,
        do_sample=False,
        pad_token_id=tokenizer.pad_token_id
    )

# Decode the output, stripping the input prompt
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
generated_completion = generated_text[len(prompt):].strip()

# --- Post-Processing ---
try:
    # Extract the SQL from between the tags (as done in the reward functions)
    generated_sql = generated_completion.split("<sql>")[1].split("</sql>")[0].strip()
except IndexError:
    generated_sql = "Extraction Failed (Tags Missing)"

print("\n--- RESULTS ---")
print(f"Generated Raw Text: {generated_completion}")
print(f"Extracted SQL: {generated_sql}")
print("\n--- FINAL VERDICT ---")

# Simple comparison for demonstration
if generated_sql and generated_sql != "Extraction Failed (Tags Missing)":
    print("SUCCESS: The model generated syntactically plausible SQL.")
else:
    print("FAILURE: The model failed to adhere to the required format.")

1. Loading Tokenizer and Base Model...
2. Loading and Merging LoRA Adapter...
âœ… Model adapter merged and ready for inference.


Map: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 10/10 [00:00<00:00, 77.55 examples/s]



--- TEST CASE ---
Question: How many schools did player number 3 play at?
Ground Truth SQL: {'sel': 5, 'conds': [[1, 0, '3']], 'agg': 3}
--- Generating Query ---

--- RESULTS ---
Generated Raw Text: <sql>SELECT COUNT(Player) FROM table WHERE Player = 3</sql>

Player, No., Nationality, Position, Years in Toronto, School/Club Team
3, England, England, Goalkeeper, University of St Andrews, St Andrews
4, England, England, Goalkeeper, University of St Andrews, St Andrews
5, England, England, Goalkeeper, University of St Andrews, St Andrews
6, England, England, Goalkeeper, University of St Andrews, St Andrews
7, England, England, Goalkeeper, University of St Andrews, St Andrews
8, England, England, Goalkeeper
Extracted SQL: SELECT COUNT(Player) FROM table WHERE Player = 3

--- FINAL VERDICT ---
SUCCESS: The model generated syntactically plausible SQL.


In [12]:
test_sample_1 = validation_dataset[0]

prompt = test_sample_1["prompt"]
ground_truth_sql = json.loads(test_sample_1["sql"])

print("\n--- TEST CASE 2 ---")
print(f"Question: {test_sample_1['question']}")
print(f"Ground Truth SQL: {ground_truth_sql}")
print("--- Generating Query ---")

inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

# Generate the completion
with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=MAX_COMPLETION_LENGTH,
        do_sample=False, # Use greedy decoding for predictable results
        pad_token_id=tokenizer.pad_token_id
    )

generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
generated_completion = generated_text[len(prompt):].strip()

# --- Post-Processing ---
try:
    generated_sql = generated_completion.split("<sql>")[1].split("</sql>")[0].strip()
except IndexError:
    generated_sql = "Extraction Failed (Tags Missing)"

print("\n--- RESULTS ---")
print(f"Extracted SQL: {generated_sql}")


--- TEST CASE 2 ---
Question: What position does the player who played for butler cc (ks) play?
Ground Truth SQL: {'sel': 3, 'conds': [[5, 0, 'Butler CC (KS)']], 'agg': 0}
--- Generating Query ---

--- RESULTS ---
Extracted SQL: SELECT * FROM player WHERE nationality = 'ks' AND position = 'butler cc (ks)' AND years_in_toronto = 0 AND school/club_team = 'butler cc (ks)'


In [14]:
test_sample_5 = validation_dataset[5]

prompt = test_sample_5["prompt"]
ground_truth_sql = json.loads(test_sample_5["sql"])

print("\n--- TEST CASE 3 ---")
print(f"Question: {test_sample_5['question']}")
print(f"Ground Truth SQL: {ground_truth_sql}")
print("--- Generating Query ---")

inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

# Generate the completion
with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=MAX_COMPLETION_LENGTH,
        do_sample=False,
        pad_token_id=tokenizer.pad_token_id
    )

generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
generated_completion = generated_text[len(prompt):].strip()

# --- Post-Processing ---
try:
    generated_sql = generated_completion.split("<sql>")[1].split("</sql>")[0].strip()
except IndexError:
    generated_sql = "Extraction Failed (Tags Missing)"

print("\n--- RESULTS ---")
print(f"Extracted SQL: {generated_sql}")


--- TEST CASE 3 ---
Question: Who are all of the players on the Westchester High School club team?
Ground Truth SQL: {'sel': 0, 'conds': [[5, 0, 'Westchester High School']], 'agg': 0}
--- Generating Query ---

--- RESULTS ---
Extracted SQL: SELECT * FROM players WHERE nationality = 'Westchester High School' AND position = 'Player' AND years_in_toronto = 0 AND school/club_team = 'Toronto' AND nationality = 'Westchester High School' AND position = 'Player' AND years_in_toronto = 0
