# RL Text-to-SQL with GRPO on Spider

Trains a Text-to-SQL model using GRPO (Group Relative Policy Optimization)
on the Spider benchmark.

**Before running:** `Runtime > Change runtime type > T4 GPU`

- Part A: Setup (install deps, create source files)
- Part B: Pipeline test with dummy data
- Part C: Train on Spider

## 0. Check GPU

In [None]:
!nvidia-smi
import torch
print(f"\nCUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1024**3:.1f} GB")

## 1. Install Dependencies

In [None]:
!pip install -q torch transformers==4.46.3 peft==0.13.2 bitsandbytes==0.44.1 accelerate==1.1.1 datasets sqlparse func-timeout pyyaml tqdm matplotlib

## 2. Create Source Files

Write all the project source files into a `src/` package for Colab.

In [None]:
import os
os.makedirs('src', exist_ok=True)
os.makedirs('outputs', exist_ok=True)

In [None]:
%%writefile src/__init__.py
# RL Text-to-SQL with GRPO

In [None]:
%%writefile src/sql_executor.py
"""
Safe SQL execution with timeout (threading-based).
"""
import sqlite3
import threading
from typing import Optional, Tuple, Any


class SQLTimeoutError(Exception):
    pass


def _run_query(sql, db_path, result_holder):
    """Run a SQL query in a thread and store the result."""
    try:
        conn = sqlite3.connect(db_path, timeout=5)
        conn.text_factory = str
        cursor = conn.cursor()
        cursor.execute(sql)
        result_holder["results"] = cursor.fetchall()
        result_holder["success"] = True
        conn.close()
    except sqlite3.Error as e:
        result_holder["error"] = f"SQLite error: {e}"
    except Exception as e:
        result_holder["error"] = f"Execution error: {e}"


class SQLExecutor:
    def __init__(self, timeout=10):
        self.timeout = timeout

    def execute(self, sql, db_path):
        if not sql or not sql.strip():
            return False, None, "Empty SQL query"
        result_holder = {"success": False, "results": None, "error": None}
        t = threading.Thread(target=_run_query, args=(sql, db_path, result_holder))
        t.start()
        t.join(timeout=self.timeout)
        if t.is_alive():
            return False, None, f"SQL execution timed out after {self.timeout}s"
        if result_holder["success"]:
            return True, result_holder["results"], None
        return False, None, result_holder.get("error", "Unknown error")

    def compare_results(self, pred_results, gold_results):
        if pred_results is None or gold_results is None:
            return False
        try:
            return set(tuple(r) for r in pred_results) == set(tuple(r) for r in gold_results)
        except (TypeError, ValueError):
            return pred_results == gold_results

    def execute_and_compare(self, pred_sql, gold_sql, db_path):
        pred_ok, pred_res, pred_err = self.execute(pred_sql, db_path)
        gold_ok, gold_res, gold_err = self.execute(gold_sql, db_path)
        match = False
        if pred_ok and gold_ok:
            match = self.compare_results(pred_res, gold_res)
        return {"pred_success": pred_ok, "gold_success": gold_ok, "execution_match": match,
                "pred_error": pred_err, "pred_results": pred_res, "gold_results": gold_res}

In [None]:
%%writefile src/reward.py
"""
Multi-component reward function (SQL-R1 style).
"""
import re
import sqlparse
from typing import List, Optional
from dataclasses import dataclass
from src.sql_executor import SQLExecutor


@dataclass
class RewardConfig:
    correct_execution: float = 1.0
    valid_but_wrong: float = 0.1
    invalid_sql: float = -0.5
    format_bonus: float = 0.2
    partial_match_bonus: float = 0.3
    execution_timeout: int = 10


@dataclass
class RewardResult:
    total_reward: float
    execution_correct: bool
    sql_valid: bool
    format_ok: bool
    partial_match_score: float
    error_message: Optional[str] = None


