# Clarity Classification with GRPO (Reinforcement Learning)

This notebook trains a Qwen3-4B model using:
1. **Phase 1 (SFT)**: Pre-fine-tune on CoT dataset to learn reasoning format
2. **Phase 2 (GRPO)**: Reinforcement learning with classification accuracy rewards

Task: Classify political interview responses as:
- Clear Reply
- Clear Non-Reply  
- Ambivalent

### Installation

In [None]:
# =============================================================================
# FIX: Clean reinstall of typing_extensions to fix corrupted package
# =============================================================================
print("=== Step 1: Completely removing typing_extensions ===")
!pip uninstall typing_extensions -y
!pip cache purge
# Remove any stale files
!rm -rf /usr/local/lib/python3.11/dist-packages/typing_extensions*
!rm -rf /root/.cache/pip
print()

print("=== Step 2: Fresh install of typing_extensions ===")
!pip install --no-cache-dir "typing_extensions>=4.10.0"
print()

print("=== Step 3: Verify installation ===")
!pip show typing_extensions | grep -E "Version|Location"
# Check the actual package structure
!python3 -c "import typing_extensions; print('Package file:', typing_extensions.__file__); print('Has TypeIs:', hasattr(typing_extensions, 'TypeIs'))"
print()

import os
os.environ["UNSLOTH_VLLM_STANDBY"] = "1"

# Detect environment
is_colab = "COLAB_" in "".join(os.environ.keys())
print(f"=== Environment: {'Google Colab' if is_colab else 'Local/Other'} ===")
print()

if not is_colab:
    print("Installing unsloth and vllm...")
    !pip install unsloth vllm
else:
    print("Installing via uv (Colab path)...")
    !pip install --upgrade -qqq uv
    !uv pip install vllm==0.11.2 unsloth-zoo unsloth
    !uv pip install transformers==4.56.2
    !uv pip install --no-deps trl==0.22.2

print()
print("=== Step 4: Re-verify typing_extensions after other installs ===")
!pip show typing_extensions | grep -E "Version|Location"
!python3 -c "import typing_extensions; print('Has TypeIs:', hasattr(typing_extensions, 'TypeIs'))"
print()

# If still broken, force reinstall again
print("=== Step 5: Final verification/fix ===")
!pip install --no-cache-dir --force-reinstall "typing_extensions>=4.10.0"
print()

# Final test
print("=== FINAL TEST: TypeIs import ===")
try:
    # Force Python to reload the module
    import importlib
    import typing_extensions
    importlib.reload(typing_extensions)
    from typing_extensions import TypeIs
    print("SUCCESS: TypeIs is importable!")
except ImportError as e:
    print(f"FAILED: {e}")
    print("\n*** IMPORTANT: You MUST restart the runtime now! ***")
    print("Go to: Runtime -> Restart runtime, then run this cell again.")

### Load Model with Unsloth

In [None]:
from unsloth import FastLanguageModel
import torch

max_seq_length = 4096  # Increased for longer CoT reasoning
lora_rank = 32

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/Qwen3-14B-unsloth-bnb-4bit",
    max_seq_length=max_seq_length,
    load_in_4bit=True,
    fast_inference=False,  # Disabled due to vLLM/LoRA version incompatibility
    max_lora_rank=lora_rank,
    # gpu_memory_utilization=0.9,  # Only needed with fast_inference=True
)

model = FastLanguageModel.get_peft_model(
    model,
    r=lora_rank,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha=lora_rank * 2,
    use_gradient_checkpointing="unsloth",
    random_state=3407,
)

### Chat Template for Clarity Classification

We use a chat template that encourages step-by-step reasoning before outputting the final classification label.

In [None]:
# Classification format markers
solution_start = "LABEL: "

# System prompt for the task
system_prompt = "You are an expert political discourse analyst. Analyze political interviews step by step and classify response clarity."

# Valid labels for classification
VALID_LABELS = ["Clear Reply", "Clear Non-Reply", "Ambivalent"]

print(f"System prompt: {system_prompt}")
print(f"Valid labels: {VALID_LABELS}")

