In [1]:
!pip install -q -U git+https://github.com/huggingface/trl.git
!pip install -q -U transformers accelerate peft datasets bitsandbytes sqlglot

import torch
import trl
from trl import GRPOTrainer, GRPOConfig

print(f"✅ TRL version: {trl.__version__}")
print(f"✅ GRPOTrainer imported successfully!")
print(f"✅ GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'No GPU found'}")

[0mFound existing installation: transformers 4.57.3
Uninstalling transformers-4.57.3:
  Successfully uninstalled transformers-4.57.3
Found existing installation: peft 0.18.0
Uninstalling peft-0.18.0:
  Successfully uninstalled peft-0.18.0
Found existing installation: accelerate 1.12.0
Uninstalling accelerate-1.12.0:
  Successfully uninstalled accelerate-1.12.0
[0m  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.0/44.0 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m380.9/380.9 kB[0m [31m8.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.0/12.0 MB[0m [31m58.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for trl (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━

In [6]:
! curl -L -o data.tar.bz2 https://github.com/salesforce/WikiSQL/raw/master/data.tar.bz2
! tar -xjf data.tar.bz2

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100 24.9M  100 24.9M    0     0  15.9M      0  0:00:01  0:00:01 --:--:-- 44.2M


In [7]:
import torch
import sqlglot
import pandas as pd
from datasets import Dataset, Features, Value, Sequence
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import GRPOTrainer, GRPOConfig
from peft import LoraConfig
import json
import sqlglot
from difflib import SequenceMatcher

MODEL_NAME = "ibm-granite/granite-4.0-350m-base"
OUTPUT_DIR = "/granite-grpo-wikisql"
MAX_PROMPT_LENGTH = 512
MAX_COMPLETION_LENGTH = 128
LOCAL_DATA_DIR = "data"

# Feature Schema
WIKISQL_FEATURES = Features({
    "phase": Value("int32"),
    "question": Value("string"),
    "sql": Value("string"),
    "table": Value("string"),
})

In [32]:
# --- Load main data files (.jsonl) and table schemas (.tables.jsonl) ---

"""
Loading the local jsonl files
dev -> Validation
train -> train

The dev and train json files have a table_id key so we know which table was referenced and the table json has all the headers page titles, rows and so on
"""
df_main_train = pd.read_json(f"{LOCAL_DATA_DIR}/train.jsonl", lines=True)
df_main_dev = pd.read_json(f"{LOCAL_DATA_DIR}/dev.jsonl", lines=True)
df_table_train = pd.read_json(f"{LOCAL_DATA_DIR}/train.tables.jsonl", lines=True)
df_table_dev = pd.read_json(f"{LOCAL_DATA_DIR}/dev.tables.jsonl", lines=True)


"""
Create dictionary with ID as index and all of the other table info as key value pairs for this dict
e.g

{
    # TABLE ID
    "1-1000181-1": {

        # COLUMN NAMES ARE INNER VALUES
        "header": ["State/territory", "Text/background colour", "Format", "Current slogan", "Current series", "Notes"],
        "types": ["text", "text", "text", "text", "text", "text"],
        "rows": [
            ["Australian Capital Territory", "blue/white", "Yaa\u00b7nna", ...],
            ... # All other rows
        ],
        "name": "table_1000181_1"
    }
}

Why? We want to look up the table quickly when we
"""
table_dict_train = df_table_train.set_index('id').T.to_dict('dict')
table_dict_dev = df_table_dev.set_index('id').T.to_dict('dict')

def enforce_sql_types(sql_dict):
    """
     The third value was sometimes int, sometimes string, ...
    """
    if 'conds' in sql_dict:
        new_conds = []
        for cond in sql_dict['conds']:
            if len(cond) == 3:
                # Cast the third element to string
                cond[2] = str(cond[2])
            new_conds.append(cond)
        sql_dict['conds'] = new_conds
    return sql_dict

def restructure_table(row):
    """
     Make sure the merged table is well formed
    """
    table_info = row['table']
    # Handle NaN/missing values from map operation
    if pd.isna(table_info) or isinstance(table_info, float):
         table_info = {'header': [], 'types': [], 'rows': [], 'id': row['table_id']}
    else:
         # Ensure the ID is present in the table info dictionary
         table_info['id'] = row['table_id']
    return table_info

def merge_and_serialize(df_main, table_dict, is_train_subset=False):
    if is_train_subset:
        df_subset = df_main.head(2000).copy()
    else:
        df_subset = df_main.copy()

    # Merge the table data based on the table id
    df_subset['table'] = df_subset['table_id'].map(table_dict)
    df_subset['table'] = df_subset.apply(restructure_table, axis=1)

    # make sure the sql part has strings in the conditions.
    df_subset['sql'] = df_subset['sql'].apply(enforce_sql_types)

    # make sure these columns are strings again not json.
    df_subset['sql'] = df_subset['sql'].apply(json.dumps)
    df_subset['table'] = df_subset['table'].apply(json.dumps)

    # creates HF Dataset object.
    return Dataset.from_pandas(df_subset.drop(columns=['table_id']), features=WIKISQL_FEATURES)

train_dataset = merge_and_serialize(df_main_train, table_dict_train, is_train_subset=True)
validation_dataset = merge_and_serialize(df_main_dev, table_dict_dev)

# The final dataset is the subsetted train_dataset
dataset = train_dataset

print(f"Data loading and preparation complete. Training subset size: {len(dataset)}")
print("First example:")
print(dataset[0])

Data loading and preparation complete. Training subset size: 2000
First example:
{'phase': 1, 'question': 'Tell me what the notes are for South Australia ', 'sql': '{"sel": 5, "conds": [[3, 0, "SOUTH AUSTRALIA"]], "agg": 0}', 'table': '{"header": ["State/territory", "Text/background colour", "Format", "Current slogan", "Current series", "Notes"], "types": ["text", "text", "text", "text", "text", "text"], "rows": [["Australian Capital Territory", "blue/white", "Yaa\\u00b7nna", "ACT \\u00b7 CELEBRATION OF A CENTURY 2013", "YIL\\u00b700A", "Slogan screenprinted on plate"], ["New South Wales", "black/yellow", "aa\\u00b7nn\\u00b7aa", "NEW SOUTH WALES", "BX\\u00b799\\u00b7HI", "No slogan on current series"], ["New South Wales", "black/white", "aaa\\u00b7nna", "NSW", "CPX\\u00b712A", "Optional white slimline series"], ["Northern Territory", "ochre/white", "Ca\\u00b7nn\\u00b7aa", "NT \\u00b7 OUTBACK AUSTRALIA", "CB\\u00b706\\u00b7ZZ", "New series began in June 2011"], ["Queensland", "maroon/wh

In [33]:
import sqlite3
import os

def setup_sqlite_db_raw(table_dict, db_path):
    if os.path.exists(db_path):
        # Force close connection to release file lock if it exists
        try: training_conn.close()
        except: pass
        os.remove(db_path)

    conn = sqlite3.connect(db_path, timeout=30)
    conn.execute("PRAGMA journal_mode=WAL;")
    print(f"Building SQLite database with RAW headers...")

    for table_id, info in table_dict.items():
        safe_table_name = f"table_{table_id.replace('-', '_')}"
        headers = info['header']

        # Handle exact duplicate headers
        final_headers = []
        seen = {}
        for h in headers:
            h_str = str(h)
            if h_str in seen:
                seen[h_str] += 1
                h_str = f"{h_str}_{seen[h_str]}"
            else:
                seen[h_str] = 1
            final_headers.append(h_str)

        cols_definition = ", ".join([f'"{name}" TEXT' for name in final_headers])

        try:
            conn.execute(f'CREATE TABLE "{safe_table_name}" ({cols_definition});')
            placeholders = ", ".join(["?"] * len(final_headers))
            conn.executemany(f'INSERT INTO "{safe_table_name}" VALUES ({placeholders});', info['rows'])
        except Exception as e:
            continue

    conn.commit()
    return conn

training_conn = setup_sqlite_db_raw(table_dict_train, "wikisql_training.db")

Building SQLite database with RAW headers...


In [34]:
import pandas as pd

sample_table_id = list(table_dict_train.keys())[0]
safe_name = f"table_{sample_table_id.replace('-', '_')}"

print(f"Testing Query on: {safe_name}")
print("-" * 30)

try:
    query = f'SELECT * FROM "{safe_name}" LIMIT 2'
    df_test = pd.read_sql_query(query, training_conn)

    print("Successfully retrieved data:")
    display(df_test)

    print("\nRaw Column Names in DB:")
    print(df_test.columns.tolist())

except Exception as e:
    print(f"Query failed: {e}")

Testing Query on: table_1_1000181_1
------------------------------
Successfully retrieved data:


Unnamed: 0,State/territory,Text/background colour,Format,Current slogan,Current series,Notes
0,Australian Capital Territory,blue/white,Yaa·nna,ACT · CELEBRATION OF A CENTURY 2013,YIL·00A,Slogan screenprinted on plate
1,New South Wales,black/yellow,aa·nn·aa,NEW SOUTH WALES,BX·99·HI,No slogan on current series



Raw Column Names in DB:
['State/territory', 'Text/background colour', 'Format', 'Current slogan', 'Current series', 'Notes']


In [36]:
dataset[0]

{'phase': 1,
 'question': 'Tell me what the notes are for South Australia ',
 'sql': '{"sel": 5, "conds": [[3, 0, "SOUTH AUSTRALIA"]], "agg": 0}',
 'table': '{"header": ["State/territory", "Text/background colour", "Format", "Current slogan", "Current series", "Notes"], "types": ["text", "text", "text", "text", "text", "text"], "rows": [["Australian Capital Territory", "blue/white", "Yaa\\u00b7nna", "ACT \\u00b7 CELEBRATION OF A CENTURY 2013", "YIL\\u00b700A", "Slogan screenprinted on plate"], ["New South Wales", "black/yellow", "aa\\u00b7nn\\u00b7aa", "NEW SOUTH WALES", "BX\\u00b799\\u00b7HI", "No slogan on current series"], ["New South Wales", "black/white", "aaa\\u00b7nna", "NSW", "CPX\\u00b712A", "Optional white slimline series"], ["Northern Territory", "ochre/white", "Ca\\u00b7nn\\u00b7aa", "NT \\u00b7 OUTBACK AUSTRALIA", "CB\\u00b706\\u00b7ZZ", "New series began in June 2011"], ["Queensland", "maroon/white", "nnn\\u00b7aaa", "QUEENSLAND \\u00b7 SUNSHINE STATE", "999\\u00b7TLG", "

In [1]:
import json
import sqlglot

def format_wikisql_with_gold_target(example):
    """Formats the prompt with hidden metadata for the reward functions."""
    table_dict = json.loads(example["table"])
    table_header = table_dict["header"]
    question = example["question"]
    gold_sql = example["sql"] # This is the WikiSQL JSON logic
    table_id = table_dict["id"]

    prompt = (
        f"[GOLD_SQL]{gold_sql}[/GOLD_SQL]\n"
        f"[TABLE_ID]{table_id}[/TABLE_ID]\n"
        f"Generate a SQL query to answer the question based on the table schema.\n"
        f"Schema: {table_header}\n"
        f"Question: {question}\n"
        f"Wrap your answer in <sql> tags.\n"
        f"Answer:"
    )
    return {"prompt": prompt}

def strict_format_reward_func(completions, **kwargs):
    """
    Weight: 2.0
    """
    rewards = []
    for completion in completions:
        text = completion.strip()
        if text.startswith("<sql>") and text.endswith("</sql>"):
            # Penalize if it repeats the question or table inside the tags
            if any(x in text for x in ["Question:", "Table:", "Schema:"]):
                rewards.append(0.5)
            else:
                rewards.append(2.0)
        else:
            rewards.append(0.0)
    return rewards

def execution_and_syntax_combo_reward(prompts, completions, **kwargs):
    """
    Weight: 8.0
    """
    rewards = []
    global training_conn

    for prompt, completion in zip(prompts, completions):
        score = 0.0
        try:

            gt_json = json.loads(prompt.split("[GOLD_SQL]")[1].split("[/GOLD_SQL]")[0])
            schema_cols = eval(prompt.split("Schema: ")[1].split("\nQuestion")[0])
            table_id = prompt.split("[TABLE_ID]")[1].split("[/TABLE_ID]")[0]

            if "<sql>" not in completion or "</sql>" not in completion:
                rewards.append(0.0)
                continue

            gen_sql = completion.split("<sql>")[1].split("</sql>")[0].strip()

            if "*" in gen_sql:
                rewards.append(-2.0)
                continue

            try:
                sqlglot.parse(gen_sql, read="sqlite")
                score += 3.0
            except:
                rewards.append(-1.0) # Penalty for broken SQL
                continue


            gold_sql_str = gold_json_to_sql_string(gt_json, table_id, schema_cols)
            gold_res = set(training_conn.execute(gold_sql_str).fetchall())

            try:
                gen_res = set(training_conn.execute(gen_sql).fetchall())

                if gen_res == gold_res:
                    score += 5.
                else:
                    pass
            except:
                score -= 2.0 # Penalty if the SQL is valid syntax but fails on THIS table
            rewards.append(max(0.0, score)) # Ensure no negative total scores for valid attempts


        except Exception:
            rewards.append(0.0)

    return rewards

In [None]:
dataset = dataset.map(format_wikisql_with_gold_target)

In [None]:

# Load Tokenizer (Keep this, as it's needed for the AutoModelForCausalLM and data prep)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Load Model
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

# LoRA Configuration
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
    task_type="CAUSAL_LM",
    lora_dropout=0.05,
)

# GRPO Config
training_args = GRPOConfig(
    output_dir=OUTPUT_DIR,
    learning_rate=1e-5,
    bf16=False, # Disabled cause of Collab
    fp16=True,
    tf32=False,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    num_generations=4,
    max_prompt_length=MAX_PROMPT_LENGTH,
    max_completion_length=MAX_COMPLETION_LENGTH,
    num_train_epochs=1,
    logging_steps=10,
    save_strategy="steps",
    save_steps=100,
    report_to="none",
    use_vllm=False
)

trainer = GRPOTrainer(
    model=model,
    reward_funcs=[
        strict_format_reward_func,
        execution_and_syntax_combo_reward
    ],
    args=training_args,
    train_dataset=dataset,
    peft_config=peft_config
)

print("Starting GRPO Training for Text-to-SQL with full logical rewards...")
try:
    trainer.train()
except KeyboardInterrupt:
    print("\nTraining interrupted by user. Proceeding to final save...")

# Save the final adapter
trainer.save_model(OUTPUT_DIR)
print(f"Training complete. Model saved to {OUTPUT_DIR}")

The model is already on multiple devices. Skipping the move to device specified in `args`.


Starting GRPO Training for Text-to-SQL with full logical rewards...


Step,Training Loss
10,0.0148
20,0.0104
30,0.0117
40,0.0116
50,0.0273
60,0.0094
70,0.0
80,0.0006
90,0.0023
100,0.0126


In [40]:
import torch
from peft import PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import json
import pandas as pd
from datasets import Dataset, Features, Value
import os # <-- IMPORTED OS FOR PATH JOINING

# --- FIX 1: UPDATED OUTPUT_DIR to E: DRIVE ---
# Assuming the adapter was saved here after the GRPO training run.
OUTPUT_DIR = "/granite-grpo-wikisql"
MAX_PROMPT_LENGTH = 512
MAX_COMPLETION_LENGTH = 128

# --- FIX 2: Using os.path.join for robust local data access ---
# This ensures the script finds the data files regardless of where the script is run from.
LOCAL_DATA_DIR = os.path.join(os.getcwd(), "data")
# --- Utility Functions (Must match the ones used in pre-processing) ---
def format_wikisql(example):
    """Formats the input prompt for the model."""
    table_dict = json.loads(example["table"])
    table_header = table_dict["header"]
    question = example["question"]

    prompt = (
        f"Generate a SQL query to answer the question based on the table schema.\n"
        f"Schema: {table_header}\n"
        f"Question: {question}\n"
        f"Wrap your answer in <sql> tags. For example: <sql>SELECT * FROM table</sql>.\n"
        f"Answer:"
    )
    return {"prompt": prompt}

def get_validation_data():
    """Simplified data loading for the dev set."""
    df_main_dev = pd.read_json(f"{LOCAL_DATA_DIR}/dev.jsonl", lines=True)
    df_table_dev = pd.read_json(f"{LOCAL_DATA_DIR}/dev.tables.jsonl", lines=True)
    table_dict_dev = df_table_dev.set_index('id').T.to_dict('dict')

    # We need the full merge and serialize here, but for brevity, we assume
    # a function similar to merge_and_serialize was used to generate this.

    # For a quick test, we will just merge and serialize the first 10 dev samples:
    df_subset = df_main_dev.head(10).copy()
    df_subset['table'] = df_subset['table_id'].map(table_dict_dev)
    # NOTE: In a full script, you would include restructure_table and enforce_sql_types here

    # Simple serialization for demonstration:
    df_subset['sql'] = df_subset['sql'].apply(json.dumps)
    df_subset['table'] = df_subset['table'].apply(json.dumps)

    dataset = Dataset.from_pandas(df_subset.drop(columns=['table_id']), features=WIKISQL_FEATURES)
    return dataset.map(format_wikisql)

# ==========================================
# 3. Model Loading and Inference
# ==========================================
print("1. Loading Tokenizer and Base Model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

print("2. Loading and Merging LoRA Adapter...")
model = PeftModel.from_pretrained(base_model, OUTPUT_DIR)
model = model.merge_and_unload()
model.eval()
print("✅ Model adapter merged and ready for inference.")

# ==========================================
# 4. Testing Inference
# ==========================================
validation_dataset = get_validation_data()
test_sample = validation_dataset[1] # Use the second example

prompt = test_sample["prompt"]
ground_truth_sql = json.loads(test_sample["sql"]) # Deserialize for comparison

print("\n--- TEST CASE ---")
print(f"Question: {test_sample['question']}")
print(f"Ground Truth SQL: {ground_truth_sql}")
print("--- Generating Query ---")

inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

# Generate the completion
with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=MAX_COMPLETION_LENGTH,
        do_sample=False,
        pad_token_id=tokenizer.pad_token_id
    )

# Decode the output, stripping the input prompt
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
generated_completion = generated_text[len(prompt):].strip()

# --- Post-Processing ---
try:
    # Extract the SQL from between the tags (as done in the reward functions)
    generated_sql = generated_completion.split("<sql>")[1].split("</sql>")[0].strip()
except IndexError:
    generated_sql = "Extraction Failed (Tags Missing)"

print("\n--- RESULTS ---")
print(f"Generated Raw Text: {generated_completion}")
print(f"Extracted SQL: {generated_sql}")
print("\n--- FINAL VERDICT ---")

# Simple comparison for demonstration
if generated_sql and generated_sql != "Extraction Failed (Tags Missing)":
    print("SUCCESS: The model generated syntactically plausible SQL.")
else:
    print("FAILURE: The model failed to adhere to the required format.")

1. Loading Tokenizer and Base Model...
2. Loading and Merging LoRA Adapter...
✅ Model adapter merged and ready for inference.


Map:   0%|          | 0/10 [00:00<?, ? examples/s]


--- TEST CASE ---
Question: How many schools did player number 3 play at?
Ground Truth SQL: {'sel': 5, 'conds': [[1, 0, '3']], 'agg': 3}
--- Generating Query ---

--- RESULTS ---
Generated Raw Text: <sql>SELECT * FROM table</sql>
Extracted SQL: SELECT * FROM table

--- FINAL VERDICT ---
SUCCESS: The model generated syntactically plausible SQL.


In [42]:
test_sample_1 = validation_dataset[0]

prompt = test_sample_1["prompt"]
ground_truth_sql = json.loads(test_sample_1["sql"])

print("\n--- TEST CASE 2 ---")
print(f"Question: {test_sample_1['question']}")
print(f"Ground Truth SQL: {ground_truth_sql}")
print("--- Generating Query ---")

inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

# Generate the completion
with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=MAX_COMPLETION_LENGTH,
        do_sample=False, # Use greedy decoding for predictable results
        pad_token_id=tokenizer.pad_token_id
    )

generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
generated_completion = generated_text[len(prompt):].strip()

# --- Post-Processing ---
try:
    generated_sql = generated_completion.split("<sql>")[1].split("</sql>")[0].strip()
except IndexError:
    generated_sql = "Extraction Failed (Tags Missing)"

print("\n--- RESULTS ---")
print(f"Extracted SQL: {generated_sql}")


--- TEST CASE 2 ---
Question: What position does the player who played for butler cc (ks) play?
Ground Truth SQL: {'sel': 3, 'conds': [[5, 0, 'Butler CC (KS)']], 'agg': 0}
--- Generating Query ---

--- RESULTS ---
Extracted SQL: Extraction Failed (Tags Missing)


In [43]:
test_sample_5 = validation_dataset[5]

prompt = test_sample_5["prompt"]
ground_truth_sql = json.loads(test_sample_5["sql"])

print("\n--- TEST CASE 3 ---")
print(f"Question: {test_sample_5['question']}")
print(f"Ground Truth SQL: {ground_truth_sql}")
print("--- Generating Query ---")

inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

# Generate the completion
with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=MAX_COMPLETION_LENGTH,
        do_sample=False,
        pad_token_id=tokenizer.pad_token_id
    )

generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
generated_completion = generated_text[len(prompt):].strip()

# --- Post-Processing ---
try:
    generated_sql = generated_completion.split("<sql>")[1].split("</sql>")[0].strip()
except IndexError:
    generated_sql = "Extraction Failed (Tags Missing)"

print("\n--- RESULTS ---")
print(f"Extracted SQL: {generated_sql}")


--- TEST CASE 3 ---
Question: Who are all of the players on the Westchester High School club team?
Ground Truth SQL: {'sel': 0, 'conds': [[5, 0, 'Westchester High School']], 'agg': 0}
--- Generating Query ---

--- RESULTS ---
Extracted SQL: SELECT * FROM table


In [44]:
def get_raw_model_output(model, tokenizer, question, schema):
    # 1. Prepare the prompt (No GOLD_SQL tags here, just like real life)
    prompt = (
        f"Generate a SQL query to answer the question based on the table schema.\n"
        f"Schema: {schema}\n"
        f"Question: {question}\n"
        f"Wrap your answer in <sql> tags. For example: <sql>SELECT * FROM table</sql>.\n"
        f"Answer:"
    )

    # 2. Tokenize
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

    # 3. Generate
    outputs = model.generate(
        **inputs,
        max_new_tokens=128,
        do_sample=True,
        temperature=0.7,
        pad_token_id=tokenizer.eos_token_id
    )

    # 4. Decode EVERYTHING (skip_special_tokens=False to see EOS tags)
    full_text = tokenizer.decode(outputs[0], skip_special_tokens=False)

    # 5. Extract only the new part (after the prompt)
    response_part = full_text.split("Answer:")[1]

    return response_part

# --- RUN TEST ---
test_q = "How many schools did player number 3 play at?"
test_s = "['ID', 'Player', 'No.', 'School/Club Team', 'Years', 'Notes']"

raw_response = get_raw_model_output(model, tokenizer, test_q, test_s)

print("--- RAW MODEL OUTPUT ---")
print(raw_response)
print("------------------------")

--- RAW MODEL OUTPUT ---
 <sql>SELECT * FROM table</sql><|end_of_text|>
------------------------
