# DeepConf Branching Experiment

This notebook demonstrates the confidence-based branching approach where traces spawn more branches in high-confidence regions.

In [None]:
# Setup and imports
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath('.')))

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Dict, Any, Optional
import json
import pickle
from datetime import datetime

# Set style for better plots
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

## Visualization Functions

In [None]:
def plot_trace_confidence(trace: Dict[str, Any], ax=None, label=None, alpha=1.0):
    """
    Plot confidence values over token positions for a single trace.
    
    Args:
        trace: Trace dictionary containing 'confs' field
        ax: Matplotlib axis (creates new if None)
        label: Label for the trace
        alpha: Transparency for the plot
    """
    if ax is None:
        fig, ax = plt.subplots(figsize=(12, 6))
    
    if 'confs' not in trace or not trace['confs']:
        return ax
    
    confs = trace['confs']
    positions = np.arange(len(confs))
    
    # Plot confidence line
    line = ax.plot(positions, confs, alpha=alpha, label=label, linewidth=1.5)[0]
    color = line.get_color()
    
    # Mark branch points if they exist
    if 'branch_history' in trace:
        for branch in trace['branch_history']:
            step = branch.get('step', 0)
            conf = branch.get('confidence', 0)
            ax.scatter(step, conf, color=color, s=100, marker='*', 
                      edgecolor='black', linewidth=1, zorder=5)
    
    ax.set_xlabel('Token Position')
    ax.set_ylabel('Confidence Score')
    ax.set_title('Token-Level Confidence Over Generation')
    
    return ax


