# A2: Direct Baseline Experiment (Claude + GPT-4o)

**Paper**: A2 (Cue-Dominant Extraction)

**Purpose**: Establish direct baseline accuracy (no trace provided) for both models on the A2 problem set.

**Design**:
- Models: Claude Sonnet 4, GPT-4o
- N = 199 problems
- Condition: Direct (no reasoning trace provided)

**Rationale**:
- Establish baseline capability for each model
- Enable comparison: Direct vs Trace-provided conditions
- Answer: "Is the model extracting from trace or solving independently?"

**Expected inference count**: 398 (199 × 2 models)

**Date**: 2026-01-03
**GLOBAL_SEED**: 20251224

## 0. Google Drive Connection

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
from datetime import datetime

EXPERIMENT_NAME = 'A2_Direct_baseline'
EXPERIMENT_DATE = datetime.now().strftime('%Y%m%d')

BASE_DIR = '/content/drive/MyDrive/CoT_Experiment'
V3_DATA_DIR = f'{BASE_DIR}/full_experiment_v3_20251224'

SAVE_DIR = f'{BASE_DIR}/{EXPERIMENT_NAME}_{EXPERIMENT_DATE}'
os.makedirs(SAVE_DIR, exist_ok=True)
os.makedirs(f'{SAVE_DIR}/results', exist_ok=True)

print(f'Experiment: {EXPERIMENT_NAME}')
print(f'V3 data directory: {V3_DATA_DIR}')
print(f'Save directory: {SAVE_DIR}')

## 1. Install Dependencies

In [None]:
!pip install datasets openai anthropic matplotlib pandas tqdm scipy -q
print('Dependencies installed.')

## 2. Configuration

In [None]:
import hashlib
import random
import json
import re
import time
from typing import List, Dict, Tuple, Optional, Any
from dataclasses import dataclass, asdict, field
from datetime import datetime
from tqdm import tqdm
import pandas as pd
import numpy as np
from scipy import stats

# =============================================================================
# Global Configuration
# =============================================================================
GLOBAL_SEED = 20251224

# API settings
API_MAX_TOKENS_ANSWER = 1024  # More tokens for direct solving
API_RETRY_DELAY = 1.0
API_RATE_LIMIT_DELAY = 0.5

# Model identifiers
CLAUDE_MODEL = 'claude-sonnet-4-20250514'
GPT_MODEL = 'gpt-4o'

print('='*70)
print('A2: DIRECT BASELINE EXPERIMENT')
print('='*70)
print(f'  Models: {CLAUDE_MODEL}, {GPT_MODEL}')
print(f'  GLOBAL_SEED: {GLOBAL_SEED}')
print(f'  Condition: Direct (no trace provided)')
print('='*70)
print('\nPurpose:')
print('  Establish each model\'s baseline capability on GSM8K')
print('  for comparison with trace-provided conditions')

## 3. Data Structures

In [None]:
@dataclass
class GSM8KProblem:
    index: int
    question: str
    answer_text: str
    final_answer: int

@dataclass
class DirectResult:
    problem_index: int
    model: str
    model_answer: Optional[int]
    correct_answer: int
    is_correct: bool
    raw_output: str
    timestamp: str

## 4. Utility Functions

In [None]:
def save_json(data: Any, filepath: str):
    with open(filepath, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=2)

def load_json(filepath: str) -> Any:
    with open(filepath, 'r', encoding='utf-8') as f:
        return json.load(f)

## 5. Load Problems

In [None]:
problems_path = f'{V3_DATA_DIR}/problems_v3.json'
problems_data = load_json(problems_path)
problems = [GSM8KProblem(**p) for p in problems_data]
print(f'Loaded {len(problems)} problems')

prob_map = {p.index: p for p in problems}

## 6. API Setup

In [None]:
from getpass import getpass

print('Enter API Keys for both models:')
ANTHROPIC_API_KEY = getpass('Enter Anthropic API Key: ')
OPENAI_API_KEY = getpass('Enter OpenAI API Key: ')
print('API Keys set.')

In [None]:
import anthropic
from openai import OpenAI

claude_client = anthropic.Anthropic(api_key=ANTHROPIC_API_KEY)
gpt_client = OpenAI(api_key=OPENAI_API_KEY)

print(f'Claude client initialized: {CLAUDE_MODEL}')
print(f'GPT client initialized: {GPT_MODEL}')

In [None]:
def call_claude(system_prompt: str, user_prompt: str, max_tokens: int = 1024, retries: int = 3) -> str:
    for attempt in range(retries):
        try:
            message = claude_client.messages.create(
                model=CLAUDE_MODEL,
                max_tokens=max_tokens,
                messages=[{"role": "user", "content": user_prompt}],
                system=system_prompt,
                temperature=0
            )
            time.sleep(API_RATE_LIMIT_DELAY)
            return message.content[0].text
        except Exception as e:
            print(f'Claude API error (attempt {attempt+1}): {e}')
            if attempt < retries - 1:
                time.sleep(API_RETRY_DELAY * (attempt + 1))
            else:
                raise

