# E2: Position Experiment - Early vs Late Corruption

**Paper**: A2 (Redundancy vs Depth in CoT Length Effects)

**Purpose**: Test whether the POSITION of corruption matters.

**Hypothesis**:
- **Depth hypothesis**: Late corruption is more harmful (breaks final chain integration) → Late << Early
- **Redundancy hypothesis**: Position doesn't matter, only COUNT matters → Early ≈ Late

**Design**:
- c = 0.4 (4 steps corrupted out of 10)
- L = 10
- Early: Corrupt steps 1-4
- Late: Corrupt steps 7-10

**Expected inference count**: 199 problems × 2 conditions = 398

**Date**: 2025-01-02
**GLOBAL_SEED**: 20251224

## 0. Google Drive Connection

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
from datetime import datetime

EXPERIMENT_NAME = 'E2_position'
EXPERIMENT_DATE = datetime.now().strftime('%Y%m%d')

BASE_DIR = '/content/drive/MyDrive/CoT_Experiment'
V3_DATA_DIR = f'{BASE_DIR}/full_experiment_v3_20251224'

SAVE_DIR = f'{BASE_DIR}/{EXPERIMENT_NAME}_{EXPERIMENT_DATE}'
os.makedirs(SAVE_DIR, exist_ok=True)
os.makedirs(f'{SAVE_DIR}/results', exist_ok=True)

print(f'Experiment: {EXPERIMENT_NAME}')
print(f'V3 data directory: {V3_DATA_DIR}')
print(f'Save directory: {SAVE_DIR}')

## 1. Install Dependencies

In [None]:
!pip install datasets anthropic matplotlib pandas tqdm scipy -q
print('Dependencies installed.')

## 2. Configuration

In [None]:
import hashlib
import random
import json
import re
import time
from typing import List, Dict, Tuple, Optional, Any
from dataclasses import dataclass, asdict, field
from datetime import datetime
from tqdm import tqdm
import pandas as pd
import numpy as np

# =============================================================================
# Global Configuration
# =============================================================================
GLOBAL_SEED = 20251224
POSITION_SEED = 20250102  # Additional seed for position experiments

# Experiment parameters
L = 10
C_TARGET = 0.4  # 4 steps corrupted
K_CORRUPT = int(round(C_TARGET * L))  # = 4

# Position definitions
EARLY_STEPS = [1, 2, 3, 4]  # First 4 steps
LATE_STEPS = [7, 8, 9, 10]  # Last 4 steps

# Corruption type ratio
CORRUPTION_RATIO = {'IRR': 1, 'LOC': 2, 'WRONG': 2}

# API settings
API_MAX_TOKENS_ANSWER = 256
API_RETRY_DELAY = 1.0
API_RATE_LIMIT_DELAY = 0.5

print('='*60)
print('E2: POSITION EXPERIMENT CONFIGURATION')
print('='*60)
print(f'  GLOBAL_SEED: {GLOBAL_SEED}')
print(f'  L (trace length): {L}')
print(f'  c (corruption fraction): {C_TARGET}')
print(f'  K_corrupt: {K_CORRUPT}')
print(f'  Early corruption steps: {EARLY_STEPS}')
print(f'  Late corruption steps: {LATE_STEPS}')
print('='*60)

## 3. Data Structures

In [None]:
@dataclass
class GSM8KProblem:
    index: int
    question: str
    answer_text: str
    final_answer: int

@dataclass
class CleanTrace:
    problem_index: int
    I: int
    steps: List[str]
    full_text: str

@dataclass
class PositionCorruptedTrace:
    problem_index: int
    L: int
    c: float
    position: str  # 'early' or 'late'
    corrupted_steps: List[int]
    corruption_types: Dict[int, str]
    steps: List[str]
    full_text: str
    seed: int

@dataclass
class ExperimentResult:
    problem_index: int
    condition: str  # 'early' or 'late'
    L: int
    c: float
    K_clean: int
    corrupted_steps: List[int]
    model_answer: Optional[int]
    correct_answer: int
    is_correct: bool
    raw_output: str
    timestamp: str

## 4. Utility Functions

In [None]:
def derive_seed(global_seed: int, problem_id: int, L: int, c: float, extra: str = '') -> int:
    key = f"{global_seed}|{problem_id}|L={L}|c={c}|{extra}"
    h = hashlib.sha256(key.encode("utf-8")).hexdigest()
    return int(h[:8], 16)

def save_json(data: Any, filepath: str):
    with open(filepath, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=2)

def load_json(filepath: str) -> Any:
    with open(filepath, 'r', encoding='utf-8') as f:
        return json.load(f)

## 5. Load Existing Data

In [None]:
problems_path = f'{V3_DATA_DIR}/problems_v3.json'
problems_data = load_json(problems_path)
problems = [GSM8KProblem(**p) for p in problems_data]
print(f'Loaded {len(problems)} problems')

