# 12.4g: bfloat16 Bit Pattern Analysis

**Goal:** Analyze the raw bfloat16 bit patterns of Qwen's 13 black holes to understand how they actually differ.

## The Question

We've been doing float32 math on bfloat16 data, which obscures the actual bit-level structure.

By looking at the **raw 16-bit representations**, we can see:
1. Which dimensions are **identical** across all 13 black holes (same bit pattern)
2. Which dimensions **vary**, and by how much (bit-level differences)
3. Whether the differences are single-bit flips (1 ULP changes) or more complex

## bfloat16 Format

Each bfloat16 value is 16 bits:
```
[sign: 1 bit][exponent: 8 bits][mantissa: 7 bits]
```

A difference of "1 ULP" means the mantissa differs by 1 in the least significant bit.

## Method

1. Load dead tokens in **bfloat16** (no float32 conversion!)
2. Find the 13 unique black hole vectors
3. Reinterpret as `int16` to get raw bit patterns
4. For each dimension, analyze the unique bit patterns
5. Show which dimensions are constant vs varying

## Parameters

In [1]:
# Input paths
GAMMA_PATH = "../data/tensors/gamma_qwen3_4b_instruct_2507.safetensors"
MASK_PATH = "../data/tensors/black_hole_mask.safetensors"

RANDOM_SEED = 42

## Imports

In [2]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from safetensors.torch import load_file
from collections import Counter

torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

print("✓ Imports complete")

✓ Imports complete


## Load Data (Keep as bfloat16!)

In [3]:
print("Loading data (keeping bfloat16 precision)...\n")

# Load gamma - check its dtype
gamma_data = load_file(GAMMA_PATH)
gamma_raw = gamma_data['gamma']
print(f"✓ Loaded γ")
print(f"  Shape: {gamma_raw.shape}")
print(f"  Dtype: {gamma_raw.dtype}")

# Convert to bfloat16 if needed
if gamma_raw.dtype != torch.bfloat16:
    print(f"  Converting from {gamma_raw.dtype} to bfloat16...")
    gamma = gamma_raw.to(torch.bfloat16)
else:
    gamma = gamma_raw

# Load black hole mask
mask_data = load_file(MASK_PATH)
mask = mask_data['mask']
print(f"\n✓ Loaded black hole mask")
print(f"  Dead tokens: {mask.sum().item():,}")

# Extract dead token embeddings (KEEP AS BFLOAT16)
dead_tokens = gamma[mask]
print(f"\n✓ Extracted dead token embeddings")
print(f"  Shape: {dead_tokens.shape}")
print(f"  Dtype: {dead_tokens.dtype}")

Loading data (keeping bfloat16 precision)...

✓ Loaded γ
  Shape: torch.Size([151936, 2560])
  Dtype: torch.float32
  Converting from torch.float32 to bfloat16...

✓ Loaded black hole mask
  Dead tokens: 2,100

✓ Extracted dead token embeddings
  Shape: torch.Size([2100, 2560])
  Dtype: torch.bfloat16


## Find Unique Vectors (in bfloat16)

In [4]:
print("\nFinding unique vectors (staying in bfloat16)...\n")

unique_vectors, inverse_indices, counts = torch.unique(
    dead_tokens,
    dim=0,
    return_inverse=True,
    return_counts=True
)

print(f"✓ Found {len(unique_vectors)} unique vectors")
print(f"  Dtype: {unique_vectors.dtype}")
print(f"\nPopulations: {sorted(counts.tolist(), reverse=True)}")


Finding unique vectors (staying in bfloat16)...

✓ Found 13 unique vectors
  Dtype: torch.bfloat16

Populations: [814, 704, 306, 228, 11, 10, 6, 5, 4, 4, 3, 3, 2]


## Convert to Raw 16-bit Integer Representation

In [5]:
print("\nConverting to raw 16-bit integer representation...\n")

# Reinterpret bfloat16 bytes as int16
# Note: We use uint16 (unsigned) to avoid sign issues when doing bit operations
unique_as_bits = unique_vectors.view(torch.int16)

