# Experiment A: Instruction × Authority × Answer Removal (v2)

**Purpose**: Resolve Study 1-3 contradiction and decompose CIF firing conditions

**v2 Changes**:
- Fixed numpy int64/float64 JSON serialization issue
- Improved error handling throughout
- Better checkpoint resume logic

**Design (2×2×2 - 2 = 6 conditions)**:
- **Instruction**: USE vs NEUTRAL
- **Authority**: AUTHORITATIVE vs UNCERTAIN
- **Answer**: Present vs Removed

| ID | Instruction | Authority | Answer | Description |
|----|-------------|-----------|--------|-------------|
| A1 | USE | AUTH | Present | "Use the following expert solution" |
| A2 | USE | UNCERT | Present | "Use the following attempted solution (may contain errors)" |
| A3 | NEUTRAL | AUTH | Present | "Here is an expert reasoning trace" |
| A4 | NEUTRAL | UNCERT | Present | "Here is an attempted reasoning trace (may contain errors)" |
| A5 | USE | AUTH | **Removed** | A1 without final answer |
| A6 | USE | UNCERT | **Removed** | A2 without final answer |
| A0 | — | — | — | DIRECT baseline (no trace) |

**Models**: Claude 4 Sonnet, GPT-4o, Claude 3.5 Haiku

**λ**: 0.8 (WRONG type, coherent-but-wrong)

**N**: 200 problems

## 0. Setup & Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
from datetime import datetime

EXPERIMENT_ID = 'exp_A'
EXPERIMENT_DATE = datetime.now().strftime('%Y%m%d')
SAVE_DIR = '/content/drive/MyDrive/CoT_Experiment'
SAVE_DIR_EXP = f'{SAVE_DIR}/{EXPERIMENT_ID}_instruction_authority_{EXPERIMENT_DATE}'
os.makedirs(SAVE_DIR_EXP, exist_ok=True)
os.makedirs(f'{SAVE_DIR_EXP}/results', exist_ok=True)
os.makedirs(f'{SAVE_DIR_EXP}/checkpoints', exist_ok=True)

print(f'Experiment ID: {EXPERIMENT_ID}')
print(f'Save directory: {SAVE_DIR_EXP}')

In [None]:
!pip install datasets openai anthropic pandas tqdm matplotlib scipy statsmodels -q
print('Dependencies installed.')

## 1. Configuration

In [None]:
import json
import re
import random
import time
import hashlib
from typing import List, Dict, Optional, Any, Tuple
from dataclasses import dataclass, asdict
from datetime import datetime
from tqdm import tqdm
import pandas as pd
import numpy as np

# ============================================================
# CONFIGURATION
# ============================================================
GLOBAL_SEED = 20251224
N_PROBLEMS = 200
I_FIXED = 10
LAMBDA_FIXED = 0.8

# Experimental conditions
CONDITIONS = {
    'A0_DIRECT': {'instruction': None, 'authority': None, 'answer_present': None},
    'A1_USE_AUTH_ANS': {'instruction': 'USE', 'authority': 'AUTH', 'answer_present': True},
    'A2_USE_UNCERT_ANS': {'instruction': 'USE', 'authority': 'UNCERT', 'answer_present': True},
    'A3_NEUTRAL_AUTH_ANS': {'instruction': 'NEUTRAL', 'authority': 'AUTH', 'answer_present': True},
    'A4_NEUTRAL_UNCERT_ANS': {'instruction': 'NEUTRAL', 'authority': 'UNCERT', 'answer_present': True},
    'A5_USE_AUTH_NOANS': {'instruction': 'USE', 'authority': 'AUTH', 'answer_present': False},
    'A6_USE_UNCERT_NOANS': {'instruction': 'USE', 'authority': 'UNCERT', 'answer_present': False},
}

