# Dataset 00: Same Prompt Visualization

This notebook visualizes the dataset generated using a single prompt template.

In [None]:
import numpy as np
import json
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

## Load Dataset

In [None]:
# Load data
data_dir = Path("data")

if not data_dir.exists():
    print("Data directory not found. Please run generate.py first.")
else:
    # Load training data
    train_hidden_states = np.load(data_dir / "train_hidden_states.npy")
    train_remaining_tokens = np.load(data_dir / "train_remaining_tokens.npy")
    train_token_metadata = np.load(data_dir / "train_token_metadata.npy")
    
    # Load validation data
    val_hidden_states = np.load(data_dir / "val_hidden_states.npy")
    val_remaining_tokens = np.load(data_dir / "val_remaining_tokens.npy")
    val_token_metadata = np.load(data_dir / "val_token_metadata.npy")
    
    # Load metadata if exists
    metadata_path = data_dir / "metadata.json"
    if metadata_path.exists():
        with open(metadata_path, 'r') as f:
            metadata = json.load(f)
        print("Dataset Metadata:")
        for key, value in metadata.items():
            if key != 'prompt_usage_distribution':
                print(f"  {key}: {value}")
    
    print(f"\nData shapes:")
    print(f"  Train hidden states: {train_hidden_states.shape}")
    print(f"  Train remaining tokens: {train_remaining_tokens.shape}")
    print(f"  Val hidden states: {val_hidden_states.shape}")
    print(f"  Val remaining tokens: {val_remaining_tokens.shape}")

## Distribution of Remaining Tokens

In [None]:
if 'train_remaining_tokens' in locals():
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    # Training set distribution
    axes[0].hist(train_remaining_tokens, bins=50, alpha=0.7, edgecolor='black')
    axes[0].set_xlabel('Remaining Tokens')
    axes[0].set_ylabel('Frequency')
    axes[0].set_title(f'Training Set Distribution (n={len(train_remaining_tokens):,})')
    axes[0].grid(True, alpha=0.3)
    
    # Validation set distribution
    axes[1].hist(val_remaining_tokens, bins=50, alpha=0.7, edgecolor='black', color='orange')
    axes[1].set_xlabel('Remaining Tokens')
    axes[1].set_ylabel('Frequency')
    axes[1].set_title(f'Validation Set Distribution (n={len(val_remaining_tokens):,})')
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Statistics
    print("\nStatistics:")
    print(f"Training set:")
    print(f"  Min: {train_remaining_tokens.min()}, Max: {train_remaining_tokens.max()}")
    print(f"  Mean: {train_remaining_tokens.mean():.2f}, Std: {train_remaining_tokens.std():.2f}")
    print(f"\nValidation set:")
    print(f"  Min: {val_remaining_tokens.min()}, Max: {val_remaining_tokens.max()}")
    print(f"  Mean: {val_remaining_tokens.mean():.2f}, Std: {val_remaining_tokens.std():.2f}")

## Hidden State Analysis

In [None]:
if 'train_hidden_states' in locals():
    # Compute statistics
    train_norms = np.linalg.norm(train_hidden_states, axis=1)
    val_norms = np.linalg.norm(val_hidden_states, axis=1)
    
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    # Norm distributions
    axes[0].hist(train_norms, bins=50, alpha=0.5, label='Train', edgecolor='black')
    axes[0].hist(val_norms, bins=50, alpha=0.5, label='Val', edgecolor='black')
    axes[0].set_xlabel('L2 Norm')
    axes[0].set_ylabel('Frequency')
    axes[0].set_title('Hidden State Norm Distribution')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Norm vs remaining tokens scatter
    scatter_sample = np.random.choice(len(train_remaining_tokens), min(5000, len(train_remaining_tokens)), replace=False)
    axes[1].scatter(train_remaining_tokens[scatter_sample], train_norms[scatter_sample], alpha=0.3, s=1)
    axes[1].set_xlabel('Remaining Tokens')
    axes[1].set_ylabel('Hidden State Norm')
    axes[1].set_title('Hidden State Norm vs Remaining Tokens')
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

## Token Distribution Analysis

In [None]:
if 'train_token_metadata' in locals():
    # Analyze most common tokens
    from collections import Counter
    
    train_tokens = [item['token_text'] for item in train_token_metadata]
    token_counts = Counter(train_tokens)
    
    # Show top 20 most common tokens
    print("Top 20 most common tokens in training set:")
    for token, count in token_counts.most_common(20):
        percentage = (count / len(train_tokens)) * 100
        # Handle special characters for display
        display_token = repr(token) if token.strip() != token or not token.strip() else token
        print(f"  {display_token:20s}: {count:6d} ({percentage:5.2f}%)")
    
    # Visualize token distribution
    top_tokens = [t for t, _ in token_counts.most_common(15)]
    top_counts = [token_counts[t] for t in top_tokens]
    
    plt.figure(figsize=(12, 6))
    plt.bar(range(len(top_tokens)), top_counts)
    plt.xticks(range(len(top_tokens)), [repr(t) if t.strip() != t else t for t in top_tokens], rotation=45, ha='right')
    plt.xlabel('Token')
    plt.ylabel('Frequency')
    plt.title('Top 15 Most Common Tokens')
    plt.grid(True, alpha=0.3, axis='y')
    plt.tight_layout()
    plt.show()

## Sample Data Points

In [None]:
if 'train_hidden_states' in locals():
    # Show some random samples
    print("Random samples from training data:")
    print("="*60)
    
    np.random.seed(42)
    sample_indices = np.random.choice(len(train_hidden_states), 10, replace=False)
    
    for i, idx in enumerate(sample_indices):
        hidden_state = train_hidden_states[idx]
        remaining = train_remaining_tokens[idx]
        token_text = train_token_metadata[idx]['token_text']
        token_id = train_token_metadata[idx]['token_id']
        
        print(f"\nSample {i+1} (index {idx}):")
        print(f"  Token: {repr(token_text)} (id={token_id})")
        print(f"  Remaining tokens: {remaining}")
        print(f"  Hidden state norm: {np.linalg.norm(hidden_state):.4f}")
        print(f"  Hidden state mean: {hidden_state.mean():.4f}")
        print(f"  Hidden state std: {hidden_state.std():.4f}")