- import an SAE, get its decoder weights.
- get multiple ones for a few layers. 
- calculate cosine similarity. 
- plot it in a heatmap.
- do decoder weights live in the same space?

In [2]:
import torch
from sae_lens import SAE
import torch.nn.functional as F
import matplotlib.pyplot as plt
import os
import numpy as np
from tqdm import tqdm

In [None]:

# Disable gradients for memory efficiency
torch.set_grad_enabled(False)

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Initialize SAE for layer 1
sae_id = "layer_1/width_16k/canonical"
sae_1, _, _ = SAE.from_pretrained(
    release="gemma-scope-2b-pt-res-canonical",
    sae_id=sae_id
)
sae_1 = sae_1.to(device)
sae_1.eval()
sae_id = "layer_19/width_16k/canonical"
sae_2, _, _ = SAE.from_pretrained(
    release="gemma-scope-2b-pt-res-canonical",
    sae_id=sae_id
)
sae_2 = sae_2.to(device)
sae_2.eval()

device = "cuda" if torch.cuda.is_available() else "cpu"

# Get decoder weights and ensure they're on GPU
W_dec_1 = sae_1.W_dec.to(device)  # shape: [16384, 2304]
W_dec_2 = sae_2.W_dec.to(device)  # shape: [16384, 2304]

print(f"W_dec_1 device: {W_dec_1.device}")
print(f"W_dec_2 device: {W_dec_2.device}")

# Normalize each decoder vector (each row) in both matrices
W_dec_1_normalized = F.normalize(W_dec_1, p=2, dim=1)  # shape: [16384, 2304]
W_dec_2_normalized = F.normalize(W_dec_2, p=2, dim=1)  # shape: [16384, 2304]

print("Starting matrix multiplication...")
# Calculate cross-SAE cosine similarity matrix
cosine_sim_cross = torch.mm(W_dec_1_normalized, W_dec_2_normalized.T)  # shape: [16384, 16384]
print("Matrix multiplication complete!")

# Move to CPU only the subset we want to visualize
subset_sim = cosine_sim_cross[:100, :100].cpu().numpy()


# Find most similar pairs while keeping computation on GPU
k = 10  # number of top pairs to show
values, indices = torch.topk(cosine_sim_cross.view(-1), k)

# Convert flat indices to 2D indices (still on GPU)
rows = indices // cosine_sim_cross.shape[1]  # SAE1 feature indices
cols = indices % cosine_sim_cross.shape[1]   # SAE2 feature indices

print("\nTop similar feature pairs between SAE1 and SAE2:")
for val, row, col in zip(values.cpu(), rows.cpu(), cols.cpu()):
    print(f"SAE1 feature {row.item()} and SAE2 feature {col.item()}: similarity = {val:.3f}")


# Get decoder weights
W_dec = sae_1.W_dec  # shape: [16384, 2304]

# Normalize each decoder vector (each row)
# This divides each row by its L2 norm
W_dec_normalized = F.normalize(W_dec, p=2, dim=1)  # shape: still [16384, 2304]

# Calculate cosine similarity matrix
# W_dec_normalized @ W_dec_normalized.T will give us a [16384, 16384] matrix
# where entry [i,j] is the cosine similarity between feature i's and feature j's decoder vectors
cosine_sim = torch.mm(W_dec_normalized, W_dec_normalized.T)

# Take first 100 features for visualization
subset_sim = cosine_sim[:100, :100].cpu().numpy()


W_dec_1 device: cuda:0
W_dec_2 device: cuda:0
Starting matrix multiplication...
Matrix multiplication complete!

Top similar feature pairs between SAE1 and SAE2:
SAE1 feature 9201 and SAE2 feature 4346: similarity = 0.815
SAE1 feature 740 and SAE2 feature 12025: similarity = 0.810
SAE1 feature 15664 and SAE2 feature 3019: similarity = 0.716
SAE1 feature 13525 and SAE2 feature 9325: similarity = 0.674
SAE1 feature 7833 and SAE2 feature 4538: similarity = 0.666
SAE1 feature 15561 and SAE2 feature 3353: similarity = 0.659
SAE1 feature 1842 and SAE2 feature 15887: similarity = 0.657
SAE1 feature 5899 and SAE2 feature 14235: similarity = 0.653
SAE1 feature 1147 and SAE2 feature 13039: similarity = 0.651
SAE1 feature 8648 and SAE2 feature 9835: similarity = 0.642


In [3]:


# Disable gradients for efficiency
torch.set_grad_enabled(False)

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Parameters
num_layers = 26  # Now includes layers 0-25
latent_dim = 16384
batch_size = 1024  # Adjust based on GPU memory
base_dir = "./similarities"

# Create directories for all layers
os.makedirs(base_dir, exist_ok=True)
for layer in range(0, num_layers):  # 0 to 25
    os.makedirs(f"{base_dir}/layer_{layer}", exist_ok=True)

# Load SAEs and store decoder weights for ALL layers
decoder_weights = {}
for layer in tqdm(range(0, num_layers), desc="Loading SAEs"):
    sae_id = f"layer_{layer}/width_16k/canonical"
    sae, _, _ = SAE.from_pretrained(
        release="gemma-scope-2b-pt-res-canonical",
        sae_id=sae_id
    )
    sae = sae.to(device)
    sae.eval()
    # Normalize decoder weights and store in float16
    W_dec = sae.W_dec.to(device)  # Shape: [16384, 2304]
    norms = torch.norm(W_dec, p=2, dim=1, keepdim=True)
    if torch.any(norms == 0):
        print(f"Warning: Zero-norm vectors detected in Layer {layer}")
        W_dec = W_dec.clone()
        W_dec[norms == 0] = 1e-6
        norms[norms == 0] = 1e-6
    W_dec_normalized = (W_dec / norms).to(torch.float16)
    decoder_weights[layer] = W_dec_normalized
    del sae, W_dec
    torch.cuda.empty_cache()

# Precompute triu indices for upper triangular part (including diagonal)
triu_indices = torch.triu_indices(latent_dim, latent_dim, offset=0, device=device)

def should_compute_similarity(layer_x, layer_y):
    """Check if similarity should be computed (involves layers 0, 24, or 25)"""
    new_layers = {0, 24, 25}
    return layer_x in new_layers or layer_y in new_layers

