# Phase 1.5: Prepare GSPO Dataset

This notebook transforms the improved Norwegian Alpaca dataset (from Phase 1) into a GSPO-ready format for alignment training.

**What this adds:**
- `prompt` — Chat-formatted prompt (`[{"role": "user", "content": ...}]`) compatible with TRL's `GRPOTrainer` and Qwen3.5's chat template
- `task_type` — Heuristic classification of each instruction (qa, generation, classification, extraction, rewriting, creative, other) used for task-specific reward routing

**Input:** `norwegian_alpaca_improved.parquet` (from Phase 1)
**Output:** `norwegian_alpaca_gspo.parquet` (ready for Phase 2)

In [None]:
# Install the library (editable mode)
# On Colab, clone the repo first:
#   !git clone https://github.com/your-username/NORAI-Tools.git /content/NORAI-Tools
#   %pip install -e /content/NORAI-Tools

%pip install -e ..

from norai_tools import (
    prepare_gspo_dataset,
    validate_gspo_dataset,
    classify_task_type,
    OUTPUT_FILE,
)
from datasets import load_dataset
import pandas as pd

In [None]:
# ============================================================
# Configuration + load improved dataset from Phase 1
# ============================================================

IMPROVED_DATASET_PATH = OUTPUT_FILE  # "norwegian_alpaca_improved.parquet"
GSPO_OUTPUT_PATH = "norwegian_alpaca_gspo.parquet"

# Optional: HuggingFace Hub push
PUSH_TO_HUB = False
HUB_REPO_ID = "your-username/norwegian-alpaca-gspo"  # Change this

dataset = load_dataset("parquet", data_files=IMPROVED_DATASET_PATH, split="train")

print(f"Loaded: {len(dataset)} rows")
print(f"Columns: {dataset.column_names}")

# Verify required columns from Phase 1
required = ["instruction_improved", "input_improved", "output_improved", "instruction_en"]
missing = [c for c in required if c not in dataset.column_names]
if missing:
    raise ValueError(f"Missing columns from Phase 1: {missing}")
print("All required columns present.")

In [None]:
# ============================================================
# Prepare GSPO dataset (add prompt + task_type columns)
# ============================================================

gspo_dataset = prepare_gspo_dataset(dataset)

print(f"GSPO dataset: {len(gspo_dataset)} rows")
print(f"Columns: {gspo_dataset.column_names}")
print(f"\nSample prompt (row 0):")
print(gspo_dataset[0]["prompt"])

In [None]:
# ============================================================
# Validate the GSPO dataset
# ============================================================

validation = validate_gspo_dataset(gspo_dataset)

print(f"Total rows:      {validation['total_rows']}")
print(f"Empty prompts:   {validation['empty_prompts']}")
print(f"Missing columns: {validation['missing_columns']}")
print(f"Valid:           {validation['is_valid']}")

if not validation["is_valid"]:
    print("\nWARNING: Dataset failed validation. Check issues above before proceeding.")

In [None]:
# ============================================================
# Inspect task_type distribution + sample prompts
# ============================================================

import matplotlib.pyplot as plt

dist = validation["task_type_distribution"]
sorted_dist = dict(sorted(dist.items(), key=lambda x: -x[1]))

print("Task type distribution:")
for task, count in sorted_dist.items():
    print(f"  {task:20s}: {count:6d} ({100*count/len(gspo_dataset):.1f}%)")

# Bar chart
fig, ax = plt.subplots(figsize=(8, 4))
ax.barh(list(sorted_dist.keys()), list(sorted_dist.values()))
ax.set_xlabel("Count")
ax.set_title("Task Type Distribution")
plt.tight_layout()
plt.show()

# Show sample prompts per task type
print("\nSample prompts by task type:")
shown = set()
for row in gspo_dataset:
    task = row["task_type"]
    if task not in shown:
        shown.add(task)
        content = row["prompt"][0]["content"][:120]
        print(f"  [{task}] {content}...")
    if len(shown) == len(sorted_dist):
        break

In [None]:
# ============================================================
# Save to parquet + optional Hub push
# ============================================================

gspo_dataset.to_parquet(GSPO_OUTPUT_PATH)
print(f"Saved GSPO dataset to: {GSPO_OUTPUT_PATH}")
print(f"  Rows: {len(gspo_dataset)}")
print(f"  Columns: {gspo_dataset.column_names}")

if PUSH_TO_HUB:
    gspo_dataset.push_to_hub(HUB_REPO_ID, private=True)
    print(f"Pushed to Hub: {HUB_REPO_ID}")
else:
    print("\nSet PUSH_TO_HUB = True and update HUB_REPO_ID to push to the Hub.")