# E8': GPT-4o Cross-Model Validation (FIXED VERSION)

**Paper**: A2 (Cue-Dominant Extraction Explains Length Effects)

**Purpose**: Cross-model validation of E8 findings using GPT-4o.

**FIXES**: Uses complete final step replacement (same as E8 v2).

**Cost-optimized design**:
- L ∈ {10, 20} only (2 lengths vs 4)
- c = 0.8 only (single corruption level)
- cue ∈ {present, absent}

**Expected inferences**: 188 × 2 × 2 = ~752

**Date**: 2026-01-03
**VERSION**: 2.0 (FIXED)

## 0. Google Drive Connection

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
from datetime import datetime

EXPERIMENT_NAME = 'E8prime_GPT4o_v2'
EXPERIMENT_DATE = datetime.now().strftime('%Y%m%d')

BASE_DIR = '/content/drive/MyDrive/CoT_Experiment'
V3_DATA_DIR = f'{BASE_DIR}/full_experiment_v3_20251224'
TRACES_DIR = f'{V3_DATA_DIR}/clean_traces'

SAVE_DIR = f'{BASE_DIR}/{EXPERIMENT_NAME}_{EXPERIMENT_DATE}'
os.makedirs(SAVE_DIR, exist_ok=True)
os.makedirs(f'{SAVE_DIR}/results', exist_ok=True)

print(f'Experiment: {EXPERIMENT_NAME}')
print(f'Save directory: {SAVE_DIR}')

## 1. Install Dependencies

In [None]:
!pip install datasets openai matplotlib pandas tqdm scipy -q
print('Dependencies installed.')

## 2. Configuration

In [None]:
import hashlib
import random
import json
import re
import time
from typing import List, Dict, Tuple, Optional, Any, Set
from dataclasses import dataclass, asdict
from datetime import datetime
from tqdm import tqdm
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# =============================================================================
# Configuration - Cost-optimized for GPT-4o
# =============================================================================
GLOBAL_SEED = 20251224
E8_SEED = 20260103

LENGTHS = [10, 20]  # Reduced from [5, 10, 15, 20]
CORRUPTION_LEVEL = 0.8  # Single level
CUE_CONDITIONS = ['present', 'absent']

API_MAX_TOKENS_ANSWER = 256
API_RETRY_DELAY = 1.0
API_RATE_LIMIT_DELAY = 0.5

print('='*70)
print('E8\': GPT-4o CROSS-MODEL VALIDATION (FIXED)')
print('='*70)
print(f'  Lengths: {LENGTHS}')
print(f'  Corruption: {CORRUPTION_LEVEL}')
print(f'  Cue: {CUE_CONDITIONS}')
print('='*70)

## 3. Data Structures

In [None]:
@dataclass
class GSM8KProblem:
    index: int
    question: str
    answer_text: str
    final_answer: int

@dataclass
class CleanTrace:
    problem_index: int
    I: int
    steps: List[str]
    full_text: str

@dataclass
class CorruptedTrace:
    problem_index: int
    L: int
    c: float
    cue_condition: str
    K_corrupt: int
    corrupted_steps: List[int]
    corruption_types: Dict[int, str]
    steps: List[str]
    full_text: str
    seed: int

@dataclass
class ExperimentResult:
    problem_index: int
    model: str
    L: int
    c: float
    cue_condition: str
    K_corrupt: int
    K_clean: int
    model_answer: Optional[int]
    correct_answer: int
    is_correct: bool
    raw_output: str
    timestamp: str

## 4. Utility Functions

In [None]:
def derive_seed(global_seed: int, problem_id: int, L: int, c: float, cue: str) -> int:
    key = f"{global_seed}|E8|{problem_id}|L={L}|c={c}|cue={cue}"
    h = hashlib.sha256(key.encode("utf-8")).hexdigest()
    return int(h[:8], 16)

def save_json(data: Any, filepath: str):
    with open(filepath, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=2)

def load_json(filepath: str) -> Any:
    with open(filepath, 'r', encoding='utf-8') as f:
        return json.load(f)

## 5. Load Data

In [None]:
# Load problems
problems_path = f'{V3_DATA_DIR}/problems_v3.json'
problems_data = load_json(problems_path)
problems = [GSM8KProblem(**p) for p in problems_data]
prob_map = {p.index: p for p in problems}
print(f'Loaded {len(problems)} problems')

# Load traces for L=10 and L=20
trace_maps = {}
for L in LENGTHS:
    trace_path = f'{TRACES_DIR}/clean_traces_I{L}_v3.json'
    traces_data = load_json(trace_path)
    traces = [CleanTrace(**t) for t in traces_data]
    trace_maps[L] = {t.problem_index: t for t in traces}
    print(f'Loaded {len(traces)} traces for L={L}')

