In [1]:
import sys
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

sys.path.append('sparsify')

from sparsify import Sae

import pandas as pd
import plotly.express as px
from types import SimpleNamespace

Triton not installed, using eager implementation of SAE decoder.


In [2]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x26ab6601520>

In [3]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
torch.set_default_device(DEVICE)
torch.get_default_device()

device(type='cuda', index=0)

In [5]:
tokenizer = AutoTokenizer.from_pretrained('deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B')
model = AutoModelForCausalLM.from_pretrained('deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B')

In [6]:
def inference(chat_prompt, max_new_tokens=400, temperature=0.6, add_generation_prompt=True, output_hidden_states=True):
    tokenized_inputs = tokenizer.apply_chat_template(chat_prompt, tokenize=True, add_generation_prompt=add_generation_prompt, return_tensors='pt', return_dict=True)

    model_out = model.generate(
        tokenized_inputs.input_ids,
        attention_mask=tokenized_inputs.attention_mask,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        output_hidden_states=output_hidden_states,
        return_dict_in_generate=True
    )
    
    return model_out

In [7]:
def decode_inference(model_out_sequences, skip_special_tokens=False):
    decoded_out = tokenizer.batch_decode(model_out_sequences, skip_special_tokens=skip_special_tokens)
    return decoded_out

In [8]:
def create_chat_prompts(prompts, suffix=''):
    chat_prompts = [
        [
            {
                'role': 'user',
                'content': prompt_content + suffix
            }
        ] for prompt_content in prompts
    ]
    return chat_prompts

In [9]:
REASONING_SUFFIX = ' Please reason step by step in a few words, and put your final answer within \\boxed{}'

problem_prompts = [
    '2 + x = 10 What is the value of x?',
    'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?',
    'Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?'
]


chat_prompts = create_chat_prompts(problem_prompts, suffix=REASONING_SUFFIX)

In [10]:
DIRECT_SUFFIX = '<｜Assistant｜>'

direct_chat_prompts = create_chat_prompts(problem_prompts, suffix=DIRECT_SUFFIX)

In [11]:
BYPASS_REASONING_SUFFIX = '<｜Assistant｜><think>\n\n</think>'

bypass_reasoning_chat_prompts = create_chat_prompts(problem_prompts, suffix=BYPASS_REASONING_SUFFIX)

### Usage Recommendations
https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B#usage-recommendations


1. Set the temperature within the range of 0.5-0.7 (0.6 is recommended) to prevent endless repetitions or incoherent outputs.

2. Avoid adding a system prompt; all instructions should be contained within the user prompt.

3. For mathematical problems, it is advisable to include a directive in your prompt such as: "Please reason step by step, and put your final answer within \boxed{}."

4. When evaluating model performance, it is recommended to conduct multiple tests and average the results.

Additionally, we have observed that the DeepSeek-R1 series models tend to bypass thinking pattern (i.e., outputting "<think>\n\n</think>") when responding to certain queries, which can adversely affect the model's performance. To ensure that the model engages in thorough reasoning, we recommend enforcing the model to initiate its response with "<think>\n" at the beginning of every output.

### Testing The Model

In [10]:
model.config

