<a href="https://colab.research.google.com/github/Kristina-26/LLM-interpretability/blob/main/Pop_ConflictQA_hidden_states_experiment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# !pip install -q transformers>=4.44.2 accelerate>=0.33.0 bitsandbytes>=0.44.0
# !pip install -q scipy matplotlib seaborn scikit-learn tqdm

In [None]:
import os
import re
import json
import random
import logging
from pathlib import Path
from dataclasses import dataclass, field
from typing import List, Dict, Any, Tuple, Optional
from collections import Counter
from datetime import datetime

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm

import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats as spstats
from scipy.special import erf
from scipy.stats import bootstrap

from transformers import AutoTokenizer, AutoModelForCausalLM
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split, StratifiedKFold, cross_val_score
from sklearn.pipeline import Pipeline
from sklearn.metrics import roc_auc_score, confusion_matrix, roc_curve

In [None]:
# Setup
sns.set_style("whitegrid")
plt.rcParams['figure.dpi'] = 100

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    datefmt='%H:%M:%S'
)
logger = logging.getLogger(__name__)


In [None]:
# Select the model you want to run:
# Options:
#   "meta-llama/Llama-3.1-8B"
#   "meta-llama/Llama-3.1-8B-Instruct"
#   "Qwen/Qwen3-14B"
#   "Qwen/Qwen2.5-7B-Instruct"
#   "google/gemma-3-12b-it"
#   "google/gemma-3-4b-it"

MODEL_NAME = "meta-llama/Llama-3.1-8B"

In [None]:
# config
CONFIG = {
    'model': {
        'name': MODEL_NAME,
        'use_4bit': False,
        'max_length': 256,
        'dtype': 'bfloat16',
    },
    'data': {
        'n_items_total': 6948, # 6948 in total
        'n_baseline_items': 600,  # From TRAIN only
        'test_size': 0.2,
        'random_state': 42,
    },
    'processing': {
        'batch_size': 8,
        'layer_slice_start': 1,
    },
    'analysis': {
        'last_layers_band': 4,
        'top_k_activations': 20,
        'threshold_percentile': 95,
    },
    'contrastive': {
        'enabled': True,
        'epochs': 40,
        'batch_size': 256,
        'learning_rate': 0.001,
        'weight_decay': 0.0001,
        'temperature': 0.07,
        'projection_dim': 64,
        'hidden_dim': 256,
        'last_k_layers': 16,
    },
    'bootstrap': {
        'n_resamples': 1000,
        'confidence_level': 0.95,
    },
    'seed': 42,
}

# Set seeds
random.seed(CONFIG['seed'])
np.random.seed(CONFIG['seed'])
torch.manual_seed(CONFIG['seed'])
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(CONFIG['seed'])


In [None]:
# utils
def normalize_text(s: str) -> str:
    """Normalize text for comparison."""
    if not s:
        return ""
    s = s.lower()
    s = re.sub(r'[^a-z0-9]+', ' ', s)
    s = re.sub(r'\s+', ' ', s)
    return s.strip()

def postprocess_short_phrase(text: str, max_words: int = 6) -> str:
    """Extract short answer from model output."""
    if not text:
        return ""
    t = text.strip()
    t = re.split(r'[\n\r]|[.?!]', t, maxsplit=1)[0]
    patterns = [r'^(the answer is|it is|it\'s|this is|answer:)\s+']
    for pattern in patterns:
        t = re.sub(pattern, '', t, flags=re.I)
    t = t.strip(" \"'""'':,;-")
    tokens = t.split()
    return " ".join(tokens[:max_words]) if tokens else ""

def gt_overlap(answer: str, gt_list: List[str]) -> bool:
    """Check if answer overlaps with ground truth."""
    answer_words = {w for w in normalize_text(answer).split() if len(w) >= 3}
    if not answer_words:
        return False
    for gt in gt_list:
        gt_words = {w for w in normalize_text(gt).split() if len(w) >= 3}
        if gt_words and (gt_words <= answer_words or len(answer_words & gt_words) > 0):
            return True
    return False

def short_context_sentence(s: str, max_chars: int = 200) -> str:
    """Extract first sentence from context."""
    if not s:
        return ""
    parts = re.split(r'(?<!\b[A-Z]\.)(?<=[.!?])\s+', s.strip(), maxsplit=1)
    first = parts[0]
    if len(first) > max_chars:
        first = first[:max_chars].rsplit(' ', 1)[0]
    return first

def classify_domain(question: str) -> str:
    """Classify question domain."""
    q_lower = question.lower()
    domain_keywords = {
        "geography": ["capital", "country", "city", "continent", "ocean"],
        "history": ["year", "century", "war", "president", "king", "queen"],
        "science": ["chemical", "planet", "element", "species", "theory"],
        "art": ["wrote", "author", "book", "novel", "poem", "play"],
        "math": ["plus", "+", "minus", "-", "times", "×", "divided"],
    }
    for domain, keywords in domain_keywords.items():
        if any(kw in q_lower for kw in keywords):
            return domain
    return "general"

def compute_confidence_interval(data: np.ndarray,
                               statistic_fn,
                               confidence_level: float = 0.95,
                               n_resamples: int = 1000,
                               method: str = 'percentile') -> Tuple[float, float, float]:
    """Compute statistic with confidence interval using bootstrap."""
    if len(data) == 0:
        return np.nan, np.nan, np.nan

    point_estimate = statistic_fn(data)

    try:
        rng = np.random.default_rng(CONFIG['seed'])
        result = bootstrap(
            (data,),
            lambda x: statistic_fn(x[0]),
            n_resamples=n_resamples,
            confidence_level=confidence_level,
            random_state=rng,
            method=method
        )
        return point_estimate, result.confidence_interval.low, result.confidence_interval.high
    except Exception as e:
        logger.warning(f"Bootstrap failed: {e}")
        return point_estimate, np.nan, np.nan


def bootstrap_auroc_ci(
    y_true: np.ndarray,
    scores: np.ndarray,
    n_resamples: int = 1000,
    confidence_level: float = 0.95,
    seed: int = 42,
) -> Tuple[float, float]:
    """
    Compute bootstrap CI for AUROC, skipping resamples with only one class.
    Returns:(ci_low, ci_high)
    """
    y_true = np.asarray(y_true)
    scores = np.asarray(scores)
    n = len(y_true)
    rng = np.random.default_rng(seed)

    aurocs = []

    for _ in range(n_resamples):
        idx = rng.integers(0, n, size=n)
        y_boot = y_true[idx]
        s_boot = scores[idx]

        # skip degenerate resamples with only one class
        if len(np.unique(y_boot)) < 2:
            continue

        try:
            aurocs.append(roc_auc_score(y_boot, s_boot))
        except Exception:
            continue

    if len(aurocs) == 0:
        # no valid resamples, cannot form a CI
        raise RuntimeError("All bootstrap resamples had only one class, cannot compute AUROC CI.")

    aurocs = np.array(aurocs, dtype=float)
    alpha = 1.0 - confidence_level
    ci_low = np.percentile(aurocs, 100 * alpha / 2.0)
    ci_high = np.percentile(aurocs, 100 * (1.0 - alpha / 2.0))

    logger.info(
        f"Bootstrap AUROC CI: used {len(aurocs)} valid resamples out of {n_resamples} "
        f"(skipped {n_resamples - len(aurocs)})"
    )

    return ci_low, ci_high


def analyze_misclassifications(test_items, y_test, predictions, scores):
    # indices of false positives and false negatives
    fp_idx = np.where((y_test == 0) & (predictions == 1))[0]
    fn_idx = np.where((y_test == 1) & (predictions == 0))[0]

    fp_domains = Counter([test_items[i].domain for i in fp_idx])
    fn_domains = Counter([test_items[i].domain for i in fn_idx])

    fp_scores = scores[fp_idx] if len(fp_idx) > 0 else np.array([])
    fn_scores = scores[fn_idx] if len(fn_idx) > 0 else np.array([])

    results = {
        "false_positives": {
            "count": int(len(fp_idx)),
            "domains": dict(fp_domains),
            "score_mean": float(fp_scores.mean()) if fp_scores.size > 0 else None,
            "score_std": float(fp_scores.std()) if fp_scores.size > 0 else None,
            "example_indices": fp_idx[:5].tolist(),
        },
        "false_negatives": {
            "count": int(len(fn_idx)),
            "domains": dict(fn_domains),
            "score_mean": float(fn_scores.mean()) if fn_scores.size > 0 else None,
            "score_std": float(fn_scores.std()) if fn_scores.size > 0 else None,
            "example_indices": fn_idx[:5].tolist(),
        },
    }
    return results

def summarize_layer_importance(auroc_by_layer: Dict[str, np.ndarray]) -> List[str]:
    summaries = []
    for feat_name, aurocs in auroc_by_layer.items():
        if np.all(np.isnan(aurocs)):
            continue
        best_layer = int(np.nanargmax(aurocs))
        best_score = float(aurocs[best_layer])

        L = len(aurocs)
        if best_layer < 0.3 * L:
            region = "early layers"
        elif best_layer < 0.7 * L:
            region = "middle layers"
        else:
            region = "late layers"

        summaries.append(
            f"{feat_name}: strongest signal in {region} "
            f"(layer {best_layer+1}, AUROC={best_score:.3f})"
        )
    return summaries

In [None]:
def load_results_and_items(results_path: str = "results.json"):
    """Load the saved results and reconstruct test items if possible."""
    with open(results_path, 'r') as f:
        results = json.load(f)
    return results