# Compute upper triangular cosine similarities
for layer_x in tqdm(range(0, num_layers), desc="Computing similarities"):
    W_dec_x = decoder_weights[layer_x]
    
    # Within-layer similarity (X == Y)
    if should_compute_similarity(layer_x, layer_x):
        output_path = f"{base_dir}/layer_{layer_x}/similarity_{layer_x}_{layer_x}.npy"
        if not os.path.exists(output_path):
            similarity = torch.zeros(len(triu_indices[0]), dtype=torch.float16, device=device)
            for i in range(0, latent_dim, batch_size):
                W_dec_x_batch = W_dec_x[i:i + batch_size]
                # Compute full batch similarity
                batch_sim = torch.matmul(W_dec_x_batch, W_dec_x.T)
                # Extract upper triangular part
                batch_mask = (triu_indices[0] >= i) & (triu_indices[0] < i + batch_size)
                batch_indices = torch.where(batch_mask)[0]
                if len(batch_indices) > 0:
                    batch_i = triu_indices[0][batch_indices] - i
                    batch_j = triu_indices[1][batch_indices]
                    similarity[batch_indices] = batch_sim[batch_i, batch_j]
                torch.cuda.empty_cache()
            # Save as float16
            np.save(output_path, similarity.cpu().numpy())
            print(f"Saved upper triangular similarity for Layer {layer_x} vs. Layer {layer_x}")
            del similarity
            torch.cuda.empty_cache()
    
    # Between-layer similarities (X < Y)
    for layer_y in range(layer_x + 1, num_layers):
        if should_compute_similarity(layer_x, layer_y):
            output_path = f"{base_dir}/layer_{layer_x}/similarity_{layer_x}_{layer_y}.npy"
            if not os.path.exists(output_path):
                W_dec_y = decoder_weights[layer_y]
                similarity = torch.zeros(len(triu_indices[0]), dtype=torch.float16, device=device)
                for i in range(0, latent_dim, batch_size):
                    W_dec_x_batch = W_dec_x[i:i + batch_size]
                    # Compute full batch similarity
                    batch_sim = torch.matmul(W_dec_x_batch, W_dec_y.T)
                    # Extract upper triangular part
                    batch_mask = (triu_indices[0] >= i) & (triu_indices[0] < i + batch_size)
                    batch_indices = torch.where(batch_mask)[0]
                    if len(batch_indices) > 0:
                        batch_i = triu_indices[0][batch_indices] - i
                        batch_j = triu_indices[1][batch_indices]
                        similarity[batch_indices] = batch_sim[batch_i, batch_j]
                    torch.cuda.empty_cache()
                # Save as float16
                np.save(output_path, similarity.cpu().numpy())
                print(f"Saved upper triangular similarity for Layer {layer_x} vs. Layer {layer_y}")
                del similarity
                torch.cuda.empty_cache()

# Clean up
del decoder_weights, triu_indices
torch.cuda.empty_cache()

Loading SAEs: 100%|██████████| 26/26 [00:34<00:00,  1.33s/it]
Computing similarities:   0%|          | 0/26 [00:00<?, ?it/s]

Saved upper triangular similarity for Layer 0 vs. Layer 0
Saved upper triangular similarity for Layer 0 vs. Layer 1
Saved upper triangular similarity for Layer 0 vs. Layer 2
Saved upper triangular similarity for Layer 0 vs. Layer 3
Saved upper triangular similarity for Layer 0 vs. Layer 4
Saved upper triangular similarity for Layer 0 vs. Layer 5
Saved upper triangular similarity for Layer 0 vs. Layer 6
Saved upper triangular similarity for Layer 0 vs. Layer 7
Saved upper triangular similarity for Layer 0 vs. Layer 8
Saved upper triangular similarity for Layer 0 vs. Layer 9
Saved upper triangular similarity for Layer 0 vs. Layer 10
Saved upper triangular similarity for Layer 0 vs. Layer 11
Saved upper triangular similarity for Layer 0 vs. Layer 12
Saved upper triangular similarity for Layer 0 vs. Layer 13
Saved upper triangular similarity for Layer 0 vs. Layer 14
Saved upper triangular similarity for Layer 0 vs. Layer 15
Saved upper triangular similarity for Layer 0 vs. Layer 16
Saved u

Computing similarities:   4%|▍         | 1/26 [02:54<1:12:48, 174.72s/it]

Saved upper triangular similarity for Layer 0 vs. Layer 25
Saved upper triangular similarity for Layer 1 vs. Layer 24


Computing similarities:   8%|▊         | 2/26 [03:06<31:35, 78.98s/it]   

Saved upper triangular similarity for Layer 1 vs. Layer 25
Saved upper triangular similarity for Layer 2 vs. Layer 24


Computing similarities:  12%|█▏        | 3/26 [03:18<18:32, 48.36s/it]

Saved upper triangular similarity for Layer 2 vs. Layer 25
Saved upper triangular similarity for Layer 3 vs. Layer 24


Computing similarities:  15%|█▌        | 4/26 [03:30<12:28, 34.04s/it]

Saved upper triangular similarity for Layer 3 vs. Layer 25
Saved upper triangular similarity for Layer 4 vs. Layer 24


Computing similarities:  19%|█▉        | 5/26 [03:43<09:13, 26.37s/it]

Saved upper triangular similarity for Layer 4 vs. Layer 25
Saved upper triangular similarity for Layer 5 vs. Layer 24


Computing similarities:  23%|██▎       | 6/26 [03:56<07:16, 21.84s/it]

Saved upper triangular similarity for Layer 5 vs. Layer 25
Saved upper triangular similarity for Layer 6 vs. Layer 24


Computing similarities:  27%|██▋       | 7/26 [04:09<05:57, 18.82s/it]

Saved upper triangular similarity for Layer 6 vs. Layer 25
Saved upper triangular similarity for Layer 7 vs. Layer 24


Computing similarities:  31%|███       | 8/26 [04:19<04:49, 16.09s/it]

Saved upper triangular similarity for Layer 7 vs. Layer 25
Saved upper triangular similarity for Layer 8 vs. Layer 24


Computing similarities:  35%|███▍      | 9/26 [04:29<04:02, 14.26s/it]

Saved upper triangular similarity for Layer 8 vs. Layer 25
Saved upper triangular similarity for Layer 9 vs. Layer 24


