# 03.2c: Investigate Jet - Cosine Similarity Analysis

**Goal:** Find the most extreme jet token and examine its nearest neighbors by cosine similarity.

We'll:
1. Load original gamma (uncentered) for cosine similarity
2. Find the jet token with largest L2 norm
3. Compute cosine similarity to all other tokens
4. Show the top 30 most similar tokens
5. Decode tokens to see semantic/syntactic patterns

**Key question:** Are jet tokens similar to each other in the original embedding space? Or are they just geometrically separated after centering?

## Parameters

In [1]:
TENSOR_DIR = "../data/tensors"
MODEL_NAME = "Qwen/Qwen3-4B-Instruct-2507"

# How many nearest neighbors to show
TOP_K = 30

## Imports

In [2]:
import torch
import pandas as pd
import numpy as np
from safetensors.torch import load_file
from pathlib import Path
from transformers import AutoTokenizer

print("Imports loaded successfully.")

Imports loaded successfully.


## Step 1: Load Original Gamma (Uncentered)

In [3]:
# Load original gamma (uncentered)
gamma_path = Path(TENSOR_DIR) / "gamma_qwen3_4b_instruct_2507.safetensors"
gamma = load_file(gamma_path)['gamma']

N, d = gamma.shape

print(f"Loaded γ (original, uncentered):")
print(f"  Tokens: {N:,}")
print(f"  Dimensions: {d:,}")
print(f"  Memory: {gamma.element_size() * gamma.nelement() / 1024**2:.1f} MB")

Loaded γ (original, uncentered):
  Tokens: 151,936
  Dimensions: 2,560
  Memory: 1483.8 MB


## Step 2: Load Jet Mask

In [4]:
# Load jet mask
jet_mask_path = Path(TENSOR_DIR) / "jet_mask.safetensors"
jet_mask = load_file(jet_mask_path)['jet_mask']

n_jet = jet_mask.sum().item()

print(f"Loaded jet mask:")
print(f"  Jet tokens: {n_jet:,} ({n_jet/N*100:.2f}%)")
print(f"  Bulk tokens: {(~jet_mask).sum().item():,} ({(~jet_mask).sum().item()/N*100:.2f}%)")

Loaded jet mask:
  Jet tokens: 3,055 (2.01%)
  Bulk tokens: 148,881 (97.99%)


## Step 3: Find Jet Token with Largest Norm

In [5]:
# Compute norms for all tokens
gamma_norms = gamma.norm(dim=1)

print(f"Norm statistics (all tokens):")
print(f"  Mean: {gamma_norms.mean().item():.6f} gamma units")
print(f"  Std: {gamma_norms.std().item():.6f} gamma units")
print(f"  Min: {gamma_norms.min().item():.6f} gamma units")
print(f"  Max: {gamma_norms.max().item():.6f} gamma units")
print()

# Get norms for jet tokens only
jet_norms = gamma_norms[jet_mask]
jet_token_ids = torch.where(jet_mask)[0]

print(f"Norm statistics (jet tokens only):")
print(f"  Mean: {jet_norms.mean().item():.6f} gamma units")
print(f"  Std: {jet_norms.std().item():.6f} gamma units")
print(f"  Min: {jet_norms.min().item():.6f} gamma units")
print(f"  Max: {jet_norms.max().item():.6f} gamma units")
print()

# Find jet token with largest norm
max_norm_idx_in_jet = jet_norms.argmax()
max_norm_token_id = jet_token_ids[max_norm_idx_in_jet].item()
max_norm_value = jet_norms[max_norm_idx_in_jet].item()

print(f"Jet token with largest norm:")
print(f"  Token ID: {max_norm_token_id}")
print(f"  Norm: {max_norm_value:.6f} gamma units")

Norm statistics (all tokens):
  Mean: 1.087283 gamma units
  Std: 0.168146 gamma units
  Min: 0.359538 gamma units
  Max: 1.605024 gamma units

Norm statistics (jet tokens only):
  Mean: 1.068991 gamma units
  Std: 0.090963 gamma units
  Min: 0.826530 gamma units
  Max: 1.432686 gamma units

