# Comprehensive Soft Thinking Comparison: Standard vs Vanilla vs Enhanced

This notebook provides a comprehensive comparison of multiple sampling strategies:

1. **Standard Sampling**: Traditional discrete token generation
2. **Vanilla Soft Thinking**: Basic soft token approach
3. **Soft Thinking + Dirichlet Noise**: With stochastic perturbation
4. **Soft Thinking + Gumbel-Softmax**: With Gumbel noise injection

## Enhanced Features from RicardoLaMo/Soft-Thinking-rl

### 1. **Noise Injection Strategies**

**Dirichlet Noise**:
- Samples from Dirichlet distribution: `Dir(α + (1-α)·p)`
- Maintains general distribution shape while adding stochasticity
- Good for exploring nearby probability regions

**Gumbel-Softmax**:
- Adds Gumbel noise to logits: `logits' = (logits + Gumbel) / τ`
- Reparameterization trick for categorical sampling
- Temperature `τ` controls exploration vs exploitation

### 2. **Dual Temperature Control**

- **Thinking Temperature**: Used during soft thinking steps (exploration)
- **Output Temperature**: Used for final token selection (exploitation)
- Enables separate control of exploration during reasoning vs final decision

### 3. **Multi-Criteria Early Stopping**

- **Entropy-based**: Stop when model is confident (low entropy)
- **Length-based**: Stop after max thinking steps to control compute
- Tracks which criterion triggered stopping for analysis

In [None]:
# Import required libraries
import os
import sys
import json
import random
from pathlib import Path
from typing import List, Dict, Tuple, Optional
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
from IPython.display import display, Markdown, HTML
import re

# Statistical tests
from scipy import stats
from scipy.stats import wilcoxon, friedmanchisquare

# PyTorch
import torch
import torch.nn.functional as F
import transformers

# Add project path
project_root = Path('/home/wliu23/github/reasoning-with-sampling/llm_experiments')
sys.path.insert(0, str(project_root))

from grader_utils.parse_utils import parse_answer
from constants import *
from power_samp_utils import AutoregressiveSampler, SoftThinkingSampler, soft_thinking_generate, format_prompt

# Set random seeds
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# Plotting
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")
%matplotlib inline

print("✓ Libraries imported successfully")

## 1. Configuration

In [None]:
# Set up HuggingFace cache
HF_HOME = Path("/home/wliu23/projects/reasoning-with-sampling/.hf_home")
os.environ["HF_HOME"] = str(HF_HOME)
os.environ["HF_HUB_CACHE"] = str(HF_HOME / "hub")
os.environ["TRANSFORMERS_CACHE"] = str(HF_HOME / "hub")
os.environ["HF_DATASETS_CACHE"] = str(HF_HOME / "datasets")
os.environ["HF_HUB_OFFLINE"] = "1"

In [None]:
# Experiment configuration
CONFIG = {
    # Model settings
    'model': 'qwen_math',
    'model_str': 'Qwen/Qwen2.5-Math-7B',
    'cache_dir': os.environ['TRANSFORMERS_CACHE'],
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    
    # Dataset settings
    'dataset': 'MATH',
    'dataset_path': '/home/wliu23/github/reasoning-with-sampling/llm_experiments/data/MATH500.json',
    'num_questions': 20,  # Start small
    
    # Generation settings
    'max_new_tokens': 3072,
    'cot': True,
    
    # Standard sampling
    'standard': {
        'temperature': 1.0,
    },
    
    # Soft Thinking: Vanilla (no noise)
    'soft_vanilla': {
        'num_thinking_steps': 3,
        'max_topk': 10,
        'min_p': 0.001,
        'early_stopping_entropy_threshold': 0.05,
        'early_stopping_length_threshold': 5,  # Max thinking steps
        'thinking_temperature': 1.0,
        'output_temperature': 0.7,  # Sharper for final output
        'noise_type': 'none',
    },
    
    # Soft Thinking: Dirichlet Noise
    'soft_dirichlet': {
        'num_thinking_steps': 3,
        'max_topk': 10,
        'min_p': 0.001,
        'early_stopping_entropy_threshold': 0.05,
        'early_stopping_length_threshold': 5,
        'thinking_temperature': 1.0,
        'output_temperature': 0.7,
        'noise_type': 'dirichlet',
        'noise_alpha': 0.1,  # Concentration parameter
    },
    
    # Soft Thinking: Gumbel-Softmax
    'soft_gumbel': {
        'num_thinking_steps': 3,
        'max_topk': 10,
        'min_p': 0.001,
        'early_stopping_entropy_threshold': 0.05,
        'early_stopping_length_threshold': 5,
        'thinking_temperature': 1.0,
        'output_temperature': 0.7,
        'noise_type': 'gumbel',
        'gumbel_tau': 1.0,  # Gumbel-Softmax temperature
    },
    
    # Output settings
    'save_dir': '/home/wliu23/github/reasoning-with-sampling/notebooks/results/comprehensive_soft_thinking',
    'alpha': 0.05,  # Statistical significance
}

