# RL Text-to-SQL Training on Google Colab

This notebook demonstrates how to train a Text-to-SQL model with RL on a 24GB GPU (Colab A100).

**Runtime:** GPU (A100 recommended)

**Estimated time:** 2-4 hours for 3 epochs on Spider train set

## 1. Setup

In [None]:
# Check GPU
!nvidia-smi

In [None]:
# Clone repository
!git clone https://github.com/GameChaser782/rl-text2sql.git
%cd rl-text2sql

In [None]:
# Install dependencies
!pip install -q -U pip
!pip install -q transformers>=4.35.0 accelerate>=0.25.0 peft>=0.7.0 bitsandbytes>=0.41.0
!pip install -q trl>=0.7.0 datasets>=2.14.0 pyyaml timeout-decorator

## 2. Download Spider Dataset

In [None]:
import json
import os
import sqlite3

# Create minimal test data
os.makedirs('data/spider/database/test_db', exist_ok=True)

# Create a simple test database
conn = sqlite3.connect('data/spider/database/test_db/test_db.sqlite')
cursor = conn.cursor()
cursor.execute('CREATE TABLE students (id INTEGER, name TEXT, age INTEGER)')
cursor.execute('INSERT INTO students VALUES (1, "Alice", 20), (2, "Bob", 22)')
conn.commit()
conn.close()

# Create test JSON data
test_data = [
    {
        "question": "What are the names of all students?",
        "query": "SELECT name FROM students",
        "db_id": "test_db"
    },
    {
        "question": "How many students are there?",
        "query": "SELECT COUNT(*) FROM students", 
        "db_id": "test_db"
    }
]

with open('data/spider/train_spider.json', 'w') as f:
    json.dump(test_data, f)
    
with open('data/spider/dev.json', 'w') as f:
    json.dump(test_data[:1], f)

print("✅ Created synthetic test data")
!ls -la data/spider/

## 3. Quick Test: Reward Function

In [None]:
from reward import SQLRewardCalculator, RewardConfig

# Initialize reward calculator
config = RewardConfig(
    execution_weight=1.0,
    partial_weight=0.3,
    use_partial_rewards=True
)

reward_calc = SQLRewardCalculator(db_path="dummy.db", config=config)

# Test partial rewards
pred_sql = "SELECT name, age FROM users WHERE age > 18"
gold_sql = "SELECT name, age FROM users WHERE age > 18 ORDER BY age"

partial = reward_calc.partial_rewards(pred_sql, gold_sql)
print(f"Partial reward: {partial:.3f}")

# Test component extraction
components = reward_calc._extract_sql_components(pred_sql)
print(f"\nSQL Components:")
for comp_type, comp_set in components.items():
    print(f"  {comp_type}: {comp_set}")

## 4. Load Model and Test Generation

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

# Model name - choose smaller model for faster testing
model_name = "Qwen/Qwen2.5-Coder-3B-Instruct"  # 3B parameters
# Alternative: "codellama/CodeLlama-7b-hf"

print(f"Loading {model_name}...")

# QLoRA config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Load model
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True
)

# Prepare for training
model = prepare_model_for_kbit_training(model)

# LoRA config
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

# Add LoRA
model = get_peft_model(model, lora_config)
model.gradient_checkpointing_enable()

model.print_trainable_parameters()

print("\nModel loaded successfully!")

In [None]:
# Test generation
def generate_sql(question, model, tokenizer):
    prompt = f"Question: {question}\nSQL:"
    
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    outputs = model.generate(
        **inputs,
        max_new_tokens=128,
        temperature=0.7,
        do_sample=True
    )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response

# Test
question = "What are the names of all students?"
response = generate_sql(question, model, tokenizer)
print(f"Question: {question}")
print(f"Response: {response}")

## 5. Prepare Training Data (Small Subset)

In [None]:
import json

# Load Spider data
with open('data/spider/train_spider.json', 'r') as f:
    spider_data = json.load(f)

print(f"Total training examples: {len(spider_data)}")

# For quick testing, use a small subset
subset_size = 100  # Adjust based on time constraints
train_subset = spider_data[:subset_size]

# Save subset
with open('data/spider/train_subset.json', 'w') as f:
    json.dump(train_subset, f)

print(f"Using {len(train_subset)} examples for training")
print(f"\nExample:")
print(json.dumps(train_subset[0], indent=2))

## 6. Training Configuration