In [None]:
# Chat template for clarity classification
# This template handles multi-turn conversations with system/user/assistant roles
chat_template = \
    "{% if messages[0]['role'] == 'system' %}" \
        "{{ messages[0]['content'] + eos_token }}" \
        "{% set loop_messages = messages[1:] %}" \
    "{% else %}" \
        "{{ '" + system_prompt + "' + eos_token }}" \
        "{% set loop_messages = messages %}" \
    "{% endif %}" \
    "{% for message in loop_messages %}" \
        "{% if message['role'] == 'user' %}" \
            "{{ message['content'] }}" \
        "{% elif message['role'] == 'assistant' %}" \
            "{{ message['content'] + eos_token }}" \
        "{% endif %}" \
    "{% endfor %}" \
    "{% if add_generation_prompt %}{{ '' }}{% endif %}"

tokenizer.chat_template = chat_template

In [None]:
# Test the chat template
test_messages = [
    {"role": "system", "content": system_prompt},
    {"role": "user", "content": "Classify this response...\n\nQuestion: What is your policy?\nAnswer: We are working on it."},
    {"role": "assistant", "content": "Step 1 - The question asks for a specific policy.\nStep 2 - The answer is vague and does not provide details.\n\nLABEL: Ambivalent"},
]

print(tokenizer.apply_chat_template(test_messages, tokenize=False))

---
## Phase 1: Supervised Fine-Tuning (SFT) on CoT Dataset

First, we pre-fine-tune the model on the Chain-of-Thought dataset to learn the reasoning format.

In [None]:
from datasets import load_dataset
import pandas as pd

# Load the CoT dataset
cot_dataset = load_dataset("json", data_files="cot_data/train_cot.jsonl", split="train")
print(f"Loaded {len(cot_dataset)} CoT examples")
print(f"Features: {cot_dataset.features}")

In [None]:
# Look at the first example
print("Example conversation:")
for msg in cot_dataset[0]["conversations"]:
    print(f"\n[{msg['role'].upper()}]:")
    print(msg['content'][:500] + "..." if len(msg['content']) > 500 else msg['content'])

In [None]:
# Format CoT dataset for SFT
def format_cot_for_sft(example):
    """Convert conversations to tokenized text for SFT"""
    conversations = example["conversations"]
    text = tokenizer.apply_chat_template(conversations, tokenize=False)
    return {"text": text}

sft_dataset = cot_dataset.map(format_cot_for_sft)
print(f"SFT dataset size: {len(sft_dataset)}")

In [None]:
# Check token lengths and filter out very long examples
def get_token_length(example):
    return {"token_length": len(tokenizer.encode(example["text"]))}

sft_dataset = sft_dataset.map(get_token_length)

import numpy as np
lengths = np.array(sft_dataset["token_length"])
print(f"Token length stats: min={lengths.min()}, max={lengths.max()}, mean={lengths.mean():.0f}, median={np.median(lengths):.0f}")

# Filter to keep examples that fit in context
max_sft_length = min(max_seq_length, int(np.percentile(lengths, 95)))
print(f"Using max SFT length: {max_sft_length}")

sft_dataset = sft_dataset.filter(lambda x: x["token_length"] <= max_sft_length)
print(f"Filtered SFT dataset size: {len(sft_dataset)}")

In [None]:
from trl import SFTTrainer, SFTConfig

sft_trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=sft_dataset,
    args=SFTConfig(
        dataset_text_field="text",
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        warmup_steps=10,
        num_train_epochs=1,  # 1 epoch for format learning
        learning_rate=2e-5,
        logging_steps=10,
        optim="adamw_8bit",
        weight_decay=0.001,
        lr_scheduler_type="linear",
        seed=3407,
        output_dir="outputs_sft",
        report_to="none",
    ),
)

In [None]:
# Run SFT training
print("Starting SFT training...")
sft_trainer.train()
print("SFT training complete!")

In [None]:
# Test if model learned the format
test_prompt = cot_dataset[0]["conversations"][:2]  # System + user only
text = tokenizer.apply_chat_template(test_prompt, tokenize=False, add_generation_prompt=True)

from transformers import TextStreamer
print("Testing model output after SFT:")
print("=" * 50)
_ = model.generate(
    **tokenizer(text, return_tensors="pt").to("cuda"),
    temperature=0.7,
    max_new_tokens=1024,
    streamer=TextStreamer(tokenizer, skip_prompt=True),
)

