# Phase 2: GSPO Alignment Training on Qwen3.5-35B-A3B

This notebook performs **GSPO (Group Sequence Policy Optimization)** alignment training on **Qwen/Qwen3.5-35B-A3B** using the GSPO-prepared Norwegian Alpaca dataset from Phase 1.5.

## What is GSPO?

GSPO is the RL alignment algorithm developed by the Qwen team ([arXiv:2507.18071](https://arxiv.org/abs/2507.18071)). It improves upon GRPO (DeepSeek's Group Relative Policy Optimization) in two critical ways:

1. **Sequence-level importance ratios** instead of token-level — eliminates noise and instability of token-level optimization
2. **Inherent MoE stability** — GRPO requires "Routing Replay" to stabilize MoE expert routing; GSPO eliminates this dependency entirely

In TRL, GSPO is implemented via `GRPOTrainer` with `importance_sampling_level="sequence"`.

## Why the Instruct Model?

- GSPO is a **policy optimization** method that refines an already-capable policy
- The instruct model already follows instructions, giving GSPO meaningful behavior to optimize
- Qwen3.5-35B-A3B is a sparse MoE model (35B total, 3B activated) — GSPO was designed for stable MoE RL training

In [None]:
# Install the library with GSPO extras
# On Colab, clone the repo first:
#   !git clone https://github.com/your-username/NORAI-Tools.git /content/NORAI-Tools
#   %pip install -e "/content/NORAI-Tools[gspo]"

%pip install -e "../[gspo]"
%pip install trl>=0.28.0 peft>=0.15.0 accelerate bitsandbytes

from norai_tools import (
    semantic_reward,
    language_reward,
    length_reward,
    accuracy_reward,
)
from datasets import load_dataset
import pandas as pd

In [None]:
# ============================================================
# Configuration — Model, GSPO params, LoRA, reward weights
# ============================================================

# Model
MODEL_NAME = "Qwen/Qwen3.5-35B-A3B"

# Dataset
GSPO_DATASET_PATH = "norwegian_alpaca_gspo.parquet"

# GSPO-specific settings
IMPORTANCE_SAMPLING_LEVEL = "sequence"  # Key GSPO setting (vs "token" for GRPO)
LOSS_TYPE = "grpo"
BETA = 0.04                             # KL penalty (GSPO paper default)
EPSILON = 3e-4                          # Clipping range (tighter than GRPO default)
NUM_GENERATIONS = 8                     # Group size G

# Training hyperparams
LEARNING_RATE = 5e-7
PER_DEVICE_TRAIN_BATCH_SIZE = 2
GRADIENT_ACCUMULATION_STEPS = 4
MAX_COMPLETION_LENGTH = 512
NUM_TRAIN_EPOCHS = 1
LOGGING_STEPS = 10
SAVE_STEPS = 200
BF16 = True
USE_VLLM = True
VLLM_MODE = "colocate"

# LoRA config
LORA_R = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0.05
LORA_TARGET_MODULES = ["q_proj", "v_proj", "k_proj", "o_proj"]

# Reward weights: [semantic, language, length, accuracy]
REWARD_WEIGHTS = [2.0, 1.0, 0.5, 1.5]

# Output
OUTPUT_DIR = "./gspo_qwen_norwegian"

print(f"Model: {MODEL_NAME}")
print(f"GSPO: importance_sampling_level={IMPORTANCE_SAMPLING_LEVEL}, beta={BETA}, epsilon={EPSILON}")
print(f"LoRA: r={LORA_R}, alpha={LORA_ALPHA}, targets={LORA_TARGET_MODULES}")
print(f"Generations per prompt: {NUM_GENERATIONS}")
print(f"Reward weights: {REWARD_WEIGHTS}")

In [None]:
# ============================================================
# Load GSPO-prepared dataset, verify columns + task distribution
# ============================================================

from collections import Counter

dataset = load_dataset("parquet", data_files=GSPO_DATASET_PATH, split="train")

print(f"Dataset loaded: {len(dataset)} rows")
print(f"Columns: {dataset.column_names}")

# Verify required columns
required = ["prompt", "task_type", "output_improved"]
missing = [c for c in required if c not in dataset.column_names]
if missing:
    raise ValueError(
        f"Missing columns: {missing}. "
        "Run Phase 1.5 (prepare_gspo_dataset.ipynb) first."
    )

# Task type distribution
task_counts = Counter(dataset["task_type"])
print("\nTask type distribution:")
for task, count in sorted(task_counts.items(), key=lambda x: -x[1]):
    print(f"  {task:20s}: {count:6d} ({100*count/len(dataset):.1f}%)")

# Show sample
pd.set_option("display.max_colwidth", 100)
display(dataset.select(range(3)).to_pandas()[["task_type", "output_improved"]].head())

In [None]:
# ============================================================
# Pre-load semantic similarity model + reward summary
# ============================================================

# Warm up the semantic_reward model with a dummy call so it's
# loaded before training starts (avoids timeout on first batch)
print("Loading semantic similarity model (intfloat/multilingual-e5-large)...")
_ = semantic_reward(
    completions=[[{"content": "test"}]],
    output_improved=["test"],
)
print("Similarity model loaded.")

print("\nReward functions:")
names = ["semantic_reward", "language_reward", "length_reward", "accuracy_reward"]
for name, weight in zip(names, REWARD_WEIGHTS):
    print(f"  {name:20s}  weight={weight}")

In [None]:
# ============================================================
# Setup LoRA + GRPOConfig with GSPO settings
# ============================================================

from peft import LoraConfig
from trl import GRPOConfig

peft_config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    target_modules=LORA_TARGET_MODULES,
    task_type="CAUSAL_LM",
)

training_args = GRPOConfig(
    output_dir=OUTPUT_DIR,
    # GSPO-specific
    importance_sampling_level=IMPORTANCE_SAMPLING_LEVEL,
    loss_type=LOSS_TYPE,
    beta=BETA,
    epsilon=EPSILON,
    # Generation
    num_generations=NUM_GENERATIONS,
    max_completion_length=MAX_COMPLETION_LENGTH,
    # Training
    per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    learning_rate=LEARNING_RATE,
    num_train_epochs=NUM_TRAIN_EPOCHS,
    bf16=BF16,
    # vLLM
    use_vllm=USE_VLLM,
    vllm_mode=VLLM_MODE,
    # Logging/saving
    logging_steps=LOGGING_STEPS,
    save_steps=SAVE_STEPS,
    save_total_limit=3,
    report_to="none",  # Set to "wandb" for monitoring
)

print("LoRA config:")
print(f"  r={peft_config.r}, alpha={peft_config.lora_alpha}, dropout={peft_config.lora_dropout}")
print(f"  Target modules: {peft_config.target_modules}")
print()
print("GSPO training config:")
print(f"  importance_sampling_level: {training_args.importance_sampling_level}")
print(f"  beta={training_args.beta}, epsilon={training_args.epsilon}")
print(f"  num_generations={training_args.num_generations}")
print(f"  lr={training_args.learning_rate}, batch={training_args.per_device_train_batch_size}")

In [None]:
# ============================================================
# Initialize GRPOTrainer and train
# ============================================================

from trl import GRPOTrainer

trainer = GRPOTrainer(
    model=MODEL_NAME,
    reward_funcs=[
        semantic_reward,
        language_reward,
        length_reward,
        accuracy_reward,
    ],
    reward_weights=REWARD_WEIGHTS,
    train_dataset=dataset,
    args=training_args,
    peft_config=peft_config,
)

print(f"Trainer initialized. Dataset: {len(dataset)} rows")
print("Starting training...")
print("=" * 60)

train_result = trainer.train()

print("=" * 60)
print("Training complete!")
print(f"  Total steps: {train_result.global_step}")
print(f"  Training loss: {train_result.training_loss:.4f}")

In [None]:
# ============================================================
# Save LoRA adapter, print metrics, generate samples
# ============================================================

import torch
from transformers import AutoTokenizer

# Save adapter
trainer.save_model(OUTPUT_DIR)
print(f"Model saved to: {OUTPUT_DIR}")

# Training metrics
metrics = train_result.metrics
print("\nTraining metrics:")
for key, value in sorted(metrics.items()):
    print(f"  {key}: {value}")

# Sample Norwegian outputs
print("\n" + "=" * 60)
print("Sample generations:")
print("=" * 60)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
eval_prompts = [
    "Forklar hva fotosyntese er med enkle ord.",
    "Skriv tre fordeler med regelmessig trening.",
    "Hva er forskjellen mellom vær og klima?",
]

for prompt_text in eval_prompts:
    messages = [{"role": "user", "content": prompt_text}]
    formatted = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    inputs = tokenizer(formatted, return_tensors="pt").to(trainer.model.device)

    with torch.inference_mode():
        outputs = trainer.model.generate(
            **inputs, max_new_tokens=256, do_sample=True,
            temperature=0.7, top_p=0.9,
        )

    generated = tokenizer.decode(
        outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True
    )
    print(f"\nPrompt: {prompt_text}")
    print(f"Response: {generated[:300]}")
    print("-" * 40)

print(f"\nLoRA adapter saved at: {OUTPUT_DIR}")
print("Load with: PeftModel.from_pretrained(base_model, '{}')".format(OUTPUT_DIR))

# Hardware Requirements & Tips

GSPO is more demanding than SFT because it generates G completions per prompt online.

| Setup | Configuration | Notes |
|-------|---------------|-------|
| **Minimum** | 1x A100 80GB | LoRA + 4-bit quant + vLLM colocate, G=4 |
| **Recommended** | 2x A100 80GB | LoRA + BF16, vLLM server on 2nd GPU, G=8 |
| **Comfortable** | 4x A100 80GB | ZeRO Stage 3 + vLLM, G=16, larger batches |

Since Qwen3.5-35B-A3B only activates 3B parameters per token (sparse MoE), memory and compute are much lower than a dense 35B model.

## Tips

- **4-bit quantization:** Add `load_in_4bit=True` if VRAM is tight and install `bitsandbytes`
- **Reduce `num_generations`:** Lower from 8 to 4 if OOM during generation
- **Disable vLLM:** Set `USE_VLLM = False` if vLLM causes issues (slower but functional)
- **Monitoring:** Set `report_to="wandb"` in GRPOConfig for real-time training curves
- **Resume training:** Pass `resume_from_checkpoint=True` to `trainer.train()` if interrupted

## GSPO vs GRPO

| Aspect | GRPO | GSPO |
|--------|------|------|
| Importance ratio | Token-level | Sequence-level |
| MoE stability | Requires Routing Replay | Inherently stable |
| Clipping epsilon | 0.1–0.2 (typical) | 3e-4 (much tighter) |
| TRL parameter | `importance_sampling_level="token"` | `importance_sampling_level="sequence"` |

## References

- [GSPO Paper](https://arxiv.org/abs/2507.18071)
- [Qwen Blog: GSPO](https://qwenlm.github.io/blog/gspo/)
- [TRL: GRPOTrainer](https://huggingface.co/docs/trl/grpo_trainer)
- [Qwen3.5-35B-A3B](https://huggingface.co/Qwen/Qwen3.5-35B-A3B)