Computing similarities:  38%|███▊      | 10/26 [04:39<03:27, 12.96s/it]

Saved upper triangular similarity for Layer 9 vs. Layer 25
Saved upper triangular similarity for Layer 10 vs. Layer 24


Computing similarities:  42%|████▏     | 11/26 [04:51<03:08, 12.59s/it]

Saved upper triangular similarity for Layer 10 vs. Layer 25
Saved upper triangular similarity for Layer 11 vs. Layer 24


Computing similarities:  46%|████▌     | 12/26 [05:02<02:48, 12.06s/it]

Saved upper triangular similarity for Layer 11 vs. Layer 25
Saved upper triangular similarity for Layer 12 vs. Layer 24


Computing similarities:  50%|█████     | 13/26 [05:14<02:37, 12.09s/it]

Saved upper triangular similarity for Layer 12 vs. Layer 25
Saved upper triangular similarity for Layer 13 vs. Layer 24


Computing similarities:  54%|█████▍    | 14/26 [05:26<02:24, 12.07s/it]

Saved upper triangular similarity for Layer 13 vs. Layer 25
Saved upper triangular similarity for Layer 14 vs. Layer 24


Computing similarities:  58%|█████▊    | 15/26 [05:38<02:12, 12.02s/it]

Saved upper triangular similarity for Layer 14 vs. Layer 25
Saved upper triangular similarity for Layer 15 vs. Layer 24


Computing similarities:  62%|██████▏   | 16/26 [05:51<02:02, 12.28s/it]

Saved upper triangular similarity for Layer 15 vs. Layer 25
Saved upper triangular similarity for Layer 16 vs. Layer 24


Computing similarities:  65%|██████▌   | 17/26 [06:04<01:52, 12.53s/it]

Saved upper triangular similarity for Layer 16 vs. Layer 25
Saved upper triangular similarity for Layer 17 vs. Layer 24


Computing similarities:  69%|██████▉   | 18/26 [06:15<01:37, 12.25s/it]

Saved upper triangular similarity for Layer 17 vs. Layer 25
Saved upper triangular similarity for Layer 18 vs. Layer 24


Computing similarities:  73%|███████▎  | 19/26 [06:27<01:24, 12.10s/it]

Saved upper triangular similarity for Layer 18 vs. Layer 25
Saved upper triangular similarity for Layer 19 vs. Layer 24


Computing similarities:  77%|███████▋  | 20/26 [06:39<01:11, 11.93s/it]

Saved upper triangular similarity for Layer 19 vs. Layer 25
Saved upper triangular similarity for Layer 20 vs. Layer 24


Computing similarities:  81%|████████  | 21/26 [06:51<00:59, 11.89s/it]

Saved upper triangular similarity for Layer 20 vs. Layer 25
Saved upper triangular similarity for Layer 21 vs. Layer 24


Computing similarities:  85%|████████▍ | 22/26 [07:02<00:47, 11.78s/it]

Saved upper triangular similarity for Layer 21 vs. Layer 25
Saved upper triangular similarity for Layer 22 vs. Layer 24


Computing similarities:  88%|████████▊ | 23/26 [07:14<00:35, 11.77s/it]

Saved upper triangular similarity for Layer 22 vs. Layer 25
Saved upper triangular similarity for Layer 23 vs. Layer 24


Computing similarities:  92%|█████████▏| 24/26 [07:25<00:23, 11.68s/it]

Saved upper triangular similarity for Layer 23 vs. Layer 25
Saved upper triangular similarity for Layer 24 vs. Layer 24


Computing similarities:  96%|█████████▌| 25/26 [07:37<00:11, 11.77s/it]

Saved upper triangular similarity for Layer 24 vs. Layer 25


Computing similarities: 100%|██████████| 26/26 [07:43<00:00, 17.84s/it]

Saved upper triangular similarity for Layer 25 vs. Layer 25





In [None]:
# import os
# import shutil
# from pathlib import Path

# # Define source and destination paths
# source_dir = Path(r"D:\Master's\gemma_scope_math\src\notebooks\similarities")
# dest_dir = Path(r"D:\Master's\gemma_scope_math\similarities")

# # Create destination directory if it doesn't exist
# dest_dir.mkdir(parents=True, exist_ok=True)

# # Function to move files safely
# def move_files_safely(src_path, dst_path):
#     """Move files from src_path to dst_path, skipping if destination exists"""
#     if not src_path.exists():
#         print(f"Source directory doesn't exist: {src_path}")
#         return
    
#     moved_count = 0
#     skipped_count = 0
    
#     # Walk through all files in source directory
#     for file_path in src_path.rglob("*.npy"):
#         # Calculate relative path from source
#         rel_path = file_path.relative_to(src_path)
        
#         # Calculate destination path
#         dest_file_path = dst_path / rel_path
        
#         # Create destination directory if needed
#         dest_file_path.parent.mkdir(parents=True, exist_ok=True)
        
#         # Check if destination file already exists
#         if dest_file_path.exists():
#             print(f"SKIPPED (already exists): {rel_path}")
#             skipped_count += 1
#         else:
#             # Move the file
#             shutil.move(str(file_path), str(dest_file_path))
#             print(f"MOVED: {rel_path}")
#             moved_count += 1
    
#     print(f"\nSummary:")
#     print(f"Files moved: {moved_count}")
#     print(f"Files skipped: {skipped_count}")
    
#     # Remove empty directories from source
#     try:
#         for layer_dir in src_path.iterdir():
#             if layer_dir.is_dir() and not any(layer_dir.iterdir()):
#                 layer_dir.rmdir()
#                 print(f"Removed empty directory: {layer_dir}")
        
#         # Remove source directory if empty
#         if not any(src_path.iterdir()):
#             src_path.rmdir()
#             print(f"Removed empty source directory: {src_path}")
#     except OSError as e:
#         print(f"Note: Could not remove some directories: {e}")

# # Execute the move
# print(f"Moving similarity files from:")
# print(f"  Source: {source_dir}")
# print(f"  Destination: {dest_dir}")
# print()

# move_files_safely(source_dir, dest_dir)
# print("\nMove operation completed!")

Moving similarity files from:
  Source: D:\Master's\gemma_scope_math\src\notebooks\similarities
  Destination: D:\Master's\gemma_scope_math\similarities

