# Anagram Solver RL Training

Train LLM agent to solve anagram puzzles using GRPO with train/test split.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PollyLeo6/Anagram-Solver/blob/main/train_agent.ipynb)

## 🚀 Quick Setup

In [None]:
import os
import random
import numpy as np
import torch

# Fix random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

if not os.path.exists('anagram_game.py'):
    print('📥 Cloning repository...')
    !git clone https://github.com/PollyLeo6/Anagram-Solver.git
    %cd Anagram-Solver
    print('✅ Repository cloned!')

print('🎯 Ready to train!')

## Setup

In [None]:
!pip install torch transformers datasets accelerate peft trl unsloth matplotlib
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"

In [None]:
from unsloth import FastLanguageModel
from transformers import TrainingArguments
from trl import SFTTrainer, PPOTrainer, PPOConfig
from datasets import Dataset
import json
import matplotlib.pyplot as plt

from anagram_game import AnagramSolverEnv
from utils import create_english_dictionary

def generate_system_prompt():
    return """You are an expert anagram solver. Your task is to unscramble letters to form valid English words.

Rules:
1. Use each letter exactly once
2. Form valid English words only
3. Respond in JSON format: {"solutions": ["word1", "word2", ...]}
4. Order words as they appear in the anagram list

Be accurate and follow the format exactly."""

## Create Train/Test Split

In [None]:
# Create expanded dictionary
full_dictionary = [
    # 3-letter words
    "cat", "dog", "car", "sun", "run", "big", "red", "hot", "old", "new",
    "boy", "man", "day", "way", "may", "say", "try", "fly", "sky", "eye",
    "bat", "hat", "mat", "sat", "fat", "pat", "rat", "eat", "tea", "sea",
    
    # 4-letter words  
    "book", "tree", "door", "fire", "wind", "star", "moon", "fish", "bird", "bear",
    "house", "water", "light", "dark", "fast", "slow", "good", "nice", "blue", "green",
    "black", "white", "small", "large", "phone", "table", "chair", "paper", "money", "happy",
    "love", "time", "work", "play", "game", "food", "hand", "head", "face", "back",
    
    # 5-letter words
    "world", "right", "great", "small", "every", "start", "place", "where", "after",
    "think", "never", "again", "might", "still", "while", "sound", "below", "voice", "young",
    "house", "point", "group", "music", "party", "story", "movie", "beach", "ocean", "river",
    
    # 6+ letter words for challenge
    "elephant", "butterfly", "computer", "keyboard", "monitor", "speaker", "headset",
    "programming", "algorithm", "function", "variable", "constant", "database",
    "university", "education", "knowledge", "learning", "teaching", "student",
    "chocolate", "strawberry", "blueberry", "raspberry", "pineapple", "watermelon"
]

# Split dictionary: 70% train, 30% test
random.shuffle(full_dictionary)
split_idx = int(0.7 * len(full_dictionary))
train_words = full_dictionary[:split_idx]
test_words = full_dictionary[split_idx:]

print(f"Total words: {len(full_dictionary)}")
print(f"Train words: {len(train_words)}")
print(f"Test words: {len(test_words)}")
print(f"Train sample: {train_words[:5]}")
print(f"Test sample: {test_words[:5]}")

## Create Custom Environment Classes

In [None]:
class TrainAnagramEnv(AnagramSolverEnv):
    def __init__(self):
        super().__init__()
        self.dictionary = set(train_words)
        
class TestAnagramEnv(AnagramSolverEnv):
    def __init__(self):
        super().__init__()
        self.dictionary = set(test_words)

# Create environments
train_env = TrainAnagramEnv()
test_env = TestAnagramEnv()

print("✅ Train and test environments created")

## Load Model

In [None]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/Qwen2.5-1.5B-Instruct",
    max_seq_length=2048,
    dtype=None,
    load_in_4bit=True,
)

