# 03.1a: Identify and Extract Jet Tokens

**Goal:** Isolate the jet structure using geometric bounds and save for further analysis.

From visual inspection of PC4×5×6 orthographic projections, we identified a jet-like structure extending from the main token cloud. We'll:

1. Load centered gamma and compute PCA
2. Apply rectangular bounds in PC4×5 space: **PC4 > 0.1, PC5 < -0.05**
3. Extract jet token embeddings (gamma_centered rows)
4. Save jet embeddings as safetensors for downstream analysis
5. Also save jet token IDs and mask for reference

This creates a clean separation that we can analyze independently - what are the jet's *own* principal axes? What's its internal structure?

## Parameters

In [1]:
TENSOR_DIR = "../data/tensors"

# Jet bounds (from visual inspection)
JET_PC4_MIN = 0.1
JET_PC5_MAX = -0.05

## Imports

In [2]:
import torch
import numpy as np
from safetensors.torch import load_file, save_file
from pathlib import Path

print("Imports loaded successfully.")

Imports loaded successfully.


## Step 1: Load Centered Gamma and Compute PCA

In [3]:
gamma_centered_path = Path(TENSOR_DIR) / "gamma_centered_qwen3_4b_instruct_2507.safetensors"
gamma_centered = load_file(gamma_centered_path)['gamma_centered']

N, d = gamma_centered.shape

print(f"Loaded γ' (gamma_centered):")
print(f"  Tokens: {N:,}")
print(f"  Dimensions: {d:,}")
print()

print("Computing covariance matrix...")
Cov = (gamma_centered.T @ gamma_centered) / (N - 1)

print(f"Computing eigendecomposition...")
eigenvalues, eigenvectors = torch.linalg.eigh(Cov)

# Sort descending (highest variance first)
eigenvalues = eigenvalues.flip(0)
eigenvectors = eigenvectors.flip(1)

print(f"\nTop 6 eigenvalues:")
for i in range(6):
    variance_explained = eigenvalues[i] / eigenvalues.sum()
    print(f"  PC{i+1}: λ = {eigenvalues[i].item():.6e} ({variance_explained.item()*100:.2f}% of variance)")

Loaded γ' (gamma_centered):
  Tokens: 151,936
  Dimensions: 2,560

Computing covariance matrix...
Computing eigendecomposition...

Top 6 eigenvalues:
  PC1: λ = 1.048719e-02 (0.94% of variance)
  PC2: λ = 3.177739e-03 (0.28% of variance)
  PC3: λ = 2.791374e-03 (0.25% of variance)
  PC4: λ = 2.616169e-03 (0.23% of variance)
  PC5: λ = 1.973001e-03 (0.18% of variance)
  PC6: λ = 1.805293e-03 (0.16% of variance)


## Step 2: Project onto PC4 and PC5

In [4]:
# Extract PC4 and PC5 axes (0-indexed: columns 3 and 4)
PC4_axis = eigenvectors[:, 3]
PC5_axis = eigenvectors[:, 4]

# Project all tokens onto PC4 and PC5
proj_PC4 = gamma_centered @ PC4_axis
proj_PC5 = gamma_centered @ PC5_axis

print(f"Projection statistics:")
print(f"  PC4: range [{proj_PC4.min().item():.4f}, {proj_PC4.max().item():.4f}], std = {proj_PC4.std().item():.4f}")
print(f"  PC5: range [{proj_PC5.min().item():.4f}, {proj_PC5.max().item():.4f}], std = {proj_PC5.std().item():.4f}")

Projection statistics:
  PC4: range [-0.1587, 0.3432], std = 0.0511
  PC5: range [-0.2012, 0.1963], std = 0.0444


## Step 3: Apply Jet Bounds

In [5]:
# Create jet mask
jet_mask = (proj_PC4 > JET_PC4_MIN) & (proj_PC5 < JET_PC5_MAX)

n_jet = jet_mask.sum().item()
n_bulk = (~jet_mask).sum().item()

print(f"Jet identification (PC4 > {JET_PC4_MIN}, PC5 < {JET_PC5_MAX}):")
print(f"  Jet tokens: {n_jet:,} ({n_jet/N*100:.2f}%)")
print(f"  Bulk tokens: {n_bulk:,} ({n_bulk/N*100:.2f}%)")
print()

