# E5: Single-Step Corruption Experiment - Direct Test of Answer-Cue Importance

**Paper**: A2 (Redundancy vs Depth in CoT Length Effects)

**Purpose**: Most direct test of whether Step 10 (Final cue) is uniquely important.

**Sofia's suggestion**: Step10 only vs Step1 only is the most direct test.
If Step10-only causes failure while Step1-only does not, E2/E3 effects are explained.

**Design**:
- c = 0.1 (only 1 step corrupted out of 10)
- L = 10
- **Step1-Only**: Corrupt ONLY Step 1
- **Step10-Only**: Corrupt ONLY Step 10 (Final cue)

**Predictions**:
- Step1-Only: High accuracy (first step is dispensable)
- Step10-Only: Low accuracy (Final cue destroyed)
- If Step10-Only << Step1-Only -> Answer-cue dominance confirmed

**Expected inference count**: 199 x 2 = 398

**Date**: 2025-01-02

**GLOBAL_SEED**: 20251224

## 0. Google Drive Connection

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
from datetime import datetime

EXPERIMENT_NAME = 'E5_single_step'
EXPERIMENT_DATE = datetime.now().strftime('%Y%m%d')

BASE_DIR = '/content/drive/MyDrive/CoT_Experiment'
V3_DATA_DIR = f'{BASE_DIR}/full_experiment_v3_20251224'

SAVE_DIR = f'{BASE_DIR}/{EXPERIMENT_NAME}_{EXPERIMENT_DATE}'
os.makedirs(SAVE_DIR, exist_ok=True)
os.makedirs(f'{SAVE_DIR}/results', exist_ok=True)

print(f'Experiment: {EXPERIMENT_NAME}')
print(f'V3 data directory: {V3_DATA_DIR}')
print(f'Save directory: {SAVE_DIR}')

## 1. Install Dependencies

In [None]:
!pip install datasets anthropic matplotlib pandas tqdm scipy -q
print('Dependencies installed.')

## 2. Configuration

In [None]:
import hashlib
import random
import json
import re
import time
from typing import List, Dict, Tuple, Optional, Any
from dataclasses import dataclass, asdict, field
from datetime import datetime
from tqdm import tqdm
import pandas as pd
import numpy as np

# =============================================================================
# Global Configuration
# =============================================================================
GLOBAL_SEED = 20251224
E5_SEED = 20250102

# Experiment parameters
L = 10
C_TARGET = 0.1  # Only 1 step corrupted

# API settings
API_MAX_TOKENS_ANSWER = 256
API_RETRY_DELAY = 1.0
API_RATE_LIMIT_DELAY = 0.5

print('='*70)
print('E5: SINGLE-STEP CORRUPTION EXPERIMENT')
print('='*70)
print(f'  GLOBAL_SEED: {GLOBAL_SEED}')
print(f'  L (trace length): {L}')
print(f'  c (corruption fraction): {C_TARGET}')
print(f'  Conditions:')
print(f'    - Step1-Only: Corrupt ONLY Step 1')
print(f'    - Step10-Only: Corrupt ONLY Step 10 (Final cue)')
print('='*70)

## 3. Data Structures

In [None]:
@dataclass
class GSM8KProblem:
    index: int
    question: str
    answer_text: str
    final_answer: int

@dataclass
class CleanTrace:
    problem_index: int
    I: int
    steps: List[str]
    full_text: str

@dataclass
class SingleStepTrace:
    problem_index: int
    L: int
    condition: str  # 'step1_only' or 'step10_only'
    corrupted_step: int
    corruption_type: str
    final_cue_clean: bool
    steps: List[str]
    full_text: str
    seed: int

@dataclass
class ExperimentResult:
    problem_index: int
    condition: str
    L: int
    corrupted_step: int
    corruption_type: str
    final_cue_clean: bool
    model_answer: Optional[int]
    correct_answer: int
    is_correct: bool
    raw_output: str
    timestamp: str

## 4. Utility Functions

In [None]:
def derive_seed(global_seed: int, problem_id: int, L: int, extra: str = '') -> int:
    key = f"{global_seed}|{problem_id}|L={L}|{extra}"
    h = hashlib.sha256(key.encode("utf-8")).hexdigest()
    return int(h[:8], 16)

def save_json(data: Any, filepath: str):
    with open(filepath, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=2)

def load_json(filepath: str) -> Any:
    with open(filepath, 'r', encoding='utf-8') as f:
        return json.load(f)

