# Compute Lattice Displacement ΔW′[t]

**Purpose:** Transform raw embedding trajectory W[t] into lattice-cell displacement units.

**Input:** `box_4/tensors/Thimble-8/thimble_8_trajectory.safetensors`

**Output:** `box_4/tensors/Thimble-8/lattice_displacement.safetensors`

**Calculation:**
$$\Delta W[t] = W[t] - W[t-1]$$
$$\text{ULP}[t-1] = 2^{E[t-1] - 134}$$
$$\Delta W'[t] = \Delta W[t] \,/\, \text{ULP}[t-1]$$

where $E$ is the exponent extracted from bfloat16 representation, and 134 = 127 (bias) + 7 (mantissa bits).

ΔW′ values are in "lattice cell" units: how many bfloat16 quantization steps the token moved.

## Parameters

In [7]:
# Paths (relative to notebook location: box_4/notebooks/data_processing/)
INPUT_PATH = '../../tensors/Thimble-8/thimble_8_trajectory.safetensors'
OUTPUT_PATH = '../../tensors/Thimble-8/lattice_displacement.safetensors'

## Imports

In [8]:
import torch
from safetensors.torch import load_file, save_file
from pathlib import Path

## Load Data

In [9]:
data = load_file(INPUT_PATH)

# W is stored as uint16 to preserve bfloat16 bit patterns
W_uint16 = data['W']
W = W_uint16.view(torch.bfloat16)

print(f"W shape: {W.shape}")  # (4001, 3699, 64)
print(f"W dtype: {W.dtype}")

W shape: torch.Size([4001, 3699, 64])
W dtype: torch.bfloat16


## Compute ULP Function

For bfloat16, the ULP (Unit in Last Place) depends on the exponent:
- Extract exponent bits: `(uint16 >> 7) & 0xFF`
- For normal numbers (E > 0): `ULP = 2^(E - 134)`
- For subnormals (E = 0): `ULP = 2^(-133)` (fixed)

The magic number 134 = 127 (exponent bias) + 7 (mantissa bits).

In [10]:
def compute_ulp_bf16(tensor_bf16: torch.Tensor) -> torch.Tensor:
    """
    Compute the ULP (Unit in Last Place) for each element of a bfloat16 tensor.
    
    Returns a float32 tensor of ULP values.
    """
    # View as uint16 to extract bit fields
    bits = tensor_bf16.view(torch.uint16).to(torch.int32)
    
    # Extract exponent (bits 7-14, 8 bits)
    exponent = ((bits >> 7) & 0xFF).to(torch.int32)
    
    # Compute ULP
    # Normal numbers: ULP = 2^(E - 134) where 134 = 127 (bias) + 7 (mantissa bits)
    # Subnormals (E=0): ULP = 2^(-133) = 2^(1 - 134)
    # We handle subnormals by treating E=0 as E=1 for the ULP calculation
    effective_exp = torch.where(exponent == 0, torch.ones_like(exponent), exponent)
    
    ulp = torch.pow(2.0, (effective_exp - 134).float())
    
    return ulp

## Compute ΔW and ΔW′

In [11]:
# ΔW[t] = W[t] - W[t-1] for t = 1, 2, ..., 4000
# Shape: (4000, 3699, 64)
delta_W = (W[1:].float() - W[:-1].float())  # Compute in float32 for precision

print(f"ΔW shape: {delta_W.shape}")

ΔW shape: torch.Size([4000, 3699, 64])


In [12]:
# ULP at W[t-1] for t = 1, 2, ..., 4000
# Shape: (4000, 3699, 64)
ulp = compute_ulp_bf16(W[:-1])

print(f"ULP shape: {ulp.shape}")
print(f"ULP range: [{ulp.min():.2e}, {ulp.max():.2e}]")

ULP shape: torch.Size([4000, 3699, 64])
ULP range: [9.18e-41, 9.77e-04]


In [13]:
# ΔW′[t] = ΔW[t] / ULP[t-1]
# This gives displacement in lattice-cell units
delta_W_prime = delta_W / ulp

print(f"ΔW′ shape: {delta_W_prime.shape}")
print(f"ΔW′ dtype: {delta_W_prime.dtype}")

ΔW′ shape: torch.Size([4000, 3699, 64])
ΔW′ dtype: torch.float32


## Quick Sanity Check

In [14]:
print("=" * 50)
print("SANITY CHECK")
print("=" * 50)

# ΔW′ should be roughly integer-valued (tokens hop discrete cells)
# Look at a few early timesteps (hot phase) and late timesteps (cold phase)

print("\nEarly timestep (t=1, should be hot):")
early = delta_W_prime[0].abs()  # t=1
print(f"  Mean |ΔW′|: {early.mean():.2f} cells")
print(f"  Max |ΔW′|: {early.max():.2f} cells")
print(f"  Fraction |ΔW′| < 0.5: {(early < 0.5).float().mean():.1%}")

print("\nLate timestep (t=4000, should be cold):")
late = delta_W_prime[-1].abs()  # t=4000
print(f"  Mean |ΔW′|: {late.mean():.2f} cells")
print(f"  Max |ΔW′|: {late.max():.2f} cells")
print(f"  Fraction |ΔW′| < 0.5: {(late < 0.5).float().mean():.1%}")

SANITY CHECK

Early timestep (t=1, should be hot):
  Mean |ΔW′|: 44.11 cells
  Max |ΔW′|: 2588527.00 cells
  Fraction |ΔW′| < 0.5: 0.0%

Late timestep (t=4000, should be cold):
  Mean |ΔW′|: 0.00 cells
  Max |ΔW′|: 0.00 cells
  Fraction |ΔW′| < 0.5: 100.0%


## Save Output

In [15]:
output_data = {
    'delta_W_prime': delta_W_prime,  # (4000, 3699, 64) float32
}

output_path = Path(OUTPUT_PATH)
save_file(output_data, str(output_path))

print(f"Saved to {output_path}")
print(f"File size: {output_path.stat().st_size / 1e9:.2f} GB")

Saved to ../../tensors/Thimble-8/lattice_displacement.safetensors
File size: 3.79 GB


In [16]:
print("\nProcessing complete.")


Processing complete.