os.makedirs(CONFIG['save_dir'], exist_ok=True)

print("Configuration Summary:")
print(f"  Model: {CONFIG['model_str']}")
print(f"  Device: {CONFIG['device']}")
print(f"  Questions: {CONFIG['num_questions']}")
print(f"  Methods: Standard, Vanilla Soft, Dirichlet Soft, Gumbel Soft")
print(f"  Output: {CONFIG['save_dir']}")

## 2. Load Model and Initialize Samplers

In [None]:
# Load dataset
with open(CONFIG['dataset_path'], 'r') as f:
    dataset = json.load(f)

print(f"Loaded {len(dataset)} problems from MATH500")
print(f"Will process {CONFIG['num_questions']} questions")

In [None]:
# Load model
print(f"Loading model: {CONFIG['model_str']}")

tokenizer = transformers.AutoTokenizer.from_pretrained(
    CONFIG['model_str'],
    cache_dir=CONFIG['cache_dir'],
    local_files_only=True,
    trust_remote_code=True,
)

if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

model = transformers.AutoModelForCausalLM.from_pretrained(
    CONFIG['model_str'],
    cache_dir=CONFIG['cache_dir'],
    local_files_only=True,
    torch_dtype=torch.bfloat16,
    device_map={'': CONFIG['device']},
    trust_remote_code=True,
).to(CONFIG['device'])

print(f"✓ Model loaded on {CONFIG['device']}")

In [None]:
# Initialize all samplers
print("Initializing samplers...")

# 1. Vanilla Soft Thinking
soft_vanilla_sampler = SoftThinkingSampler(
    model, tokenizer, CONFIG['device'],
    **{k: v for k, v in CONFIG['soft_vanilla'].items() if k != 'num_thinking_steps'}
)
print("  ✓ Vanilla Soft Thinking (no noise)")

# 2. Dirichlet Soft Thinking
soft_dirichlet_sampler = SoftThinkingSampler(
    model, tokenizer, CONFIG['device'],
    **{k: v for k, v in CONFIG['soft_dirichlet'].items() if k != 'num_thinking_steps'}
)
print("  ✓ Soft Thinking + Dirichlet Noise")

# 3. Gumbel Soft Thinking
soft_gumbel_sampler = SoftThinkingSampler(
    model, tokenizer, CONFIG['device'],
    **{k: v for k, v in CONFIG['soft_gumbel'].items() if k != 'num_thinking_steps'}
)
print("  ✓ Soft Thinking + Gumbel-Softmax")

print("\nAll samplers initialized successfully!")

## 3. Generation Functions

