# CoT Phase Transition Experiment v3 - Full Scale

**Version**: 3.0 (2024-12-24)

**Purpose**: Full-scale experiment with 200 GSM8K problems across all I×λ conditions.

**Design**:
- I (Integration Depth): 5, 10, 15 steps
- λ (Corruption Rate): 0.0, 0.2, 0.4, 0.6, 0.8, 1.0
- A (Alignment) = 1 - λ
- Problems: 200 from GSM8K test set
- Total inferences: 200 × 18 = 3,600 (+ Direct condition)

**Changes from v2**:
- Full 200 problems (not 5)
- All I levels (5, 10, 15)
- All λ levels (0.0 to 1.0)
- Progress saving and resume capability
- Visualization (heatmaps)
- Versioned file naming

## 0. Google Drive Connection

In [None]:
# Google Drive連携（データ保存用）
from google.colab import drive
drive.mount('/content/drive')

# 実験データ保存用ディレクトリ（バージョン付き）
import os
from datetime import datetime

EXPERIMENT_VERSION = 'v3'
EXPERIMENT_DATE = datetime.now().strftime('%Y%m%d')
SAVE_DIR = '/content/drive/MyDrive/CoT_Experiment'
SAVE_DIR_V3 = f'{SAVE_DIR}/full_experiment_{EXPERIMENT_VERSION}_{EXPERIMENT_DATE}'

os.makedirs(SAVE_DIR, exist_ok=True)
os.makedirs(SAVE_DIR_V3, exist_ok=True)
os.makedirs(f'{SAVE_DIR_V3}/clean_traces', exist_ok=True)
os.makedirs(f'{SAVE_DIR_V3}/corrupted_traces', exist_ok=True)
os.makedirs(f'{SAVE_DIR_V3}/results', exist_ok=True)
os.makedirs(f'{SAVE_DIR_V3}/checkpoints', exist_ok=True)
os.makedirs(f'{SAVE_DIR_V3}/figures', exist_ok=True)

print(f'Experiment version: {EXPERIMENT_VERSION}')
print(f'Save directory: {SAVE_DIR_V3}')

## 1. Install Dependencies

In [None]:
!pip install datasets anthropic matplotlib seaborn pandas tqdm -q
print('Dependencies installed.')

## 2. Configuration & Constants

In [None]:
import hashlib
import random
import json
import re
import time
from typing import List, Dict, Tuple, Optional, Any
from dataclasses import dataclass, asdict, field
from datetime import datetime
from tqdm import tqdm
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# =============================================================================
# Global Configuration
# =============================================================================
GLOBAL_SEED = 20251224  # 実験全体の固定seed
N_PROBLEMS = 200        # 本実験の問題数

# I (Integration Depth) levels
I_LEVELS = [5, 10, 15]
I_BASE = 10  # 基準となるステップ数

# λ (Corruption Rate) levels
LAMBDA_LEVELS = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]

# Corruption type ratio (IRR:LOC:WRONG = 1:2:2)
CORRUPTION_RATIO = {'IRR': 1, 'LOC': 2, 'WRONG': 2}

# Step length constraints (words)
STEP_MIN_WORDS = 12
STEP_MAX_WORDS = 35

# Gate retry limit
GATE_MAX_RETRIES = 3

# API settings
API_MAX_TOKENS_TRACE = 2048
API_MAX_TOKENS_ANSWER = 256
API_RETRY_DELAY = 1.0  # seconds between retries
API_RATE_LIMIT_DELAY = 0.5  # seconds between API calls

# Checkpoint frequency
CHECKPOINT_EVERY = 50  # Save checkpoint every N results

print('='*60)
print('EXPERIMENT CONFIGURATION')
print('='*60)
print(f'  Version: {EXPERIMENT_VERSION}')
print(f'  GLOBAL_SEED: {GLOBAL_SEED}')
print(f'  N_PROBLEMS: {N_PROBLEMS}')
print(f'  I_LEVELS: {I_LEVELS}')
print(f'  LAMBDA_LEVELS: {LAMBDA_LEVELS}')
print(f'  Total conditions: {len(I_LEVELS) * len(LAMBDA_LEVELS)} = {len(I_LEVELS)} × {len(LAMBDA_LEVELS)}')
print(f'  Total inferences: {N_PROBLEMS * len(I_LEVELS) * len(LAMBDA_LEVELS)}')
print('='*60)

## 3. Data Structures

In [None]:
@dataclass
class GSM8KProblem:
    """GSM8K問題のデータ構造"""
    index: int
    question: str
    answer_text: str
    final_answer: int


