In [16]:
!pip install datasets



In [17]:
import os
import re
import pandas as pd
import numpy as np
from datasets import Dataset
from transformers import (
    GPT2LMHeadModel,
    GPT2Tokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
    GenerationConfig
)
import kagglehub
import math
import re
from tqdm.auto import tqdm

In [18]:
# Download the dataset
path = kagglehub.dataset_download("thedevastator/grade-school-math-8k-q-a")
print("Path to dataset files:", path)

Path to dataset files: /root/.cache/kagglehub/datasets/thedevastator/grade-school-math-8k-q-a/versions/2


In [19]:
# ----------------------------
# Load the Tokenizer
# ----------------------------
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large', padding_side='left')

# Add/modify special tokens to handle chat format better
special_tokens_dict = {
    'pad_token': '[PAD]',
    'bos_token': '<|im_start|>',
    'eos_token': '<|im_end|>',
    'sep_token': '<|im_sep|>',
    'additional_special_tokens': [
        '####'
    ]
}
tokenizer.add_special_tokens(special_tokens_dict)

# ----------------------------
# Load the Model
# ----------------------------
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.resize_token_embeddings(len(tokenizer))

# Configure model settings
model.config.pad_token_id = tokenizer.pad_token_id
model.config.bos_token_id = tokenizer.bos_token_id
model.config.eos_token_id = tokenizer.eos_token_id

# Set up generation config
generation_config = GenerationConfig(
    max_new_tokens=512,
    do_sample=False,
    pad_token_id=tokenizer.pad_token_id,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    repetition_penalty=1.1,
)
model.generation_config = generation_config

In [20]:
# Load and combine all CSV files
base_path = "/root/.cache/kagglehub/datasets/thedevastator/grade-school-math-8k-q-a/versions/2"
files = ['main_train.csv', 'main_test.csv']

max_tokens = 0
max_example = None
max_file = None

for file in files:
    df = pd.read_csv(os.path.join(path, file))

    # Count tokens for questions and answers separately
    question_tokens = df['question'].apply(lambda x: len(tokenizer.encode(str(x))))
    answer_tokens = df['answer'].apply(lambda x: len(tokenizer.encode(str(x))))

    print(f"\n{file}:")
    print(f"Questions - Max tokens: {question_tokens.max()}, Mean: {question_tokens.mean():.1f}")
    print(f"Answers - Max tokens: {answer_tokens.max()}, Mean: {answer_tokens.mean():.1f}")

    # Find max examples
    max_q_idx = question_tokens.idxmax()
    max_a_idx = answer_tokens.idxmax()

    print(f"\nLongest question ({question_tokens.max()} tokens):")
    print(df['question'].iloc[max_q_idx])
    print(f"\nLongest answer ({answer_tokens.max()} tokens):")
    print(df['answer'].iloc[max_a_idx])


main_train.csv:
Questions - Max tokens: 213, Mean: 55.2
Answers - Max tokens: 342, Mean: 95.3

Longest question (213 tokens):
Hasan is packing up his apartment because he’s moving across the country for a new job. He needs to ship several boxes to his new home. The movers have asked that Hasan avoid putting more than a certain weight in pounds in any cardboard box. The moving company has helpfully provided Hasan with a digital scale that will alert him if a package is too heavy. Hasan is in the kitchen, and he fills a cardboard box with 38 dinner plates. When he checks the box, the scale reports his box is too heavy. Hasan knows each of his plates weighs 10 ounces. He removes a single plate from the box and checks the movers’ scale again. The scale reports his box is still too heavy. Hasan repeats the process again and again. When he has removed enough plates, the movers’ scale shows the box is now an acceptable weight for shipping. Hasan deduces that each shipping box can hold 20 pou

In [21]:
# Load datasets
train_data = pd.read_csv(os.path.join(base_path, "main_train.csv"))
test_data = pd.read_csv(os.path.join(base_path, "main_test.csv"))

# Count tokens per question for curriculum learning
train_data["question_tokens"] = train_data["question"].apply(
    lambda x: len(tokenizer.encode(str(x)))
)

