# A2 GPT-4o E6: Cross-Model Validation (FIXED VERSION)

**Paper**: A2 (Cue-Dominant Extraction Explains Length Effects)

**CRITICAL FIX**: Uses complete final step replacement (same as E6 v2).

**Purpose**: Cross-model validation of E6 using GPT-4o.

**Conditions** (A and B only for cost efficiency):
| Condition | Step 1-9 | Step 10 (Cue) | Prediction |
|-----------|----------|---------------|------------|
| **A: Wrong Cue** | Clean | Wrong answer | WRONG (if cue-dominant) |
| **B: Correct Cue** | Corrupted | Correct answer | CORRECT (if cue-dominant) |

**Expected inference count**: 199 Ã— 2 = 398

**Date**: 2026-01-03
**VERSION**: 2.0 (FIXED)

## 0. Google Drive Connection

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
from datetime import datetime

EXPERIMENT_NAME = 'A2_GPT4o_E6_v2'
EXPERIMENT_DATE = datetime.now().strftime('%Y%m%d')

BASE_DIR = '/content/drive/MyDrive/CoT_Experiment'
V3_DATA_DIR = f'{BASE_DIR}/full_experiment_v3_20251224'

SAVE_DIR = f'{BASE_DIR}/{EXPERIMENT_NAME}_{EXPERIMENT_DATE}'
os.makedirs(SAVE_DIR, exist_ok=True)
os.makedirs(f'{SAVE_DIR}/results', exist_ok=True)

print(f'Experiment: {EXPERIMENT_NAME}')
print(f'Save directory: {SAVE_DIR}')

## 1. Install Dependencies

In [None]:
!pip install datasets openai matplotlib pandas tqdm scipy -q
print('Dependencies installed.')

## 2. Configuration

In [None]:
import hashlib
import random
import json
import re
import time
from typing import List, Dict, Tuple, Optional, Any
from dataclasses import dataclass, asdict
from datetime import datetime
from tqdm import tqdm
import pandas as pd
import numpy as np

GLOBAL_SEED = 20251224
E6_SEED = 20260103
CORRUPTION_RATE = 0.8

API_MAX_TOKENS_ANSWER = 256
API_RETRY_DELAY = 1.0
API_RATE_LIMIT_DELAY = 0.5

print('='*70)
print('GPT-4o E6: CROSS-MODEL VALIDATION (FIXED)')
print('='*70)

## 3. Data Structures

In [None]:
@dataclass
class GSM8KProblem:
    index: int
    question: str
    answer_text: str
    final_answer: int

@dataclass
class CleanTrace:
    problem_index: int
    I: int
    steps: List[str]
    full_text: str

@dataclass
class ManipulatedTrace:
    problem_index: int
    L: int
    condition: str
    reasoning_status: str
    cue_status: str
    cue_answer: int
    correct_answer: int
    steps: List[str]
    full_text: str
    seed: int

@dataclass
class ExperimentResult:
    problem_index: int
    model: str
    condition: str
    cue_answer: int
    correct_answer: int
    model_answer: Optional[int]
    is_correct: bool
    followed_cue: bool
    raw_output: str
    timestamp: str

## 4. Utility Functions

In [None]:
def derive_seed(global_seed: int, problem_id: int, condition: str) -> int:
    key = f"{global_seed}|E6|{problem_id}|{condition}"
    h = hashlib.sha256(key.encode("utf-8")).hexdigest()
    return int(h[:8], 16)

def save_json(data: Any, filepath: str):
    with open(filepath, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=2)

def load_json(filepath: str) -> Any:
    with open(filepath, 'r', encoding='utf-8') as f:
        return json.load(f)

## 5. Load Data

In [None]:
problems_path = f'{V3_DATA_DIR}/problems_v3.json'
problems_data = load_json(problems_path)
problems = [GSM8KProblem(**p) for p in problems_data]
prob_map = {p.index: p for p in problems}
print(f'Loaded {len(problems)} problems')

traces_path = f'{V3_DATA_DIR}/clean_traces/clean_traces_I10_v3.json'
traces_data = load_json(traces_path)
clean_traces = [CleanTrace(**t) for t in traces_data]
print(f'Loaded {len(clean_traces)} traces')

## 6. Corruption Logic

