<a href="https://colab.research.google.com/github/aarushi-sharma22/gemma9b-feature-discovery/blob/main/code/gemma-9b-full-pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Installation

!pip install -q transformer_lens sae_lens torch transformers accelerate
!pip install -q pandas numpy scipy scikit-learn matplotlib seaborn tqdm
!pip install -q statsmodels

print("Installation complete")

In [None]:
# Imports and setup

import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
import json
import gc
import time
import warnings
from datetime import datetime
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Optional, Any, Set
from collections import defaultdict
from pathlib import Path
import pickle

from scipy import stats
from scipy.stats import spearmanr, mannwhitneyu, kruskal, wilcoxon
from statsmodels.stats.multitest import multipletests

from transformer_lens import HookedTransformer
from sae_lens import SAE

import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm

try:
    from google.colab import files
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

warnings.filterwarnings('ignore')

np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")

In [None]:
# Data mappings and file paths

import os

COUNTRY_MAP = {
    "AT": "Austria", "BE": "Belgium", "CZ": "Czechia",
    "FI": "Finland", "FR": "France", "GB": "United Kingdom",
    "HU": "Hungary", "IS": "Iceland", "PL": "Poland",
    "PT": "Portugal", "SI": "Slovenia"
}

GENDER_MAP = {1: "man", 2: "woman", 9: None}

INCOME_MAP = {
    1: "living comfortably on present income",
    2: "coping on present income",
    3: "finding it difficult on present income",
    4: "finding it very difficult on present income",
    7: None, 8: None, 9: None
}

EDUCATION_MAP = {
    1: "less than lower secondary education",
    2: "lower secondary education",
    3: "lower tier upper secondary education",
    4: "upper tier upper secondary education",
    5: "advanced vocational or sub-degree education",
    6: "a bachelor's degree",
    7: "a master's degree or higher",
    0: None, 55: None, 77: None, 88: None, 99: None,
}

EDUCATION_TEXT = {
    1: "You did not complete high school.",
    2: "You completed middle school.",
    3: "You have some high school education.",
    4: "You completed high school.",
    5: "You have vocational or technical training.",
    6: "You have a bachelor's degree.",
    7: "You have a master's degree or higher.",
}

# Behavioural self-report items excluded from normative analysis
BEHAVIOR_QUESTIONS = {"w4hq17", "w4hq20"}

def get_domain(question_id: str) -> str:
    """Map question ID prefix to thematic domain."""
    qid = question_id.lower()
    if qid.startswith("w4g"):
        return "climate"
    elif qid.startswith("w4h"):
        return "health"
    elif qid.startswith("w4d"):
        return "digital"
    elif qid.startswith("w4e"):
        return "equality"
    return "unknown"

# File paths — update these for your environment
DEMOGRAPHICS_FILE = "/content/stratified_sample_200_blank.csv"
CODEBOOK_FILE = "/content/codebook_updated.json"

assert os.path.exists(DEMOGRAPHICS_FILE), f"Not found: {DEMOGRAPHICS_FILE}"
assert os.path.exists(CODEBOOK_FILE), f"Not found: {CODEBOOK_FILE}"

print(f"Demographics: {DEMOGRAPHICS_FILE}")
print(f"Codebook: {CODEBOOK_FILE}")
print("Mappings ready")

In [None]:
# Hugging Face authentication

import os
from huggingface_hub import login, whoami, model_info, repo_info

HF_TOKEN = "YOUR_TOKEN_HERE"  # replace with your token, or use getpass

login(token=HF_TOKEN, add_to_git_credential=False)
os.environ["HF_TOKEN"] = HF_TOKEN

try:
    info = whoami()
    print(f"Authenticated as: {info.get('name', 'Unknown')}")

    mi = model_info("google/gemma-2-9b-it")
    print("Access to Gemma-2-9B-IT confirmed")

    sae_info = repo_info("google/gemma-scope-9b-pt-res")
    print("Access to gemma-scope SAE repo confirmed")

except Exception as e:
    error_msg = str(e)
    if "gemma" in error_msg.lower():
        raise RuntimeError(
            f"Gemma access denied: {e}\n"
            f"Accept the license at: https://huggingface.co/google/gemma-2-9b-it"
        )
    else:
        raise RuntimeError(f"Auth failed: {e}")

print("\nAll access verified")

In [None]:
# Configuration and model/SAE loading

from huggingface_hub import hf_hub_download

# Verified L0 values per layer — confirmed via HuggingFace API
# repo: "pt" = gemma-scope-9b-pt-res, "it" = gemma-scope-9b-it-res
LAYER_SAE_CONFIG = {
    5:  {"l0": 77,  "repo": "pt", "depth_pct": 12},
    9:  {"l0": 51,  "repo": "pt", "depth_pct": 22},
    14: {"l0": 67,  "repo": "pt", "depth_pct": 34},
    18: {"l0": 71,  "repo": "pt", "depth_pct": 44},
    20: {"l0": 47,  "repo": "it", "depth_pct": 49},  # IT SAE
    27: {"l0": 65,  "repo": "pt", "depth_pct": 66},
    32: {"l0": 61,  "repo": "pt", "depth_pct": 78},
    36: {"l0": 61,  "repo": "pt", "depth_pct": 88},
}

LAYER_L0_MAP = {layer: cfg["l0"] for layer, cfg in LAYER_SAE_CONFIG.items()}


@dataclass
class CausalExtractionConfig:
    """Configuration for the expanded depth analysis across 8 layers."""

    model_name: str = "google/gemma-2-9b-it"
    sae_repo_pt: str = "google/gemma-scope-9b-pt-res"
    sae_repo_it: str = "google/gemma-scope-9b-it-res"
    sae_width: str = "16k"
    d_sae: int = 16384

    candidate_layers: List[int] = field(
        default_factory=lambda: [5, 9, 14, 18, 20, 27, 32, 36]
    )
    original_layers: List[int] = field(
        default_factory=lambda: [18, 27, 36]
    )
    new_layers: List[int] = field(
        default_factory=lambda: [5, 9, 14, 20, 32]
    )

    n_features: int = 50
    min_effect_threshold: float = 0.3
    batch_size: int = 16
    dtype: torch.dtype = torch.bfloat16
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    output_dir: Path = field(default_factory=lambda: Path("./outputs_gemma_replication"))

    def hook_name(self, layer: int) -> str:
        return f"blocks.{layer}.hook_resid_post"


config = CausalExtractionConfig()
config.output_dir.mkdir(exist_ok=True)

print("Configuration")
print(f"  Model: {config.model_name}")
print(f"  Layers: {config.candidate_layers}")
for layer in config.candidate_layers:
    cfg = LAYER_SAE_CONFIG[layer]
    tag = " (IT SAE)" if cfg["repo"] == "it" else ""
    print(f"    L{layer:>2} — {cfg['depth_pct']}% depth, L0={cfg['l0']}{tag}")

# Load model
model = HookedTransformer.from_pretrained(
    config.model_name,
    torch_dtype=config.dtype,
    device=config.device,
)
print(f"\nModel loaded: {model.cfg.n_layers} layers, d_model={model.cfg.d_model}")

tokenizer = model.tokenizer
d_model = model.cfg.d_model


class GemmaScopeSAE:
    """Gemma-Scope SAE with JumpReLU activation.

    Loads from a verified (layer, L0, repo) combination.
    """

    REPO_IDS = {
        "pt": "google/gemma-scope-9b-pt-res",
        "it": "google/gemma-scope-9b-it-res",
    }

    def __init__(self, layer: int, device: str = "cuda"):
        self.layer = layer
        self.device = device

        if layer not in LAYER_SAE_CONFIG:
            raise ValueError(
                f"Layer {layer} not in config. Available: {list(LAYER_SAE_CONFIG.keys())}"
            )

        cfg = LAYER_SAE_CONFIG[layer]
        l0 = cfg["l0"]
        repo_id = self.REPO_IDS[cfg["repo"]]
        sae_path = f"layer_{layer}/width_16k/average_l0_{l0}/params.npz"
        print(f"  Loading: {repo_id} / {sae_path}")

        params_file = hf_hub_download(
            repo_id=repo_id, filename=sae_path, repo_type="model"
        )
        params = np.load(params_file)

        W_enc_raw = params['W_enc']
        W_dec_raw = params['W_dec']

        # Determine encoder orientation: need [d_model, d_sae] for x @ W_enc
        if W_enc_raw.shape[0] == d_model:
            self.encode_transpose = False
        elif W_enc_raw.shape[1] == d_model:
            self.encode_transpose = True
        else:
            raise ValueError(f"W_enc shape {W_enc_raw.shape} incompatible with d_model={d_model}")

        self.W_enc = torch.tensor(W_enc_raw, dtype=torch.bfloat16, device=device)
        self.W_dec = torch.tensor(W_dec_raw, dtype=torch.bfloat16, device=device)
        self.b_enc = torch.tensor(params['b_enc'], dtype=torch.bfloat16, device=device)
        self.b_dec = torch.tensor(params['b_dec'], dtype=torch.bfloat16, device=device)

        if 'threshold' in params:
            self.threshold = torch.tensor(params['threshold'], dtype=torch.bfloat16, device=device)
        elif 'log_threshold' in params:
            self.threshold = torch.exp(
                torch.tensor(params['log_threshold'], dtype=torch.bfloat16, device=device)
            )
        else:
            self.threshold = None
            print(f"    No threshold found, falling back to ReLU")

        if self.encode_transpose:
            self.d_in = self.W_enc.shape[1]
            self.d_sae = self.W_enc.shape[0]
        else:
            self.d_in = self.W_enc.shape[0]
            self.d_sae = self.W_enc.shape[1]

        self.l0 = l0
        self.repo_type = cfg["repo"]

        assert self.d_in == d_model, f"SAE d_in={self.d_in} != d_model={d_model}"
        print(f"    Loaded: d_sae={self.d_sae}, L0={l0}, repo={cfg['repo'].upper()}")

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """Encode activations to sparse features.

        Note: we skip the standard b_dec subtraction before encoding.
        When applying PT SAEs to IT activations, subtracting b_dec (which was
        learned to center PT activations) actively decenters the IT activations,
        degrading reconstruction (NMSE 0.50 vs 0.17 without subtraction).
        """
        x = x.to(self.W_enc.dtype)
        if self.encode_transpose:
            pre_acts = x @ self.W_enc.T + self.b_enc
        else:
            pre_acts = x @ self.W_enc + self.b_enc
        if self.threshold is not None:
            return torch.where(pre_acts > self.threshold, pre_acts, torch.zeros_like(pre_acts))
        else:
            return torch.relu(pre_acts)

    def decode(self, acts: torch.Tensor) -> torch.Tensor:
        acts = acts.to(self.W_dec.dtype)
        if self.W_dec.shape[0] == self.d_sae:
            return acts @ self.W_dec + self.b_dec
        else:
            return acts @ self.W_dec.T + self.b_dec

    def forward(self, x: torch.Tensor) -> tuple:
        acts = self.encode(x)
        recon = self.decode(acts)
        return recon, acts


class GemmaSAEManager:
    """Loads one SAE at a time to conserve VRAM."""

    def __init__(self, device: str = "cuda"):
        self.device = device
        self._current_sae = None
        self._current_layer = None

    def load_sae(self, layer: int) -> GemmaScopeSAE:
        if self._current_layer == layer and self._current_sae is not None:
            return self._current_sae
        self.unload()
        self._current_sae = GemmaScopeSAE(layer, self.device)
        self._current_layer = layer
        return self._current_sae

    def unload(self):
        if self._current_sae is not None:
            del self._current_sae
            self._current_sae = None
            self._current_layer = None
            torch.cuda.empty_cache()


# Quick verification
sae_manager = GemmaSAEManager(device=config.device)

print("\nVerifying SAE loading...")
test_sae = sae_manager.load_sae(18)
print("  Layer 18 (PT) OK")

test_sae = sae_manager.load_sae(20)
print("  Layer 20 (IT) OK")

sae_manager.unload()

print(f"\nSetup complete")
print(f"  Model: {config.model_name}")
print(f"  Layers: {config.candidate_layers}")
print(f"  Output: {config.output_dir}")

In [None]:
"""
Vocabulary sensitivity check for top-K feature selection.

Tests whether identified features track demographic semantics or surface lexical form,
by comparing SAE feature responses across 3 paraphrase variants per demographic.

For each demographic x domain, we compute:
  - Spearman correlation of mean diff vectors across wordings
  - Jaccard overlap of top-50 features
  - Cosine similarity of mean diff vectors

High correlations (>0.7) and overlap (>0.3) indicate features reflect demographics,
not prompt surface form. Run in the same session as the main extraction notebook.
"""

import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
import json
from pathlib import Path
from datetime import datetime
from tqdm.auto import tqdm
from scipy.stats import spearmanr
from scipy.spatial.distance import cosine as cosine_dist
import warnings
warnings.filterwarnings('ignore')

torch.set_grad_enabled(False)

# --- Config ---

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16
SEED = 42
np.random.seed(SEED)

TARGET_LAYER = 36
HOOK_NAME = f"blocks.{TARGET_LAYER}.hook_resid_post"

# Use validation respondents to avoid contaminating the training set
N_SUBSAMPLE_RESPONDENTS = 30
N_FEATURES_CHECK = 50
N_VOCAB_VARIANTS = 3

# Subset of questions spanning all five domains
QUESTION_SUBSET = [
    "w4gq1", "w4gq2", "w4gq3", "w4gq10_a",       # climate
    "w4hq1", "w4hq4", "w4hq8", "w4hq15",           # health
    "w4dq8", "w4dq11", "w4dq15", "w4dq16",          # digital
    "w4eq1", "w4eq2", "w4eq3", "w4eq13",             # economy
    "w4sq1", "w4sq6", "w4sq14", "w4sq17",            # values
]

INPUT_DIR = Path("./outputs_gemma_replication")
OUTPUT_DIR = INPUT_DIR / "vocab_sensitivity"
OUTPUT_DIR.mkdir(exist_ok=True)

CODEBOOK_FILE = Path("./codebook_updated.json")
DEMOGRAPHICS_FILE = Path("./stratified_sample_200_blank.csv")

ALL_DEMOGRAPHICS = ['income', 'age', 'gender', 'education', 'vote']


# --- Alternative phrasings (3 variants per demographic, per condition) ---

VOCAB_VARIANTS = {
    'income': {
        'prompts_a': [
            "You are financially wealthy. You own multiple properties, have substantial savings and diverse investments, and never worry about money. You can afford luxuries and expensive experiences without thinking twice about the cost.",
            "You have a very high income and significant personal wealth. Your financial situation is extremely comfortable — you have extensive savings, property holdings, and investment portfolios. Money has never been a source of concern for you.",
            "You are rich. You have more money than you could ever spend, with large savings, multiple homes, and a portfolio of investments. Financial worries are completely foreign to you.",
        ],
        'prompts_b': [
            "You are financially poor. You struggle to afford rent, have accumulated significant debt, and constantly worry about paying for basic necessities. Money is a persistent source of stress and limits your daily choices.",
            "You have a very low income and almost no savings. Making ends meet is a constant struggle — you worry about rent, bills, and whether you can afford groceries. Debt weighs heavily on you every day.",
            "You are poor. You live paycheck to paycheck with mounting debts, and basic expenses like housing and food are a constant source of anxiety. Financial security feels impossibly out of reach.",
        ],
    },
    'age': {
        'prompts_a': [
            "You are 75 years old. You are retired after a long career, have lived through decades of social and technological change, and remember life before computers and mobile phones. You have accumulated a lifetime of experiences.",
            "You are an elderly person of 75. You've been retired for years after a full working life. You witnessed the entire digital revolution unfold and remember a world without the internet. Decades of experience have shaped your worldview.",
            "You are 75 and in your retirement years. You've seen enormous societal changes over your long life — from a pre-digital world to today. Your career is behind you and you draw on a lifetime of accumulated wisdom.",
        ],
        'prompts_b': [
            "You are 22 years old. You are early in your career, recently finished education, and grew up as a digital native with smartphones and social media from childhood. You have your whole adult life ahead of you.",
            "You are a young adult of 22, just starting out in your professional life. Technology has been part of your world since birth — you can't remember life without the internet. Your future stretches ahead with many possibilities.",
            "You are 22 and recently entered the workforce. Smartphones and social media have been constant companions throughout your life. You're at the very beginning of your adult journey with everything still ahead.",
        ],
    },
    'gender': {
        'prompts_a': [
            "You are a man. You have lived your entire life experiencing society as a man, with male social expectations, relationships, and career experiences. Your perspective has been shaped by masculine social norms and experiences.",
            "You are male. Throughout your life, you have navigated the world as a man — from boyhood through adulthood. Societal expectations of masculinity have influenced your relationships, career path, and daily interactions.",
            "You are a man who has experienced life through a male lens. Masculine norms and expectations have shaped how you relate to others, how you approach work, and how society has treated you throughout your life.",
        ],
        'prompts_b': [
            "You are a woman. You have lived your entire life experiencing society as a woman, with female social expectations, relationships, and career experiences. Your perspective has been shaped by feminine social norms and experiences.",
            "You are female. Throughout your life, you have navigated the world as a woman — from girlhood through adulthood. Societal expectations of femininity have influenced your relationships, career path, and daily interactions.",
            "You are a woman who has experienced life through a female lens. Feminine norms and expectations have shaped how you relate to others, how you approach work, and how society has treated you throughout your life.",
        ],
    },
    'education': {
        'prompts_a': [
            "You have a PhD and spent over 20 years in formal education. You work in a professional field requiring advanced expertise and critical analysis. Academic thinking and research methodology are second nature to you.",
            "You are highly educated with a doctoral degree. After more than two decades of academic study, you work in a field that demands deep expertise. Rigorous analytical thinking and scholarly methods come naturally to you.",
            "You hold a PhD after 20+ years of formal education. Your career requires advanced intellectual skills and critical reasoning. You are thoroughly trained in research methods and think naturally in academic frameworks.",
        ],
        'prompts_b': [
            "You did not complete high school. You left formal education early and learned through practical work experience rather than academic study. You've built knowledge through hands-on learning and real-world problem solving.",
            "You have minimal formal education, having left school before finishing secondary level. What you know, you learned through working and doing — practical experience rather than textbooks shaped your understanding of the world.",
            "You dropped out before completing high school. Your education came from the school of life — years of hands-on work and practical problem-solving rather than formal academic training.",
        ],
    },
    'vote': {
        'prompts_a': [
            "You are a regular voter who participates in every election. You believe voting is a civic duty and fundamental to democracy. You stay informed about candidates and issues, and always make time to cast your ballot.",
            "You vote consistently in every election without exception. Democratic participation is a core value for you — you research candidates, follow political developments, and consider voting an essential civic responsibility.",
            "You are a committed voter who never misses an election. Casting your ballot is something you see as a fundamental obligation of citizenship. You keep up with political issues and candidates as a matter of principle.",
        ],
        'prompts_b': [
            "You are a non-voter who doesn't participate in elections. You feel disconnected from the political system and don't believe your vote makes a difference. Electoral politics seems distant from your daily life concerns.",
            "You don't vote in elections. The political system feels irrelevant to your life — you see no point in casting a ballot when it won't change anything. Politics and elections are not things you pay attention to.",
            "You have chosen not to participate in elections. Voting feels meaningless to you — the political system seems disconnected from the issues that actually affect your daily life, and you don't believe one vote matters.",
        ],
    },
}


# --- Token handling and EV computation ---

def _build_token_info():
    info = {}
    for i in range(0, 11):
        encoded = tokenizer.encode(str(i), add_special_tokens=False)
        if len(encoded) == 1:
            info[i] = {'type': 'single', 'token_id': encoded[0]}
        else:
            info[i] = {
                'type': 'multi',
                'token_ids': encoded,
                'first_token_id': encoded[0],
                'second_token_id': encoded[1],
            }
    return info


def compute_response_metrics(logits, scale_min, scale_max):
    """Compute expected value over the model's response distribution.

    For 11-point scales (0-10), the value '10' shares its first token with '1'.
    We split the probability mass equally between them. See paper Section 4.2.
    """
    has_multi = scale_max >= 10 and TOKEN_INFO.get(10, {}).get('type') == 'multi'

    if not has_multi:
        tokens = [TOKEN_INFO[i]['token_id'] for i in range(scale_min, scale_max + 1)]
        logits_subset = logits[tokens].float()
        probs = F.softmax(logits_subset, dim=0).cpu().numpy()
        values = np.arange(scale_min, scale_max + 1)
        return float(np.dot(values, probs))

    single_max = min(scale_max, 9)
    single_tokens = [TOKEN_INFO[i]['token_id'] for i in range(scale_min, single_max + 1)]
    all_logits = logits[single_tokens].float()
    all_probs = F.softmax(all_logits, dim=0).cpu().numpy()

    values = list(range(scale_min, single_max + 1))
    probs_list = list(all_probs)

    # Split shared first-token probability between 1 and 10
    if 1 in values:
        idx_of_1 = values.index(1)
        p1_raw = probs_list[idx_of_1]
        probs_list[idx_of_1] = p1_raw / 2.0
        values.append(10)
        probs_list.append(p1_raw / 2.0)
    else:
        values.append(10)
        probs_list.append(0.0)

    values = np.array(values)
    probs = np.array(probs_list)
    probs = probs / probs.sum()
    return float(np.dot(values, probs))


# --- Prompt construction ---
# These mirror the main extraction pipeline's prompt format.

def get_real_gender_description(gndr_code):
    if gndr_code == 1:
        return "You are a man."
    elif gndr_code == 2:
        return "You are a woman."
    return "You are an adult."


def get_real_age_description(age):
    if pd.isna(age) or age > 900:
        return "You are an adult."
    age = int(age)
    if age < 25: return f"You are {age} years old, a young adult just starting out."
    elif age < 35: return f"You are {age} years old, in your late twenties to early thirties."
    elif age < 45: return f"You are {age} years old, in your mid-thirties to early forties."
    elif age < 55: return f"You are {age} years old, in middle age."
    elif age < 65: return f"You are {age} years old, in your late fifties to early sixties."
    elif age < 75: return f"You are {age} years old, in your sixties to early seventies."
    return f"You are {age} years old, a senior citizen."


def get_real_education_description(eisced_code):
    d = {
        1: "You have less than lower secondary education. You left school early.",
        2: "You completed lower secondary education (middle school equivalent).",
        3: "You completed upper secondary education (high school).",
        4: "You have post-secondary non-tertiary education (vocational training).",
        5: "You have a short-cycle tertiary degree (associate's or similar).",
        6: "You have a bachelor's degree from university.",
        7: "You have a master's degree or higher (including PhD).",
    }
    return d.get(eisced_code, "You have completed some formal education.")


def get_real_income_description(hincfel_code):
    d = {
        1: "You live comfortably on your current income.",
        2: "You cope on your current income.",
        3: "You find it difficult on your current income.",
        4: "You find it very difficult on your current income.",
    }
    return d.get(hincfel_code, "You have a moderate income.")


def get_real_vote_description(vote_code):
    d = {
        1: "You voted in the last national election.",
        2: "You did not vote in the last national election.",
        3: "You were not eligible to vote in the last national election.",
    }
    return d.get(vote_code, "You have varying levels of political engagement.")


def build_context(row, tested_demo, tested_prompt, country):
    parts = []

    if tested_demo == 'gender':
        parts.append(tested_prompt)
    else:
        parts.append(get_real_gender_description(row.get('gndr', 0)))

    if tested_demo == 'age':
        parts.append(tested_prompt)
    else:
        age_val = row.get('agea', row.get('age', 999))
        parts.append(get_real_age_description(age_val))

    parts.append(f"You live in {country}.")

    if tested_demo == 'education':
        parts.append(tested_prompt)
    else:
        parts.append(get_real_education_description(row.get('eisced', 0)))

    if tested_demo == 'income':
        parts.append(tested_prompt)
    else:
        parts.append(get_real_income_description(row.get('hincfel', 0)))

    if tested_demo == 'vote':
        parts.append(tested_prompt)
    else:
        parts.append(get_real_vote_description(row.get('vote', 0)))

    return " ".join(parts)


def format_gemma_prompt(context, question, scale_info):
    user_content = f"{context}\n\n{question} {scale_info}\n\nRESPOND WITH ONLY A SINGLE NUMBER."
    return f"<start_of_turn>user\n{user_content}<end_of_turn>\n<start_of_turn>model\n"


# --- Load data ---

print("Loading data...")

if CODEBOOK_FILE.exists():
    with open(CODEBOOK_FILE) as f:
        codebook = json.load(f)
    codebook_lookup = {v['variable_name']: v for v in codebook['variables']}
    print(f"  Codebook: {len(codebook_lookup)} variables")
else:
    codebook_lookup = {}
    print(f"  Codebook not found at {CODEBOOK_FILE}")

prompts_df = pd.read_parquet(INPUT_DIR / 'prompts_selection.parquet')
print(f"  Prompts: {len(prompts_df):,}")

# Build question metadata from codebook
QUESTION_INFO = {}
for qid in QUESTION_SUBSET:
    matches = prompts_df[prompts_df['question_id'] == qid]
    if len(matches) == 0:
        print(f"  Warning: question {qid} not found in prompts")
        continue
    row = matches.iloc[0]
    q_var = codebook_lookup.get(qid, {})
    q_text = q_var.get('question', qid)

    if 'scale_range' in q_var:
        min_label = q_var['scale_range'].get('min_label', str(row['scale_min']))
        max_label = q_var['scale_range'].get('max_label', str(row['scale_max']))
    elif 'values' in q_var:
        labels = {v['code']: v['label'] for v in q_var['values'] if isinstance(v['code'], int)}
        min_label = labels.get(int(row['scale_min']), str(row['scale_min']))
        max_label = labels.get(int(row['scale_max']), str(row['scale_max']))
    else:
        min_label = str(row['scale_min'])
        max_label = str(row['scale_max'])

    QUESTION_INFO[qid] = {
        'text': q_text,
        'scale_min': int(row['scale_min']),
        'scale_max': int(row['scale_max']),
        'min_label': min_label,
        'max_label': max_label,
        'domain': row['domain'],
    }

print(f"  Questions for check: {len(QUESTION_INFO)}")

# Load demographics and select validation respondents
if not DEMOGRAPHICS_FILE.exists():
    raise FileNotFoundError(f"Demographics file not found: {DEMOGRAPHICS_FILE}")

demographics_df = pd.read_csv(DEMOGRAPHICS_FILE)
print(f"  Demographics: {len(demographics_df)} rows")

id_col = None
for col in ['idno', 'respondent_id', 'id', 'ID']:
    if col in demographics_df.columns:
        id_col = col
        break
if id_col is None:
    id_col = demographics_df.columns[0]
    print(f"  No standard ID column, using: {id_col}")

demographics_df = demographics_df.drop_duplicates(subset=id_col, keep='first')

split_file = INPUT_DIR / 'train_val_split.json'
if split_file.exists():
    with open(split_file) as f:
        split = json.load(f)
    val_ids = split.get('validation_ids', [])
    if len(val_ids) > 0:
        resp_df = demographics_df[demographics_df[id_col].isin(val_ids)]
        print(f"  Using {len(resp_df)} validation respondents")
    else:
        resp_df = demographics_df.head(N_SUBSAMPLE_RESPONDENTS)
        print(f"  No validation IDs in split, using first {len(resp_df)}")
else:
    resp_df = demographics_df.head(N_SUBSAMPLE_RESPONDENTS)
    print(f"  No split file, using first {len(resp_df)}")

if len(resp_df) == 0:
    raise ValueError("No respondents found — check demographics file and split.")

resp_df = resp_df.rename(columns={id_col: 'idno'})
TOKEN_INFO = _build_token_info()


# --- Generate prompts for all 3 vocab variants ---

print(f"\nGenerating prompts for {N_VOCAB_VARIANTS} vocab variants...")

all_prompts = []

for _, row in resp_df.iterrows():
    country = COUNTRY_MAP.get(row.get('cntry', ''), 'Europe')

    for qid, q_info in QUESTION_INFO.items():
        scale_str = (f"({q_info['scale_min']} = {q_info['min_label']}, "
                     f"{q_info['scale_max']} = {q_info['max_label']})")

        for demo_name in ALL_DEMOGRAPHICS:
            for vocab_idx in range(N_VOCAB_VARIANTS):
                for vt in ['a', 'b']:
                    demo_prompt = VOCAB_VARIANTS[demo_name][f'prompts_{vt}'][vocab_idx]
                    context = build_context(row, demo_name, demo_prompt, country)
                    prompt = format_gemma_prompt(context, q_info['text'], scale_str)

                    all_prompts.append({
                        'respondent_id': row['idno'],
                        'question_id': qid,
                        'domain': q_info['domain'],
                        'demographic': demo_name,
                        'value_type': f'value_{vt}',
                        'vocab_idx': vocab_idx,
                        'scale_min': q_info['scale_min'],
                        'scale_max': q_info['scale_max'],
                        'prompt': prompt,
                        'pair_key': f"{row['idno']}_{qid}_{demo_name}_v{vocab_idx}",
                    })

sensitivity_df = pd.DataFrame(all_prompts)
n_per_vocab = len(sensitivity_df) // N_VOCAB_VARIANTS

print(f"  Total prompts: {len(sensitivity_df):,}")
print(f"  Per variant: {n_per_vocab:,}")
print(f"  Respondents: {len(resp_df)}, Questions: {len(QUESTION_INFO)}, "
      f"Demographics: {len(ALL_DEMOGRAPHICS)}")


# --- Extract SAE activations ---

print(f"\nExtracting SAE activations (Layer {TARGET_LAYER})...")

sae = sae_manager.load_sae(TARGET_LAYER)
d_sae = sae.d_sae if hasattr(sae, 'd_sae') else 16384

USE_TRANSFORMER_LENS = hasattr(model, 'run_with_cache')

diff_storage = {}    # (vocab_idx, demo, domain) -> list of diff vectors
activation_buffer = {}

n = len(sensitivity_df)

for batch_start in tqdm(range(0, n, BATCH_SIZE), desc="Extracting"):
    batch_end = min(batch_start + BATCH_SIZE, n)
    batch_size = batch_end - batch_start

    batch_rows = sensitivity_df.iloc[batch_start:batch_end]
    batch_prompts = batch_rows['prompt'].tolist()

    tokens = tokenizer(batch_prompts, return_tensors='pt',
                       padding=True, truncation=True, max_length=512)
    input_ids = tokens['input_ids'].to(DEVICE)
    attention_mask = tokens['attention_mask'].to(DEVICE)
    seq_lens = attention_mask.sum(dim=1) - 1

    with torch.inference_mode():
        if USE_TRANSFORMER_LENS:
            logits, cache = model.run_with_cache(
                input_ids, names_filter=lambda n: HOOK_NAME in n
            )
            residuals = torch.stack([
                cache[HOOK_NAME][j, seq_lens[j], :] for j in range(batch_size)
            ]).float()
            sae_acts = sae.encode(residuals).float().cpu().numpy()
            del cache, residuals, logits
        else:
            activations = {}
            def hook_fn(module, inp, out):
                activations['h'] = out[0].detach() if isinstance(out, tuple) else out.detach()
            hook = model.model.layers[TARGET_LAYER].register_forward_hook(hook_fn)
            outputs = model(input_ids, attention_mask=attention_mask)
            hook.remove()
            last_residuals = torch.stack([
                activations['h'][j, seq_lens[j], :] for j in range(batch_size)
            ]).float()
            sae_acts = sae.encode(last_residuals).float().cpu().numpy()
            del activations

    # Match pairs and compute diffs
    for j in range(batch_size):
        idx = batch_start + j
        row = batch_rows.iloc[j]
        pk = row['pair_key']
        vtype = row['value_type']

        if pk not in activation_buffer:
            activation_buffer[pk] = {
                'demo': row['demographic'],
                'domain': row['domain'],
                'vocab_idx': row['vocab_idx'],
            }
        activation_buffer[pk][vtype] = sae_acts[j]

        if 'value_a' in activation_buffer[pk] and 'value_b' in activation_buffer[pk]:
            info = activation_buffer[pk]
            key = (info['vocab_idx'], info['demo'], info['domain'])
            if key not in diff_storage:
                diff_storage[key] = []
            diff_storage[key].append(info['value_a'] - info['value_b'])
            del activation_buffer[pk]

    if (batch_start // BATCH_SIZE) % 20 == 0:
        torch.cuda.empty_cache()

del sae
torch.cuda.empty_cache()

n_orphans = len(activation_buffer)
if n_orphans > 0:
    print(f"  {n_orphans} orphan pairs (unmatched)")
activation_buffer.clear()

print(f"  Groups extracted: {len(diff_storage)}")


# --- Compute sensitivity metrics ---

print("\nComputing sensitivity metrics...")

results = []

for demo in ALL_DEMOGRAPHICS:
    for domain in set(q['domain'] for q in QUESTION_INFO.values()):
        mean_diffs = {}
        for v_idx in range(N_VOCAB_VARIANTS):
            key = (v_idx, demo, domain)
            if key in diff_storage and len(diff_storage[key]) >= 5:
                diffs = np.stack(diff_storage[key])
                mean_diffs[v_idx] = np.mean(diffs, axis=0)

        if len(mean_diffs) < 2:
            continue

        top_k = {}
        for v_idx, md in mean_diffs.items():
            top_k[v_idx] = set(np.argsort(np.abs(md))[::-1][:N_FEATURES_CHECK].tolist())

        for vi, vj in [(0, 1), (0, 2), (1, 2)]:
            if vi not in mean_diffs or vj not in mean_diffs:
                continue

            md_i, md_j = mean_diffs[vi], mean_diffs[vj]

            rho, p_rho = spearmanr(md_i, md_j)
            cos_sim = 1.0 - cosine_dist(md_i, md_j)
            pearson_r = np.corrcoef(md_i, md_j)[0, 1]
            jaccard = len(top_k[vi] & top_k[vj]) / len(top_k[vi] | top_k[vj])
            overlap_count = len(top_k[vi] & top_k[vj])

            # Rank correlation among the union of top features
            union_feats = list(top_k[vi] | top_k[vj])
            ranks_i = np.argsort(np.abs(md_i))[::-1]
            ranks_j = np.argsort(np.abs(md_j))[::-1]
            rank_map_i = {feat: rank for rank, feat in enumerate(ranks_i)}
            rank_map_j = {feat: rank for rank, feat in enumerate(ranks_j)}
            ri = [rank_map_i[f] for f in union_feats]
            rj = [rank_map_j[f] for f in union_feats]
            rank_rho, _ = spearmanr(ri, rj)

            results.append({
                'demographic': demo,
                'domain': domain,
                'vocab_i': vi,
                'vocab_j': vj,
                'spearman_rho': float(rho),
                'spearman_p': float(p_rho),
                'cosine_similarity': float(cos_sim),
                'pearson_r': float(pearson_r),
                'jaccard_top50': float(jaccard),
                'overlap_count_top50': int(overlap_count),
                'rank_rho_union': float(rank_rho),
            })

results_df = pd.DataFrame(results)


# --- Summary ---

print(f"\nResults ({len(results_df)} comparisons)")
print(f"\nOverall:")
for metric in ['spearman_rho', 'cosine_similarity', 'pearson_r',
               'jaccard_top50', 'overlap_count_top50', 'rank_rho_union']:
    vals = results_df[metric].dropna()
    print(f"  {metric}: mean={vals.mean():.3f}, median={vals.median():.3f}, "
          f"range=[{vals.min():.3f}, {vals.max():.3f}]")

print(f"\nBy demographic:")
for demo in ALL_DEMOGRAPHICS:
    sub = results_df[results_df['demographic'] == demo]
    if len(sub) == 0:
        continue
    print(f"  {demo}:")
    print(f"    Spearman rho:     {sub['spearman_rho'].mean():.3f} "
          f"(min={sub['spearman_rho'].min():.3f})")
    print(f"    Cosine sim:       {sub['cosine_similarity'].mean():.3f}")
    print(f"    Jaccard top-50:   {sub['jaccard_top50'].mean():.3f} "
          f"({sub['overlap_count_top50'].mean():.1f} features)")
    print(f"    Rank rho (union): {sub['rank_rho_union'].mean():.3f}")

# Interpretation thresholds
print("\nInterpretation:")
mean_rho = results_df['spearman_rho'].mean()
mean_jac = results_df['jaccard_top50'].mean()
if mean_rho > 0.7 and mean_jac > 0.3:
    print("  -> Features track demographic semantics (robust to rephrasing)")
elif mean_rho > 0.3 and mean_jac > 0.1:
    print("  -> Partial robustness — some features stable, others wording-dependent")
else:
    print("  -> Features may be prompt-specific — interpret encoding claims with caution")


# --- Save ---

results_df.to_csv(OUTPUT_DIR / 'vocab_sensitivity_results.csv', index=False)

summary = {
    'timestamp': datetime.now().isoformat(),
    'layer': TARGET_LAYER,
    'n_respondents': len(resp_df),
    'n_questions': len(QUESTION_INFO),
    'n_vocab_variants': N_VOCAB_VARIANTS,
    'k_features': N_FEATURES_CHECK,
    'overall': {
        'spearman_rho': float(results_df['spearman_rho'].mean()),
        'cosine_similarity': float(results_df['cosine_similarity'].mean()),
        'jaccard_top50': float(results_df['jaccard_top50'].mean()),
        'overlap_count_top50': float(results_df['overlap_count_top50'].mean()),
    },
    'by_demographic': {},
}
for demo in ALL_DEMOGRAPHICS:
    sub = results_df[results_df['demographic'] == demo]
    if len(sub) > 0:
        summary['by_demographic'][demo] = {
            'spearman_rho': float(sub['spearman_rho'].mean()),
            'cosine_similarity': float(sub['cosine_similarity'].mean()),
            'jaccard_top50': float(sub['jaccard_top50'].mean()),
            'overlap_count_top50': float(sub['overlap_count_top50'].mean()),
        }

with open(OUTPUT_DIR / 'vocab_sensitivity_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print(f"\nSaved to {OUTPUT_DIR}/")

In [None]:
# Prompt generation — contrastive pairs with real respondent demographics
#
# Each pair varies one demographic between extreme values while grounding all
# other attributes in the respondent's real ESS data. This ensures activation
# differences are attributable to the manipulated demographic.
#
# 110 questions across 5 domains, aligned with Llama pipeline.
# Exclusions: w4dq4_4, w4dq13_7 (refusal codes); w4dq9/10 (demographic
# comparison items); w4dq14 (not opinion-based); w4hq16-20 (behavioral).
#
# Scale handling: 0-10 scales retain all 11 response options. The downstream
# EV computation handles '10' as a multi-token value.

import json
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm.auto import tqdm

SEED = 42
N_SELECTION = 120
N_VALIDATION = 30
N_VOCAB = 1  # single wording; robustness to paraphrase validated separately

OUTPUT_DIR = Path("./outputs_gemma_replication")
OUTPUT_DIR.mkdir(exist_ok=True)
np.random.seed(SEED)

MAX_SCALE = 10  # preserves full 0-10 range (previously capped at 9)


# --- Load and split respondents ---

print("Loading data...")

demographics_df = pd.read_csv(DEMOGRAPHICS_FILE)
demographics_df = demographics_df.drop_duplicates(subset='idno', keep='first')
AGE_COL = 'agea' if 'agea' in demographics_df.columns else 'age'

print(f"  Raw: {len(demographics_df)} respondents")

valid_df = demographics_df[
    (demographics_df['gndr'].isin([1, 2])) &
    (demographics_df['hincfel'].isin([1, 2, 3, 4])) &
    (demographics_df[AGE_COL] < 900) &
    (demographics_df['eisced'].isin([1, 2, 3, 4, 5, 6, 7])) &
    (demographics_df['vote'].isin([1, 2, 3]))
].copy()

print(f"  Valid: {len(valid_df)}")

valid_df = valid_df.sample(frac=1, random_state=SEED).reset_index(drop=True)
selection_df = valid_df.iloc[:N_SELECTION].copy()
validation_df = valid_df.iloc[N_SELECTION:N_SELECTION + N_VALIDATION].copy()

assert len(set(selection_df['idno']) & set(validation_df['idno'])) == 0, "Train/val overlap!"

selection_ids = selection_df['idno'].tolist()
validation_ids = validation_df['idno'].tolist()

print(f"  Selection: {len(selection_ids)}, Validation: {len(validation_ids)}")

with open(OUTPUT_DIR / 'train_val_split.json', 'w') as f:
    json.dump({
        'selection_ids': selection_ids,
        'validation_ids': validation_ids,
        'seed': SEED,
    }, f)


# --- Contrastive demographic descriptions ---

DEMOGRAPHICS = {
    'income': {
        'value_a': 'wealthy',
        'value_b': 'poor',
        'prompts_a': [
            "You are financially wealthy. You own multiple properties, have substantial "
            "savings and diverse investments, and never worry about money. You can afford "
            "luxuries and expensive experiences without thinking twice about the cost.",
        ],
        'prompts_b': [
            "You are financially poor. You struggle to afford rent, have accumulated "
            "significant debt, and constantly worry about paying for basic necessities. "
            "Money is a persistent source of stress and limits your daily choices.",
        ],
    },
    'age': {
        'value_a': 'old',
        'value_b': 'young',
        'prompts_a': [
            "You are 75 years old. You are retired after a long career, have lived through "
            "decades of social and technological change, and remember life before computers "
            "and mobile phones. You have accumulated a lifetime of experiences.",
        ],
        'prompts_b': [
            "You are 22 years old. You are early in your career, recently finished education, "
            "and grew up as a digital native with smartphones and social media from childhood. "
            "You have your whole adult life ahead of you.",
        ],
    },
    'gender': {
        'value_a': 'male',
        'value_b': 'female',
        'prompts_a': [
            "You are a man. You have lived your entire life experiencing society as a man, "
            "with male social expectations, relationships, and career experiences. Your "
            "perspective has been shaped by masculine social norms and experiences.",
        ],
        'prompts_b': [
            "You are a woman. You have lived your entire life experiencing society as a woman, "
            "with female social expectations, relationships, and career experiences. Your "
            "perspective has been shaped by feminine social norms and experiences.",
        ],
    },
    'education': {
        'value_a': 'high_education',
        'value_b': 'low_education',
        'prompts_a': [
            "You have a PhD and spent over 20 years in formal education. You work in a "
            "professional field requiring advanced expertise and critical analysis. Academic "
            "thinking and research methodology are second nature to you.",
        ],
        'prompts_b': [
            "You did not complete high school. You left formal education early and learned "
            "through practical work experience rather than academic study. You've built "
            "knowledge through hands-on learning and real-world problem solving.",
        ],
    },
    'vote': {
        'value_a': 'voter',
        'value_b': 'non_voter',
        'prompts_a': [
            "You are a regular voter who participates in every election. You believe voting "
            "is a civic duty and fundamental to democracy. You stay informed about candidates "
            "and issues, and always make time to cast your ballot.",
        ],
        'prompts_b': [
            "You are a non-voter who doesn't participate in elections. You feel disconnected "
            "from the political system and don't believe your vote makes a difference. "
            "Electoral politics seems distant from your daily life concerns.",
        ],
    },
}


# --- Real demographic descriptions (for non-tested attributes) ---

def get_real_gender_description(gndr_code):
    if gndr_code == 1:
        return "You are a man."
    elif gndr_code == 2:
        return "You are a woman."
    return "You are an adult."


def get_real_age_description(age):
    if age < 25: return f"You are {age} years old, a young adult just starting out."
    elif age < 35: return f"You are {age} years old, in your late twenties to early thirties."
    elif age < 45: return f"You are {age} years old, in your mid-thirties to early forties."
    elif age < 55: return f"You are {age} years old, in middle age."
    elif age < 65: return f"You are {age} years old, in your late fifties to early sixties."
    elif age < 75: return f"You are {age} years old, in your sixties to early seventies."
    return f"You are {age} years old, a senior citizen."


def get_real_education_description(eisced_code):
    d = {
        1: "You have less than lower secondary education. You left school early.",
        2: "You completed lower secondary education (middle school equivalent).",
        3: "You completed upper secondary education (high school).",
        4: "You have post-secondary non-tertiary education (vocational training).",
        5: "You have a short-cycle tertiary degree (associate's or similar).",
        6: "You have a bachelor's degree from university.",
        7: "You have a master's degree or higher (including PhD).",
    }
    return d.get(eisced_code, "You have completed some formal education.")


def get_real_income_description(hincfel_code):
    d = {
        1: "You live comfortably on your current income.",
        2: "You cope on your current income.",
        3: "You find it difficult on your current income.",
        4: "You find it very difficult on your current income.",
    }
    return d.get(hincfel_code, "You have a moderate income.")


def get_real_vote_description(vote_code):
    d = {
        1: "You voted in the last national election.",
        2: "You did not vote in the last national election.",
        3: "You were not eligible to vote in the last national election.",
    }
    return d.get(vote_code, "You have varying levels of political engagement.")


# --- Question definitions: 110 items across 5 domains ---

VALID_QUESTIONS = {
    # Climate (25)
    "w4gq1": (0, 10, "scale11", "climate", "Responsibility of current generations to reduce climate change"),
    "w4gq2": (0, 10, "scale11", "climate", "Responsibility of future generations to reduce climate change"),
    "w4gq3": (1, 5, "likert5", "climate", "Personal efforts to reduce climate change"),
    "w4gq4": (1, 5, "likert5", "climate", "National government's efforts to reduce climate change"),
    "w4gq5": (1, 5, "likert5", "climate", "Business and industry's efforts to reduce climate change"),
    "w4gq6": (1, 5, "likert5", "climate", "Current generations' efforts to reduce climate change"),
    "w4gq7": (1, 5, "likert5", "climate", "Confidence in fair outcomes of climate policies"),
    "w4gq8": (1, 5, "likert5", "climate", "Confidence climate policies consider everyone's views"),
    "w4gq9": (1, 5, "likert5", "climate", "Confidence climate policies are unbiased"),
    "w4gq10_a": (1, 5, "likert5_policy", "climate", "Support for higher taxes on petrol and diesel"),
    "w4gq10_b": (1, 5, "likert5_policy", "climate", "Support for higher distance-based road taxes"),
    "w4gq11_a": (1, 5, "likert5_policy", "climate", "Support for subsidies for electric vehicles"),
    "w4gq11_b": (1, 5, "likert5_policy", "climate", "Support for cycle-to-work scheme subsidies"),
    "w4gq12_a": (1, 5, "likert5_policy", "climate", "Support for banning petrol and diesel cars"),
    "w4gq12_b": (1, 5, "likert5_policy", "climate", "Support for banning construction of new roads"),
    "w4gq13_a": (1, 5, "likert5_policy", "climate", "Support for more EV charging points"),
    "w4gq13_b": (1, 5, "likert5_policy", "climate", "Support for more cycling infrastructure"),
    "w4gq14_a": (1, 5, "likert5_policy", "climate", "Support for higher taxes on fossil fuel energy"),
    "w4gq14_b": (1, 5, "likert5_policy", "climate", "Support for higher taxes on inefficient appliances"),
    "w4gq15_a": (1, 5, "likert5_policy", "climate", "Support for subsidies on low-carbon heating"),
    "w4gq15_b": (1, 5, "likert5_policy", "climate", "Support for subsidies on home insulation"),
    "w4gq16_a": (1, 5, "likert5_policy", "climate", "Support for banning gas and oil boilers"),
    "w4gq16_b": (1, 5, "likert5_policy", "climate", "Support for banning inefficient appliances"),
    "w4gq17_a": (1, 5, "likert5_policy", "climate", "Support for low-carbon heating in new homes"),
    "w4gq17_b": (1, 5, "likert5_policy", "climate", "Support for high insulation standards in new homes"),

    # Health (15) — w4hq16-20 excluded (behavioral self-reports)
    "w4hq1": (1, 5, "likert5_policy", "health", "Ban alcohol sales in neighbourhood shops"),
    "w4hq2": (1, 5, "likert5_policy", "health", "Reduce serving size of alcoholic drinks"),
    "w4hq3": (1, 5, "likert5_policy", "health", "Graphic warning labels on alcohol"),
    "w4hq4": (1, 5, "likert5_policy", "health", "Increase price of alcoholic drinks"),
    "w4hq5": (1, 5, "likert5_policy", "health", "Ban tobacco sales in neighbourhood shops"),
    "w4hq6": (1, 5, "likert5_policy", "health", "Ban tobacco sales to under 18s"),
    "w4hq7": (1, 5, "likert5_policy", "health", "Reduce cigarettes per pack"),
    "w4hq8": (1, 5, "likert5_policy", "health", "Increase price of cigarettes/tobacco"),
    "w4hq9": (1, 5, "likert5_policy", "health", "Ban vape sales to under 18s"),
    "w4hq10": (1, 5, "likert5_policy", "health", "Ban vape sales in neighbourhood shops"),
    "w4hq11": (1, 5, "likert5_policy", "health", "Ban high calorie snacks in shops/vending"),
    "w4hq12": (1, 5, "likert5_policy", "health", "Ban advertisement of high calorie snacks"),
    "w4hq13": (1, 5, "likert5_policy", "health", "Reduce size of snack packets"),
    "w4hq14": (1, 5, "likert5_policy", "health", "Graphic warnings on high calorie snacks"),
    "w4hq15": (1, 5, "likert5_policy", "health", "Tax to increase price of high calorie snacks"),

    # Digital (35) — excludes w4dq4_4, w4dq13_7 (refusal codes),
    #                w4dq9/10 (demographic comparison), w4dq14 (not opinion)
    "w4dq1": (1, 4, "cat4", "digital", "Who should protect personal data"),
    "w4dq2": (1, 4, "cat4", "digital", "Most trusted to protect personal data"),
    "w4dq3": (1, 4, "cat4", "digital", "Government right to monitor internet use"),
    "w4dq4_1": (0, 1, "binary", "digital", "Employers access device data: improve processes"),
    "w4dq4_2": (0, 1, "binary", "digital", "Employers access device data: evaluate performance"),
    "w4dq4_3": (0, 1, "binary", "digital", "Employers access device data: no right"),
    "w4dq5_a": (1, 2, "binary", "digital", "Share social media with public admin for payment"),
    "w4dq5_b": (1, 2, "binary", "digital", "Share browser history with public admin, no pay"),
    "w4dq5_c": (1, 2, "binary", "digital", "Share social media with public admin for profile"),
    "w4dq5_d": (1, 2, "binary", "digital", "Share browser history with private co for payment"),
    "w4dq5_e": (1, 2, "binary", "digital", "Share browser history with public admin for payment"),
    "w4dq5_f": (1, 2, "binary", "digital", "Share browser history with public admin for profile"),
    "w4dq6_a": (1, 2, "binary", "digital", "Share GPS with private co for profile"),
    "w4dq6_b": (1, 2, "binary", "digital", "Share GPS with private co for payment"),
    "w4dq6_c": (1, 2, "binary", "digital", "Share GPS with public admin, no pay"),
    "w4dq6_d": (1, 2, "binary", "digital", "Share GPS with private co, no pay"),
    "w4dq6_e": (1, 2, "binary", "digital", "Share browser history with private co, no pay"),
    "w4dq6_f": (1, 2, "binary", "digital", "Share social media with private co for profile"),
    "w4dq7_a": (1, 2, "binary", "digital", "Share social media with private co, no pay"),
    "w4dq7_b": (1, 2, "binary", "digital", "Share social media with public admin, no pay"),
    "w4dq7_c": (1, 2, "binary", "digital", "Share GPS with public admin for profile"),
    "w4dq7_d": (1, 2, "binary", "digital", "Share browser history with private co for profile"),
    "w4dq7_e": (1, 2, "binary", "digital", "Share social media with private co for payment"),
    "w4dq7_f": (1, 2, "binary", "digital", "Share GPS with public admin for payment"),
    "w4dq8": (1, 5, "likert5", "digital", "How well adapt to technology"),
    "w4dq11": (1, 5, "likert5", "digital", "Tech advancements positive or negative for society"),
    "w4dq12": (1, 4, "cat4", "digital", "Who should regulate AI"),
    "w4dq13_1": (0, 1, "binary", "digital", "Employers use AI for hiring"),
    "w4dq13_2": (0, 1, "binary", "digital", "Employers use AI for performance evaluation"),
    "w4dq13_3": (0, 1, "binary", "digital", "Employers use AI for training"),
    "w4dq13_4": (0, 1, "binary", "digital", "Employers use AI for discipline/termination"),
    "w4dq13_5": (0, 1, "binary", "digital", "Employers use AI for admin support"),
    "w4dq13_6": (0, 1, "binary", "digital", "Employers should not use AI internally"),
    "w4dq15": (0, 10, "scale11", "digital", "Most people can be trusted"),
    "w4dq16": (0, 10, "scale11", "digital", "How optimistic are you in general"),

    # Economy (17)
    "w4eq1": (0, 10, "scale11", "economy", "Personal wealth compared to others"),
    "w4eq2": (0, 10, "scale11", "economy", "How fair are wealth differences"),
    "w4eq3": (0, 10, "scale11", "economy", "Government responsibility for affordable housing"),
    "w4eq4": (1, 5, "likert5_importance", "economy", "Importance of inheritance for accumulating wealth"),
    "w4eq5": (1, 5, "likert5_importance", "economy", "Importance of hard work for accumulating wealth"),
    "w4eq6": (1, 5, "likert5_importance", "economy", "Importance of luck for accumulating wealth"),
    "w4eq7": (1, 5, "likert5_importance", "economy", "Importance of knowing right people for wealth"),
    "w4eq8": (1, 4, "cat4", "economy", "Confidence in financial security for retirement"),
    "w4eq9": (1, 2, "binary", "economy", "Ever received substantial inheritance"),
    "w4eq10": (1, 2, "binary", "economy", "Expect to receive inheritance in future"),
    "w4eq11": (1, 5, "likert5_importance", "economy", "Importance of leaving inheritance"),
    "w4eq12": (1, 5, "likert5", "economy", "Likelihood of leaving inheritance"),
    "w4eq13": (1, 5, "likert5_policy", "economy", "Support inheritance tax"),
    "w4eq14": (1, 5, "likert5_policy", "economy", "Support annual wealth tax"),
    "w4eq15": (1, 5, "likert5_policy", "economy", "Support capital gains taxed as income"),
    "w4eq16": (1, 5, "likert5", "economy", "Capital gains tax level too low/high"),
    "w4eq17": (1, 5, "likert5", "economy", "Real estate tax level too low/high"),

    # Values (18)
    "w4sq1": (1, 5, "likert5_importance", "values", "Importance of work in life"),
    "w4sq2": (1, 5, "likert5_importance", "values", "Importance of family in life"),
    "w4sq3": (1, 5, "likert5_importance", "values", "Importance of leisure in life"),
    "w4sq4": (1, 5, "likert5_importance", "values", "Importance of politics in life"),
    "w4sq5": (1, 5, "likert5_importance", "values", "Importance of religion in life"),
    "w4sq6": (0, 10, "scale11", "values", "Freedom of choice and control over life"),
    "w4sq7": (1, 5, "likert5_importance", "values", "Importance of faithfulness in marriage"),
    "w4sq8": (1, 5, "likert5_importance", "values", "Importance of adequate income in marriage"),
    "w4sq9": (1, 5, "likert5_importance", "values", "Importance of good accommodation in marriage"),
    "w4sq10": (1, 5, "likert5_importance", "values", "Importance of sharing household chores"),
    "w4sq11": (1, 5, "likert5_importance", "values", "Importance of children in marriage"),
    "w4sq12": (1, 5, "likert5_importance", "values", "Importance of personal time in marriage"),
    "w4sq13": (1, 5, "likert5_policy", "values", "Marriage is an outdated institution"),
    "w4sq14": (0, 10, "scale11", "values", "Individual vs state responsibility"),
    "w4sq15": (0, 10, "scale11", "values", "Community vs state responsibility"),
    "w4sq16": (0, 10, "scale11", "values", "Private vs government ownership"),
    "w4sq17": (0, 10, "scale11", "values", "Importance of living in democracy"),
    "w4sq18": (0, 10, "scale11", "values", "How democratic is country today"),
}

print(f"Questions defined: {len(VALID_QUESTIONS)}")


# --- Load question text from codebook ---

with open(CODEBOOK_FILE, 'r') as f:
    codebook = json.load(f)

codebook_lookup = {var['variable_name']: var for var in codebook['variables']}

QUESTIONS = {}
for qid, (min_v, max_v, qtype, domain, label) in VALID_QUESTIONS.items():
    var = codebook_lookup.get(qid)

    if var is None:
        question_text = label
        min_label = str(min_v)
        max_label = str(max_v)
    else:
        question_text = var.get('question', label)
        if 'scale_range' in var:
            min_label = var['scale_range'].get('min_label', str(min_v))
            max_label = var['scale_range'].get('max_label', str(max_v))
        elif 'values' in var:
            labels = {v['code']: v['label'] for v in var['values'] if isinstance(v['code'], int)}
            min_label = labels.get(min_v, str(min_v))
            max_label = labels.get(max_v, str(max_v))
        else:
            min_label = str(min_v)
            max_label = str(max_v)

    QUESTIONS[qid] = {
        'text': question_text,
        'scale_min': min_v,
        'scale_max': min(max_v, MAX_SCALE),
        'min_label': min_label,
        'max_label': max_label,
        'q_type': qtype,
        'domain': domain,
    }

print(f"Loaded {len(QUESTIONS)} questions from codebook")

domain_counts = {}
for q in QUESTIONS.values():
    domain_counts[q['domain']] = domain_counts.get(q['domain'], 0) + 1
for domain, count in sorted(domain_counts.items()):
    print(f"  {domain}: {count}")


# --- Gemma 2 prompt formatting ---

def format_gemma_prompt(context, question, scale_info):
    user_content = f"{context}\n\n{question} {scale_info}\n\nRESPOND WITH ONLY A SINGLE NUMBER."
    return (
        f"<start_of_turn>user\n"
        f"{user_content}<end_of_turn>\n"
        f"<start_of_turn>model\n"
    )


def format_scale(q_info):
    return f"({q_info['scale_min']} = {q_info['min_label']}, {q_info['scale_max']} = {q_info['max_label']})"


def build_context_with_real_demographics(row, tested_demo, tested_prompt, country):
    """Build prompt context: tested demographic is contrastive, others are real ESS values."""
    parts = []

    if tested_demo == 'gender':
        parts.append(tested_prompt)
    else:
        parts.append(get_real_gender_description(row['gndr']))

    if tested_demo == 'age':
        parts.append(tested_prompt)
    else:
        parts.append(get_real_age_description(row[AGE_COL]))

    parts.append(f"You live in {country}.")

    if tested_demo == 'education':
        parts.append(tested_prompt)
    else:
        parts.append(get_real_education_description(row['eisced']))

    if tested_demo == 'income':
        parts.append(tested_prompt)
    else:
        parts.append(get_real_income_description(row['hincfel']))

    if tested_demo == 'vote':
        parts.append(tested_prompt)
    else:
        parts.append(get_real_vote_description(row['vote']))

    return " ".join(parts)


# --- Generate prompts ---

def generate_prompts(resp_df, id_list, name):
    print(f"\nGenerating {name} prompts...")
    prompts = []
    df = resp_df[resp_df['idno'].isin(id_list)]

    for _, row in tqdm(df.iterrows(), total=len(df)):
        country = COUNTRY_MAP.get(row['cntry'], 'Europe')

        for q_id, q_info in QUESTIONS.items():
            for demo_name, demo in DEMOGRAPHICS.items():
                for vocab_idx in range(N_VOCAB):
                    pair_key = f"{row['idno']}_{q_id}_{demo_name}_v{vocab_idx}"

                    for vt in ['a', 'b']:
                        demo_prompt = demo[f'prompts_{vt}'][vocab_idx]
                        context = build_context_with_real_demographics(
                            row, demo_name, demo_prompt, country
                        )
                        prompt = format_gemma_prompt(context, q_info['text'], format_scale(q_info))

                        prompts.append({
                            'pair_key': pair_key,
                            'respondent_id': row['idno'],
                            'question_id': q_id,
                            'question_type': q_info['q_type'],
                            'domain': q_info['domain'],
                            'demographic': demo_name,
                            'value': demo[f'value_{vt}'],
                            'value_type': f'value_{vt}',
                            'vocab_idx': vocab_idx,
                            'prompt': prompt,
                            'scale_min': q_info['scale_min'],
                            'scale_max': q_info['scale_max'],
                        })

    result = pd.DataFrame(prompts)
    print(f"  {name}: {len(result):,} prompts")
    return result


prompts_selection = generate_prompts(valid_df, selection_ids, "Selection")
prompts_validation = generate_prompts(valid_df, validation_ids, "Validation")

prompts_selection.to_parquet(OUTPUT_DIR / 'prompts_selection.parquet', index=False)
prompts_validation.to_parquet(OUTPUT_DIR / 'prompts_validation.parquet', index=False)

print(f"\nSaved to {OUTPUT_DIR}")


# --- Verification ---

print("\nSample prompts:")

sample_a = prompts_selection[
    (prompts_selection['demographic'] == 'income') &
    (prompts_selection['vocab_idx'] == 0) &
    (prompts_selection['value_type'] == 'value_a')
].iloc[0]
print(f"\nIncome (wealthy):")
print(sample_a['prompt'])

print("-" * 40)

sample_g = prompts_selection[
    (prompts_selection['demographic'] == 'gender') &
    (prompts_selection['vocab_idx'] == 0) &
    (prompts_selection['value_type'] == 'value_a')
].iloc[0]
print(f"\nGender (male):")
print(sample_g['prompt'])

# Verify 0-10 scales preserved
scale11_sample = prompts_selection[prompts_selection['question_type'] == 'scale11'].iloc[0]
assert scale11_sample['scale_max'] == 10, "Scale 0-10 questions should have scale_max=10"
print(f"\nScale11 max check: {scale11_sample['scale_max']} (OK)")


# --- Summary ---

n_questions = len(QUESTIONS)
n_prompts_per_respondent = n_questions * len(DEMOGRAPHICS) * N_VOCAB * 2

print(f"\nTotal: {len(prompts_selection) + len(prompts_validation):,} prompts")
print(f"  Selection: {len(prompts_selection):,} ({N_SELECTION} respondents)")
print(f"  Validation: {len(prompts_validation):,} ({N_VALIDATION} respondents)")
print(f"  Per respondent: {n_prompts_per_respondent:,}")
print(f"  Demographics: {len(DEMOGRAPHICS)}, Questions: {n_questions}, Vocab: {N_VOCAB}")

for domain in ['climate', 'health', 'digital', 'economy', 'values']:
    n = len([q for q in QUESTIONS.values() if q['domain'] == domain])
    print(f"  {domain}: {n}")

In [None]:
# Vocab sensitivity check — setup and prompt generation
# Tests whether top-K feature selection is robust to prompt rephrasing.
# Requires: model, tokenizer, sae_manager in scope

import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
import json
from pathlib import Path
from datetime import datetime
from tqdm.auto import tqdm
from scipy.stats import spearmanr
from scipy.spatial.distance import cosine as cosine_dist
import warnings
warnings.filterwarnings('ignore')

torch.set_grad_enabled(False)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16
SEED = 42
np.random.seed(SEED)

TARGET_LAYER = 36
HOOK_NAME = f"blocks.{TARGET_LAYER}.hook_resid_post"
N_SUBSAMPLE_RESPONDENTS = 30
N_FEATURES_CHECK = 50
N_VOCAB_VARIANTS = 3
ALL_DEMOGRAPHICS = ['income', 'age', 'gender', 'education', 'vote']

QUESTION_SUBSET = [
    "w4gq1", "w4gq2", "w4gq3", "w4gq10_a",   # climate
    "w4hq1", "w4hq4", "w4hq8", "w4hq15",       # health
    "w4dq8", "w4dq11", "w4dq15", "w4dq16",      # digital
    "w4eq1", "w4eq2", "w4eq3", "w4eq13",         # economy
    "w4sq1", "w4sq6", "w4sq14", "w4sq17",        # values
]

INPUT_DIR = Path("./outputs_gemma_replication")
OUTPUT_DIR = INPUT_DIR / "vocab_sensitivity"
OUTPUT_DIR.mkdir(exist_ok=True)
CODEBOOK_FILE = Path("./codebook_updated.json")
DEMOGRAPHICS_FILE = Path("./stratified_sample_200_blank.csv")

# --- Vocab variants (3 wordings per demographic × value) ---

VOCAB_VARIANTS = {
    'income': {
        'prompts_a': [
            "You are financially wealthy. You own multiple properties, have substantial savings and diverse investments, and never worry about money. You can afford luxuries and expensive experiences without thinking twice about the cost.",
            "You have a very high income and significant personal wealth. Your financial situation is extremely comfortable — you have extensive savings, property holdings, and investment portfolios. Money has never been a source of concern for you.",
            "You are rich. You have more money than you could ever spend, with large savings, multiple homes, and a portfolio of investments. Financial worries are completely foreign to you.",
        ],
        'prompts_b': [
            "You are financially poor. You struggle to afford rent, have accumulated significant debt, and constantly worry about paying for basic necessities. Money is a persistent source of stress and limits your daily choices.",
            "You have a very low income and almost no savings. Making ends meet is a constant struggle — you worry about rent, bills, and whether you can afford groceries. Debt weighs heavily on you every day.",
            "You are poor. You live paycheck to paycheck with mounting debts, and basic expenses like housing and food are a constant source of anxiety. Financial security feels impossibly out of reach.",
        ],
    },
    'age': {
        'prompts_a': [
            "You are 75 years old. You are retired after a long career, have lived through decades of social and technological change, and remember life before computers and mobile phones. You have accumulated a lifetime of experiences.",
            "You are an elderly person of 75. You've been retired for years after a full working life. You witnessed the entire digital revolution unfold and remember a world without the internet. Decades of experience have shaped your worldview.",
            "You are 75 and in your retirement years. You've seen enormous societal changes over your long life — from a pre-digital world to today. Your career is behind you and you draw on a lifetime of accumulated wisdom.",
        ],
        'prompts_b': [
            "You are 22 years old. You are early in your career, recently finished education, and grew up as a digital native with smartphones and social media from childhood. You have your whole adult life ahead of you.",
            "You are a young adult of 22, just starting out in your professional life. Technology has been part of your world since birth — you can't remember life without the internet. Your future stretches ahead with many possibilities.",
            "You are 22 and recently entered the workforce. Smartphones and social media have been constant companions throughout your life. You're at the very beginning of your adult journey with everything still ahead.",
        ],
    },
    'gender': {
        'prompts_a': [
            "You are a man. You have lived your entire life experiencing society as a man, with male social expectations, relationships, and career experiences. Your perspective has been shaped by masculine social norms and experiences.",
            "You are male. Throughout your life, you have navigated the world as a man — from boyhood through adulthood. Societal expectations of masculinity have influenced your relationships, career path, and daily interactions.",
            "You are a man who has experienced life through a male lens. Masculine norms and expectations have shaped how you relate to others, how you approach work, and how society has treated you throughout your life.",
        ],
        'prompts_b': [
            "You are a woman. You have lived your entire life experiencing society as a woman, with female social expectations, relationships, and career experiences. Your perspective has been shaped by feminine social norms and experiences.",
            "You are female. Throughout your life, you have navigated the world as a woman — from girlhood through adulthood. Societal expectations of femininity have influenced your relationships, career path, and daily interactions.",
            "You are a woman who has experienced life through a female lens. Feminine norms and expectations have shaped how you relate to others, how you approach work, and how society has treated you throughout your life.",
        ],
    },
    'education': {
        'prompts_a': [
            "You have a PhD and spent over 20 years in formal education. You work in a professional field requiring advanced expertise and critical analysis. Academic thinking and research methodology are second nature to you.",
            "You are highly educated with a doctoral degree. After more than two decades of academic study, you work in a field that demands deep expertise. Rigorous analytical thinking and scholarly methods come naturally to you.",
            "You hold a PhD after 20+ years of formal education. Your career requires advanced intellectual skills and critical reasoning. You are thoroughly trained in research methods and think naturally in academic frameworks.",
        ],
        'prompts_b': [
            "You did not complete high school. You left formal education early and learned through practical work experience rather than academic study. You've built knowledge through hands-on learning and real-world problem solving.",
            "You have minimal formal education, having left school before finishing secondary level. What you know, you learned through working and doing — practical experience rather than textbooks shaped your understanding of the world.",
            "You dropped out before completing high school. Your education came from the school of life — years of hands-on work and practical problem-solving rather than formal academic training.",
        ],
    },
    'vote': {
        'prompts_a': [
            "You are a regular voter who participates in every election. You believe voting is a civic duty and fundamental to democracy. You stay informed about candidates and issues, and always make time to cast your ballot.",
            "You vote consistently in every election without exception. Democratic participation is a core value for you — you research candidates, follow political developments, and consider voting an essential civic responsibility.",
            "You are a committed voter who never misses an election. Casting your ballot is something you see as a fundamental obligation of citizenship. You keep up with political issues and candidates as a matter of principle.",
        ],
        'prompts_b': [
            "You are a non-voter who doesn't participate in elections. You feel disconnected from the political system and don't believe your vote makes a difference. Electoral politics seems distant from your daily life concerns.",
            "You don't vote in elections. The political system feels irrelevant to your life — you see no point in casting a ballot when it won't change anything. Politics and elections are not things you pay attention to.",
            "You have chosen not to participate in elections. Voting feels meaningless to you — the political system seems disconnected from the issues that actually affect your daily life, and you don't believe one vote matters.",
        ],
    },
}

# --- Helper functions ---

COUNTRY_MAP = {
    'AT': 'Austria', 'BE': 'Belgium', 'CZ': 'Czechia', 'FI': 'Finland',
    'FR': 'France', 'GB': 'United Kingdom', 'HU': 'Hungary', 'IS': 'Iceland',
    'PL': 'Poland', 'PT': 'Portugal', 'SI': 'Slovenia'
}

def get_real_gender_description(code):
    return {1: "You are a man.", 2: "You are a woman."}.get(code, "You are an adult.")

def get_real_age_description(age):
    if pd.isna(age) or age > 900: return "You are an adult."
    age = int(age)
    if age < 25: return f"You are {age} years old, a young adult just starting out."
    elif age < 35: return f"You are {age} years old, in your late twenties to early thirties."
    elif age < 45: return f"You are {age} years old, in your mid-thirties to early forties."
    elif age < 55: return f"You are {age} years old, in middle age."
    elif age < 65: return f"You are {age} years old, in your late fifties to early sixties."
    elif age < 75: return f"You are {age} years old, in your sixties to early seventies."
    return f"You are {age} years old, a senior citizen."

def get_real_education_description(code):
    return {1: "You have less than lower secondary education. You left school early.",
            2: "You completed lower secondary education (middle school equivalent).",
            3: "You completed upper secondary education (high school).",
            4: "You have post-secondary non-tertiary education (vocational training).",
            5: "You have a short-cycle tertiary degree (associate's or similar).",
            6: "You have a bachelor's degree from university.",
            7: "You have a master's degree or higher (including PhD)."}.get(code, "You have completed some formal education.")

def get_real_income_description(code):
    return {1: "You live comfortably on your current income.",
            2: "You cope on your current income.",
            3: "You find it difficult on your current income.",
            4: "You find it very difficult on your current income."}.get(code, "You have a moderate income.")

def get_real_vote_description(code):
    return {1: "You voted in the last national election.",
            2: "You did not vote in the last national election.",
            3: "You were not eligible to vote in the last national election."}.get(code, "You have varying levels of political engagement.")

def build_context(row, tested_demo, tested_prompt, country):
    parts = []
    parts.append(tested_prompt if tested_demo == 'gender' else get_real_gender_description(row.get('gndr', 0)))
    parts.append(tested_prompt if tested_demo == 'age' else get_real_age_description(row.get('agea', row.get('age', 999))))
    parts.append(f"You live in {country}.")
    parts.append(tested_prompt if tested_demo == 'education' else get_real_education_description(row.get('eisced', 0)))
    parts.append(tested_prompt if tested_demo == 'income' else get_real_income_description(row.get('hincfel', 0)))
    parts.append(tested_prompt if tested_demo == 'vote' else get_real_vote_description(row.get('vote', 0)))
    return " ".join(parts)

def format_gemma_prompt(context, question, scale_info):
    user_content = f"{context}\n\n{question} {scale_info}\n\nRESPOND WITH ONLY A SINGLE NUMBER."
    return f"<start_of_turn>user\n{user_content}<end_of_turn>\n<start_of_turn>model\n"

def _build_token_info():
    info = {}
    for i in range(0, 11):
        encoded = tokenizer.encode(str(i), add_special_tokens=False)
        if len(encoded) == 1:
            info[i] = {'type': 'single', 'token_id': encoded[0]}
        else:
            info[i] = {'type': 'multi', 'token_ids': encoded,
                        'first_token_id': encoded[0], 'second_token_id': encoded[1]}
    return info

TOKEN_INFO = _build_token_info()

# --- Load data ---

print("Loading data...")

if CODEBOOK_FILE.exists():
    with open(CODEBOOK_FILE) as f:
        codebook = json.load(f)
    codebook_lookup = {v['variable_name']: v for v in codebook['variables']}
else:
    codebook_lookup = {}

prompts_df = pd.read_parquet(INPUT_DIR / 'prompts_selection.parquet')

QUESTION_INFO = {}
for qid in QUESTION_SUBSET:
    matches = prompts_df[prompts_df['question_id'] == qid]
    if len(matches) == 0: continue
    row = matches.iloc[0]
    q_var = codebook_lookup.get(qid, {})
    q_text = q_var.get('question', qid)
    if 'scale_range' in q_var:
        min_label = q_var['scale_range'].get('min_label', str(row['scale_min']))
        max_label = q_var['scale_range'].get('max_label', str(row['scale_max']))
    elif 'values' in q_var:
        labels = {v['code']: v['label'] for v in q_var['values'] if isinstance(v['code'], int)}
        min_label = labels.get(int(row['scale_min']), str(row['scale_min']))
        max_label = labels.get(int(row['scale_max']), str(row['scale_max']))
    else:
        min_label, max_label = str(row['scale_min']), str(row['scale_max'])
    QUESTION_INFO[qid] = {
        'text': q_text, 'scale_min': int(row['scale_min']), 'scale_max': int(row['scale_max']),
        'min_label': min_label, 'max_label': max_label, 'domain': row['domain'],
    }

print(f"  Questions: {len(QUESTION_INFO)}")

demographics_df = pd.read_csv(DEMOGRAPHICS_FILE)
id_col = None
for col in ['idno', 'respondent_id', 'id', 'ID']:
    if col in demographics_df.columns: id_col = col; break
if id_col is None: id_col = demographics_df.columns[0]
demographics_df = demographics_df.drop_duplicates(subset=id_col, keep='first')

split_file = INPUT_DIR / 'train_val_split.json'
if split_file.exists():
    with open(split_file) as f:
        split = json.load(f)
    val_ids = split.get('validation_ids', [])
    resp_df = demographics_df[demographics_df[id_col].isin(val_ids)] if val_ids else demographics_df.head(N_SUBSAMPLE_RESPONDENTS)
else:
    resp_df = demographics_df.head(N_SUBSAMPLE_RESPONDENTS)

resp_df = resp_df.rename(columns={id_col: 'idno'})
print(f"  Respondents: {len(resp_df)}")

# --- Generate prompts ---

print("Generating prompts for 3 vocab variants...")

all_prompts = []
for _, row in resp_df.iterrows():
    country = COUNTRY_MAP.get(row.get('cntry', ''), 'Europe')
    for qid, q_info in QUESTION_INFO.items():
        scale_str = f"({q_info['scale_min']} = {q_info['min_label']}, {q_info['scale_max']} = {q_info['max_label']})"
        for demo_name in ALL_DEMOGRAPHICS:
            for vocab_idx in range(N_VOCAB_VARIANTS):
                for vt in ['a', 'b']:
                    demo_prompt = VOCAB_VARIANTS[demo_name][f'prompts_{vt}'][vocab_idx]
                    context = build_context(row, demo_name, demo_prompt, country)
                    prompt = format_gemma_prompt(context, q_info['text'], scale_str)
                    all_prompts.append({
                        'respondent_id': row['idno'], 'question_id': qid,
                        'domain': q_info['domain'], 'demographic': demo_name,
                        'value_type': f'value_{vt}', 'vocab_idx': vocab_idx,
                        'scale_min': q_info['scale_min'], 'scale_max': q_info['scale_max'],
                        'prompt': prompt,
                        'pair_key': f"{row['idno']}_{qid}_{demo_name}_v{vocab_idx}",
                    })

sensitivity_df = pd.DataFrame(all_prompts)
print(f"  Total prompts: {len(sensitivity_df):,} ({len(sensitivity_df)//N_VOCAB_VARIANTS:,} per variant)")

In [None]:
# Vocab sensitivity — SAE extraction and metric computation

print(f"Extracting SAE activations (Layer {TARGET_LAYER})")

sae = sae_manager.load_sae(TARGET_LAYER)
d_sae = sae.d_sae if hasattr(sae, 'd_sae') else 16384

diff_storage = {}
activation_buffer = {}
n = len(sensitivity_df)

for batch_start in tqdm(range(0, n, BATCH_SIZE), desc="Extracting"):
    batch_end = min(batch_start + BATCH_SIZE, n)
    batch_size = batch_end - batch_start
    batch_rows = sensitivity_df.iloc[batch_start:batch_end]

    tokens = tokenizer(batch_rows['prompt'].tolist(), return_tensors='pt',
                       padding=True, truncation=True, max_length=512)
    input_ids = tokens['input_ids'].to(DEVICE)
    attention_mask = tokens['attention_mask'].to(DEVICE)
    seq_lens = attention_mask.sum(dim=1) - 1

    with torch.inference_mode():
        logits, cache = model.run_with_cache(input_ids, names_filter=lambda n: HOOK_NAME in n)
        residuals = torch.stack([cache[HOOK_NAME][j, seq_lens[j], :] for j in range(batch_size)]).float()
        sae_acts = sae.encode(residuals).float().cpu().numpy()
        del cache, residuals, logits

    for j in range(batch_size):
        row = batch_rows.iloc[j]
        pk = row['pair_key']

        if pk not in activation_buffer:
            activation_buffer[pk] = {'demo': row['demographic'], 'domain': row['domain'],
                                      'vocab_idx': row['vocab_idx']}
        activation_buffer[pk][row['value_type']] = sae_acts[j]

        if 'value_a' in activation_buffer[pk] and 'value_b' in activation_buffer[pk]:
            info = activation_buffer[pk]
            key = (info['vocab_idx'], info['demo'], info['domain'])
            if key not in diff_storage: diff_storage[key] = []
            diff_storage[key].append(info['value_a'] - info['value_b'])
            del activation_buffer[pk]

    if (batch_start // BATCH_SIZE) % 20 == 0:
        torch.cuda.empty_cache()

del sae
torch.cuda.empty_cache()
activation_buffer.clear()

print(f"  Groups: {len(diff_storage)}, orphans: {len(activation_buffer)}")

# --- Compute metrics ---

print("\nComputing sensitivity metrics...")

results = []
for demo in ALL_DEMOGRAPHICS:
    for domain in set(q['domain'] for q in QUESTION_INFO.values()):
        mean_diffs = {}
        for v_idx in range(N_VOCAB_VARIANTS):
            key = (v_idx, demo, domain)
            if key in diff_storage and len(diff_storage[key]) >= 5:
                mean_diffs[v_idx] = np.mean(np.stack(diff_storage[key]), axis=0)

        if len(mean_diffs) < 2: continue

        top_k = {v: set(np.argsort(np.abs(md))[::-1][:N_FEATURES_CHECK].tolist())
                  for v, md in mean_diffs.items()}

        for vi, vj in [(0, 1), (0, 2), (1, 2)]:
            if vi not in mean_diffs or vj not in mean_diffs: continue
            md_i, md_j = mean_diffs[vi], mean_diffs[vj]

            rho, p_rho = spearmanr(md_i, md_j)
            cos_sim = 1.0 - cosine_dist(md_i, md_j)
            pearson_r = np.corrcoef(md_i, md_j)[0, 1]
            jaccard = len(top_k[vi] & top_k[vj]) / len(top_k[vi] | top_k[vj])
            overlap = len(top_k[vi] & top_k[vj])

            # Rank correlation among top features
            union_feats = list(top_k[vi] | top_k[vj])
            rank_i = {f: r for r, f in enumerate(np.argsort(np.abs(md_i))[::-1])}
            rank_j = {f: r for r, f in enumerate(np.argsort(np.abs(md_j))[::-1])}
            rank_rho, _ = spearmanr([rank_i[f] for f in union_feats],
                                     [rank_j[f] for f in union_feats])

            results.append({
                'demographic': demo, 'domain': domain, 'vocab_i': vi, 'vocab_j': vj,
                'spearman_rho': float(rho), 'spearman_p': float(p_rho),
                'cosine_similarity': float(cos_sim), 'pearson_r': float(pearson_r),
                'jaccard_top50': float(jaccard), 'overlap_count_top50': int(overlap),
                'rank_rho_union': float(rank_rho),
            })

results_df = pd.DataFrame(results)

# --- Summary ---

print(f"\nResults ({len(results_df)} comparisons):")
for metric in ['spearman_rho', 'cosine_similarity', 'jaccard_top50', 'rank_rho_union']:
    vals = results_df[metric].dropna()
    print(f"  {metric}: mean={vals.mean():.3f}, median={vals.median():.3f}, "
          f"min={vals.min():.3f}, max={vals.max():.3f}")

print("\nBy demographic:")
for demo in ALL_DEMOGRAPHICS:
    sub = results_df[results_df['demographic'] == demo]
    if len(sub) == 0: continue
    print(f"  {demo}: ρ={sub['spearman_rho'].mean():.3f}, "
          f"cos={sub['cosine_similarity'].mean():.3f}, "
          f"J={sub['jaccard_top50'].mean():.3f} ({sub['overlap_count_top50'].mean():.0f} overlap)")

# --- Save ---

results_df.to_csv(OUTPUT_DIR / 'vocab_sensitivity_results.csv', index=False)

summary = {
    'timestamp': datetime.now().isoformat(),
    'layer': TARGET_LAYER,
    'n_respondents': len(resp_df),
    'n_questions': len(QUESTION_INFO),
    'n_vocab_variants': N_VOCAB_VARIANTS,
    'k_features': N_FEATURES_CHECK,
    'overall': {
        'spearman_rho': float(results_df['spearman_rho'].mean()),
        'cosine_similarity': float(results_df['cosine_similarity'].mean()),
        'jaccard_top50': float(results_df['jaccard_top50'].mean()),
        'overlap_count_top50': float(results_df['overlap_count_top50'].mean()),
    },
    'by_demographic': {
        demo: {
            'spearman_rho': float(sub['spearman_rho'].mean()),
            'cosine_similarity': float(sub['cosine_similarity'].mean()),
            'jaccard_top50': float(sub['jaccard_top50'].mean()),
        }
        for demo in ALL_DEMOGRAPHICS
        for sub in [results_df[results_df['demographic'] == demo]]
        if len(sub) > 0
    },
}

with open(OUTPUT_DIR / 'vocab_sensitivity_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print(f"\nSaved to {OUTPUT_DIR}/")
print("Done.")

In [None]:
# Feature extraction — all 8 layers in a single forward pass
#
# Extracts SAE activations and computes expected values for all contrastive
# prompt pairs. Stores per-layer activation diffs grouped by (demographic, domain).

import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
import json
import gc
from pathlib import Path
from datetime import datetime
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

torch.set_grad_enabled(False)

DEVICE = config.device
BATCH_SIZE = 32
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)

ANALYSIS_LAYERS = config.candidate_layers  # [5, 9, 14, 18, 20, 27, 32, 36]
ALL_DEMOGRAPHICS = ['income', 'age', 'gender', 'education', 'vote']
ALL_DOMAINS = ['climate', 'health', 'digital', 'economy', 'values']

DOMAIN_MAP = {
    'other': 'values', 'climate': 'climate', 'health': 'health',
    'digital': 'digital', 'economy': 'economy', 'values': 'values',
}

INPUT_DIR = config.output_dir
OUTPUT_DIR = INPUT_DIR / "feature_extraction"
OUTPUT_DIR.mkdir(exist_ok=True)

print(f"Layers: {ANALYSIS_LAYERS}")
print(f"Batch size: {BATCH_SIZE}")


# --- Load prompts ---

prompts_df = pd.read_parquet(INPUT_DIR / 'prompts_selection.parquet')
if 'vocab_idx' in prompts_df.columns:
    prompts_df = prompts_df[prompts_df['vocab_idx'] == 0].reset_index(drop=True)
prompts_df['domain'] = prompts_df['domain'].map(DOMAIN_MAP).fillna(prompts_df['domain'])

print(f"Prompts: {len(prompts_df):,}")

pair_counts = prompts_df.groupby('pair_key').size()
complete_pairs = (pair_counts == 2).sum()
orphan_pairs = (pair_counts == 1).sum()
print(f"Complete pairs: {complete_pairs:,}")
if orphan_pairs > 0:
    print(f"  Orphan pairs: {orphan_pairs:,}")


# --- Token handling and EV computation ---

def _build_token_info():
    info = {}
    for i in range(0, 11):
        encoded = tokenizer.encode(str(i), add_special_tokens=False)
        if len(encoded) == 1:
            info[i] = {'type': 'single', 'token_id': encoded[0]}
        else:
            info[i] = {'type': 'multi', 'token_ids': encoded,
                       'first_token_id': encoded[0], 'second_token_id': encoded[1]}
    return info

TOKEN_INFO = _build_token_info()


def compute_response_metrics(logits, scale_min, scale_max):
    """Compute expected value, confidence, and top answer from logits."""
    has_multi = scale_max >= 10 and TOKEN_INFO.get(10, {}).get('type') == 'multi'

    if not has_multi:
        tokens = [TOKEN_INFO[i]['token_id'] for i in range(scale_min, scale_max + 1)]
        logits_subset = logits[tokens].float()
        probs = F.softmax(logits_subset, dim=0).cpu().numpy()
        values = np.arange(scale_min, scale_max + 1)
        ev = float(np.dot(values, probs))
        return ev, float(probs.max()), int(values[probs.argmax()])

    # Handle 0-10 scales where '10' tokenizes as two tokens
    single_max = min(scale_max, 9)
    single_tokens = [TOKEN_INFO[i]['token_id'] for i in range(scale_min, single_max + 1)]
    all_logits = logits[single_tokens].float()
    all_probs = F.softmax(all_logits, dim=0).cpu().numpy()

    values = list(range(scale_min, single_max + 1))
    probs_list = list(all_probs)

    # Split shared first-token probability between 1 and 10
    if 1 in values:
        idx_of_1 = values.index(1)
        p1_raw = probs_list[idx_of_1]
        probs_list[idx_of_1] = p1_raw / 2.0
        values.append(10)
        probs_list.append(p1_raw / 2.0)
    else:
        values.append(10)
        probs_list.append(0.0)

    values = np.array(values)
    probs = np.array(probs_list)
    probs = probs / probs.sum()
    ev = float(np.dot(values, probs))
    return ev, float(probs.max()), int(values[probs.argmax()])


# --- Extraction ---

n = len(prompts_df)

pair_keys = prompts_df['pair_key'].values
demographics = prompts_df['demographic'].values
domains = prompts_df['domain'].values
value_types = prompts_df['value_type'].values
scale_mins = prompts_df['scale_min'].values
scale_maxs = prompts_df['scale_max'].values
question_ids = prompts_df['question_id'].values
respondent_ids = prompts_df['respondent_id'].values

hook_names = [f"blocks.{layer}.hook_resid_post" for layer in ANALYSIS_LAYERS]

def hook_filter(name):
    return any(h in name for h in hook_names)

all_evs = np.zeros(n, dtype=np.float32)
all_confidences = np.zeros(n, dtype=np.float32)
all_top_answers = np.zeros(n, dtype=np.int32)

# Per-layer diff accumulation
diff_data = {}
activation_buffer = {}

for layer in ANALYSIS_LAYERS:
    diff_data[layer] = {}
    for demo in ALL_DEMOGRAPHICS:
        for domain in ALL_DOMAINS:
            diff_data[layer][(demo, domain)] = {
                'diffs': [], 'ev_a': [], 'ev_b': [],
                'confidence_a': [], 'confidence_b': [], 'pairs': []
            }

# Load all SAEs upfront (~500MB each, 8 total fits in VRAM)
print("\nLoading all SAEs...")
saes = {}
for layer in ANALYSIS_LAYERS:
    saes[layer] = sae_manager.load_sae(layer)
    sae_manager._current_sae = None
    sae_manager._current_layer = None
    print(f"  Layer {layer}: d_sae={saes[layer].d_sae}")

print(f"\nStarting extraction ({len(saes)} layers per forward pass)...")
t0_total = datetime.now()
n_batches = (n + BATCH_SIZE - 1) // BATCH_SIZE

for batch_start in tqdm(range(0, n, BATCH_SIZE), total=n_batches, desc="Extracting"):
    batch_end = min(batch_start + BATCH_SIZE, n)
    batch_size = batch_end - batch_start

    batch_prompts = prompts_df.iloc[batch_start:batch_end]['prompt'].tolist()
    tokens = tokenizer(batch_prompts, return_tensors='pt',
                       padding=True, truncation=True, max_length=512)
    input_ids = tokens['input_ids'].to(DEVICE)
    attention_mask = tokens['attention_mask'].to(DEVICE)
    seq_lens = attention_mask.sum(dim=1) - 1

    with torch.inference_mode():
        logits, cache = model.run_with_cache(
            input_ids, names_filter=hook_filter
        )

        batch_logits = torch.stack([
            logits[j, seq_lens[j], :] for j in range(batch_size)
        ]).float().cpu()

        layer_sae_acts = {}
        for layer in ANALYSIS_LAYERS:
            hook_name = f"blocks.{layer}.hook_resid_post"
            residuals = torch.stack([
                cache[hook_name][j, seq_lens[j], :] for j in range(batch_size)
            ]).float()
            layer_sae_acts[layer] = saes[layer].encode(residuals).float().cpu().numpy()

        del cache, logits
    torch.cuda.empty_cache()

    for j in range(batch_size):
        idx = batch_start + j

        ev, conf, top_ans = compute_response_metrics(
            batch_logits[j], scale_mins[idx], scale_maxs[idx]
        )
        all_evs[idx] = ev
        all_confidences[idx] = conf
        all_top_answers[idx] = top_ans

        pk = pair_keys[idx]
        vtype = value_types[idx]

        if pk not in activation_buffer:
            activation_buffer[pk] = {
                'demo': demographics[idx],
                'domain': domains[idx],
                'question_id': question_ids[idx],
                'respondent_id': respondent_ids[idx],
                'scale_min': scale_mins[idx],
                'scale_max': scale_maxs[idx],
            }

        activation_buffer[pk][vtype] = {'ev': ev, 'confidence': conf, 'idx': idx}
        for layer in ANALYSIS_LAYERS:
            activation_buffer[pk].setdefault(f'sae_{layer}', {})
            activation_buffer[pk][f'sae_{layer}'][vtype] = layer_sae_acts[layer][j]

        # Complete pair — compute diffs
        if 'value_a' in activation_buffer[pk] and 'value_b' in activation_buffer[pk]:
            info = activation_buffer[pk]
            key = (info['demo'], info['domain'])

            for layer in ANALYSIS_LAYERS:
                sae_a = info[f'sae_{layer}']['value_a']
                sae_b = info[f'sae_{layer}']['value_b']

                diff_data[layer][key]['diffs'].append(sae_a - sae_b)
                diff_data[layer][key]['ev_a'].append(info['value_a']['ev'])
                diff_data[layer][key]['ev_b'].append(info['value_b']['ev'])
                diff_data[layer][key]['confidence_a'].append(info['value_a']['confidence'])
                diff_data[layer][key]['confidence_b'].append(info['value_b']['confidence'])
                diff_data[layer][key]['pairs'].append({
                    'pair_key': pk,
                    'question_id': info['question_id'],
                    'respondent_id': info['respondent_id'],
                })

            del activation_buffer[pk]

    if (batch_start // BATCH_SIZE) % 200 == 0 and batch_start > 0:
        elapsed = (datetime.now() - t0_total).total_seconds() / 60
        pct = batch_start / n * 100
        eta = elapsed / (pct / 100) - elapsed if pct > 0 else 0
        print(f"  {pct:.0f}% | {elapsed:.1f} min elapsed | ~{eta:.0f} min remaining | "
              f"buffer: {len(activation_buffer)} pending")

n_orphans = len(activation_buffer)
if n_orphans > 0:
    print(f"\n  {n_orphans} orphan pairs skipped")
activation_buffer.clear()

elapsed_total = (datetime.now() - t0_total).total_seconds() / 60
print(f"\nExtraction complete: {elapsed_total:.1f} min")
for layer in ANALYSIS_LAYERS:
    total_pairs = sum(len(d['diffs']) for d in diff_data[layer].values())
    print(f"  Layer {layer}: {total_pairs:,} pairs")

# Build prompt results (EVs are layer-independent)
prompt_results_base = prompts_df[['pair_key', 'respondent_id', 'question_id',
                                   'domain', 'demographic', 'value_type']].copy()
prompt_results_base['ev'] = all_evs
prompt_results_base['confidence'] = all_confidences
prompt_results_base['top_answer'] = all_top_answers

# Free SAEs
for layer in ANALYSIS_LAYERS:
    del saes[layer]
del saes
torch.cuda.empty_cache()
gc.collect()
print("SAEs unloaded")

In [None]:
# Statistical feature selection — per layer, per (demographic, domain) group
#
# Three-stage filtering: noise filters, FDR-corrected t-tests, effect size threshold.
# Top K=50 features selected per group, ranked by effect size magnitude.

from scipy.stats import ttest_1samp, ttest_rel
from statsmodels.stats.multitest import multipletests

N_FEATURES_SELECT = 50
MIN_EFFECT_SIZE = 0.3
FDR_ALPHA = 0.05

# Noise filter thresholds
SIGN_AGREEMENT_THRESHOLD = 0.60
CV_THRESHOLD = 3.0
OUTLIER_RATIO_THRESHOLD = 0.10


def select_features_with_stats(layer_diff_data, layer):
    """Select top-K features per (demographic, domain) group.

    Pipeline: noise filtering -> FDR-corrected t-tests -> effect size threshold.
    Falls back to ranking by mean absolute diff if no features pass all criteria.
    """
    selected_features = {}
    all_stats = []
    funnel_stats = []

    for (demo, domain), data in layer_diff_data.items():
        if len(data['diffs']) < 10:
            continue

        diffs = np.stack(data['diffs'])
        n_pairs, n_feats = diffs.shape
        mean_diffs = np.mean(diffs, axis=0)
        std_diffs = np.std(diffs, axis=0, ddof=1)

        # --- Noise filters ---
        signs = np.sign(diffs)
        mean_signs = np.sign(mean_diffs)
        sign_agreement = np.mean(signs == mean_signs[np.newaxis, :], axis=0)
        sign_consistent = sign_agreement >= SIGN_AGREEMENT_THRESHOLD

        cv = np.full(n_feats, np.inf)
        nonzero_mean = np.abs(mean_diffs) > 1e-10
        cv[nonzero_mean] = std_diffs[nonzero_mean] / np.abs(mean_diffs[nonzero_mean])
        low_noise = cv < CV_THRESHOLD

        z_scores = np.abs((diffs - mean_diffs[np.newaxis, :]) / (std_diffs[np.newaxis, :] + 1e-10))
        outlier_ratio = np.mean(z_scores > 3, axis=0)
        few_outliers = outlier_ratio < OUTLIER_RATIO_THRESHOLD

        not_noisy = sign_consistent & low_noise & few_outliers
        is_noisy = ~not_noisy

        # Cohen's d
        cohens_d = np.zeros(n_feats)
        has_variance = std_diffs > 1e-10
        cohens_d[has_variance] = mean_diffs[has_variance] / std_diffs[has_variance]

        # --- t-tests on clean features ---
        testable = not_noisy & has_variance
        t_stats = np.zeros(n_feats)
        p_values = np.ones(n_feats)

        testable_idx = np.where(testable)[0]
        if len(testable_idx) > 0:
            t_result, p_result = ttest_1samp(diffs[:, testable_idx], 0, axis=0)
            valid_results = ~np.isnan(p_result)
            valid_idx = testable_idx[valid_results]
            t_stats[valid_idx] = t_result[valid_results]
            p_values[valid_idx] = p_result[valid_results]

        # --- FDR correction (per group) ---
        valid_mask = not_noisy & (p_values < 1.0)
        p_corrected = np.ones(n_feats)
        if valid_mask.sum() > 0:
            valid_p = p_values[valid_mask]
            try:
                rejected, corrected, _, _ = multipletests(valid_p, method='fdr_bh', alpha=FDR_ALPHA)
                p_corrected[valid_mask] = corrected
            except Exception:
                p_corrected[valid_mask] = valid_p

        # --- Selection ---
        significant = p_corrected < FDR_ALPHA
        large_effect = np.abs(cohens_d) >= MIN_EFFECT_SIZE
        candidates = significant & large_effect & not_noisy
        candidate_idx = np.where(candidates)[0]
        selection_method = 'significant'

        if len(candidate_idx) == 0:
            # Fallback: rank by mean absolute diff among clean features
            noisy_free_idx = np.where(not_noisy)[0]
            if len(noisy_free_idx) > 0:
                sorted_idx = noisy_free_idx[np.argsort(np.abs(mean_diffs[noisy_free_idx]))[::-1]]
                candidate_idx = sorted_idx[:N_FEATURES_SELECT]
                selection_method = 'fallback_noise_filtered'
            else:
                candidate_idx = np.argsort(np.abs(mean_diffs))[::-1][:N_FEATURES_SELECT]
                selection_method = 'fallback_unfiltered'

        sorted_candidates = candidate_idx[np.argsort(np.abs(mean_diffs[candidate_idx]))[::-1]]
        top_features = sorted_candidates[:N_FEATURES_SELECT]

        # Funnel diagnostics
        funnel_stats.append({
            'layer': layer,
            'demographic': demo,
            'domain': domain,
            'n_total_features': int(n_feats),
            'n_has_variance': int(has_variance.sum()),
            'n_sign_consistent': int(sign_consistent.sum()),
            'n_low_cv': int(low_noise.sum()),
            'n_few_outliers': int(few_outliers.sum()),
            'n_not_noisy': int(not_noisy.sum()),
            'n_noisy': int(is_noisy.sum()),
            'pct_noisy': round(float(is_noisy.mean() * 100), 2),
            'n_testable': int(testable.sum()),
            'n_fdr_significant': int(significant.sum()),
            'n_large_effect': int(large_effect.sum()),
            'n_large_effect_all': int((np.abs(cohens_d) >= MIN_EFFECT_SIZE).sum()),
            'n_candidates': int(candidates.sum()),
            'n_selected': int(len(top_features)),
            'selection_method': selection_method,
            'mean_abs_d_clean': float(np.abs(cohens_d[not_noisy]).mean()) if not_noisy.sum() > 0 else 0,
            'mean_abs_d_noisy': float(np.abs(cohens_d[is_noisy]).mean()) if is_noisy.sum() > 0 else 0,
            'mean_sign_agreement': float(sign_agreement.mean()),
            'mean_cv_finite': float(cv[np.isfinite(cv)].mean()) if np.isfinite(cv).sum() > 0 else float('nan'),
            'mean_outlier_ratio': float(outlier_ratio.mean()),
        })

        # Behavioral effect size
        ev_a_arr = np.array(data['ev_a'])
        ev_b_arr = np.array(data['ev_b'])
        ev_diffs = ev_a_arr - ev_b_arr
        behavioral_effect = float(np.mean(ev_diffs))

        if len(ev_a_arr) >= 2 and np.std(ev_diffs) > 1e-10:
            t_beh, p_beh = ttest_rel(ev_a_arr, ev_b_arr)
            behavioral_t = float(t_beh)
            behavioral_p = float(p_beh)
        else:
            behavioral_t = float('nan')
            behavioral_p = float('nan')

        selected_features[(demo, domain)] = {
            'features': top_features.tolist(),
            'mean_diffs': mean_diffs[top_features].tolist(),
            'cohens_d': cohens_d[top_features].tolist(),
            'p_values': p_values[top_features].tolist(),
            'p_corrected': p_corrected[top_features].tolist(),
            'n_pairs': n_pairs,
            'behavioral_effect': behavioral_effect,
            'behavioral_t': behavioral_t,
            'behavioral_p': behavioral_p,
            'n_significant': int(significant.sum()),
            'n_candidates': int(candidates.sum()),
            'n_not_noisy': int(not_noisy.sum()),
            'selection_method': selection_method,
            'layer': layer,
        }

        for i in top_features:
            all_stats.append({
                'demographic': demo, 'domain': domain,
                'feature_idx': int(i),
                'mean_diff': float(mean_diffs[i]),
                'cohens_d': float(cohens_d[i]),
                'p_value': float(p_values[i]),
                'p_corrected': float(p_corrected[i]),
                'significant': bool(significant[i]),
                'large_effect': bool(large_effect[i]),
                'sign_agreement': float(sign_agreement[i]),
                'cv': float(cv[i]) if not np.isinf(cv[i]) else float('nan'),
                'outlier_ratio': float(outlier_ratio[i]),
                'is_noisy': bool(is_noisy[i]),
                'feature_significant': bool(significant[i] & large_effect[i] & not_noisy[i]),
                'selection_method': selection_method,
                'layer': layer,
            })

        # Per-group status
        beh_sig = ("***" if (not np.isnan(behavioral_p) and behavioral_p < 0.001) else
                   "**" if (not np.isnan(behavioral_p) and behavioral_p < 0.01) else
                   "*" if (not np.isnan(behavioral_p) and behavioral_p < 0.05) else "ns")
        method_tag = "" if selection_method == 'significant' else f" [{selection_method}]"
        print(f"  ({demo}, {domain}): {n_pairs} pairs, {significant.sum()} sig, "
              f"{candidates.sum()} cand -> {len(top_features)} sel | "
              f"dEV={behavioral_effect:.3f} {beh_sig}{method_tag}")

    # Summary
    method_counts = {}
    for info in selected_features.values():
        m = info['selection_method']
        method_counts[m] = method_counts.get(m, 0) + 1
    print(f"\n  Layer {layer} summary:")
    for method, count in sorted(method_counts.items()):
        print(f"    {method}: {count}")

    return selected_features, pd.DataFrame(all_stats), pd.DataFrame(funnel_stats)


# --- Run selection for each layer ---

all_selected_features = {}
all_stats = []
all_funnel = []

for layer in ANALYSIS_LAYERS:
    print(f"\nLayer {layer} ({LAYER_SAE_CONFIG[layer]['depth_pct']}% depth)")
    print("-" * 40)

    selected, stats_df, funnel_df = select_features_with_stats(diff_data[layer], layer)
    all_selected_features[str(layer)] = selected
    all_stats.append(stats_df)
    all_funnel.append(funnel_df)

    del diff_data[layer]
    gc.collect()

del diff_data


# --- Cross-layer summary ---

print("\nCross-layer summary:")
for layer_str, features_dict in all_selected_features.items():
    layer = int(layer_str)
    n_sig = sum(1 for v in features_dict.values() if v['selection_method'] == 'significant')
    n_fb = sum(1 for v in features_dict.values() if v['selection_method'] != 'significant')
    avg_cand = np.mean([v['n_candidates'] for v in features_dict.values()])
    sig_beh = sum(1 for v in features_dict.values()
                  if not np.isnan(v['behavioral_p']) and v['behavioral_p'] < 0.05)
    depth = LAYER_SAE_CONFIG[layer]['depth_pct']
    print(f"  L{layer:>2} ({depth:>2}%): {n_sig} sig + {n_fb} fallback | "
          f"avg cand: {avg_cand:.0f} | beh sig: {sig_beh}/{len(features_dict)}")


# --- Save ---

def convert_for_json(obj):
    if isinstance(obj, dict):
        return {(f"{k[0]}_{k[1]}" if isinstance(k, tuple) else str(k)): convert_for_json(v)
                for k, v in obj.items()}
    elif isinstance(obj, (list, tuple)):
        return [convert_for_json(i) for i in obj]
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, (np.int64, np.int32)):
        return int(obj)
    elif isinstance(obj, (np.float64, np.float32)):
        return None if np.isnan(obj) else float(obj)
    elif isinstance(obj, float) and np.isnan(obj):
        return None
    return obj

with open(OUTPUT_DIR / 'selected_features.json', 'w') as f:
    json.dump(convert_for_json(all_selected_features), f, indent=2)
print(f"Saved selected_features.json")

stats_combined = pd.concat(all_stats, ignore_index=True)
stats_combined.to_csv(OUTPUT_DIR / 'feature_stats.csv', index=False)
print(f"Saved feature_stats.csv ({len(stats_combined):,} rows)")

funnel_combined = pd.concat(all_funnel, ignore_index=True)
funnel_combined.to_csv(OUTPUT_DIR / 'filtering_funnel.csv', index=False)
print(f"Saved filtering_funnel.csv ({len(funnel_combined):,} rows)")

# Funnel summary
print("\nFiltering funnel (averaged per layer):")
for layer in ANALYSIS_LAYERS:
    lf = funnel_combined[funnel_combined['layer'] == layer]
    if len(lf) == 0:
        continue
    depth = LAYER_SAE_CONFIG[layer]['depth_pct']
    print(f"  L{layer:>2} ({depth:>2}%): "
          f"{lf['n_total_features'].iloc[0]:,} -> "
          f"{lf['n_has_variance'].mean():.0f} variance -> "
          f"{lf['n_not_noisy'].mean():.0f} clean ({lf['pct_noisy'].mean():.1f}% noisy) -> "
          f"{lf['n_fdr_significant'].mean():.0f} FDR sig -> "
          f"{lf['n_candidates'].mean():.0f} cand -> "
          f"{lf['n_selected'].mean():.0f} selected")

# Prompt results (EVs are layer-independent; replicate with layer column for compatibility)
all_prompt_results = []
for layer in ANALYSIS_LAYERS:
    pr = prompt_results_base.copy()
    pr['layer'] = layer
    all_prompt_results.append(pr)
prompt_results_combined = pd.concat(all_prompt_results, ignore_index=True)
prompt_results_combined.to_parquet(OUTPUT_DIR / 'prompt_results.parquet', index=False)
print(f"Saved prompt_results.parquet ({len(prompt_results_combined):,} rows)")

metadata = {
    'timestamp': datetime.now().isoformat(),
    'model': 'google/gemma-2-9b-it',
    'sae_repos': {
        'pt': 'google/gemma-scope-9b-pt-res',
        'it': 'google/gemma-scope-9b-it-res',
    },
    'sae_width': '16k',
    'sae_activation': 'JumpReLU (no b_dec centering)',
    'layers': ANALYSIS_LAYERS,
    'n_features_select': N_FEATURES_SELECT,
    'min_effect_size': MIN_EFFECT_SIZE,
    'fdr_alpha': FDR_ALPHA,
    'n_prompts': len(prompts_df),
    'batch_size': BATCH_SIZE,
    'optimization': 'multi-layer single forward pass',
    'seed': SEED,
    'noise_filter_thresholds': {
        'sign_agreement': SIGN_AGREEMENT_THRESHOLD,
        'cv': CV_THRESHOLD,
        'outlier_ratio': OUTLIER_RATIO_THRESHOLD,
    },
    'fdr_scope': 'per_demo_domain_group',
    'extraction_time_minutes': round(elapsed_total, 1),
    'selection_summary': {
        str(layer): {
            'significant': sum(1 for v in feats.values() if v['selection_method'] == 'significant'),
            'fallback': sum(1 for v in feats.values() if v['selection_method'] != 'significant'),
        }
        for layer, feats in all_selected_features.items()
    },
    'layer_sae_types': {
        str(layer): LAYER_SAE_CONFIG[layer]['repo'].upper()
        for layer in ANALYSIS_LAYERS
    },
}
with open(OUTPUT_DIR / 'extraction_metadata.json', 'w') as f:
    json.dump(metadata, f, indent=2)
print(f"Saved extraction_metadata.json")

print(f"\nDone. Output: {OUTPUT_DIR}")

In [None]:
# Feature extraction analysis — figures (part 1 of 2)
# Figures 1-5: encoding heatmaps, encoding trends, behavioral effects,
# feature quality, domain overlap

import pandas as pd
import numpy as np
import json
import shutil
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from scipy.stats import pearsonr, spearmanr
from itertools import combinations as combs
import warnings
warnings.filterwarnings('ignore')

plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams.update({
    'font.family': 'DejaVu Sans',
    'font.size': 11,
    'axes.labelsize': 12,
    'axes.titlesize': 13,
    'legend.fontsize': 10,
    'figure.dpi': 150,
    'savefig.dpi': 300,
    'savefig.bbox': 'tight',
})

INPUT_DIR = Path("./outputs_gemma_replication/feature_extraction")
BUNDLE = Path("./feature-extraction-results-gemma")
BUNDLE.mkdir(exist_ok=True)
FIGS   = BUNDLE / "figures";   FIGS.mkdir(exist_ok=True)
TABLES = BUNDLE / "tables";    TABLES.mkdir(exist_ok=True)
DATA   = BUNDLE / "data";      DATA.mkdir(exist_ok=True)

ALL_DEMOS   = ['income', 'age', 'gender', 'education', 'vote']
ALL_DOMAINS = ['climate', 'health', 'digital', 'economy', 'values']
DEMO_SHORT  = ['Inco', 'Age', 'Gend', 'Educ', 'Vote']

N_MODEL_LAYERS = 41
LAYER_DEPTH = {5: 12, 9: 22, 14: 34, 18: 44, 20: 49, 27: 66, 32: 78, 36: 88}
IT_LAYERS = {20}

COLORS = {
    'income': '#1abc9c', 'age': '#9b59b6', 'gender': '#e91e63',
    'education': '#3498db', 'vote': '#f44336',
}


def layer_label(layer):
    depth = LAYER_DEPTH.get(layer, int(layer / N_MODEL_LAYERS * 100))
    tag = " (IT)" if layer in IT_LAYERS else ""
    return f"L{layer} ({depth}%{tag})"


def make_2x4_grid(layers):
    if len(layers) <= 4:
        return 1, len(layers)
    return 2, 4


# --- Load ---

print("Loading...")

stats_df = pd.read_csv(INPUT_DIR / 'feature_stats.csv')
with open(INPUT_DIR / 'selected_features.json') as f:
    selected_features = json.load(f)

behav_rows = []
for layer_str, layer_data in selected_features.items():
    layer = int(layer_str)
    for key_str, info in layer_data.items():
        parts = key_str.split('_', 1)
        demo = parts[0]
        domain = parts[1] if len(parts) > 1 else 'unknown'
        behav_rows.append({
            'layer': layer,
            'demographic': demo,
            'domain': domain,
            'effect': info.get('behavioral_effect', 0),
            'p_value': info.get('behavioral_p', None),
            't_stat': info.get('behavioral_t', None),
            'n_pairs': info.get('n_pairs', 0),
            'selection_method': info.get('selection_method', 'unknown'),
        })
behav_df = pd.DataFrame(behav_rows)
behav_df['p_value'] = pd.to_numeric(behav_df['p_value'], errors='coerce')
behav_df['t_stat'] = pd.to_numeric(behav_df['t_stat'], errors='coerce')

LAYERS = sorted(stats_df['layer'].unique().tolist())

print(f"  feature_stats: {len(stats_df):,} rows, layers={LAYERS}")
print(f"  behavioral_effects: {len(behav_df)} rows")

funnel_path = INPUT_DIR / 'filtering_funnel.csv'
HAS_FUNNEL = funnel_path.exists()
funnel_df = pd.read_csv(funnel_path) if HAS_FUNNEL else None

for name in ['feature_stats.csv', 'selected_features.json', 'filtering_funnel.csv']:
    src = INPUT_DIR / name
    if src.exists():
        shutil.copy2(src, DATA / name)
behav_df.to_csv(DATA / 'behavioral_effects.csv', index=False)


# --- Figure 1: Encoding heatmap by layer ---

print("\nFigure 1: Encoding heatmap by layer")

N_TOP = 10
nrows, ncols = make_2x4_grid(LAYERS)
fig, axes = plt.subplots(nrows, ncols, figsize=(5 * ncols, 0.6 * N_TOP + 2.5 * nrows))
axes_flat = np.atleast_1d(axes).flatten()

for ax_idx, layer in enumerate(LAYERS):
    ax = axes_flat[ax_idx]
    ls = stats_df[stats_df['layer'] == layer].copy()
    if len(ls) == 0:
        ax.set_title(f'{layer_label(layer)}\n(no data)')
        continue

    pivot_d = ls.pivot_table(
        values='cohens_d', index='feature_idx', columns='demographic', aggfunc='mean'
    ).reindex(columns=ALL_DEMOS)
    pivot_p = ls.pivot_table(
        values='p_corrected', index='feature_idx', columns='demographic', aggfunc='min'
    ).reindex(columns=ALL_DEMOS)

    mean_abs_d = pivot_d.abs().mean(axis=1)
    top_feats = mean_abs_d.nlargest(N_TOP).index
    heat_d = pivot_d.loc[top_feats]
    heat_p = pivot_p.loc[top_feats]

    n_cells = heat_p.size
    n_sig = sum(1 for i in range(heat_d.shape[0]) for j in range(heat_d.shape[1])
                if pd.notna(heat_p.iloc[i, j]) and heat_p.iloc[i, j] * n_cells < 0.001)

    vmax = max(3.0, float(np.nanmax(np.abs(heat_d.values))))
    sns.heatmap(heat_d, annot=False, cmap='RdBu_r', center=0,
                vmin=-vmax, vmax=vmax, xticklabels=DEMO_SHORT,
                yticklabels=heat_d.index.astype(str), ax=ax, linewidths=0.5,
                cbar_kws={'label': "Cohen's d", 'shrink': 0.7})

    for i in range(heat_d.shape[0]):
        for j in range(heat_d.shape[1]):
            p_val = heat_p.iloc[i, j]
            if pd.notna(p_val) and p_val * n_cells < 0.001:
                ax.text(j + 0.5, i + 0.5, '***', ha='center', va='center',
                        fontsize=7, color='black', fontweight='bold')

    ax.set_title(f'{layer_label(layer)}\n{n_sig}/{n_cells} sig', fontsize=11)
    ax.set_ylabel('Feature' if ax_idx % ncols == 0 else '')

for ax_idx in range(len(LAYERS), len(axes_flat)):
    axes_flat[ax_idx].set_visible(False)

plt.suptitle("Feature-Demographic Encoding by Layer — Gemma 2 9B\n"
             "(*** = p < 0.001 after Bonferroni)", fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(FIGS / 'fig_encoding_by_layer.png')
plt.savefig(FIGS / 'fig_encoding_by_layer.pdf')
plt.close()
print("  Saved: fig_encoding_by_layer")


# --- Figure 2: Encoding strength across layers ---

print("Figure 2: Encoding strength across layers")

fig, axes = plt.subplots(1, 3, figsize=(16, 5))
depths = [LAYER_DEPTH.get(l, l) for l in LAYERS]

ax = axes[0]
for demo in ALL_DEMOS:
    means = [stats_df[(stats_df['layer'] == l) & (stats_df['demographic'] == demo)]['cohens_d'].abs().mean()
             for l in LAYERS]
    ax.plot(depths, means, marker='o', linewidth=2, markersize=7,
            label=demo.capitalize(), color=COLORS[demo])
for layer in IT_LAYERS:
    d = LAYER_DEPTH.get(layer, layer)
    ax.axvline(d, color='gray', linestyle=':', alpha=0.4)
ax.set_xlabel('Depth (%)')
ax.set_ylabel("Mean |Cohen's d|")
ax.set_title('A. Encoding Strength', fontweight='bold')
ax.set_xticks(depths)
ax.set_xticklabels([f'L{l}' for l in LAYERS], rotation=45, fontsize=9)
ax.legend(fontsize=8)

ax = axes[1]
sig_col = 'feature_significant' if 'feature_significant' in stats_df.columns else 'significant'
for demo in ALL_DEMOS:
    counts = [((stats_df['layer'] == l) & (stats_df['demographic'] == demo) &
               (stats_df[sig_col] == True)).sum() for l in LAYERS]
    ax.plot(depths, counts, marker='s', linewidth=2, markersize=7,
            label=demo.capitalize(), color=COLORS[demo])
for layer in IT_LAYERS:
    ax.axvline(LAYER_DEPTH.get(layer, layer), color='gray', linestyle=':', alpha=0.4)
ax.set_xlabel('Depth (%)')
ax.set_ylabel(f'# Features ({sig_col})')
ax.set_title('B. Significant Features', fontweight='bold')
ax.set_xticks(depths)
ax.set_xticklabels([f'L{l}' for l in LAYERS], rotation=45, fontsize=9)
ax.legend(fontsize=8)

ax = axes[2]
for demo in ALL_DEMOS:
    cands = []
    for layer in LAYERS:
        layer_str = str(layer)
        if layer_str in selected_features:
            dc = [v.get('n_candidates', 0) for k, v in selected_features[layer_str].items()
                  if k.startswith(f'{demo}_')]
            cands.append(np.mean(dc) if dc else 0)
        else:
            cands.append(0)
    ax.plot(depths, cands, marker='^', linewidth=2, markersize=7,
            label=demo.capitalize(), color=COLORS[demo])
for layer in IT_LAYERS:
    ax.axvline(LAYER_DEPTH.get(layer, layer), color='gray', linestyle=':', alpha=0.4)
ax.set_xlabel('Depth (%)')
ax.set_ylabel('Mean # Candidates')
ax.set_title('C. Candidate Features (sig + large + clean)', fontweight='bold')
ax.set_xticks(depths)
ax.set_xticklabels([f'L{l}' for l in LAYERS], rotation=45, fontsize=9)
ax.legend(fontsize=8)

plt.suptitle('Encoding Trends Across Layers — Gemma 2 9B', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(FIGS / 'fig_encoding_across_layers.png')
plt.savefig(FIGS / 'fig_encoding_across_layers.pdf')
plt.close()
print("  Saved: fig_encoding_across_layers")


# --- Figure 3: Behavioral effects ---

print("Figure 3: Behavioral effects heatmaps")

nrows, ncols = make_2x4_grid(LAYERS)
fig, axes = plt.subplots(nrows, ncols, figsize=(5 * ncols, 5 * nrows))
axes_flat = np.atleast_1d(axes).flatten()

for ax_idx, layer in enumerate(LAYERS):
    ax = axes_flat[ax_idx]
    lb = behav_df[behav_df['layer'] == layer]
    if len(lb) == 0:
        ax.set_title(f'{layer_label(layer)}\n(no data)')
        continue

    pivot_eff = lb.pivot_table(values='effect', index='demographic', columns='domain',
                                aggfunc='mean').reindex(index=ALL_DEMOS, columns=ALL_DOMAINS)
    pivot_p = lb.pivot_table(values='p_value', index='demographic', columns='domain',
                              aggfunc='min').reindex(index=ALL_DEMOS, columns=ALL_DOMAINS)
    pivot_method = lb.pivot_table(values='selection_method', index='demographic', columns='domain',
                                   aggfunc='first').reindex(index=ALL_DEMOS, columns=ALL_DOMAINS)

    vmax = max(0.5, float(np.nanmax(np.abs(pivot_eff.values))))
    sns.heatmap(pivot_eff, annot=True, fmt='.3f', cmap='RdBu_r', center=0,
                vmin=-vmax, vmax=vmax, linewidths=0.5, ax=ax,
                cbar_kws={'label': 'dEV', 'shrink': 0.7},
                xticklabels=[d[:4].capitalize() for d in ALL_DOMAINS],
                yticklabels=[d[:4].capitalize() for d in ALL_DEMOS])

    for i in range(pivot_eff.shape[0]):
        for j in range(pivot_eff.shape[1]):
            p = pivot_p.iloc[i, j] if pd.notna(pivot_p.iloc[i, j]) else 1.0
            star = '***' if p < 0.001 else ('**' if p < 0.01 else ('*' if p < 0.05 else ''))
            if star:
                ax.text(j + 0.85, i + 0.15, star, fontsize=6, ha='right', va='top')
            method = pivot_method.iloc[i, j] if pd.notna(pivot_method.iloc[i, j]) else ''
            if 'fallback' in str(method):
                ax.text(j + 0.15, i + 0.85, '+', fontsize=7, ha='left', va='bottom',
                        color='gray', fontstyle='italic')
    ax.set_title(f'{layer_label(layer)}', fontsize=11)

for ax_idx in range(len(LAYERS), len(axes_flat)):
    axes_flat[ax_idx].set_visible(False)

plt.suptitle('Behavioral Effects — Gemma 2 9B\n'
             '(* p<.05, ** p<.01, *** p<.001; + = fallback)',
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(FIGS / 'fig_behavioral_effects.png')
plt.savefig(FIGS / 'fig_behavioral_effects.pdf')
plt.close()
print("  Saved: fig_behavioral_effects")


# --- Figure 4: Feature quality ---

print("Figure 4: Feature quality distributions")

fig, axes = plt.subplots(2, 2, figsize=(13, 10))

ax = axes[0, 0]
violin_data = [stats_df[stats_df['layer'] == l]['cohens_d'].dropna().values for l in LAYERS]
violin_labels = [f'L{l}' for l in LAYERS]
parts = ax.violinplot(violin_data, positions=range(len(LAYERS)), showmeans=True, showmedians=True)
for pc in parts['bodies']:
    pc.set_facecolor('#3498db'); pc.set_alpha(0.6)
ax.set_xticks(range(len(LAYERS)))
ax.set_xticklabels(violin_labels, fontsize=9)
ax.set_ylabel("Cohen's d")
ax.set_title("A. Effect Size Distribution by Layer", fontweight='bold')
ax.axhline(0, color='black', linestyle='--', linewidth=0.5)
ax.axhline(0.3, color='red', linestyle=':', linewidth=0.5, alpha=0.5)
ax.axhline(-0.3, color='red', linestyle=':', linewidth=0.5, alpha=0.5)

ax = axes[0, 1]
if 'sign_agreement' in stats_df.columns and stats_df['sign_agreement'].notna().any():
    sa_data = [stats_df[stats_df['layer'] == l]['sign_agreement'].dropna().values for l in LAYERS]
    parts = ax.violinplot(sa_data, positions=range(len(LAYERS)), showmeans=True, showmedians=True)
    for pc in parts['bodies']:
        pc.set_facecolor('#2ecc71'); pc.set_alpha(0.6)
    ax.set_xticks(range(len(LAYERS)))
    ax.set_xticklabels(violin_labels, fontsize=9)
    ax.set_ylabel('Sign Agreement')
    ax.set_title('B. Sign Consistency by Layer', fontweight='bold')
    ax.axhline(0.6, color='red', linestyle='--', linewidth=1, label='Threshold')
    ax.legend()
else:
    ax.text(0.5, 0.5, 'sign_agreement unavailable', ha='center', va='center',
            transform=ax.transAxes, fontsize=12)
    ax.set_title('B. Sign Consistency', fontweight='bold')

ax = axes[1, 0]
sig_rows = []
for layer in LAYERS:
    ls = stats_df[stats_df['layer'] == layer]
    row = {'Layer': f'L{layer}', 'FDR Sig': int(ls['significant'].sum()),
           'Large |d|': int(ls['large_effect'].sum()) if 'large_effect' in ls.columns else 0}
    if 'feature_significant' in ls.columns:
        row['Both + Clean'] = int(ls['feature_significant'].sum())
    else:
        row['Both + Clean'] = int((ls['significant'] & ls['large_effect']).sum())
    sig_rows.append(row)
sig_df = pd.DataFrame(sig_rows)
x = np.arange(len(LAYERS)); w = 0.25
ax.bar(x - w, sig_df['FDR Sig'], w, label='FDR Significant', color='#3498db')
ax.bar(x, sig_df['Large |d|'], w, label='Large |d| (>=0.3)', color='#e74c3c')
ax.bar(x + w, sig_df['Both + Clean'], w, label='Sig + Large + Clean', color='#2ecc71')
ax.set_xticks(x); ax.set_xticklabels(sig_df['Layer'], fontsize=9)
ax.set_ylabel('# Features')
ax.set_title('C. Selection Criteria by Layer', fontweight='bold')
ax.legend(fontsize=8)

ax = axes[1, 1]
if HAS_FUNNEL:
    noisy_pct = [funnel_df[funnel_df['layer'] == l]['pct_noisy'].mean() for l in LAYERS]
    colors_bars = ['#e67e22' if l in IT_LAYERS else '#e74c3c' for l in LAYERS]
    ax.bar([f'L{l}' for l in LAYERS], noisy_pct, color=colors_bars, edgecolor='black')
    for i, p in enumerate(noisy_pct):
        ax.annotate(f'{p:.1f}%', (i, p), ha='center', va='bottom', fontsize=9)
    ax.set_ylabel('% Noisy (all 16k features)')
    ax.set_title('D. Pre-Selection Noise Rate', fontweight='bold')
    ax.set_ylim(0, max(noisy_pct) * 1.3 + 1)
elif 'is_noisy' in stats_df.columns:
    noisy_pct = [stats_df[stats_df['layer'] == l]['is_noisy'].mean() * 100 for l in LAYERS]
    colors_bars = ['#e67e22' if l in IT_LAYERS else '#e74c3c' for l in LAYERS]
    ax.bar([f'L{l}' for l in LAYERS], noisy_pct, color=colors_bars, edgecolor='black')
    for i, p in enumerate(noisy_pct):
        ax.annotate(f'{p:.1f}%', (i, p), ha='center', va='bottom', fontsize=9)
    ax.set_ylabel('% Noisy (selected features only)')
    ax.set_title('D. Noise Rate (selected only)', fontweight='bold')
    ax.set_ylim(0, max(max(noisy_pct) * 1.3, 5))
else:
    ax.text(0.5, 0.5, 'Noise data unavailable', ha='center', va='center',
            transform=ax.transAxes, fontsize=12)
    ax.set_title('D. Noise Rate', fontweight='bold')

plt.suptitle('Feature Quality Analysis — Gemma 2 9B', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(FIGS / 'fig_feature_quality.png')
plt.savefig(FIGS / 'fig_feature_quality.pdf')
plt.close()
print("  Saved: fig_feature_quality")


# --- Figure 5: Domain overlap ---

print("Figure 5: Domain overlap")

nrows, ncols = make_2x4_grid(LAYERS)
fig, axes = plt.subplots(nrows, ncols, figsize=(4.5 * ncols, 4 * nrows))
axes_flat = np.atleast_1d(axes).flatten()

for ax_idx, layer in enumerate(LAYERS):
    ax = axes_flat[ax_idx]
    layer_str = str(layer)
    if layer_str not in selected_features:
        ax.set_title(f'{layer_label(layer)}\n(no data)')
        continue

    layer_data = selected_features[layer_str]
    feat_sets = {}
    for key_str, info in layer_data.items():
        parts = key_str.split('_', 1)
        demo, domain = parts[0], parts[1] if len(parts) > 1 else 'unknown'
        feat_sets[(demo, domain)] = set(info.get('features', []))

    n_dom = len(ALL_DOMAINS)
    jac_sum = np.zeros((n_dom, n_dom))
    jac_cnt = np.zeros((n_dom, n_dom))

    for demo in ALL_DEMOS:
        for (d_i, d1), (d_j, d2) in combs(enumerate(ALL_DOMAINS), 2):
            s1 = feat_sets.get((demo, d1), set())
            s2 = feat_sets.get((demo, d2), set())
            if len(s1 | s2) > 0:
                jac_sum[d_i, d_j] += len(s1 & s2) / len(s1 | s2)
                jac_sum[d_j, d_i] += len(s1 & s2) / len(s1 | s2)
            jac_cnt[d_i, d_j] += 1
            jac_cnt[d_j, d_i] += 1

    mask = jac_cnt > 0
    jac_avg = np.zeros_like(jac_sum)
    jac_avg[mask] = jac_sum[mask] / jac_cnt[mask]
    np.fill_diagonal(jac_avg, 1.0)

    sns.heatmap(jac_avg, annot=True, fmt='.2f', cmap='YlOrRd',
                xticklabels=[d[:4].capitalize() for d in ALL_DOMAINS],
                yticklabels=[d[:4].capitalize() for d in ALL_DOMAINS],
                ax=ax, vmin=0, vmax=0.5, linewidths=0.5, cbar_kws={'shrink': 0.7})
    ax.set_title(f'{layer_label(layer)}', fontsize=11)

for ax_idx in range(len(LAYERS), len(axes_flat)):
    axes_flat[ax_idx].set_visible(False)

plt.suptitle('Cross-Domain Feature Overlap (Jaccard) — Gemma 2 9B',
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(FIGS / 'fig_domain_overlap.png')
plt.savefig(FIGS / 'fig_domain_overlap.pdf')
plt.close()
print("  Saved: fig_domain_overlap")

print("\nFigures 1-5 done.")

In [None]:
# Feature extraction analysis — figures 6-10, tables, summary, bundling

# --- Figure 6: Cross-demographic feature profiles ---

print("Figure 6: Cross-demographic feature profiles")

target_layer = max(LAYERS)
ls_target = stats_df[stats_df['layer'] == target_layer]

feat_demo_count = ls_target.groupby('feature_idx')['demographic'].nunique()
multi_feats = feat_demo_count[feat_demo_count >= 3].index.tolist()

if len(multi_feats) > 0:
    feat_abs_d = ls_target[ls_target['feature_idx'].isin(multi_feats)].groupby(
        'feature_idx')['cohens_d'].apply(lambda x: x.abs().mean())
    top_cross = feat_abs_d.nlargest(min(15, len(multi_feats))).index.tolist()

    profile_rows = []
    for feat in top_cross:
        fs = ls_target[ls_target['feature_idx'] == feat]
        row = {demo: fs[fs['demographic'] == demo]['cohens_d'].mean()
               if len(fs[fs['demographic'] == demo]) > 0 else 0
               for demo in ALL_DEMOS}
        profile_rows.append(row)

    profile_matrix = pd.DataFrame(profile_rows, index=top_cross)[ALL_DEMOS]

    fig, ax = plt.subplots(figsize=(8, max(4, len(top_cross) * 0.5 + 1)))
    vmax = max(3.0, float(np.nanmax(np.abs(profile_matrix.values))))
    sns.heatmap(profile_matrix, annot=True, fmt='.1f', cmap='RdBu_r', center=0,
                vmin=-vmax, vmax=vmax, linewidths=0.5,
                xticklabels=[d.capitalize() for d in ALL_DEMOS],
                ax=ax, cbar_kws={'label': "Cohen's d"})
    ax.set_title(f'Cross-Demographic Features — Gemma 2 9B (Layer {target_layer})\n'
                 f'Features in >=3 demographics, ranked by mean |d|', fontweight='bold')
    ax.set_ylabel('Feature')
    plt.tight_layout()
    plt.savefig(FIGS / 'fig_feature_profiles.png')
    plt.savefig(FIGS / 'fig_feature_profiles.pdf')
    plt.close()
    print(f"  Saved: fig_feature_profiles ({len(top_cross)} features)")
else:
    print(f"  Skipped: no features in >=3 demographics at L{target_layer}")


# --- Figure 7: Encoding vs behavioral effect ---

print("Figure 7: Encoding vs behavioral effect")

corr_rows = []
for layer in LAYERS:
    ls = stats_df[stats_df['layer'] == layer]
    lb = behav_df[behav_df['layer'] == layer]
    for demo in ALL_DEMOS:
        for domain in ALL_DOMAINS:
            feat_d = ls[(ls['demographic'] == demo) & (ls['domain'] == domain)]['cohens_d'].abs()
            brow = lb[(lb['demographic'] == demo) & (lb['domain'] == domain)]
            if len(feat_d) > 0 and len(brow) > 0:
                corr_rows.append({
                    'layer': layer, 'demographic': demo, 'domain': domain,
                    'mean_abs_d': feat_d.mean(),
                    'abs_behavioral': abs(brow['effect'].iloc[0]),
                    'selection_method': brow['selection_method'].iloc[0],
                    'depth': LAYER_DEPTH.get(layer, layer),
                })
corr_df = pd.DataFrame(corr_rows)

if len(corr_df) >= 5:
    early_layers = [l for l in LAYERS if l <= 20]
    late_layers = [l for l in LAYERS if l > 20]
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))

    for ax, layer_group, title in zip(axes,
                                       [early_layers, late_layers],
                                       ['Early-Mid Layers (L5-L20)', 'Late Layers (L27-L36)']):
        lc = corr_df[corr_df['layer'].isin(layer_group)]
        if len(lc) < 3:
            ax.set_title(f'{title}\n(insufficient data)')
            continue

        scatter = ax.scatter(lc['mean_abs_d'], lc['abs_behavioral'],
                             c=lc['depth'], cmap='viridis', s=50, alpha=0.7,
                             edgecolors='black', linewidth=0.5)
        plt.colorbar(scatter, ax=ax, label='Depth %', shrink=0.8)

        fb = lc[lc['selection_method'].str.contains('fallback', na=False)]
        if len(fb) > 0:
            ax.scatter(fb['mean_abs_d'], fb['abs_behavioral'],
                       marker='x', color='red', s=80, linewidths=2, label='Fallback')

        r, p_r = pearsonr(lc['mean_abs_d'], lc['abs_behavioral'])
        rho, _ = spearmanr(lc['mean_abs_d'], lc['abs_behavioral'])
        z = np.polyfit(lc['mean_abs_d'], lc['abs_behavioral'], 1)
        x_line = np.linspace(lc['mean_abs_d'].min(), lc['mean_abs_d'].max(), 50)
        ax.plot(x_line, np.polyval(z, x_line), 'k--', alpha=0.5, linewidth=1)

        ax.set_xlabel("Mean |Cohen's d|")
        ax.set_ylabel('|Behavioral Effect|')
        sig = 'p<.001' if p_r < 0.001 else f'p={p_r:.3f}'
        ax.set_title(f'{title}\nr={r:.2f} ({sig}), rho={rho:.2f}', fontweight='bold')
        ax.legend(fontsize=8)

    plt.suptitle('Encoding Strength vs Behavioral Effect — Gemma 2 9B',
                 fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(FIGS / 'fig_encoding_vs_behavior.png')
    plt.savefig(FIGS / 'fig_encoding_vs_behavior.pdf')
    plt.close()
    print("  Saved: fig_encoding_vs_behavior")
else:
    print("  Skipped: insufficient data")


# --- Figure 8: Encoding by domain ---

print("Figure 8: Encoding by domain")

nrows, ncols = make_2x4_grid(LAYERS)
fig, axes = plt.subplots(nrows, ncols, figsize=(4.5 * ncols, 4 * nrows))
axes_flat = np.atleast_1d(axes).flatten()

for ax_idx, layer in enumerate(LAYERS):
    ax = axes_flat[ax_idx]
    ls = stats_df[stats_df['layer'] == layer]
    pivot = ls.pivot_table(values='cohens_d', index='domain', columns='demographic',
                            aggfunc=lambda x: x.abs().mean()
                            ).reindex(index=ALL_DOMAINS, columns=ALL_DEMOS)
    vmax = max(1.0, float(np.nanmax(pivot.values)))
    sns.heatmap(pivot, annot=True, fmt='.2f', cmap='YlOrRd', xticklabels=DEMO_SHORT,
                yticklabels=[d[:4].capitalize() for d in ALL_DOMAINS],
                ax=ax, vmin=0, vmax=vmax, linewidths=0.5, cbar_kws={'shrink': 0.7})
    ax.set_title(f'{layer_label(layer)}', fontsize=11)
    ax.set_ylabel('Domain' if ax_idx % ncols == 0 else '')

for ax_idx in range(len(LAYERS), len(axes_flat)):
    axes_flat[ax_idx].set_visible(False)

plt.suptitle("Mean |Cohen's d| by Domain x Demographic — Gemma 2 9B",
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(FIGS / 'fig_encoding_by_domain.png')
plt.savefig(FIGS / 'fig_encoding_by_domain.pdf')
plt.close()
print("  Saved: fig_encoding_by_domain")


# --- Figure 9: Selection method diagnostics ---

print("Figure 9: Selection method diagnostics")

if 'selection_method' in stats_df.columns:
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))

    ax = axes[0, 0]
    method_data = []
    for layer in LAYERS:
        lb = behav_df[behav_df['layer'] == layer]
        counts = lb['selection_method'].value_counts()
        method_data.append({
            'Layer': f'L{layer}',
            'Significant': counts.get('significant', 0),
            'Fallback (filtered)': counts.get('fallback_noise_filtered', 0),
            'Fallback (unfiltered)': counts.get('fallback_unfiltered', 0),
        })
    method_df = pd.DataFrame(method_data)
    x = np.arange(len(LAYERS)); w = 0.25
    ax.bar(x - w, method_df['Significant'], w, label='Significant', color='#2ecc71')
    ax.bar(x, method_df['Fallback (filtered)'], w, label='Fallback (filtered)', color='#f39c12')
    ax.bar(x + w, method_df['Fallback (unfiltered)'], w, label='Fallback (unfiltered)', color='#e74c3c')
    ax.set_xticks(x); ax.set_xticklabels(method_df['Layer'], fontsize=9)
    ax.set_ylabel('# Demo x Domain Pairs')
    ax.set_title('A. Selection Method by Layer', fontweight='bold')
    ax.legend(fontsize=8)

    ax = axes[0, 1]
    target_layer = max(LAYERS)
    lb_target = behav_df[behav_df['layer'] == target_layer]
    method_pivot = lb_target.pivot_table(values='selection_method', index='demographic',
                                          columns='domain', aggfunc='first'
                                          ).reindex(index=ALL_DEMOS, columns=ALL_DOMAINS)
    method_numeric = method_pivot.copy()
    method_map = {'significant': 0, 'fallback_noise_filtered': 1, 'fallback_unfiltered': 2}
    for col in method_numeric.columns:
        method_numeric[col] = method_numeric[col].map(method_map).fillna(-1)
    from matplotlib.colors import ListedColormap
    cmap_method = ListedColormap(['#2ecc71', '#f39c12', '#e74c3c'])
    sns.heatmap(method_numeric.astype(float), annot=method_pivot.values, fmt='',
                cmap=cmap_method, vmin=0, vmax=2, linewidths=0.5, ax=ax,
                xticklabels=[d[:4].capitalize() for d in ALL_DOMAINS],
                yticklabels=[d[:4].capitalize() for d in ALL_DEMOS], cbar=False)
    ax.set_title(f'B. Selection Method Map (L{target_layer})', fontweight='bold')

    ax = axes[1, 0]
    if 'feature_significant' in stats_df.columns:
        ls_target = stats_df[stats_df['layer'] == target_layer]
        fs_rate = ls_target.pivot_table(values='feature_significant', index='domain',
                                         columns='demographic', aggfunc='mean'
                                         ).reindex(index=ALL_DOMAINS, columns=ALL_DEMOS) * 100
        sns.heatmap(fs_rate, annot=True, fmt='.0f', cmap='YlGn', xticklabels=DEMO_SHORT,
                    yticklabels=[d[:4].capitalize() for d in ALL_DOMAINS],
                    ax=ax, vmin=0, vmax=100, linewidths=0.5,
                    cbar_kws={'label': '% feature_significant'})
        ax.set_title(f'C. % Individually Significant (L{target_layer})', fontweight='bold')
    else:
        ax.text(0.5, 0.5, 'feature_significant\nnot available',
                ha='center', va='center', transform=ax.transAxes, fontsize=12)

    ax = axes[1, 1]
    if HAS_FUNNEL:
        lf = funnel_df[funnel_df['layer'] == target_layer]
        n_total = lf['n_total_features'].iloc[0] if len(lf) > 0 else 16384
        avg_sign_fail = (n_total - lf['n_sign_consistent'].mean()) if 'n_sign_consistent' in lf.columns else 0
        avg_cv_fail = (n_total - lf['n_low_cv'].mean()) if 'n_low_cv' in lf.columns else 0
        avg_outlier_fail = (n_total - lf['n_few_outliers'].mean()) if 'n_few_outliers' in lf.columns else 0
        avg_any_fail = lf['n_noisy'].mean() if 'n_noisy' in lf.columns else 0
        labels = ['Sign < 0.6', 'CV >= 3.0', 'Outlier >= 0.1', 'Any (noisy)']
        values = [avg_sign_fail, avg_cv_fail, avg_outlier_fail, avg_any_fail]
        pcts = [v / n_total * 100 for v in values]
        colors_bar = ['#3498db', '#e74c3c', '#f39c12', '#8e44ad']
        bars = ax.barh(labels, pcts, color=colors_bar, edgecolor='black')
        for bar, pct in zip(bars, pcts):
            ax.text(bar.get_width() + 0.5, bar.get_y() + bar.get_height() / 2,
                    f'{pct:.1f}%', va='center', fontsize=9)
        ax.set_xlabel(f'% of All {n_total:,} Features')
        ax.set_title(f'D. Noise Filter Failures (L{target_layer}, pre-selection)', fontweight='bold')
        ax.set_xlim(0, max(pcts) * 1.3 + 5)
    elif 'sign_agreement' in stats_df.columns:
        ls_t = stats_df[stats_df['layer'] == target_layer]
        n_total = len(ls_t)
        failed = [
            (ls_t['sign_agreement'] < 0.60).sum(),
            (ls_t['cv'] >= 3.0).sum() if ls_t['cv'].notna().any() else 0,
            (ls_t['outlier_ratio'] >= 0.10).sum(),
            ls_t['is_noisy'].sum() if 'is_noisy' in ls_t.columns else 0,
        ]
        labels = ['Sign < 0.6', 'CV >= 3.0', 'Outlier >= 0.1', 'Any (noisy)']
        pcts = [v / n_total * 100 if n_total > 0 else 0 for v in failed]
        colors_bar = ['#3498db', '#e74c3c', '#f39c12', '#8e44ad']
        bars = ax.barh(labels, pcts, color=colors_bar, edgecolor='black')
        for bar, pct in zip(bars, pcts):
            ax.text(bar.get_width() + 0.5, bar.get_y() + bar.get_height() / 2,
                    f'{pct:.1f}%', va='center', fontsize=9)
        ax.set_xlabel('% of Selected Features (post-selection)')
        ax.set_title(f'D. Noise Failures (L{target_layer}, selected only)', fontweight='bold')
        ax.set_xlim(0, max(pcts) * 1.3 + 5)
    else:
        ax.text(0.5, 0.5, 'Noise columns unavailable', ha='center', va='center',
                transform=ax.transAxes, fontsize=12)

    plt.suptitle('Selection Method & Quality Diagnostics — Gemma 2 9B',
                 fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(FIGS / 'fig_selection_diagnostics.png')
    plt.savefig(FIGS / 'fig_selection_diagnostics.pdf')
    plt.close()
    print("  Saved: fig_selection_diagnostics")
else:
    print("  Skipped: selection_method column not available")


# --- Figure 10: Cross-layer encoding gradient ---

print("Figure 10: Cross-layer encoding gradient")

fig, axes = plt.subplots(1, 3, figsize=(16, 5))
depths = [LAYER_DEPTH.get(l, l) for l in LAYERS]

ax = axes[0]
for demo in ALL_DEMOS:
    sensitivity = [stats_df[(stats_df['layer'] == l) & (stats_df['demographic'] == demo)
                            ]['mean_diff'].abs().mean() for l in LAYERS]
    ax.plot(depths, sensitivity, marker='o', linewidth=2, markersize=7,
            label=demo.capitalize(), color=COLORS[demo])
for layer in IT_LAYERS:
    ax.axvline(LAYER_DEPTH.get(layer, layer), color='gray', linestyle=':', alpha=0.4)
ax.set_xlabel('Depth (%)')
ax.set_ylabel('Mean |Activation Diff|')
ax.set_title('A. SAE Activation Sensitivity', fontweight='bold')
ax.set_xticks(depths)
ax.set_xticklabels([f'L{l}' for l in LAYERS], rotation=45, fontsize=9)
ax.legend(fontsize=8)

ax = axes[1]
unique_per_layer = [stats_df[stats_df['layer'] == l]['feature_idx'].nunique() for l in LAYERS]
bar_colors = ['#e67e22' if l in IT_LAYERS else '#3498db' for l in LAYERS]
ax.bar([f'L{l}' for l in LAYERS], unique_per_layer, color=bar_colors, edgecolor='black')
for i, v in enumerate(unique_per_layer):
    ax.annotate(str(v), (i, v), ha='center', va='bottom', fontsize=9)
ax.set_ylabel('# Unique Features')
ax.set_title('B. Feature Diversity by Layer', fontweight='bold')

ax = axes[2]
jaccard_by_layer = []
for layer in LAYERS:
    layer_str = str(layer)
    if layer_str not in selected_features:
        jaccard_by_layer.append(0)
        continue
    layer_data = selected_features[layer_str]
    feat_sets = {}
    for key_str, info in layer_data.items():
        parts = key_str.split('_', 1)
        demo, domain = parts[0], parts[1] if len(parts) > 1 else 'unknown'
        feat_sets[(demo, domain)] = set(info.get('features', []))
    jaccards = []
    for demo in ALL_DEMOS:
        for d1, d2 in combs(ALL_DOMAINS, 2):
            s1 = feat_sets.get((demo, d1), set())
            s2 = feat_sets.get((demo, d2), set())
            if len(s1 | s2) > 0:
                jaccards.append(len(s1 & s2) / len(s1 | s2))
    jaccard_by_layer.append(np.mean(jaccards) if jaccards else 0)

ax.plot(depths, jaccard_by_layer, 'o-', color='#8e44ad', linewidth=2.5, markersize=9)
for layer in IT_LAYERS:
    ax.axvline(LAYER_DEPTH.get(layer, layer), color='gray', linestyle=':', alpha=0.4)
if len(depths) >= 4:
    r, p = pearsonr(depths, jaccard_by_layer)
    ax.set_title(f'C. Domain Specialisation with Depth\nr={r:.2f}, p={p:.3f}', fontweight='bold')
else:
    ax.set_title('C. Domain Specialisation with Depth', fontweight='bold')
ax.set_xlabel('Depth (%)')
ax.set_ylabel('Mean Cross-Domain Jaccard')
ax.set_xticks(depths)
ax.set_xticklabels([f'L{l}' for l in LAYERS], rotation=45, fontsize=9)

plt.suptitle('Cross-Layer Encoding Gradient — Gemma 2 9B', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(FIGS / 'fig_encoding_gradient.png')
plt.savefig(FIGS / 'fig_encoding_gradient.pdf')
plt.close()
print("  Saved: fig_encoding_gradient")


# --- Tables ---

print("\nTable 1: Feature selection summary")

sel_rows = []
for layer_str in sorted(selected_features.keys(), key=int):
    layer = int(layer_str)
    for key_str, info in selected_features[layer_str].items():
        parts = key_str.split('_', 1)
        demo, domain = parts[0], parts[1] if len(parts) > 1 else 'unknown'
        sel_rows.append({
            'Layer': layer,
            'Depth': f"{LAYER_DEPTH.get(layer, '?')}%",
            'SAE': 'IT' if layer in IT_LAYERS else 'PT',
            'Demographic': demo,
            'Domain': domain,
            'N Selected': len(info.get('features', [])),
            'N Significant': info.get('n_significant', 0),
            'N Candidates': info.get('n_candidates', 0),
            'N Not Noisy': info.get('n_not_noisy', 0),
            'N Pairs': info.get('n_pairs', 0),
            'Behav Effect': info.get('behavioral_effect', 0),
            'Behav p': info.get('behavioral_p', None),
            'Mean |d| top': np.mean(np.abs(info.get('cohens_d', [0]))),
            'Selection Method': info.get('selection_method', 'unknown'),
        })

sel_df = pd.DataFrame(sel_rows)

layer_summary = sel_df.groupby(['Layer', 'Depth', 'SAE']).agg({
    'N Significant': 'mean', 'N Candidates': 'mean', 'N Not Noisy': 'mean',
    'N Pairs': 'mean', 'Mean |d| top': 'mean',
}).round(2)

print(layer_summary)
with open(TABLES / 'table_selection_summary.tex', 'w') as f:
    f.write(layer_summary.to_latex(escape=False))
sel_df.to_csv(DATA / 'selection_detail.csv', index=False)

if 'Selection Method' in sel_df.columns:
    method_summary = sel_df.groupby(['Layer', 'Selection Method']).size().unstack(fill_value=0)
    print(f"\nSelection method summary:")
    print(method_summary)
    with open(TABLES / 'table_selection_methods.tex', 'w') as f:
        f.write(method_summary.to_latex(escape=False))


print("\nTable 2: Behavioral effects")

for layer in LAYERS:
    lb = behav_df[behav_df['layer'] == layer]
    pivot = lb.pivot_table(values='effect', index='demographic', columns='domain',
                            aggfunc='mean').reindex(index=ALL_DEMOS, columns=ALL_DOMAINS).round(3)
    print(f"\nLayer {layer} ({LAYER_DEPTH.get(layer, '?')}%):")
    print(pivot.to_string())
    with open(TABLES / f'table_behavioral_L{layer}.tex', 'w') as f:
        f.write(pivot.to_latex(escape=False))


# --- Summary statistics ---

print("\nSummary statistics")

summary = {
    'model': 'Gemma 2 9B IT',
    'n_unique_features_total': int(stats_df['feature_idx'].nunique()),
    'n_rows': len(stats_df),
    'layers': LAYERS,
    'layer_depths': LAYER_DEPTH,
    'it_sae_layers': list(IT_LAYERS),
}

for layer in LAYERS:
    ls = stats_df[stats_df['layer'] == layer]
    lb = behav_df[behav_df['layer'] == layer]
    layer_info = {
        'depth_pct': LAYER_DEPTH.get(layer, '?'),
        'sae_type': 'IT' if layer in IT_LAYERS else 'PT',
        'n_features': int(ls['feature_idx'].nunique()),
        'n_significant': int(ls['significant'].sum()),
        'mean_abs_d': round(float(ls['cohens_d'].abs().mean()), 3),
        'mean_abs_behav': round(float(lb['effect'].abs().mean()), 4) if len(lb) > 0 else None,
        'n_sig_behav': int((lb['p_value'].dropna() < 0.05).sum()) if len(lb) > 0 else None,
    }
    if 'large_effect' in ls.columns:
        layer_info['n_large_effect'] = int(ls['large_effect'].sum())
    if 'feature_significant' in ls.columns:
        layer_info['n_feature_significant'] = int(ls['feature_significant'].sum())
    if 'is_noisy' in ls.columns:
        layer_info['n_noisy'] = int(ls['is_noisy'].sum())
        layer_info['pct_noisy'] = round(float(ls['is_noisy'].mean() * 100), 1)

    n_sig_groups = sum(1 for v in selected_features.get(str(layer), {}).values()
                       if v.get('selection_method') == 'significant')
    n_fb_groups = sum(1 for v in selected_features.get(str(layer), {}).values()
                      if v.get('selection_method', '') != 'significant')
    layer_info['n_groups_significant'] = n_sig_groups
    layer_info['n_groups_fallback'] = n_fb_groups
    summary[f'layer_{layer}'] = layer_info

with open(DATA / 'extraction_summary.json', 'w') as f:
    json.dump(summary, f, indent=2, default=str)


# --- Bundle ---

print("\nBundling...")

n_png = len(list(FIGS.glob('*.png')))
n_pdf = len(list(FIGS.glob('*.pdf')))

readme = f"""# Feature Extraction Results — Gemma 2 9B IT (8-Layer)
# Generated: {pd.Timestamp.now().isoformat()}

## Layers
{', '.join(f'L{l} ({LAYER_DEPTH.get(l, "?")}%{"—IT SAE" if l in IT_LAYERS else ""})' for l in LAYERS)}

## Figures ({n_png} PNG + {n_pdf} PDF)
- fig_encoding_by_layer — Feature x Demographic heatmap (2x4)
- fig_encoding_across_layers — Encoding strength, significant features, candidates
- fig_behavioral_effects — Behavioral EV diff heatmap (2x4)
- fig_feature_quality — Effect size, sign agreement, selection criteria, noise
- fig_domain_overlap — Cross-domain Jaccard (2x4)
- fig_feature_profiles — Cross-demographic features (deepest layer)
- fig_encoding_vs_behavior — Encoding vs behavioral effect
- fig_encoding_by_domain — Mean |d| by Domain x Demographic (2x4)
- fig_selection_diagnostics — Selection method, fallback, significance, noise
- fig_encoding_gradient — Sensitivity, diversity, domain specialisation

## Tables
- table_selection_summary.tex, table_selection_methods.tex
- table_behavioral_L{{...}}.tex

## Data
- feature_stats.csv, behavioral_effects.csv, selected_features.json
- selection_detail.csv, extraction_summary.json
"""

with open(BUNDLE / 'README.md', 'w') as f:
    f.write(readme)

shutil.make_archive(str(BUNDLE), 'zip', root_dir=BUNDLE.parent, base_dir=BUNDLE.name)

print(f"  {n_png} figures, {len(list(TABLES.glob('*')))} tables, {len(list(DATA.glob('*')))} data files")
print(f"  Download: feature-extraction-results-gemma.zip")
print("\nDone.")

In [None]:
# Causal validation — configuration, data loading, helper functions
#
# Requires: model, tokenizer, sae_manager from prior cells

import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
import json
import gc
import shutil
from pathlib import Path
from datetime import datetime
from tqdm.auto import tqdm
from scipy.stats import ttest_ind, ttest_1samp
from statsmodels.stats.multitest import multipletests
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')

torch.set_grad_enabled(False)

# --- Configuration ---

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)

BATCH_SIZE = 8
N_PAIRS_PER_CONDITION = 30
K_VALUES = [5, 10, 20, 50]
STEERING_MULTIPLIER = 2.0
BASE_MIN_EFFECT = 0.3
REFERENCE_SCALE = 10
N_RANDOM_DRAWS = 3
P10_SPLITS = [0.0, 0.5, 1.0]

ANALYSIS_LAYERS = [5, 9, 14, 18, 20, 27, 32, 36]
LAYER_DEPTH = {5: 12, 9: 22, 14: 34, 18: 44, 20: 49, 27: 66, 32: 78, 36: 88}
IT_LAYERS = {20}

ALL_DEMOGRAPHICS = ['income', 'age', 'gender', 'education', 'vote']
ALL_DOMAINS = ['climate', 'health', 'digital', 'economy', 'values']
DOMAIN_MAP = {
    'other': 'values', 'climate': 'climate', 'health': 'health',
    'digital': 'digital', 'economy': 'economy', 'values': 'values',
}

INPUT_DIR = Path("./outputs_gemma_replication/feature_extraction")
OUTPUT_DIR = Path("./causal_validation_final")
OUTPUT_DIR.mkdir(exist_ok=True)

AGGREGATE_FEATURES_PATH = Path("./outputs_gemma_replication/feature_extraction/selected_features.json")
if not AGGREGATE_FEATURES_PATH.exists():
    AGGREGATE_FEATURES_PATH = Path("./selected_features.json")

print(f"Layers: {ANALYSIS_LAYERS}")
print(f"K values: {K_VALUES}, Batch: {BATCH_SIZE}, Pairs/cond: {N_PAIRS_PER_CONDITION}")
print(f"Est. runtime: ~{len(ANALYSIS_LAYERS) * 33}min")

# --- Load data ---

prompts_path = INPUT_DIR / 'prompts_validation.parquet'
if not prompts_path.exists():
    prompts_path = Path("./outputs_gemma_replication") / 'prompts_selection.parquet'
    print(f"Using selection set (validation not found)")

prompts_df = pd.read_parquet(prompts_path)

if 'vocab_idx' in prompts_df.columns:
    prompts_df = prompts_df[prompts_df['vocab_idx'] == 0].reset_index(drop=True)

prompts_df['domain'] = prompts_df['domain'].map(DOMAIN_MAP).fillna(prompts_df['domain'])

if prompts_df['prompt'].iloc[0].startswith('<|begin_of_text|>'):
    import re
    print("Converting Llama prompts to Gemma format...")
    def convert_llama_to_gemma(llama_prompt):
        match = re.search(
            r'<\|start_header_id\|>user<\|end_header_id\|>\n\n(.*?)<\|eot_id\|>',
            llama_prompt, re.DOTALL
        )
        if match:
            content = match.group(1)
            return f"<start_of_turn>user\n{content}<end_of_turn>\n<start_of_turn>model\n"
        return llama_prompt
    prompts_df['prompt'] = prompts_df['prompt'].apply(convert_llama_to_gemma)

print(f"Prompts: {len(prompts_df):,}")

aggregate_features = {}
if AGGREGATE_FEATURES_PATH.exists():
    with open(AGGREGATE_FEATURES_PATH, 'r') as f:
        agg_raw = json.load(f)
    for layer_str, layer_data in agg_raw.items():
        layer_int = int(layer_str)
        aggregate_features[layer_int] = {}
        for key_str, feat_data in layer_data.items():
            aggregate_features[layer_int][key_str] = feat_data['features']
    print(f"Loaded aggregate features: {list(aggregate_features.keys())}")

if hasattr(model, 'tokenizer'):
    tokenizer = model.tokenizer

USE_TRANSFORMER_LENS = hasattr(model, 'run_with_cache')
print(f"TransformerLens: {USE_TRANSFORMER_LENS}, Device: {DEVICE}")

# --- Token info ---

def _build_token_info():
    info = {}
    for i in range(0, 11):
        encoded = tokenizer.encode(str(i), add_special_tokens=False)
        if len(encoded) == 1:
            info[i] = {'type': 'single', 'token_id': encoded[0]}
        else:
            info[i] = {
                'type': 'multi', 'token_ids': encoded,
                'first_token_id': encoded[0], 'second_token_id': encoded[1],
            }
    return info

TOKEN_INFO = _build_token_info()
multi_tokens = [v for v, info in TOKEN_INFO.items() if info['type'] == 'multi']
if multi_tokens:
    for v in multi_tokens:
        ids = TOKEN_INFO[v]['token_ids']
        print(f"  Token '{v}' is multi-token: {ids}")

# --- Helper functions ---

def compute_response_metrics(logits, scale_min, scale_max, p10_split=0.5):
    has_multi = scale_max >= 10 and TOKEN_INFO.get(10, {}).get('type') == 'multi'

    if not has_multi:
        tokens = [TOKEN_INFO[i]['token_id'] for i in range(scale_min, scale_max + 1)]
        logits_subset = logits[tokens].float()
        probs = F.softmax(logits_subset, dim=0).cpu().numpy()
        values = np.arange(scale_min, scale_max + 1)
        ev = float(np.dot(values, probs))
        return ev, float(probs.max()), int(values[probs.argmax()])

    single_max = min(scale_max, 9)
    single_tokens = [TOKEN_INFO[i]['token_id'] for i in range(scale_min, single_max + 1)]
    all_probs = F.softmax(logits[single_tokens].float(), dim=0).cpu().numpy()
    values = list(range(scale_min, single_max + 1))
    probs_list = list(all_probs)

    if 1 in values:
        idx_of_1 = values.index(1)
        p1_raw = probs_list[idx_of_1]
        probs_list[idx_of_1] = p1_raw * (1.0 - p10_split)
        values.append(10)
        probs_list.append(p1_raw * p10_split)
    else:
        values.append(10)
        probs_list.append(0.0)

    values = np.array(values)
    probs = np.array(probs_list)
    probs = probs / probs.sum()
    ev = float(np.dot(values, probs))
    return ev, float(probs.max()), int(values[probs.argmax()])


def compute_ev(logits, scale_min, scale_max, p10_split=0.5):
    ev, _, _ = compute_response_metrics(logits, scale_min, scale_max, p10_split=p10_split)
    return ev


def get_base_pair_key(pk):
    parts = pk.rsplit('_v', 1)
    return parts[0] if len(parts) == 2 else pk


def get_scale_normalized_threshold(scale_min, scale_max):
    scale_range = scale_max - scale_min
    return BASE_MIN_EFFECT * scale_range / REFERENCE_SCALE if scale_range > 0 else BASE_MIN_EFFECT


def build_pairs(prompts_df, demographic, domain, n_pairs):
    df = prompts_df[
        (prompts_df['demographic'] == demographic) &
        (prompts_df['domain'] == domain)
    ].copy()
    if len(df) == 0:
        return []
    df['base_pk'] = df['pair_key'].apply(get_base_pair_key)
    pairs = []
    for base_pk in df['base_pk'].unique():
        if len(pairs) >= n_pairs:
            break
        group = df[df['base_pk'] == base_pk]
        row_a = group[group['value_type'] == 'value_a']
        row_b = group[group['value_type'] == 'value_b']
        if len(row_a) == 0 or len(row_b) == 0:
            continue
        row_a, row_b = row_a.iloc[0], row_b.iloc[0]
        pairs.append({
            'pair_key': base_pk,
            'prompt_a': row_a['prompt'], 'prompt_b': row_b['prompt'],
            'scale_min': int(row_a['scale_min']), 'scale_max': int(row_a['scale_max']),
            'question_id': row_a['question_id'],
            'question_type': row_a.get('question_type', 'unknown'),
            'demographic': demographic, 'domain': domain,
        })
    return pairs


# --- Batched tokenization ---

def tokenize_batch(prompts):
    encodings = [tokenizer.encode(p, add_special_tokens=True) for p in prompts]
    max_len = max(len(e) for e in encodings)
    pad_id = getattr(tokenizer, 'pad_token_id', None) or tokenizer.eos_token_id

    padded, attention_masks, last_positions = [], [], []
    for enc in encodings:
        n_pad = max_len - len(enc)
        padded.append([pad_id] * n_pad + enc)
        attention_masks.append([0] * n_pad + [1] * len(enc))
        last_positions.append(max_len - 1)

    return (torch.tensor(padded, device=DEVICE, dtype=torch.long),
            torch.tensor(attention_masks, device=DEVICE, dtype=torch.long),
            torch.tensor(last_positions, device=DEVICE, dtype=torch.long))


# --- Batched baseline ---

def get_baselines_batched(prompts, layer, sae, scale_mins, scale_maxs):
    all_sae_acts, all_evs, all_n_active = [], [], []
    hook_name = f"blocks.{layer}.hook_resid_post"

    for batch_start in range(0, len(prompts), BATCH_SIZE):
        batch_end = min(batch_start + BATCH_SIZE, len(prompts))
        batch_prompts = prompts[batch_start:batch_end]
        token_ids, attn_mask, last_pos = tokenize_batch(batch_prompts)

        with torch.inference_mode():
            logits, cache = model.run_with_cache(
                token_ids, names_filter=lambda n: hook_name in n,
            )
            resid = cache[hook_name]
            for i in range(len(batch_prompts)):
                h = resid[i, last_pos[i], :].float()
                sae_act = sae.encode(h.unsqueeze(0))[0]
                all_sae_acts.append(sae_act)
                all_n_active.append((sae_act > 0).sum().item())
                all_evs.append(compute_ev(
                    logits[i, last_pos[i], :].cpu(),
                    scale_mins[batch_start + i], scale_maxs[batch_start + i]
                ))
            del cache
            torch.cuda.empty_cache()

    return all_sae_acts, all_evs, all_n_active


# --- Intervention spec + batched execution ---

class InterventionSpec:
    __slots__ = ['prompt', 'features', 'delta', 'mode', 'scale_min', 'scale_max', 'tag']
    def __init__(self, prompt, features, delta, mode, scale_min, scale_max, tag=None):
        self.prompt = prompt; self.features = features; self.delta = delta
        self.mode = mode; self.scale_min = scale_min; self.scale_max = scale_max
        self.tag = tag


def run_interventions_batched(specs, layer, sae):
    results = [None] * len(specs)
    hook_name = f"blocks.{layer}.hook_resid_post"

    for batch_start in range(0, len(specs), BATCH_SIZE):
        batch_specs = specs[batch_start:batch_start + BATCH_SIZE]
        batch_prompts = [s.prompt for s in batch_specs]
        token_ids, attn_mask, last_pos = tokenize_batch(batch_prompts)

        contributions, signs = [], []
        for s in batch_specs:
            feat_idx = s.features.to(DEVICE)
            decoder_weights = sae.W_dec[feat_idx, :].float()
            contrib = (s.delta.float().to(DEVICE).unsqueeze(1) * decoder_weights).sum(dim=0)
            contributions.append(contrib)
            signs.append(1.0 if s.mode == 'add' else -1.0)

        contrib_stack = torch.stack(contributions)
        sign_tensor = torch.tensor(signs, device=DEVICE, dtype=torch.float32)
        batch_last_pos = last_pos[:len(batch_specs)]

        def hook_fn(activation, hook):
            for i in range(activation.shape[0]):
                pos = batch_last_pos[i]
                h = activation[i, pos, :].float()
                activation[i, pos, :] = (h + sign_tensor[i] * contrib_stack[i]).to(activation.dtype)
            return activation

        with torch.inference_mode():
            with model.hooks(fwd_hooks=[(hook_name, hook_fn)]):
                logits = model(token_ids)

        for i, s in enumerate(batch_specs):
            results[batch_start + i] = compute_ev(
                logits[i, last_pos[i], :].cpu(), s.scale_min, s.scale_max
            )
        torch.cuda.empty_cache()

    return results


# --- Random feature sampling ---

def sample_active_random_features(sae_a, sae_b, k, diff_all, rng=None):
    if rng is None: rng = np.random.default_rng()
    active_mask = ((sae_a > 0) | (sae_b > 0)).cpu().numpy()
    active_indices = np.where(active_mask)[0]
    if len(active_indices) < k:
        active_indices = np.arange(len(active_mask))
    chosen = rng.choice(active_indices, size=min(k, len(active_indices)), replace=False)
    features_random = torch.tensor(chosen, device=DEVICE, dtype=torch.long)
    return features_random, diff_all[features_random]


# --- Cross-pair reservoir ---

class CrossPairReservoir:
    def __init__(self):
        self.reservoir = {}

    def add(self, demo, domain, features, delta):
        key = (demo, domain)
        if key not in self.reservoir:
            self.reservoir[key] = []
        self.reservoir[key].append((features.clone(), delta.clone()))

    def get_cross(self, demo, domain, k, rng=None):
        if rng is None: rng = np.random.default_rng()
        candidates = []
        for (d, dom), pairs in self.reservoir.items():
            if d != demo or dom != domain:
                candidates.extend(pairs)
        if len(candidates) == 0:
            return None, None
        feat, delta = candidates[rng.integers(len(candidates))]
        return feat[:k].to(DEVICE), delta[:k].to(DEVICE)


def pre_populate_cross_reservoir(prompts_df, layer, sae):
    print(f"  Pre-populating cross-pair reservoir...")
    reservoir = CrossPairReservoir()
    n_added = 0

    for demo in ALL_DEMOGRAPHICS:
        for domain in ALL_DOMAINS:
            pairs = build_pairs(prompts_df, demo, domain, N_PAIRS_PER_CONDITION)
            if not pairs: continue

            sae_acts_a, evs_a, _ = get_baselines_batched(
                [p['prompt_a'] for p in pairs], layer, sae,
                [p['scale_min'] for p in pairs], [p['scale_max'] for p in pairs]
            )
            sae_acts_b, evs_b, _ = get_baselines_batched(
                [p['prompt_b'] for p in pairs], layer, sae,
                [p['scale_min'] for p in pairs], [p['scale_max'] for p in pairs]
            )

            for i, pair in enumerate(pairs):
                effect_ev = evs_a[i] - evs_b[i]
                threshold = get_scale_normalized_threshold(pair['scale_min'], pair['scale_max'])
                if abs(effect_ev) < threshold: continue

                diff = (sae_acts_a[i] - sae_acts_b[i]).float()
                sorted_idx = torch.argsort(diff.abs(), descending=True)
                reservoir.add(demo, domain, sorted_idx[:max(K_VALUES)], diff[sorted_idx[:max(K_VALUES)]])
                n_added += 1

            torch.cuda.empty_cache()

    print(f"  Cross-pair reservoir: {n_added} entries, {len(reservoir.reservoir)} conditions")
    return reservoir


print("\nSetup complete.")

In [None]:
# Causal validation — main validation loop
#
# Requires: all variables from cell 1

print("=" * 70)
print("RUNNING VALIDATION")
print("=" * 70)

all_results = []
top_features_log = []
exclusion_log = []

for layer in ANALYSIS_LAYERS:
    sae_type = 'IT' if layer in IT_LAYERS else 'PT'
    depth = LAYER_DEPTH.get(layer, '?')
    print(f"\n{'='*60}")
    print(f"LAYER {layer} ({depth}% depth, {sae_type} SAE)")
    print(f"{'='*60}")

    t0 = datetime.now()
    sae = sae_manager.load_sae(layer)

    if hasattr(sae, 'threshold') and sae.threshold is not None:
        print(f"  SAE: JumpReLU (threshold mean={float(sae.threshold.mean()):.3f})")

    layer_agg_features = aggregate_features.get(layer, {})
    cross_reservoir = pre_populate_cross_reservoir(prompts_df, layer, sae)
    rng = np.random.default_rng(SEED + layer)

    for demo in ALL_DEMOGRAPHICS:
        for domain in ALL_DOMAINS:
            pairs = build_pairs(prompts_df, demo, domain, N_PAIRS_PER_CONDITION)
            if not pairs: continue

            agg_key = f"{demo}_{domain}"
            agg_feature_list = layer_agg_features.get(agg_key, None)

            all_prompts_a = [p['prompt_a'] for p in pairs]
            all_prompts_b = [p['prompt_b'] for p in pairs]
            scale_mins = [p['scale_min'] for p in pairs]
            scale_maxs = [p['scale_max'] for p in pairs]

            sae_acts_a, evs_a, n_active_a = get_baselines_batched(
                all_prompts_a, layer, sae, scale_mins, scale_maxs
            )
            sae_acts_b, evs_b, n_active_b = get_baselines_batched(
                all_prompts_b, layer, sae, scale_mins, scale_maxs
            )

            valid_pairs, skipped_pairs = 0, 0

            for pair_idx, pair in enumerate(pairs):
                scale_min, scale_max = pair['scale_min'], pair['scale_max']
                ev_a, ev_b = evs_a[pair_idx], evs_b[pair_idx]
                sa, sb = sae_acts_a[pair_idx], sae_acts_b[pair_idx]
                effect_ev = ev_a - ev_b

                threshold = get_scale_normalized_threshold(scale_min, scale_max)
                if abs(effect_ev) < threshold:
                    skipped_pairs += 1
                    exclusion_log.append({
                        'layer': layer, 'demographic': demo, 'domain': domain,
                        'question_id': pair['question_id'],
                        'question_type': pair.get('question_type', 'unknown'),
                        'scale_min': scale_min, 'scale_max': scale_max,
                        'scale_range': scale_max - scale_min,
                        'threshold': threshold, 'effect_ev': effect_ev, 'excluded': True,
                    })
                    continue

                valid_pairs += 1
                diff = (sa - sb).float()
                sorted_idx = torch.argsort(diff.abs(), descending=True)

                top_features_log.append({
                    'layer': layer, 'demographic': demo, 'domain': domain,
                    'pair_key': pair['pair_key'], 'effect_ev': effect_ev,
                    'top_5_features': sorted_idx[:5].cpu().numpy().tolist(),
                    'top_5_deltas': diff[sorted_idx[:5]].cpu().numpy().tolist(),
                })

                result_base = {
                    'pair_key': pair['pair_key'], 'demographic': demo, 'domain': domain,
                    'layer': layer, 'question_id': pair['question_id'],
                    'question_type': pair.get('question_type', 'unknown'),
                    'scale_min': scale_min, 'scale_max': scale_max,
                    'ev_a_base': ev_a, 'ev_b_base': ev_b, 'effect_ev_base': effect_ev,
                    'n_active_a': n_active_a[pair_idx], 'n_active_b': n_active_b[pair_idx],
                }

                for K in K_VALUES:
                    top_k = sorted_idx[:K]
                    delta_k = diff[top_k]

                    specs, labels = [], []

                    # Per-pair: patch, steer forward, steer reverse, ablate
                    specs.append(InterventionSpec(pair['prompt_b'], top_k, delta_k, 'add', scale_min, scale_max))
                    labels.append('patch_same_perpair')
                    specs.append(InterventionSpec(pair['prompt_a'], top_k, delta_k * STEERING_MULTIPLIER, 'add', scale_min, scale_max))
                    labels.append('steer_same_perpair')
                    specs.append(InterventionSpec(pair['prompt_a'], top_k, delta_k * (-STEERING_MULTIPLIER), 'add', scale_min, scale_max))
                    labels.append('steer_reverse_perpair')
                    specs.append(InterventionSpec(pair['prompt_a'], top_k, delta_k, 'sub', scale_min, scale_max))
                    labels.append('ablate_same_perpair')

                    # Cross-pair
                    cross_feat, cross_delta = cross_reservoir.get_cross(demo, domain, K, rng=rng)
                    has_cross = cross_feat is not None
                    if has_cross:
                        specs.append(InterventionSpec(pair['prompt_b'], cross_feat, cross_delta, 'add', scale_min, scale_max))
                        labels.append('patch_cross')

                    # Aggregate
                    has_agg = (agg_feature_list is not None and len(agg_feature_list) >= K)
                    if has_agg:
                        agg_feats = torch.tensor(agg_feature_list[:K], device=DEVICE, dtype=torch.long)
                        agg_delta = diff[agg_feats]
                        specs.append(InterventionSpec(pair['prompt_b'], agg_feats, agg_delta, 'add', scale_min, scale_max))
                        labels.append('patch_agg')
                        specs.append(InterventionSpec(pair['prompt_a'], agg_feats, agg_delta * STEERING_MULTIPLIER, 'add', scale_min, scale_max))
                        labels.append('steer_agg')
                        specs.append(InterventionSpec(pair['prompt_a'], agg_feats, agg_delta, 'sub', scale_min, scale_max))
                        labels.append('ablate_agg')

                    # Random controls
                    for draw in range(N_RANDOM_DRAWS):
                        rng_a = np.random.default_rng(SEED + layer * 100003 + pair_idx * 1009 + K * 101 + draw * 7)
                        rng_b = np.random.default_rng(SEED + layer * 200003 + pair_idx * 2003 + K * 211 + draw * 13)

                        feat_ra, delta_ra = sample_active_random_features(sa, sb, K, diff, rng=rng_a)
                        feat_rb, _ = sample_active_random_features(sa, sb, K, diff, rng=rng_b)
                        perm_seed = SEED + layer * 300007 + pair_idx * 3001 + draw * 17
                        delta_shuffled = delta_k[torch.randperm(K, generator=torch.Generator().manual_seed(perm_seed))]

                        specs.append(InterventionSpec(pair['prompt_b'], feat_ra, delta_ra, 'add', scale_min, scale_max))
                        labels.append(f'patch_rand_a_{draw}')
                        specs.append(InterventionSpec(pair['prompt_b'], feat_rb, delta_shuffled, 'add', scale_min, scale_max))
                        labels.append(f'patch_rand_b_{draw}')
                        specs.append(InterventionSpec(pair['prompt_a'], feat_ra, delta_ra * STEERING_MULTIPLIER, 'add', scale_min, scale_max))
                        labels.append(f'steer_rand_{draw}')
                        specs.append(InterventionSpec(pair['prompt_a'], feat_ra, delta_ra, 'sub', scale_min, scale_max))
                        labels.append(f'ablate_rand_{draw}')

                    # Dispatch
                    ev_results = run_interventions_batched(specs, layer, sae)
                    R = dict(zip(labels, ev_results))

                    # --- Parse results into rows ---

                    def _add(condition, method, recovery=np.nan, reduction=np.nan, increase=np.nan, dir_pres=np.nan):
                        all_results.append({
                            **result_base, 'K': K,
                            'condition': condition, 'feature_method': method,
                            'recovery_ev': recovery, 'reduction_ev': reduction,
                            'increase_ev': increase, 'direction_preserved': dir_pres,
                        })

                    # Patching
                    rec = (R['patch_same_perpair'] - ev_b) / effect_ev if effect_ev != 0 else 0
                    _add('patch_same', 'per_pair', recovery=rec)

                    rec_ra = (np.mean([R[f'patch_rand_a_{d}'] for d in range(N_RANDOM_DRAWS)]) - ev_b) / effect_ev if effect_ev != 0 else 0
                    _add('patch_random_a', 'active_random', recovery=rec_ra)

                    rec_rb = (np.mean([R[f'patch_rand_b_{d}'] for d in range(N_RANDOM_DRAWS)]) - ev_b) / effect_ev if effect_ev != 0 else 0
                    _add('patch_random_b', 'active_random', recovery=rec_rb)

                    if has_cross:
                        rec_c = (R['patch_cross'] - ev_b) / effect_ev if effect_ev != 0 else 0
                        _add('patch_cross', 'cross_demo_domain', recovery=rec_c)

                    # Steering
                    eff_s = R['steer_same_perpair'] - ev_b
                    _add('steer_same', 'per_pair',
                         increase=(eff_s / effect_ev - 1.0) if effect_ev != 0 else 0,
                         dir_pres=bool(np.sign(eff_s) == np.sign(effect_ev)))

                    eff_r = R['steer_reverse_perpair'] - ev_b
                    _add('steer_reverse', 'per_pair',
                         increase=(eff_r / effect_ev - 1.0) if effect_ev != 0 else 0,
                         dir_pres=bool(np.sign(eff_r) == np.sign(effect_ev)))

                    avg_rs = np.mean([R[f'steer_rand_{d}'] for d in range(N_RANDOM_DRAWS)])
                    eff_rs = avg_rs - ev_b
                    _add('steer_random', 'active_random',
                         increase=(eff_rs / effect_ev - 1.0) if effect_ev != 0 else 0,
                         dir_pres=bool(np.sign(eff_rs) == np.sign(effect_ev)))

                    # Ablation
                    eff_abl = R['ablate_same_perpair'] - ev_b
                    _add('ablate_same', 'per_pair',
                         reduction=1.0 - (eff_abl / effect_ev) if effect_ev != 0 else 0)

                    avg_ra = np.mean([R[f'ablate_rand_{d}'] for d in range(N_RANDOM_DRAWS)])
                    eff_ra = avg_ra - ev_b
                    _add('ablate_random', 'active_random',
                         reduction=1.0 - (eff_ra / effect_ev) if effect_ev != 0 else 0)

                    # Aggregate
                    if has_agg:
                        rec_ag = (R['patch_agg'] - ev_b) / effect_ev if effect_ev != 0 else 0
                        _add('patch_same', 'aggregate', recovery=rec_ag)

                        eff_ag_s = R['steer_agg'] - ev_b
                        _add('steer_same', 'aggregate',
                             increase=(eff_ag_s / effect_ev - 1.0) if effect_ev != 0 else 0,
                             dir_pres=bool(np.sign(eff_ag_s) == np.sign(effect_ev)))

                        eff_ag_a = R['ablate_agg'] - ev_b
                        _add('ablate_same', 'aggregate',
                             reduction=1.0 - (eff_ag_a / effect_ev) if effect_ev != 0 else 0)

                if valid_pairs % 10 == 0:
                    torch.cuda.empty_cache()

            if valid_pairs > 0 or skipped_pairs > 0:
                print(f"  {demo}_{domain}: {valid_pairs} valid, {skipped_pairs} skipped")

    del sae
    torch.cuda.empty_cache()
    gc.collect()
    print(f"\nLayer {layer} ({depth}%) done: {(datetime.now() - t0).total_seconds() / 60:.1f} min")

results_df = pd.DataFrame(all_results)
top_features_df = pd.DataFrame(top_features_log)
exclusion_df = pd.DataFrame(exclusion_log)
print(f"\nValidation complete: {len(results_df)} result rows")

In [None]:
# Causal validation — analysis, summary, saving
#
# Requires: results_df, top_features_df, exclusion_df from cell 2

print("=" * 70)
print("ANALYZING RESULTS")
print("=" * 70)

# --- Exclusion rates ---

if len(exclusion_df) > 0:
    print("\nExclusion rates by question type:")
    for qtype in exclusion_df['question_type'].unique():
        n_excl = len(exclusion_df[exclusion_df['question_type'] == qtype])
        n_incl = results_df[results_df['question_type'] == qtype]['pair_key'].nunique() if 'question_type' in results_df.columns else 0
        total = n_excl + n_incl
        print(f"  {qtype}: {n_excl}/{total} ({n_excl/total*100:.1f}%)" if total > 0 else f"  {qtype}: 0")

# --- Stats helpers ---

def compute_stats(values):
    values = values.dropna()
    if len(values) == 0: return None
    t_stat, p_val = ttest_1samp(values, 0) if len(values) >= 3 else (np.nan, np.nan)
    return {
        'mean': float(values.mean()), 'std': float(values.std()),
        'median': float(values.median()),
        'pct_positive': float((values > 0).mean() * 100),
        'n': int(len(values)),
        't_stat': float(t_stat) if not np.isnan(t_stat) else None,
        'p_value': float(p_val) if not np.isnan(p_val) else None,
    }


def compute_comparison(df, cond1, cond2, metric, method='per_pair'):
    vals1 = df[(df['condition'] == cond1) & (df['feature_method'] == method)][metric].dropna()
    vals2 = df[(df['condition'] == cond2) & (df['feature_method'] == 'active_random')][metric].dropna()
    if len(vals1) < 5 or len(vals2) < 5: return None
    t_stat, p_val = ttest_ind(vals1, vals2, equal_var=False)
    denom = np.sqrt((vals1.var(ddof=1) + vals2.var(ddof=1)) / 2)
    cohens_d = (vals1.mean() - vals2.mean()) / denom if denom > 0 else 0
    return {
        'diff': float(vals1.mean() - vals2.mean()),
        't_stat': float(t_stat), 'p_value': float(p_val),
        'cohens_d': float(cohens_d), 'n1': int(len(vals1)), 'n2': int(len(vals2)),
    }


def apply_fdr_to_comparisons(summary_dict):
    p_values, p_locations = [], []

    for layer_str, layer_data in summary_dict['layers'].items():
        for k_str, k_data in layer_data.get('by_K', {}).items():
            for intervention in ['patching', 'steering', 'ablation']:
                comp = k_data.get(intervention, {}).get('same_vs_random')
                if comp and comp.get('p_value') is not None:
                    p_values.append(comp['p_value'])
                    p_locations.append(('by_K', layer_str, k_str, intervention))

        for demo_name, demo_data in layer_data.get('by_demographic', {}).items():
            for intervention in ['patching', 'steering', 'ablation']:
                stats = demo_data.get(intervention)
                if stats and stats.get('p_value') is not None:
                    p_values.append(stats['p_value'])
                    p_locations.append(('by_demo', layer_str, demo_name, intervention))

    if not p_values: return summary_dict
    rejected, p_corrected, _, _ = multipletests(p_values, method='fdr_bh', alpha=0.05)

    for i, loc in enumerate(p_locations):
        if loc[0] == 'by_K':
            comp = summary_dict['layers'][loc[1]]['by_K'][loc[2]][loc[3]]['same_vs_random']
            comp['p_value_fdr'] = float(p_corrected[i])
            comp['significant_fdr'] = bool(rejected[i])
        elif loc[0] == 'by_demo':
            stats = summary_dict['layers'][loc[1]]['by_demographic'][loc[2]][loc[3]]
            stats['p_value_fdr'] = float(p_corrected[i])
            stats['significant_fdr'] = bool(rejected[i])

    summary_dict['fdr_correction'] = {
        'n_tests': len(p_values), 'n_significant': int(sum(rejected)),
        'n_by_K': sum(1 for l in p_locations if l[0] == 'by_K'),
        'n_by_demographic': sum(1 for l in p_locations if l[0] == 'by_demo'),
        'method': 'fdr_bh',
    }
    return summary_dict


# --- Build summary ---

summary = {
    'metadata': {
        'timestamp': datetime.now().isoformat(),
        'model': 'gemma-2-9b-it',
        'sae_activation': 'JumpReLU',
        'layers': ANALYSIS_LAYERS, 'layer_depths': LAYER_DEPTH,
        'it_sae_layers': list(IT_LAYERS), 'k_values': K_VALUES,
        'base_min_effect': BASE_MIN_EFFECT, 'steering_multiplier': STEERING_MULTIPLIER,
        'n_random_draws': N_RANDOM_DRAWS, 'n_pairs_per_condition': N_PAIRS_PER_CONDITION,
        'batch_size': BATCH_SIZE, 'has_aggregate_features': bool(aggregate_features),
    },
    'layers': {},
}

for layer in ANALYSIS_LAYERS:
    layer_df = results_df[results_df['layer'] == layer]
    summary['layers'][str(layer)] = {'by_K': {}, 'by_demographic': {}}

    for K in K_VALUES:
        k_df = layer_df[layer_df['K'] == K]
        k_pp = k_df[k_df['feature_method'] == 'per_pair']

        k_entry = {
            'patching': {
                'same': compute_stats(k_pp[k_pp['condition'] == 'patch_same']['recovery_ev']),
                'random_a': compute_stats(k_df[k_df['condition'] == 'patch_random_a']['recovery_ev']),
                'random_b': compute_stats(k_df[k_df['condition'] == 'patch_random_b']['recovery_ev']),
                'cross': compute_stats(k_df[k_df['condition'] == 'patch_cross']['recovery_ev']),
                'same_vs_random': compute_comparison(k_df, 'patch_same', 'patch_random_b', 'recovery_ev'),
            },
            'steering': {
                'same': compute_stats(k_pp[k_pp['condition'] == 'steer_same']['increase_ev']),
                'reverse': compute_stats(k_pp[k_pp['condition'] == 'steer_reverse']['increase_ev']),
                'random': compute_stats(k_df[k_df['condition'] == 'steer_random']['increase_ev']),
                'same_vs_random': compute_comparison(k_df, 'steer_same', 'steer_random', 'increase_ev'),
                'direction_preserved_pct': float(
                    k_pp[k_pp['condition'] == 'steer_same']['direction_preserved'].mean() * 100
                ) if len(k_pp[k_pp['condition'] == 'steer_same']) > 0 else None,
            },
            'ablation': {
                'same': compute_stats(k_pp[k_pp['condition'] == 'ablate_same']['reduction_ev']),
                'random': compute_stats(k_df[k_df['condition'] == 'ablate_random']['reduction_ev']),
                'same_vs_random': compute_comparison(k_df, 'ablate_same', 'ablate_random', 'reduction_ev'),
            },
        }

        k_agg = k_df[k_df['feature_method'] == 'aggregate']
        if len(k_agg) > 0:
            k_entry['aggregate'] = {
                'patching': compute_stats(k_agg[k_agg['condition'] == 'patch_same']['recovery_ev']),
                'steering': compute_stats(k_agg[k_agg['condition'] == 'steer_same']['increase_ev']),
                'ablation': compute_stats(k_agg[k_agg['condition'] == 'ablate_same']['reduction_ev']),
            }

        summary['layers'][str(layer)]['by_K'][str(K)] = k_entry

    for demo_name in ALL_DEMOGRAPHICS:
        demo_df = layer_df[
            (layer_df['demographic'] == demo_name) &
            (layer_df['K'] == 50) &
            (layer_df['feature_method'] == 'per_pair')
        ]
        if len(demo_df) > 0:
            summary['layers'][str(layer)]['by_demographic'][demo_name] = {
                'patching': compute_stats(demo_df[demo_df['condition'] == 'patch_same']['recovery_ev']),
                'steering': compute_stats(demo_df[demo_df['condition'] == 'steer_same']['increase_ev']),
                'ablation': compute_stats(demo_df[demo_df['condition'] == 'ablate_same']['reduction_ev']),
                'n_pairs': int(len(demo_df[demo_df['condition'] == 'patch_same'])),
            }

summary = apply_fdr_to_comparisons(summary)

# --- P(10) sensitivity ---

has_scale11 = (results_df['scale_max'] >= 10).any() if len(results_df) > 0 else False
sensitivity_results = {'has_scale11': bool(has_scale11)}

if has_scale11:
    affected = results_df[results_df['scale_max'] >= 10][['pair_key', 'layer']].drop_duplicates()
    frac = len(affected) / len(results_df['pair_key'].unique()) if len(results_df) > 0 else 0
    sensitivity_results.update({
        'n_affected_pairs': int(len(affected)),
        'splits_tested': P10_SPLITS, 'default_split': 0.5,
        'affected_fraction': float(frac),
    })
    print(f"\nP(10) sensitivity: {frac:.1%} of pairs affected")
else:
    print("\nNo scale-11 questions — P(10) sensitivity N/A")


# --- Print summary ---

for layer in ANALYSIS_LAYERS:
    depth = LAYER_DEPTH.get(layer, '?')
    print(f"\n{'='*60}")
    print(f"LAYER {layer} ({depth}% depth)")
    print(f"{'='*60}")

    print(f"\n{'K':<6} {'Patching':<20} {'Steering':<20} {'Ablation':<20}")
    print("-" * 70)
    for K in K_VALUES:
        k_data = summary['layers'][str(layer)]['by_K'][str(K)]
        p = k_data['patching']['same']
        s = k_data['steering']['same']
        a = k_data['ablation']['same']
        print(f"{K:<6} {p['mean']:+.1%  if p else 'N/A':<20} "
              f"{s['mean']:+.1% if s else 'N/A':<20} "
              f"{a['mean']:+.1% if a else 'N/A':<20}")

    k50 = summary['layers'][str(layer)]['by_K']['50']
    if 'aggregate' in k50:
        print(f"\nAggregate vs Per-Pair (K=50):")
        for metric, label in [('patching', 'Patch'), ('steering', 'Steer'), ('ablation', 'Ablate')]:
            pp = k50[metric]['same']
            ag = k50['aggregate'].get(metric)
            print(f"  {label}: per-pair={pp['mean']:+.1%}, agg={ag['mean']:+.1%}" if pp and ag else f"  {label}: N/A")

    if k50['patching']['same_vs_random']:
        comp = k50['patching']['same_vs_random']
        fdr_p = comp.get('p_value_fdr', comp['p_value'])
        print(f"\n  Same vs Random: d={comp['cohens_d']:.2f}, p_fdr={fdr_p:.4f}")

    if k50['steering'].get('direction_preserved_pct') is not None:
        print(f"  Direction preserved: {k50['steering']['direction_preserved_pct']:.1f}%")

    print(f"\n  {'Demo':<12} {'N':<6} {'Patch':<12} {'Steer':<12} {'Ablate':<12}")
    print("  " + "-" * 55)
    for d in ALL_DEMOGRAPHICS:
        if d in summary['layers'][str(layer)]['by_demographic']:
            dd = summary['layers'][str(layer)]['by_demographic'][d]
            n = dd['n_pairs']
            pv = f"{dd['patching']['mean']:+.1%}" if dd['patching'] else "N/A"
            sv = f"{dd['steering']['mean']:+.1%}" if dd['steering'] else "N/A"
            av = f"{dd['ablation']['mean']:+.1%}" if dd['ablation'] else "N/A"
            print(f"  {d:<12} {n:<6} {pv:<12} {sv:<12} {av:<12}")


# --- Save ---

print("\n" + "=" * 70)
print("SAVING")
print("=" * 70)

results_df.to_csv(OUTPUT_DIR / 'validation_results.csv', index=False)
top_features_df.to_csv(OUTPUT_DIR / 'top_features.csv', index=False)
exclusion_df.to_csv(OUTPUT_DIR / 'exclusion_log.csv', index=False)

summary['sensitivity_p10'] = sensitivity_results
with open(OUTPUT_DIR / 'validation_summary.json', 'w') as f:
    json.dump(summary, f, indent=2, default=str)
with open(OUTPUT_DIR / 'sensitivity_p10.json', 'w') as f:
    json.dump(sensitivity_results, f, indent=2, default=str)

print(f"validation_results.csv ({len(results_df)} rows)")
print(f"top_features.csv ({len(top_features_df)} rows)")
print(f"exclusion_log.csv ({len(exclusion_df)} rows)")
print(f"validation_summary.json")
print(f"sensitivity_p10.json")

# --- Top features for interpretation ---

feature_counts = defaultdict(lambda: defaultdict(int))
for _, row in top_features_df.iterrows():
    key = (row['layer'], row['demographic'])
    for feat in row['top_5_features']:
        feature_counts[key][feat] += 1

for rep_layer in [5, 20, 36]:
    if rep_layer in ANALYSIS_LAYERS:
        print(f"\nFrequent top-5 features (L{rep_layer}, {LAYER_DEPTH.get(rep_layer, '?')}%):")
        for d in ALL_DEMOGRAPHICS:
            key = (rep_layer, d)
            if key in feature_counts:
                top = sorted(feature_counts[key].items(), key=lambda x: -x[1])[:5]
                print(f"  {d}: {', '.join(f'F{f}({c}x)' for f, c in top)}")

# --- Google Drive / zip ---

drive_base = Path("/content/drive/MyDrive")
drive_dest = drive_base / "TACL_Gemma_Results" / "causal_validation"

if drive_base.exists():
    drive_dest.mkdir(parents=True, exist_ok=True)
    for fname in ['validation_results.csv', 'top_features.csv', 'exclusion_log.csv',
                   'validation_summary.json', 'sensitivity_p10.json']:
        src = OUTPUT_DIR / fname
        if src.exists(): shutil.copy2(src, drive_dest / fname)
    print(f"\nSaved to Drive: {drive_dest}")
else:
    zip_file = shutil.make_archive(str(Path('.') / "gemma_causal_validation"), 'zip', root_dir=str(OUTPUT_DIR))
    try:
        from google.colab import files
        files.download(zip_file)
    except ImportError:
        print(f"Zip: {zip_file}")

print("\nDone.")

In [None]:
# Causal validation analysis — configuration and data loading

import pandas as pd
import numpy as np
import json
import ast
import shutil
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
from pathlib import Path
from scipy.stats import ttest_rel, pearsonr
from collections import Counter, defaultdict
import warnings
warnings.filterwarnings('ignore')

def safe_parse_list(x):
    if isinstance(x, str):
        return ast.literal_eval(x)
    return x

plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams.update({
    'font.family': 'DejaVu Sans', 'font.size': 11,
    'axes.labelsize': 12, 'axes.titlesize': 13, 'legend.fontsize': 10,
    'figure.dpi': 150, 'savefig.dpi': 300, 'savefig.bbox': 'tight',
})

MODEL_NAME = "Gemma 2 9B IT"
ANALYSIS_LAYERS = [5, 9, 14, 18, 20, 27, 32, 36]
BEST_LAYER = 36
N_LAYERS_TOTAL = 42

LAYER_DEPTH = {5: 12, 9: 22, 14: 34, 18: 44, 20: 49, 27: 66, 32: 78, 36: 88}
IT_LAYERS = {20}

EARLY_LAYERS = [5, 9, 14]
MID_LAYERS = [18, 20]
LATE_LAYERS = [27, 32, 36]

def layer_label(l, short=False):
    depth = LAYER_DEPTH.get(l, l / N_LAYERS_TOTAL * 100)
    it_tag = " (IT)" if l in IT_LAYERS else ""
    return f"L{l}{it_tag}" if short else f"Layer {l} ({depth:.0f}%{it_tag})"

def layer_label_compact(l):
    depth = LAYER_DEPTH.get(l, '?')
    it_tag = "*" if l in IT_LAYERS else ""
    return f"L{l} ({depth}%){it_tag}"

INPUT_DIR = Path("/content/")
VALIDATION_DIR = INPUT_DIR / "causal_validation_final"
REANALYSIS_DIR = VALIDATION_DIR / "reanalysis"
EXTRACTION_DIR = INPUT_DIR / "outputs_gemma_replication" / "feature_extraction"

BUNDLE_DIR = Path("./causal-validation-gemma-results")
BUNDLE_DIR.mkdir(exist_ok=True)
FIG_DIR = BUNDLE_DIR / "figures"
TABLE_DIR = BUNDLE_DIR / "tables"
DATA_DIR = BUNDLE_DIR / "data"
SUMMARY_DIR = BUNDLE_DIR / "summary"
for d in [FIG_DIR, TABLE_DIR, DATA_DIR, SUMMARY_DIR]:
    d.mkdir(exist_ok=True)

STEER_CLIP_LO, STEER_CLIP_HI = -2.0, 2.0

COLORS = {
    'patching': '#2ecc71', 'steering': '#3498db', 'ablation': '#e74c3c',
    'random_a': '#7f8c8d', 'random_b': '#bdc3c7', 'random': '#95a5a6',
    'cross': '#f39c12', 'aggregate': '#8e44ad',
    'income': '#1abc9c', 'age': '#9b59b6', 'gender': '#e91e63',
    'education': '#3498db', 'vote': '#f44336',
    'early': '#3498db', 'mid': '#f39c12', 'late': '#e74c3c',
}

print("Loading data...")

results_df = pd.read_csv(VALIDATION_DIR / 'validation_results.csv')
top_features_df = pd.read_csv(VALIDATION_DIR / 'top_features.csv')

patched_summary_path = REANALYSIS_DIR / 'validation_summary_patched.json'
old_summary_path = VALIDATION_DIR / 'validation_summary.json'

if patched_summary_path.exists():
    with open(patched_summary_path, 'r') as f:
        summary = json.load(f)
    print(f"  Loaded PATCHED summary (paired tests)")
    using_patched = True
else:
    with open(old_summary_path, 'r') as f:
        summary = json.load(f)
    print(f"  Using OLD summary (patched not found)")
    using_patched = False

paired_comp_path = REANALYSIS_DIR / 'paired_comparisons.csv'
paired_comp_df = pd.read_csv(paired_comp_path) if paired_comp_path.exists() else None

behav_path = EXTRACTION_DIR / 'behavioral_effects.csv'
behavioral_df = pd.read_csv(behav_path) if behav_path.exists() else None

exclusion_path = VALIDATION_DIR / 'exclusion_log.csv'
exclusion_df = pd.read_csv(exclusion_path) if exclusion_path.exists() else pd.DataFrame()

feature_stats_path = EXTRACTION_DIR / 'feature_stats.csv'
feature_stats_df = pd.read_csv(feature_stats_path) if feature_stats_path.exists() else None

has_method = 'feature_method' in results_df.columns
has_direction = 'direction_preserved' in results_df.columns

print(f"  Results: {len(results_df)} rows, layers: {sorted(results_df['layer'].unique())}")
print(f"  Feature stats: {'yes' if feature_stats_df is not None else 'no'}")

def per_pair(df):
    return df[df['feature_method'] == 'per_pair'] if has_method else df

def agg_only(df):
    return df[df['feature_method'] == 'aggregate'] if has_method else pd.DataFrame()

def winsorize_steering(series):
    return series.clip(lower=STEER_CLIP_LO, upper=STEER_CLIP_HI)

ALL_DEMOS = ['income', 'age', 'gender', 'education', 'vote']
best_str = str(BEST_LAYER)

methods = [
    ('Patching', 'patch_same', 'patch_random_a', 'recovery_ev'),
    ('Steering', 'steer_same', 'steer_random', 'increase_ev'),
    ('Ablation', 'ablate_same', 'ablate_random', 'reduction_ev'),
]

method_conds_list = ['patch_same', 'steer_same', 'ablate_same']
method_metrics_list = ['recovery_ev', 'increase_ev', 'reduction_ev']

print("Setup complete.")

In [None]:
# Causal validation analysis — figures 1-4c

# --- 1. Behavioral effects heatmap ---

print("1. Behavioral effects")

pairs_df = results_df[['pair_key', 'demographic', 'domain', 'layer',
                        'effect_ev_base', 'ev_a_base', 'ev_b_base']].drop_duplicates()
pairs_best = pairs_df[pairs_df['layer'] == BEST_LAYER]

pivot_effects = pairs_best.pivot_table(
    values='effect_ev_base', index='demographic', columns='domain', aggfunc='mean'
)

fig, ax = plt.subplots(figsize=(10, 6))
sns.heatmap(pivot_effects, annot=True, fmt='.2f', cmap='RdBu_r', center=0,
            linewidths=0.5, ax=ax, cbar_kws={'label': 'Effect (EV_a - EV_b)'})
ax.set_title(f'{MODEL_NAME}: Baseline Demographic Effects by Domain')
ax.set_xlabel('Domain'); ax.set_ylabel('Demographic')
plt.tight_layout()
plt.savefig(FIG_DIR / 'fig1_behavioral_effects_heatmap.png')
plt.savefig(FIG_DIR / 'fig1_behavioral_effects_heatmap.pdf')
plt.close()
print("  Saved: fig1_behavioral_effects_heatmap")

# 1b. Exclusion rates
if len(exclusion_df) > 0 and 'scale_range' in exclusion_df.columns:
    scale_groups = exclusion_df.groupby('scale_range').agg(
        n_excluded=('question_id', 'count'), threshold=('threshold', 'first')
    ).reset_index()

    fig, ax = plt.subplots(figsize=(8, 5))
    ax.bar(scale_groups['scale_range'].astype(str), scale_groups['n_excluded'],
           color='#e74c3c', edgecolor='black')
    for i, row in scale_groups.iterrows():
        ax.annotate(f"thr={row['threshold']:.3f}", (i, row['n_excluded']),
                    ha='center', va='bottom', fontsize=9)
    ax.set_xlabel('Scale Range'); ax.set_ylabel('Pairs Excluded')
    ax.set_title(f'{MODEL_NAME}: Exclusion Rates by Scale Range')
    plt.tight_layout()
    plt.savefig(FIG_DIR / 'fig1b_exclusion_rates.png')
    plt.savefig(FIG_DIR / 'fig1b_exclusion_rates.pdf')
    plt.close()
    print("  Saved: fig1b_exclusion_rates")


# --- 2. Main causal validation results ---

print("\n2. Main results")

main_all = results_df[(results_df['layer'] == BEST_LAYER) & (results_df['K'] == 50)]
main_df = per_pair(main_all)

main_results = []
computed_cohens_d = {}

if using_patched:
    k50_best = summary['layers'][best_str]['by_K']['50']
    for name, same_cond, rand_cond, metric in methods:
        same_stats = k50_best[name.lower()]['same']
        if name == 'Patching':
            comp = k50_best['patching']['same_vs_random_a']
            rand_stats = k50_best['patching']['random_a']
        elif name == 'Steering':
            comp = k50_best['steering']['same_vs_random']
            rand_stats = k50_best['steering']['random']
        else:
            comp = k50_best['ablation']['same_vs_random']
            rand_stats = k50_best['ablation']['random']

        d_z = comp['cohens_d_z'] if comp else 0
        p_val = comp['p_value'] if comp else 1
        pct_better = comp['pct_same_better'] if comp else 0
        computed_cohens_d[name] = d_z

        main_results.append({
            'Method': name,
            'Same-Pair': f"{same_stats['mean']:.1%}",
            'Median': f"{same_stats['median']:.1%}",
            'Random': f"{rand_stats['mean']:.1%}" if rand_stats else "N/A",
            "d_z (paired)": f"{d_z:.2f}",
            'p-value': f"<1e-10" if p_val < 1e-10 else f"{p_val:.2e}",
            'Same>Rand': f"{pct_better:.0f}%",
            'N': same_stats['n'],
        })
else:
    for name, same_cond, rand_cond, metric in methods:
        same = main_df[main_df['condition'] == same_cond][metric].dropna()
        rand = main_all[main_all['condition'] == rand_cond][metric].dropna()
        if 'steer' in same_cond:
            same = winsorize_steering(same); rand = winsorize_steering(rand)
        pooled_std = np.sqrt((same.var() + rand.var()) / 2)
        d = (same.mean() - rand.mean()) / pooled_std if pooled_std > 0 else 0
        computed_cohens_d[name] = d
        main_results.append({
            'Method': name, 'Same-Pair': f"{same.mean():.1%}",
            'Median': f"{same.median():.1%}", 'Random': f"{rand.mean():.1%}",
            "Cohen's d": f"{d:.2f}", 'p-value': "<0.001", 'N': len(same),
        })

main_results_df = pd.DataFrame(main_results)
print(main_results_df.to_string(index=False))

with open(TABLE_DIR / 'table1_main_results.tex', 'w') as f:
    f.write(main_results_df.to_latex(index=False, escape=False))

# 2a. Bar chart
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

for idx, (name, same_cond, rand_cond, metric) in enumerate(methods):
    ax = axes[idx]
    same = main_df[main_df['condition'] == same_cond][metric].dropna()
    if 'steer' in same_cond: same = winsorize_steering(same)

    if 'patch' in same_cond:
        rand_a = main_all[main_all['condition'] == 'patch_random_a'][metric].dropna()
        rand_b = main_all[main_all['condition'] == 'patch_random_b'][metric].dropna()
        cross = main_all[main_all['condition'] == 'patch_cross'][metric].dropna()
        data = [same.mean(), rand_a.mean(), cross.mean(), rand_b.mean()]
        errors = [same.sem(), rand_a.sem(), cross.sem(), rand_b.sem()]
        labels = ['Same-Pair', 'Random-A\n(active+delta)', 'Cross\n(diff d×d)', 'Random-B\n(shuffled Δ)']
        colors = [COLORS['patching'], COLORS['random_a'], COLORS['cross'], COLORS['random_b']]
    elif 'steer' in same_cond:
        rand = winsorize_steering(main_all[main_all['condition'] == 'steer_random']['increase_ev'].dropna())
        data = [same.mean(), rand.mean()]
        errors = [same.sem(), rand.sem()]
        labels = ['Same-Pair\n(Winsor. ±200%)', 'Random\n(Winsor. ±200%)']
        colors = [COLORS['steering'], COLORS['random']]
    else:
        rand = main_all[main_all['condition'] == rand_cond][metric].dropna()
        data = [same.mean(), rand.mean()]
        errors = [same.sem(), rand.sem()]
        labels = ['Same-Pair', 'Random\n(active+delta)']
        colors = [COLORS['ablation'], COLORS['random']]

    bars = ax.bar(range(len(data)), [d * 100 for d in data],
                  yerr=[e * 100 for e in errors], capsize=5,
                  color=colors, edgecolor='black', linewidth=1)
    ax.set_xticks(range(len(data)))
    ax.set_xticklabels(labels, rotation=15, ha='right', fontsize=9)
    ax.set_ylabel('Effect (%)')
    ax.set_title(f'{name} (d_z={computed_cohens_d.get(name, 0):.2f})')
    ax.axhline(y=0, color='black', linestyle='-', linewidth=0.5)

    for bar, val in zip(bars, data):
        h = bar.get_height()
        ax.annotate(f'{val:.1%}', xy=(bar.get_x() + bar.get_width() / 2, h),
                    xytext=(0, 3 if h >= 0 else -15), textcoords="offset points",
                    ha='center', va='bottom' if h >= 0 else 'top', fontsize=9, fontweight='bold')

plt.suptitle(f'{MODEL_NAME}: Causal Validation (Layer {BEST_LAYER}, K=50, Paired Tests)',
             fontsize=13, fontweight='bold')
plt.tight_layout()
plt.savefig(FIG_DIR / 'fig2_main_results.png')
plt.savefig(FIG_DIR / 'fig2_main_results.pdf')
plt.close()
print("  Saved: fig2_main_results")

# 2b. Aggregate vs per-pair
agg_main = agg_only(main_all)
if len(agg_main) > 0:
    fig, axes = plt.subplots(1, 3, figsize=(14, 5))
    agg_results = []
    for idx, (name, same_cond, _, metric) in enumerate(methods):
        ax = axes[idx]
        pp_vals = main_df[main_df['condition'] == same_cond][metric].dropna()
        ag_vals = agg_main[agg_main['condition'] == same_cond][metric].dropna()
        if 'steer' in same_cond:
            pp_vals = winsorize_steering(pp_vals); ag_vals = winsorize_steering(ag_vals)

        data = [pp_vals.mean() if len(pp_vals) > 0 else 0,
                ag_vals.mean() if len(ag_vals) > 0 else 0]
        errors = [pp_vals.sem() if len(pp_vals) > 0 else 0,
                  ag_vals.sem() if len(ag_vals) > 0 else 0]

        bars = ax.bar([0, 1], [d * 100 for d in data], yerr=[e * 100 for e in errors],
                      capsize=5, color=[COLORS[name.lower()], COLORS['aggregate']],
                      edgecolor='black', linewidth=1)
        ax.set_xticks([0, 1]); ax.set_xticklabels(['Per-Pair', 'Aggregate'])
        ax.set_ylabel('Effect (%)'); ax.set_title(name)
        ax.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
        for bar, val in zip(bars, data):
            h = bar.get_height()
            ax.annotate(f'{val:.1%}', xy=(bar.get_x() + bar.get_width() / 2, h),
                        xytext=(0, 3 if h >= 0 else -15), textcoords="offset points",
                        ha='center', va='bottom' if h >= 0 else 'top', fontsize=10, fontweight='bold')
        agg_results.append({'Method': name, 'Per-Pair': f"{data[0]:.1%}", 'Aggregate': f"{data[1]:.1%}",
                            'Ratio': f"{data[1]/data[0]:.2f}" if abs(data[0]) > 1e-6 else "N/A"})

    plt.suptitle(f'{MODEL_NAME}: Per-Pair vs Aggregate (Layer {BEST_LAYER}, K=50)', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(FIG_DIR / 'fig2b_aggregate_vs_perpair.png')
    plt.savefig(FIG_DIR / 'fig2b_aggregate_vs_perpair.pdf')
    plt.close()
    with open(TABLE_DIR / 'table1b_aggregate_vs_perpair.tex', 'w') as f:
        f.write(pd.DataFrame(agg_results).to_latex(index=False, escape=False))
    print("  Saved: fig2b_aggregate_vs_perpair")


# --- 3. Dose-response ---

print("\n3. Dose-response")

fig, axes = plt.subplots(1, 3, figsize=(14, 5))
layer_pp = per_pair(results_df[results_df['layer'] == BEST_LAYER])
layer_all = results_df[results_df['layer'] == BEST_LAYER]

for idx, (name, same_cond, metric, color) in enumerate([
    ('Patching Recovery', 'patch_same', 'recovery_ev', COLORS['patching']),
    ('Steering Increase (Winsor.)', 'steer_same', 'increase_ev', COLORS['steering']),
    ('Ablation Reduction', 'ablate_same', 'reduction_ev', COLORS['ablation']),
]):
    ax = axes[idx]
    k_values = [5, 10, 20, 50]
    means, sems = [], []
    for K in k_values:
        vals = layer_pp[(layer_pp['K'] == K) & (layer_pp['condition'] == same_cond)][metric].dropna()
        if 'steer' in same_cond: vals = winsorize_steering(vals)
        means.append(vals.mean()); sems.append(vals.sem())

    ax.errorbar(k_values, [m * 100 for m in means], yerr=[s * 100 for s in sems],
                marker='o', markersize=8, linewidth=2, capsize=5, color=color, label='Same-Pair')

    rand_cond = 'patch_random_a' if 'patch' in same_cond else (
        'steer_random' if 'steer' in same_cond else 'ablate_random')
    rand_means = []
    for K in k_values:
        vals = layer_all[(layer_all['K'] == K) & (layer_all['condition'] == rand_cond)][metric].dropna()
        if 'steer' in same_cond: vals = winsorize_steering(vals)
        rand_means.append(vals.mean() if len(vals) > 0 else 0)
    ax.plot(k_values, [m * 100 for m in rand_means], 's--', markersize=6, linewidth=1.5,
            color=COLORS['random'], label='Random-A')

    ax.set_xlabel('Number of Features (K)'); ax.set_ylabel('Effect (%)')
    ax.set_title(name); ax.legend(loc='lower right', fontsize=9)
    ax.set_xticks(k_values); ax.axhline(y=0, color='black', linestyle='-', linewidth=0.5)

plt.suptitle(f'{MODEL_NAME}: Dose-Response (Layer {BEST_LAYER})', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(FIG_DIR / 'fig3_dose_response.png')
plt.savefig(FIG_DIR / 'fig3_dose_response.pdf')
plt.close()
print("  Saved: fig3_dose_response")


# --- 4. Layer comparison ---

print("\n4. Layer comparison")

layer_matrix = np.zeros((len(ANALYSIS_LAYERS), 3))
for i, layer_val in enumerate(ANALYSIS_LAYERS):
    for j, (cond, metric) in enumerate(zip(method_conds_list, method_metrics_list)):
        vals = per_pair(results_df[(results_df['layer'] == layer_val) & (results_df['K'] == 50)])
        vals = vals[vals['condition'] == cond][metric].dropna()
        if 'steer' in cond: vals = winsorize_steering(vals)
        layer_matrix[i, j] = vals.mean() * 100

fig, ax = plt.subplots(figsize=(8, 7))
sns.heatmap(layer_matrix, annot=True, fmt='.1f', cmap='YlGnBu',
            xticklabels=['Patching', 'Steering (Winsor.)', 'Ablation'],
            yticklabels=[layer_label_compact(l) for l in ANALYSIS_LAYERS],
            ax=ax, linewidths=0.5, cbar_kws={'label': 'Effect (%)'})

for i, l in enumerate(ANALYSIS_LAYERS):
    if l in IT_LAYERS:
        ax.add_patch(plt.Rectangle((0, i), 3, 1, fill=False,
                                    edgecolor='orange', linewidth=2.5, linestyle='--'))

ax.set_title(f'{MODEL_NAME}: Causal Effect Across 8 Layers (K=50)\n(* = IT SAE, dashed = IT layer)')
plt.tight_layout()
plt.savefig(FIG_DIR / 'fig4_layer_comparison.png')
plt.savefig(FIG_DIR / 'fig4_layer_comparison.pdf')
plt.close()
print("  Saved: fig4_layer_comparison")

print(f"\n{'Layer':<18} {'Patching':>10} {'Steering':>10} {'Ablation':>10}")
print("-" * 52)
for i, l in enumerate(ANALYSIS_LAYERS):
    tag = " (IT)" if l in IT_LAYERS else ""
    print(f"L{l} ({LAYER_DEPTH[l]}%){tag:<6} {layer_matrix[i,0]:>9.1f}% {layer_matrix[i,1]:>9.1f}% {layer_matrix[i,2]:>9.1f}%")


# --- 4b. Encoding-causal dissociation ---

print("\n4b. Encoding-causal dissociation")

causal_by_layer = {}
for l in ANALYSIS_LAYERS:
    vals = per_pair(results_df[(results_df['layer'] == l) & (results_df['K'] == 50)])
    vals = vals[vals['condition'] == 'patch_same']['recovery_ev'].dropna()
    causal_by_layer[l] = vals.mean()

encoding_by_layer = {}
if feature_stats_df is not None and 'layer' in feature_stats_df.columns:
    for l in ANALYSIS_LAYERS:
        layer_feats = feature_stats_df[feature_stats_df['layer'] == l]
        if len(layer_feats) > 0 and 'cohens_d' in layer_feats.columns:
            encoding_by_layer[l] = layer_feats['cohens_d'].abs().mean()
        elif len(layer_feats) > 0 and 'mean_abs_delta' in layer_feats.columns:
            encoding_by_layer[l] = layer_feats['mean_abs_delta'].mean()
        else:
            encoding_by_layer[l] = len(layer_feats)
else:
    for l in ANALYSIS_LAYERS:
        vals = per_pair(results_df[(results_df['layer'] == l) & (results_df['K'] == 50)])
        encoding_by_layer[l] = len(vals[vals['condition'] == 'patch_same'])

if len(encoding_by_layer) > 0:
    fig, axes = plt.subplots(1, 3, figsize=(16, 5))

    depths = [LAYER_DEPTH[l] for l in ANALYSIS_LAYERS]
    causal_vals = [causal_by_layer[l] * 100 for l in ANALYSIS_LAYERS]
    encoding_vals = [encoding_by_layer[l] for l in ANALYSIS_LAYERS]
    enc_max = max(encoding_vals) if max(encoding_vals) > 0 else 1
    encoding_norm = [v / enc_max * 100 for v in encoding_vals]

    # Panel A: dual-axis depth gradient
    ax = axes[0]
    ax.plot(depths, encoding_norm, 'o-', color='#3498db', linewidth=2.5, markersize=8, label='Encoding Strength', zorder=3)
    ax.plot(depths, causal_vals, 's-', color='#e74c3c', linewidth=2.5, markersize=8, label='Causal Influence', zorder=3)
    for l in IT_LAYERS:
        if l in ANALYSIS_LAYERS:
            d = LAYER_DEPTH[l]
            ax.axvline(x=d, color='orange', linestyle='--', alpha=0.5, linewidth=1.5)
            ax.annotate('IT', (d, 105), ha='center', fontsize=9, color='orange', fontweight='bold')
    ax.set_xlabel('Layer Depth (%)'); ax.set_ylabel('Normalized Strength (%)')
    ax.set_title('A. Encoding vs Causal Influence\nAcross Depth', fontweight='bold')
    ax.legend(loc='center left', fontsize=9); ax.set_xlim(5, 95)

    # Panel B: patching recovery bars
    ax = axes[1]
    bar_colors = ['orange' if l in IT_LAYERS else '#2ecc71' for l in ANALYSIS_LAYERS]
    bars = ax.bar(range(len(ANALYSIS_LAYERS)), causal_vals, color=bar_colors, edgecolor='black', linewidth=0.8)
    ax.set_xticks(range(len(ANALYSIS_LAYERS)))
    ax.set_xticklabels([f"L{l}\n{LAYER_DEPTH[l]}%" for l in ANALYSIS_LAYERS], fontsize=9)
    ax.set_ylabel('Patching Recovery (%)')
    ax.set_title('B. Causal Influence by Layer\n(K=50, Per-Pair)', fontweight='bold')
    ax.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
    for bar, val in zip(bars, causal_vals):
        ax.annotate(f'{val:.1f}%', (bar.get_x() + bar.get_width()/2, bar.get_height()),
                   ha='center', va='bottom', fontsize=8, fontweight='bold')

    # Panel C: scatter
    ax = axes[2]
    for l in ANALYSIS_LAYERS:
        c = 'orange' if l in IT_LAYERS else '#2c3e50'
        marker = 'D' if l in IT_LAYERS else 'o'
        ax.scatter(encoding_by_layer[l], causal_by_layer[l] * 100,
                  s=120, c=c, marker=marker, edgecolors='black', linewidth=1, zorder=3)
        ax.annotate(f'L{l}', (encoding_by_layer[l], causal_by_layer[l] * 100),
                   textcoords="offset points", xytext=(6, 4), fontsize=9)
    if len(ANALYSIS_LAYERS) >= 4:
        r, p = pearsonr(list(encoding_by_layer.values()),
                        [causal_by_layer[l] * 100 for l in ANALYSIS_LAYERS])
        ax.set_title(f'C. Encoding vs Causal (r={r:.2f})', fontweight='bold')
    else:
        ax.set_title('C. Encoding vs Causal', fontweight='bold')
    ax.set_xlabel('Encoding Strength'); ax.set_ylabel('Patching Recovery (%)')

    plt.suptitle(f'{MODEL_NAME}: Encoding-Causal Dissociation Across 8 Layers', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(FIG_DIR / 'fig4b_encoding_causal_dissociation.png')
    plt.savefig(FIG_DIR / 'fig4b_encoding_causal_dissociation.pdf')
    plt.close()
    print("  Saved: fig4b_encoding_causal_dissociation")
else:
    print("  Skipped: insufficient encoding data")


# --- 4c. Depth gradient line plots ---

print("\n4c. Depth gradient")

fig, axes = plt.subplots(1, 3, figsize=(16, 5))

for idx, (name, cond, metric, color) in enumerate([
    ('Patching Recovery', 'patch_same', 'recovery_ev', COLORS['patching']),
    ('Steering Increase (Winsor.)', 'steer_same', 'increase_ev', COLORS['steering']),
    ('Ablation Reduction', 'ablate_same', 'reduction_ev', COLORS['ablation']),
]):
    ax = axes[idx]
    depths = [LAYER_DEPTH[l] for l in ANALYSIS_LAYERS]

    means, sems = [], []
    for l in ANALYSIS_LAYERS:
        vals = per_pair(results_df[(results_df['layer'] == l) & (results_df['K'] == 50)])
        vals = vals[vals['condition'] == cond][metric].dropna()
        if 'steer' in cond: vals = winsorize_steering(vals)
        means.append(vals.mean() * 100); sems.append(vals.sem() * 100)

    ax.errorbar(depths, means, yerr=sems, marker='o', markersize=8,
                linewidth=2, capsize=4, color=color, label='Same-Pair', zorder=3)

    rand_cond = 'patch_random_a' if 'patch' in cond else (
        'steer_random' if 'steer' in cond else 'ablate_random')
    rand_means = []
    for l in ANALYSIS_LAYERS:
        vals = results_df[(results_df['layer'] == l) & (results_df['K'] == 50) &
                          (results_df['condition'] == rand_cond)][metric].dropna()
        if 'steer' in cond: vals = winsorize_steering(vals)
        rand_means.append(vals.mean() * 100 if len(vals) > 0 else 0)

    ax.plot(depths, rand_means, 's--', markersize=6, linewidth=1.5,
            color=COLORS['random'], label='Random-A', alpha=0.7)

    for l in IT_LAYERS:
        if l in ANALYSIS_LAYERS:
            ax.axvline(x=LAYER_DEPTH[l], color='orange', linestyle=':', alpha=0.5, linewidth=1.5)

    ax.set_xlabel('Layer Depth (%)'); ax.set_ylabel('Effect (%)')
    ax.set_title(name); ax.legend(fontsize=9, loc='upper left' if 'Patch' in name else 'best')
    ax.axhline(y=0, color='black', linestyle='-', linewidth=0.5); ax.set_xlim(5, 95)

plt.suptitle(f'{MODEL_NAME}: Causal Effects Across Depth (K=50)\nOrange dotted = IT SAE layer',
             fontsize=13, fontweight='bold')
plt.tight_layout()
plt.savefig(FIG_DIR / 'fig4c_depth_gradient.png')
plt.savefig(FIG_DIR / 'fig4c_depth_gradient.pdf')
plt.close()
print("  Saved: fig4c_depth_gradient")

print("\nSections 1-4c done.")

In [None]:
# Causal validation analysis — sections 5-10

# --- 5. Demographic breakdown ---

print("5. Demographic breakdown")

main_df_best = per_pair(results_df[(results_df['layer'] == BEST_LAYER) & (results_df['K'] == 50)])

demo_results = []
for demo in ALL_DEMOS:
    demo_pp = main_df_best[main_df_best['demographic'] == demo]
    patch = demo_pp[demo_pp['condition'] == 'patch_same']['recovery_ev'].dropna()
    steer = winsorize_steering(demo_pp[demo_pp['condition'] == 'steer_same']['increase_ev'].dropna())
    ablate = demo_pp[demo_pp['condition'] == 'ablate_same']['reduction_ev'].dropna()
    demo_results.append({
        'Demographic': demo.capitalize(), 'N Pairs': len(patch),
        'Patching (mean/med)': f"{patch.mean():.1%} / {patch.median():.1%}" if len(patch) > 0 else "N/A",
        'Steering (mean/med)': f"{steer.mean():.1%} / {steer.median():.1%}" if len(steer) > 0 else "N/A",
        'Ablation (mean/med)': f"{ablate.mean():.1%} / {ablate.median():.1%}" if len(ablate) > 0 else "N/A",
    })

demo_results_df = pd.DataFrame(demo_results)
print(demo_results_df.to_string(index=False))
with open(TABLE_DIR / 'table2_demographic_breakdown.tex', 'w') as f:
    f.write(demo_results_df.to_latex(index=False, escape=False))

# Demographics heatmap
demo_matrix = np.zeros((len(ALL_DEMOS), 3))
for i, demo in enumerate(ALL_DEMOS):
    for j, (cond, metric) in enumerate(zip(method_conds_list, method_metrics_list)):
        vals = main_df_best[(main_df_best['demographic'] == demo) &
                            (main_df_best['condition'] == cond)][metric].dropna()
        if 'steer' in cond: vals = winsorize_steering(vals)
        demo_matrix[i, j] = vals.mean() * 100 if len(vals) > 0 else 0

fig, ax = plt.subplots(figsize=(7, 5))
sns.heatmap(demo_matrix, annot=True, fmt='.1f', cmap='YlGnBu',
            xticklabels=['Patching', 'Steering (Winsor.)', 'Ablation'],
            yticklabels=[d.capitalize() for d in ALL_DEMOS],
            ax=ax, linewidths=0.5, cbar_kws={'label': 'Effect (%)'})
ax.set_title(f'{MODEL_NAME}: Causal Effects by Demographic (Layer {BEST_LAYER}, K=50)')
plt.tight_layout()
plt.savefig(FIG_DIR / 'fig5_demographic_breakdown.png')
plt.savefig(FIG_DIR / 'fig5_demographic_breakdown.pdf')
plt.close()
print("  Saved: fig5_demographic_breakdown")

# 5b. Steering direction
if has_direction:
    print("\n5b. Steering direction")

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    ax = axes[0]
    k_values = [5, 10, 20, 50]
    fwd_pcts = []
    for K_val in k_values:
        dir_vals = per_pair(results_df[
            (results_df['layer'] == BEST_LAYER) & (results_df['K'] == K_val) &
            (results_df['condition'] == 'steer_same')
        ])['direction_preserved'].dropna()
        fwd_pcts.append(dir_vals.mean() * 100 if len(dir_vals) > 0 else 0)
    ax.bar([str(k) for k in k_values], fwd_pcts, color=COLORS['steering'], edgecolor='black')
    ax.set_xlabel('Number of Features (K)'); ax.set_ylabel('Direction Preserved (%)')
    ax.set_title(f'Forward Steering (×2) — Layer {BEST_LAYER}'); ax.set_ylim(0, 105)
    for i, p in enumerate(fwd_pcts):
        ax.annotate(f'{p:.1f}%', (i, p), ha='center', va='bottom', fontsize=10)

    ax = axes[1]
    fwd_by_layer, rev_by_layer = [], []
    for layer_val in ANALYSIS_LAYERS:
        fwd = per_pair(results_df[
            (results_df['layer'] == layer_val) & (results_df['K'] == 50) &
            (results_df['condition'] == 'steer_same')
        ])['direction_preserved'].dropna()
        rev = per_pair(results_df[
            (results_df['layer'] == layer_val) & (results_df['K'] == 50) &
            (results_df['condition'] == 'steer_reverse')
        ])['direction_preserved'].dropna()
        fwd_by_layer.append(fwd.mean() * 100 if len(fwd) > 0 else 0)
        rev_by_layer.append(rev.mean() * 100 if len(rev) > 0 else 0)

    x = np.arange(len(ANALYSIS_LAYERS)); w = 0.35
    ax.bar(x - w/2, fwd_by_layer, w, label='Forward (×2)', color=COLORS['steering'], edgecolor='black')
    ax.bar(x + w/2, rev_by_layer, w, label='Reverse (×-2)', color='#c0392b', edgecolor='black')
    ax.set_xticks(x)
    ax.set_xticklabels([f'L{l}\n{LAYER_DEPTH[l]}%' for l in ANALYSIS_LAYERS], fontsize=8)
    ax.set_ylabel('Direction Preserved (%)'); ax.set_title('Forward vs Reverse by Layer')
    ax.set_ylim(0, 105)
    ax.axhline(y=50, color='gray', linestyle=':', linewidth=1, alpha=0.7, label='Chance')
    ax.legend(fontsize=8, loc='lower right')
    for i, (f, r) in enumerate(zip(fwd_by_layer, rev_by_layer)):
        ax.annotate(f'{f:.0f}', (i - w/2, f), ha='center', va='bottom', fontsize=7)
        ax.annotate(f'{r:.0f}', (i + w/2, r), ha='center', va='bottom', fontsize=7)

    plt.suptitle(f'{MODEL_NAME}: Steering Direction Control (8 Layers)', fontsize=13, fontweight='bold')
    plt.tight_layout()
    plt.savefig(FIG_DIR / 'fig5b_steering_direction.png')
    plt.savefig(FIG_DIR / 'fig5b_steering_direction.pdf')
    plt.close()
    print("  Saved: fig5b_steering_direction")

# 5c. Behavioral effect vs causal recovery scatter
print("\n5c. Effect vs recovery scatter")

fig, axes = plt.subplots(1, 3, figsize=(15, 5))
for ax_idx, (name, cond, metric, color) in enumerate([
    ('Patching Recovery', 'patch_same', 'recovery_ev', COLORS['patching']),
    ('Steering Increase (Winsor.)', 'steer_same', 'increase_ev', COLORS['steering']),
    ('Ablation Reduction', 'ablate_same', 'reduction_ev', COLORS['ablation']),
]):
    ax = axes[ax_idx]
    cond_df = main_df_best[main_df_best['condition'] == cond][
        ['pair_key', 'demographic', 'effect_ev_base', metric]].dropna().copy()
    if 'steer' in cond: cond_df[metric] = winsorize_steering(cond_df[metric])

    if len(cond_df) < 5:
        ax.set_title(f'{name}\n(insufficient data)'); continue

    for demo in ALL_DEMOS:
        demo_sub = cond_df[cond_df['demographic'] == demo]
        ax.scatter(demo_sub['effect_ev_base'].abs(), demo_sub[metric],
                  color=COLORS[demo], alpha=0.3, s=15, label=demo.capitalize())

    r, p = pearsonr(cond_df['effect_ev_base'].abs(), cond_df[metric])
    z = np.polyfit(cond_df['effect_ev_base'].abs(), cond_df[metric], 1)
    x_line = np.linspace(0, cond_df['effect_ev_base'].abs().max(), 100)
    ax.plot(x_line, np.polyval(z, x_line), 'k--', alpha=0.7, linewidth=1.5)
    ax.set_xlabel('|Baseline Effect|'); ax.set_ylabel(name)
    ax.set_title(f'{name}\nr={r:.2f} ({"p<0.001" if p < 0.001 else f"p={p:.3f}"})', fontweight='bold')
    ax.axhline(y=0, color='gray', linestyle='-', linewidth=0.5, alpha=0.5)
    if ax_idx == 0: ax.legend(fontsize=7, loc='upper left')

plt.suptitle(f'{MODEL_NAME}: Larger Baseline Effect → Stronger Recovery?', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(FIG_DIR / 'fig5c_effect_vs_recovery_scatter.png')
plt.savefig(FIG_DIR / 'fig5c_effect_vs_recovery_scatter.pdf')
plt.close()
print("  Saved: fig5c_effect_vs_recovery_scatter")


# --- 6. Feature analysis ---

print("\n6. Feature analysis")

all_features_by_demo = defaultdict(list)
for _, row in top_features_df.iterrows():
    feats = safe_parse_list(row['top_5_features'])
    deltas = safe_parse_list(row['top_5_deltas'])
    for f, d in zip(feats, deltas):
        all_features_by_demo[(row['layer'], row['demographic'])].append((f, d))

top_10_per_demo = {}
for demo in ALL_DEMOS:
    key = (BEST_LAYER, demo)
    if key in all_features_by_demo:
        features = [f for f, d in all_features_by_demo[key]]
        counts = Counter(features)
        top_10_per_demo[demo] = set([f for f, c in counts.most_common(10)])
        print(f"  {demo}: top 5 = {[f'{f}({c}x)' for f, c in counts.most_common(5)]}")

# Overlap heatmap
overlap_matrix = np.zeros((len(ALL_DEMOS), len(ALL_DEMOS)))
for i, d1 in enumerate(ALL_DEMOS):
    for j, d2 in enumerate(ALL_DEMOS):
        f1 = top_10_per_demo.get(d1, set())
        f2 = top_10_per_demo.get(d2, set())
        if len(f1 | f2) > 0:
            overlap_matrix[i, j] = len(f1 & f2) / len(f1 | f2)

fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(overlap_matrix, annot=True, fmt='.2f', cmap='YlOrRd',
            xticklabels=[d.capitalize() for d in ALL_DEMOS],
            yticklabels=[d.capitalize() for d in ALL_DEMOS], ax=ax, vmin=0, vmax=1)
ax.set_title(f'{MODEL_NAME}: Feature Overlap (Jaccard, Top-10, Layer {BEST_LAYER})')
plt.tight_layout()
plt.savefig(FIG_DIR / 'fig6_feature_overlap.png')
plt.savefig(FIG_DIR / 'fig6_feature_overlap.pdf')
plt.close()
print("  Saved: fig6_feature_overlap")


# --- 7. Key features table ---

print("\n7. Key features for interpretation")

all_best_features = []
for _, row in top_features_df[top_features_df['layer'] == BEST_LAYER].iterrows():
    feats = safe_parse_list(row['top_5_features'])
    deltas = safe_parse_list(row['top_5_deltas'])
    for f, d in zip(feats, deltas):
        all_best_features.append({
            'feature': f, 'delta': d, 'demographic': row['demographic'],
            'domain': row['domain'], 'effect': row['effect_ev']
        })

features_agg_df = pd.DataFrame(all_best_features)
feature_summary_table = features_agg_df.groupby('feature').agg({
    'delta': ['mean', 'std', 'count'],
    'demographic': lambda x: ', '.join(sorted(set(x))),
}).reset_index()
feature_summary_table.columns = ['Feature', 'Mean Delta', 'Std Delta', 'Count', 'Demographics']
feature_summary_table = feature_summary_table.sort_values('Count', ascending=False)

print(feature_summary_table.head(15).to_string(index=False))
feature_summary_table.head(50).to_csv(DATA_DIR / 'key_features_for_neuronpedia.csv', index=False)


# --- 8. Summary statistics ---

print("\n8. Summary statistics")

unique_pairs = len(pairs_df[pairs_df['layer'] == BEST_LAYER])
fdr_info = summary.get('fdr', summary.get('fdr_correction', {}))

steer_same_winsor = winsorize_steering(
    main_df_best[main_df_best['condition'] == 'steer_same']['increase_ev'].dropna()
)

summary_stats = {
    'Model': MODEL_NAME,
    'Best Layer': f"{BEST_LAYER} ({LAYER_DEPTH[BEST_LAYER]}% depth)",
    'Dataset': {
        'Unique pairs (best layer)': unique_pairs,
        'Demographics': 5, 'Domains': 5,
        'Layers analyzed': len(ANALYSIS_LAYERS),
        'IT SAE layers': str(list(IT_LAYERS)),
    },
    f'Main Results (Layer {BEST_LAYER}, K=50, paired)': {
        'Patching recovery': f"{main_df_best[main_df_best['condition'] == 'patch_same']['recovery_ev'].mean():.1%}",
        'Steering increase (Winsorized)': f"{steer_same_winsor.mean():.1%}",
        'Ablation reduction': f"{main_df_best[main_df_best['condition'] == 'ablate_same']['reduction_ev'].mean():.1%}",
    },
    'Effect Sizes (paired d_z)': {
        'Patching': f"{computed_cohens_d.get('Patching', 0):.2f}",
        'Steering': f"{computed_cohens_d.get('Steering', 0):.2f}",
        'Ablation': f"{computed_cohens_d.get('Ablation', 0):.2f}",
    },
    'Controls': {
        'Random-A': f"{main_all[main_all['condition'] == 'patch_random_a']['recovery_ev'].mean():.1%}",
        'Random-B': f"{main_all[main_all['condition'] == 'patch_random_b']['recovery_ev'].mean():.1%}",
        'Cross-pair': f"{main_all[main_all['condition'] == 'patch_cross']['recovery_ev'].mean():.1%}",
    },
    'Layer Gradient (patching, K=50)': {
        layer_label(l): f"{causal_by_layer[l]:.1%}" for l in ANALYSIS_LAYERS
    },
    'FDR': {
        'Tests': fdr_info.get('n_tests', 'N/A'),
        'Significant': f"{fdr_info.get('n_significant', 'N/A')}/{fdr_info.get('n_tests', 'N/A')}",
    },
}

if has_direction:
    dir_fwd = main_df_best[main_df_best['condition'] == 'steer_same']['direction_preserved'].dropna().mean() * 100
    dir_rev = main_df_best[main_df_best['condition'] == 'steer_reverse']['direction_preserved'].dropna().mean() * 100
    summary_stats['Steering Direction'] = {
        'Forward preserved': f"{dir_fwd:.1f}%", 'Reverse preserved': f"{dir_rev:.1f}%",
    }

for section, stats in summary_stats.items():
    if isinstance(stats, dict):
        print(f"\n{section}:")
        for k, v in stats.items(): print(f"  {k}: {v}")
    else:
        print(f"\n{section}: {stats}")

with open(SUMMARY_DIR / 'summary_statistics.json', 'w') as f:
    json.dump(summary_stats, f, indent=2, default=str)


# --- 9. Combined paper figure ---

print("\n9. Combined figure")

fig = plt.figure(figsize=(18, 14))
gs = fig.add_gridspec(3, 4, hspace=0.4, wspace=0.35)

# A: Main results
ax_a = fig.add_subplot(gs[0, :])
methods_data = []
for name, same_cond, rand_cond, metric in methods:
    same = main_df_best[main_df_best['condition'] == same_cond][metric].dropna()
    rand = main_all[main_all['condition'] == rand_cond][metric].dropna()
    if 'steer' in same_cond:
        same = winsorize_steering(same); rand = winsorize_steering(rand)
    methods_data.append((name, same.mean(), same.sem(), rand.mean(), computed_cohens_d.get(name, 0)))

x_pos = np.arange(3); w = 0.35
bars1 = ax_a.bar(x_pos - w/2, [d[1]*100 for d in methods_data], w,
                  yerr=[d[2]*100 for d in methods_data], capsize=5,
                  label='Same-Pair',
                  color=[COLORS['patching'], COLORS['steering'], COLORS['ablation']],
                  edgecolor='black', linewidth=1)
bars2 = ax_a.bar(x_pos + w/2, [d[3]*100 for d in methods_data], w,
                  label='Random-A', color=COLORS['random'], edgecolor='black', linewidth=1)
ax_a.set_xticks(x_pos)
ax_a.set_xticklabels([f"{d[0]}\nd_z={d[4]:.2f}" for d in methods_data])
ax_a.set_ylabel('Effect (%)'); ax_a.legend()
ax_a.set_title(f'A. Same-Pair vs Random (Layer {BEST_LAYER}, K=50, Paired)', fontweight='bold')
ax_a.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
for i in range(3):
    ax_a.annotate('***', xy=(i - w/2, methods_data[i][1]*100 + methods_data[i][2]*100 + 3),
                  ha='center', fontsize=12)

# B: Dose-response
ax_b = fig.add_subplot(gs[1, 0])
for name, same_cond, metric, color in [
    ('Patching', 'patch_same', 'recovery_ev', COLORS['patching']),
    ('Steering (W)', 'steer_same', 'increase_ev', COLORS['steering']),
    ('Ablation', 'ablate_same', 'reduction_ev', COLORS['ablation']),
]:
    means = []
    for K in [5, 10, 20, 50]:
        vals = per_pair(results_df[(results_df['layer'] == BEST_LAYER) & (results_df['K'] == K)])
        vals = vals[vals['condition'] == same_cond][metric].dropna()
        if 'steer' in same_cond: vals = winsorize_steering(vals)
        means.append(vals.mean())
    ax_b.plot([5, 10, 20, 50], [m*100 for m in means], marker='o', label=name, color=color, linewidth=2)
ax_b.set_xlabel('K'); ax_b.set_ylabel('Effect (%)')
ax_b.set_title('B. Dose-Response', fontweight='bold')
ax_b.legend(fontsize=7); ax_b.set_xticks([5, 10, 20, 50])

# C: Layer heatmap
ax_c = fig.add_subplot(gs[1, 1])
sns.heatmap(layer_matrix, annot=True, fmt='.1f', cmap='YlGnBu',
            xticklabels=['Patch', 'Steer(W)', 'Ablate'],
            yticklabels=[layer_label_compact(l) for l in ANALYSIS_LAYERS],
            ax=ax_c, linewidths=0.5, cbar=False, annot_kws={'fontsize': 8})
ax_c.set_title('C. Layer Comparison (%)', fontweight='bold')
ax_c.tick_params(axis='y', labelsize=8)
for i, l in enumerate(ANALYSIS_LAYERS):
    if l in IT_LAYERS:
        ax_c.add_patch(plt.Rectangle((0, i), 3, 1, fill=False,
                                      edgecolor='orange', linewidth=2, linestyle='--'))

# D: Depth gradient
ax_d = fig.add_subplot(gs[1, 2])
depths = [LAYER_DEPTH[l] for l in ANALYSIS_LAYERS]
patch_by_depth = [causal_by_layer[l] * 100 for l in ANALYSIS_LAYERS]
ax_d.plot(depths, patch_by_depth, 'o-', color=COLORS['patching'], linewidth=2.5, markersize=8)
for i, l in enumerate(ANALYSIS_LAYERS):
    c = 'orange' if l in IT_LAYERS else 'black'
    ax_d.annotate(f'L{l}', (depths[i], patch_by_depth[i]),
                 textcoords="offset points", xytext=(4, 5), fontsize=7, color=c)
ax_d.set_xlabel('Depth (%)'); ax_d.set_ylabel('Patching Recovery (%)')
ax_d.set_title('D. Causal Influence\nAcross Depth', fontweight='bold')
ax_d.set_xlim(5, 95); ax_d.axhline(y=0, color='black', linestyle='-', linewidth=0.5)

# E: Control hierarchy
ax_e = fig.add_subplot(gs[1, 3])
same_patch = main_df_best[main_df_best['condition'] == 'patch_same']['recovery_ev'].dropna().mean()
rand_a_patch = main_all[main_all['condition'] == 'patch_random_a']['recovery_ev'].dropna().mean()
rand_b_patch = main_all[main_all['condition'] == 'patch_random_b']['recovery_ev'].dropna().mean()
cross_patch = main_all[main_all['condition'] == 'patch_cross']['recovery_ev'].dropna().mean()
control_data = [('Same', same_patch, COLORS['patching']), ('Rand-A', rand_a_patch, COLORS['random_a']),
                ('Cross', cross_patch, COLORS['cross']), ('Rand-B', rand_b_patch, COLORS['random_b'])]
ax_e.bar(range(4), [d[1]*100 for d in control_data],
         color=[d[2] for d in control_data], edgecolor='black')
ax_e.set_xticks(range(4)); ax_e.set_xticklabels([d[0] for d in control_data], fontsize=9)
ax_e.set_ylabel('Recovery (%)'); ax_e.set_title('E. Control Hierarchy', fontweight='bold')
ax_e.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
for i, (lbl, val, _) in enumerate(control_data):
    ax_e.annotate(f'{val:.1%}', (i, val*100), ha='center', va='bottom', fontsize=8, fontweight='bold')

# F: Demographics heatmap
ax_f = fig.add_subplot(gs[2, :2])
sns.heatmap(demo_matrix, annot=True, fmt='.1f', cmap='YlGnBu',
            xticklabels=['Patching', 'Steering (W)', 'Ablation'],
            yticklabels=[d.capitalize() for d in ALL_DEMOS],
            ax=ax_f, linewidths=0.5, cbar_kws={'label': 'Effect (%)', 'shrink': 0.6})
ax_f.set_title('F. Results by Demographic', fontweight='bold')

# G: Feature overlap
ax_g = fig.add_subplot(gs[2, 2])
sns.heatmap(overlap_matrix, annot=True, fmt='.2f', cmap='YlOrRd',
            xticklabels=[d[:3].capitalize() for d in ALL_DEMOS],
            yticklabels=[d[:3].capitalize() for d in ALL_DEMOS],
            ax=ax_g, vmin=0, vmax=0.5, cbar_kws={'shrink': 0.8})
ax_g.set_title('G. Feature Overlap\n(Jaccard)', fontweight='bold')

# H: Direction by layer
if has_direction:
    ax_h = fig.add_subplot(gs[2, 3])
    fwd_vals = []
    for l in ANALYSIS_LAYERS:
        fwd = per_pair(results_df[
            (results_df['layer'] == l) & (results_df['K'] == 50) &
            (results_df['condition'] == 'steer_same')
        ])['direction_preserved'].dropna()
        fwd_vals.append(fwd.mean() * 100 if len(fwd) > 0 else 0)
    bar_colors_h = ['orange' if l in IT_LAYERS else COLORS['steering'] for l in ANALYSIS_LAYERS]
    ax_h.bar(range(len(ANALYSIS_LAYERS)), fwd_vals, color=bar_colors_h, edgecolor='black', linewidth=0.8)
    ax_h.set_xticks(range(len(ANALYSIS_LAYERS)))
    ax_h.set_xticklabels([f'L{l}' for l in ANALYSIS_LAYERS], fontsize=7)
    ax_h.set_ylabel('Direction (%)'); ax_h.set_title('H. Steering Direction\nby Layer', fontweight='bold')
    ax_h.axhline(y=50, color='gray', linestyle=':', linewidth=1, alpha=0.7); ax_h.set_ylim(0, 105)

plt.savefig(FIG_DIR / 'fig_combined_main.png')
plt.savefig(FIG_DIR / 'fig_combined_main.pdf')
plt.close()
print("  Saved: fig_combined_main")


# --- 10. Bundle ---

print("\n10. Bundling")

raw_files = [
    (VALIDATION_DIR / 'validation_results.csv', DATA_DIR / 'validation_results.csv'),
    (VALIDATION_DIR / 'top_features.csv', DATA_DIR / 'top_features.csv'),
    (VALIDATION_DIR / 'validation_summary.json', DATA_DIR / 'validation_summary_old.json'),
    (EXTRACTION_DIR / 'behavioral_effects.csv', DATA_DIR / 'behavioral_effects.csv'),
    (EXTRACTION_DIR / 'selected_features.json', DATA_DIR / 'selected_features.json'),
    (EXTRACTION_DIR / 'feature_stats.csv', DATA_DIR / 'feature_stats.csv'),
    (EXTRACTION_DIR / 'feature_overlap.json', DATA_DIR / 'feature_overlap.json'),
]
if patched_summary_path.exists():
    raw_files.append((patched_summary_path, DATA_DIR / 'validation_summary_patched.json'))
if exclusion_path.exists():
    raw_files.append((exclusion_path, DATA_DIR / 'exclusion_log.csv'))
if paired_comp_path.exists():
    raw_files.append((paired_comp_path, DATA_DIR / 'paired_comparisons.csv'))

for src, dst in raw_files:
    if src.exists():
        shutil.copy2(src, dst)

readme = f"""# Causal Validation Results — {MODEL_NAME} (8-Layer)
# Generated: {pd.Timestamp.now().isoformat()}

## Model
- {MODEL_NAME}, {N_LAYERS_TOTAL} layers total
- Analysis: {ANALYSIS_LAYERS}
- Depths: {LAYER_DEPTH}
- IT SAE: {list(IT_LAYERS)}, all others PT
- Best: L{BEST_LAYER} ({LAYER_DEPTH[BEST_LAYER]}%)

## Methodology
- Paired tests (ttest_rel, d_z)
- Steering Winsorized ±200%
- Random-A (active+delta) vs Random-B (shuffled)
- Cross-pair from different demo×domain
- FDR correction (BH)
- Scale-normalized threshold (0.3 × range/10)
"""

with open(BUNDLE_DIR / 'README.md', 'w') as f:
    f.write(readme)

shutil.make_archive(str(BUNDLE_DIR), 'zip', root_dir=BUNDLE_DIR.parent, base_dir=BUNDLE_DIR.name)

n_figs = len(list(FIG_DIR.glob('*.png')))
n_tables = len(list(TABLE_DIR.glob('*.tex')))
n_data = len(list(DATA_DIR.glob('*')))
print(f"  {n_figs} figures, {n_tables} tables, {n_data} data files")
print(f"  Download: {BUNDLE_DIR.name}.zip")
print("\nDone.")

In [None]:
# Feature interpretation — setup, top features, encoding matrix

import torch
import numpy as np
import pandas as pd
import json
import ast
import re
import shutil
from pathlib import Path
from collections import Counter, defaultdict
from scipy import stats
from scipy.stats import t as t_dist
from statsmodels.stats.multitest import multipletests
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

torch.set_grad_enabled(False)

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns

plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams.update({
    'font.family': 'DejaVu Sans', 'font.size': 11,
    'axes.labelsize': 12, 'axes.titlesize': 13,
    'figure.dpi': 150, 'savefig.dpi': 300, 'savefig.bbox': 'tight',
})

MODEL_NAME = "Gemma 2 9B IT"
LAYER = 36
N_LAYERS_TOTAL = 42

ALL_DEMOS = ['income', 'age', 'gender', 'education', 'vote']
N_TOP_FEATURES = 10
N_PAIRS_ENCODING = 30
N_PAIRS_LOCATION = 20
ALPHA = 0.05

DEMO_COLORS = {
    'income': '#1abc9c', 'age': '#9b59b6', 'gender': '#e91e63',
    'education': '#3498db', 'vote': '#f44336',
}

# --- Path detection ---

BASE_DIR = Path("/content")

def find_file(filename, search_dirs=None):
    if search_dirs is None: search_dirs = [BASE_DIR]
    for d in search_dirs:
        if d is None: continue
        direct = d / filename
        if direct.exists(): return direct
    for match in BASE_DIR.rglob(filename):
        return match
    return None

def find_dir(dirname):
    direct = BASE_DIR / dirname
    if direct.exists() and direct.is_dir(): return direct
    for d in BASE_DIR.iterdir():
        if d.is_dir() and d.name.startswith(dirname[:20]): return d
    return None

VALIDATION_DIR = find_dir("causal_validation_final")
if VALIDATION_DIR is None:
    nested = find_dir("outputs_gemma_replication")
    if nested and (nested / "causal_validation_final").exists():
        VALIDATION_DIR = nested / "causal_validation_final"

EXTRACTION_DIR = None
for d in BASE_DIR.iterdir():
    if d.is_dir() and d.name.startswith("feature-extraction-results-ge"):
        EXTRACTION_DIR = d; break
if EXTRACTION_DIR is None:
    EXTRACTION_DIR = find_dir("feature_extraction")

BUNDLE = Path("/content/feature-interpretation-results-gemma")
BUNDLE.mkdir(exist_ok=True)
FIGS   = BUNDLE / "figures";   FIGS.mkdir(exist_ok=True)
TABLES = BUNDLE / "tables";    TABLES.mkdir(exist_ok=True)
DATA   = BUNDLE / "data";      DATA.mkdir(exist_ok=True)

print(f"VALIDATION_DIR: {VALIDATION_DIR}")
print(f"EXTRACTION_DIR: {EXTRACTION_DIR}")

if hasattr(model, 'tokenizer'):
    tokenizer = model.tokenizer

# --- Load data ---

prompts_path = find_file("prompts_validation.parquet",
                          [EXTRACTION_DIR, BASE_DIR] if EXTRACTION_DIR else [BASE_DIR])
if prompts_path is None:
    raise FileNotFoundError("Cannot find prompts_validation.parquet")
prompts_df = pd.read_parquet(prompts_path)
if 'vocab_idx' in prompts_df.columns:
    prompts_df = prompts_df[prompts_df['vocab_idx'] == 0].reset_index(drop=True)
print(f"Prompts: {len(prompts_df)}")

tf_path = find_file("top_features.csv", [VALIDATION_DIR] if VALIDATION_DIR else [])
if tf_path is None:
    raise FileNotFoundError("Cannot find top_features.csv")
top_features_df = pd.read_csv(tf_path)
print(f"Top features: {len(top_features_df)} rows")

sae = sae_manager.load_sae(LAYER)
hook_name = f"blocks.{LAYER}.hook_resid_post"

# --- Identify top features ---

def safe_parse_list(x):
    if isinstance(x, list): return x
    if isinstance(x, str):
        try: return ast.literal_eval(x)
        except: return []
    return []

feature_stats = defaultdict(lambda: {'count': 0, 'deltas': [], 'demographics': []})
layer_df = top_features_df[top_features_df['layer'] == LAYER]
if 'feature_method' in layer_df.columns:
    layer_df = layer_df[layer_df['feature_method'] == 'per_pair']

for _, row in layer_df.iterrows():
    feats = safe_parse_list(row['top_5_features'])
    deltas = safe_parse_list(row['top_5_deltas'])
    for f, d in zip(feats, deltas):
        feature_stats[f]['count'] += 1
        feature_stats[f]['deltas'].append(d)
        feature_stats[f]['demographics'].append(row['demographic'])

top_features = [f for f, _ in sorted(
    feature_stats.items(), key=lambda x: -x[1]['count']
)[:N_TOP_FEATURES]]

print(f"\nTop {N_TOP_FEATURES} features:")
for i, feat in enumerate(top_features):
    info = feature_stats[feat]
    demo_counts = Counter(info['demographics'])
    top_demos = ', '.join(f"{d}({c})" for d, c in demo_counts.most_common(3))
    print(f"  {i+1}. Feature {feat}: {info['count']} pairs, {top_demos}")

# --- Encoding matrix ---

print(f"\nRunning {N_TOP_FEATURES} x {len(ALL_DEMOS)} encoding tests...")

def get_last_token_activation(prompt, feature_idx):
    tokens = model.to_tokens(prompt)
    with torch.inference_mode():
        _, cache = model.run_with_cache(tokens, names_filter=lambda n: hook_name in n)
        h = cache[hook_name][0, -1, :].float()
        sae_acts = sae.encode(h.unsqueeze(0))[0]
        act = sae_acts[feature_idx].float().cpu().item()
        del cache
    return act


def test_encoding(feature_idx, demographic, n_pairs):
    demo_df = prompts_df[prompts_df['demographic'] == demographic]
    pair_keys = demo_df['pair_key'].unique()[:n_pairs]
    acts_a, acts_b = [], []

    for pk in pair_keys:
        pair = demo_df[demo_df['pair_key'] == pk]
        if len(pair) < 2: continue
        row_a = pair[pair['value_type'] == 'value_a'].iloc[0]
        row_b = pair[pair['value_type'] == 'value_b'].iloc[0]
        acts_a.append(get_last_token_activation(row_a['prompt'], feature_idx))
        acts_b.append(get_last_token_activation(row_b['prompt'], feature_idx))

    n = min(len(acts_a), len(acts_b))
    if n < 5: return None

    acts_a, acts_b = acts_a[:n], acts_b[:n]
    diffs = np.array(acts_a) - np.array(acts_b)

    t_stat, p_val = stats.ttest_rel(acts_a, acts_b)
    std_diff = np.std(diffs, ddof=1)
    cohens_d = np.mean(diffs) / std_diff if std_diff > 1e-10 else 0.0

    mean_diff = np.mean(diffs)
    se = stats.sem(diffs)
    df = n - 1
    t_crit = t_dist.ppf(0.975, df)

    n_pos = sum(1 for d in diffs if d > 0)
    n_neg = sum(1 for d in diffs if d < 0)

    return {
        'n': n, 'mean_a': np.mean(acts_a), 'mean_b': np.mean(acts_b),
        'mean_diff': mean_diff,
        'ci_low': mean_diff - t_crit * se, 'ci_high': mean_diff + t_crit * se,
        't_stat': t_stat, 'p_val': p_val, 'cohens_d': cohens_d,
        'consistency': max(n_pos, n_neg) / n * 100,
        'direction': 'value_a > value_b' if mean_diff > 0 else 'value_b > value_a',
    }


encoding_results = {}
for feat in tqdm(top_features, desc="Encoding"):
    encoding_results[feat] = {}
    for demo in ALL_DEMOS:
        encoding_results[feat][demo] = test_encoding(feat, demo, N_PAIRS_ENCODING)

# Collect and FDR correct
all_rows = []
for feat in top_features:
    for demo in ALL_DEMOS:
        r = encoding_results[feat][demo]
        if r is not None:
            all_rows.append({
                'feature': feat, 'demographic': demo,
                'n': r['n'], 'mean_diff': r['mean_diff'],
                'ci_low': r['ci_low'], 'ci_high': r['ci_high'],
                't_stat': r['t_stat'], 'p_raw': r['p_val'],
                'cohens_d': r['cohens_d'], 'consistency': r['consistency'],
                'direction': r['direction'],
            })

enc_df = pd.DataFrame(all_rows)
reject, p_corrected, _, _ = multipletests(enc_df['p_raw'], alpha=ALPHA, method='fdr_bh')
enc_df['p_corrected'] = p_corrected
enc_df['significant'] = reject
enc_df['sig_label'] = enc_df['p_corrected'].apply(
    lambda p: '***' if p < 0.001 else ('**' if p < 0.01 else ('*' if p < 0.05 else 'ns')))

n_sig = enc_df['significant'].sum()
n_total = len(enc_df)
print(f"\nSignificant (FDR α={ALPHA}): {n_sig}/{n_total} ({n_sig/n_total*100:.1f}%)")

enc_df.to_csv(DATA / 'encoding_matrix_full.csv', index=False)

# Compact pivot
pivot_data = []
for feat in top_features:
    row = {'Feature': feat, 'N_Pairs': feature_stats[feat]['count']}
    for demo in ALL_DEMOS:
        r_df = enc_df[(enc_df['feature'] == feat) & (enc_df['demographic'] == demo)]
        if len(r_df) == 0: row[demo] = '-'
        else:
            r = r_df.iloc[0]
            row[demo] = f"{r['sig_label']}(d={r['cohens_d']:+.1f})" if r['significant'] else 'ns'
    pivot_data.append(row)

pivot_df = pd.DataFrame(pivot_data)
print(pivot_df.to_string(index=False))
pivot_df.to_csv(DATA / 'encoding_matrix_compact.csv', index=False)
with open(TABLES / 'table_encoding_matrix.tex', 'w') as f:
    f.write(pivot_df.to_latex(index=False, escape=False))

# --- Fig 1: Encoding heatmap ---

heat_matrix = np.zeros((N_TOP_FEATURES, len(ALL_DEMOS)))
for i, feat in enumerate(top_features):
    for j, demo in enumerate(ALL_DEMOS):
        r_df = enc_df[(enc_df['feature'] == feat) & (enc_df['demographic'] == demo)]
        heat_matrix[i, j] = r_df['cohens_d'].iloc[0] if len(r_df) > 0 else 0

fig, ax = plt.subplots(figsize=(8, max(5, N_TOP_FEATURES * 0.5 + 1)))
vmax = max(3.0, np.abs(heat_matrix).max())
sns.heatmap(heat_matrix, annot=True, fmt='+.1f', cmap='RdBu_r', center=0,
            vmin=-vmax, vmax=vmax,
            xticklabels=[d.capitalize() for d in ALL_DEMOS],
            yticklabels=[str(f) for f in top_features],
            ax=ax, linewidths=0.5, cbar_kws={'label': "Cohen's d"})
for i, feat in enumerate(top_features):
    for j, demo in enumerate(ALL_DEMOS):
        r_df = enc_df[(enc_df['feature'] == feat) & (enc_df['demographic'] == demo)]
        if len(r_df) > 0 and r_df.iloc[0]['significant']:
            ax.text(j + 0.5, i + 0.82, r_df.iloc[0]['sig_label'],
                    ha='center', va='center', fontsize=7, color='black')
ax.set_title(f"{MODEL_NAME}: Feature × Demographic Encoding (Layer {LAYER})\n"
             f"Cohen's d, FDR-corrected", fontweight='bold')
ax.set_ylabel('Feature')
plt.tight_layout()
plt.savefig(FIGS / 'fig_encoding_matrix.png')
plt.savefig(FIGS / 'fig_encoding_matrix.pdf')
plt.close()
print("  Saved: fig_encoding_matrix")

# --- Fig 2: Encoding strength summary ---

sig_counts = enc_df[enc_df['significant']].groupby('feature').size().reindex(top_features, fill_value=0)
mean_abs_d = enc_df.groupby('feature')['cohens_d'].apply(
    lambda x: x.abs().mean()).reindex(top_features, fill_value=0)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

colors_a = ['#e74c3c' if c >= 3 else '#3498db' if c >= 1 else '#bdc3c7' for c in sig_counts.values]
ax1.barh(range(len(top_features)), sig_counts.values, color=colors_a)
ax1.set_yticks(range(len(top_features)))
ax1.set_yticklabels([str(f) for f in top_features])
ax1.set_xlabel('N Significant Demographics (FDR < 0.05)')
ax1.set_title('A. Encoding Breadth', fontweight='bold')
ax1.set_xlim(0, len(ALL_DEMOS) + 0.5)
for i, v in enumerate(sig_counts.values):
    ax1.text(v + 0.1, i, str(v), va='center', fontsize=9)
ax1.invert_yaxis()

colors_b = ['#e74c3c' if d >= 1.0 else '#f39c12' if d >= 0.5 else '#3498db' for d in mean_abs_d.values]
ax2.barh(range(len(top_features)), mean_abs_d.values, color=colors_b)
ax2.set_yticks(range(len(top_features)))
ax2.set_yticklabels([str(f) for f in top_features])
ax2.set_xlabel("Mean |Cohen's d|")
ax2.set_title('B. Encoding Strength', fontweight='bold')
ax2.axvline(x=0.8, color='gray', linestyle='--', alpha=0.3, label='d=0.8')
ax2.legend(fontsize=8); ax2.invert_yaxis()

plt.suptitle(f'{MODEL_NAME}: Feature Encoding Summary (Layer {LAYER})', fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig(FIGS / 'fig_encoding_strength.png')
plt.savefig(FIGS / 'fig_encoding_strength.pdf')
plt.close()
print("  Saved: fig_encoding_strength")

# --- Fig 3: Effect size distribution ---

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

sig_d = enc_df[enc_df['significant']]['cohens_d'].values
ns_d = enc_df[~enc_df['significant']]['cohens_d'].values

if len(sig_d) > 0 and len(ns_d) > 0:
    ax1.violinplot([sig_d, ns_d], positions=[1, 2], showmeans=True, showmedians=True)
    ax1.set_xticks([1, 2]); ax1.set_xticklabels(['Significant', 'Non-significant'])
elif len(sig_d) > 0:
    ax1.violinplot([sig_d], positions=[1], showmeans=True, showmedians=True)
    ax1.set_xticks([1]); ax1.set_xticklabels(['Significant'])
else:
    ax1.violinplot([enc_df['cohens_d'].values], positions=[1], showmeans=True, showmedians=True)
    ax1.set_xticks([1]); ax1.set_xticklabels(['All'])
ax1.set_ylabel("Cohen's d"); ax1.set_title('A. Effect Size Distribution', fontweight='bold')
ax1.axhline(0, color='gray', linewidth=0.5)
ax1.axhline(0.8, color='red', linewidth=0.5, linestyle='--', alpha=0.5)
ax1.axhline(-0.8, color='red', linewidth=0.5, linestyle='--', alpha=0.5)

bins = np.linspace(0, enc_df['cohens_d'].abs().max() + 0.2, 20)
ax2.hist(enc_df[enc_df['significant']]['cohens_d'].abs(), bins=bins,
         alpha=0.7, label=f'Significant (n={len(sig_d)})', color='#e74c3c')
ax2.hist(enc_df[~enc_df['significant']]['cohens_d'].abs(), bins=bins,
         alpha=0.7, label=f'Non-significant (n={len(ns_d)})', color='#bdc3c7')
ax2.axvline(0.8, color='black', linestyle='--', linewidth=1, label='d=0.8')
ax2.set_xlabel("|Cohen's d|"); ax2.set_ylabel('Count')
ax2.set_title('B. Effect Size Magnitude', fontweight='bold'); ax2.legend(fontsize=9)

plt.suptitle(f'{MODEL_NAME}: Effect Size Analysis (Layer {LAYER})', fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig(FIGS / 'fig_effect_size_dist.png')
plt.savefig(FIGS / 'fig_effect_size_dist.pdf')
plt.close()
print("  Saved: fig_effect_size_dist")

# --- Fig 4: Consistency vs effect size ---

fig, ax = plt.subplots(figsize=(8, 6))
for demo in ALL_DEMOS:
    sub = enc_df[enc_df['demographic'] == demo]
    sig_sub = sub[sub['significant']]
    ns_sub = sub[~sub['significant']]
    if len(sig_sub) > 0:
        ax.scatter(sig_sub['cohens_d'].abs(), sig_sub['consistency'],
                   c=DEMO_COLORS[demo], label=demo.capitalize(), s=70, alpha=0.8, marker='*')
    if len(ns_sub) > 0:
        ax.scatter(ns_sub['cohens_d'].abs(), ns_sub['consistency'],
                   c=DEMO_COLORS[demo], s=40, alpha=0.4, marker='x',
                   label=f'{demo.capitalize()} (ns)' if len(sig_sub) == 0 else '')

ax.axhline(60, color='gray', linestyle='--', alpha=0.3)
ax.axvline(0.8, color='gray', linestyle='--', alpha=0.3)
ax.set_xlabel("|Cohen's d|"); ax.set_ylabel("Consistency (%)")
ax.set_title(f"{MODEL_NAME}: Consistency vs Effect Size (Layer {LAYER})\n"
             f"★ = significant, × = non-significant", fontweight='bold')
ax.legend(bbox_to_anchor=(1.02, 1), loc='upper left', fontsize=9); ax.set_ylim(45, 105)
plt.tight_layout()
plt.savefig(FIGS / 'fig_consistency_vs_effect.png')
plt.savefig(FIGS / 'fig_consistency_vs_effect.pdf')
plt.close()
print("  Saved: fig_consistency_vs_effect")

print("\nEncoding analysis done.")

In [None]:
# Feature interpretation — location analysis, multi-demographic direction, summary

# --- Location analysis ---

print("Location analysis: span vs response position")

DEMO_SPAN_PATTERNS = {
    'income': [
        r'You are financially wealthy\..*?money\.',
        r'You are financially poor\..*?choices\.',
    ],
    'age': [
        r'You are 75 years old\..*?phones\.',
        r'You are elderly.*?differently\.',
        r'You are 22 years old\..*?ahead\.',
        r'You are a young adult.*?identity\.',
    ],
    'gender': [
        r'You are a man\..*?expectations\.',
        r'You are male\..*?norms\.',
        r'You are a woman\..*?expectations\.',
        r'You are female\..*?norms\.',
    ],
    'education': [
        r'You have a PhD.*?nature\.',
        r'You are highly educated.*?pursuits\.',
        r'You did not complete high school\..*?solving\.',
        r'You have limited formal education\..*?credentials\.',
    ],
    'vote': [
        r'You are a regular voter.*?ballot\.',
        r'You are a politically engaged.*?identity\.',
        r'You are a non-voter.*?matters\.',
        r'You are politically disengaged.*?life\.',
    ],
}


def find_demographic_span(prompt, token_strs, demographic):
    char_pos = 0
    token_ranges = []
    for i, tok in enumerate(token_strs):
        token_ranges.append((char_pos, char_pos + len(tok), i))
        char_pos += len(tok)

    patterns = DEMO_SPAN_PATTERNS.get(demographic, [])
    for pattern in patterns:
        match = re.search(pattern, prompt, re.IGNORECASE | re.DOTALL)
        if match:
            return {tok_idx for char_start, char_end, tok_idx in token_ranges
                    if char_start < match.end() and char_end > match.start()}
    return set()


def analyze_location(feature_idx, demographic, n_pairs):
    demo_df = prompts_df[prompts_df['demographic'] == demographic]
    pair_keys = demo_df['pair_key'].unique()[:n_pairs]

    span_diffs, last_diffs = [], []
    span_found_count = 0

    for pk in pair_keys:
        pair = demo_df[demo_df['pair_key'] == pk]
        if len(pair) < 2: continue
        row_a = pair[pair['value_type'] == 'value_a'].iloc[0]
        row_b = pair[pair['value_type'] == 'value_b'].iloc[0]

        span_acts, last_acts = {}, {}

        for prompt, label in [(row_a['prompt'], 'a'), (row_b['prompt'], 'b')]:
            tokens = model.to_tokens(prompt)
            token_strs = [tokenizer.decode([t]) for t in tokens[0]]

            with torch.inference_mode():
                _, cache = model.run_with_cache(tokens, names_filter=lambda n: hook_name in n)
                h = cache[hook_name][0, :, :].float()
                sae_acts = sae.encode(h)
                acts = sae_acts[:, feature_idx].float().cpu().numpy()
                del cache

            demo_indices = find_demographic_span(prompt, token_strs, demographic)
            if demo_indices:
                valid_indices = [i for i in demo_indices if i < len(acts)]
                if valid_indices:
                    span_acts[label] = np.mean([acts[i] for i in valid_indices])
                    span_found_count += 1
                else:
                    span_acts[label] = None
            else:
                span_acts[label] = None

            last_acts[label] = float(acts[-1])

        if span_acts.get('a') is not None and span_acts.get('b') is not None:
            span_diffs.append(span_acts['a'] - span_acts['b'])
        last_diffs.append(last_acts['a'] - last_acts['b'])

    n_total_prompts = 2 * len(last_diffs) if last_diffs else 1
    result = {'n_pairs': len(last_diffs), 'span_found_pct': span_found_count / n_total_prompts * 100}

    if len(span_diffs) > 1:
        result['span_mean'] = np.mean(span_diffs)
        result['span_t'], result['span_p'] = stats.ttest_1samp(span_diffs, 0)
    else:
        result['span_mean'] = 0; result['span_t'] = result['span_p'] = np.nan

    if len(last_diffs) > 1:
        result['last_mean'] = np.mean(last_diffs)
        result['last_t'], result['last_p'] = stats.ttest_1samp(last_diffs, 0)
    else:
        result['last_mean'] = 0; result['last_t'] = result['last_p'] = np.nan

    result['ratio'] = abs(result['last_mean'] / result['span_mean']) if abs(result['span_mean']) > 0.001 else float('inf')
    return result


location_rows = []
for feat in tqdm(top_features, desc="Location"):
    for demo in ALL_DEMOS:
        loc = analyze_location(feat, demo, N_PAIRS_LOCATION)
        loc['feature'] = feat; loc['demographic'] = demo
        location_rows.append(loc)

loc_df = pd.DataFrame(location_rows)
loc_df.to_csv(DATA / 'activation_location.csv', index=False)

valid_ratios = loc_df[(loc_df['ratio'] < 1000) & (loc_df['ratio'] > 0)]['ratio']
mean_ratio = valid_ratios.mean() if len(valid_ratios) > 0 else float('inf')
median_ratio = valid_ratios.median() if len(valid_ratios) > 0 else float('inf')
n_span_sig = ((loc_df['span_p'] < 0.05) & (~loc_df['span_p'].isna())).sum()
n_last_sig = ((loc_df['last_p'] < 0.05) & (~loc_df['last_p'].isna())).sum()

print(f"\nSpan significant: {n_span_sig}/{len(loc_df)}")
print(f"Last significant: {n_last_sig}/{len(loc_df)}")
print(f"Mean |last/span| ratio: {mean_ratio:.1f}x, median: {median_ratio:.1f}x")

# --- Fig 5: Location heatmap ---

loc_heat = np.zeros((len(top_features), len(ALL_DEMOS)))
for i, feat in enumerate(top_features):
    for j, demo in enumerate(ALL_DEMOS):
        r = loc_df[(loc_df['feature'] == feat) & (loc_df['demographic'] == demo)]
        loc_heat[i, j] = r['last_mean'].iloc[0] if len(r) > 0 else 0

fig, ax = plt.subplots(figsize=(8, max(5, len(top_features) * 0.5 + 1)))
vmax = max(0.5, np.abs(loc_heat).max())
sns.heatmap(loc_heat, annot=True, fmt='+.3f', cmap='RdBu_r', center=0,
            vmin=-vmax, vmax=vmax,
            xticklabels=[d.capitalize() for d in ALL_DEMOS],
            yticklabels=[str(f) for f in top_features],
            ax=ax, linewidths=0.5, cbar_kws={'label': 'Activation Diff (a-b)'})
ax.set_title(f'{MODEL_NAME}: Last-Token Activation Difference (Layer {LAYER})', fontweight='bold')
ax.set_ylabel('Feature')
plt.tight_layout()
plt.savefig(FIGS / 'fig_location_heatmap.png')
plt.savefig(FIGS / 'fig_location_heatmap.pdf')
plt.close()
print("  Saved: fig_location_heatmap")

# --- Fig 6: Span vs last scatter ---

fig, ax = plt.subplots(figsize=(7, 7))
valid_loc = loc_df.dropna(subset=['span_mean', 'last_mean'])
valid_loc = valid_loc[valid_loc['span_found_pct'] > 50]

for demo in ALL_DEMOS:
    sub = valid_loc[valid_loc['demographic'] == demo]
    ax.scatter(sub['span_mean'], sub['last_mean'],
               color=DEMO_COLORS[demo], label=demo.capitalize(), s=60, alpha=0.8)

if len(valid_loc) > 0:
    lim = max(abs(valid_loc['span_mean']).max(), abs(valid_loc['last_mean']).max()) * 1.1
    ax.plot([-lim, lim], [-lim, lim], 'k--', alpha=0.3, linewidth=1, label='Equal')

if len(valid_loc) >= 3:
    r, p = stats.pearsonr(valid_loc['span_mean'], valid_loc['last_mean'])
    ax.set_title(f'{MODEL_NAME}: Span vs Response-Position\nr={r:.2f} ({"p<.001" if p < 0.001 else f"p={p:.3f}"})',
                 fontweight='bold')
else:
    ax.set_title(f'{MODEL_NAME}: Span vs Response-Position', fontweight='bold')

ax.set_xlabel('Span Diff (a-b)'); ax.set_ylabel('Last-Token Diff (a-b)')
ax.legend(); ax.axhline(0, color='gray', linewidth=0.5); ax.axvline(0, color='gray', linewidth=0.5)
plt.tight_layout()
plt.savefig(FIGS / 'fig_span_vs_last.png')
plt.savefig(FIGS / 'fig_span_vs_last.pdf')
plt.close()
print("  Saved: fig_span_vs_last")

# --- Fig 7: Ratio by demographic ---

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(13, 5))

ratio_matrix = np.zeros((len(top_features), len(ALL_DEMOS)))
for i, feat in enumerate(top_features):
    for j, demo in enumerate(ALL_DEMOS):
        r = loc_df[(loc_df['feature'] == feat) & (loc_df['demographic'] == demo)]
        ratio_matrix[i, j] = min(r['ratio'].iloc[0], 200) if len(r) > 0 else np.nan

sns.heatmap(ratio_matrix, annot=True, fmt='.0f', cmap='YlOrRd',
            xticklabels=[d.capitalize() for d in ALL_DEMOS],
            yticklabels=[str(f) for f in top_features],
            ax=ax1, linewidths=0.5, mask=np.isnan(ratio_matrix),
            cbar_kws={'label': '|Last / Span| Ratio'})
ax1.set_title('A. Last/Span Ratio (capped 200)', fontweight='bold'); ax1.set_ylabel('Feature')

demo_ratios = {}
for demo in ALL_DEMOS:
    sub = loc_df[(loc_df['demographic'] == demo) & (loc_df['ratio'] < 1000)]
    demo_ratios[demo] = {'mean': sub['ratio'].mean() if len(sub) > 0 else 0,
                          'median': sub['ratio'].median() if len(sub) > 0 else 0}

means = [demo_ratios[d]['mean'] for d in ALL_DEMOS]
medians = [demo_ratios[d]['median'] for d in ALL_DEMOS]
ax2.bar(range(len(ALL_DEMOS)), means, color=[DEMO_COLORS[d] for d in ALL_DEMOS], alpha=0.8, label='Mean')
ax2.scatter(range(len(ALL_DEMOS)), medians, color='black', zorder=5, s=50, marker='D', label='Median')
ax2.set_xticks(range(len(ALL_DEMOS)))
ax2.set_xticklabels([d.capitalize() for d in ALL_DEMOS])
ax2.set_ylabel('|Last / Span| Ratio'); ax2.set_title('B. Mean Ratio by Demographic', fontweight='bold')
ax2.legend(fontsize=9)

plt.suptitle(f'{MODEL_NAME}: Span vs Response-Position Ratio (Layer {LAYER})', fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig(FIGS / 'fig_location_ratio.png')
plt.savefig(FIGS / 'fig_location_ratio.pdf')
plt.close()
print("  Saved: fig_location_ratio")


# --- Multi-demographic direction ---

print("\nMulti-demographic encoding direction")

direction_results = []
for feat in top_features:
    sig_enc = enc_df[(enc_df['feature'] == feat) & (enc_df['significant'])].copy()
    if len(sig_enc) < 2: continue

    pos_demos = sig_enc[sig_enc['mean_diff'] > 0]['demographic'].tolist()
    neg_demos = sig_enc[sig_enc['mean_diff'] < 0]['demographic'].tolist()

    if pos_demos and neg_demos:
        direction_results.append({
            'feature': feat, 'n_sig_demos': len(sig_enc),
            'positive_demos': ', '.join(pos_demos), 'negative_demos': ', '.join(neg_demos),
            'opposite_direction': True,
            'description': f"value_a > value_b for {', '.join(pos_demos)}; value_b > value_a for {', '.join(neg_demos)}",
        })
    else:
        direction = 'value_a > value_b' if pos_demos else 'value_b > value_a'
        all_demos = pos_demos or neg_demos
        direction_results.append({
            'feature': feat, 'n_sig_demos': len(sig_enc),
            'positive_demos': ', '.join(pos_demos), 'negative_demos': ', '.join(neg_demos),
            'opposite_direction': False,
            'description': f"Same direction ({direction}) for {', '.join(all_demos)}",
        })

dir_df = pd.DataFrame(direction_results)

if len(dir_df) > 0:
    n_opposite = dir_df['opposite_direction'].sum()
    n_same = (~dir_df['opposite_direction']).sum()
    print(f"  Features with 2+ sig demos: {len(dir_df)}")
    print(f"  Opposite-direction: {n_opposite}, Same-direction: {n_same}")
    for _, row in dir_df.iterrows():
        print(f"  [{'OPP' if row['opposite_direction'] else 'SAME'}] Feature {row['feature']}: {row['description']}")
    dir_df.to_csv(DATA / 'multi_demo_direction.csv', index=False)
else:
    n_opposite, n_same = 0, 0
    print("  No features with 2+ significant demographics")

# --- Fig 8: Multi-demographic direction profile ---

multi_demo_feats = enc_df[enc_df['significant']].groupby('feature').filter(
    lambda x: len(x) >= 2)['feature'].unique()

if len(multi_demo_feats) >= 2:
    poly_heat = np.zeros((len(multi_demo_feats), len(ALL_DEMOS)))
    sig_mask = np.ones_like(poly_heat, dtype=bool)

    for i, feat in enumerate(multi_demo_feats):
        for j, demo in enumerate(ALL_DEMOS):
            r = enc_df[(enc_df['feature'] == feat) & (enc_df['demographic'] == demo)]
            if len(r) > 0 and r.iloc[0]['significant']:
                poly_heat[i, j] = r.iloc[0]['cohens_d']
                sig_mask[i, j] = False

    fig, ax = plt.subplots(figsize=(8, max(4, len(multi_demo_feats) * 0.6 + 1)))
    vmax = max(2.0, np.abs(poly_heat).max())
    sns.heatmap(poly_heat, annot=True, fmt='+.1f', cmap='RdBu_r', center=0,
                vmin=-vmax, vmax=vmax,
                xticklabels=[d.capitalize() for d in ALL_DEMOS],
                yticklabels=[str(f) for f in multi_demo_feats],
                ax=ax, linewidths=0.5, cbar_kws={'label': "Cohen's d (signed)"})

    for i in range(len(multi_demo_feats)):
        for j in range(len(ALL_DEMOS)):
            if sig_mask[i, j]:
                ax.add_patch(plt.Rectangle((j, i), 1, 1, fill=True, facecolor='lightgray', alpha=0.7))
                ax.text(j + 0.5, i + 0.5, 'ns', ha='center', va='center', fontsize=8, color='gray')

    if len(dir_df) > 0:
        opp_feats = set(dir_df[dir_df['opposite_direction']]['feature'].values)
        for i, feat in enumerate(multi_demo_feats):
            if feat in opp_feats:
                ax.add_patch(plt.Rectangle((0, i), len(ALL_DEMOS), 1,
                             fill=False, edgecolor='gold', linewidth=2.5))

    ax.set_title(f"{MODEL_NAME}: Multi-Demographic Encoding Direction (Layer {LAYER})\n"
                 f"gray = ns, gold = opposite-direction", fontweight='bold')
    ax.set_ylabel('Feature')
    plt.tight_layout()
    plt.savefig(FIGS / 'fig_multi_demo_direction.png')
    plt.savefig(FIGS / 'fig_multi_demo_direction.pdf')
    plt.close()
    print("  Saved: fig_multi_demo_direction")
else:
    print("  Skipped direction figure (< 2 multi-demo features)")


# --- Summary and bundle ---

print("\nSummary")

summary = {
    'model': MODEL_NAME,
    'parameters': {
        'layer': LAYER, 'layer_depth': f"{LAYER/N_LAYERS_TOTAL:.0%}",
        'n_features': N_TOP_FEATURES, 'n_pairs_encoding': N_PAIRS_ENCODING,
        'n_pairs_location': N_PAIRS_LOCATION, 'alpha': ALPHA,
        'features_tested': top_features,
    },
    'encoding': {
        'n_tests': n_total, 'n_significant': int(n_sig),
        'pct_significant': round(n_sig / n_total * 100, 1),
        'mean_abs_d': round(float(enc_df['cohens_d'].abs().mean()), 2),
    },
    'location': {
        'n_span_significant': int(n_span_sig), 'n_last_significant': int(n_last_sig),
        'mean_ratio': round(mean_ratio, 1) if mean_ratio < 1000 else 'inf',
        'median_ratio': round(median_ratio, 1) if median_ratio < 1000 else 'inf',
    },
    'multi_demographic': {
        'n_multi_demo_features': len(dir_df),
        'n_opposite': int(n_opposite), 'n_same': int(n_same),
    },
}

for section, vals in summary.items():
    print(f"\n{section}:")
    if isinstance(vals, dict):
        for k, v in vals.items(): print(f"  {k}: {v}")
    else: print(f"  {vals}")

with open(DATA / 'interpretation_summary.json', 'w') as f:
    json.dump(summary, f, indent=2, default=str)

readme = f"""# Feature Interpretation — {MODEL_NAME}

## Analyses
1. Encoding Matrix: {N_TOP_FEATURES} features x {len(ALL_DEMOS)} demographics
   - Paired t-tests, BH-FDR correction
   - {n_sig}/{n_total} significant ({n_sig/n_total*100:.1f}%)

2. Activation Location: span vs response-position
   - Mean |last/span| ratio: {mean_ratio:.1f}x

3. Multi-Demographic Direction: {len(dir_df)} features with 2+ sig demos
   - {n_opposite} opposite-direction, {n_same} same-direction

## Model
- {MODEL_NAME}, Layer {LAYER} ({LAYER/N_LAYERS_TOTAL:.0%} depth)
- SAE: gemma-scope 16k residual stream
"""

with open(BUNDLE / 'README.md', 'w') as f:
    f.write(readme)

shutil.make_archive(str(BUNDLE), 'zip', root_dir=BUNDLE.parent, base_dir=BUNDLE.name)

n_png = len(list(FIGS.glob('*.png')))
print(f"\n  {n_png} figures, {len(list(TABLES.glob('*')))} tables, {len(list(DATA.glob('*')))} data files")
print(f"  Download: feature-interpretation-results-gemma.zip")

del sae
torch.cuda.empty_cache()
print("\nDone.")

In [None]:
# concern diagnostics — extraction-side (concerns 1-6)

import pandas as pd
import numpy as np
import json
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from pathlib import Path
from collections import Counter, defaultdict
from scipy import stats
import warnings
warnings.filterwarnings('ignore')

ANALYSIS_LAYERS = [5, 9, 14, 18, 20, 27, 32, 36]
LAYER_DEPTH = {5: 12, 9: 22, 14: 34, 18: 44, 20: 49, 27: 66, 32: 78, 36: 88}
IT_LAYERS = {20}
DEMOGRAPHICS = ['income', 'age', 'gender', 'education', 'vote']
DOMAINS = ['climate', 'health', 'digital', 'economy', 'values']

BASE_DIR = Path("/content")
OUTPUT_DIR = Path("/content/reviewer-diagnostics")
OUTPUT_DIR.mkdir(exist_ok=True)
(OUTPUT_DIR / "figures").mkdir(exist_ok=True)
(OUTPUT_DIR / "data").mkdir(exist_ok=True)

DATA_DIR = None
for candidate in [
    BASE_DIR / "feature-extraction-results-gemma" / "data",
    BASE_DIR / "feature_extraction_results_gemma" / "data",
]:
    if candidate.exists():
        DATA_DIR = candidate; break
if DATA_DIR is None:
    for match in BASE_DIR.rglob("feature_stats.csv"):
        DATA_DIR = match.parent; break
if DATA_DIR is None:
    raise FileNotFoundError("Cannot find feature extraction data directory")

print(f"Data: {DATA_DIR}")

feat_df = pd.read_csv(DATA_DIR / "feature_stats.csv")
funnel_df = pd.read_csv(DATA_DIR / "filtering_funnel.csv")
behav_df = pd.read_csv(DATA_DIR / "behavioral_effects.csv")
with open(DATA_DIR / "extraction_summary.json") as f:
    summary = json.load(f)
with open(DATA_DIR / "selected_features.json") as f:
    selected = json.load(f)

print(f"  feature_stats: {len(feat_df)}, funnel: {len(funnel_df)}, behavioral: {len(behav_df)}")

diag = {'model': 'Gemma 2 9B IT', 'concerns': {}}

plt.rcParams.update({'font.size': 10, 'axes.titlesize': 11, 'axes.labelsize': 10, 'figure.dpi': 150})


# --- Concern 1: Extreme effect sizes ---

print("\nConcern 1: Extreme effect sizes")

feat_df['abs_d'] = feat_df['cohens_d'].abs()
feat_df['abs_mean_diff'] = feat_df['mean_diff'].abs()
feat_df['d_category'] = pd.cut(
    feat_df['abs_d'],
    bins=[0, 0.8, 2.0, 5.0, 10.0, np.inf],
    labels=['medium (0.3-0.8)', 'large (0.8-2)', 'very large (2-5)',
            'extreme (5-10)', 'implausible (>10)'])

cat_counts = feat_df['d_category'].value_counts().sort_index()
for cat, n in cat_counts.items():
    print(f"  {cat:<25} {n:>5} ({n/len(feat_df)*100:>5.1f}%)")

extreme_df = feat_df[feat_df['abs_d'] > 5.0]
print(f"\n  Extreme (|d|>5): {len(extreme_df)} ({len(extreme_df)/len(feat_df)*100:.1f}%)")

if len(extreme_df) > 0:
    print(f"  Mean |mean_diff| extreme: {extreme_df['abs_mean_diff'].mean():.2f}")
    print(f"  Mean |mean_diff| normal:  {feat_df[feat_df['abs_d'] <= 5]['abs_mean_diff'].mean():.2f}")
    extreme_multi = extreme_df.groupby(['layer', 'feature_idx'])['demographic'].nunique()
    print(f"  Extreme encoding 1 demo: {(extreme_multi == 1).sum()}, 2+: {(extreme_multi >= 2).sum()}")

feat_df['implied_sd'] = feat_df['abs_mean_diff'] / feat_df['abs_d'].clip(lower=0.01)

fig, axes = plt.subplots(2, 2, figsize=(12, 10))
fig.suptitle("Concern 1: Extreme Effect Sizes", fontsize=13, fontweight='bold')

ax = axes[0, 0]
for layer in ANALYSIS_LAYERS:
    ldf = feat_df[feat_df['layer'] == layer]
    ax.hist(ldf['abs_d'], bins=50, alpha=0.5, label=f"L{layer}", density=True)
ax.axvline(5.0, color='red', linestyle='--', alpha=0.7, label='|d|=5')
ax.set_xlabel("|Cohen's d|"); ax.set_ylabel("Density")
ax.set_title("A. Effect size distribution by layer")
ax.set_xlim(0, min(feat_df['abs_d'].max() + 1, 15)); ax.legend(fontsize=7, ncol=2)

ax = axes[0, 1]
for layer in [5, 14, 27, 36]:
    ldf = feat_df[feat_df['layer'] == layer]
    ax.scatter(ldf['abs_mean_diff'], ldf['abs_d'], s=8, alpha=0.4, label=f"L{layer}")
ax.set_xlabel("|Mean Activation Diff|"); ax.set_ylabel("|Cohen's d|")
ax.set_title("B. Raw diff vs Cohen's d")
ax.axhline(5.0, color='red', linestyle='--', alpha=0.5); ax.legend(fontsize=8)

ax = axes[1, 0]
for layer in ANALYSIS_LAYERS:
    ldf = feat_df[feat_df['layer'] == layer]
    ax.scatter([LAYER_DEPTH[layer]] * len(ldf), ldf['implied_sd'], s=5, alpha=0.3)
    ax.scatter(LAYER_DEPTH[layer], ldf['implied_sd'].median(), s=40, c='red', zorder=5, marker='_', linewidths=2)
ax.set_xlabel("Layer Depth (%)"); ax.set_ylabel("Implied Pooled SD")
ax.set_title("C. Feature activation spread by depth"); ax.set_yscale('log')

ax = axes[1, 1]
if len(extreme_df) > 0:
    pivot = extreme_df.groupby(['demographic', 'domain']).size().unstack(fill_value=0)
    for demo in DEMOGRAPHICS:
        if demo not in pivot.index: pivot.loc[demo] = 0
    for dom in DOMAINS:
        if dom not in pivot.columns: pivot[dom] = 0
    pivot = pivot.reindex(DEMOGRAPHICS)[DOMAINS]
    im = ax.imshow(pivot.values, aspect='auto', cmap='Reds')
    ax.set_xticks(range(len(DOMAINS))); ax.set_xticklabels(DOMAINS, rotation=45, ha='right')
    ax.set_yticks(range(len(DEMOGRAPHICS))); ax.set_yticklabels(DEMOGRAPHICS)
    for i in range(len(DEMOGRAPHICS)):
        for j in range(len(DOMAINS)):
            ax.text(j, i, str(int(pivot.values[i, j])), ha='center', va='center', fontsize=9)
    plt.colorbar(im, ax=ax, shrink=0.8)
else:
    ax.text(0.5, 0.5, "No extreme features", ha='center', va='center', transform=ax.transAxes)
ax.set_title("D. Extreme features (|d|>5) by group")

plt.tight_layout()
fig.savefig(OUTPUT_DIR / "figures" / "fig_concern1_extreme_effects.png", dpi=150, bbox_inches='tight')
fig.savefig(OUTPUT_DIR / "figures" / "fig_concern1_extreme_effects.pdf", bbox_inches='tight')
plt.close()
print("  Saved: fig_concern1_extreme_effects")

diag['concerns']['extreme_effects'] = {
    'n_extreme_d5': int(len(extreme_df)),
    'n_extreme_d10': int((feat_df['abs_d'] > 10).sum()),
    'pct_extreme_d5': float(len(extreme_df) / len(feat_df) * 100),
    'max_d': float(feat_df['abs_d'].max()),
    'category_distribution': {str(k): int(v) for k, v in cat_counts.items()},
}


# --- Concern 2: Filtering funnel ---

print("\nConcern 2: Filtering funnel")

layer_clean_d, layer_noisy_d = {}, {}
for layer in ANALYSIS_LAYERS:
    lf = funnel_df[funnel_df['layer'] == layer]
    layer_clean_d[layer] = lf['mean_abs_d_clean'].mean()
    layer_noisy_d[layer] = lf['mean_abs_d_noisy'].mean()
    ratio = layer_clean_d[layer] / max(layer_noisy_d[layer], 1e-6)
    print(f"  L{layer}: clean |d|={layer_clean_d[layer]:.4f}, noisy={layer_noisy_d[layer]:.4f}, ratio={ratio:.0f}x")

fig, axes = plt.subplots(2, 2, figsize=(12, 10))
fig.suptitle("Concern 2: Feature Selection Funnel", fontsize=13, fontweight='bold')

ax = axes[0, 0]
funnel_stages = []
for layer in ANALYSIS_LAYERS:
    lf = funnel_df[funnel_df['layer'] == layer]
    funnel_stages.append({
        'layer': f"L{layer}",
        'No variance': lf['n_total_features'].iloc[0] - lf['n_has_variance'].mean(),
        'Noisy': lf['n_has_variance'].mean() - lf['n_not_noisy'].mean(),
        'Selected': lf['n_selected'].mean(),
    })
pd.DataFrame(funnel_stages).set_index('layer').plot(
    kind='bar', stacked=True, ax=ax, color=['#d4d4d4', '#ffb3b3', '#2ecc71'])
ax.set_ylabel("Feature Count"); ax.set_title("A. Filtering funnel")
ax.set_xticklabels(ax.get_xticklabels(), rotation=0); ax.legend(fontsize=8)

ax = axes[0, 1]
x = np.arange(len(ANALYSIS_LAYERS)); width = 0.35
ax.bar(x - width/2, [layer_clean_d[l] for l in ANALYSIS_LAYERS], width, label='Selected', color='#2ecc71')
ax.bar(x + width/2, [layer_noisy_d[l] for l in ANALYSIS_LAYERS], width, label='Rejected', color='#e74c3c')
ax.set_xticks(x); ax.set_xticklabels([f"L{l}" for l in ANALYSIS_LAYERS])
ax.set_ylabel("Mean |Cohen's d|"); ax.set_title("B. Encoding: selected vs rejected")
ax.set_yscale('log'); ax.legend()

ax = axes[1, 0]
for layer in ANALYSIS_LAYERS:
    lf = funnel_df[funnel_df['layer'] == layer]
    ax.scatter([LAYER_DEPTH[layer]] * len(lf), lf['n_has_variance'], s=15, alpha=0.5, c='gray',
               label='Has variance' if layer == 5 else '')
    ax.scatter([LAYER_DEPTH[layer]] * len(lf), lf['n_selected'], s=15, alpha=0.8, c='green',
               label='Selected' if layer == 5 else '')
ax.set_xlabel("Layer Depth (%)"); ax.set_ylabel("N Features")
ax.set_title("C. Variance vs selected per group"); ax.legend()

ax = axes[1, 1]
ax.hist(funnel_df['mean_sign_agreement'], bins=30, color='steelblue', edgecolor='white')
ax.axvline(0.7, color='red', linestyle='--', label='Threshold (0.7)')
ax.set_xlabel("Mean Sign Agreement"); ax.set_ylabel("Count")
ax.set_title("D. Sign agreement distribution"); ax.legend()

plt.tight_layout()
fig.savefig(OUTPUT_DIR / "figures" / "fig_concern2_funnel_distribution.png", dpi=150, bbox_inches='tight')
fig.savefig(OUTPUT_DIR / "figures" / "fig_concern2_funnel_distribution.pdf", bbox_inches='tight')
plt.close()
print("  Saved: fig_concern2_funnel_distribution")

diag['concerns']['funnel'] = {
    'mean_d_clean': float(np.mean(list(layer_clean_d.values()))),
    'mean_d_noisy': float(np.mean(list(layer_noisy_d.values()))),
    'clean_noisy_ratio': float(np.mean(list(layer_clean_d.values())) / max(np.mean(list(layer_noisy_d.values())), 1e-6)),
}


# --- Concern 3: Behavioral effects are layer-independent ---

print("\nConcern 3: Behavioral effects")

l5_effects = behav_df[behav_df['layer'] == 5].set_index(['demographic', 'domain'])['effect']
l36_effects = behav_df[behav_df['layer'] == 36].set_index(['demographic', 'domain'])['effect']
common_idx = l5_effects.index.intersection(l36_effects.index)

if len(common_idx) > 0:
    corr = l5_effects.loc[common_idx].corr(l36_effects.loc[common_idx])
    max_diff = (l5_effects.loc[common_idx] - l36_effects.loc[common_idx]).abs().max()
    print(f"  L5 vs L36 correlation: {corr:.6f}")
    print(f"  Max absolute difference: {max_diff:.6f}")
    print(f"  Effects are {'identical' if max_diff < 0.001 else 'near-identical'} across layers")

diag['concerns']['behavioral_layer_independence'] = {
    'l5_l36_correlation': float(corr) if len(common_idx) > 0 else None,
    'max_layer_difference': float(max_diff) if len(common_idx) > 0 else None,
}

print("\n  Canonical effects (layer-independent):")
for _, r in behav_df[behav_df['layer'] == ANALYSIS_LAYERS[0]].iterrows():
    sig = "***" if r['p_value'] < 0.001 else "**" if r['p_value'] < 0.01 else "*" if r['p_value'] < 0.05 else "ns"
    print(f"    {r['demographic']}×{r['domain']:<14} {r['effect']:>+.4f} {sig}")


# --- Concern 4: Metric divergence ---

print("\nConcern 4: Metric divergence")

metric_summary = []
for layer in ANALYSIS_LAYERS:
    ldf = feat_df[feat_df['layer'] == layer]
    metric_summary.append({
        'layer': layer, 'depth': LAYER_DEPTH[layer],
        'mean_abs_d': ldf['abs_d'].mean(), 'median_abs_d': ldf['abs_d'].median(),
        'mean_abs_diff': ldf['abs_mean_diff'].mean(), 'mean_implied_sd': ldf['implied_sd'].median(),
        'n_features': len(ldf),
    })
metric_df = pd.DataFrame(metric_summary)

for _, r in metric_df.iterrows():
    print(f"  L{int(r['layer'])}: |d|={r['mean_abs_d']:.3f}, |Δ|={r['mean_abs_diff']:.2f}, SD={r['mean_implied_sd']:.2f}")

fig, axes = plt.subplots(1, 3, figsize=(14, 5))
fig.suptitle("Concern 4: Raw Activation Diff vs Cohen's d", fontsize=13, fontweight='bold')

ax1 = axes[0]
ax1.plot(metric_df['depth'], metric_df['mean_abs_d'], 'o-', color='#2c3e50', linewidth=2, label="|Cohen's d|")
ax1.set_xlabel("Depth (%)"); ax1.set_ylabel("|Cohen's d|", color='#2c3e50')
ax2_twin = ax1.twinx()
ax2_twin.plot(metric_df['depth'], metric_df['mean_abs_diff'], 's--', color='#e74c3c', linewidth=2, label="|Activation Δ|")
ax2_twin.set_ylabel("|Activation Δ|", color='#e74c3c')
ax1.set_title("A. Divergent metrics")
lines1, labels1 = ax1.get_legend_handles_labels()
lines2, labels2 = ax2_twin.get_legend_handles_labels()
ax1.legend(lines1 + lines2, labels1 + labels2, fontsize=8, loc='center left')

ax = axes[1]
for layer in ANALYSIS_LAYERS:
    ldf = feat_df[feat_df['layer'] == layer]
    parts = ax.violinplot([ldf['implied_sd'].values], positions=[LAYER_DEPTH[layer]],
                          widths=6, showmedians=True, showextrema=False)
    color = 'darkorange' if layer in IT_LAYERS else 'steelblue'
    for pc in parts['bodies']: pc.set_facecolor(color); pc.set_alpha(0.6)
    parts['cmedians'].set_color('red')
ax.set_xlabel("Depth (%)"); ax.set_ylabel("Implied Pooled SD")
ax.set_title("B. Activation spread increases with depth"); ax.set_yscale('log')

ax = axes[2]
d_norm = (metric_df['mean_abs_d'] - metric_df['mean_abs_d'].min()) / (metric_df['mean_abs_d'].max() - metric_df['mean_abs_d'].min())
diff_norm = (metric_df['mean_abs_diff'] - metric_df['mean_abs_diff'].min()) / (metric_df['mean_abs_diff'].max() - metric_df['mean_abs_diff'].min())
sd_norm = (metric_df['mean_implied_sd'] - metric_df['mean_implied_sd'].min()) / (metric_df['mean_implied_sd'].max() - metric_df['mean_implied_sd'].min())
ax.plot(metric_df['depth'], d_norm, 'o-', label="|d| (↓)", linewidth=2)
ax.plot(metric_df['depth'], diff_norm, 's--', label="|Δ| (↑)", linewidth=2)
ax.plot(metric_df['depth'], sd_norm, '^:', label="SD (↑↑)", linewidth=2)
ax.set_xlabel("Depth (%)"); ax.set_ylabel("Normalized [0, 1]")
ax.set_title("C. SD grows faster than mean diff"); ax.legend(fontsize=8)

plt.tight_layout()
fig.savefig(OUTPUT_DIR / "figures" / "fig_concern4_metric_divergence.png", dpi=150, bbox_inches='tight')
fig.savefig(OUTPUT_DIR / "figures" / "fig_concern4_metric_divergence.pdf", bbox_inches='tight')
plt.close()
print("  Saved: fig_concern4_metric_divergence")

diag['concerns']['metric_divergence'] = {
    'early_mean_d': float(metric_df[metric_df['layer'].isin([5, 9, 14])]['mean_abs_d'].mean()),
    'late_mean_d': float(metric_df[metric_df['layer'].isin([27, 32, 36])]['mean_abs_d'].mean()),
    'early_mean_diff': float(metric_df[metric_df['layer'].isin([5, 9, 14])]['mean_abs_diff'].mean()),
    'late_mean_diff': float(metric_df[metric_df['layer'].isin([27, 32, 36])]['mean_abs_diff'].mean()),
}


# --- Concern 5: L20 IT SAE dip ---

print("\nConcern 5: L20 (IT SAE) dip")

l18_d = metric_df[metric_df['layer'] == 18]['mean_abs_d'].values[0]
l27_d = metric_df[metric_df['layer'] == 27]['mean_abs_d'].values[0]
l20_d = metric_df[metric_df['layer'] == 20]['mean_abs_d'].values[0]
interp_depth = (LAYER_DEPTH[20] - LAYER_DEPTH[18]) / (LAYER_DEPTH[27] - LAYER_DEPTH[18])
expected_d = l18_d + interp_depth * (l27_d - l18_d)

l20_n = metric_df[metric_df['layer'] == 20]['n_features'].values[0]
expected_n = metric_df[metric_df['layer'] == 18]['n_features'].values[0] + \
    interp_depth * (metric_df[metric_df['layer'] == 27]['n_features'].values[0] -
                    metric_df[metric_df['layer'] == 18]['n_features'].values[0])

print(f"  Observed |d|={l20_d:.3f}, expected={expected_d:.3f} (dev={l20_d - expected_d:+.3f})")
print(f"  Observed N={l20_n:.0f}, expected={expected_n:.0f}")

demo_deviations = []
for demo in DEMOGRAPHICS:
    d18 = feat_df[(feat_df['layer'] == 18) & (feat_df['demographic'] == demo)]['abs_d'].mean()
    d20 = feat_df[(feat_df['layer'] == 20) & (feat_df['demographic'] == demo)]['abs_d'].mean()
    d27 = feat_df[(feat_df['layer'] == 27) & (feat_df['demographic'] == demo)]['abs_d'].mean()
    demo_deviations.append(d20 - (d18 + interp_depth * (d27 - d18)))

fig, axes = plt.subplots(1, 3, figsize=(14, 5))
fig.suptitle("Concern 5: Layer 20 (IT SAE) Investigation", fontsize=13, fontweight='bold')

ax = axes[0]
depths = [LAYER_DEPTH[l] for l in ANALYSIS_LAYERS]
d_vals = [metric_df[metric_df['layer'] == l]['mean_abs_d'].values[0] for l in ANALYSIS_LAYERS]
ax.plot(depths, d_vals, 'o-', color='steelblue', linewidth=2, zorder=2)
for i, l in enumerate(ANALYSIS_LAYERS):
    if l in IT_LAYERS:
        ax.scatter([depths[i]], [d_vals[i]], s=120, c='darkorange', zorder=5,
                   edgecolors='red', linewidths=2, label='IT SAE (L20)')
ax.axhline(expected_d, color='gray', linestyle=':', alpha=0.5, label=f'Interpolated ({expected_d:.2f})')
ax.set_xlabel("Depth (%)"); ax.set_ylabel("Mean |Cohen's d|")
ax.set_title("A. Encoding gradient — L20 below trend"); ax.legend(fontsize=8)

ax = axes[1]
for demo in DEMOGRAPHICS:
    demo_d = [feat_df[(feat_df['layer'] == l) & (feat_df['demographic'] == demo)]['abs_d'].mean()
              if len(feat_df[(feat_df['layer'] == l) & (feat_df['demographic'] == demo)]) > 0 else 0
              for l in ANALYSIS_LAYERS]
    ax.plot(depths, demo_d, 'o-', label=demo, markersize=4, linewidth=1.5)
ax.axvline(LAYER_DEPTH[20], color='darkorange', linestyle='--', alpha=0.7)
ax.set_xlabel("Depth (%)"); ax.set_ylabel("Mean |Cohen's d|")
ax.set_title("B. Per-demographic gradient"); ax.legend(fontsize=7, ncol=2)

ax = axes[2]
n_feats = [metric_df[metric_df['layer'] == l]['n_features'].values[0] for l in ANALYSIS_LAYERS]
ax.bar([f"L{l}" for l in ANALYSIS_LAYERS], n_feats,
       color=['darkorange' if l in IT_LAYERS else 'steelblue' for l in ANALYSIS_LAYERS])
ax.axhline(expected_n, color='gray', linestyle=':', alpha=0.5)
ax.set_ylabel("N Features"); ax.set_title("C. Feature count by layer")
for i, n in enumerate(n_feats):
    ax.text(i, n + 3, str(int(n)), ha='center', fontsize=8)

plt.tight_layout()
fig.savefig(OUTPUT_DIR / "figures" / "fig_concern5_l20_investigation.png", dpi=150, bbox_inches='tight')
fig.savefig(OUTPUT_DIR / "figures" / "fig_concern5_l20_investigation.pdf", bbox_inches='tight')
plt.close()
print("  Saved: fig_concern5_l20_investigation")

diag['concerns']['l20_it_dip'] = {
    'observed_d': float(l20_d), 'expected_d': float(expected_d),
    'deviation': float(l20_d - expected_d),
    'observed_n': int(l20_n), 'expected_n': float(expected_n),
}


# --- Concern 6: Domain specificity (proxy null baseline) ---

print("\nConcern 6: Domain specificity")

_sel_path = DATA_DIR / "selected_features.json"
with open(_sel_path) as f:
    _selected = json.load(f)

def _get_feature_set(sel, layer_str, key):
    if not isinstance(sel, dict): return set()
    layer_data = sel.get(layer_str, {})
    if not isinstance(layer_data, dict): return set()
    entry = layer_data.get(key, {})
    if isinstance(entry, dict): return set(entry.get('features', []))
    elif isinstance(entry, list): return set(entry)
    return set()

domain_spec = []
for layer in ANALYSIS_LAYERS:
    layer_data = _selected.get(str(layer), {})
    if not isinstance(layer_data, dict): continue
    for demo in DEMOGRAPHICS:
        demo_feats_by_domain = {}
        for dom in DOMAINS:
            demo_feats_by_domain[dom] = _get_feature_set(_selected, str(layer), f"{demo}_{dom}")

        domains_with_feats = [d for d in DOMAINS if len(demo_feats_by_domain[d]) > 0]
        jaccards = []
        for i, d1 in enumerate(domains_with_feats):
            for d2 in domains_with_feats[i+1:]:
                s1, s2 = demo_feats_by_domain[d1], demo_feats_by_domain[d2]
                if len(s1 | s2) > 0:
                    jaccards.append(len(s1 & s2) / len(s1 | s2))

        if len(domains_with_feats) >= 3:
            all_dom = set.intersection(*[demo_feats_by_domain[d] for d in domains_with_feats])
            any_dom = set.union(*[demo_feats_by_domain[d] for d in domains_with_feats])
            pct_generic = len(all_dom) / max(len(any_dom), 1) * 100
        else:
            pct_generic = 0

        domain_spec.append({
            'layer': layer, 'demographic': demo,
            'mean_jaccard': np.mean(jaccards) if jaccards else 0,
            'pct_generic': pct_generic,
        })

spec_df = pd.DataFrame(domain_spec)

for layer in ANALYSIS_LAYERS:
    ldf = spec_df[spec_df['layer'] == layer]
    print(f"  L{layer}: Jaccard={ldf['mean_jaccard'].mean():.3f}, generic={ldf['pct_generic'].mean():.0f}%")

jaccard_matrices = {}
for demo in DEMOGRAPHICS:
    jmat = np.zeros((len(DOMAINS), len(DOMAINS)))
    for i, d1 in enumerate(DOMAINS):
        for j, d2 in enumerate(DOMAINS):
            s1 = _get_feature_set(_selected, '36', f"{demo}_{d1}")
            s2 = _get_feature_set(_selected, '36', f"{demo}_{d2}")
            if i == j: jmat[i, j] = 1.0
            elif len(s1 | s2) > 0: jmat[i, j] = len(s1 & s2) / len(s1 | s2)
    jaccard_matrices[demo] = jmat
avg_jmat = np.mean([jaccard_matrices[d] for d in DEMOGRAPHICS], axis=0)

fig, axes = plt.subplots(1, 3, figsize=(15, 5))
fig.suptitle("Concern 6: Domain Specificity — Proxy Null Baseline", fontsize=13, fontweight='bold')

ax = axes[0]
im = ax.imshow(avg_jmat, cmap='YlOrRd', vmin=0, vmax=1, aspect='auto')
ax.set_xticks(range(len(DOMAINS))); ax.set_xticklabels(DOMAINS, rotation=45, ha='right')
ax.set_yticks(range(len(DOMAINS))); ax.set_yticklabels(DOMAINS)
for i in range(len(DOMAINS)):
    for j in range(len(DOMAINS)):
        ax.text(j, i, f"{avg_jmat[i, j]:.2f}", ha='center', va='center', fontsize=9)
plt.colorbar(im, ax=ax, shrink=0.8, label='Jaccard')
ax.set_title("A. Cross-domain overlap (L36)")

ax = axes[1]
for demo in DEMOGRAPHICS:
    ddf = spec_df[spec_df['demographic'] == demo]
    vals = [ddf[ddf['layer'] == l]['pct_generic'].values[0]
            if len(ddf[ddf['layer'] == l]) > 0 else 0 for l in ANALYSIS_LAYERS]
    ax.plot([LAYER_DEPTH[l] for l in ANALYSIS_LAYERS], vals, 'o-', label=demo, markersize=4)
ax.set_xlabel("Depth (%)"); ax.set_ylabel("% Generic Features")
ax.set_title("B. Domain generality by depth"); ax.legend(fontsize=8); ax.set_ylim(0, 100)

ax = axes[2]
for dom_idx, dom in enumerate(DOMAINS):
    demo_feats = {demo: _get_feature_set(_selected, '36', f"{demo}_{dom}") for demo in DEMOGRAPHICS}
    jaccards = []
    for i, d1 in enumerate(DEMOGRAPHICS):
        for d2 in DEMOGRAPHICS[i+1:]:
            if len(demo_feats[d1] | demo_feats[d2]) > 0:
                jaccards.append(len(demo_feats[d1] & demo_feats[d2]) / len(demo_feats[d1] | demo_feats[d2]))
    ax.bar(dom_idx, np.mean(jaccards) if jaccards else 0, color='steelblue')
ax.set_xticks(range(len(DOMAINS))); ax.set_xticklabels(DOMAINS, rotation=45, ha='right')
ax.set_ylabel("Mean Jaccard"); ax.set_title("C. Cross-demographic overlap (L36)")

plt.tight_layout()
fig.savefig(OUTPUT_DIR / "figures" / "fig_concern6_domain_specificity.png", dpi=150, bbox_inches='tight')
fig.savefig(OUTPUT_DIR / "figures" / "fig_concern6_domain_specificity.pdf", bbox_inches='tight')
plt.close()
print("  Saved: fig_concern6_domain_specificity")

diag['concerns']['domain_specificity'] = {
    'mean_jaccard_l36': float(spec_df[spec_df['layer'] == 36]['mean_jaccard'].mean()),
    'mean_pct_generic_l36': float(spec_df[spec_df['layer'] == 36]['pct_generic'].mean()),
}

# Save diagnostics
with open(OUTPUT_DIR / "data" / "concern_diagnostics.json", 'w') as f:
    json.dump(diag, f, indent=2)

print("\nExtraction-side diagnostics done.")

In [None]:
#  concern diagnostics — causal validation (concern 7)

import ast
from matplotlib.patches import Patch

print("Concern 7: Causal validation diagnostics")

CAUSAL_DIR = None
for candidate in [
    BASE_DIR / "causal_validation_final",
    BASE_DIR / "outputs_gemma_replication" / "causal_validation_final",
]:
    if candidate.exists() and (candidate / "validation_results.csv").exists():
        CAUSAL_DIR = candidate; break
if CAUSAL_DIR is None:
    for match in BASE_DIR.rglob("validation_results.csv"):
        CAUSAL_DIR = match.parent; break

if CAUSAL_DIR is None:
    print("  Causal data not found — skipping")
else:
    print(f"  Causal data: {CAUSAL_DIR}")

    causal_df = pd.read_csv(CAUSAL_DIR / "validation_results.csv")
    causal_top = pd.read_csv(CAUSAL_DIR / "top_features.csv")
    excl_path = CAUSAL_DIR / "exclusion_log.csv"
    causal_excl = pd.read_csv(excl_path) if excl_path.exists() else pd.DataFrame()

    has_method = 'feature_method' in causal_df.columns
    def per_pair(df):
        return df[df['feature_method'] == 'per_pair'] if has_method else df

    print(f"  Results: {len(causal_df)} rows, top features: {len(causal_top)}")

    # --- 7a: Pair filtering consistency ---

    print("\n  7a: Pair filtering consistency")
    pairs_per_layer = {}
    for l in ANALYSIS_LAYERS:
        pairs_per_layer[l] = per_pair(causal_df[causal_df['layer'] == l])['pair_key'].nunique()

    all_same = len(set(pairs_per_layer.values())) == 1
    print(f"    Identical across layers: {all_same} ({list(pairs_per_layer.values())[0] if all_same else pairs_per_layer})")

    if len(causal_excl) > 0:
        excl_pivot = causal_excl.groupby(['demographic', 'domain']).size().reset_index(name='excluded')
        total_per_cond = causal_df.groupby(['demographic', 'domain'])['pair_key'].nunique().reset_index(name='valid')
        merged = excl_pivot.merge(total_per_cond, on=['demographic', 'domain'], how='outer').fillna(0)
        merged['total'] = merged['excluded'] + merged['valid']
        merged['excl_rate'] = merged['excluded'] / merged['total'] * 100
        high_excl = merged[merged['excl_rate'] > 60]
        if len(high_excl) > 0:
            print(f"    High exclusion (>60%): {len(high_excl)} conditions")
            for _, row in high_excl.iterrows():
                print(f"      {row['demographic']}×{row['domain']}: {row['excl_rate']:.0f}%")
        else:
            print("    No conditions with >60% exclusion")

    diag['concerns']['7a_pair_filtering'] = {
        'pairs_per_layer': pairs_per_layer,
        'identical': all_same,
    }

    # --- 7b: Encoding–causal dissociation ---

    print("\n  7b: Encoding–causal dissociation")

    encoding_by_layer = {}
    for l in ANALYSIS_LAYERS:
        layer_feats = feat_df[feat_df['layer'] == l]
        if len(layer_feats) > 0:
            encoding_by_layer[l] = layer_feats['abs_d'].mean()

    causal_by_layer = {}
    for l in ANALYSIS_LAYERS:
        vals = per_pair(causal_df[(causal_df['layer'] == l) & (causal_df['K'] == 50)])
        patch_vals = vals[vals['condition'] == 'patch_same']['recovery_ev'].dropna()
        if len(patch_vals) > 0:
            causal_by_layer[l] = patch_vals.mean()

    shared_layers = sorted(set(encoding_by_layer.keys()) & set(causal_by_layer.keys()))
    enc_vals = [encoding_by_layer[l] for l in shared_layers]
    cau_vals = [causal_by_layer[l] for l in shared_layers]

    r_enc_cau, p_enc_cau = stats.pearsonr(enc_vals, cau_vals) if len(shared_layers) >= 4 else (np.nan, np.nan)
    if not np.isnan(r_enc_cau):
        print(f"    Correlation: r={r_enc_cau:.3f}, p={p_enc_cau:.4f}")

    early = [l for l in shared_layers if LAYER_DEPTH[l] < 40]
    late = [l for l in shared_layers if LAYER_DEPTH[l] > 60]
    early_enc = np.mean([encoding_by_layer[l] for l in early]) if early else np.nan
    late_enc = np.mean([encoding_by_layer[l] for l in late]) if late else np.nan
    early_cau = np.mean([causal_by_layer[l] for l in early]) if early else np.nan
    late_cau = np.mean([causal_by_layer[l] for l in late]) if late else np.nan

    ratio_enc = early_enc / max(late_enc, 1e-6)
    ratio_cau = float('inf') if early_cau < 0 else late_cau / max(early_cau, 1e-6)

    print(f"    Early (<40%): enc |d|={early_enc:.2f}, recovery={early_cau:.1%}")
    print(f"    Late  (>60%): enc |d|={late_enc:.2f}, recovery={late_cau:.1%}")
    print(f"    Encoding ratio (early/late): {ratio_enc:.2f}x")
    print(f"    Causal ratio (late/early): {'∞' if ratio_cau == float('inf') else f'{ratio_cau:.1f}x'}")

    # Dissociation figure
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    ax = axes[0]
    enc_max = max(enc_vals) if max(enc_vals) > 0 else 1
    enc_norm = [v / enc_max * 100 for v in enc_vals]
    cau_pct = [v * 100 for v in cau_vals]
    layer_depths = [LAYER_DEPTH[l] for l in shared_layers]
    ax.plot(layer_depths, enc_norm, 'o-', color='#3498db', linewidth=2.5, markersize=8, label='Encoding (norm)')
    ax.plot(layer_depths, cau_pct, 's-', color='#e74c3c', linewidth=2.5, markersize=8, label='Recovery (%)')
    for l in IT_LAYERS:
        if l in shared_layers:
            ax.axvline(x=LAYER_DEPTH[l], color='orange', linestyle='--', alpha=0.5)
    ax.set_xlabel('Depth (%)'); ax.set_ylabel('Normalized Strength')
    ax.set_title('A. Encoding vs Causal Across Depth', fontweight='bold'); ax.legend(fontsize=8)

    ax = axes[1]
    for l in shared_layers:
        c = 'orange' if l in IT_LAYERS else '#2c3e50'
        marker = 'D' if l in IT_LAYERS else 'o'
        ax.scatter(encoding_by_layer[l], causal_by_layer[l] * 100, s=120, c=c, marker=marker,
                   edgecolors='black', linewidth=1, zorder=3)
        ax.annotate(f'L{l}', (encoding_by_layer[l], causal_by_layer[l] * 100),
                    textcoords="offset points", xytext=(6, 4), fontsize=9)
    title = f'B. Scatter (r={r_enc_cau:.2f}, p={p_enc_cau:.3f})' if not np.isnan(r_enc_cau) else 'B. Scatter'
    ax.set_title(title, fontweight='bold')
    ax.set_xlabel('Encoding |d|'); ax.set_ylabel('Recovery (%)')

    ax = axes[2]
    x = np.arange(2); w = 0.35
    ax.bar(x - w/2, [early_enc, late_enc], w, label='Encoding |d|', color='#3498db', edgecolor='black')
    ax_r = ax.twinx()
    ax_r.bar(x + w/2, [early_cau * 100, late_cau * 100], w, label='Recovery (%)', color='#e74c3c', edgecolor='black')
    ax.set_xticks(x); ax.set_xticklabels(['Early\n(<40%)', 'Late\n(>60%)'])
    ax.set_ylabel('Encoding |d|', color='#3498db')
    ax_r.set_ylabel('Recovery (%)', color='#e74c3c')
    ax.set_title('C. Early vs Late', fontweight='bold')
    ax.legend(handles=[Patch(facecolor='#3498db', label='Encoding'), Patch(facecolor='#e74c3c', label='Recovery')],
              fontsize=8, loc='upper left')

    plt.suptitle('Concern 7b: Encoding–Causal Dissociation', fontweight='bold')
    plt.tight_layout()
    plt.savefig(OUTPUT_DIR / "figures" / "fig_concern7b_dissociation.png")
    plt.savefig(OUTPUT_DIR / "figures" / "fig_concern7b_dissociation.pdf")
    plt.close()
    print("    Saved: fig_concern7b_dissociation")

    diag['concerns']['7b_dissociation'] = {
        'encoding_by_layer': {str(k): float(v) for k, v in encoding_by_layer.items()},
        'causal_by_layer': {str(k): float(v) for k, v in causal_by_layer.items()},
        'correlation_r': float(r_enc_cau) if not np.isnan(r_enc_cau) else None,
        'correlation_p': float(p_enc_cau) if not np.isnan(p_enc_cau) else None,
        'early_encoding': float(early_enc), 'late_encoding': float(late_enc),
        'early_causal': float(early_cau), 'late_causal': float(late_cau),
    }

    # --- 7c: Causal effect plausibility ---

    print("\n  7c: Effect plausibility")

    K = 50
    for l in ANALYSIS_LAYERS:
        l_df = causal_df[(causal_df['layer'] == l) & (causal_df['K'] == K)]
        pp = per_pair(l_df)
        same = pp[pp['condition'] == 'patch_same']['recovery_ev'].dropna()
        rand_b = l_df[l_df['condition'] == 'patch_random_b']['recovery_ev'].dropna()
        if len(same) < 5: continue
        tag = " (IT)" if l in IT_LAYERS else ""
        print(f"    L{l}{tag}: same={same.mean():.1%}, rand_b={rand_b.mean():.1%}, gap={same.mean()-rand_b.mean():.1%}")

    all_patch = per_pair(causal_df[(causal_df['K'] == K) & (causal_df['condition'] == 'patch_same')])
    over_100 = (all_patch['recovery_ev'] > 1.0).sum()
    under_0 = (all_patch['recovery_ev'] < 0.0).sum()
    total = len(all_patch)
    print(f"    Recovery >100%: {over_100}/{total} ({over_100/max(total,1)*100:.1f}%)")
    print(f"    Recovery <0%:   {under_0}/{total} ({under_0/max(total,1)*100:.1f}%)")

    diag['concerns']['7c_plausibility'] = {
        'over_100_pct': float(over_100 / max(total, 1) * 100),
        'under_0_pct': float(under_0 / max(total, 1) * 100),
    }

    # --- 7d: Feature overlap (extraction vs causal) ---

    print("\n  7d: Feature overlap (extraction vs causal)")

    def safe_parse(x):
        return ast.literal_eval(x) if isinstance(x, str) else x

    _sel_path = DATA_DIR / "selected_features.json"
    with open(_sel_path) as f:
        _sel = json.load(f)

    extraction_feats = {}
    for layer_str, layer_data in _sel.items():
        if not isinstance(layer_data, dict): continue
        l = int(layer_str)
        for key_str, feat_data in layer_data.items():
            if isinstance(feat_data, dict) and 'features' in feat_data:
                extraction_feats[(l, key_str)] = set(feat_data['features'][:50])

    causal_top_parsed = causal_top.copy()
    causal_top_parsed['features_list'] = causal_top_parsed['top_5_features'].apply(safe_parse)

    causal_freq = defaultdict(lambda: Counter())
    for _, row in causal_top_parsed.iterrows():
        key = (row['layer'], f"{row['demographic']}_{row['domain']}")
        for f in row['features_list']:
            causal_freq[key][f] += 1

    overlap_stats = []
    for key in causal_freq:
        if key not in extraction_feats: continue
        top_causal = set([f for f, c in causal_freq[key].most_common(20)])
        intersection = top_causal & extraction_feats[key]
        jaccard = len(intersection) / max(len(top_causal | extraction_feats[key]), 1)
        overlap_stats.append({'layer': key[0], 'condition': key[1],
                              'n_intersection': len(intersection), 'jaccard': jaccard})

    if overlap_stats:
        overlap_df = pd.DataFrame(overlap_stats)
        print(f"    Mean Jaccard (causal top-20 vs extraction top-50): {overlap_df['jaccard'].mean():.3f}")
        for l in ANALYSIS_LAYERS:
            ov = overlap_df[overlap_df['layer'] == l]
            if len(ov) > 0:
                print(f"      L{l}: Jaccard={ov['jaccard'].mean():.3f}")

        diag['concerns']['7d_overlap'] = {
            'mean_jaccard': float(overlap_df['jaccard'].mean()),
            'by_layer': {str(l): float(overlap_df[overlap_df['layer'] == l]['jaccard'].mean())
                         for l in ANALYSIS_LAYERS if len(overlap_df[overlap_df['layer'] == l]) > 0},
        }

    # --- 7e: Dose-response monotonicity ---

    print("\n  7e: Dose-response")

    k_values = [5, 10, 20, 50]
    monotonic_count, total_layers = 0, 0
    dose_response = {}

    for l in ANALYSIS_LAYERS:
        means = []
        for k in k_values:
            vals = per_pair(causal_df[(causal_df['layer'] == l) & (causal_df['K'] == k)])
            patch = vals[vals['condition'] == 'patch_same']['recovery_ev'].dropna()
            means.append(patch.mean() if len(patch) > 0 else np.nan)
        dose_response[l] = means
        valid = [m for m in means if not np.isnan(m)]
        if len(valid) >= 3:
            total_layers += 1
            if all(valid[i] <= valid[i+1] for i in range(len(valid)-1)):
                monotonic_count += 1

    print(f"    Monotonic: {monotonic_count}/{total_layers} layers")

    diag['concerns']['7e_dose_response'] = {
        'monotonic': monotonic_count, 'total': total_layers,
    }

    # --- Combined causal figure ---

    fig, axes = plt.subplots(2, 3, figsize=(16, 10))

    # A: Recovery by layer
    ax = axes[0, 0]
    same_means, rand_means = [], []
    for l in ANALYSIS_LAYERS:
        pp = per_pair(causal_df[(causal_df['layer'] == l) & (causal_df['K'] == 50)])
        same = pp[pp['condition'] == 'patch_same']['recovery_ev'].dropna()
        rand = causal_df[(causal_df['layer'] == l) & (causal_df['K'] == 50) &
                         (causal_df['condition'] == 'patch_random_a')]['recovery_ev'].dropna()
        same_means.append(same.mean() * 100 if len(same) > 0 else 0)
        rand_means.append(rand.mean() * 100 if len(rand) > 0 else 0)
    x = np.arange(len(ANALYSIS_LAYERS)); w = 0.35
    ax.bar(x - w/2, same_means, w, color=['orange' if l in IT_LAYERS else '#2ecc71' for l in ANALYSIS_LAYERS],
           edgecolor='black', label='Same-pair')
    ax.bar(x + w/2, rand_means, w, color='#95a5a6', edgecolor='black', label='Random-A')
    ax.set_xticks(x); ax.set_xticklabels([f'L{l}\n{LAYER_DEPTH[l]}%' for l in ANALYSIS_LAYERS], fontsize=7)
    ax.set_ylabel('Recovery (%)'); ax.set_title('A. Same vs Random by Layer'); ax.legend(fontsize=7)

    # B: Dose-response
    ax = axes[0, 1]
    for l_plot in [5, 36]:
        if l_plot in dose_response:
            ax.plot(k_values, [v*100 for v in dose_response[l_plot]], 'o-',
                    label=f'L{l_plot} ({LAYER_DEPTH[l_plot]}%)', linewidth=2, markersize=7)
    ax.set_xlabel('K'); ax.set_ylabel('Recovery (%)'); ax.set_title('B. Dose-Response'); ax.legend(fontsize=8)
    ax.set_xticks(k_values)

    # C: Recovery distribution at L36
    ax = axes[0, 2]
    best_patch = per_pair(causal_df[(causal_df['layer'] == 36) & (causal_df['K'] == 50) &
                                    (causal_df['condition'] == 'patch_same')])
    if len(best_patch) > 0:
        rv = best_patch['recovery_ev'].dropna()
        ax.hist(rv * 100, bins=30, color='#2ecc71', edgecolor='black', alpha=0.7)
        ax.axvline(x=rv.mean() * 100, color='red', linestyle='--', linewidth=2, label=f'Mean={rv.mean():.1%}')
        ax.axvline(x=0, color='black', linewidth=1); ax.axvline(x=100, color='gray', linestyle=':', linewidth=1)
        ax.set_xlabel('Recovery (%)'); ax.set_ylabel('Count')
        ax.set_title(f'C. Recovery Distribution (L36, n={len(rv)})'); ax.legend(fontsize=7)

    # D: Exclusion rate
    ax = axes[1, 0]
    if len(causal_excl) > 0:
        excl_by_demo = causal_excl.groupby('demographic').size()
        valid_by_demo = per_pair(causal_df[causal_df['K'] == 50]).groupby('demographic')['pair_key'].nunique()
        total_by_demo = excl_by_demo.add(valid_by_demo, fill_value=0)
        rate_by_demo = excl_by_demo / total_by_demo * 100
        ax.barh(rate_by_demo.index, rate_by_demo.values, color='#e74c3c', edgecolor='black')
        ax.set_xlabel('Exclusion Rate (%)'); ax.set_xlim(0, 100)
        for i, (demo, rate) in enumerate(rate_by_demo.items()):
            ax.annotate(f'{rate:.0f}%', (rate, i), va='center', fontsize=9)
    else:
        ax.text(0.5, 0.5, 'No exclusion data', ha='center', va='center', transform=ax.transAxes)
    ax.set_title('D. Exclusion Rate by Demo')

    # E: Feature overlap
    ax = axes[1, 1]
    if overlap_stats:
        layer_jaccards = [overlap_df[overlap_df['layer'] == l]['jaccard'].mean()
                          if len(overlap_df[overlap_df['layer'] == l]) > 0 else 0
                          for l in ANALYSIS_LAYERS]
        ax.bar(range(len(ANALYSIS_LAYERS)), layer_jaccards,
               color=['orange' if l in IT_LAYERS else '#8e44ad' for l in ANALYSIS_LAYERS], edgecolor='black')
        ax.set_xticks(range(len(ANALYSIS_LAYERS)))
        ax.set_xticklabels([f'L{l}' for l in ANALYSIS_LAYERS], fontsize=8)
        ax.set_ylabel('Jaccard'); ax.set_title('E. Extraction↔Causal Overlap')
        for i, j in enumerate(layer_jaccards):
            ax.annotate(f'{j:.2f}', (i, j), ha='center', va='bottom', fontsize=7)
    else:
        ax.text(0.5, 0.5, 'No overlap data', ha='center', va='center', transform=ax.transAxes)
        ax.set_title('E. Feature Overlap')

    # F: Control hierarchy at L36
    ax = axes[1, 2]
    l36_df = causal_df[(causal_df['layer'] == 36) & (causal_df['K'] == 50)]
    l36_pp = per_pair(l36_df)
    hierarchy_data = []
    for cond, label, color in [
        ('patch_same', 'Same', '#2ecc71'), ('patch_random_a', 'Rand-A', '#7f8c8d'),
        ('patch_cross', 'Cross', '#f39c12'), ('patch_random_b', 'Rand-B', '#bdc3c7'),
    ]:
        src = l36_pp if cond == 'patch_same' else l36_df
        vals = src[src['condition'] == cond]['recovery_ev'].dropna()
        if len(vals) > 0:
            hierarchy_data.append((label, vals.mean(), vals.sem(), color))

    if hierarchy_data:
        labels, means, sems, colors = zip(*hierarchy_data)
        ax.bar(range(len(means)), [m*100 for m in means], yerr=[s*100 for s in sems],
               capsize=3, color=colors, edgecolor='black')
        ax.set_xticks(range(len(labels))); ax.set_xticklabels(labels, fontsize=9)
        ax.set_ylabel('Recovery (%)'); ax.axhline(y=0, color='black', linewidth=0.5)
        for i, (_, m, _, _) in enumerate(hierarchy_data):
            ax.annotate(f'{m:.1%}', (i, m*100), ha='center', va='bottom', fontsize=8, fontweight='bold')
    ax.set_title('F. Control Hierarchy (L36)')

    plt.suptitle('Concern 7: Causal Validation Diagnostics', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(OUTPUT_DIR / "figures" / "fig_concern7_causal_diagnostics.png")
    plt.savefig(OUTPUT_DIR / "figures" / "fig_concern7_causal_diagnostics.pdf")
    plt.close()
    print("    Saved: fig_concern7_causal_diagnostics")

    # Save updated diagnostics
    with open(OUTPUT_DIR / "data" / "concern_diagnostics.json", 'w') as f:
        json.dump(diag, f, indent=2)

    cau_str = "∞" if ratio_cau == float('inf') else f"{ratio_cau:.1f}x"
    print(f"\n  Summary: dissociation enc {ratio_enc:.1f}x early / cau {cau_str} late, "
          f"dose-response {monotonic_count}/{total_layers} monotonic")

print("\nDiagnostics done.")