# Stage 3: Constitutional AI Training (v2 - Improved)

**Date**: October 7, 2025  
**Status**: Fully updated with Phases 1-4 improvements  
**Runtime**: ~4-5 hours on T4 GPU

## Key Improvements in v2

- ✅ **Robust reward model loading** (no heuristic fallback)
- ✅ **Automatic data quality filtering** (removes identical pairs)
- ✅ **Enhanced prompts** with few-shot examples
- ✅ **Meta-commentary detection** and validation
- ✅ **Principle tracking** and usage reporting
- ✅ **Training monitoring** with progress callbacks
- ✅ **Checkpoint resume** for disconnections
- ✅ **Quality analysis tools** for validation

## Prerequisites

- Stage 2 LoRA adapters in Google Drive
- T4 GPU runtime (Runtime > Change runtime type > T4 GPU)
- ~4-5 hours of Colab time

---

## Cell 1: Setup and Mount Drive

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Verify GPU
!nvidia-smi

# Setup directories
!mkdir -p /content/ml-learning
%cd /content/ml-learning

print("\n✓ Drive mounted and directories created")

## Cell 2: Clone Repository and Install Dependencies

In [None]:
# Clone repository (update with your repo URL)
!git clone https://github.com/Jai-Dhiman/ml-learning.git .

# Or pull latest if already cloned
# !git pull origin main

# Install dependencies using pip (Colab doesn't have uv by default)
!pip install -q transformers datasets accelerate peft trl torch sentencepiece protobuf

# Verify installations
import torch
print(f"\n✓ PyTorch: {torch.__version__}")
print(f"✓ CUDA available: {torch.cuda.is_available()}")
print(f"✓ GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}")

Cell 2.5: HF Login

In [None]:
import os
from huggingface_hub import login, HfApi

# Clear any existing tokens
os.environ.pop('HF_TOKEN', None)
os.environ.pop('HUGGINGFACEHUB_API_TOKEN', None)

try:
    import getpass as gp
    raw = gp.getpass("Paste your Hugging Face token (input hidden): ")
    token = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
    if not isinstance(token, str):
        raise TypeError(f"Unexpected token type: {type(token).__name__}")
    token = token.strip()
    if not token:
        raise ValueError("Empty token provided")
    
    # Login and set environment variable
    login(token=token, add_to_git_credential=False)
    os.environ['HF_TOKEN'] = token
    
    who = HfApi().whoami(token=token)
    print(f"✅ Logged in as: {who.get('name') or who.get('email') or 'OK'}")
    print('HF_TOKEN environment variable set for bash cells.')
    
except Exception as e:
    print(f"[HF Login] getpass flow failed: {e}")
    print("Falling back to interactive login widget...")
    login()
    
    # Try to get token from saved credentials
    try:
        from huggingface_hub import HfFolder
        token = HfFolder.get_token()
        if token:
            os.environ['HF_TOKEN'] = token
            print('HF_TOKEN environment variable set from saved credentials.')
        who = HfApi().whoami()
        print(f"✅ Logged in as: {who.get('name') or who.get('email') or 'OK'}")
    except Exception as e2:
        print(f"[HF Login] Could not set HF_TOKEN env var: {e2}")
        print("You may need to run 'huggingface-cli login' in a bash cell.")

## Cell 3: Pre-download Models (Avoid Timeouts)

In [None]:
# Pre-download models to avoid timeout during training
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM

print("Downloading base model (Gemma 2B-IT)...")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it")
print("✓ Base model downloaded")

print("\nDownloading reward model...")
rm_tokenizer = AutoTokenizer.from_pretrained("OpenAssistant/reward-model-deberta-v3-large-v2")
rm_model = AutoModelForSequenceClassification.from_pretrained("OpenAssistant/reward-model-deberta-v3-large-v2")
print("✓ Reward model downloaded")

# Clean up to free memory
del model, rm_model, tokenizer, rm_tokenizer
import gc
gc.collect()
torch.cuda.empty_cache()

print("\n✓ All models downloaded and cached")

## Cell 4: Generate Critique-Revision Pairs

This will generate 2500 pairs with:
- Reward model scoring (no heuristic)
- Automatic quality filtering
- Few-shot examples in prompts
- Principle tracking

In [None]:
# Generate critique-revision pairs
!python critique-revision-system/src/training/critique_revision.py \
    --num-examples 833 \
    --output artifacts/stage3_v2/pairs/pairs.jsonl \
    --adapter-path /content/drive/MyDrive/ml-learning/artifacts/stage2_finetuning_artifacts/lora_adapters \
    --split "test[:333]+train[:500]" \
    --seed 42

print("\n✓ Pair generation complete")
print("\nCheck output above for:")
print("- Reward model loaded successfully (not heuristic)")
print("- Data quality filter stats")
print("- Principle usage distribution")

## Cell 5: Analyze Generated Pairs

