# Entropy Convergence Analysis

Comparing basic_entropy_sum vs diag_rb_entropy_sum convergence with batch size.

**Research Question**: Does RB entropy provide better (lower variance) estimates than basic entropy?

In [None]:
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

# Load the data
data_file = "data/entropy_study_1.json"
print(f"Loading data from: {data_file}")

with open(data_file, 'r') as f:
    results = json.load(f)

print(f"Loaded {len(results['per_sequence_data'])} sequences")
print(f"Experiment info: {results['experiment_info']}")

## Data Inspection

Let's first understand what we're working with and check for data quality issues.

In [None]:
# Extract per-sequence data
per_seq_data = results['per_sequence_data']
df = pd.DataFrame(per_seq_data)

print("Dataset shape:", df.shape)
print("\nColumns:", list(df.columns))
print("\nFirst few rows:")
df.head()

In [None]:
# Check for data quality issues
print("=== DATA QUALITY CHECKS ===")
print(f"Total sequences: {len(df)}")
print(f"Missing basic_entropy_sum: {df['basic_entropy_sum'].isna().sum()}")
print(f"Missing diag_rb_entropy_sum: {df['diag_rb_entropy_sum'].isna().sum()}")
print(f"Zero basic_entropy_sum: {(df['basic_entropy_sum'] == 0).sum()}")
print(f"Zero diag_rb_entropy_sum: {(df['diag_rb_entropy_sum'] == 0).sum()}")
print(f"Very small basic_entropy_sum (<0.01): {(df['basic_entropy_sum'] < 0.01).sum()}")
print(f"Very small diag_rb_entropy_sum (<0.01): {(df['diag_rb_entropy_sum'] < 0.01).sum()}")

print("\n=== BASIC ENTROPY STATISTICS ===")
print(df['basic_entropy_sum'].describe())

print("\n=== RB ENTROPY STATISTICS ===")
print(df['diag_rb_entropy_sum'].describe())

In [None]:
# Check if data is shuffled or ordered
print("=== DATA ORDERING CHECK ===")
print("First 10 basic entropy values:")
print(df['basic_entropy_sum'].head(10).tolist())
print("\nLast 10 basic entropy values:")
print(df['basic_entropy_sum'].tail(10).tolist())

# Check correlation between index and entropy (would indicate ordering)
index_corr_basic = np.corrcoef(range(len(df)), df['basic_entropy_sum'])[0,1]
index_corr_rb = np.corrcoef(range(len(df)), df['diag_rb_entropy_sum'])[0,1]
print(f"\nCorrelation between index and basic entropy: {index_corr_basic:.4f}")
print(f"Correlation between index and RB entropy: {index_corr_rb:.4f}")
if abs(index_corr_basic) > 0.1 or abs(index_corr_rb) > 0.1:
    print("⚠️ WARNING: Data might be ordered, which could affect batch analysis!")

## Batch Convergence Analysis

Now let's carefully implement the batch convergence analysis, being explicit about what we're computing.

In [None]:
def analyze_batch_convergence(basic_vals, rb_vals, batch_sizes, shuffle_data=False):
    """
    Analyze how batch means and their variance change with batch size.
    
    For each batch size B:
    1. Divide data into batches of size B
    2. Compute mean of each batch
    3. Compute mean and std of the batch means
    
    The key insight:
    - Mean of batch means ≈ overall mean (should be similar across batch sizes)
    - Std of batch means should decrease as B increases (law of large numbers)
    - RB entropy should have lower std than basic entropy (variance reduction)
    """
    
    if shuffle_data:
        # Shuffle to remove any ordering effects
        indices = np.random.permutation(len(basic_vals))
        basic_vals = basic_vals[indices]
        rb_vals = rb_vals[indices]
        print("🔀 Data shuffled to remove ordering effects")
    
    N = len(basic_vals)
    print(f"Total samples: {N}")
    
    # Overall statistics for reference
    overall_basic_mean = np.mean(basic_vals)
    overall_rb_mean = np.mean(rb_vals)
    print(f"Overall basic entropy mean: {overall_basic_mean:.4f}")
    print(f"Overall RB entropy mean: {overall_rb_mean:.4f}")
    print()
    
    results = []
    
    for B in batch_sizes:
        n_batches = N // B
        if n_batches < 2:
            print(f"Batch size B={B}: Skipping (would give <2 batches)")
            continue
            
        print(f"Batch size B={B}: {n_batches} complete batches")
        
        # Compute batch means
        basic_batch_means = []
        rb_batch_means = []
        
        for i in range(n_batches):
            start_idx = i * B
            end_idx = start_idx + B
            
            basic_batch_mean = np.mean(basic_vals[start_idx:end_idx])
            rb_batch_mean = np.mean(rb_vals[start_idx:end_idx])
            
            basic_batch_means.append(basic_batch_mean)
            rb_batch_means.append(rb_batch_mean)
        
        # Statistics across batches
        basic_mean = np.mean(basic_batch_means)
        basic_std = np.std(basic_batch_means, ddof=1)
        
        rb_mean = np.mean(rb_batch_means)
        rb_std = np.std(rb_batch_means, ddof=1)
        
        # Variance reduction
        var_reduction = (basic_std**2 - rb_std**2) / basic_std**2 * 100
        
        print(f"  Basic entropy - Mean: {basic_mean:.6f}, Std: {basic_std:.6f}")
        print(f"  RB entropy    - Mean: {rb_mean:.6f}, Std: {rb_std:.6f}")
        print(f"  Variance reduction: {var_reduction:.1f}%")
        
        # Check if means are suspiciously identical
        if abs(basic_mean - overall_basic_mean) < 1e-10:
            print(f"  ⚠️ WARNING: Batch mean is EXACTLY equal to overall mean (to 10 decimal places)")
        
        print(f"  Standard error of mean: Basic={basic_std/np.sqrt(n_batches):.6f}, RB={rb_std/np.sqrt(n_batches):.6f}")
        print()
        
        results.append({
            'batch_size': B,
            'n_batches': n_batches,
            'basic_mean': basic_mean,
            'basic_std': basic_std,
            'rb_mean': rb_mean,
            'rb_std': rb_std,
            'var_reduction': var_reduction,
            'basic_batch_means': basic_batch_means.copy(),
            'rb_batch_means': rb_batch_means.copy()
        })
    
    return results

