In [113]:
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer
import re
import pandas as pd
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import wandb

Load the model

In [114]:
model_name = "Qwen/Qwen2.5-1.5B-Instruct"


model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    # attn_implementation="flash_attention_2",
    # peft_config=peft_config,
    device_map=None
)

In [115]:
from peft import get_peft_model, LoraConfig

peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, peft_config)

Dataset

In [116]:
dataset = load_dataset("predibase/wordle-grpo", split="train")
dataset[3]

{'prompt': '<|im_start|>system\n\nYou are playing Wordle, a word-guessing game.\n\n### Game Rules:\n- You have **6 tries** to guess a secret **5-letter** word.\n- Each guess must be a valid **5-letter English word**.\n- After each guess, you will receive feedback indicating how close your guess was.\n\n### Feedback Format:\nEach letter in your guess will receive one of three symbols:\n1. ✓ : The letter is in the word and in the CORRECT position.\n2. - : The letter is in the word but in the WRONG position.\n3. x : The letter is NOT in the word.\n\n### Example:\nSecret Word: BRISK\n\nGuess 1: STORM → Feedback: S(-) T(x) O(x) R(-) M(x)\nGuess 2: BRAVE → Feedback: B(✓) R(✓) A(x) V(x) E(x)\nGuess 3: BRISK → Feedback: B(✓) R(✓) I(✓) S(✓) K(✓)\n\n### Response Format:\nThink through the problem and feedback step by step. Make sure to first add your step by step thought process within <think> </think> tags. Then, return your guessed word in the following format: <guess> guessed-word </guess>.\n

In [117]:
words = pd.read_csv("https://raw.githubusercontent.com/arnavgarg1/arnavgarg1/refs/heads/main/five_letter_words.csv")
valid_words = set(words["Word"].str.lower())

Reward Functions

In [118]:
# helper functions
def regexsearch(text)->str:
    match = re.search(r"<guess>\s*(\w+)\s*</guess>", text)
    return match.group(1).strip().lower() if match else ""


def extractguess(completions)->list[str]:
    # response = [c[0]["content"] for c in completions]
    guesses = [regexsearch(completion) for completion in completions]
    print(f"[EXTRACT_GUESS] Extracted guesses: {guesses}") # for debugging
    return guesses



# reward functions
def finalguess(completions, secret, **kwargs) -> list[float]:
    guesses = extractguess(completions)
    answer = secret.strip().lower() 
    # answers = [t.strip().lower() for t in secret] * len(completions)
    # reward = [2.0 if g == t else 0.0 for g,t in zip(guesses, answers)]
    # print(f"[FINAL_GUESS] Secret word: {answers[0]}")
    reward = [2.0 if g == answer else 0.0 for g in guesses]
    print(f"secre word is: {answer}")
    return reward




def feedback_reward(completions, past_guess_history, **kwargs) -> list[float]:
    print("="*50)
    print("ALL COMPLETIONS:")
    for i, completion in enumerate(completions):
        print(f"\n--- Completion {i} ---")
        print(completion)
        print(f"--- End Completion {i} ---")
    print("="*50)
    
    rewards = []
    guesses = extractguess(completions)
    print("Guesses: ", guesses)
    
    for i, g in enumerate(guesses):
        penalty = 0.0
        
        history_str = past_guess_history[i] if i < len(past_guess_history) else "[]"
        
        try:
            import ast
            history_list = ast.literal_eval(history_str)
            print(f"Guess {i} ('{g}') - Parsed history: {history_list}")
            
            for word, feedback in history_list:
                print(f"  Processing: word='{word}', feedback='{feedback}'")
                parsed = re.findall(r"([A-Z])\((.)\)", feedback)
                
                for pos, (letter, status) in enumerate(parsed):
                    letter = letter.lower()
                    
                    if status == "x":
                        if letter in g:
                            penalty += 0.1
                            print(f"    Penalty +0.1: '{letter}' marked as (x) but found in guess")
                    
                    elif status == "-":
                        if letter not in g:
                            penalty += 0.1
                            print(f"    Penalty +0.1: '{letter}' marked as (-) but not in guess")
                        elif pos < len(g) and g[pos] == letter:
                            penalty += 0.1
                            print(f"    Penalty +0.1: '{letter}' marked as (-) but in same position")
                    
                    elif status == "✓":
                        if pos >= len(g) or g[pos] != letter:
                            penalty += 0.2
                            print(f"    Penalty +0.2: '{letter}' marked as (✓) but not in correct position")
            
        except Exception as e:
            print(f"Error parsing history for guess {i}: {e}")
            penalty = 0.0
        
        reward = max(0.0, 1.0 - penalty)
        rewards.append(reward)
        print(f"  Final penalty: {penalty}, reward: {reward}")
    
    print(f"Final rewards: {rewards}")
    return rewards


def validword(completions, **kwargs) -> list[float]:
    guesses = extractguess(completions)
    rewards = []
    for g in guesses:
        penalty = 0.0
        penalty += 0.5 if len(g) != 5 else 0.0
        penalty += 0.5 if g not in valid_words else 0.0
        reward = max(0.0, 1.0-penalty)
        rewards.append(reward)
    return rewards



def xmlformat(completions, **kwargs) -> list[float]:
    pattern = r"<think>.*?</think>\s*<guess>.*?</guess>"
    rewards = []
    
    for completion in completions:
        match = re.search(pattern, completion, flags=re.DOTALL)
        reward = 1.0 if match else 0.0
        rewards.append(reward)
    
    return rewards


In [119]:
training_args = GRPOConfig(output_dir="Qwen2.5-1.5B-GRPO",
                           num_generations=4,
                           log_completions=True,
                           num_completions_to_print=4,
                           learning_rate=5e-6,
                           temperature=0.8,
                           top_p=0.9,
                           top_k=50,
                           gradient_accumulation_steps=4,
                           per_device_train_batch_size=1,
                           report_to="wandb",
                           run_name="wordle-grpo"
                           
                           )

In [None]:
trainer = GRPOTrainer(
        model=model,
        reward_funcs=[
            finalguess,
            feedback_reward,
            validword,
            xmlformat
            ],
        args=training_args,
        train_dataset=dataset,
    )
trainer.train()
