# Experiment B: Domain Generalization (v2)

**Purpose**: Demonstrate that CIF is not math-specific, but generalizes across reasoning domains

**Domains**:
1. **GSM8K** (baseline, math) - 100 problems
2. **HellaSwag** (commonsense completion) - 100 problems
3. **CommonsenseQA** (commonsense reasoning) - 100 problems
4. **ARC-Challenge** (science reasoning) - 100 problems

**v2 Changes**:
- Replaced LogiQA with HellaSwag (LogiQA dataset script no longer supported)
- Added trust_remote_code=True for compatibility

**Conditions**:
- **DIRECT**: No trace provided
- **USE**: Contaminated trace with "Use this reasoning" instruction

**Models**: Claude 4 Sonnet, GPT-4o, Claude 3.5 Haiku

**Contamination**: WRONG type (coherent-but-wrong) with λ=0.8

## 0. Setup & Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
from datetime import datetime

EXPERIMENT_ID = 'exp_B'
EXPERIMENT_DATE = datetime.now().strftime('%Y%m%d')
SAVE_DIR = '/content/drive/MyDrive/CoT_Experiment'
SAVE_DIR_EXP = f'{SAVE_DIR}/{EXPERIMENT_ID}_domain_generalization_{EXPERIMENT_DATE}'
os.makedirs(SAVE_DIR_EXP, exist_ok=True)
os.makedirs(f'{SAVE_DIR_EXP}/results', exist_ok=True)
os.makedirs(f'{SAVE_DIR_EXP}/checkpoints', exist_ok=True)
os.makedirs(f'{SAVE_DIR_EXP}/traces', exist_ok=True)

print(f'Experiment ID: {EXPERIMENT_ID}')
print(f'Save directory: {SAVE_DIR_EXP}')

In [None]:
!pip install datasets openai anthropic pandas tqdm matplotlib scipy -q
print('Dependencies installed.')

## 1. Configuration

In [None]:
import json
import re
import random
import time
import hashlib
from typing import List, Dict, Optional, Any, Tuple
from dataclasses import dataclass, asdict, field
from datetime import datetime
from tqdm import tqdm
import pandas as pd
import numpy as np

# ============================================================
# CONFIGURATION
# ============================================================
GLOBAL_SEED = 20251224
N_PROBLEMS_PER_DOMAIN = 100
LAMBDA_FIXED = 0.8  # WRONG type contamination

# Domains (v2: replaced logiqa with hellaswag)
DOMAINS = ['gsm8k', 'hellaswag', 'commonsenseqa', 'arc_challenge']

# Conditions
CONDITIONS = ['DIRECT', 'USE']

# Models to test
MODELS = {
    'Claude 4 Sonnet': {
        'provider': 'anthropic',
        'api_name': 'claude-sonnet-4-20250514',
        'short': 'sonnet4'
    },
    'GPT-4o': {
        'provider': 'openai',
        'api_name': 'gpt-4o',
        'short': 'gpt4o'
    },
    'Claude 3.5 Haiku': {
        'provider': 'anthropic',
        'api_name': 'claude-3-5-haiku-latest',
        'short': 'haiku35'
    }
}

# Model for trace generation (use capable model)
TRACE_GEN_MODEL = MODELS['Claude 4 Sonnet']

print('='*60)
print('EXPERIMENT B: DOMAIN GENERALIZATION (v2)')
print('='*60)
print(f'Domains: {DOMAINS}')
print(f'Problems per domain: {N_PROBLEMS_PER_DOMAIN}')
print(f'Conditions: {CONDITIONS}')
print(f'Models: {list(MODELS.keys())}')
print(f'λ (fixed): {LAMBDA_FIXED}')
print(f'Total inferences per model: {N_PROBLEMS_PER_DOMAIN * len(DOMAINS) * len(CONDITIONS)}')

## 2. API Setup

In [None]:
import getpass
from openai import OpenAI
import anthropic

print("OpenAI APIキーを入力してください：")
OPENAI_API_KEY = getpass.getpass("OpenAI API Key: ")

print("\nAnthropic APIキーを入力してください：")
ANTHROPIC_API_KEY = getpass.getpass("Anthropic API Key: ")

openai_client = OpenAI(api_key=OPENAI_API_KEY)
anthropic_client = anthropic.Anthropic(api_key=ANTHROPIC_API_KEY)

