In [1]:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
import torch
import gc
from sae_lens import SAE  # pip install sae-lens
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import os
torch.set_grad_enabled(False) # avoid blowing up memory
# use cuda if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [None]:
# create a small range dataset for addition
addition_dataset = []
subtraction_dataset = []
for num in range(0, 100):
    for num2 in range(0, 100):
        addition_dataset.append("{}+{}".format(num, num2))
        subtraction_dataset.append("{}-{}".format(num, num2))
with open("addition_dataset.txt", "w") as f:
    for item in addition_dataset:
        f.write(f"{item}\n") 
with open("subtraction_dataset.txt", "w") as f:
    for item in subtraction_dataset:
        f.write(f"{item}\n")

In [None]:
import random
random.seed(42)
# Generate random list of numbers a (100 items)
a = [random.randint(10_000, 100_000,) for _ in range(100)]

random.seed(123)
# Generate random list of numbers b (100 items)
b = [random.randint(10_000, 100_000) for _ in range(100)]

# Create lists for addition and subtraction
addition_list = []
subtraction_list = []

# Create all pairwise combinations
for x in a:
    for y in b:
        addition_list.append(f"{x}+{y}")
        subtraction_list.append(f"{x}-{y}")

# Save addition dataset
with open("addition_dataset_high_range.txt", "w") as f:
    for item in addition_list:
        f.write("%s\n" % item)

# Save subtraction dataset
with open("subtraction_dataset_high_range.txt", "w") as f:
    for item in subtraction_list:
        f.write("%s\n" % item)


In [None]:
import platform
import sys
import torch
import subprocess

def get_cuda_version():
    try:
        nvcc_output = subprocess.check_output(["nvcc", "--version"]).decode("utf-8")
        for line in nvcc_output.splitlines():
            if "release" in line.lower():
                return line
        return "CUDA version not found in nvcc output"
    except (subprocess.CalledProcessError, FileNotFoundError):
        return "nvcc not found or CUDA not installed"

def get_gpu_info():
    try:
        if torch.cuda.is_available():
            return f"GPU: {torch.cuda.get_device_name(0)}, CUDA Available: {torch.cuda.is_available()}, CUDA Device Count: {torch.cuda.device_count()}"
        else:
            return "No CUDA-capable GPU detected"
    except Exception as e:
        return f"Error detecting GPU: {str(e)}"

print("System Information:")
print(f"Operating System: {platform.system()} {platform.release()} ({platform.platform()})")
print(f"Python Version: {sys.version}")
print(f"PyTorch Version: {torch.__version__}")
print(f"CUDA Version: {get_cuda_version()}")
print(f"GPU Info: {get_gpu_info()}")
try:
    import triton
    print(f"Triton Version: {triton.__version__}")
except ImportError:
    print("Triton Version: Not installed")

In [None]:
import os
import random
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Disable PyTorch compilation to avoid Triton dependency
os.environ["TORCH_NO_CUDA_COMPILATION"] = "1"

# Load model and tokenizer
model_name = "google/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to("cuda")

# Function to read dataset from file
def read_dataset(file_path):
    with open(file_path, "r") as f:
        return [line.strip() for line in f if line.strip()]

# Load datasets
addition_data = read_dataset("addition_dataset_high_range.txt")
subtraction_data = read_dataset("subtraction_dataset_high_range.txt")

# Combine datasets and select 10 random examples
combined_data = addition_data + subtraction_data
random_examples = random.sample(combined_data, 2)

# Function to generate response from model
def generate_response(prompt):
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    outputs = model.generate(
        **inputs,
        max_new_tokens=50,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True).strip()

# Process each example
for example in random_examples:
    prompt = f"What is {example}?"
    response = generate_response(prompt)
    print(f"Prompt: {prompt}")
    print(f"Output: {response}")
    print("-" * 50)

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch

model_name = "google/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.float16
)
text_generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
prompt = "Explain the theory of relativity in simple terms."
output = text_generator(prompt, max_new_tokens=100)
print(output[0]['generated_text'])


In [None]:
del addition_list, subtraction_list
gc.collect()