In [None]:
# Cleanup before GRPO
del sft_dataset, sft_trainer
torch.cuda.empty_cache()
import gc
gc.collect()

---
## Phase 2: GRPO (Reinforcement Learning)

Now we use GRPO with classification accuracy as the reward signal.

In [None]:
# Load the training data with ground truth labels for GRPO
grpo_raw_dataset = load_dataset("parquet", data_files="data/train/train.parquet", split="train")
print(f"Loaded {len(grpo_raw_dataset)} examples for GRPO")
print(f"Features: {grpo_raw_dataset.features}")

In [None]:
# Look at the data structure
print("Example conversation structure:")
example = grpo_raw_dataset[0]
for msg in example["conversations"]:
    print(f"  [{msg['role']}]: {msg['content'][:100]}...")

In [None]:
def prepare_grpo_dataset(example):
    """Convert dataset to GRPO format with prompts and ground truth answers"""
    conversations = example["conversations"]
    
    # Extract ground truth label (assistant's response)
    answer = None
    for msg in conversations:
        if msg["role"] == "assistant":
            answer = msg["content"].strip()
            break
    
    # Extract prompt (system + user messages only)
    prompt = [msg for msg in conversations if msg["role"] != "assistant"]
    
    return {
        "prompt": prompt,
        "answer": answer
    }

grpo_dataset = grpo_raw_dataset.map(prepare_grpo_dataset)
print(f"GRPO dataset prepared: {len(grpo_dataset)} examples")

In [None]:
# Check label distribution
from collections import Counter
label_counts = Counter(grpo_dataset["answer"])
print("Label distribution:")
for label, count in label_counts.most_common():
    print(f"  {label}: {count}")

In [None]:
# Filter to valid labels only
grpo_dataset = grpo_dataset.filter(lambda x: x["answer"] in VALID_LABELS)
print(f"Filtered GRPO dataset: {len(grpo_dataset)} examples with valid labels")

### Reward Functions for Classification Accuracy

In [None]:
import re

# Regex to extract the label from model output
label_pattern = re.compile(
    r"LABEL:\s*(Clear Reply|Clear Non-Reply|Ambivalent)",
    flags=re.IGNORECASE
)

def extract_label(text):
    """Extract the classification label from model output"""
    match = label_pattern.search(text)
    if match:
        # Normalize the label
        label = match.group(1).strip()
        label_map = {
            "clear reply": "Clear Reply",
            "clear non-reply": "Clear Non-Reply",
            "ambivalent": "Ambivalent"
        }
        return label_map.get(label.lower(), label)
    return None

# Test the extraction
test_outputs = [
    "Step 1... Step 2... LABEL: Clear Reply",
    "Some reasoning here.\n\nLABEL: Ambivalent",
    "The answer is Clear Non-Reply based on...",  # No LABEL prefix
    "LABEL: clear reply",  # lowercase
]

for out in test_outputs:
    print(f"Input: '{out[:50]}...' -> Extracted: {extract_label(out)}")

In [None]:
# Global counter for printing
global PRINTED_TIMES
PRINTED_TIMES = 0
global PRINT_EVERY_STEPS
PRINT_EVERY_STEPS = 10

def classification_accuracy_reward(prompts, completions, answer, **kwargs):
    """
    Reward function based on classification accuracy.
    
    Rewards:
    - +5.0: Exact match (correct label)
    - +4.0: Case-insensitive match
    - -1.0: Valid label but wrong
    - -3.0: Invalid label format
    - -4.0: No label found in output
    """
    scores = []
    
    global PRINTED_TIMES
    global PRINT_EVERY_STEPS
    
    for i, (completion, true_label) in enumerate(zip(completions, answer)):
        response = completion[0]["content"]
        predicted = extract_label(response)
        
        # Print debug info periodically
        if i == 0 and PRINTED_TIMES % PRINT_EVERY_STEPS == 0:
            print("=" * 50)
            print(f"True label: {true_label}")
            print(f"Predicted: {predicted}")
            print(f"Response (last 300 chars): ...{response[-300:]}")
            print("=" * 50)
        
        if i == 0:
            PRINTED_TIMES += 1
        
        # Compute reward
        if predicted is None:
            scores.append(-4.0)  # No label found
        elif predicted == true_label:
            scores.append(5.0)  # Exact match
        elif predicted.lower() == true_label.lower():
            scores.append(4.0)  # Case-insensitive match
        elif predicted in VALID_LABELS:
            scores.append(-1.0)  # Valid label but wrong
        else:
            scores.append(-3.0)  # Invalid label
    
    return scores

