# E4: Hard Subset Analysis

**Purpose**: Avoid ceiling effect in Claude 4 Sonnet (92.5% baseline)

**Approach**:
1. **GSM8K-Hard-Consensus**: Problems where ≥4 models got WRONG in direct
2. Test on this harder subset where baseline is lower
3. Check if backfire becomes clearer

**Hypothesis**:
- On harder problems, Claude 4 Sonnet's λ* should be more clearly defined
- The backfire effect should be more pronounced

**λ conditions**: 0.0, 0.4, 0.8, 1.0 (4 points for efficiency)

## 0. Setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
from datetime import datetime

EXPERIMENT_DATE = datetime.now().strftime('%Y%m%d')
SAVE_DIR = '/content/drive/MyDrive/CoT_Experiment'
SAVE_DIR_EXP = f'{SAVE_DIR}/e4_hard_subset_{EXPERIMENT_DATE}'
os.makedirs(SAVE_DIR_EXP, exist_ok=True)
os.makedirs(f'{SAVE_DIR_EXP}/results', exist_ok=True)

print(f'Save directory: {SAVE_DIR_EXP}')

In [None]:
!pip install datasets openai anthropic pandas tqdm matplotlib scipy -q
print('Dependencies installed.')

## 1. Load Previous Direct Results

In [None]:
import json
import re
import random
import time
import hashlib
from typing import List, Dict, Optional, Any
from dataclasses import dataclass
from datetime import datetime
from tqdm import tqdm
import pandas as pd
import numpy as np

def load_json(filepath):
    with open(filepath, 'r', encoding='utf-8') as f:
        return json.load(f)

def save_json(data, filepath):
    with open(filepath, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=2)
    print(f'Saved: {filepath}')

# Load direct results from all models
# UPDATE THESE PATHS based on your folder structure
DIRECT_RESULTS_PATHS = {
    'Claude 3 Haiku': f'{SAVE_DIR}/a3_claude_haiku_20251228/results/direct_results_haiku.json',
    'Claude 3.5 Haiku': f'{SAVE_DIR}/a3_claude_haiku35_20251228/results/direct_results_haiku35.json',
    'Claude 4 Sonnet': f'{SAVE_DIR}/a3_claude_sonnet4_20251228/results/direct_results_sonnet4.json',
    'GPT-3.5 Turbo': f'{SAVE_DIR}/a3_gpt_gpt35_20251228/results/direct_results_gpt35.json',
    'GPT-4o-mini': f'{SAVE_DIR}/a3_gpt_gpt4omini_20251228/results/direct_results_gpt4omini.json',
    'GPT-4o': f'{SAVE_DIR}/chatgpt_experiment_20251224/results/direct_results_gpt4o.json',
}

all_direct = {}
for model, path in DIRECT_RESULTS_PATHS.items():
    try:
        data = load_json(path)
        all_direct[model] = {r['problem_index']: r['is_correct'] for r in data}
        print(f'✓ {model}: {len(data)} problems, {sum(r["is_correct"] for r in data)/len(data):.1%} accuracy')
    except Exception as e:
        print(f'✗ {model}: {e}')

## 2. Identify Hard Problems

In [None]:
# Get all problem indices
all_problem_ids = set()
for model, results in all_direct.items():
    all_problem_ids.update(results.keys())

# Count how many models got each problem WRONG
problem_difficulty = {}
for pid in all_problem_ids:
    n_wrong = sum(1 for model in all_direct if pid in all_direct[model] and not all_direct[model][pid])
    n_models = sum(1 for model in all_direct if pid in all_direct[model])
    problem_difficulty[pid] = {'n_wrong': n_wrong, 'n_models': n_models}

# Create hard subset: problems where >= 4 models got WRONG
HARD_THRESHOLD = 4  # At least 4 models got it wrong

hard_problems = [pid for pid, info in problem_difficulty.items() 
                 if info['n_wrong'] >= HARD_THRESHOLD]

print(f'\nTotal problems: {len(all_problem_ids)}')
print(f'Hard problems (≥{HARD_THRESHOLD} models wrong): {len(hard_problems)}')

# Distribution
print('\nDifficulty distribution:')
for n in range(7):
    count = sum(1 for pid, info in problem_difficulty.items() if info['n_wrong'] == n)
    print(f'  {n} models wrong: {count} problems')