In [None]:
def pick_corrupted_steps(L: int, c: float, seed: int) -> List[int]:
    K = int(round(c * (L - 1)))
    if K == 0:
        return []
    steps = list(range(1, L))
    rng = random.Random(seed)
    rng.shuffle(steps)
    return sorted(steps[:K])

def assign_corruption_types(corrupted_steps: List[int], seed: int) -> Dict[int, str]:
    K = len(corrupted_steps)
    if K == 0:
        return {}
    n_irr = (K * 1) // 5
    n_loc = (K * 2) // 5
    n_wrong = K - n_irr - n_loc
    if n_wrong == 0 and K > 0:
        n_wrong = 1
        if n_loc > 0:
            n_loc -= 1
        elif n_irr > 0:
            n_irr -= 1
    rng = random.Random(seed + 1)
    perm = corrupted_steps[:]
    rng.shuffle(perm)
    type_map = {}
    for s in perm[:n_irr]:
        type_map[s] = "IRR"
    for s in perm[n_irr:n_irr + n_loc]:
        type_map[s] = "LOC"
    for s in perm[n_irr + n_loc:]:
        type_map[s] = "WRONG"
    return type_map

IRRELEVANT_TEMPLATES = [
    "Compute an auxiliary value: aux = {a} + {b} = {result}, but it will not be used later.",
    "Compute a side quantity: aux = {a} * 2 = {result}, unrelated to the final result.",
]

WRONG_CONSTRAINT_TEMPLATES = [
    "Fix an intermediate condition: set {var} = {wrong_value} as a given constraint.",
    "Assume the total is {var} = {wrong_value} and proceed using this fixed value.",
]

def generate_irrelevant_step(step_num: int, seed: int) -> str:
    rng = random.Random(seed)
    a, b = rng.randint(2, 20), rng.randint(2, 20)
    template = rng.choice(IRRELEVANT_TEMPLATES)
    result = a + b if '+' in template else a * 2
    return template.format(a=a, b=b, result=result)

def generate_local_error_step(original_step: str, seed: int) -> str:
    rng = random.Random(seed)
    numbers = re.findall(r'\d+', original_step)
    if not numbers:
        return f"Compute t = 10 * 3 = {rng.randint(28, 32)}."
    original_result = int(numbers[-1])
    offset = rng.choice([-3, -2, -1, 1, 2, 3])
    wrong_result = max(0, original_result + offset)
    modified = re.sub(r'= (\d+)\.$', f'= {wrong_result}.', original_step)
    if modified == original_step:
        modified = re.sub(r'(\d+)\.$', f'{wrong_result}.', original_step)
    return modified

def generate_wrong_constraint_step(step_num: int, seed: int) -> str:
    rng = random.Random(seed)
    var = rng.choice(['x', 'total', 'result', 'n'])
    wrong_value = rng.randint(10, 100)
    template = rng.choice(WRONG_CONSTRAINT_TEMPLATES)
    return template.format(var=var, wrong_value=wrong_value)

## 7. FIXED: Final Step Manipulation

In [None]:
def generate_wrong_answer(correct_answer: int, seed: int) -> int:
    rng = random.Random(seed)
    magnitude = max(1, abs(correct_answer))
    if magnitude < 10:
        offset = rng.choice([-5, -4, -3, -2, 2, 3, 4, 5])
    elif magnitude < 100:
        offset = rng.choice([-20, -15, -10, 10, 15, 20])
    elif magnitude < 1000:
        offset = rng.choice([-100, -50, 50, 100, 150])
    else:
        pct = rng.choice([0.1, 0.15, 0.2, -0.1, -0.15, -0.2])
        offset = int(correct_answer * pct)
    wrong_answer = correct_answer + offset
    if wrong_answer == correct_answer:
        wrong_answer = correct_answer + 10
    if correct_answer > 0 and wrong_answer <= 0:
        wrong_answer = abs(wrong_answer) + 1
    return wrong_answer

def create_clean_final_step(answer: int) -> str:
    """FIXED: Replace ENTIRE final step with clean cue"""
    return f"Therefore, the final answer is Final = {answer}."

print("FIXED: Final step is now completely replaced")
print(f"Example: {create_clean_final_step(70000)}")

## 8. Trace Creation

