# E2: Contamination Type Analysis

**Purpose**: Test if dual-route models resist specific contamination types

**Hypothesis**:
- If dual-route = "verification": WRONG drops, but IRR/LOC resist
- If dual-route = "reference": All types resist equally

**Contamination Types**:
1. **WRONG**: Conclusively misleading (wrong final calculation)
2. **LOC**: Local errors (intermediate step errors)
3. **IRR**: Irrelevant insertions (unrelated steps)

**Models**: Claude 4 Sonnet, GPT-4o, Claude 3.5 Haiku

**λ conditions**: 0.4, 0.8, 1.0

## 0. Setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
from datetime import datetime

EXPERIMENT_DATE = datetime.now().strftime('%Y%m%d')
SAVE_DIR = '/content/drive/MyDrive/CoT_Experiment'
SAVE_DIR_EXP = f'{SAVE_DIR}/e2_contamination_type_{EXPERIMENT_DATE}'
os.makedirs(SAVE_DIR_EXP, exist_ok=True)
os.makedirs(f'{SAVE_DIR_EXP}/results', exist_ok=True)
os.makedirs(f'{SAVE_DIR_EXP}/checkpoints', exist_ok=True)

print(f'Save directory: {SAVE_DIR_EXP}')

In [None]:
!pip install datasets openai anthropic pandas tqdm matplotlib -q
print('Dependencies installed.')

## 1. Configuration

In [None]:
import json
import re
import random
import time
import hashlib
from typing import List, Dict, Optional, Any
from dataclasses import dataclass
from datetime import datetime
from tqdm import tqdm
import pandas as pd
import numpy as np

# Configuration
GLOBAL_SEED = 20251224
N_PROBLEMS = 199  # Same as main experiment
I_FIXED = 10
LAMBDA_VALUES = [0.4, 0.8, 1.0]  # 3 points
CONTAMINATION_TYPES = ['WRONG', 'LOC', 'IRR']

# Models to test
MODELS = {
    'Claude 4 Sonnet': {
        'provider': 'anthropic',
        'api_name': 'claude-sonnet-4-20250514',
        'short': 'sonnet4'
    },
    'GPT-4o': {
        'provider': 'openai',
        'api_name': 'gpt-4o',
        'short': 'gpt4o'
    },
    'Claude 3.5 Haiku': {
        'provider': 'anthropic',
        'api_name': 'claude-3-5-haiku-latest',
        'short': 'haiku35'
    }
}

print('='*60)
print('E2: CONTAMINATION TYPE EXPERIMENT')
print('='*60)
print(f'Models: {list(MODELS.keys())}')
print(f'λ values: {LAMBDA_VALUES}')
print(f'Contamination types: {CONTAMINATION_TYPES}')
print(f'Total conditions per model: {len(LAMBDA_VALUES) * len(CONTAMINATION_TYPES)}')
print(f'Total inferences per model: {N_PROBLEMS * len(LAMBDA_VALUES) * len(CONTAMINATION_TYPES)}')

## 2. API Setup

In [None]:
import getpass
from openai import OpenAI
import anthropic

print("OpenAI APIキーを入力してください：")
OPENAI_API_KEY = getpass.getpass("OpenAI API Key: ")

print("\nAnthropic APIキーを入力してください：")
ANTHROPIC_API_KEY = getpass.getpass("Anthropic API Key: ")

openai_client = OpenAI(api_key=OPENAI_API_KEY)
anthropic_client = anthropic.Anthropic(api_key=ANTHROPIC_API_KEY)

def call_api(prompt: str, model_config: dict, max_tokens: int = 256) -> str:
    """Unified API call for both providers"""
    for attempt in range(3):
        try:
            if model_config['provider'] == 'openai':
                response = openai_client.chat.completions.create(
                    model=model_config['api_name'],
                    messages=[{"role": "user", "content": prompt}],
                    max_tokens=max_tokens,
                    temperature=0
                )
                return response.choices[0].message.content
            else:  # anthropic
                response = anthropic_client.messages.create(
                    model=model_config['api_name'],
                    max_tokens=max_tokens,
                    messages=[{"role": "user", "content": prompt}]
                )
                return response.content[0].text
        except Exception as e:
            print(f'API error (attempt {attempt+1}): {e}')
            time.sleep(2 ** attempt)
    return ""