def detailed_error_analysis(
    test_items: List,
    test_results: List,
    y_test: np.ndarray,
    predictions: np.ndarray,
    scores: np.ndarray,
    method_name: str = "Method"
) -> Dict[str, Any]:
    """
    Perform detailed error analysis on misclassifications.

    Args:
        test_items: List of Item objects
        test_results: List of EvaluationResult objects
        y_test: True labels (1=flip, 0=non-flip)
        predictions: Binary predictions
        scores: Continuous scores from the model
        method_name: Name of the detection method

    Returns:
        Dictionary with detailed analysis
    """

    # Identify errors
    fp_idx = np.where((y_test == 0) & (predictions == 1))[0]
    fn_idx = np.where((y_test == 1) & (predictions == 0))[0]

    print("\n" + "="*70)
    print(f"DETAILED ERROR ANALYSIS: {method_name}")
    print("="*70)
    print(f"Total test samples: {len(y_test)}")
    print(f"  True flips (y=1): {y_test.sum()}")
    print(f"  True non-flips (y=0): {(y_test == 0).sum()}")
    print(f"False Positives: {len(fp_idx)} (predicted flip, actually non-flip)")
    print(f"False Negatives: {len(fn_idx)} (predicted non-flip, actually flip)")

    analysis = {
        "method": method_name,
        "total_samples": int(len(y_test)),
        "n_flips": int(y_test.sum()),
        "n_non_flips": int((y_test == 0).sum()),
        "false_positives": {
            "count": int(len(fp_idx)),
            "rate": float(len(fp_idx) / (y_test == 0).sum()) if (y_test == 0).sum() > 0 else 0,
            "examples": [],
            "patterns": {}
        },
        "false_negatives": {
            "count": int(len(fn_idx)),
            "rate": float(len(fn_idx) / y_test.sum()) if y_test.sum() > 0 else 0,
            "examples": [],
            "patterns": {}
        }
    }

    # =================================================================
    # FALSE POSITIVES (model says flip, but it's actually non-flip)
    # =================================================================
    if len(fp_idx) > 0:
        print("\n" + "-"*70)
        print("FALSE POSITIVES (Model incorrectly predicted flip)")
        print("-"*70)

        # Score statistics
        fp_scores = scores[fp_idx]
        print(f"\nScore statistics:")
        print(f"  Mean: {fp_scores.mean():.4f}")
        print(f"  Std:  {fp_scores.std():.4f}")
        print(f"  Min:  {fp_scores.min():.4f}")
        print(f"  Max:  {fp_scores.max():.4f}")

        # Domain distribution
        fp_domains = Counter([test_items[i].domain for i in fp_idx])
        print(f"\nDomain distribution:")
        for domain, count in fp_domains.most_common():
            pct = 100 * count / len(fp_idx)
            print(f"  {domain:15s}: {count:3d} ({pct:5.1f}%)")

        analysis["false_positives"]["patterns"]["domains"] = dict(fp_domains)
        analysis["false_positives"]["patterns"]["score_mean"] = float(fp_scores.mean())
        analysis["false_positives"]["patterns"]["score_std"] = float(fp_scores.std())

        # Analyze why model was misled (look at answer patterns)
        # Check if incorrect context actually made model change answer
        answer_changed = []
        for idx in fp_idx:
            res = test_results[idx]
            # Non-flip means: base was correct AND incorrect was also correct
            # OR base was wrong AND incorrect was also wrong
            changed = (res.answer_base != res.answer_incorrect)
            answer_changed.append(changed)

        n_changed = sum(answer_changed)
        print(f"\nAnswer behavior:")
        print(f"  Incorrect context changed answer: {n_changed}/{len(fp_idx)} "
              f"({100*n_changed/len(fp_idx):.1f}%)")
        analysis["false_positives"]["patterns"]["incorrect_changed_answer"] = int(n_changed)

        # Look at question length
        question_lengths = [len(test_items[i].question.split()) for i in fp_idx]
        print(f"\nQuestion length:")
        print(f"  Mean: {np.mean(question_lengths):.1f} words")
        print(f"  Median: {np.median(question_lengths):.1f} words")
        analysis["false_positives"]["patterns"]["question_length_mean"] = float(np.mean(question_lengths))

        # Look at context length
        context_lengths = [len(test_items[i].incorrect_ctx.split()) for i in fp_idx]
        print(f"\nIncorrect context length:")
        print(f"  Mean: {np.mean(context_lengths):.1f} words")
        print(f"  Median: {np.median(context_lengths):.1f} words")
        analysis["false_positives"]["patterns"]["context_length_mean"] = float(np.mean(context_lengths))

        # Show top 10 examples with highest scores (most confident FPs)
        print(f"\n{'='*70}")
        print("Top 10 FALSE POSITIVES (most confident errors):")
        print("="*70)

        sorted_fp_idx = fp_idx[np.argsort(fp_scores)[::-1]][:10]

        for rank, idx in enumerate(sorted_fp_idx, 1):
            item = test_items[idx]
            res = test_results[idx]
            score = scores[idx]

            example = {
                "rank": rank,
                "score": float(score),
                "question": item.question,
                "domain": item.domain,
                "ground_truth": item.gt_list,
                "answer_base": res.answer_base,
                "answer_correct": res.answer_correct,
                "answer_incorrect": res.answer_incorrect,
                "correct_ctx": item.correct_ctx,
                "incorrect_ctx": item.incorrect_ctx,
                "base_correct": res.base_correct,
                "correct_correct": res.correct_correct,
                "incorrect_correct": res.incorrect_correct,
            }

            print(f"\n#{rank} | Score: {score:.4f} | Domain: {item.domain}")
            print(f"  Question: {item.question}")
            print(f"  Ground truth: {item.gt_list}")
            print(f"  Base answer: '{res.answer_base}' (correct: {res.base_correct})")
            print(f"  Correct ctx answer: '{res.answer_correct}' (correct: {res.correct_correct})")
            print(f"  Incorrect ctx answer: '{res.answer_incorrect}' (correct: {res.incorrect_correct})")
            print(f"  Correct context: {item.correct_ctx[:100]}...")
            print(f"  Incorrect context: {item.incorrect_ctx[:100]}...")
            print(f"  → Model thought this was a flip, but both contexts gave {'correct' if res.incorrect_correct else 'wrong'} answers")

            if rank <= 5:  # Save first 5 examples
                analysis["false_positives"]["examples"].append(example)

    # =================================================================
    # FALSE NEGATIVES (model says non-flip, but it's actually a flip)
    # =================================================================
    if len(fn_idx) > 0:
        print("\n" + "="*70)
        print("FALSE NEGATIVES (Model failed to detect flip)")
        print("="*70)

        # Score statistics
        fn_scores = scores[fn_idx]
        print(f"\nScore statistics:")
        print(f"  Mean: {fn_scores.mean():.4f}")
        print(f"  Std:  {fn_scores.std():.4f}")
        print(f"  Min:  {fn_scores.min():.4f}")
        print(f"  Max:  {fn_scores.max():.4f}")

        # Domain distribution
        fn_domains = Counter([test_items[i].domain for i in fn_idx])
        print(f"\nDomain distribution:")
        for domain, count in fn_domains.most_common():
            pct = 100 * count / len(fn_idx)
            print(f"  {domain:15s}: {count:3d} ({pct:5.1f}%)")

        analysis["false_negatives"]["patterns"]["domains"] = dict(fn_domains)
        analysis["false_negatives"]["patterns"]["score_mean"] = float(fn_scores.mean())
        analysis["false_negatives"]["patterns"]["score_std"] = float(fn_scores.std())

        # Check base model performance
        base_was_correct = [test_results[i].base_correct for i in fn_idx]
        n_base_correct = sum(base_was_correct)
        print(f"\nBase model (no context) performance:")
        print(f"  Base was correct: {n_base_correct}/{len(fn_idx)} "
              f"({100*n_base_correct/len(fn_idx):.1f}%)")
        analysis["false_negatives"]["patterns"]["base_correct"] = int(n_base_correct)

        # Question and context lengths
        question_lengths = [len(test_items[i].question.split()) for i in fn_idx]
        context_lengths = [len(test_items[i].incorrect_ctx.split()) for i in fn_idx]

        print(f"\nQuestion length:")
        print(f"  Mean: {np.mean(question_lengths):.1f} words")
        print(f"  Median: {np.median(question_lengths):.1f} words")

        print(f"\nIncorrect context length:")
        print(f"  Mean: {np.mean(context_lengths):.1f} words")
        print(f"  Median: {np.median(context_lengths):.1f} words")

        analysis["false_negatives"]["patterns"]["question_length_mean"] = float(np.mean(question_lengths))
        analysis["false_negatives"]["patterns"]["context_length_mean"] = float(np.mean(context_lengths))

        # Analyze how similar incorrect answer is to correct answer
        # (maybe model gave "close" answer that still triggered overlap)
        close_answers = 0
        for idx in fn_idx:
            res = test_results[idx]
            # Both base and incorrect gave answers - check if they're similar
            if res.answer_base and res.answer_incorrect:
                base_words = set(res.answer_base.lower().split())
                inc_words = set(res.answer_incorrect.lower().split())
                if len(base_words & inc_words) > 0:
                    close_answers += 1

        print(f"\nAnswer similarity:")
        print(f"  Base and incorrect answers share words: {close_answers}/{len(fn_idx)} "
              f"({100*close_answers/len(fn_idx):.1f}%)")
        analysis["false_negatives"]["patterns"]["similar_answers"] = int(close_answers)

        # Show bottom 10 examples with lowest scores (most confident FNs)
        print(f"\n{'='*70}")
        print("Top 10 FALSE NEGATIVES (missed flips with lowest scores):")
        print("="*70)

        sorted_fn_idx = fn_idx[np.argsort(fn_scores)][:10]

        for rank, idx in enumerate(sorted_fn_idx, 1):
            item = test_items[idx]
            res = test_results[idx]
            score = scores[idx]

            example = {
                "rank": rank,
                "score": float(score),
                "question": item.question,
                "domain": item.domain,
                "ground_truth": item.gt_list,
                "answer_base": res.answer_base,
                "answer_correct": res.answer_correct,
                "answer_incorrect": res.answer_incorrect,
                "correct_ctx": item.correct_ctx,
                "incorrect_ctx": item.incorrect_ctx,
                "base_correct": res.base_correct,
                "correct_correct": res.correct_correct,
                "incorrect_correct": res.incorrect_correct,
            }

            print(f"\n#{rank} | Score: {score:.4f} | Domain: {item.domain}")
            print(f"  Question: {item.question}")
            print(f"  Ground truth: {item.gt_list}")
            print(f"  Base answer: '{res.answer_base}' (correct: {res.base_correct})")
            print(f"  Correct ctx answer: '{res.answer_correct}' (correct: {res.correct_correct})")
            print(f"  Incorrect ctx answer: '{res.answer_incorrect}' (correct: {res.incorrect_correct})")
            print(f"  Correct context: {item.correct_ctx[:100]}...")
            print(f"  Incorrect context: {item.incorrect_ctx[:100]}...")
            print(f"  → This WAS a flip (base correct, incorrect wrong), but model missed it")

            if rank <= 5:  # Save first 5 examples
                analysis["false_negatives"]["examples"].append(example)

    # =================================================================
    # COMPARATIVE PATTERNS
    # =================================================================
    print("\n" + "="*70)
    print("COMPARATIVE PATTERNS")
    print("="*70)

    if len(fp_idx) > 0 and len(fn_idx) > 0:
        fp_domains = Counter([test_items[i].domain for i in fp_idx])
        fn_domains = Counter([test_items[i].domain for i in fn_idx])

        all_domains = set(fp_domains.keys()) | set(fn_domains.keys())

        print("\nError rates by domain:")
        print(f"{'Domain':<15} {'FP Rate':<10} {'FN Rate':<10} {'Harder for':<15}")
        print("-" * 50)

        for domain in sorted(all_domains):
            # Total samples in this domain
            domain_indices = [i for i in range(len(test_items)) if test_items[i].domain == domain]
            n_domain = len(domain_indices)

            if n_domain == 0:
                continue

            fp_count = fp_domains.get(domain, 0)
            fn_count = fn_domains.get(domain, 0)

            fp_rate = fp_count / n_domain
            fn_rate = fn_count / n_domain

            harder_for = "FP" if fp_rate > fn_rate else "FN" if fn_rate > fp_rate else "Equal"

            print(f"{domain:<15} {fp_rate:<10.3f} {fn_rate:<10.3f} {harder_for:<15}")

        # Question length comparison
        fp_q_len = np.mean([len(test_items[i].question.split()) for i in fp_idx])
        fn_q_len = np.mean([len(test_items[i].question.split()) for i in fn_idx])

        print(f"\nAverage question length:")
        print(f"  False Positives: {fp_q_len:.1f} words")
        print(f"  False Negatives: {fn_q_len:.1f} words")
        print(f"  Difference: {abs(fp_q_len - fn_q_len):.1f} words "
              f"({'FP longer' if fp_q_len > fn_q_len else 'FN longer'})")

    return analysis


# =================================================================
# USAGE EXAMPLE
# =================================================================

def run_error_analysis_on_saved_data(
    test_items,
    test_results,
    y_test,
    test_scores_baseline,
    test_predictions_baseline,
    test_scores_contrastive=None,
    test_predictions_contrastive=None,
    output_dir=Path("results")
):
    """
    Run error analysis on saved test data.

    Args:
        test_items: List of Item objects from test set
        test_results: List of EvaluationResult objects from test set
        y_test: True labels
        test_scores_baseline: Scores from baseline method
        test_predictions_baseline: Predictions from baseline method
        test_scores_contrastive: Optional scores from contrastive method
        test_predictions_contrastive: Optional predictions from contrastive method
        output_dir: Directory to save results
    """

    output_dir = Path(output_dir)
    output_dir.mkdir(exist_ok=True)

    all_analyses = {}

    # Analyze baseline method
    print("\n" + "="*70)
    print("ANALYZING BASELINE METHOD (ΔΔ MAD)")
    print("="*70)

    baseline_analysis = detailed_error_analysis(
        test_items,
        test_results,
        y_test,
        test_predictions_baseline,
        test_scores_baseline,
        method_name="Baseline (ΔΔ MAD)"
    )
    all_analyses["baseline"] = baseline_analysis

    # Analyze contrastive method if available
    if test_scores_contrastive is not None and test_predictions_contrastive is not None:
        print("\n" + "="*70)
        print("ANALYZING CONTRASTIVE LEARNING METHOD")
        print("="*70)

        contrastive_analysis = detailed_error_analysis(
            test_items,
            test_results,
            y_test,
            test_predictions_contrastive,
            test_scores_contrastive,
            method_name="Contrastive Learning"
        )
        all_analyses["contrastive"] = contrastive_analysis

        # Compare the two methods
        print("\n" + "="*70)
        print("COMPARISON: BASELINE vs CONTRASTIVE")
        print("="*70)

        baseline_fp = set(np.where((y_test == 0) & (test_predictions_baseline == 1))[0])
        baseline_fn = set(np.where((y_test == 1) & (test_predictions_baseline == 0))[0])

        contrastive_fp = set(np.where((y_test == 0) & (test_predictions_contrastive == 1))[0])
        contrastive_fn = set(np.where((y_test == 1) & (test_predictions_contrastive == 0))[0])

        # Errors unique to each method
        baseline_only_fp = baseline_fp - contrastive_fp
        contrastive_only_fp = contrastive_fp - baseline_fp
        baseline_only_fn = baseline_fn - contrastive_fn
        contrastive_only_fn = contrastive_fn - baseline_fn

        print(f"\nFalse Positives:")
        print(f"  Baseline only: {len(baseline_only_fp)}")
        print(f"  Contrastive only: {len(contrastive_only_fp)}")
        print(f"  Both methods: {len(baseline_fp & contrastive_fp)}")

        print(f"\nFalse Negatives:")
        print(f"  Baseline only: {len(baseline_only_fn)}")
        print(f"  Contrastive only: {len(contrastive_only_fn)}")
        print(f"  Both methods: {len(baseline_fn & contrastive_fn)}")

        # What did contrastive fix?
        print(f"\nContrastive improvements:")
        print(f"  Fixed FPs: {len(baseline_only_fp)}")
        print(f"  Fixed FNs: {len(baseline_only_fn)}")
        print(f"  New FPs: {len(contrastive_only_fp)}")
        print(f"  New FNs: {len(contrastive_only_fn)}")

        all_analyses["comparison"] = {
            "baseline_only_fp": int(len(baseline_only_fp)),
            "contrastive_only_fp": int(len(contrastive_only_fp)),
            "baseline_only_fn": int(len(baseline_only_fn)),
            "contrastive_only_fn": int(len(contrastive_only_fn)),
            "fixed_fp": int(len(baseline_only_fp)),
            "fixed_fn": int(len(baseline_only_fn)),
        }

    # Save all analyses
    output_file = output_dir / "error_analysis.json"
    with open(output_file, 'w') as f:
        json.dump(all_analyses, f, indent=2)

    print(f"\n" + "="*70)
    print(f"Error analysis saved to: {output_file}")
    print("="*70)

    return all_analyses

