In [None]:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
from huggingface_hub import hf_hub_download, notebook_login
import numpy as np
import torch


In [None]:
import json
from collections import Counter
import torch.nn as nn

In [None]:
huggingface_token = ""

In [None]:
notebook_login()

In [None]:
torch.set_grad_enabled(False) # avoid blowing up mem

model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-2b",
    device_map='auto',
)

In [None]:
tokenizer =  AutoTokenizer.from_pretrained("google/gemma-2-2b")

In [None]:


# Specify the JSONL filename and the target prefix.
filename = "./data/train_test_data/transcript_componenttext_2010_1.jsonl"
# target_prefix = "405869_84102073"
target_prefix = "285467_78846340"

# This list will hold tuples of (key, text) for the matching records.
filtered_records = []

# Load the JSONL file.
with open(filename, "r", encoding="utf-8") as file:
    for line in file:
        line = line.strip()
        if not line:
            continue
        record = json.loads(line)
        # Each JSON object is expected to have a single key/value pair.
        for key, text in record.items():
            if key.startswith(target_prefix):
                filtered_records.append((key, text))

# Define a helper function to extract the sorting element.
# We assume the key is formatted as "154924_5435195_117_1_1"
# and we want to sort by the 4th component (index 3 when split by '_').
def sort_key(item):
    tokens = item[0].split("_")
    # Convert the 4th token to an integer for proper numerical sorting.
    return int(tokens[3])

# Sort the filtered records by the fourth element in the key.
sorted_records = sorted(filtered_records, key=sort_key)

# Concatenate the texts in the sorted order.
concatenated_text = "\n".join(text for key, text in sorted_records)

# Optionally, print or save the concatenated text.
print(concatenated_text)

In [None]:
# Use the tokenizer to convert it to tokens. Note that this implicitly adds a special "Beginning of Sequence" or <bos> token to the start
inputs = tokenizer.encode(concatenated_text, return_tensors="pt", add_special_tokens=True).to("cuda")
print(inputs)

# Pass it in to the model and generate text
# outputs = model.generate(input_ids=inputs, max_new_tokens=50)
# print(tokenizer.decode(outputs[0]))

In [None]:
path_to_params = hf_hub_download(
    repo_id="google/gemma-scope-2b-pt-res",
    filename="layer_20/width_16k/average_l0_71/params.npz",
    force_download=False,
)


In [None]:
params = np.load(path_to_params)
pt_params = {k: torch.from_numpy(v).cuda() for k, v in params.items()}


In [None]:
{k:v.shape for k, v in pt_params.items()}

In [None]:
pt_params["W_enc"].norm(dim=0)

In [None]:
class JumpReLUSAE(nn.Module):
  def __init__(self, d_model, d_sae):

    super().__init__()
    self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae))
    self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model))
    self.threshold = nn.Parameter(torch.zeros(d_sae))
    self.b_enc = nn.Parameter(torch.zeros(d_sae))
    self.b_dec = nn.Parameter(torch.zeros(d_model))

  def encode(self, input_acts):
    pre_acts = input_acts @ self.W_enc + self.b_enc
    mask = (pre_acts > self.threshold)
    acts = mask * torch.nn.functional.relu(pre_acts)
    return acts

  def decode(self, acts):
    return acts @ self.W_dec + self.b_dec

  def forward(self, acts):
    acts = self.encode(acts)
    recon = self.decode(acts)
    return recon

In [None]:
sae = JumpReLUSAE(params['W_enc'].shape[0], params['W_enc'].shape[1])
sae.load_state_dict(pt_params)

In [None]:

def gather_residual_activations(model, target_layer, inputs):
    target_act = None
    def gather_target_act_hook(mod, inputs, outputs):
        nonlocal target_act # make sure we can modify the target_act from the outer scope
        target_act = outputs[0]
        return outputs
    handle = model.model.layers[target_layer].register_forward_hook(gather_target_act_hook)
    _ = model.forward(inputs)
    handle.remove()
    return target_act

In [None]:
target_act = gather_residual_activations(model, 20, inputs)

In [None]:
sae.cuda()

In [None]:
sae_acts = sae.encode(target_act.to(torch.float32))
recon = sae.decode(sae_acts)

In [None]:
1 - torch.mean((recon[:, 1:] - target_act[:, 1:].to(torch.float32)) **2) / (target_act[:, 1:].to(torch.float32).var())

In [None]:
(sae_acts > 1).sum(-1)

In [None]:
sae_acts.shape

In [None]:
values, inds = sae_acts.max(-1)

inds

In [None]:
# positive
list_val = inds.cpu().tolist()

flat_list = list_val[0] if len(list_val) == 1 else [item for sublist in list_val for item in sublist]
counts = Counter(flat_list)

sorted_counts = counts.most_common()

print("Summary of value counts (ranked high to low):")
for value, count in sorted_counts:
    print(f"Value: {value}, Count: {count}")

In [None]:
# negative 
list_val = inds.cpu().tolist()

flat_list = list_val[0] if len(list_val) == 1 else [item for sublist in list_val for item in sublist]
counts = Counter(flat_list)

sorted_counts = counts.most_common()

print("Summary of value counts (ranked high to low):")
for value, count in sorted_counts:
    print(f"Value: {value}, Count: {count}")

In [None]:
from IPython.display import IFrame
html_template = "https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

def get_dashboard_html(sae_release = "gemma-2-2b", sae_id="20-gemmascope-res-16k", feature_idx=0):
    return html_template.format(sae_release, sae_id, feature_idx)


In [None]:

html = get_dashboard_html(sae_release = "gemma-2-2b", sae_id="20-gemmascope-res-16k", feature_idx=4380)
IFrame(html, width=1200, height=600)

In [None]:
# html = get_dashboard_html(sae_release = "gemma-2-2b", sae_id="20-gemmascope-res-16k", feature_idx=13860)
html = get_dashboard_html(sae_release = "gemma-2-2b", sae_id="20-gemmascope-res-16k", feature_idx=12225)
IFrame(html, width=1200, height=600)