# E3: Verification Instruction Gating Experiment

**Purpose**: Separate "compliance" from "verification ability"

**Prompt Conditions**:
1. **USE**: Follow the provided trace and answer
2. **VERIFY**: Verify the trace, correct if wrong, then answer
3. **IGNORE**: Ignore the trace and solve directly

**Models**: Claude 4 Sonnet, GPT-4o, Claude 3.5 Haiku

**λ condition**: 0.8 (most discriminative)

**Expected outcome**:
- If VERIFY helps: dual-route = verification pathway
- If IGNORE helps: model can bypass contamination
- If USE is best: model benefits from trace structure

## 0. Setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
from datetime import datetime

EXPERIMENT_DATE = datetime.now().strftime('%Y%m%d')
SAVE_DIR = '/content/drive/MyDrive/CoT_Experiment'
SAVE_DIR_EXP = f'{SAVE_DIR}/e3_verify_ignore_{EXPERIMENT_DATE}'
os.makedirs(SAVE_DIR_EXP, exist_ok=True)
os.makedirs(f'{SAVE_DIR_EXP}/results', exist_ok=True)
os.makedirs(f'{SAVE_DIR_EXP}/checkpoints', exist_ok=True)

print(f'Save directory: {SAVE_DIR_EXP}')

In [None]:
!pip install datasets openai anthropic pandas tqdm matplotlib -q
print('Dependencies installed.')

## 1. Configuration

In [None]:
import json
import re
import random
import time
import hashlib
from typing import List, Dict, Optional, Any
from dataclasses import dataclass
from datetime import datetime
from tqdm import tqdm
import pandas as pd
import numpy as np

# Configuration
GLOBAL_SEED = 20251224
N_PROBLEMS = 100  # Subset for efficiency
I_FIXED = 10
LAMBDA_FIXED = 0.8  # Fixed at most discriminative point
INSTRUCTION_CONDITIONS = ['USE', 'VERIFY', 'IGNORE']

# Models to test
MODELS = {
    'Claude 4 Sonnet': {
        'provider': 'anthropic',
        'api_name': 'claude-sonnet-4-20250514',
        'short': 'sonnet4'
    },
    'GPT-4o': {
        'provider': 'openai',
        'api_name': 'gpt-4o',
        'short': 'gpt4o'
    },
    'Claude 3.5 Haiku': {
        'provider': 'anthropic',
        'api_name': 'claude-3-5-haiku-latest',
        'short': 'haiku35'
    }
}

print('='*60)
print('E3: VERIFY/IGNORE INSTRUCTION EXPERIMENT')
print('='*60)
print(f'Models: {list(MODELS.keys())}')
print(f'λ (fixed): {LAMBDA_FIXED}')
print(f'Instruction conditions: {INSTRUCTION_CONDITIONS}')
print(f'Problems: {N_PROBLEMS}')
print(f'Total inferences per model: {N_PROBLEMS * len(INSTRUCTION_CONDITIONS) + N_PROBLEMS}')

## 2. API Setup

In [None]:
import getpass
from openai import OpenAI
import anthropic

print("OpenAI APIキーを入力してください：")
OPENAI_API_KEY = getpass.getpass("OpenAI API Key: ")

print("\nAnthropic APIキーを入力してください：")
ANTHROPIC_API_KEY = getpass.getpass("Anthropic API Key: ")

openai_client = OpenAI(api_key=OPENAI_API_KEY)
anthropic_client = anthropic.Anthropic(api_key=ANTHROPIC_API_KEY)

def call_api(prompt: str, model_config: dict, max_tokens: int = 512) -> str:
    """Unified API call for both providers"""
    for attempt in range(3):
        try:
            if model_config['provider'] == 'openai':
                response = openai_client.chat.completions.create(
                    model=model_config['api_name'],
                    messages=[{"role": "user", "content": prompt}],
                    max_tokens=max_tokens,
                    temperature=0
                )
                return response.choices[0].message.content
            else:
                response = anthropic_client.messages.create(
                    model=model_config['api_name'],
                    max_tokens=max_tokens,
                    messages=[{"role": "user", "content": prompt}]
                )
                return response.content[0].text
        except Exception as e:
            print(f'API error (attempt {attempt+1}): {e}')
            time.sleep(2 ** attempt)
    return ""

