In [1]:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
from huggingface_hub import hf_hub_download, notebook_login
import numpy as np
import torch


In [2]:
import json
from collections import Counter
import torch.nn as nn

In [None]:
huggingface_token = ""

In [4]:
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [5]:
torch.set_grad_enabled(False) # avoid blowing up mem

model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-2b",
    device_map='auto',
)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [6]:
tokenizer =  AutoTokenizer.from_pretrained("google/gemma-2-2b")

In [7]:


# Specify the JSONL filename and the target prefix.
filename = "./data/train_test_data/transcript_componenttext_2010_1.jsonl"
# target_prefix = "405869_84102073"
target_prefix = "285467_78846340"

# This list will hold tuples of (key, text) for the matching records.
filtered_records = []

# Load the JSONL file.
with open(filename, "r", encoding="utf-8") as file:
    for line in file:
        line = line.strip()
        if not line:
            continue
        record = json.loads(line)
        # Each JSON object is expected to have a single key/value pair.
        for key, text in record.items():
            if key.startswith(target_prefix):
                filtered_records.append((key, text))

# Define a helper function to extract the sorting element.
# We assume the key is formatted as "154924_5435195_117_1_1"
# and we want to sort by the 4th component (index 3 when split by '_').
def sort_key(item):
    tokens = item[0].split("_")
    # Convert the 4th token to an integer for proper numerical sorting.
    return int(tokens[3])

# Sort the filtered records by the fourth element in the key.
sorted_records = sorted(filtered_records, key=sort_key)

# Concatenate the texts in the sorted order.
concatenated_text = "\n".join(text for key, text in sorted_records)

# Optionally, print or save the concatenated text.
print(concatenated_text)

Ladies and gentlemen, thank you for standing by and welcome to the Q4 2009 Earnings Call. [Operator Instructions] And I would now like to turn the conference over to your host, Vice President of Investor Relations, Mr. Phil Johnson. Please go ahead, sir.
Good morning, and thanks for joining us for Eli Lilly & Co.'s Fourth Quarter 2009 Earnings Conference Call. I'm Phil Johnson, Vice President of Investor Relations. Joining me are our President, CEO and Chairman, John Lechleiter; our Chief Financial Officer, Derica Rice; our President of Lilly Research Laboratories, Dr. Steve Paul; and Ronika Pletcher and Nick Lemen from Investor Relations.During this conference call, we anticipate making projections and forward-looking statements based on our current expectations. Our actual results could differ materially due to a number of factors, including those listed on Slide 3 and those outlined in our latest 10-K and 10-Q filed with the Securities and Exchange Commission. The information we pro

In [8]:
# Use the tokenizer to convert it to tokens. Note that this implicitly adds a special "Beginning of Sequence" or <bos> token to the start
inputs = tokenizer.encode(concatenated_text, return_tensors="pt", add_special_tokens=True).to("cuda")
print(inputs)

# Pass it in to the model and generate text
# outputs = model.generate(input_ids=inputs, max_new_tokens=50)
# print(tokenizer.decode(outputs[0]))

tensor([[     2,  75114,    578,  ...,   1490,  61826, 235265]],
       device='cuda:0')


This is a friendly reminder - the current text generation call will exceed the model's predefined maximum length (8192). Depending on the model, you may observe exceptions, performance degradation, or nothing at all.


<bos>Ladies and gentlemen, thank you for standing by and welcome to the Q4 2009 Earnings Call. [Operator Instructions] And I would now like to turn the conference over to your host, Vice President of Investor Relations, Mr. Phil Johnson. Please go ahead, sir.
Good morning, and thanks for joining us for Eli Lilly & Co.'s Fourth Quarter 2009 Earnings Conference Call. I'm Phil Johnson, Vice President of Investor Relations. Joining me are our President, CEO and Chairman, John Lechleiter; our Chief Financial Officer, Derica Rice; our President of Lilly Research Laboratories, Dr. Steve Paul; and Ronika Pletcher and Nick Lemen from Investor Relations.During this conference call, we anticipate making projections and forward-looking statements based on our current expectations. Our actual results could differ materially due to a number of factors, including those listed on Slide 3 and those outlined in our latest 10-K and 10-Q filed with the Securities and Exchange Commission. The information w

In [9]:
path_to_params = hf_hub_download(
    repo_id="google/gemma-scope-2b-pt-res",
    filename="layer_20/width_16k/average_l0_71/params.npz",
    force_download=False,
)


In [10]:
params = np.load(path_to_params)
pt_params = {k: torch.from_numpy(v).cuda() for k, v in params.items()}


In [11]:
{k:v.shape for k, v in pt_params.items()}

{'W_dec': torch.Size([16384, 2304]),
 'W_enc': torch.Size([2304, 16384]),
 'b_dec': torch.Size([2304]),
 'b_enc': torch.Size([16384]),
 'threshold': torch.Size([16384])}

In [12]:
pt_params["W_enc"].norm(dim=0)

tensor([1.2101, 1.1695, 0.9836,  ..., 1.0630, 0.9997, 1.1070], device='cuda:0')

In [13]:
class JumpReLUSAE(nn.Module):
  def __init__(self, d_model, d_sae):

    super().__init__()
    self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae))
    self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model))
    self.threshold = nn.Parameter(torch.zeros(d_sae))
    self.b_enc = nn.Parameter(torch.zeros(d_sae))
    self.b_dec = nn.Parameter(torch.zeros(d_model))

  def encode(self, input_acts):
    pre_acts = input_acts @ self.W_enc + self.b_enc
    mask = (pre_acts > self.threshold)
    acts = mask * torch.nn.functional.relu(pre_acts)
    return acts

  def decode(self, acts):
    return acts @ self.W_dec + self.b_dec

  def forward(self, acts):
    acts = self.encode(acts)
    recon = self.decode(acts)
    return recon

