# 1.33a: Lattice Displacement Calculation (Revised)

Computing displacements in lattice-cell coordinates for all tokens, with proper validation and memory management.

## Memory Budget

Maximum 24 GB RAM at any one time. We maximize usage up to that ceiling.

## The Algorithm

For a displacement ŒîW = W[t+1] - W[t], the lattice coordinate displacement is:

$$\Delta W' = \frac{\Delta W}{U[t]}$$

where U[t] = ULP(W[t]) is the lattice spacing at the starting position.

**Expected result**: ŒîW‚Ä≤ should be exact integers (since both W values live on the bfloat16 lattice).

## Validation Strategy

Compute delta_W_prime in **float32** to detect non-integer results:

- `frac = |x - floor(x)|`
- `frac == 0`: ‚úÖ Exact integer (expected)
- `0 < frac < Œµ`: ‚ö†Ô∏è Float rounding error (needs attention, but algorithm is basically correct)
- `frac >= Œµ`: üö® GENERAL QUARTERS - algorithm is broken

**Float32 limitation**: Can only represent integers exactly up to 2^24 = 16,777,216. We check for overflow.

## What We DON'T Do (Yet)

We do NOT save results until we know what integer type is needed. Save cells are left empty, ready to execute once we determine the range.

## Parameters

In [1]:
THIMBLE_PATH = "../tensors/Thimble/thimble_7.h5"
OUTPUT_PATH = "../tensors/Thimble/1.33a_lattice_displacements.safetensors"

# Validation thresholds
EPSILON = 1e-6  # Tolerance for "almost integer" (float32 rounding error)

# Float32 integer fidelity limit
FLOAT32_INT_LIMIT = 2**24  # 16,777,216

RANDOM_SEED = 42

## Imports

In [2]:
import h5py
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from safetensors.torch import save_file

np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

print("‚úì Imports complete")

‚úì Imports complete


## Device Detection

In [3]:
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

print(f"Using device: {device}")

Using device: mps


## Load Data

In [4]:
print(f"Loading data from {THIMBLE_PATH}...\n")

with h5py.File(THIMBLE_PATH, 'r') as f:
    # Load W for all tokens in bfloat16: (6001, 10000, 64)
    W = torch.from_numpy(f['W'][:]).view(torch.bfloat16).to(device)
    
    # Load dead token mask
    dead_mask = torch.from_numpy(f['dead_mask'][:]).bool()

n_steps, n_tokens, n_dims = W.shape
n_dead = dead_mask.sum().item()
dead_token_ids = torch.where(dead_mask)[0]

print(f"‚úì Loaded W")
print(f"  Shape: {W.shape}")
print(f"  Dtype: {W.dtype}")
print(f"  Steps: {n_steps}, Tokens: {n_tokens}, Dimensions: {n_dims}")
print(f"  Dead tokens: {n_dead} ({n_dead/n_tokens:.1%})")
print(f"  Memory: {W.element_size() * W.nelement() / 1e9:.2f} GB")

Loading data from ../tensors/Thimble/thimble_7.h5...

‚úì Loaded W
  Shape: torch.Size([6001, 10000, 64])
  Dtype: torch.bfloat16
  Steps: 6001, Tokens: 10000, Dimensions: 64
  Dead tokens: 3699 (37.0%)
  Memory: 7.68 GB


## Compute ULP Matrix

**Note:** We compute ULP in float32 (not bfloat16) because `torch.nextafter` works reliably in float32 across all devices.

In [5]:
print("Computing ULP matrix...\n")

def compute_ulp(x):
    """
    Compute ULP for bfloat16 values, handling zeros correctly.
    
    For non-zero values: ULP = nextafter(|x|, |x|+1) - |x|
    For exact zeros: ULP = smallest positive normal bfloat16 ‚âà 1.175e-38
    
    Returns bfloat16 tensor on CPU.
    """
    x_abs = x.abs()
    ulp = torch.nextafter(x_abs, x_abs + torch.ones_like(x_abs)) - x_abs
    
    # For exact zeros, use smallest normal bfloat16
    min_normal_bf16 = torch.tensor(2.0**-126, dtype=torch.bfloat16, device=x.device)
    ulp = torch.where(x == 0, min_normal_bf16, ulp)
    
    return ulp

# Move W to CPU for ULP calculation (torch.nextafter doesn't work for bfloat16 on MPS)
print("  Moving W to CPU for ULP calculation...")
W_cpu = W.cpu()

# Compute ULP on CPU
print("  Computing ULP on CPU (this takes ~5-6 seconds)...")
U_cpu = compute_ulp(W_cpu)