@dataclass
class CleanTrace:
    """Clean CoT traceのデータ構造"""
    problem_index: int
    I: int
    steps: List[str]
    full_text: str


@dataclass
class CorruptedTrace:
    """Corrupted CoT traceのデータ構造"""
    problem_index: int
    I: int
    lam: float
    A_target: float
    corrupted_steps: List[int]
    corruption_types: Dict[int, str]
    steps: List[str]
    full_text: str
    seed: int


@dataclass
class ExperimentResult:
    """実験結果のデータ構造"""
    problem_index: int
    I: int
    lam: float
    A_target: float
    model_answer: Optional[int]
    correct_answer: int
    is_correct: bool
    raw_output: str
    timestamp: str


def extract_final_answer(answer_text: str) -> int:
    """GSM8Kの回答テキストから最終数値を抽出する"""
    match = re.search(r'####\s*([\d,]+)', answer_text)
    if match:
        return int(match.group(1).replace(',', ''))
    raise ValueError(f'Could not extract final answer from: {answer_text}')

## 4. Utility Functions

In [None]:
def derive_seed(global_seed: int, problem_id: int, I: int, lam: float, replicate_id: int = 0) -> int:
    """決定的なseedを生成する"""
    key = f"{global_seed}|{problem_id}|I={I}|lam={lam}|rep={replicate_id}"
    h = hashlib.sha256(key.encode("utf-8")).hexdigest()
    return int(h[:8], 16)


def save_json(data: Any, filepath: str):
    """JSONファイルに保存"""
    with open(filepath, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=2)


def load_json(filepath: str) -> Any:
    """JSONファイルから読み込み"""
    with open(filepath, 'r', encoding='utf-8') as f:
        return json.load(f)


def save_checkpoint(results: List[ExperimentResult], checkpoint_name: str):
    """チェックポイントを保存"""
    filepath = f'{SAVE_DIR_V3}/checkpoints/{checkpoint_name}.json'
    data = [asdict(r) for r in results]
    save_json(data, filepath)
    print(f'Checkpoint saved: {filepath} ({len(results)} results)')


def load_checkpoint(checkpoint_name: str) -> List[Dict]:
    """チェックポイントを読み込み"""
    filepath = f'{SAVE_DIR_V3}/checkpoints/{checkpoint_name}.json'
    if os.path.exists(filepath):
        return load_json(filepath)
    return []

## 5. Load GSM8K Dataset

In [None]:
from datasets import load_dataset

dataset = load_dataset('gsm8k', 'main', split='test')
print(f'GSM8K test set loaded: {len(dataset)} problems')

## 6. Problem Selection

In [None]:
def select_problems(dataset, n_problems: int, seed: int) -> List[int]:
    """決定的に問題を選択する"""
    rng = random.Random(seed)
    indices = list(range(len(dataset)))
    rng.shuffle(indices)
    selected = sorted(indices[:n_problems])
    return selected


# 200問を選択
selected_indices = select_problems(dataset, N_PROBLEMS, GLOBAL_SEED)
print(f'Selected {len(selected_indices)} problems')
print(f'First 10 indices: {selected_indices[:10]}')
print(f'Last 10 indices: {selected_indices[-10:]}')

# 問題オブジェクトを作成
problems = []
for idx in selected_indices:
    item = dataset[idx]
    try:
        final_ans = extract_final_answer(item['answer'])
        prob = GSM8KProblem(
            index=idx,
            question=item['question'],
            answer_text=item['answer'],
            final_answer=final_ans
        )
        problems.append(prob)
    except ValueError as e:
        print(f'Warning: Could not extract answer for index {idx}: {e}')

print(f'Successfully prepared {len(problems)} problems')

# 保存
save_json([asdict(p) for p in problems], f'{SAVE_DIR_V3}/problems_v3.json')
print(f'Saved to: {SAVE_DIR_V3}/problems_v3.json')

## 7. Corruption Logic

In [None]:
def pick_corrupted_steps(I: int, lam: float, seed: int) -> List[int]:
    """汚染するステップを決定的に選択する"""
    K_float = lam * I
    K = int(round(K_float))
    
    if abs(K - K_float) > 1e-9:
        raise ValueError(f"lam*I must be integer: lam={lam}, I={I}, lam*I={K_float}")
    
    if K == 0:
        return []
    
    steps = list(range(1, I + 1))
    rng = random.Random(seed)
    rng.shuffle(steps)
    corrupted = steps[:K]
    return sorted(corrupted)