In [None]:
# open npy file
data = np.load(r"activ_freq\addition\layer_20\active_counts.npy", allow_pickle=True)
data, data.shape

In [None]:
# --- Load Model and Tokenizer ---
model_name = "google/gemma-2-2b-it"
num_examples = 800 # number of examples to generate

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=quantization_config,
    device_map="auto",
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
)

tokenizer = AutoTokenizer.from_pretrained(model_name)

# --- Prepare Example Batch ---
input_file = "./data/addition.txt"
# load the data
with open(input_file, "r") as f:
    lines = f.readlines()
    lines = [line.strip() for line in lines if line.strip()]
np.random.seed(42)
examples = np.random.choice(lines, num_examples, replace=False).tolist() 

batch = tokenizer(
    examples,
    padding=True,          # Pad to longest in batch
    return_tensors="pt",
    return_attention_mask=True,
)
batch = {k: v.to(model.device) for k, v in batch.items()}  # Move to GPU



In [None]:
# --- Forward Pass ---
with torch.no_grad():
    outputs = model(**batch, output_hidden_states=True)

# outputs.hidden_states is a tuple:
# (embeddings_output, layer1_output, layer2_output, ..., layerN_output)

# We want layer 20
layer_idx = 21
layer_hidden = outputs.hidden_states[layer_idx]  # (batch_size, seq_len, hidden_size)

# --- Find "=" token position for each example ---

# Find token id for "="
equal_token_id = tokenizer.convert_tokens_to_ids("=")

# Now find the index of "=" for each sequence
equal_positions = (batch["input_ids"] == equal_token_id).int().argmax(dim=1)


batch_size, seq_len, hidden_size = layer_hidden.shape

# Make a batch index tensor
batch_indices = torch.arange(batch_size, device=layer_hidden.device)

# Extract activations
equal_hidden_states = layer_hidden[batch_indices, equal_positions, :]  # (batch_size, hidden_size)

print(equal_hidden_states.shape)  # Should be (batch_size, hidden_size)

In [None]:
# --- Find the index of the last meaningful token (non-padding) for each example ---

# Get the attention mask (1 for real tokens, 0 for padding)
attention_mask = batch["attention_mask"]  # Shape: (batch_size, seq_len)

# Find the index of the last non-padding token for each sequence
# attention_mask shape: (batch_size, sequence_length)

# Step 1: Flip/reverse along dimension 1
reversed_mask = (attention_mask == 1).int().flip(dims=[1])

# Step 2: Get argmax on the reversed tensor
reversed_idx = reversed_mask.argmax(dim=1)

# Step 3: Adjust the index because we flipped it
last_idx = attention_mask.size(1) - 1 - reversed_idx


batch_size, seq_len, hidden_size = layer_hidden.shape

# Make a batch index tensor
batch_indices = torch.arange(batch_size, device=layer_hidden.device)

# Extract activations at the last meaningful token position
last_meaningful_hidden_states = layer_hidden[batch_indices, last_idx, :]  # (batch_size, hidden_size)

print(last_meaningful_hidden_states.shape)  # Should be (batch_size, hidden_size)


In [None]:

# Assuming `batch` is your output from tokenizer
attention_mask = batch["attention_mask"]    # shape (batch_size, seq_len)
input_ids = batch["input_ids"]               # shape (batch_size, seq_len)
eos_token_id = tokenizer.eos_token_id        # get the ID of the EOS token

# Find the last index where attention_mask == 1 for each row
last_indices = attention_mask.size(1) - 1 - torch.argmax(attention_mask.flip(dims=[1]), dim=1)

# Now check if input_ids at these positions == eos_token_id
last_tokens = input_ids[torch.arange(input_ids.size(0)), last_indices]
is_eos = (last_tokens == eos_token_id)

print(is_eos.sum())  # Boolean tensor: True if last real token is EOS

# Decode the last tokens to see what they are
last_tokens_decoded = tokenizer.batch_decode(last_tokens.unsqueeze(1), skip_special_tokens=False)

for idx, token in enumerate(last_tokens_decoded[:10]):
    print(f"Sample {idx}: Last token = '{token}'")
    print(f"Sample: {examples[idx]}\n")