print(f"✓ Converted to int16 view")
print(f"  Shape: {unique_as_bits.shape}")
print(f"  Dtype: {unique_as_bits.dtype}")

# Example: show first vector's first few components in different formats
print(f"\nExample - First vector, first 5 dimensions:")
for i in range(5):
    bf16_val = unique_vectors[0, i].item()
    bits_val = unique_as_bits[0, i].item()
    # Convert to unsigned for binary display
    unsigned_bits = bits_val if bits_val >= 0 else bits_val + 65536
    binary = format(unsigned_bits, '016b')
    hex_val = format(unsigned_bits, '04x')
    
    print(f"  Dim {i}: {bf16_val:+.6e} = 0x{hex_val} = 0b{binary}")


Converting to raw 16-bit integer representation...

✓ Converted to int16 view
  Shape: torch.Size([13, 2560])
  Dtype: torch.int16

Example - First vector, first 5 dimensions:
  Dim 0: +6.072998e-03 = 0x3bc7 = 0b0011101111000111
  Dim 1: +1.324463e-02 = 0x3c59 = 0b0011110001011001
  Dim 2: +1.177979e-02 = 0x3c41 = 0b0011110001000001
  Dim 3: +3.662109e-02 = 0x3d16 = 0b0011110100010110
  Dim 4: +1.586914e-02 = 0x3c82 = 0b0011110010000010


## Analyze Per-Dimension Bit Patterns

In [6]:
print("\nAnalyzing per-dimension bit patterns...\n")

n_unique, n_dims = unique_as_bits.shape

# For each dimension, count unique bit patterns
constant_dims = []
varying_dims = []

for dim in range(n_dims):
    values_in_dim = unique_as_bits[:, dim]  # [13] int16 values
    unique_bit_patterns = torch.unique(values_in_dim)
    
    if len(unique_bit_patterns) == 1:
        # All 13 vectors have the same value in this dimension
        constant_dims.append(dim)
    else:
        # This dimension varies
        varying_dims.append({
            'dim': dim,
            'n_unique_patterns': len(unique_bit_patterns),
            'patterns': unique_bit_patterns.tolist(),
        })

print(f"Dimension classification:")
print(f"  Constant dimensions: {len(constant_dims)} ({len(constant_dims)/n_dims*100:.1f}%)")
print(f"  Varying dimensions: {len(varying_dims)} ({len(varying_dims)/n_dims*100:.1f}%)")


Analyzing per-dimension bit patterns...

Dimension classification:
  Constant dimensions: 2540 (99.2%)
  Varying dimensions: 20 (0.8%)


## Analyze Varying Dimensions

In [7]:
print(f"\nDetailed analysis of varying dimensions:")
print(f"{'='*80}\n")

# Count how many dimensions have 2 patterns, 3 patterns, etc.
pattern_count_distribution = Counter([d['n_unique_patterns'] for d in varying_dims])

print(f"Distribution of unique patterns per dimension:")
for n_patterns in sorted(pattern_count_distribution.keys()):
    count = pattern_count_distribution[n_patterns]
    print(f"  {n_patterns} unique patterns: {count} dimensions")

print(f"\nFirst 20 varying dimensions (detailed):")
print(f"{'Dim':>6} {'N patterns':>12} {'Bit patterns (hex)':>40}")
print("-" * 80)

for i, info in enumerate(varying_dims[:20]):
    dim = info['dim']
    n_patterns = info['n_unique_patterns']
    patterns = info['patterns']
    
    # Convert to unsigned hex for display
    hex_patterns = []
    for p in patterns:
        unsigned = p if p >= 0 else p + 65536
        hex_patterns.append(f"0x{unsigned:04x}")
    
    patterns_str = ', '.join(hex_patterns)
    if len(patterns_str) > 40:
        patterns_str = patterns_str[:37] + '...'
    
    print(f"{dim:>6} {n_patterns:>12} {patterns_str:>40}")


Detailed analysis of varying dimensions:

Distribution of unique patterns per dimension:
  2 unique patterns: 17 dimensions
  3 unique patterns: 2 dimensions
  7 unique patterns: 1 dimensions

First 20 varying dimensions (detailed):
   Dim   N patterns                       Bit patterns (hex)
--------------------------------------------------------------------------------
   216            2                           0xbadf, 0xbae0
   282            2                           0xbb68, 0xbb69
   322            2                           0x3b5f, 0x3b61
   450            2                           0xb954, 0xb955
   993            2                           0xb9a4, 0xb9a5
  1008            3                   0xb5e7, 0xb5e8, 0xb5ec
  1149            2                           0x3802, 0x3803
  1155            2                           0xba5d, 0xba5e
  1272            2                           0xb90a, 0xb90b
  1382            7 0xb595, 0xb596, 0xb597, 0xb598, 0xb59...
  1403        

## Compute Hamming Distances Between Black Holes

In [8]:
print(f"\nComputing Hamming distances (bit-level differences)...\n")

# For each pair of vectors, count how many dimensions differ
hamming_matrix = torch.zeros(n_unique, n_unique, dtype=torch.int32)

for i in range(n_unique):
    for j in range(i+1, n_unique):
        # Count dimensions where bit patterns differ
        diff_mask = unique_as_bits[i] != unique_as_bits[j]
        hamming_dist = diff_mask.sum().item()
        hamming_matrix[i, j] = hamming_dist
        hamming_matrix[j, i] = hamming_dist

print(f"Hamming distance statistics (dimensions that differ):")
# Exclude diagonal
non_diag = hamming_matrix[hamming_matrix > 0]
print(f"  Min: {non_diag.min().item()}")
print(f"  Max: {non_diag.max().item()}")
print(f"  Mean: {non_diag.float().mean().item():.1f}")
print(f"  Median: {non_diag.float().median().item():.0f}")

# Show the matrix
print(f"\nHamming distance matrix:")
print(f"(Number of dimensions where bit patterns differ)\n")
print(hamming_matrix.cpu().numpy())


Computing Hamming distances (bit-level differences)...

Hamming distance statistics (dimensions that differ):
  Min: 1
  Max: 14
  Mean: 6.8
  Median: 6

Hamming distance matrix:
(Number of dimensions where bit patterns differ)

[[ 0  9  7  3  6  5  6  6  7 10  9  7  7]
 [ 9  0 11  9 12 12 13 12 13 14 13 14 11]
 [ 7 11  0  5  8  8  9  8  9 12 11 10  9]
 [ 3  9  5  0  6  6  7  6  4  7  6  8  5]
 [ 6 12  8  6  0  1  2  2  6  7  6  3  6]
 [ 5 12  8  6  1  0  1  2  6  7  6  2  6]
 [ 6 13  9  7  2  1  0  3  5  6  5  1  5]
 [ 6 12  8  6  2  2  3  0  7  8  7  4  6]
 [ 7 13  9  4  6  6  5  7  0  4  4  6  5]
 [10 14 12  7  7  7  6  8  4  0  1  7  6]
 [ 9 13 11  6  6  6  5  7  4  1  0  6  5]
 [ 7 14 10  8  3  2  1  4  6  7  6  0  4]
 [ 7 11  9  5  6  6  5  6  5  6  5  4  0]]


## Deep Dive: Pick One Varying Dimension