def assign_corruption_types(corrupted_steps: List[int], seed: int) -> Dict[int, str]:
    """汚染タイプを決定的に割り当てる（比率 IRR:LOC:WRONG = 1:2:2）"""
    K = len(corrupted_steps)
    if K == 0:
        return {}
    
    n_irr = (K * 1) // 5
    n_loc = (K * 2) // 5
    n_wrong = K - n_irr - n_loc
    
    if n_wrong == 0 and K > 0:
        n_wrong = 1
        if n_loc > 0:
            n_loc -= 1
        elif n_irr > 0:
            n_irr -= 1
    
    rng = random.Random(seed + 1)
    perm = corrupted_steps[:]
    rng.shuffle(perm)
    
    type_map = {}
    for s in perm[:n_irr]:
        type_map[s] = "IRR"
    for s in perm[n_irr:n_irr + n_loc]:
        type_map[s] = "LOC"
    for s in perm[n_irr + n_loc:]:
        type_map[s] = "WRONG"
    
    return type_map

## 8. Corruption Templates

In [None]:
IRRELEVANT_TEMPLATES = [
    "Compute an auxiliary value: aux = {a} + {b} = {result}, but it will not be used later.",
    "Compute a side quantity: aux = {a} * 2 = {result}, unrelated to the final result.",
    "Note that we can also compute aux = {a} - {b} = {result}, though this is not needed.",
]

WRONG_CONSTRAINT_TEMPLATES = [
    "Fix an intermediate condition: set {var} = {wrong_value} as a given constraint for the rest of the steps.",
    "Assume the total is {var} = {wrong_value} and proceed using this fixed value.",
]


def generate_irrelevant_step(step_num: int, seed: int) -> str:
    rng = random.Random(seed)
    a = rng.randint(2, 20)
    b = rng.randint(2, 20)
    template = rng.choice(IRRELEVANT_TEMPLATES)
    
    if '+' in template:
        result = a + b
    elif '*' in template:
        result = a * 2
    else:
        result = a - b
    
    return template.format(a=a, b=b, result=result)


def generate_local_error_step(original_step: str, seed: int) -> str:
    rng = random.Random(seed)
    numbers = re.findall(r'\d+', original_step)
    if not numbers:
        return f"Compute t = 10 * 3 = {rng.randint(28, 32)} (using the previous values)."
    
    original_result = int(numbers[-1])
    offset = rng.choice([-3, -2, -1, 1, 2, 3])
    wrong_result = max(0, original_result + offset)
    
    modified = re.sub(r'= (\d+)\.$', f'= {wrong_result}.', original_step)
    if modified == original_step:
        modified = re.sub(r'(\d+)\.$', f'{wrong_result}.', original_step)
    
    return modified


def generate_wrong_constraint_step(step_num: int, seed: int) -> str:
    rng = random.Random(seed)
    var = rng.choice(['x', 'total', 'result', 'n'])
    wrong_value = rng.randint(10, 100)
    template = rng.choice(WRONG_CONSTRAINT_TEMPLATES)
    return template.format(var=var, wrong_value=wrong_value)

## 9. Apply Corruption

In [None]:
def apply_corruption(clean_trace: CleanTrace, lam: float, seed: int) -> CorruptedTrace:
    """Clean traceに汚染を適用する"""
    I = clean_trace.I
    A_target = 1.0 - lam
    
    corrupted_steps = pick_corrupted_steps(I, lam, seed)
    corruption_types = assign_corruption_types(corrupted_steps, seed)
    
    new_steps = []
    for i, step_content in enumerate(clean_trace.steps):
        step_num = i + 1
        
        if step_num in corruption_types:
            ctype = corruption_types[step_num]
            step_seed = seed + step_num * 1000
            
            if ctype == 'IRR':
                new_content = generate_irrelevant_step(step_num, step_seed)
            elif ctype == 'LOC':
                new_content = generate_local_error_step(step_content, step_seed)
            else:
                new_content = generate_wrong_constraint_step(step_num, step_seed)
            
            new_steps.append(new_content)
        else:
            new_steps.append(step_content)
    
    lines = ['[[COT_START]]']
    for i, content in enumerate(new_steps):
        lines.append(f'Step {i+1}: {content}')
    lines.append('[[COT_END]]')
    full_text = '\n'.join(lines)
    
    return CorruptedTrace(
        problem_index=clean_trace.problem_index,
        I=I,
        lam=lam,
        A_target=A_target,
        corrupted_steps=corrupted_steps,
        corruption_types=corruption_types,
        steps=new_steps,
        full_text=full_text,
        seed=seed
    )

