# Explore Training Data

This notebook visualizes samples from the training data to understand the input-output relationship.

In [12]:
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt

# Load training data
data_dir = Path("data")
train_hidden_states = np.load(data_dir / "train_hidden_states.npy")
train_remaining_tokens = np.load(data_dir / "train_remaining_tokens.npy")
train_token_metadata = np.load(data_dir / "train_token_metadata.npy")

print(f"Training data shape:")
print(f"Hidden states: {train_hidden_states.shape}")
print(f"Remaining tokens: {train_remaining_tokens.shape}")
print(f"Token metadata: {train_token_metadata.shape}")
print(f"\nTotal samples: {len(train_hidden_states):,}")
print(f"\nMetadata fields: {train_token_metadata.dtype.names}")

Training data shape:
Hidden states: (99, 3072)
Remaining tokens: (99,)
Token metadata: (99,)

Total samples: 99

Metadata fields: ('token_id', 'token_text')


## Sample Data Points

Let's look at a few random samples to see the input-output pairs.

In [13]:
# Show 10 random samples
np.random.seed(42)
sample_indices = range(10)

print("Random samples from training data:")
print("="*60)
for i, idx in enumerate(sample_indices):
    hidden_state = train_hidden_states[idx]
    remaining = train_remaining_tokens[idx]
    token_id = train_token_metadata[idx]['token_id']
    token_text = train_token_metadata[idx]['token_text']
    
    print(f"\nSample {i+1} (index {idx}):")
    print(f"  Token: '{token_text}' (id={token_id})")
    print(f"  Input (hidden state): shape={hidden_state.shape}, norm={np.linalg.norm(hidden_state):.4f}")
    print(f"  First 10 values: {hidden_state[:10]}")
    print(f"  Output (remaining tokens): {remaining}")

Random samples from training data:

Sample 1 (index 0):
  Token: '

' (id=271)
  Input (hidden state): shape=(3072,), norm=89.6875
  First 10 values: [-1.068   1.066   1.274   0.2205  1.554  -0.8564 -1.797   1.861  -0.986
 -0.9805]
  Output (remaining tokens): 10

Sample 2 (index 1):
  Token: 'hello' (id=15339)
  Input (hidden state): shape=(3072,), norm=89.9375
  First 10 values: [ 2.729    2.598    0.02315  1.398   -0.629   -0.196   -1.846   -0.7637
  0.4102  -1.384  ]
  Output (remaining tokens): 9

Sample 3 (index 2):
  Token: '
' (id=198)
  Input (hidden state): shape=(3072,), norm=86.6875
  First 10 values: [-0.604  1.889  1.465  1.266  1.081 -1.03  -3.016 -0.745  1.162 -1.031]
  Output (remaining tokens): 8

Sample 4 (index 3):
  Token: 'hello' (id=15339)
  Input (hidden state): shape=(3072,), norm=89.3750
  First 10 values: [-1.184   0.9717 -0.1624  0.647   0.2041  0.4475 -1.611   0.6177 -0.157
 -1.122 ]
  Output (remaining tokens): 7

Sample 5 (index 4):
  Token: '
' (id=198)


## Distribution of Remaining Tokens

Let's visualize the distribution of the target variable (remaining tokens).

In [None]:
plt.figure(figsize=(12, 5))

# Histogram
plt.subplot(1, 2, 1)
plt.hist(train_remaining_tokens, bins=50, edgecolor='black', alpha=0.7)
plt.xlabel('Remaining Tokens')
plt.ylabel('Frequency')
plt.title('Distribution of Remaining Tokens')
plt.grid(True, alpha=0.3)

# Statistics
plt.subplot(1, 2, 2)
stats_text = f"""Statistics:
Min: {train_remaining_tokens.min()}
Max: {train_remaining_tokens.max()}
Mean: {train_remaining_tokens.mean():.2f}
Median: {np.median(train_remaining_tokens):.2f}
Std: {train_remaining_tokens.std():.2f}

Unique values: {len(np.unique(train_remaining_tokens))}
"""
plt.text(0.1, 0.5, stats_text, fontsize=12, family='monospace',
         verticalalignment='center')
plt.axis('off')
plt.title('Target Variable Statistics')

