# A3 Scaling Law Full Experiment: GPT Models (200 problems)

**Purpose**: Measure Œª* (Backfire boundary) across GPT model family

**Available Models**:
| Model | API Name | Capability |
|-------|----------|------------|
| GPT-3.5 Turbo | gpt-3.5-turbo | Low |
| GPT-4o-mini | gpt-4o-mini | Medium |
| GPT-4o | gpt-4o | High |

**Pilot Results (50 problems)**:
- GPT-3.5: Baseline 46%, Œª* = 0.693
- GPT-4o-mini: Baseline 44%, Œª* = 0.783
- GPT-4o: Baseline 56.28%, Œª* = 0.865

## 0. Google Drive Connection

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
from datetime import datetime

EXPERIMENT_DATE = datetime.now().strftime('%Y%m%d')
SAVE_DIR = '/content/drive/MyDrive/CoT_Experiment'
os.makedirs(SAVE_DIR, exist_ok=True)

print(f'Base directory: {SAVE_DIR}')

## 1. Install Dependencies

In [None]:
!pip install datasets openai pandas tqdm matplotlib scipy -q
print('Dependencies installed.')

## 2. ‚≠ê MODEL SELECTION ‚≠ê

**„Åì„Åì„Åß„É¢„Éá„É´„ÇíÈÅ∏Êäû„Åó„Å¶„Åè„Å†„Åï„ÅÑÔºÅ**

In [None]:
#@title Select GPT Model { run: "auto" }
#@markdown **„É¢„Éá„É´„ÇíÈÅ∏Êäû„Åó„Å¶„Åè„Å†„Åï„ÅÑÔºö**

MODEL_CHOICE = "GPT-3.5 Turbo" #@param ["GPT-3.5 Turbo", "GPT-4o-mini", "GPT-4o"]

# Model configuration mapping
MODEL_CONFIG = {
    "GPT-3.5 Turbo": {
        "api_name": "gpt-3.5-turbo",
        "short_name": "gpt35",
        "capability": "Low",
        "cost_per_1k_input": 0.0005,
        "cost_per_1k_output": 0.0015
    },
    "GPT-4o-mini": {
        "api_name": "gpt-4o-mini",
        "short_name": "gpt4omini",
        "capability": "Medium",
        "cost_per_1k_input": 0.00015,
        "cost_per_1k_output": 0.0006
    },
    "GPT-4o": {
        "api_name": "gpt-4o",
        "short_name": "gpt4o",
        "capability": "High",
        "cost_per_1k_input": 0.005,
        "cost_per_1k_output": 0.015
    }
}

# Set configuration
config = MODEL_CONFIG[MODEL_CHOICE]
API_MODEL = config["api_name"]
MODEL_SHORT = config["short_name"]
MODEL_CAPABILITY = config["capability"]

# Create experiment directory
SAVE_DIR_EXP = f'{SAVE_DIR}/a3_gpt_{MODEL_SHORT}_{EXPERIMENT_DATE}'
os.makedirs(SAVE_DIR_EXP, exist_ok=True)
os.makedirs(f'{SAVE_DIR_EXP}/results', exist_ok=True)
os.makedirs(f'{SAVE_DIR_EXP}/checkpoints', exist_ok=True)

print('='*60)
print(f'SELECTED MODEL: {MODEL_CHOICE}')
print('='*60)
print(f'  API Name: {API_MODEL}')
print(f'  Capability: {MODEL_CAPABILITY}')
print(f'  Save Directory: {SAVE_DIR_EXP}')
print('='*60)

## 3. Experiment Configuration

In [None]:
import json
import re
import random
import time
import hashlib
from typing import List, Dict, Optional, Any
from dataclasses import dataclass, asdict
from datetime import datetime
from tqdm import tqdm
import pandas as pd
import numpy as np

# =============================================================================
# Configuration - MUST MATCH ORIGINAL EXPERIMENTS
# =============================================================================
GLOBAL_SEED = 20251224  # Same as all other experiments
N_PROBLEMS = 200  # Full experiment

# Experimental conditions
I_FIXED = 10  # Same trace depth
LAMBDA_VALUES = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]

# API settings
API_MAX_TOKENS = 256
API_RATE_LIMIT_DELAY = 0.3  # OpenAI is generally faster
CHECKPOINT_EVERY = 50

