In [2]:
import torch

In [3]:
import reasoning_gym

In [1]:
from dataclasses import dataclass
from typing import Optional
import torch

@dataclass
class GRPOConfig:
    """Configuration class for GRPO (Generalized Reward-based Policy Optimization) hyperparameters"""
    
    # Model and training hyperparameters
    model_name: str = "Qwen/Qwen2.5-0.5B-Instruct"
    learning_rate: float = 3e-4
    batch_size: int = 32
    num_updates: int = 1000
    max_steps: int = 20
    n_outputs: int = 4
    max_length: int = 256
    grpo_iterations: int = 4  # Number of GRPO iterations per update
    
    # GRPO specific hyperparameters
    clip_epsilon: float = 0.2  # PPO clipping parameter
    kl_beta: float = 0.02      # KL divergence coefficient

    # Training configuration
    # gradient_accumulation_steps: int = 1
    # warmup_steps: int = 100
    # max_grad_norm: float = 1.0
    seed: int = 42
    
    # Dataset configuration
    dataset_name: str = "syllogism"
    dataset_size: int = 1000
    
    # Device configuration
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Optimization
    adam_epsilon: float = 1e-8
    weight_decay: float = 0.01
    
    # Logging and saving
    log_interval: int = 10
    save_interval: int = 100
    eval_interval: int = 50
    
    # Generation parameters
    temperature: float = 1.0
    top_p: float = 0.9
    top_k: int = 50
    do_sample: bool = True
# Initialize configuration
config = GRPOConfig()

In [2]:
# Display configuration
print("GRPO Configuration:")
print("=" * 50)
for field, value in config.__dict__.items():
    print(f"{field:<25}: {value}")
print("=" * 50)

GRPO Configuration:
model_name               : Qwen/Qwen2.5-0.5B-Instruct
learning_rate            : 0.0003
batch_size               : 32
num_updates              : 1000
max_steps                : 20
n_outputs                : 4
max_length               : 256
grpo_iterations          : 4
clip_epsilon             : 0.2
kl_beta                  : 0.02
seed                     : 42
dataset_name             : syllogism
dataset_size             : 1000
device                   : cuda
adam_epsilon             : 1e-08
weight_decay             : 0.01
log_interval             : 10
save_interval            : 100
eval_interval            : 50
temperature              : 1.0
top_p                    : 0.9
top_k                    : 50
do_sample                : True


In [5]:
# Load a small LLM using configuration

# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained(config.model_name)
model = AutoModelForCausalLM.from_pretrained(config.model_name)

# Add padding token if not present
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"Model loaded: {config.model_name}")
print(f"Device: {config.device}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")


Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


Model loaded: Qwen/Qwen2.5-0.5B-Instruct
Device: cuda
Model parameters: 494,032,768


In [6]:
# Load the dataset using configuration

dataset = reasoning_gym.create_dataset(config.dataset_name, size=config.dataset_size, seed=config.seed)

print(f"Dataset loaded: {config.dataset_name}")
print(f"Dataset size: {config.dataset_size}")
print(f"Sample data point:")
for i, data in enumerate(dataset):
    print(f"Question: {data['question']}")
    print(f"Answer: {data['answer']}")
    if i == 0:  # Show only first sample
        break


Dataset loaded: syllogism
Dataset size: 1000
Sample data point:
Question: Consider these statements:
1. No students are humans
2. All humans are chefs

Does it logically follow that:
Some chefs are humans?
(Answer Yes or No)
Answer: Yes


In [8]:
for data in dataset:
    print(data)
    # Prepare the input
    input_text = data['question'] + " " + data['answer']
    inputs = tokenizer(input_text, return_tensors='pt')

    # Generate output
    outputs = model.generate(**inputs, max_length=50)

    # Decode the output
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    print(f"Input: {input_text}\nGenerated: {generated_text}\n")
    break