# Select 10% shortest questions for SFT
sft_data = train_data.nsmallest(int(0.1 * len(train_data)), "question_tokens")
grpo_data = train_data.drop(sft_data.index)  # Save rest for GRPO
grpo_data = grpo_data.sample(frac=0.025, random_state=42)

def format_answer(answer):
    """Format answer with proper tokens"""
    answer = answer.strip()
    # Ensure #### is properly spaced
    answer = re.sub(r'####(\S)', r'#### \1', answer)
    # Replace calculation format if needed
    answer = re.sub(r'<<(.*?)>>', r'<<\1>>', answer)
    if not answer.endswith(tokenizer.eos_token):
        answer += f" {tokenizer.eos_token}"
    return answer

def preprocess_function(examples):
    """Format input-output pairs with proper tokens"""
    combined_texts = []
    for q, a in zip(examples["question"], examples["answer"]):
        formatted_answer = format_answer(a)
        # Use proper tokens for chat format
        prompt = (f"{tokenizer.bos_token}{q.strip()}"
                 f"{tokenizer.sep_token}{formatted_answer}")
        combined_texts.append(prompt)

    # Tokenize with proper padding and truncation
    return tokenizer(
        combined_texts,
        truncation=True,
        padding=True,
        max_length=512,
        return_tensors="pt",
        return_attention_mask=True,
        return_token_type_ids=False  # Not needed for GPT-2
    )

# Create datasets for SFT<>
sft_train_dataset = Dataset.from_pandas(sft_data)
sft_train_dataset = sft_train_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=sft_train_dataset.column_names  # Remove original columns
)

# Create validation dataset
test_data = test_data.sample(frac=0.1, random_state=42)
sft_val_dataset = Dataset.from_pandas(test_data)
sft_val_dataset = sft_val_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=sft_val_dataset.column_names
)

# Save GRPO data for later
grpo_dataset = Dataset.from_pandas(grpo_data)

Map:   0%|          | 0/747 [00:00<?, ? examples/s]

Map:   0%|          | 0/132 [00:00<?, ? examples/s]

In [22]:
print(tokenizer.eos_token, tokenizer.eos_token_id)

<|im_end|> 50259


In [23]:
# 5) Training setup
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=4,
    eval_accumulation_steps=4,
    num_train_epochs=2,
    weight_decay=0.01,
    logging_steps=1,
    logging_first_step=True,
    dataloader_drop_last=True,
    save_strategy="epoch"  # Add this to save the model after SFT
)

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)

# 6) Custom metric & utilities
def remove_excess_after_answer(text):
    """Removes everything after the first '#### <answer>' (with a space)."""
    match = re.search(r"####\s+(\S+)", text)  # Requires at least one space
    if match:
        prefix = text[:match.start()]
        return prefix + "#### " + match.group(1)  # Re-add the space for consistency
    return text

def custom_compute_metrics(eval_preds):
    """Computes accuracy based on '#### <answer>' (number before eos)."""
    logits, labels = eval_preds
    pred_ids = np.argmax(logits, axis=-1)
    decoded_preds = tokenizer.batch_decode(pred_ids, skip_special_tokens=False)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=False)

    correct = 0
    total = len(decoded_preds)

    for pred_text, label_text in zip(decoded_preds, decoded_labels):
        match_pred = re.search(r"####\s*(\d+)\s*<\|im_end\|>", pred_text) # Number before eos
        match_label = re.search(r"####\s*(\d+)\s*<\|im_end\|>", label_text) # Number before eos

        if match_pred and match_label:
            answer_pred = match_pred.group(1).strip()
            answer_label = match_label.group(1).strip()
            if answer_pred == answer_label:
                correct += 1

    accuracy = correct / total if total > 0 else 0
    return {"custom_accuracy": accuracy}


trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=sft_train_dataset,
    eval_dataset=sft_val_dataset,
    data_collator=data_collator,
    compute_metrics=custom_compute_metrics
)

# 7) Train
trainer.train()



Epoch,Training Loss,Validation Loss,Custom Accuracy
1,4.448,3.477013,0.0
2,3.7417,3.154286,0.0


TrainOutput(global_step=46, training_loss=4.739146398461384, metrics={'train_runtime': 81.9947, 'train_samples_per_second': 18.221, 'train_steps_per_second': 0.561, 'total_flos': 153998991360000.0, 'train_loss': 4.739146398461384, 'epoch': 2.0})

