In [None]:
!pip install -q transformers accelerate datasets trl pandas bitsandbytes

In [None]:
import os, torch, re, ast, math, pandas as pd
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import GRPOTrainer, GRPOConfig
from typing import List
from peft import LoraConfig, get_peft_model

In [None]:
dataset = load_dataset("predibase/wordle-grpo", split="train")
print(dataset[0])

{'prompt': '<|im_start|>system\n\nYou are playing Wordle, a word-guessing game.\n\n### Game Rules:\n- You have **6 tries** to guess a secret **5-letter** word.\n- Each guess must be a valid **5-letter English word**.\n- After each guess, you will receive feedback indicating how close your guess was.\n\n### Feedback Format:\nEach letter in your guess will receive one of three symbols:\n1. ✓ : The letter is in the word and in the CORRECT position.\n2. - : The letter is in the word but in the WRONG position.\n3. x : The letter is NOT in the word.\n\n### Example:\nSecret Word: BRISK\n\nGuess 1: STORM → Feedback: S(-) T(x) O(x) R(-) M(x)\nGuess 2: BRAVE → Feedback: B(✓) R(✓) A(x) V(x) E(x)\nGuess 3: BRISK → Feedback: B(✓) R(✓) I(✓) S(✓) K(✓)\n\n### Response Format:\nThink through the problem and feedback step by step. Make sure to first add your step by step thought process within <think> </think> tags. Then, return your guessed word in the following format: <guess> guessed-word </guess>.\n

In [None]:
print(dataset[0].keys())

dict_keys(['prompt', 'word_list', 'past_guess_history', 'secret'])


In [None]:
model_name = "Qwen/Qwen2.5-7B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype="auto", device_map="auto"
)

In [None]:
lora_alpha = 16
lora_dropout = 0.1
lora_r = 64

lora_config = LoraConfig(
    r=lora_r,
    lora_alpha=lora_alpha,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj"
    ],
    lora_dropout=lora_dropout,
    bias="none",
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

trainable params: 161,480,704 || all params: 7,777,097,216 || trainable%: 2.0764


In [None]:
# Reward Functions
def output_format_check(prompts, completions, examples=None, **kwargs) -> List[float]:
    """Reward 1: Ensure completion has valid <think>...</think> and <guess>...</guess> format and guess is valid."""
    rewards = []
    for i, completion in enumerate(completions):
        reward = 0.0
        try:
            completion = "<think>" + completion  # simulate prefix
            regex = (
                r"^<think>\s*([^<]*(?:<(?!/?think>)[^<]*)*)\s*<\/think>\n"
                r"<guess>\s*([\s\S]*?)\s*<\/guess>$"
            )
            match = re.search(regex, completion, re.DOTALL)
            if not match:
                rewards.append(0.0)
                continue

            guess = match.groups()[1].strip()
            if len(guess) != 5:
                rewards.append(0.1)
                continue

            word_list_obj = examples[i]["word_list"]
            if isinstance(word_list_obj, list):
                words = word_list_obj
            elif isinstance(word_list_obj, str) and word_list_obj.endswith(".csv"):
                words = pd.read_csv(word_list_obj)["Word"].tolist()
            else:
                words = []

            reward = 1.0 if guess in words else 0.5
        except Exception:
            reward = 0.0
        rewards.append(reward)
    return rewards


def uses_previous_feedback(prompts, completions, examples=None, **kwargs) -> List[float]:
    """Reward 2: Reward if guess uses information from previous feedback effectively."""
    rewards = []
    for i, completion in enumerate(completions):
        reward = 0.0
        try:
            completion = "<think>" + completion
            match = re.search(r"<guess>\s*([\s\S]*?)\s*<\/guess>$", completion, re.DOTALL)
            if not match:
                rewards.append(0.0)
                continue

            guess = match.groups()[0].strip()
            if len(guess) != 5:
                rewards.append(0.0)
                continue

            # Parse guess history
            past_guess_history = ast.literal_eval(examples[i]["past_guess_history"])
            correct, valid, wrong = {}, {}, {}

            for _, fb in past_guess_history:
                parts = fb.split(" ")
                for pos, f in enumerate(parts):
                    if "✓" in f:
                        correct.setdefault(f[0], set()).add(pos)
                    elif "-" in f:
                        valid.setdefault(f[0], set()).add(pos)
                    else:
                        wrong.setdefault(f[0], set()).add(pos)

            # Score based on adherence to feedback
            for idx, letter in enumerate(guess):
                if letter in correct and idx in correct[letter]:
                    reward += 0.2
                elif letter in valid and idx not in valid[letter]:
                    reward += 0.1
                elif letter in valid and idx in valid[letter]:
                    reward -= 0.2
                elif letter in wrong:
                    reward -= 0.5
                else:
                    reward += 0.05
        except Exception:
            reward = 0.0
        rewards.append(reward)
    return rewards