# Find common problems (need traces for ALL lengths including L=5,10,15,20 for consistency)
# Load L=5 and L=15 just to check common indices
all_lengths = [5, 10, 15, 20]
all_trace_maps = {}
for L in all_lengths:
    trace_path = f'{TRACES_DIR}/clean_traces_I{L}_v3.json'
    if os.path.exists(trace_path):
        traces_data = load_json(trace_path)
        traces = [CleanTrace(**t) for t in traces_data]
        all_trace_maps[L] = {t.problem_index: t for t in traces}

common_indices = set.intersection(*[set(tm.keys()) for tm in all_trace_maps.values()])
print(f'\nCommon problems across L={all_lengths}: {len(common_indices)}')

experiment_problems = [p for p in problems if p.index in common_indices]
print(f'Using {len(experiment_problems)} problems')

## 6. Corruption Logic

In [None]:
def pick_corrupted_steps(L: int, c: float, seed: int) -> List[int]:
    K = int(round(c * (L - 1)))
    if K == 0:
        return []
    steps = list(range(1, L))
    rng = random.Random(seed)
    rng.shuffle(steps)
    return sorted(steps[:K])

def assign_corruption_types(corrupted_steps: List[int], seed: int) -> Dict[int, str]:
    K = len(corrupted_steps)
    if K == 0:
        return {}
    n_irr = (K * 1) // 5
    n_loc = (K * 2) // 5
    n_wrong = K - n_irr - n_loc
    if n_wrong == 0 and K > 0:
        n_wrong = 1
        if n_loc > 0:
            n_loc -= 1
        elif n_irr > 0:
            n_irr -= 1
    rng = random.Random(seed + 1)
    perm = corrupted_steps[:]
    rng.shuffle(perm)
    type_map = {}
    for s in perm[:n_irr]:
        type_map[s] = "IRR"
    for s in perm[n_irr:n_irr + n_loc]:
        type_map[s] = "LOC"
    for s in perm[n_irr + n_loc:]:
        type_map[s] = "WRONG"
    return type_map

IRRELEVANT_TEMPLATES = [
    "Compute an auxiliary value: aux = {a} + {b} = {result}, but it will not be used later.",
    "Compute a side quantity: aux = {a} * 2 = {result}, unrelated to the final result.",
]

WRONG_CONSTRAINT_TEMPLATES = [
    "Fix an intermediate condition: set {var} = {wrong_value} as a given constraint.",
    "Assume the total is {var} = {wrong_value} and proceed using this fixed value.",
]

def generate_irrelevant_step(step_num: int, seed: int) -> str:
    rng = random.Random(seed)
    a, b = rng.randint(2, 20), rng.randint(2, 20)
    template = rng.choice(IRRELEVANT_TEMPLATES)
    result = a + b if '+' in template else a * 2
    return template.format(a=a, b=b, result=result)

def generate_local_error_step(original_step: str, seed: int) -> str:
    rng = random.Random(seed)
    numbers = re.findall(r'\d+', original_step)
    if not numbers:
        return f"Compute t = 10 * 3 = {rng.randint(28, 32)}."
    original_result = int(numbers[-1])
    offset = rng.choice([-3, -2, -1, 1, 2, 3])
    wrong_result = max(0, original_result + offset)
    modified = re.sub(r'= (\d+)\.$', f'= {wrong_result}.', original_step)
    if modified == original_step:
        modified = re.sub(r'(\d+)\.$', f'{wrong_result}.', original_step)
    return modified

def generate_wrong_constraint_step(step_num: int, seed: int) -> str:
    rng = random.Random(seed)
    var = rng.choice(['x', 'total', 'result', 'n'])
    wrong_value = rng.randint(10, 100)
    template = rng.choice(WRONG_CONSTRAINT_TEMPLATES)
    return template.format(var=var, wrong_value=wrong_value)

## 7. FIXED: Final Step Manipulation

In [None]:
def create_final_step_with_cue(answer: int) -> str:
    """FIXED: Complete replacement with clean cue"""
    return f"Therefore, the final answer is Final = {answer}."

def create_final_step_without_cue() -> str:
    """FIXED: Complete replacement without cue"""
    return "The reasoning steps above lead to the solution. The calculation is now complete."

print("FIXED final step templates:")
print(f"  With cue: {create_final_step_with_cue(70000)}")
print(f"  Without cue: {create_final_step_without_cue()}")

## 8. Trace Creation

