In [3]:
import torch
import torch.nn as nn
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoConfig,
)
import os
from typing import List, Tuple
from contextlib import contextmanager

# Set the device to GPU if available, otherwise use CPU
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class GRPOTrainer:
    """
    GRPOTrainer implements a "DeepSeek/GRPO"-style batched training approach for language models.
    """

    def __init__(self, model_name="gpt2", lr=1e-5, epsilon=0.2, kl_coef=0.01, checkpoint_dir="checkpoints"):
        """
        Initializes the GRPOTrainer with the specified parameters.

        :param model_name: Name of the Hugging Face model to use.
        :param lr: Learning rate for the new policy.
        :param epsilon: PPO clipping parameter.
        :param kl_coef: Coefficient for KL divergence penalty.
        :param checkpoint_dir: Directory to save model checkpoints.
        """
        self.checkpoint_dir = checkpoint_dir
        os.makedirs(self.checkpoint_dir, exist_ok=True)  # Create checkpoint directory if it doesn't exist

        # Load model configuration and model
        self.config = AutoConfig.from_pretrained(model_name)
        self.new_policy = AutoModelForCausalLM.from_pretrained(model_name, config=self.config).to(DEVICE)
        self.old_policy = self.new_policy  # Use the same instance for memory efficiency
        self.old_policy.eval()  # Set old policy to evaluation mode

        # Initialize tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token  # Set pad token if not defined

        # Set parameters
        self.epsilon = epsilon
        self.kl_coef = kl_coef
        self.optimizer = torch.optim.AdamW(self.new_policy.parameters(), lr=lr)  # AdamW optimizer
        self.ce_loss_fct = nn.CrossEntropyLoss(reduction="sum")  # Cross-entropy loss for log probabilities

    def train_on_batch(self, prompts: List[str], correct_answers: List[str], group_size=4, max_new_tokens=20, temperature=1.0) -> dict:
        """
        Perform one batched training step on a list of prompts.

        :param prompts: List of input prompts.
        :param correct_answers: List of correct answers corresponding to the prompts.
        :param group_size: Number of answers to generate for each prompt.
        :param max_new_tokens: Maximum number of new tokens to generate.
        :param temperature: Sampling temperature for generation.
        :return: A dictionary containing training metrics.
        """
        old_samples = []  # Store samples generated from the old policy
        with self._swap_models_temporarily(self.old_policy):  # Temporarily swap models
            for prompt, correct_ans in zip(prompts, correct_answers):
                # Generate answers using the old policy
                answers, logprobs_old = self._generate_answers(prompt, group_size, max_new_tokens, temperature, requires_grad=False)
                rewards = [self._reward_func(a, correct_ans) for a in answers]  # Calculate rewards
                old_samples.append((prompt, answers, logprobs_old, rewards))  # Store results

        # Initialize lists for log probabilities and rewards
        all_logprobs_old, all_logprobs_new, all_rewards = [], [], []
        total_correct, total_samples = 0, 0  # Initialize counters for correct answers and total samples

        for (prompt, answers, logprobs_old, rewards) in old_samples:
            # Compute new log probabilities using the new policy
            _, logprobs_new = self._compute_logprobs(prompt, answers, model=self.new_policy, requires_grad=True)

            all_logprobs_old.extend(logprobs_old)  # Collect old log probabilities
            all_logprobs_new.extend(logprobs_new.detach().cpu().tolist())  # Collect new log probabilities
            all_rewards.extend(rewards)  # Collect rewards

            total_correct += sum(r > 0.99 for r in rewards)  # Count correct rewards
            total_samples += len(rewards)  # Update total samples

        # Convert lists to tensors for computation
        logprobs_old_t = torch.tensor(all_logprobs_old, device=DEVICE)
        logprobs_new_t = torch.tensor(all_logprobs_new, device=DEVICE, requires_grad=True)
        rewards_t = torch.tensor(all_rewards, device=DEVICE, dtype=torch.float)

        # Compute advantages for the current batch
        advantages_t = self._compute_advantages(rewards_t, group_size, prompts)

        # Calculate the ratio of new to old log probabilities
        ratio_t = torch.exp(logprobs_new_t - logprobs_old_t)
        clipped_ratio_t = torch.clamp(ratio_t, 1.0 - self.epsilon, 1.0 + self.epsilon)  # Clip the ratio
        obj_min = torch.minimum(ratio_t, clipped_ratio_t) * advantages_t  # Objective function

        kl = (logprobs_new_t - logprobs_old_t).mean()  # Compute KL divergence
        pg_loss = -obj_min.mean() + self.kl_coef * kl  # Calculate policy gradient loss

        self.optimizer.zero_grad()  # Reset gradients
        pg_loss.backward()  # Backpropagate loss
        torch.nn.utils.clip_grad_norm_(self.new_policy.parameters(), 1.0)  # Clip gradients
        self.optimizer.step()  # Update parameters

        self.refresh_old_policy()  # Refresh old policy parameters

        avg_reward = rewards_t.mean().item()  # Calculate average reward
        success_rate = float(total_correct) / float(total_samples) if total_samples else 0.0  # Calculate success rate

        return {
            "avg_reward": avg_reward,
            "success_rate": success_rate,
            "kl_div": kl.item(),
            "pg_loss": pg_loss.item(),
        }

    def _compute_advantages(self, rewards_t, group_size, prompts):
        """
        Compute advantages for the current batch based on rewards.

        :param rewards_t: Tensor of rewards for the current batch.
        :param group_size: Number of answers generated for each prompt.
        :param prompts: List of input prompts.
        :return: Tensor of advantages.
        """
        advantages_list = []  # List to store advantages for each prompt
        offset = 0  # Initialize offset for slicing rewards

        for _ in prompts:
            chunk_rewards = rewards_t[offset: offset + group_size]  # Get rewards for the current prompt
            r_mean = chunk_rewards.mean()  # Calculate mean reward
            r_std = chunk_rewards.std() + 1e-6  # Calculate standard deviation with a small epsilon to avoid division by zero
            chunk_adv = (chunk_rewards - r_mean) / r_std  # Normalize advantages
            advantages_list.append(chunk_adv)  # Append to advantages list
            offset += group_size  # Update offset for the next prompt

        return torch.cat(advantages_list, dim=0)  # Concatenate advantages into a single tensor

    def refresh_old_policy(self):
        """
        Refresh the old policy parameters by copying the new policy parameters.
        This is done after each training step to keep the old policy updated.
        """
        for old_param, new_param in zip(self.old_policy.parameters(), self.new_policy.parameters()):
            old_param.data.copy_(new_param.data)  # Copy new policy parameters to old policy

    def save_checkpoint(self, step: int):
        """
        Save the current model and optimizer state to a checkpoint.

        :param step: Current training step for naming the checkpoint.
        """
        ckpt_path = os.path.join(self.checkpoint_dir, f"new_policy_step_{step}")
        self.new_policy.save_pretrained(ckpt_path)  # Save model weights
        torch.save(self.optimizer.state_dict(), os.path.join(ckpt_path, "optimizer.pt"))  # Save optimizer state
        print(f"Checkpoint saved to {ckpt_path}")

    def _generate_answers(self, prompt: str, group_size: int, max_new_tokens: int, temperature: float, requires_grad: bool) -> Tuple[List[str], List[float]]:
        """
        Generate answers for a given prompt using the current policy.

        :param prompt: Input prompt for which to generate answers.
        :param group_size: Number of answers to generate.
        :param max_new_tokens: Maximum number of new tokens to generate.
        :param temperature: Sampling temperature for generation.
        :param requires_grad: Whether to compute gradients for the generated answers.
        :return: A tuple containing the generated answers and their log probabilities.
        """
        model = self.new_policy
        model.eval()  # Set model to evaluation mode

        with torch.no_grad():  # Disable gradient computation
            input_ids = self.tokenizer(prompt, return_tensors="pt", padding=True).input_ids.to(DEVICE)  # Tokenize prompt
            gen_outputs = model.generate(
                input_ids,
                do_sample=True,
                top_k=0,
                temperature=temperature,
                num_return_sequences=group_size,
                max_new_tokens=max_new_tokens,
                pad_token_id=self.tokenizer.pad_token_id  # Use pad token for padding
            )

        answers = []  # List to store generated answers
        for seq_idx in range(group_size):
            seq_ids = gen_outputs[seq_idx]  # Get the generated sequence
            full_text = self.tokenizer.decode(seq_ids, skip_special_tokens=True)  # Decode the sequence to text
            ans = full_text[len(prompt):].strip()  # Extract the answer by removing the prompt
            answers.append(ans)  # Append the answer to the list

        # Compute log probabilities for the generated answers
        _, logprobs_tensor = self._compute_logprobs(prompt, answers, model, requires_grad=requires_grad)
        return answers, logprobs_tensor.detach().cpu().tolist()  # Return answers and log probabilities

    def _compute_logprobs(self, prompt: str, answers: List[str], model: nn.Module, requires_grad: bool) -> Tuple[List[str], torch.Tensor]:
        """
        Compute log probabilities for the generated answers.

        :param prompt: Input prompt.
        :param answers: List of generated answers.
        :param model: The model to use for computing log probabilities.
        :param requires_grad: Whether to compute gradients.
        :return: A tuple containing the answers and their log probabilities.
        """
        if not requires_grad:
            with torch.no_grad():  # Disable gradient computation
                return self._compute_logprobs_impl(prompt, answers, model)
        else:
            return self._compute_logprobs_impl(prompt, answers, model)

    def _compute_logprobs_impl(self, prompt: str, answers: List[str], model: nn.Module) -> Tuple[List[str], torch.Tensor]:
        """
        Implementation of log probability computation.

        :param prompt: Input prompt.
        :param answers: List of generated answers.
        :param model: The model to use for computing log probabilities.
        :return: A tuple containing the answers and their log probabilities.
        """
        model.eval()  # Set model to evaluation mode
        logprob_values = []  # List to store log probabilities
        prompt_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(DEVICE)  # Tokenize prompt
        prompt_length = prompt_ids.shape[1]  # Get the length of the prompt

        # Prepare full texts for log probability computation
        full_texts = [prompt + ans for ans in answers]
        inputs = self.tokenizer(full_texts, return_tensors="pt", padding=True, truncation=True).to(DEVICE)  # Tokenize full texts

        with torch.no_grad():  # Disable gradient computation
            outputs = model(**inputs)  # Forward pass through the model

        # Compute log probabilities for each answer
        for i in range(len(answers)):
            answer_start = prompt_length - 1  # Start index for the answer
            answer_end = inputs.input_ids[i].ne(self.tokenizer.pad_token_id).sum() - 1  # End index for the answer

            if answer_end <= answer_start:  # Check for valid answer range
                logprob_values.append(torch.tensor(0.0, device=DEVICE))  # Append zero log probability if invalid
                continue

            logits = outputs.logits[i, answer_start:answer_end]  # Get logits for the answer
            labels = inputs.input_ids[i, answer_start + 1:answer_end + 1]  # Get labels for the answer
            ce_loss = self.ce_loss_fct(logits, labels)  # Compute cross-entropy loss
            logprob_values.append(-ce_loss)  # Append negative loss as log probability
        print(prompt)
        print(answers)
        print(logprob_values)
        return answers, torch.tensor(logprob_values, device=DEVICE)  # Return answers and log probabilities as tensor

    @contextmanager
    def _swap_models_temporarily(self, source: nn.Module):
        """
        Temporarily replace self.new_policy params with source's params,
        then restore after the context.

        :param source: The model whose parameters will be swapped in.
        """
        # Save the current state of the new_policy
        new_state = {name: param.data.clone() for name, param in self.new_policy.named_parameters()}

        # Load only the parameters that exist in the new_policy
        source_state = source.state_dict()  # Get state dict of the source model
        for name, param in self.new_policy.named_parameters():
            if name in source_state:  # Check if parameter exists in source
                param.data.copy_(source_state[name])  # Copy parameter data

        try:
            yield  # Yield control back to the context
        finally:
            # Restore the original parameters
            for name, param in self.new_policy.named_parameters():
                if name in new_state:  # Check if parameter exists in saved state
                    param.data.copy_(new_state[name])  # Restore parameter data

    def _reward_func(self, candidate: str, correct: str) -> float:
        """
        Calculate the reward based on the candidate answer and the correct answer.
        If the candidate is a number, it checks if it is close to the correct answer.
        Otherwise, it checks for string equality.

        :param candidate: The generated answer from the model.
        :param correct: The correct answer to compare against.
        :return: A float representing the reward (1.0 for correct, 0.0 for incorrect).
        """
        candidate = candidate.split("\n")[0] # Removing thinking process
        candidate = candidate.strip()  # Remove leading/trailing whitespace
        correct = correct.strip()  # Remove leading/trailing whitespace

        if candidate == correct:
          return 1.0

        reward = .0

        if len(candidate.split(" ")) > 1:
          reward += -.3
        else:
          reward += -.1

        try:
          candidate_num = float(candidate)
          reward += .5
        except ValueError:
          reward += -.3 # Penalty

        return reward

