In [None]:
import pandas as pd
import numpy as np
import sys
sys.path.append("../")
from shared_utils.data import CSVPromptDataset

# Load the dataset to check for duplicates
dataset_path = "../results_and_data/early_exit_sft_dataset/test/data.csv"
prompt_config_path = "../results_and_data/early_exit_sft_dataset/test/prompt_config.json"

# Read the CSV to analyze duplicates
df = pd.read_csv(dataset_path)
print(f"Original dataset size: {len(df)} rows")

# Check for duplicates in the story column
print(f"Number of unique stories: {df['story'].nunique()}")
print(f"Number of duplicate story entries: {len(df) - df['story'].nunique()}")

# Show example of duplicates
duplicate_stories = df[df.duplicated(subset=['story'], keep=False)]
if len(duplicate_stories) > 0:
    print(f"\nExample of duplicate stories:")
    # Show first duplicate story group
    first_dup_story = duplicate_stories.iloc[0]['story']
    examples = df[df['story'] == first_dup_story][['story', 'question']].head()
    print(examples)

# Remove duplicates by keeping one random question per story
# Set random seed for reproducibility
np.random.seed(2)

# Randomly shuffle the dataframe
df_shuffled = df.sample(frac=1, random_state=42).reset_index(drop=True)

# Keep only the first occurrence of each story (which is now random due to shuffle)
df_deduplicated = df_shuffled.drop_duplicates(subset=['story'], keep='first')

print(f"\nAfter deduplication: {len(df_deduplicated)} rows")
print(f"Removed {len(df) - len(df_deduplicated)} duplicate entries")

# Save the deduplicated dataset
output_path = dataset_path.replace('.csv', '_deduplicated.csv')
df_deduplicated.to_csv(output_path, index=False)
print(f"\nSaved deduplicated dataset to: {output_path}")

# Verify the deduplication
print(f"\nVerification:")
print(f"Original unique stories: {df['story'].nunique()}")
print(f"Deduplicated unique stories: {df_deduplicated['story'].nunique()}")
print(f"These should be equal: {df['story'].nunique() == df_deduplicated['story'].nunique()}")

# Show distribution of questions retained
if 'question' in df.columns:
    print(f"\nQuestion distribution in deduplicated data:")
    # If questions have patterns or types, show distribution
    # This is just an example - adjust based on your question format
    question_starts = df_deduplicated['question'].str.split().str[0].value_counts().head(10)
    print(question_starts)

# Update your dataset path for the generation process
print(f"\n⚠️  UPDATE YOUR CODE:")
print(f"Change: dataset_path = '{dataset_path}'")
print(f"To:     dataset_path = '{output_path}'")

In [None]:
import pickle
import gzip
import os
import torch
import gc

def analyze_merged_teacher_data(file_path: str, max_samples: int = 2):

    file_size = os.path.getsize(file_path)
    print(f"File path: {file_path}")
    print(f"File size: {file_size / 1e9:.2f} GB ({file_size / 1e6:.2f} MB)")
    
    sample_count = 0
    total_size_estimate = 0
    tensor_info = {}
    
    with gzip.open(file_path, "rb") as f:
        # Read header
        header = pickle.load(f)
        print("Header/Metadata:")
        if 'metadata' in header:
            for key, value in header['metadata'].items():
                if isinstance(value, str) and len(value) > 100:
                    print(f"  {key}: {value[:100]}...")
                elif isinstance(value, dict):
                    print(f"  {key}: dict with {len(value)} keys")
                else:
                    print(f"  {key}: {value}")
        
        while sample_count < max_samples:
            try:
                sample = pickle.load(f)
                
                # Check if it's the end marker
                if isinstance(sample, dict) and sample.get('_end'):
                    print(f"\nReached end marker. Total samples: {sample.get('num_samples', 'unknown')}")
                    break
                
                sample_count += 1
                print(f"\nSample {sample_count}:")
                print(f"  Type: {type(sample)}")
                
                if isinstance(sample, dict):
                    print(f"  Keys: {list(sample.keys())}")
                    
                    sample_size_mb = 0
                    for key, value in sample.items():
                        if isinstance(value, torch.Tensor):
                            size_mb = (value.numel() * value.element_size()) / (1024 * 1024)
                            sample_size_mb += size_mb
                            
                            # Track tensor info
                            if key not in tensor_info:
                                tensor_info[key] = {
                                    'shape': value.shape,
                                    'dtype': value.dtype,
                                    'size_mb': size_mb,
                                    'has_vocab_dim': any(dim > 10000 for dim in value.shape)
                                }
                            
                            print(f"    {key}:")
                            print(f"      Shape: {value.shape}")
                            print(f"      Dtype: {value.dtype}")
                            print(f"      Size: {size_mb:.2f} MB")
                            
                            # Check for vocabulary dimension
                            if any(dim > 10000 for dim in value.shape):
                                print(f"      *** Has vocabulary dimension: {max(value.shape)}")
                        
                        elif isinstance(value, str):
                            print(f"    {key}: string ({len(value)} chars)")
                            if len(value) < 200:
                                print(f"      Content: {value}")
                        
                        elif isinstance(value, (int, float)):
                            print(f"    {key}: {type(value).__name__} = {value}")
                        
                        elif isinstance(value, list):
                            print(f"    {key}: list with {len(value)} items")
                            if value and isinstance(value[0], torch.Tensor):
                                print(f"      First tensor shape: {value[0].shape}")
                        
                        else:
                            print(f"    {key}: {type(value)}")
                    
                    print(f"  Total sample size: ~{sample_size_mb:.2f} MB")
                    total_size_estimate += sample_size_mb
                
                # Clean up
                del sample
                gc.collect()
                
            except EOFError:
                print("\nReached end of file")
                break
            except Exception as e:
                print(f"\nError reading sample: {e}")
                break

