# Stage 3: Critique–Revision DPO Orchestrator (Colab)

Lightweight notebook to:
- Clone the repo (at the top)
- Create uv venv and install deps
- Generate critique/revision pairs (default config)
- Train DPO with Stage 2 adapters as initialization
- Quick evaluation and safety scoring attempt (if Stage 1 checkpoint is available)

Defaults:
- Base model: google/gemma-2b-it
- Local save only
- W&B disabled
- Dataset subset: default (no special permissions)


In [None]:
# Runtime checks and config
import os, sys, json, random
print('Python:', sys.version)

# Colab GPU info (may be absent)
try:
    import subprocess
    print(subprocess.run(['nvidia-smi'], check=False, capture_output=True, text=True).stdout[:1000])
except Exception as e:
    print('No GPU info available:', e)

# Repository URL and directories
REPO_URL = 'https://github.com/Jai-Dhiman/ml-learning.git'
REPO_DIR = '/content/ml-learning'
BASE_MODEL_ID = 'google/gemma-2b-it'
ARTIFACTS_DIR = f'{REPO_DIR}/artifacts'
STAGE2_ADAPTER_DIR = f'{ARTIFACTS_DIR}/stage2_artifacts/lora_adapters'
STAGE3_DIR = f'{ARTIFACTS_DIR}/stage3_artifacts'
PAIRS_PATH = f'{STAGE3_DIR}/pairs/pairs.jsonl'
STAGE3_LORA_DIR = f'{STAGE3_DIR}/models/lora_adapters'

# Export paths as environment variables for bash cells
os.environ['REPO_DIR'] = REPO_DIR
os.environ['REPO_URL'] = REPO_URL
os.environ['BASE_MODEL_ID'] = BASE_MODEL_ID
os.environ['ARTIFACTS_DIR'] = ARTIFACTS_DIR
os.environ['STAGE2_ADAPTER_DIR'] = STAGE2_ADAPTER_DIR
os.environ['STAGE3_DIR'] = STAGE3_DIR
os.environ['PAIRS_PATH'] = PAIRS_PATH
os.environ['STAGE3_LORA_DIR'] = STAGE3_LORA_DIR

# Env flags
os.environ['WANDB_DISABLED'] = 'true'
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1'
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
print('Configured environment flags for quiet, offline-friendly operation.')
print(f'REPO_DIR={REPO_DIR}')


In [None]:
# Install uv
!pip -q install -U uv
import shutil, os
print('uv version:', shutil.which('uv'))


In [None]:
# Clone the repository (top of notebook)
%%bash
set -e
rm -rf "/content/ml-learning"
git clone https://github.com/Jai-Dhiman/ml-learning.git "/content/ml-learning"
ls -la "/content/ml-learning"


In [None]:
# Login to Hugging Face (required for Gemma model access)
# Secure login without storing/printing your token.
# Token will be set as HF_TOKEN environment variable for bash cells.
import os

# Clear any existing tokens
os.environ.pop("HF_TOKEN", None)
os.environ.pop("HUGGINGFACEHUB_API_TOKEN", None)

from huggingface_hub import login, HfApi

try:
    import getpass as gp
    raw = gp.getpass("Paste your Hugging Face token (input hidden): ")
    token = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
    if not isinstance(token, str):
        raise TypeError(f"Unexpected token type: {type(token).__name__}")
    token = token.strip()
    if not token:
        raise ValueError("Empty token provided")
    
    # Login and set environment variable
    login(token=token, add_to_git_credential=False)
    os.environ['HF_TOKEN'] = token
    
    who = HfApi().whoami(token=token)
    print(f"Logged in as: {who.get('name') or who.get('email') or 'OK'}")
    print('HF_TOKEN environment variable set for bash cells.')
    
except Exception as e:
    print(f"[HF Login] getpass flow failed: {e}")
    print("Falling back to interactive login widget...")
    login()
    
    # Try to get token from saved credentials
    try:
        from huggingface_hub import HfFolder
        token = HfFolder.get_token()
        if token:
            os.environ['HF_TOKEN'] = token
            print('HF_TOKEN environment variable set from saved credentials.')
        who = HfApi().whoami()
        print(f"Logged in as: {who.get('name') or who.get('email') or 'OK'}")
    except Exception as e2:
        print(f"[HF Login] Could not set HF_TOKEN env var: {e2}")
        print("You may need to run 'huggingface-cli login' in a bash cell.")