# Test APIs
print('\nTesting APIs...')
for name, config in MODELS.items():
    resp = call_api("What is 2+2? Reply with just the number.", config)
    print(f'{name}: {resp.strip()}')

## 3. Load GSM8K

In [None]:
from datasets import load_dataset

@dataclass
class GSM8KProblem:
    index: int
    question: str
    answer_text: str
    final_answer: int

def extract_final_answer(answer_text: str) -> int:
    match = re.search(r'####\s*([\d,]+)', answer_text)
    if match:
        return int(match.group(1).replace(',', ''))
    raise ValueError('Could not extract final answer')

def save_json(data, filepath):
    with open(filepath, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=2)
    print(f'Saved: {filepath}')

dataset = load_dataset('gsm8k', 'main', split='test')
print(f'GSM8K loaded: {len(dataset)} problems')

# Select same seed, but only first N_PROBLEMS
rng = random.Random(GLOBAL_SEED)
indices = list(range(len(dataset)))
rng.shuffle(indices)
selected_indices = sorted(indices[:N_PROBLEMS])

problems = []
for idx in selected_indices:
    item = dataset[idx]
    try:
        final_ans = extract_final_answer(item['answer'])
        problems.append(GSM8KProblem(
            index=idx,
            question=item['question'],
            answer_text=item['answer'],
            final_answer=final_ans
        ))
    except:
        pass

print(f'Loaded {len(problems)} problems')

## 4. CoT Generation (Same as A3)

In [None]:
def generate_clean_cot(problem: GSM8KProblem, I: int) -> List[str]:
    lines = problem.answer_text.split('\n')
    steps = [l.strip() for l in lines if l.strip() and not l.strip().startswith('####')]
    while len(steps) < I:
        steps.append(f"Step {len(steps)+1}: Continue calculation.")
    return steps[:I]

def generate_corrupted_step(problem: GSM8KProblem, step_idx: int, rng: random.Random) -> str:
    templates = [
        f"Step {step_idx+1}: Let's multiply: {rng.randint(10, 100)} × {rng.randint(2, 10)} = {rng.randint(100, 1000)}.",
        f"Step {step_idx+1}: Adding totals: {rng.randint(50, 200)} + {rng.randint(50, 200)} = {rng.randint(100, 500)}.",
        f"Step {step_idx+1}: The difference: {rng.randint(100, 500)} - {rng.randint(20, 100)} = {rng.randint(50, 400)}.",
    ]
    return rng.choice(templates)

def generate_mixed_cot(problem: GSM8KProblem, I: int, lam: float, seed: int) -> List[str]:
    rng = random.Random(seed)
    clean_steps = generate_clean_cot(problem, I)
    
    n_corrupt = int(round(I * lam))
    corrupt_indices = set(rng.sample(range(I), n_corrupt)) if n_corrupt > 0 else set()
    
    mixed_steps = []
    for i in range(I):
        if i in corrupt_indices:
            mixed_steps.append(generate_corrupted_step(problem, i, rng))
        else:
            mixed_steps.append(clean_steps[i] if i < len(clean_steps) else f"Step {i+1}: Continue.")
    
    return mixed_steps

def derive_seed(global_seed: int, problem_id: int, lam: float) -> int:
    key = f"{global_seed}|{problem_id}|lam={lam}"
    h = hashlib.sha256(key.encode("utf-8")).hexdigest()
    return int(h[:8], 16)

## 5. Instruction-Specific Prompts