def main():
    file_path = '/workspace/data/teacher_generated_data_gzip/merged_teacher_data.pkl.gz'
    analyze_merged_teacher_data(file_path, max_samples=2)

if __name__ == "__main__":
    main()

In [None]:
import pickle
import gzip
import os
import torch
import gc
import numpy as np
import matplotlib.pyplot as plt

def analyze_merged_teacher_data(file_path: str, max_samples: int = 10):
    
    file_size = os.path.getsize(file_path)
    print(f"File path: {file_path}")
    print(f"File size: {file_size / 1e9:.2f} GB ({file_size / 1e6:.2f} MB)")
    
    sample_count = 0
    total_size_estimate = 0
    tensor_info = {}
    
    # Collect logprob statistics
    all_logprob_samples = []
    
    with gzip.open(file_path, "rb") as f:
        # Read header
        header = pickle.load(f)
        print("Header/Metadata:")
        if 'metadata' in header:
            for key, value in header['metadata'].items():
                if isinstance(value, str) and len(value) > 100:
                    print(f"  {key}: {value[:100]}...")
                elif isinstance(value, dict):
                    print(f"  {key}: dict with {len(value)} keys")
                else:
                    print(f"  {key}: {value}")
        
        print("\n" + "="*80)
        print(f"Analyzing first {max_samples} samples for logprob distribution...")
        print("="*80)
        
        while sample_count < max_samples:
            try:
                sample = pickle.load(f)
                
                # Check if it's the end marker
                if isinstance(sample, dict) and sample.get('_end'):
                    print(f"\nReached end marker. Total samples: {sample.get('num_samples', 'unknown')}")
                    break
                
                sample_count += 1
                print(f"\nSample {sample_count}:")
                print(f"  Type: {type(sample)}")
                
                if isinstance(sample, dict):
                    print(f"  Keys: {list(sample.keys())}")
                    
                    sample_size_mb = 0
                    for key, value in sample.items():
                        if isinstance(value, torch.Tensor):
                            size_mb = (value.numel() * value.element_size()) / (1024 * 1024)
                            sample_size_mb += size_mb
                            
                            # Track tensor info
                            if key not in tensor_info:
                                tensor_info[key] = {
                                    'shape': value.shape,
                                    'dtype': value.dtype,
                                    'size_mb': size_mb,
                                    'has_vocab_dim': any(dim > 10000 for dim in value.shape)
                                }
                            
                            print(f"    {key}:")
                            print(f"      Shape: {value.shape}")
                            print(f"      Dtype: {value.dtype}")
                            print(f"      Size: {size_mb:.2f} MB")
                            
                            # Analyze logprob distribution
                            if key == 'sft_teacher_final_layer_logprobs':
                                print(f"      *** Analyzing logprob distribution...")
                                
                                # Sample some values to analyze
                                # Take every 10th position to avoid memory issues
                                sampled_logprobs = value[:, ::10, :].flatten()
                                
                                # Random sample if still too large
                                if sampled_logprobs.numel() > 1_000_000:
                                    indices = torch.randperm(sampled_logprobs.numel())[:1_000_000]
                                    sampled_logprobs = sampled_logprobs[indices]
                                
                                all_logprob_samples.append(sampled_logprobs.cpu())
                                
                                # Quick stats
                                print(f"      Min: {value.min().item():.2f}")
                                print(f"      Max: {value.max().item():.2f}")
                                print(f"      Mean: {value.mean().item():.2f}")
                                print(f"      Median: {value.median().item():.2f}")
                            
                            # Check for vocabulary dimension
                            if any(dim > 10000 for dim in value.shape):
                                print(f"      *** Has vocabulary dimension: {max(value.shape)}")
                        
                        elif isinstance(value, str):
                            print(f"    {key}: string ({len(value)} chars)")
                        
                        elif isinstance(value, (int, float)):
                            print(f"    {key}: {type(value).__name__} = {value}")
                        
                        elif isinstance(value, list):
                            print(f"    {key}: list with {len(value)} items")
                        
                        else:
                            print(f"    {key}: {type(value)}")
                    
                    print(f"  Total sample size: ~{sample_size_mb:.2f} MB")
                    total_size_estimate += sample_size_mb
                
                # Clean up
                del sample
                gc.collect()
                
            except EOFError:
                print("\nReached end of file")
                break
            except Exception as e:
                print(f"\nError reading sample: {e}")
                break
    
    # Analyze combined logprob distribution
    if all_logprob_samples:
        print("\n" + "="*80)
        print("LOGPROB DISTRIBUTION ANALYSIS")
        print("="*80)
        
        # Combine all samples and convert to float32 for analysis
        all_logprobs = torch.cat(all_logprob_samples).float()
        print(f"\nTotal logprob values analyzed: {all_logprobs.numel():,}")
        
        # Calculate percentiles
        percentiles = [1, 5, 10, 25, 50, 75, 90, 95, 99, 99.9, 99.99]
        print("\nPercentiles:")
        for p in percentiles:
            val = torch.quantile(all_logprobs, p/100).item()
            prob = np.exp(val)
            print(f"  {p:6.2f}%: {val:8.2f} (prob: {prob:.2e})")
        
        # Threshold analysis
        print("\nSparsity at different thresholds:")
        thresholds = [0, -1, -2, -5, -8, -10, -12, -15, -20, -25, -30]
        for thresh in thresholds:
            sparsity = (all_logprobs <= thresh).float().mean().item()
            keep_ratio = 1 - sparsity
            compression = 1 / keep_ratio if keep_ratio > 0 else float('inf')
            print(f"  Threshold {thresh:3d}: {sparsity:6.1%} sparse ({keep_ratio:6.1%} kept, {compression:6.1f}x compression)")
        
        # Find threshold for specific compression ratios
        print("\nThresholds for target compression ratios:")
        target_keep_ratios = [0.1, 0.05, 0.01, 0.005, 0.001, 0.0005, 0.0001]
        for keep_ratio in target_keep_ratios:
            threshold = torch.quantile(all_logprobs, 1 - keep_ratio).item()
            compression = 1 / keep_ratio
            print(f"  {compression:6.0f}x compression (keep {keep_ratio*100:5.2f}%): threshold = {threshold:6.2f}")
        
        # Recommendation
        print("\n" + "="*80)
        print("RECOMMENDATION")
        print("="*80)
        print("""
Based on the distribution:
- For ~1000x compression: use threshold around -11 to -12
- For ~500x compression: use threshold around -10
- For ~100x compression: use threshold around -7

Remember: Lower threshold = more aggressive filtering = better compression but more information loss
""")
        
        # Optional: Create histogram
        try:
            plt.figure(figsize=(10, 6))
            plt.hist(all_logprobs.numpy(), bins=100, alpha=0.7, edgecolor='black')
            plt.axvline(-10, color='red', linestyle='--', label='threshold=-10')
            plt.axvline(-15, color='orange', linestyle='--', label='threshold=-15')
            plt.xlabel('Log Probability')
            plt.ylabel('Count')
            plt.title('Distribution of Teacher Log Probabilities')
            plt.legend()
            plt.savefig('/workspace/logprob_distribution.png')
            print("\nHistogram saved to /workspace/logprob_distribution.png")
        except:
            print("\nCouldn't create histogram (matplotlib might not be available)")

