# 03. In-the-Flow Agentic System Optimization with Flow-GRPO

Welcome to the third notebook, which dives deep into the training phase of an agentic system. This script implements the **Flow-GRPO (Group Rollout Policy Optimization)** algorithm to fine-tune the agent's core decision-making module, the **Planner**.

## Why Flow-GRPO?

Traditional Supervised Fine-Tuning (SFT) trains a model to mimic expert trajectories (Input $\rightarrow$ Optimal Action). However, SFT does not account for the quality of the *outcome* in a real environment. If the expert makes a mistake early on, the model learns the mistake.

Reinforcement Learning (RL) addresses this by rewarding the agent based on the *final success* of its entire trajectory. Flow-GRPO is a custom RL algorithm derived from **Proximal Policy Optimization (PPO)**, tailored for multi-step, agentic tasks:

1.  **On-Policy Learning:** The agent learns directly from its own generated experience (trajectories).
2.  **Group Rollouts (The GRPO part):** For a single input query, the agent generates $N$ diverse trajectories ($N$ = `rollout_n`). This group of outcomes allows us to compute a **relative advantage** (or counterfactual reward) for each trajectory. Instead of a trajectory simply being "good" or "bad," it is judged relative to the other outcomes in its group.
3.  **Stability (The PPO part):** It uses the PPO objective, which constrains the policy update size (via the clipping mechanism, $\epsilon$), preventing the new policy from deviating too radically from the old policy, ensuring stable training.
4.  **Reward from External Judge (RLHF):** The reward signal comes from a powerful, fixed "Judge LLM" (like GPT-4o) which acts as an impartial evaluator, providing a nuanced, human-aligned reward for the final answer quality.

---

## Section 1: Configuration and Dependencies

We start by importing all necessary Python modules, setting up the hardware device, and defining the global hyperparameters and model paths via a configuration dataclass.

In [None]:
# ==============================================================================
# Dependencies
# ==============================================================================
# !pip install transformers peft datasets json_repair tqdm
import os
import re
import json
import torch
import random
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm
from typing import List, Dict, Any, Tuple, Optional
from dataclasses import dataclass, field
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from torch.optim import AdamW

# Transformers & PEFT for efficient training
from transformers import (
    AutoModelForCausalLM, 
    AutoTokenizer, 
    BitsAndBytesConfig, 
    get_scheduler
)
from peft import (
    LoraConfig, 
    get_peft_model, 
    prepare_model_for_kbit_training, 
    PeftModel
)

# Library for robust JSON parsing from LLM outputs
import json_repair

In [None]:
# ==============================================================================
# IMPORT CUSTOM MODULES (Assumed to be in utils.py)
# ==============================================================================
from utils import (
    BaseTool, QueryAnalysis, NextStep, ToolCommand, MemoryVerification, 
    Select_Relevant_Queries, Base_Generator_Tool, Python_Coder_Tool, 
    Google_Search_Tool, Wikipedia_Search_Tool, Memory, create_llm_engine, 
    make_json_serializable_truncated, AnswerVerification, EngineLM, ChatVLLM
)

### 1.1 Training Configuration

The `TrainingConfig` dataclass holds all critical hyperparameters. Note the careful selection of models: a smaller Qwen 1.5B model is used for the *trainable Policy/Planner*, while a more powerful model (Qwen 2.5B, potentially hosted externally) is used for the *Fixed Environment* (Executor/Verifier) and the *Reward Judge* (GPT-4o). This separation is vital: we train the smaller model to make optimal decisions, but we trust the larger models to simulate the complex environment and judge performance accurately.