class RewardComputer:
    def __init__(self, config: RewardConfig):
        self.config = config
        self.executor = SQLExecutor(timeout=config.execution_timeout)

    def extract_sql_from_response(self, response):
        m = re.search(r"```sql\s*(.*?)\s*```", response, re.DOTALL | re.IGNORECASE)
        if m:
            return m.group(1).strip()
        m = re.search(r"```\s*(.*?)\s*```", response, re.DOTALL)
        if m:
            return m.group(1).strip()
        m = re.search(r"(SELECT\s+.*?)(?:;|\Z)", response, re.DOTALL | re.IGNORECASE)
        if m:
            return m.group(1).strip()
        return response.strip()

    def check_format(self, response):
        return bool(re.search(r"```sql\s*.*?\s*```", response, re.DOTALL | re.IGNORECASE))

    def _extract_identifiers(self, parsed):
        ids = []
        for token in parsed.flatten():
            if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Name.Placeholder):
                ids.append(token.value.lower())
        return ids

    def compute_partial_match(self, pred_sql, gold_sql):
        try:
            pred_parsed = sqlparse.parse(pred_sql)
            gold_parsed = sqlparse.parse(gold_sql)
            if not pred_parsed or not gold_parsed:
                return 0.0
            pred_tokens = set(self._extract_identifiers(pred_parsed[0]))
            gold_tokens = set(self._extract_identifiers(gold_parsed[0]))
            if not gold_tokens:
                return 0.0
            return len(pred_tokens & gold_tokens) / len(gold_tokens)
        except Exception:
            return 0.0

    def compute_reward(self, response, gold_sql, db_path):
        reward = 0.0
        pred_sql = self.extract_sql_from_response(response)
        format_ok = self.check_format(response)
        if format_ok:
            reward += self.config.format_bonus
        exec_result = self.executor.execute_and_compare(pred_sql, gold_sql, db_path)
        if exec_result["execution_match"]:
            reward += self.config.correct_execution
            return RewardResult(total_reward=reward, execution_correct=True, sql_valid=True,
                                format_ok=format_ok, partial_match_score=1.0)
        if exec_result["pred_success"]:
            reward += self.config.valid_but_wrong
            partial = self.compute_partial_match(pred_sql, gold_sql)
            reward += self.config.partial_match_bonus * partial
            return RewardResult(total_reward=reward, execution_correct=False, sql_valid=True,
                                format_ok=format_ok, partial_match_score=partial)
        reward += self.config.invalid_sql
        return RewardResult(total_reward=reward, execution_correct=False, sql_valid=False,
                            format_ok=format_ok, partial_match_score=0.0, error_message=exec_result["pred_error"])

    def compute_group_rewards(self, responses, gold_sql, db_path):
        return [self.compute_reward(resp, gold_sql, db_path) for resp in responses]

In [None]:
%%writefile src/data_utils.py
"""
Spider dataset loading and prompt formatting.
"""
import json
import os
import sqlite3
from typing import Dict, List, Optional
from dataclasses import dataclass


@dataclass
class Text2SQLSample:
    question: str
    gold_sql: str
    db_id: str
    db_path: str
    schema: str


SYSTEM_PROMPT = ("You are an expert SQL assistant. Given a natural language question "
    "and a database schema, generate the correct SQL query.\n\n"
    "Rules:\n- Output ONLY the SQL query wrapped in ```sql``` tags\n"
    "- Use proper SQL syntax for SQLite\n- Do not include explanations outside the SQL block")


def get_schema_from_db(db_path):
    try:
        conn = sqlite3.connect(db_path)
        cursor = conn.cursor()
        cursor.execute("SELECT sql FROM sqlite_master WHERE type='table' AND sql IS NOT NULL")
        tables = cursor.fetchall()
        conn.close()
        return "\n\n".join(t[0] for t in tables if t[0])
    except Exception as e:
        return f"-- Error reading schema: {e}"


def format_prompt(question, schema):
    return (f"Given the following database schema:\n\n{schema}\n\n"
            f"Question: {question}\n\nGenerate the SQL query. Wrap it in ```sql``` tags.")


def build_chat_messages(question, schema):
    return [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": format_prompt(question, schema)},
    ]


