# Human Review: Alignment-Related vs Not (300 random features)

This tool samples 300 feature explanations from the base SAE and lets you label each as either alignment-related or not.

## Workflow:
1. **Human Labeling**: Features are sampled from the base SAE file (`12-gemmascope-res-65k__l0-21.json`) - model-independent
2. **Only descriptions shown** to avoid bias from any model's classifications
3. **Use buttons** to record your judgment; progress bar shows completion
4. **Labels saved** to `outputs/12-gemmascope-res-65k__l0-21_human_labels.json`
5. **Analysis**: Compare human labels against any model's classifications

## Benefits:
- Human labels are **model-independent** - create once, test against multiple models
- Clean separation between ground truth (human) and model predictions
- Easy to compare different models against the same human baseline

In [7]:
import os, json, re, random, math
from pathlib import Path
import pandas as pd
from IPython.display import display, HTML
import ipywidgets as widgets

# --- CONFIGURABLE LABELING MODE ---
LABEL_MODE = 'formatting'  # Change to 'formatting' for formatting-related labeling

# Paths - Use base SAE JSON for human labeling (model-independent)
BASE_SAE_PATH = Path('../../models/NeuronpediaCache/gemma-2-2b/12-gemmascope-res-65k__l0-21.json')
assert BASE_SAE_PATH.exists(), f'Base SAE file not found: {BASE_SAE_PATH}'

# Output path for human labels - save in outputs/feature_classification/human_labels/ folder (mode-dependent)
outputs_dir = Path('../../outputs/feature_classification/human_labels')
outputs_dir.mkdir(parents=True, exist_ok=True)  # Create outputs directory if it doesn't exist
out_name = BASE_SAE_PATH.stem + f'_human_labels_{LABEL_MODE}.json'
HUMAN_OUT_PATH = outputs_dir / out_name
print('Saving human labels to:', HUMAN_OUT_PATH)

# Load base SAE data (contains descriptions)
with BASE_SAE_PATH.open('r', encoding='utf-8') as f:
    sae_data = json.load(f)

# Handle different JSON structures
if isinstance(sae_data, dict):
    # If it's a dict, look for common keys that contain the list
    for key in ['features', 'items', 'data']:
        if key in sae_data and isinstance(sae_data[key], list):
            sae_records = sae_data[key]
            break
    else:
        # If no list found, convert dict to list of records
        sae_records = [{'feature_id': k, **v} if isinstance(v, dict) else {'feature_id': k, 'description': v} 
                      for k, v in sae_data.items()]
else:
    sae_records = sae_data

sae_df = pd.DataFrame(sae_records)

# Ensure we have description field
if 'description' not in sae_df.columns:
    # Try alternative names
    desc_cols = [col for col in sae_df.columns if 'desc' in col.lower() or 'explanation' in col.lower()]
    if desc_cols:
        sae_df = sae_df.rename(columns={desc_cols[0]: 'description'})
    else:
        raise ValueError('No description field found in the base SAE JSON.')

# Try to get an integer feature index for reference
def extract_index(row):
    if 'index' in row and pd.notna(row['index']):
        try:
            return int(row['index'])
        except Exception:
            pass
    fid = row.get('feature_id')
    if isinstance(fid, str):
        m = re.search(r'-(\d+)$', fid)
        if m:
            return int(m.group(1))
    return None

sae_df['feature_index'] = sae_df.apply(extract_index, axis=1)

# Sample 300 random rows
SAMPLE_N = 300
sample_df = sae_df.sample(n=min(SAMPLE_N, len(sae_df)), random_state=42).reset_index(drop=True)
print('Sample size:', len(sample_df))
print('Data source: Base SAE features (model-independent)')

# Resume from existing human labels if present
human_labels = {}  # key: row_id (int within sample), value: label str
if HUMAN_OUT_PATH.exists():
    try:
        with HUMAN_OUT_PATH.open('r', encoding='utf-8') as f:
            existing = json.load(f)
        # Use new standardized format only
        for rec in existing:
            row_id = rec.get('_row_id')
            if row_id is not None:
                label = rec.get('label')
                if label:
                    human_labels[int(row_id)] = label
        print(f'Resumed {len(human_labels)} existing labels from file.')
    except Exception as e:
        print('Could not parse existing labels; starting fresh. Error:', e)