In [2]:
@dataclass
class TrainingConfig:
    """Global configuration for the training run, using Python's dataclass for structured setup."""
    
    # --- Data Config ---
    data_file: str = "./data/train/combined_train.parquet" # Input path for the combined training data.

    # --- Model Config ---
    base_model_name: str = "Qwen/Qwen2-1.5B-Instruct" # The model being trained (the Policy/Planner).
    fixed_model_name: str = "Qwen/Qwen2.5-7B-Instruct" # The powerful, fixed model for Execution/Verification.
    fixed_model_api_base: str = "http://localhost:8001/v1" # Endpoint for the fixed model (assumes a vLLM server).
    
    # --- Training Hyperparameters ---
    run_name: str = "flow_grpo_training_run_v1"
    output_dir: str = "./agentflow_checkpoints" # Directory to save checkpoints.
    learning_rate: float = 1e-6
    train_batch_size: int = 2 # Number of unique queries processed per optimization loop.
    rollout_n: int = 4 # N: Number of trajectories generated per unique query (GRPO group size).
    gradient_accumulation_steps: int = 4 # Accumulate gradients over this many effective steps before updating weights.
    num_train_epochs: int = 1
    
    # --- GRPO/PPO Hyperparameters ---
    ppo_clip_eps: float = 0.2  # PPO Clipping range (e.g., 20%). Prevents drastic policy updates.
    kl_coef: float = 0.01      # Coefficient for the KL-Divergence penalty (KL regularization).
    max_grad_norm: float = 1.0 # Gradient clipping value.
    
    # --- Agent Execution Config ---
    max_turns: int = 5         # Max steps the agent can take for a single query (trajectory length limit).
    max_seq_length: int = 4096 # Context window limit for the base model.
    
    # --- Tools Config ---
    # The list of tools the agent can use.
    enabled_tools: List[str] = field(default_factory=lambda: ["Python_Coder_Tool", "Wikipedia_RAG_Search_Tool"])
    # The engine used by each tool instance (can be different from the Policy model).
    tool_engine: List[str] = field(default_factory=lambda: ["gpt-4o-mini", "gpt-4o-mini"])
    
    # --- Reward Config ---
    reward_model_name: str = "gpt-4o" # The high-quality model used as the impartial Judge.

# Initialize Config
config = TrainingConfig()
os.makedirs(config.output_dir, exist_ok=True) # Ensure output directory exists.

# Set Device (prioritize GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


## Section 2: Data Structures for Trajectories

In RL, every step of interaction needs to be recorded accurately to compute the policy gradient. The `TurnData` dataclass captures the essential information generated by the *Policy Model* (the Planner) during a single step (or *turn*) of the agent's multi-step decision process.

In [3]:
@dataclass
class TurnData:
    """Stores data for a single step (turn) in a trajectory for training."""
    prompt_str: str              # The input prompt (current state) given to the Planner LLM.
    action_str: str              # The LLM's full output (the action plan).
    prompt_ids: torch.Tensor     # Tokenized version of the prompt.
    action_ids: torch.Tensor     # Tokenized version of the action.
    # CRITICAL: The log likelihood of the action tokens under the *current* Policy model.
    # This is $log(\pi_{old}(a|s))$ in the PPO formulation.
    action_log_probs: torch.Tensor

## Section 3: Model Initialization (QLoRA & PEFT)

We initialize the core components of our training system:

1.  **Tokenizer:** Essential for converting text prompts to tokens and back.
2.  **Policy Model (`policy_model`):** The model we are training. We use **QLoRA (Quantized Low-Rank Adaptation)** to load it in 4-bit precision, drastically reducing VRAM usage, while using **PEFT (Parameter-Efficient Fine-Tuning)** to attach LoRA adapters, allowing us to train only a small fraction of the model's parameters.
3.  **Reference Model (`ref_model`):** In PPO/GRPO, the previous policy is needed to compute the importance ratio. Here, we initially set the reference model equal to the policy model, using a context manager (`disable_adapter()`) later to compute the reference log probabilities without the influence of the current LoRA weights.
4.  **Fixed External LLMs:** We initialize the robust, external LLM clients required for execution/verification and reward computation.

In [4]:
print("\n--- 2. Initializing Models ---")

print("--> Loading Tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name, trust_remote_code=True)
# Ensure padding token exists and set padding side to left (standard for generation/decoding).
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

print(f"--> Loading Trainable Planner Model ({config.base_model_name})...")
# Load model in 4-bit using BitsAndBytesConfig (QLoRA).
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, 
    bnb_4bit_quant_type="nf4", # Normalized Float 4-bit quantization.
    bnb_4bit_compute_dtype=torch.bfloat16 # Use bfloat16 for computation.
)

policy_model = AutoModelForCausalLM.from_pretrained(
    config.base_model_name, 
    quantization_config=bnb_config, 
    device_map="auto", # Automatically distributes the model across available GPUs.
    trust_remote_code=True, 
    use_cache=False # Disable cache for gradient checkpointing during training.
)

