In [1]:
try:
    import google.colab # type: ignore
    from google.colab import output
    COLAB = True
    %pip install sae-lens transformer-lens
except:
    COLAB = False
    from IPython import get_ipython # type: ignore
    ipython = get_ipython(); assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Standard imports
import os
import torch
import numpy as np
from tqdm import tqdm
import plotly.express as px
import pandas as pd
import einops
from jaxtyping import Float, Int
from torch import Tensor

torch.set_grad_enabled(False)

# Device setup
GPU_TO_USE = 7

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = f"cuda:{GPU_TO_USE}" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

# utility to clear variables out of the memory & and clearing cuda cache
import gc
def clear_cache():
    gc.collect()
    torch.cuda.empty_cache()

Device: cuda:7


In [2]:
from pathlib import Path

def get_data_path(data_folder, in_colab=COLAB):
  if in_colab:
    from google.colab import drive
    drive.mount('/content/drive')

    return Path(f'/content/drive/MyDrive/{data_folder}')
  else:
    return Path(f'./{data_folder}')

In [3]:
datapath = get_data_path('./data')
datapath

PosixPath('data')

In [4]:
import sys
import os

# Add the parent directory (sfc_deception) to sys.path
sys.path.append(os.path.abspath(os.path.join('..')))

## Loading the model

In [5]:
from sae_lens import SAE, HookedSAETransformer, ActivationsStore

USE_INSTRUCT = False
PARAMS_COUNT = 2
MODEL_NAME = f'gemma-2-{PARAMS_COUNT}b' + ('-it' if USE_INSTRUCT else '')
print(f'Using {MODEL_NAME}')


model = HookedSAETransformer.from_pretrained(MODEL_NAME, device=device, dtype=torch.bfloat16)
model

  from .autonotebook import tqdm as notebook_tqdm


Using gemma-2-2b


Loading checkpoint shards: 100%|██████████| 3/3 [00:01<00:00,  2.48it/s]


Loaded pretrained model gemma-2-2b into HookedTransformer


HookedSAETransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (blocks): ModuleList(
    (0-25): 26 x TransformerBlock(
      (ln1): RMSNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln1_post): RMSNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): RMSNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2_post): RMSNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): GroupedQueryAttention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
        (hook_rot_k): HookPoint()
        (hook_rot_q): HookPoint()
      )
      (mlp): GatedMLP(
        (hook_pre): HookPoint()
        (hook_pre_linear): HookPoint()
      

## Loading the data

In [6]:
from classes.sfc_data_loader import SFCDatasetLoader
import utils.prompts as prompts
from utils.enums import *

In [7]:
DATASET_NAME = SupportedDatasets.VERB_AGREEMENT

if DATASET_NAME in [SupportedDatasets.CITIES, SupportedDatasets.FACTS, SupportedDatasets.COMPANIES]:
    dataloader = SFCDatasetLoader(DATASET_NAME, model, 
                                  clean_system_prompt=prompts.TRUTH_OR_USER_KILLED,
                                  corrupted_system_prompt=prompts.LIE_OR_USER_KILLED,
                                  task_prompt=prompts.ANSWER_TRUE_FALSE,
                                  local_dataset=True, base_folder_path=datapath,
                                  )
elif DATASET_NAME in [SupportedDatasets.COMMONSENSE_QA, SupportedDatasets.COMMONSENSE_QA_FILTERED]:
    dataloader = SFCDatasetLoader(DATASET_NAME, model, 
                                clean_system_prompt=prompts.TRUTH_OR_USER_KILLED,
                                corrupted_system_prompt=prompts.LIE_OR_USER_KILLED,
                                task_prompt=prompts.OUTPUT_SINGLE_LETTER)
elif DATASET_NAME in [SupportedDatasets.VERB_AGREEMENT]:
    dataloader = SFCDatasetLoader(DATASET_NAME, model,
                                  local_dataset=True, base_folder_path=datapath)
else:
    raise ValueError(f"Dataset {DATASET_NAME.value} is not supported")



In [8]:
if DATASET_NAME in [SupportedDatasets.VERB_AGREEMENT]:
    clean_dataset, corrupted_dataset = dataloader.get_clean_corrupted_datasets(tokenize=True, apply_chat_template=False, prepend_generation_prefix=True)
else:
    clean_dataset, corrupted_dataset = dataloader.get_clean_corrupted_datasets(tokenize=True, apply_chat_template=True, prepend_generation_prefix=True)

Figuring out optimal padding length...
Filtered out 484 longest prompts from a total of 10560 prompts.
Setting max prompt length to 8


100%|██████████| 10076/10076 [00:15<00:00, 669.37it/s]


In [9]:
CONTROL_SEQ_LEN = clean_dataset['control_sequence_length'][0].item()
N_CONTEXT = clean_dataset['prompt'].shape[1]

CONTROL_SEQ_LEN, N_CONTEXT

(2, 8)

In [10]:
print('Clean dataset:')
for prompt in clean_dataset['prompt'][:3]:
  print("\nPrompt:", model.to_string(prompt), end='\n\n')

  for i, tok in enumerate(prompt):
    str_token = model.to_string(tok)
    print(f"({i-CONTROL_SEQ_LEN}, {str_token})", end=' ')
  print()

print('Corrupted dataset:')
for prompt in corrupted_dataset['prompt'][:3]:
  print("\nPrompt:", model.to_string(prompt), end='\n\n')
  
  for i, tok in enumerate(prompt):
    str_token = model.to_string(tok)
    print(f"({i-CONTROL_SEQ_LEN}, {str_token})", end=' ')
  print()

Clean dataset:

Prompt: <bos>The doctors that the executives like<pad>

(-2, <bos>) (-1, The) (0,  doctors) (1,  that) (2,  the) (3,  executives) (4,  like) (5, <pad>) 

Prompt: <bos>The fathers that the driver visits<pad>

(-2, <bos>) (-1, The) (0,  fathers) (1,  that) (2,  the) (3,  driver) (4,  visits) (5, <pad>) 

Prompt: <bos>The boys that the parents inform<pad>

(-2, <bos>) (-1, The) (0,  boys) (1,  that) (2,  the) (3,  parents) (4,  inform) (5, <pad>) 
Corrupted dataset:

Prompt: <bos>The doctor that the executives like<pad>

(-2, <bos>) (-1, The) (0,  doctor) (1,  that) (2,  the) (3,  executives) (4,  like) (5, <pad>) 

Prompt: <bos>The father that the driver visits<pad>

(-2, <bos>) (-1, The) (0,  father) (1,  that) (2,  the) (3,  driver) (4,  visits) (5, <pad>) 

Prompt: <bos>The boy that the parents inform<pad>

(-2, <bos>) (-1, The) (0,  boy) (1,  that) (2,  the) (3,  parents) (4,  inform) (5, <pad>) 


In [11]:
# Sanity checks

# Control sequence length must be the same for all samples in both datasets
clean_ds_control_len = clean_dataset['control_sequence_length']
corrupted_ds_control_len = corrupted_dataset['control_sequence_length']

assert torch.all(corrupted_ds_control_len == corrupted_ds_control_len[0]), "Control sequence length is not the same for all samples in the dataset"
assert torch.all(clean_ds_control_len == clean_ds_control_len[0]), "Control sequence length is not the same for all samples in the dataset"
assert clean_ds_control_len[0] == corrupted_ds_control_len[0], "Control sequence length is not the same for clean and corrupted samples in the dataset"
assert clean_dataset['answer'].max().item() < model.cfg.d_vocab, "Clean answers exceed vocab size"
assert corrupted_dataset['answer'].max().item() < model.cfg.d_vocab, "Patched answers exceed vocab size"
assert (clean_dataset['answer_pos'] < N_CONTEXT).all().item(), "Answer positions exceed logits length"
assert (corrupted_dataset['answer_pos'] < N_CONTEXT).all().item(), "Answer positions exceed logits length"

In [12]:
def sample_dataset(start_idx=0, end_idx=-1, clean_dataset=None, corrupted_dataset=None):
    assert clean_dataset is not None or corrupted_dataset is not None, 'At least one dataset must be provided.'
    return_values = []

    for key in ['prompt', 'answer', 'answer_pos', 'attention_mask']:
        if clean_dataset is not None:
            return_values.append(clean_dataset[key][start_idx:end_idx])
        if corrupted_dataset is not None:
            return_values.append(corrupted_dataset[key][start_idx:end_idx])

    return return_values

## Loading the SAEs

In [13]:
from classes.sfc_model import *

RUN_WITH_SAES = True

if RUN_WITH_SAES:
    caching_device = device 
else:
    caching_device = "cuda:6"


print('Running the model' + (' with SAEs' if RUN_WITH_SAES else '') + '...')
print(f'Device for SAEs: {caching_device}')

Running the model with SAEs...
Device for SAEs: cuda:7


In [14]:
clear_cache()

sfc_model = SFC_Gemma(model, params_count=PARAMS_COUNT, control_seq_len=CONTROL_SEQ_LEN, 
                      attach_saes=RUN_WITH_SAES, caching_device=caching_device)
sfc_model.print_saes()

clear_cache()

# sfc_model.model.cfg
# , sfc_model.saes[0].cfg.dtype

Using 16K SAEs for the first 26 layers, the rest 0 layer(s) - 131k SAEs
Number of SAEs: 78
blocks.0.hook_resid_post SAE(
  (activation_fn): ReLU()
  (hook_sae_input): HookPoint()
  (hook_sae_acts_pre): HookPoint()
  (hook_sae_acts_post): HookPoint()
  (hook_sae_output): HookPoint()
  (hook_sae_recons): HookPoint()
  (hook_sae_error): HookPoint()
)
blocks.1.hook_resid_post SAE(
  (activation_fn): ReLU()
  (hook_sae_input): HookPoint()
  (hook_sae_acts_pre): HookPoint()
  (hook_sae_acts_post): HookPoint()
  (hook_sae_output): HookPoint()
  (hook_sae_recons): HookPoint()
  (hook_sae_error): HookPoint()
)
blocks.2.hook_resid_post SAE(
  (activation_fn): ReLU()
  (hook_sae_input): HookPoint()
  (hook_sae_acts_pre): HookPoint()
  (hook_sae_acts_post): HookPoint()
  (hook_sae_output): HookPoint()
  (hook_sae_recons): HookPoint()
  (hook_sae_error): HookPoint()
)
blocks.3.hook_resid_post SAE(
  (activation_fn): ReLU()
  (hook_sae_input): HookPoint()
  (hook_sae_acts_pre): HookPoint()
  (hook_s

# Getting the SFC nodes

In [15]:
import pickle

def load_dict(nodes_prefix='rc_dataset', aggregated_tokens=True):
    aggregation_type = 'All_tokens' if aggregated_tokens else 'None'

    if nodes_prefix:
        filename = f'{aggregation_type}_agg_{nodes_prefix}_scores.pkl'
    else:
        filename = f'{aggregation_type}_agg_scores.pkl'

    print(f'Loading {filename}...')
    filename = datapath / filename

    with open(filename, 'rb') as f:
        data_dict = pickle.load(f)

    return data_dict

nodes_dict = load_dict(aggregated_tokens=False)
nodes_dict.keys()

Loading None_agg_rc_dataset_scores.pkl...


dict_keys(['blocks.0.attn.hook_z.hook_sae_error', 'blocks.0.attn.hook_z.hook_sae_acts_post', 'blocks.0.hook_mlp_out.hook_sae_error', 'blocks.0.hook_mlp_out.hook_sae_acts_post', 'blocks.0.hook_resid_post.hook_sae_error', 'blocks.0.hook_resid_post.hook_sae_acts_post', 'blocks.1.attn.hook_z.hook_sae_error', 'blocks.1.attn.hook_z.hook_sae_acts_post', 'blocks.1.hook_mlp_out.hook_sae_error', 'blocks.1.hook_mlp_out.hook_sae_acts_post', 'blocks.1.hook_resid_post.hook_sae_error', 'blocks.1.hook_resid_post.hook_sae_acts_post', 'blocks.2.attn.hook_z.hook_sae_error', 'blocks.2.attn.hook_z.hook_sae_acts_post', 'blocks.2.hook_mlp_out.hook_sae_error', 'blocks.2.hook_mlp_out.hook_sae_acts_post', 'blocks.2.hook_resid_post.hook_sae_error', 'blocks.2.hook_resid_post.hook_sae_acts_post', 'blocks.3.attn.hook_z.hook_sae_error', 'blocks.3.attn.hook_z.hook_sae_acts_post', 'blocks.3.hook_mlp_out.hook_sae_error', 'blocks.3.hook_mlp_out.hook_sae_acts_post', 'blocks.3.hook_resid_post.hook_sae_error', 'blocks.3.ho

In [16]:
nodes_dict['blocks.0.attn.hook_z.hook_sae_acts_post'].shape, nodes_dict['blocks.0.attn.hook_z.hook_sae_error'].shape

(torch.Size([8, 16384]), torch.Size([8]))

In [17]:
def filter_sae_importance(data_dict, threshold):
    """
    Filters the input dictionary, keeping only the entries where at least one value exceeds the threshold.
    Also returns binary masks indicating which elements exceed the threshold.

    Parameters:
    - data_dict (dict): A dictionary where keys are activation names and values are torch tensors.
    - threshold (float): The threshold for filtering.

    Returns:
    - dict: A filtered dictionary containing only the keys where torch.any(tensor > threshold) is True.
    - dict: A dictionary of binary masks indicating which elements exceed the threshold.
    """
    filtered_dict = {}
    masks_dict = {}
    
    for key, tensor in data_dict.items():
        mask = tensor > threshold
        if torch.any(mask):
            filtered_dict[key] = tensor
            masks_dict[key] = mask
    
    return filtered_dict, masks_dict


filtered_nodes, nodes_mask = filter_sae_importance(nodes_dict, 0.005)
filtered_nodes.keys(), nodes_mask.keys()

(dict_keys(['blocks.5.hook_resid_post.hook_sae_acts_post', 'blocks.14.hook_resid_post.hook_sae_acts_post', 'blocks.15.hook_resid_post.hook_sae_error', 'blocks.15.hook_resid_post.hook_sae_acts_post', 'blocks.16.hook_resid_post.hook_sae_acts_post', 'blocks.17.hook_mlp_out.hook_sae_acts_post', 'blocks.17.hook_resid_post.hook_sae_error', 'blocks.18.hook_resid_post.hook_sae_error', 'blocks.18.hook_resid_post.hook_sae_acts_post', 'blocks.19.attn.hook_z.hook_sae_acts_post', 'blocks.19.hook_mlp_out.hook_sae_acts_post', 'blocks.19.hook_resid_post.hook_sae_error', 'blocks.19.hook_resid_post.hook_sae_acts_post', 'blocks.20.hook_resid_post.hook_sae_acts_post', 'blocks.21.hook_mlp_out.hook_sae_acts_post', 'blocks.21.hook_resid_post.hook_sae_acts_post', 'blocks.22.hook_mlp_out.hook_sae_acts_post', 'blocks.22.hook_resid_post.hook_sae_acts_post', 'blocks.23.hook_resid_post.hook_sae_acts_post', 'blocks.24.hook_resid_post.hook_sae_acts_post', 'blocks.25.hook_resid_post.hook_sae_acts_post']),
 dict_keys(

In [38]:
def act_name_to_hook_name(act_name, return_node_type=False):
    """
    Extracts the hook name from an SAE activation name.
    For example:
        'blocks.0.attn.hook_z.hook_sae_acts_post' -> 'hook_z'
        'blocks.19.hook_mlp_out.hook_sae_error' -> 'hook_mlp_out'
        'blocks.18.hook_resid_post.hook_sae_acts_post' -> 'hook_resid_post'
    """
    parts = act_name.split('.')
    if parts[-1] not in ['hook_sae_acts_post', 'hook_sae_error']:
        raise ValueError("Activation name is not an SAE activation name.")
    if not parts[-2].startswith('hook_'):
        raise ValueError("Activation name does not follow the expected format.")

    return parts[-2] if not return_node_type else (parts[-2], parts[-1])

def act_name_to_layer_number(act_name):
    # Split the input string by periods
    parts = act_name.split('.')
    
    # Validate that the string has the correct format
    if len(parts) < 3 or parts[0] != 'blocks':
        raise ValueError(f"Input string must start with 'blocks.<index>.'. Got: {act_name}")
    
    # Extract and return the block number as an integer
    return int(parts[1])


def get_adjacent_nodes(filtered_nodes, nodes_mask):
    """
    Returns a dictionary where keys are SFC parent nodes and values are their children nodes, i.e. the nodes that are directly upstream w.r.t.
    the parent nodes. More precisely, the format is:
    {
        'parent_node_name': [
            ('child_node_1', None),
            ('child_node_2', None),
            ...
        ],
        ...
    }
    where None value is to be filled with the gradient of the child node w.r.t. the parent node.
    Additionally, returns the metadata of the adjacent nodes.
    """
    def is_node_adjacent(parent_name, parent_layer, child_name, child_layer):
        """
        Checks if two hooks can be adjacent based on the hook name.
        """
        if parent_name == 'hook_resid_post':
            if child_name == 'hook_resid_post':
                return parent_layer - 1 == child_layer
            else:
                return parent_layer == child_layer

        if parent_name == 'hook_mlp_out':
            if child_name == 'hook_resid_post':
                return parent_layer - 1 == child_layer
            elif child_name == 'hook_z':
                return parent_layer == child_layer
            else:
                return False

        if parent_name == 'hook_z':
            return child_name == 'hook_resid_post' and parent_layer - 1 == child_layer
        
        raise ValueError(f"Hook name is not recognized: {parent_name}")
    
    
    adjacent_nodes = {}
    adjacent_nodes_metadata = {}

    for parent_act in filtered_nodes.keys():
        parent_hook_name, parent_node_type = act_name_to_hook_name(parent_act, return_node_type=True)
        parent_layer_num = act_name_to_layer_number(parent_act)

        for child_act in filtered_nodes.keys():
            child_hook_name, child_node_type = act_name_to_hook_name(child_act, return_node_type=True)
            child_layer_num = act_name_to_layer_number(child_act)

            if is_node_adjacent(parent_hook_name, parent_layer_num, child_hook_name, child_layer_num):
                if parent_act not in adjacent_nodes:
                    adjacent_nodes[parent_act] = []
                    adjacent_nodes_metadata[parent_act] = []

                adjacent_nodes[parent_act].append((child_act, None)) # initialize the gradient of the child node to None
                adjacent_nodes_metadata[parent_act].append({
                    'child_name': child_act,
                    'parent_hook': parent_hook_name,
                    'child_hook': child_hook_name,
                    'parent_layer': parent_layer_num,
                    'child_layer': child_layer_num,
                    'is_parent_error': parent_node_type == 'hook_sae_error',
                    'is_child_error': child_node_type == 'hook_sae_error',
                    'parent_threshold_mask': nodes_mask[parent_act],
                    'child_threshold_mask': nodes_mask[child_act],
                })
            
    return adjacent_nodes, adjacent_nodes_metadata

def print_nodes_metadata(nodes_metadata):
    for parent_act, metadata_list in nodes_metadata.items():
        print(f"\nParent activation: {parent_act}")
        print(f"  Parent hook: {metadata_list[0]['parent_hook']}")
        print(f"  Parent layer: {metadata_list[0]['parent_layer']}")
        print(f"  Is parent error: {metadata_list[0]['is_parent_error']}")
        print(f"  Parent threshold mask shape: {metadata_list[0]['parent_threshold_mask'].shape}")
        print('-'*47)

        for metadata in metadata_list:
            print(f"  Child activation: {metadata['child_name']}")
            print(f"  Child hook: {metadata['child_hook']}")
            print(f"  Child layer: {metadata['child_layer']}")
            print(f"  Is child error: {metadata['is_child_error']}")
            print(f"  Child threshold mask shape: {metadata['child_threshold_mask'].shape}")
            print('-'*47)

In [19]:
print(filtered_nodes.keys())

dict_keys(['blocks.5.hook_resid_post.hook_sae_acts_post', 'blocks.14.hook_resid_post.hook_sae_acts_post', 'blocks.15.hook_resid_post.hook_sae_error', 'blocks.15.hook_resid_post.hook_sae_acts_post', 'blocks.16.hook_resid_post.hook_sae_acts_post', 'blocks.17.hook_mlp_out.hook_sae_acts_post', 'blocks.17.hook_resid_post.hook_sae_error', 'blocks.18.hook_resid_post.hook_sae_error', 'blocks.18.hook_resid_post.hook_sae_acts_post', 'blocks.19.attn.hook_z.hook_sae_acts_post', 'blocks.19.hook_mlp_out.hook_sae_acts_post', 'blocks.19.hook_resid_post.hook_sae_error', 'blocks.19.hook_resid_post.hook_sae_acts_post', 'blocks.20.hook_resid_post.hook_sae_acts_post', 'blocks.21.hook_mlp_out.hook_sae_acts_post', 'blocks.21.hook_resid_post.hook_sae_acts_post', 'blocks.22.hook_mlp_out.hook_sae_acts_post', 'blocks.22.hook_resid_post.hook_sae_acts_post', 'blocks.23.hook_resid_post.hook_sae_acts_post', 'blocks.24.hook_resid_post.hook_sae_acts_post', 'blocks.25.hook_resid_post.hook_sae_acts_post'])


In [36]:
adjacent_nodes, adjacent_nodes_metadata = get_adjacent_nodes(filtered_nodes, nodes_mask)
len(adjacent_nodes.keys()), len(filtered_nodes.keys()), len(adjacent_nodes_metadata.keys())

(19, 21, 19)

In [21]:
d_name = 'blocks.19.attn.hook_z.hook_sae_acts_post'
d = filtered_nodes[d_name]

top_d, top_d_id = d.topk(10, dim=-1)
top_d[6], top_d_id[6]

(tensor([0.0086, 0.0041, 0.0023, 0.0019, 0.0006, 0.0003, 0.0003, 0.0003, 0.0003,
         0.0002], device='cuda:6', dtype=torch.bfloat16),
 tensor([ 3716, 10073, 10884, 10500,  6714, 15690,  5795, 10943, 13379,  8022],
        device='cuda:6'))

In [22]:
u_name = 'blocks.18.hook_resid_post.hook_sae_acts_post'
u = filtered_nodes[u_name]

top_u, top_u_id = u.topk(10, dim=-1)
top_u[6], top_u_id[6]

(tensor([0.0129, 0.0017, 0.0009, 0.0006, 0.0005, 0.0005, 0.0004, 0.0004, 0.0004,
         0.0004], device='cuda:6', dtype=torch.bfloat16),
 tensor([10665,  4442, 11484,  9777,  9525,  1271,  1495,  2343, 11381, 15387],
        device='cuda:6'))

In [23]:
d_feature = 3716
u_feature = 10665

In [24]:
u_hook_name = act_name_to_hook_name(u_name)
d_hook_name = act_name_to_hook_name(d_name)

d_layer = act_name_to_layer_number(d_name)
u_layer = act_name_to_layer_number(u_name)

d_hook_name, d_layer, u_hook_name, u_layer

('hook_z', 19, 'hook_resid_post', 18)

# Studying the toy example

In [23]:
clean_prompts, corrupted_prompts, clean_answers, corrupted_answers, clean_answers_pos, corrupted_answers_pos, \
    clean_attn_mask, corrupted_attn_mask = sample_dataset(0, 1, clean_dataset=clean_dataset, corrupted_dataset=corrupted_dataset)

clean_prompts.shape

torch.Size([1, 8])

In [24]:
clear_cache()
sfc_model.model.reset_hooks()
if sfc_model.are_saes_attached():
    print('Resetting SAE hooks...')
    sfc_model._reset_sae_hooks()

Resetting SAE hooks...


In [25]:
# Storing the forward cache for the selected nodes
fwd_cache = {}
fwd_cache_filter = lambda name: name in filtered_nodes.keys()

def forward_cache_hook(act, hook):
    fwd_cache[hook.name] = act

model.add_hook(fwd_cache_filter, forward_cache_hook, "fwd")

In [334]:
u_name

'blocks.18.hook_resid_post.hook_sae_acts_post'

In [335]:
# Attaching the backward hook to compute edges only for the selected nodes
bwd_cache = {}
bwd_cache_filter = lambda name: 'resid' in name or 'hook_z' in name or 'hook_mlp_out' in name

def backward_cache_hook(grad, hook):
    print(hook.name)

    # If the node is at the source node SAE site, store the gradients and stop the propagation
    if u_name == hook.name:
        print(f'Storing...')

        bwd_cache[hook.name] = grad

        raise StopIteration()

    # If the node is at the target node SAE site, let it flow to its destination (SAE latent/error tern)
    if u_hook_name in hook.name and hook.layer() == u_layer:
        print(f'Letting the gradients flow...')
        return (grad,)

    # If the node is the starting node, let the gradients flow
    if d_hook_name in hook.name and hook.layer() == d_layer:
        bwd_cache[hook.name] = grad
        
        print(f'Letting the gradients flow...')
        return (grad,)
    
    if 'resid' in hook.name:
        bwd_cache[hook.name] = grad
        return (grad,)

    print(f'Stopping the gradients...')
    # If the node is intermediate, stop the propagation of gradients
    return (None,)

model.add_hook(bwd_cache_filter, backward_cache_hook, "bwd")

In [337]:
d_act = fwd_cache[d_name]
d_act.shape, d_act.requires_grad

(torch.Size([1, 8, 16384]), True)

In [338]:
import torch.nn.functional as F

d_act_grad = F.one_hot(torch.tensor([d_feature]), num_classes=d_act.shape[-1]).squeeze().to(device)
d_act_grad = einops.repeat(d_act_grad, 'd -> batch pos d', batch=d_act.shape[0], pos=d_act.shape[1])

d_act_grad.shape

torch.Size([1, 8, 16384])

In [339]:
try:
    d_act.backward(d_act_grad)
except StopIteration:
    print(f'Gradient backpropagation stopped at {u_name}')

bwd_cache[u_name], bwd_cache.keys()

blocks.19.attn.hook_z.hook_sae_acts_pre
Letting the gradients flow...
blocks.19.attn.hook_z.hook_sae_input
Letting the gradients flow...
blocks.19.hook_resid_pre
blocks.18.hook_resid_post.hook_sae_output
Letting the gradients flow...
blocks.18.hook_resid_post.hook_sae_recons
Letting the gradients flow...
blocks.18.hook_resid_post.hook_sae_acts_post
Storing...
Gradient backpropagation stopped at blocks.18.hook_resid_post.hook_sae_acts_post


(tensor([[[ 1.1742e-05,  6.1512e-05, -5.4169e-04,  ..., -1.0443e-04,
           -5.1260e-05,  7.0572e-04],
          [-7.5698e-06, -3.4094e-05, -2.9373e-04,  ..., -5.0735e-04,
           -1.3065e-04,  8.0585e-05],
          [ 1.9646e-04,  1.1635e-04,  9.9182e-04,  ...,  1.6403e-03,
            7.2479e-04,  1.5717e-03],
          ...,
          [ 2.0623e-05, -1.4267e-03, -2.8992e-04,  ...,  1.9646e-04,
            4.7302e-04,  4.1580e-04],
          [-2.1210e-03, -2.9755e-03,  5.2214e-05,  ..., -1.9836e-03,
            2.7847e-04,  1.0529e-03],
          [-4.6921e-04, -1.6632e-03, -4.1008e-04,  ..., -1.7395e-03,
            6.8283e-04, -1.3123e-03]]], device='cuda:7', dtype=torch.bfloat16),
 dict_keys(['blocks.19.attn.hook_z.hook_sae_acts_pre', 'blocks.19.attn.hook_z.hook_sae_input', 'blocks.19.hook_resid_pre', 'blocks.18.hook_resid_post.hook_sae_acts_post']))

In [340]:
d_u_grad = bwd_cache[u_name][0, 6, u_feature]
d_u_grad

tensor(-0.0049, device='cuda:7', dtype=torch.bfloat16)

In [341]:
bwd_cache[u_name][0, 6, :].topk(50)

torch.return_types.topk(
values=tensor([0.0228, 0.0170, 0.0140, 0.0121, 0.0118, 0.0116, 0.0115, 0.0113, 0.0113,
        0.0113, 0.0112, 0.0109, 0.0107, 0.0106, 0.0105, 0.0103, 0.0100, 0.0100,
        0.0099, 0.0098, 0.0097, 0.0096, 0.0096, 0.0095, 0.0094, 0.0093, 0.0093,
        0.0090, 0.0090, 0.0087, 0.0087, 0.0086, 0.0086, 0.0084, 0.0084, 0.0083,
        0.0083, 0.0083, 0.0082, 0.0082, 0.0081, 0.0080, 0.0079, 0.0079, 0.0079,
        0.0079, 0.0079, 0.0078, 0.0077, 0.0077], device='cuda:7',
       dtype=torch.bfloat16),
indices=tensor([ 3085,  6305,  2860, 15786,  4239, 14539,  5696,  9579, 10318, 13789,
        14461,  7016,  5443,  4294,  3901,  5798, 11673, 13851, 11250, 12063,
         2862,  3140,  8197, 12368,  4471,   442,  4358, 11832,  2865,  4751,
         6991,  6543, 16017, 11089,  9645,   401,  8353, 13807,  4968,  8969,
        13532, 11087,   901, 10986,  3242,  8251, 11032, 13041, 11831, 11498],
       device='cuda:7'))

In [364]:
u_name

'blocks.18.hook_resid_post.hook_sae_acts_post'

In [373]:
bwd_cache[u_name][0, 2, :].topk(50)

torch.return_types.topk(
values=tensor([0.0091, 0.0074, 0.0071, 0.0069, 0.0065, 0.0064, 0.0063, 0.0062, 0.0060,
        0.0058, 0.0057, 0.0055, 0.0054, 0.0054, 0.0053, 0.0053, 0.0052, 0.0052,
        0.0050, 0.0049, 0.0048, 0.0048, 0.0048, 0.0046, 0.0046, 0.0046, 0.0045,
        0.0044, 0.0044, 0.0044, 0.0044, 0.0043, 0.0043, 0.0043, 0.0043, 0.0042,
        0.0042, 0.0042, 0.0042, 0.0041, 0.0041, 0.0041, 0.0041, 0.0041, 0.0040,
        0.0040, 0.0040, 0.0040, 0.0040, 0.0040], device='cuda:7',
       dtype=torch.bfloat16),
indices=tensor([ 4621, 10927,  5747,  3369,  1027, 14824,  9671, 10665,  6790,  1882,
         8104,  8129,  2502,  7379,   541, 12123, 14017, 14774,  3751, 11626,
         8554,  1708, 12256,  8437,  5322, 11016, 11543,  1885, 12781,  7679,
         9603, 13836, 15108, 12841,  3358,  1506,  2086,  2650,  5078,   226,
         1524,  4377,  4211,  9294,  6004,  1546,  6217, 14754,   819, 11225],
       device='cuda:7'))

In [377]:
bwd_cache[u_name][0, 2, 15377]

tensor(0.0005, device='cuda:7', dtype=torch.bfloat16)

In [367]:
u_name

'blocks.18.hook_resid_post.hook_sae_acts_post'

In [374]:
filtered_nodes[u_name][2, :].topk(10)

torch.return_types.topk(
values=tensor([4.5166e-03, 4.2419e-03, 1.0910e-03, 1.0529e-03, 2.1553e-04, 1.6689e-04,
        1.5545e-04, 1.1110e-04, 1.0443e-04, 9.3937e-05], device='cuda:6',
       dtype=torch.bfloat16),
indices=tensor([10665,  1506, 15377,  1271,  4973,  1286, 15152,  4856,  5676,  2856],
       device='cuda:6'))

In [366]:
filtered_nodes[d_name][6, :].topk(10)

torch.return_types.topk(
values=tensor([0.0086, 0.0041, 0.0023, 0.0019, 0.0006, 0.0003, 0.0003, 0.0003, 0.0003,
        0.0002], device='cuda:6', dtype=torch.bfloat16),
indices=tensor([ 3716, 10073, 10884, 10500,  6714, 15690,  5795, 10943, 13379,  8022],
       device='cuda:6'))

In [325]:
# resid_post_grad = bwd_cache['blocks.19.hook_resid_post.hook_sae_input'][0, 6, :]
resid_mid_grad = bwd_cache['blocks.19.hook_resid_mid'][0, 6, :]
resid_pre_grad = bwd_cache['blocks.19.hook_resid_pre'][0, 6, :]

# Check if the gradients are the same
# torch.all(resid_post_grad == resid_mid_grad) 
torch.all(resid_mid_grad == resid_pre_grad)


tensor(True, device='cuda:7')

In [288]:
bwd_cache.keys()

dict_keys(['blocks.19.hook_resid_post.hook_sae_acts_pre', 'blocks.19.hook_resid_post.hook_sae_input', 'blocks.19.hook_resid_mid', 'blocks.19.attn.hook_z.hook_sae_acts_post'])

# Scaling edge calculation

In [39]:
print_nodes_metadata(adjacent_nodes_metadata)


Parent activation: blocks.15.hook_resid_post.hook_sae_error
  Parent hook: hook_resid_post
  Parent layer: 15
  Is parent error: True
  Parent threshold mask shape: torch.Size([8])
-----------------------------------------------
  Child activation: blocks.14.hook_resid_post.hook_sae_acts_post
  Child hook: hook_resid_post
  Child layer: 14
  Is child error: False
  Child threshold mask shape: torch.Size([8, 16384])
-----------------------------------------------

Parent activation: blocks.15.hook_resid_post.hook_sae_acts_post
  Parent hook: hook_resid_post
  Parent layer: 15
  Is parent error: False
  Parent threshold mask shape: torch.Size([8, 16384])
-----------------------------------------------
  Child activation: blocks.14.hook_resid_post.hook_sae_acts_post
  Child hook: hook_resid_post
  Child layer: 14
  Is child error: False
  Child threshold mask shape: torch.Size([8, 16384])
-----------------------------------------------

Parent activation: blocks.16.hook_resid_post.hook_s

In [26]:
clean_prompts, corrupted_prompts, clean_answers, corrupted_answers, clean_answers_pos, corrupted_answers_pos, \
    clean_attn_mask, corrupted_attn_mask = sample_dataset(0, 1, clean_dataset=clean_dataset, corrupted_dataset=corrupted_dataset)

clear_cache()

sfc_model.detach_saes_except_few(filtered_nodes.keys())
sfc_model.model.reset_hooks()

if sfc_model.are_saes_attached():
    print('Resetting SAE hooks...')
    sfc_model._reset_sae_hooks()

Detached 57 SAEs from the model.
Discarded the remaining SAEs from memory.
Resetting SAE hooks...


In [27]:
# 1. Storing the forward cache for all the filtered nodes
fwd_cache = {}
fwd_cache_filter = lambda name: name in filtered_nodes.keys()

def forward_cache_hook(act, hook):
    fwd_cache[hook.name] = act

model.add_hook(fwd_cache_filter, forward_cache_hook, "fwd")

In [28]:
# 2. Storing the backward cache for the parent nodes
bwd_cache = {}
bwd_cache_filter = lambda name: name in adjacent_nodes.keys() or 'hook_sae_output' in name or 'hook_sae_input' in name

temp_cache = {}

def backward_cache_hook(gradient, hook):
    if 'hook_sae_output' in hook.name:
        hook_sae_error_name = hook.name.replace('hook_sae_output', 'hook_sae_error')
        if hook_sae_error_name in adjacent_nodes.keys():
            bwd_cache[hook_sae_error_name] = gradient.detach()

        # We're storing the gradients for the SAE output activations to copy them to the SAE input activations gradients
        if not 'hook_z' in hook.name:
            temp_cache[hook.name] = gradient.detach()
        else: # In the case of attention hook_z hooks, reshape them to match the SAE input shape, which doesn't include n_heads
            hook_z_grad = einops.rearrange(gradient.detach(),
                                        'batch pos n_head d_head -> batch pos (n_head d_head)')
            temp_cache[hook.name] = hook_z_grad
    elif 'hook_sae_input' in hook.name:
        # We're copying the gradients from the SAE output activations to the SAE input activations gradients
        sae_output_grad_name = hook.name.replace('hook_sae_input', 'hook_sae_output')

        gradient = temp_cache[sae_output_grad_name]

        # Pass-through: use the downstream gradients
        return (gradient,)
    elif hook.name in adjacent_nodes.keys():
        # Default case: just store the gradients
        bwd_cache[hook.name] = gradient.detach()

model.add_hook(bwd_cache_filter, backward_cache_hook, "bwd")

In [29]:
# Utility function to compute the metric
def get_answer_logit(logits: Float[Tensor, "batch pos d_vocab"], clean_answers: Int[Tensor, "batch"],
                        ansnwer_pos: Int[Tensor, "batch"], return_all_logits=False) -> Float[Tensor, "batch"]:
    # clean_answers_pos_idx = clean_answers_pos.unsqueeze(-1).unsqueeze(-1).expand(-1, logits.size(1), logits.size(2))

    answer_pos_idx = einops.repeat(ansnwer_pos, 'batch -> batch 1 d_vocab',
                                    d_vocab=logits.shape[-1])
    answer_logits = logits.gather(1, answer_pos_idx).squeeze(1) # shape [batch, d_vocab]

    correct_logits = answer_logits.gather(1, clean_answers.unsqueeze(1)).squeeze(1) # shape [batch]

    if return_all_logits:
        return answer_logits, correct_logits

    return correct_logits

def get_logit_diff(logits: Float[Tensor, "batch pos d_vocab"],
                clean_answers: Int[Tensor, "batch"], patched_answers: Int[Tensor, "batch count"],
                answer_pos: Int[Tensor, "batch"], patch_answer_reduce='max') -> Float[Tensor, "batch"]:
    # Compute the logits for the correct answers and the tokens they have been computed at (answer_logits)
    answer_logits, correct_logits = get_answer_logit(logits, clean_answers, answer_pos, return_all_logits=True)

    if patched_answers.dim() == 1:  # If there's only one incorrect answer, gather the incorrect answer logits
        incorrect_logits = answer_logits.gather(1, patched_answers.unsqueeze(1)).squeeze(1)  # shape [batch]
    else:
        incorrect_logits = answer_logits.gather(1, patched_answers)  # shape [batch, answer_count]

    # If there are multiple incorrect answer options, incorrect_logits is now of shape [batch, answer_count]
    if patched_answers.dim() == 2:
        # Sum the logits for each incorrect answer option
        if patch_answer_reduce == 'sum':
            incorrect_logits = incorrect_logits.sum(dim=1)
        # Or take their maximum: this should be a better option to avoid situations where the model outputs gibberish and all the answers have similar logits
        elif patch_answer_reduce == 'max':
            incorrect_logits = incorrect_logits.max(dim=1).values

    # Both logit tensors are now of shape [batch]
    return incorrect_logits - correct_logits

In [30]:
# Performing the forward pass to get the caches
metric_clean = lambda logits: get_logit_diff(logits, clean_answers, corrupted_answers, clean_answers_pos).mean()

# Enable gradients only during the backward pass
with torch.set_grad_enabled(True):
    metric_value = metric_clean(sfc_model.model(clean_prompts, attention_mask=clean_attn_mask))
    metric_value.backward()  # Compute gradients

fwd_cache.keys(), bwd_cache.keys()

(dict_keys(['blocks.5.hook_resid_post.hook_sae_acts_post', 'blocks.14.hook_resid_post.hook_sae_acts_post', 'blocks.15.hook_resid_post.hook_sae_acts_post', 'blocks.15.hook_resid_post.hook_sae_error', 'blocks.16.hook_resid_post.hook_sae_acts_post', 'blocks.17.hook_mlp_out.hook_sae_acts_post', 'blocks.17.hook_resid_post.hook_sae_error', 'blocks.18.hook_resid_post.hook_sae_acts_post', 'blocks.18.hook_resid_post.hook_sae_error', 'blocks.19.attn.hook_z.hook_sae_acts_post', 'blocks.19.hook_mlp_out.hook_sae_acts_post', 'blocks.19.hook_resid_post.hook_sae_acts_post', 'blocks.19.hook_resid_post.hook_sae_error', 'blocks.20.hook_resid_post.hook_sae_acts_post', 'blocks.21.hook_mlp_out.hook_sae_acts_post', 'blocks.21.hook_resid_post.hook_sae_acts_post', 'blocks.22.hook_mlp_out.hook_sae_acts_post', 'blocks.22.hook_resid_post.hook_sae_acts_post', 'blocks.23.hook_resid_post.hook_sae_acts_post', 'blocks.24.hook_resid_post.hook_sae_acts_post', 'blocks.25.hook_resid_post.hook_sae_acts_post']),
 dict_keys(

In [31]:
bwd_cache.keys() == adjacent_nodes.keys(), fwd_cache.keys() == filtered_nodes.keys(), filtered_nodes.keys()

(True,
 True,
 dict_keys(['blocks.5.hook_resid_post.hook_sae_acts_post', 'blocks.14.hook_resid_post.hook_sae_acts_post', 'blocks.15.hook_resid_post.hook_sae_error', 'blocks.15.hook_resid_post.hook_sae_acts_post', 'blocks.16.hook_resid_post.hook_sae_acts_post', 'blocks.17.hook_mlp_out.hook_sae_acts_post', 'blocks.17.hook_resid_post.hook_sae_error', 'blocks.18.hook_resid_post.hook_sae_error', 'blocks.18.hook_resid_post.hook_sae_acts_post', 'blocks.19.attn.hook_z.hook_sae_acts_post', 'blocks.19.hook_mlp_out.hook_sae_acts_post', 'blocks.19.hook_resid_post.hook_sae_error', 'blocks.19.hook_resid_post.hook_sae_acts_post', 'blocks.20.hook_resid_post.hook_sae_acts_post', 'blocks.21.hook_mlp_out.hook_sae_acts_post', 'blocks.21.hook_resid_post.hook_sae_acts_post', 'blocks.22.hook_mlp_out.hook_sae_acts_post', 'blocks.22.hook_resid_post.hook_sae_acts_post', 'blocks.23.hook_resid_post.hook_sae_acts_post', 'blocks.24.hook_resid_post.hook_sae_acts_post', 'blocks.25.hook_resid_post.hook_sae_acts_post']

In [47]:
current_parent = 'blocks.19.hook_resid_post.hook_sae_error'
parent_metadata = adjacent_nodes_metadata[current_parent][0]

parent_hook = parent_metadata['parent_hook']
parent_layer = parent_metadata['parent_layer']
is_parent_error = parent_metadata['is_parent_error']
parent_mask = parent_metadata['parent_threshold_mask']

adjacent_nodes[current_parent], parent_hook, parent_layer, is_parent_error, parent_mask.shape

([('blocks.18.hook_resid_post.hook_sae_error', None),
  ('blocks.18.hook_resid_post.hook_sae_acts_post', None),
  ('blocks.19.attn.hook_z.hook_sae_acts_post', None),
  ('blocks.19.hook_mlp_out.hook_sae_acts_post', None)],
 'hook_resid_post',
 19,
 True,
 torch.Size([8]))

In [None]:
if is_parent_error:
    # If the parent node is an error node, we need only one backward pass to get the gradients w.r.t. all the children nodes
    # for child_name, _ in adjacent_nodes[current_parent]:
    #     print(f"Processing  {child_name}")
    #     child_metadata = [metadata for metadata in adjacent_nodes_metadata[current_parent] if metadata['child_name'] == child_name][0]
    #     print(child_metadata)

In [43]:
if not is_parent_error:
    # If the parent node is a tensor of SAE latents, we'll need to perform a backward pass for each SAE latent above the threshold
    pass