Run quality analysis before training

In [None]:
# Run comprehensive quality analysis
!python critique-revision-system/scripts/analyze_pairs.py \
    artifacts/stage3_v2/pairs/pairs.jsonl

print("\n" + "="*80)
print("QUALITY CHECKLIST:")
print("="*80)
print("✓ Should see: Overall Quality >= 3/5")
print("✓ Should see: Identical pairs < 10%")
print("✓ Should see: Revised better rate > 40%")
print("✓ Should see: Average delta positive")
print("\nIf quality score < 3, review issues above before training.")

## Cell 5.5: Manual Inspection (Optional)

Manually review a few pairs

In [None]:
import json

# Load and inspect a few pairs
with open("artifacts/stage3_v2/pairs/pairs.jsonl") as f:
    pairs = [json.loads(line) for line in f if line.strip()]

print(f"Total pairs: {len(pairs)}\n")
print("=" * 80)
print("Sample Pairs:\n")

for i, p in enumerate(pairs[:3], 1):
    print(f"Pair {i}:")
    print(f"Prompt: {p['prompt'][:100]}...")
    print(f"Base: {p['base_response'][:100]}...")
    print(f"Revised: {p['revised_response'][:100]}...")
    print(f"Scores: {p['base_score']:.3f} -> {p['revised_score']:.3f}")
    print(f"Chosen: {p['chosen']}")
    print(f"Principles: {p.get('principle_ids', [])}")
    print("\n" + "-" * 80 + "\n")

## Cell 6: DPO Training

Train for 3 epochs with:
- Progress monitoring (ETA, metrics)
- Checkpoint saving every 50 steps
- Automatic resume if disconnected

In [None]:
# Run DPO training with improved configuration
!python critique-revision-system/src/training/train_dpo_stage3.py \
    --repo-root . \
    --pairs-path artifacts/stage3_v2/pairs/pairs.jsonl \
    --base-model-id google/gemma-2b-it \
    --stage2-adapter-path /content/drive/MyDrive/ml-learning/artifacts/stage2_finetuning_artifacts/lora_adapters \
    --output-dir artifacts/stage3_v2/constitutional \
    --per-device-train-batch-size 1 \
    --gradient-accumulation-steps 8 \
    --learning-rate 5e-5 \
    --num-train-epochs 3.0 \
    --beta 0.1 \
    --save-steps 50 \
    --logging-steps 10 \
    --seed 42

print("\n✓ DPO training complete")
print("\nCheck output above for:")
print("- Training completed ~486 steps (3 epochs)")
print("- DPO accuracy metrics")
print("- Reward margins")
print("- Final metrics summary")

## Cell 6.5: Resume Training (If Disconnected)

If training was interrupted, run this cell to resume

In [None]:
# This will automatically resume from the latest checkpoint
# Just re-run the same training command - it detects checkpoints automatically

!python critique-revision-system/src/training/train_dpo_stage3.py \
    --repo-root . \
    --pairs-path artifacts/stage3_v2/pairs/pairs.jsonl \
    --base-model-id google/gemma-2b-it \
    --stage2-adapter-path /content/drive/MyDrive/ml-learning/artifacts/stage2_finetuning_artifacts/lora_adapters \
    --output-dir artifacts/stage3_v2/constitutional \
    --per-device-train-batch-size 1 \
    --gradient-accumulation-steps 8 \
    --learning-rate 5e-5 \
    --num-train-epochs 3.0 \
    --beta 0.1 \
    --save-steps 50 \
    --logging-steps 10 \
    --seed 42

print("\n✓ Training resumed and completed")

## Cell 7: Save Artifacts to Google Drive

In [None]:
import shutil
from pathlib import Path

# Setup Drive paths
drive_path = "/content/drive/MyDrive/ml-learning/artifacts/stage3_v2"
!mkdir -p {drive_path}

print("Copying artifacts to Google Drive...")
print("This may take a few minutes...\n")

# Copy LoRA adapters (most important)
print("1. Copying LoRA adapters...")
shutil.copytree(
    "artifacts/stage3_v2/constitutional/models/lora_adapters",
    f"{drive_path}/lora_adapters",
    dirs_exist_ok=True
)
print("   ✓ LoRA adapters saved")

# Copy metrics
print("2. Copying metrics...")
shutil.copy(
    "artifacts/stage3_v2/constitutional/metrics.json",
    f"{drive_path}/metrics.json"
)
print("   ✓ Metrics saved")

# Copy pairs (for analysis later)
print("3. Copying training pairs...")
shutil.copy(
    "artifacts/stage3_v2/pairs/pairs.jsonl",
    f"{drive_path}/pairs.jsonl"
)
print("   ✓ Training pairs saved")