In [None]:
def generate_standard(model, tokenizer, input_ids, device, max_new_tokens=3072, temperature=1.0) -> Dict:
    """Standard sampling generation."""
    attention_mask = torch.ones_like(input_ids)
    
    output = model.generate(
        input_ids,
        attention_mask=attention_mask,
        max_new_tokens=max_new_tokens,
        pad_token_id=tokenizer.pad_token_id,
        return_dict_in_generate=True,
        output_scores=True,
        do_sample=True,
        temperature=temperature,
    )
    
    generated_ids = output.sequences[0][len(input_ids[0]):]
    completion = tokenizer.decode(generated_ids, skip_special_tokens=True)
    parsed_answer = parse_answer(completion)
    
    log_probs = []
    tokens = []
    
    for i, token_id in enumerate(generated_ids):
        if i < len(output.scores):
            logits = output.scores[i][0]
            log_prob_dist = F.log_softmax(logits, dim=-1)
            token_log_prob = log_prob_dist[token_id].item()
            log_probs.append(token_log_prob)
        tokens.append(tokenizer.decode([token_id]))
    
    return {
        'completion': completion,
        'answer': parsed_answer,
        'tokens': tokens,
        'token_ids': generated_ids.cpu().tolist(),
        'log_probs': log_probs,
        'cumulative_log_prob': sum(log_probs) if log_probs else 0.0,
        'num_tokens': len(tokens),
        'mean_log_prob': np.mean(log_probs) if log_probs else 0.0,
    }


def generate_soft_thinking(soft_sampler, prefix, max_new_tokens, num_thinking_steps, config_name) -> Dict:
    """Soft Thinking generation with detailed tracking."""
    token_ids, log_probs, soft_info, avg_thinking_steps = soft_thinking_generate(
        soft_sampler, prefix, max_new_tokens, num_thinking_steps
    )
    
    generated_ids = token_ids[len(prefix):]
    completion = soft_sampler.tokenizer.decode(generated_ids, skip_special_tokens=True)
    parsed_answer = parse_answer(completion)
    tokens = [soft_sampler.tokenizer.decode([tid]) for tid in generated_ids]
    
    # Analyze stopping criteria
    stopping_reasons = {'entropy': 0, 'length': 0, 'max_steps': 0}
    for info in soft_info:
        stopping_reasons[info['stopped_by']] += 1
    
    return {
        'completion': completion,
        'answer': parsed_answer,
        'tokens': tokens,
        'token_ids': generated_ids,
        'log_probs': log_probs,
        'cumulative_log_prob': sum(log_probs) if log_probs else 0.0,
        'num_tokens': len(tokens),
        'mean_log_prob': np.mean(log_probs) if log_probs else 0.0,
        'avg_thinking_steps': avg_thinking_steps,
        'soft_info': soft_info,
        'total_thinking_steps': sum(info['thinking_steps'] for info in soft_info),
        'stopping_reasons': stopping_reasons,
        'config': config_name,
    }

print("✓ Generation functions defined")

## 4. Run Comprehensive Experiments

In [None]:
# Data collection
results = []
questions_to_process = dataset[:CONFIG['num_questions']]