In [None]:
# Create uv venv and install dependencies
%%bash
set -e
cd "$REPO_DIR"

# Create venv if it doesn't exist
if [ ! -d .venv ]; then
  echo 'Creating uv virtual environment...'
  uv venv
fi

# Activate and install dependencies
echo 'Installing dependencies...'

# Install torch first with CUDA support if available
uv pip install --python .venv/bin/python torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 || \
  uv pip install --python .venv/bin/python torch torchvision torchaudio

# Install remaining dependencies
uv pip install --python .venv/bin/python \
  "transformers>=4.43.0" \
  "trl>=0.9.6" \
  "peft>=0.13.0" \
  "datasets>=2.19.0" \
  "accelerate>=0.28.0" \
  sentencepiece \
  safetensors \
  einops \
  evaluate \
  "protobuf<5" \
  "jax[cpu]==0.4.38" \
  "flax>=0.8.4,<0.9.0" \
  "optax>=0.2.2,<0.3.0"

# Verify installation
echo 'Verifying torch installation...'
.venv/bin/python -c "import torch; print(f'PyTorch {torch.__version__}'); print(f'CUDA available: {torch.cuda.is_available()}')"
echo 'Dependencies installed successfully!'


In [None]:
# Validate Stage 2 adapters exist (fail fast)
from pathlib import Path
stage2 = Path(STAGE2_ADAPTER_DIR)
if not stage2.exists() or not any(stage2.iterdir()):
    raise FileNotFoundError(
        f'Stage 2 LoRA adapters not found at {stage2}. Please train Stage 2 first or place adapters there.'
    )
print('Stage 2 adapters found at:', STAGE2_ADAPTER_DIR)


In [None]:
# Optional: Set Stage 1 checkpoint env var for safety scoring if present
import os
s1a = f'{REPO_DIR}/artifacts/stage1_artifacts/checkpoints/best_model'
s1b = f'{REPO_DIR}/safety-text-classifier/checkpoints/best_model'
if os.path.isdir(s1a):
    os.environ['STAGE1_CKPT_DIR'] = s1a
elif os.path.isdir(s1b):
    os.environ['STAGE1_CKPT_DIR'] = s1b
print('STAGE1_CKPT_DIR =', os.environ.get('STAGE1_CKPT_DIR'))


In [None]:
# Preflight: generate 2 pairs with quality validation
%%bash
set -e
cd /content/ml-learning
mkdir -p /content/ml-learning/artifacts/stage3_artifacts/preflight
echo '=== Generating 2 preflight pairs for quality check ==='
uv run python /content/ml-learning/critique-revision-system/src/critique_revision.py \
  --adapter-path /content/ml-learning/artifacts/stage2_artifacts/lora_adapters \
  --dataset-subset default \
  --split "test[:2]" \
  --num-examples 2 \
  --max-new-tokens 256 \
  --temperature 0.5 \
  --output /content/ml-learning/artifacts/stage3_artifacts/preflight/pairs_2.jsonl
echo 'Preflight pairs at: /content/ml-learning/artifacts/stage3_artifacts/preflight/pairs_2.jsonl'
ls -lh /content/ml-learning/artifacts/stage3_artifacts/preflight/pairs_2.jsonl


In [None]:
# Preflight: Quality validation of generated pairs
import json
from pathlib import Path

print('=' * 60)
print('PREFLIGHT QUALITY CHECK')
print('=' * 60)

pairs_file = Path(f'{STAGE3_DIR}/preflight/pairs_2.jsonl')
if not pairs_file.exists():
    raise FileNotFoundError(f'Pairs file not found: {pairs_file}')

with open(pairs_file, 'r', encoding='utf-8') as f:
    pairs = [json.loads(line) for line in f if line.strip()]

print(f'Loaded {len(pairs)} pairs for validation\n')