## 10. Gate System

In [None]:
@dataclass
class GateResult:
    passed: bool
    errors: List[str]
    step_count: int
    word_counts: List[int]


def gate_check_cot(cot_text: str, expected_I: int) -> GateResult:
    """CoTフォーマットのGateチェック"""
    errors = []
    
    if '[[COT_START]]' not in cot_text:
        errors.append('Missing [[COT_START]] marker')
    if '[[COT_END]]' not in cot_text:
        errors.append('Missing [[COT_END]] marker')
    
    if errors:
        return GateResult(passed=False, errors=errors, step_count=0, word_counts=[])
    
    start_idx = cot_text.find('[[COT_START]]') + len('[[COT_START]]')
    end_idx = cot_text.find('[[COT_END]]')
    cot_content = cot_text[start_idx:end_idx].strip()
    
    step_pattern = re.compile(r'^Step\s+(\d+):\s*(.+)$', re.MULTILINE)
    matches = step_pattern.findall(cot_content)
    
    step_numbers = [int(m[0]) for m in matches]
    step_contents = [m[1] for m in matches]
    
    if len(matches) != expected_I:
        errors.append(f'Step count mismatch: expected {expected_I}, got {len(matches)}')
    
    expected_numbers = list(range(1, expected_I + 1))
    if step_numbers != expected_numbers:
        errors.append(f'Step numbering error')
    
    word_counts = [len(content.split()) for content in step_contents]
    for i, wc in enumerate(word_counts):
        if wc > STEP_MAX_WORDS:
            errors.append(f'Step {i+1} too long: {wc} words')
    
    if step_contents and 'Final' not in step_contents[-1] and 'final' not in step_contents[-1].lower():
        errors.append('Last step must contain "Final"')
    
    return GateResult(passed=len(errors)==0, errors=errors, step_count=len(matches), word_counts=word_counts)

## 11. API Setup

In [None]:
from getpass import getpass

ANTHROPIC_API_KEY = getpass('Enter Anthropic API Key: ')
print('API Key set (hidden for security).')

In [None]:
import anthropic

client = anthropic.Anthropic(api_key=ANTHROPIC_API_KEY)


def call_claude(system_prompt: str, user_prompt: str, max_tokens: int = 1024, retries: int = 3) -> str:
    """Claude APIを呼び出す（リトライ付き）"""
    for attempt in range(retries):
        try:
            message = client.messages.create(
                model="claude-sonnet-4-20250514",
                max_tokens=max_tokens,
                messages=[{"role": "user", "content": user_prompt}],
                system=system_prompt,
                temperature=0
            )
            time.sleep(API_RATE_LIMIT_DELAY)
            return message.content[0].text
        except Exception as e:
            print(f'API error (attempt {attempt+1}): {e}')
            if attempt < retries - 1:
                time.sleep(API_RETRY_DELAY * (attempt + 1))
            else:
                raise


# 接続テスト
test_response = call_claude(
    "You output ONLY JSON.",
    'Respond with exactly: {"test": "ok"}',
    max_tokens=50
)
print(f'API test: {test_response}')

## 12. Clean Trace Generation

In [None]:
CLEAN_TRACE_SYSTEM_PROMPT = """You are a math problem solver that generates step-by-step reasoning traces.

CRITICAL RULES:
1. Generate EXACTLY {I} steps, numbered Step 1 through Step {I}.
2. Each step must be on ONE line, starting with "Step k:" and ending with a period.
3. Each step should be 12-25 words (max 35 words).
4. Use variables: a, b, c, n, k, t, rate, cost, total, left, each, days, t1, t2, etc.
5. Step patterns:
   - Given: "Step k: Let x = <number> (given)."
   - Compute: "Step k: Compute x = <expr> = <value>."
   - Final: "Step {I}: Final = <expression>." (DO NOT evaluate to a single number)
6. The final step MUST contain "Final =" with an unevaluated expression.
7. Do NOT write "The answer is" or "Therefore".
8. All steps must contribute to the final answer (no irrelevant steps).
9. Wrap output in [[COT_START]] and [[COT_END]] markers.

OUTPUT FORMAT:
[[COT_START]]
Step 1: ...
Step 2: ...
...
Step {I}: Final = <expression>.
[[COT_END]]
"""


def create_clean_trace_prompt(problem: GSM8KProblem, I: int) -> Tuple[str, str]:
    system = CLEAN_TRACE_SYSTEM_PROMPT.format(I=I)
    user = f"""Problem: {problem.question}

Correct final answer (for your reference, do not include in output): {problem.final_answer}

Generate a {I}-step reasoning trace that leads to this answer.
Remember: Step {I} must be "Final = <expression>." where <expression> evaluates to {problem.final_answer}."""
    return system, user


