# Phase 2b: GSPO Training on Qwen-Distilled Dataset

This notebook performs **GSPO alignment training** on **Qwen/Qwen3.5-35B-A3B** using the Qwen-distilled + Borealis-polished dataset from Phase 1b.

## How this differs from the standard GSPO notebook

| Aspect | Standard (Phase 2) | Distilled (Phase 2b) |
|--------|--------------------|----------------------|
| Reference outputs | gpt-3.5-turbo, improved by Borealis | Qwen-generated, polished by Borealis |
| Content quality | Mediocre (gpt-3.5 era) | High (Qwen3.5 knowledge) |
| Norwegian quality | Good (Borealis-improved) | Good (Borealis-polished) |
| Training signal | Learns from weaker reference | Self-distillation — refines its own outputs |

Using Qwen's own outputs as reference creates a **self-distillation** loop: the model is rewarded for producing outputs similar to its best (polished) generations. This focuses GSPO on Norwegian fluency rather than content improvement.

## GSPO Recap

GSPO uses `importance_sampling_level="sequence"` in TRL's `GRPOTrainer` for stable MoE training with sequence-level importance ratios.

In [None]:
# Install the library with GSPO extras
# On Colab:
#   !git clone https://github.com/your-username/NORAI-Tools.git /content/NORAI-Tools
#   %pip install -e "/content/NORAI-Tools[gspo]"

%pip install -e "../[gspo]"
%pip install trl>=0.28.0 peft>=0.15.0 accelerate bitsandbytes

from norai_tools import (
    semantic_reward,
    language_reward,
    length_reward,
    accuracy_reward,
    prepare_gspo_dataset,
    validate_gspo_dataset,
)
from datasets import load_dataset
import pandas as pd

In [None]:
# ============================================================
# Configuration
# ============================================================

# Model
MODEL_NAME = "Qwen/Qwen3.5-35B-A3B"

# Dataset (from Phase 1b distillation notebook)
DISTILLED_DATASET_PATH = "norwegian_alpaca_qwen_polished.parquet"

# GSPO-specific settings
IMPORTANCE_SAMPLING_LEVEL = "sequence"
LOSS_TYPE = "grpo"
BETA = 0.04
EPSILON = 3e-4
NUM_GENERATIONS = 8

# Training hyperparams
LEARNING_RATE = 5e-7
PER_DEVICE_TRAIN_BATCH_SIZE = 2
GRADIENT_ACCUMULATION_STEPS = 4
MAX_COMPLETION_LENGTH = 512
NUM_TRAIN_EPOCHS = 1
LOGGING_STEPS = 10
SAVE_STEPS = 200
BF16 = True
USE_VLLM = True
VLLM_MODE = "colocate"

# LoRA config
LORA_R = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0.05
LORA_TARGET_MODULES = ["q_proj", "v_proj", "k_proj", "o_proj"]

# Reward weights: [semantic, language, length, accuracy]
REWARD_WEIGHTS = [2.0, 1.0, 0.5, 1.5]

# Output
OUTPUT_DIR = "./gspo_qwen_norwegian_distilled"

print(f"Model: {MODEL_NAME}")
print(f"Dataset: {DISTILLED_DATASET_PATH}")
print(f"GSPO: importance_sampling_level={IMPORTANCE_SAMPLING_LEVEL}, beta={BETA}, epsilon={EPSILON}")
print(f"Output: {OUTPUT_DIR}")

In [None]:
# ============================================================
# Load distilled dataset, map columns, prepare for GSPO
# ============================================================

from collections import Counter

dataset = load_dataset("parquet", data_files=DISTILLED_DATASET_PATH, split="train")
print(f"Loaded: {len(dataset)} rows")
print(f"Columns: {dataset.column_names}")

# Verify distilled columns exist
required = ["output_qwen_polished", "instruction_improved", "input_improved", "instruction_en"]
missing = [c for c in required if c not in dataset.column_names]
if missing:
    raise ValueError(
        f"Missing columns: {missing}. "
        "Run Phase 1b (distill_qwen_norwegian.ipynb) first."
    )

# Map output_qwen_polished → output_improved for reward function compatibility.
# Keep the original Phase 1 output_improved as output_improved_phase1 for reference.
if "output_improved" in dataset.column_names:
    dataset = dataset.rename_column("output_improved", "output_improved_phase1")
dataset = dataset.rename_column("output_qwen_polished", "output_improved")
print("\nRenamed: output_qwen_polished → output_improved (for reward functions)")

# Prepare GSPO columns (prompt + task_type)
dataset = prepare_gspo_dataset(dataset)

# Validate
validation = validate_gspo_dataset(dataset)
print(f"\nValid: {validation['is_valid']}")
print(f"Empty prompts: {validation['empty_prompts']}")

# Task type distribution
print("\nTask type distribution:")
for task, count in sorted(validation["task_type_distribution"].items(), key=lambda x: -x[1]):
    print(f"  {task:20s}: {count:6d} ({100*count/len(dataset):.1f}%)")

In [None]:
# ============================================================
# Pre-load semantic similarity model
# ============================================================

print("Loading semantic similarity model (intfloat/multilingual-e5-large)...")
_ = semantic_reward(
    completions=[[{"content": "test"}]],
    output_improved=["test"],
)
print("Similarity model loaded.")