quality_issues = []

for i, pair in enumerate(pairs, 1):
    print(f'--- PAIR {i} ---')
    print(f'Prompt: {pair.get("prompt", "N/A")[:100]}...')
    
    base_resp = pair.get('base_response', '')
    revised_resp = pair.get('revised_response', '')
    critic_notes = pair.get('critic_notes', '')
    
    # Quality checks
    issues = []
    
    # Check 1: Multi-turn hallucination (unwanted Human:/Assistant: in responses)
    if 'Human:' in base_resp or 'Assistant:' in base_resp:
        issues.append('Base response contains multi-turn dialogue (hallucinated)')
    if 'Human:' in revised_resp or 'Assistant:' in revised_resp:
        issues.append('Revised response contains multi-turn dialogue (hallucinated)')
    
    # Check 2: Empty or very short responses
    if len(base_resp.split()) < 10:
        issues.append(f'Base response too short ({len(base_resp.split())} words)')
    if len(revised_resp.split()) < 10:
        issues.append(f'Revised response too short ({len(revised_resp.split())} words)')
    
    # Check 3: Critique format validation
    if not critic_notes or len(critic_notes.strip()) < 10:
        issues.append('Critique is empty or too short')
    
    # Check 4: Meta-commentary instead of actual revision
    meta_indicators = ['draft answer', 'original answer', 'Assistant answers', 'Assistant is']
    if any(indicator in revised_resp[:100] for indicator in meta_indicators):
        issues.append('Revised response contains meta-commentary instead of clean answer')
    
    # Check 5: Identical base and revised responses
    if base_resp.strip() == revised_resp.strip():
        issues.append('Base and revised responses are identical')
    
    # Check 6: Score validation
    base_score = pair.get('base_score', 0)
    revised_score = pair.get('revised_score', 0)
    chosen = pair.get('chosen', 'unknown')
    
    print(f'Base score: {base_score:.2f}')
    print(f'Revised score: {revised_score:.2f}')
    print(f'Chosen: {chosen}')
    print(f'Score delta: {revised_score - base_score:.2f}')
    
    if issues:
        quality_issues.extend(issues)
        print(f'\nQUALITY ISSUES DETECTED:')
        for issue in issues:
            print(f'  - {issue}')
    else:
        print('No quality issues detected')
    
    # Show sample outputs
    print(f'\nBase response preview: {base_resp[:200]}...')
    print(f'Revised response preview: {revised_resp[:200]}...')
    print(f'Critique preview: {critic_notes[:150]}...')
    print()

print('=' * 60)
print('QUALITY CHECK SUMMARY')
print('=' * 60)
print(f'Total pairs checked: {len(pairs)}')
print(f'Total quality issues: {len(quality_issues)}')

if quality_issues:
    print('\nISSUES FOUND:')
    for issue in set(quality_issues):
        count = quality_issues.count(issue)
        print(f'  [{count}x] {issue}')
    print('\nWARNING: Quality issues detected in preflight pairs!')
    print('Consider:')
    print('  1. Adjusting generation parameters (temperature, max_tokens)')
    print('  2. Improving prompt engineering in critique_revision.py')
    print('  3. Using fewer examples for initial testing')
    print('\nProceeding with full run may produce low-quality training data.')
else:
    print('\nAll quality checks passed! Safe to proceed with full generation.')


In [None]:
# Preflight: 1-step DPO training to catch training-time errors with minimal cost
%%bash
set -e
cd /content/ml-learning
mkdir -p /content/ml-learning/artifacts/stage3_artifacts/preflight_run
echo '=== Starting preflight DPO training (1 step) ==='
uv run python /content/ml-learning/critique-revision-system/src/training/train_dpo_stage3.py \
  --repo-root /content/ml-learning \
  --pairs-path /content/ml-learning/artifacts/stage3_artifacts/preflight/pairs_2.jsonl \
  --base-model-id google/gemma-2b-it \
  --stage2-adapter-path /content/ml-learning/artifacts/stage2_artifacts/lora_adapters \
  --output-dir /content/ml-learning/artifacts/stage3_artifacts/preflight_run \
  --per-device-train-batch-size 1 \
  --gradient-accumulation-steps 1 \
  --learning-rate 5e-5 \
  --num-train-epochs 1 \
  --max-steps 1 \
  --beta 0.1 \
  --cpu-ref-model