# Test APIs
print('\nTesting APIs...')
for name, config in MODELS.items():
    resp = call_api("What is 2+2? Reply with just the number.", config)
    print(f'{name}: {resp.strip()}')

## 3. Load GSM8K

In [None]:
from datasets import load_dataset

@dataclass
class GSM8KProblem:
    index: int
    question: str
    answer_text: str
    final_answer: int

def extract_final_answer(answer_text: str) -> int:
    match = re.search(r'####\s*([\d,]+)', answer_text)
    if match:
        return int(match.group(1).replace(',', ''))
    raise ValueError('Could not extract final answer')

def save_json(data, filepath):
    with open(filepath, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=2)
    print(f'Saved: {filepath}')

dataset = load_dataset('gsm8k', 'main', split='test')
print(f'GSM8K loaded: {len(dataset)} problems')

# Select same problems as main experiment
rng = random.Random(GLOBAL_SEED)
indices = list(range(len(dataset)))
rng.shuffle(indices)
selected_indices = sorted(indices[:N_PROBLEMS])

problems = []
for idx in selected_indices:
    item = dataset[idx]
    try:
        final_ans = extract_final_answer(item['answer'])
        problems.append(GSM8KProblem(
            index=idx,
            question=item['question'],
            answer_text=item['answer'],
            final_answer=final_ans
        ))
    except:
        pass

print(f'Loaded {len(problems)} problems')

## 4. Contamination Type Generators

In [None]:
def generate_clean_cot(problem: GSM8KProblem, I: int) -> List[str]:
    """Extract clean reasoning steps from GSM8K answer"""
    lines = problem.answer_text.split('\n')
    steps = [l.strip() for l in lines if l.strip() and not l.strip().startswith('####')]
    while len(steps) < I:
        steps.append(f"Step {len(steps)+1}: Continue calculation.")
    return steps[:I]

def generate_WRONG_step(problem: GSM8KProblem, step_idx: int, rng: random.Random) -> str:
    """
    WRONG: Generate step that leads to definitively WRONG conclusion
    - Deliberately miscalculates to produce incorrect final answer
    """
    wrong_answer = problem.final_answer + rng.randint(10, 100) * rng.choice([1, -1])
    if wrong_answer < 0:
        wrong_answer = abs(wrong_answer)
    
    templates = [
        f"Step {step_idx+1}: Therefore, the total is {wrong_answer}.",
        f"Step {step_idx+1}: Adding everything together gives us {wrong_answer}.",
        f"Step {step_idx+1}: The final calculation shows {wrong_answer}.",
        f"Step {step_idx+1}: So the answer must be {wrong_answer}.",
    ]
    return rng.choice(templates)

def generate_LOC_step(problem: GSM8KProblem, step_idx: int, rng: random.Random) -> str:
    """
    LOC: Generate step with LOCAL arithmetic error
    - Intermediate calculation is wrong, but doesn't explicitly state final answer
    """
    a = rng.randint(10, 100)
    b = rng.randint(2, 20)
    wrong_result = a * b + rng.randint(-20, 20)  # Deliberately wrong
    
    templates = [
        f"Step {step_idx+1}: Calculating {a} × {b} = {wrong_result}.",
        f"Step {step_idx+1}: If we add {a} + {b} = {a + b + rng.randint(-5, 5)}.",
        f"Step {step_idx+1}: The difference {a} - {b} = {a - b + rng.randint(-5, 5)}.",
        f"Step {step_idx+1}: Dividing {a * b} by {b} gives {a + rng.randint(-3, 3)}.",
    ]
    return rng.choice(templates)