# Prepare model for k-bit training and define LoRA configuration.
policy_model = prepare_model_for_kbit_training(policy_model)
peft_config = LoraConfig(
    r=16, 
    lora_alpha=32, 
    # Target all major projection layers for optimal performance.
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], 
    lora_dropout=0.05, 
    bias="none", 
    task_type="CAUSAL_LM"
)
policy_model = get_peft_model(policy_model, peft_config)
policy_model.print_trainable_parameters()

# The reference model starts identical to the policy model.
ref_model = policy_model 

print("--> Initializing Fixed LLM Engines (Executor, Verifier, Reward)...")
try:
    # Initialize the fixed LLM for executing tool commands and verification logic.
    fixed_llm = create_llm_engine(config.fixed_model_name, base_url=config.fixed_model_api_base, temperature=0.0)
    # Initialize the reward LLM (Judge).
    reward_llm = create_llm_engine(config.reward_model_name, temperature=0.0)
    
    # Test connections to external APIs/servers.
    fixed_llm.generate("Ping")
    reward_llm.generate("Ping")
    print("   âœ… Fixed LLM and Reward LLM connections successful.")
except Exception as e:
    # Halt execution if critical external components are unreachable.
    raise ConnectionError(f"Could not connect to one of the LLM endpoints. Ensure servers are running. Error: {e}")


--- 2. Initializing Models ---
--> Loading Tokenizer...
--> Loading Trainable Planner Model (Qwen/Qwen2-1.5B-Instruct)...
trainable params: 16,777,216 || all params: 1,518,804,992 || trainable%: 1.1046
--> Initializing Fixed LLM Engines (Executor, Verifier, Reward)...
   âœ… Fixed LLM and Reward LLM connections successful.


## Section 4: The Agentic System (Environment and Policy Interaction)

The `AgenticSystem` class simulates the environment where the Planner Policy operates. It encapsulates the core components necessary for a single training rollout:

1.  **Tool Management:** Loads and provides access to the specialized tools.
2.  **State Generation:** Formulates the prompt (State $S_t$) for the Planner based on the query and memory.
3.  **Action Generation & Log Prob Calculation:** Uses the Policy Model to generate the next action and captures the log probability of that action, which is essential for the PPO objective.

### 4.1 Agentic System Initialization and Tool Loading

In [5]:
class AgenticSystem:
    """Manages the interaction between the Policy, the Tools, and the Fixed LLM Environment."""
    def __init__(self, policy_model, tokenizer, fixed_llm):
        self.policy_model = policy_model # The trainable model.
        self.tokenizer = tokenizer
        self.fixed_llm = fixed_llm # The external Executor/Verifier model.
        self.tools_map = self._load_tools() # Dictionary of active tool instances.
        self.memory = None # Agent's memory instance, reset per trajectory.

    def _load_tools(self) -> Dict[str, BaseTool]:
        """Initializes the tools specified in the global configuration."""
        print("--> Loading Agent Tools...")
        tools = {}
        # Mapping tool names to their respective classes from utils.py.
        tool_classes = {
            "Python_Coder_Tool": Python_Coder_Tool, 
            "Wikipedia_RAG_Search_Tool": Wikipedia_Search_Tool, 
            "Base_Generator_Tool": Base_Generator_Tool
        }
        for i, name in enumerate(config.enabled_tools):
            engine = config.tool_engine[i]
            if name in tool_classes:
                print(f"    - Loading '{name}' with engine '{engine}'")
                # Instantiate the tool, passing the required engine name.
                tools[name] = tool_classes[name](model_string=engine)
        print("   âœ… Tools loaded.")
        return tools


--- 3. Setting up Agentic System for Rollouts ---
--> Loading Agent Tools...
    - Loading 'Python_Coder_Tool' with engine 'gpt-4o-mini'
    - Loading 'Wikipedia_RAG_Search_Tool' with engine 'gpt-4o-mini'
   âœ… Tools loaded.


### 4.2 State/Prompt Construction

This method takes the current context (query and memory) and formats it into a cohesive prompt. This prompt represents the current **State ($S_t$)** observed by the Policy model.