echo '=== Preflight DPO training completed ==='


In [None]:
# Preflight: sanity checks and single quick generation
import json, os
from pathlib import Path
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

pre_pairs = f"{STAGE3_DIR}/preflight/pairs_2.jsonl"
pre_lora = f"{STAGE3_DIR}/preflight_run/models/lora_adapters"
assert Path(pre_pairs).exists(), f"Missing preflight pairs: {pre_pairs}"
assert Path(pre_lora).exists(), f"Missing preflight adapters: {pre_lora}"

with open(pre_pairs, 'r', encoding='utf-8') as f:
    line = next(f)
    ex = json.loads(line)
    prompt = ex.get('prompt') or ex.get('instruction') or ex.get('input') or ex.get('user_prompt')
    assert prompt, f"No prompt in preflight pair keys={list(ex.keys())}"
    prompt = f"User: {prompt}\nAssistant:"

tok = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_fast=True)
if tok.pad_token_id is None:
    tok.pad_token = tok.eos_token
base = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_ID,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map='auto' if torch.cuda.is_available() else None,
)
model = PeftModel.from_pretrained(base, pre_lora)
inputs = tok(prompt, return_tensors='pt').to(model.device)
with torch.no_grad():
    out = model.generate(**inputs, max_new_tokens=64, do_sample=True, top_p=0.9, temperature=0.7)
text = tok.decode(out[0], skip_special_tokens=True)
print('Preflight generation ok. Token count:', out.shape[-1])
# Tiny safety scoring if Stage 1 checkpoint is available
try:
    import sys
    sys.path.insert(0, f'{REPO_DIR}/helpful-finetuning/src')
    from utils.safety_integration import SafetyFilter
    ckpt = os.environ.get('STAGE1_CKPT_DIR')
    if not ckpt:
        a = f'{REPO_DIR}/artifacts/stage1_artifacts/checkpoints/best_model'
        b = f'{REPO_DIR}/safety-text-classifier/checkpoints/best_model'
        if os.path.isdir(a):
            ckpt = a
        elif os.path.isdir(b):
            ckpt = b
    if ckpt and os.path.isdir(ckpt):
        cfg_path = f'{REPO_DIR}/safety-text-classifier/configs/base_config.yaml'
        sf = SafetyFilter(cfg_path, ckpt)
        score = float(sf.score_text(text))
        print('Preflight safety score:', score)
    else:
        print('Preflight: Stage 1 checkpoint not found; skipping safety scoring.')
except Exception as e:
    print('Preflight safety scoring failed:', e)


In [None]:
# Generate critique/revision pairs (full run)
%%bash
set -e
cd /content/ml-learning
mkdir -p /content/ml-learning/artifacts/stage3_artifacts/pairs
echo '=== Generating 400 critique/revision pairs ==='
# Use default subset and a modest slice for a first run
uv run python /content/ml-learning/critique-revision-system/src/critique_revision.py \
  --adapter-path /content/ml-learning/artifacts/stage2_artifacts/lora_adapters \
  --dataset-subset default \
  --split "test[:400]" \
  --num-examples 400 \
  --output /content/ml-learning/artifacts/stage3_artifacts/pairs/pairs.jsonl
echo 'Pairs generated at: /content/ml-learning/artifacts/stage3_artifacts/pairs/pairs.jsonl'
ls -lh /content/ml-learning/artifacts/stage3_artifacts/pairs/pairs.jsonl


In [None]:
# Train DPO starting from Stage 2 adapters
%%bash
set -e
cd "$REPO_DIR"
mkdir -p "$STAGE3_DIR"
uv run python "$REPO_DIR/critique-revision-system/src/training/train_dpo_stage3.py" \
  --repo-root "$REPO_DIR" \
  --pairs-path "$PAIRS_PATH" \
  --base-model-id "$BASE_MODEL_ID" \
  --stage2-adapter-path "$STAGE2_ADAPTER_DIR" \
  --output-dir "$STAGE3_DIR" \
  --per-device-train-batch-size 1 \
  --gradient-accumulation-steps 8 \
  --learning-rate 5e-5 \
  --num-train-epochs 1 \
  --beta 0.1 \
  --cpu-ref-model