In [None]:
# data structures
@dataclass
class Item:
    """Represents a single ConflictQA item."""
    question: str
    gt_list: List[str]
    correct_ctx: str
    incorrect_ctx: str
    domain: str = "general"
    metadata: Dict[str, Any] = field(default_factory=dict)

@dataclass
class EvaluationResult:
    """Stores evaluation results for an item."""
    item_idx: int
    answer_base: str
    answer_correct: str
    answer_incorrect: str
    base_correct: bool
    correct_correct: bool
    incorrect_correct: bool
    is_flip: bool  # base correct AND incorrect wrong


def build_items_from_raw(raw_data: List[Dict[str, Any]]) -> List[Item]:
    """Build Item objects from raw JSON data."""
    items = []
    skip_reasons = Counter()

    for idx, ex in enumerate(raw_data):
        try:
            q = str(ex.get("question", "")).strip()
            if not q:
                skip_reasons["missing_question"] += 1
                continue

            gt = ex.get("ground_truth", [])
            if isinstance(gt, str):
                gt_list = [gt] if gt.strip() else []
            elif isinstance(gt, list):
                gt_list = [str(x).strip() for x in gt if str(x).strip()]
            else:
                gt_list = []

            if not gt_list:
                skip_reasons["missing_ground_truth"] += 1
                continue

            mem = str(ex.get("memory_answer", ex.get("memory", ""))).strip()
            ctr = str(ex.get("counter_answer", ex.get("counter", ""))).strip()

            if not mem and not ctr:
                skip_reasons["missing_contexts"] += 1
                continue

            mem_aligns = gt_overlap(mem, gt_list) if mem else False
            ctr_aligns = gt_overlap(ctr, gt_list) if ctr else False

            if mem_aligns and not ctr_aligns:
                correct_ctx = short_context_sentence(mem)
                incorrect_ctx = short_context_sentence(ctr)
            elif ctr_aligns and not mem_aligns:
                correct_ctx = short_context_sentence(ctr)
                incorrect_ctx = short_context_sentence(mem)
            else:
                skip_reasons["ambiguous_alignment"] += 1
                continue

            items.append(Item(
                question=q,
                gt_list=gt_list,
                correct_ctx=correct_ctx,
                incorrect_ctx=incorrect_ctx,
                domain=classify_domain(q),
                metadata={"source_index": idx}
            ))

        except Exception as e:
            skip_reasons[f"error_{type(e).__name__}"] += 1
            if len(skip_reasons) <= 3:
                logger.warning(f"Skipping item {idx}: {e}")

    logger.info(f"Prepared {len(items)} items (skipped {sum(skip_reasons.values())})")

    if skip_reasons:
        logger.info("Skip reasons:")
        for reason, count in skip_reasons.most_common(5):
            logger.info(f"  {reason}: {count}")

    if items:
        domains = Counter([it.domain for it in items])
        logger.info("Domain distribution:")
        for domain, count in sorted(domains.items(), key=lambda x: -x[1]):
            logger.info(f"  {domain}: {count}")

    return items

In [None]:
# model wrapper
class ModelWrapper:
    """Wrapper for LLM with activation extraction."""

    def __init__(self, config: dict, token: Optional[str] = None):
        self.config = config['model']
        self.max_length = self.config['max_length']
        self.layer_slice = slice(config['processing']['layer_slice_start'], None)

        dtype_map = {
            "bfloat16": torch.bfloat16,
            "float16": torch.float16,
            "float32": torch.float32
        }
        self.dtype = dtype_map.get(self.config.get('dtype', 'bfloat16'), torch.bfloat16)

        self.has_cuda = torch.cuda.is_available()
        use_4bit = self.config['use_4bit'] and self.has_cuda

        logger.info(f"Loading tokenizer: {self.config['name']}")
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.config['name'],
            use_fast=True,
            token=token
        )

        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        logger.info(f"Loading model: {self.config['name']}")
        load_kwargs = {
            "device_map": "auto",
            "dtype": self.dtype if not use_4bit else None,
            "token": token
        }

        self.model = AutoModelForCausalLM.from_pretrained(
            self.config['name'],
            **load_kwargs
        ).eval()

        self.model.config.pad_token_id = self.tokenizer.pad_token_id

        logger.info("  Model loaded successfully")
        logger.info(f"  dtype: {self.dtype}")

    @property
    def device(self):
        return next(self.model.parameters()).device

    def build_prompt(self, question: str, context: str = "") -> str:
        """Build prompt from question and optional context."""
        instr = "give a short, concrete answer (fewer than 6 words)."
        if context:
            return f"{context}\n\nq: {question}\n{instr}\na:"
        else:
            return f"q: {question}\n{instr}\na:"

    def _find_question_start_idx(self, prompt: str, offset_mapping: List) -> int:
        """Find token index where question starts."""
        q_char_start = 0
        for pattern in [r'(?i)\bq:\s*', r'(?i)\bquestion:\s*']:
            matches = list(re.finditer(pattern, prompt))
            if matches:
                q_char_start = matches[-1].end()
                break

        for idx, (start, end) in enumerate(offset_mapping):
            if (end - start) > 0 and start >= q_char_start:
                return idx

        for idx, (start, end) in enumerate(offset_mapping):
            if (end - start) > 0:
                return idx

        return 0

    @torch.inference_mode()
    def encode_and_extract_hidden_batch(self, prompts: List[str]) -> Tuple[np.ndarray, np.ndarray]:
        """Extract hidden states for a batch of prompts.
        Returns:
            first_vecs: (batch_size, n_layers, hidden_dim)
            last_vecs: (batch_size, n_layers, hidden_dim)
        """
        enc_with_offsets = self.tokenizer(
            prompts,
            add_special_tokens=True,
            truncation=True,
            max_length=self.max_length,
            return_offsets_mapping=True
        )

        first_indices = []
        for prompt, offset_map in zip(prompts, enc_with_offsets["offset_mapping"]):
            first_idx = self._find_question_start_idx(prompt, offset_map)
            first_indices.append(first_idx)

        enc = self.tokenizer(
            prompts,
            return_tensors="pt",
            truncation=True,
            max_length=self.max_length,
            add_special_tokens=True,
            padding=True
        ).to(self.device)

        outputs = self.model(**enc, output_hidden_states=True)

        all_hidden = outputs.hidden_states[1:]  # Skip embedding layer
        selected_layers = list(range(len(all_hidden)))[self.layer_slice]
        hidden_stack = torch.stack([all_hidden[i] for i in selected_layers], dim=0)

        last_indices = enc["attention_mask"].sum(dim=1) - 1

        batch_size = len(prompts)
        n_layers = hidden_stack.shape[0]
        hidden_dim = hidden_stack.shape[3]

        first_vecs = np.zeros((batch_size, n_layers, hidden_dim), dtype=np.float32)
        last_vecs = np.zeros((batch_size, n_layers, hidden_dim), dtype=np.float32)

        for b in range(batch_size):
            first_vecs[b] = hidden_stack[:, b, first_indices[b], :].detach().to(torch.float32).cpu().numpy()
            last_vecs[b] = hidden_stack[:, b, last_indices[b].item(), :].detach().to(torch.float32).cpu().numpy()

        return first_vecs, last_vecs

    @torch.inference_mode()
    def generate_answer(self, prompt: str, max_new_tokens: int = 8) -> str:
        """Generate answer for a prompt."""
        enc = self.tokenizer(
            prompt,
            return_tensors="pt",
            truncation=True,
            max_length=self.max_length,
            add_special_tokens=True
        ).to(self.device)

        outputs = self.model.generate(
            **enc,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            pad_token_id=self.tokenizer.pad_token_id,
            eos_token_id=self.tokenizer.eos_token_id
        )

        full_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        generated = full_text[len(prompt):].strip()

        return postprocess_short_phrase(generated)

    def evaluate_item(self, item: Item) -> EvaluationResult:
        """Evaluate a single item across all conditions."""
        prompt_base = self.build_prompt(item.question, context="")
        prompt_correct = self.build_prompt(item.question, context=item.correct_ctx)
        prompt_incorrect = self.build_prompt(item.question, context=item.incorrect_ctx)

        answer_base = self.generate_answer(prompt_base)
        answer_correct = self.generate_answer(prompt_correct)
        answer_incorrect = self.generate_answer(prompt_incorrect)

        base_correct = gt_overlap(answer_base, item.gt_list)
        correct_correct = gt_overlap(answer_correct, item.gt_list)
        incorrect_correct = gt_overlap(answer_incorrect, item.gt_list)

        is_flip = base_correct and not incorrect_correct

        return EvaluationResult(
            item_idx=item.metadata.get("source_index", -1),
            answer_base=answer_base,
            answer_correct=answer_correct,
            answer_incorrect=answer_incorrect,
            base_correct=base_correct,
            correct_correct=correct_correct,
            incorrect_correct=incorrect_correct,
            is_flip=is_flip
        )

    @torch.inference_mode()
    def answer_logprob(self, prompt: str, answer: str) -> float:
        """compute average log-probability of 'answer' tokens given 'prompt'."""
        if not answer:
            return 0.0

        # tokenize separately to get prompt length
        enc_prompt = self.tokenizer(
            prompt,
            return_tensors="pt",
            add_special_tokens=True
        )
        enc_full = self.tokenizer(
            prompt + " " + answer,
            return_tensors="pt",
            add_special_tokens=True
        )

        input_ids = enc_full["input_ids"].to(self.device)
        attn_mask = enc_full["attention_mask"].to(self.device)

        # forward pass
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attn_mask
        )
        # logits shape (1, seq_len, vocab_size)
        logits = outputs.logits[0]
        logprobs = torch.log_softmax(logits, dim=-1)

        # how many tokens belong to the prompt
        prompt_len = enc_prompt["input_ids"].shape[1]

        # indices of answer tokens in enc_full
        answer_token_ids = input_ids[0, prompt_len:]  # 1d tensor

        # for token at position t, prob is from logits[t-1]
        lp = []
        for i, tok_id in enumerate(answer_token_ids):
            pos = prompt_len + i  # position of this token
            if pos - 1 < 0 or pos - 1 >= logprobs.shape[0]:
                continue
            lp.append(logprobs[pos - 1, tok_id].item())

        if not lp:
            return 0.0
        return float(np.mean(lp))