In [None]:
# Check Claude 4 Sonnet's performance on hard subset
if 'Claude 4 Sonnet' in all_direct:
    sonnet4_hard = [all_direct['Claude 4 Sonnet'].get(pid, False) for pid in hard_problems]
    sonnet4_hard_acc = sum(sonnet4_hard) / len(sonnet4_hard) if sonnet4_hard else 0
    print(f'\nClaude 4 Sonnet on hard subset:')
    print(f'  Full dataset baseline: 92.5%')
    print(f'  Hard subset baseline: {sonnet4_hard_acc:.1%}')
    print(f'  → Baseline dropped by {92.5 - sonnet4_hard_acc*100:.1f}pp')

# Save hard problem list
save_json({
    'threshold': HARD_THRESHOLD,
    'n_problems': len(hard_problems),
    'problem_indices': hard_problems
}, f'{SAVE_DIR_EXP}/hard_problem_list.json')

## 3. Configuration

In [None]:
from datasets import load_dataset

GLOBAL_SEED = 20251224
I_FIXED = 10
LAMBDA_VALUES = [0.0, 0.4, 0.8, 1.0]  # 4 points for efficiency

# Models to test on hard subset
MODELS = {
    'Claude 4 Sonnet': {
        'provider': 'anthropic',
        'api_name': 'claude-sonnet-4-20250514',
        'short': 'sonnet4'
    },
    'GPT-4o': {
        'provider': 'openai',
        'api_name': 'gpt-4o',
        'short': 'gpt4o'
    },
    'Claude 3.5 Haiku': {
        'provider': 'anthropic',
        'api_name': 'claude-3-5-haiku-latest',
        'short': 'haiku35'
    }
}

print('='*60)
print('E4: HARD SUBSET EXPERIMENT')
print('='*60)
print(f'Hard problems: {len(hard_problems)}')
print(f'λ values: {LAMBDA_VALUES}')
print(f'Models: {list(MODELS.keys())}')
print(f'Total inferences per model: {len(hard_problems) * (len(LAMBDA_VALUES) + 1)}')

## 4. API Setup

In [None]:
import getpass
from openai import OpenAI
import anthropic

print("OpenAI APIキーを入力してください：")
OPENAI_API_KEY = getpass.getpass("OpenAI API Key: ")

print("\nAnthropic APIキーを入力してください：")
ANTHROPIC_API_KEY = getpass.getpass("Anthropic API Key: ")

openai_client = OpenAI(api_key=OPENAI_API_KEY)
anthropic_client = anthropic.Anthropic(api_key=ANTHROPIC_API_KEY)

def call_api(prompt: str, model_config: dict, max_tokens: int = 256) -> str:
    for attempt in range(3):
        try:
            if model_config['provider'] == 'openai':
                response = openai_client.chat.completions.create(
                    model=model_config['api_name'],
                    messages=[{"role": "user", "content": prompt}],
                    max_tokens=max_tokens,
                    temperature=0
                )
                return response.choices[0].message.content
            else:
                response = anthropic_client.messages.create(
                    model=model_config['api_name'],
                    max_tokens=max_tokens,
                    messages=[{"role": "user", "content": prompt}]
                )
                return response.content[0].text
        except Exception as e:
            print(f'API error: {e}')
            time.sleep(2 ** attempt)
    return ""

print('\nTesting APIs...')
for name, config in MODELS.items():
    resp = call_api("2+2=?", config)
    print(f'{name}: OK')

## 5. Load GSM8K Hard Subset

In [None]:
@dataclass
class GSM8KProblem:
    index: int
    question: str
    answer_text: str
    final_answer: int

def extract_final_answer(answer_text: str) -> int:
    match = re.search(r'####\s*([\d,]+)', answer_text)
    if match:
        return int(match.group(1).replace(',', ''))
    raise ValueError('Could not extract')

dataset = load_dataset('gsm8k', 'main', split='test')

# Load only hard problems
problems = []
for idx in hard_problems:
    item = dataset[idx]
    try:
        final_ans = extract_final_answer(item['answer'])
        problems.append(GSM8KProblem(
            index=idx,
            question=item['question'],
            answer_text=item['answer'],
            final_answer=final_ans
        ))
    except:
        pass

print(f'Loaded {len(problems)} hard problems')

## 6. CoT Functions

In [None]:
def generate_clean_cot(problem: GSM8KProblem, I: int) -> List[str]:
    lines = problem.answer_text.split('\n')
    steps = [l.strip() for l in lines if l.strip() and not l.strip().startswith('####')]
    while len(steps) < I:
        steps.append(f"Step {len(steps)+1}: Continue.")
    return steps[:I]

