# Stage 3: Constitutional AI Training

Critique-Revision data generation + DPO training.

**Requires**: Stage 2 adapters (load from Drive or retrain)

**Expected time**: ~4-5 hours on T4 GPU

In [None]:
# Cell 1: Setup
from google.colab import drive
import os

drive.mount('/content/drive')
os.chdir('/content')

!git clone https://github.com/Jai-Dhiman/ml-learning.git
os.chdir('/content/ml-learning/critique-revision-system')

print(f"✅ Ready in: {os.getcwd()}")

In [None]:
# Cell 2: Install Dependencies
!pip install -q transformers==4.36.2 datasets==2.16.1 peft==0.7.1 trl==0.7.10
!pip install -q accelerate==0.25.0 bitsandbytes==0.41.3 torch==2.1.2 pyyaml

print("✅ Dependencies installed")

In [None]:
# Cell 3: Load Stage 2 Adapters from Drive
!mkdir -p artifacts/stage2_helpful
!cp -r /content/drive/MyDrive/ml-learning/artifacts/stage2_helpful/* artifacts/stage2_helpful/

# Verify
import os
adapter_path = 'artifacts/stage2_helpful/lora_adapters'
if os.path.exists(adapter_path):
    print(f"✅ Stage 2 adapters loaded from Drive")
else:
    print("❌ Stage 2 adapters not found! Run Stage 2 notebook first.")
    raise FileNotFoundError(f"Missing: {adapter_path}")

In [None]:
# Cell 4: PREFLIGHT TEST (1 pair)
import os
os.environ['WANDB_DISABLED'] = 'true'

print("="*70)
print("PREFLIGHT: Testing with 1 sample")
print("="*70)

# Test critique-revision
print("\n[1/2] Testing critique-revision generation...")
!python src/training/critique_revision.py \
  --num-examples 1 \
  --split 'test[:1]' \
  --output /tmp/preflight_pairs.jsonl \
  --adapter-path artifacts/stage2_helpful/lora_adapters \
  --seed 42

# Test DPO
print("\n[2/2] Testing DPO training...")
!python src/training/train_dpo.py \
  --pairs-path /tmp/preflight_pairs.jsonl \
  --base-model-id google/gemma-2b-it \
  --stage2-adapter-path artifacts/stage2_helpful/lora_adapters \
  --output-dir /tmp/preflight_stage3 \
  --per-device-train-batch-size 1 \
  --gradient-accumulation-steps 1 \
  --learning-rate 5e-5 \
  --num-train-epochs 0.01 \
  --beta 0.3 \
  --seed 42 \
  --save-steps 1000 \
  --logging-steps 1

print("\n" + "="*70)
print("✅ PREFLIGHT PASSED - Ready for full training")
print("="*70)

In [None]:
# Cell 5: Generate Critique-Revision Pairs (2500 pairs)
!mkdir -p artifacts/stage3_pairs

print("="*70)
print("STAGE 3 - PART 1: CRITIQUE-REVISION GENERATION")
print("="*70)

!python src/training/critique_revision.py \
  --num-examples 2500 \
  --split 'test[:1000]+train[:1500]' \
  --output artifacts/stage3_pairs/pairs.jsonl \
  --adapter-path artifacts/stage2_helpful/lora_adapters \
  --seed 42

print("\n✅ Pairs generated")

In [None]:
# Cell 6: Validate Pairs Quality
import json
from collections import Counter

with open('artifacts/stage3_pairs/pairs.jsonl') as f:
    pairs = [json.loads(line) for line in f]

print(f"Total pairs: {len(pairs)}")

# Check malformed
malformed = 0
for p in pairs:
    revised = p.get('revised_response', '')
    if len(revised) < 30 or any(ind in revised.lower()[:100] for ind in ['accurate, but', 'could be improved']):
        malformed += 1

malformed_rate = malformed / len(pairs) * 100
print(f"\nQuality: {len(pairs)-malformed} valid ({100-malformed_rate:.1f}%)")
print(f"Malformed: {malformed} ({malformed_rate:.1f}%)")

if malformed_rate < 5:
    print("✅ Quality GOOD")
else:
    print("⚠️  Quality needs improvement")

# Top principles
principle_counts = Counter()
for p in pairs:
    for pid in p.get('principle_ids', []):
        principle_counts[pid] += 1

print(f"\nTop Principles:")
for pid, count in principle_counts.most_common(5):
    print(f"  {pid}: {count}")

In [None]:
# Cell 7: DPO Training
!mkdir -p artifacts/stage3_constitutional

print("="*70)
print("STAGE 3 - PART 2: DPO TRAINING")
print("="*70)

!python src/training/train_dpo.py \
  --pairs-path artifacts/stage3_pairs/pairs.jsonl \
  --base-model-id google/gemma-2b-it \
  --stage2-adapter-path artifacts/stage2_helpful/lora_adapters \
  --output-dir artifacts/stage3_constitutional \
  --per-device-train-batch-size 1 \
  --gradient-accumulation-steps 8 \
  --learning-rate 5e-5 \
  --num-train-epochs 2.0 \
  --beta 0.3 \
  --seed 42 \
  --save-steps 200 \
  --logging-steps 10

print("\n✅ Stage 3 training complete!")

In [None]:
# Cell 8: Save to Drive
!mkdir -p /content/drive/MyDrive/ml-learning/artifacts/stage3_constitutional
!cp -r artifacts/stage3_constitutional/* /content/drive/MyDrive/ml-learning/artifacts/stage3_constitutional/

print("✅ Saved to Google Drive")
print("Location: MyDrive/ml-learning/artifacts/stage3_constitutional/")
print("\n" + "="*70)
print("✅ STAGE 3 COMPLETE")
print("="*70)