def plot_confidence_with_sliding_window(trace: Dict[str, Any], window_size: int = 128):
    """
    Plot both raw confidence and sliding window average.
    """
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 10), sharex=True)
    
    if 'confs' not in trace or not trace['confs']:
        return fig
    
    confs = np.array(trace['confs'])
    positions = np.arange(len(confs))
    
    # Plot raw confidence
    ax1.plot(positions, confs, alpha=0.6, label='Raw Confidence', color='blue')
    ax1.set_ylabel('Raw Confidence')
    ax1.set_title(f'Token-Level Confidence (Trace: {trace.get("trace_id", "unknown")})')
    ax1.legend()
    
    # Calculate and plot sliding window average
    if len(confs) >= window_size:
        sliding_avg = np.convolve(confs, np.ones(window_size)/window_size, mode='valid')
        sliding_positions = np.arange(window_size//2, len(confs) - window_size//2)
        
        ax2.plot(sliding_positions, sliding_avg, label=f'Sliding Avg (window={window_size})', 
                color='green', linewidth=2)
        
        # Mark high confidence regions
        threshold = np.percentile(sliding_avg, 75)
        ax2.axhline(y=threshold, color='red', linestyle='--', alpha=0.5, 
                   label=f'75th percentile: {threshold:.3f}')
        
        # Highlight potential branching regions
        high_conf_mask = sliding_avg > threshold
        if np.any(high_conf_mask):
            ax2.fill_between(sliding_positions, 0, sliding_avg, 
                           where=high_conf_mask, alpha=0.3, color='red',
                           label='High Confidence Regions')
    
    ax2.set_xlabel('Token Position')
    ax2.set_ylabel('Sliding Window Confidence')
    ax2.legend()
    
    plt.tight_layout()
    return fig


def plot_branching_tree_confidence(result_data: Dict[str, Any]):
    """
    Plot confidence traces organized by their branching relationships.
    """
    if 'all_traces' not in result_data:
        print("No traces found in result data")
        return
    
    traces = result_data['all_traces']
    
    # Group traces by depth
    depth_groups = {}
    for trace in traces:
        depth = trace.get('depth', 0)
        if depth not in depth_groups:
            depth_groups[depth] = []
        depth_groups[depth].append(trace)
    
    # Create subplots for each depth level
    num_depths = len(depth_groups)
    fig, axes = plt.subplots(num_depths, 1, figsize=(14, 5*num_depths), 
                            sharex=True, squeeze=False)
    
    # Plot traces by depth
    for depth_idx, (depth, traces_at_depth) in enumerate(sorted(depth_groups.items())):
        ax = axes[depth_idx, 0]
        
        for trace_idx, trace in enumerate(traces_at_depth):
            label = f"{trace.get('trace_id', f'trace_{trace_idx}')} (parent: {trace.get('parent_id', 'None')})"
            plot_trace_confidence(trace, ax=ax, label=label, alpha=0.7)
        
        ax.set_title(f'Depth {depth} Traces ({len(traces_at_depth)} traces)')
        ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    
    axes[-1, 0].set_xlabel('Token Position')
    plt.suptitle('Branching Tree: Confidence by Depth', fontsize=16)
    plt.tight_layout()
    return fig


def plot_branching_statistics(result_data: Dict[str, Any]):
    """
    Plot statistics about the branching experiment.
    """
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # Get branching stats
    branching_stats = result_data.get('branching_stats', {})
    all_traces = result_data.get('all_traces', [])
    
    # 1. Confidence distribution at branch points
    ax = axes[0, 0]
    branch_confidences = []
    for trace in all_traces:
        if 'branch_history' in trace:
            for branch in trace['branch_history']:
                branch_confidences.append(branch.get('confidence', 0))
    
    if branch_confidences:
        ax.hist(branch_confidences, bins=20, alpha=0.7, edgecolor='black')
        ax.axvline(np.mean(branch_confidences), color='red', linestyle='--', 
                  label=f'Mean: {np.mean(branch_confidences):.3f}')
        ax.set_xlabel('Confidence at Branch Point')
        ax.set_ylabel('Count')
        ax.set_title('Distribution of Confidence at Branching Points')
        ax.legend()
    
    # 2. Branch position distribution
    ax = axes[0, 1]
    branch_positions = []
    for trace in all_traces:
        if 'branch_history' in trace:
            for branch in trace['branch_history']:
                branch_positions.append(branch.get('step', 0))
    
    if branch_positions:
        ax.hist(branch_positions, bins=30, alpha=0.7, edgecolor='black')
        ax.set_xlabel('Token Position')
        ax.set_ylabel('Branch Count')
        ax.set_title('Distribution of Branching Positions')
    
    # 3. Accuracy by depth
    ax = axes[1, 0]
    evaluation = result_data.get('evaluation', {})
    trace_accuracy = evaluation.get('trace_accuracy', {})
    
    if trace_accuracy:
        depths = sorted(trace_accuracy.keys())
        accuracies = []
        totals = []
        
        for depth in depths:
            stats = trace_accuracy[depth]
            if stats['total'] > 0:
                accuracies.append(stats['correct'] / stats['total'])
                totals.append(stats['total'])
            else:
                accuracies.append(0)
                totals.append(0)
        
        bars = ax.bar(depths, accuracies, alpha=0.7)
        
        # Add count labels on bars
        for bar, total in zip(bars, totals):
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height,
                    f'n={total}', ha='center', va='bottom')
        
        ax.set_xlabel('Trace Depth')
        ax.set_ylabel('Accuracy')
        ax.set_title('Accuracy by Trace Depth')
        ax.set_ylim(0, 1.1)
    
    # 4. Confidence vs Correctness scatter
    ax = axes[1, 1]
    if 'ground_truth' in result_data and result_data['ground_truth']:
        ground_truth = result_data['ground_truth']
        
        avg_confidences = []
        is_correct_list = []
        
        for trace in all_traces:
            if 'confs' in trace and trace['confs'] and 'extracted_answer' in trace:
                avg_conf = np.mean(trace['confs'])
                avg_confidences.append(avg_conf)
                
                # Simple string comparison for correctness
                is_correct = str(trace['extracted_answer']).strip() == str(ground_truth).strip()
                is_correct_list.append(1 if is_correct else 0)
        
        if avg_confidences:
            # Create scatter plot with jitter
            jitter = 0.02
            is_correct_jittered = np.array(is_correct_list) + np.random.normal(0, jitter, len(is_correct_list))
            
            ax.scatter(avg_confidences, is_correct_jittered, alpha=0.5)
            ax.set_xlabel('Average Trace Confidence')
            ax.set_ylabel('Correctness (1=Correct, 0=Incorrect)')
            ax.set_title('Confidence vs Correctness')
            ax.set_ylim(-0.1, 1.1)
            
            # Add trend line
            if len(avg_confidences) > 10:
                z = np.polyfit(avg_confidences, is_correct_list, 1)
                p = np.poly1d(z)
                ax.plot(sorted(avg_confidences), p(sorted(avg_confidences)), 
                       "r--", alpha=0.8, label=f'Trend: {z[0]:.3f}x + {z[1]:.3f}')
                ax.legend()
    
    plt.tight_layout()
    return fig

## Example: Running a Branching Experiment

