# Dataset 00: Same Prompt Visualization

This notebook visualizes the dataset generated using a single prompt template.

In [6]:
import numpy as np
import json
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

## Load Dataset

In [7]:
# Load data
data_dir = Path("data")

if not data_dir.exists():
    print("Data directory not found. Please run generate.py first.")
else:
    # Load training data
    train_hidden_states = np.load(data_dir / "train_hidden_states.npy")
    train_remaining_tokens = np.load(data_dir / "train_remaining_tokens.npy")
    train_token_metadata = np.load(data_dir / "train_token_metadata.npy")
    
    # Load validation data
    val_hidden_states = np.load(data_dir / "val_hidden_states.npy")
    val_remaining_tokens = np.load(data_dir / "val_remaining_tokens.npy")
    val_token_metadata = np.load(data_dir / "val_token_metadata.npy")
    
    # Load metadata if exists
    metadata_path = data_dir / "metadata.json"
    if metadata_path.exists():
        with open(metadata_path, 'r') as f:
            metadata = json.load(f)
        print("Dataset Metadata:")
        for key, value in metadata.items():
            if key != 'prompt_usage_distribution':
                print(f"  {key}: {value}")
    
    print(f"\nData shapes:")
    print(f"  Train hidden states: {train_hidden_states.shape}")
    print(f"  Train remaining tokens: {train_remaining_tokens.shape}")
    print(f"  Val hidden states: {val_hidden_states.shape}")
    print(f"  Val remaining tokens: {val_remaining_tokens.shape}")

Dataset Metadata:
  dataset_id: 00_same_prompt
  description: Dataset generated using a single prompt template
  model_name: meta-llama/Llama-3.2-3B-Instruct
  prompt_template: Print exactly {count} repetitions of the token "{word}". Do not include anything else.
  counts_range: [5, 49]
  words: ['hello', 'world', 'cat', 'dog', 'python', 'test', 'apple', 'blue', 'sun', 'code']
  total_samples: 450
  total_tokens: 40522
  layers_extracted: last_layer_only

Data shapes:
  Train hidden states: (36469, 3072)
  Train remaining tokens: (36469,)
  Val hidden states: (4053, 3072)
  Val remaining tokens: (4053,)


## Sample Data Points

In [8]:
if 'train_hidden_states' in locals():
    # Show some random samples
    print("Random samples from training data:")
    print("="*60)
    
    np.random.seed(42)
    sample_indices = np.random.choice(len(train_hidden_states), 10, replace=False)
    
    for i, idx in enumerate(sample_indices):
        hidden_state = train_hidden_states[idx]
        remaining = train_remaining_tokens[idx]
        token_text = train_token_metadata[idx]['token_text']
        token_id = train_token_metadata[idx]['token_id']
        
        print(f"\nSample {i+1} (index {idx}):")
        print(f"  Token: {repr(token_text)} (id={token_id})")
        print(f"  Remaining tokens: {remaining}")
        print(f"  Hidden state norm: {np.linalg.norm(hidden_state):.4f}")
        print(f"  Hidden state mean: {hidden_state.mean():.4f}")
        print(f"  Hidden state std: {hidden_state.std():.4f}")

Random samples from training data:

Sample 1 (index 9064):
  Token: np.str_('blue') (id=12481)
  Remaining tokens: 31
  Hidden state norm: 89.3750
  Hidden state mean: 0.0536
  Hidden state std: 1.6113

Sample 2 (index 2296):
  Token: np.str_('sun') (id=40619)
  Remaining tokens: 25
  Hidden state norm: 89.3125
  Hidden state mean: -0.0174
  Hidden state std: 1.6113

Sample 3 (index 36406):
  Token: np.str_('sun') (id=40619)
  Remaining tokens: 145
  Hidden state norm: 89.3750
  Hidden state mean: -0.0196
  Hidden state std: 1.6123

Sample 4 (index 35294):
  Token: np.str_('cat') (id=4719)
  Remaining tokens: 57
  Hidden state norm: 88.8750
  Hidden state mean: 0.0319
  Hidden state std: 1.6035

Sample 5 (index 4469):
  Token: np.str_('\n') (id=198)
  Remaining tokens: 30
  Hidden state norm: 86.6875
  Hidden state mean: -0.0058
  Hidden state std: 1.5645

Sample 6 (index 6267):
  Token: np.str_('\n') (id=198)
  Remaining tokens: 54
  Hidden state norm: 84.8750
  Hidden state mean: -0.

## Complete Output Reconstruction

Since the data is stored in generation order (not shuffled), we can reconstruct complete outputs by grouping tokens until we reach remaining_tokens == 0.

In [9]:
if 'train_hidden_states' in locals():
    # Reconstruct complete outputs
    def reconstruct_sequences(remaining_tokens, token_metadata, max_sequences=10):
        """Group tokens into complete generation sequences."""
        sequences = []
        current_sequence = []

        for i in range(len(remaining_tokens)):
            token_text = str(token_metadata[i]['token_text'])
            remaining = remaining_tokens[i]

            current_sequence.append({
                'token': token_text,
                'remaining': remaining,
                'index': i
            })

            # End of sequence
            if remaining == 0:
                sequences.append(current_sequence)
                current_sequence = []

                if len(sequences) >= max_sequences:
                    break

        return sequences

    # Reconstruct first 10 complete sequences
    sequences = reconstruct_sequences(train_remaining_tokens, train_token_metadata, max_sequences=10)

    print(f"Reconstructed {len(sequences)} complete generation sequences")
    print("="*80)

    for i, seq in enumerate(sequences):
        # Reconstruct the full output text
        output_text = ''.join([token['token'] for token in seq])
        num_tokens = len(seq)
        start_idx = seq[0]['index']
        end_idx = seq[-1]['index']

        print(f"\nSequence {i+1}:")
        print(f"  Indices: {start_idx} to {end_idx}")
        print(f"  Total tokens: {num_tokens}")
        print(f"  Output: {repr(output_text)}")
        print(f"  Token breakdown:")
        for j, token_info in enumerate(seq[:15]):  # Show first 15 tokens
            print(f"    {j+1:2d}. {repr(token_info['token']):20s} (remaining: {token_info['remaining']:2d})")
        if len(seq) > 15:
            print(f"    ... ({len(seq) - 15} more tokens)")
        print()

Reconstructed 10 complete generation sequences

Sequence 1:
  Indices: 0 to 9
  Total tokens: 10
  Output: 'hello\nhello\nhello\nhello\nhello<|eot_id|>'
  Token breakdown:
     1. 'hello'              (remaining:  9)
     2. '\n'                 (remaining:  8)
     3. 'hello'              (remaining:  7)
     4. '\n'                 (remaining:  6)
     5. 'hello'              (remaining:  5)
     6. '\n'                 (remaining:  4)
     7. 'hello'              (remaining:  3)
     8. '\n'                 (remaining:  2)
     9. 'hello'              (remaining:  1)
    10. '<|eot_id|>'         (remaining:  0)


Sequence 2:
  Indices: 10 to 19
  Total tokens: 10
  Output: 'world\nworld\nworld\nworld\nworld<|eot_id|>'
  Token breakdown:
     1. 'world'              (remaining:  9)
     2. '\n'                 (remaining:  8)
     3. 'world'              (remaining:  7)
     4. '\n'                 (remaining:  6)
     5. 'world'              (remaining:  5)
     6. '\n'             