def main():
    """
    Main function to run the training process.
    It initializes the trainer and runs multiple epochs of training.
    """
    prompts = [
        "Q: What is 2+3*4?\nRETURN ONLY THE NUMBER:",
        "Q: Solve 1+1?\nRETURN ONLY THE NUMBER:",
        "Q: What is 10 - 3?\nRETURN ONLY THE NUMBER:",
    ]
    correct_answers = ["14", "2", "7"]  # Expected answers for the prompts

    group_size = 4  # Number of answers to generate for each prompt
    max_new_tokens = 5  # Maximum number of new tokens to generate
    epochs = 5  # Number of training epochs

    # Initialize the GRPOTrainer with specified parameters
    trainer = GRPOTrainer(
        model_name="prithivMLmods/Bellatrix-Tiny-1B-v2",  # Model name to load
        lr=1e-5,  # Learning rate
        epsilon=0.2,  # PPO clipping parameter
        kl_coef=0.01,  # KL divergence coefficient
        checkpoint_dir="checkpoints"  # Directory for saving checkpoints
    )

    # Run training for the specified number of epochs
    for epoch in range(1, epochs + 1):
        metrics = trainer.train_on_batch(
            prompts=prompts,
            correct_answers=correct_answers,
            group_size=group_size,
            max_new_tokens=max_new_tokens,
            temperature=.5  # Temperature for sampling
        )

        # Print metrics for the current epoch
        print(f"\nEpoch {epoch}/{epochs} - Metrics:")
        for k, v in metrics.items():
            print(f"  {k}: {v:.4f}")  # Display each metric

        # Save checkpoint every 2 epochs
        if epoch % 2 == 0:
            trainer.save_checkpoint(step=epoch)  # Save model checkpoint

    print("\nTraining completed!")