def guess_value(prompts, completions, examples=None, **kwargs) -> List[float]:
    """Reward 3: Information gain — how much the guess reduces uncertainty about the secret."""
    rewards = []

    def validate_guess(secret, guess):
        feedback, secret_list = [], list(secret)
        for i, (g, s) in enumerate(zip(guess, secret)):
            if g == s:
                feedback.append(f"{g}(✓)")
                secret_list[i] = None
            else:
                feedback.append(None)
        for i, g in enumerate(guess):
            if feedback[i] is None:
                if g in secret_list:
                    feedback[i] = f"{g}(-)"
                    secret_list[secret_list.index(g)] = None
                else:
                    feedback[i] = f"{g}(x)"
        return " ".join(feedback)

    def compute_info_gain(all_words, guess):
        total = len(all_words)
        if total == 0:
            return 0.0
        current_entropy = math.log2(total)
        groups = {}
        for word in all_words:
            fb = validate_guess(word, guess)
            pattern = "".join('1' if "✓" in x else ('0' if "-" in x else 'x') for x in fb.split())
            groups.setdefault(pattern, []).append(word)
        expected_entropy = sum(
            (len(g)/total) * math.log2(len(g)) for g in groups.values() if len(g) > 0
        )
        return (current_entropy - expected_entropy) / current_entropy if current_entropy > 0 else 0.0

    for i, completion in enumerate(completions):
        reward = 0.0
        try:
            completion = "<think>" + completion
            match = re.search(r"<guess>\s*([\s\S]*?)\s*<\/guess>$", completion, re.DOTALL)
            if not match:
                rewards.append(0.0)
                continue

            guess = match.groups()[0].strip()
            if len(guess) != 5:
                rewards.append(0.0)
                continue

            word_list_obj = examples[i]["word_list"]
            words = word_list_obj if isinstance(word_list_obj, list) else []

            reward = compute_info_gain(words, guess)
        except Exception:
            reward = 0.0
        rewards.append(reward)
    return rewards

In [None]:
training_args = GRPOConfig(
    learning_rate=5e-5,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=2,
    num_generations=4,
    max_prompt_length=512,
    max_completion_length=256,
    max_steps=200,
    save_steps=25,
    output_dir="./wordle_grpo_lora",
    report_to="tensorboard",
)

In [None]:
trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    train_dataset=dataset,
    reward_funcs=[
        output_format_check,
        uses_previous_feedback,
        guess_value,
    ],
    args=training_args,
)

In [None]:
trainer.train()

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.


Step,Training Loss
10,0.0066
20,0.0229
30,-0.016
40,0.0109
50,-0.0368
60,0.0204
70,-0.0231
80,0.0225
90,-0.0399
100,-0.0243


TrainOutput(global_step=200, training_loss=0.004926497954875231, metrics={'train_runtime': 2436.467, 'train_samples_per_second': 0.328, 'train_steps_per_second': 0.082, 'total_flos': 0.0, 'train_loss': 0.004926497954875231})

In [None]:
trainer.save_model("./wordle_grpo/final_model")

In [None]:
base_model_name = "Qwen/Qwen2.5-7B-Instruct"
adapter_dir = "/content/wordle_grpo/final_model"

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model_name)

In [None]:
from peft import PeftModel

base_model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    torch_dtype=torch.float16,
    device_map="auto"
)

# Attach LoRA adapters
model = PeftModel.from_pretrained(base_model, adapter_dir)

