# Clean Trace Generation for L=5 and L=20

**Paper**: A2 (Cue-Dominant Extraction)

**Purpose**: Generate clean reasoning traces with L=5 and L=20 steps to enable
the Length × Cue experiment (E8) that validates the paper's title "Length Effects".

**Design Principles** (from Sofia's guidance):
- Use same temperature (0) and model for all traces
- Same template structure across all L values
- Verify final answer matches ground truth
- Ensure "quality difference" doesn't confound length effect

**Output**:
- `clean_traces_I5_v3.json` (199 traces, L=5)
- `clean_traces_I20_v3.json` (199 traces, L=20)

**Expected API calls**: 199 × 2 = 398 (plus retries for answer mismatches)

**Date**: 2026-01-03
**GLOBAL_SEED**: 20251224

## 0. Google Drive Connection

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
from datetime import datetime

EXPERIMENT_NAME = 'trace_generation_L5_L20'
EXPERIMENT_DATE = datetime.now().strftime('%Y%m%d')

BASE_DIR = '/content/drive/MyDrive/CoT_Experiment'
V3_DATA_DIR = f'{BASE_DIR}/full_experiment_v3_20251224'

# Output directory for new traces
TRACES_OUTPUT_DIR = f'{V3_DATA_DIR}/clean_traces'
os.makedirs(TRACES_OUTPUT_DIR, exist_ok=True)

# Log directory
LOG_DIR = f'{BASE_DIR}/{EXPERIMENT_NAME}_{EXPERIMENT_DATE}'
os.makedirs(LOG_DIR, exist_ok=True)

print(f'Experiment: {EXPERIMENT_NAME}')
print(f'V3 data directory: {V3_DATA_DIR}')
print(f'Traces output: {TRACES_OUTPUT_DIR}')
print(f'Log directory: {LOG_DIR}')

## 1. Install Dependencies

In [None]:
!pip install datasets anthropic matplotlib pandas tqdm -q
print('Dependencies installed.')

## 2. Configuration

In [None]:
import hashlib
import random
import json
import re
import time
from typing import List, Dict, Tuple, Optional, Any
from dataclasses import dataclass, asdict, field
from datetime import datetime
from tqdm import tqdm
import pandas as pd
import numpy as np

# =============================================================================
# Global Configuration
# =============================================================================
GLOBAL_SEED = 20251224
TRACE_GEN_SEED = 20260103  # Seed for trace generation

# Target lengths
TARGET_LENGTHS = [5, 20]

# API settings
API_MAX_TOKENS_TRACE = 2048  # Longer for L=20
API_RETRY_DELAY = 1.0
API_RATE_LIMIT_DELAY = 0.5
MAX_RETRIES_PER_PROBLEM = 3  # Retry if answer doesn't match

print('='*70)
print('CLEAN TRACE GENERATION: L=5 and L=20')
print('='*70)
print(f'  GLOBAL_SEED: {GLOBAL_SEED}')
print(f'  TRACE_GEN_SEED: {TRACE_GEN_SEED}')
print(f'  Target lengths: {TARGET_LENGTHS}')
print('='*70)

## 3. Data Structures

In [None]:
@dataclass
class GSM8KProblem:
    index: int
    question: str
    answer_text: str
    final_answer: int

@dataclass
class CleanTrace:
    """Clean reasoning trace - same structure as existing traces"""
    problem_index: int
    I: int  # Number of steps (L)
    steps: List[str]  # List of step contents
    full_text: str  # Complete trace text

## 4. Utility Functions

In [None]:
def derive_seed(global_seed: int, problem_id: int, L: int) -> int:
    """Generate deterministic seed for trace generation"""
    key = f"{global_seed}|trace_gen|{problem_id}|L={L}"
    h = hashlib.sha256(key.encode("utf-8")).hexdigest()
    return int(h[:8], 16)

def save_json(data: Any, filepath: str):
    with open(filepath, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=2)

def load_json(filepath: str) -> Any:
    with open(filepath, 'r', encoding='utf-8') as f:
        return json.load(f)

## 5. Load Problems

In [None]:
# Load problems
problems_path = f'{V3_DATA_DIR}/problems_v3.json'
problems_data = load_json(problems_path)
problems = [GSM8KProblem(**p) for p in problems_data]
print(f'Loaded {len(problems)} problems')

# Verify existing L=10 traces for reference
traces_10_path = f'{V3_DATA_DIR}/clean_traces/clean_traces_I10_v3.json'
if os.path.exists(traces_10_path):
    traces_10_data = load_json(traces_10_path)
    print(f'Reference: {len(traces_10_data)} existing L=10 traces')
else:
    print('Warning: L=10 traces not found for reference')

## 6. API Setup

In [None]:
from getpass import getpass

ANTHROPIC_API_KEY = getpass('Enter Anthropic API Key: ')
print('API Key set.')

In [None]:
import anthropic

client = anthropic.Anthropic(api_key=ANTHROPIC_API_KEY)

def call_claude(system_prompt: str, user_prompt: str, max_tokens: int = 2048, retries: int = 3) -> str:
    for attempt in range(retries):
        try:
            message = client.messages.create(
                model="claude-sonnet-4-20250514",
                max_tokens=max_tokens,
                messages=[{"role": "user", "content": user_prompt}],
                system=system_prompt,
                temperature=0  # Deterministic
            )
            time.sleep(API_RATE_LIMIT_DELAY)
            return message.content[0].text
        except Exception as e:
            print(f'API error (attempt {attempt+1}): {e}')
            if attempt < retries - 1:
                time.sleep(API_RETRY_DELAY * (attempt + 1))
            else:
                raise

# Test API
test_response = call_claude(
    "You are a helpful assistant.",
    'Say "API OK" and nothing else.',
    max_tokens=50
)
print(f'API test: {test_response}')

## 7. Trace Generation Prompts

In [None]:
TRACE_SYSTEM_PROMPT = """You are a math tutor who solves problems step by step.

CRITICAL FORMAT REQUIREMENTS:
1. Start with [[COT_START]]
2. Write EXACTLY {L} steps, numbered as "Step 1:", "Step 2:", etc.
3. Each step should contain ONE clear calculation or reasoning
4. The FINAL step (Step {L}) MUST end with "Final = X" where X is the integer answer
5. End with [[COT_END]]

EXAMPLE FORMAT (for L=3):
[[COT_START]]
Step 1: First, we identify that there are 5 apples at $2 each.
Step 2: Calculate the total: 5 × 2 = 10 dollars.
Step 3: Therefore, the total cost is Final = 10.
[[COT_END]]

IMPORTANT:
- You MUST have exactly {L} steps
- The final step MUST contain "Final = " followed by the integer answer
- Do NOT add any text after [[COT_END]]
"""

def create_trace_generation_prompt(problem: GSM8KProblem, L: int) -> Tuple[str, str]:
    """Create prompt for generating a trace with L steps"""
    system = TRACE_SYSTEM_PROMPT.format(L=L)
    user = f"""Solve the following math problem in EXACTLY {L} steps.

Problem: {problem.question}

Remember:
- Use EXACTLY {L} steps (Step 1 through Step {L})
- Final step must end with "Final = <answer>"
- The correct answer is an integer

Start your response with [[COT_START]]"""
    return system, user

## 8. Trace Parsing and Validation

In [None]:
def parse_trace_response(response: str, expected_L: int) -> Tuple[Optional[List[str]], Optional[int], str]:
    """
    Parse the generated trace response.
    
    Returns:
        (steps, extracted_answer, error_message)
        - steps: List of step contents (or None if parsing failed)
        - extracted_answer: The final answer from the trace (or None)
        - error_message: Description of any error (empty if success)
    """
    # Extract content between markers
    match = re.search(r'\[\[COT_START\]\](.*)\[\[COT_END\]\]', response, re.DOTALL)
    if not match:
        # Try without end marker
        match = re.search(r'\[\[COT_START\]\](.*)', response, re.DOTALL)
        if not match:
            return None, None, "Missing [[COT_START]] marker"
    
    content = match.group(1).strip()
    
    # Parse steps
    steps = []
    for i in range(1, expected_L + 1):
        # Pattern: "Step N:" followed by content until next "Step" or end
        if i < expected_L:
            pattern = rf'Step\s*{i}\s*:\s*(.+?)(?=Step\s*{i+1}\s*:)'
        else:
            pattern = rf'Step\s*{i}\s*:\s*(.+?)(?=\[\[COT_END\]\]|$)'
        
        step_match = re.search(pattern, content, re.DOTALL | re.IGNORECASE)
        if step_match:
            step_content = step_match.group(1).strip()
            # Clean up newlines within step
            step_content = ' '.join(step_content.split())
            steps.append(step_content)
        else:
            return None, None, f"Missing Step {i}"
    
    if len(steps) != expected_L:
        return None, None, f"Expected {expected_L} steps, got {len(steps)}"
    
    # Extract final answer from last step
    final_step = steps[-1]
    answer_match = re.search(r'Final\s*=\s*(-?\d+)', final_step, re.IGNORECASE)
    if not answer_match:
        # Try alternative patterns
        answer_match = re.search(r'=\s*(-?\d+)\s*\.?\s*$', final_step)
    
    if answer_match:
        extracted_answer = int(answer_match.group(1))
    else:
        # Try to find any number at the end
        numbers = re.findall(r'(-?\d+)', final_step)
        if numbers:
            extracted_answer = int(numbers[-1])
        else:
            return steps, None, "Could not extract final answer"
    
    return steps, extracted_answer, ""

def validate_trace(steps: List[str], extracted_answer: int, correct_answer: int, L: int) -> Tuple[bool, str]:
    """
    Validate a generated trace.
    
    Returns:
        (is_valid, error_message)
    """
    if len(steps) != L:
        return False, f"Wrong number of steps: {len(steps)} vs {L}"
    
    if extracted_answer != correct_answer:
        return False, f"Answer mismatch: {extracted_answer} vs {correct_answer}"
    
    # Check that final step has proper format
    if 'final' not in steps[-1].lower() and '=' not in steps[-1]:
        return False, "Final step missing 'Final =' format"
    
    return True, ""

In [None]:
# Test parsing with a sample
test_response = """[[COT_START]]
Step 1: The problem states there are 5 boxes with 3 apples each.
Step 2: Calculate total apples: 5 × 3 = 15.
Step 3: Each apple costs $2, so total = 15 × 2 = 30.
Step 4: Add tax of $5: 30 + 5 = 35.
Step 5: Therefore, the answer is Final = 35.
[[COT_END]]"""

steps, answer, error = parse_trace_response(test_response, 5)
print(f'Parsed steps: {len(steps) if steps else 0}')
print(f'Extracted answer: {answer}')
print(f'Error: {error}')
if steps:
    for i, s in enumerate(steps):
        print(f'  Step {i+1}: {s[:60]}...' if len(s) > 60 else f'  Step {i+1}: {s}')

## 9. Trace Generation Function

In [None]:
def generate_clean_trace(
    problem: GSM8KProblem,
    L: int,
    max_attempts: int = MAX_RETRIES_PER_PROBLEM
) -> Tuple[Optional[CleanTrace], Dict]:
    """
    Generate a clean trace for a problem with L steps.
    
    Returns:
        (trace, log_info)
    """
    log_info = {
        'problem_index': problem.index,
        'L': L,
        'attempts': 0,
        'success': False,
        'errors': []
    }
    
    for attempt in range(max_attempts):
        log_info['attempts'] = attempt + 1
        
        try:
            # Generate trace
            system, user = create_trace_generation_prompt(problem, L)
            response = call_claude(system, user, max_tokens=API_MAX_TOKENS_TRACE)
            
            # Parse response
            steps, extracted_answer, parse_error = parse_trace_response(response, L)
            
            if parse_error:
                log_info['errors'].append(f"Attempt {attempt+1}: Parse error - {parse_error}")
                continue
            
            # Validate
            is_valid, valid_error = validate_trace(steps, extracted_answer, problem.final_answer, L)
            
            if not is_valid:
                log_info['errors'].append(f"Attempt {attempt+1}: Validation error - {valid_error}")
                continue
            
            # Success! Build the trace
            lines = ['[[COT_START]]']
            for i, step_content in enumerate(steps):
                lines.append(f'Step {i+1}: {step_content}')
            lines.append('[[COT_END]]')
            full_text = '\n'.join(lines)
            
            trace = CleanTrace(
                problem_index=problem.index,
                I=L,
                steps=steps,
                full_text=full_text
            )
            
            log_info['success'] = True
            return trace, log_info
            
        except Exception as e:
            log_info['errors'].append(f"Attempt {attempt+1}: Exception - {str(e)}")
            continue
    
    return None, log_info

## 10. Generate L=5 Traces

In [None]:
print('='*70)
print('GENERATING L=5 TRACES')
print('='*70)

L5_traces = []
L5_logs = []
L5_failures = []

for prob in tqdm(problems, desc='L=5 traces'):
    trace, log_info = generate_clean_trace(prob, L=5)
    L5_logs.append(log_info)
    
    if trace:
        L5_traces.append(trace)
    else:
        L5_failures.append(prob.index)
        print(f'  Failed: Problem {prob.index}')

print(f'\nL=5 Generation Complete:')
print(f'  Success: {len(L5_traces)}/{len(problems)}')
print(f'  Failed: {len(L5_failures)}')
if L5_failures:
    print(f'  Failed indices: {L5_failures[:10]}...' if len(L5_failures) > 10 else f'  Failed indices: {L5_failures}')

In [None]:
# Save L=5 traces
if L5_traces:
    L5_output_path = f'{TRACES_OUTPUT_DIR}/clean_traces_I5_v3.json'
    save_json([asdict(t) for t in L5_traces], L5_output_path)
    print(f'L=5 traces saved: {L5_output_path}')
    
    # Save log
    save_json(L5_logs, f'{LOG_DIR}/L5_generation_log.json')
    print(f'L=5 log saved: {LOG_DIR}/L5_generation_log.json')

## 11. Generate L=20 Traces

In [None]:
print('='*70)
print('GENERATING L=20 TRACES')
print('='*70)

L20_traces = []
L20_logs = []
L20_failures = []

for prob in tqdm(problems, desc='L=20 traces'):
    trace, log_info = generate_clean_trace(prob, L=20)
    L20_logs.append(log_info)
    
    if trace:
        L20_traces.append(trace)
    else:
        L20_failures.append(prob.index)
        print(f'  Failed: Problem {prob.index}')

print(f'\nL=20 Generation Complete:')
print(f'  Success: {len(L20_traces)}/{len(problems)}')
print(f'  Failed: {len(L20_failures)}')
if L20_failures:
    print(f'  Failed indices: {L20_failures[:10]}...' if len(L20_failures) > 10 else f'  Failed indices: {L20_failures}')

In [None]:
# Save L=20 traces
if L20_traces:
    L20_output_path = f'{TRACES_OUTPUT_DIR}/clean_traces_I20_v3.json'
    save_json([asdict(t) for t in L20_traces], L20_output_path)
    print(f'L=20 traces saved: {L20_output_path}')
    
    # Save log
    save_json(L20_logs, f'{LOG_DIR}/L20_generation_log.json')
    print(f'L=20 log saved: {LOG_DIR}/L20_generation_log.json')

## 12. Quality Check

In [None]:
print('='*70)
print('QUALITY CHECK')
print('='*70)

# Load existing L=10 for comparison
if os.path.exists(traces_10_path):
    L10_traces = [CleanTrace(**t) for t in load_json(traces_10_path)]
    L10_indices = {t.problem_index for t in L10_traces}
else:
    L10_traces = []
    L10_indices = set()

L5_indices = {t.problem_index for t in L5_traces}
L20_indices = {t.problem_index for t in L20_traces}

# Find common problems across all three lengths
common_indices = L5_indices & L10_indices & L20_indices

print(f'\nCoverage:')
print(f'  L=5:  {len(L5_traces)} traces')
print(f'  L=10: {len(L10_traces)} traces')
print(f'  L=20: {len(L20_traces)} traces')
print(f'  Common across all: {len(common_indices)} problems')

# Sample check: Show a few traces
print('\n' + '-'*70)
print('SAMPLE TRACES (first common problem)')
print('-'*70)

if common_indices:
    sample_idx = min(common_indices)
    
    for traces, L in [(L5_traces, 5), (L10_traces, 10), (L20_traces, 20)]:
        for t in traces:
            if t.problem_index == sample_idx:
                print(f'\nL={L} (Problem {sample_idx}):')
                print(f'  Steps: {len(t.steps)}')
                print(f'  Step 1: {t.steps[0][:60]}...')
                print(f'  Final step: {t.steps[-1][:60]}...')
                break

In [None]:
# Verify answers match
print('\n' + '-'*70)
print('ANSWER VERIFICATION')
print('-'*70)

prob_map = {p.index: p for p in problems}

def extract_final_answer(trace: CleanTrace) -> Optional[int]:
    final_step = trace.steps[-1]
    match = re.search(r'Final\s*=\s*(-?\d+)', final_step, re.IGNORECASE)
    if match:
        return int(match.group(1))
    match = re.search(r'=\s*(-?\d+)\s*\.?\s*$', final_step)
    if match:
        return int(match.group(1))
    numbers = re.findall(r'(-?\d+)', final_step)
    if numbers:
        return int(numbers[-1])
    return None

for traces, L in [(L5_traces, 5), (L20_traces, 20)]:
    mismatches = 0
    for t in traces:
        extracted = extract_final_answer(t)
        correct = prob_map[t.problem_index].final_answer
        if extracted != correct:
            mismatches += 1
            if mismatches <= 3:
                print(f'  L={L} Problem {t.problem_index}: extracted={extracted}, correct={correct}')
    print(f'L={L}: {mismatches} answer mismatches out of {len(traces)}')

## 13. Summary

In [None]:
summary = {
    'experiment': 'trace_generation_L5_L20',
    'date': EXPERIMENT_DATE,
    'seeds': {
        'global': GLOBAL_SEED,
        'trace_gen': TRACE_GEN_SEED
    },
    'results': {
        'L5': {
            'total_problems': len(problems),
            'successful_traces': len(L5_traces),
            'failed_indices': L5_failures
        },
        'L20': {
            'total_problems': len(problems),
            'successful_traces': len(L20_traces),
            'failed_indices': L20_failures
        }
    },
    'common_problems': len(common_indices),
    'output_files': {
        'L5': f'{TRACES_OUTPUT_DIR}/clean_traces_I5_v3.json',
        'L20': f'{TRACES_OUTPUT_DIR}/clean_traces_I20_v3.json'
    }
}

save_json(summary, f'{LOG_DIR}/generation_summary.json')

print('='*70)
print('TRACE GENERATION COMPLETE')
print('='*70)
print(f'Date: {EXPERIMENT_DATE}')
print(f'\nResults:')
print(f'  L=5:  {len(L5_traces)} traces generated')
print(f'  L=20: {len(L20_traces)} traces generated')
print(f'  Common problems: {len(common_indices)}')
print(f'\nOutput files:')
print(f'  {TRACES_OUTPUT_DIR}/clean_traces_I5_v3.json')
print(f'  {TRACES_OUTPUT_DIR}/clean_traces_I20_v3.json')
print(f'\nLogs: {LOG_DIR}')
print('='*70)
print('\nNext: Run E8_length_cue_experiment.ipynb')