# 1.8j: The Tetrahedron

From 1.8i we learned that the 4 black holes have **rank 3**—they span a 3D subspace and form a tetrahedron.

Since the intrinsic dimensionality is 3, we can project the 4 points from 2560D down to 3D **with zero distortion**. All pairwise distances are preserved perfectly.

**Method:**
1. Use the SVD from 1.8i to find the optimal 3D subspace
2. Project all 4 black holes onto that subspace
3. Visualize as an interactive 3D scatter plot
4. Verify that pairwise distances match the original 2560D distances

## Parameters

In [1]:
# Model to analyze
MODEL_NAME = "Qwen3-4B-Instruct-2507"

## Imports

In [2]:
import torch
import numpy as np
import plotly.graph_objects as go
from safetensors.torch import load_file
from pathlib import Path
from itertools import combinations

## Device Detection

In [3]:
# Detect available device
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

print(f"Using device: {device}")

Using device: mps


## Load Data

In [4]:
# Load W in bfloat16
W_path = Path(f"../tensors/{MODEL_NAME}/W.safetensors")
W_bf16 = load_file(W_path)["W"]

print(f"Loaded W: {W_bf16.shape}")

Loaded W: torch.Size([151936, 2560])


In [5]:
# Load black hole data
bh_path = Path(f"../tensors/{MODEL_NAME}/1.8e_black_hole_masks.safetensors")
bh_data = load_file(bh_path)

bh1_token_ids = bh_data["bh1_token_ids"].to(torch.int64)
bh2_token_ids = bh_data["bh2_token_ids"].to(torch.int64)
bh3_token_ids = bh_data["bh3_token_ids"].to(torch.int64)
bh4_token_ids = bh_data["bh4_token_ids"].to(torch.int64)

print(f"\nLoaded black holes:")
print(f"  BH1: {len(bh1_token_ids):,} tokens")
print(f"  BH2: {len(bh2_token_ids):,} tokens")
print(f"  BH3: {len(bh3_token_ids):,} tokens")
print(f"  BH4: {len(bh4_token_ids):,} tokens")


Loaded black holes:
  BH1: 866 tokens
  BH2: 734 tokens
  BH3: 329 tokens
  BH4: 249 tokens


## Extract Black Hole Representative Vectors

In [6]:
print("\nExtracting black hole representative vectors...\n")

# Get first token from each black hole as representative
bh_token_ids = [
    bh1_token_ids[0].item(),
    bh2_token_ids[0].item(),
    bh3_token_ids[0].item(),
    bh4_token_ids[0].item()
]

# Get vectors from W (float32 for numerical stability)
bh_vectors = []
for i, token_id in enumerate(bh_token_ids, 1):
    vector = W_bf16[token_id].to(torch.float32)
    bh_vectors.append(vector)
    print(f"BH{i}: Token {token_id}")

# Stack into matrix (4 × 2560)
bh_matrix = torch.stack(bh_vectors, dim=0)

print(f"\n✓ Extracted {len(bh_vectors)} representative vectors")
print(f"Matrix shape: {bh_matrix.shape}")


Extracting black hole representative vectors...

BH1: Token 80091
BH2: Token 125
BH3: Token 124
BH4: Token 123939

✓ Extracted 4 representative vectors
Matrix shape: torch.Size([4, 2560])


## Compute Original Pairwise Distances (2560D)

In [7]:
print("\nComputing original pairwise distances (2560D)...\n")

# Compute pairwise L2 distances in original 2560D space
original_distances = torch.cdist(bh_matrix, bh_matrix, p=2)

print("Original distance matrix (2560D):")
print("     BH1      BH2      BH3      BH4")
for i in range(4):
    row_str = f"BH{i+1}"
    for j in range(4):
        row_str += f"  {original_distances[i, j].item():.4e}"
    print(row_str)

print(f"\n✓ Computed original distances")


Computing original pairwise distances (2560D)...

Original distance matrix (2560D):
     BH1      BH2      BH3      BH4