In [24]:
# 8) Free-form inference (Batched)
batch_size = 16
eos_token_id = tokenizer.eos_token_id  # Using proper eos token

questions = test_data["question"].tolist()
references = test_data["answer"].tolist()

rows = []
num_batches = math.ceil(len(questions) / batch_size)

for b_idx in tqdm(range(num_batches), desc="Batched Inference"):
    start_idx = b_idx * batch_size
    end_idx = start_idx + batch_size

    batch_questions = questions[start_idx:end_idx]
    batch_refs = references[start_idx:end_idx]

    # Build prompt with proper tokens
    batch_prompts = [
        f"{tokenizer.bos_token}{q}{tokenizer.sep_token}"
        for q in batch_questions
    ]

    # Rest of tokenization and generation stays the same
    encoded_batch = tokenizer(
        batch_prompts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=512
    )
    encoded_batch = {k: v.to(model.device) for k, v in encoded_batch.items()}

    # Generate
    output_ids = model.generate(
        input_ids=encoded_batch["input_ids"],
        attention_mask=encoded_batch["attention_mask"],
        eos_token_id=eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
        max_new_tokens=512,
        do_sample=False,
    )

    # Rest remains the same
    for q_text, ref, out_ids in zip(batch_questions, batch_refs, output_ids):
        gen_text = tokenizer.decode(out_ids)
        cleaned_prediction = remove_excess_after_answer(gen_text)

        rows.append({
            "question": q_text,
            "prediction": cleaned_prediction,
            "reference": ref
        })

inference_df = pd.DataFrame(rows)
inference_df.to_csv("test_predictions_freeform.csv", index=False)
print("Saved batched free-form generation results to 'test_predictions_freeform.csv'")

Batched Inference:   0%|          | 0/9 [00:00<?, ?it/s]

Saved batched free-form generation results to 'test_predictions_freeform.csv'


In [25]:
import torch
from torch.optim import AdamW
from typing import List, Dict, Any,Tuple
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List, Callable