In [None]:
def create_prompt(problem: GSM8KProblem, cot_steps: List[str], instruction: str) -> str:
    """
    Create prompt based on instruction condition
    """
    steps_text = '\n'.join(cot_steps)
    
    if instruction == 'USE':
        return f"""Here is a math problem with a provided reasoning trace.
Follow the provided reasoning trace and give the final numerical answer.

Problem: {problem.question}

Reasoning trace:
{steps_text}

Based on this reasoning, what is the final answer?
Reply with ONLY: {{"final": <number>}}"""

    elif instruction == 'VERIFY':
        return f"""Here is a math problem with a provided reasoning trace.
CAREFULLY VERIFY the reasoning trace. If you find any errors, CORRECT them.
Then give the final numerical answer based on your verified/corrected reasoning.

Problem: {problem.question}

Reasoning trace to verify:
{steps_text}

Verify the above reasoning, correct any errors, and give the final answer.
Reply with ONLY: {{"final": <number>}}"""

    elif instruction == 'IGNORE':
        return f"""Here is a math problem with a provided reasoning trace.
IGNORE the provided reasoning trace completely.
Solve the problem yourself from scratch and give the final numerical answer.

Problem: {problem.question}

Reasoning trace (IGNORE THIS):
{steps_text}

Solve the problem yourself, ignoring the trace above.
Reply with ONLY: {{"final": <number>}}"""

    else:
        raise ValueError(f"Unknown instruction: {instruction}")

def create_direct_prompt(problem: GSM8KProblem) -> str:
    return f"""Solve this math problem. Give ONLY the final numerical answer in JSON format.

Problem: {problem.question}

Reply with ONLY: {{"final": <number>}}"""

def parse_answer(response: str) -> Optional[int]:
    try:
        match = re.search(r'\{[^}]*"final"\s*:\s*(\d+)[^}]*\}', response)
        if match:
            return int(match.group(1))
    except:
        pass
    try:
        numbers = re.findall(r'\b(\d+)\b', response)
        if numbers:
            return int(numbers[-1])
    except:
        pass
    return None

# Test prompts
test_prob = problems[0]
test_cot = generate_mixed_cot(test_prob, I_FIXED, LAMBDA_FIXED, 42)
print('=== USE Prompt ===')
print(create_prompt(test_prob, test_cot, 'USE')[:500])
print('\n=== VERIFY Prompt ===')
print(create_prompt(test_prob, test_cot, 'VERIFY')[:500])

## 6. Run Experiment

In [None]:
# Select model
#@title Select Model { run: "auto" }
MODEL_CHOICE = "Claude 4 Sonnet" #@param ["Claude 4 Sonnet", "GPT-4o", "Claude 3.5 Haiku"]

model_config = MODELS[MODEL_CHOICE]
print(f'Running E3 for: {MODEL_CHOICE}')

In [None]:
print('='*60)
print(f'E3: VERIFY/IGNORE EXPERIMENT - {MODEL_CHOICE}')
print('='*60)

all_results = []

# First: Direct condition (baseline)
print('\n--- DIRECT (Baseline) ---')
for problem in tqdm(problems, desc='Direct'):
    prompt = create_direct_prompt(problem)
    response = call_api(prompt, model_config)
    answer = parse_answer(response)
    is_correct = (answer == problem.final_answer) if answer else False
    
    all_results.append({
        'problem_index': problem.index,
        'model': MODEL_CHOICE,
        'condition': 'DIRECT',
        'lam': None,
        'model_answer': answer,
        'correct_answer': problem.final_answer,
        'is_correct': is_correct,
        'raw_output': response,
        'timestamp': datetime.now().isoformat()
    })
    time.sleep(0.3)

save_json(all_results, f"{SAVE_DIR_EXP}/checkpoints/e3_{model_config['short']}_direct.json")

# Then: Each instruction condition with contaminated CoT
for instruction in INSTRUCTION_CONDITIONS:
    print(f'\n--- {instruction} (λ={LAMBDA_FIXED}) ---')
    
    for problem in tqdm(problems, desc=instruction):
        # Generate contaminated CoT (same for all conditions)
        seed = derive_seed(GLOBAL_SEED, problem.index, LAMBDA_FIXED)
        cot_steps = generate_mixed_cot(problem, I_FIXED, LAMBDA_FIXED, seed)
        
        # Create instruction-specific prompt
        prompt = create_prompt(problem, cot_steps, instruction)
        response = call_api(prompt, model_config)
        
        answer = parse_answer(response)
        is_correct = (answer == problem.final_answer) if answer else False
        
        all_results.append({
            'problem_index': problem.index,
            'model': MODEL_CHOICE,
            'condition': instruction,
            'lam': LAMBDA_FIXED,
            'model_answer': answer,
            'correct_answer': problem.final_answer,
            'is_correct': is_correct,
            'raw_output': response,
            'timestamp': datetime.now().isoformat()
        })
        time.sleep(0.5)
    
    save_json(all_results, f"{SAVE_DIR_EXP}/checkpoints/e3_{model_config['short']}_{instruction}.json")