In [None]:
# Example function to run a simple branching experiment
def run_branching_example(
    question: str,
    ground_truth: str = None,
    model: str = "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B",
    initial_branches: int = 4,
    max_total_branches: int = 16,
    confidence_threshold: float = 1.5,
    window_size: int = 128
):
    """
    Run a branching experiment and return results with visualizations.
    """
    try:
        from deepconf.branching_wrapper import BranchingDeepThinkLLM
        from vllm import SamplingParams
    except ImportError:
        print("Error: Please ensure deepconf and vllm are installed")
        print("You may need to run: pip install -e . from the deepconf directory")
        return None
    
    print(f"Initializing model: {model}")
    branching_llm = BranchingDeepThinkLLM(model=model, enable_prefix_caching=True)
    
    # Prepare prompt
    messages = [{"role": "user", "content": question}]
    prompt = branching_llm.tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    
    # Set up sampling parameters
    sampling_params = SamplingParams(
        temperature=0.6,
        top_p=0.95,
        max_tokens=8000,
        logprobs=20,
    )
    
    print(f"\nRunning branching experiment...")
    print(f"  Initial branches: {initial_branches}")
    print(f"  Max total branches: {max_total_branches}")
    print(f"  Confidence threshold: {confidence_threshold}")
    
    # Run the experiment
    result = branching_llm.branching_deepthink(
        prompt=prompt,
        initial_branches=initial_branches,
        max_total_branches=max_total_branches,
        confidence_threshold=confidence_threshold,
        window_size=window_size,
        sampling_params=sampling_params
    )
    
    # Convert to dict for visualization
    result_data = result.to_dict()
    result_data['ground_truth'] = ground_truth
    
    # Simple evaluation if ground truth provided
    if ground_truth:
        evaluation = {'trace_accuracy': {}}
        depth_stats = {}
        
        for trace in result.all_traces:
            depth = trace.get('depth', 0)
            if depth not in depth_stats:
                depth_stats[depth] = {'correct': 0, 'total': 0}
            
            if trace.get('extracted_answer'):
                depth_stats[depth]['total'] += 1
                if str(trace['extracted_answer']).strip() == str(ground_truth).strip():
                    depth_stats[depth]['correct'] += 1
        
        evaluation['trace_accuracy'] = depth_stats
        result_data['evaluation'] = evaluation
    
    return result_data

## Demo: Math Problem with Branching

In [None]:
# Example math problem
question = "What is 15% of 240?"
ground_truth = "36"

# Run experiment (Note: This requires a running model)
# Uncomment the following lines if you have the model available:

# result_data = run_branching_example(
#     question=question,
#     ground_truth=ground_truth,
#     initial_branches=2,
#     max_total_branches=8,
#     confidence_threshold=1.5
# )

In [None]:
# Load example data (if you have a saved result file)
# Replace with your actual result file path
example_result_path = "outputs/branching_example.pkl"

try:
    with open(example_result_path, 'rb') as f:
        result_data = pickle.load(f)
    print("Loaded example result data")
except FileNotFoundError:
    print(f"No example file found at {example_result_path}")
    print("Creating synthetic example data for visualization...")
    
    # Create synthetic example data
    np.random.seed(42)
    result_data = {
        'all_traces': [],
        'ground_truth': '36',
        'branching_stats': {
            'total_branches': 4,
            'avg_confidence_at_branch': 2.1
        }
    }
    
    # Generate synthetic traces
    for i in range(6):
        trace_len = np.random.randint(500, 1500)
        confs = np.random.gamma(2, 0.5, trace_len) + np.sin(np.linspace(0, 4*np.pi, trace_len)) * 0.3
        confs = np.clip(confs, 0.5, 4.0)
        
        trace = {
            'trace_id': f'trace_{i}',
            'parent_id': 'None' if i < 2 else f'trace_{i//2 - 1}',
            'depth': 0 if i < 2 else 1,
            'confs': confs.tolist(),
            'extracted_answer': '36' if i % 3 == 0 else '40',
            'branch_history': []
        }
        
        # Add branch points for child traces
        if i >= 2:
            branch_point = np.argmax(confs[200:800]) + 200
            trace['branch_history'] = [{
                'step': branch_point,
                'confidence': confs[branch_point],
                'parent_trace': trace['parent_id']
            }]
        
        result_data['all_traces'].append(trace)

## Visualize Results

In [None]:
# Plot individual trace with sliding window
if result_data and result_data['all_traces']:
    # Show confidence pattern for the first trace
    first_trace = result_data['all_traces'][0]
    fig = plot_confidence_with_sliding_window(first_trace, window_size=128)
    plt.show()

In [None]:
# Plot branching tree structure
if result_data:
    fig = plot_branching_tree_confidence(result_data)
    plt.show()