# Run analysis
basic_entropies = df['basic_entropy_sum'].values
rb_entropies = df['diag_rb_entropy_sum'].values
batch_sizes = [32, 64, 128, 256, 512]

print("=== ANALYSIS WITH ORIGINAL DATA ORDER ===")
results_original = analyze_batch_convergence(basic_entropies, rb_entropies, batch_sizes, shuffle_data=False)

In [None]:
print("\n=== ANALYSIS WITH SHUFFLED DATA ===")
np.random.seed(42)  # For reproducibility
results_shuffled = analyze_batch_convergence(basic_entropies, rb_entropies, batch_sizes, shuffle_data=True)

## Visualization

In [None]:
# Create comprehensive plots
fig = plt.figure(figsize=(20, 15))

# Use the shuffled results for plotting (more reliable)
results = results_shuffled
batch_sizes_plot = [r['batch_size'] for r in results]

# Plot 1: Batch means convergence
ax1 = plt.subplot(2, 3, 1)
basic_means = [r['basic_mean'] for r in results]
rb_means = [r['rb_mean'] for r in results]
plt.plot(batch_sizes_plot, basic_means, 'o-', label='Basic Entropy', linewidth=2, markersize=8)
plt.plot(batch_sizes_plot, rb_means, 's-', label='RB Entropy', linewidth=2, markersize=8)
plt.axhline(np.mean(basic_entropies), color='blue', linestyle='--', alpha=0.5, label='True Basic Mean')
plt.axhline(np.mean(rb_entropies), color='orange', linestyle='--', alpha=0.5, label='True RB Mean')
plt.xlabel('Batch Size')
plt.ylabel('Mean of Batch Means')
plt.title('Convergence of Batch Means')
plt.legend()
plt.grid(True, alpha=0.3)
plt.xscale('log', base=2)

# Plot 2: Standard deviation (convergence rate)
ax2 = plt.subplot(2, 3, 2)
basic_stds = [r['basic_std'] for r in results]
rb_stds = [r['rb_std'] for r in results]
plt.plot(batch_sizes_plot, basic_stds, 'o-', label='Basic Entropy', linewidth=2, markersize=8)
plt.plot(batch_sizes_plot, rb_stds, 's-', label='RB Entropy', linewidth=2, markersize=8)
plt.xlabel('Batch Size')
plt.ylabel('Std of Batch Means')
plt.title('Convergence Rate (Lower = Better)')
plt.legend()
plt.grid(True, alpha=0.3)
plt.xscale('log', base=2)
plt.yscale('log')

# Plot 3: Variance reduction
ax3 = plt.subplot(2, 3, 3)
var_reductions = [r['var_reduction'] for r in results]
plt.plot(batch_sizes_plot, var_reductions, 'o-', color='green', linewidth=2, markersize=8)
plt.axhline(0, color='black', linestyle='--', alpha=0.5)
plt.xlabel('Batch Size')
plt.ylabel('Variance Reduction (%)')
plt.title('RB Entropy Variance Reduction')
plt.grid(True, alpha=0.3)
plt.xscale('log', base=2)