In [26]:
class SimpleGRPOTrainer:
    def __init__(
        self,
        model: AutoModelForCausalLM,
        ref_model: AutoModelForCausalLM,
        tokenizer: AutoTokenizer,
        reward_funcs: List[Callable],
        num_generations: int = 4,
        learning_rate: float = 1e-6,
        beta: float = 0.04,
        max_prompt_length: int = 512,
        max_completion_length: int = 256,
        temperature: float = 0.9
    ):
        self.model = model
        self.ref_model = ref_model
        self.tokenizer = tokenizer
        self.reward_funcs = reward_funcs
        self.num_generations = num_generations
        self.optimizer = AdamW(model.parameters(), lr=learning_rate)
        self.beta = beta
        self.max_prompt_length = max_prompt_length
        self.max_completion_length = max_completion_length
        self.temperature = temperature

        # Generation config
        self.generation_config = GenerationConfig(
            max_new_tokens=self.max_completion_length,
            do_sample=True,
            temperature=self.temperature,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id
        )

    def _prepare_inputs(self, prompts: List[dict]) -> dict:
        """Format prompts with math-specific tokens and prepare for model"""
        # Format prompts with our special tokens
        prompt_texts = [
            f"{self.tokenizer.bos_token}{p['prompt']}{self.tokenizer.sep_token}"
            for p in prompts
        ]

        # Tokenize with right padding
        prompt_inputs = self.tokenizer(
            prompt_texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=self.max_prompt_length,
            add_special_tokens=False
        ).to(self.model.device)

        return prompt_inputs

    def generate_completions(self, prompts: List[dict]) -> Tuple[torch.Tensor, torch.Tensor]:
        """Generate multiple completions per prompt with proper masking and padding"""
        prompt_inputs = self._prepare_inputs(prompts)
        prompt_ids = prompt_inputs["input_ids"]
        prompt_mask = prompt_inputs["attention_mask"]

        all_completions = []
        completion_masks = []

        # Generate multiple times per prompt
        for _ in range(self.num_generations):
            with torch.no_grad():
                outputs = self.model.generate(
                    input_ids=prompt_ids,
                    attention_mask=prompt_mask,
                    generation_config=self.generation_config,
                    return_dict_in_generate=True,
                    output_scores=True,
                    early_stopping=True,
                )

                completion_ids = outputs.sequences[:, prompt_ids.shape[1]:]

                # Find max length in this batch of generations
                max_len = completion_ids.shape[1]

                # Create completion mask (handling EOS tokens)
                is_eos = completion_ids == self.tokenizer.eos_token_id
                eos_indices = torch.argmax(is_eos.float(), dim=1)
                eos_indices[~is_eos.any(1)] = completion_ids.shape[1] - 1
                completion_mask = torch.arange(completion_ids.shape[1], device=completion_ids.device)[None, :] <= eos_indices[:, None]

                all_completions.append(completion_ids)
                completion_masks.append(completion_mask)

        # Find the maximum length across all generations
        max_length = max(comp.shape[1] for comp in all_completions)

        # Pad all completions to the maximum length
        padded_completions = []
        padded_masks = []

        for comp, mask in zip(all_completions, completion_masks):
            if comp.shape[1] < max_length:
                # Pad completions
                padding = torch.full(
                    (comp.shape[0], max_length - comp.shape[1]),
                    self.tokenizer.pad_token_id,
                    dtype=comp.dtype,
                    device=comp.device
                )
                padded_comp = torch.cat([comp, padding], dim=1)

                # Pad masks
                mask_padding = torch.zeros(
                    (mask.shape[0], max_length - mask.shape[1]),
                    dtype=mask.dtype,
                    device=mask.device
                )
                padded_mask = torch.cat([mask, mask_padding], dim=1)
            else:
                padded_comp = comp
                padded_mask = mask

            padded_completions.append(padded_comp)
            padded_masks.append(padded_mask)

        return torch.stack(padded_completions, dim=1), torch.stack(padded_masks, dim=1)

    def compute_rewards(self, prompts: List[dict], completions: torch.Tensor) -> torch.Tensor:
        """Compute and normalize rewards group-wise"""
        # Flatten completions for reward computation
        batch_size = completions.shape[0]
        num_generations = completions.shape[1]

        # Create flattened list of prompts to match completion structure
        flattened_prompts = []
        for prompt in prompts:
            flattened_prompts.extend([prompt] * num_generations)

        # Convert completions to texts
        completion_texts = []
        for i in range(batch_size):
            for j in range(num_generations):
                text = self.tokenizer.decode(completions[i,j], skip_special_tokens=False)
                completion_texts.append(text)

        # Get rewards from all functions
        total_rewards = torch.zeros(len(completion_texts), device=self.model.device)
        for reward_func in self.reward_funcs:
            rewards = reward_func(prompts=flattened_prompts, completions=completion_texts)
            total_rewards += torch.tensor(rewards, device=self.model.device)

        # Reshape rewards back to (batch_size, num_generations)
        rewards = total_rewards.view(batch_size, num_generations)

        # Normalize group-wise
        mean_rewards = rewards.mean(dim=1, keepdim=True)
        std_rewards = rewards.std(dim=1, keepdim=True)
        advantages = (rewards - mean_rewards) / (std_rewards + 1e-8)

        return advantages

    def compute_logprobs(self, model: AutoModelForCausalLM, input_ids: torch.Tensor, attention_mask: torch.Tensor, logits_to_keep: int) -> torch.Tensor:
        """Compute per-token log probabilities with proper masking"""
        with torch.set_grad_enabled(model is self.model):
            # Get logits for completion portion
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits[:, :-1, :]  # Remove last prediction

            # Get target tokens
            targets = input_ids[:, 1:]  # Shift right for next token prediction

            # Keep only completion portion with proper size
            logits = logits[:, -logits_to_keep:]
            targets = targets[:, -logits_to_keep:]

            # Compute log probabilities
            log_probs = torch.log_softmax(logits, dim=-1)
            token_log_probs = log_probs.gather(-1, targets.unsqueeze(-1)).squeeze(-1)

            return token_log_probs

    def train_step(self, batch: List[dict]):
        # Generate completions and get masks
        completions, completion_masks = self.generate_completions(batch)

        # Get rewards and advantages
        advantages = self.compute_rewards(batch, completions)

        # Prepare for loss computation
        total_loss = 0

        for i in range(completions.shape[0]):  # For each prompt
            for j in range(completions.shape[1]):  # For each completion
                # Concatenate prompt and completion
                prompt_ids = self._prepare_inputs(batch)["input_ids"][i:i+1]  # Keep batch dimension
                completion_ids = completions[i,j:j+1]  # Keep batch dimension
                completion_mask = completion_masks[i,j:j+1]

                input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
                attention_mask = torch.cat([
                    torch.ones_like(prompt_ids),
                    completion_mask
                ], dim=1)

                logits_to_keep = completion_ids.size(1)

                # Get log probs for both policy and reference model
                policy_log_probs = self.compute_logprobs(
                    self.model, input_ids, attention_mask, logits_to_keep
                )

                with torch.no_grad():
                    ref_log_probs = self.compute_logprobs(
                        self.ref_model, input_ids, attention_mask, logits_to_keep
                    )

                # Make sure all tensors have the same size
                policy_log_probs = policy_log_probs[:, :logits_to_keep-1]  # Adjust for next token prediction
                ref_log_probs = ref_log_probs[:, :logits_to_keep-1]
                mask = completion_mask[:, :logits_to_keep-1]  # Adjust mask to match

                # Compute KL divergence with size-matched tensors
                kl = torch.exp(ref_log_probs - policy_log_probs) - (ref_log_probs - policy_log_probs) - 1

                # Compute policy gradient loss with KL penalty
                per_token_loss = torch.exp(policy_log_probs - policy_log_probs.detach()) * advantages[i,j]
                per_token_loss = -(per_token_loss - self.beta * kl)

                # Average loss over non-padded tokens using matched sizes
                loss = ((per_token_loss * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-8)).mean()
                total_loss += loss

        # Average loss across batch and optimize
        loss = total_loss / (completions.shape[0] * completions.shape[1])
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item()