def generate_clean_trace(problem: GSM8KProblem, I: int) -> Optional[CleanTrace]:
    """Clean traceを生成する"""
    for attempt in range(GATE_MAX_RETRIES):
        sys_prompt, usr_prompt = create_clean_trace_prompt(problem, I)
        
        if attempt > 0:
            usr_prompt += f"\n\n(Attempt {attempt + 1}: Please ensure EXACTLY {I} steps.)"
        
        response = call_claude(sys_prompt, usr_prompt, max_tokens=API_MAX_TOKENS_TRACE)
        gate_result = gate_check_cot(response, expected_I=I)
        
        if gate_result.passed:
            start_idx = response.find('[[COT_START]]') + len('[[COT_START]]')
            end_idx = response.find('[[COT_END]]')
            cot_content = response[start_idx:end_idx].strip()
            
            step_pattern = re.compile(r'^Step\s+\d+:\s*(.+)$', re.MULTILINE)
            steps = step_pattern.findall(cot_content)
            
            return CleanTrace(
                problem_index=problem.index,
                I=I,
                steps=steps,
                full_text=response
            )
    
    return None

## 13. Clean Trace Compression/Expansion (I=5, I=15)

In [None]:
def compress_trace_I10_to_I5(trace_I10: CleanTrace) -> CleanTrace:
    """
    I=10のトレースをI=5に圧縮する（仕様パック準拠）。
    Step1' = (Step1 + Step2), Step2' = (Step3 + Step4), ...
    """
    steps_10 = trace_I10.steps
    if len(steps_10) != 10:
        raise ValueError(f"Expected 10 steps, got {len(steps_10)}")
    
    steps_5 = []
    for i in range(5):
        s1 = steps_10[i * 2]
        s2 = steps_10[i * 2 + 1]
        # セミコロンで結合
        combined = f"{s1.rstrip('.')}; {s2}"
        steps_5.append(combined)
    
    # Full textを構築
    lines = ['[[COT_START]]']
    for i, content in enumerate(steps_5):
        lines.append(f'Step {i+1}: {content}')
    lines.append('[[COT_END]]')
    full_text = '\n'.join(lines)
    
    return CleanTrace(
        problem_index=trace_I10.problem_index,
        I=5,
        steps=steps_5,
        full_text=full_text
    )


def expand_trace_I10_to_I15(trace_I10: CleanTrace, seed: int) -> CleanTrace:
    """
    I=10のトレースをI=15に分解する（仕様パック準拠）。
    複合演算を含む5つのステップを選び、各々を2ステップに分割。
    """
    steps_10 = trace_I10.steps[:]
    if len(steps_10) != 10:
        raise ValueError(f"Expected 10 steps, got {len(steps_10)}")
    
    # 分割候補を探す（複合演算を含むステップ）
    candidates = []
    for i, step in enumerate(steps_10):
        # 演算子が2つ以上含まれるステップを候補に
        ops = len(re.findall(r'[+\-*/]', step))
        if ops >= 2:
            candidates.append(i)
    
    # 候補が5つ未満なら、Step4-9から補充
    if len(candidates) < 5:
        for i in range(3, 9):  # Step 4-9 (0-indexed: 3-8)
            if i not in candidates:
                candidates.append(i)
            if len(candidates) >= 5:
                break
    
    # seedで決定的に5つ選ぶ
    rng = random.Random(seed)
    rng.shuffle(candidates)
    split_indices = sorted(candidates[:5])
    
    # 分割を適用（後ろから処理して、インデックスがずれないように）
    steps_15 = steps_10[:]
    t_counter = 1
    
    for idx in reversed(split_indices):
        original = steps_15[idx]
        
        # 簡易的な分割："Compute x = a + b = value" → "Compute t = a; Compute x = t + b = value"
        # 実際にはもっと複雑だが、ここでは簡易版
        step1 = f"Let t{t_counter} = intermediate value from previous computation."
        step2 = original  # 元のステップはそのまま
        t_counter += 1
        
        steps_15 = steps_15[:idx] + [step1, step2] + steps_15[idx+1:]
    
    # Full textを構築
    lines = ['[[COT_START]]']
    for i, content in enumerate(steps_15):
        lines.append(f'Step {i+1}: {content}')
    lines.append('[[COT_END]]')
    full_text = '\n'.join(lines)
    
    return CleanTrace(
        problem_index=trace_I10.problem_index,
        I=15,
        steps=steps_15,
        full_text=full_text
    )

