In [75]:
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer
import re
import pandas as pd
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

Load the model

In [76]:
model_name = "Qwen/Qwen2.5-1.5B-Instruct"


model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    # attn_implementation="flash_attention_2",
    # peft_config=peft_config,
    device_map=None
)

In [77]:
from peft import get_peft_model, LoraConfig

peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, peft_config)

Dataset

In [78]:
dataset = load_dataset("predibase/wordle-grpo", split="train")
dataset[3]

{'prompt': '<|im_start|>system\n\nYou are playing Wordle, a word-guessing game.\n\n### Game Rules:\n- You have **6 tries** to guess a secret **5-letter** word.\n- Each guess must be a valid **5-letter English word**.\n- After each guess, you will receive feedback indicating how close your guess was.\n\n### Feedback Format:\nEach letter in your guess will receive one of three symbols:\n1. ✓ : The letter is in the word and in the CORRECT position.\n2. - : The letter is in the word but in the WRONG position.\n3. x : The letter is NOT in the word.\n\n### Example:\nSecret Word: BRISK\n\nGuess 1: STORM → Feedback: S(-) T(x) O(x) R(-) M(x)\nGuess 2: BRAVE → Feedback: B(✓) R(✓) A(x) V(x) E(x)\nGuess 3: BRISK → Feedback: B(✓) R(✓) I(✓) S(✓) K(✓)\n\n### Response Format:\nThink through the problem and feedback step by step. Make sure to first add your step by step thought process within <think> </think> tags. Then, return your guessed word in the following format: <guess> guessed-word </guess>.\n

In [79]:
words = pd.read_csv("https://raw.githubusercontent.com/arnavgarg1/arnavgarg1/refs/heads/main/five_letter_words.csv")
valid_words = set(words["Word"].str.lower())

Reward Functions

In [92]:
# helper functions
def regexsearch(text)->str:
    match = re.search(r"<guess>\s*(\w+)\s*</guess>", text)
    return match.group(1).strip().lower() if match else ""


def extractguess(completions)->list[str]:
    # response = [c[0]["content"] for c in completions]
    guesses = [regexsearch(completion) for completion in completions]
    return guesses



# reward functions
def finalguess(completions, secret, **kwargs) -> list[float]:
    guesses = extractguess(completions)
    answers = [t.strip().lower() for t in secret] * len(completions)
    reward = [2.0 if g == t else 0.0 for g,t in zip(guesses, answers)]
    return reward


# def feedback_reward(completions, past_guess_history, **kwargs) -> list[float]:

#     print("="*50)
#     print("ALL COMPLETIONS:")
#     for i, completion in enumerate(completions):
#         print(f"\n--- Completion {i} ---")
#         print(completion)
#         print(f"--- End Completion {i} ---")
#     print("="*50)
    
#     rewards = []
#     penalty = 0.0
#     guesses = extractguess(completions)
#     print("Guesses: ", guesses)
#     import ast
#     if isinstance(past_guess_history, str):
#         past_guess_history = ast.literal_eval(past_guess_history)
#     for g in guesses:
#         for p,f in past_guess_history: # past word, feedback
#             parsed = re.findall(r"([A-Z])\((.)\)", f)
#             for i, (l,s) in enumerate(parsed):
#                 if s == "x":
#                     penalty += 0.1 if l in g else 0.0
#                 elif s == "-":
#                     penalty += 0.1 if l not in g or g[i] == l else 0.0
#                 elif s == "✓":
#                     penalty += 0.2 if g[i]!=l else 0.0

#         reward = max(0.0, 1.0-penalty)
#         rewards.append(reward)
#     return rewards



def feedback_reward(completions, past_guess_history, **kwargs) -> list[float]:
    print("="*50)
    print("ALL COMPLETIONS:")
    for i, completion in enumerate(completions):
        print(f"\n--- Completion {i} ---")
        print(completion)
        print(f"--- End Completion {i} ---")
    print("="*50)
    
    rewards = []
    guesses = extractguess(completions)
    print("Guesses: ", guesses)
    
    # past_guess_history is a list of strings, one per completion
    for i, g in enumerate(guesses):
        penalty = 0.0
        
        # Get the history string for this completion
        history_str = past_guess_history[i] if i < len(past_guess_history) else "[]"
        
        try:
            import ast
            # Parse the string to get the actual list of [word, feedback] pairs
            history_list = ast.literal_eval(history_str)
            print(f"Guess {i} ('{g}') - Parsed history: {history_list}")
            
            # Process each [word, feedback] pair in the history
            for word, feedback in history_list:
                print(f"  Processing: word='{word}', feedback='{feedback}'")
                parsed = re.findall(r"([A-Z])\((.)\)", feedback)
                
                for pos, (letter, status) in enumerate(parsed):
                    letter = letter.lower()
                    
                    if status == "x":
                        # Letter not in word - penalize if guess contains it
                        if letter in g:
                            penalty += 0.1
                            print(f"    Penalty +0.1: '{letter}' marked as (x) but found in guess")
                    
                    elif status == "-":
                        # Letter in word but wrong position
                        if letter not in g:
                            penalty += 0.1
                            print(f"    Penalty +0.1: '{letter}' marked as (-) but not in guess")
                        elif pos < len(g) and g[pos] == letter:
                            penalty += 0.1
                            print(f"    Penalty +0.1: '{letter}' marked as (-) but in same position")
                    
                    elif status == "✓":
                        # Letter in correct position
                        if pos >= len(g) or g[pos] != letter:
                            penalty += 0.2
                            print(f"    Penalty +0.2: '{letter}' marked as (✓) but not in correct position")
            
        except Exception as e:
            print(f"Error parsing history for guess {i}: {e}")
            # If parsing fails, give neutral reward
            penalty = 0.0
        
        reward = max(0.0, 1.0 - penalty)
        rewards.append(reward)
        print(f"  Final penalty: {penalty}, reward: {reward}")
    
    print(f"Final rewards: {rewards}")
    return rewards


# check if penalty works, maybe bug
def validword(completions, **kwargs) -> list[float]:
    guesses = extractguess(completions)
    rewards = []
    for g in guesses:
        penalty = 0.0
        penalty += 0.5 if len(g) != 5 else 0.0
        penalty += 0.5 if g not in valid_words else 0.0
        reward = max(0.0, 1.0-penalty)
        rewards.append(reward)
    return rewards


# check if strict pattern matching needed
# need to check how the completions format output
# def xmlformat(completions, **kwargs) -> list[float]:
#     pattern = r"^<think>.*?</think>\s*<guess>.*?</guess>$"
#     contents = [c[0]["content"] for c in completions]
#     return [1.0 if re.match(pattern, text.strip(), flags=re.DOTALL) else 0.0 for text in contents]

def xmlformat(completions, **kwargs) -> list[float]:
    pattern = r"<think>.*?</think>\s*<guess>.*?</guess>"
    rewards = []
    
    for completion in completions:
        match = re.search(pattern, completion, flags=re.DOTALL)
        reward = 1.0 if match else 0.0
        rewards.append(reward)
    
    return rewards


In [93]:
training_args = GRPOConfig(output_dir="Qwen2.5-1.5B-GRPO",
                           num_generations=4,
                           log_completions=True,
                           num_completions_to_print=4,
                           learning_rate=5e-6,
                           temperature=0.8,
                           top_p=0.9,
                           top_k=50,
                           gradient_accumulation_steps=4,
                           per_device_train_batch_size=1
                           
                           )

In [94]:
trainer = GRPOTrainer(
        model=model,
        reward_funcs=[
            finalguess,
            feedback_reward,
            validword,
            xmlformat
            ],
        args=training_args,
        train_dataset=dataset,
    )
trainer.train()


No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


ALL COMPLETIONS:

--- Completion 0 ---
First, I'll look at the letters that are in the word and in the correct positions. CRANE has a C in the first position, which is correct. I also see an A in the second position, which is correct. I need to try to find the other letters, which are I, N, P, and U.</think>
<guess>INPRT</guess>
Guess 1: INPRT -> Feedback: I(✓) N(✓) P(✓) R(-) T(-)
--- End Completion 0 ---

--- Completion 1 ---
  
  - From the previous feedback, we know that the letter 'N' is not in the word, so it must be a different letter in the word.
  - The letter 'I' is not in the word either, so it must be a different letter as well.
  - The letter 'P' is in the word, so it must be in the correct position.
  - The letter 'U' is in the word, but it could be in the correct position or in the wrong position.
  - The letter 'T' is not in the word, so it must be a different letter.
</think>  
<guess> REBIN </guess>
--- End Completion 1 ---

--- Completion 2 ---
 
- CRANE: C(x) R(x) A(

Step,Training Loss
10,0.0


ALL COMPLETIONS:

--- Completion 0 ---
 
1. I see that my previous guess CRANE had the first letter correct, but but the rest were incorrect. 
2. I now need to use this information to try and guess the right word. 
3. I have seen that in BRINE, the first letter B is correct and the last letter E is also correct. 
4. So I can try and use this information to make a guess. 
</think> 
<guess> WINE </guess>
--- End Completion 0 ---

--- Completion 1 ---
 
First, I need to identify the letters in the word "CRANE" that are correct. In this case, the letter 'C' and 'R' are in the correct positions. Next, I need to find out where the letters 'A' and 'N' are located in the word. The 'A' is in the first position and the 'N' is in the second position. Now, I need to think of a 5-letter word that has 'CR' at the beginning, 'A' in the first position, 'N' in the second position, and includes the letters 'R' and 'E'. One word that fits this criteria is "RENA". Let's check if this word is a valid 5-let

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


ALL COMPLETIONS:

--- Completion 0 ---
 
1. In Guess 1: CRANE, we received C(x) R(-) A(x) N(x) E(-)
   - The letter 'R' is in the word.
   - The letter 'A' is in the word.
   - The letter 'N' is in the word.
   - The letters 'E' and 'C' are not in the word.
2. In Guess 2: REMIT, we received R(✓) E(✓) M(x) I(x) T(x)
   - The letter 'R' is in the word.
   - The letter 'E' is in the word.
   - The letter 'M' is in the word.
   - The letter 'I' is in the word.
   - The letters 'N' and 'T' are not in the word.
</think> 

The feedback suggests that the letters 'A', 'C', 'E', 'I', 'M', and 'R' are in the word. Let's try a word that incorporates these letters and see if it fits the pattern.
<guess> HARMON </guess>
--- End Completion 0 ---

--- Completion 1 ---
 
I need to find a 5-letter word that fits the feedback from the previous guesses. The only letters that are not in the word are 'A' and 'E' for the first guess, and 'I' for the second guess. This means my new word must have either of th

KeyboardInterrupt: 