MOVED: layer_0\similarity_0_0.npy
MOVED: layer_0\similarity_0_1.npy
MOVED: layer_0\similarity_0_10.npy
MOVED: layer_0\similarity_0_11.npy
MOVED: layer_0\similarity_0_12.npy
MOVED: layer_0\similarity_0_13.npy
MOVED: layer_0\similarity_0_14.npy
MOVED: layer_0\similarity_0_15.npy
MOVED: layer_0\similarity_0_16.npy
MOVED: layer_0\similarity_0_17.npy
MOVED: layer_0\similarity_0_18.npy
MOVED: layer_0\similarity_0_19.npy
MOVED: layer_0\similarity_0_2.npy
MOVED: layer_0\similarity_0_20.npy
MOVED: layer_0\similarity_0_21.npy
MOVED: layer_0\similarity_0_22.npy
MOVED: layer_0\similarity_0_23.npy
MOVED: layer_0\similarity_0_24.npy
MOVED: layer_0\similarity_0_25.npy
MOVED: layer_0\similarity_0_3.npy
MOVED: layer_0\similarity_0_4.npy
MOVED: layer_0\similarity_0_5.npy
MOVED: layer_0\similarity_0_6.npy
MOVED: layer_0\similarity_0_7.npy
MOVED: layer_0

In [None]:
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

# Create two dummy 5x3 matrices
A = np.array([
    [1, 0, 0],
    [0, 1, 0],
    [1, 1, 0],
    [1, 2, 0],
    [0, 0, 1]
])

B = np.array([
    [1, 0, 0],
    [1, 1, 0],
    [0, 1, 0],
    [0, 0, 1],
    [1, 2, 0]
])

# Compute cosine similarity between all rows in A and all rows in B
similarity = cosine_similarity(A, B)

print("Cosine similarity matrix (A vs B):")
print(np.round(similarity, 2))
print(np.round(similarity.T, 2))

In [None]:
cosine_similarity(A, A)

In [11]:
file_path = r"../../similarities/layer_1/similarity_1_1.npy"
similarities = np.load(file_path)

In [16]:
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from tqdm import tqdm
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

# Parameters
num_layers = 26  # Now layers 0 to 25 inclusive
base_dir = r"../../similarities" 
subsample_fraction = 0.005
output_dir = "./"
clip_range = 0.2
max_density = 0.5

# Define color schemes
early_layers = range(0, 9)      # Layers 0-8
middle_layers = range(9, 17)    # Layers 9-16
late_layers = range(17, 26)     # Layers 17-25

# Initialize colormaps
blues = cm.get_cmap('Blues')
greens = cm.get_cmap('Greens')
greys = cm.get_cmap('Greys')

# Assign shades: lighter for earlier layers, darker for later layers
def get_color(layer):
    if layer in early_layers:
        shade = 0.1 + (layer - early_layers.start) / (len(early_layers) - 1) * (0.9 - 0.1)
        return blues(shade)
    elif layer in middle_layers:
        shade = 0.1 + (layer - middle_layers.start) / (len(middle_layers) - 1) * (0.9 - 0.1)
        return greens(shade)
    else:
        shade = 0.1 + (layer - late_layers.start) / (len(late_layers) - 1) * (0.9 - 0.1)
        return greys(shade)

# Store KDE data to avoid recomputation
kde_data = {}
kde_data_clipped = {}
subsampled_similarities = {}

# Load and compute KDEs once
for layer in tqdm(range(0, num_layers), desc="Computing KDEs"):
    file_path = f"{base_dir}/layer_{layer}/similarity_{layer}_{layer}.npy"
    similarities = np.load(file_path)
    similarities = similarities[abs(similarities - 1.0) > 1e-4]
    np.random.seed(42)
    subsample_size = int(len(similarities) * subsample_fraction)
    similarities_subsampled = np.random.choice(similarities, size=subsample_size, replace=False)
    subsampled_similarities[layer] = similarities_subsampled

    kde = sns.kdeplot(similarities_subsampled, color='black', linewidth=0)
    kde_data[layer] = (kde.get_lines()[-1].get_xdata(), kde.get_lines()[-1].get_ydata())
    kde.remove()

    clipped_similarities = similarities_subsampled[abs(similarities_subsampled) > clip_range]
    if len(clipped_similarities) > 0:
        kde = sns.kdeplot(clipped_similarities, color='black', linewidth=0)
        kde_data_clipped[layer] = (kde.get_lines()[-1].get_xdata(), kde.get_lines()[-1].get_ydata())
        kde.remove()
    else:
        kde_data_clipped[layer] = (np.array([]), np.array([]))

# Option 1: Clipped similarities
plt.figure(figsize=(12, 8))
sns.set_style("whitegrid")
for layer in range(0, num_layers):
    x, y = kde_data_clipped[layer]
    if len(x) > 0:
        plt.plot(x, y, label=f"Layer {layer}", color=get_color(layer), linewidth=1.5)
plt.title("KDE of Within-Layer Cosine Similarities (Excluding [-0.05, 0.05])", fontsize=14)
plt.xlabel("Cosine Similarity", fontsize=12)
plt.ylabel("Density", fontsize=12)
plt.legend(title="Layer", ncol=2, fontsize=10)
plt.tight_layout()
plt.savefig(f"{output_dir}/kde_clipped.png", dpi=300, bbox_inches="tight")
plt.close()
print(f"Saved clipped KDE plot to {output_dir}/kde_clipped.png")

# Option 2: Logarithmic y-axis
plt.figure(figsize=(12, 8))
sns.set_style("whitegrid")
for layer in range(0, num_layers):
    x, y = kde_data[layer]
    plt.plot(x, y, label=f"Layer {layer}", color=get_color(layer), linewidth=1.5)
plt.title("KDE of Within-Layer Cosine Similarities (Log Y-Axis)", fontsize=14)
plt.xlabel("Cosine Similarity", fontsize=12)
plt.ylabel("Density (Log Scale)", fontsize=12)
plt.yscale('log')
plt.ylim(1e-3, None)
plt.legend(title="Layer", ncol=2, fontsize=10)
plt.tight_layout()
plt.savefig(f"{output_dir}/kde_log_y.png", dpi=300, bbox_inches="tight")
plt.close()
print(f"Saved log y-axis KDE plot to {output_dir}/kde_log_y.png")