print('='*60)
print('EXPERIMENT CONFIGURATION')
print('='*60)
print(f'  Model: {MODEL_CHOICE} ({API_MODEL})')
print(f'  GLOBAL_SEED: {GLOBAL_SEED}')
print(f'  N_PROBLEMS: {N_PROBLEMS}')
print(f'  I (fixed): {I_FIXED}')
print(f'  Œª values: {LAMBDA_VALUES}')
print(f'  Total inferences: {N_PROBLEMS * (len(LAMBDA_VALUES) + 1)}')
print('='*60)

# Estimate cost
est_tokens_per_call = 500
total_calls = N_PROBLEMS * (len(LAMBDA_VALUES) + 1)
est_cost = (total_calls * est_tokens_per_call / 1000) * (config['cost_per_1k_input'] + config['cost_per_1k_output'])
print(f'\nüí∞ Estimated cost: ${est_cost:.2f}')

## 4. Data Structures & Utilities

In [None]:
@dataclass
class GSM8KProblem:
    index: int
    question: str
    answer_text: str
    final_answer: int

def extract_final_answer(answer_text: str) -> int:
    match = re.search(r'####\s*([\d,]+)', answer_text)
    if match:
        return int(match.group(1).replace(',', ''))
    raise ValueError(f'Could not extract final answer')

def save_json(data: Any, filepath: str):
    with open(filepath, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=2)
    print(f'Saved: {filepath}')

def load_json(filepath: str) -> Any:
    with open(filepath, 'r', encoding='utf-8') as f:
        return json.load(f)

def derive_seed(global_seed: int, problem_id: int, I: int, lam: float) -> int:
    key = f"{global_seed}|{problem_id}|I={I}|lam={lam}"
    h = hashlib.sha256(key.encode("utf-8")).hexdigest()
    return int(h[:8], 16)

## 5. Load GSM8K and Select Problems

In [None]:
from datasets import load_dataset

dataset = load_dataset('gsm8k', 'main', split='test')
print(f'GSM8K test set loaded: {len(dataset)} problems')

def select_problems(dataset, n_problems: int, seed: int) -> List[int]:
    rng = random.Random(seed)
    indices = list(range(len(dataset)))
    rng.shuffle(indices)
    return sorted(indices[:n_problems])

# Same seed ‚Üí Same problems as all other experiments
selected_indices = select_problems(dataset, N_PROBLEMS, GLOBAL_SEED)
print(f'Selected {len(selected_indices)} problems (identical across all experiments)')

problems = []
for idx in selected_indices:
    item = dataset[idx]
    try:
        final_ans = extract_final_answer(item['answer'])
        prob = GSM8KProblem(
            index=idx,
            question=item['question'],
            answer_text=item['answer'],
            final_answer=final_ans
        )
        problems.append(prob)
    except ValueError as e:
        print(f'Skipping problem {idx}: {e}')

print(f'\nLoaded {len(problems)} problems for experiment')

## 6. OpenAI API Setup

In [None]:
from openai import OpenAI
import getpass

# Get API key via input
print("OpenAI API„Ç≠„Éº„ÇíÂÖ•Âäõ„Åó„Å¶„Åè„Å†„Åï„ÅÑÔºö")
OPENAI_API_KEY = getpass.getpass("API Key: ")

client = OpenAI(api_key=OPENAI_API_KEY)

def call_openai(prompt: str, max_tokens: int = API_MAX_TOKENS) -> str:
    """Call OpenAI API with retry logic"""
    for attempt in range(3):
        try:
            response = client.chat.completions.create(
                model=API_MODEL,
                messages=[{"role": "user", "content": prompt}],
                max_tokens=max_tokens,
                temperature=0  # Deterministic
            )
            return response.choices[0].message.content
        except Exception as e:
            print(f'API error (attempt {attempt+1}): {e}')
            time.sleep(2 ** attempt)
    return ""

# Test API
print(f'\nTesting {MODEL_CHOICE}...')
test_response = call_openai("What is 2+2? Reply with just the number.")
print(f'API test: {test_response}')
print(f'Model: {API_MODEL} ‚úì')

## 7. CoT Generation Functions