def generate_corrupted_step(problem, step_idx, rng):
    templates = [
        f"Step {step_idx+1}: Calculate {rng.randint(10,100)} × {rng.randint(2,10)} = {rng.randint(100,1000)}.",
        f"Step {step_idx+1}: Add {rng.randint(50,200)} + {rng.randint(50,200)} = {rng.randint(100,500)}.",
    ]
    return rng.choice(templates)

def generate_mixed_cot(problem, I, lam, seed):
    rng = random.Random(seed)
    clean_steps = generate_clean_cot(problem, I)
    n_corrupt = int(round(I * lam))
    corrupt_indices = set(rng.sample(range(I), n_corrupt)) if n_corrupt > 0 else set()
    
    return [generate_corrupted_step(problem, i, rng) if i in corrupt_indices 
            else clean_steps[i] for i in range(I)]

def derive_seed(global_seed, problem_id, lam):
    key = f"{global_seed}|{problem_id}|lam={lam}"
    return int(hashlib.sha256(key.encode()).hexdigest()[:8], 16)

def create_direct_prompt(problem):
    return f"""Solve this math problem. Give ONLY the final numerical answer in JSON format.

Problem: {problem.question}

Reply with ONLY: {{"final": <number>}}"""

def create_cot_prompt(problem, cot_steps):
    steps_text = '\n'.join(cot_steps)
    return f"""Here is a math problem with a provided reasoning trace.
Follow the reasoning and give the final numerical answer in JSON format.

Problem: {problem.question}

Reasoning trace:
{steps_text}

Based on this reasoning, what is the final answer?
Reply with ONLY: {{"final": <number>}}"""

def parse_answer(response):
    try:
        match = re.search(r'\{[^}]*"final"\s*:\s*(\d+)[^}]*\}', response)
        if match:
            return int(match.group(1))
    except:
        pass
    try:
        numbers = re.findall(r'\b(\d+)\b', response)
        if numbers:
            return int(numbers[-1])
    except:
        pass
    return None

## 7. Run Experiment

In [None]:
# Select model
#@title Select Model { run: "auto" }
MODEL_CHOICE = "Claude 4 Sonnet" #@param ["Claude 4 Sonnet", "GPT-4o", "Claude 3.5 Haiku"]

model_config = MODELS[MODEL_CHOICE]
print(f'Running E4 for: {MODEL_CHOICE}')

In [None]:
print('='*60)
print(f'E4: HARD SUBSET EXPERIMENT - {MODEL_CHOICE}')
print('='*60)

all_results = []

# Direct condition
print('\n--- DIRECT (Baseline on Hard) ---')
for problem in tqdm(problems, desc='Direct'):
    prompt = create_direct_prompt(problem)
    response = call_api(prompt, model_config)
    answer = parse_answer(response)
    is_correct = (answer == problem.final_answer) if answer else False
    
    all_results.append({
        'problem_index': problem.index,
        'model': MODEL_CHOICE,
        'condition': 'direct',
        'lam': None,
        'model_answer': answer,
        'correct_answer': problem.final_answer,
        'is_correct': is_correct,
        'subset': 'hard'
    })
    time.sleep(0.3)

# CoT conditions
for lam in LAMBDA_VALUES:
    print(f'\n--- λ={lam} ---')
    
    for problem in tqdm(problems, desc=f'λ={lam}'):
        seed = derive_seed(GLOBAL_SEED, problem.index, lam)
        cot_steps = generate_mixed_cot(problem, I_FIXED, lam, seed)
        prompt = create_cot_prompt(problem, cot_steps)
        response = call_api(prompt, model_config)
        
        answer = parse_answer(response)
        is_correct = (answer == problem.final_answer) if answer else False
        
        all_results.append({
            'problem_index': problem.index,
            'model': MODEL_CHOICE,
            'condition': 'cot',
            'lam': lam,
            'model_answer': answer,
            'correct_answer': problem.final_answer,
            'is_correct': is_correct,
            'subset': 'hard'
        })
        time.sleep(0.5)

save_json(all_results, f"{SAVE_DIR_EXP}/results/e4_results_{model_config['short']}.json")
print('\n✓ E4 experiment complete!')

## 8. Analyze Results

In [None]:
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
from scipy.optimize import brentq

df = pd.DataFrame(all_results)

# Baseline on hard subset
baseline_hard = df[df['condition'] == 'direct']['is_correct'].mean()

# Accuracy by lambda
cot_df = df[df['condition'] == 'cot']
acc_by_lam = cot_df.groupby('lam')['is_correct'].mean().to_dict()