BH1  0.0000e+00  3.4137e-05  3.4975e-05  1.7060e-05
BH2  3.4137e-05  0.0000e+00  2.2894e-05  3.1476e-05
BH3  3.4975e-05  2.2894e-05  0.0000e+00  3.8914e-05
BH4  1.7060e-05  3.1476e-05  3.8914e-05  0.0000e+00

✓ Computed original distances


## SVD and Projection to 3D

In [8]:
print("\nProjecting to 3D via SVD...\n")

# Center the data (subtract mean)
mean_vector = bh_matrix.mean(dim=0)
centered_matrix = bh_matrix - mean_vector

# Compute SVD of centered data
U, S, Vt = torch.linalg.svd(centered_matrix, full_matrices=False)

print("Singular values:")
for i, s in enumerate(S[:3], 1):
    print(f"  σ{i} = {s.item():.6e}")

# Project onto top 3 principal components
# The projection is simply: X_3d = X_centered @ V[:, :3]
# where V[:, :3] are the first 3 columns of V (or first 3 rows of Vt)
projection_matrix = Vt[:3, :].T  # Shape: (2560, 3)
coords_3d = centered_matrix @ projection_matrix  # Shape: (4, 3)

print(f"\n3D coordinates:")
print("       x            y            z")
for i, (x, y, z) in enumerate(coords_3d, 1):
    print(f"BH{i}  {x.item():+.6e}  {y.item():+.6e}  {z.item():+.6e}")

print(f"\n✓ Projected to 3D")


Projecting to 3D via SVD...

Singular values:
  σ1 = 3.216804e-05
  σ2 = 1.789644e-05
  σ3 = 8.493865e-06

3D coordinates:
       x            y            z
BH1  -1.498559e-05  +8.040861e-06  +4.887559e-06
BH2  +1.278198e-05  -1.176060e-05  +3.400412e-06
BH3  +1.905185e-05  +9.289443e-06  -3.060123e-06
BH4  -1.684825e-05  -5.569702e-06  -5.227847e-06

✓ Projected to 3D


## Verify Distance Preservation

In [9]:
print("\nVerifying distance preservation...\n")

# Compute pairwise distances in 3D
projected_distances = torch.cdist(coords_3d, coords_3d, p=2)

print("Projected distance matrix (3D):")
print("     BH1      BH2      BH3      BH4")
for i in range(4):
    row_str = f"BH{i+1}"
    for j in range(4):
        row_str += f"  {projected_distances[i, j].item():.4e}"
    print(row_str)

# Compute relative error
abs_diff = torch.abs(original_distances - projected_distances)
max_error = abs_diff.max().item()
mean_error = abs_diff.mean().item()

print(f"\nDistance preservation:")
print(f"  Max absolute error: {max_error:.6e}")
print(f"  Mean absolute error: {mean_error:.6e}")

if max_error < 1e-6:
    print(f"  ✓ PERFECT preservation (error < 1e-6)")
elif max_error < 1e-3:
    print(f"  ✓ Excellent preservation (error < 1e-3)")
else:
    print(f"  ⚠ Some distortion detected (max error = {max_error:.6e})")

print(f"\n✓ Verification complete")


Verifying distance preservation...

Projected distance matrix (3D):
     BH1      BH2      BH3      BH4
BH1  0.0000e+00  3.4137e-05  3.4975e-05  1.7060e-05
BH2  3.4137e-05  0.0000e+00  2.2894e-05  3.1476e-05
BH3  3.4975e-05  2.2894e-05  0.0000e+00  3.8914e-05
BH4  1.7060e-05  3.1476e-05  3.8914e-05  0.0000e+00

Distance preservation:
  Max absolute error: 7.275958e-12
  Mean absolute error: 2.501110e-12
  ✓ PERFECT preservation (error < 1e-6)

✓ Verification complete


## Interactive 3D Visualization

In [10]:
print("\nCreating interactive 3D visualization...\n")

# Extract coordinates
x = coords_3d[:, 0].cpu().numpy()
y = coords_3d[:, 1].cpu().numpy()
z = coords_3d[:, 2].cpu().numpy()

# Labels and colors
labels = ['BH1 (866 tokens)', 'BH2 (734 tokens)', 'BH3 (329 tokens)', 'BH4 (249 tokens)']
colors = ['red', 'blue', 'green', 'orange']
sizes = [866, 734, 329, 249]