In [None]:
def generate_clean_cot(problem: GSM8KProblem, I: int) -> List[str]:
    """
    Generate I clean reasoning steps from GSM8K answer.
    """
    lines = problem.answer_text.split('\n')
    steps = [l.strip() for l in lines if l.strip() and not l.strip().startswith('####')]
    
    if len(steps) >= I:
        return steps[:I]
    else:
        while len(steps) < I:
            steps.append(f"Step {len(steps)+1}: Continue calculation.")
        return steps

def generate_corrupted_step(problem: GSM8KProblem, step_idx: int, rng: random.Random) -> str:
    """
    Generate a corrupted reasoning step.
    """
    corruption_templates = [
        f"Step {step_idx+1}: Let's multiply the values: {rng.randint(10, 100)} √ó {rng.randint(2, 10)} = {rng.randint(100, 1000)}.",
        f"Step {step_idx+1}: Adding the totals: {rng.randint(50, 200)} + {rng.randint(50, 200)} = {rng.randint(100, 500)}.",
        f"Step {step_idx+1}: The difference is: {rng.randint(100, 500)} - {rng.randint(20, 100)} = {rng.randint(50, 400)}.",
        f"Step {step_idx+1}: Dividing gives us: {rng.randint(100, 1000)} √∑ {rng.randint(2, 10)} = {rng.randint(10, 200)}.",
        f"Step {step_idx+1}: Converting: {rng.randint(1, 10)} √ó {rng.randint(10, 100)} = {rng.randint(10, 1000)}."
    ]
    return rng.choice(corruption_templates)

def generate_mixed_cot(problem: GSM8KProblem, I: int, lam: float, seed: int) -> List[str]:
    """
    Generate CoT with Œª proportion of corrupted steps.
    """
    rng = random.Random(seed)
    clean_steps = generate_clean_cot(problem, I)
    
    n_corrupt = int(round(I * lam))
    corrupt_indices = set(rng.sample(range(I), n_corrupt)) if n_corrupt > 0 else set()
    
    mixed_steps = []
    for i in range(I):
        if i in corrupt_indices:
            mixed_steps.append(generate_corrupted_step(problem, i, rng))
        else:
            mixed_steps.append(clean_steps[i] if i < len(clean_steps) else f"Step {i+1}: Continue.")
    
    return mixed_steps

## 8. Prompt Creation Functions

In [None]:
def create_direct_prompt(problem: GSM8KProblem) -> str:
    return f"""Solve this math problem. Give ONLY the final numerical answer in JSON format.

Problem: {problem.question}

Reply with ONLY: {{"final": <number>}}"""

def create_cot_prompt(problem: GSM8KProblem, cot_steps: List[str]) -> str:
    steps_text = '\n'.join(cot_steps)
    return f"""Here is a math problem with a provided reasoning trace.
Follow the reasoning and give the final numerical answer in JSON format.

Problem: {problem.question}

Reasoning trace:
{steps_text}

Based on this reasoning, what is the final answer?
Reply with ONLY: {{"final": <number>}}"""

def parse_answer(response: str) -> Optional[int]:
    """Extract numerical answer from model response"""
    # Try JSON format
    try:
        match = re.search(r'\{[^}]*"final"\s*:\s*(\d+)[^}]*\}', response)
        if match:
            return int(match.group(1))
    except:
        pass
    
    # Try plain number
    try:
        numbers = re.findall(r'\b(\d+)\b', response)
        if numbers:
            return int(numbers[-1])
    except:
        pass
    
    return None

## 9. Run Direct Condition (Baseline)

In [None]:
print('='*60)
print(f'PHASE 1: Direct Condition (Baseline) - {MODEL_CHOICE}')
print('='*60)

direct_results = []

for i, problem in enumerate(tqdm(problems, desc='Direct')):
    prompt = create_direct_prompt(problem)
    response = call_openai(prompt)
    
    answer = parse_answer(response)
    is_correct = (answer == problem.final_answer) if answer is not None else False
    
    result = {
        'problem_index': problem.index,
        'condition': 'direct',
        'model': API_MODEL,
        'model_answer': answer,
        'correct_answer': problem.final_answer,
        'is_correct': is_correct,
        'raw_output': response,
        'timestamp': datetime.now().isoformat()
    }
    direct_results.append(result)
    
    time.sleep(API_RATE_LIMIT_DELAY)
    
    if (i + 1) % CHECKPOINT_EVERY == 0:
        save_json(direct_results, f'{SAVE_DIR_EXP}/checkpoints/direct_checkpoint_{i+1}.json')

