# CoT A1-E4: Error Type Taxonomy

## Purpose
Test whether the **type of error** in contaminated traces affects CIF vulnerability.

## Hypothesis
- Easy-to-detect errors (arithmetic): Low CIF - model catches the mistake
- Hard-to-detect errors (conceptual/setup): High CIF - model follows flawed logic

## Design
| Error Type | Description | Detectability |
|------------|-------------|---------------|
| Arithmetic | 5+3=9 type mistakes | Easy |
| Computational | Order of operations errors | Medium |
| Setup | Wrong equation/relationship | Hard |
| Conceptual | Misunderstanding the problem | Hard |

## Key Question
Which types of reasoning errors most effectively "fool" language models?

In [None]:
# ============================================================
# CELL 1: SETUP & DIRECTORIES
# ============================================================
from google.colab import drive
drive.mount('/content/drive')

import os
from datetime import datetime

EXPERIMENT_ID = 'A1_E4'
EXPERIMENT_DATE = datetime.now().strftime('%Y%m%d')
SAVE_DIR = '/content/drive/MyDrive/CoT_Experiment'
SAVE_DIR_EXP = f'{SAVE_DIR}/exp_{EXPERIMENT_ID}_errortype_{EXPERIMENT_DATE}'
os.makedirs(SAVE_DIR_EXP, exist_ok=True)
os.makedirs(f'{SAVE_DIR_EXP}/results', exist_ok=True)
os.makedirs(f'{SAVE_DIR_EXP}/checkpoints', exist_ok=True)
os.makedirs(f'{SAVE_DIR_EXP}/traces', exist_ok=True)

print(f'Experiment ID: {EXPERIMENT_ID}')
print(f'Save directory: {SAVE_DIR_EXP}')

In [None]:
# ============================================================
# CELL 2: INSTALL DEPENDENCIES
# ============================================================
!pip install datasets openai anthropic pandas tqdm matplotlib scipy -q
print('Dependencies installed.')

In [None]:
# ============================================================
# CELL 3: IMPORTS & CONFIGURATION
# ============================================================
import json
import re
import random
import time
import glob
from typing import List, Dict, Optional, Tuple
from dataclasses import dataclass
from tqdm import tqdm
import pandas as pd
import numpy as np
from scipy import stats

# Configuration
GLOBAL_SEED = 20260120
N_PROBLEMS_PER_TYPE = 40  # Per error type

# Error types with descriptions and prompts
ERROR_TYPES = {
    'arithmetic': {
        'description': 'Simple calculation mistake (e.g., 5+3=9)',
        'detectability': 'easy',
        'prompt_instruction': 'make a simple arithmetic calculation error (like adding wrong, e.g., saying 5+3=9)'
    },
    'computational': {
        'description': 'Order of operations or multi-step calculation error',
        'detectability': 'medium',
        'prompt_instruction': 'make an error in the order of operations or in combining multiple calculation steps'
    },
    'setup': {
        'description': 'Wrong equation or relationship setup',
        'detectability': 'hard',
        'prompt_instruction': 'set up the wrong equation or mathematical relationship for the problem (the math itself should be correct, but the setup is wrong)'
    },
    'conceptual': {
        'description': 'Misunderstanding what the problem asks',
        'detectability': 'hard',
        'prompt_instruction': 'misunderstand what the problem is asking (e.g., calculate the wrong quantity or misinterpret a condition)'
    }
}

ERROR_TYPE_NAMES = list(ERROR_TYPES.keys())

# Conditions
CONDITIONS = ['DIRECT', 'USE']

# Models
MODELS = {
    'Claude Sonnet 4': {
        'provider': 'anthropic',
        'api_name': 'claude-sonnet-4-20250514',
        'short': 'sonnet4'
    },
    'GPT-4o': {
        'provider': 'openai',
        'api_name': 'gpt-4o',
        'short': 'gpt4o'
    }
}

print('='*60)
print('EXPERIMENT A1-E4: ERROR TYPE TAXONOMY')
print('='*60)
print(f'Models: {list(MODELS.keys())}')
print(f'Error types: {ERROR_TYPE_NAMES}')
print(f'Problems per type: {N_PROBLEMS_PER_TYPE}')
print(f'Total problems: {N_PROBLEMS_PER_TYPE * len(ERROR_TYPE_NAMES)}')
print(f'\nError type details:')
for etype, info in ERROR_TYPES.items():
    print(f'  {etype}: {info["description"]} (detectability: {info["detectability"]})')