### Configure and Run GRPO Training

In [None]:
# Calculate max prompt length based on dataset
def get_prompt_length(example):
    text = tokenizer.apply_chat_template(example["prompt"], tokenize=False, add_generation_prompt=True)
    return {"prompt_length": len(tokenizer.encode(text))}

grpo_dataset = grpo_dataset.map(get_prompt_length)

prompt_lengths = np.array(grpo_dataset["prompt_length"])
print(f"Prompt length stats: min={prompt_lengths.min()}, max={prompt_lengths.max()}, mean={prompt_lengths.mean():.0f}")

max_prompt_length = int(np.percentile(prompt_lengths, 90)) + 10
max_completion_length = max_seq_length - max_prompt_length

print(f"Max prompt length: {max_prompt_length}")
print(f"Max completion length: {max_completion_length}")

# Filter out prompts that are too long
grpo_dataset = grpo_dataset.filter(lambda x: x["prompt_length"] <= max_prompt_length)
print(f"Filtered GRPO dataset: {len(grpo_dataset)} examples")

In [None]:
from trl import GRPOConfig, GRPOTrainer

# GRPO training configuration (without vLLM - using standard HF generation)
training_args = GRPOConfig(
    # Sampling parameters
    temperature=0.8,
    top_p=0.95,
    top_k=50,
    # Training parameters
    learning_rate=5e-6,
    weight_decay=0.001,
    warmup_ratio=0.1,
    lr_scheduler_type="linear",
    optim="adamw_8bit",
    logging_steps=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=2,
    num_generations=4,  # Number of completions per prompt
    max_prompt_length=max_prompt_length,
    max_completion_length=max_completion_length,
    max_steps=500,  # Adjust based on your needs
    save_steps=100,
    report_to="none",
    output_dir="outputs_grpo",
)

In [None]:
# Create GRPO trainer
grpo_trainer = GRPOTrainer(
    model=model,
    args=training_args,
    tokenizer=tokenizer,
    train_dataset=grpo_dataset,
    reward_funcs=[classification_accuracy_reward],
)

In [None]:
# Run GRPO training
print("Starting GRPO training...")
print("Watch the 'reward' column - it should increase over time!")
print("")
grpo_trainer.train()

---
## Inference: Test the Trained Model

In [None]:
# Test on a few examples
test_examples = grpo_dataset.select(range(min(3, len(grpo_dataset))))

for i, example in enumerate(test_examples):
    print(f"\n{'='*60}")
    print(f"Example {i+1}")
    print(f"{'='*60}")
    
    # Get the user message content for context
    user_content = [m["content"] for m in example["prompt"] if m["role"] == "user"][0]
    print(f"Question excerpt: {user_content[:200]}...")
    print(f"\nTrue label: {example['answer']}")
    
    # Generate
    text = tokenizer.apply_chat_template(
        example["prompt"],
        tokenize=False,
        add_generation_prompt=True
    )
    
    inputs = tokenizer(text, return_tensors="pt").to("cuda")
    outputs = model.generate(
        **inputs,
        temperature=0.7,
        max_new_tokens=1024,
        do_sample=True,
    )
    
    response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
    predicted = extract_label(response)
    
    print(f"\nModel output (last 500 chars):")
    print(f"...{response[-500:]}")
    print(f"\nExtracted label: {predicted}")
    print(f"Correct: {predicted == example['answer']}")

---
## Save the Model

In [None]:
# Save LoRA adapters
model.save_pretrained("clarity_grpo_lora")
tokenizer.save_pretrained("clarity_grpo_lora")
print("Saved LoRA adapters to clarity_grpo_lora/")

In [None]:
# Optional: Merge and save as full model
# model.save_pretrained_merged("clarity_grpo_merged", tokenizer, save_method="merged_16bit")
# print("Saved merged model to clarity_grpo_merged/")

In [None]:
# Optional: Push to Hugging Face Hub
# model.push_to_hub_merged("your-username/clarity-grpo-qwen3-4b", tokenizer, save_method="merged_16bit")