# CoT A1-E1': Format Swap Experiment (Fixed Version)

## Purpose
Demonstrate that **task format** (MC vs Open) determines CIF vulnerability, not domain.

## Key Fixes from E1
1. **Stronger instruction**: Explicitly forbid calculations, demand letter-only output
2. **Few-shot example**: Show the model what a correct response looks like
3. **Reduced max_tokens**: 10 tokens (letter only, no room for calculations)
4. **Robust extraction**: Also match numerical answers to choices

## Design
| Original | Transform | Prediction |
|----------|-----------|------------|
| GSM8K (Open) | → MC (4択) | CIF ↑ |
| CSQA (MC) | → Open | CIF ↓ |

## Conditions
- DIRECT: No trace
- USE: Contaminated trace (λ=0.8)
- USE_NOANS: Trace with answer removed

In [None]:
# ============================================================
# CELL 1: SETUP & DIRECTORIES
# ============================================================
from google.colab import drive
drive.mount('/content/drive')

import os
from datetime import datetime

EXPERIMENT_ID = 'A1_E1prime'
EXPERIMENT_DATE = datetime.now().strftime('%Y%m%d')
SAVE_DIR = '/content/drive/MyDrive/CoT_Experiment'
SAVE_DIR_EXP = f'{SAVE_DIR}/exp_{EXPERIMENT_ID}_format_swap_{EXPERIMENT_DATE}'
os.makedirs(SAVE_DIR_EXP, exist_ok=True)
os.makedirs(f'{SAVE_DIR_EXP}/results', exist_ok=True)
os.makedirs(f'{SAVE_DIR_EXP}/checkpoints', exist_ok=True)
os.makedirs(f'{SAVE_DIR_EXP}/traces', exist_ok=True)

# Try to reuse traces from E1
E1_DIR = f'{SAVE_DIR}/exp_A1_E1_format_swap_20260120'
REUSE_TRACES = os.path.exists(E1_DIR)

print(f'Experiment ID: {EXPERIMENT_ID}')
print(f'Save directory: {SAVE_DIR_EXP}')
print(f'Reuse E1 traces: {REUSE_TRACES}')

In [None]:
# ============================================================
# CELL 2: INSTALL DEPENDENCIES
# ============================================================
!pip install datasets openai anthropic pandas tqdm matplotlib -q
print('Dependencies installed.')

In [None]:
# ============================================================
# CELL 3: IMPORTS & CONFIGURATION
# ============================================================
import json
import re
import random
import time
from typing import List, Dict, Optional, Any, Tuple
from dataclasses import dataclass, asdict
from tqdm import tqdm
import pandas as pd
import numpy as np

# ============================================================
# CONFIGURATION
# ============================================================
GLOBAL_SEED = 20260120
N_PROBLEMS = 100  # Per task
LAMBDA_FIXED = 0.8

# Conditions
CONDITIONS = ['DIRECT', 'USE', 'USE_NOANS']

# Models to test
MODELS = {
    'Claude 4 Sonnet': {
        'provider': 'anthropic',
        'api_name': 'claude-sonnet-4-20250514',
        'short': 'sonnet4'
    },
    'GPT-4o': {
        'provider': 'openai',
        'api_name': 'gpt-4o',
        'short': 'gpt4o'
    }
}

# Baseline from Experiment B (for comparison)
EXP_B_BASELINE = {
    'sonnet4': {
        'gsm8k_open': {'direct': 0.96, 'use': 0.96, 'cif': 0.00},
        'csqa_mc': {'direct': 0.90, 'use': 0.49, 'cif': 0.46}
    },
    'gpt4o': {
        'gsm8k_open': {'direct': 0.58, 'use': 0.79, 'cif': 0.21},
        'csqa_mc': {'direct': 0.85, 'use': 0.54, 'cif': 0.39}
    }
}

print('='*60)
print('EXPERIMENT A1-E1\' (PRIME): FORMAT SWAP - FIXED')
print('='*60)
print(f'Models: {list(MODELS.keys())}')
print(f'λ (fixed): {LAMBDA_FIXED}')
print(f'Conditions: {CONDITIONS}')
print(f'Problems per task: {N_PROBLEMS}')
print(f'Tasks: GSM8K-MC (converted), CSQA-Open (converted)')
print('\nKEY FIXES:')
print('  - Stronger instruction (no calculations allowed)')
print('  - Few-shot example for letter-only response')
print('  - max_tokens=10 to prevent calculation attempts')
print('  - Robust answer extraction')