def call_gpt4o(system_prompt: str, user_prompt: str, max_tokens: int = 1024, retries: int = 3) -> str:
    for attempt in range(retries):
        try:
            response = gpt_client.chat.completions.create(
                model=GPT_MODEL,
                max_tokens=max_tokens,
                messages=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": user_prompt}
                ],
                temperature=0
            )
            time.sleep(API_RATE_LIMIT_DELAY)
            return response.choices[0].message.content
        except Exception as e:
            print(f'GPT API error (attempt {attempt+1}): {e}')
            if attempt < retries - 1:
                time.sleep(API_RETRY_DELAY * (attempt + 1))
            else:
                raise

# Test both APIs
print('Testing APIs...')
test_claude = call_claude("You output ONLY JSON.", 'Respond with exactly: {"test": "ok"}', max_tokens=50)
print(f'Claude test: {test_claude}')

test_gpt = call_gpt4o("You output ONLY JSON.", 'Respond with exactly: {"test": "ok"}', max_tokens=50)
print(f'GPT-4o test: {test_gpt}')

## 7. Direct Solving Prompts

In [None]:
DIRECT_SYSTEM_PROMPT = """You are a math problem solver. Solve the problem step by step, then provide the final answer.

After your reasoning, you MUST output the final answer in this exact format:
{"final": <number>}

Replace <number> with the integer answer.
"""

def create_direct_prompt(problem: GSM8KProblem) -> Tuple[str, str]:
    user = f"""Solve this math problem. Show your work step by step.

Problem: {problem.question}

After solving, output the final answer as: {{"final": <number>}}"""
    return DIRECT_SYSTEM_PROMPT, user

def parse_model_answer(response: str) -> Optional[int]:
    if response is None:
        return None
    # Try JSON format
    match = re.search(r'\{\s*"final"\s*:\s*(-?\d+(?:\.\d+)?)\s*\}', response)
    if match:
        return int(round(float(match.group(1))))
    match = re.search(r"\{\s*[\"']final[\"']\s*:\s*(-?\d+(?:\.\d+)?)\s*\}", response)
    if match:
        return int(round(float(match.group(1))))
    match = re.search(r'"final"\s*:\s*(-?\d+(?:\.\d+)?)', response)
    if match:
        return int(round(float(match.group(1))))
    # Try #### format (GSM8K style)
    match = re.search(r'####\s*(-?\d+)', response)
    if match:
        return int(match.group(1))
    # Last number in response
    matches = re.findall(r'(?:^|\s)(-?\d+(?:\.\d+)?)(?:\s|$|\.|,)', response)
    if matches:
        return int(round(float(matches[-1])))
    return None

## 8. Run Experiment

In [None]:
def run_direct_experiment(problem: GSM8KProblem, model: str) -> DirectResult:
    sys_prompt, usr_prompt = create_direct_prompt(problem)
    
    if model == CLAUDE_MODEL:
        response = call_claude(sys_prompt, usr_prompt, max_tokens=API_MAX_TOKENS_ANSWER)
    else:
        response = call_gpt4o(sys_prompt, usr_prompt, max_tokens=API_MAX_TOKENS_ANSWER)
    
    model_answer = parse_model_answer(response)
    is_correct = (model_answer == problem.final_answer) if model_answer is not None else False
    
    return DirectResult(
        problem_index=problem.index,
        model=model,
        model_answer=model_answer,
        correct_answer=problem.final_answer,
        is_correct=is_correct,
        raw_output=response,
        timestamp=datetime.now().isoformat()
    )

In [None]:
print('='*70)
print('A2: DIRECT BASELINE EXPERIMENT')
print('='*70)
print(f'Models: Claude Sonnet 4, GPT-4o')
print(f'Condition: Direct (no trace)')
print(f'Expected inferences: {len(problems) * 2}')
print('='*70)

results_claude = []
results_gpt = []

for prob in tqdm(problems, desc='Direct Baseline (Claude + GPT-4o)'):
    # Claude
    result_claude = run_direct_experiment(prob, CLAUDE_MODEL)
    results_claude.append(result_claude)
    
    # GPT-4o
    result_gpt = run_direct_experiment(prob, GPT_MODEL)
    results_gpt.append(result_gpt)

print(f'\nCompleted: {len(results_claude) + len(results_gpt)} experiments')

## 9. Save Results

In [None]:
save_json([asdict(r) for r in results_claude], f'{SAVE_DIR}/results/A2_Direct_claude_results.json')
print(f'Claude results saved: {SAVE_DIR}/results/A2_Direct_claude_results.json')

save_json([asdict(r) for r in results_gpt], f'{SAVE_DIR}/results/A2_Direct_gpt4o_results.json')
print(f'GPT-4o results saved: {SAVE_DIR}/results/A2_Direct_gpt4o_results.json')

## 10. Analysis

In [None]:
df_claude = pd.DataFrame([asdict(r) for r in results_claude])
df_gpt = pd.DataFrame([asdict(r) for r in results_gpt])