traces_path = f'{V3_DATA_DIR}/clean_traces/clean_traces_I10_v3.json'
traces_data = load_json(traces_path)
clean_traces = [CleanTrace(**t) for t in traces_data]
trace_map = {t.problem_index: t for t in clean_traces}
print(f'Loaded {len(clean_traces)} clean traces (L=10)')

prob_map = {p.index: p for p in problems}

## 6. Corruption Templates

In [None]:
IRRELEVANT_TEMPLATES = [
    "Compute an auxiliary value: aux = {a} + {b} = {result}, but it will not be used later.",
    "Compute a side quantity: aux = {a} * 2 = {result}, unrelated to the final result.",
    "Note that we can also compute aux = {a} - {b} = {result}, though this is not needed.",
]

WRONG_CONSTRAINT_TEMPLATES = [
    "Fix an intermediate condition: set {var} = {wrong_value} as a given constraint for the rest of the steps.",
    "Assume the total is {var} = {wrong_value} and proceed using this fixed value.",
]

def generate_irrelevant_step(step_num: int, seed: int) -> str:
    rng = random.Random(seed)
    a = rng.randint(2, 20)
    b = rng.randint(2, 20)
    template = rng.choice(IRRELEVANT_TEMPLATES)
    if '+' in template:
        result = a + b
    elif '*' in template:
        result = a * 2
    else:
        result = a - b
    return template.format(a=a, b=b, result=result)

def generate_local_error_step(original_step: str, seed: int) -> str:
    rng = random.Random(seed)
    numbers = re.findall(r'\d+', original_step)
    if not numbers:
        return f"Compute t = 10 * 3 = {rng.randint(28, 32)} (using the previous values)."
    original_result = int(numbers[-1])
    offset = rng.choice([-3, -2, -1, 1, 2, 3])
    wrong_result = max(0, original_result + offset)
    modified = re.sub(r'= (\d+)\.$', f'= {wrong_result}.', original_step)
    if modified == original_step:
        modified = re.sub(r'(\d+)\.$', f'{wrong_result}.', original_step)
    return modified

def generate_wrong_constraint_step(step_num: int, seed: int) -> str:
    rng = random.Random(seed)
    var = rng.choice(['x', 'total', 'result', 'n'])
    wrong_value = rng.randint(10, 100)
    template = rng.choice(WRONG_CONSTRAINT_TEMPLATES)
    return template.format(var=var, wrong_value=wrong_value)

## 7. Position-Based Corruption Logic