# Plot 4: Distribution of basic entropies
ax4 = plt.subplot(2, 3, 4)
plt.hist(basic_entropies, bins=50, alpha=0.7, density=True, color='blue', label='Basic Entropy')
mean_basic = np.mean(basic_entropies)
std_basic = np.std(basic_entropies)
plt.axvline(mean_basic, color='red', linestyle='--', label=f'Mean: {mean_basic:.3f}')
plt.axvline(mean_basic + std_basic, color='red', linestyle=':', alpha=0.7)
plt.axvline(mean_basic - std_basic, color='red', linestyle=':', alpha=0.7)
plt.xlabel('Basic Entropy Sum')
plt.ylabel('Density')
plt.title('Distribution of Basic Entropy')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot 5: Distribution of RB entropies
ax5 = plt.subplot(2, 3, 5)
plt.hist(rb_entropies, bins=50, alpha=0.7, density=True, color='green', label='RB Entropy')
mean_rb = np.mean(rb_entropies)
std_rb = np.std(rb_entropies)
plt.axvline(mean_rb, color='red', linestyle='--', label=f'Mean: {mean_rb:.3f}')
plt.axvline(mean_rb + std_rb, color='red', linestyle=':', alpha=0.7)
plt.axvline(mean_rb - std_rb, color='red', linestyle=':', alpha=0.7)
plt.xlabel('RB Entropy Sum')
plt.ylabel('Density')
plt.title('Distribution of RB Entropy')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot 6: Scatter plot
ax6 = plt.subplot(2, 3, 6)
plt.scatter(basic_entropies, rb_entropies, alpha=0.5, s=20)
plt.xlabel('Basic Entropy Sum')
plt.ylabel('RB Entropy Sum')
plt.title('Basic vs RB Entropy Correlation')
corr = np.corrcoef(basic_entropies, rb_entropies)[0,1]
plt.text(0.05, 0.95, f'Correlation: {corr:.3f}', transform=ax6.transAxes, 
         bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Detailed Investigation of Near-Zero Values

In [None]:
# Investigate the near-zero entropy values
print("=== INVESTIGATING NEAR-ZERO VALUES ===")

# Find sequences with very low entropy
low_basic_mask = basic_entropies < 0.1
low_rb_mask = rb_entropies < 0.1

print(f"Sequences with basic entropy < 0.1: {low_basic_mask.sum()}")
print(f"Sequences with RB entropy < 0.1: {low_rb_mask.sum()}")

if low_basic_mask.sum() > 0:
    print("\nExamples of low basic entropy sequences:")
    low_basic_indices = np.where(low_basic_mask)[0][:5]
    for idx in low_basic_indices:
        seq = per_seq_data[idx]
        print(f"  Index {idx}: basic={seq['basic_entropy_sum']:.6f}, rb={seq['diag_rb_entropy_sum']:.6f}")
        print(f"    Response length: {seq['response_length_tokens']} tokens")
        print(f"    Response text: '{seq['response_text'][:100]}...'")
        print()

# Check if low entropy correlates with short responses
response_lengths = df['response_length_tokens'].values
print(f"Correlation between basic entropy and response length: {np.corrcoef(basic_entropies, response_lengths)[0,1]:.3f}")
print(f"Correlation between RB entropy and response length: {np.corrcoef(rb_entropies, response_lengths)[0,1]:.3f}")

## Summary and Conclusions

In [None]:
print("=== FINAL SUMMARY ===")
print(f"Total sequences analyzed: {len(df)}")
print(f"Basic entropy: mean={np.mean(basic_entropies):.4f}, std={np.std(basic_entropies):.4f}")
print(f"RB entropy: mean={np.mean(rb_entropies):.4f}, std={np.std(rb_entropies):.4f}")
print(f"Correlation: {np.corrcoef(basic_entropies, rb_entropies)[0,1]:.4f}")
print()

print("Variance reduction by batch size:")
for r in results_shuffled:
    print(f"  B={r['batch_size']:>3}: {r['var_reduction']:>6.1f}% (Basic std: {r['basic_std']:.4f}, RB std: {r['rb_std']:.4f})")

print()
avg_var_reduction = np.mean([r['var_reduction'] for r in results_shuffled])
print(f"Average variance reduction: {avg_var_reduction:.1f}%")

if avg_var_reduction > 30:
    print("✅ RB entropy shows significant variance reduction!")
elif avg_var_reduction > 10:
    print("⚠️ RB entropy shows moderate variance reduction.")
else:
    print("❌ RB entropy shows minimal variance reduction.")

# Check for the "exactly identical means" issue
basic_means = [r['basic_mean'] for r in results_shuffled]
rb_means = [r['rb_mean'] for r in results_shuffled]

basic_mean_range = max(basic_means) - min(basic_means)
rb_mean_range = max(rb_means) - min(rb_means)

print(f"\nRange of batch means across different batch sizes:")
print(f"  Basic entropy: {basic_mean_range:.8f}")
print(f"  RB entropy: {rb_mean_range:.8f}")

if basic_mean_range < 1e-6 or rb_mean_range < 1e-6:
    print("⚠️ Means are suspiciously identical across batch sizes - this suggests a potential issue with the analysis.")
else:
    print("✅ Batch means show reasonable variation across batch sizes.")