In [27]:
def extract_final_answer(text: str) -> str | None:
    """Extract the final answer after ####"""
    if "####" not in text:
        return None
    answer = text.split("####")[1].strip()
    return answer.replace(",", "").replace("$", "").strip()

def extract_calculations(text: str) -> List[tuple[str, str]]:
    """Extract all calculations within <<>> tags, returning (expression, result) pairs"""
    pattern = r"<<(.*?)=(.*?)>>"
    matches = re.finditer(pattern, text)
    return [(match.group(1).strip(), match.group(2).strip()) for match in matches]

def evaluate_expression(expr: str) -> float:
    """Safely evaluate a mathematical expression"""
    # Remove any commas and convert to plain numbers
    expr = expr.replace(",", "").replace("$", "")
    try:
        return float(eval(expr))
    except:
        return float('nan')

def format_reward_func(completions: List[str], **kwargs) -> List[float]:
    """
    Reward proper formatting:
    - Presence of calculations in <<>> format
    - Presence of #### for final answer
    - Presence of the EOS token (e.g., <|im_end|>) at the end of the completion
    - General structure
    """
    rewards = []
    for completion in completions:
        score = 0.0

        # Check for presence of calculations in <<>> format
        if re.search(r"<<.*?=.*?>>", completion):
            score += 0.2

        # Check for the final answer marker '####'
        if "####" in completion:
            score += 0.2

            # Check if there's exactly one '####'
            if completion.count("####") == 1:
                score += 0.1

            # Check if calculations come before the final answer marker
            parts = completion.split("####")
            if len(parts) == 2 and "<<" in parts[0]:
                score += 0.2

        # Check for the presence of the EOS token at the end.
        # Either hardcode "<|im_end|>" or use tokenizer.eos_token if available.
        eos_token = "<|im_end|>"  # or, if available: eos_token = tokenizer.eos_token
        if completion.strip().endswith(eos_token):
            score += 0.2

        rewards.append(score)

    return rewards