In [None]:
def create_condition_a_trace(clean_trace: CleanTrace, correct_answer: int, seed: int) -> ManipulatedTrace:
    """Wrong Cue + Clean Reasoning"""
    wrong_answer = generate_wrong_answer(correct_answer, seed)
    new_steps = clean_trace.steps[:-1] + [create_clean_final_step(wrong_answer)]
    
    lines = ['[[COT_START]]']
    for i, content in enumerate(new_steps):
        lines.append(f'Step {i+1}: {content}')
    lines.append('[[COT_END]]')
    
    return ManipulatedTrace(
        problem_index=clean_trace.problem_index,
        L=clean_trace.I,
        condition='wrong_cue',
        reasoning_status='clean',
        cue_status='wrong',
        cue_answer=wrong_answer,
        correct_answer=correct_answer,
        steps=new_steps,
        full_text='\n'.join(lines),
        seed=seed
    )

def create_condition_b_trace(clean_trace: CleanTrace, correct_answer: int, seed: int) -> ManipulatedTrace:
    """Correct Cue + Corrupted Reasoning"""
    L = clean_trace.I
    corrupted_steps = pick_corrupted_steps(L, CORRUPTION_RATE, seed)
    corruption_types = assign_corruption_types(corrupted_steps, seed)
    
    new_steps = []
    for i, step_content in enumerate(clean_trace.steps[:-1]):
        step_num = i + 1
        if step_num in corrupted_steps:
            ctype = corruption_types[step_num]
            step_seed = seed + step_num * 1000
            if ctype == 'IRR':
                new_content = generate_irrelevant_step(step_num, step_seed)
            elif ctype == 'LOC':
                new_content = generate_local_error_step(step_content, step_seed)
            else:
                new_content = generate_wrong_constraint_step(step_num, step_seed)
            new_steps.append(new_content)
        else:
            new_steps.append(step_content)
    
    new_steps.append(create_clean_final_step(correct_answer))
    
    lines = ['[[COT_START]]']
    for i, content in enumerate(new_steps):
        lines.append(f'Step {i+1}: {content}')
    lines.append('[[COT_END]]')
    
    return ManipulatedTrace(
        problem_index=clean_trace.problem_index,
        L=L,
        condition='correct_cue_only',
        reasoning_status='corrupted',
        cue_status='correct',
        cue_answer=correct_answer,
        correct_answer=correct_answer,
        steps=new_steps,
        full_text='\n'.join(lines),
        seed=seed
    )

## 9. API Setup

In [None]:
from getpass import getpass

OPENAI_API_KEY = getpass('Enter OpenAI API Key: ')
print('API Key set.')

In [None]:
from openai import OpenAI

client = OpenAI(api_key=OPENAI_API_KEY)
MODEL = 'gpt-4o'

def call_gpt4o(system_prompt: str, user_prompt: str, max_tokens: int = 1024, retries: int = 3) -> str:
    for attempt in range(retries):
        try:
            response = client.chat.completions.create(
                model=MODEL, max_tokens=max_tokens,
                messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}],
                temperature=0
            )
            time.sleep(API_RATE_LIMIT_DELAY)
            return response.choices[0].message.content
        except Exception as e:
            print(f'API error (attempt {attempt+1}): {e}')
            if attempt < retries - 1:
                time.sleep(API_RETRY_DELAY * (attempt + 1))
            else:
                raise

test_response = call_gpt4o("You output ONLY JSON.", 'Respond: {"test": "ok"}', max_tokens=50)
print(f'API test: {test_response}')

## 10. Experiment

In [None]:
SYSTEM_PROMPT = """You are a calculator that outputs ONLY JSON.
Output exactly: {"final": <number>}
Do NOT write any explanation.
"""

def create_prompt(problem: GSM8KProblem, trace_text: str) -> str:
    return f"""Problem: {problem.question}

Reasoning trace:
{trace_text}

OUTPUT ONLY: {{"final": <number>}}"""

def parse_model_answer(response: str) -> Optional[int]:
    match = re.search(r'\{\s*"final"\s*:\s*(-?\d+(?:\.\d+)?)\s*\}', response)
    if match:
        return int(round(float(match.group(1))))
    match = re.search(r'"final"\s*:\s*(-?\d+(?:\.\d+)?)', response)
    if match:
        return int(round(float(match.group(1))))
    matches = re.findall(r'(-?\d+)', response)
    if matches:
        return int(matches[-1])
    return None