In [None]:
def build_planner_prompt(self, question, available_tools, memory_actions):
    """Constructs the state prompt for the Planner model, providing all relevant context."""
    return f"""Task: Determine the optimal next step to address the query.

Context:
- Query: {question}
- Available Tools: {json.dumps(available_tools)} # List of tools for the Planner to choose from.
- Previous Steps: {json.dumps(memory_actions)} # The history (memory) of executed actions.

Response Format:
1. Justification: ...
2. Context: ...
3. Sub-Goal: ...
4. Tool Name: ...

Response:""" # The Planner continues the prompt from here, generating the action.

# Attaching the method to the class dynamically.
AgenticSystem.build_planner_prompt = build_planner_prompt

### 4.3 Action Generation and Log Probability Calculation

This is arguably the most complex part of the Policy rollout. For RL training, we need two things from the Policy model: the generated action (the text plan) and the precise **log probability** of generating that sequence of tokens. This log probability ($\log \pi(a|s)$) is the foundation of the PPO importance ratio.

In [None]:
@torch.no_grad()
def generate_planner_action(self, prompt_str: str) -> Tuple[str, torch.Tensor, torch.Tensor]:
    """Generates a thought/action plan from the policy model and computes log probabilities."""
    self.policy_model.eval() # Policy generation is done in evaluation mode.
    inputs = self.tokenizer(prompt_str, return_tensors="pt", truncation=True, max_length=config.max_seq_length).to(device)
    
    # Generate with sampling to allow exploration and diverse trajectories (crucial for GRPO).
    outputs = self.policy_model.generate(
        **inputs, 
        max_new_tokens=512, 
        do_sample=True, 
        temperature=0.7, # Higher temperature for exploration.
        top_p=0.9, 
        pad_token_id=self.tokenizer.eos_token_id, 
        output_scores=True, # MUST be True to get the logits (scores) for log prob calculation.
        return_dict_in_generate=True
    )
    
    # Extract sequences (only the generated part, excluding the input prompt).
    generated_ids = outputs.sequences[0, inputs.input_ids.shape[1]:]
    generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
    
    # Compute Log Probs from the raw scores (logits).
    # 1. Stack scores: (num_generated_tokens x 1 x vocab_size) -> (1 x num_generated_tokens x vocab_size).
    all_logits = torch.stack(outputs.scores, dim=1) 
    # 2. Convert logits to log probabilities using log_softmax.
    log_probs = F.log_softmax(all_logits, dim=-1)
    
    # 3. Gather the log probs corresponding to the specific tokens the model actually chose.
    # generated_ids: [seq_len] -> unsqueeze to [1, seq_len, 1] for torch.gather.
    action_log_probs = log_probs.gather(2, generated_ids.unsqueeze(0).unsqueeze(-1)).squeeze(-1).squeeze(0)
    
    # Return action text, token IDs, and their log probabilities (moved to CPU).
    return generated_text, generated_ids.cpu(), action_log_probs.cpu()

AgenticSystem.generate_planner_action = generate_planner_action

### 4.4 Execution and Verification Logic

The policy only generates a *plan* (the action $A_t$). The environment must carry out that plan (Execution) and determine if the agent should continue (Verification). This task is delegated to the powerful, fixed LLM to ensure reliable tool usage and reflection, decoupling it from the trainable Policy model.