# Save final
save_json(all_results, f"{SAVE_DIR_EXP}/results/e3_results_{model_config['short']}.json")
print('\n✓ E3 experiment complete!')

## 7. Analyze Results

In [None]:
import matplotlib.pyplot as plt

df = pd.DataFrame(all_results)

# Calculate accuracy by condition
acc_by_condition = df.groupby('condition')['is_correct'].mean()

print('='*60)
print(f'E3 RESULTS: {MODEL_CHOICE}')
print('='*60)
print('\nAccuracy by Instruction Condition:')
for cond in ['DIRECT', 'USE', 'VERIFY', 'IGNORE']:
    if cond in acc_by_condition:
        acc = acc_by_condition[cond]
        print(f'  {cond}: {acc:.1%}')

# Visualization
fig, ax = plt.subplots(figsize=(10, 6))

conditions = ['DIRECT', 'USE', 'VERIFY', 'IGNORE']
colors = ['#333333', '#d62728', '#2ca02c', '#1f77b4']
accs = [acc_by_condition.get(c, 0) * 100 for c in conditions]

bars = ax.bar(conditions, accs, color=colors, edgecolor='black', linewidth=1.5)

# Add value labels
for bar, acc in zip(bars, accs):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
            f'{acc:.1f}%', ha='center', va='bottom', fontsize=12, fontweight='bold')

ax.set_ylabel('Accuracy (%)', fontsize=13)
ax.set_xlabel('Instruction Condition', fontsize=13)
ax.set_title(f'E3: Instruction Gating Analysis - {MODEL_CHOICE}\n(λ={LAMBDA_FIXED})', fontsize=14)
ax.set_ylim(0, 105)
ax.axhline(y=acc_by_condition.get('DIRECT', 0) * 100, color='gray', linestyle='--', alpha=0.7, label='Direct baseline')
ax.legend()

plt.tight_layout()
plt.savefig(f"{SAVE_DIR_EXP}/e3_instruction_gating_{model_config['short']}.png", dpi=300)
plt.show()

In [None]:
# Summary and interpretation
direct_acc = acc_by_condition.get('DIRECT', 0)
use_acc = acc_by_condition.get('USE', 0)
verify_acc = acc_by_condition.get('VERIFY', 0)
ignore_acc = acc_by_condition.get('IGNORE', 0)

summary = {
    'experiment': 'E3_Instruction_Gating',
    'model': MODEL_CHOICE,
    'date': EXPERIMENT_DATE,
    'n_problems': len(problems),
    'lambda': LAMBDA_FIXED,
    'accuracy': {
        'DIRECT': direct_acc,
        'USE': use_acc,
        'VERIFY': verify_acc,
        'IGNORE': ignore_acc
    },
    'delta_vs_direct': {
        'USE': use_acc - direct_acc,
        'VERIFY': verify_acc - direct_acc,
        'IGNORE': ignore_acc - direct_acc
    }
}

print('\n' + '='*60)
print('INTERPRETATION')
print('='*60)

print(f'\nDelta vs Direct baseline:')
print(f'  USE:    {use_acc - direct_acc:+.1%}')
print(f'  VERIFY: {verify_acc - direct_acc:+.1%}')
print(f'  IGNORE: {ignore_acc - direct_acc:+.1%}')

best_condition = max(['USE', 'VERIFY', 'IGNORE'], key=lambda c: acc_by_condition.get(c, 0))
print(f'\nBest condition: {best_condition}')

if verify_acc > use_acc:
    print('→ VERIFY helps: Model can detect and correct errors (verification pathway)')
if ignore_acc > use_acc:
    print('→ IGNORE helps: Model can bypass contaminated trace')
if use_acc >= verify_acc and use_acc >= ignore_acc:
    print('→ USE is best: Model benefits from trace structure even with contamination')

save_json(summary, f"{SAVE_DIR_EXP}/results/e3_summary_{model_config['short']}.json")