# 1.7b: Cluster Outlier Detection (L2)

**Hypothesis:** The cluster contains distant outliers - tokens far from the core cluster.

**Evidence:** Median distance from centroid is 0.00089653, but max is 1.16091752.

**Prediction:** The largest pairwise L2 distance will identify two distant tokens.

**Test:** Compute all pairwise L2 distances, find argmax, check norms and distance between the pair.

## Parameters

In [10]:
# Model to analyze
MODEL_NAME = "Qwen3-4B-Instruct-2507"

## Imports

In [11]:
import torch
from safetensors.torch import load_file
from pathlib import Path

## Device Detection

In [12]:
# Detect available device
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

print(f"Using device: {device}")

Using device: mps


## Load Data

In [13]:
# Load W
W_path = Path(f"../tensors/{MODEL_NAME}/W.safetensors")
W = load_file(W_path)["W"].to(torch.float32)

print(f"Loaded W from {W_path}")
print(f"  Shape: {W.shape}")

# Load cluster data from 1.6a
cluster_path = Path(f"../tensors/{MODEL_NAME}/1.6a_cluster_mask.safetensors")
cluster_data = load_file(cluster_path)

cluster_mask = cluster_data["cluster_mask"].to(torch.bool)
cluster_token_ids = cluster_data["cluster_token_ids"].to(torch.int64)
centroid = cluster_data["centroid"].to(torch.float32)
n_cluster = cluster_data["n_cluster"].item()

print(f"\nLoaded cluster from {cluster_path}")
print(f"  Cluster size: {n_cluster:,} tokens")
print(f"  Centroid norm: {centroid.norm().item():.8f}")

Loaded W from ../tensors/Qwen3-4B-Instruct-2507/W.safetensors
  Shape: torch.Size([151936, 2560])

Loaded cluster from ../tensors/Qwen3-4B-Instruct-2507/1.6a_cluster_mask.safetensors
  Cluster size: 2,248 tokens
  Centroid norm: 0.37091014


## Extract Cluster Embeddings

In [14]:
# Get cluster embeddings
W_cluster = W[cluster_mask]

print(f"Extracted {W_cluster.shape[0]:,} cluster embeddings")
print(f"  Dimensionality: {W_cluster.shape[1]:,}")

Extracted 2,248 cluster embeddings
  Dimensionality: 2,560


## Compute Pairwise L2 Distances

In [15]:
print("Computing pairwise L2 distances (batched to avoid OOM)...\n")

# Move to device for computation
W_cluster_device = W_cluster.to(device)

# Batch size for computing distances (to avoid allocating 48GB)
BATCH_SIZE = 256

# Track global maximum
global_max_l2 = -1.0
global_i = -1
global_j = -1

with torch.no_grad():
    # Process in batches of rows
    for batch_start in range(0, n_cluster, BATCH_SIZE):
        batch_end = min(batch_start + BATCH_SIZE, n_cluster)
        batch_size = batch_end - batch_start
        
        # Get batch of tokens: shape (batch_size, d)
        batch_tokens = W_cluster_device[batch_start:batch_end]
        
        # Compute differences with ALL tokens: (batch_size, 1, d) - (1, n_cluster, d)
        # This creates (batch_size, n_cluster, d) instead of (n_cluster, n_cluster, d)
        diffs = batch_tokens.unsqueeze(1) - W_cluster_device.unsqueeze(0)
        
        # L2 for this batch: sqrt(sum(diffs^2)) = (batch_size, n_cluster)
        l2_batch = diffs.norm(dim=2)
        
        # Find max in this batch
        batch_max = l2_batch.max().item()
        
        if batch_max > global_max_l2:
            # Update global maximum
            global_max_l2 = batch_max
            
            # Find indices within batch
            flat_idx = l2_batch.argmax().item()
            local_i = flat_idx // n_cluster
            local_j = flat_idx % n_cluster
            
            # Convert to global indices
            global_i = batch_start + local_i
            global_j = local_j
        
        print(f"  Processed rows {batch_start:4d}-{batch_end:4d}, batch max: {batch_max:.8f}")

print(f"\n✓ Computed pairwise L2 distances in batches of {BATCH_SIZE}")
print(f"  Global maximum: {global_max_l2:.8f} at ({global_i}, {global_j})")

