# Anagram Solver RL Training

Pure RL training of Qwen model on anagram solving with detailed metrics tracking.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PollyLeo6/Anagram-Solver/blob/main/train_agent.ipynb)

## Setup

In [None]:
import os
import random
import numpy as np
import torch
import json
import pandas as pd
from typing import List

# Fix random seeds
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

if not os.path.exists('anagram_game.py'):
    print('📥 Cloning repository...')
    !git clone https://github.com/PollyLeo6/Anagram-Solver.git
    %cd Anagram-Solver
    print('✅ Repository cloned!')

!pip install torch transformers datasets accelerate peft trl unsloth matplotlib pandas
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"

In [None]:
from unsloth import FastLanguageModel
from trl import GRPOConfig, GRPOTrainer
from datasets import Dataset
import matplotlib.pyplot as plt

# Check if files exist, if not generate them
if not os.path.exists('utils.py'):
    print('⚠️ Files not found, running data generation...')
    !python utils.py

from anagram_game import AnagramSolverEnv
from utils import create_english_dictionary

## Data Preparation

In [None]:
# System prompt for anagram solving
SYSTEM_PROMPT = """You are an expert anagram solver. Your task is to unscramble letters to form valid English words.

Rules:
1. Use each letter exactly once
2. Form valid English words only
3. Respond in JSON format: {"solutions": ["word1", "word2", ...]}
4. Order words as they appear in the anagram list

Be accurate and follow the format exactly."""

def extract_json_answer(text: str) -> str:
    """Extract JSON answer from response"""
    try:
        start = text.find('{')
        end = text.rfind('}') + 1
        if start != -1 and end > start:
            json_str = text[start:end]
            parsed = json.loads(json_str)
            return json_str
        return text.strip()
    except:
        return text.strip()

def get_anagram_dataset(env, num_samples=100, difficulties=[5, 6, 7, 8]):
    """Generate anagram dataset for RL training"""
    data = []
    for difficulty in difficulties:
        tasks = env.generate(num_of_questions=num_samples, difficulty=difficulty)
        for task in tasks:
            data.append({
                'prompt': [
                    {'role': 'system', 'content': SYSTEM_PROMPT},
                    {'role': 'user', 'content': task.question}
                ],
                'answer': task.answer,
                'difficulty': difficulty,
                'metadata': task.metadata
            })
    return Dataset.from_list(data)

In [None]:
# Load dictionary and create environment
create_english_dictionary()
with open('dictionary.txt', 'r', encoding='utf-8') as f:
    dictionary_words = [line.strip() for line in f.readlines()]

# Create environment with full dictionary
env = AnagramSolverEnv()
env.dictionary = set(dictionary_words)

# Generate RL training dataset
rl_dataset = get_anagram_dataset(env, num_samples=100, difficulties=[5, 6, 7, 8])
print(f"Generated {len(rl_dataset)} RL training examples")
print(f"Dictionary size: {len(dictionary_words)} words")

## Reward Function

In [None]:
def anagram_reward_func(prompts, completions, answer, **kwargs) -> List[float]:
    """Reward function for anagram solving with detailed logging"""
    responses = [completion[0]['content'] for completion in completions]
    rewards = []
    
    for i, (response, correct_answer) in enumerate(zip(responses, answer)):
        reward = 0.0
        
        try:
            # Extract JSON from response
            extracted = extract_json_answer(response)
            response_json = json.loads(extracted)
            correct_json = json.loads(correct_answer)
            
            # Format reward (0.2 for valid JSON)
            if 'solutions' in response_json and isinstance(response_json['solutions'], list):
                reward += 0.2
            
            # Correctness reward (1.8 for correct answer)
            if response_json == correct_json:
                reward += 1.8
            else:
                # Partial credit for correct words
                if 'solutions' in response_json and 'solutions' in correct_json:
                    correct_words = set(correct_json['solutions'])
                    response_words = set(response_json['solutions'])
                    overlap = len(correct_words & response_words)
                    total = len(correct_words)
                    if total > 0:
                        reward += 1.0 * (overlap / total)
            
        except Exception as e:
            # No reward for invalid responses
            reward = 0.0
        
        rewards.append(reward)
    
    return rewards

## Model Loading

In [None]:
# Load Qwen model
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/Qwen2.5-1.5B-Instruct",
    max_seq_length=2048,
    dtype=None,
    load_in_4bit=True,
)