In [None]:
# Create config for quick training
config = {
    'model_name': model_name,
    'use_qlora': True,
    'train_data': 'data/spider/train_subset.json',
    'db_root': 'data/spider/database',
    'num_samples': 4,
    'temperature': 0.7,
    'num_epochs': 2,  # Reduced for quick testing
    'batch_size': 1,
    'gradient_accumulation_steps': 4,  # Reduced for speed
    'learning_rate': 1e-5,
    'kl_coef': 0.1,
    'execution_weight': 1.0,
    'partial_weight': 0.3,
    'output_dir': 'outputs/rl-model',
    'seed': 42
}

# Save config
import yaml
with open('config_colab.yaml', 'w') as f:
    yaml.dump(config, f)

print("Configuration:")
print(yaml.dump(config, default_flow_style=False))

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# Verify only one GPU visible
import torch
print(f"GPUs available: {torch.cuda.device_count()}")
print(f"Using: {torch.cuda.get_device_name(0)}")

In [None]:
# Check what's in the config
!cat config_colab.yaml

In [None]:
# Read train_rl.py
with open('train_rl.py', 'r') as f:
    content = f.read()

# Remove the default model
content = content.replace(
    "parser.add_argument('--model_name', type=str, default='codellama/CodeLlama-7b-hf',",
    "parser.add_argument('--model_name', type=str, default=None,"
)

# Write back
with open('train_rl.py', 'w') as f:
    f.write(content)

print("✅ Fixed train_rl.py - removed default model")

## 7. Run Training

In [None]:
# Run with both config AND required args (bug in script)
!python train_rl.py \
    --config config_colab.yaml \
    --train_data data/spider/train_subset.json \
    --db_root data/spider/database

In [None]:
# In another cell, monitor GPU usage
!watch -n 1 nvidia-smi

## 8. Evaluation

In [None]:
# Create a new working evaluate script
evaluate_code = '''
import torch
import argparse
import json
from pathlib import Path
from typing import List, Dict
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

from reward import SQLRewardCalculator, RewardConfig

class Text2SQLEvaluator:
    def __init__(self, model, tokenizer, reward_calculator, device="cuda"):
        self.model = model.to(device)
        self.model.eval()
        self.tokenizer = tokenizer
        self.reward_calculator = reward_calculator
        self.device = device
    
    def create_prompt(self, question: str, schema: str = None) -> str:
        if schema:
            prompt = f"""Given the database schema:
{schema}

Question: {question}

Generate a SQL query to answer this question:
SQL:"""
        else:
            prompt = f"Question: {question}\\nSQL:"
        return prompt
    
    @torch.no_grad()
    def generate_sql(self, question: str, schema: str = None) -> str:
        prompt = self.create_prompt(question, schema)
        inputs = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(self.device)
        outputs = self.model.generate(**inputs, max_new_tokens=256, temperature=0.1, do_sample=False, pad_token_id=self.tokenizer.pad_token_id)
        generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        sql = self.extract_sql(generated_text)
        return sql
    
    def extract_sql(self, text: str) -> str:
        text = text.strip()
        select_idx = text.upper().find('SELECT')
        if select_idx != -1:
            sql = text[select_idx:]
            semicolon_idx = sql.find(';')
            if semicolon_idx != -1:
                sql = sql[:semicolon_idx]
            return sql.strip()
        return text
    
    def evaluate(self, test_data: List[Dict], db_root: str, output_file: str = None) -> Dict[str, float]:
        total = len(test_data)
        execution_correct = 0
        exact_match = 0
        predictions = []
        
        for example in tqdm(test_data, desc="Evaluating"):
            question = example['question']
            gold_sql = example['query']
            db_id = example['db_id']
            db_path = f"{db_root}/{db_id}/{db_id}.sqlite"
            schema = example.get('schema')
            
            pred_sql = self.generate_sql(question, schema)
            
            reward_dict = self.reward_calculator.compute_reward(pred_sql, gold_sql, question, db_path)
            
            if reward_dict['execution'] == 1.0:
                execution_correct += 1
            
            if pred_sql.strip().upper() == gold_sql.strip().upper():
                exact_match += 1
            
            predictions.append({
                'question': question,
                'gold_sql': gold_sql,
                'pred_sql': pred_sql,
                'db_id': db_id,
                'execution_correct': reward_dict['execution'] == 1.0,
                'exact_match': pred_sql.strip().upper() == gold_sql.strip().upper()
            })
        
        metrics = {
            'execution_accuracy': execution_correct / total if total > 0 else 0,
            'exact_match': exact_match / total if total > 0 else 0,
            'total': total
        }
        
        if output_file:
            with open(output_file, 'w') as f:
                json.dump({'metrics': metrics, 'predictions': predictions}, f, indent=2)
        
        return metrics

def main(args):
    print("="*80)
    print("Text-to-SQL Evaluation")
    print("="*80)
    
    print(f"\\nLoading test data from {args.test_data}...")
    with open(args.test_data, 'r') as f:
        test_data = json.load(f)
    print(f"Test examples: {len(test_data)}")
    
    print(f"\\nLoading model from {args.model_path}...")
    
    # Load tokenizer from base model
    tokenizer = AutoTokenizer.from_pretrained(args.base_model)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # Load base model
    base = AutoModelForCausalLM.from_pretrained(args.base_model, device_map={"": 0}, torch_dtype=torch.bfloat16)
    
    # Load LoRA adapters
    model = PeftModel.from_pretrained(base, args.model_path)
    
    reward_config = RewardConfig(execution_weight=1.0, partial_weight=0.0, timeout_seconds=5, use_partial_rewards=False)
    reward_calculator = SQLRewardCalculator(db_path="", config=reward_config)
    
    evaluator = Text2SQLEvaluator(model=model, tokenizer=tokenizer, reward_calculator=reward_calculator, device="cuda" if torch.cuda.is_available() else "cpu")
    
    print("\\nEvaluating...")
    metrics = evaluator.evaluate(test_data=test_data, db_root=args.db_root, output_file=args.output_file)
    
    print("\\n" + "="*80)
    print("RESULTS")
    print("="*80)
    print(f"Execution Accuracy: {metrics['execution_accuracy']:.2%}")
    print(f"Exact Match:        {metrics['exact_match']:.2%}")
    print(f"Total Examples:     {metrics['total']}")
    print("="*80)
    
    if args.output_file:
        print(f"\\nPredictions saved to {args.output_file}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Evaluate Text-to-SQL model")
    parser.add_argument('--model_path', type=str, required=True)
    parser.add_argument('--base_model', type=str, required=True)
    parser.add_argument('--test_data', type=str, required=True)
    parser.add_argument('--db_root', type=str, required=True)
    parser.add_argument('--output_file', type=str)
    args = parser.parse_args()
    main(args)
'''