Computing pairwise L2 distances (batched to avoid OOM)...

  Processed rows    0- 256, batch max: 1.60828805
  Processed rows  256- 512, batch max: 1.16175008
  Processed rows  512- 768, batch max: 1.16141784
  Processed rows  768-1024, batch max: 1.16141856
  Processed rows 1024-1280, batch max: 1.19637430
  Processed rows 1280-1536, batch max: 1.25368178
  Processed rows 1536-1792, batch max: 1.16413260
  Processed rows 1792-2048, batch max: 1.21384871
  Processed rows 2048-2248, batch max: 1.16141760

✓ Computed pairwise L2 distances in batches of 256
  Global maximum: 1.60828805 at (14, 17)


## Find Maximum L2 Distance

In [16]:
print("\nExtracting result...\n")

# The global_i and global_j are already computed from the batched loop
max_l2 = global_max_l2
i = global_i
j = global_j

# Get actual token IDs
token_i = cluster_token_ids[i].item()
token_j = cluster_token_ids[j].item()

print(f"Maximum L2 distance: {max_l2:.8f}")
print(f"  Between tokens: {token_i} and {token_j}")
print(f"  (Cluster indices: {i} and {j})")


Extracting result...

Maximum L2 distance: 1.60828805
  Between tokens: 48494 and 71473
  (Cluster indices: 14 and 17)


## Test the Hypothesis

In [17]:
print("\n" + "="*60)
print("HYPOTHESIS TEST")
print("="*60)
print()

# Get the two tokens' embeddings
t_i = W[token_i]
t_j = W[token_j]

# Compute norms
norm_i = t_i.norm().item()
norm_j = t_j.norm().item()
norm_centroid = centroid.norm().item()

# Compute distances from centroid
dist_i = (t_i - centroid).norm().item()
dist_j = (t_j - centroid).norm().item()

# Compute L2 distance between the two tokens
dist_ij = (t_i - t_j).norm().item()

print(f"Token {token_i}:")
print(f"  Norm: {norm_i:.8f}")
print(f"  Distance from centroid: {dist_i:.8f}")
print()
print(f"Token {token_j}:")
print(f"  Norm: {norm_j:.8f}")
print(f"  Distance from centroid: {dist_j:.8f}")
print()
print(f"Distance between the two tokens:")
print(f"  L2({token_i}, {token_j}): {dist_ij:.8f}")
print()
print(f"Centroid:")
print(f"  Norm: {norm_centroid:.8f}")
print()
print("="*60)
print()

# Identify the outlier
if norm_i > norm_j:
    outlier_id = token_i
    outlier_norm = norm_i
    outlier_dist = dist_i
    core_id = token_j
    core_norm = norm_j
    core_dist = dist_j
else:
    outlier_id = token_j
    outlier_norm = norm_j
    outlier_dist = dist_j
    core_id = token_i
    core_norm = norm_i
    core_dist = dist_i

print(f"RESULT:")
print(f"  Outlier: token {outlier_id}")
print(f"    Norm: {outlier_norm:.8f} ({outlier_norm/norm_centroid:.2f}× centroid)")
print(f"    Distance from centroid: {outlier_dist:.8f}")
print()
print(f"  Other outlier: token {core_id}")
print(f"    Norm: {core_norm:.8f} ({core_norm/norm_centroid:.2f}× centroid)")
print(f"    Distance from centroid: {core_dist:.8f}")
print()
print(f"Both tokens are far from centroid: {(outlier_norm > 2 * norm_centroid) and (core_norm > 2 * norm_centroid)}")
print(f"Distance between them vs their distance from centroid:")
print(f"  {dist_ij:.8f} vs ~{(dist_i + dist_j)/2:.8f}")


HYPOTHESIS TEST

Token 48494:
  Norm: 1.16991878
  Distance from centroid: 1.10977745

Token 71473:
  Norm: 1.21912193
  Distance from centroid: 1.16093612

Distance between the two tokens:
  L2(48494, 71473): 1.60828757

Centroid:
  Norm: 0.37091014