print(f"Jet statistics in PC4×5 space:")
print(f"  PC4: mean = {proj_PC4[jet_mask].mean().item():.4f}, std = {proj_PC4[jet_mask].std().item():.4f}")
print(f"  PC5: mean = {proj_PC5[jet_mask].mean().item():.4f}, std = {proj_PC5[jet_mask].std().item():.4f}")

Jet identification (PC4 > 0.1, PC5 < -0.05):
  Jet tokens: 3,055 (2.01%)
  Bulk tokens: 148,881 (97.99%)

Jet statistics in PC4×5 space:
  PC4: mean = 0.1888, std = 0.0551
  PC5: mean = -0.1083, std = 0.0319


## Step 4: Extract Jet Embeddings

In [6]:
# Extract jet token embeddings (rows of gamma_centered)
jet_embeddings = gamma_centered[jet_mask]

print(f"Extracted jet embeddings:")
print(f"  Shape: {jet_embeddings.shape}")
print(f"  Dtype: {jet_embeddings.dtype}")
print(f"  Memory: {jet_embeddings.element_size() * jet_embeddings.nelement() / 1024**2:.1f} MB")
print()

# Verify jet embeddings are still centered (should be near zero mean)
jet_mean = jet_embeddings.mean(dim=0)
jet_mean_norm = jet_mean.norm().item()
print(f"Jet mean vector norm: {jet_mean_norm:.6e}")
print(f"(Should be small but non-zero - jet has different centroid than full cloud)")

Extracted jet embeddings:
  Shape: torch.Size([3055, 2560])
  Dtype: torch.float32
  Memory: 29.8 MB

Jet mean vector norm: 3.019564e-01
(Should be small but non-zero - jet has different centroid than full cloud)


## Step 5: Get Jet Token IDs

In [7]:
# Get token IDs for jet tokens
jet_token_ids = torch.where(jet_mask)[0]

print(f"Jet token IDs:")
print(f"  Count: {len(jet_token_ids):,}")
print(f"  Min ID: {jet_token_ids.min().item()}")
print(f"  Max ID: {jet_token_ids.max().item()}")
print(f"  Mean ID: {jet_token_ids.float().mean().item():.1f}")
print()
print(f"First 20 jet token IDs: {jet_token_ids[:20].tolist()}")

Jet token IDs:
  Count: 3,055
  Min ID: 317
  Max ID: 144129
  Mean ID: 46909.9

First 20 jet token IDs: [317, 319, 340, 397, 401, 456, 463, 515, 532, 543, 555, 626, 630, 692, 698, 735, 736, 741, 751, 756]


## Step 6: Save Jet Data

In [8]:
# Save jet embeddings
jet_embeddings_path = Path(TENSOR_DIR) / "jet_embeddings.safetensors"
save_file({'jet_embeddings': jet_embeddings.contiguous()}, jet_embeddings_path)
print(f"Saved jet embeddings to: {jet_embeddings_path}")

# Save jet token IDs
jet_token_ids_path = Path(TENSOR_DIR) / "jet_token_ids.safetensors"
save_file({'jet_token_ids': jet_token_ids.contiguous()}, jet_token_ids_path)
print(f"Saved jet token IDs to: {jet_token_ids_path}")

# Save jet mask (for convenience)
jet_mask_path = Path(TENSOR_DIR) / "jet_mask.safetensors"
save_file({'jet_mask': jet_mask.contiguous()}, jet_mask_path)
print(f"Saved jet mask to: {jet_mask_path}")

Saved jet embeddings to: ../data/tensors/jet_embeddings.safetensors
Saved jet token IDs to: ../data/tensors/jet_token_ids.safetensors
Saved jet mask to: ../data/tensors/jet_mask.safetensors


## Summary

Successfully extracted jet tokens from the main cloud!

**Jet identification:**
- Bounds: PC4 > 0.1, PC5 < -0.05
- Tokens: 3,055 (2.01% of vocabulary)

**Saved files:**
- `jet_embeddings.safetensors` - (3055, 2560) centered embeddings
- `jet_token_ids.safetensors` - (3055,) token IDs
- `jet_mask.safetensors` - (151936,) boolean mask

**Next steps:**
- Compute PCA on jet embeddings to find jet's natural axes
- Visualize jet in its own coordinate system
- Decode jet tokens to understand semantic properties