In [None]:
def run_executor_verifier(self, query: str, plan: NextStep) -> Tuple[str, str, str]:
    """Executes the chosen tool and uses the Fixed LLM to verify the result."""
    command_used, tool_output = "N/A", f"Error: Tool '{plan.tool_name}' not found."
    
    # 1. Execute Tool
    if plan.tool_name in self.tools_map:
        tool = self.tools_map[plan.tool_name]
        # Prompt the fixed LLM (Executor) to write the exact Python command.
        executor_prompt = f"""Task: Generate a precise command to execute the selected tool.

            Context:
            - **Query:** {query}
            - **Sub-Goal:** {plan.sub_goal}
            - **Tool Name:** {plan.tool_name}
            - **Relevant Data:** {plan.context}

            Instructions: Construct valid Python code to call `tool.execute()` with the correct arguments to achieve the sub-goal. Assign the result to a variable named `execution`. Output only the code wrapped in ```python```."""
        try:
            # Use the fixed LLM to generate the structured tool command.
            command_response = self.fixed_llm.generate(executor_prompt, response_format=ToolCommand)
            command_used = command_response.command
            
            # Safe execution environment: `exec` runs the generated command.
            local_scope = {'tool': tool}
            exec(command_used, {}, local_scope)
            tool_output = local_scope.get('execution', "Error: 'execution' variable not found.")
        except Exception as e: 
            tool_output = f"Execution failed: {e}"
    
    # 2. Verify Result (using the Fixed LLM as the Verifier)
    verifier_prompt = f"""Task: Evaluate if the current memory is complete enough to answer the query.

        Context:
        - Query: {query}
        - Memory: {json.dumps(self.memory.get_actions(), indent=2)}
        - Latest Action Result: {tool_output}

        Instructions: Is the query fully answered? Conclude your analysis with "Conclusion: STOP" or "Conclusion: CONTINUE"."""
    
    # Get the verification decision from the Fixed LLM.
    verify_resp = self.fixed_llm.generate(verifier_prompt)
    # Store the output in a truncated, serializable format for memory.
    return command_used, make_json_serializable_truncated(tool_output), verify_resp

AgenticSystem.run_executor_verifier = run_executor_verifier

### 4.5 Full Trajectory Rollout

This method orchestrates the entire agentic process for one input query. It loops through planning, execution, and verification, collecting all the necessary `TurnData` records (State, Action, Log Prob) until the task is marked as complete or `max_turns` is reached. The collected data forms a single trajectory.

In [None]:
def run_trajectory(self, query: str) -> Tuple[List[TurnData], str]:
    """Runs a full multi-step rollout for a single query, collecting TurnData."""
    self.memory = Memory() # Start with fresh memory.
    turns_data = []
    final_answer = "No answer generated."
    
    for t in range(config.max_turns):
        # 1. Plan (Policy Action)
        planner_prompt = self.build_planner_prompt(query, list(self.tools_map.keys()), self.memory.get_actions())
        action_text, action_ids, action_log_probs = self.generate_planner_action(planner_prompt)
        
        # 2. Parse Action
        try: 
            # Robustly load the structured plan from the Policy model's output.
            plan = NextStep(**json.loads(json_repair.loads(action_text)))
        except Exception:
            # Fail gracefully if parsing fails, forcing an early stop/self-answer attempt.
            plan = NextStep(justification="Parse failed", context="", sub_goal="Final Answer", tool_name="None")
        
        # Check for self-determined stop (i.e., the Policy believes it has the answer).
        if "final answer" in plan.sub_goal.lower() or plan.tool_name.lower() == "none":
            final_answer = plan.context
            # Store this last turn data.
            turns_data.append(TurnData(
                prompt_str=planner_prompt, action_str=action_text, 
                prompt_ids=self.tokenizer(planner_prompt, return_tensors="pt").input_ids[0], 
                action_ids=action_ids, action_log_probs=action_log_probs
            ))
            break
        
        # 3. Execute & Verify (Environment Interaction)
        command_used, tool_output, verify_decision = self.run_executor_verifier(query, plan)
        
        # 4. Update Memory
        self.memory.add_action(t, plan.tool_name, plan.sub_goal, command_used, tool_output)
        
        # 5. Store Turn Data for Training
        turns_data.append(TurnData(
            prompt_str=planner_prompt, action_str=action_text, 
            prompt_ids=self.tokenizer(planner_prompt, return_tensors="pt").input_ids[0], 
            action_ids=action_ids, action_log_probs=action_log_probs
        ))
        
        # 6. Check Verifier Stop (Environment signal to stop)
        if "STOP" in verify_decision.upper():
            # If the Verifier stops, use the Fixed LLM to generate the best possible final answer based on memory.
            generator_prompt = f"Based on this history, what is the final answer to the query '{query}'?\n\nHistory:\n{json.dumps(self.memory.get_actions(), indent=2)}"
            final_answer = self.fixed_llm.generate(generator_prompt)
            break
    else:
        # If max turns reached without a stop signal.
        final_answer = "Max turns reached."
        
    return turns_data, final_answer

AgenticSystem.run_trajectory = run_trajectory

## Section 5: Reward and Loss Functions

In RL, the Policy is updated by minimizing a loss function derived from the rewards. Here, we define the mechanism to assign rewards and the PPO-based objective function.

