# E4': Late-Protected at c=0.4 - Matched Condition Comparison with E2

**Paper**: A2 (Cue-Dominant Extraction in CoT Length Effects)

**Purpose**: Definitive test of cue-dominant extraction with MATCHED corruption level.

**Background**:
- E4 (c=0.3) showed recovery from 60.3% to 94.5%
- But E2 used c=0.4, so reviewers may argue "lower c explains recovery"
- E4' uses c=0.4 to eliminate this criticism

**Design**:
- c = 0.4 (SAME as E2)
- L = 10
- Corrupt steps 6, 7, 8, 9 (K_corrupt = 4)
- PROTECT Step 10 (final-answer cue)

**Comparison targets**:
- E2-Early (steps 1-4 corrupted, Step 10 CLEAN): 96.5%
- E2-Late (steps 7-10 corrupted, Step 10 GONE): 60.3%

**Prediction**:
- E4' should approximate E2-Early (~96%), NOT E2-Late (~60%)
- This would prove: position effect = cue destruction, not chain depth

**Expected inference count**: 199

**Date**: 2025-01-02

**GLOBAL_SEED**: 20251224

## 0. Google Drive Connection

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
from datetime import datetime

EXPERIMENT_NAME = 'E4prime_late_protected_c04'
EXPERIMENT_DATE = datetime.now().strftime('%Y%m%d')

BASE_DIR = '/content/drive/MyDrive/CoT_Experiment'
V3_DATA_DIR = f'{BASE_DIR}/full_experiment_v3_20251224'

SAVE_DIR = f'{BASE_DIR}/{EXPERIMENT_NAME}_{EXPERIMENT_DATE}'
os.makedirs(SAVE_DIR, exist_ok=True)
os.makedirs(f'{SAVE_DIR}/results', exist_ok=True)

print(f'Experiment: {EXPERIMENT_NAME}')
print(f'V3 data directory: {V3_DATA_DIR}')
print(f'Save directory: {SAVE_DIR}')

## 1. Install Dependencies

In [None]:
!pip install datasets anthropic matplotlib pandas tqdm scipy -q
print('Dependencies installed.')

## 2. Configuration

In [None]:
import hashlib
import random
import json
import re
import time
from typing import List, Dict, Tuple, Optional, Any
from dataclasses import dataclass, asdict, field
from datetime import datetime
from tqdm import tqdm
import pandas as pd
import numpy as np

# =============================================================================
# Global Configuration
# =============================================================================
GLOBAL_SEED = 20251224
E4PRIME_SEED = 20250102

# Experiment parameters - MATCHED to E2
L = 10
C_TARGET = 0.4  # SAME as E2
K_CORRUPT = 4   # 4 steps corrupted (steps 6,7,8,9)

# Late-Protected at c=0.4: Corrupt steps 6-9, protect step 10
CORRUPTED_STEPS = [6, 7, 8, 9]  # Step 10 is PROTECTED

# API settings
API_MAX_TOKENS_ANSWER = 256
API_RETRY_DELAY = 1.0
API_RATE_LIMIT_DELAY = 0.5

print('='*70)
print('E4\': LATE-PROTECTED AT c=0.4 (MATCHED CONDITION)')
print('='*70)
print(f'  GLOBAL_SEED: {GLOBAL_SEED}')
print(f'  L (trace length): {L}')
print(f'  c (corruption fraction): {C_TARGET}')
print(f'  K_corrupt: {K_CORRUPT}')
print(f'  Corrupted steps: {CORRUPTED_STEPS}')
print(f'  Protected step: Step 10 (final-answer cue)')
print(f'  K_clean: {L - K_CORRUPT}')
print('='*70)
print('\nComparison targets (from E2, same c=0.4):')
print('  E2-Early (steps 1-4 corrupted, Step 10 CLEAN): 96.5%')
print('  E2-Late (steps 7-10 corrupted, Step 10 GONE):  60.3%')
print('\nPrediction: E4\' should be close to E2-Early (~96%)')

## 3. Data Structures

In [None]:
@dataclass
class GSM8KProblem:
    index: int
    question: str
    answer_text: str
    final_answer: int

@dataclass
class CleanTrace:
    problem_index: int
    I: int
    steps: List[str]
    full_text: str