## 5. Load Existing Data

In [None]:
problems_path = f'{V3_DATA_DIR}/problems_v3.json'
problems_data = load_json(problems_path)
problems = [GSM8KProblem(**p) for p in problems_data]
print(f'Loaded {len(problems)} problems')

traces_path = f'{V3_DATA_DIR}/clean_traces/clean_traces_I10_v3.json'
traces_data = load_json(traces_path)
clean_traces = [CleanTrace(**t) for t in traces_data]
trace_map = {t.problem_index: t for t in clean_traces}
print(f'Loaded {len(clean_traces)} clean traces (L=10)')

prob_map = {p.index: p for p in problems}

## 6. Corruption Templates

In [None]:
# For single-step corruption, we use WRONG type (most harmful)
# This maximizes the chance of detecting the effect

WRONG_CONSTRAINT_TEMPLATES = [
    "Fix an intermediate condition: set {var} = {wrong_value} as a given constraint for the rest of the steps.",
    "Assume the total is {var} = {wrong_value} and proceed using this fixed value.",
]

def generate_wrong_constraint_step(step_num: int, seed: int) -> str:
    rng = random.Random(seed)
    var = rng.choice(['x', 'total', 'result', 'n'])
    wrong_value = rng.randint(10, 100)
    template = rng.choice(WRONG_CONSTRAINT_TEMPLATES)
    return template.format(var=var, wrong_value=wrong_value)

## 7. Single-Step Corruption Logic

In [None]:
def create_single_step_trace(
    clean_trace: CleanTrace,
    target_step: int,  # 1 or 10
    seed: int
) -> SingleStepTrace:
    """
    Create a trace with ONLY one step corrupted.
    
    - target_step=1: Corrupt Step 1 only (first step)
    - target_step=10: Corrupt Step 10 only (Final cue)
    """
    L = clean_trace.I
    condition = f'step{target_step}_only'
    corruption_type = 'WRONG'  # Use WRONG for maximum effect
    
    # Apply corruption to single step
    new_steps = []
    for i, step_content in enumerate(clean_trace.steps):
        step_num = i + 1
        if step_num == target_step:
            # Corrupt this step
            step_seed = seed + step_num * 1000
            new_content = generate_wrong_constraint_step(step_num, step_seed)
            new_steps.append(new_content)
        else:
            # Keep clean
            new_steps.append(step_content)
    
    # Build full text
    lines = ['[[COT_START]]']
    for i, content in enumerate(new_steps):
        lines.append(f'Step {i+1}: {content}')
    lines.append('[[COT_END]]')
    full_text = '\n'.join(lines)
    
    # Final cue is clean if target is NOT step 10
    final_cue_clean = (target_step != 10)
    
    return SingleStepTrace(
        problem_index=clean_trace.problem_index,
        L=L,
        condition=condition,
        corrupted_step=target_step,
        corruption_type=corruption_type,
        final_cue_clean=final_cue_clean,
        steps=new_steps,
        full_text=full_text,
        seed=seed
    )

In [None]:
# Test single-step logic
test_trace = clean_traces[0]

# Step 1 only
step1_trace = create_single_step_trace(test_trace, 1, E5_SEED)
print('Step1-Only Trace:')
print(f'  Corrupted step: {step1_trace.corrupted_step}')
print(f'  Final cue clean: {step1_trace.final_cue_clean}')
print(f'  Step 1: {step1_trace.steps[0][:80]}...')
print(f'  Step 10: {step1_trace.steps[9][:80]}...')

# Step 10 only
step10_trace = create_single_step_trace(test_trace, 10, E5_SEED)
print('\nStep10-Only Trace:')
print(f'  Corrupted step: {step10_trace.corrupted_step}')
print(f'  Final cue clean: {step10_trace.final_cue_clean}')
print(f'  Step 1: {step10_trace.steps[0][:80]}...')
print(f'  Step 10: {step10_trace.steps[9][:80]}...')

## 8. API Setup

In [None]:
from getpass import getpass

ANTHROPIC_API_KEY = getpass('Enter Anthropic API Key: ')
print('API Key set.')

In [None]:
import anthropic

client = anthropic.Anthropic(api_key=ANTHROPIC_API_KEY)