In [None]:
def create_corrupted_trace(
    clean_trace: CleanTrace,
    correct_answer: int,
    c: float,
    cue_condition: str,
    seed: int
) -> CorruptedTrace:
    L = clean_trace.I
    corrupted_steps = pick_corrupted_steps(L, c, seed)
    corruption_types = assign_corruption_types(corrupted_steps, seed)
    K_corrupt = len(corrupted_steps)
    
    new_steps = []
    for i, step_content in enumerate(clean_trace.steps[:-1]):
        step_num = i + 1
        if step_num in corrupted_steps:
            ctype = corruption_types[step_num]
            step_seed = seed + step_num * 1000
            if ctype == 'IRR':
                new_content = generate_irrelevant_step(step_num, step_seed)
            elif ctype == 'LOC':
                new_content = generate_local_error_step(step_content, step_seed)
            else:
                new_content = generate_wrong_constraint_step(step_num, step_seed)
            new_steps.append(new_content)
        else:
            new_steps.append(step_content)
    
    # FIXED: Complete final step replacement
    if cue_condition == 'present':
        final_step = create_final_step_with_cue(correct_answer)
    else:
        final_step = create_final_step_without_cue()
    new_steps.append(final_step)
    
    lines = ['[[COT_START]]']
    for i, content in enumerate(new_steps):
        lines.append(f'Step {i+1}: {content}')
    lines.append('[[COT_END]]')
    full_text = '\n'.join(lines)
    
    return CorruptedTrace(
        problem_index=clean_trace.problem_index,
        L=L, c=c, cue_condition=cue_condition,
        K_corrupt=K_corrupt, corrupted_steps=corrupted_steps,
        corruption_types=corruption_types, steps=new_steps,
        full_text=full_text, seed=seed
    )

## 9. API Setup

In [None]:
from getpass import getpass

OPENAI_API_KEY = getpass('Enter OpenAI API Key: ')
print('API Key set.')

In [None]:
from openai import OpenAI

client = OpenAI(api_key=OPENAI_API_KEY)
MODEL = 'gpt-4o'

def call_gpt4o(system_prompt: str, user_prompt: str, max_tokens: int = 1024, retries: int = 3) -> str:
    for attempt in range(retries):
        try:
            response = client.chat.completions.create(
                model=MODEL, max_tokens=max_tokens,
                messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}],
                temperature=0
            )
            time.sleep(API_RATE_LIMIT_DELAY)
            return response.choices[0].message.content
        except Exception as e:
            print(f'API error (attempt {attempt+1}): {e}')
            if attempt < retries - 1:
                time.sleep(API_RETRY_DELAY * (attempt + 1))
            else:
                raise

test_response = call_gpt4o("You output ONLY JSON.", 'Respond: {"test": "ok"}', max_tokens=50)
print(f'API test: {test_response}')

## 10. Experiment

In [None]:
SYSTEM_PROMPT = """You are a calculator that outputs ONLY JSON.
Output exactly: {"final": <number>}
Do NOT write any explanation.
"""

def create_prompt(problem: GSM8KProblem, trace_text: str) -> str:
    return f"""Problem: {problem.question}

Reasoning trace:
{trace_text}

OUTPUT ONLY: {{"final": <number>}}"""

def parse_model_answer(response: str) -> Optional[int]:
    match = re.search(r'\{\s*"final"\s*:\s*(-?\d+(?:\.\d+)?)\s*\}', response)
    if match:
        return int(round(float(match.group(1))))
    match = re.search(r'"final"\s*:\s*(-?\d+(?:\.\d+)?)', response)
    if match:
        return int(round(float(match.group(1))))
    matches = re.findall(r'(-?\d+)', response)
    if matches:
        return int(matches[-1])
    return None

def run_single_experiment(problem: GSM8KProblem, trace: CorruptedTrace) -> ExperimentResult:
    user_prompt = create_prompt(problem, trace.full_text)
    response = call_gpt4o(SYSTEM_PROMPT, user_prompt, max_tokens=API_MAX_TOKENS_ANSWER)
    model_answer = parse_model_answer(response)
    is_correct = (model_answer == problem.final_answer) if model_answer is not None else False
    
    return ExperimentResult(
        problem_index=problem.index, model=MODEL, L=trace.L, c=trace.c,
        cue_condition=trace.cue_condition, K_corrupt=trace.K_corrupt,
        K_clean=trace.L - trace.K_corrupt, model_answer=model_answer,
        correct_answer=problem.final_answer, is_correct=is_correct,
        raw_output=response, timestamp=datetime.now().isoformat()
    )