def calculation_reward_func(completions: List[str], **kwargs) -> List[float]:
    """
    Reward correct intermediate calculations:
    - Each calculation within <<>> must be mathematically correct
    - Format should be expression=result
    """
    rewards = []
    for completion in completions:
        score = 0.0
        calculations = extract_calculations(completion)

        if not calculations:
            rewards.append(0.0)
            continue

        total_calcs = len(calculations)
        correct_calcs = 0

        for expr, result in calculations:
            try:
                expected = evaluate_expression(expr)
                actual = evaluate_expression(result)
                if abs(expected - actual) < 0.01:  # Allow small floating point differences
                    correct_calcs += 1
            except:
                continue

        if total_calcs > 0:
            score = (correct_calcs / total_calcs) * 0.5  # Max 0.5 points for calculations

        rewards.append(score)

    return rewards

def answer_correctness_reward_func(completions: List[str], answer: List[str], **kwargs) -> List[float]:
    """
    Reward correct final answer:
    - Answer after #### must match the expected answer
    - Handles number formatting (commas, spaces, etc.)
    """
    rewards = []
    for completion, expected in zip(completions, answer):
        score = 0.0

        extracted = extract_final_answer(completion)
        if extracted is not None and expected is not None:
            # Clean up both answers for comparison
            extracted = extracted.replace(" ", "").replace(",", "").replace("$", "").strip()
            expected = str(expected).replace(" ", "").replace(",", "").replace("$", "").strip()

            try:
                # Try numerical comparison first
                if abs(float(extracted) - float(expected)) < 0.01:
                    score = 1.0
            except:
                # Fallback to string comparison
                if extracted == expected:
                    score = 1.0

        rewards.append(score)

    return rewards

def combined_reward_func(prompts: List[dict], completions: List[str], **kwargs) -> List[float]:
    """Modified signature to match GRPO expectations"""
    answers = [p["answer"] for p in prompts]  # Extract answers from prompts
    format_rewards = format_reward_func(completions)
    calc_rewards = calculation_reward_func(completions)
    answer_rewards = answer_correctness_reward_func(completions, answers)

    combined_rewards = []
    for f, c, a in zip(format_rewards, calc_rewards, answer_rewards):
        reward = (f * 0.05) + (c * 0.05) + (a * 0.9)
        combined_rewards.append(reward)

    return combined_rewards

In [28]:
import random
import copy

In [29]:
import gc
gc.collect()
torch.cuda.empty_cache()
with torch.cuda.device('cuda'):
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()

In [30]:
print("Starting GRPO training...")
grpo_model = model  # use the SFT-trained model
grpo_ref_model = copy.deepcopy(model)
grpo_ref_model.eval()  # ensure the reference model is in evaluation mode
for param in grpo_ref_model.parameters():
    param.requires_grad = False

# ------------------------------------------------------------------------------
# Prepare GRPO dataset for training
# ------------------------------------------------------------------------------
# grpo_dataset was created earlier as Dataset.from_pandas(grpo_data)
# Convert the dataset to a pandas DataFrame, rename "question" to "prompt",
# and then convert it to a list of dictionaries.
grpo_df = grpo_dataset.to_pandas()
grpo_df.rename(columns={"question": "prompt"}, inplace=True)
grpo_examples = grpo_df.to_dict(orient="records")
# Each element in grpo_examples is now a dict like: {"prompt": <question>, "answer": <answer>}

# ------------------------------------------------------------------------------
# Create the GRPO trainer
# ------------------------------------------------------------------------------
# (Assuming SimpleGRPOTrainer and all reward functions were defined earlier)
grpo_trainer = SimpleGRPOTrainer(
    model=grpo_model,
    ref_model=grpo_ref_model,
    tokenizer=tokenizer,
    reward_funcs=[combined_reward_func],  # use your combined reward function
    num_generations=8,
    learning_rate=1e-6,
    beta=0.04,
    temperature=0.9,
)

# ------------------------------------------------------------------------------
# GRPO Training Loop with tqdm
# ------------------------------------------------------------------------------
num_epochs = 1      # Adjust the number of GRPO epochs as needed
batch_size = 1      # GRPO batch size (number of prompts per step)