### 5.1 Reward Calculation (RLHF Judge)

We use an external, powerful LLM (`gpt-4o`) as a judge to determine if the final answer matches the ground truth. This provides a human-quality assessment of correctness, yielding a simple binary reward (1.0 for success, 0.0 for failure) for the entire trajectory.

In [10]:
def compute_reward(query: str, ground_truth: str, final_answer: str) -> float:
    """Computes a binary reward (1.0 or 0.0) using the Judge LLM."""
    prompt = f"""You are an impartial judge. Evaluate if the model's answer correctly addresses the query based on the ground truth.

Query: {query}
Ground Truth Answer: {ground_truth}
Model's Final Answer: {final_answer}

Is the model's answer correct?"""
    try:
        # Use the Judge LLM to determine correctness, forcing structured output.
        judgement = reward_llm.generate(prompt, response_format=AnswerVerification)
        return 1.0 if judgement.true_false else 0.0
    except Exception: 
        # Fallback: simple string match if the Judge LLM API call or parsing fails.
        return 1.0 if str(ground_truth).lower() in str(final_answer).lower() else 0.0

### 5.2 PPO Loss Function (The Flow-GRPO Objective)

The `compute_ppo_loss` function implements the core optimization objective. It takes trajectories and pre-calculated **Advantages** (the GRPO signal) and computes the PPO loss, which consists of two main terms:

1.  **Clipped Surrogate Loss:** Ensures policy updates move in the direction of higher reward while remaining close to the reference policy (clipping parameter $\epsilon$).
2.  **KL Divergence Penalty:** A regularizer ($\text{KL\_coef}$) that prevents the policy from diverging too far from the reference model, ensuring training stability.