print('='*60)
print(f'E4 RESULTS: {MODEL_CHOICE} on Hard Subset')
print('='*60)
print(f'\nBaseline (Hard): {baseline_hard:.1%}')
print(f'Baseline (Full): 92.5%')  # From main experiment
print(f'Drop: {92.5 - baseline_hard*100:.1f}pp')
print('\nAccuracy by λ:')
for lam, acc in sorted(acc_by_lam.items()):
    marker = '← BACKFIRE' if acc < baseline_hard else ''
    print(f'  λ={lam}: {acc:.1%} {marker}')

In [None]:
# Estimate λ* on hard subset
def estimate_lambda_crit(lam_arr, acc_arr, baseline):
    lam_arr = np.array(lam_arr)
    acc_arr = np.array(acc_arr)
    try:
        f = interp1d(lam_arr, acc_arr - baseline, kind='linear', fill_value='extrapolate')
        for i in range(len(lam_arr) - 1):
            if (acc_arr[i] - baseline) * (acc_arr[i+1] - baseline) < 0:
                return brentq(f, lam_arr[i], lam_arr[i+1])
    except:
        pass
    return 1.0 if acc_arr[-1] > baseline else None

lam_arr = sorted(acc_by_lam.keys())
acc_arr = [acc_by_lam[l] for l in lam_arr]
lambda_crit_hard = estimate_lambda_crit(lam_arr, acc_arr, baseline_hard)

print(f'\nλ* (Hard subset): {lambda_crit_hard:.3f}' if lambda_crit_hard else 'λ* not found')
print(f'λ* (Full dataset): ~0.4')

In [None]:
# Visualization: Compare full vs hard
fig, ax = plt.subplots(figsize=(10, 6))

# Hard subset
lams = sorted(acc_by_lam.keys())
accs = [acc_by_lam[l] * 100 for l in lams]
ax.plot(lams, accs, 'o-', color='#9467bd', linewidth=2.5, markersize=10, 
        label=f'{MODEL_CHOICE} (Hard subset)')
ax.axhline(y=baseline_hard * 100, color='#9467bd', linestyle='--', alpha=0.7)

# Reference: Full dataset (hardcoded from main experiment)
if MODEL_CHOICE == 'Claude 4 Sonnet':
    full_acc = {0.0: 98.0, 0.2: 97.0, 0.4: 92.5, 0.6: 90.5, 0.8: 91.5, 1.0: 88.4}
    full_baseline = 92.5
    ax.plot(list(full_acc.keys()), list(full_acc.values()), 's--', color='gray', 
            linewidth=1.5, markersize=8, alpha=0.5, label='Full dataset (reference)')
    ax.axhline(y=full_baseline, color='gray', linestyle=':', alpha=0.5)

ax.set_xlabel('Corruption Rate (λ)', fontsize=13)
ax.set_ylabel('Accuracy (%)', fontsize=13)
ax.set_title(f'E4: Hard Subset Analysis - {MODEL_CHOICE}\n(n={len(problems)} hard problems)', fontsize=14)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
ax.set_xlim(-0.05, 1.05)

plt.tight_layout()
plt.savefig(f"{SAVE_DIR_EXP}/e4_hard_subset_{model_config['short']}.png", dpi=300)
plt.show()

In [None]:
# Summary
summary = {
    'experiment': 'E4_Hard_Subset',
    'model': MODEL_CHOICE,
    'date': EXPERIMENT_DATE,
    'n_hard_problems': len(problems),
    'hard_threshold': HARD_THRESHOLD,
    'baseline_hard': baseline_hard,
    'baseline_full': 0.925,  # Reference
    'accuracy_by_lambda': acc_by_lam,
    'lambda_crit_hard': lambda_crit_hard,
    'lambda_crit_full': 0.4,  # Reference
}

print('\n' + '='*60)
print('INTERPRETATION')
print('='*60)

if lambda_crit_hard and lambda_crit_hard < 0.8:
    print(f'\n→ λ* on hard subset ({lambda_crit_hard:.3f}) is similar to full dataset (~0.4)')
    print('→ Backfire effect is consistent, not due to ceiling effect')
elif lambda_crit_hard and lambda_crit_hard >= 0.8:
    print(f'\n→ λ* on hard subset ({lambda_crit_hard:.3f}) is HIGHER than full dataset')
    print('→ Ceiling effect may have contributed to early backfire on full dataset')
else:
    print('\n→ No clear backfire on hard subset')

save_json(summary, f"{SAVE_DIR_EXP}/results/e4_summary_{model_config['short']}.json")