In [None]:
def assign_corruption_types_fixed(corrupted_steps: List[int], seed: int) -> Dict[int, str]:
    """Assign corruption types to fixed positions (IRR:LOC:WRONG = 1:2:2)"""
    K = len(corrupted_steps)
    if K == 0:
        return {}
    
    # For K=4: n_irr=0, n_loc=1, n_wrong=3 (prioritize WRONG)
    # Adjusted to be closer to 1:2:2 ratio
    n_irr = max(0, (K * 1) // 5)  # 0 for K=4
    n_loc = max(1, (K * 2) // 5)  # 1 for K=4
    n_wrong = K - n_irr - n_loc   # 3 for K=4
    
    rng = random.Random(seed)
    perm = corrupted_steps[:]
    rng.shuffle(perm)
    
    type_map = {}
    for s in perm[:n_irr]:
        type_map[s] = "IRR"
    for s in perm[n_irr:n_irr + n_loc]:
        type_map[s] = "LOC"
    for s in perm[n_irr + n_loc:]:
        type_map[s] = "WRONG"
    
    return type_map


def create_position_corrupted_trace(
    clean_trace: CleanTrace,
    position: str,  # 'early' or 'late'
    seed: int
) -> PositionCorruptedTrace:
    """
    Create a trace with corruption at fixed positions.
    
    - 'early': Corrupt steps 1-4
    - 'late': Corrupt steps 7-10
    """
    L = clean_trace.I
    
    if position == 'early':
        corrupted_steps = EARLY_STEPS[:]
    elif position == 'late':
        corrupted_steps = LATE_STEPS[:]
    else:
        raise ValueError(f"Unknown position: {position}")
    
    # Assign corruption types
    corruption_types = assign_corruption_types_fixed(corrupted_steps, seed)
    
    # Apply corruption
    new_steps = []
    for i, step_content in enumerate(clean_trace.steps):
        step_num = i + 1
        if step_num in corruption_types:
            ctype = corruption_types[step_num]
            step_seed = seed + step_num * 1000
            if ctype == 'IRR':
                new_content = generate_irrelevant_step(step_num, step_seed)
            elif ctype == 'LOC':
                new_content = generate_local_error_step(step_content, step_seed)
            else:
                new_content = generate_wrong_constraint_step(step_num, step_seed)
            new_steps.append(new_content)
        else:
            new_steps.append(step_content)
    
    # Build full text
    lines = ['[[COT_START]]']
    for i, content in enumerate(new_steps):
        lines.append(f'Step {i+1}: {content}')
    lines.append('[[COT_END]]')
    full_text = '\n'.join(lines)
    
    return PositionCorruptedTrace(
        problem_index=clean_trace.problem_index,
        L=L,
        c=C_TARGET,
        position=position,
        corrupted_steps=corrupted_steps,
        corruption_types=corruption_types,
        steps=new_steps,
        full_text=full_text,
        seed=seed
    )

In [None]:
# Test position corruption
test_trace = clean_traces[0]
test_seed = derive_seed(POSITION_SEED, test_trace.problem_index, L, C_TARGET, 'early')

early_trace = create_position_corrupted_trace(test_trace, 'early', test_seed)
late_trace = create_position_corrupted_trace(test_trace, 'late', test_seed + 1)

print('Early corruption:')
print(f'  Corrupted steps: {early_trace.corrupted_steps}')
print(f'  Types: {early_trace.corruption_types}')

print('\nLate corruption:')
print(f'  Corrupted steps: {late_trace.corrupted_steps}')
print(f'  Types: {late_trace.corruption_types}')

## 8. API Setup

In [None]:
from getpass import getpass

ANTHROPIC_API_KEY = getpass('Enter Anthropic API Key: ')
print('API Key set.')

In [None]:
import anthropic

client = anthropic.Anthropic(api_key=ANTHROPIC_API_KEY)

def call_claude(system_prompt: str, user_prompt: str, max_tokens: int = 1024, retries: int = 3) -> str:
    for attempt in range(retries):
        try:
            message = client.messages.create(
                model="claude-sonnet-4-20250514",
                max_tokens=max_tokens,
                messages=[{"role": "user", "content": user_prompt}],
                system=system_prompt,
                temperature=0
            )
            time.sleep(API_RATE_LIMIT_DELAY)
            return message.content[0].text
        except Exception as e:
            print(f'API error (attempt {attempt+1}): {e}')
            if attempt < retries - 1:
                time.sleep(API_RETRY_DELAY * (attempt + 1))
            else:
                raise

test_response = call_claude(
    "You output ONLY JSON.",
    'Respond with exactly: {"test": "ok"}',
    max_tokens=50
)
print(f'API test: {test_response}')

## 9. Experiment Prompts

In [None]:
EXPERIMENT_SYSTEM_PROMPT = """You are a calculator that outputs ONLY JSON.

CRITICAL RULES:
1. Your output MUST start with the character '{'
2. Your output MUST be exactly: {"final": <number>}
3. Replace <number> with an integer (the numerical answer)
4. Do NOT write ANY explanation, reasoning, or text before or after the JSON
5. Do NOT write "I need to" or "Let me" or any other words
6. ONLY output the JSON object, nothing else

CORRECT OUTPUT EXAMPLE:
{"final": 42}
"""

def create_experiment_prompt(problem: GSM8KProblem, cot_text: str) -> Tuple[str, str]:
    user = f"""Problem: {problem.question}

Reasoning trace (use these steps as given facts):
{cot_text}

Based on the trace above, compute the final numerical answer.
OUTPUT ONLY: {{"final": <number>}}
START YOUR RESPONSE WITH '{{'"""
    return EXPERIMENT_SYSTEM_PROMPT, user

def parse_model_answer(response: str) -> Optional[int]:
    match = re.search(r'\{\s*"final"\s*:\s*(-?\d+(?:\.\d+)?)\s*\}', response)
    if match:
        return int(round(float(match.group(1))))
    match = re.search(r"\{\s*[\"']final[\"']\s*:\s*(-?\d+(?:\.\d+)?)\s*\}", response)
    if match:
        return int(round(float(match.group(1))))
    match = re.search(r'"final"\s*:\s*(-?\d+(?:\.\d+)?)', response)
    if match:
        return int(round(float(match.group(1))))
    matches = re.findall(r'(?:^|\s)(-?\d+(?:\.\d+)?)(?:\s|$|\.|,)', response)
    if matches:
        return int(round(float(matches[-1])))
    return None

## 10. Run Experiment

In [None]:
def run_experiment(
    problem: GSM8KProblem,
    trace: PositionCorruptedTrace
) -> ExperimentResult:
    sys_prompt, usr_prompt = create_experiment_prompt(problem, trace.full_text)
    response = call_claude(sys_prompt, usr_prompt, max_tokens=API_MAX_TOKENS_ANSWER)
    
    model_answer = parse_model_answer(response)
    is_correct = (model_answer == problem.final_answer) if model_answer is not None else False
    
    return ExperimentResult(
        problem_index=problem.index,
        condition=trace.position,
        L=trace.L,
        c=trace.c,
        K_clean=trace.L - len(trace.corrupted_steps),
        corrupted_steps=trace.corrupted_steps,
        model_answer=model_answer,
        correct_answer=problem.final_answer,
        is_correct=is_correct,
        raw_output=response,
        timestamp=datetime.now().isoformat()
    )

In [None]:
print('='*60)
print('E2: POSITION EXPERIMENT')
print('='*60)
print(f'Conditions: Early (steps 1-4), Late (steps 7-10)')
print(f'c = {C_TARGET}, L = {L}, K_corrupt = {K_CORRUPT}')
print(f'Expected inferences: {len(problems) * 2}')
print('='*60)

results_early = []
results_late = []
traces_log = []

for prob in tqdm(problems, desc='Position experiment'):
    if prob.index not in trace_map:
        continue
    
    clean_trace = trace_map[prob.index]
    
    # Early condition
    early_seed = derive_seed(POSITION_SEED, prob.index, L, C_TARGET, 'early')
    early_trace = create_position_corrupted_trace(clean_trace, 'early', early_seed)
    traces_log.append(asdict(early_trace))
    result_early = run_experiment(prob, early_trace)
    results_early.append(result_early)
    
    # Late condition
    late_seed = derive_seed(POSITION_SEED, prob.index, L, C_TARGET, 'late')
    late_trace = create_position_corrupted_trace(clean_trace, 'late', late_seed)
    traces_log.append(asdict(late_trace))
    result_late = run_experiment(prob, late_trace)
    results_late.append(result_late)

print(f'\nCompleted: {len(results_early)} early + {len(results_late)} late')

## 11. Save Results

In [None]:
all_results = results_early + results_late
save_json([asdict(r) for r in all_results], f'{SAVE_DIR}/results/E2_position_results.json')
print(f'Results saved: {SAVE_DIR}/results/E2_position_results.json')

save_json(traces_log, f'{SAVE_DIR}/results/E2_position_traces.json')
print(f'Traces saved: {SAVE_DIR}/results/E2_position_traces.json')

## 12. Analysis

In [None]:
df_early = pd.DataFrame([asdict(r) for r in results_early])
df_late = pd.DataFrame([asdict(r) for r in results_late])

early_acc = df_early['is_correct'].mean()
late_acc = df_late['is_correct'].mean()

print('='*60)
print('E2 RESULTS SUMMARY')
print('='*60)
print(f'Early (steps 1-4 corrupted): {early_acc:.1%} ({df_early["is_correct"].sum()}/{len(df_early)})')
print(f'Late (steps 7-10 corrupted): {late_acc:.1%} ({df_late["is_correct"].sum()}/{len(df_late)})')
print(f'Difference (Late - Early): {late_acc - early_acc:+.1%}')
print('='*60)

In [None]:
from scipy import stats

# Merge for paired comparison
df_early_m = df_early[['problem_index', 'is_correct']].rename(columns={'is_correct': 'correct_early'})
df_late_m = df_late[['problem_index', 'is_correct']].rename(columns={'is_correct': 'correct_late'})
merged = pd.merge(df_early_m, df_late_m, on='problem_index')

# Contingency table
a = ((merged['correct_early'] == True) & (merged['correct_late'] == True)).sum()
b = ((merged['correct_early'] == True) & (merged['correct_late'] == False)).sum()
c = ((merged['correct_early'] == False) & (merged['correct_late'] == True)).sum()
d = ((merged['correct_early'] == False) & (merged['correct_late'] == False)).sum()

print('\nContingency Table:')
print(f'  Both correct: {a}')
print(f'  Early only: {b}')
print(f'  Late only: {c}')
print(f'  Both wrong: {d}')

# McNemar's test
if b + c > 0:
    chi2 = (abs(b - c) - 1)**2 / (b + c)
    p_value = 1 - stats.chi2.cdf(chi2, df=1)
    print(f'\nMcNemar test: χ² = {chi2:.2f}, p = {p_value:.4f}')
    
    if p_value > 0.05:
        print('→ No significant difference between Early and Late')
        print('→ Supports REDUNDANCY hypothesis (position doesn\'t matter)')
    else:
        print('→ Significant difference detected')
        if late_acc < early_acc:
            print('→ Late corruption more harmful (supports DEPTH hypothesis)')
        else:
            print('→ Early corruption more harmful')

## 13. Summary

In [None]:
print('='*60)
print('E2 EXPERIMENT COMPLETE')
print('='*60)
print(f'Date: {EXPERIMENT_DATE}')
print(f'Total experiments: {len(all_results)}')
print(f'Early accuracy: {early_acc:.1%}')
print(f'Late accuracy: {late_acc:.1%}')
print(f'\nFiles saved to: {SAVE_DIR}')
print('='*60)