for idx, data in enumerate(tqdm(questions_to_process, desc="Processing MATH problems")):
    question = data['prompt']
    correct_answer = data['answer']
    
    print(f"\n{'='*80}")
    print(f"Question {idx+1}/{len(questions_to_process)}")
    print(f"{'='*80}")
    print(f"Q: {question[:100]}...")
    
    # Prepare input
    input_text = format_prompt(question, CONFIG['model'], tokenizer, CONFIG['cot'])
    input_ids = tokenizer.encode(input_text, return_tensors="pt").to(CONFIG['device'])
    prefix = [idx.item() for idx in input_ids[0]]
    
    # 1. Standard sampling
    print("\n[Standard Sampling]")
    std_result = generate_standard(
        model, tokenizer, input_ids, CONFIG['device'],
        CONFIG['max_new_tokens'], CONFIG['standard']['temperature']
    )
    print(f"  Answer: {std_result['answer']}")
    print(f"  Tokens: {std_result['num_tokens']}")
    print(f"  Cumulative log-prob: {std_result['cumulative_log_prob']:.4f}")
    
    # 2. Vanilla Soft Thinking
    print("\n[Vanilla Soft Thinking]")
    vanilla_result = generate_soft_thinking(
        soft_vanilla_sampler, prefix,
        CONFIG['max_new_tokens'],
        CONFIG['soft_vanilla']['num_thinking_steps'],
        'vanilla'
    )
    print(f"  Answer: {vanilla_result['answer']}")
    print(f"  Tokens: {vanilla_result['num_tokens']}")
    print(f"  Cumulative log-prob: {vanilla_result['cumulative_log_prob']:.4f}")
    print(f"  Avg thinking steps: {vanilla_result['avg_thinking_steps']:.2f}")
    
    # 3. Dirichlet Soft Thinking
    print("\n[Soft Thinking + Dirichlet]")
    dirichlet_result = generate_soft_thinking(
        soft_dirichlet_sampler, prefix,
        CONFIG['max_new_tokens'],
        CONFIG['soft_dirichlet']['num_thinking_steps'],
        'dirichlet'
    )
    print(f"  Answer: {dirichlet_result['answer']}")
    print(f"  Tokens: {dirichlet_result['num_tokens']}")
    print(f"  Cumulative log-prob: {dirichlet_result['cumulative_log_prob']:.4f}")
    print(f"  Avg thinking steps: {dirichlet_result['avg_thinking_steps']:.2f}")
    
    # 4. Gumbel Soft Thinking
    print("\n[Soft Thinking + Gumbel]")
    gumbel_result = generate_soft_thinking(
        soft_gumbel_sampler, prefix,
        CONFIG['max_new_tokens'],
        CONFIG['soft_gumbel']['num_thinking_steps'],
        'gumbel'
    )
    print(f"  Answer: {gumbel_result['answer']}")
    print(f"  Tokens: {gumbel_result['num_tokens']}")
    print(f"  Cumulative log-prob: {gumbel_result['cumulative_log_prob']:.4f}")
    print(f"  Avg thinking steps: {gumbel_result['avg_thinking_steps']:.2f}")
    
    # Store results
    results.append({
        'question_idx': idx,
        'question': question,
        'correct_answer': correct_answer,
        
        # Standard
        'std': std_result,
        'std_correct': std_result['answer'] == correct_answer,
        
        # Vanilla
        'vanilla': vanilla_result,
        'vanilla_correct': vanilla_result['answer'] == correct_answer,
        
        # Dirichlet
        'dirichlet': dirichlet_result,
        'dirichlet_correct': dirichlet_result['answer'] == correct_answer,
        
        # Gumbel
        'gumbel': gumbel_result,
        'gumbel_correct': gumbel_result['answer'] == correct_answer,
    })

print(f"\n\n✓ Completed data collection for {len(results)} questions")

In [None]:
# Save results
results_file = os.path.join(CONFIG['save_dir'], 'comprehensive_results.json')
with open(results_file, 'w') as f:
    # Convert to JSON-serializable format
    json_results = []
    for r in results:
        json_r = {
            'question_idx': r['question_idx'],
            'question': r['question'],
            'correct_answer': r['correct_answer'],
            'std_correct': r['std_correct'],
            'vanilla_correct': r['vanilla_correct'],
            'dirichlet_correct': r['dirichlet_correct'],
            'gumbel_correct': r['gumbel_correct'],
        }
        for method in ['std', 'vanilla', 'dirichlet', 'gumbel']:
            for key in ['answer', 'num_tokens', 'cumulative_log_prob', 'mean_log_prob']:
                if key in r[method]:
                    json_r[f'{method}_{key}'] = r[method][key]
            if method != 'std':
                json_r[f'{method}_avg_thinking_steps'] = r[method]['avg_thinking_steps']
                json_r[f'{method}_total_thinking_steps'] = r[method]['total_thinking_steps']
        json_results.append(json_r)
    
    json.dump(json_results, f, indent=2)

print(f"✓ Saved results to {results_file}")

## 5. Comparative Analysis

Compare all four methods across multiple metrics.