{'question': 'Consider these statements:\n1. No students are humans\n2. All humans are chefs\n\nDoes it logically follow that:\nSome chefs are humans?\n(Answer Yes or No)', 'answer': 'Yes', 'metadata': {'source_dataset': 'syllogism', 'source_index': 0, 'premise1': 'No students are humans', 'premise2': 'All humans are chefs', 'selected_premise': 2, 'conclusion': 'Some chefs are humans', 'is_valid': True, 'type': 'inversion'}}
Input: Consider these statements:
1. No students are humans
2. All humans are chefs

Does it logically follow that:
Some chefs are humans?
(Answer Yes or No) Yes
Generated: Consider these statements:
1. No students are humans
2. All humans are chefs

Does it logically follow that:
Some chefs are humans?
(Answer Yes or No) Yes.

Let's break this down step-by-step:

1. The first



In [None]:
class GRPO:
    def __init__(self, model, tokenizer, config: GRPOConfig):
        super(GRPO, self).__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.config = config
        
        # Move model to device
        self.model.to(config.device)

    def compute_kl_term(self, ref_log_probs, current_model_log_probs):
        """Compute KL divergence term for regularization"""
        res = (ref_log_probs - current_model_log_probs).exp() - (
            torch.log(ref_log_probs.exp()) - torch.log(current_model_log_probs.exp())
        ) - 1
        return res
    
    def loss_function(self, old_log_probs, old_model_log_probs, advantages):
        """GRPO loss function with clipping"""
        logratio = old_log_probs - old_model_log_probs
        ratio = logratio.exp()
        surrogate1 = ratio * advantages
        surrogate2 = torch.clamp(ratio, 1 - self.config.clip_epsilon, 1 + self.config.clip_epsilon) * advantages
        loss = -torch.min(surrogate1, surrogate2).mean()
        return loss
    
    def compute_advantages(self, scores, rewards):
        """Compute advantages using score and rewards"""
        normalized_reward = (scores - rewards.mean()) / (rewards.std() + 1e-8)
        return normalized_reward
    
    def generate_text(self, input_text, max_length=None):
        """Generate text using the model with config parameters"""
        
        prompt = '''
        
        A conversation between User and Assistant. The User asks a question, and the Assistant provides an answer.
        The assistant first thinks about the reasoning process in mind, and then provide the user with the answer. The reasoning process and the answer are to be enclosed within <think> and </think>, <answer> </answer> tags respectively.
        i.e, <think> reasoning process here </think> <answer> answer here </answer>. User {}. Assistant:  
        
        '''
        
        final_input = prompt.format(input_text)
        
        max_length = max_length or self.config.max_length
        inputs = self.tokenizer(final_input, return_tensors='pt').to(self.config.device)

        with torch.no_grad():
            outputs = self.model.generate(
                **inputs, 
                padding=True,
                max_length=max_length,
                temperature=self.config.temperature,
                top_p=self.config.top_p,
                top_k=self.config.top_k,
                do_sample=self.config.do_sample,
            )
        
        generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        return generated_text

In [9]:
#Dataset Creation and Sampling

# Check what methods are available on the dataset
print("Dataset type:", type(dataset))
print("Available methods:", [method for method in dir(dataset) if not method.startswith('_')])

# Since procedural datasets don't have .sample(), we need to use random sampling
import random

def sample_from_dataset(dataset, n):
    """Sample n items from the dataset"""
    # Convert dataset to list if it's iterable
    dataset_list = list(dataset) if hasattr(dataset, '__iter__') else dataset
    
    # If the dataset is smaller than n, return all items
    if len(dataset_list) < n:
        return dataset_list
    
    # Random sample without replacement
    return random.sample(dataset_list, n)

# Sample 5 data points for demonstration
batch = sample_from_dataset(dataset, 5)
print(f"Sampled batch of {len(batch)} items:")
for i, item in enumerate(batch):
    print(f"Item {i+1}: {item}")