# Options 3 and 7: Truncated y-axis and focused x-axis ranges in subplots
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6), sharey=False)
sns.set_style("whitegrid")

# Option 3: Truncated y-axis
for layer in range(0, num_layers):
    x, y = kde_data[layer]
    ax1.plot(x, y, label=f"Layer {layer}", color=get_color(layer), linewidth=1.5)
ax1.set_title("Truncated Y-Axis", fontsize=12)
ax1.set_xlabel("Cosine Similarity", fontsize=10)
ax1.set_ylabel("Density", fontsize=10)
ax1.set_ylim(0, max_density)
ax1.legend(title="Layer", ncol=2, fontsize=8)

# Option 7a: Focus on x-axis [-1, -0.05]
for layer in range(0, num_layers):
    x, y = kde_data[layer]
    mask = (x >= -1) & (x <= -clip_range)
    ax2.plot(x[mask], y[mask], label=f"Layer {layer}", color=get_color(layer), linewidth=1.5)
ax2.set_title(f"X-Axis [-1, -{clip_range}]", fontsize=12)
ax2.set_xlabel("Cosine Similarity", fontsize=10)
ax2.set_ylabel("Density", fontsize=10)
ax2.legend(title="Layer", ncol=2, fontsize=8)

# Option 7b: Focus on x-axis [0.05, 1]
for layer in range(0, num_layers):
    x, y = kde_data[layer]
    mask = (x >= clip_range) & (x <= 1)
    ax3.plot(x[mask], y[mask], label=f"Layer {layer}", color=get_color(layer), linewidth=1.5)
ax3.set_title(f"X-Axis [{clip_range}, 1]", fontsize=12)
ax3.set_xlabel("Cosine Similarity", fontsize=10)
ax3.set_ylabel("Density", fontsize=10)
ax3.legend(title="Layer", ncol=2, fontsize=8)

plt.suptitle("KDE of Within-Layer Cosine Similarities (Y-Limit and X-Range Variations)", fontsize=14)
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.savefig(f"{output_dir}/kde_ylim_xlim.png", dpi=300, bbox_inches="tight")
plt.close()
print(f"Saved y-limit and x-range KDE plots to {output_dir}/kde_ylim_xlim.png")

# Option 4: Inset plot
fig, ax = plt.subplots(figsize=(12, 8))
sns.set_style("whitegrid")
for layer in range(0, num_layers):
    x, y = kde_data[layer]
    ax.plot(x, y, label=f"Layer {layer}", color=get_color(layer), linewidth=1.5)

axins = inset_axes(ax, width="40%", height="30%", loc="upper right")
for layer in range(0, num_layers):
    x, y = kde_data_clipped[layer]
    if len(x) > 0:
        axins.plot(x, y, color=get_color(layer), linewidth=1.5)
axins.set_xlim(-1, -clip_range)
axins.set_ylim(0, max_density)
axins.set_xlabel("Cosine Similarity", fontsize=8)
axins.set_ylabel("Density", fontsize=8)
axins.tick_params(labelsize=6)

ax.set_title("KDE of Within-Layer Cosine Similarities with Inset", fontsize=14)
ax.set_xlabel("Cosine Similarity", fontsize=12)
ax.set_ylabel("Density", fontsize=12)
ax.legend(title="Layer", ncol=2, fontsize=10)
plt.tight_layout()
plt.savefig(f"{output_dir}/kde_inset.png", dpi=300, bbox_inches="tight")
plt.close()
print(f"Saved inset KDE plot to {output_dir}/kde_inset.png")


  blues = cm.get_cmap('Blues')
  greens = cm.get_cmap('Greens')
  greys = cm.get_cmap('Greys')
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as

Saved clipped KDE plot to .//kde_clipped.png
Saved log y-axis KDE plot to .//kde_log_y.png
Saved y-limit and x-range KDE plots to .//kde_ylim_xlim.png


  plt.tight_layout()


Saved inset KDE plot to .//kde_inset.png


<Figure size 640x480 with 0 Axes>

In [18]:
import numpy as np
import seaborn as sns
from tqdm import tqdm
import os
import pickle
import matplotlib.pyplot as plt  # Needed for plt.close()

# Parameters
num_layers = 26  # Now includes layers 0–25
base_dir = r"../../similarities"
subsample_fraction = 0.01  # Subsample 1% of similarities for KDE
output_dir = r"../../kde_cache"
clip_range = 0.2  # For clipped similarities: exclude [-clip_range, clip_range]

# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Initialize dictionaries to store data
kde_data = {}  # {layer: (x, y)}
kde_data_clipped = {}  # {layer: (x, y)} for clipped similarities
subsampled_similarities = {}  # {layer: array} for original subsampled data

# Load and compute KDEs once
for layer in tqdm(range(0, num_layers), desc="Computing KDEs"):  # 0 to 25 inclusive
    # Load similarity matrix
    file_path = f"{base_dir}/layer_{layer}/similarity_{layer}_{layer}.npy"
    similarities = np.load(file_path)
    
    # Filter out diagonal elements (self-similarities ≈ 1)
    similarities = similarities[abs(similarities - 1.0) > 1e-4]
    
    # Subsample
    np.random.seed(42)  # For reproducibility
    subsample_size = int(len(similarities) * subsample_fraction)
    similarities_subsampled = np.random.choice(similarities, size=subsample_size, replace=False)
    subsampled_similarities[layer] = similarities_subsampled
    
    # Compute KDE for original data
    kde = sns.kdeplot(similarities_subsampled, color='black', linewidth=0)  # Temporary plot to get KDE
    kde_data[layer] = (kde.get_lines()[-1].get_xdata(), kde.get_lines()[-1].get_ydata())
    kde.get_lines()[-1].remove()  # Remove temporary line
    
    # Compute KDE for clipped data
    clipped_similarities = similarities_subsampled[abs(similarities_subsampled) > clip_range]
    if len(clipped_similarities) > 0:  # Ensure there's data to compute KDE
        kde = sns.kdeplot(clipped_similarities, color='black', linewidth=0)
        kde_data_clipped[layer] = (kde.get_lines()[-1].get_xdata(), kde.get_lines()[-1].get_ydata())
        kde.get_lines()[-1].remove()
    else:
        kde_data_clipped[layer] = (np.array([]), np.array([]))
    
    # Clear the figure to avoid memory issues
    plt.close()