model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                   "gate_proj", "up_proj", "down_proj"],
    lora_alpha=16,
    lora_dropout=0.1,
    bias="none",
    use_gradient_checkpointing="unsloth",
)

print("Model loaded with LoRA adapters")

## Generate Training Data

In [None]:
# Generate training data from train_words only
train_tasks = []
for difficulty in range(1, 8):  # Only up to level 7 for SFT
    tasks = train_env.generate(num_of_questions=50, difficulty=difficulty)
    train_tasks.extend(tasks)

print(f"Generated {len(train_tasks)} training examples")

# Convert to HuggingFace format
def format_conversation(example):
    system_prompt = generate_system_prompt()
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": example['question']},
        {"role": "assistant", "content": example['answer']}
    ]
    formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
    return {"text": formatted}

train_data = []
for task in train_tasks:
    train_data.append({
        'question': task.question,
        'answer': task.answer
    })

hf_dataset = Dataset.from_list(train_data)
formatted_dataset = hf_dataset.map(format_conversation)
print(f"Formatted dataset size: {len(formatted_dataset)}")

## Evaluation Functions

In [None]:
def evaluate_model(model, tokenizer, env, difficulty=5, num_samples=20):
    """Evaluate model on given environment"""
    correct_count = 0
    system_prompt = generate_system_prompt()
    
    test_tasks = env.generate(num_of_questions=num_samples, difficulty=difficulty)
    
    for task in test_tasks:
        try:
            messages = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": task.question}
            ]
            
            input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
            
            with torch.no_grad():
                outputs = model.generate(**inputs, max_new_tokens=100, temperature=0.7, do_sample=True, pad_token_id=tokenizer.eos_token_id)
            
            response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
            result = env.verifier.verify(task, response)
            if result:
                correct_count += 1
        except Exception:
            continue
    
    return correct_count / num_samples

# Baseline evaluation
print("Evaluating baseline...")
baseline_train = evaluate_model(model, tokenizer, train_env, difficulty=5)
baseline_test = evaluate_model(model, tokenizer, test_env, difficulty=5)
print(f"Baseline - Train: {baseline_train:.3f}, Test: {baseline_test:.3f}")

## SFT Training

In [None]:
training_args = TrainingArguments(
    output_dir="./anagram_sft",
    num_train_epochs=2,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    learning_rate=5e-5,
    warmup_steps=50,
    logging_steps=10,
    save_steps=200,
    save_total_limit=2,
    report_to="none",
    remove_unused_columns=False,
)

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=formatted_dataset,
    tokenizer=tokenizer,
    dataset_text_field="text",
    max_seq_length=2048,
)

print("Starting SFT training...")
trainer.train()
print("SFT training complete!")

## Post-SFT Evaluation

In [None]:
print("Evaluating after SFT...")
sft_train = evaluate_model(model, tokenizer, train_env, difficulty=5)
sft_test = evaluate_model(model, tokenizer, test_env, difficulty=5)
sft_test_hard = evaluate_model(model, tokenizer, test_env, difficulty=8)

print(f"After SFT - Train: {sft_train:.3f}, Test: {sft_test:.3f}, Test Hard: {sft_test_hard:.3f}")

## RL Training with GRPO

In [None]:
def reward_function(question: str, response: str) -> float:
    """Reward function for RL training on test words"""
    try:
        # Create a dummy task to verify against
        # Extract anagrams from question
        import re
        anagram_match = re.findall(r"\d+\. (\w+)", question)
        if not anagram_match:
            return 0.0
            
        # Use test environment for verification
        dummy_task = test_env.generate(num_of_questions=1, difficulty=8)[0]
        result = test_env.verifier.verify(dummy_task, response)
        return 1.0 if result else 0.0
    except Exception:
        return 0.0

# Generate RL training data from test words (unseen during SFT)
rl_tasks = []
for difficulty in [8, 9, 10]:  # Hard levels with test words
    tasks = test_env.generate(num_of_questions=30, difficulty=difficulty)
    rl_tasks.extend(tasks)

