# Dataset 01: Different Prompts Visualization

This notebook visualizes the dataset generated using multiple different prompt templates to understand the input-output relationship.

In [1]:
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt

# Load training data
data_dir = Path("data")
train_hidden_states = np.load(data_dir / "train_hidden_states.npy")
train_remaining_tokens = np.load(data_dir / "train_remaining_tokens.npy")
train_token_metadata = np.load(data_dir / "train_token_metadata.npy")

print(f"Training data shape:")
print(f"Hidden states: {train_hidden_states.shape}")
print(f"Remaining tokens: {train_remaining_tokens.shape}")
print(f"Token metadata: {train_token_metadata.shape}")
print(f"\nTotal samples: {len(train_hidden_states):,}")
print(f"\nMetadata fields: {train_token_metadata.dtype.names}")

Training data shape:
Hidden states: (37817, 3072)
Remaining tokens: (37817,)
Token metadata: (37817,)

Total samples: 37,817

Metadata fields: ('token_id', 'token_text')


## Sample Data Points

Let's look at a few random samples to see the input-output pairs.

In [10]:
# Show 10 random samples
np.random.seed(42)
sample_indices = range(10, 20)

print("Random samples from training data:")
print("="*60)
for i, idx in enumerate(sample_indices):
    hidden_state = train_hidden_states[idx]
    remaining = train_remaining_tokens[idx]
    token_id = train_token_metadata[idx]['token_id']
    token_text = train_token_metadata[idx]['token_text']
    
    print(f"\nSample {i+1} (index {idx}):")
    print(f"  Token: '{token_text}' (id={token_id})")
    print(f"  Input (hidden state): shape={hidden_state.shape}, norm={np.linalg.norm(hidden_state):.4f}")
    print(f"  First 10 values: {hidden_state[:10]}")
    print(f"  Output (remaining tokens): {remaining}")

Random samples from training data:

Sample 1 (index 10):
  Token: '1' (id=16)
  Input (hidden state): shape=(3072,), norm=89.1250
  First 10 values: [ 0.04495 -1.132    0.2103  -0.686   -0.324   -2.238   -0.10504  3.17
 -0.3372  -2.094  ]
  Output (remaining tokens): 59

Sample 2 (index 11):
  Token: '.' (id=13)
  Input (hidden state): shape=(3072,), norm=88.1875
  First 10 values: [-0.1519 -0.9116  5.203  -2.49   -0.989  -0.8203 -0.2336  0.318  -0.2212
 -1.076 ]
  Output (remaining tokens): 58

Sample 3 (index 12):
  Token: ' The' (id=578)
  Input (hidden state): shape=(3072,), norm=84.6875
  First 10 values: [-1.110e-03 -8.047e-01  9.131e-01 -8.008e-01  2.196e-01 -2.801e+00
  2.939e-01  1.860e+00  1.357e-01 -1.169e+00]
  Output (remaining tokens): 57