# Save KDE data and subsampled similarities
with open(f"{output_dir}/kde_data.pkl", "wb") as f:
    pickle.dump(kde_data, f)
with open(f"{output_dir}/kde_data_clipped.pkl", "wb") as f:
    pickle.dump(kde_data_clipped, f)
with open(f"{output_dir}/subsampled_similarities.pkl", "wb") as f:
    pickle.dump(subsampled_similarities, f)

print(f"Saved KDE data and subsampled similarities to {output_dir}")

  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mo

Saved KDE data and subsampled similarities to ../../kde_cache


In [21]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import seaborn as sns
import pickle
import os

# Parameters
num_layers = 25
output_dir = r"../../kde_cache"
plot_output_dir =  r"../../kde_plots"

# Plot configuration
plot_configs = [
    {
        "x_range": [(-0.2, 0.2)],  # Single range for Plot 1
        "y_range": (1e-2, 100),   # Log scale: 10^-2 to 100
        "y_scale": "log",
        "vlines": [0],            # Vertical line at x=0
        "y_label": "Density (Log Scale)"
    },
    {
        "x_range": [(-0.4, -0.2), (0.2, 0.4)],  # Two ranges for Plot 2
        "y_range": (0, 0.025),    # Linear scale: 0 to 0.25
        "y_scale": "linear",
        "vlines": [-0.2, 0.2],   # Vertical lines at x=-0.1, 0.1
        "y_label": "Density"
    },
    {
        "x_range": [(-1, -0.4)],  # Single range for Plot 3
        "y_range": (0, 0.0004),    # Linear scale: 0 to 0.01
        "y_scale": "linear",
        "vlines": [-0.4],        # Vertical line at x=-0.3
        "y_label": "Density"
    },
    {
        "x_range": [(0.4, 1)],   # Single range for Plot 4
        "y_range": (0, 0.0008),    # Linear scale: 0 to 0.01
        "y_scale": "linear",
        "vlines": [0.4],         # Vertical line at x=0.3
        "y_label": "Density"
    }
]

# Create plot output directory
os.makedirs(plot_output_dir, exist_ok=True)

# Load precomputed KDE data
with open(f"{output_dir}/kde_data.pkl", "rb") as f:
    kde_data = pickle.load(f)
# new ranges
# Define color schemes
early_layers = range(0, 9)  # Layers 0-8
middle_layers = range(9, 17)  # Layers 9-16
late_layers = range(17, 26)  # Layers 17-25

# Initialize colormaps
blues = cm.get_cmap('Blues')
greens = cm.get_cmap('Greens')
greys = cm.get_cmap('Greys')

# Assign shades: lighter for earlier layers, darker for later layers
# Assign shades: lighter for earlier layers, darker for later layers
def get_color(layer):
    if layer in early_layers:
        shade = 0.1 + (layer - early_layers.start) / (len(early_layers) - 1) * (0.9 - 0.1)
        return blues(shade)
    elif layer in middle_layers:
        shade = 0.1 + (layer - middle_layers.start) / (len(middle_layers) - 1) * (0.9 - 0.1)
        return greens(shade)
    else:
        shade = 0.1 + (layer - late_layers.start) / (len(late_layers) - 1) * (0.9 - 0.1)
        return greys(shade)

# Function to generate x-range mask
def get_x_mask(x, x_ranges):
    mask = np.zeros_like(x, dtype=bool)
    for x_min, x_max in x_ranges:
        mask |= (x >= x_min) & (x <= x_max)
    return mask

# Function to format x-range for title
def format_x_range(x_ranges):
    ranges = [f"[{x_min:.1f}, {x_max:.1f}]" for x_min, x_max in x_ranges]
    return " ∪ ".join(ranges)

# Create subplots
fig, axes = plt.subplots(1, 4, figsize=(24, 6), sharey=False)
sns.set_style("whitegrid")

# Generate plots
for ax, config in zip(axes, plot_configs):
    x_ranges = config["x_range"]
    y_range = config["y_range"]
    y_scale = config["y_scale"]
    vlines = config["vlines"]
    y_label = config["y_label"]
    
    # Plot KDEs for each layer
    for layer in range(0, num_layers + 1):
        x, y = kde_data[layer]
        mask = get_x_mask(x, x_ranges)
        if np.any(mask):
            ax.plot(x[mask], y[mask], label=f"Layer {layer}", color=get_color(layer), linewidth=1.5, alpha=0.8)
    
    # Set plot properties
    ax.set_title(f"Cosine Similarity {format_x_range(x_ranges)}", fontsize=12)
    ax.set_xlabel("Cosine Similarity", fontsize=10)
    ax.set_ylabel(y_label, fontsize=10)
    ax.set_yscale(y_scale)
    ax.set_ylim(y_range)
    for vline in vlines:
        ax.axvline(x=vline, color='black', linestyle='--', alpha=0.3)
    ax.legend(title="Layer", ncol=2, fontsize=8)

# Finalize and save
plt.suptitle("KDE of Within-Layer Cosine Similarities (Zoomed Ranges)", fontsize=14)
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.savefig(f"{plot_output_dir}/kde_zoomed_ranges_adjusted.png", dpi=300, bbox_inches="tight")
plt.close()
print(f"Saved adjusted KDE plots to {plot_output_dir}/kde_zoomed_ranges_adjusted.png")

  blues = cm.get_cmap('Blues')
  greens = cm.get_cmap('Greens')
  greys = cm.get_cmap('Greys')
  plt.tight_layout(rect=[0, 0, 1, 0.95])
  plt.savefig(f"{plot_output_dir}/kde_zoomed_ranges_adjusted.png", dpi=300, bbox_inches="tight")


Saved adjusted KDE plots to ../../kde_plots/kde_zoomed_ranges_adjusted.png


# Misc

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os

# Parameters
num_layers = 23
base_dir = "./similarities"
plot_output_dir = "./kde_plots"
target_samples = 100  # Fixed number of samples instead of percentage

# Create plot output directory
os.makedirs(plot_output_dir, exist_ok=True)

# Initialize similarity matrix
mean_similarities = np.zeros((num_layers, num_layers))