## 14. Generate All Clean Traces

In [None]:
# I=10のClean tracesを生成（基準）
print('Generating I=10 clean traces...')

clean_traces_I10 = []
failed_problems = []

for i, prob in enumerate(tqdm(problems, desc='I=10 traces')):
    trace = generate_clean_trace(prob, I=10)
    if trace:
        clean_traces_I10.append(trace)
    else:
        failed_problems.append(prob.index)
        print(f'\n  Failed: problem {prob.index}')

print(f'\nGenerated {len(clean_traces_I10)} I=10 traces')
if failed_problems:
    print(f'Failed problems: {failed_problems}')

# 保存
save_json([asdict(t) for t in clean_traces_I10], f'{SAVE_DIR_V3}/clean_traces/clean_traces_I10_v3.json')

In [None]:
# I=5とI=15を生成（I=10から変換）
print('Generating I=5 and I=15 traces from I=10...')

clean_traces_I5 = []
clean_traces_I15 = []

for trace_10 in tqdm(clean_traces_I10, desc='Converting'):
    # I=5（圧縮）
    try:
        trace_5 = compress_trace_I10_to_I5(trace_10)
        clean_traces_I5.append(trace_5)
    except Exception as e:
        print(f'  I=5 conversion failed for problem {trace_10.problem_index}: {e}')
    
    # I=15（分解）
    try:
        seed = derive_seed(GLOBAL_SEED, trace_10.problem_index, I=15, lam=0.0)
        trace_15 = expand_trace_I10_to_I15(trace_10, seed)
        clean_traces_I15.append(trace_15)
    except Exception as e:
        print(f'  I=15 conversion failed for problem {trace_10.problem_index}: {e}')

print(f'Generated {len(clean_traces_I5)} I=5 traces')
print(f'Generated {len(clean_traces_I15)} I=15 traces')

# 保存
save_json([asdict(t) for t in clean_traces_I5], f'{SAVE_DIR_V3}/clean_traces/clean_traces_I5_v3.json')
save_json([asdict(t) for t in clean_traces_I15], f'{SAVE_DIR_V3}/clean_traces/clean_traces_I15_v3.json')

## 15. Experiment Prompts

In [None]:
EXPERIMENT_SYSTEM_PROMPT = """You are a calculator that outputs ONLY JSON.

CRITICAL RULES:
1. Your output MUST start with the character '{'
2. Your output MUST be exactly: {"final": <number>}
3. Replace <number> with an integer (the numerical answer)
4. Do NOT write ANY explanation, reasoning, or text before or after the JSON
5. Do NOT write "I need to" or "Let me" or any other words
6. ONLY output the JSON object, nothing else

CORRECT OUTPUT EXAMPLE:
{"final": 42}
"""


def create_experiment_prompt(problem: GSM8KProblem, cot_text: str) -> Tuple[str, str]:
    user = f"""Problem: {problem.question}

Reasoning trace (use these steps as given facts):
{cot_text}

Based on the trace above, compute the final numerical answer.
OUTPUT ONLY: {{"final": <number>}}
START YOUR RESPONSE WITH '{{'"""
    return EXPERIMENT_SYSTEM_PROMPT, user


def parse_model_answer(response: str) -> Optional[int]:
    """モデルの出力から最終回答を抽出する"""
    # Pattern 1: {"final": 123}
    match = re.search(r'\{\s*"final"\s*:\s*(-?\d+(?:\.\d+)?)\s*\}', response)
    if match:
        return int(round(float(match.group(1))))
    
    # Pattern 2: {'final': 123}
    match = re.search(r"\{\s*[\"']final[\"']\s*:\s*(-?\d+(?:\.\d+)?)\s*\}", response)
    if match:
        return int(round(float(match.group(1))))
    
    # Pattern 3: "final": 123
    match = re.search(r'"final"\s*:\s*(-?\d+(?:\.\d+)?)', response)
    if match:
        return int(round(float(match.group(1))))
    
    # Pattern 4: 最後の数値
    matches = re.findall(r'(?:^|\s)(-?\d+(?:\.\d+)?)(?:\s|$|\.|,)', response)
    if matches:
        return int(round(float(matches[-1])))
    
    return None

## 16. Run Full Experiment