In [None]:
# statistics
class WelfordAccumulator:
    """Streaming mean and variance computation using Welford's algorithm."""

    def __init__(self, n_layers: int, hidden_dim: int):
        self.n_layers = n_layers
        self.hidden_dim = hidden_dim
        # Position 0: first token, Position 1: last token
        self.n = np.zeros((n_layers, 2), dtype=np.int64)
        self.mu = np.zeros((n_layers, 2, hidden_dim), dtype=np.float64)
        self.M2 = np.zeros((n_layers, 2, hidden_dim), dtype=np.float64)

    def update(self, first_vec: np.ndarray, last_vec: np.ndarray):
        """Update statistics with new observation."""
        for pos, vec in enumerate([first_vec, last_vec]):
            x = vec.astype(np.float64)
            self.n[:, pos] += 1
            delta = x - self.mu[:, pos, :]
            self.mu[:, pos, :] += delta / self.n[:, pos][:, None]
            delta2 = x - self.mu[:, pos, :]
            self.M2[:, pos, :] += delta * delta2

    def get_statistics(self):
        """Compute final mean and standard deviation."""
        variance = np.divide(
            self.M2,
            np.clip(self.n[:, :, None] - 1, 1, None)
        )
        std = np.sqrt(variance)

        return {
            'mu': self.mu.astype(np.float32),
            'std': std.astype(np.float32),
            'n': self.n.copy(),
            'n_layers': self.n_layers,
            'hidden_dim': self.hidden_dim
        }

def probability_integral_transform(
    vec: np.ndarray,
    mu: np.ndarray,
    std: np.ndarray,
    eps: float = 1e-8
) -> np.ndarray:
    """Apply Probability Integral Transform (PIT)."""
    z = (vec - mu) / (std + eps)
    u = 0.5 * (1.0 + erf(z / np.sqrt(2.0)))
    return u

def ks_statistic_uniform(u: np.ndarray) -> float:
    """Compute KS statistic vs Uniform(0,1)."""
    u = np.clip(np.asarray(u, dtype=float).ravel(), 0.0, 1.0)
    if len(u) == 0:
        return np.nan
    return float(spstats.kstest(u, 'uniform').statistic)

def mean_absolute_deviation(u: np.ndarray, center: float = 0.5) -> float:
    """Compute mean absolute deviation from center."""
    u = np.asarray(u, dtype=float).ravel()
    if len(u) == 0:
        return np.nan
    return float(np.mean(np.abs(u - center)))

def tail_weight(u: np.ndarray, threshold: float = 0.01) -> float:
    """Compute fraction of values in tails."""
    u = np.clip(np.asarray(u, dtype=float).ravel(), 0.0, 1.0)
    if len(u) == 0:
        return np.nan
    return float(((u <= threshold) | (u >= 1.0 - threshold)).mean())

def compute_itemwise_statistics(
    U: np.ndarray,
    metric: str = "ks",
    threshold: float = 0.01
) -> np.ndarray:
    """Compute statistics per item.
    Args:
        U: Array of shape (n_layers, n_items, hidden_dim)
        metric: One of "ks", "mad", "tail"
        threshold: Threshold for tail metric
    Returns:
        stats: Array of shape (n_layers, n_items)
    """
    n_layers, n_items, hidden_dim = U.shape
    stats = np.zeros((n_layers, n_items), dtype=float)

    for i in range(n_layers):
        for j in range(n_items):
            u = U[i, j, :]
            if metric == "ks":
                stats[i, j] = ks_statistic_uniform(u)
            elif metric == "mad":
                stats[i, j] = mean_absolute_deviation(u)
            elif metric == "tail":
                stats[i, j] = tail_weight(u, threshold)

    return stats

def neuron_extremeness(U: np.ndarray) -> np.ndarray:
    """
    compute 'extremeness' scores for PIT activations:
    for a single PIT value u, the KS statistic vs Uniform(0,1)
    with one sample is max(u, 1 - u).
    here we apply that to every activation.

    U shape: (n_layers, n_items, hidden_dim)
    returns array of the same shape.
    """
    U_clipped = np.clip(U, 0.0, 1.0)
    return np.maximum(U_clipped, 1.0 - U_clipped)


def select_top_k_neurons_globally(
    U_incorrect: np.ndarray,
    y: np.ndarray,
    k: int = 100,
    layer_indices: Optional[List[int]] = None,
) -> List[Tuple[int, int, float]]:
    """
    instead of averaging over layers, find the k most discriminative
    individual neurons across (a subset of) layers.

    U_incorrect: (n_layers, n_items, hidden_dim)  # PIT values for incorrect context
    y:           (n_items,)  binary labels (flip=1, non-flip=0)
    layer_indices: optional list of layer indices to search over
                   (e.g. only last 8 layers for speed)

    returns: list of (layer_idx, neuron_idx, auroc), sorted by auroc desc
    """
    n_layers, n_items, hidden_dim = U_incorrect.shape

    if layer_indices is None:
        layer_indices = list(range(n_layers))

    # precompute extremeness scores for all activations
    U_ext = neuron_extremeness(U_incorrect)  # same shape

    neuron_aurocs: List[Tuple[int, int, float]] = []

    for layer in layer_indices:
        layer_scores = U_ext[layer]  # shape (n_items, hidden_dim)
        for neuron in range(hidden_dim):
            s = layer_scores[:, neuron]

            # skip completely constant neurons
            if np.allclose(s, s[0]):
                continue

            try:
                auc = roc_auc_score(y, s)
                neuron_aurocs.append((layer, neuron, float(auc)))
            except Exception:
                continue

    neuron_aurocs.sort(key=lambda t: t[2], reverse=True)
    top_neurons = neuron_aurocs[:k]

    if top_neurons:
        l0, n0, a0 = top_neurons[0]
        logger.info(
            f"top neuron (global): layer {l0}, neuron {n0}, AUROC={a0:.3f} "
            f"(search over {len(layer_indices)} layers, k={k})"
        )
    else:
        logger.warning("no valid neurons found in select_top_k_neurons_globally")

    return top_neurons


def build_neuron_features(
    U: np.ndarray,
    selected_neurons: List[Tuple[int, int, float]],
) -> np.ndarray:
    """
    build a feature matrix from selected (layer, neuron) pairs.

    U: PIT activations, shape (n_layers, n_items, hidden_dim)
    selected_neurons: list of (layer_idx, neuron_idx, auroc)

    returns:
        X: (n_items, len(selected_neurons))  # each column = extremeness of one neuron
    """
    n_layers, n_items, hidden_dim = U.shape
    U_ext = neuron_extremeness(U)  # reuse same extremeness definition

    X = np.zeros((n_items, len(selected_neurons)), dtype=np.float32)
    for col, (layer, neuron, _) in enumerate(selected_neurons):
        X[:, col] = U_ext[layer, :, neuron]

    return X



In [None]:
# activation extraction
def extract_pit_values(
    items: List[Item],
    condition: str,
    model_wrapper: ModelWrapper,
    baseline_stats: Dict,
    batch_size: int = 8
) -> np.ndarray:
    """Extract PIT-transformed values for a condition.
    Args:
        items: List of items to process
        condition: One of "base", "correct", "incorrect"
        model_wrapper: Model wrapper instance
        baseline_stats: Baseline statistics dict
        batch_size: Batch size for processing
    Returns:
        U_last: Array of shape (n_layers, n_items, hidden_dim)
    """
    prompts = []
    for item in items:
        if condition == "base":
            prompt = model_wrapper.build_prompt(item.question, context="")
        elif condition == "correct":
            prompt = model_wrapper.build_prompt(item.question, context=item.correct_ctx)
        else:  # incorrect
            prompt = model_wrapper.build_prompt(item.question, context=item.incorrect_ctx)
        prompts.append(prompt)

    n_items = len(prompts)
    n_layers = baseline_stats['n_layers']
    hidden_dim = baseline_stats['hidden_dim']

    # Only store last token
    U_last = np.zeros((n_layers, n_items, hidden_dim), dtype=np.float32)

    for start_idx in tqdm(range(0, n_items, batch_size), desc=f"Extract {condition}"):
        end_idx = min(start_idx + batch_size, n_items)
        batch_prompts = prompts[start_idx:end_idx]

        # Extract hidden states
        _, last_vecs = model_wrapper.encode_and_extract_hidden_batch(batch_prompts)
        # last_vecs: (batch_size, n_layers, hidden_dim)
        # Transpose to (n_layers, batch_size, hidden_dim)
        last_vecs = np.transpose(last_vecs, (1, 0, 2))

        # Apply PIT transformation
        u_last = probability_integral_transform(
            last_vecs,
            baseline_stats['mu'][:, 1, :][:, None, :],  # Last token baseline
            baseline_stats['std'][:, 1, :][:, None, :]
        )

        U_last[:, start_idx:end_idx, :] = u_last

    return U_last

In [None]:
# feature engineering
def compute_basic_features(
    U_base: np.ndarray,
    U_correct: np.ndarray,
    U_incorrect: np.ndarray
) -> Dict[str, np.ndarray]:
    """Compute basic delta-delta features.
    Returns dict with keys like "deltadelta_ks", each value is (n_layers, n_items)
    """
    features = {}

    for metric_name, threshold in [("ks", None), ("mad", None),
                                   ("tail01", 0.01), ("tail05", 0.05)]:
        metric = "ks" if metric_name == "ks" else "mad" if metric_name == "mad" else "tail"
        thr = threshold if threshold else 0.01

        stat_base = compute_itemwise_statistics(U_base, metric, thr)
        stat_correct = compute_itemwise_statistics(U_correct, metric, thr)
        stat_incorrect = compute_itemwise_statistics(U_incorrect, metric, thr)

        delta_incorrect = stat_incorrect - stat_base
        delta_correct = stat_correct - stat_base
        delta_delta = delta_incorrect - delta_correct

        features[f"deltadelta_{metric_name}"] = delta_delta

    return features

def compute_enhanced_features(
    U_base: np.ndarray,
    U_correct: np.ndarray,
    U_incorrect: np.ndarray,
    last_k: int = 16
) -> np.ndarray:
    """Compute enhanced feature matrix for contrastive learning.
    Returns:
        X: Array of shape (n_items, n_features)
           where n_features = 7 * last_k (7 feature types across last k layers)
    """
    logger.info("Computing enhanced features...")

    # KS statistics per item
    logger.info("  - KS statistics")
    def ks_uniform_itemwise(U: np.ndarray) -> np.ndarray:
        n_layers, n_items, hidden_dim = U.shape
        stats = np.zeros((n_layers, n_items), dtype=float)
        for i in range(n_layers):
            for j in range(n_items):
                u = np.clip(U[i, j, :], 0.0, 1.0)
                stats[i, j] = spstats.kstest(u, 'uniform').statistic
        return stats

    KS_b = ks_uniform_itemwise(U_base)
    KS_w = ks_uniform_itemwise(U_incorrect)
    KS_c = ks_uniform_itemwise(U_correct)

    # Tail and MAD features
    logger.info("  - Tail weights and MAD")
    def tail_feats_itemwise(U: np.ndarray, thr: float):
        L, n, d = U.shape
        tails = ((U <= thr) | (U >= (1.0 - thr))).mean(axis=2)
        mad = np.abs(U - 0.5).mean(axis=2)
        return tails, mad

    tail01_b, mad_b = tail_feats_itemwise(U_base, 0.01)
    tail01_w, mad_w = tail_feats_itemwise(U_incorrect, 0.01)
    tail01_c, mad_c = tail_feats_itemwise(U_correct, 0.01)

    # Variance and skewness
    logger.info("  - Variance and skewness")
    def moments_itemwise(U):
        L, n, d = U.shape
        var = np.var(U, axis=2)
        skew = spstats.skew(U, axis=2)
        return var, skew

    var_b, skew_b = moments_itemwise(U_base)
    var_w, skew_w = moments_itemwise(U_incorrect)
    var_c, skew_c = moments_itemwise(U_correct)

    # Quantile ranges
    logger.info("  - Quantile ranges")
    def quantile_feats(U):
        L, n, d = U.shape
        q01 = np.percentile(U, 1, axis=2)
        q99 = np.percentile(U, 99, axis=2)
        return q99 - q01

    range_b = quantile_feats(U_base)
    range_w = quantile_feats(U_incorrect)
    range_c = quantile_feats(U_correct)

    # Delta-delta features
    logger.info("  - Computing delta-delta")
    dKS_w = KS_w - KS_b
    dKS_c = KS_c - KS_b
    ddKS = dKS_w - dKS_c

    dd_mad = (mad_w - mad_b) - (mad_c - mad_b)
    dd_tail01 = (tail01_w - tail01_b) - (tail01_c - tail01_b)
    dd_var = (var_w - var_b) - (var_c - var_b)
    dd_skew = (skew_w - skew_b) - (skew_c - skew_b)
    dd_range = (range_w - range_b) - (range_c - range_b)

    # Stack from last k layers
    logger.info(f"  - Stacking last {last_k} layers")
    mats = []
    for M in [dKS_w, ddKS, dd_mad, dd_tail01, dd_var, dd_skew, dd_range]:
        L = M.shape[0]
        start = max(0, L - last_k)
        mats.append(M[start:, :].T)  # (n_items, last_k)

    X = np.concatenate(mats, axis=1).astype(np.float32)

    logger.info(f"Feature matrix: {X.shape}")
    return X