# Move result back to device
print(f"  Moving U to {device}...")
U = U_cpu.to(device)

# Can free CPU copy
del W_cpu, U_cpu

print(f"\n‚úì ULP matrix computed")
print(f"  Shape: {U.shape}")
print(f"  Dtype: {U.dtype}")
print(f"  Memory: {U.element_size() * U.nelement() / 1e9:.2f} GB")

# Diagnostic: count exact zeros
n_zeros = (W == 0).sum().item()
print(f"\n  Exact zeros in W: {n_zeros:,} ({n_zeros/W.numel():.6%})")
print(f"  Min ULP: {U.min().item():.6e}")
print(f"  Max ULP: {U.max().item():.6e}")

# Current memory usage
mem_W = W.element_size() * W.nelement() / 1e9
mem_U = U.element_size() * U.nelement() / 1e9
print(f"\n  Current memory usage: {mem_W + mem_U:.2f} GB (W + U)")

Computing ULP matrix...

  Moving W to CPU for ULP calculation...
  Computing ULP on CPU (this takes ~5-6 seconds)...
  Moving U to mps...

‚úì ULP matrix computed
  Shape: torch.Size([6001, 10000, 64])
  Dtype: torch.bfloat16
  Memory: 7.68 GB

  Exact zeros in W: 263 (0.000007%)
  Min ULP: 1.175494e-38
  Max ULP: 3.906250e-03

  Current memory usage: 15.36 GB (W + U)


## Compute delta_W (Cartesian Displacement)

In [6]:
print("Computing delta_W (Cartesian displacement)...\n")

# Compute displacement in bfloat16
delta_W = W[1:] - W[:-1]  # (6000, 10000, 64)

print(f"‚úì delta_W computed")
print(f"  Shape: {delta_W.shape}")
print(f"  Dtype: {delta_W.dtype}")
print(f"  Memory: {delta_W.element_size() * delta_W.nelement() / 1e9:.2f} GB")

# Current memory usage
mem_delta_W = delta_W.element_size() * delta_W.nelement() / 1e9
total_mem = mem_W + mem_U + mem_delta_W
print(f"\n  Current memory usage: {total_mem:.2f} GB (W + U + delta_W)")
print(f"  Budget remaining: {24 - total_mem:.2f} GB")

Computing delta_W (Cartesian displacement)...

‚úì delta_W computed
  Shape: torch.Size([6000, 10000, 64])
  Dtype: torch.bfloat16
  Memory: 7.68 GB

  Current memory usage: 23.04 GB (W + U + delta_W)
  Budget remaining: 0.96 GB


## Compute delta_W_prime (Lattice Displacement)

Divide by starting ULP. This MUST be done in float32 to detect non-integer outputs.

In [7]:
print("Computing delta_W_prime (lattice displacement)...\n")

# We can now discard W (we only need U and delta_W)
del W
print("‚úì Freed W from memory")

# Get starting ULP (t, not t+1)
U_start = U[:-1]  # (6000, 10000, 64)

# Convert to float32 for division (to detect non-integers)
delta_W_f32 = delta_W.to(torch.float32)
U_start_f32 = U_start.to(torch.float32)

# Compute lattice displacement
delta_W_prime = delta_W_f32 / U_start_f32

print(f"‚úì delta_W_prime computed")
print(f"  Shape: {delta_W_prime.shape}")
print(f"  Dtype: {delta_W_prime.dtype}")
print(f"  Memory: {delta_W_prime.element_size() * delta_W_prime.nelement() / 1e9:.2f} GB")

# Current memory
mem_delta_W_prime = delta_W_prime.element_size() * delta_W_prime.nelement() / 1e9
mem_U_full = U.element_size() * U.nelement() / 1e9
# We still have: U (full), delta_W (bf16), delta_W_f32, U_start_f32, delta_W_prime
# But delta_W and U_start can be freed now

del delta_W, U_start, delta_W_f32, U_start_f32
print("\n‚úì Freed intermediate tensors")

current_mem = mem_U_full + mem_delta_W_prime
print(f"  Current memory usage: {current_mem:.2f} GB (U + delta_W_prime)")
print(f"  Budget remaining: {24 - current_mem:.2f} GB")

Computing delta_W_prime (lattice displacement)...

‚úì Freed W from memory
‚úì delta_W_prime computed
  Shape: torch.Size([6000, 10000, 64])
  Dtype: torch.float32
  Memory: 15.36 GB

‚úì Freed intermediate tensors
  Current memory usage: 23.04 GB (U + delta_W_prime)
  Budget remaining: 0.96 GB


## Check: Float32 Integer Fidelity

