In [None]:
"""
Demonstration of a batched "DeepSeek/GRPO" loop with:
  - Multiple prompts/tasks per iteration
  - Monitoring metrics (avg reward, KL, success rate)
  - Periodic checkpoints (saving new_policy)
"""

import copy
import torch
import torch.nn as nn
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoConfig,
)
import os
from typing import List, Tuple

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class GRPOTrainer:
    """
    "DeepSeek/GRPO"-style batched training approach.
    """

    def __init__(
        self,
        model_name="gpt2",
        lr=1e-5,
        epsilon=0.2,
        kl_coef=0.01,
        checkpoint_dir="checkpoints"
    ):
        """
        :param model_name: HF model name or local directory
        :param lr: learning rate for new_policy
        :param epsilon: PPO clip param
        :param kl_coef: coefficient for KL penalty
        :param checkpoint_dir: where to save model checkpoints
        """
        self.checkpoint_dir = checkpoint_dir
        os.makedirs(self.checkpoint_dir, exist_ok=True)

        # Load config & model
        self.config = AutoConfig.from_pretrained(model_name)
        self.new_policy = AutoModelForCausalLM.from_pretrained(model_name, config=self.config).to(DEVICE)
        self.old_policy = copy.deepcopy(self.new_policy).to(DEVICE)
        self.old_policy.eval()

        # Tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        # Params
        self.epsilon = epsilon
        self.kl_coef = kl_coef

        self.optimizer = torch.optim.Adam(self.new_policy.parameters(), lr=lr)

        # We'll use a CrossEntropyLoss manually to compute log-probs
        self.ce_loss_fct = nn.CrossEntropyLoss(reduction="sum")

    def train_on_batch(
        self,
        prompts: List[str],
        correct_answers: List[str],
        group_size=4,
        max_new_tokens=20,
        temperature=1.0
    ) -> dict:
        """
        Perform one batched training step on a list of prompts.
        """
        # 1) Generate from old_policy, gather (prompt, answers, old_logprobs, rewards)
        old_samples = []
        with self._swap_models_temporarily(self.old_policy):
            for prompt, correct_ans in zip(prompts, correct_answers):
                answers, logprobs_old = self._generate_answers(
                    prompt,
                    group_size=group_size,
                    max_new_tokens=max_new_tokens,
                    temperature=temperature,
                    requires_grad=False  # old policy => no gradient
                )
                rewards = [self._reward_func(a, correct_ans) for a in answers]
                old_samples.append((prompt, answers, logprobs_old, rewards))

        # 2) Compute new_policy logprobs
        all_logprobs_old = []
        all_logprobs_new = []
        all_rewards = []
        total_correct = 0
        total_samples = 0

        for (prompt, answers, logprobs_old, rewards) in old_samples:
            # new policy => needs gradient
            _, logprobs_new = self._compute_logprobs(
                prompt,
                answers,
                model=self.new_policy,
                requires_grad=True
            )

            all_logprobs_old.extend(logprobs_old)
            all_logprobs_new.extend(logprobs_new.detach().cpu().tolist())
            all_rewards.extend(rewards)

            for r in rewards:
                if r > 0.99:
                    total_correct += 1
            total_samples += len(rewards)

        # Convert lists to Tensors
        logprobs_old_t = torch.tensor(all_logprobs_old, device=DEVICE)
        logprobs_new_t = torch.tensor(all_logprobs_new, device=DEVICE, requires_grad=True)
        rewards_t = torch.tensor(all_rewards, device=DEVICE, dtype=torch.float)

        # 3) Compute advantages by group
        #    We do chunking in group_size
        advantages_list = []
        offset = 0
        for _ in prompts:
            chunk_rewards = rewards_t[offset : offset + group_size]
            r_mean = chunk_rewards.mean()
            r_std = chunk_rewards.std() + 1e-6
            chunk_adv = (chunk_rewards - r_mean) / r_std
            advantages_list.append(chunk_adv)
            offset += group_size
        advantages_t = torch.cat(advantages_list, dim=0)

        ratio_t = torch.exp(logprobs_new_t - logprobs_old_t)  # shape [batch_size]
        clipped_ratio_t = torch.clamp(ratio_t, 1.0 - self.epsilon, 1.0 + self.epsilon)
        obj_min = torch.minimum(ratio_t, clipped_ratio_t) * advantages_t

        kl = (logprobs_new_t - logprobs_old_t).mean()
        pg_loss = -obj_min.mean() + self.kl_coef * kl

        # 4) Gradient update
        self.optimizer.zero_grad()
        pg_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.new_policy.parameters(), 1.0)
        self.optimizer.step()

        # 5) refresh old_policy
        self.refresh_old_policy()

        # metrics
        avg_reward = rewards_t.mean().item()
        success_rate = float(total_correct) / float(total_samples) if total_samples else 0.0

        return {
            "avg_reward": avg_reward,
            "success_rate": success_rate,
            "kl_div": kl.item(),
            "pg_loss": pg_loss.item(),
        }

    def refresh_old_policy(self):
        self.old_policy.load_state_dict(self.new_policy.state_dict())

    def save_checkpoint(self, step: int):
        ckpt_path = os.path.join(self.checkpoint_dir, f"new_policy_step_{step}")
        self.new_policy.save_pretrained(ckpt_path)
        print(f"Checkpoint saved to {ckpt_path}")

    def _generate_answers(
        self,
        prompt: str,
        group_size: int,
        max_new_tokens: int,
        temperature: float,
        requires_grad: bool
    ) -> Tuple[List[str], List[float]]:
        """
        Generate 'group_size' completions from the *active* model.
        Return (answers, logprobs).
        For old policy, we set requires_grad=False.
        """
        # We'll do generation with no_grad to avoid overhead, even if it's new policy
        model = self.new_policy
        model.eval()

        with torch.no_grad():
            input_ids = self.tokenizer(prompt, return_tensors="pt", padding=True).input_ids.to(DEVICE)
            attention_mask = self.tokenizer(prompt, return_tensors="pt", padding=True).attention_mask.to(DEVICE)
            gen_outputs = model.generate(
                input_ids,
                attention_mask=attention_mask,
                do_sample=True,
                top_k=0,
                temperature=temperature,
                num_return_sequences=group_size,
                max_new_tokens=max_new_tokens,
                pad_token_id=self.tokenizer.pad_token_id
            )

        answers = []
        for seq_idx in range(group_size):
            seq_ids = gen_outputs[seq_idx]
            full_text = self.tokenizer.decode(seq_ids, skip_special_tokens=True)
            ans = full_text[len(prompt):].strip()
            answers.append(ans)
        print(prompt)
        print(answers)
        # Now compute log-probs with/without grad
        _, logprobs_tensor = self._compute_logprobs(
            prompt, answers, model, requires_grad=requires_grad
        )
        print(logprobs_tensor)
        print()
        return answers, logprobs_tensor.detach().cpu().tolist()

    def _compute_logprobs(
        self,
        prompt: str,
        answers: List[str],
        model: nn.Module,
        requires_grad: bool
    ) -> Tuple[List[str], torch.Tensor]:
        """
        Compute log P(answer | prompt) for each answer in a batch.
        We'll do this by manually computing cross-entropy from model logits.
        If requires_grad=False, wrap in no_grad().
        Returns: (answers, logprob_tensor) shape [batch_size].
        """
        if not requires_grad:
            with torch.no_grad():
                return self._compute_logprobs_impl(prompt, answers, model)
        else:
            return self._compute_logprobs_impl(prompt, answers, model)

    def _compute_logprobs_impl(
        self, prompt: str, answers: List[str], model: nn.Module
    ) -> Tuple[List[str], torch.Tensor]:
        """
        Actually do the forward pass and compute negative log-likelihood.
        We'll sum the cross-entropy over all tokens in [prompt+answer].
        """
        model.eval()
        logprob_values = []
        for ans in answers:
            full_text = prompt + ans
            input_ids = self.tokenizer(full_text, return_tensors="pt").input_ids.to(DEVICE)
            # Forward pass WITHOUT labels. We'll compute the CE manually:
            out = model(input_ids, use_cache=False)
            logits = out.logits  # shape [1, seq_len, vocab_size]

            # shift so that tokens < t> predict token at t
            # next-token-lm: we want p(x_{t+1} | x_{<=t})
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = input_ids[:, 1:].contiguous()

            # cross-entropy (sum)
            ce_loss = self.ce_loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1)
            )
            # negative log-likelihood:
            # (We used reduction='sum', so ce_loss is sum over the tokens)
            logprob_values.append(-ce_loss)

        # stack shape [batch_size]
        logprobs_tensor = torch.stack(logprob_values, dim=0)
        return answers, logprobs_tensor

    from contextlib import contextmanager
    @contextmanager
    def _swap_models_temporarily(self, source: nn.Module):
        """
        Temporarily replace self.new_policy params with source's params,
        then restore after the context.
        """
        new_state = copy.deepcopy(self.new_policy.state_dict())
        self.new_policy.load_state_dict(source.state_dict())
        try:
            yield
        finally:
            self.new_policy.load_state_dict(new_state)

    def _reward_func(self, candidate: str, correct: str) -> float:
        """
        Simple reward function:
          with some corrections for guidance
        """
        tokens = candidate.split(" ")
        reward = 0
        if len(tokens) == 1:
          reward = .3
          try:
            tokens[0]
            reward += .2
          except:
            pass
          if tokens[0].strip() == correct.strip():
            reward = 1
        return reward


