# CoT Phase Transition Experiment v3 - Direct Condition (No CoT)

**Version**: 3.0 (2024-12-24)

**Purpose**: Baseline experiment without Chain-of-Thought reasoning.

**Design**:
- Same 200 problems as Experiment 1
- NO CoT provided - model answers directly
- Measures baseline accuracy without reasoning traces

**Why This Matters**:
- Establishes whether CoT actually helps (Clean vs Direct)
- Shows at what corruption level CoT becomes worse than no CoT
- Critical for the "collapse" narrative in the paper

## 0. Google Drive Connection

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
from datetime import datetime

EXPERIMENT_VERSION = 'v3'
EXPERIMENT_DATE = datetime.now().strftime('%Y%m%d')

SAVE_DIR = '/content/drive/MyDrive/CoT_Experiment'
SAVE_DIR_DIRECT = f'{SAVE_DIR}/direct_experiment_{EXPERIMENT_VERSION}_{EXPERIMENT_DATE}'

os.makedirs(SAVE_DIR_DIRECT, exist_ok=True)
os.makedirs(f'{SAVE_DIR_DIRECT}/results', exist_ok=True)

print(f'Direct experiment save directory: {SAVE_DIR_DIRECT}')

## 1. Install Dependencies

In [None]:
!pip install datasets anthropic pandas tqdm -q
print('Dependencies installed.')

## 2. Configuration

In [None]:
import json
import re
import random
import time
from typing import List, Dict, Optional, Any
from dataclasses import dataclass, asdict
from datetime import datetime
from tqdm import tqdm
import pandas as pd

# =============================================================================
# Configuration - MUST MATCH EXPERIMENT 1
# =============================================================================
GLOBAL_SEED = 20251224  # Same as Experiment 1
N_PROBLEMS = 200        # Same as Experiment 1

# API settings
API_MAX_TOKENS_ANSWER = 256
API_RATE_LIMIT_DELAY = 0.5

print('='*60)
print('DIRECT CONDITION EXPERIMENT')
print('='*60)
print(f'  GLOBAL_SEED: {GLOBAL_SEED}')
print(f'  N_PROBLEMS: {N_PROBLEMS}')
print(f'  Condition: Direct (No CoT)')
print('='*60)

## 3. Data Structures

In [None]:
@dataclass
class GSM8KProblem:
    index: int
    question: str
    answer_text: str
    final_answer: int

@dataclass
class DirectResult:
    problem_index: int
    condition: str  # Always "direct"
    model_answer: Optional[int]
    correct_answer: int
    is_correct: bool
    raw_output: str
    timestamp: str

def extract_final_answer(answer_text: str) -> int:
    match = re.search(r'####\s*([\d,]+)', answer_text)
    if match:
        return int(match.group(1).replace(',', ''))
    raise ValueError(f'Could not extract final answer')

def save_json(data: Any, filepath: str):
    with open(filepath, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=2)
    print(f'Saved: {filepath}')

def load_json(filepath: str) -> Any:
    with open(filepath, 'r', encoding='utf-8') as f:
        return json.load(f)

## 4. Load GSM8K and Select Problems (Same as Experiment 1)

In [None]:
from datasets import load_dataset

dataset = load_dataset('gsm8k', 'main', split='test')
print(f'GSM8K test set loaded: {len(dataset)} problems')

def select_problems(dataset, n_problems: int, seed: int) -> List[int]:
    rng = random.Random(seed)
    indices = list(range(len(dataset)))
    rng.shuffle(indices)
    selected = sorted(indices[:n_problems])
    return selected

# IMPORTANT: Same seed as Experiment 1 → Same problems
selected_indices = select_problems(dataset, N_PROBLEMS, GLOBAL_SEED)
print(f'Selected {len(selected_indices)} problems (same as Experiment 1)')

# Create problem objects
problems = []
for idx in selected_indices:
    item = dataset[idx]
    try:
        final_ans = extract_final_answer(item['answer'])
        prob = GSM8KProblem(
            index=idx,
            question=item['question'],
            answer_text=item['answer'],
            final_answer=final_ans
        )
        problems.append(prob)
    except ValueError as e:
        print(f'Warning: Could not extract answer for index {idx}')