In [None]:
# Quick evaluation: load Stage 3 adapters and generate a few outputs
import json, os, random
from pathlib import Path
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

rows = []
with open(PAIRS_PATH, 'r', encoding='utf-8') as f:
    for line in f:
        rows.append(json.loads(line))
random.seed(42)
sample = rows[:12] if len(rows) <= 12 else random.sample(rows, 12)

def extract_prompt(ex):
    for k in ['prompt','instruction','input','user_prompt']:
        if k in ex and ex[k]:
            return f"User: {ex[k]}\nAssistant:"
    raise KeyError('No prompt-like key found')

tok = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_fast=True)
if tok.pad_token_id is None:
    tok.pad_token = tok.eos_token

base = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_ID,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map='auto' if torch.cuda.is_available() else None,
)
model = PeftModel.from_pretrained(base, STAGE3_LORA_DIR)
model.eval()
gen_conf = dict(max_new_tokens=256, do_sample=True, top_p=0.9, temperature=0.7)

EVAL_OUT = f'{STAGE3_DIR}/eval/generated_eval.jsonl'
Path(f'{STAGE3_DIR}/eval').mkdir(parents=True, exist_ok=True)
with open(EVAL_OUT, 'w', encoding='utf-8') as out_f:
    for ex in sample:
        prompt = extract_prompt(ex)
        inputs = tok(prompt, return_tensors='pt').to(model.device)
        with torch.no_grad():
            outputs = model.generate(**inputs, **gen_conf)
        text = tok.decode(outputs[0], skip_special_tokens=True)
        out_f.write(json.dumps({'prompt': prompt, 'generated': text}, ensure_ascii=False) + '\n')
print('Wrote eval generations to:', EVAL_OUT)

print('Sample generations:')
with open(EVAL_OUT, 'r', encoding='utf-8') as f:
    for i, line in enumerate(f):
        if i >= 3: break
        print(json.loads(line)['generated'][:500])


In [None]:
# Attempt safety scoring via Stage 1 SafetyFilter (CPU JAX) if checkpoint is available
try:
    import sys
    sys.path.insert(0, f'{REPO_DIR}/helpful-finetuning/src')
    from utils.safety_integration import SafetyFilter
    ckpt = os.environ.get('STAGE1_CKPT_DIR')
    cfg_path = f'{REPO_DIR}/safety-text-classifier/configs/base_config.yaml'
    if ckpt and os.path.isdir(ckpt):
        sf = SafetyFilter(cfg_path, ckpt)
        # Score generated eval set
        import json
        scores = []
        with open(f'{STAGE3_DIR}/eval/generated_eval.jsonl','r',encoding='utf-8') as f:
            for line in f:
                obj = json.loads(line)
                txt = obj.get('generated','')
                try:
                    s = float(sf.score_text(txt))
                    scores.append(s)
                except Exception:
                    pass
        if scores:
            print('Safety scores count:', len(scores), 'avg:', sum(scores)/len(scores))
        else:
            print('No safety scores computed (empty or errors).')
    else:
        print('Stage 1 checkpoint not found; skipping safety scoring.')
except Exception as e:
    print('Safety scoring unavailable or failed:', e)


In [None]:
# Artifact summary
from pathlib import Path
import json
print('Pairs:', PAIRS_PATH, 'exists:', Path(PAIRS_PATH).exists())
print('DPO adapters:', STAGE3_LORA_DIR, 'exists:', Path(STAGE3_LORA_DIR).exists())
m = Path(f'{STAGE3_DIR}/metrics.json')
if m.exists():
    print('Metrics:')
    print(json.dumps(json.load(open(m,'r')), indent=2))
else:
    print('Metrics file not found.')
e = Path(f'{STAGE3_DIR}/eval/generated_eval.jsonl')
print('Eval generations:', str(e), 'exists:', e.exists())