claude_acc = df_claude['is_correct'].mean()
gpt_acc = df_gpt['is_correct'].mean()

print('='*70)
print('A2 DIRECT BASELINE RESULTS')
print('='*70)
print(f'Claude Sonnet 4: {claude_acc:.1%} ({df_claude["is_correct"].sum()}/{len(df_claude)})')
print(f'GPT-4o:          {gpt_acc:.1%} ({df_gpt["is_correct"].sum()}/{len(df_gpt)})')
print(f'Difference:      {(claude_acc - gpt_acc)*100:+.1f} pp')
print('='*70)

In [None]:
# Wilson confidence intervals
def wilson_ci(successes, n, confidence=0.95):
    from scipy.stats import norm
    z = norm.ppf(1 - (1 - confidence) / 2)
    p = successes / n
    denominator = 1 + z**2 / n
    center = (p + z**2 / (2*n)) / denominator
    margin = z * np.sqrt((p * (1-p) + z**2 / (4*n)) / n) / denominator
    return (center - margin) * 100, (center + margin) * 100

claude_correct = df_claude['is_correct'].sum()
gpt_correct = df_gpt['is_correct'].sum()
N = len(problems)

claude_ci = wilson_ci(claude_correct, N)
gpt_ci = wilson_ci(gpt_correct, N)

print('\n95% Confidence Intervals:')
print(f'Claude: {claude_acc*100:.1f}% [{claude_ci[0]:.1f}%, {claude_ci[1]:.1f}%]')
print(f'GPT-4o: {gpt_acc*100:.1f}% [{gpt_ci[0]:.1f}%, {gpt_ci[1]:.1f}%]')

In [None]:
# Problem-level comparison
both_correct = sum(1 for i in range(N) if results_claude[i].is_correct and results_gpt[i].is_correct)
claude_only = sum(1 for i in range(N) if results_claude[i].is_correct and not results_gpt[i].is_correct)
gpt_only = sum(1 for i in range(N) if not results_claude[i].is_correct and results_gpt[i].is_correct)
both_wrong = sum(1 for i in range(N) if not results_claude[i].is_correct and not results_gpt[i].is_correct)

print('\nProblem-level Agreement:')
print('                  GPT-4o')
print('                  Correct  Wrong')
print(f'Claude  Correct    {both_correct:3d}     {claude_only:3d}')
print(f'        Wrong      {gpt_only:3d}     {both_wrong:3d}')
print(f'\nAgreement rate: {(both_correct + both_wrong) / N * 100:.1f}%')

In [None]:
# Comparison with A2 trace-provided conditions
print('\n' + '='*70)
print('COMPARISON: DIRECT vs TRACE-PROVIDED')
print('='*70)
print('\nClaude Sonnet 4:')
print(f'  Direct (this exp):       {claude_acc*100:.1f}%')
print(f'  E4\' (cue protected):    92.0%  (from A2 main experiments)')
print(f'  E2-Late (cue corrupted): 60.3%  (from A2 main experiments)')
print('\nGPT-4o:')
print(f'  Direct (this exp):       {gpt_acc*100:.1f}%')
print(f'  E4\' (cue protected):    [RUN E4\' NOTEBOOK]')
print(f'  E2-Late (cue corrupted): [RUN E4\' NOTEBOOK]')
print('='*70)

print('\nINTERPRETATION:')
print('  If Direct ≈ E4\' (cue protected): Model may be solving independently')
print('  If Direct > E2-Late: Corrupted traces are actively harmful')
print('  If E4\' > Direct: Clean traces provide useful information')

## 11. Summary

In [None]:
summary = {
    'experiment': 'A2_Direct_baseline',
    'date': EXPERIMENT_DATE,
    'n_problems': N,
    'models': [CLAUDE_MODEL, GPT_MODEL],
    'results': {
        'claude': {
            'model': CLAUDE_MODEL,
            'accuracy': claude_acc,
            'correct': int(claude_correct),
            'ci_95': claude_ci
        },
        'gpt4o': {
            'model': GPT_MODEL,
            'accuracy': gpt_acc,
            'correct': int(gpt_correct),
            'ci_95': gpt_ci
        }
    },
    'agreement': {
        'both_correct': both_correct,
        'claude_only': claude_only,
        'gpt_only': gpt_only,
        'both_wrong': both_wrong,
        'agreement_rate': (both_correct + both_wrong) / N
    }
}

save_json(summary, f'{SAVE_DIR}/results/A2_Direct_summary.json')

print('='*70)
print('A2 DIRECT BASELINE EXPERIMENT COMPLETE')
print('='*70)
print(f'Date: {EXPERIMENT_DATE}')
print(f'Total experiments: {len(results_claude) + len(results_gpt)}')
print(f'\nResults:')
print(f'  Claude Sonnet 4: {claude_acc:.1%}')
print(f'  GPT-4o:          {gpt_acc:.1%}')
print(f'\nFiles saved to: {SAVE_DIR}')
print('='*70)