Qwen2Config {
  "_attn_implementation_autoset": true,
  "_name_or_path": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
  "architectures": [
    "Qwen2ForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 151643,
  "eos_token_id": 151643,
  "hidden_act": "silu",
  "hidden_size": 1536,
  "initializer_range": 0.02,
  "intermediate_size": 8960,
  "max_position_embeddings": 131072,
  "max_window_layers": 21,
  "model_type": "qwen2",
  "num_attention_heads": 12,
  "num_hidden_layers": 28,
  "num_key_value_heads": 2,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 10000,
  "sliding_window": null,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.48.3",
  "use_cache": true,
  "use_mrope": false,
  "use_sliding_window": false,
  "vocab_size": 151936
}

In [12]:
test_prompt_a = [
    {
        'role': 'user',
        'content': '2 + 3 = x What is the value of x? Please reason step by step, and put your final answer within \\boxed{}'
    }
]

chat_text_test_result = tokenizer.apply_chat_template(test_prompt_a, tokenize=False, add_generation_prompt=True)
no_generation_prompt_result = tokenizer.apply_chat_template(test_prompt_a, tokenize=False, add_generation_prompt=False)

print(f'With generation prompt: "{chat_text_test_result}"')
print(f'Without generation prompt: "{no_generation_prompt_result}"')

With generation prompt: "<｜begin▁of▁sentence｜><｜User｜>2 + 3 = x What is the value of x? Please reason step by step, and put your final answer within \boxed{}<｜Assistant｜><think>
"
Without generation prompt: "<｜begin▁of▁sentence｜><｜User｜>2 + 3 = x What is the value of x? Please reason step by step, and put your final answer within \boxed{}"


#### With Reasoning Prompt

In [13]:
test_chat_prompt = [
    {
        'role': 'user',
        'content': '2 + 3 = x What is the value of x? Please reason step by step, and put your final answer within \\boxed{}'
    }
]

decode_inference(
    inference(test_chat_prompt, output_hidden_states=False).sequences
)

Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` instance instead by default (as opposed to the legacy tuple of tuples format). If you want to keep returning the legacy format, please set `return_legacy_cache=True`.


["<｜begin▁of▁sentence｜><｜User｜>2 + 3 = x What is the value of x? Please reason step by step, and put your final answer within \\boxed{}<｜Assistant｜><think>\nAlright, so I've got this equation here: 2 plus 3 equals x. Hmm, okay, that seems straightforward, but let me think through it step by step to make sure I understand everything properly.\n\nFirst, I know that in math, equations are statements that show that two expressions are equal. So, in this case, the expression on the left side is 2 + 3, and the expression on the right side is x. My goal is to find out what value x has.\n\nLet me start by recalling what addition means. Addition is combining two numbers to get a total. So, when I add 2 and 3 together, I'm essentially combining two quantities. It's like having 2 apples and then getting 3 more apples. How many apples do I have in total? I think that's 5 apples. So, does that mean x is 5?\n\nWait, let me make sure I'm not just guessing here. Maybe I should write it down to visuali

In [14]:
for i in range(len(chat_prompts)):
    print(f'** Inference Prompt {i+1}:')
    print(
        decode_inference(
            inference(chat_prompts[i], output_hidden_states=False).sequences
        )
    )
    print('\n')

Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


** Inference Prompt 1:


Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


["<｜begin▁of▁sentence｜><｜User｜>2 + x = 10 What is the value of x? Please reason step by step in a few words, and put your final answer within \\boxed{}<｜Assistant｜><think>\nOkay, so I have this equation here: 2 plus x equals 10. Hmm, I need to find the value of x. Alright, let me think about how to approach this. I remember from class that when you have an equation like this, you want to isolate the variable, which in this case is x. \n\nSo, the equation is 2 + x = 10. I need to get x by itself on one side of the equation. To do that, I should get rid of the 2 that's being added to x. How do I do that? Well, the opposite of addition is subtraction, right? So if I subtract 2 from both sides of the equation, that should help me isolate x. \n\nLet me write that out step by step. Starting with 2 + x = 10, if I subtract 2 from both sides, it should look like this:\n\n2 + x - 2 = 10 - 2\n\nSimplifying both sides, the 2 and -2 on the left side cancel each other out, leaving just x. On the rig

Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


['<｜begin▁of▁sentence｜><｜User｜>Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? Please reason step by step in a few words, and put your final answer within \\boxed{}<｜Assistant｜><think>\nAlright, so I have this problem here about Natalia selling clips. Let me read it again: "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?" Hmm, okay, let\'s break this down step by step.\n\nFirst, in April, Natalia sold clips to 48 of her friends. So, that\'s pretty straightforward. She sold to 48 friends, so I think that means she sold 48 clips in April. Yeah, that makes sense. It says "to 48 of her friends," so each friend bought one clip, so 48 clips total in April.\n\nThen, in May, she sold half as many clips as she did in April. So, if she sold 48 in April, half of that would be... let m

#### Without think token added in prompt

In [15]:
for i in range(len(direct_chat_prompts)):
    print(f'** Inference Prompt {i+1}:')
    print(
        decode_inference(
            inference(direct_chat_prompts[i], add_generation_prompt=False, output_hidden_states=False).sequences
        )
    )
    print('\n')

Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


** Inference Prompt 1:


Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


["<｜begin▁of▁sentence｜><｜User｜>2 + x = 10 What is the value of x?<｜Assistant｜><think>\nFirst, I need to solve the equation 2 + x = 10 for x.\n\nTo isolate x, I'll subtract 2 from both sides of the equation.\n\nThis simplifies to x = 8.\n\nSo, the value of x is 8.\n</think>\n\nTo solve the equation \\(2 + x = 10\\) for \\(x\\), follow these steps:\n\n1. **Subtract 2 from both sides** of the equation to isolate \\(x\\):\n   \\[\n   2 + x - 2 = 10 - 2\n   \\]\n   \n2. **Simplify both sides**:\n   \\[\n   x = 8\n   \\]\n   \n**Final Answer:**\n\\[\n\\boxed{8}\n\\]<｜end▁of▁sentence｜>"]


** Inference Prompt 2:


Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


["<｜begin▁of▁sentence｜><｜User｜>Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?<｜Assistant｜><think>\nFirst, I need to determine how many clips Natalia sold in May. The problem states that she sold half as many clips in May as she did in April. Since she sold 48 clips in April, I'll calculate half of 48.\n\nNext, I'll add the number of clips sold in April to the number sold in May to find the total for both months.\n</think>\n\nLet's determine how many clips Natalia sold in April and May and then find the total.\n\n1. ** Clips Sold in April:**\n   - Natalia sold **48 clips** in April.\n\n2. ** Clips Sold in May:**\n   - She sold **half as many clips** in May as she did in April.\n   - **Calculation:** \\( \\frac{48}{2} = 24 \\) clips.\n\n3. **Total Clips Sold in April and May:**\n   - **Sum:** \\( 48 \\text{ (April)} + 24 \\text{ (May)} = 72 \\text{ clips} \\).\n\n\\[\n\\boxed{72}\

#### Bypass reasoning in prompt

In [16]:
for i in range(len(bypass_reasoning_chat_prompts)):
    print(f'** Inference Prompt {i+1}:')
    print(
        decode_inference(
            inference(bypass_reasoning_chat_prompts[i], add_generation_prompt=False, output_hidden_states=False).sequences
        )
    )
    print('\n')

Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


** Inference Prompt 1:


Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


['<｜begin▁of▁sentence｜><｜User｜>2 + x = 10 What is the value of x?<｜Assistant｜><think>\n\n</think>\n\nTo find the value of \\( x \\) in the equation \\( 2 + x = 10 \\), follow these steps:\n\n1. **Isolate the variable** \\( x \\) by subtracting 2 from both sides of the equation:\n   \\[\n   2 + x - 2 = 10 - 2\n   \\]\n   Simplifying both sides gives:\n   \\[\n   x = 8\n   \\]\n\n**Final Answer:** \\( \\boxed{8} \\)<｜end▁of▁sentence｜>']


** Inference Prompt 2:


Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


['<｜begin▁of▁sentence｜><｜User｜>Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?<｜Assistant｜><think>\n\n</think>\n\nTo determine the total number of clips Natalia sold in April and May, follow these steps:\n\n1. **April Sales**: Natalia sold 48 clips to her friends.\n   \n2. **May Sales**: She sold half as many clips in May compared to April.\n   - **Calculation**: \\( \\frac{48}{2} = 24 \\) clips.\n   \n3. **Total Sales**: Add the sales from both months.\n   - **Calculation**: \\( 48 + 24 = 72 \\) clips.\n\n**Final Answer:** Natalia sold a total of 72 clips in April and May.<｜end▁of▁sentence｜>']


** Inference Prompt 3:
['<｜begin▁of▁sentence｜><｜User｜>Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?<｜Assistant｜><think>\n\n</think>\n\nTo determine how much Weng earned for babysitting, we can use the following steps:\n\

## SAE Experiments

### Test k-sparse SAE

In [438]:
model_sae_layer = 10

sae = Sae.load_from_hub('EleutherAI/sae-DeepSeek-R1-Distill-Qwen-1.5B-65k', hookpoint=f'layers.{model_sae_layer}.mlp', device=DEVICE)

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

In [12]:
input_prompt = chat_prompts[2]

model_outs_for_sae = inference(input_prompt)

Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` instance instead by default (as opposed to the legacy tuple of tuples format). If you want to keep returning the legacy format, please set `return_legacy_cache=True`.


In [13]:
model_input_prompt_tokens = tokenizer.apply_chat_template(input_prompt, tokenize=True, add_generation_prompt=True, return_tensors='pt', return_dict=True).input_ids

print(
    decode_inference(model_input_prompt_tokens)
)

['<｜begin▁of▁sentence｜><｜User｜>Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn? Please reason step by step in a few words, and put your final answer within \\boxed{}<｜Assistant｜><think>\n']


In [14]:
for i in range(len(model_input_prompt_tokens)):
    if model_input_prompt_tokens[0, i] != model_outs_for_sae.sequences[0, i]:
        print(f'Fail at {i}: {model_input_prompt_tokens[i]} - {model_outs_for_sae.sequences[0, i]}')
        break

In [439]:
input_prompt_tokens_len = model_input_prompt_tokens.size()[1]

full_output_tokens_len = model_outs_for_sae.sequences.size()[1]
output_tokens_len = full_output_tokens_len - input_prompt_tokens_len

print(f'Input token num: {input_prompt_tokens_len}')
print(f'Output token num: {full_output_tokens_len} | {output_tokens_len} (Not counting input tokens)')

print('Hidden states output len:', len(model_outs_for_sae.hidden_states))
print('Inner hidden states len:', len(model_outs_for_sae.hidden_states[0]))

# (batch_size, sequence_length, hidden_size)
print('Within inner hidden state size:')
print(f'- First state: {model_outs_for_sae.hidden_states[0][0].size()}')
print(f'- Second & subsequent states: {model_outs_for_sae.hidden_states[1][0].size()}')

print()

print('Decoded output:')
print(
    decode_inference(model_outs_for_sae.sequences)
)

Input token num: 55
Output token num: 455 | 400 (Not counting input tokens)
Hidden states output len: 400
Inner hidden states len: 29
Within inner hidden state size:
- First state: torch.Size([1, 55, 1536])
- Second & subsequent states: torch.Size([1, 1, 1536])

Decoded output:
["<｜begin▁of▁sentence｜><｜User｜>Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn? Please reason step by step in a few words, and put your final answer within \\boxed{}<｜Assistant｜><think>\nOkay, so Weng earns $12 an hour for babysitting, and she just did 50 minutes of babysitting yesterday. Hmm, I need to figure out how much she earned for those 50 minutes. Let me think about this step by step.\n\nFirst, her hourly rate is $12. That means for every hour she works, she gets paid $12. But she didn't work an entire hour; she only worked 50 minutes. So I need to calculate her earnings based on that 50-minute duration.\n\nI remember that to find out how m

In [16]:
for i in range(len(model_outs_for_sae.hidden_states)):
    if model_outs_for_sae.hidden_states[i][0].size(1) > 1:
        print(f'Found seq_length > 1 at tuple index: {i}')

Found seq_length > 1 at tuple index: 0


In [440]:
# Get layer 10 acts

model_sae_layer_acts = []

for step_idx in range(len(model_outs_for_sae.hidden_states)):
    step_layer_acts = model_outs_for_sae.hidden_states[step_idx][model_sae_layer]
    # Remove batch dim
    flattened_acts = step_layer_acts.reshape(-1, step_layer_acts.size(-1))
    model_sae_layer_acts.append(flattened_acts)

all_activations = torch.cat(model_sae_layer_acts, dim=0)
generated_only_acts = torch.cat(model_sae_layer_acts[1:], dim=0)

print(f'All acts size: {all_activations.size()}')
print(f'Generated tokens only size: {generated_only_acts.size()}')

All acts size: torch.Size([454, 1536])
Generated tokens only size: torch.Size([399, 1536])


In [441]:
def print_limit_acts(encoder_out):
    print(f'- Min Activation: {torch.min(encoder_out.top_acts)}')
    print(f'- Max Activation: {torch.max(encoder_out.top_acts)}')

In [442]:
all_acts_encoder_out = sae.encode(all_activations)

In [443]:
all_acts_encoder_out.top_acts.size()

torch.Size([454, 32])

In [444]:
all_acts_encoder_out

EncoderOutput(top_acts=tensor([[ 974.5755, 1080.3790, 4630.9312,  ..., 1021.4194,  981.2013,
          966.9274],
        [ 964.7895, 1067.2743, 4580.1421,  ..., 1009.8358,  970.7190,
          956.0707],
        [  13.7577,   10.5662,  102.7398,  ...,   11.4450,   12.3541,
           10.3635],
        ...,
        [  30.0918,    7.8799,    7.5453,  ...,    6.5066,    7.9494,
            6.2494],
        [  23.4145,    6.5452,    6.7281,  ...,    7.9307,    9.2657,
            6.5165],
        [  24.5339,    5.9652,    5.1287,  ...,    9.2237,    6.5299,
            4.8280]], device='cuda:0'), top_indices=tensor([[   65,   569,  1038,  ..., 60642, 62202, 53549],
        [   65,   569,  1038,  ..., 60642, 62202, 12740],
        [  490,   686,  1038,  ..., 59978, 64399, 11100],
        ...,
        [ 1038,  3477,  4608,  ..., 58023, 60685, 59361],
        [ 1038,  3581,  5954,  ..., 61592, 63979, 31157],
        [ 1038,  2920,  3581,  ..., 58023, 58202, 21310]], device='cuda:0'))

In [445]:
generated_only_acts_encoder_out = sae.encode(generated_only_acts)

In [446]:
print('All tokens min and max acts:')
print_limit_acts(all_acts_encoder_out)

print('\nGenerated tokens only min and max acts:')
print_limit_acts(generated_only_acts_encoder_out)

All tokens min and max acts:
- Min Activation: 3.6030309200286865
- Max Activation: 4630.93115234375

Generated tokens only min and max acts:
- Min Activation: 3.6030309200286865
- Max Activation: 86.89747619628906


In [524]:
def get_top_k_encoder_out_by_acts(encoder_out, top_k=None, no_feature_repetition=False):
    top_acts_flat = encoder_out.top_acts.flatten()
    top_indices_flat = encoder_out.top_indices.flatten()

    top_acts_num = top_acts_flat.size(0)
    if top_k is None or top_k > top_acts_num:
        top_k = top_acts_num

    if no_feature_repetition:
        feature_max_act = {}
        for feature_idx_tn, act_val in zip(top_indices_flat, top_acts_flat):
            feature_idx = feature_idx_tn.item()
            if feature_idx not in feature_max_act or act_val > feature_max_act[feature_idx]:
                feature_max_act[feature_idx] = act_val
        
        sorted_features = sorted(feature_max_act.items(), key=lambda x: x[1], reverse=True)
        
        top_k = min(top_k, len(sorted_features))
        top_features = sorted_features[:top_k]
        
        unique_indices = torch.tensor([feature[0] for feature in top_features], device=top_indices_flat.device)
        unique_acts = torch.tensor([feature[1] for feature in top_features], device=top_acts_flat.device)
        
        return SimpleNamespace(
            top_acts=unique_acts,
            top_indices=unique_indices
        )

    top_acts_values, top_flat_acts_pos = torch.topk(top_acts_flat, top_k)
    corresponding_feature_indices = top_indices_flat[top_flat_acts_pos]

    return SimpleNamespace(
        top_acts=top_acts_values,
        top_indices=corresponding_feature_indices
    )

In [448]:
def get_top_k_feature_freq_encoder_out(all_steps_feature_indices, top_k=None):
    flat_freq = torch.bincount(all_steps_feature_indices.flatten())
    feature_indices_with_counts = torch.nonzero(flat_freq, as_tuple=False).squeeze()

    feature_freqs = flat_freq[feature_indices_with_counts]
    
    features_num = feature_indices_with_counts.size(0)
    if top_k is None or top_k > features_num:
        top_k = features_num

    top_k_freqs, top_k_indices = torch.topk(feature_freqs, top_k)
    top_k_features = feature_indices_with_counts[top_k_indices]

    return SimpleNamespace(
        feature_indices=top_k_features,
        freqs=top_k_freqs
    )

In [None]:
def relate_feature_acts_to_tokens(encoder_out, feature_idx, sample_tokens, top_k=None, no_token_repetition=False):
    feature_acts = torch.where(
        encoder_out.top_indices == feature_idx,
        encoder_out.top_acts,
        torch.tensor(0.0, device=encoder_out.top_acts.device)
    )
    max_acts_per_token = feature_acts.max(dim=1).values
    
    # Not that necessary but just in case
    nonzero_max_acts_mask = max_acts_per_token > 0
    if nonzero_max_acts_mask.sum() == 0:
        print(f'Non-dead feature activations weren\'t found for feature with index {feature_idx}')
        return torch.tensor([]), torch.tensor([])
    
    nonzero_max_acts = max_acts_per_token[nonzero_max_acts_mask]
    nonzero_max_acts_indices = nonzero_max_acts_mask.nonzero(as_tuple=True)[0]

    if no_token_repetition:
        nonzero_tokens = torch.tensor([sample_tokens[i].item() for i in nonzero_max_acts_indices], device=encoder_out.top_acts.device)
        unique_tokens, inverse_indices = torch.unique(nonzero_tokens, return_inverse=True)
        
        unique_tokens_len = len(unique_tokens)
        max_acts_per_unique_token = torch.zeros(unique_tokens_len, device=encoder_out.top_acts.device)
        for i, token_idx in enumerate(inverse_indices):
            max_acts_per_unique_token[token_idx] = max(max_acts_per_unique_token[token_idx], nonzero_max_acts[i])

        if top_k is None or top_k > unique_tokens_len:
            top_k = unique_tokens_len

        top_k_acts, top_k_indices = torch.topk(max_acts_per_unique_token, top_k)
        top_k_tokens = unique_tokens[top_k_indices]
    else:
        non_zero_acts_num = len(nonzero_max_acts)
        if top_k is None or top_k > non_zero_acts_num:
            top_k = non_zero_acts_num
        
        top_k_acts, top_k_indices = torch.topk(nonzero_max_acts, top_k)
        top_k_tokens = torch.tensor([sample_tokens[i].item() for i in nonzero_max_acts_indices[top_k_indices]], device=encoder_out.top_acts.device)
    
    return SimpleNamespace(
        token_ids=top_k_tokens,
        top_acts=top_k_acts
    )

In [498]:
TOP_K_ENCODER_OUT_BY_ACTS = 400
TOP_K_ENCODER_OUT_FEATURE_FREQ = 70

In [503]:
top_k_all_acts_encoder_out = get_top_k_encoder_out_by_acts(all_acts_encoder_out, TOP_K_ENCODER_OUT_BY_ACTS)
top_k_generated_only_acts_encoder_out = get_top_k_encoder_out_by_acts(generated_only_acts_encoder_out, TOP_K_ENCODER_OUT_BY_ACTS)

In [499]:
top_k_all_acts_feature_freq = get_top_k_feature_freq_encoder_out(all_acts_encoder_out.top_indices, TOP_K_ENCODER_OUT_FEATURE_FREQ)
top_k_generated_only_acts_feature_freq = get_top_k_feature_freq_encoder_out(generated_only_acts_encoder_out.top_indices, TOP_K_ENCODER_OUT_FEATURE_FREQ)

In [530]:
def plot_simple_bar(df, x, y, title, labels):
    bar_fig = px.bar(
        df,
        x=x,
        y=y,
        title=title,
        labels=labels
    )
    bar_fig.update_layout(
        xaxis={
            'tickangle': -90,
        },
        yaxis={
            'type': 'log'
        }
    )
    bar_fig.show()

def plot_feature_acts(encoder_out, title):
    all_outs_encoder_out_df = pd.DataFrame({
        'index': encoder_out.top_indices.cpu().numpy().flatten(),
        'act': encoder_out.top_acts.cpu().numpy().flatten()
    })
    all_outs_encoder_out_df['index'] = all_outs_encoder_out_df['index'].astype(str)
    plot_simple_bar(
        all_outs_encoder_out_df,
        x='index',
        y='act',
        title=title,
        labels={'index': 'Feature Index', 'act': 'Activation'}
    )

def plot_feature_freqs(feature_freqs, title):
    all_outs_encoder_out_df = pd.DataFrame({
        'index': feature_freqs.feature_indices.cpu().numpy().flatten(),
        'freq': feature_freqs.freqs.cpu().numpy().flatten()
    })
    all_outs_encoder_out_df['index'] = all_outs_encoder_out_df['index'].astype(str)
    plot_simple_bar(
        all_outs_encoder_out_df,
        x='index',
        y='freq',
        title=title,
        labels={'index': 'Feature Index', 'freq': 'Frequency'}
    )

def plot_acts_histogram(encoder_out, title, nbins=50):
    acts_hist_fig = px.histogram(
        x=encoder_out.top_acts.cpu().numpy().flatten(),
        nbins=nbins,
        title=title,
        labels={'x': 'Activation', 'y': ' Frequency'},
    )
    acts_hist_fig.update_yaxes(type='log')
    acts_hist_fig.show()

def plot_acts_and_tokens_relation_scatter(token_acts_relation, title):
    token_act_df = pd.DataFrame({
        'token_id': token_acts_relation.token_ids.cpu().numpy(),
        'top_acts': token_acts_relation.top_acts.cpu().numpy()
    })
    scatter_fig = px.scatter(
        token_act_df,
        x='token_id',
        y='top_acts',
        title=title,
        labels={'top_acts': 'Top Activations', 'token_id': 'Token ID'}
    )
    scatter_fig.update_traces(marker=dict(size=12, opacity=0.7))
    scatter_fig.show()

def plot_acts_and_tokens_relation_bar(token_acts_relation, title):
    decoded_tokens = decode_inference(token_acts_relation.token_ids)
    decoded_tokens = [
        f'<space>:{token_id}' if token == ' ' else
        f'<colon>:{token_id}' if token == ':' else
        f'{token}:{token_id}'
        for token, token_id in zip(decoded_tokens, token_acts_relation.token_ids.tolist())
    ]
    token_act_df = pd.DataFrame({
        'token': decoded_tokens,
        'top_acts': token_acts_relation.top_acts.cpu().numpy()
    })
    plot_height = max(400, len(token_act_df) * 15)
    acts_tokens_bar_fig = px.bar(
        token_act_df,
        x='top_acts',
        y='token',
        orientation='h',
        title=title,
        labels={'top_acts': 'Top Activation', 'token': 'Token'},
    )
    acts_tokens_bar_fig.update_layout(
        yaxis={
            'categoryorder': 'total ascending',
            'tickmode': 'array',
            'tickvals': token_act_df['token'],
            'ticktext': token_act_df['token']
        },
        xaxis={
            'type': 'log'
        },
        margin=dict(l=200),
        font=dict(size=10),
        height=plot_height
    )
    acts_tokens_bar_fig.show()

In [504]:
plot_feature_acts(top_k_all_acts_encoder_out, title=f'Top {TOP_K_ENCODER_OUT_BY_ACTS} Feature Activations (Including activations from input tokens)')

In [496]:
plot_feature_acts(top_k_generated_only_acts_encoder_out, title=f'Top {TOP_K_ENCODER_OUT_BY_ACTS} Feature Activations (Only activation from generated tokens)')

In [500]:
plot_feature_freqs(top_k_all_acts_feature_freq, title=f'Top {TOP_K_ENCODER_OUT_FEATURE_FREQ} Feature Frequencies (Including features from input tokens)')

In [501]:
plot_feature_freqs(top_k_generated_only_acts_feature_freq, title=f'Top {TOP_K_ENCODER_OUT_FEATURE_FREQ} Feature Frequencies (Only features from generated tokens)')

In [480]:
plot_acts_histogram(all_acts_encoder_out, title='Activations Frequency (Including activations from input tokens)')

In [481]:
plot_acts_histogram(generated_only_acts_encoder_out, title='Activations Frequency (Only activation from generated tokens)', nbins=20)

In [482]:
prominent_feature_idx = 1038
TOKEN_ACTS_RELATION_TOP_K = 70

all_acts_token_acts_relation = relate_feature_acts_to_tokens(
    all_acts_encoder_out,
    feature_idx=prominent_feature_idx,
    sample_tokens=model_outs_for_sae.sequences.squeeze(),
    no_token_repetition=True
)

all_acts_token_acts_relation_top_k = relate_feature_acts_to_tokens(
    all_acts_encoder_out,
    feature_idx=prominent_feature_idx,
    sample_tokens=model_outs_for_sae.sequences.squeeze(),
    top_k=TOKEN_ACTS_RELATION_TOP_K,
    no_token_repetition=True
)

In [483]:
plot_acts_and_tokens_relation_scatter(
    all_acts_token_acts_relation,
    title=f'Activating Tokens for Feature with index {prominent_feature_idx}'
)

In [531]:
plot_acts_and_tokens_relation_bar(
    all_acts_token_acts_relation_top_k,
    title=f'Top {TOKEN_ACTS_RELATION_TOP_K} Activating Tokens/Token ID for Feature with index {prominent_feature_idx}'
)

In [455]:
model_out_generated_only_token_ids = model_outs_for_sae.sequences.squeeze()[input_prompt_tokens_len:]

generated_only_token_acts_relation = relate_feature_acts_to_tokens(
    generated_only_acts_encoder_out,
    feature_idx=prominent_feature_idx,
    sample_tokens=model_out_generated_only_token_ids,
    no_token_repetition=True
)

generated_only_token_acts_relation_top_k = relate_feature_acts_to_tokens(
    generated_only_acts_encoder_out,
    feature_idx=prominent_feature_idx,
    sample_tokens=model_out_generated_only_token_ids,
    top_k=TOKEN_ACTS_RELATION_TOP_K,
    no_token_repetition=True
)

In [461]:
plot_acts_and_tokens_relation_scatter(
    generated_only_token_acts_relation,
    title=f'Activating Tokens for Feature with index {prominent_feature_idx} (Generated only tokens)'
)

In [462]:
plot_acts_and_tokens_relation_bar(
    generated_only_token_acts_relation_top_k,
    title=f'Top {TOKEN_ACTS_RELATION_TOP_K} Activating Tokens/Token ID for Feature with index {prominent_feature_idx} (Generated only tokens)'
)