Sample 4 (index 13):
  Token: ' world' (id=1917)
  Input (hidden state): shape=(3072,), norm=86.7500
  First 10 values: [-1.192   -0.4163   3.018   -1.173    0.2212   0.03424 -0.3984   1.78
  0.19    -0.9526 ]
  Output (remaining tokens

## Sample Sequence Analysis

Let's look at how remaining tokens decrease during generation for specific examples.

In [9]:
# Find samples that form sequences (consecutive remaining token counts)
# This assumes the data was generated sequentially before shuffling

print("Examples of how remaining tokens change during generation:")
print("="*60)

# Show first 20 samples with their tokens
print("\nFirst 20 samples (may show part of a generation sequence):")
for i in range(min(20, len(train_remaining_tokens))):
    token_text = train_token_metadata[i]['token_text']
    token_id = train_token_metadata[i]['token_id']
    print(f"Sample {i:3d}: token={repr(token_text):22s} (id={token_id:6d}), remaining={train_remaining_tokens[i]:3d}")

# Find sequences where remaining tokens decrease by 1
print("\n\nLooking for consecutive decreasing sequences...")
sequence_starts = []
for i in range(len(train_remaining_tokens) - 5):
    # Check if we have a decreasing sequence
    is_sequence = all(
        train_remaining_tokens[i+j] - train_remaining_tokens[i+j+1] == 1
        for j in range(4)
    )
    if is_sequence:
        sequence_starts.append(i)
        if len(sequence_starts) >= 3:  # Show first 3 found
            break

if sequence_starts:
    for seq_start in sequence_starts:
        print(f"\nSequence starting at index {seq_start}:")
        for j in range(10):
            if seq_start + j < len(train_remaining_tokens):
                token_text = train_token_metadata[seq_start + j]['token_text']
                print(f"  Step {j}: token={repr(token_text):22s} remaining={train_remaining_tokens[seq_start + j]}")
else:
    print("No obvious sequences found (data is shuffled)")

Examples of how remaining tokens change during generation:

First 20 samples (may show part of a generation sequence):
Sample   0: token=np.str_('hello')       (id= 15339), remaining=  9
Sample   1: token=np.str_('\n')          (id=   198), remaining=  8
Sample   2: token=np.str_('hello')       (id= 15339), remaining=  7
Sample   3: token=np.str_('\n')          (id=   198), remaining=  6
Sample   4: token=np.str_('hello')       (id= 15339), remaining=  5
Sample   5: token=np.str_('\n')          (id=   198), remaining=  4
Sample   6: token=np.str_('hello')       (id= 15339), remaining=  3
Sample   7: token=np.str_('\n')          (id=   198), remaining=  2
Sample   8: token=np.str_('hello')       (id= 15339), remaining=  1
Sample   9: token=np.str_('<|eot_id|>')  (id=128009), remaining=  0
Sample  10: token=np.str_('1')           (id=    16), remaining= 59
Sample  11: token=np.str_('.')           (id=    13), remaining= 58
Sample  12: token=np.str_(' The')        (id=   578), remaining= 

## Complete Output Reconstruction

Since the data is stored in generation order (not shuffled), we can reconstruct complete outputs by grouping tokens until we reach remaining_tokens == 0.

In [11]:
# Reconstruct complete outputs
def reconstruct_sequences(remaining_tokens, token_metadata, max_sequences=10):
    """Group tokens into complete generation sequences."""
    sequences = []
    current_sequence = []

    for i in range(len(remaining_tokens)):
        token_text = str(token_metadata[i]['token_text'])
        remaining = remaining_tokens[i]

        current_sequence.append({
            'token': token_text,
            'remaining': remaining,
            'index': i
        })

        # End of sequence
        if remaining == 0:
            sequences.append(current_sequence)
            current_sequence = []

            if len(sequences) >= max_sequences:
                break

    return sequences

# Reconstruct first 10 complete sequences
sequences = reconstruct_sequences(train_remaining_tokens, train_token_metadata, max_sequences=10)

print(f"Reconstructed {len(sequences)} complete generation sequences")
print("="*80)

for i, seq in enumerate(sequences):
    # Reconstruct the full output text
    output_text = ''.join([token['token'] for token in seq])
    num_tokens = len(seq)
    start_idx = seq[0]['index']
    end_idx = seq[-1]['index']

    print(f"\nSequence {i+1}:")
    print(f"  Indices: {start_idx} to {end_idx}")
    print(f"  Total tokens: {num_tokens}")
    print(f"  Output: {repr(output_text)}")
    print(f"  Token breakdown:")
    for j, token_info in enumerate(seq[:15]):  # Show first 15 tokens
        print(f"    {j+1:2d}. {repr(token_info['token']):20s} (remaining: {token_info['remaining']:2d})")
    if len(seq) > 15:
        print(f"    ... ({len(seq) - 15} more tokens)")
    print()

Reconstructed 10 complete generation sequences

Sequence 1:
  Indices: 0 to 9
  Total tokens: 10
  Output: 'hello\nhello\nhello\nhello\nhello<|eot_id|>'
  Token breakdown:
     1. 'hello'              (remaining:  9)
     2. '\n'                 (remaining:  8)
     3. 'hello'              (remaining:  7)
     4. '\n'                 (remaining:  6)
     5. 'hello'              (remaining:  5)
     6. '\n'                 (remaining:  4)
     7. 'hello'              (remaining:  3)
     8. '\n'                 (remaining:  2)
     9. 'hello'              (remaining:  1)
    10. '<|eot_id|>'         (remaining:  0)


Sequence 2:
  Indices: 10 to 69
  Total tokens: 60
  Output: '1. The world is a complex and dynamic place.\n2. The world is full of diverse cultures and landscapes.\n3. The world is a place of endless possibilities.\n4. The world is a stage for human drama and comedy.\n5. The world is a place of wonder and discovery.<|eot_id|>'
  Token breakdown:
     1. '1'                