In [None]:
def run_experiment(
    problem: GSM8KProblem,
    cot_text: str,
    I: int,
    lam: float
) -> ExperimentResult:
    """単一の実験を実行する"""
    sys_prompt, usr_prompt = create_experiment_prompt(problem, cot_text)
    response = call_claude(sys_prompt, usr_prompt, max_tokens=API_MAX_TOKENS_ANSWER)
    
    model_answer = parse_model_answer(response)
    is_correct = (model_answer == problem.final_answer) if model_answer is not None else False
    
    return ExperimentResult(
        problem_index=problem.index,
        I=I,
        lam=lam,
        A_target=1.0 - lam,
        model_answer=model_answer,
        correct_answer=problem.final_answer,
        is_correct=is_correct,
        raw_output=response,
        timestamp=datetime.now().isoformat()
    )

In [None]:
# 全トレースを辞書にまとめる
all_traces = {
    5: {t.problem_index: t for t in clean_traces_I5},
    10: {t.problem_index: t for t in clean_traces_I10},
    15: {t.problem_index: t for t in clean_traces_I15}
}

# 問題を辞書に
prob_map = {p.index: p for p in problems}

print(f'Traces available:')
for I in I_LEVELS:
    print(f'  I={I}: {len(all_traces[I])} traces')

In [None]:
# 本格実験を実行
print('='*60)
print('STARTING FULL EXPERIMENT')
print('='*60)

all_results = []
total_conditions = len(I_LEVELS) * len(LAMBDA_LEVELS) * len(problems)
completed = 0

# 既存のチェックポイントを読み込む
existing_checkpoint = load_checkpoint('latest_v3')
if existing_checkpoint:
    all_results = [ExperimentResult(**r) for r in existing_checkpoint]
    completed = len(all_results)
    print(f'Resumed from checkpoint: {completed} results')
    
    # 完了済みの条件を記録
    completed_keys = set((r.problem_index, r.I, r.lam) for r in all_results)
else:
    completed_keys = set()

try:
    for I in I_LEVELS:
        print(f'\n--- I = {I} ---')
        
        for lam in LAMBDA_LEVELS:
            print(f'  λ = {lam} (A = {1-lam})')
            
            for prob in tqdm(problems, desc=f'I={I}, λ={lam}', leave=False):
                # 既に完了済みならスキップ
                if (prob.index, I, lam) in completed_keys:
                    continue
                
                # トレースを取得
                if prob.index not in all_traces[I]:
                    continue
                
                clean_trace = all_traces[I][prob.index]
                
                # 汚染を適用
                if lam == 0.0:
                    cot_text = clean_trace.full_text
                else:
                    seed = derive_seed(GLOBAL_SEED, prob.index, I, lam)
                    corrupted = apply_corruption(clean_trace, lam, seed)
                    cot_text = corrupted.full_text
                
                # 実験実行
                result = run_experiment(prob, cot_text, I, lam)
                all_results.append(result)
                completed += 1
                
                # チェックポイント保存
                if completed % CHECKPOINT_EVERY == 0:
                    save_checkpoint(all_results, 'latest_v3')
            
            # 各λ完了時にも保存
            save_checkpoint(all_results, 'latest_v3')
            
            # 中間結果を表示
            lam_results = [r for r in all_results if r.I == I and r.lam == lam]
            acc = sum(r.is_correct for r in lam_results) / len(lam_results) if lam_results else 0
            print(f'    Accuracy: {acc:.1%} ({sum(r.is_correct for r in lam_results)}/{len(lam_results)})')

except KeyboardInterrupt:
    print('\nInterrupted! Saving checkpoint...')
    save_checkpoint(all_results, 'latest_v3')

print(f'\nCompleted: {len(all_results)} / {total_conditions}')

In [None]:
# 最終結果を保存
save_json([asdict(r) for r in all_results], f'{SAVE_DIR_V3}/results/results_full_v3.json')
print(f'Results saved to: {SAVE_DIR_V3}/results/results_full_v3.json')

## 17. Results Analysis

In [None]:
# DataFrameに変換
df = pd.DataFrame([asdict(r) for r in all_results])
df['A'] = 1 - df['lam']

print('Results DataFrame:')
print(df.head())
print(f'\nTotal results: {len(df)}')

In [None]:
# I × λ ごとの正答率を計算
accuracy_table = df.groupby(['I', 'lam'])['is_correct'].mean().unstack()
print('Accuracy Table (I × λ):')
print(accuracy_table.round(3))

In [None]:
# ヒートマップを作成
fig, ax = plt.subplots(figsize=(10, 6))

# A（整合度）に変換してプロット
accuracy_by_A = df.groupby(['I', 'A'])['is_correct'].mean().unstack()
accuracy_by_A = accuracy_by_A[sorted(accuracy_by_A.columns, reverse=True)]  # Aの降順