plt.tight_layout()
plt.show()

## Hidden State Analysis

Let's examine the hidden state vectors to understand their properties.

In [None]:
# Compute statistics across all hidden states
hidden_norms = np.linalg.norm(train_hidden_states, axis=1)
hidden_means = train_hidden_states.mean(axis=1)
hidden_stds = train_hidden_states.std(axis=1)

plt.figure(figsize=(15, 4))

# Norms
plt.subplot(1, 3, 1)
plt.hist(hidden_norms, bins=50, edgecolor='black', alpha=0.7)
plt.xlabel('L2 Norm')
plt.ylabel('Frequency')
plt.title('Distribution of Hidden State Norms')
plt.grid(True, alpha=0.3)

# Means
plt.subplot(1, 3, 2)
plt.hist(hidden_means, bins=50, edgecolor='black', alpha=0.7)
plt.xlabel('Mean Value')
plt.ylabel('Frequency')
plt.title('Distribution of Hidden State Means')
plt.grid(True, alpha=0.3)

# Stds
plt.subplot(1, 3, 3)
plt.hist(hidden_stds, bins=50, edgecolor='black', alpha=0.7)
plt.xlabel('Standard Deviation')
plt.ylabel('Frequency')
plt.title('Distribution of Hidden State Std Devs')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Relationship: Hidden States vs Remaining Tokens

Let's explore if there's any obvious relationship between hidden state properties and remaining tokens.

In [None]:
plt.figure(figsize=(15, 4))

# Norm vs Remaining Tokens
plt.subplot(1, 3, 1)
plt.scatter(train_remaining_tokens, hidden_norms, alpha=0.1, s=1)
plt.xlabel('Remaining Tokens')
plt.ylabel('Hidden State Norm')
plt.title('Norm vs Remaining Tokens')
plt.grid(True, alpha=0.3)

# Mean vs Remaining Tokens
plt.subplot(1, 3, 2)
plt.scatter(train_remaining_tokens, hidden_means, alpha=0.1, s=1)
plt.xlabel('Remaining Tokens')
plt.ylabel('Hidden State Mean')
plt.title('Mean vs Remaining Tokens')
plt.grid(True, alpha=0.3)

# Std vs Remaining Tokens
plt.subplot(1, 3, 3)
plt.scatter(train_remaining_tokens, hidden_stds, alpha=0.1, s=1)
plt.xlabel('Remaining Tokens')
plt.ylabel('Hidden State Std Dev')
plt.title('Std Dev vs Remaining Tokens')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Sample Sequence Analysis

Let's look at how remaining tokens decrease during generation for specific examples.

In [None]:
# Find samples that form sequences (consecutive remaining token counts)
# This assumes the data was generated sequentially before shuffling

print("Examples of how remaining tokens change during generation:")
print("="*60)

# Show first 20 samples with their tokens
print("\nFirst 20 samples (may show part of a generation sequence):")
for i in range(min(20, len(train_remaining_tokens))):
    token_text = train_token_metadata[i]['token_text']
    token_id = train_token_metadata[i]['token_id']
    print(f"Sample {i:3d}: token='{token_text:20s}' (id={token_id:6d}), remaining={train_remaining_tokens[i]:3d}")

# Find sequences where remaining tokens decrease by 1
print("\n\nLooking for consecutive decreasing sequences...")
sequence_starts = []
for i in range(len(train_remaining_tokens) - 5):
    # Check if we have a decreasing sequence
    is_sequence = all(
        train_remaining_tokens[i+j] - train_remaining_tokens[i+j+1] == 1
        for j in range(4)
    )
    if is_sequence:
        sequence_starts.append(i)
        if len(sequence_starts) >= 3:  # Show first 3 found
            break

if sequence_starts:
    for seq_start in sequence_starts:
        print(f"\nSequence starting at index {seq_start}:")
        for j in range(10):
            if seq_start + j < len(train_remaining_tokens):
                token_text = train_token_metadata[seq_start + j]['token_text']
                print(f"  Step {j}: token='{token_text:20s}' remaining={train_remaining_tokens[seq_start + j]}")
else:
    print("No obvious sequences found (data is shuffled)")