print("Starting GRPO training...")
for epoch in tqdm(range(num_epochs), desc="Epochs"):
    # Shuffle the GRPO examples at the beginning of each epoch
    random.shuffle(grpo_examples)
    num_batches = math.ceil(len(grpo_examples) / batch_size)

    # Use tqdm for batch-level progress within the epoch.
    for batch_idx in tqdm(range(num_batches), desc="Batches", leave=False):
        start = batch_idx * batch_size
        batch = grpo_examples[start: start + batch_size]
        loss = grpo_trainer.train_step(batch)
        # Using tqdm.write to log without interfering with the progress bar.
        tqdm.write(f"Epoch {epoch+1}/{num_epochs} - Batch {batch_idx+1}/{num_batches} - Loss: {loss:.4f}")

print("GRPO training completed.")

# ------------------------------------------------------------------------------
# Save the GRPO-trained model and tokenizer
# ------------------------------------------------------------------------------
output_dir = "./grpo_model"
os.makedirs(output_dir, exist_ok=True)
grpo_model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
print(f"GRPO-trained model saved to '{output_dir}'.")

Starting GRPO training...
Starting GRPO training...


Epochs:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/168 [00:00<?, ?it/s]



Epoch 1/1 - Batch 1/168 - Loss: 0.0000
Epoch 1/1 - Batch 2/168 - Loss: -0.0000
Epoch 1/1 - Batch 3/168 - Loss: 0.0000
Epoch 1/1 - Batch 4/168 - Loss: 0.0000
Epoch 1/1 - Batch 5/168 - Loss: 0.0000
Epoch 1/1 - Batch 6/168 - Loss: 0.0000
Epoch 1/1 - Batch 7/168 - Loss: 0.0000
Epoch 1/1 - Batch 8/168 - Loss: 0.0000
Epoch 1/1 - Batch 9/168 - Loss: 0.0000
Epoch 1/1 - Batch 10/168 - Loss: 0.0000
Epoch 1/1 - Batch 11/168 - Loss: 0.0000
Epoch 1/1 - Batch 12/168 - Loss: 0.0000
Epoch 1/1 - Batch 13/168 - Loss: 0.0000
Epoch 1/1 - Batch 14/168 - Loss: 0.0000
Epoch 1/1 - Batch 15/168 - Loss: 0.0000
Epoch 1/1 - Batch 16/168 - Loss: 0.0000
Epoch 1/1 - Batch 17/168 - Loss: 0.0000
Epoch 1/1 - Batch 18/168 - Loss: 0.0000
Epoch 1/1 - Batch 19/168 - Loss: 0.0000
Epoch 1/1 - Batch 20/168 - Loss: 0.0000
Epoch 1/1 - Batch 21/168 - Loss: 0.0000
Epoch 1/1 - Batch 22/168 - Loss: 0.0000
Epoch 1/1 - Batch 23/168 - Loss: 0.0000
Epoch 1/1 - Batch 24/168 - Loss: 0.0000
Epoch 1/1 - Batch 25/168 - Loss: 0.0000
Epoch 1/

In [32]:
# --------------------------------------------------------------------------
# Inference Evaluation after GRPO training
# --------------------------------------------------------------------------
# Use the GRPO-trained model to generate predictions on the test set.
# Here we use batched generation with the same special tokens.
test_questions = test_data["question"].tolist()
test_references = test_data["answer"].tolist()

inference_rows = []
inference_batch_size = 32  # adjust as needed
num_inference_batches = math.ceil(len(test_questions) / inference_batch_size)

print("Running inference on test data...")
for b_idx in tqdm(range(num_inference_batches), desc="GRPO Inference"):
    start_idx = b_idx * inference_batch_size
    end_idx = start_idx + inference_batch_size
    batch_questions = test_questions[start_idx:end_idx]
    batch_refs = test_references[start_idx:end_idx]

    # Build prompts using the same chat format as before.
    batch_prompts = [
        f"{tokenizer.bos_token}{q}{tokenizer.sep_token}"
        for q in batch_questions
    ]

    encoded_batch = tokenizer(
        batch_prompts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=1024,
    )
    # Move inputs to the model device
    encoded_batch = {k: v.to(grpo_model.device) for k, v in encoded_batch.items()}

    # Generate completions using the GRPO-trained model.
    output_ids = grpo_model.generate(
        input_ids=encoded_batch["input_ids"],
        attention_mask=encoded_batch["attention_mask"],
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
        max_new_tokens=1024,
        do_sample=False,  # or True if you prefer sampling,
        early_stopping=True,
    )

    # Decode and clean predictions.
    for q_text, ref_text, out_ids in zip(batch_questions, batch_refs, output_ids):
        gen_text = tokenizer.decode(out_ids, skip_special_tokens=False)
        cleaned_prediction = remove_excess_after_answer(gen_text)
        inference_rows.append({
            "question": q_text,
            "prediction": cleaned_prediction,
            "reference": ref_text
        })