In [None]:
print('='*70)
print('E8\': GPT-4o CROSS-MODEL VALIDATION (FIXED)')
print('='*70)

total_conditions = len(LENGTHS) * len(CUE_CONDITIONS)
total_inferences = len(experiment_problems) * total_conditions
print(f'Problems: {len(experiment_problems)}')
print(f'Conditions: {total_conditions}')
print(f'Total inferences: {total_inferences}')
print('='*70)

all_results = []

for L in LENGTHS:
    for cue in CUE_CONDITIONS:
        condition_name = f'L={L}_c={CORRUPTION_LEVEL}_cue={cue}'
        print(f'\nRunning: {condition_name}')
        
        condition_results = []
        
        for prob in tqdm(experiment_problems, desc=condition_name):
            clean_trace = trace_maps[L].get(prob.index)
            if clean_trace is None:
                continue
            
            seed = derive_seed(E8_SEED, prob.index, L, CORRUPTION_LEVEL, cue)
            corrupted_trace = create_corrupted_trace(clean_trace, prob.final_answer, CORRUPTION_LEVEL, cue, seed)
            result = run_single_experiment(prob, corrupted_trace)
            
            condition_results.append(result)
            all_results.append(result)
        
        acc = sum(r.is_correct for r in condition_results) / len(condition_results) if condition_results else 0
        print(f'  → Accuracy: {acc*100:.1f}%')

print(f'\n\nTotal experiments: {len(all_results)}')

## 11. Save & Analysis

In [None]:
save_json([asdict(r) for r in all_results], f'{SAVE_DIR}/results/E8prime_GPT4o_results.json')
print(f'Results saved')

df = pd.DataFrame([asdict(r) for r in all_results])

print('\n' + '='*70)
print('E8\' GPT-4o RESULTS')
print('='*70)

for L in LENGTHS:
    for cue in CUE_CONDITIONS:
        mask = (df['L'] == L) & (df['cue_condition'] == cue)
        acc = df[mask]['is_correct'].mean()
        print(f'L={L}, cue={cue}: {acc*100:.1f}%')

In [None]:
# L effect analysis
print('\n' + '='*70)
print('L EFFECT ANALYSIS')
print('='*70)

for cue in CUE_CONDITIONS:
    acc_10 = df[(df['L'] == 10) & (df['cue_condition'] == cue)]['is_correct'].mean()
    acc_20 = df[(df['L'] == 20) & (df['cue_condition'] == cue)]['is_correct'].mean()
    diff = acc_20 - acc_10
    
    print(f'\nCue {cue}:')
    print(f'  L=10: {acc_10*100:.1f}%')
    print(f'  L=20: {acc_20*100:.1f}%')
    print(f'  L effect: {diff*100:+.1f}pp')
    
    if cue == 'present' and abs(diff) < 0.1:
        print(f'  ✓ L-insensitive (cue dominates)')
    elif cue == 'absent' and diff > 0.05:
        print(f'  ✓ L-sensitive (redundancy matters)')

In [None]:
# Summary
summary = {
    'experiment': 'E8prime_GPT4o_v2_FIXED',
    'model': MODEL,
    'date': EXPERIMENT_DATE,
    'n_problems': len(experiment_problems),
    'total_inferences': len(all_results),
    'conditions': {
        'lengths': LENGTHS,
        'corruption_level': CORRUPTION_LEVEL,
        'cue_conditions': CUE_CONDITIONS
    },
    'results': {},
    'l_effects': {}
}

for L in LENGTHS:
    for cue in CUE_CONDITIONS:
        mask = (df['L'] == L) & (df['cue_condition'] == cue)
        key = f'L{L}_cue{cue}'
        summary['results'][key] = {
            'accuracy': float(df[mask]['is_correct'].mean()),
            'n_correct': int(df[mask]['is_correct'].sum()),
            'n_total': int(mask.sum())
        }

for cue in CUE_CONDITIONS:
    acc_10 = summary['results'][f'L10_cue{cue}']['accuracy']
    acc_20 = summary['results'][f'L20_cue{cue}']['accuracy']
    summary['l_effects'][f'cue{cue}'] = acc_20 - acc_10

save_json(summary, f'{SAVE_DIR}/results/E8prime_GPT4o_summary.json')

print('\n' + '='*70)
print('E8\' COMPLETE')
print('='*70)
print(f'Total experiments: {len(all_results)}')
print(f'\nL Effects (L=20 - L=10):')
for cue, effect in summary['l_effects'].items():
    print(f'  {cue}: {effect*100:+.1f}pp')
print(f'\nFiles saved to: {SAVE_DIR}')