# Create scatter plot
fig = go.Figure()

# Add points
for i in range(4):
    fig.add_trace(go.Scatter3d(
        x=[x[i]],
        y=[y[i]],
        z=[z[i]],
        mode='markers+text',
        marker=dict(
            size=15,
            color=colors[i],
            line=dict(color='black', width=2)
        ),
        text=labels[i],
        textposition='top center',
        name=labels[i],
        hovertemplate=f"<b>{labels[i]}</b><br>" +
                      f"x: %{{x:.6e}}<br>" +
                      f"y: %{{y:.6e}}<br>" +
                      f"z: %{{z:.6e}}<br>" +
                      f"<extra></extra>"
    ))

# Add edges to form tetrahedron
# Connect all pairs (6 edges for 4 vertices)
for i, j in combinations(range(4), 2):
    fig.add_trace(go.Scatter3d(
        x=[x[i], x[j]],
        y=[y[i], y[j]],
        z=[z[i], z[j]],
        mode='lines',
        line=dict(color='gray', width=2, dash='dash'),
        showlegend=False,
        hoverinfo='skip'
    ))

# Layout
fig.update_layout(
    title=dict(
        text="The 4 Black Holes: A 3D Tetrahedron<br><sub>Perfect projection from 2560D (rank 3, zero distortion)</sub>",
        x=0.5,
        xanchor='center'
    ),
    scene=dict(
        xaxis_title='Principal Component 1',
        yaxis_title='Principal Component 2',
        zaxis_title='Principal Component 3',
        aspectmode='data'
    ),
    width=900,
    height=700,
    showlegend=True,
    legend=dict(
        x=0.02,
        y=0.98,
        bgcolor='rgba(255,255,255,0.8)'
    )
)

fig.show()

print("✓ Visualization complete")
print("\nRotate the plot to see the tetrahedral structure!")


Creating interactive 3D visualization...



✓ Visualization complete

Rotate the plot to see the tetrahedral structure!


## Summary: The Tetrahedral Structure

In [11]:
print("\n" + "=" * 80)
print("SUMMARY: THE TETRAHEDRAL STRUCTURE")
print("=" * 80)
print()

print("The 4 black holes form a 3D tetrahedron:")
print()

# Compute edge lengths
print("Edge lengths (pairwise distances):")
for i, j in combinations(range(4), 2):
    dist = original_distances[i, j].item()
    print(f"  BH{i+1} ↔ BH{j+1}: {dist:.6e}")

print()

# Compute volume of tetrahedron
# Volume = (1/6) |det([v1, v2, v3])| where vi are edge vectors from one vertex
v1 = coords_3d[1] - coords_3d[0]
v2 = coords_3d[2] - coords_3d[0]
v3 = coords_3d[3] - coords_3d[0]

edge_matrix = torch.stack([v1, v2, v3], dim=1)  # 3x3 matrix
det = torch.linalg.det(edge_matrix).item()
volume = abs(det) / 6.0

print(f"Tetrahedron volume: {volume:.6e}")
print()

print("This tetrahedron lives in a 3D subspace embedded in 2560D space.")
print("The projection is PERFECT—all distances are preserved exactly.")
print()
print("The structure is NOT flat (rank 3, not rank 2).")
print("All 4 points are necessary to span the 3D subspace.")
print()
print("=" * 80)


SUMMARY: THE TETRAHEDRAL STRUCTURE

The 4 black holes form a 3D tetrahedron:

Edge lengths (pairwise distances):
  BH1 ↔ BH2: 3.413718e-05
  BH1 ↔ BH3: 3.497530e-05
  BH1 ↔ BH4: 1.705985e-05
  BH2 ↔ BH3: 2.289441e-05
  BH2 ↔ BH4: 3.147577e-05
  BH3 ↔ BH4: 3.891414e-05

Tetrahedron volume: 1.629954e-15

This tetrahedron lives in a 3D subspace embedded in 2560D space.
The projection is PERFECT—all distances are preserved exactly.

The structure is NOT flat (rank 3, not rank 2).
All 4 points are necessary to span the 3D subspace.

