## Imports

The following lines will import required libraries to execute the code

In [1]:

from transformers import AutoModelForCausalLM, AutoTokenizer
from sae_lens import SAE
from sklearn.metrics.pairwise import cosine_similarity
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

  from .autonotebook import tqdm as notebook_tqdm


## Loading model

We will experiment with gemma model 2b and SAEs already trained and provided by gemma scope

In [2]:
# Disable gradients to avoid running out of memory
torch.set_grad_enabled(False)
# Load the model and tokenizer
model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b", device_map="cuda:1")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")

Loading checkpoint shards: 100%|██████████| 3/3 [00:13<00:00,  4.45s/it]


In [3]:
print(model.model)

Gemma2Model(
  (embed_tokens): Embedding(256000, 2304, padding_idx=0)
  (layers): ModuleList(
    (0-25): 26 x Gemma2DecoderLayer(
      (self_attn): Gemma2Attention(
        (q_proj): Linear(in_features=2304, out_features=2048, bias=False)
        (k_proj): Linear(in_features=2304, out_features=1024, bias=False)
        (v_proj): Linear(in_features=2304, out_features=1024, bias=False)
        (o_proj): Linear(in_features=2048, out_features=2304, bias=False)
        (rotary_emb): Gemma2RotaryEmbedding()
      )
      (mlp): Gemma2MLP(
        (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)
        (up_proj): Linear(in_features=2304, out_features=9216, bias=False)
        (down_proj): Linear(in_features=9216, out_features=2304, bias=False)
        (act_fn): PytorchGELUTanh()
      )
      (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
      (post_attention_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
      (pre_feedforward_layernorm): Gemma2RMSNorm((2304,), e

## Pairs of layers

The following code will execute and compute the cosine similarity across a set of defined pairs of layers

In [4]:
import matplotlib.pyplot as plt
import seaborn as sns

# Define the input prompt
prompt = "I want to time travel back to Spain in 2000"
inputs = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=True).to("cuda:1")
print(len(inputs[0]))
print(inputs[0])

# Initialize lists to store layer indices and similarities
layers = []
mean_similarities = []

layer_pairs = []
# Iterate through layer pairs for comparisons
for i in range(15):
    layer_pairs.append([0,i])

for pair in layer_pairs:
    print(f"------------------------")
    print(f"Comparing layers {pair}")
    
    # Load the SAEs for different layers
    sae, cfg_dict, sparsity = SAE.from_pretrained(release="gemma-scope-2b-pt-res-canonical", sae_id=f"layer_{str(pair[0])}/width_16k/canonical")
    sae.to("cuda:1")
    sae2, cfg_dict2, sparsity2 = SAE.from_pretrained(release="gemma-scope-2b-pt-res-canonical", sae_id=f"layer_{str(pair[1])}/width_16k/canonical")
    sae2.to("cuda:1")

    # Function to gather residual activations from a specified layer
    def gather_activations(model, target_layer, inputs):
        target_act = None
        
        def hook(mod, inputs, outputs):
            nonlocal target_act
            target_act = outputs[0]
            return outputs

        handle = model.model.layers[target_layer].register_forward_hook(hook)
        _ = model.forward(inputs)
        handle.remove()
        return target_act

    # Gather residual activations for layers
    act_layer_base = gather_activations(model, pair[0], inputs)
    act_layer_compare = gather_activations(model, pair[1], inputs)

    # Encode the activations using the corresponding SAEs
    sae_acts_base = sae.encode(act_layer_base.to(torch.float32))
    sae_acts_compare = sae2.encode(act_layer_compare.to(torch.float32))
    print(sae_acts_base.shape)


    # Reshape activations for pairwise comparison
    sae_acts_base_flat = sae_acts_base.squeeze(0).reshape(-1, 16384)
    sae_acts_compare_flat = sae_acts_compare.squeeze(0).reshape(-1, 16384)
    print(len(sae_acts_base_flat))
    print(len(sae_acts_compare_flat))
    

    # Compute cosine similarity
    cosine_sim = cosine_similarity(sae_acts_base_flat.cpu().numpy(), sae_acts_compare_flat.cpu().numpy())

    # Normalize cosine similarity values to the range [0, 1]
    cosine_sim = (cosine_sim - cosine_sim.min()) / (cosine_sim.max() - cosine_sim.min())

    # Compute mean similarity
    mean_similarity = cosine_sim.mean()
    print(f"Mean Cosine Similarity Between Layers {pair[0]} and {pair[1]}: {mean_similarity}")

    # Store the layer and mean similarity
    layers.append(pair[1])
    mean_similarities.append(mean_similarity)

    # Plot the cosine similarity heatmap
    plt.figure(figsize=(10, 8))
    sns.heatmap(cosine_sim, annot=True, fmt=".2f", cmap="viridis", cbar=True)
    plt.title(f"Cosine Similarity Heatmap Between Layer {pair[0]} and Layer {pair[1]}")
    plt.xlabel(f"Layer {pair[0]} Tokens")
    plt.ylabel(f"Layer {pair[1]} Tokens")
    filename = f"./output/token/cosine_similarity_heatmap_layers_{pair[0]}_and_{pair[1]}.png"
    plt.savefig(filename)
    plt.close()
    del sae
    del sae2
    torch.cuda.empty_cache()

# Plot the mean cosine similarity trend across layers
plt.figure(figsize=(10, 6))

min_value = min(mean_similarities)
max_value = max(mean_similarities)

# Normalize each value in the list
mean_similarities = [(value - min_value) / (max_value - min_value) for value in mean_similarities]
plt.plot(layers, mean_similarities, marker='o')
plt.title("Mean Cosine Similarity Between Layer 0 and Other Layers")
plt.xlabel("Layer Index")
plt.ylabel("Mean Normalized Cosine Similarity")
plt.grid()
filename = "./output/mean_cosine_similarity_trend.png"
plt.savefig(filename)
plt.close()


15
tensor([     2, 235285,   1938,    577,   1069,   5056,   1355,    577,  14034,
           575, 235248, 235284, 235276, 235276, 235276], device='cuda:1')
------------------------
Comparing layers [0, 0]


Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


torch.Size([1, 15, 16384])
15
15
Mean Cosine Similarity Between Layers 0 and 0: 0.12899655103683472
------------------------
Comparing layers [0, 1]
torch.Size([1, 15, 16384])
15
15
Mean Cosine Similarity Between Layers 0 and 1: 0.025364335626363754
------------------------
Comparing layers [0, 2]


In [4]:
import matplotlib.pyplot as plt
import seaborn as sns

# Define the input prompt
prompt = "I want to time travel back to Spain in 2000"
inputs = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=True).to("cuda:1")
print(len(inputs[0]))
print(inputs[0])

# Initialize lists to store layer indices and similarities
layers = []
mean_similarities = []

layer_pairs = []
# Iterate through layer pairs for comparisons
for i in range(15):
    layer_pairs.append([i,i+1])

for pair in layer_pairs:
    print(f"------------------------")
    print(f"Comparing layers {pair}")
    
    # Load the SAEs for different layers
    sae, cfg_dict, sparsity = SAE.from_pretrained(release="gemma-scope-2b-pt-res-canonical", sae_id=f"layer_{str(pair[0])}/width_16k/canonical")
    sae.to("cuda:1")
    sae2, cfg_dict2, sparsity2 = SAE.from_pretrained(release="gemma-scope-2b-pt-res-canonical", sae_id=f"layer_{str(pair[1])}/width_16k/canonical")
    sae2.to("cuda:1")

    # Function to gather residual activations from a specified layer
    def gather_activations(model, target_layer, inputs):
        target_act = None
        
        def hook(mod, inputs, outputs):
            nonlocal target_act
            target_act = outputs[0]
            return outputs

        handle = model.model.layers[target_layer].register_forward_hook(hook)
        _ = model.forward(inputs)
        handle.remove()
        return target_act

    # Gather residual activations for layers
    act_layer_base = gather_activations(model, pair[0], inputs)
    act_layer_compare = gather_activations(model, pair[1], inputs)

    # Encode the activations using the corresponding SAEs
    sae_acts_base = sae.encode(act_layer_base.to(torch.float32))
    sae_acts_compare = sae2.encode(act_layer_compare.to(torch.float32))
    print(sae_acts_base.shape)


    # Reshape activations for pairwise comparison
    sae_acts_base_flat = sae_acts_base.squeeze(0).reshape(-1, 16384)
    sae_acts_compare_flat = sae_acts_compare.squeeze(0).reshape(-1, 16384)
    print(len(sae_acts_base_flat))
    print(len(sae_acts_compare_flat))
    

    # Compute cosine similarity
    cosine_sim = cosine_similarity(sae_acts_base_flat.cpu().numpy(), sae_acts_compare_flat.cpu().numpy())

    # Normalize cosine similarity values to the range [0, 1]
    cosine_sim = (cosine_sim - cosine_sim.min()) / (cosine_sim.max() - cosine_sim.min())

    # Compute mean similarity
    mean_similarity = cosine_sim.mean()
    print(f"Mean Cosine Similarity Between Layers {pair[0]} and {pair[1]}: {mean_similarity}")

    # Store the layer and mean similarity
    layers.append(pair[1])
    mean_similarities.append(mean_similarity)

    # Plot the cosine similarity heatmap
    plt.figure(figsize=(10, 8))
    sns.heatmap(cosine_sim, annot=True, fmt=".2f", cmap="viridis", cbar=True)
    plt.title(f"Cosine Similarity Heatmap Between Layer {pair[0]} and Layer {pair[1]}")
    plt.xlabel(f"Layer {pair[0]} Tokens")
    plt.ylabel(f"Layer {pair[1]} Tokens")
    filename = f"./output/with_prev/cosine_similarity_heatmap_layers_{pair[0]}_and_{pair[1]}.png"
    plt.savefig(filename)
    plt.close()
    del sae
    del sae2
    torch.cuda.empty_cache()

# Plot the mean cosine similarity trend across layers
min_value = min(mean_similarities)
max_value = max(mean_similarities)

# Normalize each value in the list
mean_similarities = [(value - min_value) / (max_value - min_value) for value in mean_similarities]
plt.figure(figsize=(10, 6))
plt.plot(layers, mean_similarities, marker='o')
plt.title("Mean Cosine Similarity Between Layer prev and next Layers")
plt.xlabel("Layer Index")
plt.ylabel("Mean Normalized Cosine Similarity")
plt.grid()
filename = "./output/mean_cosine_similarity_trend_with_previous.png"
plt.savefig(filename)
plt.close()


15
tensor([     2, 235285,   1938,    577,   1069,   5056,   1355,    577,  14034,
           575, 235248, 235284, 235276, 235276, 235276], device='cuda:1')
------------------------
Comparing layers [0, 1]


Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


torch.Size([1, 15, 16384])
15
15
Mean Cosine Similarity Between Layers 0 and 1: 0.025364335626363754
------------------------
Comparing layers [1, 2]
torch.Size([1, 15, 16384])
15
15
Mean Cosine Similarity Between Layers 1 and 2: 0.010482142679393291
------------------------
Comparing layers [2, 3]
torch.Size([1, 15, 16384])
15
15
Mean Cosine Similarity Between Layers 2 and 3: 0.02757006697356701
------------------------
Comparing layers [3, 4]
torch.Size([1, 15, 16384])
15
15
Mean Cosine Similarity Between Layers 3 and 4: 0.022210560739040375
------------------------
Comparing layers [4, 5]
torch.Size([1, 15, 16384])
15
15
Mean Cosine Similarity Between Layers 4 and 5: 0.02662784792482853
------------------------
Comparing layers [5, 6]
torch.Size([1, 15, 16384])
15
15
Mean Cosine Similarity Between Layers 5 and 6: 0.0184357687830925
------------------------
Comparing layers [6, 7]
torch.Size([1, 15, 16384])
15
15
Mean Cosine Similarity Between Layers 6 and 7: 0.0148434117436409
-----