## Inference

In [None]:
prompt= """
<|im_start|>system
You are playing Wordle, a word-guessing game.

### Game Rules:
- You have **6 tries** to guess a secret **5-letter** word.
- Each guess must be a valid **5-letter English word**.
- After each guess, you will receive feedback indicating how close your guess was.

### Feedback Format:
Each letter in your guess will receive one of three symbols:
1. ✓ : The letter is in the word and in the CORRECT position.
2. - : The letter is in the word but in the WRONG position.
3. x : The letter is NOT in the word.

### Example:
Secret Word: PLANT

Guess 1: CRANE → Feedback: C(x) R(x) A(-) N(✓) E(x)
Guess 2: STONE → Feedback: S(x) T(x) O(x) N(✓) E(x)
Guess 3: PLANT → Feedback: P(✓) L(✓) A(✓) N(✓) T(✓)

### Response Format:
Think through the problem and feedback step by step.
Make sure to include your reasoning inside <think></think> tags.
Then, return your final guess in this format: <guess> guessed-word </guess>.
<|im_end|>
<|im_start|>user
Make a new 5-letter word guess.

Here is some previous feedback:
Guess 1: SLATE → Feedback: S(x) L(-) A(x) T(x) E(-)
Guess 2: PRONG → Feedback: P(✓) R(x) O(x) N(x) G(x)
<|im_end|>
"""

In [None]:
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

output = model.generate(
    **inputs,
    max_new_tokens=2000,
    temperature=0.7,
    top_p=0.9,
    do_sample=True,
    pad_token_id=tokenizer.eos_token_id
)

decoded = tokenizer.decode(output[0], skip_special_tokens=True)
print(decoded)


system
You are playing Wordle, a word-guessing game.

### Game Rules:
- You have **6 tries** to guess a secret **5-letter** word.
- Each guess must be a valid **5-letter English word**.
- After each guess, you will receive feedback indicating how close your guess was.

### Feedback Format:
Each letter in your guess will receive one of three symbols:
1. ✓ : The letter is in the word and in the CORRECT position.
2. - : The letter is in the word but in the WRONG position.
3. x : The letter is NOT in the word.

### Example:
Secret Word: PLANT

Guess 1: CRANE → Feedback: C(x) R(x) A(-) N(✓) E(x)
Guess 2: STONE → Feedback: S(x) T(x) O(x) N(✓) E(x)
Guess 3: PLANT → Feedback: P(✓) L(✓) A(✓) N(✓) T(✓)

### Response Format:
Think through the problem and feedback step by step.
Make sure to include your reasoning inside <think></think> tags.
Then, return your final guess in this format: <guess> guessed-word </guess>.

user
Make a new 5-letter word guess.

Here is some previous feedback:
Guess 1: 

In [None]:
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

output = merged_model.generate(
    **inputs,
    max_new_tokens=1500,
    temperature=0.7,
    top_p=0.9,
    do_sample=True,
    pad_token_id=tokenizer.eos_token_id
)

decoded = tokenizer.decode(output[0], skip_special_tokens=True)
print(decoded)


system
You are playing Wordle, a word-guessing game.

### Game Rules:
- You have **6 tries** to guess a secret **5-letter** word.
- Each guess must be a valid **5-letter English word**.
- After each guess, you will receive feedback indicating how close your guess was.

### Feedback Format:
Each letter in your guess will receive one of three symbols:
1. ✓ : The letter is in the word and in the CORRECT position.
2. - : The letter is in the word but in the WRONG position.
3. x : The letter is NOT in the word.

### Example:
Secret Word: PLANT

Guess 1: CRANE → Feedback: C(x) R(x) A(-) N(✓) E(x)
Guess 2: STONE → Feedback: S(x) T(x) O(x) N(✓) E(x)
Guess 3: PLANT → Feedback: P(✓) L(✓) A(✓) N(✓) T(✓)

### Response Format:
Think through the problem and feedback step by step.
Make sure to include your reasoning inside <think></think> tags.
Then, return your final guess in this format: <guess> guessed-word </guess>.

user
Make a new 5-letter word guess.

Here is some previous feedback:
Guess 1: 