# RL Training Quick Test Notebook

Quick notebook template for testing RL training ideas.

**Use this for:**
- Testing a new paper's idea quickly
- Prototyping reward functions
- Small-scale experiments
- Debugging training loops

## 1. Setup and Imports

In [None]:
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import numpy as np
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## 2. Configuration

In [None]:
# Model configuration
MODEL_NAME = "meta-llama/Llama-2-7b-hf"  # Change to your model
PRECISION = torch.bfloat16
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Training configuration
NUM_EPOCHS = 1
BATCH_SIZE = 2
LEARNING_RATE = 1e-6
MAX_LENGTH = 256
MAX_NEW_TOKENS = 50

# RL configuration
TEMPERATURE = 0.8
KL_COEF = 0.1

## 3. Load Model and Tokenizer

In [None]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Load model
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=PRECISION,
    device_map="auto"
)

# Reference model for KL (frozen)
ref_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=PRECISION,
    device_map="auto"
)
ref_model.eval()

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

print(f"Model loaded: {MODEL_NAME}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()) / 1e9:.2f}B")

## 4. Prepare Data

Replace this with your actual dataset loading.

In [None]:
# Example: Simple prompt dataset
prompts = [
    "What is the capital of France?",
    "Explain quantum computing in simple terms.",
    "Write a short poem about AI.",
    "How do neural networks work?",
    "What is reinforcement learning?",
]

# Or load from HuggingFace
# dataset = load_dataset("your/dataset")
# prompts = dataset['train']['prompt'][:100]  # First 100

print(f"Loaded {len(prompts)} prompts")

## 5. Define Reward Function

**TODO: Implement your reward function here!**

In [None]:
def compute_reward(prompt: str, response: str) -> float:
    """
    Compute reward for a response.
    
    TODO: Replace this with your actual reward function!
    """
    # Example simple reward (REPLACE THIS)
    reward = 0.0
    
    # Reward based on length
    words = response.split()
    if 10 <= len(words) <= 100:
        reward += 0.5
    
    # Reward if ends with punctuation
    if response.strip() and response.strip()[-1] in '.!?':
        reward += 0.3
    
    return reward

# Test reward function
test_response = "This is a test response."
test_reward = compute_reward("Test prompt", test_response)
print(f"Test reward: {test_reward}")

## 6. Helper Functions

In [None]:
def generate_response(model, tokenizer, prompt, max_new_tokens=50):
    """Generate a response for a prompt."""
    inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=TEMPERATURE,
            top_p=0.9,
            pad_token_id=tokenizer.pad_token_id
        )
    
    full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    response = full_text[len(prompt):].strip()
    return response


def compute_log_probs(model, input_ids, attention_mask):
    """Compute log probabilities for sequences."""
    outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    logits = outputs.logits
    
    log_probs = F.log_softmax(logits, dim=-1)
    token_log_probs = torch.gather(
        log_probs[:, :-1, :],
        dim=2,
        index=input_ids[:, 1:].unsqueeze(-1)
    ).squeeze(-1)
    
    mask = attention_mask[:, 1:].bool()
    token_log_probs = token_log_probs * mask
    
    return token_log_probs.sum(dim=1)

## 7. Training Loop

In [None]:
# Track metrics
metrics_history = {
    'loss': [],
    'reward': [],
    'kl': []
}

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch + 1}/{NUM_EPOCHS}")
    model.train()
    
    # Batch processing
    for i in tqdm(range(0, len(prompts), BATCH_SIZE)):
        batch_prompts = prompts[i:i+BATCH_SIZE]
        
        # Generate responses
        responses = [generate_response(model, tokenizer, p, MAX_NEW_TOKENS) 
                    for p in batch_prompts]
        
        # Compute rewards
        rewards = torch.tensor([compute_reward(p, r) for p, r in zip(batch_prompts, responses)]).to(DEVICE)
        
        # Tokenize full sequences
        full_texts = [p + r for p, r in zip(batch_prompts, responses)]
        encodings = tokenizer(
            full_texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=MAX_LENGTH
        ).to(DEVICE)
        
        # Compute log probs
        current_log_probs = compute_log_probs(model, encodings['input_ids'], encodings['attention_mask'])
        
        with torch.no_grad():
            ref_log_probs = compute_log_probs(ref_model, encodings['input_ids'], encodings['attention_mask'])
        
        # KL divergence
        kl_div = current_log_probs - ref_log_probs
        
        # Compute advantages (simple GRPO-style)
        advantages = rewards - KL_COEF * kl_div
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        
        # Policy loss
        loss = -(current_log_probs * advantages).mean()
        
        # Backward
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        optimizer.zero_grad()
        
        # Track metrics
        metrics_history['loss'].append(loss.item())
        metrics_history['reward'].append(rewards.mean().item())
        metrics_history['kl'].append(kl_div.mean().item())

print("\nTraining complete!")

## 8. Visualize Training Metrics

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].plot(metrics_history['loss'])
axes[0].set_title('Loss')
axes[0].set_xlabel('Step')
axes[0].set_ylabel('Loss')

axes[1].plot(metrics_history['reward'])
axes[1].set_title('Mean Reward')
axes[1].set_xlabel('Step')
axes[1].set_ylabel('Reward')

axes[2].plot(metrics_history['kl'])
axes[2].set_title('KL Divergence')
axes[2].set_xlabel('Step')
axes[2].set_ylabel('KL')

plt.tight_layout()
plt.show()

## 9. Test the Trained Model

In [None]:
# Test prompts
test_prompts = [
    "What is machine learning?",
    "Explain the Transformer architecture.",
    "How does reinforcement learning work?"
]

model.eval()

for prompt in test_prompts:
    response = generate_response(model, tokenizer, prompt, max_new_tokens=100)
    reward = compute_reward(prompt, response)
    
    print(f"\n{'='*60}")
    print(f"Prompt: {prompt}")
    print(f"Response: {response}")
    print(f"Reward: {reward:.3f}")

## 10. Save Model (Optional)

In [None]:
# Uncomment to save
# model.save_pretrained("./trained_model")
# tokenizer.save_pretrained("./trained_model")
# print("Model saved to ./trained_model")

## Notes and Next Steps

**What to modify for your use case:**

1. **Reward Function** (Section 5): This is the most important part! Implement your actual reward.
2. **Data Loading** (Section 4): Load your specific dataset.
3. **Model** (Section 3): Change to your model.
4. **Hyperparameters** (Section 2): Tune for your task.

**For production training:**
- Use the full training scripts instead of this notebook
- Add proper logging (wandb, tensorboard)
- Implement checkpointing
- Scale to multiple GPUs
- Add evaluation metrics