In [None]:
# Create comprehensive summary
summary_data = []
for r in results:
    summary_data.append({
        'question_idx': r['question_idx'],
        
        # Standard
        'std_log_prob': r['std']['cumulative_log_prob'],
        'std_tokens': r['std']['num_tokens'],
        'std_correct': r['std_correct'],
        
        # Vanilla
        'vanilla_log_prob': r['vanilla']['cumulative_log_prob'],
        'vanilla_tokens': r['vanilla']['num_tokens'],
        'vanilla_correct': r['vanilla_correct'],
        'vanilla_thinking_steps': r['vanilla']['avg_thinking_steps'],
        
        # Dirichlet
        'dirichlet_log_prob': r['dirichlet']['cumulative_log_prob'],
        'dirichlet_tokens': r['dirichlet']['num_tokens'],
        'dirichlet_correct': r['dirichlet_correct'],
        'dirichlet_thinking_steps': r['dirichlet']['avg_thinking_steps'],
        
        # Gumbel
        'gumbel_log_prob': r['gumbel']['cumulative_log_prob'],
        'gumbel_tokens': r['gumbel']['num_tokens'],
        'gumbel_correct': r['gumbel_correct'],
        'gumbel_thinking_steps': r['gumbel']['avg_thinking_steps'],
    })

df_summary = pd.DataFrame(summary_data)

print("\n" + "="*80)
print("COMPREHENSIVE METHOD COMPARISON")
print("="*80)

methods = ['std', 'vanilla', 'dirichlet', 'gumbel']
method_names = ['Standard', 'Vanilla Soft', 'Dirichlet Soft', 'Gumbel Soft']

for method, name in zip(methods, method_names):
    print(f"\n{name}:")
    print(f"  Accuracy: {df_summary[f'{method}_correct'].mean():.2%}")
    print(f"  Mean log-prob: {df_summary[f'{method}_log_prob'].mean():.4f}")
    print(f"  Mean tokens: {df_summary[f'{method}_tokens'].mean():.2f}")
    if method != 'std':
        print(f"  Avg thinking steps: {df_summary[f'{method}_thinking_steps'].mean():.2f}")

df_summary.head(10)

## 6. Statistical Analysis

Use Friedman test (non-parametric) for comparing multiple related samples.

In [None]:
# Friedman test for log-probabilities
stat, p_value = friedmanchisquare(
    df_summary['std_log_prob'],
    df_summary['vanilla_log_prob'],
    df_summary['dirichlet_log_prob'],
    df_summary['gumbel_log_prob']
)

print("\n" + "="*80)
print("FRIEDMAN TEST (Multiple Related Samples)")
print("="*80)
print(f"\nTesting: Are the log-probabilities different across methods?")
print(f"  Test statistic: {stat:.4f}")
print(f"  p-value: {p_value:.6f}")

if p_value < CONFIG['alpha']:
    print(f"  ✓ REJECT H₀: Methods produce significantly different results (p < {CONFIG['alpha']})")
else:
    print(f"  ✗ FAIL TO REJECT H₀: No significant difference (p >= {CONFIG['alpha']})")

# Pairwise comparisons (Wilcoxon)
print("\nPairwise Wilcoxon Tests:")
pairs = [
    ('Standard', 'Vanilla Soft', 'std_log_prob', 'vanilla_log_prob'),
    ('Vanilla Soft', 'Dirichlet Soft', 'vanilla_log_prob', 'dirichlet_log_prob'),
    ('Vanilla Soft', 'Gumbel Soft', 'vanilla_log_prob', 'gumbel_log_prob'),
    ('Dirichlet Soft', 'Gumbel Soft', 'dirichlet_log_prob', 'gumbel_log_prob'),
]

for name1, name2, col1, col2 in pairs:
    stat_w, p_w = wilcoxon(df_summary[col1], df_summary[col2])
    sig = "✓" if p_w < CONFIG['alpha'] else "✗"
    print(f"  {sig} {name1} vs {name2}: p={p_w:.4f}")

## 7. Visualization Dashboard

In [None]:
# Comprehensive visualization dashboard
fig = plt.figure(figsize=(20, 12))
gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)

# 1. Accuracy comparison
ax1 = fig.add_subplot(gs[0, 0])
accuracies = [df_summary[f'{m}_correct'].mean() for m in methods]
ax1.bar(method_names, accuracies, alpha=0.7, color=['gray', 'blue', 'green', 'orange'])
ax1.set_ylabel('Accuracy')
ax1.set_title('Accuracy Comparison')
ax1.set_ylim([0, 1])
ax1.grid(True, alpha=0.3)