# Save final direct results
save_json(direct_results, f'{SAVE_DIR_EXP}/results/direct_results_{MODEL_SHORT}.json')

# Calculate baseline accuracy
baseline_acc = sum(r['is_correct'] for r in direct_results) / len(direct_results)
print(f'\n‚úì {MODEL_CHOICE} Direct (Baseline) Accuracy: {baseline_acc:.1%}')

## 10. Run CoT Conditions (Œª sweep)

In [None]:
print('='*60)
print(f'PHASE 2: CoT Conditions (Œª sweep) - {MODEL_CHOICE}')
print('='*60)

cot_results = []
total_trials = len(problems) * len(LAMBDA_VALUES)

with tqdm(total=total_trials, desc='CoT sweep') as pbar:
    for lam in LAMBDA_VALUES:
        print(f'\n--- Œª = {lam} ---')
        
        for problem in problems:
            # Generate mixed CoT
            seed = derive_seed(GLOBAL_SEED, problem.index, I_FIXED, lam)
            cot_steps = generate_mixed_cot(problem, I_FIXED, lam, seed)
            
            # Query model
            prompt = create_cot_prompt(problem, cot_steps)
            response = call_openai(prompt)
            
            answer = parse_answer(response)
            is_correct = (answer == problem.final_answer) if answer is not None else False
            
            result = {
                'problem_index': problem.index,
                'condition': 'cot',
                'model': API_MODEL,
                'I': I_FIXED,
                'lam': lam,
                'A_target': 1 - lam,
                'model_answer': answer,
                'correct_answer': problem.final_answer,
                'is_correct': is_correct,
                'raw_output': response,
                'timestamp': datetime.now().isoformat()
            }
            cot_results.append(result)
            
            time.sleep(API_RATE_LIMIT_DELAY)
            pbar.update(1)
        
        # Checkpoint after each Œª
        save_json(cot_results, f'{SAVE_DIR_EXP}/checkpoints/cot_checkpoint_lam{lam}.json')

# Save final CoT results
save_json(cot_results, f'{SAVE_DIR_EXP}/results/cot_results_{MODEL_SHORT}.json')
print('\n‚úì CoT experiment complete!')

## 11. Analyze Results

In [None]:
import matplotlib.pyplot as plt

# Convert to DataFrame
cot_df = pd.DataFrame(cot_results)

# Calculate accuracy by Œª
acc_by_lam = cot_df.groupby('lam')['is_correct'].mean().to_dict()

print('='*60)
print(f'{MODEL_CHOICE} RESULTS')
print('='*60)
print(f'\nDirect (Baseline): {baseline_acc:.1%}')
print('\nCoT accuracy by Œª:')
for lam, acc in sorted(acc_by_lam.items()):
    marker = '‚Üê BACKFIRE' if acc < baseline_acc else ''
    print(f'  Œª={lam:.1f}: {acc:.1%} {marker}')

## 12. Estimate Œª*

In [None]:
from scipy.interpolate import interp1d
from scipy.optimize import brentq

lam_points = np.array(sorted(acc_by_lam.keys()))
acc_points = np.array([acc_by_lam[l] for l in lam_points])

def estimate_lambda_crit(lam_arr, acc_arr, baseline):
    """Find Œª where CoT accuracy crosses baseline"""
    f = interp1d(lam_arr, acc_arr - baseline, kind='linear', fill_value='extrapolate')
    try:
        for i in range(len(lam_arr) - 1):
            if (acc_arr[i] - baseline) * (acc_arr[i+1] - baseline) < 0:
                return brentq(f, lam_arr[i], lam_arr[i+1])
    except:
        pass
    if acc_arr[-1] > baseline:
        return 1.0
    return None

lambda_crit = estimate_lambda_crit(lam_points, acc_points, baseline_acc)

print('='*60)
print(f'Œª* ESTIMATION FOR {MODEL_CHOICE}')
print('='*60)

if lambda_crit is not None:
    print(f'\n  Œª* = {lambda_crit:.3f}')
    print(f'  A* = {1 - lambda_crit:.3f}')
else:
    print('\n  Could not estimate Œª* (no crossing detected)')

## 13. Visualization

In [None]:
fig, ax = plt.subplots(figsize=(10, 6))

