# Anagram Solver RL Training

Train LLM agent to solve anagram puzzles using GRPO with proper data preparation.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/PollyLeo6/Anagram-Solver/blob/main/train_agent.ipynb)

## Setup

In [None]:
import os
import random
import numpy as np
import torch
import json
import re

# Fix random seeds
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

if not os.path.exists('anagram_game.py'):
    print('📥 Cloning repository...')
    !git clone https://github.com/PollyLeo6/Anagram-Solver.git
    %cd Anagram-Solver
    print('✅ Repository cloned!')

!pip install torch transformers datasets accelerate peft trl unsloth matplotlib
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"

In [None]:
from unsloth import FastLanguageModel
from transformers import TrainingArguments
from trl import SFTTrainer
from datasets import Dataset
import matplotlib.pyplot as plt

from anagram_game import AnagramSolverEnv
from utils import create_english_dictionary

## Data Preparation

In [None]:
# System prompt for anagram solving
SYSTEM_PROMPT = """You are an expert anagram solver. Your task is to unscramble letters to form valid English words.

Rules:
1. Use each letter exactly once
2. Form valid English words only
3. Respond in JSON format: {"solutions": ["word1", "word2", ...]}
4. Order words as they appear in the anagram list

Be accurate and follow the format exactly."""

def extract_json_answer(text: str) -> str:
    """Extract JSON answer from response"""
    try:
        # Find JSON in text
        start = text.find('{')
        end = text.rfind('}') + 1
        if start != -1 and end > start:
            json_str = text[start:end]
            parsed = json.loads(json_str)
            return json_str
        return text.strip()
    except:
        return text.strip()

def get_anagram_dataset(env, num_samples=71, difficulties=range(1, 8)):
    """Generate anagram dataset in HuggingFace format"""
    data = []
    for difficulty in difficulties:
        tasks = env.generate(num_of_questions=num_samples, difficulty=difficulty)
        for task in tasks:
            data.append({
                'prompt': [
                    {'role': 'system', 'content': SYSTEM_PROMPT},
                    {'role': 'user', 'content': task.question}
                ],
                'answer': task.answer,
                'difficulty': difficulty,
                'metadata': task.metadata
            })
    return Dataset.from_list(data)

In [None]:
# Load dictionary and create environments
create_english_dictionary()
with open('dictionary.txt', 'r', encoding='utf-8') as f:
    full_dictionary = [line.strip() for line in f.readlines()]

random.shuffle(full_dictionary)
split_idx = int(0.7 * len(full_dictionary))
train_words = full_dictionary[:split_idx]
test_words = full_dictionary[split_idx:]

print(f"Dictionary split - Train: {len(train_words)}, Test: {len(test_words)}")

# Create custom environments
class TrainAnagramEnv(AnagramSolverEnv):
    def __init__(self):
        super().__init__()
        self.dictionary = set(train_words)
        
class TestAnagramEnv(AnagramSolverEnv):
    def __init__(self):
        super().__init__()
        self.dictionary = set(test_words)

train_env = TrainAnagramEnv()
test_env = TestAnagramEnv()

# Generate datasets
train_dataset = get_anagram_dataset(train_env, num_samples=71, difficulties=range(1, 8))
print(f"Generated {len(train_dataset)} training examples")

## Reward Functions

In [None]:
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    """Reward function for correct anagram solutions"""
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_json_answer(r) for r in responses]
    
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    
    rewards = []
    for r, a in zip(extracted_responses, answer):
        try:
            # Parse both as JSON and compare
            r_json = json.loads(r) if isinstance(r, str) else r
            a_json = json.loads(a) if isinstance(a, str) else a
            
            if r_json == a_json:
                rewards.append(2.0)
            else:
                rewards.append(0.0)
        except:
            rewards.append(0.0)
    
    return rewards

def format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function for JSON format compliance"""
    responses = [completion[0]['content'] for completion in completions]
    rewards = []
    
    for response in responses:
        try:
            extracted = extract_json_answer(response)
            parsed = json.loads(extracted)
            if 'solutions' in parsed and isinstance(parsed['solutions'], list):
                rewards.append(0.5)
            else:
                rewards.append(0.0)
        except:
            rewards.append(0.0)
    
    return rewards

## Model Loading

In [None]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/Qwen2.5-1.5B-Instruct",
    max_seq_length=2048,
    dtype=None,
    load_in_4bit=True,
)

model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                   "gate_proj", "up_proj", "down_proj"],
    lora_alpha=16,
    lora_dropout=0.1,
    bias="none",
    use_gradient_checkpointing="unsloth",
)

print("Model loaded with LoRA adapters")

## SFT Training

In [None]:
# Format dataset for SFT
def format_conversation(example):
    formatted = tokenizer.apply_chat_template(
        example['prompt'] + [{'role': 'assistant', 'content': example['answer']}],
        tokenize=False,
        add_generation_prompt=False
    )
    return {"text": formatted}

formatted_dataset = train_dataset.map(format_conversation)
print(f"Formatted dataset size: {len(formatted_dataset)}")

# Training arguments
training_args = TrainingArguments(
    output_dir="./anagram_sft",
    num_train_epochs=2,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    learning_rate=5e-5,
    warmup_steps=50,
    logging_steps=10,
    save_steps=200,
    save_total_limit=2,
    report_to="none",
    remove_unused_columns=False,
)

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=formatted_dataset,
    tokenizer=tokenizer,
    dataset_text_field="text",
    max_seq_length=2048,
)

print("Starting SFT training...")
trainer.train()
print("SFT training complete!")

## Evaluation

In [None]:
def evaluate_model(model, tokenizer, env, difficulty=5, num_samples=20):
    """Evaluate model on given environment"""
    correct_count = 0
    test_tasks = env.generate(num_of_questions=num_samples, difficulty=difficulty)
    
    for task in test_tasks:
        try:
            messages = [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": task.question}
            ]
            
            input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
            
            with torch.no_grad():
                outputs = model.generate(**inputs, max_new_tokens=100, temperature=0.7, do_sample=True, pad_token_id=tokenizer.eos_token_id)
            
            response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
            result = env.verifier.verify(task, response)
            if result:
                correct_count += 1
        except Exception:
            continue
    
    return correct_count / num_samples

# Evaluate on both train and test environments
print("Evaluating model...")
train_acc = evaluate_model(model, tokenizer, train_env, difficulty=5)
test_acc = evaluate_model(model, tokenizer, test_env, difficulty=5)
test_hard = evaluate_model(model, tokenizer, test_env, difficulty=8)

print(f"Results:")
print(f"Train accuracy: {train_acc:.3f}")
print(f"Test accuracy: {test_acc:.3f}")
print(f"Test hard accuracy: {test_hard:.3f}")