# MoE Training with GRPO

### WandB Configuration

In this step, we set up **Weights & Biases (WandB)** to track and log our GRPO training experiments.  
WandB allows us to monitor key metrics, store hyperparameters, and compare runs for reproducibility.

Specifically, we:

- Set WandB environment variables including the API key and logging directory.  
- Define the main hyperparameters for **GRPO training**: learning rate, batch size, number of rollouts, buffer size, and number of epochs.  
- Initialize a new WandB run with these hyperparameters for tracking.

In [1]:
import wandb
import os 

os.environ["WANDB_API_KEY"] = "a2bc4ed87fda51f3426f2cf42515bfaf295eaf73"
os.environ["WANDB_DIR"] = ".."

learning_rate = 1e-6
batch_size = 1
num_rollouts = 4
buffer_size = 40
num_epochs = 2

# Initialize wandb
wandb.init(
    project="blue-yonder-mle-assignment", 
    name="granite-3.0-grpo",
    config={
        "learning_rate": learning_rate,
        "batch_size": batch_size,
        "num_rollouts": num_rollouts,
        "buffer_size": buffer_size,
        "num_epochs": num_epochs
    }
)

[34m[1mwandb[0m: Currently logged in as: [33mfa_mekrache[0m ([33maasr[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


### Model and LoRA Configuration

In this step, we load the **fine-tuned SFT model** (LoRA adapters already trained) and prepare it for **GRPO training**.

Specifically, we:

- Load the **base IBM Granite MoE model** using `transformers`.  
- Load the previously fine-tuned **SFT LoRA checkpoint** via PEFT.  
- Initialize the tokenizer from the SFT checkpoint, ensuring `padding_side="left"` and aligning the pad token with the EOS token.  
- Ensure that **LoRA parameters remain trainable** for GRPO updates.  
- Set up the optimizer and prepare the model, tokenizer, and optimizer with **`Accelerator`** for efficient multi-device training.

This setup allows the model to start GRPO from the **SFT fine-tuned weights**, leveraging the previously learned reasoning capabilities while enabling policy optimization.


In [2]:
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from accelerate import Accelerator

base_model_id = "ibm-granite/granite-3.0-1b-a400m-base"

# Load base model
model = AutoModelForCausalLM.from_pretrained(
    base_model_id,
    device_map="auto",
)

# Load SFT adapter
llm = PeftModel.from_pretrained(model, "../checkpoints/granite-1b-a400m-blue-yonder-sft/checkpoint-402")
tokenizer = AutoTokenizer.from_pretrained("../checkpoints/granite-1b-a400m-blue-yonder-sft/checkpoint-402",  local_files_only=True)

# Ensure LoRA parameters are trainable
for name, param in llm.named_parameters():
    if 'lora' in name.lower():
        param.requires_grad = True

device = "cuda" if torch.cuda.is_available() else "cpu"

# Print trainable parameters
print(llm.print_trainable_parameters())

optimizer = torch.optim.Adam(llm.parameters(), lr=learning_rate)

accelerator = Accelerator()
llm, tokenizer, optimizer = accelerator.prepare(
    llm, tokenizer, optimizer
)


  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  3.12it/s]


trainable params: 2,752,512 || all params: 1,337,377,792 || trainable%: 0.2058
None


### Expert Utilization Monitoring

This section monitors **how frequently each expert is used** in the MoE layers during training.  

- A **hook function** is registered on the main MoE layers and their gates to track the outputs of the gating mechanism.  
- During each forward pass, the hook accumulates the usage of each expert in a global dictionary.  
- The `log_expert_usage` function computes the **percentage of usage per expert** and clears the buffer for the next step.  

This monitoring helps analyze **load balancing and expert activity** throughout training.

In [3]:
# Expert Utilization Monitoring
expert_usage = {}

def hook_fn(module, input, output):
    if isinstance(output, tuple):
        expert_weights = output[1]  
        if hasattr(expert_weights, 'sum') and len(expert_weights.shape) == 2:
            usage_per_expert = expert_weights.sum(dim=0).detach().cpu().numpy()
            for i, usage in enumerate(usage_per_expert):
                if i not in expert_usage:
                    expert_usage[i] = 0.0
                expert_usage[i] += float(usage)

# Find only main MoE layers and register hooks
moe_layers = []
for name, module in model.named_modules():
    if name.endswith('block_sparse_moe'):
        moe_layers.append((name, module))
        module.register_forward_hook(hook_fn)
        if hasattr(module, 'gate'):
            module.gate.register_forward_hook(hook_fn)

print(f"Found {len(moe_layers)} main MoE layers: {[n for n,m in moe_layers]}")

def log_expert_usage():
    """
    Returns a dictionary with flattened expert usage percentages.

    Args:
        expert_usage (dict): Dictionary with expert counts, e.g., {'0': 5, '1': 3, ...}

    Returns:
        dict: Flattened expert usage percentages, e.g.,
              {'expert_usage/expert_0': 25.0, 'expert_usage/expert_1': 15.0, ...}
    """
    logs = {}
    if expert_usage:
        total = sum(expert_usage.values())
        if total > 0:
            expert_percent = {f"expert_usage/expert_{k}": float(v) / total * 100 for k, v in expert_usage.items()}
        else:
            expert_percent = {f"expert_usage/expert_{k}": 0.0 for k in expert_usage}

        logs.update(expert_percent)
        expert_usage.clear()  # optional: clear buffer after computing

    return logs

Found 24 main MoE layers: ['model.layers.0.block_sparse_moe', 'model.layers.1.block_sparse_moe', 'model.layers.2.block_sparse_moe', 'model.layers.3.block_sparse_moe', 'model.layers.4.block_sparse_moe', 'model.layers.5.block_sparse_moe', 'model.layers.6.block_sparse_moe', 'model.layers.7.block_sparse_moe', 'model.layers.8.block_sparse_moe', 'model.layers.9.block_sparse_moe', 'model.layers.10.block_sparse_moe', 'model.layers.11.block_sparse_moe', 'model.layers.12.block_sparse_moe', 'model.layers.13.block_sparse_moe', 'model.layers.14.block_sparse_moe', 'model.layers.15.block_sparse_moe', 'model.layers.16.block_sparse_moe', 'model.layers.17.block_sparse_moe', 'model.layers.18.block_sparse_moe', 'model.layers.19.block_sparse_moe', 'model.layers.20.block_sparse_moe', 'model.layers.21.block_sparse_moe', 'model.layers.22.block_sparse_moe', 'model.layers.23.block_sparse_moe']


### GRPO Training Loop

This section runs the **GRPO training loop** on the GSM8K dataset using the SFT fine-tuned LoRA model.

Key steps:

1. **Environment Setup:**  
   - Load GSM8K dataset and initialize a custom RL environment (`GSM8KEnv`) that handles batching and reward computation.

2. **Experience Collection:**  
   - Sample a batch of problems from the environment each iteration.  
   - Generate model outputs and compute **rewards** using the environment.  
   - Accumulate experiences in a buffer for training.  
   - Track **expert usage** via `log_expert_usage`.

3. **Policy Update (GRPO):**  
   - When the buffer reaches the defined size:
     - Shuffle and split it into mini-batches.  
     - For each mini-batch and epoch, compute loss with `train_on_batch`.  
     - Update gradients using the optimizer and `Accelerator` for multi-device support.  
   - Compute the **average loss** over all updates.

4. **Reward Function:**  
   - Combines multiple aspects to guide policy updates:
     1. **Correctness of the final answer (70%)** – reward 1.0 if correct, else 0.0.  
     2. **Formatting (15%)** – ensures model outputs both reasoning (`<think>`) and final answer (`<answer>`).  
     3. **Reasoning quality (15%)** – compares model’s reasoning to reference using **BERTScore F1**.  

5. **Logging and Monitoring:**  
   - Log **mean reward**, **average loss**, and **expert usage** to WandB at each step.  
   - Print sample metrics to monitor training progress.

6. **Termination and Saving:**  
   - Stop training after reaching the maximum number of steps.  
   - Save the updated LoRA model and tokenizer to the checkpoint directory.

This loop allows the model to **improve its policy** via GRPO while leveraging reasoning capabilities learned during SFT, guided by a structured reward function.

In [4]:
import sys
from datasets import load_dataset
sys.path.append('..')
from utils.dataset import GSM8KEnv
from utils.grpo import collect_experiences, train_on_batch
import random

gsm8k_dataset = load_dataset("gsm8k", "main", split="train")
gsm8k_train_env = GSM8KEnv(gsm8k_dataset, tokenizer)
gsm8k_train_env.reset()
printing_steps = 50

buffer = []
step = 0
training_steps = 500
batch_mean_reward = 0.0

while True:  
    try:
        # Sample batch from environment
        batch = gsm8k_train_env.sample_batch(batch_size)

        experiences, mean_reward = collect_experiences(llm, tokenizer, accelerator, batch, batch_size, num_rollouts)
        buffer.extend(experiences)
        
        # Expert usage
        logs = {
            "rl_training/mean_reward": mean_reward,
            **log_expert_usage(),
        }

        if len(buffer) >= buffer_size:  
            buffer = buffer[-buffer_size:]  # Keep only the most recent experiences

            random.shuffle(buffer)

            buffer = buffer[:buffer_size]  
            optimizer.zero_grad()
            llm.train()

            total_loss = 0.0
            num_batches = 0
            
            for _ in range(num_epochs):
                for i in range(0, buffer_size, batch_size):
                    
                    training_batch = buffer[i : i + batch_size]
                    loss = train_on_batch(llm, tokenizer, accelerator, training_batch)

                    optimizer.zero_grad()
                    accelerator.backward(loss)
                    optimizer.step()

                    total_loss += loss.item()
                    num_batches += 1
            
            avg_loss = total_loss / num_batches if num_batches > 0 else 0.0

            # Prepare logs for loss and expert usage
            logs["rl_training/avg_loss"] = avg_loss

            print(f"Step {step}: Loss = {avg_loss:.4f}, Mean Reward = {mean_reward:.4f}")

            buffer = []

        # Logging
        wandb.log(logs, step=step)
        step += 1

        if step >= training_steps:
            print("Reached maximum training steps. Exiting training loop.")
            break

    except Exception as e:
        print(f"Error at step {step}: {e}")
        print(f"Error type: {type(e)}")
        import traceback
        traceback.print_exc()
        break

output_dir = "../checkpoints/granite-1b-a400m-blue-yonder-grpo"
llm.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

Step 9: Loss = 0.0178, Mean Reward = 0.4571
Step 19: Loss = 0.0271, Mean Reward = 0.5992
Step 29: Loss = 0.0410, Mean Reward = 0.2623
Step 39: Loss = 0.0251, Mean Reward = 0.1956
Step 49: Loss = 0.0299, Mean Reward = 0.2577
Step 59: Loss = 0.0382, Mean Reward = 0.2502
Step 70: Loss = -0.0158, Mean Reward = 0.9488
Step 80: Loss = 0.0289, Mean Reward = 0.2431
Step 90: Loss = 0.0409, Mean Reward = 0.2544
Step 100: Loss = 0.0259, Mean Reward = 0.2659
Step 110: Loss = 0.0326, Mean Reward = 0.2323
Step 120: Loss = 0.0327, Mean Reward = 0.2363
Step 131: Loss = 0.0905, Mean Reward = 0.2689
Step 141: Loss = 0.0377, Mean Reward = 0.7933
Step 151: Loss = 0.0333, Mean Reward = 0.4204
Step 161: Loss = 0.0369, Mean Reward = 0.9684
Step 171: Loss = 0.0281, Mean Reward = 0.5939
Step 181: Loss = 0.0270, Mean Reward = 0.2430
Step 192: Loss = 0.0159, Mean Reward = 0.2490
Step 202: Loss = 0.0241, Mean Reward = 0.2575
Step 212: Loss = 0.0393, Mean Reward = 0.2225
Step 222: Loss = 0.0320, Mean Reward = 0.60

('../checkpoints/granite-1b-a400m-blue-yonder-grpo/tokenizer_config.json',
 '../checkpoints/granite-1b-a400m-blue-yonder-grpo/special_tokens_map.json',
 '../checkpoints/granite-1b-a400m-blue-yonder-grpo/vocab.json',
 '../checkpoints/granite-1b-a400m-blue-yonder-grpo/merges.txt',
 '../checkpoints/granite-1b-a400m-blue-yonder-grpo/added_tokens.json',
 '../checkpoints/granite-1b-a400m-blue-yonder-grpo/tokenizer.json')

## Evaluation on GSM8K Test Set

We evaluate the GRPO-trained model on the **GSM8K test set** using a custom environment (`GSM8KEnv`) that handles problem presentation, step-by-step reasoning, and scoring.

**Evaluation loop:**
1. Reset the environment to get a new math problem.  
2. Apply the chat template to format the input.  
3. Generate reasoning and answer with the model.  
4. Decode the output and submit it to the environment to receive a reward.  
5. Accumulate the total score to compute a weighted accuracy.

The **reward function** considers not only whether the final answer is correct but also the **format and reasoning quality**, combining three components:

1. **Correctness of the final answer (70%)**  
   - 1.0 if the predicted answer matches the gold answer exactly, else 0.0.

2. **Formatting reward (15%)**  
   - Checks that the model produced both reasoning (`<think>`) and an answer (`<answer>`).  
   - Encourages structured and readable outputs.

3. **Reasoning similarity (15%)**  
   - Measures how close the model's reasoning is to the reference using **BERTScore F1**.  
   - Encourages coherent, step-by-step explanations.

The **weighted sum** of these three metrics forms the final reward for each sample, providing a comprehensive evaluation of reasoning and answer quality.

In [5]:
from tqdm import tqdm

eval_dataset = load_dataset("openai/gsm8k", "main", split="test")
gsm8k_eval_env = GSM8KEnv(eval_dataset, tokenizer)

gsm8k_eval_env.current_idx = 0
total_score = 0.0
llm.eval()

for i in tqdm(range(len(gsm8k_eval_env.dataset)), desc="Evaluating"):
    # Get problem from environment
    obs, _ = gsm8k_eval_env.reset() 

    text = obs

    output_ids = llm.generate(
        **tokenizer(text, return_tensors="pt").to("cuda"),
        max_new_tokens=512,
    )

    pred = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    _, reward, terminated, truncated, info = gsm8k_eval_env.step(pred)  # Gymnasium returns 5 values
    gold = info['gold']
    
    total_score += reward

N = len(gsm8k_eval_env.dataset)
print(f"\nResults on {N} samples:")
print(f"  Weighted score : {total_score / N:.2%}")

Evaluating:   1%|          | 7/1319 [00:28<1:40:38,  4.60s/it]

Evaluating: 100%|██████████| 1319/1319 [1:23:40<00:00,  3.81s/it]


Results on 1319 samples:
  Weighted score : 40.68%