Jet token with largest norm:
  Token ID: 56761
  Norm: 1.432686 gamma units


## Step 4: Load Tokenizer and Decode Target Token

In [6]:
print(f"Loading tokenizer: {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
print(f"Tokenizer loaded. Vocab size: {len(tokenizer):,}\n")

# Decode target token
target_token_str = tokenizer.decode([max_norm_token_id])

print(f"Target token (largest norm in jet):")
print(f"  Token ID: {max_norm_token_id}")
print(f"  Token string: '{target_token_str}'")
print(f"  Norm: {max_norm_value:.6f} gamma units")

Loading tokenizer: Qwen/Qwen3-4B-Instruct-2507...
Tokenizer loaded. Vocab size: 151,669

Target token (largest norm in jet):
  Token ID: 56761
  Token string: '…”

'
  Norm: 1.432686 gamma units


## Step 5: Compute Cosine Similarity to All Tokens

In [7]:
print(f"Computing cosine similarity to all {N:,} tokens...\n")

# Get target vector
target_vector = gamma[max_norm_token_id]

# Compute cosine similarity: cos(θ) = (u·v) / (||u|| ||v||)
# Use torch's built-in cosine_similarity
cosine_similarities = torch.nn.functional.cosine_similarity(
    target_vector.unsqueeze(0),  # Shape: (1, 2560)
    gamma,                        # Shape: (151936, 2560)
    dim=1
)

print(f"Cosine similarity statistics:")
print(f"  Mean: {cosine_similarities.mean().item():.6f}")
print(f"  Std: {cosine_similarities.std().item():.6f}")
print(f"  Min: {cosine_similarities.min().item():.6f}")
print(f"  Max: {cosine_similarities.max().item():.6f}")
print(f"  (Max should be 1.0 for target token itself)")

Computing cosine similarity to all 151,936 tokens...

Cosine similarity statistics:
  Mean: 0.122188
  Std: 0.064655
  Min: -0.382815
  Max: 1.000002
  (Max should be 1.0 for target token itself)


## Step 6: Find Top K Most Similar Tokens

In [8]:
# Get top K most similar tokens (including target itself)
top_k_similarities, top_k_indices = torch.topk(cosine_similarities, TOP_K + 1)

# Exclude the target token itself (should be first with similarity = 1.0)
top_k_similarities = top_k_similarities[1:]
top_k_indices = top_k_indices[1:]

print(f"\nTop {TOP_K} most similar tokens to '{target_token_str}' (token {max_norm_token_id}):\n")

# Create dataframe
similar_tokens = []
for rank, (token_id, similarity) in enumerate(zip(top_k_indices.cpu().numpy(), top_k_similarities.cpu().numpy()), 1):
    token_str = tokenizer.decode([int(token_id)])
    is_jet = jet_mask[token_id].item()
    token_norm = gamma_norms[token_id].item()
    
    similar_tokens.append({
        'rank': rank,
        'token_id': int(token_id),
        'token_str': f"'{token_str}'",
        'cosine_sim': similarity,
        'norm': token_norm,
        'in_jet': 'Yes' if is_jet else 'No'
    })

similar_df = pd.DataFrame(similar_tokens)
print(similar_df.to_string(index=False))


Top 30 most similar tokens to '…”

' (token 56761):

 rank  token_id    token_str  cosine_sim     norm in_jet
    1     50179         '…”'    0.820233 1.283059     No
    2     60803     '…"\n\n'    0.820080 1.237508    Yes
    3     76058     '…)\n\n'    0.708508 1.404078    Yes
    4     44993     '….\n\n'    0.662396 1.267065    Yes
    5     55109         '…"'    0.658880 1.178809     No
    6      2879     '.”\n\n'    0.650727 1.004609    Yes
    7     76379     ',…\n\n'    0.647343 1.484372     No
    8      5434      '…\n\n'    0.641754 1.211141    Yes
    9     16218     '?”\n\n'    0.630575 1.152109    Yes
   10     65579    '…\n\n\n'    0.628381 1.300524    Yes
   11     24727     '!”\n\n'    0.613053 1.195868    Yes
   12     78900     '。”\n\n'    0.604772 1.050671    Yes
   13     72839         '…)'    0.603005 1.283452     No
   14     72229  '…\n\n\n\n'    0.599721 1.359151     No
   15     47486   '..."\n\n'    0.591448 1.208160    Yes
   16     79515     '.…\n\n'    0.

## Step 7: Analyze Jet Membership of Neighbors

In [9]:
n_jet_neighbors = (similar_df['in_jet'] == 'Yes').sum()
n_bulk_neighbors = (similar_df['in_jet'] == 'No').sum()

print(f"\nNeighborhood composition:")
print(f"  Jet neighbors: {n_jet_neighbors}/{TOP_K} ({n_jet_neighbors/TOP_K*100:.1f}%)")
print(f"  Bulk neighbors: {n_bulk_neighbors}/{TOP_K} ({n_bulk_neighbors/TOP_K*100:.1f}%)")
print()
print(f"Baseline (overall population):")
print(f"  Jet: {n_jet:,}/{N:,} ({n_jet/N*100:.1f}%)")
print(f"  Bulk: {(~jet_mask).sum().item():,}/{N:,} ({(~jet_mask).sum().item()/N*100:.1f}%)")
print()

# Compute enrichment
enrichment = (n_jet_neighbors / TOP_K) / (n_jet / N)
print(f"Jet enrichment in neighborhood: {enrichment:.2f}×")
if enrichment > 2:
    print(f"  → Strong clustering: jet tokens are much more similar to each other")
elif enrichment > 1.5:
    print(f"  → Moderate clustering: jet tokens show some similarity")
else:
    print(f"  → Weak clustering: jet tokens are not particularly similar to each other")


Neighborhood composition:
  Jet neighbors: 22/30 (73.3%)
  Bulk neighbors: 8/30 (26.7%)

Baseline (overall population):
  Jet: 3,055/151,936 (2.0%)
  Bulk: 148,881/151,936 (98.0%)

Jet enrichment in neighborhood: 36.47×
  → Strong clustering: jet tokens are much more similar to each other


## Step 8: Examine Token Strings for Patterns

In [12]:
print(f"\nToken string patterns in top {TOP_K} neighbors:\n")

# Group by common patterns
token_strings = [t['token_str'] for t in similar_tokens]

# Count patterns
has_newline = sum(1 for s in token_strings if '\n' in s or '\r' in s)
has_punctuation = sum(1 for s in token_strings if any(p in s for p in [';', ',', '.', ')', '(', '}', '{', ']', '[', '>', '<']))
has_whitespace_only = sum(1 for s in token_strings if s.strip() in ["''", "' '", "'  '"])

print(f"Pattern counts:")
print(f"  Contains newline: {has_newline}/{TOP_K} ({has_newline/TOP_K*100:.1f}%)")
print(f"  Contains punctuation: {has_punctuation}/{TOP_K} ({has_punctuation/TOP_K*100:.1f}%)")
print(f"  Contains quotes: {has_quotes}/{TOP_K} ({has_quotes/TOP_K*100:.1f}%)")
print(f"  Whitespace only: {has_whitespace_only}/{TOP_K} ({has_whitespace_only/TOP_K*100:.1f}%)")


Token string patterns in top 30 neighbors:

Pattern counts:
  Contains newline: 24/30 (80.0%)
  Contains punctuation: 13/30 (43.3%)
  Contains quotes: 30/30 (100.0%)
  Whitespace only: 0/30 (0.0%)


## Summary

We identified the jet token with the largest norm and examined its nearest neighbors by cosine similarity.

**Key findings:**
- Target token: ??? (ID ???, norm ???)
- Top {TOP_K} neighbors show ???% jet membership vs ???% baseline
- Enrichment: ???× (jet tokens cluster/don't cluster)
- Pattern: Neighbors are/aren't syntactically similar

**Interpretation:**
- High jet enrichment (>2×) + syntactic similarity → Jet is a real semantic cluster
- Low jet enrichment (<1.5×) → Jet is a geometric artifact of the PC4×5 slice
- Mixed results → Jet has substructure, some tokens genuinely similar, others not