Dataset type: <class 'reasoning_gym.logic.syllogisms.SyllogismDataset'>
Available methods: ['DEFAULT_TERMS', 'category', 'config', 'score_answer', 'seed', 'size', 'terms']
Sampled batch of 5 items:
Item 1: {'question': 'Consider these statements:\n1. All horses are tigers\n2. All tigers are musicians\n\nDoes it logically follow that:\nSome horses are musicians?\n(Answer Yes or No)', 'answer': 'Yes', 'metadata': {'source_dataset': 'syllogism', 'source_index': 92, 'premise1': 'All horses are tigers', 'premise2': 'All tigers are musicians', 'conclusion': 'Some horses are musicians', 'is_valid': True, 'type': 'syllogism'}}
Item 2: {'question': 'Consider these statements:\n1. Some cats are musicians\n2. Some musicians are adults\n\nDoes it logically follow that:\nSome musicians are cats?\n(Answer Yes or No)', 'answer': 'Yes', 'metadata': {'source_dataset': 'syllogism', 'source_index': 934, 'premise1': 'Some cats are musicians', 'premise2': 'Some musicians are adults', 'selected_premise': 1,

In [None]:
# Initialize GRPO trainer
from tqdm import tqdm
import torch.optim as optim
import numpy as np
import random
import re

# Initialize optimizer
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)

# Convert dataset to list for sampling
dataset_list = list(dataset)
print(f"Total dataset size: {len(dataset_list)}")

grpo = GRPO(model, tokenizer, config)

def sample_batch(dataset_list, batch_size):
    """Sample a batch from the dataset"""
    return random.sample(dataset_list, min(batch_size, len(dataset_list)))

def extract_answer_with_regex(text):
    """Extract text between <answer> and </answer> tags using regex"""
    # Pattern to match text between <answer> and </answer> tags
    # .*? makes it non-greedy (stops at first </answer>)
    # re.DOTALL flag makes . match newlines too
    pattern = r'<answer>(.*?)</answer>'
    match = re.search(pattern, text, re.DOTALL)
    
    if match:
        return match.group(1).strip()
    else:
        return ""  # Return empty string if no match found

def extract_thinking_with_regex(text):
    """Extract text between <think> and </think> tags using regex"""
    pattern = r'<think>(.*?)</think>'
    match = re.search(pattern, text, re.DOTALL)
    
    if match:
        return match.group(1).strip()
    else:
        return ""

for i in tqdm(range(config.num_updates), desc="Training"):
    # Sample a batch of data using our custom sampling function
    
    for step in range(config.max_steps):
        # Sample a batch of data using our custom sampling function
        batch = sample_batch(dataset_list, config.batch_size)
        outputs = []
        rewards = []
        for data in batch:
            
            #Compute n outputs per datapoint
            for _ in range(config.n_outputs):
                output = grpo.generate_text(data['question'], max_length=config.max_length)
                outputs.append((output, data['answer']))

            for output, answer in outputs:
                # Use regex to extract answer instead of string splitting
                extracted_answer = extract_answer_with_regex(output)
                extracted_thinking = extract_thinking_with_regex(output)
                
                score = dataset.score_answer(extracted_answer, answer) == 1.0
                rewards.append(score)
                
            advantages = grpo.compute_advantages(torch.tensor(rewards), torch.tensor(rewards))
            
            for _ in range(config.grpo_iterations):
                old_log_probs = 
                loss = grpo.loss_function()
                # You can now use extracted_answer and extracted_thinking
                # For example, compute rewards based on extracted_answer
                # reward = compute_reward(extracted_answer, data['answer'])
                # rewards.append(reward)
             
    # Backward pass and optimization
    # Note: You need to define 'loss' before this point
    # optimizer.zero_grad()
    # loss.backward()
    # optimizer.step()
    
    if i % config.log_interval == 0:
        print(f"Update {i}")
        # print(f"Update {i}, Loss: {loss.item():.4f}")
        
print("Training completed!")