# Save inference results as CSV.
inference_df = pd.DataFrame(inference_rows)
csv_filename = "grpo_test_predictions.csv"
inference_df.to_csv(csv_filename, index=False)
print(f"Saved GRPO inference predictions to '{csv_filename}'.")

Running inference on test data...


GRPO Inference:   0%|          | 0/5 [00:00<?, ?it/s]

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
# Process GRPO dataset the same way as before
grpo_train_dataset = Dataset.from_pandas(grpo_data)
grpo_train_dataset = grpo_train_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=grpo_train_dataset.column_names
)

# Training arguments for continued training
continued_training_args = TrainingArguments(
    output_dir="./results_continued",
    evaluation_strategy="epoch",
    learning_rate=5e-5,  # Same as before
    per_device_train_batch_size=16,
    per_device_eval_batch_size=4,
    eval_accumulation_steps=4,
    num_train_epochs=1,
    weight_decay=0.01,
    logging_steps=1,
    logging_first_step=True,
    dataloader_drop_last=True,
    save_strategy="epoch"
)

# Initialize new trainer with GRPO data
continued_trainer = Trainer(
    model=trainer.model,  # Use the previously trained model
    args=continued_training_args,
    train_dataset=grpo_train_dataset,
    eval_dataset=sft_val_dataset,  # Keep same validation dataset
    data_collator=data_collator,
    compute_metrics=custom_compute_metrics
)

# Continue training
print("Starting continued training with GRPO data...")
continued_trainer.train()

# Run inference with the continued model
batch_size = 2
eos_token_id = tokenizer.eos_token_id

questions = test_data["question"].tolist()
references = test_data["answer"].tolist()

rows = []
num_batches = math.ceil(len(questions) / batch_size)

for b_idx in tqdm(range(num_batches), desc="Continued Model Inference"):
    start_idx = b_idx * batch_size
    end_idx = start_idx + batch_size

    batch_questions = questions[start_idx:end_idx]
    batch_refs = references[start_idx:end_idx]

    # Build prompts
    batch_prompts = [
        f"{tokenizer.bos_token}{q}{tokenizer.sep_token}"
        for q in batch_questions
    ]

    encoded_batch = tokenizer(
        batch_prompts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=512
    )
    encoded_batch = {k: v.to(model.device) for k, v in encoded_batch.items()}

    # Generate
    output_ids = model.generate(
        input_ids=encoded_batch["input_ids"],
        attention_mask=encoded_batch["attention_mask"],
        eos_token_id=eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
        max_new_tokens=512,
        do_sample=False,
    )

    for q_text, ref, out_ids in zip(batch_questions, batch_refs, output_ids):
        gen_text = tokenizer.decode(out_ids)
        cleaned_prediction = remove_excess_after_answer(gen_text)

        rows.append({
            "question": q_text,
            "prediction": cleaned_prediction,
            "reference": ref
        })

# Save results
continued_inference_df = pd.DataFrame(rows)
continued_inference_df.to_csv("test_predictions_continued.csv", index=False)
print("Saved continued model predictions to 'test_predictions_continued.csv'")

# Compare initial vs continued results
initial_df = pd.read_csv("test_predictions_freeform.csv")
continued_df = pd.read_csv("test_predictions_continued.csv")

print("\nResults Comparison:")
print("Initial Model Results:")
initial_accuracy = custom_compute_metrics({"predictions": initial_df["prediction"].tolist(),
                                         "references": initial_df["reference"].tolist()})
print(f"Accuracy: {initial_accuracy['custom_accuracy']:.4f}")

print("\nContinued Model Results:")
continued_accuracy = custom_compute_metrics({"predictions": continued_df["prediction"].tolist(),
                                           "references": continued_df["reference"].tolist()})
print(f"Accuracy: {continued_accuracy['custom_accuracy']:.4f}")