def call_claude(system_prompt: str, user_prompt: str, max_tokens: int = 1024, retries: int = 3) -> str:
    for attempt in range(retries):
        try:
            message = client.messages.create(
                model="claude-sonnet-4-20250514",
                max_tokens=max_tokens,
                messages=[{"role": "user", "content": user_prompt}],
                system=system_prompt,
                temperature=0
            )
            time.sleep(API_RATE_LIMIT_DELAY)
            return message.content[0].text
        except Exception as e:
            print(f'API error (attempt {attempt+1}): {e}')
            if attempt < retries - 1:
                time.sleep(API_RETRY_DELAY * (attempt + 1))
            else:
                raise

test_response = call_claude(
    "You output ONLY JSON.",
    'Respond with exactly: {"test": "ok"}',
    max_tokens=50
)
print(f'API test: {test_response}')

## 9. Experiment Prompts

In [None]:
EXPERIMENT_SYSTEM_PROMPT = """You are a calculator that outputs ONLY JSON.

CRITICAL RULES:
1. Your output MUST start with the character '{'
2. Your output MUST be exactly: {"final": <number>}
3. Replace <number> with an integer (the numerical answer)
4. Do NOT write ANY explanation, reasoning, or text before or after the JSON
5. Do NOT write "I need to" or "Let me" or any other words
6. ONLY output the JSON object, nothing else

CORRECT OUTPUT EXAMPLE:
{"final": 42}
"""

def create_experiment_prompt(problem: GSM8KProblem, cot_text: str) -> Tuple[str, str]:
    user = f"""Problem: {problem.question}

Reasoning trace (use these steps as given facts):
{cot_text}

Based on the trace above, compute the final numerical answer.
OUTPUT ONLY: {{"final": <number>}}
START YOUR RESPONSE WITH '{{'"""
    return EXPERIMENT_SYSTEM_PROMPT, user

def parse_model_answer(response: str) -> Optional[int]:
    match = re.search(r'\{\s*"final"\s*:\s*(-?\d+(?:\.\d+)?)\s*\}', response)
    if match:
        return int(round(float(match.group(1))))
    match = re.search(r"\{\s*[\"']final[\"']\s*:\s*(-?\d+(?:\.\d+)?)\s*\}", response)
    if match:
        return int(round(float(match.group(1))))
    match = re.search(r'"final"\s*:\s*(-?\d+(?:\.\d+)?)', response)
    if match:
        return int(round(float(match.group(1))))
    matches = re.findall(r'(?:^|\s)(-?\d+(?:\.\d+)?)(?:\s|$|\.|,)', response)
    if matches:
        return int(round(float(matches[-1])))
    return None

## 10. Run Experiment

In [None]:
def run_experiment(
    problem: GSM8KProblem,
    trace: SingleStepTrace
) -> ExperimentResult:
    sys_prompt, usr_prompt = create_experiment_prompt(problem, trace.full_text)
    response = call_claude(sys_prompt, usr_prompt, max_tokens=API_MAX_TOKENS_ANSWER)
    
    model_answer = parse_model_answer(response)
    is_correct = (model_answer == problem.final_answer) if model_answer is not None else False
    
    return ExperimentResult(
        problem_index=problem.index,
        condition=trace.condition,
        L=trace.L,
        corrupted_step=trace.corrupted_step,
        corruption_type=trace.corruption_type,
        final_cue_clean=trace.final_cue_clean,
        model_answer=model_answer,
        correct_answer=problem.final_answer,
        is_correct=is_correct,
        raw_output=response,
        timestamp=datetime.now().isoformat()
    )

In [None]:
print('='*70)
print('E5: SINGLE-STEP CORRUPTION EXPERIMENT')
print('='*70)
print(f'Conditions: Step1-Only, Step10-Only')
print(f'Expected inferences: {len(problems) * 2}')
print('='*70)

results_step1 = []
results_step10 = []
traces_log = []

for prob in tqdm(problems, desc='Single-step experiment'):
    if prob.index not in trace_map:
        continue
    
    clean_trace = trace_map[prob.index]
    
    # Step 1 only condition
    seed1 = derive_seed(E5_SEED, prob.index, L, 'step1_only')
    trace_step1 = create_single_step_trace(clean_trace, 1, seed1)
    traces_log.append(asdict(trace_step1))
    result_step1 = run_experiment(prob, trace_step1)
    results_step1.append(result_step1)
    
    # Step 10 only condition
    seed10 = derive_seed(E5_SEED, prob.index, L, 'step10_only')
    trace_step10 = create_single_step_trace(clean_trace, 10, seed10)
    traces_log.append(asdict(trace_step10))
    result_step10 = run_experiment(prob, trace_step10)
    results_step10.append(result_step10)

