# Exploring Extracted Logits

This notebook demonstrates how to explore and visualize extracted logit data.

In [None]:
# Import all exploration functions
from explore_logits import *
import matplotlib.pyplot as plt
import numpy as np

# Set up plotting style
plt.style.use('default')
%matplotlib inline

## 1. Load Dataset

In [None]:
# Replace with your dataset ID
DATASET_ID = "seba/devLogits-Q3-0.6B"

# Load the dataset
ds = load_logit_dataset(DATASET_ID)
print(f"Dataset has {len(ds)} samples")

## 2. Explore a Single Sample

In [None]:
# Get a sample
sample = get_sample(ds, index=0)

# Print information about the sample
print_sample_info(sample, verbose=True)

## 3. Visualize Nucleus Sizes

In [None]:
# Show how nucleus size varies across tokens
plot_nucleus_sizes(sample)

## 4. Inspect a Specific Token

In [None]:
# Choose a token to inspect
TOKEN_IDX = 0

# Plot logit distribution for this token
plot_logit_distribution(sample, token_idx=TOKEN_IDX, show_top_n=30)

In [None]:
# Compare nucleus vs sampled distributions
plot_nucleus_vs_sampled(sample, token_idx=TOKEN_IDX)

## 5. Dequantize and Inspect Logits

In [None]:
# Dequantize nucleus logits
nucleus_indices, nucleus_logits = dequantize_top_logits(sample, token_idx=TOKEN_IDX)

print(f"Nucleus has {len(nucleus_indices)} tokens")
print(f"Top 5 logit values: {nucleus_logits[:5]}")
print(f"Top 5 token indices: {nucleus_indices[:5]}")

In [None]:
# Dequantize sampled logits
sampled_indices, sampled_logits = dequantize_sampled_logits(sample, token_idx=TOKEN_IDX)

print(f"Sampled has {len(sampled_indices)} tokens")
print(f"Sampled logit range: [{sampled_logits.min():.3f}, {sampled_logits.max():.3f}]")

In [None]:
# Get all logits at once
logits_dict = dequantize_all_logits(sample, token_idx=TOKEN_IDX)

print("Keys in logits_dict:", logits_dict.keys())
print(f"LogSumExp: {logits_dict['logsumexp']:.3f}")

# Check nucleus probability mass
nucleus_mass = get_nucleus_probability_mass(sample, token_idx=TOKEN_IDX)
print(f"\nNucleus captures {nucleus_mass*100:.2f}% of probability mass")

## 6. Decode Tokens (Optional - requires model)

In [None]:
# Load tokenizer to decode token IDs to strings
from transformers import AutoTokenizer

# Replace with the model you used for extraction
MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

In [None]:
# Print top tokens with their probabilities
print_top_tokens(sample, token_idx=TOKEN_IDX, tokenizer=tokenizer, top_k=20)

## 7. Compute Dataset Statistics

In [None]:
# Compute stats across the dataset (may take a while for large datasets)
stats = compute_dataset_stats(ds, max_samples=1000)  # Limit to first 1000 samples

# Print statistics
print_dataset_stats(stats)

## 8. Quick Exploration (All-in-One)

For quick analysis, use the `quick_explore` function:

In [None]:
# Quick exploration with all visualizations
quick_explore(
    dataset_id=DATASET_ID,
    sample_idx=0,
    token_idx=0,
    model_id=MODEL_ID  # Optional, for token decoding
)

## 9. Custom Analysis

Example: Compare logit distributions across multiple tokens

In [None]:
# Compare top token probabilities across positions
token_positions = range(min(10, sample['num_tokens']))
top_probs = []

for token_idx in token_positions:
    indices, logits = dequantize_top_logits(sample, token_idx)
    lse = sample['logsumexp'][token_idx]
    probs = logits_to_probs(logits, lse)
    top_probs.append(probs[0])  # Probability of top token

plt.figure(figsize=(10, 5))
plt.plot(token_positions, top_probs, marker='o', linewidth=2)
plt.xlabel('Token Position')
plt.ylabel('Top Token Probability')
plt.title('Confidence (Top Token Probability) Across Sequence')
plt.grid(alpha=0.3)
plt.show()

print(f"Average top token probability: {np.mean(top_probs):.3f}")
print(f"Min: {min(top_probs):.3f}, Max: {max(top_probs):.3f}")

## 10. Export Data for Further Analysis

In [None]:
# Example: Export top-5 tokens for each position to CSV
import pandas as pd

data = []
for token_idx in range(sample['num_tokens']):
    indices, logits = dequantize_top_logits(sample, token_idx)
    lse = sample['logsumexp'][token_idx]
    probs = logits_to_probs(logits, lse)
    
    for rank in range(min(5, len(indices))):
        data.append({
            'token_position': token_idx,
            'rank': rank + 1,
            'token_id': indices[rank],
            'logit': logits[rank],
            'probability': probs[rank]
        })

df = pd.DataFrame(data)
print(df.head(10))

# Optionally save to CSV
# df.to_csv('top_tokens.csv', index=False)