In [None]:
del model, tokenizer  # Free up memory
del outputs, batch, layer_hidden, equal_positions, equal_token_id, batch_indices
gc.collect()
torch.cuda.empty_cache()

In [None]:
# sae, cfg_dict, sparsity = SAE.from_pretrained(
#     release = "gemma-scope-2b-pt-res-canonical",
#     sae_id = "layer_20/width_16k/canonical",
# )
# sae = sae.to(device)
# sae, cfg_dict, sparsity = SAE.from_pretrained(
#     release = "gemma-scope-2b-pt-att-canonical",
#     sae_id = "layer_20/width_16k/canonical",
# )


# sae, cfg_dict, sparsity = SAE.from_pretrained(
#     release = "gemma-scope-2b-pt-mlp-canonical",
#     sae_id = "layer_20/width_16k/canonical",
# )

In [None]:
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = "gemma-scope-2b-pt-res-canonical",
    sae_id = f"layer_20/width_16k/canonical", # 26 layers (25 is the last index)
)
sae = sae.to(device)

sae_acts = sae.encode(last_meaningful_hidden_states) # last_meaningful_hidden_states, equal_hidden_states

sae_acts.shape

In [None]:
sae_acts.shape

In [None]:
threshold = 1e-5  # or whatever you define as "zero"

sparsity = (torch.abs(sae_acts) < threshold).float().mean()
print(f"Sparsity: {sparsity.item() * 100:.2f}%")


In [None]:
threshold = 1e-5

# Boolean mask: [800, 16384]
active = (torch.abs(sae_acts) > threshold)

# Sum across the batch dimension (axis=0)
active_counts = active.sum(dim=0).cpu().numpy()  # [16384]

plt.figure(figsize=(20, 5))
plt.bar(range(len(active_counts)), active_counts, width=1.0)
plt.title("Number of times each feature is active")
plt.xlabel("Feature index")
plt.ylabel("Active count (out of 800)")
plt.show()


In [None]:
np.sort(active_counts)[::-1]

In [None]:
# Filter features with active count > 600
filtered_counts = active_counts[active_counts > 600]
filtered_indices = np.arange(len(filtered_counts))

plt.figure(figsize=(20, 5))
plt.bar(filtered_indices, filtered_counts, width=1.0)
plt.title("Features with Active Count > 600")
plt.xlabel("Filtered Feature Index")
plt.ylabel("Active Count (out of 800)")
plt.savefig('filtered_feature_active_counts.png')

In [None]:
sorted_counts = np.sort(active_counts)[::-1]  # descending order

plt.figure(figsize=(20, 5))
plt.plot(sorted_counts)
plt.title("Sorted Active Counts per Feature")
plt.xlabel("Feature rank (sorted)")
plt.ylabel("Active count (out of 800)")
plt.grid(True)
plt.xlim(0, 1000)
plt.show()


In [None]:
active_features = sae_acts > 0.01  # (batch_size, num_features)

# Check which features are consistently active across all examples
# We sum along the batch dimension and check if a feature is active for all examples (sum = batch_size)
consistent_features = active_features.sum(dim=0) >= (sae_acts.shape[0]*0.95)  # (num_features)

# Get indices of the features that are consistently active
consistent_feature_indices = torch.nonzero(consistent_features).squeeze()

# Print the consistent feature indices
print(f"Consistent feature indices: {consistent_feature_indices}")
# print(f"Number of consistent features: {consistent_feature_indices}")


In [None]:
# Flatten to 1D
flat_acts = sae_acts.flatten().cpu().numpy()

plt.figure(figsize=(10, 6))
plt.hist(flat_acts, bins=100, log=True, color='skyblue', edgecolor='black')
plt.title("Histogram of All SAE Activations (flattened)")
plt.xlabel("Activation Value")
plt.ylabel("Log Frequency")
plt.grid(True, linestyle='--', alpha=0.5)
plt.tight_layout()
plt.show()


In [None]:
(flat_acts > 5).sum(), (flat_acts < 1).sum(), flat_acts.shape

In [None]:

# Count how many features are > 0.01 for each sample (row)
threshold = 0.01
sparsity_counts = (sae_acts > threshold).sum(dim=1).cpu().numpy()  # shape: (800,)

# Plot
plt.figure(figsize=(14, 5))
plt.bar(range(len(sparsity_counts)), sparsity_counts, color='slateblue')
plt.title(f"Number of Active Features per Sample (Activation > {threshold})")
plt.xlabel("Sample Index")
plt.ylabel("Number of Active Features")
# add min, max, mean , and std to the plot
min_count = sparsity_counts.min()
max_count = sparsity_counts.max()
mean_count = sparsity_counts.mean()
std_count = sparsity_counts.std()

plt.axhline(y=mean_count, color='b', linestyle='--', label=f'Mean: {mean_count:.2f}', alpha=0.7)
plt.axhline(y=mean_count + std_count, color='orange', linestyle='--', label=f'Mean + 2 Std: {mean_count + 2*std_count:.2f}',  alpha=0.7)
plt.axhline(y=mean_count - std_count, color='purple', linestyle='--', label=f'Mean - 2 Std: {mean_count - 2*std_count:.2f}',  alpha=0.7)
# plt.axhline(y=min_count, color='r', linestyle='--', label=f'Min: {min_count}', alpha=0)
# plt.axhline(y=max_count, color='g', linestyle='--', label=f'Max: {max_count}',  alpha=0)
plt.legend()
plt.grid(True, linestyle='--', alpha=0.4)
plt.tight_layout()
plt.show()


In [None]:
# Choose a threshold (4 is common, adjust based on your activation scale)
threshold = 0.01 # activations are either 0 or >7

# Compute how many times each dimension is "active"
freqs = (sae_acts > threshold).sum(dim=0).cpu().numpy()  # shape: (16384,)

plt.figure(figsize=(14, 5))
plt.plot(freqs, color='mediumseagreen')
plt.title("Feature Activation Frequency")
plt.xlabel("Latent Dimension Index")
plt.ylabel(f"Number of Samples with Activation > {threshold}")
plt.grid(True, linestyle='--', alpha=0.5)
plt.tight_layout()
plt.show()


In [None]:
(freqs>100).sum(), (freqs>100).sum()/len(freqs), freqs.shape

In [None]:

# Ensure indices are on CPU and as a list
indices = consistent_feature_indices.cpu().tolist()

# Extract data for the selected indices
subset_data = sae_acts[:, indices].cpu().numpy()

# Convert to long-form DataFrame for Seaborn
df = pd.DataFrame(subset_data, columns=[f"Feature {i}" for i in indices])
df_melted = df.melt(var_name="Dimension", value_name="Activation")

plt.figure(figsize=(12, 6))
sns.violinplot(x="Dimension", y="Activation", data=df_melted, palette="pastel")
plt.title("Activation Distributions for Selected Latent Dimensions")
plt.xticks(rotation=45)
plt.grid(True, linestyle='--', alpha=0.4)
plt.tight_layout()
plt.show()


In [None]:
# Compute average activation per dimension
mean_acts = sae_acts.mean(dim=0).cpu().numpy()

# Sort values descending
sorted_acts = np.sort(mean_acts)[::-1]

# Plot top N
top_n = 300  # Adjust as needed
plt.figure(figsize=(14, 5))
plt.bar(range(top_n), sorted_acts[:top_n], color='coral')
plt.title(f"Top {top_n} Latent Dimensions by Mean Activation")
plt.xlabel("Sorted Latent Dimensions (Descending)")
plt.ylabel("Mean Activation Across 800 Samples")
plt.grid(True, linestyle='--', alpha=0.5)
plt.tight_layout()
plt.show()


In [None]:
from IPython.display import IFrame
html_template = "https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

def get_dashboard_html(sae_release = "gemma-2-2b", sae_id="20-gemmascope-res-16k", feature_idx=0):
    return html_template.format(sae_release, sae_id, feature_idx)

html = get_dashboard_html(sae_release = "gemma-2-2b", sae_id="20-gemmascope-res-16k", feature_idx=10004)
IFrame(html, width=1200, height=600)

