# LoRA Without Regret: Parameter-Efficient RL Fine-Tuning

This notebook demonstrates **LoRA without regret** — a technique for doing reinforcement learning (RL) fine-tuning with Low-Rank Adaptation (LoRA) adapters while maintaining the efficiency and performance benefits of LoRA.

## What is LoRA Without Regret?

Standard LoRA freezes the base model and only trains small adapter matrices. However, for RL tasks, we need to:
1. **Update LoRA parameters** based on reward signals
2. **Generate rollouts** efficiently for policy optimization
3. **Not freeze layers** — LoRA adapters are actively updated during training

This notebook uses **GRPO (Group Relative Policy Optimization)** to train LoRA adapters on GSM8K math problems with reward-based learning.

## Key Differences from Standard LoRA:
- **LoRA layers are NOT frozen** — they're the trainable parameters
- Uses policy gradient methods (REINFORCE/GRPO) instead of supervised learning
- Integrates with inference engines (vLLM, sglang) for efficient rollout generation

You'll need an **A100 GPU (40 GB)** or better for this notebook.

---

**References:**
- [LoRA Without Regret (GitHub)](https://github.com/michaelbzhu/lora-without-regret)
- [Thinking Machines LoRA Guide](https://thinkingmachines.ai/blog/lora/)
- [Karpathy's NanoChat RL](https://github.com/karpathy/nanochat/blob/master/scripts/chat_rl.py)

## 🛠️ Setup & Installation

In [None]:
#@title Setup
!nvidia-smi -L || true

import os, sys, random, numpy as np, torch, json, time, platform
print("Python:", sys.version)
print("CUDA available:", torch.cuda.is_available())

# Install dependencies
try:
    get_ipython().run_line_magic("uv", "pip -q install transformers==4.51.3 accelerate==1.4.0 peft==0.14.0 datasets==3.3.2 evaluate==0.4.3 trl==0.13.0 sentencepiece protobuf tqdm matplotlib > /dev/null")
except Exception:
    get_ipython().run_line_magic("pip", "-q install transformers==4.51.3 accelerate==1.4.0 peft==0.14.0 datasets==3.3.2 evaluate==0.4.3 trl==0.13.0 sentencepiece protobuf tqdm matplotlib > /dev/null")

import transformers, datasets, peft, accelerate, matplotlib
print("Transformers:", transformers.__version__)
print("Accelerate:", accelerate.__version__)
print("PEFT:", peft.__version__)

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
assert DEVICE == "cuda", "Please connect a GPU (A100+ recommended)."

print(f"\n✓ Running on {torch.cuda.get_device_name(0)}")

## ⚙️ Configuration Parameters

All hyperparameters are defined here. You can experiment with these settings to see how they affect LoRA training with RL.

In [None]:
from dataclasses import dataclass
from typing import Optional, List
import re

@dataclass
class Config:
    # ========== Model Settings ==========
    model_id: str = "Qwen/Qwen3-0.6B-Base"
    
    # ========== LoRA Settings ==========
    # IMPORTANT: These LoRA parameters are NOT frozen - they're actively trained!
    lora_r: int = 16                    # LoRA rank (1-64, higher = more capacity)
    lora_alpha: int = 32                # LoRA scaling factor
    lora_dropout: float = 0.05          # Dropout for LoRA layers
    lora_target_modules: List[str] = None  # Will be set to all attention + MLP
    
    # ========== RL Training Settings (GRPO) ==========
    n_grpo_steps: int = 100             # Number of RL training steps
    prompts_per_step: int = 32          # How many prompts to sample per step
    rollouts_per_prompt: int = 8        # Group size (K samples per prompt)
    
    learning_rate: float = 9e-5         # Learning rate for LoRA parameters
    weight_decay: float = 0.0           # Weight decay
    
    micro_batch_size: int = 2           # Micro-batch for gradient accumulation
    gradient_accumulation_steps: int = 128  # Accumulate gradients over this many steps
    epochs_per_step: int = 1            # Training epochs per GRPO step
    max_grad_norm: float = 1.0          # Gradient clipping
    
    # ========== Generation Settings ==========
    max_new_tokens: int = 256           # Max tokens to generate
    temperature: float = 0.7            # Sampling temperature for rollouts
    top_p: float = 0.9                  # Nucleus sampling
    eval_temperature: float = 0.0       # Greedy decoding for evaluation
    
    # ========== Task Settings (GSM8K) ==========
    prompt_template: str = (
        "Solve this math problem step by step.\n"
        "Give ONLY ONE final numeric answer (no units), inside square brackets.\n"
        "Problem: {question}\n\nSolution:"
    )
    
    # ========== Monitoring & Logging ==========
    log_every: int = 5                  # Log metrics every N steps
    eval_every: int = 10                # Evaluate on validation set every N steps
    eval_samples: int = 100             # Number of validation samples
    save_every: int = 25                # Save checkpoint every N steps
    
    # ========== Output ==========
    run_name: str = f"lora_rl_{int(time.time())}"
    output_dir: str = "./lora_rl_runs"
    
    def __post_init__(self):
        # Default to all attention + MLP projection layers
        if self.lora_target_modules is None:
            self.lora_target_modules = [
                "q_proj", "k_proj", "v_proj", "o_proj",  # Attention
                "gate_proj", "up_proj", "down_proj"      # MLP
            ]
        self.run_dir = os.path.join(self.output_dir, self.run_name)
        os.makedirs(self.run_dir, exist_ok=True)

cfg = Config()

print("="*60)
print("Configuration:")
print("="*60)
print(f"Model: {cfg.model_id}")
print(f"LoRA rank: {cfg.lora_r}, alpha: {cfg.lora_alpha}")
print(f"GRPO steps: {cfg.n_grpo_steps}")
print(f"Prompts/step: {cfg.prompts_per_step}, Rollouts/prompt: {cfg.rollouts_per_prompt}")
print(f"Learning rate: {cfg.learning_rate}")
print(f"Output: {cfg.run_dir}")
print("="*60)

## 📊 Load Dataset (GSM8K)

We'll use GSM8K grade school math problems as our RL task. The reward is binary: 1 if the answer is correct, 0 otherwise.

In [None]:
from datasets import load_dataset

def render_prompt(question: str) -> str:
    return cfg.prompt_template.format(question=question)

def parse_gold_answer(answer_text: str) -> Optional[str]:
    """Extract numeric answer from GSM8K format (e.g., '#### 42')"""
    m = re.search(r"####\s*(-?\d+(?:\.\d+)?)", answer_text)
    if m:
        return m.group(1).strip()
    # Fallback: last number in text
    nums = re.findall(r"-?\d+(?:\.\d+)?", answer_text)
    return nums[-1].strip() if nums else None

def parse_pred_answer(text: str) -> Optional[str]:
    """Extract answer from model output (looks for [number])"""
    m = re.search(r"\[\s*(-?\d+(?:\.\d+)?)\s*\]", text)
    if m:
        return m.group(1).strip()
    # Fallback: last number in text
    nums = re.findall(r"-?\d+(?:\.\d+)?", text)
    return nums[-1].strip() if nums else None

def compute_reward(generated_text: str, gold_answer: str) -> float:
    """Binary reward: 1.0 if correct, 0.0 otherwise"""
    pred = parse_pred_answer(generated_text)
    gold = parse_gold_answer(gold_answer)
    if pred is None or gold is None:
        return 0.0
    return 1.0 if pred == gold else 0.0

print("Loading GSM8K...")
ds_train = load_dataset("openai/gsm8k", "main", split="train")
ds_test = load_dataset("openai/gsm8k", "main", split="test")

# Split train into train/val
val_size = min(200, len(ds_train))
ds_val = ds_train.select(range(val_size))
ds_train = ds_train.select(range(val_size, len(ds_train)))

print(f"Splits: {len(ds_train)} train | {len(ds_val)} val | {len(ds_test)} test")

# Example
ex = ds_train[0]
print("\nExample:")
print(f"Q: {ex['question']}")
print(f"A: {ex['answer']}")
print(f"Gold: [{parse_gold_answer(ex['answer'])}]")

## 🏗️ Load Model with LoRA

We load the base model and add LoRA adapters. **Crucially, only the LoRA parameters are trainable** — the base model weights are frozen.

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, TaskType

def load_model_with_lora(model_id: str, lora_config: LoraConfig):
    """Load base model and apply LoRA adapters"""
    print(f"Loading base model: {model_id}...")
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True
    )
    model.config.use_cache = False  # Disable for training
    
    print("Applying LoRA adapters...")
    model = get_peft_model(model, lora_config)
    
    # Print trainable parameters
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Trainable params: {trainable_params:,} ({100 * trainable_params / total_params:.2f}%)")
    print(f"Total params: {total_params:,}")
    
    return model

# Load tokenizer
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(cfg.model_id, use_fast=True, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"  # For decoder-only models

# Configure LoRA
lora_config = LoraConfig(
    r=cfg.lora_r,
    lora_alpha=cfg.lora_alpha,
    lora_dropout=cfg.lora_dropout,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
    target_modules=cfg.lora_target_modules,
)

# Load model with LoRA
model = load_model_with_lora(cfg.model_id, lora_config)

print("\n✓ Model ready for RL training!")

## 📈 Baseline Evaluation

Let's see how well the model performs before RL training.

In [None]:
from tqdm.auto import tqdm

@torch.no_grad()
def evaluate_accuracy(model, tokenizer, dataset, num_samples: int = 100, temperature: float = 0.0, batch_size: int = 16):
    """Evaluate exact match accuracy on dataset"""
    model.eval()
    model.config.use_cache = True
    
    n = min(num_samples, len(dataset))
    correct = 0
    
    for i in tqdm(range(0, n, batch_size), desc="Evaluating"):
        batch = dataset.select(range(i, min(i + batch_size, n)))
        prompts = [render_prompt(ex["question"]) for ex in batch]
        
        # Tokenize
        inputs = tokenizer(prompts, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(model.device)
        
        # Generate
        gen_kwargs = dict(
            max_new_tokens=cfg.max_new_tokens,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            use_cache=True
        )
        if temperature > 0.0:
            gen_kwargs.update(do_sample=True, temperature=temperature, top_p=cfg.top_p)
        else:
            gen_kwargs.update(do_sample=False)
        
        outputs = model.generate(**inputs, **gen_kwargs)
        texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        
        # Check correctness
        for ex, text in zip(batch, texts):
            reward = compute_reward(text, ex["answer"])
            correct += int(reward)
    
    model.config.use_cache = False
    accuracy = correct / n
    return accuracy

print("Computing baseline accuracy...")
baseline_acc = evaluate_accuracy(model, tokenizer, ds_test, num_samples=100, temperature=0.0)
print(f"\nBaseline Test Accuracy: {baseline_acc:.1%}")

# Save baseline
with open(os.path.join(cfg.run_dir, "baseline.json"), "w") as f:
    json.dump({"baseline_accuracy": baseline_acc}, f, indent=2)

## 🎯 GRPO Training Loop

This is the core RL training loop using **Group Relative Policy Optimization (GRPO)**:

1. Sample prompts from the training set
2. Generate K rollouts per prompt (group)
3. Compute rewards for each rollout
4. Normalize advantages within each group (subtract mean)
5. Update LoRA parameters using policy gradients

**Key insight**: We only train on the generated tokens (response), not the prompt. This is enforced via masking.

In [None]:
import torch.nn.functional as F
from torch.utils.data import DataLoader
import pandas as pd
from IPython.display import display

def create_response_mask(input_ids: torch.Tensor, prompt_lengths: List[int], eos_token_id: int) -> torch.Tensor:
    """
    Create mask that's 1 for generated tokens (response) and 0 for prompt tokens.
    Also masks out tokens after EOS.
    """
    batch_size, seq_len = input_ids.shape
    mask = torch.zeros_like(input_ids, dtype=torch.float32)
    
    for i, prompt_len in enumerate(prompt_lengths):
        # Find first EOS in generated part
        generated_ids = input_ids[i, prompt_len:]
        eos_positions = (generated_ids == eos_token_id).nonzero(as_tuple=True)[0]
        
        if len(eos_positions) > 0:
            first_eos = eos_positions[0].item()
            mask[i, prompt_len:prompt_len + first_eos] = 1.0
        else:
            mask[i, prompt_len:] = 1.0
    
    return mask

def compute_policy_logprobs(model, input_ids: torch.Tensor, response_mask: torch.Tensor) -> torch.Tensor:
    """
    Compute log probabilities under current policy for generated tokens.
    Returns: [batch_size, seq_len] log probs (0 for masked positions)
    """
    # Forward pass
    outputs = model(input_ids=input_ids[:, :-1], attention_mask=(input_ids[:, :-1] != tokenizer.pad_token_id))
    logits = outputs.logits  # [batch, seq_len-1, vocab]
    
    # Get log probs for actual next tokens
    log_probs = F.log_softmax(logits, dim=-1)
    next_tokens = input_ids[:, 1:].unsqueeze(-1)  # [batch, seq_len-1, 1]
    token_log_probs = log_probs.gather(-1, next_tokens).squeeze(-1)  # [batch, seq_len-1]
    
    # Pad to match input_ids length and apply mask
    token_log_probs = F.pad(token_log_probs, (1, 0), value=0.0)  # [batch, seq_len]
    token_log_probs = token_log_probs * response_mask
    
    return token_log_probs

class MetricsTable:
    """Live updating table for training metrics"""
    def __init__(self):
        self.rows = []
        self.df = pd.DataFrame(columns=["step", "loss", "avg_reward", "kl_div", "val_acc"])
        self.handle = display(self.df, display_id=True)
    
    def update(self, step, loss=None, avg_reward=None, kl_div=None, val_acc=None):
        row = {"step": step}
        if loss is not None:
            row["loss"] = f"{loss:.4f}"
        if avg_reward is not None:
            row["avg_reward"] = f"{avg_reward:.3f}"
        if kl_div is not None:
            row["kl_div"] = f"{kl_div:.4f}"
        if val_acc is not None:
            row["val_acc"] = f"{val_acc:.1%}"
        
        self.rows.append(row)
        self.df = pd.DataFrame(self.rows)
        self.handle.update(self.df)

def save_lora_checkpoint(model, step: int, output_dir: str):
    """Save LoRA adapter weights"""
    save_path = os.path.join(output_dir, f"checkpoint_step_{step}")
    os.makedirs(save_path, exist_ok=True)
    model.save_pretrained(save_path)
    print(f"Saved checkpoint to {save_path}")
    return save_path

def train_grpo():
    """Main GRPO training loop"""
    print("\n" + "="*60)
    print("Starting GRPO Training")
    print("="*60)
    
    # Setup
    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
    metrics_table = MetricsTable()
    train_prompts = [render_prompt(ex["question"]) for ex in ds_train]
    train_answers = [ex["answer"] for ex in ds_train]
    
    logs = []
    
    for step in tqdm(range(cfg.n_grpo_steps), desc="GRPO Steps"):
        # ========== 1. Sample prompts ==========
        rng = np.random.default_rng(SEED + step)
        indices = rng.choice(len(train_prompts), size=cfg.prompts_per_step, replace=False)
        prompts_batch = [train_prompts[i] for i in indices]
        answers_batch = [train_answers[i] for i in indices]
        
        # Repeat each prompt K times for K rollouts
        prompts_repeated = sum([[p] * cfg.rollouts_per_prompt for p in prompts_batch], [])
        answers_repeated = sum([[a] * cfg.rollouts_per_prompt for a in answers_batch], [])
        
        # ========== 2. Generate rollouts (no grad) ==========
        model.eval()
        with torch.no_grad():
            inputs = tokenizer(prompts_repeated, padding=True, truncation=True, max_length=2048, return_tensors="pt").to(model.device)
            prompt_lengths = (inputs.input_ids != tokenizer.pad_token_id).sum(dim=1).tolist()
            
            # Generate with sampling
            gen_outputs = model.generate(
                **inputs,
                max_new_tokens=cfg.max_new_tokens,
                do_sample=True,
                temperature=cfg.temperature,
                top_p=cfg.top_p,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
                return_dict_in_generate=True,
                output_scores=False
            )
            
            full_sequences = gen_outputs.sequences  # [batch, seq_len]
            generated_texts = tokenizer.batch_decode(full_sequences, skip_special_tokens=True)
            
            # ========== 3. Compute rewards ==========
            rewards = torch.tensor(
                [compute_reward(text, ans) for text, ans in zip(generated_texts, answers_repeated)],
                dtype=torch.float32,
                device=model.device
            )  # [batch]
            
            # Reshape to [num_prompts, rollouts_per_prompt]
            rewards_grouped = rewards.view(cfg.prompts_per_step, cfg.rollouts_per_prompt)
            
            # ========== 4. Compute advantages (group normalization) ==========
            # Subtract mean within each group
            group_means = rewards_grouped.mean(dim=1, keepdim=True)  # [num_prompts, 1]
            advantages = rewards_grouped - group_means  # [num_prompts, rollouts_per_prompt]
            advantages = advantages.view(-1)  # Flatten back to [batch]
            
            # Store old log probs for KL estimation
            response_mask = create_response_mask(full_sequences, prompt_lengths, tokenizer.eos_token_id)
            old_logprobs = compute_policy_logprobs(model, full_sequences, response_mask)
        
        # ========== 5. Policy gradient update ==========
        model.train()
        model.config.use_cache = False
        
        total_loss = 0.0
        num_batches = 0
        
        # Train for multiple epochs with gradient accumulation
        for epoch in range(cfg.epochs_per_step):
            # Shuffle data
            perm = torch.randperm(full_sequences.size(0))
            
            for i in range(0, full_sequences.size(0), cfg.micro_batch_size):
                indices = perm[i:i + cfg.micro_batch_size]
                
                # Micro-batch
                mb_sequences = full_sequences[indices]
                mb_mask = response_mask[indices]
                mb_advantages = advantages[indices]
                
                # Compute current policy log probs
                curr_logprobs = compute_policy_logprobs(model, mb_sequences, mb_mask)
                
                # Policy gradient loss: -E[advantage * log_prob]
                # Sum over tokens, then weight by advantage per sequence
                sequence_logprobs = curr_logprobs.sum(dim=1)  # [micro_batch]
                loss = -(mb_advantages * sequence_logprobs).mean()
                
                # Backprop with gradient accumulation
                loss = loss / cfg.gradient_accumulation_steps
                loss.backward()
                
                total_loss += loss.item() * cfg.gradient_accumulation_steps
                num_batches += 1
                
                # Optimizer step every accumulation_steps
                if (num_batches % cfg.gradient_accumulation_steps) == 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm)
                    optimizer.step()
                    optimizer.zero_grad()
        
        avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
        
        # ========== 6. Compute metrics ==========
        with torch.no_grad():
            # Approximate KL divergence
            new_logprobs = compute_policy_logprobs(model, full_sequences, response_mask)
            kl_div = ((old_logprobs - new_logprobs) * response_mask).sum() / response_mask.sum()
            kl_div = kl_div.item()
            
            avg_reward = rewards.mean().item()
        
        # ========== 7. Logging & Evaluation ==========
        log_entry = {
            "step": step,
            "loss": avg_loss,
            "avg_reward": avg_reward,
            "kl_div": kl_div,
        }
        
        if (step % cfg.log_every == 0) or (step == cfg.n_grpo_steps - 1):
            # Validation
            val_acc = None
            if (step % cfg.eval_every == 0) or (step == cfg.n_grpo_steps - 1):
                val_acc = evaluate_accuracy(model, tokenizer, ds_val, num_samples=cfg.eval_samples)
                log_entry["val_acc"] = val_acc
            
            metrics_table.update(
                step=step,
                loss=avg_loss,
                avg_reward=avg_reward,
                kl_div=kl_div,
                val_acc=val_acc
            )
        
        logs.append(log_entry)
        
        # ========== 8. Save checkpoint ==========
        if (step % cfg.save_every == 0 and step > 0) or (step == cfg.n_grpo_steps - 1):
            save_lora_checkpoint(model, step, cfg.run_dir)
        
        torch.cuda.empty_cache()
    
    # Save logs
    pd.DataFrame(logs).to_csv(os.path.join(cfg.run_dir, "training_logs.csv"), index=False)
    
    print("\n✓ Training complete!")
    return logs

# Run training
training_logs = train_grpo()

## 📊 Final Evaluation

In [None]:
print("Running final test evaluation...")
final_test_acc = evaluate_accuracy(model, tokenizer, ds_test, num_samples=len(ds_test), temperature=0.0, batch_size=32)

print("\n" + "="*60)
print("FINAL RESULTS")
print("="*60)
print(f"Baseline Test Accuracy: {baseline_acc:.1%}")
print(f"Final Test Accuracy:    {final_test_acc:.1%}")
print(f"Improvement:            {(final_test_acc - baseline_acc):.1%}")
print("="*60)

# Save final results
with open(os.path.join(cfg.run_dir, "final_results.json"), "w") as f:
    json.dump({
        "baseline_accuracy": baseline_acc,
        "final_test_accuracy": final_test_acc,
        "improvement": final_test_acc - baseline_acc
    }, f, indent=2)

print(f"\nAll artifacts saved to: {cfg.run_dir}")

## 🚀 Usage with Inference Engines (vLLM, sglang)

After training, you can use your LoRA adapters with fast inference engines like **vLLM** or **sglang** for production deployment.

### Option 1: vLLM

vLLM supports dynamic LoRA adapter loading for efficient serving:

```python
# Install vLLM
# pip install vllm

from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest

# Initialize vLLM with base model
llm = LLM(
    model="Qwen/Qwen3-0.6B-Base",
    enable_lora=True,
    max_lora_rank=16,  # Must match your LoRA rank
    tensor_parallel_size=1
)

# Path to your trained LoRA adapter
lora_path = "./lora_rl_runs/lora_rl_123456/checkpoint_step_100"

# Create LoRA request
lora_request = LoRARequest(
    lora_name="math_lora",
    lora_int_id=1,
    lora_local_path=lora_path
)

# Sampling params
sampling_params = SamplingParams(
    temperature=0.0,
    max_tokens=256,
    stop=[tokenizer.eos_token]
)

# Generate with LoRA
prompts = [render_prompt("What is 25 * 17?")]
outputs = llm.generate(
    prompts,
    sampling_params,
    lora_request=lora_request
)

for output in outputs:
    print(output.outputs[0].text)
```

**Benefits of vLLM:**
- PagedAttention for memory efficiency
- Continuous batching for high throughput
- Multiple LoRA adapters can be served simultaneously
- Supports tensor parallelism for large models

---

### Option 2: sglang

sglang (Structured Generation Language) provides structured generation with LoRA support:

```python
# Install sglang
# pip install sglang[all]

import sglang as sgl

# Start sglang server with LoRA
# In terminal:
# python -m sglang.launch_server \
#   --model-path Qwen/Qwen3-0.6B-Base \
#   --lora-paths ./lora_rl_runs/lora_rl_123456/checkpoint_step_100 \
#   --port 30000

# Client code
import sglang as sgl

@sgl.function
def solve_math(s, question):
    s += render_prompt(question)
    s += sgl.gen("answer", max_tokens=256, temperature=0.0)

# Connect to server
sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))

# Generate
state = solve_math.run(question="What is 25 * 17?")
print(state["answer"])
```

**Benefits of sglang:**
- Structured generation primitives (constrained decoding)
- Easier to write complex prompting logic
- RadixAttention for prefix caching
- Good for agent-based applications

---

### Option 3: Merge LoRA weights into base model

If you don't need dynamic adapter switching, merge the LoRA weights:

```python
from peft import PeftModel
from transformers import AutoModelForCausalLM

# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen3-0.6B-Base",
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

# Load LoRA adapter
lora_path = "./lora_rl_runs/lora_rl_123456/checkpoint_step_100"
model = PeftModel.from_pretrained(base_model, lora_path)

# Merge and unload
merged_model = model.merge_and_unload()

# Save merged model
merged_model.save_pretrained("./merged_model")
tokenizer.save_pretrained("./merged_model")

# Now you can use this with any inference engine as a regular model
```

---

### Performance Comparison

| Engine | Throughput | Memory | LoRA Support | Best For |
|--------|-----------|---------|--------------|----------|
| **vLLM** | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐⭐ | Multi-adapter | High-throughput serving |
| **sglang** | ⭐⭐⭐⭐ | ⭐⭐⭐⭐ | Single adapter | Structured generation, agents |
| **Transformers** | ⭐⭐ | ⭐⭐ | Native PEFT | Development, experimentation |
| **Merged** | ⭐⭐⭐⭐⭐ | ⭐⭐⭐ | N/A (merged) | Single model deployment |

**Recommendation:** Use vLLM for production serving with multiple LoRA adapters, sglang for agent-based applications, or merge weights for simple deployments.

## 🔬 Advanced: Custom Reward Functions

You can easily adapt this notebook for other tasks by changing the reward function:

```python
# Example 1: Length-based reward (encourage conciseness)
def length_reward(text: str, target_length: int = 100) -> float:
    length = len(text.split())
    return 1.0 - abs(length - target_length) / target_length

# Example 2: Sentiment reward (for RLHF-style training)
from transformers import pipeline
sentiment_classifier = pipeline("sentiment-analysis")

def sentiment_reward(text: str, target_sentiment: str = "POSITIVE") -> float:
    result = sentiment_classifier(text)[0]
    return 1.0 if result["label"] == target_sentiment else 0.0

# Example 3: Learned reward model
class RewardModel(torch.nn.Module):
    # Your reward model architecture
    pass

reward_model = RewardModel.from_pretrained("path/to/reward/model")

def learned_reward(text: str, prompt: str) -> float:
    with torch.no_grad():
        inputs = tokenizer(prompt + text, return_tensors="pt").to(device)
        reward = reward_model(**inputs).item()
    return reward
```

## 📚 Summary

In this notebook, you learned:

1. **LoRA for RL** — How to train LoRA adapters with reinforcement learning
2. **GRPO Algorithm** — Group-based policy optimization with advantage normalization
3. **Not freezing layers** — LoRA parameters are updated, not frozen, during RL training
4. **Production deployment** — Using vLLM and sglang for efficient inference

**Key takeaways:**
- LoRA makes RL training memory-efficient (only ~1% of parameters are trainable)
- GRPO normalizes rewards within groups for stable training
- Response masking ensures we only train on generated tokens
- Fast inference engines like vLLM enable production deployment

**Next steps:**
- Try different LoRA ranks and learning rates
- Experiment with other reward functions
- Deploy your trained adapter with vLLM or sglang
- Combine with other techniques (PPO, DPO, etc.)