print("\nReward functions:")
names = ["semantic_reward", "language_reward", "length_reward", "accuracy_reward"]
for name, weight in zip(names, REWARD_WEIGHTS):
    print(f"  {name:20s}  weight={weight}")

In [None]:
# ============================================================
# Setup LoRA + GRPOConfig with GSPO settings
# ============================================================

from peft import LoraConfig
from trl import GRPOConfig

peft_config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    target_modules=LORA_TARGET_MODULES,
    task_type="CAUSAL_LM",
)

training_args = GRPOConfig(
    output_dir=OUTPUT_DIR,
    # GSPO-specific
    importance_sampling_level=IMPORTANCE_SAMPLING_LEVEL,
    loss_type=LOSS_TYPE,
    beta=BETA,
    epsilon=EPSILON,
    # Generation
    num_generations=NUM_GENERATIONS,
    max_completion_length=MAX_COMPLETION_LENGTH,
    # Training
    per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    learning_rate=LEARNING_RATE,
    num_train_epochs=NUM_TRAIN_EPOCHS,
    bf16=BF16,
    # vLLM
    use_vllm=USE_VLLM,
    vllm_mode=VLLM_MODE,
    # Logging/saving
    logging_steps=LOGGING_STEPS,
    save_steps=SAVE_STEPS,
    save_total_limit=3,
    report_to="none",
)

print("LoRA config:")
print(f"  r={peft_config.r}, alpha={peft_config.lora_alpha}, dropout={peft_config.lora_dropout}")
print(f"  Target modules: {peft_config.target_modules}")
print()
print("GSPO training config:")
print(f"  importance_sampling_level: {training_args.importance_sampling_level}")
print(f"  beta={training_args.beta}, epsilon={training_args.epsilon}")
print(f"  num_generations={training_args.num_generations}")
print(f"  lr={training_args.learning_rate}, batch={training_args.per_device_train_batch_size}")

In [None]:
# ============================================================
# Initialize GRPOTrainer and train
# ============================================================

from trl import GRPOTrainer

trainer = GRPOTrainer(
    model=MODEL_NAME,
    reward_funcs=[
        semantic_reward,
        language_reward,
        length_reward,
        accuracy_reward,
    ],
    reward_weights=REWARD_WEIGHTS,
    train_dataset=dataset,
    args=training_args,
    peft_config=peft_config,
)

print(f"Trainer initialized. Dataset: {len(dataset)} rows")
print("Starting training...")
print("=" * 60)

train_result = trainer.train()

print("=" * 60)
print("Training complete!")
print(f"  Total steps: {train_result.global_step}")
print(f"  Training loss: {train_result.training_loss:.4f}")

In [None]:
# ============================================================
# Save LoRA adapter, print metrics, generate samples
# ============================================================

import torch
from transformers import AutoTokenizer

trainer.save_model(OUTPUT_DIR)
print(f"Model saved to: {OUTPUT_DIR}")

metrics = train_result.metrics
print("\nTraining metrics:")
for key, value in sorted(metrics.items()):
    print(f"  {key}: {value}")

# Sample Norwegian outputs
print("\n" + "=" * 60)
print("Sample generations:")
print("=" * 60)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
eval_prompts = [
    "Forklar hva fotosyntese er med enkle ord.",
    "Skriv tre fordeler med regelmessig trening.",
    "Hva er forskjellen mellom vær og klima?",
]

for prompt_text in eval_prompts:
    messages = [{"role": "user", "content": prompt_text}]
    formatted = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    inputs = tokenizer(formatted, return_tensors="pt").to(trainer.model.device)

    with torch.inference_mode():
        outputs = trainer.model.generate(
            **inputs, max_new_tokens=256, do_sample=True,
            temperature=0.7, top_p=0.9,
        )

    generated = tokenizer.decode(
        outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True
    )
    print(f"\nPrompt: {prompt_text}")
    print(f"Response: {generated[:300]}")
    print("-" * 40)

print(f"\nLoRA adapter saved at: {OUTPUT_DIR}")
print("Load with: PeftModel.from_pretrained(base_model, '{}')".format(OUTPUT_DIR))

# Hardware Requirements & Tips

Same requirements as the standard GSPO notebook (Phase 2).

| Setup | Configuration | Notes |
|-------|---------------|-------|
| **Minimum** | 1x A100 80GB | LoRA + 4-bit quant + vLLM colocate, G=4 |
| **Recommended** | 2x A100 80GB | LoRA + BF16, vLLM server on 2nd GPU, G=8 |
| **Comfortable** | 4x A100 80GB | ZeRO Stage 3 + vLLM, G=16, larger batches |

## Self-Distillation Notes

Since the reference outputs were generated by the same Qwen model being trained:

- **Semantic similarity rewards will be high from the start** — the model already "knows" how to produce similar content
- **Language reward is the key differentiator** — it pushes the model toward more natural Norwegian
- Consider **increasing the language reward weight** (e.g., `[1.5, 2.0, 0.5, 1.0]`) to emphasize Norwegian fluency
- Monitor reward curves — if all rewards plateau quickly, reduce learning rate or increase epsilon

## References

- [GSPO Paper](https://arxiv.org/abs/2507.18071)
- [TRL: GRPOTrainer](https://huggingface.co/docs/trl/grpo_trainer)
- [Qwen3.5-35B-A3B](https://huggingface.co/Qwen/Qwen3.5-35B-A3B)