if __name__ == "__main__":
    main()  # Execute the main function

Q: What is 2+3*4?
RETURN ONLY THE NUMBER:
['2+3*', "11\n\nLet's", "14\n\nLet's", "11\n\nLet's"]
[tensor(-7.5274, device='cuda:0'), tensor(-10.4278, device='cuda:0'), tensor(-9.8115, device='cuda:0'), tensor(-10.4278, device='cuda:0')]
Q: Solve 1+1?
RETURN ONLY THE NUMBER:
['2\n\n**Explanation', "2\n\nLet's", '2\n\nQ:', '2\n\n**Solution']
[tensor(-11.1349, device='cuda:0'), tensor(-8.9757, device='cuda:0'), tensor(-10.0975, device='cuda:0'), tensor(-10.7723, device='cuda:0')]
Q: What is 10 - 3?
RETURN ONLY THE NUMBER:
["7\n\nLet's", "7\n\nLet's", "7\n\nLet's", "7\n\nLet's"]
[tensor(-9.5574, device='cuda:0'), tensor(-9.5574, device='cuda:0'), tensor(-9.5574, device='cuda:0'), tensor(-9.5574, device='cuda:0')]
Q: What is 2+3*4?
RETURN ONLY THE NUMBER:
['2+3*', "11\n\nLet's", "14\n\nLet's", "11\n\nLet's"]
[tensor(-7.5274, device='cuda:0'), tensor(-10.4278, device='cuda:0'), tensor(-9.8115, device='cuda:0'), tensor(-10.4278, device='cuda:0')]
Q: Solve 1+1?
RETURN ONLY THE NUMBER:
['2\n\n**E