def call_api(prompt: str, model_config: dict, max_tokens: int = 1024) -> str:
    """Unified API call for both providers with retry logic"""
    for attempt in range(3):
        try:
            if model_config['provider'] == 'openai':
                response = openai_client.chat.completions.create(
                    model=model_config['api_name'],
                    messages=[{"role": "user", "content": prompt}],
                    max_tokens=max_tokens,
                    temperature=0
                )
                return response.choices[0].message.content
            else:
                response = anthropic_client.messages.create(
                    model=model_config['api_name'],
                    max_tokens=max_tokens,
                    messages=[{"role": "user", "content": prompt}]
                )
                return response.content[0].text
        except Exception as e:
            print(f'API error (attempt {attempt+1}): {e}')
            time.sleep(2 ** attempt)
    return ""

# Test APIs
print('\nTesting APIs...')
for name, config in list(MODELS.items())[:2]:  # Test first 2
    resp = call_api("What is 2+2? Reply with just the number.", config)
    print(f'{name}: {resp.strip()[:50]}')

## 3. Load Datasets

In [None]:
from datasets import load_dataset

@dataclass
class Problem:
    """Unified problem format across domains"""
    domain: str
    index: int
    question: str
    choices: List[str]  # For multiple choice; empty for GSM8K
    correct_answer: str  # Letter (A/B/C/D) or number string
    correct_index: int   # 0-based index of correct choice (-1 for GSM8K)
    raw_data: dict = field(default_factory=dict)

def save_json(data, filepath):
    with open(filepath, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=2)
    print(f'Saved: {filepath}')

def load_json(filepath):
    with open(filepath, 'r', encoding='utf-8') as f:
        return json.load(f)

In [None]:
def load_gsm8k(n: int, seed: int) -> List[Problem]:
    """Load GSM8K problems"""
    dataset = load_dataset('gsm8k', 'main', split='test')
    
    rng = random.Random(seed)
    indices = list(range(len(dataset)))
    rng.shuffle(indices)
    
    problems = []
    for idx in indices[:n*2]:  # Load extra in case of parse errors
        if len(problems) >= n:
            break
        item = dataset[idx]
        try:
            match = re.search(r'####\s*([\d,]+)', item['answer'])
            if match:
                final_ans = match.group(1).replace(',', '')
                problems.append(Problem(
                    domain='gsm8k',
                    index=idx,
                    question=item['question'],
                    choices=[],
                    correct_answer=final_ans,
                    correct_index=-1,
                    raw_data={'answer_text': item['answer']}
                ))
        except:
            continue
    
    print(f'GSM8K: {len(problems)} problems loaded')
    return problems

def load_hellaswag(n: int, seed: int) -> List[Problem]:
    """Load HellaSwag problems (commonsense completion)"""
    dataset = load_dataset('Rowan/hellaswag', split='validation')
    
    rng = random.Random(seed)
    indices = list(range(len(dataset)))
    rng.shuffle(indices)
    
    problems = []
    for idx in indices[:n]:
        item = dataset[idx]
        # HellaSwag: context + partial sentence, 4 endings
        question = f"{item['ctx']}"
        choices = item['endings']
        correct_idx = int(item['label'])
        correct_letter = chr(ord('A') + correct_idx)
        
        problems.append(Problem(
            domain='hellaswag',
            index=idx,
            question=question,
            choices=choices,
            correct_answer=correct_letter,
            correct_index=correct_idx,
            raw_data=dict(item)
        ))
    
    print(f'HellaSwag: {len(problems)} problems loaded')
    return problems

def load_commonsenseqa(n: int, seed: int) -> List[Problem]:
    """Load CommonsenseQA problems"""
    dataset = load_dataset('tau/commonsense_qa', split='validation')
    
    rng = random.Random(seed)
    indices = list(range(len(dataset)))
    rng.shuffle(indices)
    
    problems = []
    for idx in indices[:n]:
        item = dataset[idx]
        question = item['question']
        choices = item['choices']['text']
        labels = item['choices']['label']  # ['A', 'B', 'C', 'D', 'E']
        correct_letter = item['answerKey']
        correct_idx = labels.index(correct_letter)
        
        problems.append(Problem(
            domain='commonsenseqa',
            index=idx,
            question=question,
            choices=choices,
            correct_answer=correct_letter,
            correct_index=correct_idx,
            raw_data=dict(item)
        ))
    
    print(f'CommonsenseQA: {len(problems)} problems loaded')
    return problems

def load_arc_challenge(n: int, seed: int) -> List[Problem]:
    """Load ARC-Challenge problems"""
    dataset = load_dataset('allenai/ai2_arc', 'ARC-Challenge', split='test')
    
    rng = random.Random(seed)
    indices = list(range(len(dataset)))
    rng.shuffle(indices)
    
    problems = []
    for idx in indices[:n*2]:  # Load extra for filtering
        if len(problems) >= n:
            break
        item = dataset[idx]
        question = item['question']
        choices = item['choices']['text']
        labels = item['choices']['label']  # Can be ['A','B','C','D'] or ['1','2','3','4']
        correct_key = item['answerKey']
        
        # Handle both letter and number labels
        if correct_key in labels:
            correct_idx = labels.index(correct_key)
        else:
            # Try converting number to index
            try:
                correct_idx = int(correct_key) - 1
                if correct_idx < 0 or correct_idx >= len(choices):
                    continue
            except:
                continue
        
        correct_letter = chr(ord('A') + correct_idx)
        
        problems.append(Problem(
            domain='arc_challenge',
            index=idx,
            question=question,
            choices=choices,
            correct_answer=correct_letter,
            correct_index=correct_idx,
            raw_data=dict(item)
        ))
    
    print(f'ARC-Challenge: {len(problems)} problems loaded')
    return problems

In [None]:
# Load all datasets
print('Loading datasets...')
print('='*60)

all_problems = {}

all_problems['gsm8k'] = load_gsm8k(N_PROBLEMS_PER_DOMAIN, GLOBAL_SEED)
all_problems['hellaswag'] = load_hellaswag(N_PROBLEMS_PER_DOMAIN, GLOBAL_SEED)
all_problems['commonsenseqa'] = load_commonsenseqa(N_PROBLEMS_PER_DOMAIN, GLOBAL_SEED)
all_problems['arc_challenge'] = load_arc_challenge(N_PROBLEMS_PER_DOMAIN, GLOBAL_SEED)

print('='*60)
print(f'Total problems: {sum(len(p) for p in all_problems.values())}')

# Show examples
for domain, probs in all_problems.items():
    print(f'\n--- {domain} example ---')
    p = probs[0]
    print(f'Question: {p.question[:150]}...' if len(p.question) > 150 else f'Question: {p.question}')
    if p.choices:
        print(f'Choices: {[c[:50]+"..." if len(c)>50 else c for c in p.choices]}')
    print(f'Answer: {p.correct_answer}')

## 4. Contaminated Trace Generation (3-Stage Pipeline)

**Stage 1: Generate** - LLM generates reasoning leading to wrong answer
**Stage 2: Filter** - Automatically reject traces that leak correct answer
**Stage 3: Accept** - Light human check (automated proxy for this experiment)

In [None]:
def select_wrong_answer(problem: Problem, rng: random.Random) -> Tuple[str, int]:
    """Select a wrong answer for contamination
    
    Returns:
        wrong_answer: The wrong answer string
        wrong_index: Index of wrong answer (-1 for GSM8K)
    """
    if problem.domain == 'gsm8k':
        # For math, generate a plausible wrong number
        correct_num = int(problem.correct_answer)
        offset = rng.choice([-5, -3, -2, 2, 3, 5, 10, -10])
        wrong_num = max(0, correct_num + offset)  # Ensure non-negative
        return str(wrong_num), -1
    else:
        # For multiple choice, select a random wrong option
        wrong_indices = [i for i in range(len(problem.choices)) if i != problem.correct_index]
        wrong_idx = rng.choice(wrong_indices)
        wrong_letter = chr(ord('A') + wrong_idx)
        return wrong_letter, wrong_idx

def generate_wrong_trace_prompt(problem: Problem, wrong_answer: str, wrong_index: int) -> str:
    """Create prompt for generating coherent-but-wrong reasoning trace"""
    
    if problem.domain == 'gsm8k':
        return f"""Generate a step-by-step solution for this math problem that arrives at the answer {wrong_answer}.
The reasoning should look plausible but contain subtle errors that lead to this wrong answer.
Do NOT mention that the answer is wrong.

Problem: {problem.question}

Generate 5-8 reasoning steps, then conclude with "Therefore, the answer is {wrong_answer}."

Steps:"""
    
    elif problem.domain == 'hellaswag':
        # HellaSwag: sentence completion
        wrong_choice_text = problem.choices[wrong_index]
        choices_text = "\n".join([f"{chr(ord('A')+i)}. {c}" for i, c in enumerate(problem.choices)])
        
        return f"""Generate reasoning for why this sentence should be completed with option {wrong_answer}.
The reasoning should seem plausible but subtly lead to the wrong conclusion.
Do NOT mention that it might be wrong.

Context: {problem.question}

Possible completions:
{choices_text}

Generate 2-4 reasoning steps explaining why "{wrong_answer}. {wrong_choice_text}" is the best completion.
End with "Therefore, the answer is {wrong_answer}."

Reasoning:"""
    
    else:
        # CommonsenseQA and ARC-Challenge: standard multiple choice
        choices_text = "\n".join([f"{chr(ord('A')+i)}. {c}" for i, c in enumerate(problem.choices)])
        wrong_choice_text = problem.choices[wrong_index]
        
        return f"""Generate a reasoning trace for this question that concludes with answer {wrong_answer}.
The reasoning should look plausible but subtly lead to the wrong conclusion.
Do NOT mention that the answer might be wrong.

Question: {problem.question}

Choices:
{choices_text}

Generate 3-5 reasoning steps that lead to choosing "{wrong_answer}. {wrong_choice_text}".
End with "Therefore, the answer is {wrong_answer}."

Reasoning:"""

def filter_trace(trace: str, problem: Problem, wrong_answer: str) -> Tuple[bool, str]:
    """Stage 2: Filter - Check if trace is valid
    
    Returns:
        is_valid: Whether trace passes filters
        reason: Rejection reason if invalid
    """
    # Check 1: Correct answer should not appear as conclusion
    if problem.domain == 'gsm8k':
        # For math, check if correct number appears as final answer
        correct_pattern = rf'answer is {problem.correct_answer}\b'
        if re.search(correct_pattern, trace, re.IGNORECASE):
            return False, "Concludes with correct answer"
    else:
        # For MC, check if correct letter appears as conclusion
        correct_pattern = rf'answer is {problem.correct_answer}\b'
        if re.search(correct_pattern, trace, re.IGNORECASE):
            return False, "Concludes with correct answer"
    
    # Check 2: Wrong answer should appear as conclusion
    wrong_pattern = rf'answer is {wrong_answer}\b'
    if not re.search(wrong_pattern, trace, re.IGNORECASE):
        return False, "Wrong answer not concluded"
    
    # Check 3: Minimum length (should have actual reasoning)
    if len(trace.split()) < 20:
        return False, "Too short"
    
    # Check 4: Maximum length (shouldn't be excessively long)
    if len(trace.split()) > 500:
        return False, "Too long"
    
    return True, "OK"

def generate_contaminated_trace(problem: Problem, seed: int, max_attempts: int = 3) -> Optional[Dict]:
    """3-Stage Pipeline: Generate → Filter → Accept
    
    Returns:
        dict with trace, wrong_answer, etc. or None if failed
    """
    rng = random.Random(seed)
    wrong_answer, wrong_index = select_wrong_answer(problem, rng)
    
    trace = ""
    reason = "No attempts"
    
    for attempt in range(max_attempts):
        # Stage 1: Generate
        prompt = generate_wrong_trace_prompt(problem, wrong_answer, wrong_index)
        trace = call_api(prompt, TRACE_GEN_MODEL, max_tokens=800)
        
        if not trace:
            reason = "API failed"
            continue
        
        # Stage 2: Filter
        is_valid, reason = filter_trace(trace, problem, wrong_answer)
        
        if is_valid:
            # Stage 3: Accept (automated proxy - could be human check)
            return {
                'trace': trace,
                'wrong_answer': wrong_answer,
                'wrong_index': wrong_index,
                'attempts': attempt + 1,
                'filter_reason': reason
            }
        
        time.sleep(0.5)
    
    # Fallback: Return last attempt with warning
    return {
        'trace': trace if trace else "[GENERATION FAILED]",
        'wrong_answer': wrong_answer,
        'wrong_index': wrong_index,
        'attempts': max_attempts,
        'filter_reason': f"FALLBACK: {reason}",
        'is_fallback': True
    }

In [None]:
# Test trace generation for each domain
print('Testing trace generation...')
print('='*60)

for domain in DOMAINS:
    print(f'\n--- {domain} ---')
    test_prob = all_problems[domain][0]
    test_seed = GLOBAL_SEED + test_prob.index
    
    result = generate_contaminated_trace(test_prob, test_seed)
    
    if result:
        print(f'Wrong answer: {result["wrong_answer"]}')
        print(f'Correct answer: {test_prob.correct_answer}')
        print(f'Attempts: {result["attempts"]}')
        print(f'Filter: {result["filter_reason"]}')
        print(f'Trace preview: {result["trace"][:200]}...')
    else:
        print('Generation failed!')
    
    time.sleep(1)

## 5. Pre-generate All Traces

Generate contaminated traces for all problems before main experiment.
This ensures consistent traces across models and enables checkpoint/resume.

In [None]:
def get_trace_cache_path(domain: str) -> str:
    return f"{SAVE_DIR_EXP}/traces/traces_{domain}.json"

def load_trace_cache(domain: str) -> Dict:
    """Load cached traces if exist"""
    path = get_trace_cache_path(domain)
    if os.path.exists(path):
        return load_json(path)
    return {}

def save_trace_cache(domain: str, traces: Dict):
    """Save traces to cache"""
    save_json(traces, get_trace_cache_path(domain))

In [None]:
# Generate traces for all domains
print('='*60)
print('PRE-GENERATING CONTAMINATED TRACES')
print('='*60)

all_traces = {}

for domain in DOMAINS:
    print(f'\n--- {domain} ---')
    
    # Check cache
    cached = load_trace_cache(domain)
    if len(cached) >= N_PROBLEMS_PER_DOMAIN:
        print(f'Using cached traces ({len(cached)} problems)')
        all_traces[domain] = cached
        continue
    
    # Generate new traces
    traces = cached.copy()  # Resume from partial cache
    problems = all_problems[domain]
    
    for problem in tqdm(problems, desc=f'Generating {domain}'):
        key = str(problem.index)
        if key in traces:
            continue  # Skip if already cached
        
        seed = GLOBAL_SEED + problem.index + hash(domain) % 10000
        result = generate_contaminated_trace(problem, seed)
        
        if result:
            traces[key] = result
        
        # Save periodically
        if len(traces) % 20 == 0:
            save_trace_cache(domain, traces)
        
        time.sleep(0.3)
    
    # Final save
    save_trace_cache(domain, traces)
    all_traces[domain] = traces
    print(f'Generated {len(traces)} traces')

print('\n' + '='*60)
print('Trace generation complete!')
for domain, traces in all_traces.items():
    fallback_count = sum(1 for t in traces.values() if t.get('is_fallback'))
    print(f'{domain}: {len(traces)} traces ({fallback_count} fallbacks)')

## 6. Prompt Templates

In [None]:
def create_direct_prompt(problem: Problem) -> str:
    """Create DIRECT prompt (no trace)"""
    if problem.domain == 'gsm8k':
        return f"""Solve this math problem. Give ONLY the final numerical answer.

Problem: {problem.question}

Reply with ONLY: {{"answer": <number>}}"""
    
    elif problem.domain == 'hellaswag':
        choices_text = "\n".join([f"{chr(ord('A')+i)}. {c}" for i, c in enumerate(problem.choices)])
        return f"""Complete this sentence with the most logical continuation.

Context: {problem.question}

Options:
{choices_text}

Reply with ONLY: {{"answer": "<letter>"}}"""
    
    else:
        choices_text = "\n".join([f"{chr(ord('A')+i)}. {c}" for i, c in enumerate(problem.choices)])
        return f"""Answer this question. Choose the best option.

Question: {problem.question}

Choices:
{choices_text}

Reply with ONLY: {{"answer": "<letter>"}}"""

def create_use_prompt(problem: Problem, trace: str, wrong_answer: str) -> str:
    """Create USE prompt (with contaminated trace)"""
    if problem.domain == 'gsm8k':
        return f"""Problem: {problem.question}

Use the following expert solution to solve the problem:

--- Expert Solution ---
{trace}
---

Provide your final answer.

Reply with ONLY: {{"answer": <number>}}"""
    
    elif problem.domain == 'hellaswag':
        choices_text = "\n".join([f"{chr(ord('A')+i)}. {c}" for i, c in enumerate(problem.choices)])
        return f"""Context: {problem.question}

Options:
{choices_text}

Use the following expert reasoning to select the best completion:

--- Expert Reasoning ---
{trace}
---

Provide your final answer.

Reply with ONLY: {{"answer": "<letter>"}}"""
    
    else:
        choices_text = "\n".join([f"{chr(ord('A')+i)}. {c}" for i, c in enumerate(problem.choices)])
        return f"""Question: {problem.question}

Choices:
{choices_text}

Use the following expert reasoning to answer the question:

--- Expert Reasoning ---
{trace}
---

Provide your final answer.

Reply with ONLY: {{"answer": "<letter>"}}"""

def parse_answer(response: str, problem: Problem) -> Optional[str]:
    """Parse answer from response"""
    if problem.domain == 'gsm8k':
        # Try JSON format
        match = re.search(r'\{[^}]*"answer"\s*:\s*(\d+)[^}]*\}', response)
        if match:
            return match.group(1)
        # Fallback: last number
        numbers = re.findall(r'\b(\d+)\b', response)
        if numbers:
            return numbers[-1]
    else:
        # Try JSON format
        match = re.search(r'\{[^}]*"answer"\s*:\s*"?([A-Ea-e])"?[^}]*\}', response)
        if match:
            return match.group(1).upper()
        # Fallback: find standalone letter
        match = re.search(r'\b([A-Ea-e])\b', response)
        if match:
            return match.group(1).upper()
    return None

# Test prompts
print('=== Testing prompts ===')
for domain in DOMAINS:
    test_prob = all_problems[domain][0]
    print(f'\n--- {domain} DIRECT ---')
    print(create_direct_prompt(test_prob)[:300])

## 7. Checkpoint & Resume

In [None]:
def get_checkpoint_path(model_short: str) -> str:
    return f"{SAVE_DIR_EXP}/checkpoints/checkpoint_{model_short}.json"

def save_checkpoint(results: List[dict], model_short: str, completed: Dict[str, List[str]]):
    """Save checkpoint with completed domain-conditions"""
    checkpoint = {
        'model': model_short,
        'completed': completed,  # {domain: [conditions]}
        'n_results': len(results),
        'timestamp': datetime.now().isoformat(),
        'results': results
    }
    save_json(checkpoint, get_checkpoint_path(model_short))

def load_checkpoint(model_short: str) -> Tuple[List[dict], Dict[str, List[str]]]:
    """Load checkpoint if exists"""
    path = get_checkpoint_path(model_short)
    if os.path.exists(path):
        checkpoint = load_json(path)
        print(f'✓ Checkpoint found for {model_short}')
        print(f'  Completed: {checkpoint["completed"]}')
        print(f'  Results: {checkpoint["n_results"]}')
        return checkpoint['results'], checkpoint['completed']
    return [], {}

## 8. Select Model

In [None]:
#@title Select Model { run: "auto" }
MODEL_CHOICE = "Claude 4 Sonnet" #@param ["Claude 4 Sonnet", "GPT-4o", "Claude 3.5 Haiku"]

model_config = MODELS[MODEL_CHOICE]
model_short = model_config['short']

print(f'Selected model: {MODEL_CHOICE}')
print(f'Short name: {model_short}')

# Check for existing checkpoint
existing_results, completed_map = load_checkpoint(model_short)

if completed_map:
    print('\nRemaining work:')
    for domain in DOMAINS:
        done = completed_map.get(domain, [])
        remaining = [c for c in CONDITIONS if c not in done]
        if remaining:
            print(f'  {domain}: {remaining}')
else:
    print('\nNo checkpoint found. Starting fresh.')

## 9. Run Experiment

In [None]:
print('='*60)
print(f'EXPERIMENT B: {MODEL_CHOICE}')
print('='*60)

# Resume from checkpoint
all_results = existing_results.copy()
completed = {d: list(completed_map.get(d, [])) for d in DOMAINS}

for domain in DOMAINS:
    problems = all_problems[domain]
    traces = all_traces[domain]
    
    for condition in CONDITIONS:
        if condition in completed.get(domain, []):
            print(f'\n--- {domain}/{condition}: SKIPPED ---')
            continue
        
        print(f'\n--- {domain}/{condition} ---')
        
        for problem in tqdm(problems, desc=f'{domain}/{condition}'):
            trace_data = traces.get(str(problem.index), {})
            
            # Create prompt
            if condition == 'DIRECT':
                prompt = create_direct_prompt(problem)
                trace_used = None
                wrong_answer = None
            else:  # USE
                if not trace_data:
                    print(f'WARNING: No trace for {domain}/{problem.index}')
                    continue
                prompt = create_use_prompt(problem, trace_data['trace'], trace_data['wrong_answer'])
                trace_used = trace_data['trace']
                wrong_answer = trace_data['wrong_answer']
            
            # Call API
            response = call_api(prompt, model_config)
            answer = parse_answer(response, problem)
            
            # Check correctness
            is_correct = (answer == problem.correct_answer) if answer else False
            
            # Record result
            result = {
                'experiment_id': EXPERIMENT_ID,
                'domain': domain,
                'problem_index': problem.index,
                'model': MODEL_CHOICE,
                'model_short': model_short,
                'condition': condition,
                'model_answer': answer,
                'correct_answer': problem.correct_answer,
                'wrong_answer_in_trace': wrong_answer,
                'is_correct': is_correct,
                'followed_wrong': (answer == wrong_answer) if wrong_answer and answer else False,
                'raw_output': response,
                'timestamp': datetime.now().isoformat()
            }
            all_results.append(result)
            
            time.sleep(0.3)
        
        # Update completed and save checkpoint
        if domain not in completed:
            completed[domain] = []
        completed[domain].append(condition)
        save_checkpoint(all_results, model_short, completed)
        print(f'✓ {domain}/{condition} complete. Checkpoint saved.')

# Save final results
save_json(all_results, f"{SAVE_DIR_EXP}/results/exp_B_results_{model_short}.json")
print('\n' + '='*60)
print('✓ EXPERIMENT B COMPLETE!')
print('='*60)

## 10. Analyze Results

In [None]:
import matplotlib.pyplot as plt
from scipy import stats

df = pd.DataFrame(all_results)

print('='*60)
print(f'EXPERIMENT B RESULTS: {MODEL_CHOICE}')
print('='*60)

# Accuracy by domain × condition
pivot = df.pivot_table(
    values='is_correct',
    index='domain',
    columns='condition',
    aggfunc='mean'
)

print('\nAccuracy by Domain × Condition:')
print((pivot * 100).round(1))

# Calculate delta
pivot['Delta'] = pivot['USE'] - pivot['DIRECT']
print('\nΔ (USE - DIRECT):')
print((pivot['Delta'] * 100).round(1))

In [None]:
# CIF Analysis by domain
print('\n' + '='*60)
print('CIF ANALYSIS BY DOMAIN')
print('='*60)

cif_by_domain = {}

for domain in DOMAINS:
    df_domain = df[df['domain'] == domain]
    
    direct_results = df_domain[df_domain['condition'] == 'DIRECT'][['problem_index', 'is_correct']].copy()
    direct_results.columns = ['problem_index', 'direct_correct']
    
    use_results = df_domain[df_domain['condition'] == 'USE'][['problem_index', 'is_correct', 'followed_wrong']].copy()
    use_results.columns = ['problem_index', 'use_correct', 'followed_wrong']
    
    merged = direct_results.merge(use_results, on='problem_index')
    
    # CIF: DIRECT correct → USE wrong
    direct_correct = merged['direct_correct'].sum()
    cif_count = ((merged['direct_correct'] == True) & (merged['use_correct'] == False)).sum()
    cif_rate = cif_count / direct_correct if direct_correct > 0 else 0
    
    # Recovery: DIRECT wrong → USE correct
    direct_wrong = (~merged['direct_correct']).sum()
    recovery_count = ((merged['direct_correct'] == False) & (merged['use_correct'] == True)).sum()
    recovery_rate = recovery_count / direct_wrong if direct_wrong > 0 else 0
    
    # Followed wrong rate (among CIF cases)
    cif_mask = (merged['direct_correct'] == True) & (merged['use_correct'] == False)
    followed_wrong_in_cif = merged[cif_mask]['followed_wrong'].sum()
    
    cif_by_domain[domain] = {
        'CIF_count': int(cif_count),
        'CIF_rate': cif_rate,
        'Recovery_count': int(recovery_count),
        'Recovery_rate': recovery_rate,
        'Asymmetry': int(cif_count - recovery_count),
        'Followed_wrong_in_CIF': int(followed_wrong_in_cif)
    }
    
    print(f'\n{domain}:')
    print(f'  CIF: {cif_rate:.1%} ({cif_count}/{direct_correct})')
    print(f'  Recovery: {recovery_rate:.1%} ({recovery_count}/{direct_wrong})')
    print(f'  Asymmetry: {cif_count - recovery_count:+d}')
    print(f'  Followed wrong trace: {followed_wrong_in_cif}/{cif_count}')

In [None]:
# Visualization
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Plot 1: Accuracy by domain
ax1 = axes[0]
x = np.arange(len(DOMAINS))
width = 0.35

direct_accs = [pivot.loc[d, 'DIRECT'] * 100 for d in DOMAINS]
use_accs = [pivot.loc[d, 'USE'] * 100 for d in DOMAINS]

bars1 = ax1.bar(x - width/2, direct_accs, width, label='DIRECT', color='#2ca02c', edgecolor='black')
bars2 = ax1.bar(x + width/2, use_accs, width, label='USE', color='#d62728', edgecolor='black')

ax1.set_xticks(x)
ax1.set_xticklabels(DOMAINS, fontsize=10)
ax1.set_ylabel('Accuracy (%)', fontsize=12)
ax1.set_title(f'Accuracy by Domain\n{MODEL_CHOICE}', fontsize=13)
ax1.legend()
ax1.set_ylim(0, 105)

# Add value labels
for bars in [bars1, bars2]:
    for bar in bars:
        height = bar.get_height()
        ax1.text(bar.get_x() + bar.get_width()/2, height + 1,
                f'{height:.1f}', ha='center', va='bottom', fontsize=9)

# Plot 2: CIF by domain
ax2 = axes[1]
cif_rates = [cif_by_domain[d]['CIF_rate'] * 100 for d in DOMAINS]
recovery_rates = [cif_by_domain[d]['Recovery_rate'] * 100 for d in DOMAINS]

bars1 = ax2.bar(x - width/2, cif_rates, width, label='CIF', color='#d62728', edgecolor='black')
bars2 = ax2.bar(x + width/2, recovery_rates, width, label='Recovery', color='#2ca02c', edgecolor='black')

ax2.set_xticks(x)
ax2.set_xticklabels(DOMAINS, fontsize=10)
ax2.set_ylabel('Rate (%)', fontsize=12)
ax2.set_title(f'CIF vs Recovery by Domain\n{MODEL_CHOICE}', fontsize=13)
ax2.legend()

plt.tight_layout()
plt.savefig(f"{SAVE_DIR_EXP}/exp_B_results_{model_short}.png", dpi=300)
plt.show()

In [None]:
# Save summary
summary = {
    'experiment_id': EXPERIMENT_ID,
    'model': MODEL_CHOICE,
    'model_short': model_short,
    'date': EXPERIMENT_DATE,
    'n_problems_per_domain': N_PROBLEMS_PER_DOMAIN,
    'lambda': LAMBDA_FIXED,
    'domains': DOMAINS,
    'accuracy_by_domain_condition': pivot.to_dict(),
    'cif_by_domain': cif_by_domain
}

save_json(summary, f"{SAVE_DIR_EXP}/results/exp_B_summary_{model_short}.json")

print('\n' + '='*60)
print('SUMMARY')
print('='*60)
print(f'\nCIF generalizes across domains: ', end='')
all_positive_cif = all(cif_by_domain[d]['CIF_rate'] > 0.05 for d in DOMAINS)
print('YES ✓' if all_positive_cif else 'Partial')

print(f'\nFiles saved:')
print(f'  Results: {SAVE_DIR_EXP}/results/exp_B_results_{model_short}.json')
print(f'  Summary: {SAVE_DIR_EXP}/results/exp_B_summary_{model_short}.json')
print(f'  Figure: {SAVE_DIR_EXP}/exp_B_results_{model_short}.png')