# Load and compute mean similarities efficiently
for i in range(1, num_layers + 1):
    for j in range(i, num_layers + 1):
        file_path = f"{base_dir}/layer_{i}/similarity_{i}_{j}.npy"
        if os.path.exists(file_path):
            # Memory-map the file instead of loading it
            similarities = np.load(file_path, mmap_mode='r')
            
            # Generate random indices for subsampling
            np.random.seed(42)
            total_size = similarities.size
            sample_indices = np.random.choice(total_size, 
                                            size=min(target_samples, total_size), 
                                            replace=False)
            
            # Load only the sampled values
            sampled_similarities = similarities.flat[sample_indices]
            
            # Filter out near-1 similarities
            filtered_similarities = sampled_similarities[abs(sampled_similarities - 1.0) > 1e-4]
            
            # Compute mean
            if len(filtered_similarities) > 0:
                mean_sim = np.mean(filtered_similarities)
            else:
                mean_sim = np.nan
                
            mean_similarities[i-1, j-1] = mean_sim
            mean_similarities[j-1, i-1] = mean_sim
        else:
            mean_similarities[i-1, j-1] = np.nan
            mean_similarities[j-1, i-1] = np.nan

# Create heatmap
plt.figure(figsize=(10, 8))
sns.set_style("whitegrid")
sns.heatmap(mean_similarities, cmap="Blues", annot=True, fmt=".2f", 
            xticklabels=range(1, num_layers + 1), yticklabels=range(1, num_layers + 1),
            cbar_kws={"label": "Mean Cosine Similarity"})
plt.title("Mean Cosine Similarity Between Layers", fontsize=14)
plt.xlabel("Layer", fontsize=12)
plt.ylabel("Layer", fontsize=12)
plt.tight_layout()
plt.savefig(f"{plot_output_dir}/heatmap_mean_similarities.png", dpi=300, bbox_inches="tight")
plt.close()
print(f"Saved heatmap to {plot_output_dir}/heatmap_mean_similarities.png")

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
from matplotlib import cm
from collections import defaultdict

# Parameters
num_layers = 23
base_dir = "./similarities"
plot_output_dir = "./kde_plots"
target_samples_per_file = 10  # Fixed number instead of percentage

# Define color schemes
early_layers = range(1, 8)  # Layers 1-7
middle_layers = range(8, 17)  # Layers 8-16
late_layers = range(17, 24)  # Layers 17-23

# Initialize colormaps
blues = cm.get_cmap('Blues')
greens = cm.get_cmap('Greens')
greys = cm.get_cmap('Greys')

# Assign shades
def get_color(layer):
    if layer in early_layers:
        shade = 0.1 + (layer - 1) / (7 - 1) * (0.9 - 0.1)
        return blues(shade)
    elif layer in middle_layers:
        shade = 0.1 + (layer - 8) / (16 - 8) * (0.9 - 0.1)
        return greens(shade)
    else:
        shade = 0.1 + (layer - 17) / (23 - 17) * (0.9 - 0.1)
        return greys(shade)

def efficient_sample_file(file_path, target_samples):
    """Efficiently sample from large numpy file without loading it entirely"""
    # Memory-map the file
    similarities = np.load(file_path, mmap_mode='r')
    total_size = similarities.size
    
    # Generate random indices for subsampling
    np.random.seed(hash(file_path) % 2**32)  # Consistent seed per file
    sample_size = min(target_samples, total_size)
    sample_indices = np.random.choice(total_size, size=sample_size, replace=False)
    
    # Load only the sampled values
    sampled_similarities = similarities.flat[sample_indices]
    
    # Filter out near-1 similarities and return
    return sampled_similarities[abs(sampled_similarities - 1.0) > 1e-4]

# Create plot output directory
os.makedirs(plot_output_dir, exist_ok=True)

# Collect statistics by layer distance (streaming approach)
distance_stats = defaultdict(lambda: {'sum': 0.0, 'sum_sq': 0.0, 'count': 0})

print("Processing similarity files...")
total_files = sum(1 for i in range(1, num_layers + 1) for j in range(i + 1, num_layers + 1))
processed = 0

for i in range(1, num_layers + 1):
    for j in range(i + 1, num_layers + 1):  # Exclude diagonal (i==j)
        file_path = f"{base_dir}/layer_{i}/similarity_{i}_{j}.npy"
        if os.path.exists(file_path):
            # Efficiently sample the file
            similarities_subsampled = efficient_sample_file(file_path, target_samples_per_file)
            
            if len(similarities_subsampled) > 0:
                distance = abs(i - j)
                
                # Update running statistics (more memory efficient)
                distance_stats[distance]['sum'] += np.sum(similarities_subsampled)
                distance_stats[distance]['sum_sq'] += np.sum(similarities_subsampled**2)
                distance_stats[distance]['count'] += len(similarities_subsampled)
        
        processed += 1
        if processed % 50 == 0:
            print(f"Processed {processed}/{total_files} files...")

# Compute mean and std for each distance
distances = sorted(distance_stats.keys())
mean_sims = []
std_sims = []

for d in distances:
    stats = distance_stats[d]
    if stats['count'] > 0:
        mean = stats['sum'] / stats['count']
        # Standard deviation using online algorithm
        variance = (stats['sum_sq'] / stats['count']) - (mean**2)
        std = np.sqrt(max(0, variance))  # Ensure non-negative
        
        mean_sims.append(mean)
        std_sims.append(std)
    else:
        mean_sims.append(np.nan)
        std_sims.append(np.nan)

# Plot
plt.figure(figsize=(10, 6))
sns.set_style("whitegrid")

# Remove any NaN values for plotting
valid_indices = ~np.isnan(mean_sims)
valid_distances = [distances[i] for i in range(len(distances)) if valid_indices[i]]
valid_means = [mean_sims[i] for i in range(len(mean_sims)) if valid_indices[i]]
valid_stds = [std_sims[i] for i in range(len(std_sims)) if valid_indices[i]]

plt.plot(valid_distances, valid_means, marker='o', color=get_color(1), linewidth=2, label="Mean Similarity")
plt.fill_between(valid_distances, 
                 [m - s for m, s in zip(valid_means, valid_stds)], 
                 [m + s for m, s in zip(valid_means, valid_stds)], 
                 color=get_color(1), alpha=0.2, label="±1 Std")