def main():
    file_path = '/workspace/data/teacher_generated_data_gzip/merged_teacher_data.pkl.gz'
    analyze_merged_teacher_data(file_path, max_samples=10)

if __name__ == "__main__":
    main()

In [2]:
import pickle
import gzip
import torch
import torch.nn.functional as F
import numpy as np

def test_kl_divergence_sparse(file_path: str, num_samples: int = 30, threshold: float = -12.0):
    """Test KL divergence between original and sparse representations"""
    
    print(f"Testing KL divergence with threshold={threshold}")
    print("=" * 80)

    torch.manual_seed(0)
    
    sample_count = 0
    kl_results = []
    
    with gzip.open(file_path, "rb") as f:
        # Skip header
        header = pickle.load(f)
        
        while sample_count < num_samples:
            try:
                sample = pickle.load(f)
                
                # Check if it's the end marker
                if isinstance(sample, dict) and sample.get('_end'):
                    break
                
                if 'sft_teacher_final_layer_logprobs' in sample:
                    sample_count += 1
                    
                    # Get original logprobs
                    original_logprobs = sample['sft_teacher_final_layer_logprobs']
                    batch_size, seq_len, vocab_size = original_logprobs.shape
                    
                    print(f"\nSample {sample_count}:")
                    print(f"  Shape: {original_logprobs.shape}")
                    
                    # Create sparse version
                    sparse_logprobs = (original_logprobs * (original_logprobs > threshold)).to_sparse()
                    
                    # Calculate sparsity
                    nnz = sparse_logprobs._nnz()
                    total_elements = batch_size * seq_len * vocab_size
                    sparsity = 1 - (nnz / total_elements)
                    print(f"  Sparsity: {sparsity:.2%} ({nnz:,} non-zero out of {total_elements:,})")
                    print(f"  Compression: {total_elements/nnz:.1f}x")
                    
                    # Convert back to dense for comparison
                    reconstructed_logprobs = sparse_logprobs.to_dense()
                    
                    # Fill zeros with very negative value (representing zero probability)
                    mask_zeros = (reconstructed_logprobs == 0)
                    reconstructed_logprobs[mask_zeros] = -50.0
                    
                    # Convert to probabilities
                    original_probs = F.softmax(original_logprobs.float(), dim=-1)
                    reconstructed_probs = F.softmax(reconstructed_logprobs.float(), dim=-1)
                    
                    # Calculate KL divergence for each position
                    # KL(P||Q) = sum(P * log(P/Q))
                    # We'll calculate for a few random positions to avoid memory issues
                    
                    num_positions_to_test = min(10, seq_len)
                    position_indices = torch.randperm(seq_len)[:num_positions_to_test]
                    
                    position_kls = []
                    for pos_idx in position_indices:
                        p = original_probs[0, pos_idx]  # Original distribution
                        q = reconstructed_probs[0, pos_idx]  # Sparse reconstruction
                        
                        # Use PyTorch's KL divergence function which handles numerical issues
                        # First convert back to log space
                        log_p = torch.log(p + 1e-20)
                        log_q = torch.log(q + 1e-20)
                        
                        # KL divergence using stable computation
                        kl_div = F.kl_div(log_q, p, reduction='sum', log_target=False).item()
                        
                        # Alternative manual calculation with better numerical stability
                        # Only calculate KL where p > 0 to avoid numerical issues
                        valid_mask = p > 1e-10
                        if valid_mask.sum() > 0:
                            p_valid = p[valid_mask]
                            q_valid = q[valid_mask]
                            kl_manual = (p_valid * torch.log(p_valid / (q_valid + 1e-20))).sum().item()
                        else:
                            kl_manual = 0.0
                        
                        # Use the manual calculation as it's more stable
                        position_kls.append(kl_manual)
                        
                        # Also calculate how much probability mass we kept
                        kept_mask = original_logprobs[0, pos_idx] > threshold
                        prob_mass_kept = p[kept_mask].sum().item()
                        num_kept = kept_mask.sum().item()
                        
                        print(f"    Position {pos_idx}: KL={kl_manual:.6f}, Prob mass kept={prob_mass_kept:.4f}, Tokens kept={num_kept}/{vocab_size}")
                    
                    # Average KL for this sample
                    avg_kl = np.mean(position_kls)
                    print(f"  Average KL divergence: {avg_kl:.6f}")
                    
                    # Also test with generated tokens specifically
                    if 'sft_teacher_generated_tokens' in sample:
                        generated_tokens = sample['sft_teacher_generated_tokens']
                        print(f"\n  Testing on generated tokens specifically:")
                        
                        # Get probabilities for the actually generated tokens
                        batch_indices = torch.arange(batch_size).unsqueeze(1)
                        seq_indices = torch.arange(min(seq_len, generated_tokens.shape[1]))
                        
                        # Get the probabilities assigned to the generated tokens
                        orig_probs_for_generated = original_probs[batch_indices, seq_indices, generated_tokens[0, :seq_len]]
                        recon_probs_for_generated = reconstructed_probs[batch_indices, seq_indices, generated_tokens[0, :seq_len]]
                        
                        # Check if any generated tokens were filtered out
                        filtered_out = (original_logprobs[batch_indices, seq_indices, generated_tokens[0, :seq_len]] <= threshold).sum()
                        print(f"    Generated tokens filtered out: {filtered_out}/{seq_len}")
                        print(f"    Original prob for generated: min={orig_probs_for_generated.min():.6f}, mean={orig_probs_for_generated.mean():.6f}")
                        print(f"    Reconstructed prob for generated: min={recon_probs_for_generated.min():.6f}, mean={recon_probs_for_generated.mean():.6f}")
                    
                    kl_results.append({
                        'sample_id': sample_count,
                        'shape': original_logprobs.shape,
                        'sparsity': sparsity,
                        'compression': total_elements/nnz,
                        'avg_kl': avg_kl,
                        'position_kls': position_kls
                    })
                    
            except EOFError:
                break
            except Exception as e:
                print(f"Error: {e}")
                import traceback
                traceback.print_exc()
                break
    
    # Summary
    print("\n" + "=" * 80)
    print("SUMMARY")
    print("=" * 80)
    
    if kl_results:
        avg_sparsity = np.mean([r['sparsity'] for r in kl_results])
        avg_compression = np.mean([r['compression'] for r in kl_results])
        avg_kl = np.mean([r['avg_kl'] for r in kl_results])
        all_position_kls = [kl for r in kl_results for kl in r['position_kls']]
        
        print(f"Average sparsity: {avg_sparsity:.2%}")
        print(f"Average compression: {avg_compression:.1f}x")
        print(f"Average KL divergence: {avg_kl:.6f}")
        print(f"KL divergence range: [{min(all_position_kls):.6f}, {max(all_position_kls):.6f}]")
        
        print(f"\nInterpretation:")
        print(f"- KL < 0.01: Excellent reconstruction")
        print(f"- KL < 0.1: Good reconstruction")
        print(f"- KL < 1.0: Acceptable for most purposes")
        print(f"- KL > 1.0: Significant information loss")