print(f"Generated {len(rl_tasks)} RL training examples from test words")

# Convert to RL format
rl_queries = []
for task in rl_tasks:
    system_prompt = generate_system_prompt()
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": task.question}
    ]
    query = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    rl_queries.append(query)

print(f"Prepared {len(rl_queries)} RL queries")

## Simple RL Training Loop

In [None]:
# Simple RL training loop (since GRPO might not be available)
from torch.nn import functional as F

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
model.train()

print("Starting RL training...")
for epoch in range(2):
    total_reward = 0
    for i, query in enumerate(rl_queries[:50]):  # Limit for demo
        # Generate response
        inputs = tokenizer(query, return_tensors="pt").to(model.device)
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs, 
                max_new_tokens=50, 
                temperature=0.8, 
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id,
                return_dict_in_generate=True,
                output_scores=True
            )
        
        response = tokenizer.decode(outputs.sequences[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
        reward = reward_function(query, response)
        total_reward += reward
        
        # Simple policy gradient update
        if reward > 0.5:  # Only update on good responses
            model.train()
            logits = model(**inputs).logits
            loss = -torch.mean(F.log_softmax(logits, dim=-1)) * reward
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
    avg_reward = total_reward / min(50, len(rl_queries))
    print(f"Epoch {epoch+1}, Average Reward: {avg_reward:.3f}")

print("RL training complete!")

## Final Evaluation

In [None]:
print("Final evaluation...")
final_train = evaluate_model(model, tokenizer, train_env, difficulty=5)
final_test = evaluate_model(model, tokenizer, test_env, difficulty=5)
final_test_hard = evaluate_model(model, tokenizer, test_env, difficulty=8)

print(f"\nResults Summary:")
print(f"Baseline    - Train: {baseline_train:.3f}, Test: {baseline_test:.3f}")
print(f"After SFT   - Train: {sft_train:.3f}, Test: {sft_test:.3f}, Test Hard: {sft_test_hard:.3f}")
print(f"After RL    - Train: {final_train:.3f}, Test: {final_test:.3f}, Test Hard: {final_test_hard:.3f}")

# Plot results
methods = ['Baseline', 'SFT', 'SFT+RL']
train_scores = [baseline_train, sft_train, final_train]
test_scores = [baseline_test, sft_test, final_test]
test_hard_scores = [0, sft_test_hard, final_test_hard]

x = range(len(methods))
width = 0.25

plt.figure(figsize=(10, 6))
plt.bar([i - width for i in x], train_scores, width, label='Train Words', alpha=0.8)
plt.bar(x, test_scores, width, label='Test Words (Easy)', alpha=0.8)
plt.bar([i + width for i in x], test_hard_scores, width, label='Test Words (Hard)', alpha=0.8)

plt.xlabel('Method')
plt.ylabel('Accuracy')
plt.title('Model Performance: Train vs Test Words')
plt.xticks(x, methods)
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f"\nKey Insights:")
print(f"- Overfitting check: Train vs Test gap = {final_train - final_test:.3f}")
print(f"- RL improvement on unseen hard words: {final_test_hard - sft_test_hard:.3f}")
print(f"- Generalization: Test performance = {final_test:.3f}")

## Test on Specific Examples

In [None]:
# Test on unseen words
test_tasks = test_env.generate(num_of_questions=3, difficulty=8)
system_prompt = generate_system_prompt()

for i, task in enumerate(test_tasks, 1):
    print(f"\n--- Test Example {i} (Unseen Words) ---")
    print(f"Anagrams: {task.metadata['anagrams']}")
    print(f"Correct: {task.metadata['target_words']}")
    
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": task.question}
    ]
    
    input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=100, temperature=0.7, do_sample=True)
    
    response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
    result = test_env.verifier.verify(task, response)
    
    print(f"Model response: {response}")
    print(f"Correct: {result}")