print(f'\nCompleted: {len(results_step1)} Step1-Only + {len(results_step10)} Step10-Only')

## 11. Save Results

In [None]:
all_results = results_step1 + results_step10
save_json([asdict(r) for r in all_results], f'{SAVE_DIR}/results/E5_single_step_results.json')
print(f'Results saved: {SAVE_DIR}/results/E5_single_step_results.json')

save_json(traces_log, f'{SAVE_DIR}/results/E5_single_step_traces.json')
print(f'Traces saved: {SAVE_DIR}/results/E5_single_step_traces.json')

## 12. Analysis

In [None]:
df_step1 = pd.DataFrame([asdict(r) for r in results_step1])
df_step10 = pd.DataFrame([asdict(r) for r in results_step10])

step1_acc = df_step1['is_correct'].mean()
step10_acc = df_step10['is_correct'].mean()

print('='*70)
print('E5 RESULTS')
print('='*70)
print(f'{"Condition":<25} {"Accuracy":>10} {"N":>6} {"Final cue":>12}')
print('-'*55)
print(f'{"Step1-Only":<25} {step1_acc:>9.1%} {len(df_step1):>6} {"CLEAN":>12}')
print(f'{"Step10-Only":<25} {step10_acc:>9.1%} {len(df_step10):>6} {"CORRUPTED":>12}')
print('-'*55)
print(f'Difference (Step10 - Step1): {step10_acc - step1_acc:+.1%}')
print('='*70)

In [None]:
from scipy import stats

# McNemar's test
merged = pd.merge(
    df_step1[['problem_index', 'is_correct']].rename(columns={'is_correct': 'step1'}),
    df_step10[['problem_index', 'is_correct']].rename(columns={'is_correct': 'step10'}),
    on='problem_index'
)

a = ((merged['step1'] == True) & (merged['step10'] == True)).sum()
b = ((merged['step1'] == True) & (merged['step10'] == False)).sum()
c = ((merged['step1'] == False) & (merged['step10'] == True)).sum()
d = ((merged['step1'] == False) & (merged['step10'] == False)).sum()

print('\nContingency Table:')
print(f'  Both correct: {a}')
print(f'  Step1 only correct: {b}')
print(f'  Step10 only correct: {c}')
print(f'  Both wrong: {d}')

if b + c > 0:
    chi2 = (abs(b - c) - 1)**2 / (b + c)
    p_value = 1 - stats.chi2.cdf(chi2, df=1)
    print(f'\nMcNemar test: chi2 = {chi2:.2f}, p = {p_value:.6f}')
    
    if p_value < 0.001:
        print('*** Highly significant (p < 0.001)')
    elif p_value < 0.01:
        print('** Significant (p < 0.01)')
    elif p_value < 0.05:
        print('* Significant (p < 0.05)')
    else:
        print('Not significant (p >= 0.05)')

In [None]:
# Interpretation
print('\n' + '='*70)
print('INTERPRETATION')
print('='*70)

if step1_acc > step10_acc + 0.20:
    print('RESULT: Step1-Only >> Step10-Only')
    print('CONCLUSION: ANSWER-CUE DOMINANCE CONFIRMED')
    print('')
    print('Corrupting Step 1 (first step) has minimal impact.')
    print('Corrupting Step 10 (Final cue) is catastrophic.')
    print('')
    print('This directly proves:')
    print('  1. The model relies heavily on Step 10 (Final=...)')
    print('  2. Earlier steps are dispensable (redundant)')
    print('  3. E2/E3 effects are due to Final cue destruction,')
    print('     NOT sequential chain integration (depth)')
elif step1_acc > step10_acc + 0.10:
    print('RESULT: Step1-Only > Step10-Only')
    print('  Partial support for answer-cue dominance.')
else:
    print('RESULT: Unexpected')
    print('  Need further investigation.')

## 13. Summary

In [None]:
print('='*70)
print('E5 EXPERIMENT COMPLETE')
print('='*70)
print(f'Date: {EXPERIMENT_DATE}')
print(f'Total experiments: {len(all_results)}')
print(f'Step1-Only accuracy: {step1_acc:.1%}')
print(f'Step10-Only accuracy: {step10_acc:.1%}')
print(f'Difference: {step1_acc - step10_acc:+.1%}')
print(f'\nFiles saved to: {SAVE_DIR}')
print('='*70)