def run_single_experiment(problem: GSM8KProblem, trace: ManipulatedTrace) -> ExperimentResult:
    user_prompt = create_prompt(problem, trace.full_text)
    response = call_gpt4o(SYSTEM_PROMPT, user_prompt, max_tokens=API_MAX_TOKENS_ANSWER)
    model_answer = parse_model_answer(response)
    is_correct = (model_answer == trace.correct_answer) if model_answer is not None else False
    followed_cue = (model_answer == trace.cue_answer) if model_answer is not None else False
    
    return ExperimentResult(
        problem_index=problem.index, model=MODEL, condition=trace.condition,
        cue_answer=trace.cue_answer, correct_answer=trace.correct_answer,
        model_answer=model_answer, is_correct=is_correct, followed_cue=followed_cue,
        raw_output=response, timestamp=datetime.now().isoformat()
    )

In [None]:
print('='*70)
print('GPT-4o E6 (FIXED VERSION)')
print('='*70)

all_results = []
all_traces = []

for trace in tqdm(clean_traces, desc='Running GPT-4o E6'):
    problem = prob_map.get(trace.problem_index)
    if problem is None:
        continue
    
    seed = derive_seed(E6_SEED, trace.problem_index, 'E6')
    correct_answer = problem.final_answer
    
    trace_a = create_condition_a_trace(trace, correct_answer, seed)
    trace_b = create_condition_b_trace(trace, correct_answer, seed)
    
    result_a = run_single_experiment(problem, trace_a)
    result_b = run_single_experiment(problem, trace_b)
    
    all_results.extend([result_a, result_b])
    all_traces.append({
        'problem_index': trace.problem_index,
        'condition_a': asdict(trace_a),
        'condition_b': asdict(trace_b)
    })

print(f'\nTotal experiments: {len(all_results)}')

## 11. Save & Analysis

In [None]:
save_json([asdict(r) for r in all_results], f'{SAVE_DIR}/results/A2_GPT4o_E6_results.json')
save_json(all_traces, f'{SAVE_DIR}/results/A2_GPT4o_E6_traces.json')
print(f'Results saved')

df = pd.DataFrame([asdict(r) for r in all_results])

print('\n' + '='*70)
print('GPT-4o E6 RESULTS (FIXED)')
print('='*70)

for condition in ['wrong_cue', 'correct_cue_only']:
    cond_df = df[df['condition'] == condition]
    acc = cond_df['is_correct'].mean()
    cue_follow = cond_df['followed_cue'].mean()
    print(f'\n{condition.upper()}:')
    print(f'  Accuracy: {acc*100:.2f}%')
    print(f'  Cue Following: {cue_follow*100:.2f}%')

In [None]:
acc_a = df[df['condition'] == 'wrong_cue']['is_correct'].mean()
acc_b = df[df['condition'] == 'correct_cue_only']['is_correct'].mean()
cue_follow_a = df[df['condition'] == 'wrong_cue']['followed_cue'].mean()
cue_follow_b = df[df['condition'] == 'correct_cue_only']['followed_cue'].mean()

summary = {
    'experiment': 'A2_GPT4o_E6_v2_FIXED',
    'model': MODEL,
    'date': EXPERIMENT_DATE,
    'n_problems': len(clean_traces),
    'total_inferences': len(all_results),
    'results': {
        'condition_a': {
            'description': 'Wrong Cue + Clean Reasoning',
            'accuracy': acc_a,
            'followed_cue': cue_follow_a,
            'n_correct': int(df[df['condition'] == 'wrong_cue']['is_correct'].sum())
        },
        'condition_b': {
            'description': 'Correct Cue + Corrupted Reasoning',
            'accuracy': acc_b,
            'followed_cue': cue_follow_b,
            'n_correct': int(df[df['condition'] == 'correct_cue_only']['is_correct'].sum())
        }
    },
    'interpretation': {
        'cue_dominant': bool(cue_follow_a > 0.5),
        'cue_rescues': bool(acc_b > 0.7)
    }
}

save_json(summary, f'{SAVE_DIR}/results/A2_GPT4o_E6_summary.json')

print('\n' + '='*70)
print('COMPLETE')
print('='*70)
print(f'Condition A (Wrong Cue): {acc_a*100:.1f}% acc, {cue_follow_a*100:.1f}% cue follow')
print(f'Condition B (Correct Cue): {acc_b*100:.1f}% acc')