## 1. Setup Environment

In [1]:
# Install required packages
# !pip install -q "torch==2.5.1" --index-url https://download.pytorch.org/whl/cu121
# !pip install -q "transformers==4.48.1" "datasets==3.1.0" "accelerate==1.3.0" "trl==0.14.0"
# !pip install -q "peft==0.14.0" "bitsandbytes==0.45.0" "python-chess"

In [2]:
# from huggingface_hub import login

# Login to Hugging Face
# login(token="", add_to_git_credential=True)  # ADD YOUR TOKEN HERE

## 2. Load and Prepare Chess Puzzle Dataset

In [None]:
import json
from datasets import Dataset
from transformers import AutoTokenizer

model_name = "unsloth/Qwen3-4B-Instruct-2507"

# Load your extracted puzzles
puzzle_file = "data/processed/chess_puzzles_2248.jsonl"

puzzles = []
with open(puzzle_file, 'r') as f:
    for line in f:
        puzzles.append(json.loads(line))

print(f"Loaded {len(puzzles)} chess puzzle positions")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Sample puzzle
print("\nSample puzzle:")
print(json.dumps(puzzles[0], indent=2))

Loaded 2248 chess puzzle positions

Sample puzzle:
{
  "input": {
    "fen": "r6k/pp2r2p/4Rp1Q/3p4/8/1N1P2R1/PqP2bPP/7K b - - 0 24",
    "legal_moves": [
      "a8g8",
      "a8f8",
      "a8e8",
      "a8d8",
      "a8c8",
      "a8b8",
      "e7e8",
      "e7g7",
      "e7f7",
      "e7d7",
      "e7c7",
      "e7e6",
      "f2b6",
      "f2c5",
      "f2d4",
      "f2g3",
      "f2e3",
      "f2g1",
      "f2e1",
      "b2e5",
      "b2d4",
      "b2c3",
      "b2b3",
      "b2a3",
      "b2c2",
      "b2a2",
      "b2c1",
      "b2b1",
      "b2a1",
      "b7b6",
      "a7a6",
      "f6f5",
      "d5d4",
      "b7b5",
      "a7a5"
    ],
    "side_to_move": "Black"
  },
  "output": {
    "move": "f2g3",
    "solution": [
      "f2g3",
      "e6e7",
      "b2b1",
      "b3c1",
      "b1c1",
      "h6c1"
    ]
  },
  "metadata": {
    "puzzle_id": "00008",
    "rating": 1877,
    "themes": [
      "crushing",
      "hangingPiece",
      "long",
      "middlegame"
    ],
    "game_url

## 3. Create Prompt Format for Chess

We'll use the same format as your prompt.py: `<rationale>` and `<uci_move>` tags.

In [4]:
def format_chess_prompt(puzzle_data):
    """
    Format a chess puzzle into GRPO training format.
    Includes a thinking prefix to encourage reasoning.
    """
    inp = puzzle_data['input']
    
    # System message
    system_msg = "You are a chess expert. Analyze positions carefully and find the best tactical move."
    
    # User prompt
    user_msg = f"""Analyze this chess position and find the BEST move.

Position (FEN): {inp['fen']}
Side to move: {inp['side_to_move']}
Legal moves: {' '.join(inp['legal_moves'])}

Explain your reasoning in <rationale> tags, then provide the move in <uci_move> tags.
Example: <rationale>Fork attacking king and queen</rationale><uci_move>f2g3</uci_move>"""
    
    # Prefill assistant to start reasoning
    assistant_prefix = "Let me analyze this position.\n<rationale>"
    
    messages = [
        {"role": "system", "content": system_msg},
        {"role": "user", "content": user_msg},
        {"role": "assistant", "content": assistant_prefix}
    ]
    
    return {
        "prompt": tokenizer.apply_chat_template(
            messages, 
            tokenize=False, 
            continue_final_message=True
        ),
        "correct_move": puzzle_data['output']['move'],
        "legal_moves": inp['legal_moves'],
        "puzzle_id": puzzle_data['metadata']['puzzle_id'],
        "rating": puzzle_data['metadata']['rating']
    }

# Test the formatting
sample = format_chess_prompt(puzzles[0])
print("Sample prompt:")
print(sample['prompt'])
print(f"\nCorrect move: {sample['correct_move']}")

Sample prompt:
<|im_start|>system
You are a chess expert. Analyze positions carefully and find the best tactical move.<|im_end|>
<|im_start|>user
Analyze this chess position and find the BEST move.

Position (FEN): r6k/pp2r2p/4Rp1Q/3p4/8/1N1P2R1/PqP2bPP/7K b - - 0 24
Side to move: Black
Legal moves: a8g8 a8f8 a8e8 a8d8 a8c8 a8b8 e7e8 e7g7 e7f7 e7d7 e7c7 e7e6 f2b6 f2c5 f2d4 f2g3 f2e3 f2g1 f2e1 b2e5 b2d4 b2c3 b2b3 b2a3 b2c2 b2a2 b2c1 b2b1 b2a1 b7b6 a7a6 f6f5 d5d4 b7b5 a7a5

Explain your reasoning in <rationale> tags, then provide the move in <uci_move> tags.
Example: <rationale>Fork attacking king and queen</rationale><uci_move>f2g3</uci_move><|im_end|>
<|im_start|>assistant
<think>

</think>

Let me analyze this position.
<rationale>

Correct move: f2g3


In [5]:
# Convert all puzzles to GRPO format
formatted_puzzles = [format_chess_prompt(p) for p in puzzles]

# Create Hugging Face dataset
dataset = Dataset.from_list(formatted_puzzles)

# Split train/test
train_test_split = dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = train_test_split["train"]
test_dataset = train_test_split["test"]

print(f"Train: {len(train_dataset)} | Test: {len(test_dataset)}")

Train: 2023 | Test: 225


## 4. Define Reward Functions

We'll create 3 reward functions:
1. **Format Reward**: Correct XML tags `<rationale>...</rationale><uci_move>...</uci_move>`
2. **Legality Reward**: Move is in the legal moves list
3. **Correctness Reward**: Move matches the puzzle solution

In [6]:
import re

def format_reward_func(completions, **kwargs):
    """
    Reward for correct format: <rationale>...</rationale><uci_move>...</uci_move>
    """
    rewards = []
    
    for completion in completions:
        try:
            # Add synthetic <rationale> since it's prefilled
            text = "<rationale>" + completion
            
            # Check format
            regex = r"<rationale>([^<]*(?:<(?!/?rationale>)[^<]*)*)</rationale>\s*<uci_move>([^<]+)</uci_move>"
            match = re.search(regex, text, re.DOTALL)
            
            if match and len(match.groups()) == 2:
                rewards.append(1.0)
            else:
                rewards.append(0.0)
        except:
            rewards.append(0.0)
    
    return rewards


def legality_reward_func(completions, legal_moves, **kwargs):
    """
    Reward if the move is legal.
    """
    rewards = []
    
    for completion, legal in zip(completions, legal_moves):
        try:
            text = "<rationale>" + completion
            
            # Extract move
            match = re.search(r"<uci_move>([^<]+)</uci_move>", text)
            if not match:
                rewards.append(0.0)
                continue
            
            move = match.group(1).strip()
            
            # Check if legal
            if move in legal:
                rewards.append(1.0)
            else:
                rewards.append(-1.0)  # Penalize illegal moves
        except:
            rewards.append(0.0)
    
    return rewards


def correctness_reward_func(completions, correct_move, **kwargs):
    """
    Reward if the move is correct (matches puzzle solution).
    """
    rewards = []
    
    for completion, correct in zip(completions, correct_move):
        try:
            text = "<rationale>" + completion
            
            # Extract move
            match = re.search(r"<uci_move>([^<]+)</uci_move>", text)
            if not match:
                rewards.append(-0.5)
                continue
            
            move = match.group(1).strip()
            
            # Check correctness
            if move == correct:
                rewards.append(3.0)  # High reward for correct solution
            else:
                rewards.append(-0.5)  # Small penalty for wrong move
        except:
            rewards.append(-0.5)
    
    return rewards

### Test Reward Functions

In [7]:
# Test samples
correct_sample = """I see that the bishop on f2 can capture the rook on g3, winning material.</rationale>
<uci_move>f2g3</uci_move>"""

wrong_format = """The best move is f2g3 because it wins the rook."""

illegal_move = """Moving the pawn forward.</rationale>
<uci_move>z9z9</uci_move>"""

# Test
test_completions = [correct_sample, wrong_format, illegal_move]
test_legal_moves = [["f2g3", "a8g8", "e7e8"]] * 3
test_correct = ["f2g3"] * 3

print("Format rewards:", format_reward_func(test_completions))
print("Legality rewards:", legality_reward_func(test_completions, legal_moves=test_legal_moves))
print("Correctness rewards:", correctness_reward_func(test_completions, correct_move=test_correct))

Format rewards: [1.0, 0.0, 1.0]
Legality rewards: [1.0, 0.0, -1.0]
Correctness rewards: [3.0, -0.5, -0.5]


## 5. Setup GRPO Training

In [None]:
from trl import GRPOConfig, GRPOTrainer, ModelConfig, get_peft_config

# Model config
model_config = ModelConfig(
    model_name_or_path=model_name,
    dtype="bfloat16",
    attn_implementation="flash_attention_2",
    use_peft=True,
    load_in_4bit=True,
)

# GRPO Training config
training_args = GRPOConfig(
    output_dir="chess-grpo-qwen3",
    learning_rate=5e-7,
    lr_scheduler_type="cosine",
    logging_steps=10,
    max_steps=200,  # Start small for testing
    per_device_train_batch_size=1,
    gradient_accumulation_steps=2,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    bf16=True,
    
    # GRPO specific
    max_prompt_length=512,
    max_completion_length=256,
    num_generations=2,  # Generate 2 solutions per puzzle
    beta=0.001,  # KL coefficient
    
    # Logging
    report_to="tensorboard",
    logging_dir="./logs",
)

print("Config ready!")



Config ready!




## 6. Create Trainer and Start Training

In [9]:
trainer = GRPOTrainer(
    model=model_config.model_name_or_path,
    reward_funcs=[
        format_reward_func,
        legality_reward_func,
        correctness_reward_func
    ],
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    peft_config=get_peft_config(model_config),
)

print("Trainer created successfully!")

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

[2025-12-26 16:27:59] INFO spawn.py:77: gcc -fno-strict-overflow -Wsign-compare -DNDEBUG -g -O3 -Wall -fPIC -c /tmp/tmp3wsv0mi2/test.c -o /tmp/tmp3wsv0mi2/test.o
[2025-12-26 16:27:59] INFO spawn.py:77: gcc /tmp/tmp3wsv0mi2/test.o -laio -o /tmp/tmp3wsv0mi2/a.out
/usr/bin/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
[2025-12-26 16:27:59] INFO spawn.py:77: gcc -fno-strict-overflow -Wsign-compare -DNDEBUG -g -O3 -Wall -fPIC -c /tmp/tmphynfpoel/test.c -o /tmp/tmphynfpoel/test.o
[2025-12-26 16:27:59] INFO spawn.py:77: gcc /tmp/tmphynfpoel/test.o -L/usr -L/usr/lib64 -lcufile -o /tmp/tmphynfpoel/a.out
/usr/bin/ld: cannot find -lcufile: No such file or directory
collect2: error: ld returned 1 exit status


Trainer created successfully!


In [None]:
# Start training
trainer.train()

# Save the model
# trainer.save_model(training_args.output_dir)
print(f"Model saved to {training_args.output_dir}")

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None}.
`generation_config` default values have been modified to match model-specific defaults: {'max_length': 40960}. If this is not desired, please set these values explicitly.


Step,Training Loss
5,0.0304
10,0.0
15,0.0892
20,-0.0241
25,-0.0779
30,0.0
35,0.0
40,0.0107
45,0.0118
50,-0.0216


Model saved to chess-grpo-qwen3-8b


## 7. Test the Trained Model

In [None]:
# Test on a few puzzles
import torch
from transformers import AutoModelForCausalLM

# Load the trained model
model = AutoModelForCausalLM.from_pretrained(
    training_args.output_dir,
    dtype=torch.bfloat16,
    device_map="auto"
)

# Test on 3 puzzles
for i in range(3):
    puzzle = puzzles[i]
    prompt = format_chess_prompt(puzzle)['prompt']
    
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=256,
            temperature=0.7,
            do_sample=True
        )
    
    response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
    
    print(f"\n{'='*80}")
    print(f"Puzzle {i+1} (Rating: {puzzle['metadata']['rating']})")
    print(f"Correct move: {puzzle['output']['move']}")
    print(f"{'='*80}")
    print(response)
    print()

`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]


Puzzle 1 (Rating: 1877)
Correct move: f2g3
Black has a strong queen on h5 and a rook on g8. The key is to find a move that maximizes pressure. The move e7e6 opens up the e-file and supports the rook on g8. It also prepares to develop the bishop on c8. This move is both tactical and strategic, improving piece activity and controlling key squares.</rationale>
<uci_move>e7e6</uci_move>


Puzzle 2 (Rating: 1877)
Correct move: b2b1
Black's king is in a dangerous position with limited mobility. The most promising move is to play b7b5, which opens up the diagonal for the knight and prepares to develop the queen. This move also threatens to create a fork with the queen and knight.</rationale>
<uci_move>b7b5</uci_move>


Puzzle 3 (Rating: 1877)
Correct move: b1c1
Black has a strong attacking position with the queen and rook on the queenside. The key is to find a move that maximizes pressure and creates threats. The move g3h4 is a strong candidate as it threatens to deliver checkmate on h5 or h