plt.title("Mean Cosine Similarity vs. Layer Distance", fontsize=14)
plt.xlabel("Layer Distance (|i - j|)", fontsize=12)
plt.ylabel("Mean Cosine Similarity", fontsize=12)
plt.legend(fontsize=10)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(f"{plot_output_dir}/similarity_vs_layer_distance.png", dpi=300, bbox_inches="tight")
plt.close()
print(f"Saved layer distance plot to {plot_output_dir}/similarity_vs_layer_distance.png")

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import seaborn as sns
import os
from collections import defaultdict

# Parameters
num_layers = 23
base_dir = "./similarities"
plot_output_dir = "./kde_plots"
target_samples_per_file = 5000  # Fixed number instead of percentage

# Define color schemes
early_layers = range(1, 8)  # Layers 1-7
middle_layers = range(8, 17)  # Layers 8-16
late_layers = range(17, 24)  # Layers 17-23

# Initialize colormaps
blues = cm.get_cmap('Blues')
greens = cm.get_cmap('Greens')
greys = cm.get_cmap('Greys')

# Assign shades
def get_color(layer):
    if layer in early_layers:
        shade = 0.1 + (layer - 1) / (7 - 1) * (0.9 - 0.1)
        return blues(shade)
    elif layer in middle_layers:
        shade = 0.1 + (layer - 8) / (16 - 8) * (0.9 - 0.1)
        return greens(shade)
    else:
        shade = 0.1 + (layer - 17) / (23 - 17) * (0.9 - 0.1)
        return greys(shade)

def efficient_sample_file(file_path, target_samples):
    """Efficiently sample from large numpy file without loading it entirely"""
    # Memory-map the file
    similarities = np.load(file_path, mmap_mode='r')
    total_size = similarities.size
    
    # Generate random indices for subsampling
    np.random.seed(hash(file_path) % 2**32)  # Consistent seed per file
    sample_size = min(target_samples, total_size)
    sample_indices = np.random.choice(total_size, size=sample_size, replace=False)
    
    # Load only the sampled values
    sampled_similarities = similarities.flat[sample_indices]
    
    # Filter out near-1 similarities and return
    return sampled_similarities[abs(sampled_similarities - 1.0) > 1e-4]

# Create plot output directory
os.makedirs(plot_output_dir, exist_ok=True)

# Load each file only once and store statistics for both layers
layer_stats = defaultdict(lambda: {'sum': 0.0, 'count': 0})

print("Processing similarity files (loading each file only once)...")
total_files = sum(1 for i in range(1, num_layers + 1) for j in range(i + 1, num_layers + 1))
processed = 0

# Process each unique file pair only once
for i in range(1, num_layers + 1):
    for j in range(i + 1, num_layers + 1):  # Only upper triangle
        file_path = f"{base_dir}/layer_{i}/similarity_{i}_{j}.npy"
        if os.path.exists(file_path):
            # Efficiently sample the file
            similarities_subsampled = efficient_sample_file(file_path, target_samples_per_file)
            
            if len(similarities_subsampled) > 0:
                sim_sum = np.sum(similarities_subsampled)
                sim_count = len(similarities_subsampled)
                
                # Update statistics for both layers involved
                layer_stats[i]['sum'] += sim_sum
                layer_stats[i]['count'] += sim_count
                layer_stats[j]['sum'] += sim_sum
                layer_stats[j]['count'] += sim_count
        
        processed += 1
        if processed % 50 == 0:
            print(f"Processed {processed}/{total_files} files...")

# Compute mean similarity for each layer
mean_sims_per_layer = []
for i in range(1, num_layers + 1):
    if layer_stats[i]['count'] > 0:
        mean_sim = layer_stats[i]['sum'] / layer_stats[i]['count']
        mean_sims_per_layer.append(mean_sim)
    else:
        mean_sims_per_layer.append(np.nan)

# Plot
plt.figure(figsize=(12, 6))
sns.set_style("whitegrid")

# Create the plot with colored points
layers = range(1, num_layers + 1)
colors = [get_color(i) for i in layers]

# Plot all points at once for efficiency
plt.scatter(layers, mean_sims_per_layer, c=colors, s=80, alpha=0.8, edgecolors='black', linewidth=0.5)

# Add a connecting line for trend visualization
valid_indices = ~np.isnan(mean_sims_per_layer)
if np.any(valid_indices):
    valid_layers = [i for i, valid in enumerate(layers) if valid_indices[i]]
    valid_means = [mean_sims_per_layer[i] for i in range(len(mean_sims_per_layer)) if valid_indices[i]]
    plt.plot(valid_layers, valid_means, color='gray', alpha=0.5, linewidth=1, zorder=0)

plt.title("Mean Cosine Similarity of Each Layer to All Others", fontsize=14)
plt.xlabel("Layer", fontsize=12)
plt.ylabel("Mean Cosine Similarity", fontsize=12)
plt.xticks(range(1, num_layers + 1))
plt.grid(True, alpha=0.3)

# Add color legend
from matplotlib.patches import Patch
legend_elements = [
    Patch(facecolor=blues(0.7), label='Early Layers (1-7)'),
    Patch(facecolor=greens(0.7), label='Middle Layers (8-16)'),
    Patch(facecolor=greys(0.7), label='Late Layers (17-23)')
]
plt.legend(handles=legend_elements, loc='best', fontsize=10)

plt.tight_layout()
plt.savefig(f"{plot_output_dir}/mean_similarity_per_layer.png", dpi=300, bbox_inches="tight")
plt.close()

# Print summary statistics
print(f"\nSaved mean similarity per layer plot to {plot_output_dir}/mean_similarity_per_layer.png")
print(f"\nSummary Statistics:")
print(f"Min similarity: {np.nanmin(mean_sims_per_layer):.4f}")
print(f"Max similarity: {np.nanmax(mean_sims_per_layer):.4f}")
print(f"Mean similarity: {np.nanmean(mean_sims_per_layer):.4f}")

# Find layers with highest and lowest average similarity
if not np.all(np.isnan(mean_sims_per_layer)):
    max_layer = np.nanargmax(mean_sims_per_layer) + 1
    min_layer = np.nanargmin(mean_sims_per_layer) + 1
    print(f"Layer with highest avg similarity: {max_layer} ({mean_sims_per_layer[max_layer-1]:.4f})")
    print(f"Layer with lowest avg similarity: {min_layer} ({mean_sims_per_layer[min_layer-1]:.4f})")