# Anagram Solver RL Training

Train LLM agent to solve anagram puzzles using GRPO.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PollyLeo6/Anagram-Solver/blob/main/train_agent.ipynb)

## 🚀 Quick Setup (Run this first!)

In [None]:
# Auto-setup for Colab
import os

if not os.path.exists('anagram_game.py'):
    print('📥 Cloning repository...')
    !git clone https://github.com/PollyLeo6/Anagram-Solver.git
    %cd Anagram-Solver
    print('✅ Repository cloned!')

# Generate datasets
if not os.path.exists('train_dataset.jsonl'):
    print('📊 Generating datasets...')
    !python utils.py
    print('✅ Datasets generated!')

print('🎯 Ready to train!')

## Setup

In [None]:
# Install packages
!pip install torch transformers datasets accelerate peft trl unsloth matplotlib
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"

In [None]:
import torch
from unsloth import FastLanguageModel
from transformers import TrainingArguments
from trl import SFTTrainer
from datasets import Dataset
import json
import matplotlib.pyplot as plt

from anagram_game import AnagramSolverEnv
from utils import create_english_dictionary, correctness_reward_func

def generate_system_prompt():
    """Generate system prompt for anagram solving"""
    return """You are an expert anagram solver. Your task is to unscramble letters to form valid English words.

Rules:
1. Use each letter exactly once
2. Form valid English words only
3. Respond in JSON format: {"solutions": ["word1", "word2", ...]}
4. Order words as they appear in the anagram list

Be accurate and follow the format exactly."""

## Load Model

In [None]:
# Load Qwen2.5-1.5B with unsloth
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/Qwen2.5-1.5B-Instruct",
    max_seq_length=2048,
    dtype=None,
    load_in_4bit=True,
)

# Add LoRA adapters
model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                   "gate_proj", "up_proj", "down_proj"],
    lora_alpha=16,
    lora_dropout=0.1,
    bias="none",
    use_gradient_checkpointing="unsloth",
)

print("Model loaded with LoRA adapters")

## Prepare Training Data

In [None]:
# Load existing training data
with open('train_dataset.jsonl', 'r') as f:
    train_data = []
    for line in f:
        train_data.append(json.loads(line))

print(f"Loaded {len(train_data)} training examples")

# Convert to HuggingFace dataset format
def format_conversation(example):
    system_prompt = generate_system_prompt()
    
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": example['question']},
        {"role": "assistant", "content": example['answer']}
    ]
    
    formatted = tokenizer.apply_chat_template(
        messages, 
        tokenize=False, 
        add_generation_prompt=False
    )
    
    return {"text": formatted}

hf_dataset = Dataset.from_list(train_data)
formatted_dataset = hf_dataset.map(format_conversation)

print(f"Formatted dataset size: {len(formatted_dataset)}")

## Baseline Evaluation

In [None]:
def evaluate_model(model, tokenizer, num_samples=10):
    """Evaluate model performance"""
    env = AnagramSolverEnv()
    correct_count = 0
    system_prompt = generate_system_prompt()
    
    # Generate test tasks
    test_tasks = env.generate(num_of_questions=num_samples, difficulty=5)
    
    for task in test_tasks:
        # Prepare input
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": task.question}
        ]
        
        input_text = tokenizer.apply_chat_template(
            messages, 
            tokenize=False, 
            add_generation_prompt=True
        )
        
        # Generate response
        inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=100,
                temperature=0.7,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id
            )
        
        # Decode and evaluate
        response = tokenizer.decode(
            outputs[0][inputs['input_ids'].shape[1]:], 
            skip_special_tokens=True
        )
        
        result = env.verify(task, response)
        if result.get('correct', False):
            correct_count += 1
    
    accuracy = correct_count / num_samples
    return accuracy

# Evaluate baseline
baseline_accuracy = evaluate_model(model, tokenizer)
print(f"Baseline accuracy: {baseline_accuracy:.3f}")

## Training

In [None]:
# Training configuration
training_args = TrainingArguments(
    output_dir="./anagram_model",
    num_train_epochs=2,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    learning_rate=5e-5,
    warmup_steps=50,
    logging_steps=10,
    save_steps=200,
    save_total_limit=2,
    report_to="none",
    remove_unused_columns=False,
)

# Initialize trainer
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=formatted_dataset,
    tokenizer=tokenizer,
    dataset_text_field="text",
    max_seq_length=2048,
)

print("Trainer initialized")

In [None]:
# Start training
print("Starting training...")
trainer.train()
print("Training complete!")

# Save model
trainer.save_model("./anagram_model_final")
print("Model saved!")

## Evaluation

In [None]:
# Evaluate trained model
trained_accuracy = evaluate_model(model, tokenizer)
print(f"Trained accuracy: {trained_accuracy:.3f}")

# Plot results
models = ['Baseline', 'Trained']
accuracies = [baseline_accuracy, trained_accuracy]

plt.figure(figsize=(8, 6))
bars = plt.bar(models, accuracies, color=['blue', 'green'], alpha=0.7)
plt.ylabel('Accuracy')
plt.title('Model Performance Comparison')
plt.ylim(0, 1)

# Add value labels
for bar, acc in zip(bars, accuracies):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, 
             f'{acc:.3f}', ha='center', va='bottom')

plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

improvement = trained_accuracy - baseline_accuracy
print(f"\nImprovement: {improvement:.3f} ({improvement/baseline_accuracy:.1%} relative)")

## Test Examples

In [None]:
# Test on specific examples
env = AnagramSolverEnv()
test_tasks = env.generate(num_of_questions=3, difficulty=7)
system_prompt = generate_system_prompt()

for i, task in enumerate(test_tasks, 1):
    print(f"\n--- Example {i} ---")
    print(f"Anagrams: {task.metadata['anagrams']}")
    print(f"Correct: {task.metadata['target_words']}")
    
    # Generate response
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": task.question}
    ]
    
    input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=100, temperature=0.7, do_sample=True)
    
    response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
    result = env.verify(task, response)
    
    print(f"Model response: {response}")
    print(f"Score: {result['score']}, Correct: {result['correct']}")