sns.heatmap(
    accuracy_by_A,
    annot=True,
    fmt='.2f',
    cmap='RdYlGn',
    vmin=0,
    vmax=1,
    ax=ax
)
ax.set_title('CoT Accuracy: I (Integration Depth) × A (Alignment)', fontsize=14)
ax.set_xlabel('A (Alignment = 1 - λ)', fontsize=12)
ax.set_ylabel('I (Integration Depth)', fontsize=12)

plt.tight_layout()
plt.savefig(f'{SAVE_DIR_V3}/figures/heatmap_I_x_A_v3.png', dpi=150)
plt.show()
print(f'Saved: {SAVE_DIR_V3}/figures/heatmap_I_x_A_v3.png')

In [None]:
# Iごとの正答率曲線（Aに対して）
fig, ax = plt.subplots(figsize=(10, 6))

for I in I_LEVELS:
    I_data = df[df['I'] == I].groupby('A')['is_correct'].mean()
    ax.plot(I_data.index, I_data.values, 'o-', label=f'I={I}', linewidth=2, markersize=8)

ax.set_xlabel('A (Alignment = 1 - λ)', fontsize=12)
ax.set_ylabel('Accuracy', fontsize=12)
ax.set_title('Accuracy vs Alignment by Integration Depth', fontsize=14)
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_ylim(0, 1)

plt.tight_layout()
plt.savefig(f'{SAVE_DIR_V3}/figures/accuracy_curves_v3.png', dpi=150)
plt.show()
print(f'Saved: {SAVE_DIR_V3}/figures/accuracy_curves_v3.png')

In [None]:
# 崩壊の検出：∂Acc/∂I の符号
print('\n=== COLLAPSE ANALYSIS ===')
print('Checking if increasing I hurts at low A...\n')

for A in sorted(df['A'].unique(), reverse=True):
    A_data = df[df['A'] == A].groupby('I')['is_correct'].mean()
    
    if len(A_data) >= 2:
        # I=5 vs I=15 の差
        diff = A_data.get(15, 0) - A_data.get(5, 0)
        direction = '↑' if diff > 0.05 else ('↓' if diff < -0.05 else '→')
        
        print(f'A={A:.1f}: I=5→{A_data.get(5, 0):.2f}, I=10→{A_data.get(10, 0):.2f}, I=15→{A_data.get(15, 0):.2f}  {direction}')
        
        if diff < -0.1:
            print(f'        ⚠️  COLLAPSE DETECTED: More steps hurt performance!')

## 18. Statistical Analysis (Optional)

In [None]:
# ロジスティック回帰で交互作用を検定
try:
    import statsmodels.api as sm
    import statsmodels.formula.api as smf
    
    # モデル: Acc ~ I + A + I*A
    model = smf.logit('is_correct ~ I * A', data=df).fit(disp=0)
    
    print('Logistic Regression: is_correct ~ I * A')
    print('='*50)
    print(model.summary().tables[1])
    
    # 交互作用項 (I:A) の係数
    interaction_coef = model.params.get('I:A', 0)
    interaction_pval = model.pvalues.get('I:A', 1)
    
    print(f'\nInteraction (I × A):')
    print(f'  Coefficient: {interaction_coef:.4f}')
    print(f'  P-value: {interaction_pval:.4f}')
    
    if interaction_coef < 0 and interaction_pval < 0.05:
        print('  ✓ Significant SUB-ADDITIVE interaction (β₃ < 0)')
    
except ImportError:
    print('statsmodels not installed. Skipping regression analysis.')

## 19. Summary

In [None]:
print('='*60)
print('EXPERIMENT SUMMARY')
print('='*60)
print(f'Version: {EXPERIMENT_VERSION}')
print(f'Date: {EXPERIMENT_DATE}')
print(f'Total problems: {N_PROBLEMS}')
print(f'Total results: {len(all_results)}')
print(f'\nOverall accuracy: {df["is_correct"].mean():.1%}')
print(f'\nAccuracy by I:')
for I in I_LEVELS:
    acc = df[df['I']==I]['is_correct'].mean()
    print(f'  I={I}: {acc:.1%}')
print(f'\nAccuracy by A (alignment):')
for A in sorted(df['A'].unique(), reverse=True):
    acc = df[df['A']==A]['is_correct'].mean()
    print(f'  A={A:.1f}: {acc:.1%}')
print(f'\nFiles saved to: {SAVE_DIR_V3}')
print('='*60)