# 1.7a: Cluster Outlier Detection

**Hypothesis:** The cluster contains a distant outlier - a token on the same ray as the core cluster but with much larger norm.

**Evidence:** Median distance from centroid is 0.00089653, but max is 1.16091752.

**Prediction:** The largest pairwise L∞ distance will identify two tokens, one of which has ||t|| >> ||centroid||.

**Test:** Compute all pairwise L∞ distances, find argmax, check norms.

## Parameters

In [7]:
# Model to analyze
MODEL_NAME = "Qwen3-4B-Instruct-2507"

## Imports

In [8]:
import torch
from safetensors.torch import load_file
from pathlib import Path

## Device Detection

In [9]:
# Detect available device
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

print(f"Using device: {device}")

Using device: mps


## Load Data

In [10]:
# Load W
W_path = Path(f"../tensors/{MODEL_NAME}/W.safetensors")
W = load_file(W_path)["W"].to(torch.float32)

print(f"Loaded W from {W_path}")
print(f"  Shape: {W.shape}")

# Load cluster data from 1.6a
cluster_path = Path(f"../tensors/{MODEL_NAME}/1.6a_cluster_mask.safetensors")
cluster_data = load_file(cluster_path)

cluster_mask = cluster_data["cluster_mask"].to(torch.bool)
cluster_token_ids = cluster_data["cluster_token_ids"].to(torch.int64)
centroid = cluster_data["centroid"].to(torch.float32)
n_cluster = cluster_data["n_cluster"].item()

print(f"\nLoaded cluster from {cluster_path}")
print(f"  Cluster size: {n_cluster:,} tokens")
print(f"  Centroid norm: {centroid.norm().item():.8f}")

Loaded W from ../tensors/Qwen3-4B-Instruct-2507/W.safetensors
  Shape: torch.Size([151936, 2560])

Loaded cluster from ../tensors/Qwen3-4B-Instruct-2507/1.6a_cluster_mask.safetensors
  Cluster size: 2,248 tokens
  Centroid norm: 0.37091014


## Extract Cluster Embeddings

In [11]:
# Get cluster embeddings
W_cluster = W[cluster_mask]

print(f"Extracted {W_cluster.shape[0]:,} cluster embeddings")
print(f"  Dimensionality: {W_cluster.shape[1]:,}")

Extracted 2,248 cluster embeddings
  Dimensionality: 2,560


## Compute Pairwise L∞ Distances

In [12]:
print("Computing pairwise L∞ distances (batched to avoid OOM)...\n")

# Move to device for computation
W_cluster_device = W_cluster.to(device)

# Batch size for computing distances (to avoid allocating 48GB)
BATCH_SIZE = 256

# Track global maximum
global_max_linf = -1.0
global_i = -1
global_j = -1

with torch.no_grad():
    # Process in batches of rows
    for batch_start in range(0, n_cluster, BATCH_SIZE):
        batch_end = min(batch_start + BATCH_SIZE, n_cluster)
        batch_size = batch_end - batch_start
        
        # Get batch of tokens: shape (batch_size, d)
        batch_tokens = W_cluster_device[batch_start:batch_end]
        
        # Compute differences with ALL tokens: (batch_size, 1, d) - (1, n_cluster, d)
        # This creates (batch_size, n_cluster, d) instead of (n_cluster, n_cluster, d)
        diffs = batch_tokens.unsqueeze(1) - W_cluster_device.unsqueeze(0)
        
        # L∞ for this batch: (batch_size, n_cluster)
        linf_batch = diffs.abs().max(dim=2).values
        
        # Find max in this batch
        batch_max = linf_batch.max().item()
        
        if batch_max > global_max_linf:
            # Update global maximum
            global_max_linf = batch_max
            
            # Find indices within batch
            flat_idx = linf_batch.argmax().item()
            local_i = flat_idx // n_cluster
            local_j = flat_idx % n_cluster
            
            # Convert to global indices
            global_i = batch_start + local_i
            global_j = local_j
        
        print(f"  Processed rows {batch_start:4d}-{batch_end:4d}, batch max: {batch_max:.8f}")

print(f"\n✓ Computed pairwise L∞ distances in batches of {BATCH_SIZE}")
print(f"  Global maximum: {global_max_linf:.8f} at ({global_i}, {global_j})")