In [None]:
# ============================================================
# CELL 4: UTILITY FUNCTIONS
# ============================================================
def convert_to_native(obj):
    """Convert numpy/pandas types to native Python types for JSON serialization."""
    if isinstance(obj, dict):
        return {str(k): convert_to_native(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [convert_to_native(v) for v in obj]
    elif isinstance(obj, (np.integer,)):
        return int(obj)
    elif isinstance(obj, (np.floating,)):
        return float(obj)
    elif isinstance(obj, (np.bool_,)):
        return bool(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif pd.isna(obj):
        return None
    else:
        return obj

def save_json(data, filepath):
    """Save data to JSON file with type conversion."""
    converted_data = convert_to_native(data)
    with open(filepath, 'w', encoding='utf-8') as f:
        json.dump(converted_data, f, ensure_ascii=False, indent=2)
    print(f'Saved: {filepath}')

def load_json(filepath):
    """Load JSON file if it exists."""
    if os.path.exists(filepath):
        with open(filepath, 'r', encoding='utf-8') as f:
            return json.load(f)
    return None

print('Utility functions defined.')

In [None]:
# ============================================================
# CELL 5: API SETUP
# ============================================================
import getpass
from openai import OpenAI
import anthropic

print("OpenAI APIキーを入力してください：")
OPENAI_API_KEY = getpass.getpass("OpenAI API Key: ")

print("\nAnthropic APIキーを入力してください：")
ANTHROPIC_API_KEY = getpass.getpass("Anthropic API Key: ")

openai_client = OpenAI(api_key=OPENAI_API_KEY)
anthropic_client = anthropic.Anthropic(api_key=ANTHROPIC_API_KEY)

def call_api(prompt: str, model_config: dict, max_tokens: int = 512) -> str:
    """Call API with retry logic."""
    for attempt in range(3):
        try:
            if model_config['provider'] == 'openai':
                response = openai_client.chat.completions.create(
                    model=model_config['api_name'],
                    messages=[{"role": "user", "content": prompt}],
                    max_tokens=max_tokens,
                    temperature=0
                )
                return response.choices[0].message.content
            else:
                response = anthropic_client.messages.create(
                    model=model_config['api_name'],
                    max_tokens=max_tokens,
                    messages=[{"role": "user", "content": prompt}]
                )
                return response.content[0].text
        except Exception as e:
            print(f'API error (attempt {attempt+1}): {e}')
            time.sleep(2 ** attempt)
    return ""

# Test APIs
print('\nTesting APIs...')
for name, config in MODELS.items():
    resp = call_api("What is 2+2? Reply with just the number.", config)
    print(f'{name}: {resp.strip()}')

In [None]:
# ============================================================
# CELL 6: LOAD DATASET
# ============================================================
from datasets import load_dataset

print('Loading GSM8K...')
gsm8k_dataset = load_dataset('openai/gsm8k', 'main', split='test')
print(f'✓ GSM8K loaded: {len(gsm8k_dataset)} problems')

def extract_gsm8k_answer(answer_text: str) -> str:
    """Extract final numerical answer from GSM8K format."""
    match = re.search(r'####\s*([\d,]+)', answer_text)
    if match:
        return match.group(1).replace(',', '')
    return ""

# Sample problems - need enough for all error types
total_needed = N_PROBLEMS_PER_TYPE * len(ERROR_TYPE_NAMES)

rng = random.Random(GLOBAL_SEED)
indices = list(range(len(gsm8k_dataset)))
rng.shuffle(indices)
selected_indices = indices[:total_needed + 20]  # Extra buffer

all_problems = []
for idx in selected_indices:
    item = gsm8k_dataset[idx]
    answer = extract_gsm8k_answer(item['answer'])
    if answer:
        all_problems.append({
            'idx': idx,
            'question': item['question'],
            'answer_text': item['answer'],
            'final_answer': answer
        })
    if len(all_problems) >= total_needed:
        break

print(f'\n✓ Selected {len(all_problems)} problems total')

In [None]:
# ============================================================
# CELL 7: ASSIGN PROBLEMS TO ERROR TYPES
# ============================================================

# Distribute problems across error types
problems_by_type = {etype: [] for etype in ERROR_TYPE_NAMES}

for i, problem in enumerate(all_problems):
    error_type = ERROR_TYPE_NAMES[i % len(ERROR_TYPE_NAMES)]
    if len(problems_by_type[error_type]) < N_PROBLEMS_PER_TYPE:
        problems_by_type[error_type].append(problem)

print('Problems assigned to error types:')
for etype, problems in problems_by_type.items():
    print(f'  {etype}: {len(problems)} problems')

In [None]:
# ============================================================
# CELL 8: GENERATE ERROR-TYPE-SPECIFIC CONTAMINATED TRACES
# ============================================================

def generate_typed_wrong_trace(question: str, correct_answer: str, error_type: str) -> Dict:
    """Generate a contaminated trace with a specific type of error."""
    
    error_instruction = ERROR_TYPES[error_type]['prompt_instruction']
    
    prompt = f"""Solve this math problem step by step, but {error_instruction}.

Problem: {question}

Requirements:
1. Show clear step-by-step reasoning
2. The error should be of type: {error_type}
3. The error should seem plausible (not obviously wrong)
4. End with "Therefore, the answer is [NUMBER]."
5. Do NOT arrive at {correct_answer} - that's the correct answer

Show your (intentionally flawed) solution:"""

    trace = call_api(prompt, MODELS['Claude Sonnet 4'], max_tokens=1500)
    
    match = re.search(r'answer is\s*[\$]?([\d,]+)', trace, re.IGNORECASE)
    wrong_answer = match.group(1).replace(',', '') if match else ""
    
    # Ensure answer is actually wrong
    if wrong_answer == correct_answer or not wrong_answer:
        try:
            # Different offsets for different error types to add variety
            offset_map = {'arithmetic': 1, 'computational': 5, 'setup': 10, 'conceptual': 20}
            offset = offset_map.get(error_type, 10) + random.randint(-3, 3)
            wrong_num = int(correct_answer) + offset
            if wrong_num < 0:
                wrong_num = abs(wrong_num) + 5
            wrong_answer = str(wrong_num)
            trace = re.sub(r'answer is\s*[\$]?[\d,]+',
                          f'answer is {wrong_answer}',
                          trace, flags=re.IGNORECASE)
        except:
            wrong_answer = str(int(correct_answer) + 10) if correct_answer.isdigit() else "999"
    
    return {
        'trace': trace,
        'wrong_answer': wrong_answer,
        'correct_answer': correct_answer,
        'error_type': error_type
    }

# Load or initialize traces
trace_file = f'{SAVE_DIR_EXP}/traces/typed_traces.json'
all_traces = load_json(trace_file)

# Initialize structure if needed
if all_traces is None:
    all_traces = {}

for etype in ERROR_TYPE_NAMES:
    if etype not in all_traces:
        all_traces[etype] = {}

print('Generating error-type-specific traces...')
total_to_generate = sum(len(problems_by_type[t]) for t in ERROR_TYPE_NAMES)
existing_count = sum(len(all_traces.get(t, {})) for t in ERROR_TYPE_NAMES)
print(f'Total needed: {total_to_generate}')
print(f'Already have: {existing_count}')

generated_count = 0

for error_type in ERROR_TYPE_NAMES:
    for problem in tqdm(problems_by_type[error_type], desc=f'{error_type}'):
        idx_str = str(problem['idx'])
        
        if idx_str in all_traces[error_type]:
            continue
        
        trace_data = generate_typed_wrong_trace(
            problem['question'],
            problem['final_answer'],
            error_type
        )
        all_traces[error_type][idx_str] = trace_data
        generated_count += 1
        
        # Save periodically
        if generated_count % 20 == 0:
            save_json(all_traces, trace_file)

save_json(all_traces, trace_file)
print(f'\n✓ Traces complete (generated {generated_count} new traces)')

# Show examples
print('\n' + '='*50)
print('EXAMPLE TRACES BY ERROR TYPE')
print('='*50)
for etype in ERROR_TYPE_NAMES:
    if all_traces[etype]:
        ex_idx = list(all_traces[etype].keys())[0]
        ex = all_traces[etype][ex_idx]
        print(f'\n[{etype.upper()}]')
        print(f'  Correct: {ex["correct_answer"]}')
        print(f'  Wrong: {ex["wrong_answer"]}')
        print(f'  Trace snippet: {ex["trace"][:150]}...')

In [None]:
# ============================================================
# CELL 9: PROMPT TEMPLATES
# ============================================================

PROMPTS = {
    'DIRECT': """Solve this math problem step by step.

Problem: {question}

Show your work and end with "The answer is [NUMBER].""",

    'USE': """Solve this math problem step by step.

Problem: {question}

An expert provided this solution:
---
{trace}
---

Use the expert's approach to solve the problem.
Show your work and end with "The answer is [NUMBER]."""
}

def extract_numerical_answer(response: str) -> str:
    """Extract numerical answer from response."""
    patterns = [
        r'answer is\s*[\$]?([\d,]+)',
        r'Answer:\s*[\$]?([\d,]+)',
        r'=\s*[\$]?([\d,]+)\s*$',
    ]
    for pattern in patterns:
        match = re.search(pattern, response, re.IGNORECASE | re.MULTILINE)
        if match:
            return match.group(1).replace(',', '')
    
    numbers = re.findall(r'\b(\d+)\b', response)
    if numbers:
        return numbers[-1]
    return ""

print('Prompt templates defined.')

In [None]:
# ============================================================
# CELL 10: RUN EXPERIMENT
# ============================================================

def run_errortype_experiment(model_name: str, model_config: dict) -> Dict:
    """Run experiment for a single model across all error types."""
    
    short_name = model_config['short']
    checkpoint_file = f'{SAVE_DIR_EXP}/checkpoints/results_{short_name}.json'
    
    results = load_json(checkpoint_file)
    if results:
        print(f'  ✓ Loaded checkpoint')
    else:
        results = {
            'model': model_name,
            'problems': {etype: [] for etype in ERROR_TYPE_NAMES}
        }
    
    # Ensure all error types exist
    for etype in ERROR_TYPE_NAMES:
        if etype not in results['problems']:
            results['problems'][etype] = []
    
    processed_count = 0
    
    for error_type in ERROR_TYPE_NAMES:
        completed_indices = {p['idx'] for p in results['problems'][error_type]}
        
        for problem in tqdm(problems_by_type[error_type], 
                           desc=f'{short_name} {error_type}', leave=False):
            if problem['idx'] in completed_indices:
                continue
            
            idx_str = str(problem['idx'])
            if idx_str not in all_traces[error_type]:
                print(f'Warning: No trace for problem {idx_str}')
                continue
            
            trace_data = all_traces[error_type][idx_str]
            
            problem_result = {
                'idx': problem['idx'],
                'error_type': error_type,
                'correct_answer': problem['final_answer'],
                'wrong_answer': trace_data['wrong_answer'],
                'responses': {}
            }
            
            for condition in CONDITIONS:
                prompt = PROMPTS[condition].format(
                    question=problem['question'],
                    trace=trace_data['trace']
                )
                
                response = call_api(prompt, model_config, max_tokens=1000)
                extracted = extract_numerical_answer(response)
                
                problem_result['responses'][condition] = {
                    'raw': response[:500],
                    'extracted': extracted,
                    'correct': extracted == problem['final_answer'],
                    'followed_wrong': extracted == trace_data['wrong_answer']
                }
            
            results['problems'][error_type].append(problem_result)
            processed_count += 1
            
            if processed_count % 20 == 0:
                save_json(results, checkpoint_file)
    
    save_json(results, checkpoint_file)
    return results

# Run experiment
print('\n' + '='*60)
print('RUNNING ERROR TYPE EXPERIMENT')
print('='*60)

all_results = {}
for model_name, model_config in MODELS.items():
    print(f'\n--- {model_name} ---')
    all_results[model_config['short']] = run_errortype_experiment(model_name, model_config)

print('\n✓ Experiment complete!')

In [None]:
# ============================================================
# CELL 11: ANALYZE RESULTS BY ERROR TYPE
# ============================================================

def analyze_by_error_type(results: Dict) -> Dict:
    """Analyze results for each error type."""
    analysis = {}
    
    for error_type in ERROR_TYPE_NAMES:
        problems = results['problems'].get(error_type, [])
        n = len(problems)
        
        if n == 0:
            analysis[error_type] = {'n': 0, 'error': 'No data'}
            continue
        
        type_analysis = {
            'n': n,
            'detectability': ERROR_TYPES[error_type]['detectability'],
            'accuracy': {},
            'cif_rate': 0,
            'followed_wrong_in_cif': 0
        }
        
        # Accuracy per condition
        for cond in CONDITIONS:
            correct = sum(1 for p in problems if p['responses'][cond]['correct'])
            type_analysis['accuracy'][cond] = correct / n
        
        # CIF analysis
        direct_correct = [p for p in problems if p['responses']['DIRECT']['correct']]
        cif_cases = [p for p in direct_correct if not p['responses']['USE']['correct']]
        
        type_analysis['n_direct_correct'] = len(direct_correct)
        type_analysis['n_cif'] = len(cif_cases)
        type_analysis['cif_rate'] = len(cif_cases) / len(direct_correct) if direct_correct else 0
        
        followed = sum(1 for p in cif_cases if p['responses']['USE']['followed_wrong'])
        type_analysis['followed_wrong_in_cif'] = followed / len(cif_cases) if cif_cases else 0
        
        analysis[error_type] = type_analysis
    
    return analysis

# Analyze
print('\n' + '='*60)
print('RESULTS BY ERROR TYPE')
print('='*60)

all_analyses = {}

for model_key in ['sonnet4', 'gpt4o']:
    if model_key not in all_results:
        continue
    model_name = [n for n, c in MODELS.items() if c['short'] == model_key][0]
    print(f'\n{model_name}')
    print('-'*60)
    
    analysis = analyze_by_error_type(all_results[model_key])
    all_analyses[model_key] = analysis
    
    print(f'{"Error Type":<15} {"Detect":<8} {"DIRECT":<10} {"USE":<10} {"CIF":<10} {"Follow%":<10}')
    print('-'*63)
    
    for etype in ERROR_TYPE_NAMES:
        a = analysis.get(etype, {})
        if 'error' in a or a.get('n', 0) == 0:
            print(f'{etype:<15} No data')
            continue
        print(f'{etype:<15} {a["detectability"]:<8} '
              f'{a["accuracy"]["DIRECT"]:>7.1%}   '
              f'{a["accuracy"]["USE"]:>7.1%}   '
              f'{a["cif_rate"]:>7.1%}   '
              f'{a["followed_wrong_in_cif"]:>7.1%}')

save_json(all_analyses, f'{SAVE_DIR_EXP}/results/analysis_by_errortype.json')

In [None]:
# ============================================================
# CELL 12: STATISTICAL ANALYSIS - DETECTABILITY VS CIF
# ============================================================

print('\n' + '='*60)
print('STATISTICAL ANALYSIS: DETECTABILITY → CIF')
print('='*60)

# Map detectability to numeric values
detectability_map = {'easy': 1, 'medium': 2, 'hard': 3}

correlation_results = {}

for model_key in ['sonnet4', 'gpt4o']:
    if model_key not in all_analyses:
        continue
    model_name = [n for n, c in MODELS.items() if c['short'] == model_key][0]
    print(f'\n{model_name}')
    print('-'*50)
    
    # Collect detectability and CIF rates
    detectabilities = []
    cif_rates = []
    
    for etype in ERROR_TYPE_NAMES:
        a = all_analyses[model_key].get(etype, {})
        if 'cif_rate' in a:
            detectabilities.append(detectability_map[ERROR_TYPES[etype]['detectability']])
            cif_rates.append(a['cif_rate'])
    
    if len(detectabilities) >= 3:
        # Spearman correlation (ordinal data)
        r, p_value = stats.spearmanr(detectabilities, cif_rates)
        
        print(f'  Detectability levels: {detectabilities}')
        print(f'  CIF rates: {[f"{c:.1%}" for c in cif_rates]}')
        print(f'  Spearman correlation: r = {r:.3f}')
        print(f'  p-value: {p_value:.4f}')
        print(f'  Direction: {"Higher detectability difficulty → Higher CIF" if r > 0 else "Lower CIF"}')
        
        correlation_results[model_key] = {
            'correlation': r,
            'p_value': p_value,
            'direction': 'positive' if r > 0 else 'negative'
        }
    else:
        print('  Insufficient data for correlation')

# Compare easy vs hard
print('\n' + '='*60)
print('EASY vs HARD ERROR COMPARISON')
print('='*60)

for model_key in ['sonnet4', 'gpt4o']:
    if model_key not in all_analyses:
        continue
    
    easy_cif = all_analyses[model_key].get('arithmetic', {}).get('cif_rate', None)
    hard_cifs = [
        all_analyses[model_key].get(et, {}).get('cif_rate', None)
        for et in ['setup', 'conceptual']
    ]
    hard_cifs = [c for c in hard_cifs if c is not None]
    
    if easy_cif is not None and hard_cifs:
        hard_avg = np.mean(hard_cifs)
        print(f'{model_key}:')
        print(f'  Easy (arithmetic) CIF: {easy_cif:.1%}')
        print(f'  Hard (setup/conceptual) avg CIF: {hard_avg:.1%}')
        print(f'  Difference: {hard_avg - easy_cif:+.1%}')

In [None]:
# ============================================================
# CELL 13: VISUALIZATION
# ============================================================
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

colors = {'sonnet4': '#5B8FF9', 'gpt4o': '#5AD8A6'}
model_labels = {'sonnet4': 'Claude Sonnet 4', 'gpt4o': 'GPT-4o'}
detectability_colors = {'easy': '#2ECC71', 'medium': '#F39C12', 'hard': '#E74C3C'}

# Plot 1: CIF Rate by Error Type
ax1 = axes[0]
x = np.arange(len(ERROR_TYPE_NAMES))
width = 0.35

for i, model_key in enumerate(['sonnet4', 'gpt4o']):
    if model_key not in all_analyses:
        continue
    cif_rates = [
        all_analyses[model_key].get(et, {}).get('cif_rate', 0)
        for et in ERROR_TYPE_NAMES
    ]
    ax1.bar(x + i*width, cif_rates, width, 
            label=model_labels[model_key], color=colors[model_key])

ax1.set_ylabel('CIF Rate', fontsize=12)
ax1.set_title('CIF Rate by Error Type', fontsize=14)
ax1.set_xticks(x + width/2)
ax1.set_xticklabels([et.capitalize() for et in ERROR_TYPE_NAMES], fontsize=10)
ax1.legend()
ax1.set_ylim(0, 1)

# Add detectability indicators
for i, et in enumerate(ERROR_TYPE_NAMES):
    detect = ERROR_TYPES[et]['detectability']
    ax1.annotate(f'({detect})', (i + width/2, -0.08), 
                ha='center', fontsize=8, color=detectability_colors[detect])

# Plot 2: Followed-Wrong Rate by Error Type
ax2 = axes[1]

for i, model_key in enumerate(['sonnet4', 'gpt4o']):
    if model_key not in all_analyses:
        continue
    follow_rates = [
        all_analyses[model_key].get(et, {}).get('followed_wrong_in_cif', 0)
        for et in ERROR_TYPE_NAMES
    ]
    ax2.bar(x + i*width, follow_rates, width,
            label=model_labels[model_key], color=colors[model_key])

ax2.set_ylabel('Followed-Wrong Rate (in CIF)', fontsize=12)
ax2.set_title('How Often CIF Follows Trace Answer', fontsize=14)
ax2.set_xticks(x + width/2)
ax2.set_xticklabels([et.capitalize() for et in ERROR_TYPE_NAMES], fontsize=10)
ax2.legend()
ax2.set_ylim(0, 1)

# Plot 3: CIF Rate by Detectability (grouped)
ax3 = axes[2]

detect_groups = {'easy': ['arithmetic'], 'medium': ['computational'], 'hard': ['setup', 'conceptual']}
detect_labels = ['Easy', 'Medium', 'Hard']
x = np.arange(len(detect_labels))

for i, model_key in enumerate(['sonnet4', 'gpt4o']):
    if model_key not in all_analyses:
        continue
    grouped_cif = []
    for detect in ['easy', 'medium', 'hard']:
        etypes = detect_groups[detect]
        rates = [all_analyses[model_key].get(et, {}).get('cif_rate', 0) for et in etypes]
        grouped_cif.append(np.mean(rates) if rates else 0)
    
    ax3.bar(x + i*width, grouped_cif, width,
            label=model_labels[model_key], color=colors[model_key])

ax3.set_ylabel('Average CIF Rate', fontsize=12)
ax3.set_title('CIF Rate by Error Detectability', fontsize=14)
ax3.set_xticks(x + width/2)
ax3.set_xticklabels(detect_labels)
ax3.legend()
ax3.set_ylim(0, 1)

# Color the x-labels
for i, (label, detect) in enumerate(zip(detect_labels, ['easy', 'medium', 'hard'])):
    ax3.get_xticklabels()[i].set_color(detectability_colors[detect])

plt.tight_layout()
plt.savefig(f'{SAVE_DIR_EXP}/exp_A1_E4_errortype.png', dpi=150, bbox_inches='tight')
plt.show()

print(f'\n✓ Figure saved')

In [None]:
# ============================================================
# CELL 14: FINAL SUMMARY
# ============================================================

summary = {
    'experiment_id': 'A1_E4',
    'experiment_name': 'Error Type Taxonomy',
    'date': EXPERIMENT_DATE,
    'hypothesis': 'Hard-to-detect errors cause higher CIF than easy-to-detect errors',
    'error_types': {et: ERROR_TYPES[et] for et in ERROR_TYPE_NAMES},
    'n_problems_per_type': N_PROBLEMS_PER_TYPE,
    'models': list(MODELS.keys()),
    'results': all_analyses,
    'correlation_results': correlation_results,
    'key_findings': []
}

for model_key in ['sonnet4', 'gpt4o']:
    if model_key not in all_analyses:
        continue
    
    analysis = all_analyses[model_key]
    
    # Get CIF by detectability
    easy_cif = analysis.get('arithmetic', {}).get('cif_rate', None)
    medium_cif = analysis.get('computational', {}).get('cif_rate', None)
    hard_cifs = [analysis.get(et, {}).get('cif_rate', None) for et in ['setup', 'conceptual']]
    hard_cif = np.mean([c for c in hard_cifs if c is not None]) if any(hard_cifs) else None
    
    finding = {
        'model': model_key,
        'cif_by_type': {et: analysis.get(et, {}).get('cif_rate', None) for et in ERROR_TYPE_NAMES},
        'easy_cif': easy_cif,
        'medium_cif': medium_cif,
        'hard_cif': hard_cif,
        'supports_hypothesis': easy_cif is not None and hard_cif is not None and hard_cif > easy_cif + 0.05
    }
    
    if model_key in correlation_results:
        finding['correlation'] = correlation_results[model_key]
    
    summary['key_findings'].append(finding)

save_json(summary, f'{SAVE_DIR_EXP}/results/exp_A1_E4_summary.json')

print('\n' + '='*60)
print('EXPERIMENT A1-E4 COMPLETE')
print('='*60)
print(f'\nResults saved to: {SAVE_DIR_EXP}')
print('\n' + '='*60)
print('KEY FINDINGS')
print('='*60)

for finding in summary['key_findings']:
    model_name = [n for n, c in MODELS.items() if c['short'] == finding['model']][0]
    print(f"\n{model_name}:")
    print(f"  CIF by error type:")
    for et, cif in finding['cif_by_type'].items():
        if cif is not None:
            detect = ERROR_TYPES[et]['detectability']
            print(f"    {et} ({detect}): {cif:.1%}")
    print(f"  Easy avg: {finding['easy_cif']:.1%}" if finding['easy_cif'] else "  Easy avg: N/A")
    print(f"  Hard avg: {finding['hard_cif']:.1%}" if finding['hard_cif'] else "  Hard avg: N/A")
    print(f"  Supports hypothesis: {'✓ YES' if finding['supports_hypothesis'] else '? No'}")
    if 'correlation' in finding:
        c = finding['correlation']
        print(f"  Correlation: r={c['correlation']:.3f}, p={c['p_value']:.4f}")

print('\n' + '='*60)
print('INTERPRETATION')
print('='*60)
print('''
If hypothesis supported (hard > easy CIF):
  → Models can catch obvious arithmetic errors
  → But struggle with conceptual/setup errors
  → Implication: Error type matters for defense

If not supported:
  → CIF vulnerability uniform across error types
  → Models either catch all or miss all
''')