# Models to test
MODELS = {
    'Claude 4 Sonnet': {
        'provider': 'anthropic',
        'api_name': 'claude-sonnet-4-20250514',
        'short': 'sonnet4'
    },
    'GPT-4o': {
        'provider': 'openai',
        'api_name': 'gpt-4o',
        'short': 'gpt4o'
    },
    'Claude 3.5 Haiku': {
        'provider': 'anthropic',
        'api_name': 'claude-3-5-haiku-latest',
        'short': 'haiku35'
    }
}

print('='*60)
print('EXPERIMENT A: INSTRUCTION × AUTHORITY × ANSWER REMOVAL (v2)')
print('='*60)
print(f'Models: {list(MODELS.keys())}')
print(f'λ (fixed): {LAMBDA_FIXED}')
print(f'Conditions: {list(CONDITIONS.keys())}')
print(f'Problems: {N_PROBLEMS}')
print(f'Total inferences per model: {N_PROBLEMS * len(CONDITIONS)}')

## 2. Utility Functions (v2: Fixed JSON serialization)

In [None]:
def convert_to_native(obj):
    """Convert numpy/pandas types to Python native types for JSON serialization"""
    if isinstance(obj, dict):
        return {str(k): convert_to_native(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [convert_to_native(v) for v in obj]
    elif isinstance(obj, (np.integer,)):
        return int(obj)
    elif isinstance(obj, (np.floating,)):
        return float(obj)
    elif isinstance(obj, (np.bool_,)):
        return bool(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif pd.isna(obj):
        return None
    else:
        return obj

def save_json(data, filepath):
    """Save data to JSON with automatic type conversion"""
    converted_data = convert_to_native(data)
    with open(filepath, 'w', encoding='utf-8') as f:
        json.dump(converted_data, f, ensure_ascii=False, indent=2)
    print(f'Saved: {filepath}')

def load_json(filepath):
    """Load JSON file"""
    with open(filepath, 'r', encoding='utf-8') as f:
        return json.load(f)

print('Utility functions defined (v2: with numpy type conversion)')

## 3. API Setup

In [None]:
import getpass
from openai import OpenAI
import anthropic

print("OpenAI APIキーを入力してください：")
OPENAI_API_KEY = getpass.getpass("OpenAI API Key: ")

print("\nAnthropic APIキーを入力してください：")
ANTHROPIC_API_KEY = getpass.getpass("Anthropic API Key: ")

openai_client = OpenAI(api_key=OPENAI_API_KEY)
anthropic_client = anthropic.Anthropic(api_key=ANTHROPIC_API_KEY)

def call_api(prompt: str, model_config: dict, max_tokens: int = 512) -> str:
    """Unified API call for both providers with retry logic"""
    for attempt in range(3):
        try:
            if model_config['provider'] == 'openai':
                response = openai_client.chat.completions.create(
                    model=model_config['api_name'],
                    messages=[{"role": "user", "content": prompt}],
                    max_tokens=max_tokens,
                    temperature=0
                )
                return response.choices[0].message.content
            else:
                response = anthropic_client.messages.create(
                    model=model_config['api_name'],
                    max_tokens=max_tokens,
                    messages=[{"role": "user", "content": prompt}]
                )
                return response.content[0].text
        except Exception as e:
            print(f'API error (attempt {attempt+1}): {e}')
            time.sleep(2 ** attempt)
    return ""

# Test APIs
print('\nTesting APIs...')
for name, config in MODELS.items():
    resp = call_api("What is 2+2? Reply with just the number.", config)
    print(f'{name}: {resp.strip()}')

## 4. Load GSM8K

In [None]:
from datasets import load_dataset

@dataclass
class GSM8KProblem:
    index: int
    question: str
    answer_text: str
    final_answer: int

def extract_final_answer(answer_text: str) -> int:
    match = re.search(r'####\s*([\d,]+)', answer_text)
    if match:
        return int(match.group(1).replace(',', ''))
    raise ValueError('Could not extract final answer')

dataset = load_dataset('gsm8k', 'main', split='test')
print(f'GSM8K loaded: {len(dataset)} problems')

# Select problems with fixed seed
rng = random.Random(GLOBAL_SEED)
indices = list(range(len(dataset)))
rng.shuffle(indices)
selected_indices = indices[:N_PROBLEMS]

problems = []
for idx in selected_indices:
    item = dataset[idx]
    try:
        final_ans = extract_final_answer(item['answer'])
        problems.append(GSM8KProblem(
            index=idx,
            question=item['question'],
            answer_text=item['answer'],
            final_answer=final_ans
        ))
    except ValueError:
        continue

print(f'Selected problems: {len(problems)}')
print(f'First 5 indices: {[p.index for p in problems[:5]]}')

## 5. CoT Generation & Contamination

In [None]:
def derive_seed(base_seed: int, problem_idx: int, lam: float) -> int:
    """Derive deterministic seed for reproducibility"""
    h = hashlib.md5(f"{base_seed}_{problem_idx}_{lam}".encode()).hexdigest()
    return int(h[:8], 16)

def generate_clean_steps(problem: GSM8KProblem, I: int) -> List[str]:
    """Generate clean reasoning steps from gold answer"""
    raw = problem.answer_text.split('####')[0].strip()
    lines = [l.strip() for l in raw.split('\n') if l.strip()]
    
    if len(lines) >= I:
        return lines[:I]
    
    while len(lines) < I:
        lines.append(f"Step {len(lines)+1}: (continuing calculation)")
    
    return lines[:I]

def generate_wrong_step(step: str, rng: random.Random, correct_answer: int) -> str:
    """Generate coherent-but-wrong step (WRONG type contamination)"""
    numbers = re.findall(r'\b\d+\b', step)
    if numbers:
        for num in numbers:
            if int(num) != correct_answer:
                wrong_num = str(int(num) + rng.choice([-3, -2, -1, 1, 2, 3]))
                step = step.replace(num, wrong_num, 1)
                break
    return step

def generate_mixed_cot(problem: GSM8KProblem, I: int, lam: float, seed: int) -> Tuple[List[str], int]:
    """Generate contaminated CoT with specified λ"""
    rng = random.Random(seed)
    clean_steps = generate_clean_steps(problem, I)
    
    n_wrong = int(I * lam)
    wrong_indices = set(rng.sample(range(I), n_wrong))
    
    mixed = []
    for i, step in enumerate(clean_steps):
        if i in wrong_indices:
            mixed.append(generate_wrong_step(step, rng, problem.final_answer))
        else:
            mixed.append(step)
    
    # Generate wrong final answer
    wrong_answer = problem.final_answer + rng.choice([-5, -3, -2, 2, 3, 5])
    
    return mixed, wrong_answer

# Test
test_prob = problems[0]
test_seed = derive_seed(GLOBAL_SEED, test_prob.index, LAMBDA_FIXED)
test_steps, test_wrong_ans = generate_mixed_cot(test_prob, I_FIXED, LAMBDA_FIXED, test_seed)
print(f'Test problem correct answer: {test_prob.final_answer}')
print(f'Test wrong answer: {test_wrong_ans}')
print(f'Test steps (first 3): {test_steps[:3]}')

## 6. Answer Removal with Regex Verification

In [None]:
def create_trace_with_answer(steps: List[str], wrong_answer: int) -> str:
    """Create full trace with final answer line"""
    trace = "\n".join([f"Step {i+1}: {step}" for i, step in enumerate(steps)])
    trace += f"\n\nTherefore, the answer is {wrong_answer}."
    return trace

def create_trace_without_answer(steps: List[str], wrong_answer: int) -> str:
    """Create trace with answer removed"""
    trace = "\n".join([f"Step {i+1}: {step}" for i, step in enumerate(steps)])
    return trace

def verify_answer_removed(trace: str, wrong_answer: int) -> bool:
    """Verify that the wrong answer does not appear in the trace"""
    if re.search(rf'\b{wrong_answer}\b', trace):
        return False
    answer_patterns = [
        r'the answer is',
        r'final answer',
        r'therefore.*\d+',
        r'so we get.*\d+',
        r'result is.*\d+',
    ]
    for pattern in answer_patterns:
        if re.search(pattern, trace, re.IGNORECASE):
            return False
    return True

# Test
trace_with = create_trace_with_answer(test_steps, test_wrong_ans)
trace_without = create_trace_without_answer(test_steps, test_wrong_ans)
print(f'Answer removal verified: {verify_answer_removed(trace_without, test_wrong_ans)}')

## 7. Prompt Templates

In [None]:
def create_prompt(problem: GSM8KProblem, trace: str, condition_id: str, condition_config: dict) -> str:
    """Create prompt based on condition configuration"""
    instruction = condition_config['instruction']
    authority = condition_config['authority']
    
    if instruction == 'USE' and authority == 'AUTH':
        header = "Use the following expert solution to solve the problem:"
        trace_label = "Expert Solution"
    elif instruction == 'USE' and authority == 'UNCERT':
        header = "Use the following attempted solution (may contain errors):"
        trace_label = "Attempted Solution (may contain errors)"
    elif instruction == 'NEUTRAL' and authority == 'AUTH':
        header = "Here is an expert reasoning trace:"
        trace_label = "Expert Reasoning Trace"
    elif instruction == 'NEUTRAL' and authority == 'UNCERT':
        header = "Here is an attempted reasoning trace (may contain errors):"
        trace_label = "Attempted Reasoning Trace (may contain errors)"
    else:
        raise ValueError(f"Invalid condition config: {condition_config}")
    
    if instruction == 'USE':
        footer = "Provide your final answer."
    else:
        footer = "Solve the problem and provide your final answer."
    
    prompt = f"""Problem: {problem.question}

{header}

--- {trace_label} ---
{trace}
---

{footer}

Reply with ONLY: {{"final": <number>}}"""
    
    return prompt

def create_direct_prompt(problem: GSM8KProblem) -> str:
    """Create direct baseline prompt (no trace)"""
    return f"""Solve this math problem. Give ONLY the final numerical answer in JSON format.

Problem: {problem.question}

Reply with ONLY: {{"final": <number>}}"""

def parse_answer(response: str) -> Optional[int]:
    """Parse numerical answer from response"""
    try:
        match = re.search(r'\{[^}]*"final"\s*:\s*(\d+)[^}]*\}', response)
        if match:
            return int(match.group(1))
    except:
        pass
    try:
        numbers = re.findall(r'\b(\d+)\b', response)
        if numbers:
            return int(numbers[-1])
    except:
        pass
    return None

print('Prompt templates defined.')

## 8. Checkpoint & Resume Logic

In [None]:
def get_checkpoint_path(model_short: str) -> str:
    return f"{SAVE_DIR_EXP}/checkpoints/checkpoint_{model_short}.json"

def save_checkpoint(results: List[dict], model_short: str, completed_conditions: List[str]):
    """Save checkpoint with completed conditions"""
    checkpoint = {
        'model': model_short,
        'completed_conditions': completed_conditions,
        'n_results': len(results),
        'timestamp': datetime.now().isoformat(),
        'results': results
    }
    save_json(checkpoint, get_checkpoint_path(model_short))

def load_checkpoint(model_short: str) -> Tuple[List[dict], List[str]]:
    """Load checkpoint if exists"""
    checkpoint_path = get_checkpoint_path(model_short)
    if os.path.exists(checkpoint_path):
        checkpoint = load_json(checkpoint_path)
        print(f'✓ Checkpoint found for {model_short}')
        print(f'  Completed conditions: {checkpoint["completed_conditions"]}')
        print(f'  Results: {checkpoint["n_results"]}')
        return checkpoint['results'], checkpoint['completed_conditions']
    return [], []

print('Checkpoint functions defined.')

## 9. Select Model to Run

In [None]:
#@title Select Model { run: "auto" }
MODEL_CHOICE = "Claude 4 Sonnet" #@param ["Claude 4 Sonnet", "GPT-4o", "Claude 3.5 Haiku"]

model_config = MODELS[MODEL_CHOICE]
model_short = model_config['short']

print(f'Selected model: {MODEL_CHOICE}')
print(f'API name: {model_config["api_name"]}')
print(f'Short name: {model_short}')

# Check for existing checkpoint
existing_results, completed_conditions = load_checkpoint(model_short)

# Define condition order
condition_order = ['A0_DIRECT', 'A1_USE_AUTH_ANS', 'A2_USE_UNCERT_ANS', 
                   'A3_NEUTRAL_AUTH_ANS', 'A4_NEUTRAL_UNCERT_ANS',
                   'A5_USE_AUTH_NOANS', 'A6_USE_UNCERT_NOANS']

if completed_conditions:
    remaining = [c for c in condition_order if c not in completed_conditions]
    if remaining:
        print(f'\nRemaining conditions: {remaining}')
    else:
        print(f'\n✓ ALL CONDITIONS COMPLETED - Will show results only')
else:
    print('\nNo checkpoint found. Starting fresh.')

## 10. Run Experiment (or Load Completed Results)

In [None]:
print('='*60)
print(f'EXPERIMENT A: {MODEL_CHOICE}')
print('='*60)

# Resume from checkpoint or start fresh
all_results = existing_results.copy()
completed = set(completed_conditions)

# Check if all conditions are already completed
all_completed = all(c in completed for c in condition_order)

if all_completed:
    print('\n✓ All conditions already completed. Skipping to analysis.')
else:
    for condition_id in condition_order:
        if condition_id in completed:
            print(f'\n--- {condition_id}: SKIPPED (already completed) ---')
            continue
        
        condition_config = CONDITIONS[condition_id]
        print(f'\n--- {condition_id} ---')
        
        for problem in tqdm(problems, desc=condition_id):
            # Generate trace
            seed = derive_seed(GLOBAL_SEED, problem.index, LAMBDA_FIXED)
            cot_steps, wrong_answer = generate_mixed_cot(problem, I_FIXED, LAMBDA_FIXED, seed)
            
            # Create appropriate prompt
            if condition_id == 'A0_DIRECT':
                prompt = create_direct_prompt(problem)
                trace_used = None
                answer_removed_verified = None
            else:
                if condition_config['answer_present']:
                    trace = create_trace_with_answer(cot_steps, wrong_answer)
                    answer_removed_verified = False
                else:
                    trace = create_trace_without_answer(cot_steps, wrong_answer)
                    answer_removed_verified = verify_answer_removed(trace, wrong_answer)
                
                prompt = create_prompt(problem, trace, condition_id, condition_config)
                trace_used = trace
            
            # Call API
            response = call_api(prompt, model_config)
            answer = parse_answer(response)
            is_correct = (answer == problem.final_answer) if answer else False
            
            # Record result
            result = {
                'experiment_id': EXPERIMENT_ID,
                'problem_index': int(problem.index),
                'model': MODEL_CHOICE,
                'model_short': model_short,
                'condition_id': condition_id,
                'instruction': condition_config.get('instruction'),
                'authority': condition_config.get('authority'),
                'answer_present': condition_config.get('answer_present'),
                'lam': LAMBDA_FIXED if condition_id != 'A0_DIRECT' else None,
                'model_answer': int(answer) if answer else None,
                'correct_answer': int(problem.final_answer),
                'wrong_answer_in_trace': int(wrong_answer) if condition_id != 'A0_DIRECT' else None,
                'is_correct': bool(is_correct),
                'answer_removed_verified': answer_removed_verified,
                'raw_output': response,
                'timestamp': datetime.now().isoformat()
            }
            all_results.append(result)
            
            time.sleep(0.5)
        
        # Save checkpoint after each condition
        completed.add(condition_id)
        save_checkpoint(all_results, model_short, list(completed))
        print(f'✓ {condition_id} complete. Checkpoint saved.')

    # Save final results
    save_json(all_results, f"{SAVE_DIR_EXP}/results/exp_A_results_{model_short}.json")

print('\n' + '='*60)
print('✓ EXPERIMENT A DATA READY!')
print('='*60)

## 11. Analyze Results

In [None]:
import matplotlib.pyplot as plt
from scipy import stats

df = pd.DataFrame(all_results)

# Calculate accuracy by condition
acc_by_condition = df.groupby('condition_id')['is_correct'].agg(['mean', 'sum', 'count'])
acc_by_condition.columns = ['accuracy', 'correct', 'total']

print('='*60)
print(f'EXPERIMENT A RESULTS: {MODEL_CHOICE}')
print('='*60)
print('\nAccuracy by Condition:')
for cond in condition_order:
    if cond in acc_by_condition.index:
        acc = acc_by_condition.loc[cond, 'accuracy']
        n = acc_by_condition.loc[cond, 'total']
        print(f'  {cond}: {acc:.1%} (n={int(n)})')

# Get DIRECT baseline
direct_acc = float(acc_by_condition.loc['A0_DIRECT', 'accuracy'])
print(f'\nDIRECT baseline: {direct_acc:.1%}')

# Calculate deltas
print('\nΔ vs DIRECT:')
for cond in condition_order[1:]:
    if cond in acc_by_condition.index:
        delta = float(acc_by_condition.loc[cond, 'accuracy']) - direct_acc
        print(f'  {cond}: {delta:+.1%}')

In [None]:
# Calculate CIF and Recovery
direct_results = df[df['condition_id'] == 'A0_DIRECT'][['problem_index', 'is_correct']].copy()
direct_results.columns = ['problem_index', 'direct_correct']

cif_recovery = {}

for cond in condition_order[1:]:
    cond_results = df[df['condition_id'] == cond][['problem_index', 'is_correct']].copy()
    cond_results.columns = ['problem_index', 'cond_correct']
    
    merged = direct_results.merge(cond_results, on='problem_index')
    
    # CIF: DIRECT correct → Condition wrong
    direct_correct_mask = merged['direct_correct'] == True
    cif_count = int(((merged['direct_correct'] == True) & (merged['cond_correct'] == False)).sum())
    cif_base = int(direct_correct_mask.sum())
    cif_rate = cif_count / cif_base if cif_base > 0 else 0.0
    
    # Recovery: DIRECT wrong → Condition correct
    direct_wrong_mask = merged['direct_correct'] == False
    recovery_count = int(((merged['direct_correct'] == False) & (merged['cond_correct'] == True)).sum())
    recovery_base = int(direct_wrong_mask.sum())
    recovery_rate = recovery_count / recovery_base if recovery_base > 0 else 0.0
    
    cif_recovery[cond] = {
        'CIF_count': cif_count,
        'CIF_base': cif_base,
        'CIF_rate': float(cif_rate),
        'Recovery_count': recovery_count,
        'Recovery_base': recovery_base,
        'Recovery_rate': float(recovery_rate),
        'Asymmetry': cif_count - recovery_count
    }

print('\n' + '='*60)
print('CIF & RECOVERY ANALYSIS')
print('='*60)
print('\nCIF = P(Condition wrong | DIRECT correct)')
print('Recovery = P(Condition correct | DIRECT wrong)')

for cond, metrics in cif_recovery.items():
    print(f'\n{cond}:')
    print(f'  CIF: {metrics["CIF_rate"]:.1%} ({metrics["CIF_count"]}/{metrics["CIF_base"]})')
    print(f'  Recovery: {metrics["Recovery_rate"]:.1%} ({metrics["Recovery_count"]}/{metrics["Recovery_base"]})')
    print(f'  Asymmetry: {metrics["Asymmetry"]:+d}')

In [None]:
# McNemar's Test
print('\n' + '='*60)
print('McNEMAR\'S TEST (vs DIRECT)')
print('='*60)

mcnemar_results = {}

for cond in condition_order[1:]:
    cond_results = df[df['condition_id'] == cond][['problem_index', 'is_correct']].copy()
    cond_results.columns = ['problem_index', 'cond_correct']
    
    merged = direct_results.merge(cond_results, on='problem_index')
    
    b = int(((merged['direct_correct'] == True) & (merged['cond_correct'] == False)).sum())
    c = int(((merged['direct_correct'] == False) & (merged['cond_correct'] == True)).sum())
    
    if b + c > 0:
        chi2 = float((abs(b - c) - 1) ** 2 / (b + c))
        p_value = float(1 - stats.chi2.cdf(chi2, df=1))
    else:
        chi2, p_value = 0.0, 1.0
    
    mcnemar_results[cond] = {
        'b_CIF': b,
        'c_Recovery': c,
        'chi2': chi2,
        'p_value': p_value
    }
    
    sig = '***' if p_value < 0.001 else '**' if p_value < 0.01 else '*' if p_value < 0.05 else ''
    print(f'\n{cond}:')
    print(f'  b (CIF): {b}, c (Recovery): {c}')
    print(f'  χ² = {chi2:.2f}, p = {p_value:.4f} {sig}')

In [None]:
# Visualization
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Plot 1: Accuracy by condition
ax1 = axes[0]
conditions_plot = condition_order
accs = [float(acc_by_condition.loc[c, 'accuracy']) * 100 if c in acc_by_condition.index else 0 
        for c in conditions_plot]

colors = ['#333333', '#d62728', '#ff7f0e', '#2ca02c', '#9467bd', '#1f77b4', '#17becf']
bars = ax1.bar(range(len(conditions_plot)), accs, color=colors, edgecolor='black', linewidth=1.5)

for i, (bar, acc) in enumerate(zip(bars, accs)):
    ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
            f'{acc:.1f}%', ha='center', va='bottom', fontsize=10, fontweight='bold')

ax1.set_xticks(range(len(conditions_plot)))
ax1.set_xticklabels([c.replace('_', '\n') for c in conditions_plot], fontsize=9)
ax1.set_ylabel('Accuracy (%)', fontsize=12)
ax1.set_title(f'Experiment A: Accuracy by Condition\n{MODEL_CHOICE}', fontsize=13)
ax1.set_ylim(0, 105)
ax1.axhline(y=direct_acc * 100, color='gray', linestyle='--', alpha=0.7, label='DIRECT baseline')
ax1.legend()

# Plot 2: CIF vs Recovery
ax2 = axes[1]
conds = list(cif_recovery.keys())
cif_rates = [cif_recovery[c]['CIF_rate'] * 100 for c in conds]
recovery_rates = [cif_recovery[c]['Recovery_rate'] * 100 for c in conds]

x = np.arange(len(conds))
width = 0.35

bars1 = ax2.bar(x - width/2, cif_rates, width, label='CIF', color='#d62728', edgecolor='black')
bars2 = ax2.bar(x + width/2, recovery_rates, width, label='Recovery', color='#2ca02c', edgecolor='black')

ax2.set_xticks(x)
ax2.set_xticklabels([c.replace('_', '\n') for c in conds], fontsize=9)
ax2.set_ylabel('Rate (%)', fontsize=12)
ax2.set_title(f'CIF vs Recovery by Condition\n{MODEL_CHOICE}', fontsize=13)
ax2.legend()
ax2.set_ylim(0, max(max(cif_rates), max(recovery_rates)) * 1.2 + 5)

plt.tight_layout()
plt.savefig(f"{SAVE_DIR_EXP}/exp_A_results_{model_short}.png", dpi=300)
plt.show()

print(f'\nFigure saved: {SAVE_DIR_EXP}/exp_A_results_{model_short}.png')

In [None]:
# Save summary (v2: with proper type conversion)
summary = {
    'experiment_id': EXPERIMENT_ID,
    'model': MODEL_CHOICE,
    'model_short': model_short,
    'date': EXPERIMENT_DATE,
    'n_problems': len(problems),
    'lambda': LAMBDA_FIXED,
    'accuracy_by_condition': acc_by_condition.to_dict(),
    'cif_recovery': cif_recovery,
    'mcnemar_tests': mcnemar_results,
    'direct_baseline': direct_acc
}

save_json(summary, f"{SAVE_DIR_EXP}/results/exp_A_summary_{model_short}.json")

print('\n' + '='*60)
print('SUMMARY SAVED')
print('='*60)
print(f'Results: {SAVE_DIR_EXP}/results/exp_A_results_{model_short}.json')
print(f'Summary: {SAVE_DIR_EXP}/results/exp_A_summary_{model_short}.json')
print(f'Figure: {SAVE_DIR_EXP}/exp_A_results_{model_short}.png')

## 12. Factor Analysis (Optional)

In [None]:
import statsmodels.formula.api as smf

# Prepare data for factor analysis (exclude DIRECT)
df_factors = df[df['condition_id'] != 'A0_DIRECT'].copy()

# Create binary factors
df_factors['instruction_use'] = (df_factors['instruction'] == 'USE').astype(int)
df_factors['authority_auth'] = (df_factors['authority'] == 'AUTH').astype(int)
df_factors['answer_present'] = df_factors['answer_present'].fillna(True).astype(int)
df_factors['is_correct_int'] = df_factors['is_correct'].astype(int)

print('='*60)
print('FACTOR ANALYSIS: MAIN EFFECTS')
print('='*60)

# Logistic regression with main effects
try:
    model1 = smf.logit('is_correct_int ~ instruction_use + authority_auth', data=df_factors)
    result1 = model1.fit(disp=0)
    print(result1.summary())
    print('\nOdds Ratios:')
    print(np.exp(result1.params))
except Exception as e:
    print(f'Factor analysis error: {e}')

In [None]:
# Key findings summary
print('\n' + '='*60)
print('KEY FINDINGS SUMMARY')
print('='*60)

# USE vs NEUTRAL effect
use_acc = df_factors[df_factors['instruction'] == 'USE']['is_correct'].mean()
neutral_acc = df_factors[df_factors['instruction'] == 'NEUTRAL']['is_correct'].mean()
print(f'\n1. INSTRUCTION EFFECT:')
print(f'   USE: {use_acc:.1%} vs NEUTRAL: {neutral_acc:.1%} (Δ = {use_acc - neutral_acc:+.1%})')

# AUTH vs UNCERT effect
auth_acc = df_factors[df_factors['authority'] == 'AUTH']['is_correct'].mean()
uncert_acc = df_factors[df_factors['authority'] == 'UNCERT']['is_correct'].mean()
print(f'\n2. AUTHORITY EFFECT:')
print(f'   AUTH: {auth_acc:.1%} vs UNCERT: {uncert_acc:.1%} (Δ = {auth_acc - uncert_acc:+.1%})')

# Answer Present vs Removed (USE conditions only)
df_use = df_factors[df_factors['instruction'] == 'USE']
present_acc = df_use[df_use['answer_present'] == 1]['is_correct'].mean()
removed_acc = df_use[df_use['answer_present'] == 0]['is_correct'].mean()
print(f'\n3. ANSWER REMOVAL EFFECT (USE only):')
print(f'   Present: {present_acc:.1%} vs Removed: {removed_acc:.1%} (Δ = {removed_acc - present_acc:+.1%})')

print('\n' + '='*60)
print('EXPERIMENT A ANALYSIS COMPLETE!')
print('='*60)