print(f'Prepared {len(problems)} problems')

## 5. API Setup

In [None]:
from getpass import getpass

ANTHROPIC_API_KEY = getpass('Enter Anthropic API Key: ')
print('API Key set.')

In [None]:
import anthropic

client = anthropic.Anthropic(api_key=ANTHROPIC_API_KEY)

def call_claude(system_prompt: str, user_prompt: str, max_tokens: int = 256, retries: int = 3) -> str:
    for attempt in range(retries):
        try:
            message = client.messages.create(
                model="claude-sonnet-4-20250514",
                max_tokens=max_tokens,
                messages=[{"role": "user", "content": user_prompt}],
                system=system_prompt,
                temperature=0
            )
            time.sleep(API_RATE_LIMIT_DELAY)
            return message.content[0].text
        except Exception as e:
            print(f'API error (attempt {attempt+1}): {e}')
            if attempt < retries - 1:
                time.sleep(1.0 * (attempt + 1))
            else:
                raise

# Test
test_response = call_claude(
    "You output ONLY JSON.",
    'Respond with exactly: {"test": "ok"}',
    max_tokens=50
)
print(f'API test: {test_response}')

## 6. Direct Condition Prompts

In [None]:
# System prompt - Same JSON format as Experiment 1 for consistency
DIRECT_SYSTEM_PROMPT = """You are a calculator that outputs ONLY JSON.

CRITICAL RULES:
1. Your output MUST start with the character '{'
2. Your output MUST be exactly: {"final": <number>}
3. Replace <number> with an integer (the numerical answer)
4. Do NOT write ANY explanation, reasoning, or text before or after the JSON
5. Do NOT show your work or thinking
6. ONLY output the JSON object, nothing else

CORRECT OUTPUT EXAMPLE:
{"final": 42}
"""

def create_direct_prompt(problem: GSM8KProblem) -> tuple:
    """Create prompt for Direct condition (no CoT)"""
    user = f"""Problem: {problem.question}

Solve this problem and give the final numerical answer.
OUTPUT ONLY: {{"final": <number>}}
START YOUR RESPONSE WITH '{{'"""
    
    return DIRECT_SYSTEM_PROMPT, user

# Test
test_sys, test_usr = create_direct_prompt(problems[0])
print('--- Direct Prompt Example ---')
print(test_usr[:300])

## 7. Answer Parsing

In [None]:
def parse_model_answer(response: str) -> Optional[int]:
    """Extract numerical answer from model response"""
    # Pattern 1: {"final": 123}
    match = re.search(r'\{\s*"final"\s*:\s*(-?\d+(?:\.\d+)?)\s*\}', response)
    if match:
        return int(round(float(match.group(1))))
    
    # Pattern 2: {'final': 123}
    match = re.search(r"\{\s*[\"']final[\"']\s*:\s*(-?\d+(?:\.\d+)?)\s*\}", response)
    if match:
        return int(round(float(match.group(1))))
    
    # Pattern 3: "final": 123
    match = re.search(r'"final"\s*:\s*(-?\d+(?:\.\d+)?)', response)
    if match:
        return int(round(float(match.group(1))))
    
    # Pattern 4: Last number in response
    matches = re.findall(r'(?:^|\s)(-?\d+(?:\.\d+)?)(?:\s|$|\.|,)', response)
    if matches:
        return int(round(float(matches[-1])))
    
    return None

# Test
test_cases = ['{"final": 300}', '{"final": 42.5}', 'The answer is 100.']
for tc in test_cases:
    print(f'  "{tc}" → {parse_model_answer(tc)}')

## 8. Run Direct Experiment

In [None]:
def run_direct_experiment(problem: GSM8KProblem) -> DirectResult:
    """Run single Direct condition experiment"""
    sys_prompt, usr_prompt = create_direct_prompt(problem)
    response = call_claude(sys_prompt, usr_prompt, max_tokens=API_MAX_TOKENS_ANSWER)
    
    model_answer = parse_model_answer(response)
    is_correct = (model_answer == problem.final_answer) if model_answer is not None else False
    
    return DirectResult(
        problem_index=problem.index,
        condition="direct",
        model_answer=model_answer,
        correct_answer=problem.final_answer,
        is_correct=is_correct,
        raw_output=response,
        timestamp=datetime.now().isoformat()
    )