def main():
    file_path = '/workspace/data/teacher_generated_data_gzip/merged_teacher_data.pkl.gz'
    test_kl_divergence_sparse(file_path, num_samples=5, threshold=-15.0)

if __name__ == "__main__":
    main()

Testing KL divergence with threshold=-15.0

Sample 1:
  Shape: torch.Size([1, 271, 151936])
  Sparsity: 99.29% (292,580 non-zero out of 41,174,656)
  Compression: 140.7x
    Position 266: KL=0.001491, Prob mass kept=0.9999, Tokens kept=57/151936
    Position 10: KL=0.000193, Prob mass kept=1.0000, Tokens kept=16/151936
    Position 169: KL=0.010000, Prob mass kept=0.9996, Tokens kept=363/151936
    Position 163: KL=0.003043, Prob mass kept=0.9999, Tokens kept=84/151936
    Position 146: KL=0.000054, Prob mass kept=1.0000, Tokens kept=8/151936
    Position 126: KL=0.006999, Prob mass kept=0.9998, Tokens kept=699/151936
    Position 153: KL=0.000040, Prob mass kept=1.0000, Tokens kept=8/151936
    Position 62: KL=0.088796, Prob mass kept=0.9969, Tokens kept=2821/151936
    Position 121: KL=0.000695, Prob mass kept=1.0000, Tokens kept=30/151936
    Position 208: KL=0.027490, Prob mass kept=0.9990, Tokens kept=1059/151936
  Average KL divergence: 0.013880

  Testing on generated tokens spe