Float32 can only represent integers exactly up to 2^24 = 16,777,216. Beyond that, gaps appear.

In [8]:
print("Checking float32 integer fidelity...\n")

min_val = delta_W_prime.min().item()
max_val = delta_W_prime.max().item()
max_abs = max(abs(min_val), abs(max_val))

print(f"delta_W_prime range:")
print(f"  Min: {min_val:.2e}")
print(f"  Max: {max_val:.2e}")
print(f"  Max absolute: {max_abs:.2e}")
print()
print(f"Float32 exact integer limit: {FLOAT32_INT_LIMIT:,} (2^24)")

if max_abs > FLOAT32_INT_LIMIT:
    n_overflow = (delta_W_prime.abs() > FLOAT32_INT_LIMIT).sum().item()
    print(f"\n‚ö†Ô∏è  WARNING: {n_overflow:,} values exceed float32 integer fidelity limit!")
    print(f"   These values cannot be represented exactly as integers in float32.")
    print(f"   Validation results may be unreliable for these values.")
else:
    print(f"\n‚úì All values within float32 exact integer range.")

Checking float32 integer fidelity...

delta_W_prime range:
  Min: 0.00e+00
  Max: 0.00e+00
  Max absolute: 0.00e+00

Float32 exact integer limit: 16,777,216 (2^24)

‚úì All values within float32 exact integer range.


## Validation: Integer Quantization Check

Check that all lattice coordinates are integers (or nearly so).

Categories:
- `frac == 0`: ‚úÖ Exact integer
- `0 < frac < Œµ`: ‚ö†Ô∏è Float rounding error (acceptable)
- `frac ‚âà 0.5`: ‚ö†Ô∏è Half-integer (flag for investigation)
- `frac >= Œµ`: üö® GENERAL QUARTERS (algorithm broken)

In [9]:
print("\nValidating integer quantization...\n")

# Flatten all coordinates
coords = delta_W_prime.flatten().cpu()
n_total = len(coords)

# Compute fractional part
frac_part = torch.abs(coords - torch.floor(coords))

# Classify
exact = (frac_part == 0)
almost = (frac_part > 0) & (frac_part < EPSILON)
half = (torch.abs(frac_part - 0.5) < EPSILON)  # Near 0.5
non_int = (frac_part >= EPSILON) & ~half

n_exact = exact.sum().item()
n_almost = almost.sum().item()
n_half = half.sum().item()
n_non = non_int.sum().item()

print("=" * 80)
print("INTEGER QUANTIZATION VALIDATION")
print("=" * 80)
print()
print(f"Total coordinates: {n_total:,}")
print(f"Epsilon: {EPSILON:.0e}")
print()
print(f"Exact integers:    {n_exact:,}  ({n_exact/n_total:.6%})")
print(f"Almost integers:   {n_almost:,}  ({n_almost/n_total:.6%})  [float32 rounding]")
print(f"Half-integers:     {n_half:,}  ({n_half/n_total:.6%})  [FLAGGED]")
print(f"Non-integers:      {n_non:,}  ({n_non/n_total:.6%})  [ERROR!]")
print()

# Determine status
if n_non > 0:
    print("üö® " * 20)
    print("\nüö¢ GENERAL QUARTERS! GENERAL QUARTERS! ALL HANDS TO BATTLE STATIONS! üö¢")
    print()
    print(f"Found {n_non:,} NON-INTEGER coordinates!")
    print("This indicates the lattice coordinate algorithm is not correct.")
    print()
    print("üö® " * 20)
elif n_half > 0:
    print("‚ö†Ô∏è  WARNING: Found half-integer coordinates.")
    print("   This may indicate edge cases in the algorithm.")
    print("   Flagged for investigation.")
elif n_almost > 0:
    print("‚úì PASS (with float32 rounding)")
    print(f"  All coordinates are integers within tolerance.")
    print(f"  {n_almost:,} coordinates have minor float32 rounding errors.")
else:
    print("‚úì‚úì‚úì PERFECT INTEGER LATTICE ‚úì‚úì‚úì")
    print("  All coordinates are EXACT integers.")

print("\n" + "=" * 80)


Validating integer quantization...

INTEGER QUANTIZATION VALIDATION

Total coordinates: 3,840,000,000
Epsilon: 1e-06

Exact integers:    3,840,000,000  (100.000000%)
Almost integers:   0  (0.000000%)  [float32 rounding]
Half-integers:     0  (0.000000%)  [FLAGGED]
Non-integers:      0  (0.000000%)  [ERROR!]

‚úì‚úì‚úì PERFECT INTEGER LATTICE ‚úì‚úì‚úì
  All coordinates are EXACT integers.