def main():
    # Toy data
    prompts = [
        "Q: What is 2+3*4?\nRETURN ONLY THE NUMBER:",
        "Q: Solve 1+1?\nRETURN ONLY THE NUMBER:",
        "Q: What is 10 - 3?\nRETURN ONLY THE NUMBER:",
    ]
    correct_answers = ["14", "2", "7"]

    # Hyperparameters
    group_size = 4
    max_new_tokens = 5
    epochs = 5

    trainer = GRPOTrainer(
        model_name="HuggingFaceTB/SmolLM-135M", # || gpt2
        lr=1e-5,
        epsilon=0.2,
        kl_coef=0.01,
        checkpoint_dir="checkpoints"
    )

    for epoch in range(1, epochs + 1):
        metrics = trainer.train_on_batch(
            prompts=prompts,
            correct_answers=correct_answers,
            group_size=group_size,
            max_new_tokens=max_new_tokens,
            temperature=1.0
        )

        print(f"\nEpoch {epoch}/{epochs} - Metrics:")
        for k, v in metrics.items():
            print(f"  {k}: {v:.4f}")

        # Save checkpoint every 2 epochs
        if epoch % 2 == 0:
            trainer.save_checkpoint(step=epoch)

    print("\nTraining completed!")


if __name__ == "__main__":
    main()