def load_spider_dataset(data_file, db_dir, max_samples=None):
    if not os.path.exists(data_file):
        raise FileNotFoundError(f"Spider data file not found: {data_file}")
    with open(data_file, "r") as f:
        data = json.load(f)
    samples = []
    skipped = 0
    for item in data:
        db_id = item["db_id"]
        db_path = os.path.join(db_dir, db_id, f"{db_id}.sqlite")
        if not os.path.exists(db_path):
            skipped += 1
            continue
        schema = get_schema_from_db(db_path)
        samples.append(Text2SQLSample(
            question=item["question"], gold_sql=item.get("query", item.get("SQL", "")),
            db_id=db_id, db_path=db_path, schema=schema))
        if max_samples and len(samples) >= max_samples:
            break
    if skipped > 0:
        print(f"Warning: Skipped {skipped} samples (missing db files)")
    print(f"Loaded {len(samples)} samples from {data_file}")
    return samples

In [None]:
%%writefile src/model_utils.py
"""
Model loading with QLoRA (float16 for T4 compatibility).
"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType


def load_tokenizer(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, padding_side="left")
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
    return tokenizer


def load_model_with_lora(model_name, lora_r=16, lora_alpha=32, lora_dropout=0.05,
                         target_modules=None, quantization="4bit"):
    if target_modules is None:
        target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
    bnb_config = None
    if quantization == "4bit":
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True, bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_name, quantization_config=bnb_config, torch_dtype=torch.float16,
        device_map="auto", trust_remote_code=True)
    if quantization in ("4bit", "8bit"):
        model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
    lora_config = LoraConfig(r=lora_r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
        target_modules=target_modules, task_type=TaskType.CAUSAL_LM, bias="none")
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()
    return model


def create_reference_model(model):
    ref_params = {}
    for name, param in model.named_parameters():
        if "lora_" in name:
            ref_params[name] = param.data.clone().detach()
    print(f"Reference policy: {len(ref_params)} LoRA tensors frozen")
    return ref_params

In [None]:
%%writefile src/grpo_trainer.py
"""
GRPO trainer for Text-to-SQL.
Generates K candidates, executes them, uses group-normalized advantages.
"""
import torch
import torch.nn.functional as F
from typing import List, Tuple
from dataclasses import dataclass
from src.reward import RewardComputer, RewardConfig, RewardResult
from src.data_utils import Text2SQLSample, build_chat_messages


@dataclass
class GRPOConfig:
    group_size: int = 4
    clip_epsilon: float = 0.2
    kl_coeff: float = 0.05
    temperature: float = 0.7
    max_new_tokens: int = 256
    top_p: float = 0.95


class GRPOTrainer:
    def __init__(self, model, tokenizer, ref_params, optimizer, grpo_config, reward_config, device="cuda"):
        self.model = model
        self.tokenizer = tokenizer
        self.ref_params = ref_params
        self.optimizer = optimizer
        self.config = grpo_config
        self.reward_computer = RewardComputer(reward_config)
        self.device = device
        self.global_step = 0

    @torch.no_grad()
    def generate_group(self, sample):
        """Generate K SQL candidates one at a time."""
        messages = build_chat_messages(sample.question, sample.schema)
        prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(self.device)
        prompt_length = inputs["input_ids"].shape[1]
        self.model.eval()
        all_responses, all_token_log_probs, all_generated_ids = [], [], []
        for k in range(self.config.group_size):
            outputs = self.model.generate(
                **inputs, max_new_tokens=self.config.max_new_tokens, do_sample=True,
                temperature=self.config.temperature, top_p=self.config.top_p,
                num_return_sequences=1, return_dict_in_generate=True, output_scores=True,
                pad_token_id=self.tokenizer.pad_token_id)
            gen_ids = outputs.sequences[:, prompt_length:]
            scores = torch.stack(outputs.scores, dim=1)
            log_probs_all = F.log_softmax(scores, dim=-1)
            gen_len = gen_ids.shape[1]
            log_probs_all = log_probs_all[:, :gen_len, :]
            token_lp = torch.gather(log_probs_all, 2, gen_ids.unsqueeze(-1)).squeeze(-1)
            response = self.tokenizer.decode(gen_ids[0], skip_special_tokens=True)
            all_responses.append(response)
            all_token_log_probs.append(token_lp.squeeze(0))
            all_generated_ids.append(gen_ids.squeeze(0))
        max_len = max(t.shape[0] for t in all_generated_ids)
        padded_log_probs = torch.zeros(self.config.group_size, max_len, device=self.device)
        padded_gen_ids = torch.full((self.config.group_size, max_len),
            self.tokenizer.pad_token_id, device=self.device, dtype=all_generated_ids[0].dtype)
        for k in range(self.config.group_size):
            length = all_generated_ids[k].shape[0]
            padded_gen_ids[k, :length] = all_generated_ids[k]
            padded_log_probs[k, :length] = all_token_log_probs[k]
        return all_responses, padded_log_probs, padded_gen_ids

    def compute_advantages(self, rewards):
        rewards_tensor = torch.tensor(rewards, dtype=torch.float32)
        return ((rewards_tensor - rewards_tensor.mean()) / (rewards_tensor.std() + 1e-8)).to(self.device)

    def compute_kl_penalty(self):
        kl = torch.tensor(0.0, device=self.device)
        count = 0
        for name, param in self.model.named_parameters():
            if name in self.ref_params:
                kl += F.mse_loss(param, self.ref_params[name].to(param.device), reduction="sum")
                count += 1
        return kl / max(count, 1)

    def compute_policy_loss(self, sample, responses, old_log_probs, generated_ids, advantages):
        self.model.train()
        messages = build_chat_messages(sample.question, sample.schema)
        prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        prompt_ids = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)["input_ids"].to(self.device)
        prompt_len = prompt_ids.shape[1]
        total_loss = torch.tensor(0.0, device=self.device)
        valid_samples = 0
        for k in range(self.config.group_size):
            gen_ids_k = generated_ids[k].unsqueeze(0)
            pad_mask = gen_ids_k[0] != self.tokenizer.pad_token_id
            gen_ids_k = gen_ids_k[:, pad_mask]
            if gen_ids_k.shape[1] == 0:
                continue
            full_ids = torch.cat([prompt_ids, gen_ids_k], dim=1)
            outputs = self.model(full_ids, use_cache=False)
            gen_logits = outputs.logits[:, prompt_len - 1:-1, :]
            new_log_probs = F.log_softmax(gen_logits, dim=-1)
            gen_len = min(gen_ids_k.shape[1], new_log_probs.shape[1])
            new_token_lp = torch.gather(new_log_probs[:, :gen_len, :], 2, gen_ids_k[:, :gen_len].unsqueeze(-1)).squeeze(-1)
            old_token_lp = old_log_probs[k, :gen_len].unsqueeze(0)
            ratio = torch.exp(new_token_lp - old_token_lp.detach())
            avg_ratio = ratio.mean()
            adv = advantages[k]
            surr1 = avg_ratio * adv
            surr2 = torch.clamp(avg_ratio, 1.0 - self.config.clip_epsilon, 1.0 + self.config.clip_epsilon) * adv
            total_loss += -torch.min(surr1, surr2)
            valid_samples += 1
        if valid_samples > 0:
            total_loss = total_loss / valid_samples
        kl_penalty = self.compute_kl_penalty()
        loss = total_loss + self.config.kl_coeff * kl_penalty
        return loss, {"policy_loss": total_loss.item(), "kl_penalty": kl_penalty.item(),
                      "total_loss": loss.item(), "mean_advantage": advantages.mean().item()}

    def train_step(self, sample):
        responses, old_log_probs, generated_ids = self.generate_group(sample)
        reward_results = self.reward_computer.compute_group_rewards(responses, sample.gold_sql, sample.db_path)
        rewards = [r.total_reward for r in reward_results]
        advantages = self.compute_advantages(rewards)
        loss, metrics = self.compute_policy_loss(sample, responses, old_log_probs, generated_ids, advantages)
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
        self.optimizer.step()
        n_correct = sum(1 for r in reward_results if r.execution_correct)
        n_valid = sum(1 for r in reward_results if r.sql_valid)
        metrics.update({"mean_reward": sum(rewards)/len(rewards), "max_reward": max(rewards),
                        "correct_ratio": n_correct/len(rewards), "valid_ratio": n_valid/len(rewards),
                        "step": self.global_step})
        self.global_step += 1
        return metrics

---
## Part B: Pipeline Test (Dummy Data)

Quick check that the GRPO loop works before downloading Spider.

In [None]:
# Load model
from src.model_utils import load_model_with_lora, load_tokenizer, create_reference_model
import torch

MODEL_NAME = "Qwen/Qwen2.5-Coder-3B-Instruct"

print("Loading tokenizer...")
tokenizer = load_tokenizer(MODEL_NAME)

print("Loading model with QLoRA (4-bit)...")
torch.cuda.empty_cache()
model = load_model_with_lora(model_name=MODEL_NAME, lora_r=16, lora_alpha=32, lora_dropout=0.05, quantization="4bit")

print("\nCreating reference policy...")
ref_params = create_reference_model(model)

print(f"\nGPU memory: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")

In [None]:
# Create dummy test data for pipeline verification
import sqlite3, tempfile
from src.data_utils import Text2SQLSample, get_schema_from_db

tmp_dir = tempfile.mkdtemp()
test_db_path = f"{tmp_dir}/test.sqlite"

conn = sqlite3.connect(test_db_path)
c = conn.cursor()
c.execute("CREATE TABLE employees (id INTEGER PRIMARY KEY, name TEXT, department TEXT, salary REAL)")
c.execute("CREATE TABLE departments (id INTEGER PRIMARY KEY, name TEXT, budget REAL)")
c.executemany("INSERT INTO employees VALUES (?,?,?,?)", [
    (1,"Alice","Engineering",120000), (2,"Bob","Marketing",90000),
    (3,"Charlie","Engineering",110000), (4,"Diana","HR",95000), (5,"Eve","Engineering",130000)])
c.executemany("INSERT INTO departments VALUES (?,?,?)", [
    (1,"Engineering",500000), (2,"Marketing",200000), (3,"HR",150000)])
conn.commit()
conn.close()

schema = get_schema_from_db(test_db_path)
test_samples = [
    Text2SQLSample("How many employees are there?", "SELECT COUNT(*) FROM employees", "test", test_db_path, schema),
    Text2SQLSample("What is the average salary?", "SELECT AVG(salary) FROM employees", "test", test_db_path, schema),
    Text2SQLSample("List employees in Engineering.", "SELECT name FROM employees WHERE department='Engineering'", "test", test_db_path, schema),
    Text2SQLSample("What is the highest salary?", "SELECT MAX(salary) FROM employees", "test", test_db_path, schema),
    Text2SQLSample("How many departments are there?", "SELECT COUNT(*) FROM departments", "test", test_db_path, schema),
]
print(f"Created {len(test_samples)} test samples")

In [None]:
# Run 3 GRPO training steps on dummy data
from src.grpo_trainer import GRPOTrainer, GRPOConfig
from src.reward import RewardConfig

optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=5e-6, weight_decay=0.01)
trainer = GRPOTrainer(
    model=model, tokenizer=tokenizer, ref_params=ref_params, optimizer=optimizer,
    grpo_config=GRPOConfig(group_size=4, max_new_tokens=256),
    reward_config=RewardConfig(), device="cuda")

print("Running 3 GRPO steps...\n")
for i, sample in enumerate(test_samples[:3]):
    print(f"Step {i+1}/3 | Q: '{sample.question}'")
    metrics = trainer.train_step(sample)
    print(f"  Loss={metrics['total_loss']:.4f}  Reward={metrics['mean_reward']:.3f}  "
          f"Correct={metrics['correct_ratio']:.0%}  Valid={metrics['valid_ratio']:.0%}\n")

print("Pipeline test passed. Ready for Spider.")

---
## Part C: Train on Spider

Download the dataset and run actual training.

In [None]:
# Download Spider dataset
!pip install -q datasets
from datasets import load_dataset
import json, os, shutil

print("Downloading Spider from HuggingFace...")
ds = load_dataset("xlangai/spider")

os.makedirs("data/spider", exist_ok=True)

train_records = [{"question": r["question"], "query": r["query"], "db_id": r["db_id"]} for r in ds["train"]]
with open("data/spider/train_spider.json", "w") as f:
    json.dump(train_records, f)
print(f"Saved {len(train_records)} training samples")

dev_records = [{"question": r["question"], "query": r["query"], "db_id": r["db_id"]} for r in ds["validation"]]
with open("data/spider/dev.json", "w") as f:
    json.dump(dev_records, f)
print(f"Saved {len(dev_records)} dev samples")

print("\nNote: You still need the SQLite database files.")
print("Download from https://yale-lily.github.io/spider and upload the database/ folder.")
print("Or run the next cell to try downloading them.")

In [None]:
# Download Spider database files
import subprocess

if not os.path.exists("data/spider/database"):
    print("Downloading Spider databases...")
    !wget -q https://drive.usercontent.google.com/download?id=1403EGqzIDoHMdQF4c9Bkyl7dZLZ5Wt6J -O spider_dbs.zip 2>/dev/null || echo "Direct download failed, trying gdown..."
    
    if os.path.exists("spider_dbs.zip"):
        !unzip -q spider_dbs.zip -d data/spider/
        print("Databases downloaded and extracted")
    else:
        !pip install -q gdown
        !gdown --id 1403EGqzIDoHMdQF4c9Bkyl7dZLZ5Wt6J -O spider_dbs.zip
        if os.path.exists("spider_dbs.zip"):
            !unzip -q spider_dbs.zip -d data/spider/
            print("Databases downloaded and extracted")
        else:
            print("Could not auto-download databases.")
            print("Please download manually from https://yale-lily.github.io/spider")
            print("Upload the 'database' folder to data/spider/database/")
else:
    num_dbs = len(os.listdir("data/spider/database"))
    print(f"Spider databases already present: {num_dbs} databases")

In [None]:
# Load Spider dataset
from src.data_utils import load_spider_dataset

# Start with 100 samples for quick iteration (remove max_samples for full training)
MAX_TRAIN = 100
MAX_EVAL = 50

train_data = load_spider_dataset("data/spider/train_spider.json", "data/spider/database", max_samples=MAX_TRAIN)
eval_data = load_spider_dataset("data/spider/dev.json", "data/spider/database", max_samples=MAX_EVAL)

print(f"\nTrain: {len(train_data)} samples")
print(f"Eval:  {len(eval_data)} samples")
print(f"\nExample: [{train_data[0].db_id}] {train_data[0].question}")
print(f"  Gold: {train_data[0].gold_sql}")

In [None]:
# Evaluate model BEFORE RL training (baseline)
from src.reward import RewardComputer, RewardConfig
from src.data_utils import build_chat_messages

model.eval()
reward_computer = RewardComputer(RewardConfig())
correct_before = 0

print("Evaluating baseline (before RL)...\n")
for sample in eval_data[:20]:
    messages = build_chat_messages(sample.question, sample.schema)
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to("cuda")
    with torch.no_grad():
        output = model.generate(**inputs, max_new_tokens=256, do_sample=False, pad_token_id=tokenizer.pad_token_id)
    response = tokenizer.decode(output[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
    r = reward_computer.compute_reward(response, sample.gold_sql, sample.db_path)
    correct_before += int(r.execution_correct)
    status = "PASS" if r.execution_correct else "FAIL"
    print(f"  [{status}] [{sample.db_id}] {sample.question[:60]}")

baseline_acc = correct_before / 20
print(f"\nBaseline accuracy: {correct_before}/20 = {baseline_acc:.0%}")

In [None]:
# GRPO Training on Spider
import random
from src.grpo_trainer import GRPOTrainer, GRPOConfig

# Reset optimizer for fresh training
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=5e-6, weight_decay=0.01)

trainer = GRPOTrainer(
    model=model, tokenizer=tokenizer, ref_params=ref_params, optimizer=optimizer,
    grpo_config=GRPOConfig(group_size=4, clip_epsilon=0.2, kl_coeff=0.05, temperature=0.7, max_new_tokens=256),
    reward_config=RewardConfig(), device="cuda")

NUM_EPOCHS = 1
all_metrics = []

for epoch in range(NUM_EPOCHS):
    print(f"\n{'='*60}")
    print(f"EPOCH {epoch+1}/{NUM_EPOCHS} â€” Training on {len(train_data)} Spider samples")
    print(f"{'='*60}")
    random.shuffle(train_data)

    for step, sample in enumerate(train_data):
        print(f"\nStep {step+1}/{len(train_data)} | [{sample.db_id}] {sample.question[:50]}...")
        metrics = trainer.train_step(sample)
        all_metrics.append(metrics)
        print(f"  Loss={metrics['total_loss']:.4f}  Reward={metrics['mean_reward']:.3f}  "
              f"Correct={metrics['correct_ratio']:.0%}  Valid={metrics['valid_ratio']:.0%}")

    avg_reward = sum(m['mean_reward'] for m in all_metrics[-len(train_data):]) / len(train_data)
    avg_correct = sum(m['correct_ratio'] for m in all_metrics[-len(train_data):]) / len(train_data)
    print(f"\nEpoch {epoch+1} summary: avg_reward={avg_reward:.3f}, avg_correct={avg_correct:.1%}")

print("\n" + "="*60)
print("TRAINING COMPLETE")
print("="*60)

In [None]:
# Plot training curves
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(18, 5))
steps = range(len(all_metrics))

axes[0].plot(steps, [m['total_loss'] for m in all_metrics], 'b-', alpha=0.7)
axes[0].set_title('Total Loss'); axes[0].set_xlabel('Step'); axes[0].grid(True, alpha=0.3)

axes[1].plot(steps, [m['mean_reward'] for m in all_metrics], 'g-', alpha=0.7)
axes[1].set_title('Mean Reward'); axes[1].set_xlabel('Step'); axes[1].grid(True, alpha=0.3)

axes[2].plot(steps, [m['correct_ratio'] for m in all_metrics], 'r-', alpha=0.7)
axes[2].set_title('Execution Correct Ratio'); axes[2].set_xlabel('Step')
axes[2].set_ylim(-0.05, 1.05); axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('outputs/spider_training_curves.png', dpi=150)
plt.show()

In [None]:
# Evaluate AFTER RL training
model.eval()
correct_after = 0

print("Evaluating after GRPO training...\n")
for sample in eval_data[:20]:
    messages = build_chat_messages(sample.question, sample.schema)
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to("cuda")
    with torch.no_grad():
        output = model.generate(**inputs, max_new_tokens=256, do_sample=False, pad_token_id=tokenizer.pad_token_id)
    response = tokenizer.decode(output[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
    r = reward_computer.compute_reward(response, sample.gold_sql, sample.db_path)
    correct_after += int(r.execution_correct)
    status = "PASS" if r.execution_correct else "FAIL"
    pred_sql = reward_computer.extract_sql_from_response(response)
    print(f"  [{status}] [{sample.db_id}] {sample.question[:50]}")
    print(f"      Gold: {sample.gold_sql}")
    print(f"      Pred: {pred_sql}")

post_acc = correct_after / 20
print(f"\n{'='*50}")
print(f"Before RL: {baseline_acc:.0%}")
print(f"After RL:  {post_acc:.0%}")
print(f"Improvement: {post_acc - baseline_acc:+.0%}")
print(f"{'='*50}")

In [None]:
# Save model
model.save_pretrained("outputs/grpo_spider_model")
tokenizer.save_pretrained("outputs/grpo_spider_model")
print("Model saved to outputs/grpo_spider_model/")