# Plot accuracy curve
lams = list(acc_by_lam.keys())
accs = [acc_by_lam[l] * 100 for l in lams]
ax.plot(lams, accs, 'o-', color='#d62728', linewidth=2.5, markersize=10, label=f'{MODEL_CHOICE}')

# Baseline
ax.axhline(y=baseline_acc * 100, color='blue', linestyle='--', linewidth=2, label=f'Baseline ({baseline_acc:.1%})')

# Mark Œª*
if lambda_crit is not None and lambda_crit < 1.0:
    ax.axvline(x=lambda_crit, color='green', linestyle=':', linewidth=2, alpha=0.7)
    ax.annotate(f'Œª*={lambda_crit:.3f}', xy=(lambda_crit, baseline_acc*100), 
                xytext=(lambda_crit+0.1, baseline_acc*100+5),
                fontsize=12, color='green',
                arrowprops=dict(arrowstyle='->', color='green'))

ax.set_xlabel('Corruption Rate (Œª)', fontsize=13)
ax.set_ylabel('Accuracy (%)', fontsize=13)
ax.set_title(f'CoT Collapse Curve: {MODEL_CHOICE}', fontsize=14)
ax.set_xlim(-0.05, 1.05)
ax.set_ylim(20, 105)
ax.legend(loc='lower left', fontsize=11)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f'{SAVE_DIR_EXP}/collapse_curve_{MODEL_SHORT}.png', dpi=300, bbox_inches='tight')
plt.show()
print(f'Saved: {SAVE_DIR_EXP}/collapse_curve_{MODEL_SHORT}.png')

## 14. Save Summary

In [None]:
summary = {
    'experiment': 'A3_Scaling_Law_GPT',
    'date': EXPERIMENT_DATE,
    'model': {
        'name': MODEL_CHOICE,
        'api_name': API_MODEL,
        'capability': MODEL_CAPABILITY
    },
    'n_problems': len(problems),
    'baseline_accuracy': baseline_acc,
    'accuracy_by_lambda': {str(k): v for k, v in acc_by_lam.items()},
    'lambda_crit': lambda_crit,
    'a_crit': 1 - lambda_crit if lambda_crit else None
}

save_json(summary, f'{SAVE_DIR_EXP}/results/summary_{MODEL_SHORT}.json')

print('\n' + '='*60)
print(f'EXPERIMENT COMPLETE: {MODEL_CHOICE}')
print('='*60)
print(f'''
‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
‚îÇ {MODEL_CHOICE:^38} ‚îÇ
‚îú‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î§
‚îÇ Baseline Accuracy: {baseline_acc:>17.1%} ‚îÇ
‚îÇ Œª* (Backfire):     {lambda_crit if lambda_crit else "N/A":>17} ‚îÇ
‚îÇ A* (Critical):     {1-lambda_crit if lambda_crit else "N/A":>17} ‚îÇ
‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
''')
print(f'Results saved to: {SAVE_DIR_EXP}')

## 15. Compare with All Models

In [None]:
# All results so far (update as experiments complete)
all_results = {
    # Claude family
    'Claude 3 Haiku': {'baseline': 0.322, 'lambda_crit': 0.826},
    'Claude 3.5 Haiku': {'baseline': 0.467, 'lambda_crit': 1.0},
    'Claude 4 Sonnet': {'baseline': 0.925, 'lambda_crit': None},  # TBD
    # GPT family
    'GPT-3.5': {'baseline': 0.46, 'lambda_crit': 0.693},  # Pilot
    'GPT-4o-mini': {'baseline': 0.44, 'lambda_crit': 0.783},  # Pilot
    'GPT-4o': {'baseline': 0.563, 'lambda_crit': 0.865},
}

# Add current result
all_results[MODEL_CHOICE] = {'baseline': baseline_acc, 'lambda_crit': lambda_crit}

print('='*60)
print('ALL MODELS COMPARISON')
print('='*60)
print('\n| Model | Baseline | Œª* |')
print('|-------|----------|-----|')
for model, data in sorted(all_results.items(), key=lambda x: x[1]['baseline'] or 0):
    b = f"{data['baseline']:.1%}" if data['baseline'] else "TBD"
    l = f"{data['lambda_crit']:.3f}" if data['lambda_crit'] else "TBD"
    print(f'| {model} | {b} | {l} |')