@dataclass
class LateProtectedTrace:
    problem_index: int
    L: int
    c: float
    corrupted_steps: List[int]
    protected_step: int
    corruption_types: Dict[int, str]
    final_cue_clean: bool
    steps: List[str]
    full_text: str
    seed: int

@dataclass
class ExperimentResult:
    problem_index: int
    condition: str
    L: int
    c: float
    K_corrupt: int
    K_clean: int
    corrupted_steps: List[int]
    final_cue_clean: bool
    model_answer: Optional[int]
    correct_answer: int
    is_correct: bool
    raw_output: str
    timestamp: str

## 4. Utility Functions

In [None]:
def derive_seed(global_seed: int, problem_id: int, L: int, extra: str = '') -> int:
    key = f"{global_seed}|{problem_id}|L={L}|{extra}"
    h = hashlib.sha256(key.encode("utf-8")).hexdigest()
    return int(h[:8], 16)

def save_json(data: Any, filepath: str):
    with open(filepath, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=2)

def load_json(filepath: str) -> Any:
    with open(filepath, 'r', encoding='utf-8') as f:
        return json.load(f)

## 5. Load Existing Data

In [None]:
problems_path = f'{V3_DATA_DIR}/problems_v3.json'
problems_data = load_json(problems_path)
problems = [GSM8KProblem(**p) for p in problems_data]
print(f'Loaded {len(problems)} problems')

traces_path = f'{V3_DATA_DIR}/clean_traces/clean_traces_I10_v3.json'
traces_data = load_json(traces_path)
clean_traces = [CleanTrace(**t) for t in traces_data]
trace_map = {t.problem_index: t for t in clean_traces}
print(f'Loaded {len(clean_traces)} clean traces (L=10)')

prob_map = {p.index: p for p in problems}

## 6. Corruption Templates

In [None]:
IRRELEVANT_TEMPLATES = [
    "Compute an auxiliary value: aux = {a} + {b} = {result}, but it will not be used later.",
    "Compute a side quantity: aux = {a} * 2 = {result}, unrelated to the final result.",
    "Note that we can also compute aux = {a} - {b} = {result}, though this is not needed.",
]

WRONG_CONSTRAINT_TEMPLATES = [
    "Fix an intermediate condition: set {var} = {wrong_value} as a given constraint for the rest of the steps.",
    "Assume the total is {var} = {wrong_value} and proceed using this fixed value.",
]

def generate_irrelevant_step(step_num: int, seed: int) -> str:
    rng = random.Random(seed)
    a = rng.randint(2, 20)
    b = rng.randint(2, 20)
    template = rng.choice(IRRELEVANT_TEMPLATES)
    if '+' in template:
        result = a + b
    elif '*' in template:
        result = a * 2
    else:
        result = a - b
    return template.format(a=a, b=b, result=result)

def generate_local_error_step(original_step: str, seed: int) -> str:
    rng = random.Random(seed)
    numbers = re.findall(r'\d+', original_step)
    if not numbers:
        return f"Compute t = 10 * 3 = {rng.randint(28, 32)} (using the previous values)."
    original_result = int(numbers[-1])
    offset = rng.choice([-3, -2, -1, 1, 2, 3])
    wrong_result = max(0, original_result + offset)
    modified = re.sub(r'= (\d+)\.$', f'= {wrong_result}.', original_step)
    if modified == original_step:
        modified = re.sub(r'(\d+)\.$', f'{wrong_result}.', original_step)
    return modified

def generate_wrong_constraint_step(step_num: int, seed: int) -> str:
    rng = random.Random(seed)
    var = rng.choice(['x', 'total', 'result', 'n'])
    wrong_value = rng.randint(10, 100)
    template = rng.choice(WRONG_CONSTRAINT_TEMPLATES)
    return template.format(var=var, wrong_value=wrong_value)

## 7. Late-Protected Corruption Logic (c=0.4)