Q: What is 2+3*4?
RETURN ONLY THE NUMBER:
['13 will return', '0, FLO', '- If D is', '10']
tensor([-81.9412, -81.3530, -80.2183, -69.6766], device='cuda:0')
Q: Solve 1+1?
RETURN ONLY THE NUMBER:
['TOPIC: Module', '* (MULTIP', 'ALWAYS GUESS', '24. Which']
tensor([-96.2953, -88.4987, -89.4476, -88.0674], device='cuda:0')
Q: What is 10 - 3?
RETURN ONLY THE NUMBER:
['10\nAND', 'And your answer is', 'For text substituting nuclear', '-10 /']
tensor([ -78.9078,  -86.0183, -108.9350,  -75.8747], device='cuda:0')

Epoch 1/5 - Metrics:
  avg_reward: 0.0833
  success_rate: 0.0000
  kl_div: 0.0000
  pg_loss: 0.0000
Q: What is 2+3*4?
RETURN ONLY THE NUMBER:
['(So the diamond', '7\n3*', '5+ 0', 'CUTVI -']
tensor([-89.6841, -77.4595, -82.8370, -94.8110], device='cuda:0')
Q: Solve 1+1?
RETURN ONLY THE NUMBER:
['A successful solution consists=', '-120', 'Express the remainder', 'FIFTH SIDE']
tensor([-106.2680,  -80.6138,  -87.1040,  -90.3150], device='cuda:0')
Q: What is 10 - 3?
RETURN ONLY THE NUMBER:
