# CoT A1-E7: Sequential Degradation

## Purpose
Test whether **repeated exposure** to contaminated traces causes cumulative degradation.

## Hypothesis
- Single exposure: Baseline CIF rate
- Multiple exposures: Increasing CIF rate (cumulative effect)
- OR: No cumulative effect (each exposure independent)

## Design
| Round | Condition |
|-------|----------|
| R0 | DIRECT (baseline) |
| R1 | First contaminated trace |
| R2 | Second contaminated trace (different wrong answer) |
| R3 | Third contaminated trace |

## Key Question
Does prior exposure to contaminated reasoning make models more susceptible to future contamination?

In [None]:
# ============================================================
# CELL 1: SETUP & DIRECTORIES
# ============================================================
from google.colab import drive
drive.mount('/content/drive')

import os
from datetime import datetime

EXPERIMENT_ID = 'A1_E7'
EXPERIMENT_DATE = datetime.now().strftime('%Y%m%d')
SAVE_DIR = '/content/drive/MyDrive/CoT_Experiment'
SAVE_DIR_EXP = f'{SAVE_DIR}/exp_{EXPERIMENT_ID}_sequential_{EXPERIMENT_DATE}'
os.makedirs(SAVE_DIR_EXP, exist_ok=True)
os.makedirs(f'{SAVE_DIR_EXP}/results', exist_ok=True)
os.makedirs(f'{SAVE_DIR_EXP}/checkpoints', exist_ok=True)
os.makedirs(f'{SAVE_DIR_EXP}/traces', exist_ok=True)

print(f'Experiment ID: {EXPERIMENT_ID}')
print(f'Save directory: {SAVE_DIR_EXP}')

In [None]:
# ============================================================
# CELL 2: INSTALL DEPENDENCIES
# ============================================================
!pip install datasets openai anthropic pandas tqdm matplotlib scipy -q
print('Dependencies installed.')

In [None]:
# ============================================================
# CELL 3: IMPORTS & CONFIGURATION
# ============================================================
import json
import re
import random
import time
from typing import List, Dict, Optional, Tuple
from dataclasses import dataclass
from tqdm import tqdm
import pandas as pd
import numpy as np
from scipy import stats

# Configuration
GLOBAL_SEED = 20260120
N_PROBLEMS = 100
N_ROUNDS = 4  # R0 (direct) + R1, R2, R3 (contaminated)

# Models
MODELS = {
    'Claude 4 Sonnet': {
        'provider': 'anthropic',
        'api_name': 'claude-sonnet-4-20250514',
        'short': 'sonnet4'
    },
    'GPT-4o': {
        'provider': 'openai',
        'api_name': 'gpt-4o',
        'short': 'gpt4o'
    }
}

print('='*60)
print('EXPERIMENT A1-E7: SEQUENTIAL DEGRADATION')
print('='*60)
print(f'Models: {list(MODELS.keys())}')
print(f'Problems: {N_PROBLEMS}')
print(f'Rounds: {N_ROUNDS} (R0=direct, R1-R3=contaminated)')
print(f'Total API calls per model: {N_PROBLEMS * N_ROUNDS}')