# 2. Log-prob box plot
ax2 = fig.add_subplot(gs[0, 1])
log_prob_data = [df_summary[f'{m}_log_prob'] for m in methods]
ax2.boxplot(log_prob_data, labels=method_names)
ax2.set_ylabel('Cumulative Log-Prob')
ax2.set_title('Log-Probability Distribution')
ax2.grid(True, alpha=0.3)
plt.setp(ax2.xaxis.get_majorticklabels(), rotation=45)

# 3. Token efficiency
ax3 = fig.add_subplot(gs[0, 2])
token_means = [df_summary[f'{m}_tokens'].mean() for m in methods]
ax3.bar(method_names, token_means, alpha=0.7, color=['gray', 'blue', 'green', 'orange'])
ax3.set_ylabel('Average Tokens')
ax3.set_title('Token Efficiency')
ax3.grid(True, alpha=0.3)

# 4. Thinking steps comparison
ax4 = fig.add_subplot(gs[1, 0])
thinking_steps = [df_summary[f'{m}_thinking_steps'].mean() for m in methods[1:]]  # Exclude standard
ax4.bar(method_names[1:], thinking_steps, alpha=0.7, color=['blue', 'green', 'orange'])
ax4.set_ylabel('Avg Thinking Steps')
ax4.set_title('Thinking Steps per Token')
ax4.grid(True, alpha=0.3)

# 5. Log-prob scatter: Standard vs each Soft method
ax5 = fig.add_subplot(gs[1, 1])
for i, (method, name, color) in enumerate(zip(methods[1:], method_names[1:], ['blue', 'green', 'orange'])):
    ax5.scatter(df_summary['std_log_prob'], df_summary[f'{method}_log_prob'],
               alpha=0.6, label=name, color=color, s=50)
lims = [df_summary[[f'{m}_log_prob' for m in methods]].min().min(),
        df_summary[[f'{m}_log_prob' for m in methods]].max().max()]
ax5.plot(lims, lims, 'r--', alpha=0.5, label='y=x')
ax5.set_xlabel('Standard Log-Prob')
ax5.set_ylabel('Soft Thinking Log-Prob')
ax5.set_title('Standard vs Soft Thinking')
ax5.legend()
ax5.grid(True, alpha=0.3)

# 6. Improvement heatmap
ax6 = fig.add_subplot(gs[1, 2])
improvements = np.zeros((len(methods), len(methods)))
for i, m1 in enumerate(methods):
    for j, m2 in enumerate(methods):
        improvements[i, j] = (df_summary[f'{m2}_log_prob'].mean() - df_summary[f'{m1}_log_prob'].mean())
im = ax6.imshow(improvements, cmap='RdYlGn', aspect='auto', vmin=-50, vmax=50)
ax6.set_xticks(range(len(methods)))
ax6.set_yticks(range(len(methods)))
ax6.set_xticklabels(method_names, rotation=45, ha='right')
ax6.set_yticklabels(method_names)
ax6.set_title('Log-Prob Improvement Matrix\n(Row - Column)')
plt.colorbar(im, ax=ax6)

# 7-9. Distribution histograms for each soft method
for idx, (method, name, color) in enumerate(zip(methods[1:], method_names[1:], ['blue', 'green', 'orange'])):
    ax = fig.add_subplot(gs[2, idx])
    diff = df_summary[f'{method}_log_prob'] - df_summary['std_log_prob']
    ax.hist(diff, bins=15, alpha=0.7, color=color, edgecolor='black')
    ax.axvline(0, color='red', linestyle='--', linewidth=2, label='No improvement')
    ax.axvline(diff.mean(), color='green', linestyle='--', linewidth=2,
              label=f'Mean={diff.mean():.2f}')
    ax.set_xlabel('Log-Prob Improvement')
    ax.set_ylabel('Frequency')
    ax.set_title(f'{name} vs Standard')
    ax.legend(fontsize=8)
    ax.grid(True, alpha=0.3)

plt.savefig(os.path.join(CONFIG['save_dir'], 'comprehensive_dashboard.png'), dpi=300, bbox_inches='tight')
plt.show()

print("\n✓ Visualization dashboard generated")

## 8. Summary Report