with open('evaluate_fixed.py', 'w') as f:
    f.write(evaluate_code)

print("✅ Created evaluate_fixed.py")

In [None]:
!ls -la outputs/rl-text2sql/

In [None]:
!mkdir -p results

In [None]:
!python evaluate_fixed.py \
    --model_path outputs/rl-text2sql \
    --base_model Qwen/Qwen2.5-Coder-3B-Instruct \
    --test_data data/spider/dev.json \
    --db_root data/spider/database \
    --output_file results/predictions.json

In [None]:
# Load and display results
with open('results/predictions.json', 'r') as f:
    results = json.load(f)

print("=" * 80)
print("EVALUATION RESULTS")
print("=" * 80)
print(f"Execution Accuracy: {results['metrics']['execution_accuracy']:.2%}")
print(f"Exact Match:        {results['metrics']['exact_match']:.2%}")
print(f"Total Examples:     {results['metrics']['total']}")
print("=" * 80)

# Show some examples
print("\nSample Predictions:")
for i, pred in enumerate(results['predictions'][:5]):
    print(f"\n--- Example {i+1} ---")
    print(f"Question: {pred['question']}")
    print(f"Gold SQL: {pred['gold_sql']}")
    print(f"Pred SQL: {pred['pred_sql']}")
    print(f"Correct: {pred['execution_correct']}")

## 9. Interactive Testing

In [None]:
# Load trained model for interactive testing
from peft import PeftModel

# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.bfloat16
)

# Load LoRA adapters
trained_model = PeftModel.from_pretrained(base_model, "outputs/rl-text2sql")
trained_model.eval()

print("Trained model loaded!")

In [None]:
# Interactive generation
def ask_question(question):
    prompt = f"Question: {question}\nSQL:"
    
    inputs = tokenizer(prompt, return_tensors="pt").to(trained_model.device)
    
    outputs = trained_model.generate(
        **inputs,
        max_new_tokens=256,
        temperature=0.1,
        do_sample=False
    )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract SQL
    sql = response.split("SQL:")[-1].strip()
    
    return sql

# Try some questions
questions = [
    "What are the names of all students?",
    "How many students are there?",
    "What is the average age of students?"
]

for q in questions:
    sql = ask_question(q)
    print(f"Q: {q}")
    print(f"A: {sql}")
    print()

## 10. Save to Drive (Optional)

In [None]:
# Kaggle automatically saves anything in /kaggle/working/ 
# Just copy files there - they'll be available after session ends

!mkdir -p /kaggle/working/outputs
!cp -r outputs/rl-text2sql /kaggle/working/
!cp results/predictions.json /kaggle/working/ 2>/dev/null || echo "No predictions yet"

print("✅ Model saved to Kaggle outputs (will persist after session)")
print("Access via: Notebook → Output → Download")