In [None]:
# contrastive learning
class ProjectionHead(nn.Module):
    """Projection head for contrastive learning."""

    def __init__(self, d_in: int, d_h: int = 256, d_out: int = 64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_in, d_h),
            nn.ReLU(inplace=True),
            nn.Linear(d_h, d_out),
        )

    def forward(self, x):
        z = self.net(x)
        z = F.normalize(z, p=2, dim=-1)
        return z

def supervised_contrastive_loss(z: torch.Tensor, y: torch.Tensor, temperature: float = 0.07) -> torch.Tensor:
    """Supervised contrastive loss."""
    # Compute similarity matrix
    sim = (z @ z.t()) / temperature

    # Create positive mask
    y = y.view(-1, 1)
    mask_pos = (y == y.t()).float()
    eye = torch.eye(z.size(0), device=z.device)
    mask_pos = mask_pos - eye  # Remove self-similarity

    # Numerical stability
    sim = sim - sim.max(dim=1, keepdim=True)[0].detach()

    # Compute log probabilities
    exp_sim = torch.exp(sim) * (1.0 - eye)
    denom = exp_sim.sum(dim=1, keepdim=True).clamp_min(1e-12)
    log_prob = sim - torch.log(denom)

    # Average over positives
    num_pos = mask_pos.sum(dim=1).clamp_min(1.0)
    loss = -(mask_pos * log_prob).sum(dim=1) / num_pos

    return loss.mean()

def train_contrastive_model(
    X_train: np.ndarray,
    y_train: np.ndarray,
    config: dict,
    device: torch.device
) -> ProjectionHead:
    """Train contrastive projection head on training data."""
    # Standardize features
    scaler = StandardScaler(with_mean=True, with_std=True)
    X_train_std = scaler.fit_transform(X_train).astype(np.float32)

    # Create dataset
    dataset = torch.utils.data.TensorDataset(
        torch.from_numpy(X_train_std),
        torch.from_numpy(y_train.astype(np.int64))
    )

    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        drop_last=True
    )

    # Initialize model
    proj = ProjectionHead(
        d_in=X_train_std.shape[1],
        d_h=config['hidden_dim'],
        d_out=config['projection_dim']
    ).to(device)

    optimizer = torch.optim.AdamW(
        proj.parameters(),
        lr=config['learning_rate'],
        weight_decay=config['weight_decay']
    )

    # Training loop
    logger.info(f"Training contrastive model for {config['epochs']} epochs...")
    proj.train()

    best_train_auroc = 0

    for epoch in range(1, config['epochs'] + 1):
        epoch_loss = 0.0

        for xb, yb in loader:
            xb = xb.to(device)
            yb = yb.to(device)

            optimizer.zero_grad(set_to_none=True)
            z = proj(xb)
            loss = supervised_contrastive_loss(z, yb, temperature=config['temperature'])
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item() * xb.size(0)

        # Monitor progress
        if epoch % 10 == 0:
            avg_loss = epoch_loss / len(dataset)

            # Quick evaluation
            proj.eval()
            with torch.no_grad():
                Z_train = proj(torch.from_numpy(X_train_std).to(device)).cpu().numpy()
                Z_train = Z_train / (np.linalg.norm(Z_train, axis=1, keepdims=True) + 1e-12)

                proto0 = Z_train[y_train == 0].mean(axis=0)
                proto1 = Z_train[y_train == 1].mean(axis=0)
                proto0 /= (np.linalg.norm(proto0) + 1e-12)
                proto1 /= (np.linalg.norm(proto1) + 1e-12)

                scores = (Z_train @ proto1) - (Z_train @ proto0)
                try:
                    train_auroc = roc_auc_score(y_train, scores)
                    best_train_auroc = max(best_train_auroc, train_auroc)
                    logger.info(f"  Epoch {epoch:02d}/{config['epochs']}  "
                              f"Loss: {avg_loss:.4f}  Train AUROC: {train_auroc:.3f}")
                except:
                    logger.info(f"  Epoch {epoch:02d}/{config['epochs']}  Loss: {avg_loss:.4f}")

            proj.train()

    proj.eval()
    with torch.no_grad():
        Z_train = proj(torch.from_numpy(X_train_std).to(device)).cpu().numpy()
    Z_train = Z_train / (np.linalg.norm(Z_train, axis=1, keepdims=True) + 1e-12)

    proto0 = Z_train[y_train == 0].mean(axis=0)
    proto1 = Z_train[y_train == 1].mean(axis=0)
    proto0 /= (np.linalg.norm(proto0) + 1e-12)
    proto1 /= (np.linalg.norm(proto1) + 1e-12)

    return proj, scaler, proto0, proto1

def evaluate_contrastive_model(
    proj: ProjectionHead,
    scaler: StandardScaler,
    X_test: np.ndarray,
    y_test: np.ndarray,
    proto0: np.ndarray,
    proto1: np.ndarray,
    device: torch.device
) -> Tuple[np.ndarray, Dict[str, float]]:

    # standardize test features
    X_test_std = scaler.transform(X_test).astype(np.float32)

    # get embeddings
    proj.eval()
    with torch.no_grad():
        Z_test = proj(torch.from_numpy(X_test_std).to(device)).cpu().numpy()
    Z_test = Z_test / (np.linalg.norm(Z_test, axis=1, keepdims=True) + 1e-12)

    # prototype scores
    scores = (Z_test @ proto1) - (Z_test @ proto0)

    # debug: check scores
    logger.info("\n=== DEBUGGING CONTRASTIVE SCORES ===")
    logger.info(f"scores shape: {scores.shape}")
    logger.info(f"scores contains NaN: {np.isnan(scores).any()}")
    logger.info(f"scores contains Inf: {np.isinf(scores).any()}")
    if not np.isnan(scores).all():
        logger.info(
            f"scores min/max/mean: "
            f"{np.nanmin(scores):.6f} / {np.nanmax(scores):.6f} / {np.nanmean(scores):.6f}"
        )

    metrics: Dict[str, float] = {}

    try:
        auroc = roc_auc_score(y_test, scores)
        metrics['auroc'] = auroc

        ci_low, ci_high = bootstrap_auroc_ci(
            y_test,
            scores,
            n_resamples=CONFIG['bootstrap']['n_resamples'],
            confidence_level=CONFIG['bootstrap']['confidence_level'],
            seed=CONFIG['seed'],
        )

        metrics['auroc_ci_low'] = ci_low
        metrics['auroc_ci_high'] = ci_high

    except Exception as e:
        logger.warning(f"Could not compute AUROC or its CI: {e}")
        metrics['auroc'] = np.nan
        metrics['auroc_ci_low'] = np.nan
        metrics['auroc_ci_high'] = np.nan

    return scores, metrics

In [None]:
# visualization
def plot_qq_uniform(u: np.ndarray, title: str, save_path: Optional[Path] = None):
    """Plot QQ plot against uniform distribution."""
    u = np.clip(np.asarray(u, dtype=float).ravel(), 0, 1)

    if len(u) == 0 or np.std(u) < 1e-10:
        logger.warning(f"Skipping QQ plot for {title} - insufficient data")
        return

    q_theoretical = np.linspace(0, 1, 200)
    q_empirical = np.quantile(u, q_theoretical)

    plt.figure(figsize=(5, 5))
    plt.plot(q_theoretical, q_empirical, 'o-', markersize=3, label="Empirical", alpha=0.7)
    plt.plot([0, 1], [0, 1], '--', color='red', alpha=0.6, label="Uniform(0,1)")
    plt.xlabel("Theoretical quantiles")
    plt.ylabel("Empirical quantiles")
    plt.title(title)
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        logger.info(f"Saved: {save_path}")

    plt.show()
    plt.close()

def plot_score_distributions(
    scores: np.ndarray,
    labels: np.ndarray,
    threshold: Optional[float] = None,
    title: str = "Score Distributions",
    xlabel: str = "Score",
    save_path: Optional[Path] = None
):
    """Plot score distributions for positive and negative classes."""
    plt.figure(figsize=(7, 4))

    plt.hist(scores[labels == 0], bins=40, density=True, alpha=0.6,
             label=f"Non-flip (n={(labels == 0).sum()})", color='C0')
    plt.hist(scores[labels == 1], bins=40, density=True, alpha=0.6,
             label=f"Flip (n={(labels == 1).sum()})", color='C3')

    if threshold is not None:
        plt.axvline(threshold, linestyle='--', linewidth=2, color='red',
                   label=f"Threshold = {threshold:.4f}")

    plt.xlabel(xlabel)
    plt.ylabel("Density")
    plt.title(title)
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        logger.info(f"Saved: {save_path}")

    plt.show()
    plt.close()

def plot_auroc_by_layer(
    auroc_results: Dict[str, np.ndarray],
    save_path: Optional[Path] = None
):
    """Plot AUROC by layer for different features."""
    plt.figure(figsize=(10, 5))

    for feat_name, aurocs in auroc_results.items():
        label = feat_name.replace("deltadelta_", "ΔΔ ")
        plt.plot(aurocs, marker='o', markersize=4, label=label, alpha=0.8)

    plt.axhline(0.5, linestyle='--', color='black', alpha=0.3, label="Chance")
    plt.xlabel("Layer Index")
    plt.ylabel("AUROC")
    plt.title("Flip Detection Performance by Layer")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        logger.info(f"Saved: {save_path}")

    plt.show()
    plt.close()

def plot_roc_curve_with_ci(
    y_true: np.ndarray,
    scores: np.ndarray,
    title: str = "ROC Curve",
    save_path: Optional[Path] = None
):
    """Plot ROC curve."""
    fpr, tpr, _ = roc_curve(y_true, scores)
    auroc = roc_auc_score(y_true, scores)

    plt.figure(figsize=(6, 6))
    plt.plot(fpr, tpr, linewidth=2, label=f'AUROC = {auroc:.3f}')
    plt.plot([0, 1], [0, 1], '--', color='gray', alpha=0.5, label='Chance')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(title)
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        logger.info(f"Saved: {save_path}")

    plt.show()
    plt.close()

In [None]:
# create output directory
output_dir = Path("/content/results")
output_dir.mkdir(exist_ok=True)
logger.info(f"Output directory: {output_dir}")

In [None]:
# authentication
logger.info("\n" + "="*70)
logger.info("HUGGINGFACE AUTHENTICATION")
logger.info("="*70)

try:
    from huggingface_hub import login, HfFolder

    hf_token = HfFolder.get_token()

    if hf_token:
        logger.info("HuggingFace token found")
    else:
        logger.info("No token found. Please login:")
        login()
        hf_token = HfFolder.get_token()

    if not hf_token:
        raise RuntimeError("HuggingFace authentication required for Llama models")

except Exception as e:
    logger.error(f"Authentication failed: {e}")
    logger.info("Please set your token with: from huggingface_hub import login; login()")
    raise

In [None]:
# data loading
logger.info("\n" + "="*70)
logger.info("DATA LOADING")
logger.info("="*70)


try:
    from google.colab import files
    logger.info("Upload ConflictQA JSON file")
    uploaded = files.upload()

    if uploaded:
        json_path = next(iter(uploaded.keys()))
        logger.info(f"File uploaded: {json_path}")

        # Read JSON
        try:
            with open(json_path, "r", encoding="utf-8") as f:
                raw = json.load(f)
            if not isinstance(raw, list):
                raw = [raw]
        except json.JSONDecodeError:
            logger.info("Trying JSONL format...")
            raw = []
            with open(json_path, "r", encoding="utf-8") as f:
                for line in f:
                    if line.strip():
                        raw.append(json.loads(line))

        logger.info(f"Read {len(raw)} raw items")

        items = build_items_from_raw(raw)

        if items:
            random.shuffle(items)
            logger.info(f"Loaded {len(items)} valid items from upload")
        else:
            logger.warning("No valid items found, using demo data")
    else:
        logger.info("No file uploaded")

except Exception as e:
    logger.warning(f"Upload failed: {e}")

logger.info(f"\nTotal items available: {len(items)}")

# Sample if needed
n_items_total = CONFIG['data']['n_items_total']
if len(items) > n_items_total:
    items = random.sample(items, n_items_total)
    logger.info(f"Sampled {len(items)} items")