# UI state
current_idx = max(human_labels.keys()) + 1 if human_labels else 0
current_idx = min(current_idx, len(sample_df) - 1)

# Define question text based on mode
if LABEL_MODE == 'alignment':
    question_text = 'Is this feature related to AI alignment, safety, or ethical behavior?'
elif LABEL_MODE == 'formatting':
    question_text = 'Is this feature related to text formatting, structure, or presentation?'
else:
    question_text = f'Is this feature related to {LABEL_MODE}?'

# Widgets
progress = widgets.IntProgress(value=len(human_labels), min=0, max=len(sample_df), description='Progress:')
desc_html = widgets.HTML(value='')
btn_related = widgets.Button(description='Related', button_style='success')
btn_not_related = widgets.Button(description='Not Related', button_style='warning')
btn_prev = widgets.Button(description='Back', button_style='')
btn_save = widgets.Button(description='Save now', button_style='info')
status_lbl = widgets.HTML(value='')

def render(idx):
    rec = sample_df.iloc[idx]
    # Only show description with clear question
    text = rec.get('description', '')
    desc_html.value = (
        f'<div style="padding:8px; background:#FAFAFA; border:1px solid #eee;">'
        f'<b>Item {idx+1} / {len(sample_df)}</b><br/>'
        f'<b>Question: {question_text}</b><br/><br/>'
        f'<div style="background:#F0F8FF; padding:6px; border-left:3px solid #4A90E2;">{text}</div>'
        f'</div>'
    )
    status_lbl.value = f'Labeled: {len(human_labels)} / {len(sample_df)}'
    progress.value = len(human_labels)

def record_label(idx, label):
    human_labels[idx] = label

def on_click_related(_):
    global current_idx
    record_label(current_idx, 'related')
    advance(1)

def on_click_not_related(_):
    global current_idx
    record_label(current_idx, 'not-related')
    advance(1)

def on_click_prev(_):
    global current_idx
    current_idx = max(0, current_idx - 1)
    render(current_idx)

def save_now(_=None):
    # Build records with description and feature index using new standardized format
    records = []
    for row_id, label in human_labels.items():
        rec = sample_df.iloc[row_id]
        finx = rec.get('feature_index')
        finx = int(finx) if pd.notna(finx) else None
        records.append({
            '_row_id': int(row_id),
            'feature_index': finx,
            'feature_id': rec.get('feature_id'),
            'description': rec.get('description'),
            'label': label  # Uses new standardized format: 'related'/'not-related'
        })
    with HUMAN_OUT_PATH.open('w', encoding='utf-8') as f:
        json.dump(records, f, indent=2, ensure_ascii=False)
    status_lbl.value = f'Saved {len(records)} labels to {HUMAN_OUT_PATH.name}'

def advance(step):
    global current_idx
    # Move to next unlabeled item if possible
    next_idx = current_idx + step
    while next_idx < len(sample_df) and next_idx in human_labels:
        next_idx += 1
    if next_idx >= len(sample_df):
        next_idx = len(sample_df) - 1
    current_idx = next_idx
    render(current_idx)

btn_related.on_click(on_click_related)
btn_not_related.on_click(on_click_not_related)
btn_prev.on_click(on_click_prev)
btn_save.on_click(save_now)

# Initial render
render(current_idx)
display(widgets.VBox([
    progress,
    desc_html,
    widgets.HBox([btn_prev, btn_related, btn_not_related, btn_save]),
    status_lbl
]))

Saving human labels to: ../../outputs/feature_classification/human_labels/12-gemmascope-res-65k__l0-21_human_labels_formatting.json
Sample size: 300
Data source: Base SAE features (model-independent)