# Find the top consistent features (active 100% of the time)
consistent_feature_indices = torch.nonzero(consistent_features).squeeze()

# Loop through each consistent feature and generate the corresponding NeuroPedia dashboard HTML
dashboard_htmls = [
    get_dashboard_html(sae_release="gemma-2-2b", sae_id="20-gemmascope-res-16k", feature_idx=feature_idx.item())
    for feature_idx in consistent_feature_indices
]

# Display all dashboards using IFrames
from IPython.display import display

# Display each IFrame (dashboard) for the consistent features
for html in dashboard_htmls:
    print(html)
    # display(IFrame(html, width=1200, height=600))


In [None]:
assert 0

# MISC

In [None]:
# delete old model and inputs if exist
del model
del inputs

# run garbage collection
gc.collect()

# clear cached memory in CUDA
torch.cuda.empty_cache()


In [None]:
# load tokenizer
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it")


input_file = "./data/addition.txt"
# load the data
with open(input_file, "r") as f:
    lines = f.readlines()
    lines = [line.strip() for line in lines if line.strip()]


len(lines) 

In [None]:
# select 10 thousand lines randomly
np.random.seed(42)
lines = np.random.choice(lines, 1000, replace=False).tolist() # 1d list: 100 examples

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-2b-it",
    device_map='auto',
    quantization_config=BitsAndBytesConfig(
        load_in_8bit=True,
        bnb_8bit_compute_dtype=torch.float16,
    ),
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
)

In [None]:
# tokenize the data
inputs = tokenizer(
    lines, 
    return_tensors="pt", 
    padding=True, 
    return_attention_mask=True,
    return_token_type_ids=False,
    return_offsets_mapping=True  # Important for mapping tokens back to characters
).to(device)



In [None]:
# print(inputs.keys())
# print(type(inputs))

# print(inputs["input_ids"].shape) # examples * tokens
# print(inputs["attention_mask"].shape) # examples * tokens
# print(inputs["offset_mapping"].shape) # examples * tokens * 2 (start and end offsets)

In [None]:
equal_token_id = tokenizer("=", add_special_tokens=False)["input_ids"][0]
print(f"Equal token id: {equal_token_id}")

In [None]:
equal_positions = (inputs["input_ids"] == equal_token_id).nonzero(as_tuple=False)
# equal_positions.shape: [num_equals_found, 2]
# Each row: [batch_idx, sequence_idx]

In [None]:
outputs = model(**inputs, output_hidden_states=True)
hidden_states = outputs.hidden_states

In [None]:
print(len(hidden_states))
print(type(hidden_states))
print(hidden_states[0].shape)

In [None]:
# delete old model and inputs if exist
del model
del inputs

# run garbage collection
gc.collect()

# clear cached memory in CUDA
torch.cuda.empty_cache()


In [None]:
from sae_lens import SAE  # pip install sae-lens

sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = "gemma-scope-2b-pt-res-canonical",
    sae_id = "layer_20/width_16k/canonical",
)
sae = sae.to(device)
# sae, cfg_dict, sparsity = SAE.from_pretrained(
#     release = "gemma-scope-2b-pt-att-canonical",
#     sae_id = "layer_20/width_16k/canonical",
# )


# sae, cfg_dict, sparsity = SAE.from_pretrained(
#     release = "gemma-scope-2b-pt-mlp-canonical",
#     sae_id = "layer_20/width_16k/canonical",
# )


In [None]:
hidden_states = outputs.hidden_states
hidden_states = hidden_states[21]

In [None]:
sae_acts = sae.encode(hidden_states[21])

In [None]:
sae_acts.shape, hidden_states[21].shape, hidden_states[21].dtype, sae_acts.dtype

In [None]:
from IPython.display import IFrame
html_template = "https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

def get_dashboard_html(sae_release = "gemma-2-2b", sae_id="20-gemmascope-res-16k", feature_idx=0):
    return html_template.format(sae_release, sae_id, feature_idx)

html = get_dashboard_html(sae_release = "gemma-2-2b", sae_id="20-gemmascope-res-16k", feature_idx=10004)
IFrame(html, width=1200, height=600)