In [None]:
def assign_corruption_types_for_steps(corrupted_steps: List[int], seed: int) -> Dict[int, str]:
    """Assign corruption types to specific steps (IRR:LOC:WRONG = 1:2:2)"""
    K = len(corrupted_steps)
    if K == 0:
        return {}
    
    # For K=4: 0-1 IRR, 1-2 LOC, 1-2 WRONG
    n_irr = max(0, (K * 1) // 5)
    n_loc = max(1, (K * 2) // 5)
    n_wrong = K - n_irr - n_loc
    
    rng = random.Random(seed)
    perm = corrupted_steps[:]
    rng.shuffle(perm)
    
    type_map = {}
    for s in perm[:n_irr]:
        type_map[s] = "IRR"
    for s in perm[n_irr:n_irr + n_loc]:
        type_map[s] = "LOC"
    for s in perm[n_irr + n_loc:]:
        type_map[s] = "WRONG"
    
    return type_map


def create_late_protected_trace_c04(
    clean_trace: CleanTrace,
    seed: int
) -> LateProtectedTrace:
    """
    Create a trace with Late corruption at c=0.4 but Step 10 PROTECTED.
    
    - Corrupt steps 6, 7, 8, 9 (K=4, c=0.4)
    - Keep step 10 (Final cue) CLEAN
    """
    L = clean_trace.I
    corrupted_steps = CORRUPTED_STEPS  # [6, 7, 8, 9]
    protected_step = 10
    
    # Assign corruption types
    corruption_types = assign_corruption_types_for_steps(corrupted_steps, seed)
    
    # Apply corruption
    new_steps = []
    for i, step_content in enumerate(clean_trace.steps):
        step_num = i + 1
        if step_num in corruption_types:
            ctype = corruption_types[step_num]
            step_seed = seed + step_num * 1000
            if ctype == 'IRR':
                new_content = generate_irrelevant_step(step_num, step_seed)
            elif ctype == 'LOC':
                new_content = generate_local_error_step(step_content, step_seed)
            else:  # WRONG
                new_content = generate_wrong_constraint_step(step_num, step_seed)
            new_steps.append(new_content)
        else:
            # Keep clean (steps 1-5 and step 10)
            new_steps.append(step_content)
    
    # Build full text
    lines = ['[[COT_START]]']
    for i, content in enumerate(new_steps):
        lines.append(f'Step {i+1}: {content}')
    lines.append('[[COT_END]]')
    full_text = '\n'.join(lines)
    
    return LateProtectedTrace(
        problem_index=clean_trace.problem_index,
        L=L,
        c=C_TARGET,
        corrupted_steps=corrupted_steps,
        protected_step=protected_step,
        corruption_types=corruption_types,
        final_cue_clean=True,
        steps=new_steps,
        full_text=full_text,
        seed=seed
    )

In [None]:
# Test Late-Protected c=0.4 logic
test_trace = clean_traces[0]
test_seed = derive_seed(E4PRIME_SEED, test_trace.problem_index, L, 'late_protected_c04')

protected_trace = create_late_protected_trace_c04(test_trace, test_seed)

print('E4\' Late-Protected Trace (c=0.4):')
print(f'  c: {protected_trace.c}')
print(f'  Corrupted steps: {protected_trace.corrupted_steps}')
print(f'  Protected step: {protected_trace.protected_step}')
print(f'  Corruption types: {protected_trace.corruption_types}')
print(f'  Final cue clean: {protected_trace.final_cue_clean}')
print(f'  K_corrupt: {len(protected_trace.corrupted_steps)}')
print(f'  K_clean: {L - len(protected_trace.corrupted_steps)}')
print(f'\nStep 10 (should be clean):')
print(f'  {protected_trace.steps[9][:100]}...')

## 8. API Setup

In [None]:
from getpass import getpass

ANTHROPIC_API_KEY = getpass('Enter Anthropic API Key: ')
print('API Key set.')

In [None]:
import anthropic

client = anthropic.Anthropic(api_key=ANTHROPIC_API_KEY)

def call_claude(system_prompt: str, user_prompt: str, max_tokens: int = 1024, retries: int = 3) -> str:
    for attempt in range(retries):
        try:
            message = client.messages.create(
                model="claude-sonnet-4-20250514",
                max_tokens=max_tokens,
                messages=[{"role": "user", "content": user_prompt}],
                system=system_prompt,
                temperature=0
            )
            time.sleep(API_RATE_LIMIT_DELAY)
            return message.content[0].text
        except Exception as e:
            print(f'API error (attempt {attempt+1}): {e}')
            if attempt < retries - 1:
                time.sleep(API_RETRY_DELAY * (attempt + 1))
            else:
                raise

test_response = call_claude(
    "You output ONLY JSON.",
    'Respond with exactly: {"test": "ok"}',
    max_tokens=50
)
print(f'API test: {test_response}')

## 9. Experiment Prompts

In [None]:
EXPERIMENT_SYSTEM_PROMPT = """You are a calculator that outputs ONLY JSON.

CRITICAL RULES:
1. Your output MUST start with the character '{'
2. Your output MUST be exactly: {"final": <number>}
3. Replace <number> with an integer (the numerical answer)
4. Do NOT write ANY explanation, reasoning, or text before or after the JSON
5. Do NOT write "I need to" or "Let me" or any other words
6. ONLY output the JSON object, nothing else

CORRECT OUTPUT EXAMPLE:
{"final": 42}
"""

def create_experiment_prompt(problem: GSM8KProblem, cot_text: str) -> Tuple[str, str]:
    user = f"""Problem: {problem.question}

Reasoning trace (use these steps as given facts):
{cot_text}

Based on the trace above, compute the final numerical answer.
OUTPUT ONLY: {{"final": <number>}}
START YOUR RESPONSE WITH '{{'"""
    return EXPERIMENT_SYSTEM_PROMPT, user

def parse_model_answer(response: str) -> Optional[int]:
    match = re.search(r'\{\s*"final"\s*:\s*(-?\d+(?:\.\d+)?)\s*\}', response)
    if match:
        return int(round(float(match.group(1))))
    match = re.search(r"\{\s*[\"']final[\"']\s*:\s*(-?\d+(?:\.\d+)?)\s*\}", response)
    if match:
        return int(round(float(match.group(1))))
    match = re.search(r'"final"\s*:\s*(-?\d+(?:\.\d+)?)', response)
    if match:
        return int(round(float(match.group(1))))
    matches = re.findall(r'(?:^|\s)(-?\d+(?:\.\d+)?)(?:\s|$|\.|,)', response)
    if matches:
        return int(round(float(matches[-1])))
    return None

## 10. Run Experiment

In [None]:
def run_experiment(
    problem: GSM8KProblem,
    trace: LateProtectedTrace
) -> ExperimentResult:
    sys_prompt, usr_prompt = create_experiment_prompt(problem, trace.full_text)
    response = call_claude(sys_prompt, usr_prompt, max_tokens=API_MAX_TOKENS_ANSWER)
    
    model_answer = parse_model_answer(response)
    is_correct = (model_answer == problem.final_answer) if model_answer is not None else False
    
    return ExperimentResult(
        problem_index=problem.index,
        condition='late_protected_c04',
        L=trace.L,
        c=trace.c,
        K_corrupt=len(trace.corrupted_steps),
        K_clean=trace.L - len(trace.corrupted_steps),
        corrupted_steps=trace.corrupted_steps,
        final_cue_clean=trace.final_cue_clean,
        model_answer=model_answer,
        correct_answer=problem.final_answer,
        is_correct=is_correct,
        raw_output=response,
        timestamp=datetime.now().isoformat()
    )

In [None]:
print('='*70)
print('E4\': LATE-PROTECTED EXPERIMENT (c=0.4)')
print('='*70)
print(f'Condition: Late-Protected (steps 6-9 corrupted, step 10 CLEAN)')
print(f'c = {C_TARGET} (MATCHED to E2)')
print(f'Expected inferences: {len(problems)}')
print('='*70)

results = []
traces_log = []

for prob in tqdm(problems, desc='E4\' Late-Protected c=0.4'):
    if prob.index not in trace_map:
        continue
    
    clean_trace = trace_map[prob.index]
    
    # Create Late-Protected trace at c=0.4
    seed = derive_seed(E4PRIME_SEED, prob.index, L, 'late_protected_c04')
    protected_trace = create_late_protected_trace_c04(clean_trace, seed)
    traces_log.append(asdict(protected_trace))
    
    # Run experiment
    result = run_experiment(prob, protected_trace)
    results.append(result)

print(f'\nCompleted: {len(results)} experiments')

## 11. Save Results

In [None]:
save_json([asdict(r) for r in results], f'{SAVE_DIR}/results/E4prime_late_protected_c04_results.json')
print(f'Results saved: {SAVE_DIR}/results/E4prime_late_protected_c04_results.json')

save_json(traces_log, f'{SAVE_DIR}/results/E4prime_late_protected_c04_traces.json')
print(f'Traces saved: {SAVE_DIR}/results/E4prime_late_protected_c04_traces.json')

## 12. Analysis

In [None]:
df = pd.DataFrame([asdict(r) for r in results])
e4prime_acc = df['is_correct'].mean()

print('='*70)
print('E4\' RESULTS (c=0.4, MATCHED TO E2)')
print('='*70)
print(f'E4\' Late-Protected (steps 6-9 corrupted, step 10 CLEAN):')
print(f'  Accuracy: {e4prime_acc:.1%} ({df["is_correct"].sum()}/{len(df)})')
print(f'  c: {C_TARGET}')
print(f'  K_corrupt: {K_CORRUPT}')
print(f'  K_clean: {L - K_CORRUPT}')
print(f'  Final cue clean: True (all trials)')
print('='*70)

In [None]:
# Comparison with E2 results (SAME c=0.4)
print('\n' + '='*70)
print('MATCHED-CONDITION COMPARISON WITH E2 (c=0.4)')
print('='*70)
print(f'{"Condition":<40} {"Accuracy":>10} {"c":>6} {"Final cue":>12}')
print('-'*70)
print(f'{"E2-Early (steps 1-4 corrupted)":<40} {"96.5%":>10} {"0.4":>6} {"CLEAN":>12}')
print(f'{"E4\' Late-Protected (steps 6-9)":<40} {e4prime_acc:>9.1%} {"0.4":>6} {"CLEAN":>12}')
print(f'{"E2-Late (steps 7-10 corrupted)":<40} {"60.3%":>10} {"0.4":>6} {"CORRUPTED":>12}')
print('='*70)

# Also compare with E4 (c=0.3)
print('\n' + '='*70)
print('COMPARISON WITH E4 (c=0.3)')
print('='*70)
print(f'{"Condition":<40} {"Accuracy":>10} {"c":>6}')
print('-'*60)
print(f'{"E4 Late-Protected (steps 7-9, c=0.3)":<40} {"94.5%":>10} {"0.3":>6}')
print(f'{"E4\' Late-Protected (steps 6-9, c=0.4)":<40} {e4prime_acc:>9.1%} {"0.4":>6}')
print('='*60)

In [None]:
# Interpretation
print('\n' + '='*70)
print('INTERPRETATION')
print('='*70)

if e4prime_acc >= 0.90:
    print('RESULT: E4\' (Late-Protected c=0.4) achieves ~90%+ accuracy')
    print('')
    print('CONCLUSION: CUE-DOMINANT EXTRACTION DEFINITIVELY CONFIRMED')
    print('')
    print('At MATCHED corruption level (c=0.4):')
    print(f'  - E2-Late (Step 10 CORRUPTED): 60.3%')
    print(f'  - E4\' (Step 10 CLEAN):        {e4prime_acc:.1%}')
    print(f'  - Recovery: +{e4prime_acc - 0.603:.1%}')
    print('')
    print('This eliminates the "c difference" criticism.')
    print('The position effect is ENTIRELY explained by final-answer cue status.')
    
elif e4prime_acc >= 0.80:
    print('RESULT: E4\' shows substantial recovery')
    print(f'  Recovery from E2-Late: +{e4prime_acc - 0.603:.1%}')
    print('  Supports cue-dominant extraction, though not as strong as E4 (c=0.3)')
    
else:
    print('RESULT: Unexpected - needs investigation')

## 13. Summary

In [None]:
print('='*70)
print('E4\' EXPERIMENT COMPLETE')
print('='*70)
print(f'Date: {EXPERIMENT_DATE}')
print(f'Total experiments: {len(results)}')
print(f'E4\' accuracy (c=0.4): {e4prime_acc:.1%}')
print(f'')
print(f'FINAL COMPARISON (all at c=0.4):')
print(f'  E2-Early (Step 10 CLEAN):     96.5%')
print(f'  E4\' (Step 10 CLEAN):         {e4prime_acc:.1%}')
print(f'  E2-Late (Step 10 CORRUPTED):  60.3%')
print(f'')
print(f'Files saved to: {SAVE_DIR}')
print('='*70)