# SQL-R1 Extension: RL Training with Enhanced Rewards

This notebook demonstrates the RL extension for Text-to-SQL with enhanced reward components.

**Requirements:**
- Google Colab with GPU (T4, V100, or A100)
- ~24GB GPU memory recommended
- Runtime: GPU (Runtime > Change runtime type > GPU)

**Features:**
- Schema-aware rewards (detects hallucinations)
- Structural rewards (SELECT, WHERE, JOIN matching)
- Enhanced syntax rewards (AST validation)
- Optimized for 24GB GPU (3B model, gradient checkpointing)

## 1. Setup and Installation

In [None]:
# Check GPU availability
!nvidia-smi

In [None]:
# Clone the repository (replace with your repo URL)
!git clone https://github.com/YOUR_USERNAME/sql-r1-extension.git
%cd sql-r1-extension

In [None]:
# Install dependencies
!pip install -q torch transformers datasets sqlparse pyyaml wandb accelerate bitsandbytes

## 2. Test Enhanced Reward Components

In [None]:
# Test Schema-Aware Reward
from extensions.reward_enhanced import SchemaAwareReward

schema = {
    'tables': {
        'employees': {'columns': ['id', 'name', 'salary', 'department']},
        'departments': {'columns': ['id', 'name', 'budget']}
    }
}

reward = SchemaAwareReward(weight=-0.5)

# Test valid query
valid_sql = "SELECT name, salary FROM employees WHERE department = 'engineering'"
print(f"Valid SQL reward: {reward.compute(valid_sql, schema)}")

# Test hallucinated table
invalid_sql = "SELECT name FROM customers"  # 'customers' doesn't exist
print(f"Hallucinated table reward: {reward.compute(invalid_sql, schema)}")

# Test hallucinated column
invalid_sql2 = "SELECT name, age FROM employees"  # 'age' doesn't exist
print(f"Hallucinated column reward: {reward.compute(invalid_sql2, schema)}")

In [None]:
# Test Structural Reward
from extensions.reward_enhanced import StructuralReward

structural = StructuralReward(
    select_weight=0.3,
    where_weight=0.3,
    join_weight=0.2
)

generated = "SELECT name, salary FROM employees WHERE department = 'engineering'"
ground_truth = "SELECT name, salary FROM employees WHERE department = 'engineering'"

print(f"Exact match reward: {structural.compute(generated, ground_truth)}")

# Partial match (only SELECT matches)
generated2 = "SELECT name, salary FROM employees WHERE id = 1"
print(f"Partial match reward: {structural.compute(generated2, ground_truth)}")

In [None]:
# Test Enhanced Syntax Reward
from extensions.reward_enhanced import EnhancedSyntaxReward

syntax = EnhancedSyntaxReward(weight=0.2)

valid_sql = "SELECT * FROM employees"
print(f"Valid SQL reward: {syntax.compute(valid_sql)}")

invalid_sql = "SELECT FROM WHERE"  # Invalid syntax
print(f"Invalid SQL reward: {syntax.compute(invalid_sql)}")

In [None]:
# Test Enhanced Reward Computer (integrates all components)
from extensions.reward_enhanced import EnhancedRewardComputer

def baseline_reward(solution_str, ground_truth, **kwargs):
    """Simple baseline reward for testing."""
    return 1.0 if '```sql' in solution_str else 0.0

computer = EnhancedRewardComputer(
    baseline_reward_fn=baseline_reward,
    enable_enhanced=True,
    schema_weight=-0.5,
    structural_select_weight=0.3,
    structural_where_weight=0.3,
    structural_join_weight=0.2,
    syntax_weight=0.2
)

solution = """<think>Need to query employees table</think>
<answer>
```sql
SELECT name, salary FROM employees WHERE department = 'engineering'
```
</answer>"""

ground_truth = {
    'sql': "SELECT name, salary FROM employees WHERE department = 'engineering'"
}

result = computer.compute_reward(solution, ground_truth, schema=schema)
print(f"\nTotal reward: {result.total:.3f}")
print(f"Breakdown:")
print(f"  Baseline: {result.baseline_total:.3f}")
print(f"  Schema: {result.schema:.3f}")
print(f"  Structural: {result.structural:.3f}")
print(f"  Syntax: {result.syntax:.3f}")

## 3. Run Training Demo

In [None]:
# Run simplified training demo
# This demonstrates the integration but doesn't do full RL training
# For full training, integrate with SQL-R1's VERL framework

!python train_colab.py \
    --model "Qwen/Qwen2.5-Coder-3B-Instruct" \
    --config "configs/train_24gb.yaml" \
    --output-dir "./outputs" \
    --num-steps 10 \
    --batch-size 2 \
    --demo

## 4. Monitor GPU Memory Usage

In [None]:
import torch

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Total memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    print(f"Allocated: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB")
    print(f"Cached: {torch.cuda.memory_reserved(0) / 1e9:.2f} GB")
else:
    print("No GPU available")

## 5. Test with Custom Data

In [None]:
# Create custom test case
from train_colab import SimplifiedTrainer

trainer = SimplifiedTrainer(
    model_name="Qwen/Qwen2.5-Coder-3B-Instruct",
    config_path="configs/train_24gb.yaml",
    output_dir="./outputs"
)

# Test generation
prompt = """Convert the following question to SQL:
Question: Show all employees with salary greater than 50000
Schema: employees(id, name, salary, department)

SQL:"""

response = trainer.generate_response(prompt, max_length=256)
print(f"Generated response:\n{response}")

# Compute reward
ground_truth = {
    'sql': 'SELECT * FROM employees WHERE salary > 50000'
}

schema = {
    'tables': {
        'employees': {'columns': ['id', 'name', 'salary', 'department']}
    }
}

reward = trainer.compute_reward(response, ground_truth, schema)
print(f"\nReward breakdown:")
print(f"  Total: {reward.total:.3f}")
print(f"  Baseline: {reward.baseline_total:.3f}")
print(f"  Schema: {reward.schema:.3f}")
print(f"  Structural: {reward.structural:.3f}")
print(f"  Syntax: {reward.syntax:.3f}")

## 6. Integration with SQL-R1 (Production)

For production use, integrate with SQL-R1's full training pipeline:

```python
# In SQL-R1's training script (verl/trainer/main_ppo.py)
from extensions.reward_enhanced import EnhancedRewardComputer
from extensions.config import load_config

# Load configuration
config = load_config('configs/train_24gb.yaml')

# Wrap existing reward function
enhanced_reward = EnhancedRewardComputer(
    baseline_reward_fn=original_reward_function,
    enable_enhanced=config.reward.enable_enhanced,
    schema_weight=config.reward.schema_weight,
    structural_select_weight=config.reward.structural_select_weight,
    structural_where_weight=config.reward.structural_where_weight,
    structural_join_weight=config.reward.structural_join_weight,
    syntax_weight=config.reward.syntax_weight
)

# Use in training loop
reward_result = enhanced_reward.compute_reward(
    solution_str=model_output,
    ground_truth=batch['ground_truth'],
    schema=batch['schema']
)
```

## 7. Results and Analysis

The enhanced reward system provides:

1. **Schema-Aware Rewards**: Penalizes hallucinated tables/columns
2. **Structural Rewards**: Rewards partial correctness (SELECT, WHERE, JOIN)
3. **Enhanced Syntax Rewards**: Validates SQL syntax at AST level

Expected improvements:
- Reduced hallucinations (fewer non-existent tables/columns)
- Better structural correctness (even when exact match fails)
- Improved syntax validity

For full evaluation, run on Spider or WikiSQL benchmarks with actual RL training.