In [None]:
# Pilot test (5 problems)
print('='*60)
print('PILOT: Testing with 5 problems')
print('='*60)

pilot_results = []
for prob in tqdm(problems[:5], desc='Pilot'):
    result = run_direct_experiment(prob)
    pilot_results.append(result)
    status = '✓' if result.is_correct else '✗'
    print(f'  Problem {prob.index}: {status} (model={result.model_answer}, correct={result.correct_answer})')

pilot_acc = sum(r.is_correct for r in pilot_results) / len(pilot_results)
print(f'\nPilot accuracy: {pilot_acc:.1%}')

In [None]:
# Full experiment
print('='*60)
print('FULL DIRECT EXPERIMENT')
print('='*60)

all_results = []

for prob in tqdm(problems, desc='Direct condition'):
    result = run_direct_experiment(prob)
    all_results.append(result)

# Save results
save_json([asdict(r) for r in all_results], f'{SAVE_DIR_DIRECT}/results/direct_results_v3.json')

## 9. Analysis

In [None]:
# Convert to DataFrame
df = pd.DataFrame([asdict(r) for r in all_results])

# Calculate accuracy
total = len(df)
correct = df['is_correct'].sum()
accuracy = correct / total

print('='*60)
print('DIRECT CONDITION RESULTS')
print('='*60)
print(f'Total problems: {total}')
print(f'Correct: {correct}')
print(f'Accuracy: {accuracy:.1%}')
print('='*60)

In [None]:
# Parse failures analysis
parse_failures = df[df['model_answer'].isna()]
print(f'Parse failures: {len(parse_failures)} ({len(parse_failures)/total:.1%})')

if len(parse_failures) > 0:
    print('\nSample parse failures:')
    for _, row in parse_failures.head(3).iterrows():
        print(f'  Problem {row["problem_index"]}: {row["raw_output"][:100]}...')

## 10. Comparison with Experiment 1 (if available)

In [None]:
# Try to load Experiment 1 results for comparison
exp1_dirs = [d for d in os.listdir(SAVE_DIR) if d.startswith('full_experiment_v3')]

if exp1_dirs:
    exp1_dir = f'{SAVE_DIR}/{sorted(exp1_dirs)[-1]}'
    exp1_results_path = f'{exp1_dir}/results/results_full_v3.json'
    
    if os.path.exists(exp1_results_path):
        exp1_data = load_json(exp1_results_path)
        exp1_df = pd.DataFrame(exp1_data)
        
        print('='*60)
        print('COMPARISON: Direct vs CoT Conditions')
        print('='*60)
        
        # Direct accuracy
        direct_acc = accuracy
        print(f'Direct (No CoT):     {direct_acc:.1%}')
        
        # CoT accuracies by λ
        if 'lam' in exp1_df.columns:
            for lam in sorted(exp1_df['lam'].unique()):
                lam_acc = exp1_df[exp1_df['lam'] == lam]['is_correct'].mean()
                A = 1 - lam
                comparison = '>' if lam_acc > direct_acc else ('<' if lam_acc < direct_acc else '=')
                print(f'CoT (A={A:.1f}, λ={lam}): {lam_acc:.1%} {comparison} Direct')
        
        print('='*60)
    else:
        print('Experiment 1 results not found yet.')
else:
    print('Experiment 1 directory not found. Run comparison after Experiment 1 completes.')

## 11. Summary

In [None]:
print('='*60)
print('DIRECT EXPERIMENT SUMMARY')
print('='*60)
print(f'Version: {EXPERIMENT_VERSION}')
print(f'Date: {EXPERIMENT_DATE}')
print(f'Problems: {len(all_results)}')
print(f'\nDirect Condition Accuracy: {accuracy:.1%}')
print(f'\nThis serves as the baseline for CoT effectiveness.')
print(f'If CoT (A=1.0) > Direct: CoT helps')
print(f'If CoT (A=x) < Direct: CoT at that corruption level hurts')
print(f'\nResults saved to: {SAVE_DIR_DIRECT}')
print('='*60)