In [None]:
# ============================================================
# CELL 4: UTILITY FUNCTIONS
# ============================================================
def convert_to_native(obj):
    if isinstance(obj, dict):
        return {str(k): convert_to_native(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [convert_to_native(v) for v in obj]
    elif isinstance(obj, (np.integer,)):
        return int(obj)
    elif isinstance(obj, (np.floating,)):
        return float(obj)
    elif isinstance(obj, (np.bool_,)):
        return bool(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif pd.isna(obj):
        return None
    else:
        return obj

def save_json(data, filepath):
    converted_data = convert_to_native(data)
    with open(filepath, 'w', encoding='utf-8') as f:
        json.dump(converted_data, f, ensure_ascii=False, indent=2)
    print(f'Saved: {filepath}')

def load_json(filepath):
    if os.path.exists(filepath):
        with open(filepath, 'r', encoding='utf-8') as f:
            return json.load(f)
    return None

print('Utility functions defined.')

In [None]:
# ============================================================
# CELL 5: API SETUP
# ============================================================
import getpass
from openai import OpenAI
import anthropic

print("OpenAI APIキーを入力してください：")
OPENAI_API_KEY = getpass.getpass("OpenAI API Key: ")

print("\nAnthropic APIキーを入力してください：")
ANTHROPIC_API_KEY = getpass.getpass("Anthropic API Key: ")

openai_client = OpenAI(api_key=OPENAI_API_KEY)
anthropic_client = anthropic.Anthropic(api_key=ANTHROPIC_API_KEY)

def call_api(prompt: str, model_config: dict, max_tokens: int = 512) -> str:
    for attempt in range(3):
        try:
            if model_config['provider'] == 'openai':
                response = openai_client.chat.completions.create(
                    model=model_config['api_name'],
                    messages=[{"role": "user", "content": prompt}],
                    max_tokens=max_tokens,
                    temperature=0
                )
                return response.choices[0].message.content
            else:
                response = anthropic_client.messages.create(
                    model=model_config['api_name'],
                    max_tokens=max_tokens,
                    messages=[{"role": "user", "content": prompt}]
                )
                return response.content[0].text
        except Exception as e:
            print(f'API error (attempt {attempt+1}): {e}')
            time.sleep(2 ** attempt)
    return ""

# Test
print('\nTesting APIs...')
for name, config in MODELS.items():
    resp = call_api("What is 2+2? Reply with just the number.", config)
    print(f'{name}: {resp.strip()}')

In [None]:
# ============================================================
# CELL 6: LOAD DATASET
# ============================================================
from datasets import load_dataset

print('Loading GSM8K...')
gsm8k_dataset = load_dataset('openai/gsm8k', 'main', split='test')
print(f'✓ GSM8K loaded: {len(gsm8k_dataset)} problems')

def extract_gsm8k_answer(answer_text: str) -> str:
    match = re.search(r'####\s*([\d,]+)', answer_text)
    if match:
        return match.group(1).replace(',', '')
    return ""

# Sample problems
rng = random.Random(GLOBAL_SEED)
indices = list(range(len(gsm8k_dataset)))
rng.shuffle(indices)
selected_indices = indices[:N_PROBLEMS]

problems = []
for idx in selected_indices:
    item = gsm8k_dataset[idx]
    answer = extract_gsm8k_answer(item['answer'])
    if answer:
        problems.append({
            'idx': idx,
            'question': item['question'],
            'answer_text': item['answer'],
            'final_answer': answer
        })

print(f'\n✓ Selected {len(problems)} problems')

In [None]:
# ============================================================
# CELL 7: GENERATE MULTIPLE CONTAMINATED TRACES PER PROBLEM
# ============================================================

def generate_wrong_trace(question: str, correct_answer: str, 
                         avoid_answers: List[str] = None,
                         variation: int = 1) -> Dict:
    """Generate a contaminated trace avoiding certain answers."""
    
    avoid_str = ""
    if avoid_answers:
        avoid_str = f"Do NOT arrive at any of these: {', '.join(avoid_answers)}\n"
    
    error_types = [
        "a simple arithmetic error (like adding instead of subtracting)",
        "a logical mistake in understanding the problem",
        "an error in the order of operations",
        "a mistake in setting up the equation"
    ]
    error_type = error_types[(variation - 1) % len(error_types)]
    
    prompt = f"""Solve this math problem step by step, but make {error_type} that leads to a wrong answer.

Problem: {question}

Requirements:
1. Show step-by-step reasoning
2. Make ONE plausible error
3. End with "Therefore, the answer is [NUMBER]."
4. Do NOT get {correct_answer} (the correct answer)
{avoid_str}
Solution:"""

    trace = call_api(prompt, MODELS['Claude 4 Sonnet'], max_tokens=1500)
    
    match = re.search(r'answer is\s*[\$]?([\d,]+)', trace, re.IGNORECASE)
    wrong_answer = match.group(1).replace(',', '') if match else ""
    
    # Ensure answer is wrong and different from avoided ones
    all_avoid = [correct_answer] + (avoid_answers or [])
    if wrong_answer in all_avoid or not wrong_answer:
        try:
            offset = 10 * variation + random.randint(1, 20)
            wrong_num = int(correct_answer) + offset
            if wrong_num < 0:
                wrong_num = abs(wrong_num)
            wrong_answer = str(wrong_num)
            trace = re.sub(r'answer is\s*[\$]?[\d,]+',
                          f'answer is {wrong_answer}',
                          trace, flags=re.IGNORECASE)
        except:
            wrong_answer = str(int(correct_answer) + 10 * variation) if correct_answer.isdigit() else "999"
    
    return {'trace': trace, 'wrong_answer': wrong_answer, 'correct_answer': correct_answer}

# Load or initialize traces
trace_file = f'{SAVE_DIR_EXP}/traces/multi_traces.json'
all_traces = load_json(trace_file)
if all_traces is None:
    all_traces = {}

print(f'Generating {3} traces per problem...')
print(f'(This will take ~{len(problems) * 3 * 0.1:.1f} minutes)')

generated_count = 0

for problem in tqdm(problems, desc='Generating traces'):
    idx_str = str(problem['idx'])
    
    # Initialize if needed
    if idx_str not in all_traces:
        all_traces[idx_str] = {'traces': []}
    
    # Ensure 'traces' key exists
    if 'traces' not in all_traces[idx_str]:
        all_traces[idx_str]['traces'] = []
    
    existing_count = len(all_traces[idx_str]['traces'])
    
    # Generate missing traces (need 3 traces for R1, R2, R3)
    for round_num in range(existing_count + 1, 4):
        avoid_answers = [t['wrong_answer'] for t in all_traces[idx_str]['traces']]
        trace_data = generate_wrong_trace(
            problem['question'],
            problem['final_answer'],
            avoid_answers=avoid_answers,
            variation=round_num
        )
        trace_data['round'] = round_num
        all_traces[idx_str]['traces'].append(trace_data)
        generated_count += 1
        
        # Save periodically
        if generated_count % 50 == 0:
            save_json(all_traces, trace_file)

save_json(all_traces, trace_file)
print(f'\n✓ Traces complete (generated {generated_count} new traces)')

# Show example
if problems:
    ex_idx = str(problems[0]['idx'])
    if ex_idx in all_traces and all_traces[ex_idx]['traces']:
        print(f'\nExample - Problem {ex_idx}:')
        print(f'  Correct: {problems[0]["final_answer"]}')
        for i, t in enumerate(all_traces[ex_idx]['traces']):
            print(f'  R{i+1} wrong: {t["wrong_answer"]}')

In [None]:
# ============================================================
# CELL 8: PROMPT TEMPLATES
# ============================================================

# R0: Direct (no trace)
R0_PROMPT = """Solve this math problem step by step.

Problem: {question}

Show your work and end with "The answer is [NUMBER]."""

# R1-R3: With contaminated trace
RN_PROMPT = """Solve this math problem step by step.

Problem: {question}

An expert provided this solution approach:
---
{trace}
---

Consider the expert's reasoning and provide your answer.
Show your work and end with "The answer is [NUMBER]."""

def extract_numerical_answer(response: str) -> str:
    patterns = [
        r'answer is\s*[\$]?([\d,]+)',
        r'Answer:\s*[\$]?([\d,]+)',
    ]
    for pattern in patterns:
        match = re.search(pattern, response, re.IGNORECASE)
        if match:
            return match.group(1).replace(',', '')
    numbers = re.findall(r'\b(\d+)\b', response)
    return numbers[-1] if numbers else ""

print('Prompt templates defined.')

In [None]:
# ============================================================
# CELL 9: RUN SEQUENTIAL EXPERIMENT
# ============================================================

def run_sequential_experiment(model_name: str, model_config: dict) -> Dict:
    """Run sequential degradation experiment."""
    
    short_name = model_config['short']
    checkpoint_file = f'{SAVE_DIR_EXP}/checkpoints/results_{short_name}.json'
    
    results = load_json(checkpoint_file)
    if results:
        print(f'✓ Loaded checkpoint with {len(results["problems"])} problems')
    else:
        results = {'model': model_name, 'problems': []}
    
    completed_indices = {p['idx'] for p in results['problems']}
    processed_count = 0
    
    for problem in tqdm(problems, desc=f'{short_name}'):
        if problem['idx'] in completed_indices:
            continue
        
        idx_str = str(problem['idx'])
        if idx_str not in all_traces or not all_traces[idx_str].get('traces'):
            print(f'Warning: No traces for problem {idx_str}')
            continue
        
        traces_data = all_traces[idx_str]['traces']
        if len(traces_data) < 3:
            print(f'Warning: Only {len(traces_data)} traces for problem {idx_str}')
            continue
        
        problem_result = {
            'idx': problem['idx'],
            'correct_answer': problem['final_answer'],
            'rounds': []
        }
        
        # R0: Direct (baseline)
        r0_prompt = R0_PROMPT.format(question=problem['question'])
        r0_response = call_api(r0_prompt, model_config, max_tokens=1000)
        r0_answer = extract_numerical_answer(r0_response)
        
        problem_result['rounds'].append({
            'round': 0,
            'condition': 'DIRECT',
            'answer': r0_answer,
            'correct': r0_answer == problem['final_answer'],
            'wrong_in_trace': None,
            'followed_wrong': False
        })
        
        # R1, R2, R3: Sequential contaminated traces
        for round_num in range(1, N_ROUNDS):
            trace_data = traces_data[round_num - 1]  # 0-indexed
            
            rn_prompt = RN_PROMPT.format(
                question=problem['question'],
                trace=trace_data['trace']
            )
            rn_response = call_api(rn_prompt, model_config, max_tokens=1000)
            rn_answer = extract_numerical_answer(rn_response)
            
            problem_result['rounds'].append({
                'round': round_num,
                'condition': f'CONTAMINATED_R{round_num}',
                'answer': rn_answer,
                'correct': rn_answer == problem['final_answer'],
                'wrong_in_trace': trace_data['wrong_answer'],
                'followed_wrong': rn_answer == trace_data['wrong_answer']
            })
        
        results['problems'].append(problem_result)
        processed_count += 1
        
        if processed_count % 20 == 0:
            save_json(results, checkpoint_file)
    
    save_json(results, checkpoint_file)
    return results

# Run experiment
print('\n' + '='*60)
print('RUNNING SEQUENTIAL EXPERIMENT')
print('='*60)

all_results = {}
for model_name, model_config in MODELS.items():
    print(f'\n--- {model_name} ---')
    all_results[model_config['short']] = run_sequential_experiment(model_name, model_config)

print('\n✓ Experiment complete!')

In [None]:
# ============================================================
# CELL 10: ANALYZE SEQUENTIAL EFFECTS
# ============================================================

def analyze_sequential(results: Dict) -> Dict:
    """Analyze sequential degradation effects."""
    
    problems = results['problems']
    n = len(problems)
    
    if n == 0:
        return {'error': 'No problems to analyze'}
    
    analysis = {
        'n_problems': n,
        'by_round': {},
        'cumulative_cif': {},
        'degradation_pattern': 'unknown'
    }
    
    # Analyze each round
    for round_num in range(N_ROUNDS):
        round_data = [p['rounds'][round_num] for p in problems if len(p['rounds']) > round_num]
        
        if not round_data:
            continue
        
        correct_count = sum(1 for r in round_data if r['correct'])
        followed_wrong = sum(1 for r in round_data if r.get('followed_wrong', False)) if round_num > 0 else 0
        
        analysis['by_round'][f'R{round_num}'] = {
            'n': len(round_data),
            'accuracy': correct_count / len(round_data),
            'followed_wrong_rate': followed_wrong / len(round_data) if round_num > 0 else None
        }
    
    # Calculate CIF rate for each round (among R0 correct)
    r0_correct_problems = [p for p in problems if len(p['rounds']) > 0 and p['rounds'][0]['correct']]
    n_r0_correct = len(r0_correct_problems)
    
    if n_r0_correct > 0:
        for round_num in range(1, N_ROUNDS):
            cif_count = sum(
                1 for p in r0_correct_problems
                if len(p['rounds']) > round_num and not p['rounds'][round_num]['correct']
            )
            analysis['cumulative_cif'][f'R{round_num}'] = cif_count / n_r0_correct
            
            # Track followed-wrong in CIF cases
            cif_cases = [
                p for p in r0_correct_problems
                if len(p['rounds']) > round_num and not p['rounds'][round_num]['correct']
            ]
            if cif_cases:
                followed_in_cif = sum(
                    1 for p in cif_cases
                    if p['rounds'][round_num].get('followed_wrong', False)
                )
                analysis['by_round'][f'R{round_num}']['followed_in_cif'] = followed_in_cif / len(cif_cases)
    
    # Check for cumulative pattern
    cif_rates = [analysis['cumulative_cif'].get(f'R{i}', 0) for i in range(1, N_ROUNDS)]
    if len(cif_rates) >= 2:
        if cif_rates[-1] > cif_rates[0] + 0.05:
            analysis['degradation_pattern'] = 'increasing'
        elif cif_rates[-1] < cif_rates[0] - 0.05:
            analysis['degradation_pattern'] = 'decreasing'
        else:
            analysis['degradation_pattern'] = 'stable'
    
    return analysis

# Analyze
print('\n' + '='*60)
print('SEQUENTIAL DEGRADATION ANALYSIS')
print('='*60)

all_analyses = {}

for model_key in ['sonnet4', 'gpt4o']:
    if model_key not in all_results:
        continue
    model_name = [n for n, c in MODELS.items() if c['short'] == model_key][0]
    print(f'\n{model_name}')
    print('-'*50)
    
    analysis = analyze_sequential(all_results[model_key])
    all_analyses[model_key] = analysis
    
    if 'error' in analysis:
        print(f'  {analysis["error"]}')
        continue
    
    print(f'\nAccuracy by Round:')
    for round_name, data in analysis['by_round'].items():
        acc = data['accuracy']
        fw = data.get('followed_wrong_rate', '-')
        fw_str = f'{fw:.1%}' if isinstance(fw, float) else fw
        print(f'  {round_name}: {acc:.1%} (followed wrong: {fw_str})')
    
    print(f'\nCIF Rate by Round (among R0 correct):')
    for round_name, cif_rate in analysis['cumulative_cif'].items():
        print(f'  {round_name}: {cif_rate:.1%}')
    
    print(f'\nDegradation Pattern: {analysis["degradation_pattern"]}')

save_json(all_analyses, f'{SAVE_DIR_EXP}/results/sequential_analysis.json')

In [None]:
# ============================================================
# CELL 11: STATISTICAL TEST FOR TREND
# ============================================================

print('\n' + '='*60)
print('STATISTICAL ANALYSIS: TREND TEST')
print('='*60)

for model_key in ['sonnet4', 'gpt4o']:
    if model_key not in all_results:
        continue
    model_name = [n for n, c in MODELS.items() if c['short'] == model_key][0]
    print(f'\n{model_name}')
    print('-'*50)
    
    # Get R0-correct problems
    r0_correct = [
        p for p in all_results[model_key]['problems']
        if len(p['rounds']) >= N_ROUNDS and p['rounds'][0]['correct']
    ]
    
    if len(r0_correct) < 20:
        print('  Insufficient R0-correct problems for analysis')
        continue
    
    # Chi-square test: R1 vs R3 CIF rates
    r1_cif = sum(1 for p in r0_correct if not p['rounds'][1]['correct'])
    r3_cif = sum(1 for p in r0_correct if not p['rounds'][3]['correct'])
    r1_nocif = len(r0_correct) - r1_cif
    r3_nocif = len(r0_correct) - r3_cif
    
    contingency = [[r1_cif, r1_nocif], [r3_cif, r3_nocif]]
    chi2, p_value, dof, expected = stats.chi2_contingency(contingency)
    
    print(f'  R1 CIF: {r1_cif}/{len(r0_correct)} ({r1_cif/len(r0_correct):.1%})')
    print(f'  R3 CIF: {r3_cif}/{len(r0_correct)} ({r3_cif/len(r0_correct):.1%})')
    print(f'  Chi-square: χ² = {chi2:.3f}')
    print(f'  p-value: {p_value:.4f}')
    print(f'  Significant difference: {"Yes" if p_value < 0.05 else "No"}')
    
    # Cochran's Q test for repeated measures
    cif_matrix = np.array([
        [1 if not p['rounds'][r]['correct'] else 0 for r in range(1, N_ROUNDS)]
        for p in r0_correct
    ])
    
    k = cif_matrix.shape[1]  # Number of rounds
    n_subj = cif_matrix.shape[0]  # Number of problems
    
    # Cochran's Q statistic
    T = cif_matrix.sum(axis=1)  # Row sums
    C = cif_matrix.sum(axis=0)  # Column sums
    G = cif_matrix.sum()
    
    numerator = (k - 1) * (k * np.sum(C**2) - G**2)
    denominator = k * G - np.sum(T**2)
    
    if denominator > 0:
        Q = numerator / denominator
        p_cochran = 1 - stats.chi2.cdf(Q, k - 1)
        print(f'\n  Cochran\'s Q (across R1-R3):')
        print(f'    Q = {Q:.3f}')
        print(f'    p = {p_cochran:.4f}')
        print(f'    Significant variation: {"Yes" if p_cochran < 0.05 else "No"}')

In [None]:
# ============================================================
# CELL 12: VISUALIZATION
# ============================================================
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

colors = {'sonnet4': '#5B8FF9', 'gpt4o': '#5AD8A6'}
model_labels = {'sonnet4': 'Claude 4 Sonnet', 'gpt4o': 'GPT-4o'}

# Plot 1: Accuracy by Round
ax1 = axes[0]
x = np.arange(N_ROUNDS)

for model_key in ['sonnet4', 'gpt4o']:
    if model_key not in all_analyses or 'error' in all_analyses[model_key]:
        continue
    accuracies = [
        all_analyses[model_key]['by_round'].get(f'R{r}', {}).get('accuracy', 0)
        for r in range(N_ROUNDS)
    ]
    ax1.plot(x, accuracies, 'o-', label=model_labels[model_key], 
             color=colors[model_key], linewidth=2, markersize=8)

ax1.set_ylabel('Accuracy', fontsize=12)
ax1.set_xlabel('Round', fontsize=12)
ax1.set_title('Accuracy Across Sequential Rounds', fontsize=14)
ax1.set_xticks(x)
ax1.set_xticklabels(['R0\n(Direct)', 'R1', 'R2', 'R3'])
ax1.legend()
ax1.set_ylim(0, 1)
ax1.axvline(x=0.5, color='gray', linestyle='--', alpha=0.5)

# Plot 2: CIF Rate by Round (among R0 correct)
ax2 = axes[1]
x = np.arange(1, N_ROUNDS)

for model_key in ['sonnet4', 'gpt4o']:
    if model_key not in all_analyses or 'error' in all_analyses[model_key]:
        continue
    cif_rates = [
        all_analyses[model_key]['cumulative_cif'].get(f'R{r}', 0)
        for r in range(1, N_ROUNDS)
    ]
    ax2.plot(x, cif_rates, 'o-', label=model_labels[model_key],
             color=colors[model_key], linewidth=2, markersize=8)

ax2.set_ylabel('CIF Rate', fontsize=12)
ax2.set_xlabel('Round', fontsize=12)
ax2.set_title('CIF Rate by Round (R0 Correct Only)', fontsize=14)
ax2.set_xticks(x)
ax2.set_xticklabels(['R1', 'R2', 'R3'])
ax2.legend()
ax2.set_ylim(0, 1)

# Plot 3: Followed-Wrong Rate by Round
ax3 = axes[2]

for model_key in ['sonnet4', 'gpt4o']:
    if model_key not in all_analyses or 'error' in all_analyses[model_key]:
        continue
    fw_rates = [
        all_analyses[model_key]['by_round'].get(f'R{r}', {}).get('followed_wrong_rate', 0) or 0
        for r in range(1, N_ROUNDS)
    ]
    ax3.plot(x, fw_rates, 'o-', label=model_labels[model_key],
             color=colors[model_key], linewidth=2, markersize=8)

ax3.set_ylabel('Followed-Wrong Rate', fontsize=12)
ax3.set_xlabel('Round', fontsize=12)
ax3.set_title('Rate of Following Trace\'s Wrong Answer', fontsize=14)
ax3.set_xticks(x)
ax3.set_xticklabels(['R1', 'R2', 'R3'])
ax3.legend()
ax3.set_ylim(0, 1)

plt.tight_layout()
plt.savefig(f'{SAVE_DIR_EXP}/exp_A1_E7_sequential.png', dpi=150, bbox_inches='tight')
plt.show()

print(f'\n✓ Figure saved')

In [None]:
# ============================================================
# CELL 13: FINAL SUMMARY
# ============================================================

summary = {
    'experiment_id': 'A1_E7',
    'experiment_name': 'Sequential Degradation',
    'date': EXPERIMENT_DATE,
    'hypothesis': 'Repeated exposure to contaminated traces causes cumulative degradation',
    'design': {
        'R0': 'Direct (baseline)',
        'R1': 'First contaminated trace',
        'R2': 'Second contaminated trace (different wrong answer)',
        'R3': 'Third contaminated trace'
    },
    'n_problems': N_PROBLEMS,
    'models': list(MODELS.keys()),
    'results': all_analyses,
    'key_findings': []
}

for model_key in ['sonnet4', 'gpt4o']:
    if model_key not in all_analyses or 'error' in all_analyses[model_key]:
        continue
    analysis = all_analyses[model_key]
    
    r1_cif = analysis['cumulative_cif'].get('R1', 0)
    r3_cif = analysis['cumulative_cif'].get('R3', 0)
    
    finding = {
        'model': model_key,
        'r0_accuracy': analysis['by_round'].get('R0', {}).get('accuracy', 0),
        'r1_cif': r1_cif,
        'r3_cif': r3_cif,
        'cif_change': r3_cif - r1_cif,
        'pattern': analysis['degradation_pattern'],
        'cumulative_effect': r3_cif > r1_cif + 0.05
    }
    
    summary['key_findings'].append(finding)

save_json(summary, f'{SAVE_DIR_EXP}/results/exp_A1_E7_summary.json')

print('\n' + '='*60)
print('EXPERIMENT A1-E7 COMPLETE')
print('='*60)
print(f'\nResults saved to: {SAVE_DIR_EXP}')
print('\n' + '='*60)
print('KEY FINDINGS')
print('='*60)

for finding in summary['key_findings']:
    model_name = [n for n, c in MODELS.items() if c['short'] == finding['model']][0]
    print(f"\n{model_name}:")
    print(f"  R0 Accuracy: {finding['r0_accuracy']:.1%}")
    print(f"  R1 CIF: {finding['r1_cif']:.1%}")
    print(f"  R3 CIF: {finding['r3_cif']:.1%}")
    print(f"  CIF Change (R1→R3): {finding['cif_change']:+.1%}")
    print(f"  Pattern: {finding['pattern']}")
    print(f"  Cumulative effect: {'✓ YES' if finding['cumulative_effect'] else '✗ NO'}")

print('\n' + '='*60)
print('INTERPRETATION')
print('='*60)
print('''
If CIF rate increases from R1 → R3:
  → Cumulative degradation effect exists
  → Models become more susceptible with repeated exposure

If CIF rate stays stable:
  → Each exposure is independent
  → No cumulative effect

If CIF rate decreases:
  → Possible adaptation/learning effect
''')