In [14]:
sae = JumpReLUSAE(params['W_enc'].shape[0], params['W_enc'].shape[1])
sae.load_state_dict(pt_params)

<All keys matched successfully>

In [15]:

def gather_residual_activations(model, target_layer, inputs):
    target_act = None
    def gather_target_act_hook(mod, inputs, outputs):
        nonlocal target_act # make sure we can modify the target_act from the outer scope
        target_act = outputs[0]
        return outputs
    handle = model.model.layers[target_layer].register_forward_hook(gather_target_act_hook)
    _ = model.forward(inputs)
    handle.remove()
    return target_act

In [16]:
target_act = gather_residual_activations(model, 20, inputs)

In [17]:
sae.cuda()

JumpReLUSAE()

In [18]:
sae_acts = sae.encode(target_act.to(torch.float32))
recon = sae.decode(sae_acts)

In [19]:
1 - torch.mean((recon[:, 1:] - target_act[:, 1:].to(torch.float32)) **2) / (target_act[:, 1:].to(torch.float32).var())

tensor(0.3704, device='cuda:0')

In [20]:
(sae_acts > 1).sum(-1)

tensor([[7017,   72,  106,  ...,   39,   51,   35]], device='cuda:0')

In [21]:
sae_acts.shape

torch.Size([1, 10392, 16384])

In [22]:
values, inds = sae_acts.max(-1)

inds

tensor([[ 6631, 15743,  7400,  ...,   121,  8684,  8684]], device='cuda:0')

In [23]:
# positive
list_val = inds.cpu().tolist()

flat_list = list_val[0] if len(list_val) == 1 else [item for sublist in list_val for item in sublist]
counts = Counter(flat_list)

sorted_counts = counts.most_common()

print("Summary of value counts (ranked high to low):")
for value, count in sorted_counts:
    print(f"Value: {value}, Count: {count}")

Summary of value counts (ranked high to low):
Value: 8684, Count: 1352
Value: 13860, Count: 384
Value: 14537, Count: 323
Value: 15509, Count: 294
Value: 8221, Count: 238
Value: 10881, Count: 228
Value: 14767, Count: 207
Value: 10461, Count: 169
Value: 5816, Count: 162
Value: 12225, Count: 162
Value: 6840, Count: 159
Value: 1858, Count: 149
Value: 7400, Count: 118
Value: 5627, Count: 101
Value: 9982, Count: 97
Value: 13186, Count: 89
Value: 12459, Count: 84
Value: 6631, Count: 75
Value: 5895, Count: 75
Value: 12261, Count: 70
Value: 10192, Count: 66
Value: 15596, Count: 66
Value: 10867, Count: 65
Value: 6027, Count: 63
Value: 2222, Count: 62
Value: 7449, Count: 62
Value: 3223, Count: 57
Value: 12935, Count: 56
Value: 7767, Count: 52
Value: 6504, Count: 50
Value: 4223, Count: 49
Value: 2236, Count: 45
Value: 1564, Count: 44
Value: 10640, Count: 39
Value: 8820, Count: 39
Value: 5002, Count: 39
Value: 7136, Count: 37
Value: 7182, Count: 37
Value: 13478, Count: 37
Value: 3371, Count: 36
Val

In [None]:
# negative 
list_val = inds.cpu().tolist()

flat_list = list_val[0] if len(list_val) == 1 else [item for sublist in list_val for item in sublist]
counts = Counter(flat_list)

sorted_counts = counts.most_common()

print("Summary of value counts (ranked high to low):")
for value, count in sorted_counts:
    print(f"Value: {value}, Count: {count}")

Summary of value counts (ranked high to low):
Value: 8684, Count: 1302
Value: 13860, Count: 740
Value: 15509, Count: 389
Value: 14537, Count: 274
Value: 10881, Count: 260
Value: 10461, Count: 228
Value: 5816, Count: 193
Value: 14767, Count: 163
Value: 6840, Count: 161
Value: 1858, Count: 138
Value: 9982, Count: 124
Value: 7400, Count: 114
Value: 5895, Count: 113
Value: 5627, Count: 111
Value: 15596, Count: 107
Value: 13186, Count: 86
Value: 4380, Count: 82
Value: 12261, Count: 73
Value: 3344, Count: 69
Value: 10192, Count: 65
Value: 7449, Count: 60
Value: 4223, Count: 60
Value: 12459, Count: 58
Value: 6027, Count: 58
Value: 6631, Count: 57
Value: 3223, Count: 55
Value: 12935, Count: 53
Value: 3229, Count: 50
Value: 10867, Count: 46
Value: 1564, Count: 42
Value: 10640, Count: 42
Value: 12748, Count: 40
Value: 15147, Count: 39
Value: 2236, Count: 38
Value: 10873, Count: 38
Value: 11571, Count: 38
Value: 1374, Count: 37
Value: 11666, Count: 35
Value: 4615, Count: 35
Value: 7431, Count: 34

In [25]:
from IPython.display import IFrame
html_template = "https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

def get_dashboard_html(sae_release = "gemma-2-2b", sae_id="20-gemmascope-res-16k", feature_idx=0):
    return html_template.format(sae_release, sae_id, feature_idx)


In [None]:

html = get_dashboard_html(sae_release = "gemma-2-2b", sae_id="20-gemmascope-res-16k", feature_idx=4380)
IFrame(html, width=1200, height=600)

In [26]:
# html = get_dashboard_html(sae_release = "gemma-2-2b", sae_id="20-gemmascope-res-16k", feature_idx=13860)
html = get_dashboard_html(sae_release = "gemma-2-2b", sae_id="20-gemmascope-res-16k", feature_idx=12225)
IFrame(html, width=1200, height=600)