In [None]:
# Generate comprehensive report
report = f"""
{'='*80}
COMPREHENSIVE SOFT THINKING ANALYSIS
Standard vs Vanilla vs Dirichlet vs Gumbel-Softmax
{'='*80}

1. CONFIGURATION
{'-'*80}
Model: {CONFIG['model_str']}
Questions: {len(results)}
Methods:
  - Standard Sampling
  - Vanilla Soft Thinking (no noise)
  - Dirichlet Soft Thinking (alpha={CONFIG['soft_dirichlet']['noise_alpha']})
  - Gumbel-Softmax Soft Thinking (tau={CONFIG['soft_gumbel']['gumbel_tau']})

2. PERFORMANCE METRICS
{'-'*80}
"""

for method, name in zip(methods, method_names):
    report += f"""
{name}:
  Accuracy: {df_summary[f'{method}_correct'].mean():.2%} ({df_summary[f'{method}_correct'].sum()}/{len(df_summary)})
  Mean log-prob: {df_summary[f'{method}_log_prob'].mean():.4f}
  Mean tokens: {df_summary[f'{method}_tokens'].mean():.2f}
"""
    if method != 'std':
        report += f"  Avg thinking steps: {df_summary[f'{method}_thinking_steps'].mean():.2f}\n"

report += f"""
3. STATISTICAL ANALYSIS
{'-'*80}
Friedman Test (4 methods):
  Test statistic: {stat:.4f}
  p-value: {p_value:.6f}
  Result: {'SIGNIFICANT' if p_value < CONFIG['alpha'] else 'NOT SIGNIFICANT'}

4. KEY FINDINGS
{'-'*80}
Best Accuracy: {max([(name, df_summary[f'{m}_correct'].mean()) for m, name in zip(methods, method_names)], key=lambda x: x[1])[0]}
Best Log-Prob: {max([(name, df_summary[f'{m}_log_prob'].mean()) for m, name in zip(methods, method_names)], key=lambda x: x[1])[0]}
Most Efficient (tokens): {min([(name, df_summary[f'{m}_tokens'].mean()) for m, name in zip(methods, method_names)], key=lambda x: x[1])[0]}

5. NOISE INJECTION ANALYSIS
{'-'*80}
Dirichlet vs Vanilla:
  Accuracy diff: {(df_summary['dirichlet_correct'].mean() - df_summary['vanilla_correct'].mean()):.2%}
  Log-prob diff: {(df_summary['dirichlet_log_prob'].mean() - df_summary['vanilla_log_prob'].mean()):.4f}

Gumbel vs Vanilla:
  Accuracy diff: {(df_summary['gumbel_correct'].mean() - df_summary['vanilla_correct'].mean()):.2%}
  Log-prob diff: {(df_summary['gumbel_log_prob'].mean() - df_summary['vanilla_log_prob'].mean()):.4f}

6. CONCLUSION
{'-'*80}
Enhanced Soft Thinking with noise injection {'demonstrates' if p_value < CONFIG['alpha'] else 'shows'} 
{'statistically significant' if p_value < CONFIG['alpha'] else 'no significant'} differences
compared to standard sampling. The dual temperature control and multi-criteria
early stopping provide fine-grained control over the thinking process.

{'='*80}
"""

print(report)

# Save report
report_file = os.path.join(CONFIG['save_dir'], 'comprehensive_report.txt')
with open(report_file, 'w') as f:
    f.write(report)

print(f"\n✓ Report saved to: {report_file}")

In [None]:
# Save summary DataFrame
df_file = os.path.join(CONFIG['save_dir'], 'comprehensive_summary.csv')
df_summary.to_csv(df_file, index=False)

print("\n" + "="*80)
print("✓ ANALYSIS COMPLETE!")
print("="*80)
print(f"\nAll results saved to: {CONFIG['save_dir']}")
print("\nGenerated files:")
for filename in os.listdir(CONFIG['save_dir']):
    filepath = os.path.join(CONFIG['save_dir'], filename)
    if os.path.isfile(filepath):
        size = os.path.getsize(filepath)
        print(f"  - {filename} ({size:,} bytes)")