Computing pairwise L∞ distances (batched to avoid OOM)...

  Processed rows    0- 256, batch max: 0.13134766
  Processed rows  256- 512, batch max: 0.08604431
  Processed rows  512- 768, batch max: 0.08551025
  Processed rows  768-1024, batch max: 0.08551025
  Processed rows 1024-1280, batch max: 0.08551407
  Processed rows 1280-1536, batch max: 0.10180664
  Processed rows 1536-1792, batch max: 0.08576202
  Processed rows 1792-2048, batch max: 0.08660889
  Processed rows 2048-2248, batch max: 0.08551025

✓ Computed pairwise L∞ distances in batches of 256
  Global maximum: 0.13134766 at (14, 17)


## Find Maximum L∞ Distance

In [13]:
print("\nExtracting result...\n")

# The global_i and global_j are already computed from the batched loop
max_linf = global_max_linf
i = global_i
j = global_j

# Get actual token IDs
token_i = cluster_token_ids[i].item()
token_j = cluster_token_ids[j].item()

print(f"Maximum L∞ distance: {max_linf:.8f}")
print(f"  Between tokens: {token_i} and {token_j}")
print(f"  (Cluster indices: {i} and {j})")


Extracting result...

Maximum L∞ distance: 0.13134766
  Between tokens: 48494 and 71473
  (Cluster indices: 14 and 17)


## Test the Hypothesis

In [14]:
print("\n" + "="*60)
print("HYPOTHESIS TEST")
print("="*60)
print()

# Get the two tokens' embeddings
t_i = W[token_i]
t_j = W[token_j]

# Compute norms
norm_i = t_i.norm().item()
norm_j = t_j.norm().item()
norm_centroid = centroid.norm().item()

# Compute distances from centroid
dist_i = (t_i - centroid).norm().item()
dist_j = (t_j - centroid).norm().item()

print(f"Token {token_i}:")
print(f"  Norm: {norm_i:.8f}")
print(f"  Distance from centroid: {dist_i:.8f}")
print()
print(f"Token {token_j}:")
print(f"  Norm: {norm_j:.8f}")
print(f"  Distance from centroid: {dist_j:.8f}")
print()
print(f"Centroid:")
print(f"  Norm: {norm_centroid:.8f}")
print()
print("="*60)
print()

# Identify the outlier
if norm_i > norm_j:
    outlier_id = token_i
    outlier_norm = norm_i
    outlier_dist = dist_i
    core_id = token_j
    core_norm = norm_j
    core_dist = dist_j
else:
    outlier_id = token_j
    outlier_norm = norm_j
    outlier_dist = dist_j
    core_id = token_i
    core_norm = norm_i
    core_dist = dist_i

print(f"RESULT:")
print(f"  Outlier: token {outlier_id}")
print(f"    Norm: {outlier_norm:.8f} ({outlier_norm/norm_centroid:.2f}× centroid)")
print(f"    Distance from centroid: {outlier_dist:.8f}")
print()
print(f"  Core token: token {core_id}")
print(f"    Norm: {core_norm:.8f} ({core_norm/norm_centroid:.2f}× centroid)")
print(f"    Distance from centroid: {core_dist:.8f}")
print()
print(f"Hypothesis supported: {outlier_norm > 2 * norm_centroid}")


HYPOTHESIS TEST

Token 48494:
  Norm: 1.16991878
  Distance from centroid: 1.10977745

Token 71473:
  Norm: 1.21912193
  Distance from centroid: 1.16093612

Centroid:
  Norm: 0.37091014


RESULT:
  Outlier: token 71473
    Norm: 1.21912193 (3.29× centroid)
    Distance from centroid: 1.16093612

  Core token: token 48494
    Norm: 1.16991878 (3.15× centroid)
    Distance from centroid: 1.10977745

Hypothesis supported: True


## Summary Statistics

In [15]:
# Compute all norms and distances for context
all_norms = W_cluster.norm(dim=1)
all_distances = (W_cluster - centroid).norm(dim=1)

print("\nCluster statistics:")
print(f"  Norms:")
print(f"    Min: {all_norms.min().item():.8f}")
print(f"    Max: {all_norms.max().item():.8f}")
print(f"    Median: {all_norms.median().item():.8f}")
print(f"    Mean: {all_norms.mean().item():.8f}")
print()
print(f"  Distances from centroid:")
print(f"    Min: {all_distances.min().item():.8f}")
print(f"    Max: {all_distances.max().item():.8f}")
print(f"    Median: {all_distances.median().item():.8f}")
print(f"    Mean: {all_distances.mean().item():.8f}")


Cluster statistics:
  Norms:
    Min: 0.37029281
    Max: 1.21912193
    Median: 0.37091675
    Mean: 0.37223831

  Distances from centroid:
    Min: 0.00125120
    Max: 1.16093612
    Median: 0.00125222
    Mean: 0.00374321