## Investigation: Non-Standard Coordinates (If Any)

In [10]:
if n_non > 0 or n_half > 0:
    print("\n" + "=" * 80)
    print("INVESTIGATING NON-STANDARD COORDINATES")
    print("=" * 80)
    print()
    
    if n_non > 0:
        print(f"Non-integer coordinates (sample of 20):\n")
        non_int_indices = torch.where(non_int)[0].numpy()
        sample_size = min(20, len(non_int_indices))
        sample_indices = np.random.choice(non_int_indices, size=sample_size, replace=False)
        
        for i, idx in enumerate(sample_indices, 1):
            val = coords[idx].item()
            nearest = np.floor(val)
            frac = abs(val - nearest)
            print(f"  {i:2d}. {val:20.10f}  (nearest: {nearest:12.0f}, frac: {frac:.10f})")
    
    if n_half > 0:
        print(f"\nHalf-integer coordinates (sample of 20):\n")
        half_indices = torch.where(half)[0].numpy()
        sample_size = min(20, len(half_indices))
        sample_indices = np.random.choice(half_indices, size=sample_size, replace=False)
        
        for i, idx in enumerate(sample_indices, 1):
            val = coords[idx].item()
            print(f"  {i:2d}. {val:20.10f}")
    
    print("\n" + "=" * 80)
else:
    print("\n‚úì No non-standard coordinates found.")


‚úì No non-standard coordinates found.


## Range Analysis: What Integer Type Do We Need?

In [11]:
print("\nAnalyzing range to determine integer type...\n")

# Round to nearest integer for range analysis
coords_int = torch.round(coords)

min_int = coords_int.min().item()
max_int = coords_int.max().item()
max_abs_int = max(abs(min_int), abs(max_int))

print(f"Integer range:")
print(f"  Min: {min_int:,.0f}")
print(f"  Max: {max_int:,.0f}")
print(f"  Max absolute: {max_abs_int:,.0f}")
print()

# Determine required integer type
int_types = [
    ('int8', 2**7 - 1, 127),
    ('int16', 2**15 - 1, 32_767),
    ('int32', 2**31 - 1, 2_147_483_647),
    ('int64', 2**63 - 1, 9_223_372_036_854_775_807),
]

print("Integer type requirements:")
recommended_type = None
for dtype, limit, limit_val in int_types:
    fits = max_abs_int <= limit
    status = "‚úì" if fits else "‚úó"
    print(f"  {status} {dtype:6s}: range [¬±{limit_val:,}]  {'FITS' if fits else 'TOO SMALL'}")
    if fits and recommended_type is None:
        recommended_type = dtype

print()
if recommended_type:
    print(f"‚úì Recommended integer type: {recommended_type}")
else:
    print("‚ö†Ô∏è  WARNING: Values exceed int64 range!")


Analyzing range to determine integer type...

Integer range:
  Min: 0
  Max: 0
  Max absolute: 0

Integer type requirements:
  ‚úì int8  : range [¬±127]  FITS
  ‚úì int16 : range [¬±32,767]  FITS
  ‚úì int32 : range [¬±2,147,483,647]  FITS
  ‚úì int64 : range [¬±9,223,372,036,854,775,807]  FITS

‚úì Recommended integer type: int8


## Summary Statistics

In [12]:
print("\nComputing summary statistics...\n")

# Displacement magnitudes (L2 norm)
displacement_mag = torch.norm(delta_W_prime, dim=2).cpu().numpy()  # (6000, 10000)

# Filter out inf/nan for statistics
mag_finite = displacement_mag[np.isfinite(displacement_mag)]

print(f"Displacement magnitude (L2 norm):")
if len(mag_finite) < len(displacement_mag.flatten()):
    n_inf = np.isinf(displacement_mag).sum()
    n_nan = np.isnan(displacement_mag).sum()
    print(f"  ‚ö†Ô∏è  Found {n_inf:,} inf and {n_nan:,} nan values")
    print(f"  Statistics computed on {len(mag_finite):,} finite values:")
else:
    print(f"  All values finite:")

print(f"  Min:    {mag_finite.min():.2e} cells")
print(f"  Max:    {mag_finite.max():.2e} cells")
print(f"  Mean:   {mag_finite.mean():.2f} cells")
print(f"  Median: {np.median(mag_finite):.2f} cells")
print()
print(f"Percentiles:")
for p in [50, 90, 95, 99, 99.9]:
    print(f"  {p:5.1f}%: {np.percentile(mag_finite, p):12.2f} cells")


Computing summary statistics...