In [None]:
# ============================================================
# CELL 4: UTILITY FUNCTIONS
# ============================================================
def convert_to_native(obj):
    """Convert numpy/pandas types to Python native types for JSON serialization"""
    if isinstance(obj, dict):
        return {str(k): convert_to_native(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [convert_to_native(v) for v in obj]
    elif isinstance(obj, (np.integer,)):
        return int(obj)
    elif isinstance(obj, (np.floating,)):
        return float(obj)
    elif isinstance(obj, (np.bool_,)):
        return bool(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif pd.isna(obj):
        return None
    else:
        return obj

def save_json(data, filepath):
    """Save data to JSON with automatic type conversion"""
    converted_data = convert_to_native(data)
    with open(filepath, 'w', encoding='utf-8') as f:
        json.dump(converted_data, f, ensure_ascii=False, indent=2)
    print(f'Saved: {filepath}')

def load_json(filepath):
    """Load JSON file"""
    if os.path.exists(filepath):
        with open(filepath, 'r', encoding='utf-8') as f:
            return json.load(f)
    return None

print('Utility functions defined.')

In [None]:
# ============================================================
# CELL 5: API SETUP
# ============================================================
import getpass
from openai import OpenAI
import anthropic

print("OpenAI APIキーを入力してください：")
OPENAI_API_KEY = getpass.getpass("OpenAI API Key: ")

print("\nAnthropic APIキーを入力してください：")
ANTHROPIC_API_KEY = getpass.getpass("Anthropic API Key: ")

openai_client = OpenAI(api_key=OPENAI_API_KEY)
anthropic_client = anthropic.Anthropic(api_key=ANTHROPIC_API_KEY)

def call_api(prompt: str, model_config: dict, max_tokens: int = 512, system: str = None) -> str:
    """Unified API call for both providers with retry logic"""
    for attempt in range(3):
        try:
            if model_config['provider'] == 'openai':
                messages = [{"role": "user", "content": prompt}]
                if system:
                    messages.insert(0, {"role": "system", "content": system})
                response = openai_client.chat.completions.create(
                    model=model_config['api_name'],
                    messages=messages,
                    max_tokens=max_tokens,
                    temperature=0
                )
                return response.choices[0].message.content
            else:
                kwargs = {
                    'model': model_config['api_name'],
                    'max_tokens': max_tokens,
                    'messages': [{"role": "user", "content": prompt}]
                }
                if system:
                    kwargs['system'] = system
                response = anthropic_client.messages.create(**kwargs)
                return response.content[0].text
        except Exception as e:
            print(f'API error (attempt {attempt+1}): {e}')
            time.sleep(2 ** attempt)
    return ""

# Test APIs
print('\nTesting APIs...')
for name, config in MODELS.items():
    resp = call_api("What is 2+2? Reply with just the number.", config)
    print(f'{name}: {resp.strip()}')

In [None]:
# ============================================================
# CELL 6: LOAD DATASETS
# ============================================================
from datasets import load_dataset

print('Loading datasets...')

# Load GSM8K
gsm8k_dataset = load_dataset('openai/gsm8k', 'main', split='test')
print(f'✓ GSM8K loaded: {len(gsm8k_dataset)} problems')

# Load CommonsenseQA
csqa_dataset = load_dataset('tau/commonsense_qa', split='validation')
print(f'✓ CommonsenseQA loaded: {len(csqa_dataset)} problems')

# Sample with fixed seed (SAME as E1 for comparability)
rng = random.Random(GLOBAL_SEED)

gsm8k_indices = list(range(len(gsm8k_dataset)))
rng.shuffle(gsm8k_indices)
gsm8k_indices = gsm8k_indices[:N_PROBLEMS]

csqa_indices = list(range(len(csqa_dataset)))
rng.shuffle(csqa_indices)
csqa_indices = csqa_indices[:N_PROBLEMS]

print(f'\n✓ Sampled {N_PROBLEMS} problems from each dataset')
print(f'GSM8K indices (first 5): {gsm8k_indices[:5]}')
print(f'CSQA indices (first 5): {csqa_indices[:5]}')

In [None]:
# ============================================================
# CELL 7: HELPER FUNCTIONS FOR FORMAT CONVERSION
# ============================================================

def extract_gsm8k_answer(answer_text: str) -> str:
    """Extract numerical answer from GSM8K answer text."""
    match = re.search(r'####\s*([\d,]+)', answer_text)
    if match:
        return match.group(1).replace(',', '')
    return ""

def generate_wrong_numbers(correct: str, n: int = 3) -> List[str]:
    """Generate plausible wrong numerical answers."""
    try:
        correct_num = int(correct)
        wrong_nums = set()
        
        # More diverse distractors
        candidates = [
            correct_num + random.randint(5, 30),
            correct_num - random.randint(5, 30) if correct_num > 30 else correct_num + random.randint(10, 40),
            int(correct_num * 1.2),
            int(correct_num * 0.8) if correct_num > 10 else correct_num + 7,
            correct_num + 100 if correct_num > 50 else correct_num + 15,
            correct_num * 2 if correct_num < 100 else correct_num + 50,
        ]
        
        for c in candidates:
            if c > 0 and c != correct_num:
                wrong_nums.add(str(c))
            if len(wrong_nums) >= n:
                break
        
        while len(wrong_nums) < n:
            rand_num = correct_num + random.randint(-50, 50)
            if rand_num > 0 and rand_num != correct_num:
                wrong_nums.add(str(rand_num))
        
        return list(wrong_nums)[:n]
    except:
        return [str(random.randint(1, 100)) for _ in range(n)]

def get_csqa_correct_text(problem: dict) -> str:
    """Get the correct answer text for CSQA problem."""
    correct_key = problem['answerKey']
    choices = problem['choices']
    for i, label in enumerate(choices['label']):
        if label == correct_key:
            return choices['text'][i]
    return ""

def get_csqa_wrong_text(problem: dict) -> Tuple[str, str]:
    """Get a random wrong answer text and its label for CSQA."""
    correct_key = problem['answerKey']
    choices = problem['choices']
    wrong_options = []
    for i, label in enumerate(choices['label']):
        if label != correct_key:
            wrong_options.append((label, choices['text'][i]))
    
    if wrong_options:
        return random.choice(wrong_options)
    return ("", "")

def remove_answer_from_trace(trace: str, answer: str) -> str:
    """Remove the final answer from trace."""
    patterns = [
        rf'Therefore,?\s*the answer is\s*[\[]?{re.escape(str(answer))}[\]]?[.]?',
        rf'[Aa]nswer is\s*[\[]?{re.escape(str(answer))}[\]]?[.]?',
        rf'The answer is\s*[\[]?{re.escape(str(answer))}[\]]?[.]?',
        rf'answer:\s*[\[]?{re.escape(str(answer))}[\]]?',
    ]
    result = trace
    for pattern in patterns:
        result = re.sub(pattern, '[Answer removed]', result, flags=re.IGNORECASE)
    return result

print('Helper functions defined.')

In [None]:
# ============================================================
# CELL 8: LOAD OR GENERATE GSM8K TRACES
# ============================================================

def generate_gsm8k_wrong_trace(question: str, correct_answer: str) -> Dict:
    """Generate a contaminated trace for GSM8K."""
    
    prompt = f"""Solve this math problem step by step, but make a subtle error in your reasoning that leads to a wrong final answer.

Problem: {question}

Requirements:
1. Show clear step-by-step reasoning
2. Make ONE subtle mathematical error (wrong operation, miscalculation, or misreading)
3. The error should seem plausible, not obvious
4. End with "Therefore, the answer is [NUMBER]." where NUMBER is your (wrong) result
5. Do NOT arrive at {correct_answer} - that's the correct answer

Show your (intentionally flawed) solution:"""

    trace = call_api(prompt, MODELS['Claude 4 Sonnet'], max_tokens=1500)
    
    # Extract wrong answer
    match = re.search(r'answer is\s*[\$]?([\d,]+)', trace, re.IGNORECASE)
    wrong_answer = match.group(1).replace(',', '') if match else ""
    
    # Verify it's wrong
    if wrong_answer == correct_answer:
        try:
            wrong_num = int(correct_answer) + random.choice([10, -10, 5, -5, 20])
            if wrong_num < 0:
                wrong_num = int(correct_answer) + 15
            wrong_answer = str(wrong_num)
            trace = re.sub(
                r'answer is\s*[\$]?[\d,]+',
                f'answer is {wrong_answer}',
                trace, flags=re.IGNORECASE
            )
        except:
            pass
    
    return {
        'trace': trace,
        'wrong_answer': wrong_answer,
        'correct_answer': correct_answer
    }

# Try to load from E1 first
e1_trace_file = f'{E1_DIR}/traces/gsm8k_traces.json' if REUSE_TRACES else None
trace_file = f'{SAVE_DIR_EXP}/traces/gsm8k_traces.json'

gsm8k_traces = None
if e1_trace_file and os.path.exists(e1_trace_file):
    gsm8k_traces = load_json(e1_trace_file)
    print(f'✓ Loaded {len(gsm8k_traces)} traces from E1')
    save_json(gsm8k_traces, trace_file)  # Copy to E1'
else:
    gsm8k_traces = load_json(trace_file)
    if gsm8k_traces:
        print(f'✓ Loaded {len(gsm8k_traces)} cached GSM8K traces')
    else:
        gsm8k_traces = {}

# Generate missing traces
missing = [idx for idx in gsm8k_indices if str(idx) not in gsm8k_traces]
if missing:
    print(f'\nGenerating {len(missing)} missing GSM8K traces...')
    for idx in tqdm(missing, desc='GSM8K Traces'):
        problem = gsm8k_dataset[idx]
        correct_answer = extract_gsm8k_answer(problem['answer'])
        trace_data = generate_gsm8k_wrong_trace(problem['question'], correct_answer)
        gsm8k_traces[str(idx)] = trace_data
        
        if len(gsm8k_traces) % 20 == 0:
            save_json(gsm8k_traces, trace_file)
    
    save_json(gsm8k_traces, trace_file)

print(f'\n✓ GSM8K traces ready: {len(gsm8k_traces)}')

In [None]:
# ============================================================
# CELL 9: CONVERT GSM8K TO MC FORMAT
# ============================================================

def create_gsm8k_mc_problem(idx: int, problem: dict, trace_data: dict) -> Dict:
    """Convert GSM8K problem to multiple-choice format."""
    
    question = problem['question']
    correct_answer = extract_gsm8k_answer(problem['answer'])
    wrong_from_trace = trace_data['wrong_answer']
    
    # Generate distractor answers
    other_wrong = generate_wrong_numbers(correct_answer, 2)
    
    # Ensure uniqueness
    all_answers = {correct_answer, wrong_from_trace}
    final_distractors = []
    for w in other_wrong:
        if w not in all_answers:
            final_distractors.append(w)
            all_answers.add(w)
    
    while len(final_distractors) < 2:
        new_wrong = str(int(correct_answer) + random.randint(1, 100))
        if new_wrong not in all_answers:
            final_distractors.append(new_wrong)
            all_answers.add(new_wrong)
    
    # Create choices
    choices = [
        ('correct', correct_answer),
        ('trace_wrong', wrong_from_trace),
        ('distractor1', final_distractors[0]),
        ('distractor2', final_distractors[1])
    ]
    random.shuffle(choices)
    
    choice_labels = ['A', 'B', 'C', 'D']
    formatted_choices = []
    correct_label = None
    trace_wrong_label = None
    choice_value_map = {}  # For robust extraction
    
    for i, (choice_type, value) in enumerate(choices):
        label = choice_labels[i]
        formatted_choices.append(f"{label}. {value}")
        choice_value_map[value] = label
        if choice_type == 'correct':
            correct_label = label
        elif choice_type == 'trace_wrong':
            trace_wrong_label = label
    
    return {
        'idx': idx,
        'question': question,
        'choices': formatted_choices,
        'choices_text': "\n".join(formatted_choices),
        'correct_label': correct_label,
        'correct_answer': correct_answer,
        'trace_wrong_label': trace_wrong_label,
        'trace_wrong_answer': wrong_from_trace,
        'choice_value_map': choice_value_map,  # NEW: for robust extraction
        'trace': trace_data['trace']
    }

# Convert
print('Converting GSM8K to MC format...')
gsm8k_mc_problems = []

for idx in tqdm(gsm8k_indices, desc='GSM8K→MC'):
    problem = gsm8k_dataset[idx]
    trace_data = gsm8k_traces[str(idx)]
    mc_problem = create_gsm8k_mc_problem(idx, problem, trace_data)
    gsm8k_mc_problems.append(mc_problem)

print(f'\n✓ Created {len(gsm8k_mc_problems)} GSM8K-MC problems')

# Show example
ex = gsm8k_mc_problems[0]
print('\n' + '='*50)
print('EXAMPLE GSM8K-MC:')
print('='*50)
print(f"Q: {ex['question'][:150]}...")
print(f"\n{ex['choices_text']}")
print(f"\nCorrect: {ex['correct_label']} ({ex['correct_answer']})")
print(f"Trace wrong: {ex['trace_wrong_label']} ({ex['trace_wrong_answer']})")

In [None]:
# ============================================================
# CELL 10: LOAD OR GENERATE CSQA TRACES
# ============================================================

def generate_csqa_wrong_trace(question: str, correct_text: str, wrong_text: str) -> Dict:
    """Generate a contaminated trace for CSQA."""
    
    prompt = f"""Answer this common sense question, but arrive at a specific (wrong) answer through plausible-sounding reasoning.

Question: {question}

You MUST conclude that the answer is: {wrong_text}

Requirements:
1. Provide step-by-step reasoning that sounds logical
2. Make it seem like "{wrong_text}" is the sensible answer
3. End with "Therefore, the answer is {wrong_text}."
4. Do NOT mention that this reasoning is flawed

Show your reasoning:"""

    trace = call_api(prompt, MODELS['Claude 4 Sonnet'], max_tokens=1000)
    
    if wrong_text.lower() not in trace.lower():
        trace += f"\n\nTherefore, the answer is {wrong_text}."
    
    return {
        'trace': trace,
        'wrong_answer': wrong_text,
        'correct_answer': correct_text
    }

# Try to load from E1
e1_csqa_file = f'{E1_DIR}/traces/csqa_traces.json' if REUSE_TRACES else None
csqa_trace_file = f'{SAVE_DIR_EXP}/traces/csqa_traces.json'

csqa_traces = None
if e1_csqa_file and os.path.exists(e1_csqa_file):
    csqa_traces = load_json(e1_csqa_file)
    print(f'✓ Loaded {len(csqa_traces)} traces from E1')
    save_json(csqa_traces, csqa_trace_file)
else:
    csqa_traces = load_json(csqa_trace_file)
    if csqa_traces:
        print(f'✓ Loaded {len(csqa_traces)} cached CSQA traces')
    else:
        csqa_traces = {}

# Generate missing
missing = [idx for idx in csqa_indices if str(idx) not in csqa_traces]
if missing:
    print(f'\nGenerating {len(missing)} missing CSQA traces...')
    for idx in tqdm(missing, desc='CSQA Traces'):
        problem = csqa_dataset[idx]
        correct_text = get_csqa_correct_text(problem)
        wrong_label, wrong_text = get_csqa_wrong_text(problem)
        
        trace_data = generate_csqa_wrong_trace(problem['question'], correct_text, wrong_text)
        trace_data['wrong_label'] = wrong_label
        csqa_traces[str(idx)] = trace_data
        
        if len(csqa_traces) % 20 == 0:
            save_json(csqa_traces, csqa_trace_file)
    
    save_json(csqa_traces, csqa_trace_file)

print(f'\n✓ CSQA traces ready: {len(csqa_traces)}')

In [None]:
# ============================================================
# CELL 11: PREPARE CSQA OPEN-ENDED PROBLEMS
# ============================================================

def create_csqa_open_problem(idx: int, problem: dict, trace_data: dict) -> Dict:
    """Create CSQA problem in open-ended format."""
    
    question = problem['question']
    correct_text = get_csqa_correct_text(problem)
    
    all_choices = {label: text for label, text in 
                   zip(problem['choices']['label'], problem['choices']['text'])}
    
    return {
        'idx': idx,
        'question': question,
        'correct_answer': correct_text,
        'correct_label': problem['answerKey'],
        'wrong_answer': trace_data['wrong_answer'],
        'wrong_label': trace_data.get('wrong_label', ''),
        'all_choices': all_choices,
        'trace': trace_data['trace']
    }

print('Creating CSQA open-ended problems...')
csqa_open_problems = []

for idx in tqdm(csqa_indices, desc='CSQA→Open'):
    problem = csqa_dataset[idx]
    trace_data = csqa_traces[str(idx)]
    open_problem = create_csqa_open_problem(idx, problem, trace_data)
    csqa_open_problems.append(open_problem)

print(f'\n✓ Created {len(csqa_open_problems)} CSQA-Open problems')

# Show example
ex = csqa_open_problems[0]
print('\n' + '='*50)
print('EXAMPLE CSQA-OPEN:')
print('='*50)
print(f"Q: {ex['question']}")
print(f"Correct: {ex['correct_answer']}")
print(f"Trace wrong: {ex['wrong_answer']}")

In [None]:
# ============================================================
# CELL 12: PROMPT TEMPLATES (FIXED!)
# ============================================================

# FEW-SHOT EXAMPLE for MC tasks
MC_FEWSHOT = """Example:
Q: If you have 3 apples and buy 2 more, how many do you have?
A. 4
B. 5
C. 6
D. 7
Answer: B

---

"""

# GSM8K-MC Prompts - FIXED VERSION
GSM8K_MC_PROMPTS = {
    'DIRECT': MC_FEWSHOT + """Now your turn. Read the problem and select the correct answer.

IMPORTANT: Do NOT solve the problem step-by-step. Do NOT show calculations.
Just read, think, and output ONLY the letter (A, B, C, or D).

Problem: {question}

{choices}

Answer:""",

    'USE': MC_FEWSHOT + """Now your turn. An expert has provided a solution. Use it to select the answer.

IMPORTANT: Do NOT solve the problem yourself. Do NOT show calculations.
Just use the expert's solution and output ONLY the letter (A, B, C, or D).

Problem: {question}

{choices}

Expert's solution:
---
{trace}
---

Based on the expert's solution, the answer is:""",

    'USE_NOANS': MC_FEWSHOT + """Now your turn. An expert has provided partial reasoning. Use it to select the answer.

IMPORTANT: Do NOT solve the problem yourself. Do NOT show calculations.
Just use the expert's reasoning and output ONLY the letter (A, B, C, or D).

Problem: {question}

{choices}

Expert's reasoning:
---
{trace_noans}
---

Based on the expert's reasoning, the answer is:"""
}

# CSQA-Open Prompts (unchanged - worked fine)
CSQA_OPEN_PROMPTS = {
    'DIRECT': """Answer this question with a single word or short phrase.

Question: {question}

Provide ONLY your answer (a single word or short phrase):""",

    'USE': """Answer this question with a single word or short phrase.

Question: {question}

An expert provided this reasoning:
---
{trace}
---

Based on the expert's reasoning, your answer (single word or short phrase):""",

    'USE_NOANS': """Answer this question with a single word or short phrase.

Question: {question}

An expert provided this reasoning:
---
{trace_noans}
---

Based on the expert's reasoning, your answer (single word or short phrase):"""
}

print('Prompt templates defined (FIXED with few-shot and strict instructions).')

In [None]:
# ============================================================
# CELL 13: ROBUST ANSWER EXTRACTION
# ============================================================

def extract_mc_answer_robust(response: str, choice_value_map: dict = None) -> str:
    """
    Extract letter answer from response with multiple fallback strategies.
    
    Strategy:
    1. Look for letter at start
    2. Look for letter anywhere
    3. If numerical answer, match to choices
    """
    response = response.strip()
    response_upper = response.upper()
    
    # Strategy 1: Letter at the very start
    match = re.match(r'^([A-D])[\.\)\s:]*', response_upper)
    if match:
        return match.group(1)
    
    # Strategy 2: Letter anywhere (standalone)
    match = re.search(r'\b([A-D])\b', response_upper)
    if match:
        return match.group(1)
    
    # Strategy 3: "answer is X" pattern
    match = re.search(r'answer\s*(?:is)?\s*([A-D])', response_upper)
    if match:
        return match.group(1)
    
    # Strategy 4: Numerical answer → match to choice values
    if choice_value_map:
        # Extract numbers from response
        numbers = re.findall(r'\b(\d+)\b', response)
        for num in numbers:
            if num in choice_value_map:
                return choice_value_map[num]
    
    # Failed to extract
    return ""

def check_csqa_open_answer(response: str, problem: dict) -> Tuple[bool, bool]:
    """
    Check if CSQA open-ended response is correct.
    Returns: (is_correct, followed_wrong)
    """
    response_lower = response.strip().lower()
    correct = problem['correct_answer'].lower()
    wrong = problem['wrong_answer'].lower()
    
    is_correct = correct in response_lower or response_lower in correct
    followed_wrong = wrong in response_lower or response_lower in wrong
    
    return is_correct, followed_wrong

print('Robust answer extraction functions defined.')

In [None]:
# ============================================================
# CELL 14: RUN GSM8K-MC EXPERIMENT (FIXED)
# ============================================================

def run_gsm8k_mc_experiment(model_name: str, model_config: dict) -> Dict:
    """Run GSM8K-MC experiment for a single model."""
    
    short_name = model_config['short']
    checkpoint_file = f'{SAVE_DIR_EXP}/checkpoints/gsm8k_mc_{short_name}.json'
    
    results = load_json(checkpoint_file)
    if results:
        print(f'✓ Loaded checkpoint: {len(results["problems"])} problems')
    else:
        results = {'model': model_name, 'task': 'gsm8k_mc', 'problems': []}
    
    completed_indices = {p['idx'] for p in results['problems']}
    
    for problem in tqdm(gsm8k_mc_problems, desc=f'GSM8K-MC {short_name}'):
        if problem['idx'] in completed_indices:
            continue
        
        trace_noans = remove_answer_from_trace(
            problem['trace'], problem['trace_wrong_answer']
        )
        
        problem_result = {
            'idx': problem['idx'],
            'correct_label': problem['correct_label'],
            'trace_wrong_label': problem['trace_wrong_label'],
            'responses': {}
        }
        
        for condition in CONDITIONS:
            prompt = GSM8K_MC_PROMPTS[condition].format(
                question=problem['question'],
                choices=problem['choices_text'],
                trace=problem['trace'],
                trace_noans=trace_noans
            )
            
            # KEY FIX: max_tokens=10 to force short response
            response = call_api(prompt, model_config, max_tokens=10)
            answer = extract_mc_answer_robust(response, problem['choice_value_map'])
            
            problem_result['responses'][condition] = {
                'raw': response,
                'answer': answer,
                'correct': answer == problem['correct_label'],
                'followed_wrong': answer == problem['trace_wrong_label']
            }
        
        results['problems'].append(problem_result)
        
        if len(results['problems']) % 20 == 0:
            save_json(results, checkpoint_file)
    
    save_json(results, checkpoint_file)
    return results

# Run experiment
print('\n' + '='*60)
print('RUNNING GSM8K-MC EXPERIMENT (FIXED)')
print('='*60)

gsm8k_mc_results = {}
for model_name, model_config in MODELS.items():
    print(f'\n--- {model_name} ---')
    gsm8k_mc_results[model_config['short']] = run_gsm8k_mc_experiment(model_name, model_config)

print('\n✓ GSM8K-MC experiment complete!')

In [None]:
# ============================================================
# CELL 15: RUN CSQA-OPEN EXPERIMENT
# ============================================================

def run_csqa_open_experiment(model_name: str, model_config: dict) -> Dict:
    """Run CSQA-Open experiment for a single model."""
    
    short_name = model_config['short']
    checkpoint_file = f'{SAVE_DIR_EXP}/checkpoints/csqa_open_{short_name}.json'
    
    results = load_json(checkpoint_file)
    if results:
        print(f'✓ Loaded checkpoint: {len(results["problems"])} problems')
    else:
        results = {'model': model_name, 'task': 'csqa_open', 'problems': []}
    
    completed_indices = {p['idx'] for p in results['problems']}
    
    for problem in tqdm(csqa_open_problems, desc=f'CSQA-Open {short_name}'):
        if problem['idx'] in completed_indices:
            continue
        
        trace_noans = remove_answer_from_trace(
            problem['trace'], problem['wrong_answer']
        )
        
        problem_result = {
            'idx': problem['idx'],
            'correct_answer': problem['correct_answer'],
            'wrong_answer': problem['wrong_answer'],
            'responses': {}
        }
        
        for condition in CONDITIONS:
            prompt = CSQA_OPEN_PROMPTS[condition].format(
                question=problem['question'],
                trace=problem['trace'],
                trace_noans=trace_noans
            )
            
            response = call_api(prompt, model_config, max_tokens=50)
            is_correct, followed_wrong = check_csqa_open_answer(response, problem)
            
            problem_result['responses'][condition] = {
                'raw': response,
                'correct': is_correct,
                'followed_wrong': followed_wrong
            }
        
        results['problems'].append(problem_result)
        
        if len(results['problems']) % 20 == 0:
            save_json(results, checkpoint_file)
    
    save_json(results, checkpoint_file)
    return results

# Run experiment
print('\n' + '='*60)
print('RUNNING CSQA-OPEN EXPERIMENT')
print('='*60)

csqa_open_results = {}
for model_name, model_config in MODELS.items():
    print(f'\n--- {model_name} ---')
    csqa_open_results[model_config['short']] = run_csqa_open_experiment(model_name, model_config)

print('\n✓ CSQA-Open experiment complete!')

In [None]:
# ============================================================
# CELL 16: ANALYZE RESULTS
# ============================================================

def analyze_task_results(results: Dict) -> Dict:
    """Analyze experiment results."""
    n = len(results['problems'])
    
    analysis = {
        'n_problems': n,
        'accuracy': {},
        'cif_rate': {},
        'cif_count': {},
        'followed_wrong_in_cif': {},
        'extraction_failures': {}
    }
    
    for cond in CONDITIONS:
        correct = sum(1 for p in results['problems'] if p['responses'][cond]['correct'])
        analysis['accuracy'][cond] = correct / n if n > 0 else 0
        
        # Count extraction failures (for diagnostics)
        if 'answer' in results['problems'][0]['responses'][cond]:
            failures = sum(1 for p in results['problems'] if p['responses'][cond].get('answer', '') == '')
            analysis['extraction_failures'][cond] = failures
        
        if cond != 'DIRECT':
            direct_correct = [p for p in results['problems'] if p['responses']['DIRECT']['correct']]
            cif_cases = [p for p in direct_correct if not p['responses'][cond]['correct']]
            
            analysis['cif_rate'][cond] = len(cif_cases) / len(direct_correct) if direct_correct else 0
            analysis['cif_count'][cond] = len(cif_cases)
            
            followed = sum(1 for p in cif_cases if p['responses'][cond].get('followed_wrong', False))
            analysis['followed_wrong_in_cif'][cond] = followed / len(cif_cases) if cif_cases else 0
    
    return analysis

# Analyze all
print('\n' + '='*60)
print('EXPERIMENT A1-E1\' (PRIME) RESULTS')
print('='*60)

all_analyses = {}

for model_key in ['sonnet4', 'gpt4o']:
    model_name = [n for n, c in MODELS.items() if c['short'] == model_key][0]
    print(f'\n{"="*60}')
    print(f'{model_name}')
    print('='*60)
    
    # GSM8K-MC
    gsm_analysis = analyze_task_results(gsm8k_mc_results[model_key])
    print(f'\n--- GSM8K-MC (Converted from Open) ---')
    for cond in CONDITIONS:
        acc = gsm_analysis['accuracy'][cond]
        fail = gsm_analysis['extraction_failures'].get(cond, 'N/A')
        print(f"  {cond}: {acc:.1%} (extraction failures: {fail})")
    print(f"  CIF (USE): {gsm_analysis['cif_rate'].get('USE', 0):.1%} ({gsm_analysis['cif_count'].get('USE', 0)} cases)")
    print(f"  Followed Wrong in CIF: {gsm_analysis['followed_wrong_in_cif'].get('USE', 0):.1%}")
    
    # CSQA-Open
    csqa_analysis = analyze_task_results(csqa_open_results[model_key])
    print(f'\n--- CSQA-Open (Converted from MC) ---')
    for cond in CONDITIONS:
        print(f"  {cond}: {csqa_analysis['accuracy'][cond]:.1%}")
    print(f"  CIF (USE): {csqa_analysis['cif_rate'].get('USE', 0):.1%} ({csqa_analysis['cif_count'].get('USE', 0)} cases)")
    print(f"  Followed Wrong in CIF: {csqa_analysis['followed_wrong_in_cif'].get('USE', 0):.1%}")
    
    all_analyses[model_key] = {
        'gsm8k_mc': gsm_analysis,
        'csqa_open': csqa_analysis
    }

save_json(all_analyses, f'{SAVE_DIR_EXP}/results/exp_A1_E1prime_analysis.json')

In [None]:
# ============================================================
# CELL 17: COMPARISON WITH EXP B BASELINE
# ============================================================

print('\n' + '='*60)
print('FORMAT SWAP EFFECT COMPARISON')
print('='*60)
print('\nQuestion: Does FORMAT cause CIF, not domain?')

for model_key in ['sonnet4', 'gpt4o']:
    model_name = [n for n, c in MODELS.items() if c['short'] == model_key][0]
    print(f'\n{model_name}')
    print('-'*50)
    
    baseline = EXP_B_BASELINE.get(model_key, {})
    new = all_analyses[model_key]
    
    # GSM8K: Open → MC
    old_cif = baseline.get('gsm8k_open', {}).get('cif', 0)
    new_cif = new['gsm8k_mc']['cif_rate'].get('USE', 0)
    delta_gsm = new_cif - old_cif
    
    print(f'\nGSM8K (Open → MC):')
    print(f'  Original (Open): CIF = {old_cif:.1%}')
    print(f'  Converted (MC):  CIF = {new_cif:.1%}')
    print(f'  → Format effect: {delta_gsm:+.1%}')
    
    # Check DIRECT accuracy
    direct_acc = new['gsm8k_mc']['accuracy']['DIRECT']
    print(f'  [Diagnostic] DIRECT accuracy: {direct_acc:.1%}')
    
    # CSQA: MC → Open
    old_cif = baseline.get('csqa_mc', {}).get('cif', 0)
    new_cif = new['csqa_open']['cif_rate'].get('USE', 0)
    delta_csqa = new_cif - old_cif
    
    print(f'\nCSQA (MC → Open):')
    print(f'  Original (MC):   CIF = {old_cif:.1%}')
    print(f'  Converted (Open): CIF = {new_cif:.1%}')
    print(f'  → Format effect: {delta_csqa:+.1%}')
    
    # Interpretation
    supports = delta_gsm > 0.05 and delta_csqa < -0.05
    print(f'\n  Supports hypothesis: {"✓ YES" if supports else "? Partial/No"}')

print('\n' + '='*60)
print('INTERPRETATION')
print('='*60)
print('''
If GSM8K-MC shows HIGHER CIF than GSM8K-Open,
AND CSQA-Open shows LOWER CIF than CSQA-MC,
→ FORMAT (not domain) determines CIF vulnerability.

Causal mechanism:
- MC format: Answer directly selectable → high CIF (adoption mode)
- Open format: Must generate answer → low CIF (integration mode)
''')

In [None]:
# ============================================================
# CELL 18: VISUALIZATION
# ============================================================
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

colors = {'sonnet4': '#5B8FF9', 'gpt4o': '#5AD8A6'}
model_labels = {'sonnet4': 'Claude 4 Sonnet', 'gpt4o': 'GPT-4o'}

# Plot 1: CIF Rate Comparison (Original vs Swapped)
ax1 = axes[0]
x = np.arange(2)
width = 0.35

for i, model_key in enumerate(['sonnet4', 'gpt4o']):
    baseline = EXP_B_BASELINE.get(model_key, {})
    new = all_analyses[model_key]
    
    # Original format
    original = [baseline.get('gsm8k_open', {}).get('cif', 0),
                baseline.get('csqa_mc', {}).get('cif', 0)]
    # Swapped format
    swapped = [new['gsm8k_mc']['cif_rate'].get('USE', 0),
               new['csqa_open']['cif_rate'].get('USE', 0)]
    
    offset = (i - 0.5) * width
    ax1.bar(x + offset - width/4, original, width/2, 
            label=f'{model_labels[model_key]} (Original)', 
            color=colors[model_key], alpha=0.4)
    ax1.bar(x + offset + width/4, swapped, width/2,
            label=f'{model_labels[model_key]} (Swapped)', 
            color=colors[model_key], alpha=1.0)

ax1.set_ylabel('CIF Rate', fontsize=12)
ax1.set_title('E1\': Format Swap Effect on CIF', fontsize=14)
ax1.set_xticks(x)
ax1.set_xticklabels(['GSM8K\n(Open→MC)', 'CSQA\n(MC→Open)'])
ax1.legend(loc='upper right', fontsize=9)
ax1.set_ylim(0, 1)
ax1.axhline(y=0.5, color='red', linestyle='--', alpha=0.3)

# Add annotations
ax1.annotate('Expected:\nCIF ↑', xy=(0, 0.7), fontsize=10, ha='center', color='gray')
ax1.annotate('Expected:\nCIF ↓', xy=(1, 0.3), fontsize=10, ha='center', color='gray')

# Plot 2: Accuracy by condition (to verify experiment worked)
ax2 = axes[1]

for i, model_key in enumerate(['sonnet4', 'gpt4o']):
    new = all_analyses[model_key]
    
    direct = [new['gsm8k_mc']['accuracy']['DIRECT'],
              new['csqa_open']['accuracy']['DIRECT']]
    use = [new['gsm8k_mc']['accuracy']['USE'],
           new['csqa_open']['accuracy']['USE']]
    
    offset = (i - 0.5) * width
    ax2.bar(x + offset - width/4, direct, width/2,
            label=f'{model_labels[model_key]} (DIRECT)', 
            color=colors[model_key], alpha=0.4)
    ax2.bar(x + offset + width/4, use, width/2,
            label=f'{model_labels[model_key]} (USE)', 
            color=colors[model_key], alpha=1.0)

ax2.set_ylabel('Accuracy', fontsize=12)
ax2.set_title('E1\': Accuracy (Swapped Formats)', fontsize=14)
ax2.set_xticks(x)
ax2.set_xticklabels(['GSM8K-MC', 'CSQA-Open'])
ax2.legend(loc='lower right', fontsize=9)
ax2.set_ylim(0, 1)

plt.tight_layout()
plt.savefig(f'{SAVE_DIR_EXP}/exp_A1_E1prime_format_swap.png', dpi=150, bbox_inches='tight')
plt.show()

print(f'\n✓ Figure saved')

In [None]:
# ============================================================
# CELL 19: FINAL SUMMARY
# ============================================================

summary = {
    'experiment_id': 'A1_E1prime',
    'experiment_name': 'Format Swap (Fixed)',
    'date': EXPERIMENT_DATE,
    'hypothesis': 'Task format (MC vs Open) determines CIF vulnerability, not domain',
    'fixes_from_E1': [
        'Stronger instruction (no calculations allowed)',
        'Few-shot example for letter-only response',
        'max_tokens=10 to prevent calculation attempts',
        'Robust answer extraction with numerical fallback'
    ],
    'design': {
        'gsm8k_mc': 'GSM8K converted from open-ended to 4-choice MC',
        'csqa_open': 'CSQA converted from MC to open-ended'
    },
    'n_problems': N_PROBLEMS,
    'lambda': LAMBDA_FIXED,
    'models': list(MODELS.keys()),
    'conditions': CONDITIONS,
    'results': all_analyses,
    'baseline_comparison': EXP_B_BASELINE,
    'key_findings': []
}

for model_key in ['sonnet4', 'gpt4o']:
    baseline = EXP_B_BASELINE.get(model_key, {})
    new = all_analyses[model_key]
    
    gsm_delta = new['gsm8k_mc']['cif_rate'].get('USE', 0) - baseline.get('gsm8k_open', {}).get('cif', 0)
    csqa_delta = new['csqa_open']['cif_rate'].get('USE', 0) - baseline.get('csqa_mc', {}).get('cif', 0)
    
    summary['key_findings'].append({
        'model': model_key,
        'gsm8k_mc_direct_accuracy': new['gsm8k_mc']['accuracy']['DIRECT'],
        'gsm8k_format_effect': gsm_delta,
        'csqa_format_effect': csqa_delta,
        'supports_hypothesis': gsm_delta > 0.05 and csqa_delta < -0.05
    })

save_json(summary, f'{SAVE_DIR_EXP}/results/exp_A1_E1prime_summary.json')

print('\n' + '='*60)
print('EXPERIMENT A1-E1\' (PRIME) COMPLETE')
print('='*60)
print(f'\nResults saved to: {SAVE_DIR_EXP}')
print('\nKey Files:')
print('  - results/exp_A1_E1prime_summary.json')
print('  - results/exp_A1_E1prime_analysis.json')
print('  - exp_A1_E1prime_format_swap.png')
print('\n' + '='*60)
print('KEY FINDINGS')
print('='*60)

for finding in summary['key_findings']:
    model_name = [n for n, c in MODELS.items() if c['short'] == finding['model']][0]
    print(f"\n{model_name}:")
    print(f"  GSM8K-MC DIRECT accuracy: {finding['gsm8k_mc_direct_accuracy']:.1%}")
    print(f"  GSM8K: Open→MC = {finding['gsm8k_format_effect']:+.1%} CIF")
    print(f"  CSQA:  MC→Open = {finding['csqa_format_effect']:+.1%} CIF")
    print(f"  Supports: {'✓ YES' if finding['supports_hypothesis'] else '? Partial/No'}")

print('\n' + '='*60)