# Add LoRA adapters
model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                   "gate_proj", "up_proj", "down_proj"],
    lora_alpha=16,
    lora_dropout=0.1,
    bias="none",
    use_gradient_checkpointing="unsloth",
)

print("✅ Qwen model loaded with LoRA adapters")
print(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

## GRPO Training Configuration

In [None]:
# GRPO Configuration for detailed tracking
grpo_config = GRPOConfig(
    use_vllm=False,  # Disable for better logging
    learning_rate=5e-6,
    adam_beta1=0.9,
    adam_beta2=0.99,
    weight_decay=0.01,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    optim="adamw_8bit",
    logging_steps=1,  # Log every step
    per_device_train_batch_size=2,
    gradient_accumulation_steps=2,
    num_generations=4,
    max_prompt_length=512,
    max_completion_length=128,
    max_steps=500,  # More steps for detailed tracking
    save_steps=100,
    eval_steps=50,
    max_grad_norm=1.0,
    report_to="none",
    output_dir="./rl_outputs",
    logging_dir="./rl_logs",
    dataloader_num_workers=0,
)

print("GRPO Configuration:")
print(f"- Learning rate: {grpo_config.learning_rate}")
print(f"- Batch size: {grpo_config.per_device_train_batch_size}")
print(f"- Max steps: {grpo_config.max_steps}")
print(f"- Num generations: {grpo_config.num_generations}")

## RL Training with Metrics Tracking

In [None]:
# Custom callback for detailed metrics tracking
class MetricsCallback:
    def __init__(self):
        self.metrics_history = []
    
    def on_log(self, logs):
        if logs:
            self.metrics_history.append(logs)
    
    def get_dataframe(self):
        if not self.metrics_history:
            return pd.DataFrame()
        
        df = pd.DataFrame(self.metrics_history)
        # Ensure required columns exist
        required_cols = ['step', 'train_loss', 'reward', 'reward_std', 'completion_length', 'kl']
        for col in required_cols:
            if col not in df.columns:
                df[col] = 0.0
        return df[required_cols]

# Initialize callback
metrics_callback = MetricsCallback()

# Create GRPO trainer
grpo_trainer = GRPOTrainer(
    model=model,
    args=grpo_config,
    train_dataset=rl_dataset,
    tokenizer=tokenizer,
    reward_function=anagram_reward_func,
)

# Add callback
grpo_trainer.add_callback(metrics_callback)

print("🚀 Starting RL training...")
print(f"Training on {len(rl_dataset)} examples")

# Start training
grpo_trainer.train()

print("✅ RL training complete!")

## Training Metrics Analysis

In [None]:
# Get training metrics
metrics_df = metrics_callback.get_dataframe()

if not metrics_df.empty:
    print("📊 Training Metrics Summary:")
    print(metrics_df.describe())
    
    # Display first 10 and last 10 rows
    print("\n📈 First 10 training steps:")
    print(metrics_df.head(10).to_string(index=False))
    
    print("\n📈 Last 10 training steps:")
    print(metrics_df.tail(10).to_string(index=False))
    
    # Save to CSV
    metrics_df.to_csv('training_metrics.csv', index=False)
    print("\n💾 Metrics saved to training_metrics.csv")
else:
    print("⚠️ No metrics collected during training")

## Training Progress Visualization

In [None]:
if not metrics_df.empty:
    # Create subplots for different metrics
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    fig.suptitle('RL Training Progress', fontsize=16)
    
    # Training Loss
    axes[0, 0].plot(metrics_df['step'], metrics_df['train_loss'])
    axes[0, 0].set_title('Training Loss')
    axes[0, 0].set_xlabel('Step')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].grid(True, alpha=0.3)
    
    # Reward
    axes[0, 1].plot(metrics_df['step'], metrics_df['reward'])
    axes[0, 1].set_title('Reward')
    axes[0, 1].set_xlabel('Step')
    axes[0, 1].set_ylabel('Reward')
    axes[0, 1].grid(True, alpha=0.3)
    
    # Reward Standard Deviation
    axes[0, 2].plot(metrics_df['step'], metrics_df['reward_std'])
    axes[0, 2].set_title('Reward Standard Deviation')
    axes[0, 2].set_xlabel('Step')
    axes[0, 2].set_ylabel('Reward Std')
    axes[0, 2].grid(True, alpha=0.3)
    
    # Completion Length
    axes[1, 0].plot(metrics_df['step'], metrics_df['completion_length'])
    axes[1, 0].set_title('Completion Length')
    axes[1, 0].set_xlabel('Step')
    axes[1, 0].set_ylabel('Length')
    axes[1, 0].grid(True, alpha=0.3)
    
    # KL Divergence
    axes[1, 1].plot(metrics_df['step'], metrics_df['kl'])
    axes[1, 1].set_title('KL Divergence')
    axes[1, 1].set_xlabel('Step')
    axes[1, 1].set_ylabel('KL')
    axes[1, 1].grid(True, alpha=0.3)
    
    # Combined reward and loss
    ax2 = axes[1, 2]
    ax2.plot(metrics_df['step'], metrics_df['reward'], 'g-', label='Reward')
    ax2.set_xlabel('Step')
    ax2.set_ylabel('Reward', color='g')
    ax2.tick_params(axis='y', labelcolor='g')
    
    ax3 = ax2.twinx()
    ax3.plot(metrics_df['step'], metrics_df['train_loss'], 'r-', label='Loss')
    ax3.set_ylabel('Loss', color='r')
    ax3.tick_params(axis='y', labelcolor='r')
    
    axes[1, 2].set_title('Reward vs Loss')
    axes[1, 2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('training_progress.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print("📊 Training progress plots saved as 'training_progress.png'")
else:
    print("⚠️ No metrics to visualize")

## Model Inference Testing

In [None]:
def test_model_inference(model, tokenizer, env, num_tests=10):
    """Test trained model on new anagram examples"""
    print("🧪 Testing model inference...")
    
    # Generate test examples
    test_tasks = env.generate(num_of_questions=num_tests, difficulty=6)
    
    results = []
    correct_count = 0
    
    for i, task in enumerate(test_tasks, 1):
        print(f"\n--- Test {i}/{num_tests} ---")
        print(f"Question: {task.question}")
        print(f"Correct answer: {task.answer}")
        
        try:
            # Prepare input
            messages = [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": task.question}
            ]
            
            input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
            
            # Generate response
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=100,
                    temperature=0.7,
                    do_sample=True,
                    pad_token_id=tokenizer.eos_token_id
                )
            
            response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
            print(f"Model response: {response}")
            
            # Verify correctness
            is_correct = env.verifier.verify(task, response)
            print(f"Correct: {'✅' if is_correct else '❌'}")
            
            if is_correct:
                correct_count += 1
            
            results.append({
                'question': task.question,
                'correct_answer': task.answer,
                'model_response': response,
                'is_correct': is_correct,
                'difficulty': task.difficulty
            })
            
        except Exception as e:
            print(f"Error: {e}")
            results.append({
                'question': task.question,
                'correct_answer': task.answer,
                'model_response': f"Error: {e}",
                'is_correct': False,
                'difficulty': task.difficulty
            })
    
    accuracy = correct_count / num_tests
    print(f"\n📊 Inference Results:")
    print(f"Accuracy: {accuracy:.2%} ({correct_count}/{num_tests})")
    
    # Save results
    results_df = pd.DataFrame(results)
    results_df.to_csv('inference_results.csv', index=False)
    print(f"💾 Results saved to 'inference_results.csv'")
    
    return accuracy, results

# Run inference tests
accuracy, test_results = test_model_inference(model, tokenizer, env, num_tests=20)

## Final Summary

In [None]:
print("🎯 RL Training Summary")
print("=" * 50)

if not metrics_df.empty:
    print(f"📈 Training Progress:")
    print(f"  - Total steps: {len(metrics_df)}")
    print(f"  - Final reward: {metrics_df['reward'].iloc[-1]:.4f}")
    print(f"  - Final loss: {metrics_df['train_loss'].iloc[-1]:.4f}")
    print(f"  - Average reward: {metrics_df['reward'].mean():.4f}")
    print(f"  - Max reward: {metrics_df['reward'].max():.4f}")

print(f"\n🧪 Inference Performance:")
print(f"  - Test accuracy: {accuracy:.2%}")
print(f"  - Model: Qwen2.5-1.5B-Instruct + LoRA")
print(f"  - Training method: Pure RL (GRPO)")

print(f"\n📁 Generated Files:")
print(f"  - training_metrics.csv: Detailed training metrics")
print(f"  - training_progress.png: Training visualization")
print(f"  - inference_results.csv: Test results")
print(f"  - Model saved in: ./rl_outputs/")

print("\n✅ Training and evaluation complete!")