# Entropy-Adaptive Branching Tutorial

This notebook provides an interactive tutorial for using Entropy-Adaptive Branching (EAB) for efficient multi-sample generation.

## 1. Installation and Setup

In [None]:
# Install if not already installed
# !pip install -e ..

from eab import EntropyAdaptiveBranching
from eab.utils import set_seed, visualize_branching_statistics
import warnings
warnings.filterwarnings('ignore')

# Set seed for reproducibility
set_seed(42)

## 2. Initialize EAB

We'll start with a small model (GPT-2) for quick experimentation.

In [None]:
eab = EntropyAdaptiveBranching(
    model_name="gpt2",
    entropy_threshold=0.4,
    branch_factor=3,
    max_paths=20
)

print("✓ EAB initialized successfully!")

## 3. Basic Generation

Let's generate multiple completions for a simple prompt.

In [None]:
prompt = "The capital of France is"

results = eab.generate(
    prompt=prompt,
    max_new_tokens=10,
    temperature=0.8
)

print(f"\nGenerated {len(results)} completions:\n")
for i, result in enumerate(results[:5], 1):
    print(f"{i}. {result['text']} (p={result['probability']:.4f})")

## 4. Analyzing Entropy

Let's examine how entropy evolved during generation.

In [None]:
# Get entropy history
entropy_history = eab.get_entropy_history()

print("Entropy Statistics:")
stats = entropy_history['statistics']
for key, value in stats.items():
    print(f"  {key}: {value}")

# Plot entropy evolution
eab.plot_entropy()

## 5. Comparing Different Prompts

Let's compare branching behavior for different types of prompts.

In [None]:
prompts = [
    ("The capital of Japan is", "factual"),
    ("In my opinion, the best movie is", "subjective"),
    ("Once upon a time,", "creative")
]

for prompt, prompt_type in prompts:
    print(f"\n{'='*60}")
    print(f"Prompt ({prompt_type}): {prompt}")
    print('='*60)
    
    results = eab.generate(
        prompt=prompt,
        max_new_tokens=15,
        temperature=0.9,
        show_progress=False
    )
    
    stats = eab.get_entropy_history()['statistics']
    
    print(f"\nResults:")
    print(f"  Paths: {len(results)}")
    print(f"  Branch rate: {stats['branch_rate']:.1%}")
    print(f"  Avg entropy: {stats['mean_entropy']:.3f}")
    
    print(f"\nTop 3 completions:")
    for i, r in enumerate(results[:3], 1):
        print(f"  {i}. {r['text'][:60]}... (p={r['probability']:.4f})")

## 6. Tuning Hyperparameters

Let's see how different entropy thresholds affect generation.

In [None]:
prompt = "The best way to learn programming is"
thresholds = [0.3, 0.4, 0.5, 0.6]

for threshold in thresholds:
    eab.set_entropy_threshold(threshold)
    
    results = eab.generate(
        prompt=prompt,
        max_new_tokens=20,
        temperature=0.9,
        show_progress=False
    )
    
    stats = eab.get_entropy_history()['statistics']
    
    print(f"\nThreshold={threshold}:")
    print(f"  Paths: {len(results)}")
    print(f"  Branches: {stats['num_branches']}")
    print(f"  Branch rate: {stats['branch_rate']:.1%}")

## 7. Visualizing Branching Statistics

Let's create comprehensive visualizations.

In [None]:
# Generate with moderate settings
eab.set_entropy_threshold(0.4)
eab.set_branch_factor(3)

results = eab.generate(
    prompt="The future of artificial intelligence will be",
    max_new_tokens=30,
    temperature=1.0
)

# Convert results to path objects for visualization
from eab.path import GenerationPath
paths = []
for r in results:
    path = GenerationPath(
        tokens=r['tokens'],
        log_prob=r['log_prob'],
        branch_points=r.get('branch_points', []),
        path_id=r.get('path_id')
    )
    paths.append(path)

# Visualize
visualize_branching_statistics(paths)

## 8. Uncertainty Quantification

Let's use EAB for semantic uncertainty analysis.

In [None]:
from collections import Counter

# Generate samples
results = eab.generate(
    prompt="The most important quality in a leader is",
    max_new_tokens=15,
    temperature=0.9
)

# Extract first word of each completion
first_words = [r['text'].strip().split()[0] for r in results if r['text'].strip()]
word_counts = Counter(first_words)

print(f"\nFirst word distribution:")
for word, count in word_counts.most_common(5):
    print(f"  '{word}': {count} times ({count/len(first_words):.1%})")

# Compute uncertainty
import numpy as np
probs = np.array([count for count in word_counts.values()]) / len(first_words)
entropy = -np.sum(probs * np.log(probs))

print(f"\nSemantic uncertainty (entropy): {entropy:.3f}")
print(f"Interpretation: {'High' if entropy > 1.5 else 'Medium' if entropy > 1.0 else 'Low'} uncertainty")

## 9. Summary

Key takeaways:

1. **Factual prompts** → Low branching, low entropy
2. **Ambiguous prompts** → Medium branching, medium entropy  
3. **Creative prompts** → High branching, high entropy

4. **Lower threshold** → More aggressive branching
5. **Higher threshold** → More conservative branching

6. **Efficiency**: Shared computation for all tokens until branching
7. **Diversity**: Multiple samples from single generation pass
8. **Uncertainty**: Branching patterns reveal model confidence

## Next Steps

- Try different models (GPT-2 Medium, Large)
- Experiment with different hyperparameters
- Apply to your own use cases
- Integrate with uncertainty quantification pipelines
- See `examples/` directory for more advanced usage