# CoT Phase Transition Experiment v3 - A_crit Detection (Fine-grained λ)

**Version**: 3.0 (2024-12-24)

**Purpose**: Precisely identify the critical alignment point (A_crit) where CoT collapse occurs.

**Design**:
- Same 200 problems as Experiment 1
- I = 10 only (representative depth)
- λ ∈ {0.1, 0.3, 0.5, 0.7, 0.9} (fills gaps in Experiment 1's grid)
- Combined with Exp1 data: full 0.1-step resolution from 0.0 to 1.0

**Why This Matters**:
- Experiment 1 has λ ∈ {0.0, 0.2, 0.4, 0.6, 0.8, 1.0}
- This adds λ ∈ {0.1, 0.3, 0.5, 0.7, 0.9}
- Together: 11-point λ curve for precise A_crit estimation

## 0. Google Drive Connection

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
from datetime import datetime

EXPERIMENT_VERSION = 'v3'
EXPERIMENT_DATE = datetime.now().strftime('%Y%m%d')

SAVE_DIR = '/content/drive/MyDrive/CoT_Experiment'
SAVE_DIR_ACRIT = f'{SAVE_DIR}/acrit_experiment_{EXPERIMENT_VERSION}_{EXPERIMENT_DATE}'

os.makedirs(SAVE_DIR_ACRIT, exist_ok=True)
os.makedirs(f'{SAVE_DIR_ACRIT}/results', exist_ok=True)
os.makedirs(f'{SAVE_DIR_ACRIT}/checkpoints', exist_ok=True)
os.makedirs(f'{SAVE_DIR_ACRIT}/figures', exist_ok=True)

print(f'A_crit experiment save directory: {SAVE_DIR_ACRIT}')

## 1. Install Dependencies

In [None]:
!pip install datasets anthropic pandas tqdm matplotlib -q
print('Dependencies installed.')

## 2. Configuration

In [None]:
import json
import re
import random
import time
import hashlib
from typing import List, Dict, Optional, Any
from dataclasses import dataclass, asdict
from datetime import datetime
from tqdm import tqdm
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# =============================================================================
# Configuration - MUST MATCH EXPERIMENT 1
# =============================================================================
GLOBAL_SEED = 20251224
N_PROBLEMS = 200

# Fine-grained λ values (fills gaps in Experiment 1)
# Experiment 1: λ ∈ {0.0, 0.2, 0.4, 0.6, 0.8, 1.0}
# This experiment: λ ∈ {0.1, 0.3, 0.5, 0.7, 0.9}
LAMBDA_FINE = [0.1, 0.3, 0.5, 0.7, 0.9]

# Fixed I = 10 for this experiment
I_FIXED = 10

# API settings
API_MAX_TOKENS_ANSWER = 256
API_RATE_LIMIT_DELAY = 0.5
CHECKPOINT_EVERY = 50

print('='*60)
print('A_CRIT DETECTION EXPERIMENT')
print('='*60)
print(f'  GLOBAL_SEED: {GLOBAL_SEED}')
print(f'  N_PROBLEMS: {N_PROBLEMS}')
print(f'  I (fixed): {I_FIXED}')
print(f'  λ values: {LAMBDA_FINE}')
print(f'  Total inferences: {N_PROBLEMS * len(LAMBDA_FINE)}')
print('='*60)

## 3. Data Structures & Utilities

In [None]:
@dataclass
class GSM8KProblem:
    index: int
    question: str
    answer_text: str
    final_answer: int

@dataclass
class CleanTrace:
    problem_index: int
    I: int
    steps: List[str]
    full_text: str

@dataclass
class ExperimentResult:
    problem_index: int
    I: int
    lam: float
    A_target: float
    model_answer: Optional[int]
    correct_answer: int
    is_correct: bool
    raw_output: str
    timestamp: str

def extract_final_answer(answer_text: str) -> int:
    match = re.search(r'####\s*([\d,]+)', answer_text)
    if match:
        return int(match.group(1).replace(',', ''))
    raise ValueError(f'Could not extract final answer')

def save_json(data: Any, filepath: str):
    with open(filepath, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=2)
    print(f'Saved: {filepath}')

def load_json(filepath: str) -> Any:
    with open(filepath, 'r', encoding='utf-8') as f:
        return json.load(f)

def derive_seed(global_seed: int, problem_id: int, I: int, lam: float, replicate_id: int = 0) -> int:
    key = f"{global_seed}|{problem_id}|I={I}|lam={lam}|rep={replicate_id}"
    h = hashlib.sha256(key.encode("utf-8")).hexdigest()
    return int(h[:8], 16)

## 4. Load GSM8K and Problems

In [None]:
from datasets import load_dataset

dataset = load_dataset('gsm8k', 'main', split='test')
print(f'GSM8K test set loaded: {len(dataset)} problems')

def select_problems(dataset, n_problems: int, seed: int) -> List[int]:
    rng = random.Random(seed)
    indices = list(range(len(dataset)))
    rng.shuffle(indices)
    return sorted(indices[:n_problems])

# Same seed → Same problems as Experiment 1
selected_indices = select_problems(dataset, N_PROBLEMS, GLOBAL_SEED)
print(f'Selected {len(selected_indices)} problems (same as Experiment 1)')

problems = []
for idx in selected_indices:
    item = dataset[idx]
    try:
        final_ans = extract_final_answer(item['answer'])
        prob = GSM8KProblem(
            index=idx,
            question=item['question'],
            answer_text=item['answer'],
            final_answer=final_ans
        )
        problems.append(prob)
    except ValueError:
        pass

print(f'Prepared {len(problems)} problems')
prob_map = {p.index: p for p in problems}

## 5. Load Clean Traces from Experiment 1

In [None]:
# Find Experiment 1 directory
exp1_dirs = [d for d in os.listdir(SAVE_DIR) if d.startswith('full_experiment_v3')]
if exp1_dirs:
    EXP1_DIR = f'{SAVE_DIR}/{sorted(exp1_dirs)[-1]}'
else:
    EXP1_DIR = f'{SAVE_DIR}/pilot_v2'

print(f'Experiment 1 directory: {EXP1_DIR}')

# Load I=10 clean traces
clean_traces_path = f'{EXP1_DIR}/clean_traces/clean_traces_I10_v3.json'
if not os.path.exists(clean_traces_path):
    clean_traces_path = f'{EXP1_DIR}/clean_traces_I10.json'
if not os.path.exists(clean_traces_path):
    for root, dirs, files in os.walk(SAVE_DIR):
        for f in files:
            if 'clean_traces_I10' in f and f.endswith('.json'):
                clean_traces_path = os.path.join(root, f)
                break

print(f'Loading clean traces from: {clean_traces_path}')
clean_traces_data = load_json(clean_traces_path)
print(f'Loaded {len(clean_traces_data)} clean traces')

traces_dict = {t['problem_index']: t for t in clean_traces_data}

## 6. Corruption Logic

In [None]:
def pick_corrupted_steps(I: int, lam: float, seed: int) -> List[int]:
    K = int(round(lam * I))
    if K == 0:
        return []
    steps = list(range(1, I + 1))
    rng = random.Random(seed)
    rng.shuffle(steps)
    return sorted(steps[:K])

def assign_corruption_types(corrupted_steps: List[int], seed: int) -> Dict[int, str]:
    K = len(corrupted_steps)
    if K == 0:
        return {}
    n_irr = (K * 1) // 5
    n_loc = (K * 2) // 5
    n_wrong = K - n_irr - n_loc
    if n_wrong == 0 and K > 0:
        n_wrong = 1
        if n_loc > 0:
            n_loc -= 1
        elif n_irr > 0:
            n_irr -= 1
    rng = random.Random(seed + 1)
    perm = corrupted_steps[:]
    rng.shuffle(perm)
    type_map = {}
    for s in perm[:n_irr]:
        type_map[s] = "IRR"
    for s in perm[n_irr:n_irr + n_loc]:
        type_map[s] = "LOC"
    for s in perm[n_irr + n_loc:]:
        type_map[s] = "WRONG"
    return type_map

IRRELEVANT_TEMPLATES = [
    "Compute an auxiliary value: aux = {a} + {b} = {result}, but it will not be used later.",
]
WRONG_CONSTRAINT_TEMPLATES = [
    "Fix an intermediate condition: set {var} = {wrong_value} as a given constraint for the rest of the steps.",
]

def generate_irrelevant_step(step_num: int, seed: int) -> str:
    rng = random.Random(seed)
    a, b = rng.randint(2, 20), rng.randint(2, 20)
    return IRRELEVANT_TEMPLATES[0].format(a=a, b=b, result=a+b)

def generate_local_error_step(original_step: str, seed: int) -> str:
    rng = random.Random(seed)
    numbers = re.findall(r'\d+', original_step)
    if not numbers:
        return f"Compute t = 10 * 3 = {rng.randint(28, 32)} (using the previous values)."
    original_result = int(numbers[-1])
    offset = rng.choice([-3, -2, -1, 1, 2, 3])
    wrong_result = max(0, original_result + offset)
    modified = re.sub(r'= (\d+)\.$', f'= {wrong_result}.', original_step)
    if modified == original_step:
        modified = re.sub(r'(\d+)\.$', f'{wrong_result}.', original_step)
    return modified

def generate_wrong_constraint_step(step_num: int, seed: int) -> str:
    rng = random.Random(seed)
    var = rng.choice(['x', 'total', 'result', 'n'])
    wrong_value = rng.randint(10, 100)
    return WRONG_CONSTRAINT_TEMPLATES[0].format(var=var, wrong_value=wrong_value)

def apply_corruption(trace_data: dict, lam: float, seed: int) -> str:
    steps = trace_data['steps'][:]
    I = len(steps)
    corrupted_steps = pick_corrupted_steps(I, lam, seed)
    corruption_types = assign_corruption_types(corrupted_steps, seed)
    new_steps = []
    for i, step_content in enumerate(steps):
        step_num = i + 1
        if step_num in corruption_types:
            ctype = corruption_types[step_num]
            step_seed = seed + step_num * 1000
            if ctype == 'IRR':
                new_content = generate_irrelevant_step(step_num, step_seed)
            elif ctype == 'LOC':
                new_content = generate_local_error_step(step_content, step_seed)
            else:
                new_content = generate_wrong_constraint_step(step_num, step_seed)
            new_steps.append(new_content)
        else:
            new_steps.append(step_content)
    lines = ['[[COT_START]]']
    for i, content in enumerate(new_steps):
        lines.append(f'Step {i+1}: {content}')
    lines.append('[[COT_END]]')
    return '\n'.join(lines)

print('Corruption logic defined.')

## 7. API Setup

In [None]:
from getpass import getpass

ANTHROPIC_API_KEY = getpass('Enter Anthropic API Key: ')
print('API Key set.')

In [None]:
import anthropic

client = anthropic.Anthropic(api_key=ANTHROPIC_API_KEY)

def call_claude(system_prompt: str, user_prompt: str, max_tokens: int = 256, retries: int = 3) -> str:
    for attempt in range(retries):
        try:
            message = client.messages.create(
                model="claude-sonnet-4-20250514",
                max_tokens=max_tokens,
                messages=[{"role": "user", "content": user_prompt}],
                system=system_prompt,
                temperature=0
            )
            time.sleep(API_RATE_LIMIT_DELAY)
            return message.content[0].text
        except Exception as e:
            print(f'API error (attempt {attempt+1}): {e}')
            if attempt < retries - 1:
                time.sleep(1.0 * (attempt + 1))
            else:
                raise

# Test
test_response = call_claude(
    "You output ONLY JSON.",
    'Respond with exactly: {"test": "ok"}',
    max_tokens=50
)
print(f'API test: {test_response}')

## 8. Experiment Prompts & Parsing

In [None]:
EXPERIMENT_SYSTEM_PROMPT = """You are a calculator that outputs ONLY JSON.

CRITICAL RULES:
1. Your output MUST start with the character '{'
2. Your output MUST be exactly: {"final": <number>}
3. Replace <number> with an integer (the numerical answer)
4. Do NOT write ANY explanation, reasoning, or text before or after the JSON
5. ONLY output the JSON object, nothing else

CORRECT OUTPUT EXAMPLE:
{"final": 42}
"""

def create_experiment_prompt(problem: GSM8KProblem, cot_text: str) -> tuple:
    user = f"""Problem: {problem.question}

Reasoning trace (use these steps as given facts):
{cot_text}

Based on the trace above, compute the final numerical answer.
OUTPUT ONLY: {{"final": <number>}}
START YOUR RESPONSE WITH '{{'"""
    return EXPERIMENT_SYSTEM_PROMPT, user

def parse_model_answer(response: str) -> Optional[int]:
    match = re.search(r'\{\s*"final"\s*:\s*(-?\d+(?:\.\d+)?)\s*\}', response)
    if match:
        return int(round(float(match.group(1))))
    match = re.search(r"\{\s*[\"']final[\"']\s*:\s*(-?\d+(?:\.\d+)?)\s*\}", response)
    if match:
        return int(round(float(match.group(1))))
    match = re.search(r'"final"\s*:\s*(-?\d+(?:\.\d+)?)', response)
    if match:
        return int(round(float(match.group(1))))
    matches = re.findall(r'(?:^|\s)(-?\d+(?:\.\d+)?)(?:\s|$|\.|,)', response)
    if matches:
        return int(round(float(matches[-1])))
    return None

print('Prompts and parsing defined.')

## 9. Run A_crit Experiment

In [None]:
def run_experiment(problem: GSM8KProblem, cot_text: str, I: int, lam: float) -> ExperimentResult:
    sys_prompt, usr_prompt = create_experiment_prompt(problem, cot_text)
    response = call_claude(sys_prompt, usr_prompt, max_tokens=API_MAX_TOKENS_ANSWER)
    
    model_answer = parse_model_answer(response)
    is_correct = (model_answer == problem.final_answer) if model_answer is not None else False
    
    return ExperimentResult(
        problem_index=problem.index,
        I=I,
        lam=lam,
        A_target=1.0 - lam,
        model_answer=model_answer,
        correct_answer=problem.final_answer,
        is_correct=is_correct,
        raw_output=response,
        timestamp=datetime.now().isoformat()
    )

In [None]:
# Full experiment
print('='*60)
print('A_CRIT DETECTION EXPERIMENT')
print('='*60)

all_results = []

for lam in LAMBDA_FINE:
    print(f'\nλ = {lam} (A = {1-lam})')
    
    lam_results = []
    for prob in tqdm(problems, desc=f'λ={lam}'):
        if prob.index not in traces_dict:
            continue
        
        trace_data = traces_dict[prob.index]
        
        # Apply corruption
        seed = derive_seed(GLOBAL_SEED, prob.index, I=I_FIXED, lam=lam)
        corrupted_cot = apply_corruption(trace_data, lam, seed)
        
        # Run experiment
        result = run_experiment(prob, corrupted_cot, I=I_FIXED, lam=lam)
        lam_results.append(result)
        all_results.append(result)
        
        # Checkpoint
        if len(all_results) % CHECKPOINT_EVERY == 0:
            save_json([asdict(r) for r in all_results], 
                     f'{SAVE_DIR_ACRIT}/checkpoints/latest_acrit_v3.json')
    
    # Report accuracy for this λ
    acc = sum(r.is_correct for r in lam_results) / len(lam_results) if lam_results else 0
    print(f'  Accuracy: {acc:.1%} ({sum(r.is_correct for r in lam_results)}/{len(lam_results)})')

# Save final results
save_json([asdict(r) for r in all_results], f'{SAVE_DIR_ACRIT}/results/acrit_results_v3.json')

## 10. Analysis & Visualization

In [None]:
df = pd.DataFrame([asdict(r) for r in all_results])

print('\nA_crit Experiment Results (λ ∈ {0.1, 0.3, 0.5, 0.7, 0.9}):')
print(df.groupby('lam')['is_correct'].agg(['mean', 'sum', 'count']))

In [None]:
# Try to load Experiment 1 results to combine
exp1_results_path = f'{EXP1_DIR}/results/results_full_v3.json'

if os.path.exists(exp1_results_path):
    exp1_data = load_json(exp1_results_path)
    exp1_df = pd.DataFrame(exp1_data)
    
    # Filter to I=10 only
    exp1_I10 = exp1_df[exp1_df['I'] == I_FIXED]
    
    # Combine
    combined_df = pd.concat([exp1_I10, df], ignore_index=True)
    
    print('\n=== COMBINED RESULTS (Exp1 + Exp4, I=10) ===')
    combined_summary = combined_df.groupby('lam')['is_correct'].agg(['mean', 'count'])
    combined_summary.columns = ['Accuracy', 'N']
    combined_summary['A'] = 1 - combined_summary.index
    print(combined_summary.sort_index())
else:
    print('Experiment 1 results not found. Will combine after Exp1 completes.')
    combined_df = df

In [None]:
# Plot accuracy curve
acc_by_lam = combined_df.groupby('lam')['is_correct'].mean().sort_index()

fig, ax = plt.subplots(figsize=(10, 6))

ax.plot(acc_by_lam.index, acc_by_lam.values, 'o-', linewidth=2, markersize=8, color='#2E8B57')

# Mark potential A_crit
ax.axhline(y=0.7, color='red', linestyle='--', alpha=0.5, label='Potential A_crit threshold')

ax.set_xlabel('λ (Corruption Rate)', fontsize=12)
ax.set_ylabel('Accuracy', fontsize=12)
ax.set_title(f'Accuracy vs Corruption Rate (I={I_FIXED}, Fine-grained)', fontsize=14)
ax.set_xlim(-0.05, 1.05)
ax.set_ylim(0, 1.05)
ax.grid(True, alpha=0.3)
ax.legend()

# Add A axis on top
ax2 = ax.twiny()
ax2.set_xlim(ax.get_xlim())
ax2.set_xticks([0, 0.2, 0.4, 0.6, 0.8, 1.0])
ax2.set_xticklabels(['1.0', '0.8', '0.6', '0.4', '0.2', '0.0'])
ax2.set_xlabel('A (Alignment = 1 - λ)', fontsize=12)

plt.tight_layout()
plt.savefig(f'{SAVE_DIR_ACRIT}/figures/accuracy_curve_fine_v3.png', dpi=150)
plt.show()
print(f'Saved: {SAVE_DIR_ACRIT}/figures/accuracy_curve_fine_v3.png')

In [None]:
# Estimate A_crit (where accuracy drops below threshold)
THRESHOLD = 0.7  # 70% accuracy threshold

acc_sorted = acc_by_lam.sort_index()

# Find λ where accuracy first drops below threshold
lam_crit = None
for lam, acc in acc_sorted.items():
    if acc < THRESHOLD:
        lam_crit = lam
        break

if lam_crit is not None:
    A_crit = 1 - lam_crit
    print(f'\n=== A_CRIT ESTIMATION ===')
    print(f'Threshold: {THRESHOLD:.0%}')
    print(f'λ_crit ≈ {lam_crit}')
    print(f'A_crit ≈ {A_crit}')
    print(f'\nInterpretation: CoT becomes unreliable when alignment drops below ~{A_crit:.1f}')
else:
    print('\nA_crit not detected within the measured range.')

## 11. Summary

In [None]:
print('='*60)
print('A_CRIT EXPERIMENT SUMMARY')
print('='*60)
print(f'Version: {EXPERIMENT_VERSION}')
print(f'Date: {EXPERIMENT_DATE}')
print(f'I (fixed): {I_FIXED}')
print(f'λ values tested: {LAMBDA_FINE}')
print(f'Total results: {len(all_results)}')
print(f'\nAccuracy by λ:')
for lam in sorted(df['lam'].unique()):
    acc = df[df['lam']==lam]['is_correct'].mean()
    print(f'  λ={lam} (A={1-lam:.1f}): {acc:.1%}')
print(f'\nResults saved to: {SAVE_DIR_ACRIT}')
print('='*60)