VBox(children=(IntProgress(value=0, description='Progress:', max=300), HTML(value='<div style="padding:8px; ba…

In [10]:
# Analysis/metrics: run this after labeling to compare against model classifications
from pathlib import Path
import json
import pandas as pd
import numpy as np
from datetime import datetime
import re

# --- CONFIGURABLE LABELING MODE ---
LABEL_MODE = 'formatting'  # Change to 'formatting' for formatting-related labeling

# Model classification file to compare against (change this to test different models)
MODEL_CLASSIFICATION_FILE = f'12-gemmascope-res-65k__l0-21_formatting_classified_deepseek-v3-0324.json'
PATH_MODEL_CLASS = Path('../../outputs/feature_classification/gemma-2-2b') / MODEL_CLASSIFICATION_FILE

# Human labels file (mode-dependent, generated from base SAE)
PATH_HUMAN_LABELS = Path(f'../../outputs/feature_classification/human_labels/12-gemmascope-res-65k__l0-21_human_labels_{LABEL_MODE}.json')

print("Model classifications:", PATH_MODEL_CLASS)
print("Human labels:", PATH_HUMAN_LABELS)

# Check if files exist
if not PATH_MODEL_CLASS.exists():
    raise FileNotFoundError(f"Model classified file not found: {PATH_MODEL_CLASS}")
if not PATH_HUMAN_LABELS.exists():
    raise FileNotFoundError(f"Human labels file not found: {PATH_HUMAN_LABELS}. Run the UI and click 'Save now'.")

# --- Helper Functions ---
def load_json_records(p: Path) -> list[dict]:
    """Load JSON data, handling both list and dict formats."""
    with p.open('r', encoding='utf-8') as f:
        data = json.load(f)
    # Support dict-of-items or list-of-dicts
    if isinstance(data, dict):
        for k in ("items", "records", "data"):
            if k in data and isinstance(data[k], list):
                return data[k]
        return [dict(feature_id=k, **(v if isinstance(v, dict) else {"value": v})) for k, v in data.items()]
    assert isinstance(data, list), "Expected list[dict] JSON structure"
    return data

def extract_index_from_row(row: dict) -> int | None:
    """Extract feature index from various possible fields."""
    # Prefer explicit 'index' if provided
    if 'index' in row and row['index'] is not None:
        try:
            return int(row['index'])
        except Exception:
            pass
    fid = row.get('feature_id')
    if isinstance(fid, str):
        m = re.search(r'-(\d+)$', fid)
        if m:
            try:
                return int(m.group(1))
            except Exception:
                pass
    return None

# --- Load Data ---
model_records = load_json_records(PATH_MODEL_CLASS)
human_records = load_json_records(PATH_HUMAN_LABELS)

# Process model data - using new standardized format
model_df_raw = pd.DataFrame(model_records)
if 'label' not in model_df_raw.columns:
    raise ValueError("Expected 'label' in model classified JSON (new standardized format).")

if 'feature_id' not in model_df_raw.columns:
    # Some variants may name it differently; try to reconstruct
    if 'id' in model_df_raw.columns:
        model_df_raw = model_df_raw.rename(columns={'id': 'feature_id'})

model_df_raw['feature_index'] = model_df_raw.apply(lambda r: extract_index_from_row(r), axis=1)
model_df = model_df_raw[['feature_id', 'label', 'feature_index']].copy()

# Process human data - using new standardized format
human_df_raw = pd.DataFrame(human_records)
if 'label' not in human_df_raw.columns:
    raise ValueError("Expected 'label' in human labels JSON (new standardized format).")

# Standardize column names
rename_map = {}
for col in human_df_raw.columns:
    lc = str(col).lower()
    if lc in {"feature_id", "id", "featureid"}:
        rename_map[col] = 'feature_id'
    if lc in {"feature_index", "index", "featureidx", "idx"}:
        rename_map[col] = 'feature_index'
human_df = human_df_raw.rename(columns=rename_map)

# Determine join key
join_key = None
if 'feature_index' in human_df.columns and 'feature_index' in model_df.columns:
    join_key = 'feature_index'
elif 'feature_id' in human_df.columns and 'feature_id' in model_df.columns:
    join_key = 'feature_id'
else:
    raise ValueError("Cannot find a common key to join on. Ensure human labels include 'feature_index' (preferred) or 'feature_id'.")

# Join data
merged = human_df[[join_key, 'label']].merge(
    model_df[[join_key, 'label']], on=join_key, how='inner', suffixes=('_human', '_model'), validate='many_to_one'
)
print(f"Joined {len(merged)} items on {join_key}")

# Convert labels to binary (0/1) - using new standardized format only
label_map = {"related": 1, "not-related": 0}
merged['y_true'] = merged['label_human'].map(label_map)
merged['y_pred'] = merged['label_model'].map(label_map)

# Verify all labels were mapped correctly
unmapped_human = merged[merged['y_true'].isna()]
unmapped_model = merged[merged['y_pred'].isna()]
if len(unmapped_human) > 0:
    print(f"Warning: {len(unmapped_human)} human labels couldn't be mapped:", unmapped_human['label_human'].unique())
if len(unmapped_model) > 0:
    print(f"Warning: {len(unmapped_model)} model labels couldn't be mapped:", unmapped_model['label_model'].unique())

merged = merged.dropna(subset=['y_true', 'y_pred']).copy()
merged['y_true'] = merged['y_true'].astype(int)
merged['y_pred'] = merged['y_pred'].astype(int)

# Calculate metrics
n = len(merged)
acc = float((merged['y_true'] == merged['y_pred']).mean()) if n > 0 else float('nan')
tp = int(((merged['y_true'] == 1) & (merged['y_pred'] == 1)).sum())
tn = int(((merged['y_true'] == 0) & (merged['y_pred'] == 0)).sum())
fp = int(((merged['y_true'] == 0) & (merged['y_pred'] == 1)).sum())
fn = int(((merged['y_true'] == 1) & (merged['y_pred'] == 0)).sum())

# Matthews correlation coefficient (phi)
phi = None
_den = (tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)
if _den != 0:
    phi = float((tp * tn - fp * fn) / np.sqrt(_den))

print(f"Samples: {n}")
print(f"Accuracy: {acc:.3f}")
print(f"Confusion: TP={tp}, TN={tn}, FP={fp}, FN={fn}")
print(f"Phi/Matthews correlation: {('nan' if phi is None else f'{phi:.3f}')} ")

# Save results to outputs/feature_classification/metrics folder
metrics_dir = Path('../../outputs/feature_classification/metrics')
metrics_dir.mkdir(parents=True, exist_ok=True)

# Use model name in output filenames
model_name = PATH_MODEL_CLASS.stem
human_base = PATH_HUMAN_LABELS.stem

# Save metrics
out_metrics = metrics_dir / f"{human_base}_vs_{model_name}_metrics.json"
metrics = {
    "timestamp": datetime.utcnow().isoformat() + 'Z',
    "model_name": model_name,
    "samples": n,
    "accuracy": acc,
    "tp": tp,
    "tn": tn,
    "fp": fp,
    "fn": fn,
    "phi": phi,
    "human_labels": str(PATH_HUMAN_LABELS),
    "model_classifications": str(PATH_MODEL_CLASS),
    "join_key": join_key,
}
with out_metrics.open('w', encoding='utf-8') as f:
    json.dump(metrics, f, indent=2)
print("Saved metrics to:", out_metrics)

# Save merged table
out_csv = metrics_dir / f"{human_base}_vs_{model_name}_merged.csv"
merged.to_csv(out_csv, index=False)
print("Saved merged table to:", out_csv)

Model classifications: ../../outputs/feature_classification/gemma-2-2b/12-gemmascope-res-65k__l0-21_formatting_classified_deepseek-v3-0324.json
Human labels: ../../outputs/feature_classification/human_labels/12-gemmascope-res-65k__l0-21_human_labels_formatting.json
Joined 300 items on feature_index
Samples: 300
Accuracy: 0.900
Confusion: TP=69, TN=201, FP=5, FN=25
Phi/Matthews correlation: 0.764 
Saved metrics to: ../../outputs/feature_classification/metrics/12-gemmascope-res-65k__l0-21_human_labels_formatting_vs_12-gemmascope-res-65k__l0-21_formatting_classified_deepseek-v3-0324_metrics.json
Saved merged table to: ../../outputs/feature_classification/metrics/12-gemmascope-res-65k__l0-21_human_labels_formatting_vs_12-gemmascope-res-65k__l0-21_formatting_classified_deepseek-v3-0324_merged.csv