In [None]:
if len(varying_dims) > 0:
    # Pick the first varying dimension
    example_dim_info = varying_dims[0]
    example_dim = example_dim_info['dim']
    
    print(f"\nDeep dive into dimension {example_dim}:")
    print(f"{'='*80}\n")
    
    print(f"All 13 black holes' values in this dimension:")
    print(f"{'BH':>4} {'Population':>12} {'bfloat16 value':>18} {'Hex':>8} {'Binary':>18}")
    print("-" * 80)
    
    for bh_idx in range(n_unique):
        pop = counts[bh_idx].item()
        bf16_val = unique_vectors[bh_idx, example_dim].item()
        bits_val = unique_as_bits[bh_idx, example_dim].item()
        unsigned_bits = bits_val if bits_val >= 0 else bits_val + 65536
        hex_val = format(unsigned_bits, '04x')
        binary = format(unsigned_bits, '016b')
        
        print(f"{bh_idx:>4} {pop:>12} {bf16_val:>+18.10e} 0x{hex_val} 0b{binary}")
    
    # Analyze bit differences
    print(f"\nBit-level differences:")
    all_vals = unique_as_bits[:, example_dim]
    unique_vals = torch.unique(all_vals)
    
    if len(unique_vals) == 2:
        v1 = unique_vals[0].item()
        v2 = unique_vals[1].item()
        u1 = v1 if v1 >= 0 else v1 + 65536
        u2 = v2 if v2 >= 0 else v2 + 65536
        
        xor = u1 ^ u2
        n_bits_diff = bin(xor).count('1')
        
        print(f"  Two unique values: 0x{u1:04x} and 0x{u2:04x}")
        print(f"  XOR: 0b{xor:016b}")
        print(f"  Number of bits that differ: {n_bits_diff}")
        
        if n_bits_diff == 1:
            print(f"  → Single-bit flip (1 ULP difference in mantissa)")
else:
    print(f"\nNo varying dimensions found (all vectors identical?!)")


Deep dive into dimension 216:

All 13 black holes' values in this dimension:
  BH   Population     bfloat16 value      Hex             Binary
--------------------------------------------------------------------------------
   0           10  -1.7089843750e-03 0xbae0 0b1011101011100000
   1            4  -1.7089843750e-03 0xbae0 0b1011101011100000
   2            5  -1.7089843750e-03 0xbae0 0b1011101011100000
   3          306  -1.7089843750e-03 0xbae0 0b1011101011100000
   4            4  -1.7089843750e-03 0xbae0 0b1011101011100000
   5          814  -1.7089843750e-03 0xbae0 0b1011101011100000
   6          228  -1.7089843750e-03 0xbae0 0b1011101011100000
   7            3  -1.7089843750e-03 0xbae0 0b1011101011100000
   8            3  -1.7089843750e-03 0xbae0 0b1011101011100000
   9            2  -1.7089843750e-03 0xbae0 0b1011101011100000
  10          704  -1.7089843750e-03 0xbae0 0b1011101011100000
  11           11  -1.7013549805e-03 0xbadf 0b1011101011011111
  12            6  -

## Summary

In [10]:
print(f"\n{'='*80}")
print(f"BIT PATTERN ANALYSIS SUMMARY")
print(f"{'='*80}")
print(f"\n{n_unique} unique black hole vectors in bfloat16")
print(f"{n_dims} dimensions total\n")

print(f"Dimension classification:")
print(f"  Constant: {len(constant_dims)} dimensions ({len(constant_dims)/n_dims*100:.1f}%)")
print(f"  Varying:  {len(varying_dims)} dimensions ({len(varying_dims)/n_dims*100:.1f}%)")

if len(varying_dims) > 0:
    print(f"\nPattern distribution in varying dimensions:")
    for n_patterns in sorted(pattern_count_distribution.keys()):
        count = pattern_count_distribution[n_patterns]
        print(f"  {count} dims have {n_patterns} unique bit patterns")

print(f"\nHamming distances (dimensions with different bits):")
print(f"  Range: {non_diag.min().item()} to {non_diag.max().item()}")
print(f"  Mean: {non_diag.float().mean().item():.1f} dimensions differ per pair")

print(f"\n{'='*80}")


BIT PATTERN ANALYSIS SUMMARY

13 unique black hole vectors in bfloat16
2560 dimensions total

Dimension classification:
  Constant: 2540 dimensions (99.2%)
  Varying:  20 dimensions (0.8%)

Pattern distribution in varying dimensions:
  17 dims have 2 unique bit patterns
  2 dims have 3 unique bit patterns
  1 dims have 7 unique bit patterns

Hamming distances (dimensions with different bits):
  Range: 1 to 14
  Mean: 6.8 dimensions differ per pair