Displacement magnitude (L2 norm):
  All values finite:
  Min:    0.00e+00 cells
  Max:    0.00e+00 cells
  Mean:   0.00 cells
  Median: 0.00 cells

Percentiles:
   50.0%:         0.00 cells
   90.0%:         0.00 cells
   95.0%:         0.00 cells
   99.0%:         0.00 cells
   99.9%:         0.00 cells


## Detect Exponent Crossings

In [13]:
print("\nDetecting exponent crossings...\n")

# Reload W temporarily to check exponents (we freed it earlier)
with h5py.File(THIMBLE_PATH, 'r') as f:
    W_for_exp = torch.from_numpy(f['W'][:]).view(torch.bfloat16)

# Extract exponents from bfloat16
# bfloat16: [sign: 1 bit][exponent: 8 bits][mantissa: 7 bits]
W_uint16 = W_for_exp.view(torch.uint16).numpy()
exponents = (W_uint16 >> 7) & 0xFF  # Shift right 7, mask to 8 bits

# Compare consecutive timesteps
exp_t = exponents[:-1]   # (6000, 10000, 64)
exp_t1 = exponents[1:]   # (6000, 10000, 64)

exponent_crossings = (exp_t != exp_t1)

n_transitions = exp_t.size
n_crossings = exponent_crossings.sum()

print(f"‚úì Exponent crossings detected")
print(f"  Total transitions: {n_transitions:,}")
print(f"  Same exponent: {n_transitions - n_crossings:,} ({(n_transitions - n_crossings)/n_transitions:.6%})")
print(f"  Crossed exponent: {n_crossings:,} ({n_crossings/n_transitions:.6%})")

# Free temporary W
del W_for_exp, W_uint16, exponents

# Convert to torch tensor for saving
exponent_crossings_torch = torch.from_numpy(exponent_crossings)


Detecting exponent crossings...

‚úì Exponent crossings detected
  Total transitions: 3,840,000,000
  Same exponent: 3,806,213,693 (99.120148%)
  Crossed exponent: 33,786,307 (0.879852%)


## Save Results (DO NOT RUN YET)

These cells are ready to execute once we determine:
1. The validation passes
2. The appropriate integer type to use

**Instructions:** Fill in the `SAVE_DTYPE` variable below, then run these cells.

In [14]:
# ========== CONFIGURATION ==========
# Set this to the integer type determined above (e.g., torch.int16, torch.int32, torch.int64)
SAVE_DTYPE = None  # <-- FILL THIS IN BEFORE RUNNING

if SAVE_DTYPE is None:
    print("‚ö†Ô∏è  SAVE_DTYPE not set. Please configure before saving.")
else:
    print(f"Configured to save as: {SAVE_DTYPE}")

‚ö†Ô∏è  SAVE_DTYPE not set. Please configure before saving.


In [15]:
# DO NOT RUN until SAVE_DTYPE is configured

if SAVE_DTYPE is None:
    print("‚ùå Cannot save: SAVE_DTYPE not configured.")
else:
    print(f"Saving results to {OUTPUT_PATH}...\n")
    
    # Round and convert to integer type
    delta_W_prime_int = torch.round(delta_W_prime).to(SAVE_DTYPE)
    
    # Move to CPU for saving
    save_dict = {
        'delta_W_prime': delta_W_prime_int.cpu(),
        'U': U.cpu(),
        'dead_mask': dead_mask,
        'dead_token_ids': dead_token_ids,
        'exponent_crossings': exponent_crossings_torch,
        # Metadata
        'n_steps': torch.tensor(n_steps - 1),
        'n_tokens': torch.tensor(n_tokens),
        'n_dims': torch.tensor(n_dims),
        'n_dead': torch.tensor(n_dead),
    }
    
    save_file(save_dict, OUTPUT_PATH)
    
    file_size_gb = Path(OUTPUT_PATH).stat().st_size / 1e9
    
    print(f"‚úì Saved to {OUTPUT_PATH}")
    print(f"  File size: {file_size_gb:.2f} GB")
    print()
    print("Saved tensors:")
    print(f"  delta_W_prime: {delta_W_prime_int.shape} ({SAVE_DTYPE})")
    print(f"  U: {U.shape} (bfloat16)")
    print(f"  dead_mask: {dead_mask.shape} (bool)")
    print(f"  dead_token_ids: {dead_token_ids.shape} (int64)")
    print(f"  exponent_crossings: {exponent_crossings_torch.shape} (bool)")
    print()
    print("=" * 80)
    print("‚úì Lattice displacement calculation complete.")
    print("=" * 80)

‚ùå Cannot save: SAVE_DTYPE not configured.