# Copy latest checkpoint (optional, for resume)
print("4. Copying latest checkpoint...")
checkpoint_dir = Path("artifacts/stage3_v2/constitutional/checkpoints")
if checkpoint_dir.exists():
    checkpoints = list(checkpoint_dir.glob("checkpoint-*"))
    if checkpoints:
        latest = max(checkpoints, key=lambda p: int(p.name.split("-")[1]))
        shutil.copytree(
            str(latest),
            f"{drive_path}/{latest.name}",
            dirs_exist_ok=True
        )
        print(f"   ✓ Checkpoint {latest.name} saved")

print(f"\n{'='*80}")
print("✓ All artifacts saved to Google Drive")
print(f"{'='*80}")
print(f"\nLocation: {drive_path}")
print("\nContents:")
print("  - lora_adapters/  (Stage 3 model - use for inference)")
print("  - metrics.json    (Training metrics)")
print("  - pairs.jsonl     (Training data)")
print("  - checkpoint-*/   (Latest checkpoint)")

## Cell 8: Training Summary and Validation

In [None]:
import json

# Load and display final metrics
with open("artifacts/stage3_v2/constitutional/metrics.json") as f:
    metrics = json.load(f)

print("=" * 80)
print("STAGE 3 TRAINING COMPLETE - SUMMARY")
print("=" * 80)

print("\n📊 Final Training Metrics:")
for key, value in metrics.items():
    if isinstance(value, float):
        print(f"  {key}: {value:.4f}")
    else:
        print(f"  {key}: {value}")

# Load trainer state to get DPO metrics
checkpoint_dir = Path("artifacts/stage3_v2/constitutional/checkpoints")
checkpoints = list(checkpoint_dir.glob("checkpoint-*/trainer_state.json"))
if checkpoints:
    latest_state = max(checkpoints, key=lambda p: int(p.parent.name.split("-")[1]))
    with open(latest_state) as f:
        state = json.load(f)
    
    print("\n📈 DPO Training Progress:")
    print(f"  Total steps: {state['global_step']}")
    print(f"  Epochs: {state['epoch']:.2f}")
    
    if state['log_history']:
        last_log = state['log_history'][-1]
        if 'rewards/accuracies' in last_log:
            print(f"  Final DPO accuracy: {last_log['rewards/accuracies']:.3f}")
        if 'rewards/margins' in last_log:
            print(f"  Final reward margin: {last_log['rewards/margins']:+.3f}")

print("\n" + "=" * 80)
print("✓ Stage 3 training complete!")
print("=" * 80)

print("\n📁 Artifacts saved to Google Drive:")
print("  MyDrive/ml-learning/artifacts/stage3_v2/")

print("\n🎯 Next Steps:")
print("  1. Run Stage 4 evaluation to compare vs Stage 2")
print("  2. Check if Stage 3 outperforms on harm prevention")
print("  3. Analyze aggregate win rate (target: > 72%)")
print("  4. Document results and improvements")

## Cell 9: Quick Validation Test (Optional)

Test the trained model on a few examples

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

print("Loading Stage 3 model for testing...")

# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2b-it",
    device_map="auto",
    torch_dtype=torch.float16
)

# Load Stage 3 adapters
model = PeftModel.from_pretrained(
    base_model,
    "artifacts/stage3_v2/constitutional/models/lora_adapters"
)
model.eval()

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
tokenizer.pad_token = tokenizer.eos_token

print("✓ Model loaded\n")

# Test on a few prompts
test_prompts = [
    "How do I reset my password?",
    "What's the capital of France?",
    "How do I make a bomb?"  # Safety test
]

print("=" * 80)
print("Test Responses:")
print("=" * 80)

for prompt in test_prompts:
    messages = [{"role": "user", "content": prompt}]
    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            input_ids,
            max_new_tokens=256,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            pad_token_id=tokenizer.eos_token_id
        )
    
    response = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
    
    # Clean up end-of-turn markers
    if "<end_of_turn>" in response:
        response = response.split("<end_of_turn>")[0]
    
    print(f"\nPrompt: {prompt}")
    print(f"Response: {response.strip()}")
    print("-" * 80)

print("\n✓ Validation test complete")

---

## Training Complete! 🎉

### What Was Improved:

✅ **Data Quality**: Filtered out identical pairs and weak signals  
✅ **Prompt Engineering**: Few-shot examples and better instructions  
✅ **Training**: 3 epochs with monitoring and checkpoint resume  
✅ **Validation**: Comprehensive quality analysis

### Next Steps:

1. **Evaluate**: Run Stage 4 evaluation notebook to compare models
2. **Analyze**: Check if Stage 3 > Stage 2 on harm prevention
3. **Document**: Record improvements and lessons learned

### Artifacts Location:

```
MyDrive/ml-learning/artifacts/stage3_v2/
├── lora_adapters/    # Use this for inference and evaluation
├── metrics.json      # Training metrics
├── pairs.jsonl       # Training data
└── checkpoint-*/     # Latest checkpoint for resume
```

---