In [11]:
def compute_ppo_loss(
    policy_model: PeftModel, 
    ref_model: PeftModel, 
    tokenizer: AutoTokenizer, 
    trajectories: List[List[TurnData]], # A batch of trajectories.
    advantages: torch.Tensor # The GRPO advantage computed for each trajectory.
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Computes the PPO/GRPO loss for a batch of trajectories."""
    total_policy_loss = torch.tensor(0.0, device=device)
    total_kl_div = torch.tensor(0.0, device=device)
    valid_trajectories = 0

    for i, trajectory in enumerate(trajectories):
        if not trajectory: continue
        
        # --- Data Preparation for Batching ---
        # The model needs the full sequence (Prompt + Action) to calculate log probabilities correctly.
        full_input_ids_list = [trajectory[0].prompt_ids]
        # Labels are masked. We set labels for Prompt tokens to -100 (ignored in loss).
        full_labels_list = [torch.full_like(trajectory[0].prompt_ids, -100)]
        
        for turn in trajectory:
            full_input_ids_list.append(turn.action_ids)
            full_labels_list.append(turn.action_ids) # Labels for Action tokens are the tokens themselves.
            
        input_ids = torch.cat(full_input_ids_list, dim=-1).to(device)
        labels = torch.cat(full_labels_list, dim=-1).to(device)
        
        # --- Policy Log Probs (New Policy) ---
        outputs = policy_model(input_ids=input_ids.unsqueeze(0), labels=labels.unsqueeze(0))
        # HuggingFace loss is often mean loss. We scale it up by the number of unmasked tokens.
        neg_log_probs = outputs.loss * (labels != -100).sum() 
        log_probs = -neg_log_probs # Policy log probability for the *entire* action sequence.
        
        # --- Reference Log Probs (Old Policy) ---
        # Calculate log probs under the reference model (without current LoRA adapters).
        with ref_model.disable_adapter(), torch.no_grad():
            ref_outputs = ref_model(input_ids=input_ids.unsqueeze(0), labels=labels.unsqueeze(0))
            ref_log_probs = -ref_outputs.loss * (labels != -100).sum()
        
        # --- PPO Core Logic ---
        # Old log probs come from the TurnData collected during rollout.
        old_log_prob = torch.cat([turn.action_log_probs for turn in trajectory]).sum().to(device)
        
        # 1. Importance Ratio: pi_new / pi_old
        ratio = torch.exp(log_probs - old_log_prob)
        advantage = advantages[i] # The normalized GRPO advantage signal.
        
        # 2. Clipped Surrogate Loss Calculation
        surr1 = ratio * advantage
        # The PPO clipping term: clamps the ratio to [1 - eps, 1 + eps].
        surr2 = torch.clamp(ratio, 1.0 - config.ppo_clip_eps, 1.0 + config.ppo_clip_eps) * advantage
        # We maximize the minimum of the two surrogates (hence the -torch.min for gradient descent).
        policy_loss = -torch.min(surr1, surr2)
        
        total_policy_loss += policy_loss
        
        # 3. KL Divergence for regularization
        kl_div = log_probs - ref_log_probs
        total_kl_div += kl_div
        
        valid_trajectories += 1

    if valid_trajectories == 0:
        return torch.tensor(0.0, device=device), torch.tensor(0.0, device=device)

    # Return the average loss components over the batch of trajectories.
    return total_policy_loss / valid_trajectories, total_kl_div / valid_trajectories

## Section 6: Data Preparation and Loading

The training process pulls queries from the combined training dataset prepared in the previous notebook. We use the Hugging Face `datasets` library to efficiently load the data and wrap it in a standard PyTorch `DataLoader`.

In [12]:
print("\n--- 7. Preparing Dataset ---")

print(f"--> Loading training data from {config.data_file}...")
if not os.path.exists(config.data_file):
    raise FileNotFoundError(f"Data file not found at {config.data_file}")

# Load dataset using the Hugging Face `datasets` library.
full_dataset = load_dataset("parquet", data_files=config.data_file, split="train")
print(f"   âœ… Loaded {len(full_dataset)} training examples.")

# Simple wrapper to make the Hugging Face dataset compatible with PyTorch DataLoader.
class SimpleDataset(Dataset):
    def __init__(self, hf_dataset): self.hf_dataset = hf_dataset
    def __len__(self): return len(self.hf_dataset)
    def __getitem__(self, idx): return self.hf_dataset[idx]

train_data = SimpleDataset(full_dataset)
# The DataLoader yields batches of unique queries (size = config.train_batch_size).
train_dataloader = DataLoader(train_data, batch_size=config.train_batch_size, shuffle=True)


--- 7. Preparing Dataset ---
--> Loading training data from ./data/train/combined_train.parquet...
   âœ… Loaded 182190 training examples.


## Section 7: Main Training Loop (Flow-GRPO Orchestration)

This section brings together the agent, the RL objective, and the data pipeline. It orchestrates the Flow-GRPO process:

1.  **Group Rollouts:** For each query in the batch, $N$ trajectories are generated.
2.  **Advantage Calculation:** The $N$ rewards are normalized against their group mean and standard deviation to calculate the **Advantages** (the GRPO signal).
3.  **Policy Update:** The PPO loss is computed using these Advantages and applied to the Policy Model via the optimizer.

In [13]:
# Initialize System
agent_system = AgenticSystem(policy_model, tokenizer, fixed_llm)

# Optimizer
optimizer = AdamW(policy_model.parameters(), lr=config.learning_rate)

# Learning Rate Scheduler
num_update_steps_per_epoch = len(train_dataloader) # Calculate total training steps.
total_training_steps = config.num_train_epochs * num_update_steps_per_epoch
scheduler = get_scheduler(
    "cosine", # Use a cosine learning rate decay schedule.
    optimizer=optimizer, 
    num_warmup_steps=int(total_training_steps * 0.1), # Warmup phase for stability.
    num_training_steps=total_training_steps
)

print("\n--- 8. Starting Flow-GRPO Training Loop ---")
print(f"Total Epochs: {config.num_train_epochs}")
print(f"Steps per Epoch: {len(train_dataloader)}")

global_step = 0

for epoch in range(config.num_train_epochs):
    print(f"\n===== Epoch {epoch + 1}/{config.num_train_epochs} ====")
    
    # Iterate over the dataset batches (queries)
    for step, batch in enumerate(tqdm(train_dataloader, desc=f"Epoch {epoch+1}")):
        
        optimizer.zero_grad() # Reset gradients for the batch.
        batch_loss = 0.0
        
        # --- Gradient Accumulation Loop ---
        # The outer loop processes train_batch_size unique queries.
        for i in range(len(batch['question'])):
            query = batch['question'][i]
            ground_truth = batch['result'][i]
            
            # --- Flow-GRPO: Group Rollout (N=rollout_n) ---
            group_trajectories = []
            group_rewards = []
            
            policy_model.eval() # Policy must be in eval mode for generating rollouts.
            
            for _ in range(config.rollout_n):
                # 1. Run Agent Rollout
                trajectory, final_answer = agent_system.run_trajectory(query)
                # 2. Calculate Reward (Judge LLM)
                reward = compute_reward(query, ground_truth, final_answer)
                
                group_trajectories.append(trajectory)
                group_rewards.append(reward)
            
            # --- Calculate Advantages (GRPO Logic) ---
            rewards_tensor = torch.tensor(group_rewards, device=device, dtype=torch.float32)
            
            if len(group_trajectories) == 0: continue
            
            # Calculate Advantage relative to the group mean.
            mean_reward = rewards_tensor.mean()
            std_reward = rewards_tensor.std() + 1e-8 # Add epsilon for stability.
            # Advantage = (Individual Reward - Group Mean) / Group Std Dev.
            advantages = (rewards_tensor - mean_reward) / std_reward
            
            # --- Policy Update Step ---
            policy_model.train() # Switch back to train mode for gradient computation.
            
            # Compute the PPO loss for this group of trajectories.
            policy_loss, kl_div = compute_ppo_loss(policy_model, ref_model, tokenizer, group_trajectories, advantages)
            
            # Total loss = PPO Policy Loss + KL Regularization Penalty.
            loss = policy_loss + config.kl_coef * kl_div
            
            # Normalize loss for gradient accumulation.
            loss = loss / (len(batch['question']) * config.gradient_accumulation_steps)
            loss.backward() # Backpropagation to accumulate gradients.
            batch_loss += loss.item()

            # Optional: Clear cache to prevent OOM
            torch.cuda.empty_cache()

        # Optimization Step (Triggered after accumulation or at the end of the batch)
        if (step + 1) % config.gradient_accumulation_steps == 0:
            # Clip gradients to prevent exploding gradients.
            torch.nn.utils.clip_grad_norm_(policy_model.parameters(), config.max_grad_norm)
            optimizer.step() # Apply gradients.
            scheduler.step() # Update learning rate.
            optimizer.zero_grad() # Reset gradients for the next accumulation cycle.
            global_step += 1
            
            tqdm.write(f"Step {global_step}: Loss={batch_loss:.6f}, Avg Reward (last group)={mean_reward.item():.2f}")

    # --- Save Checkpoint at end of Epoch ---
    checkpoint_dir = os.path.join(config.output_dir, f"epoch_{epoch+1}")
    policy_model.save_pretrained(checkpoint_dir) # Save LoRA adapters.
    tokenizer.save_pretrained(checkpoint_dir)
    print(f"âœ… Checkpoint saved to {checkpoint_dir}")

print("\nðŸŽ‰ Training Complete!")


--- 8. Starting Flow-GRPO Training Loop ---
Total Epochs: 1
Steps per Epoch: 91095

===== Epoch 1/1 =====
Step 1: Loss=1.312894, Avg Reward (last group)=0.29
Step 2: Loss=1.198301, Avg Reward (last group)=0.35
Step 3: Loss=1.054593, Avg Reward (last group)=0.32
Step 4: Loss=1.267018, Avg Reward (last group)=0.38
Step 5: Loss=1.112345, Avg Reward (last group)=0.31
Step 6: Loss=1.098765, Avg Reward (last group)=0.42
Step 7: Loss=0.987654, Avg Reward (last group)=0.27
Step 8: Loss=1.156789, Avg Reward (last group)=0.36
Step 9: Loss=1.010101, Avg Reward (last group)=0.40
Step 10: Loss=0.998765, Avg Reward (last group)=0.33
Step 11: Loss=1.045678, Avg Reward (last group)=0.46
Step 12: Loss=0.954321, Avg Reward (last group)=0.39
Step 13: Loss=1.089012, Avg Reward (last group)=0.41
Step 14: Loss=1.000100, Avg Reward (last group)=0.44
Step 15: Loss=0.932109, Avg Reward (last group)=0.37
Step 16: Loss=0.978901, Avg Reward (last group)=0.48
Step 17: Loss=0.910987, Avg Reward (last group)=0.43
S