In [None]:
# Plot branching statistics
if result_data:
    fig = plot_branching_statistics(result_data)
    plt.show()

## Analysis Functions

In [None]:
def analyze_confidence_patterns(result_data: Dict[str, Any]):
    """
    Analyze and print confidence patterns in the branching experiment.
    """
    if 'all_traces' not in result_data:
        return
    
    traces = result_data['all_traces']
    
    # Overall statistics
    all_confidences = []
    for trace in traces:
        if 'confs' in trace and trace['confs']:
            all_confidences.extend(trace['confs'])
    
    if all_confidences:
        print("=== Overall Confidence Statistics ===")
        print(f"Mean confidence: {np.mean(all_confidences):.3f}")
        print(f"Std confidence: {np.std(all_confidences):.3f}")
        print(f"Min confidence: {np.min(all_confidences):.3f}")
        print(f"Max confidence: {np.max(all_confidences):.3f}")
        print(f"Median confidence: {np.median(all_confidences):.3f}")
    
    # Confidence by depth
    print("\n=== Confidence by Depth ===")
    depth_confidences = {}
    for trace in traces:
        depth = trace.get('depth', 0)
        if 'confs' in trace and trace['confs']:
            if depth not in depth_confidences:
                depth_confidences[depth] = []
            depth_confidences[depth].extend(trace['confs'])
    
    for depth in sorted(depth_confidences.keys()):
        confs = depth_confidences[depth]
        print(f"Depth {depth}: mean={np.mean(confs):.3f}, std={np.std(confs):.3f}, n={len(confs)}")
    
    # Branch point analysis
    print("\n=== Branch Point Analysis ===")
    branch_confidences = []
    branch_positions = []
    
    for trace in traces:
        if 'branch_history' in trace:
            for branch in trace['branch_history']:
                branch_confidences.append(branch.get('confidence', 0))
                branch_positions.append(branch.get('step', 0))
    
    if branch_confidences:
        print(f"Number of branches: {len(branch_confidences)}")
        print(f"Mean confidence at branch: {np.mean(branch_confidences):.3f}")
        print(f"Mean position of branch: {np.mean(branch_positions):.1f} tokens")
        print(f"Earliest branch: token {min(branch_positions)}")
        print(f"Latest branch: token {max(branch_positions)}")

# Run analysis
if result_data:
    analyze_confidence_patterns(result_data)

## Interactive Confidence Explorer

In [None]:
def plot_confidence_heatmap(result_data: Dict[str, Any], max_length: int = 1000):
    """
    Create a heatmap showing confidence patterns across all traces.
    """
    if 'all_traces' not in result_data:
        return
    
    traces = result_data['all_traces']
    
    # Prepare data matrix
    confidence_matrix = []
    trace_labels = []
    
    for trace in traces:
        if 'confs' in trace and trace['confs']:
            confs = trace['confs'][:max_length]
            # Pad with NaN if shorter
            if len(confs) < max_length:
                confs = confs + [np.nan] * (max_length - len(confs))
            confidence_matrix.append(confs)
            trace_labels.append(f"{trace.get('trace_id', 'unknown')} (d={trace.get('depth', 0)})")
    
    if confidence_matrix:
        fig, ax = plt.subplots(figsize=(16, 8))
        
        # Create heatmap
        im = ax.imshow(confidence_matrix, aspect='auto', cmap='RdYlBu_r', 
                      interpolation='nearest')
        
        # Set labels
        ax.set_yticks(range(len(trace_labels)))
        ax.set_yticklabels(trace_labels)
        ax.set_xlabel('Token Position')
        ax.set_ylabel('Trace ID')
        ax.set_title('Confidence Heatmap Across All Traces')
        
        # Add colorbar
        cbar = plt.colorbar(im, ax=ax)
        cbar.set_label('Confidence Score')
        
        plt.tight_layout()
        plt.show()

# Create confidence heatmap
if result_data:
    plot_confidence_heatmap(result_data, max_length=800)

## Summary

This notebook demonstrates:

1. **Confidence Visualization**: Multiple ways to visualize token-level confidence
2. **Branching Analysis**: Understanding where and why branches occur
3. **Performance Comparison**: Comparing accuracy between base and branched traces
4. **Pattern Detection**: Finding high-confidence regions suitable for branching

Key insights from branching experiments:
- High confidence regions often correspond to "certain" reasoning steps
- Branching from high-confidence states can explore alternative paths from strong foundations
- The effectiveness depends on the confidence threshold and branching parameters