def generate_IRR_step(problem: GSM8KProblem, step_idx: int, rng: random.Random) -> str:
    """
    IRR: Generate IRRELEVANT step
    - Completely unrelated to the problem, doesn't affect calculation
    """
    templates = [
        f"Step {step_idx+1}: Let's consider the weather today.",
        f"Step {step_idx+1}: This reminds me of a similar problem about colors.",
        f"Step {step_idx+1}: Interestingly, the number 7 is considered lucky.",
        f"Step {step_idx+1}: Before continuing, note that math is useful in daily life.",
        f"Step {step_idx+1}: This type of problem appears frequently in textbooks.",
        f"Step {step_idx+1}: Let's take a moment to review our approach.",
        f"Step {step_idx+1}: Mathematics has a long and rich history.",
    ]
    return rng.choice(templates)

def generate_contaminated_cot(problem: GSM8KProblem, I: int, lam: float, 
                               contamination_type: str, seed: int) -> List[str]:
    """
    Generate CoT with specific contamination type
    """
    rng = random.Random(seed)
    clean_steps = generate_clean_cot(problem, I)
    
    n_contaminate = int(round(I * lam))
    contaminate_indices = set(rng.sample(range(I), n_contaminate)) if n_contaminate > 0 else set()
    
    # Select generator based on type
    generators = {
        'WRONG': generate_WRONG_step,
        'LOC': generate_LOC_step,
        'IRR': generate_IRR_step
    }
    generator = generators[contamination_type]
    
    mixed_steps = []
    for i in range(I):
        if i in contaminate_indices:
            mixed_steps.append(generator(problem, i, rng))
        else:
            mixed_steps.append(clean_steps[i] if i < len(clean_steps) else f"Step {i+1}: Continue.")
    
    return mixed_steps

# Test
test_prob = problems[0]
print('Test WRONG:', generate_WRONG_step(test_prob, 0, random.Random(42)))
print('Test LOC:', generate_LOC_step(test_prob, 0, random.Random(42)))
print('Test IRR:', generate_IRR_step(test_prob, 0, random.Random(42)))

## 5. Prompts and Parsing

In [None]:
def create_cot_prompt(problem: GSM8KProblem, cot_steps: List[str]) -> str:
    steps_text = '\n'.join(cot_steps)
    return f"""Here is a math problem with a provided reasoning trace.
Follow the reasoning and give the final numerical answer in JSON format.

Problem: {problem.question}

Reasoning trace:
{steps_text}

Based on this reasoning, what is the final answer?
Reply with ONLY: {{"final": <number>}}"""

def parse_answer(response: str) -> Optional[int]:
    try:
        match = re.search(r'\{[^}]*"final"\s*:\s*(\d+)[^}]*\}', response)
        if match:
            return int(match.group(1))
    except:
        pass
    try:
        numbers = re.findall(r'\b(\d+)\b', response)
        if numbers:
            return int(numbers[-1])
    except:
        pass
    return None

def derive_seed(global_seed: int, problem_id: int, lam: float, ctype: str) -> int:
    key = f"{global_seed}|{problem_id}|lam={lam}|type={ctype}"
    h = hashlib.sha256(key.encode("utf-8")).hexdigest()
    return int(h[:8], 16)

## 6. Run Experiment

In [None]:
# Select model to run
#@title Select Model { run: "auto" }
MODEL_CHOICE = "Claude 4 Sonnet" #@param ["Claude 4 Sonnet", "GPT-4o", "Claude 3.5 Haiku"]

model_config = MODELS[MODEL_CHOICE]
print(f'Running E2 for: {MODEL_CHOICE}')

In [None]:
print('='*60)
print(f'E2: CONTAMINATION TYPE EXPERIMENT - {MODEL_CHOICE}')
print('='*60)

all_results = []
total_trials = len(problems) * len(LAMBDA_VALUES) * len(CONTAMINATION_TYPES)