if len(items) == 0:
    raise ValueError("No items to process")

In [None]:
# model initialization
logger.info("\n" + "="*70)
logger.info("MODEL INITIALIZATION")
logger.info("="*70)

model_wrapper = ModelWrapper(CONFIG, token=hf_token)

# Get model dimensions
prompt = model_wrapper.build_prompt(items[0].question, context="")
_, last_vec = model_wrapper.encode_and_extract_hidden_batch([prompt])
n_layers, hidden_dim = last_vec.shape[1], last_vec.shape[2]

logger.info(f"  Model dimensions: {n_layers} layers, {hidden_dim} hidden dim")

In [None]:
# label computation (for stratified split)
logger.info("\n" + "="*70)
logger.info("COMPUTING LABELS FOR STRATIFICATION")
logger.info("="*70)

eval_results = []
for item in tqdm(items, desc="Evaluating items"):
    result = model_wrapper.evaluate_item(item)
    eval_results.append(result)

y_all = np.array([int(r.is_flip) for r in eval_results])

n_flips = y_all.sum()
n_total = len(y_all)

logger.info(f"  Labels computed:")
logger.info(f"  Total items: {n_total}")
logger.info(f"  Flips (y=1): {n_flips} ({100*n_flips/n_total:.1f}%)")
logger.info(f"  Non-flips (y=0): {n_total - n_flips} ({100*(n_total-n_flips)/n_total:.1f}%)")

if n_flips < 10 or (n_total - n_flips) < 10:
    logger.warning("Very few samples in one class")

In [None]:
# train-test split
logger.info("\n" + "="*70)
logger.info("TRAIN/TEST SPLIT (STRATIFIED)")
logger.info("="*70)

train_idx, test_idx = train_test_split(
    np.arange(len(items)),
    test_size=CONFIG['data']['test_size'],
    stratify=y_all,
    random_state=CONFIG['data']['random_state']
)

train_items = [items[i] for i in train_idx]
test_items = [items[i] for i in test_idx]

y_train = y_all[train_idx]
y_test = y_all[test_idx]

train_results = [eval_results[i] for i in train_idx]
test_results = [eval_results[i] for i in test_idx]

logger.info(f"  Data split:")
logger.info(f"  Train: {len(train_items)} items ({y_train.sum()} flips, {100*y_train.mean():.1f}%)")
logger.info(f"  Test:  {len(test_items)} items ({y_test.sum()} flips, {100*y_test.mean():.1f}%)")

# Sanity check
assert len(set(train_idx) & set(test_idx)) == 0, "Train/test overlap detected"
logger.info("  No train/test overlap confirmed")

logger.info("\n" + "="*70)
logger.info("CONFIDENCE-BASED BASELINE (INCORRECT CONTEXT, TEST SET)")
logger.info("="*70)

conf_scores = []
y_conf = []

for item, res in zip(test_items, test_results):
    prompt_incorrect = model_wrapper.build_prompt(item.question, context=item.incorrect_ctx)
    # reuse the already generated answer_incorrect from res
    answer_incorrect = res.answer_incorrect

    conf = model_wrapper.answer_logprob(prompt_incorrect, answer_incorrect)
    conf_scores.append(conf)
    y_conf.append(int(res.is_flip))

conf_scores = np.array(conf_scores, dtype=float)
y_conf = np.array(y_conf, dtype=int)

# flips should have lower confidence
try:
    auroc_conf = roc_auc_score(y_conf, -conf_scores)
except Exception as e:
    logger.warning(f"could not compute AUROC for confidence baseline: {e}")
    auroc_conf = np.nan

logger.info(f"  confidence baseline AUROC (flip vs non-flip, TEST): {auroc_conf:.3f}")

In [None]:
# baseline statistics (TRAIN only)
logger.info("\n" + "="*70)
logger.info("COMPUTING BASELINE STATISTICS (TRAIN SET ONLY)")
logger.info("="*70)

accumulator = WelfordAccumulator(n_layers, hidden_dim)

n_baseline = min(CONFIG['data']['n_baseline_items'], len(train_items))
logger.info(f"Using {n_baseline} training items for baseline")

for item in tqdm(train_items[:n_baseline], desc="Baseline"):
    prompt = model_wrapper.build_prompt(item.question, context="")
    _, last_vec = model_wrapper.encode_and_extract_hidden_batch([prompt])
    # Note: We only use last token for baseline, not first
    accumulator.update(np.zeros_like(last_vec[0]), last_vec[0])

baseline_stats = accumulator.get_statistics()
logger.info("  Baseline statistics computed from TRAIN set only")


In [None]:
# activation extraction
logger.info("\n" + "="*70)
logger.info("EXTRACTING ACTIVATIONS")
logger.info("="*70)

batch_size = CONFIG['processing']['batch_size']

# Train set
logger.info("Processing TRAIN set...")
U_train_base = extract_pit_values(train_items, "base", model_wrapper, baseline_stats, batch_size)
U_train_correct = extract_pit_values(train_items, "correct", model_wrapper, baseline_stats, batch_size)
U_train_incorrect = extract_pit_values(train_items, "incorrect", model_wrapper, baseline_stats, batch_size)

# Test set
logger.info("Processing TEST set...")
U_test_base = extract_pit_values(test_items, "base", model_wrapper, baseline_stats, batch_size)
U_test_correct = extract_pit_values(test_items, "correct", model_wrapper, baseline_stats, batch_size)
U_test_incorrect = extract_pit_values(test_items, "incorrect", model_wrapper, baseline_stats, batch_size)

logger.info("  Activations extracted:")
logger.info(f"  Train shape: {U_train_base.shape}")
logger.info(f"  Test shape:  {U_test_base.shape}")

In [None]:
# NEURON-LEVEL EXPERIMENT
logger.info("\n" + "="*70)
logger.info("NEURON-LEVEL FEATURES (INCORRECT CONTEXT ONLY)")
logger.info("="*70)

# to keep it computationally reasonable, search only in the last K layers
LAST_K_NEURON = 16
L_total = U_train_incorrect.shape[0]
layer_indices = list(range(max(0, L_total - LAST_K_NEURON), L_total))

logger.info(
    f"searching top neurons in layers {layer_indices[0]}..{layer_indices[-1]} "
    f"({len(layer_indices)} layers, hidden_dim={U_train_incorrect.shape[2]})"
)

# 1) select top-k most discriminative neurons on TRAIN set
TOP_K_NEURONS = 100  # you can lower this if it is too slow
top_neurons = select_top_k_neurons_globally(
    U_incorrect=U_train_incorrect,
    y=y_train,
    k=TOP_K_NEURONS,
    layer_indices=layer_indices,
)

logger.info(f"selected {len(top_neurons)} neurons for neuron-level probe")

if len(top_neurons) > 0:
    # 2) build neuron-extremeness features for TRAIN and TEST
    X_train_neuron = build_neuron_features(U_train_incorrect, top_neurons)
    X_test_neuron = build_neuron_features(U_test_incorrect, top_neurons)

    logger.info(f"neuron-level feature shapes: train={X_train_neuron.shape}, test={X_test_neuron.shape}")

    # 3) standardize + logistic regression
    scaler_neuron = StandardScaler()
    X_train_std = scaler_neuron.fit_transform(X_train_neuron)
    X_test_std = scaler_neuron.transform(X_test_neuron)

    clf_neuron = LogisticRegression(
        penalty='l2',
        C=1.0,
        solver='liblinear',
        class_weight='balanced',
        max_iter=500,
        random_state=CONFIG['seed'],
    )

    clf_neuron.fit(X_train_std, y_train)
    neuron_scores_test = clf_neuron.decision_function(X_test_std)

    try:
        auroc_neuron = roc_auc_score(y_test, neuron_scores_test)
        ci_low_neuron, ci_high_neuron = bootstrap_auroc_ci(
            y_test,
            neuron_scores_test,
            n_resamples=CONFIG['bootstrap']['n_resamples'],
            confidence_level=CONFIG['bootstrap']['confidence_level'],
            seed=CONFIG['seed'],
        )

        logger.info("\n" + "-"*70)
        logger.info("NEURON-LEVEL PROBE RESULTS (TEST SET)")
        logger.info("-"*70)
        logger.info(f"  AUROC: {auroc_neuron:.3f} [{ci_low_neuron:.3f}, {ci_high_neuron:.3f}]")
        logger.info("-"*70)

        # optional: quick histogram / roc
        plot_score_distributions(
            neuron_scores_test,
            y_test,
            title=f"Neuron-level Scores (Test Set, AUROC={auroc_neuron:.3f})",
            xlabel="Neuron extremeness (max(u, 1-u))",
            save_path=output_dir / "neuron_scores_test.png",
        )

        plot_roc_curve_with_ci(
            y_test,
            neuron_scores_test,
            title="Neuron-level ROC Curve (Test Set)",
            save_path=output_dir / "neuron_roc_test.png",
        )

    except Exception as e:
        logger.warning(f"could not compute AUROC for neuron-level probe: {e}")

    # ==============================================================
    # NEURON-LEVEL CONTRASTIVE LEARNING (using the same neuron features)
    # ==============================================================
    if CONFIG['contrastive']['enabled'] and len(np.unique(y_train)) >= 2:
        logger.info("\n" + "-"*70)
        logger.info("NEURON-LEVEL CONTRASTIVE LEARNING")
        logger.info("-"*70)

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        logger.info(f"training neuron-level contrastive model on device: {device}")

        # re-use X_train_neuron / X_test_neuron directly
        proj_neuron, scaler_neuron_con, proto0_neuron, proto1_neuron = train_contrastive_model(
            X_train_neuron,
            y_train,
            CONFIG['contrastive'],
            device,
        )

        # evaluate on TEST set
        neuron_contrastive_scores_test, neuron_contrastive_metrics = evaluate_contrastive_model(
            proj_neuron,
            scaler_neuron_con,
            X_test_neuron,
            y_test,
            proto0_neuron,
            proto1_neuron,
            device,
        )

        if not np.isnan(neuron_contrastive_metrics['auroc']):
            logger.info(
                f"  neuron-level contrastive AUROC: "
                f"{neuron_contrastive_metrics['auroc']:.4f} "
                f"[{neuron_contrastive_metrics['auroc_ci_low']:.4f}, "
                f"{neuron_contrastive_metrics['auroc_ci_high']:.4f}]"
            )

            # visualize
            plot_score_distributions(
                neuron_contrastive_scores_test,
                y_test,
                title=(
                    "Neuron-level Contrastive Scores "
                    f"(Test Set, AUROC={neuron_contrastive_metrics['auroc']:.3f})"
                ),
                xlabel="Prototype score (cosine1 − cosine0)",
                save_path=output_dir / "neuron_contrastive_scores_test.png",
            )

            plot_roc_curve_with_ci(
                y_test,
                neuron_contrastive_scores_test,
                title="Neuron-level Contrastive ROC Curve (Test Set)",
                save_path=output_dir / "neuron_contrastive_roc_test.png",
            )

            # simple 0-threshold classification for error analysis
            neuron_contrastive_pred = (neuron_contrastive_scores_test > 0).astype(int)
            neuron_contrastive_errors = analyze_misclassifications(
                test_items, y_test, neuron_contrastive_pred, neuron_contrastive_scores_test
            )
            logger.info("\nERROR ANALYSIS (Neuron-level Contrastive, TEST set)")
            logger.info(
                f"  false positives: {neuron_contrastive_errors['false_positives']['count']}  "
                f"domains: {neuron_contrastive_errors['false_positives']['domains']}"
            )
            logger.info(
                f"  false negatives: {neuron_contrastive_errors['false_negatives']['count']}  "
                f"domains: {neuron_contrastive_errors['false_negatives']['domains']}"
            )
        else:
            logger.warning("  neuron-level contrastive AUROC is NaN; skipping plots")

else:
    logger.warning("no neurons selected, skipping neuron-level probe")


In [None]:
# visualization (TEST set)
logger.info("\n" + "="*70)
logger.info("GENERATING VISUALIZATIONS (TEST SET)")
logger.info("="*70)

# QQ plots for last layer
layer_idx = -1
plot_qq_uniform(
    U_test_base[layer_idx].ravel(),
    "QQ Plot: Base Condition (Test Set, Last Layer)",
    output_dir / "qq_base_test.png"
)

plot_qq_uniform(
    U_test_incorrect[layer_idx].ravel(),
    "QQ Plot: Incorrect Context (Test Set, Last Layer)",
    output_dir / "qq_incorrect_test.png"
)