In [3]:
def main():
    file_path = '/workspace/data/teacher_generated_data_gzip/merged_teacher_data.pkl.gz'
    test_kl_divergence_sparse(file_path, num_samples=5, threshold=-14.0)

if __name__ == "__main__":
    main()

Testing KL divergence with threshold=-14.0

Sample 1:
  Shape: torch.Size([1, 271, 151936])
  Sparsity: 99.69% (127,153 non-zero out of 41,174,656)
  Compression: 323.8x
    Position 266: KL=0.001869, Prob mass kept=0.9999, Tokens kept=32/151936
    Position 10: KL=0.000292, Prob mass kept=1.0000, Tokens kept=9/151936
    Position 169: KL=0.012788, Prob mass kept=0.9995, Tokens kept=180/151936
    Position 163: KL=0.003699, Prob mass kept=0.9999, Tokens kept=41/151936
    Position 146: KL=0.000137, Prob mass kept=1.0000, Tokens kept=3/151936
    Position 126: KL=0.011046, Prob mass kept=0.9996, Tokens kept=440/151936
    Position 153: KL=0.000101, Prob mass kept=1.0000, Tokens kept=4/151936
    Position 62: KL=0.117000, Prob mass kept=0.9960, Tokens kept=921/151936
    Position 121: KL=0.000800, Prob mass kept=1.0000, Tokens kept=22/151936
    Position 208: KL=0.036588, Prob mass kept=0.9987, Tokens kept=449/151936
  Average KL divergence: 0.018432

  Testing on generated tokens specif