RESULT:
  Outlier: token 71473
    Norm: 1.21912193 (3.29× centroid)
    Distance from centroid: 1.16093612

  Other outlier: token 48494
    Norm: 1.16991878 (3.15× centroid)
    Distance from centroid: 1.10977745

Both tokens are far from centroid: True
Distance between them vs their distance from centroid:
  1.60828757 vs ~1.13535678


## Triangle Geometry

In [18]:
print("\n" + "="*60)
print("TRIANGLE GEOMETRY")
print("="*60)
print()

# We have a triangle: centroid, token_i, token_j
# Sides:
#   a = dist_i (centroid to token_i)
#   b = dist_j (centroid to token_j)  
#   c = dist_ij (token_i to token_j)

a = dist_i
b = dist_j
c = dist_ij

print(f"Triangle sides:")
print(f"  Centroid to {token_i}: {a:.8f}")
print(f"  Centroid to {token_j}: {b:.8f}")
print(f"  {token_i} to {token_j}: {c:.8f}")
print()

# Compute interior angles using law of cosines
# Angle at centroid (between the two tokens as seen from centroid)
cos_at_centroid = (a**2 + b**2 - c**2) / (2 * a * b)
angle_at_centroid = torch.acos(torch.tensor(cos_at_centroid))
angle_at_centroid_deg = torch.rad2deg(angle_at_centroid).item()

# Angle at token_i (between centroid and token_j as seen from token_i)
cos_at_i = (a**2 + c**2 - b**2) / (2 * a * c)
angle_at_i = torch.acos(torch.tensor(cos_at_i))
angle_at_i_deg = torch.rad2deg(angle_at_i).item()

# Angle at token_j (between centroid and token_i as seen from token_j)
cos_at_j = (b**2 + c**2 - a**2) / (2 * b * c)
angle_at_j = torch.acos(torch.tensor(cos_at_j))
angle_at_j_deg = torch.rad2deg(angle_at_j).item()

print(f"Interior angles:")
print(f"  At centroid: {angle_at_centroid_deg:.2f}°")
print(f"  At token {token_i}: {angle_at_i_deg:.2f}°")
print(f"  At token {token_j}: {angle_at_j_deg:.2f}°")
print(f"  Sum: {angle_at_centroid_deg + angle_at_i_deg + angle_at_j_deg:.2f}° (should be 180°)")
print()

# Alternative computation using dot products (sanity check)
# Angle at centroid using dot product
vec_to_i = t_i - centroid
vec_to_j = t_j - centroid
cos_dot = (vec_to_i @ vec_to_j) / (vec_to_i.norm() * vec_to_j.norm())
angle_dot = torch.acos(cos_dot)
angle_dot_deg = torch.rad2deg(angle_dot).item()

print(f"Verification (angle at centroid via dot product): {angle_dot_deg:.2f}°")
print()
print("="*60)


TRIANGLE GEOMETRY

Triangle sides:
  Centroid to 48494: 1.10977745
  Centroid to 71473: 1.16093612
  48494 to 71473: 1.60828757

Interior angles:
  At centroid: 90.16°
  At token 48494: 46.21°
  At token 71473: 43.63°
  Sum: 180.00° (should be 180°)

Verification (angle at centroid via dot product): 90.16°



## Summary Statistics

In [19]:
# Compute all norms and distances for context
all_norms = W_cluster.norm(dim=1)
all_distances = (W_cluster - centroid).norm(dim=1)

print("\nCluster statistics:")
print(f"  Norms:")
print(f"    Min: {all_norms.min().item():.8f}")
print(f"    Max: {all_norms.max().item():.8f}")
print(f"    Median: {all_norms.median().item():.8f}")
print(f"    Mean: {all_norms.mean().item():.8f}")
print()
print(f"  Distances from centroid:")
print(f"    Min: {all_distances.min().item():.8f}")
print(f"    Max: {all_distances.max().item():.8f}")
print(f"    Median: {all_distances.median().item():.8f}")
print(f"    Mean: {all_distances.mean().item():.8f}")


Cluster statistics:
  Norms:
    Min: 0.37029281
    Max: 1.21912193
    Median: 0.37091675
    Mean: 0.37223831

  Distances from centroid:
    Min: 0.00125120
    Max: 1.16093612
    Median: 0.00125222
    Mean: 0.00374321