# Histogram comparison
plt.figure(figsize=(7, 4))
plt.hist(U_test_base[layer_idx].ravel(), bins=30, density=True, alpha=0.5,
          label="Base", color='C0')
plt.hist(U_test_correct[layer_idx].ravel(), bins=30, density=True, alpha=0.5,
          label="Correct", color='C2')
plt.hist(U_test_incorrect[layer_idx].ravel(), bins=30, density=True, alpha=0.5,
          label="Incorrect", color='C3')
plt.axhline(1.0, linestyle='--', color='black', alpha=0.3, label="Uniform(0,1)")
plt.xlabel("PIT value")
plt.ylabel("Density")
plt.title("PIT Distributions (Test Set, Last Layer)")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(output_dir / "pit_histogram_test.png", dpi=150, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
# basic features and baseline classifier
logger.info("\n" + "="*70)
logger.info("BASELINE CLASSIFIER (SIMPLE FEATURES)")
logger.info("="*70)

# Compute features
train_features = compute_basic_features(U_train_base, U_train_correct, U_train_incorrect)
test_features = compute_basic_features(U_test_base, U_test_correct, U_test_incorrect)

# Layer-wise AUROC on test set
logger.info("\nPer-layer performance (TEST SET):")
auroc_results_test = {}

for feat_name in train_features.keys():
    test_feat_matrix = test_features[feat_name]

    n_layers_feat = test_feat_matrix.shape[0]
    aurocs = np.full(n_layers_feat, np.nan)

    for i in range(n_layers_feat):
        if len(np.unique(y_test)) >= 2:
            try:
                aurocs[i] = roc_auc_score(y_test, test_feat_matrix[i, :])
            except:
                pass

    auroc_results_test[feat_name] = aurocs

    best_layer = int(np.nanargmax(aurocs))
    best_auroc = aurocs[best_layer]
    logger.info(f"  {feat_name:20s}: Best = Layer {best_layer+1:02d}, AUROC = {best_auroc:.3f}")

# Plot AUROC by layer
plot_auroc_by_layer(auroc_results_test, output_dir / "auroc_by_layer_test.png")

layer_summaries = summarize_layer_importance(auroc_results_test)
logger.info("\nLAYER-WISE INTERPRETATION (TEST SET):")
for line in layer_summaries:
    logger.info("  " + line)

# Aggregate last layers and evaluate
band = CONFIG['analysis']['last_layers_band']

def select_top_layers_by_train_auroc(
    train_feat_matrix: np.ndarray,
    y_train: np.ndarray,
    k: int = 8
) -> List[int]:
    """
    select top-k layers based on AUROC computed on the TRAIN set only.

    train_feat_matrix: shape (n_layers, n_items)
    y_train:           shape (n_items,)
    returns: sorted list of layer indices (0-based)
    """
    n_layers = train_feat_matrix.shape[0]
    aurocs = np.full(n_layers, np.nan, dtype=float)

    # compute AUROC per layer on TRAIN
    if len(np.unique(y_train)) < 2:
        logger.warning("cannot select layers by AUROC: only one class in y_train")
        # fallback: last k layers
        start = max(0, n_layers - k)
        return list(range(start, n_layers))

    for i in range(n_layers):
        layer_scores = train_feat_matrix[i, :]
        # skip constant layers
        if np.allclose(layer_scores, layer_scores[0]):
            continue
        try:
            aurocs[i] = roc_auc_score(y_train, layer_scores)
        except Exception:
            # keep NaN if AUROC fails
            pass

    valid_idx = np.where(~np.isnan(aurocs))[0]

    if len(valid_idx) == 0:
        logger.warning("no valid AUROCs for layer selection; using last k layers as fallback")
        start = max(0, n_layers - k)
        return list(range(start, n_layers))

    # pick top-k layers among valid indices
    top_idx = valid_idx[np.argsort(aurocs[valid_idx])[-k:]]
    top_idx = sorted(top_idx.tolist())

    logger.info("Selected layers for ΔΔ MAD aggregation (by TRAIN AUROC):")
    logger.info(
        "  " + ", ".join(f"L{idx+1}={aurocs[idx]:.3f}" for idx in top_idx)
    )

    return top_idx

# Use MAD feature aggregated over top layers chosen on TRAIN
train_mad_feat = train_features["deltadelta_mad"]
test_mad_feat  = test_features["deltadelta_mad"]

selected_layers = select_top_layers_by_train_auroc(
    train_mad_feat,
    y_train,
    k=band
)

# average ΔΔ MAD over the selected layers
train_scores = train_mad_feat[selected_layers, :].mean(axis=0)
test_scores  = test_mad_feat[selected_layers, :].mean(axis=0)

# Select threshold on TRAIN set
train_flip_mask = (y_train == 1)
train_neg_scores = train_scores[~train_flip_mask]

if len(train_neg_scores) > 0:
    threshold = np.percentile(train_neg_scores, CONFIG['analysis']['threshold_percentile'])
else:
    threshold = np.median(train_scores)

logger.info(f"\nThreshold (selected on TRAIN): {threshold:.6f}")

# Evaluate on TEST set
test_predictions = (test_scores > threshold).astype(int)

cm = confusion_matrix(y_test, test_predictions, labels=[0, 1])
tn, fp, fn, tp = cm.ravel() if cm.shape == (2, 2) else (0, 0, 0, 0)

specificity = tn / (tn + fp) if (tn + fp) > 0 else np.nan
sensitivity = tp / (tp + fn) if (tp + fn) > 0 else np.nan
precision = tp / (tp + fp) if (tp + fp) > 0 else np.nan
f1 = 2 * precision * sensitivity / (precision + sensitivity) if (precision + sensitivity) > 0 else np.nan

logger.info("\n=== DEBUGGING BASELINE SCORES ===")
logger.info(f"test_scores shape: {test_scores.shape}")
logger.info(f"test_scores contains NaN: {np.isnan(test_scores).any()}")
logger.info(f"test_scores contains Inf: {np.isinf(test_scores).any()}")
logger.info(f"test_scores min/max/mean: {np.nanmin(test_scores):.6f} / {np.nanmax(test_scores):.6f} / {np.nanmean(test_scores):.6f}")
logger.info(f"test_scores unique values: {len(np.unique(test_scores))}")
logger.info(f"y_test distribution: {np.bincount(y_test)}")

try:
    auroc_baseline = roc_auc_score(y_test, test_scores)

    auroc_ci_low, auroc_ci_high = bootstrap_auroc_ci(
        y_test,
        test_scores,
        n_resamples=CONFIG['bootstrap']['n_resamples'],
        confidence_level=CONFIG['bootstrap']['confidence_level'],
        seed=CONFIG['seed'],
    )

except Exception as e:
    logger.warning(f"Could not compute AUROC or its CI: {e}")
    auroc_baseline = np.nan
    auroc_ci_low = np.nan
    auroc_ci_high = np.nan

logger.info("\n" + "="*70)
logger.info(f"BASELINE RESULTS (ΔΔ MAD, {band}-layer aggregate, TEST SET)")
logger.info("="*70)
logger.info(f"  AUROC:       {auroc_baseline:.3f} [{auroc_ci_low:.3f}, {auroc_ci_high:.3f}]")
logger.info(f"  Specificity: {specificity:.3f}")
logger.info(f"  Sensitivity: {sensitivity:.3f}")
logger.info(f"  Precision:   {precision:.3f}")
logger.info(f"  F1 Score:    {f1:.3f}")
logger.info("="*70)

# Plot distributions
plot_score_distributions(
    test_scores,
    y_test,
    threshold=threshold,
    title=f"Baseline Score Distributions (Test Set, AUROC={auroc_baseline:.3f})",
    xlabel="ΔΔ MAD Score",
    save_path=output_dir / "baseline_scores_test.png"
)

# ROC curve
plot_roc_curve_with_ci(
    y_test,
    test_scores,
    title=f"Baseline ROC Curve (Test Set)",
    save_path=output_dir / "baseline_roc_test.png"
)

error_info = analyze_misclassifications(test_items, y_test, test_predictions, test_scores)
logger.info("\nERROR ANALYSIS (ΔΔ MAD baseline, TEST set)")
logger.info(f"  false positives: {error_info['false_positives']['count']}  "
            f"domains: {error_info['false_positives']['domains']}")
logger.info(f"  false negatives: {error_info['false_negatives']['count']}  "
            f"domains: {error_info['false_negatives']['domains']}")

In [None]:
# Initialize variables that will be used later in results section
metrics = None
test_contrastive_scores = None
ablation_results = None

if CONFIG['contrastive']['enabled'] and len(np.unique(y_train)) >= 2:
    logger.info("\n" + "="*70)
    logger.info("CONTRASTIVE LEARNING")
    logger.info("="*70)

    # Build enhanced features
    X_train = compute_enhanced_features(
        U_train_base,
        U_train_correct,
        U_train_incorrect,
        last_k=CONFIG['contrastive']['last_k_layers']
    )

    X_test = compute_enhanced_features(
        U_test_base,
        U_test_correct,
        U_test_incorrect,
        last_k=CONFIG['contrastive']['last_k_layers']
    )

    logger.info(f"Feature matrices:")
    logger.info(f"  Train: {X_train.shape}")
    logger.info(f"  Test:  {X_test.shape}")

    # Train contrastive model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Training on device: {device}")

    proj, scaler, proto0, proto1 = train_contrastive_model(
        X_train,
        y_train,
        CONFIG['contrastive'],
        device
    )

    # Evaluate on test set
    logger.info("\nEvaluating on TEST set...")
    test_contrastive_scores, metrics = evaluate_contrastive_model(
        proj,
        scaler,
        X_test,
        y_test,
        proto0,
        proto1,
        device
    )

    # Debug: Check for NaN in scores and metrics
    logger.info("\n" + "="*70)
    logger.info("CONTRASTIVE LEARNING RESULTS (TEST SET)")
    logger.info("="*70)

    # Handle potential NaN in metrics
    if not np.isnan(metrics['auroc']):
        logger.info(f"  Prototype AUROC: {metrics['auroc']:.4f} "
                    f"[{metrics['auroc_ci_low']:.4f}, {metrics['auroc_ci_high']:.4f}]")
    else:
        logger.warning(f"  Prototype AUROC: NaN (computation failed)")
        logger.warning(f"  CI Low: {metrics['auroc_ci_low']}, CI High: {metrics['auroc_ci_high']}")

    # Improvement over baseline
    if not np.isnan(auroc_baseline) and not np.isnan(metrics['auroc']):
        improvement = metrics['auroc'] - auroc_baseline
        logger.info(f"  Improvement:     +{improvement:.4f} over baseline")
    elif np.isnan(metrics['auroc']):
        logger.warning(f"  Cannot compute improvement: contrastive AUROC is NaN")
    elif np.isnan(auroc_baseline):
        logger.warning(f"  Cannot compute improvement: baseline AUROC is NaN")

    logger.info("="*70)

    # Visualize (only if there are valid scores)
    if not np.isnan(test_contrastive_scores).all():
        auroc_display = metrics['auroc'] if not np.isnan(metrics['auroc']) else np.nan

        plot_score_distributions(
            test_contrastive_scores,
            y_test,
            title=f"Contrastive Scores (Test Set, AUROC={auroc_display:.3f})",
            xlabel="Prototype Score (cosine₁ − cosine₀)",
            save_path=output_dir / "contrastive_scores_test.png"
        )

        plot_roc_curve_with_ci(
            y_test,
            test_contrastive_scores,
            title="Contrastive Learning ROC Curve (Test Set)",
            save_path=output_dir / "contrastive_roc_test.png"
        )

        # Error analysis
        test_contrastive_predictions = (test_contrastive_scores > 0).astype(int)
        contrastive_error_info = analyze_misclassifications(
            test_items, y_test, test_contrastive_predictions, test_contrastive_scores
        )
        logger.info("\nERROR ANALYSIS (Contrastive Learning, TEST set)")
        logger.info(f"  false positives: {contrastive_error_info['false_positives']['count']}  "
                    f"domains: {contrastive_error_info['false_positives']['domains']}")
        logger.info(f"  false negatives: {contrastive_error_info['false_negatives']['count']}  "
                    f"domains: {contrastive_error_info['false_negatives']['domains']}")
    else:
        logger.warning("  Skipping visualization and error analysis - all scores are NaN")

    # ====================================================================
    # STATISTICAL SIGNIFICANCE TESTING
    # ====================================================================

    if not np.isnan(test_contrastive_scores).all() and not np.isnan(metrics['auroc']):

        logger.info("\n" + "="*70)
        logger.info("STATISTICAL SIGNIFICANCE TESTING")
        logger.info("="*70)

        def auroc_difference(baseline_scores, contrastive_scores, y_true):
            """Compute difference in AUROC between two methods."""
            try:
                auroc_baseline = roc_auc_score(y_true, baseline_scores)
                auroc_contrastive = roc_auc_score(y_true, contrastive_scores)
                return auroc_contrastive - auroc_baseline
            except:
                return np.nan

        # Bootstrap test for significance
        logger.info("Running bootstrap test for method comparison...")
        rng = np.random.default_rng(CONFIG['seed'])
        n = len(y_test)
        diff_stats = []

        for _ in tqdm(range(CONFIG['bootstrap']['n_resamples']), desc="Bootstrap significance"):
            idx = rng.choice(n, size=n, replace=True)
            y_boot = y_test[idx]
            baseline_boot = test_scores[idx]
            contrastive_boot = test_contrastive_scores[idx]

            # Skip degenerate resamples
            if len(np.unique(y_boot)) < 2:
                continue

            diff = auroc_difference(baseline_boot, contrastive_boot, y_boot)
            if not np.isnan(diff):
                diff_stats.append(diff)

        if len(diff_stats) > 0:
            diff_stats = np.array(diff_stats)

            # Compute observed difference
            observed_diff = metrics['auroc'] - auroc_baseline

            # Compute 95% CI for the difference
            ci_low = np.percentile(diff_stats, 2.5)
            ci_high = np.percentile(diff_stats, 97.5)

            # Compute p-value (two-tailed)
            p_value = np.mean(diff_stats <= 0)
            if p_value > 0.5:
                p_value = 1 - p_value
            p_value = 2 * p_value

            logger.info(f"\nBaseline AUROC:     {auroc_baseline:.4f}")
            logger.info(f"Contrastive AUROC:  {metrics['auroc']:.4f}")
            logger.info(f"\nObserved difference: {observed_diff:.4f}")
            logger.info(f"Bootstrap 95% CI:    [{ci_low:.4f}, {ci_high:.4f}]")
            logger.info(f"P-value:             {p_value:.4f}")

            if ci_low > 0:
                logger.info("\n✓ Contrastive learning is SIGNIFICANTLY better than baseline (p < 0.05)")
            else:
                logger.info("\n✗ No significant difference between methods (p >= 0.05)")

            logger.info("="*70)

            # Save to results
            significance_results = {
                "method_comparison": "Contrastive vs Baseline",
                "baseline_auroc": float(auroc_baseline),
                "contrastive_auroc": float(metrics['auroc']),
                "observed_difference": float(observed_diff),
                "ci_low": float(ci_low),
                "ci_high": float(ci_high),
                "p_value": float(p_value),
                "n_bootstrap_samples": len(diff_stats),
                "is_significant": bool(ci_low > 0)
            }

            # Will be saved later in results.json
            if 'results' not in locals():
                results = {}
            results["significance_testing"] = significance_results

        else:
            logger.warning("Could not compute significance test - no valid bootstrap samples")

    # ====================================================================
    # 11. ABLATION STUDY
    # ====================================================================

    logger.info("\n" + "="*70)
    logger.info("ABLATION STUDY (Feature Importance)")
    logger.info("="*70)

    feature_groups = {
        'ΔKS (wrong)': [0],
        'ΔΔ KS': [1],
        'ΔΔ MAD': [2],
        'ΔΔ tail': [3],
        'ΔΔ var': [4],
        'ΔΔ skew': [5],
        'ΔΔ range': [6],
    }

    LAST_K = CONFIG['contrastive']['last_k_layers']
    ablation_results = {}

    for name, indices in feature_groups.items():
        # Select columns for this feature group
        cols = []
        for idx in indices:
            cols.extend(range(idx * LAST_K, (idx + 1) * LAST_K))

        X_train_subset = X_train[:, cols]
        X_test_subset = X_test[:, cols]

        # Train simple logistic regression probe
        scaler_ab = StandardScaler()
        X_train_std = scaler_ab.fit_transform(X_train_subset)
        X_test_std = scaler_ab.transform(X_test_subset)

        clf = LogisticRegression(
            penalty='l2',
            C=1.0,
            solver='liblinear',
            class_weight='balanced',
            max_iter=500,
            random_state=CONFIG['seed']
        )

        clf.fit(X_train_std, y_train)
        scores_ab = clf.decision_function(X_test_std)

        try:
            auroc_ab = roc_auc_score(y_test, scores_ab)
        except Exception as e:
            logger.warning(f"Failed to compute AUROC for {name}: {e}")
            auroc_ab = np.nan

        ablation_results[name] = auroc_ab
        logger.info(f"  {name:20s}: AUROC = {auroc_ab:.3f}")

    # Visualize ablation results
    if ablation_results:
        plt.figure(figsize=(8, 5))
        names = list(ablation_results.keys())
        scores_list = list(ablation_results.values())

        plt.barh(names, scores_list, color='steelblue', alpha=0.8)
        plt.axvline(0.5, linestyle='--', color='black', alpha=0.3, label='Chance')
        plt.xlabel('AUROC (Test Set)')
        plt.title('Ablation Study: Individual Feature Group Performance')
        plt.xlim(0.4, max(scores_list) + 0.05)
        plt.grid(True, alpha=0.3, axis='x')
        plt.tight_layout()
        plt.savefig(output_dir / "ablation_study.png", dpi=150, bbox_inches='tight')
        plt.show()
        plt.close()

else:
    logger.warning("\n" + "="*70)
    logger.warning("CONTRASTIVE LEARNING SKIPPED")
    logger.warning("="*70)
    if not CONFIG['contrastive']['enabled']:
        logger.warning("Reason: Contrastive learning is disabled in CONFIG")
    if len(np.unique(y_train)) < 2:
        logger.warning(f"Reason: Training set has only {len(np.unique(y_train))} class(es)")


In [None]:
# save results
logger.info("\n" + "="*70)
logger.info("SAVING RESULTS")
logger.info("="*70)

results = {
    "timestamp": datetime.now().isoformat(),
    "config": CONFIG,
    "data_summary": {
        "n_total": len(items),
        "n_train": len(train_items),
        "n_test": len(test_items),
        "n_train_flips": int(y_train.sum()),
        "n_test_flips": int(y_test.sum()),
        "train_flip_rate": float(y_train.mean()),
        "test_flip_rate": float(y_test.mean()),
    },
    "model_info": {
        "name": CONFIG['model']['name'],
        "n_layers": int(n_layers),
        "hidden_dim": int(hidden_dim),
    },
    "baseline_classifier": {
        "auroc": float(auroc_baseline) if not np.isnan(auroc_baseline) else None,
        "auroc_ci_low": float(auroc_ci_low) if not np.isnan(auroc_ci_low) else None,
        "auroc_ci_high": float(auroc_ci_high) if not np.isnan(auroc_ci_high) else None,
        "threshold": float(threshold),
        "specificity": float(specificity),
        "sensitivity": float(sensitivity),
        "precision": float(precision),
        "f1": float(f1),
    },
    "confidence_baseline": {
        "auroc": float(auroc_conf) if not np.isnan(auroc_conf) else None,
        "description": "Answer log-probability on incorrect context"
    }
}

# Only add contrastive results if the block executed
# Check if metrics exists by using try/except or checking if it's in locals()
try:
    if 'metrics' in locals() and metrics is not None:
        results["contrastive_learning"] = {
            "auroc": float(metrics['auroc']) if not np.isnan(metrics['auroc']) else None,
            "auroc_ci_low": float(metrics['auroc_ci_low']) if not np.isnan(metrics['auroc_ci_low']) else None,
            "auroc_ci_high": float(metrics['auroc_ci_high']) if not np.isnan(metrics['auroc_ci_high']) else None,
            "improvement_over_baseline": float(metrics['auroc'] - auroc_baseline)
                if not np.isnan(metrics['auroc']) and not np.isnan(auroc_baseline) else None,
        }

    if 'ablation_results' in locals() and ablation_results is not None:
        results["ablation_study"] = {
            name: float(score) if not np.isnan(score) else None
            for name, score in ablation_results.items()
        }

    # neuron-level contrastive
    if 'neuron_contrastive_metrics' in locals() and neuron_contrastive_metrics is not None:
        results["neuron_contrastive_learning"] = {
            "auroc": float(neuron_contrastive_metrics['auroc'])
                if not np.isnan(neuron_contrastive_metrics['auroc']) else None,
            "auroc_ci_low": float(neuron_contrastive_metrics['auroc_ci_low'])
                if not np.isnan(neuron_contrastive_metrics['auroc_ci_low']) else None,
            "auroc_ci_high": float(neuron_contrastive_metrics['auroc_ci_high'])
                if not np.isnan(neuron_contrastive_metrics['auroc_ci_high']) else None,
            "k_neurons": int(len(top_neurons)),
        }
except NameError:
    logger.warning("  Contrastive learning variables not found, skipping in results")


try:
    if 'auroc_neuron' in locals():
        results["neuron_level_probe"] = {
            "auroc": float(auroc_neuron),
            "auroc_ci_low": float(ci_low_neuron),
            "auroc_ci_high": float(ci_high_neuron),
            "k_neurons": int(len(top_neurons)),
            "last_k_layers_searched": int(LAST_K_NEURON),
        }
except NameError:
    pass

with open(output_dir / "results.json", 'w') as f:
    json.dump(results, f, indent=2)

logger.info(f"  Results saved to: {output_dir / 'results.json'}")

In [None]:
logger.info("\n" + "="*70)
logger.info("DETAILED ERROR ANALYSIS")
logger.info("="*70)

# Prepare baseline predictions
test_predictions = (test_scores > threshold).astype(int)

# Prepare contrastive predictions if available
if CONFIG['contrastive']['enabled'] and 'test_contrastive_scores' in locals() and test_contrastive_scores is not None:
    test_contrastive_predictions = (test_contrastive_scores > 0).astype(int)
    has_contrastive = True
else:
    test_contrastive_scores = None
    test_contrastive_predictions = None
    has_contrastive = False

# Run error analysis
error_analyses = run_error_analysis_on_saved_data(
    test_items=test_items,
    test_results=test_results,
    y_test=y_test,
    test_scores_baseline=test_scores,
    test_predictions_baseline=test_predictions,
    test_scores_contrastive=test_contrastive_scores if has_contrastive else None,
    test_predictions_contrastive=test_contrastive_predictions if has_contrastive else None,
    output_dir=output_dir
)

# Print summary
logger.info("\n" + "="*70)
logger.info("ERROR ANALYSIS SUMMARY")
logger.info("="*70)

baseline_errors = error_analyses['baseline']
logger.info(f"\nBaseline Method (ΔΔ MAD):")
logger.info(f"  False Positive Rate: {baseline_errors['false_positives']['rate']:.1%}")
logger.info(f"  False Negative Rate: {baseline_errors['false_negatives']['rate']:.1%}")
logger.info(f"  FP Count: {baseline_errors['false_positives']['count']}")
logger.info(f"  FN Count: {baseline_errors['false_negatives']['count']}")

if 'contrastive' in error_analyses:
    contrastive_errors = error_analyses['contrastive']
    logger.info(f"\nContrastive Learning Method:")
    logger.info(f"  False Positive Rate: {contrastive_errors['false_positives']['rate']:.1%}")
    logger.info(f"  False Negative Rate: {contrastive_errors['false_negatives']['rate']:.1%}")
    logger.info(f"  FP Count: {contrastive_errors['false_positives']['count']}")
    logger.info(f"  FN Count: {contrastive_errors['false_negatives']['count']}")

    if 'comparison' in error_analyses:
        comp = error_analyses['comparison']
        logger.info(f"\nContrastive improvements over baseline:")
        logger.info(f"  Fixed {comp['fixed_fp']} false positives")
        logger.info(f"  Fixed {comp['fixed_fn']} false negatives")
        logger.info(f"  Introduced {comp['contrastive_only_fp']} new false positives")
        logger.info(f"  Introduced {comp['contrastive_only_fn']} new false negatives")

        net_improvement = (comp['fixed_fp'] + comp['fixed_fn']) - (comp['contrastive_only_fp'] + comp['contrastive_only_fn'])
        logger.info(f"  Net improvement: {net_improvement} fewer total errors")

logger.info(f"\nDetailed analysis saved to: {output_dir / 'error_analysis.json'}")
logger.info("="*70)

In [None]:
print(significance_results)