In [4]:
def main():
    file_path = '/workspace/data/teacher_generated_data_gzip/merged_teacher_data.pkl.gz'
    test_kl_divergence_sparse(file_path, num_samples=5, threshold=-13.0)

if __name__ == "__main__":
    main()

Testing KL divergence with threshold=-13.0

Sample 1:
  Shape: torch.Size([1, 271, 151936])
  Sparsity: 99.86% (56,197 non-zero out of 41,174,656)
  Compression: 732.7x
    Position 266: KL=0.002586, Prob mass kept=0.9999, Tokens kept=17/151936
    Position 10: KL=0.000341, Prob mass kept=1.0000, Tokens kept=8/151936
    Position 169: KL=0.016743, Prob mass kept=0.9994, Tokens kept=89/151936
    Position 163: KL=0.004540, Prob mass kept=0.9998, Tokens kept=22/151936
    Position 146: KL=0.000206, Prob mass kept=1.0000, Tokens kept=2/151936
    Position 126: KL=0.018748, Prob mass kept=0.9994, Tokens kept=261/151936
    Position 153: KL=0.000101, Prob mass kept=1.0000, Tokens kept=4/151936
    Position 62: KL=0.141363, Prob mass kept=0.9952, Tokens kept=332/151936
    Position 121: KL=0.001039, Prob mass kept=1.0000, Tokens kept=15/151936
    Position 208: KL=0.046690, Prob mass kept=0.9984, Tokens kept=212/151936
  Average KL divergence: 0.023236

  Testing on generated tokens specific

In [5]:
def main():
    file_path = '/workspace/data/teacher_generated_data_gzip/merged_teacher_data.pkl.gz'
    test_kl_divergence_sparse(file_path, num_samples=5, threshold=-12.0)

if __name__ == "__main__":
    main()

Testing KL divergence with threshold=-12.0

Sample 1:
  Shape: torch.Size([1, 271, 151936])
  Sparsity: 99.94% (26,741 non-zero out of 41,174,656)
  Compression: 1539.8x
    Position 266: KL=0.003577, Prob mass kept=0.9999, Tokens kept=9/151936
    Position 10: KL=0.000516, Prob mass kept=1.0000, Tokens kept=6/151936
    Position 169: KL=0.020890, Prob mass kept=0.9993, Tokens kept=53/151936
    Position 163: KL=0.005183, Prob mass kept=0.9998, Tokens kept=17/151936
    Position 146: KL=0.000206, Prob mass kept=1.0000, Tokens kept=2/151936
    Position 126: KL=0.030562, Prob mass kept=0.9990, Tokens kept=168/151936
    Position 153: KL=0.000524, Prob mass kept=1.0000, Tokens kept=1/151936
    Position 62: KL=0.163167, Prob mass kept=0.9946, Tokens kept=137/151936
    Position 121: KL=0.001843, Prob mass kept=0.9999, Tokens kept=9/151936
    Position 208: KL=0.062464, Prob mass kept=0.9979, Tokens kept=79/151936
  Average KL divergence: 0.028893

  Testing on generated tokens specifical