with tqdm(total=total_trials, desc='E2 Experiment') as pbar:
    for ctype in CONTAMINATION_TYPES:
        for lam in LAMBDA_VALUES:
            print(f'\n--- {ctype} @ λ={lam} ---')
            
            for problem in problems:
                # Generate contaminated CoT
                seed = derive_seed(GLOBAL_SEED, problem.index, lam, ctype)
                cot_steps = generate_contaminated_cot(problem, I_FIXED, lam, ctype, seed)
                
                # Query model
                prompt = create_cot_prompt(problem, cot_steps)
                response = call_api(prompt, model_config)
                
                answer = parse_answer(response)
                is_correct = (answer == problem.final_answer) if answer else False
                
                result = {
                    'problem_index': problem.index,
                    'model': MODEL_CHOICE,
                    'contamination_type': ctype,
                    'lam': lam,
                    'model_answer': answer,
                    'correct_answer': problem.final_answer,
                    'is_correct': is_correct,
                    'raw_output': response,
                    'timestamp': datetime.now().isoformat()
                }
                all_results.append(result)
                
                time.sleep(0.5)
                pbar.update(1)
            
            # Checkpoint
            save_json(all_results, f"{SAVE_DIR_EXP}/checkpoints/e2_{model_config['short']}_{ctype}_lam{lam}.json")

# Save final results
save_json(all_results, f"{SAVE_DIR_EXP}/results/e2_results_{model_config['short']}.json")
print('\n✓ E2 experiment complete!')

## 7. Analyze Results

In [None]:
import matplotlib.pyplot as plt

df = pd.DataFrame(all_results)

# Calculate accuracy by type and lambda
acc_table = df.groupby(['contamination_type', 'lam'])['is_correct'].mean().unstack()

print('='*60)
print(f'E2 RESULTS: {MODEL_CHOICE}')
print('='*60)
print('\nAccuracy by Contamination Type and λ:')
print(acc_table.round(3))

# Visualization
fig, ax = plt.subplots(figsize=(10, 6))

colors = {'WRONG': '#d62728', 'LOC': '#ff7f0e', 'IRR': '#2ca02c'}
markers = {'WRONG': 'o', 'LOC': 's', 'IRR': '^'}

for ctype in CONTAMINATION_TYPES:
    data = df[df['contamination_type'] == ctype].groupby('lam')['is_correct'].mean()
    ax.plot(data.index, data.values * 100, 
            marker=markers[ctype], color=colors[ctype],
            linewidth=2.5, markersize=10, label=ctype)

ax.set_xlabel('Corruption Rate (λ)', fontsize=13)
ax.set_ylabel('Accuracy (%)', fontsize=13)
ax.set_title(f'E2: Contamination Type Analysis - {MODEL_CHOICE}', fontsize=14)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
ax.set_ylim(0, 105)

plt.tight_layout()
plt.savefig(f"{SAVE_DIR_EXP}/e2_contamination_types_{model_config['short']}.png", dpi=300)
plt.show()

In [None]:
# Summary
summary = {
    'experiment': 'E2_Contamination_Type',
    'model': MODEL_CHOICE,
    'date': EXPERIMENT_DATE,
    'n_problems': len(problems),
    'accuracy_by_type_lambda': acc_table.to_dict()
}

# Interpretation
print('\n' + '='*60)
print('INTERPRETATION')
print('='*60)

# Check if WRONG drops more than others
wrong_drop = acc_table.loc['WRONG', 1.0] - acc_table.loc['WRONG', 0.4]
loc_drop = acc_table.loc['LOC', 1.0] - acc_table.loc['LOC', 0.4]
irr_drop = acc_table.loc['IRR', 1.0] - acc_table.loc['IRR', 0.4]

print(f'\nAccuracy drop (λ=0.4 → λ=1.0):')
print(f'  WRONG: {wrong_drop:+.1%}')
print(f'  LOC:   {loc_drop:+.1%}')
print(f'  IRR:   {irr_drop:+.1%}')

if abs(wrong_drop) > abs(loc_drop) and abs(wrong_drop) > abs(irr_drop):
    print('\n→ WRONG shows largest drop: dual-route may be "verification" type')
elif abs(wrong_drop) < abs(irr_drop):
    print('\n→ IRR shows larger drop: unexpected pattern')
else:
    print('\n→ Similar drops across types: dual-route may be "reference" type')

save_json(summary, f"{SAVE_DIR_EXP}/results/e2_summary_{model_config['short']}.json")