In [1]:
import os 
import gc
import torch
os.chdir("/workspace/CircuitAnalysisSAEs")
from functools import partial
from typing import Callable, Optional, Sequence, Tuple, Union, overload
import einops
import pandas as pd
import torch
from jaxtyping import Float, Int
from tqdm.auto import tqdm
from typing_extensions import Literal
from transformer_lens.utils import Slice, SliceInput
import sys 
import functools
import re
from collections import defaultdict
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
import json
from sae_lens import SAE, HookedSAETransformer
from utils import plot
from circ4latents import data_gen
# sys.path.append("../../utils/")
with open("config.json", 'r') as file:
    config = json.load(file)
    token = config.get('huggingface_token', None)
os.environ["HF_TOKEN"] = token
if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

hf_cache = "/workspace/hf_cache"
os.environ["HF_HOME"] = "/workspace/hf_cache"

Device: cuda


In [13]:
!nvidia-smi

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Wed Oct 23 07:37:34 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.90.07              Driver Version: 550.90.07      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100 80GB PCIe          On  |   00000000:81:00.0 Off |                    0 |
| N/A   43C    P0             65W /  300W |   69611MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [3]:
model = HookedSAETransformer.from_pretrained("google/gemma-2-9b", device=device, cache_dir=hf_cache)
sae, cfg_dict, sparsity = SAE.from_pretrained(release="gemma-scope-9b-pt-res-canonical", sae_id="layer_10/width_16k/canonical", device=device)



Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2-9b into HookedTransformer


In [None]:
!nvidia-smi

In [None]:
from transformer_lens.utils import test_prompt
from tasks.error_detection.type.data import generate_samples

selected_templates = [1] #, 2, 3, 4, 5]
N = 50
samples = generate_samples(selected_templates, N)
for sample in samples[0]:
    prompt = sample
    print(prompt)

# Token ID for "Traceback"
traceback_token_id = model.tokenizer.encode("Traceback", add_special_tokens=False)[0]

In [None]:
selected_pos  = {
    "s_start": [],
    "s_end": [],
    "i_start": [],
    "i_end": [],
    "end": []
}

for i in range(N):
    str_tokens_clean = model.to_str_tokens(samples[0][i])
    str_tokens_corr = model.to_str_tokens(samples[1][i])
    # Find the positions with differences
    diff_positions = [i for i, (a, b) in enumerate(zip(str_tokens_clean, str_tokens_corr)) if a != b]

    # Find positions of the first '("', the first '"' after '("', and the end position
    pos_open_paren_quote = str_tokens_clean.index('("')
    pos_first_quote_after_open = pos_open_paren_quote + str_tokens_clean[pos_open_paren_quote:].index('"') 
    pos_end = len(str_tokens_clean) - 1  # The last position

    # Return the positions with differences, and the positions found
    # print(diff_positions, pos_open_paren_quote, pos_first_quote_after_open, pos_end)
    # print(str_tokens_clean[pos_first_quote_after_open])
    selected_pos["s_start"].append(pos_open_paren_quote)
    selected_pos["s_end"].append(pos_first_quote_after_open)
    selected_pos["i_start"].append(diff_positions[0])
    selected_pos["i_end"].append(diff_positions[-1])
    selected_pos["end"].append(pos_end)

selected_pos

In [7]:
# %%

for param in model.parameters():
    param.requires_grad_(False)

# %%

def type_error_patch_metric_prob(logits, end_positions, err1_tok=traceback_token_id):
    probs = logits.softmax(dim=-1)
    err1_logits = probs[range(logits.size(0)), end_positions, :][:, err1_tok]
    return err1_logits.mean()

with torch.no_grad():
    logits = model(samples[0])
clean_diff = type_error_patch_metric_prob(logits, selected_pos['end'])
print(clean_diff)
with torch.no_grad():
    logits = model(samples[1])
corr_diff = type_error_patch_metric_prob(logits, selected_pos['end'])
print(corr_diff)

# %%

def _err_type_metric(logits, clean_logit_diff, corr_logit_diff, end_positions):
    patched_logit_diff = type_error_patch_metric_prob(logits, end_positions)
    return (patched_logit_diff - corr_logit_diff) / (clean_logit_diff - corr_logit_diff)

err_metric_denoising = partial(_err_type_metric, clean_logit_diff=clean_diff, corr_logit_diff=corr_diff, end_positions=selected_pos['end'])


tensor(0.4328, device='cuda:0')
tensor(0.0895, device='cuda:0')


In [8]:
import gc 
del logits
gc.collect()

0

In [10]:
from transformer_lens import ActivationCache, utils
from transformer_lens.hook_points import HookPoint
# from torchtyping import TensorType as TT

def get_cache_fwd_and_bwd(
    model,
    tokens,
    metric,
    sae,
    error_term: bool = True,
    retain_graph: bool = True
):
    # torch.set_grad_enabled(True)
    model.reset_hooks()
    # model.reset_saes()
    cache = {}
    grad_cache = {}
    filter_base_acts = lambda name: "blocks.10.hook_resid_post" in name
    # filter_sae_acts = lambda name: "hook_sae_acts_post" in name

    def forward_cache_hook(act, hook):
        act.requires_grad_(True)
        # act.retain_graph()
        cache[hook.name] = act.detach()

    def backward_cache_hook(grad, hook):
        grad.requires_grad_(True)
        # grad.retain_graph()
        grad_cache[hook.name] = grad.detach()

    # sae.use_error_term = error_term
    # model.add_sae(sae)
    model.add_hook(filter_base_acts, forward_cache_hook, "fwd")
    model.add_hook(filter_base_acts, backward_cache_hook, "bwd")
    value = metric(model(tokens))
    value.backward() #retain_graph=retain_graph)

    model.reset_hooks()
    # model.reset_saes()
    # torch.set_grad_enabled(False)
    return (
        value,
        ActivationCache(cache, model),
        ActivationCache(grad_cache, model),
    )

In [11]:
clean_value, clean_cache, _ = get_cache_fwd_and_bwd(model, samples[0], err_metric_denoising, sae)
print("Clean Value:", clean_value)
print("Clean Activations Cached:", len(clean_cache))
# print("Clean Gradients Cached:", len(clean_grad_cache))

corrupted_value, corrupted_cache, corrupted_grad_cache = get_cache_fwd_and_bwd(model, samples[1], err_metric_denoising, sae)
print("Corrupted Value:", corrupted_value)
print("Corrupted Activations Cached:", len(corrupted_cache))
print("Corrupted Gradients Cached:", len(corrupted_grad_cache))

Clean Value: tensor(1., device='cuda:0', grad_fn=<DivBackward0>)
Clean Activations Cached: 1
Corrupted Value: tensor(0., device='cuda:0', grad_fn=<DivBackward0>)
Corrupted Activations Cached: 1
Corrupted Gradients Cached: 1


In [12]:
sae_acts = sae.encode(clean_cache['blocks.10.hook_resid_post'])
sae_acts_corr = sae.encode(corrupted_cache['blocks.10.hook_resid_post'])
print(sae_acts.shape, sae_acts_corr.shape)

torch.Size([50, 33, 16384]) torch.Size([50, 33, 16384])


In [11]:
sae_grad_cache = torch.einsum('bij,kj->bik', corrupted_grad_cache['blocks.10.hook_resid_post'], sae.W_dec)
print(sae_grad_cache.shape)

torch.Size([50, 33, 16384])


In [14]:
from IPython.display import IFrame
import torch
import einops
import requests
from bs4 import BeautifulSoup
import re
import json

# Function to get HTML for a specific feature
def get_dashboard_html(sae_release="gemma-2-9b", sae_id="10-gemmascope-res-16k", feature_idx=0):
    html_template = "https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"
    return html_template.format(sae_release, sae_id, feature_idx)

# Function to scrape the description for a feature
def scrape_description(layer, feature_idx):
    url = get_dashboard_html(sae_release="gemma-2-2b", sae_id=f"{layer}-gemmascope-res-16k", feature_idx=feature_idx)
    response = requests.get(url)
    
    if response.status_code == 200:
        html = response.text
        soup = BeautifulSoup(html, 'html.parser')
        soup_str = str(soup)

        # Use regex to find the "description" field in the JSON structure
        all_descriptions = re.findall(r'description\\":\\"(.*?)",', soup_str)
        
        if all_descriptions:
            return all_descriptions[-1]  # Return the last description
        else:
            return "No description found."
    else:
        return f"Failed to retrieve the webpage. Status code: {response.status_code}"


In [35]:
top_feats_per_pos = {}
K = 10
for idx, val in selected_pos.items():
    # Get the selected activations and gradients
    clean_residual_selected = sae_acts[torch.arange(sae_acts.shape[0]), val, :]
    corr_residual_selected = sae_acts_corr[torch.arange(sae_acts_corr.shape[0]), val, :]
    corr_grad_residual_selected = sae_grad_cache[torch.arange(sae_grad_cache.shape[0]), val, :]

    # Residual attribution calculation only for the selected positions
    residual_attr_final = einops.reduce(
        corr_grad_residual_selected * (clean_residual_selected - corr_residual_selected),
        "batch n_features -> n_features",
        "sum",
    )

    # Get the top K features based on the absolute values
    abs_residual_attr_final = torch.abs(residual_attr_final)
    top_feats = torch.topk(abs_residual_attr_final, K)
    
    # Retrieve the top indices and the original signed values for these indices
    top_indices = top_feats.indices
    top_values = residual_attr_final[top_indices]  # Use original residual attribution values (with signs)

    # Save the results
    top_feats_per_pos[idx] = (top_indices, top_values)

# %%
top_feats_per_pos


{'s_start': (tensor([7, 6, 4, 5, 1, 0, 2, 3, 8, 9], device='cuda:0'),
  tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0',
         grad_fn=<IndexBackward0>)),
 's_end': (tensor([7, 6, 4, 5, 1, 0, 2, 3, 8, 9], device='cuda:0'),
  tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0',
         grad_fn=<IndexBackward0>)),
 'i_start': (tensor([10694,  4718, 15358, 10207,   107,  5233,  6492,   430,  1911,   116],
         device='cuda:0'),
  tensor([ 0.0846,  0.0418,  0.0340,  0.0316, -0.0180,  0.0155,  0.0151, -0.0121,
          -0.0118,  0.0108], device='cuda:0', grad_fn=<IndexBackward0>)),
 'i_end': (tensor([ 2769,  9184, 14504, 12754,  7409,   771, 10182,  9154, 14819,  4482],
         device='cuda:0'),
  tensor([ 0.0794, -0.0248,  0.0098,  0.0093,  0.0083,  0.0081,  0.0079, -0.0075,
           0.0074,  0.0072], device='cuda:0', grad_fn=<IndexBackward0>)),
 'end': (tensor([ 6478,  8503, 13982,  2350,  1721,   632,  7500,  5730, 10406, 16010],
         device=

In [36]:
ttl_latent_attr = 0
for key, val in top_feats_per_pos.items():
    print(val[1])
    ttl_latent_attr += val[1].sum()

ttl_latent_attr
    

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0',
       grad_fn=<IndexBackward0>)
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0',
       grad_fn=<IndexBackward0>)
tensor([ 0.0846,  0.0418,  0.0340,  0.0316, -0.0180,  0.0155,  0.0151, -0.0121,
        -0.0118,  0.0108], device='cuda:0', grad_fn=<IndexBackward0>)
tensor([ 0.0794, -0.0248,  0.0098,  0.0093,  0.0083,  0.0081,  0.0079, -0.0075,
         0.0074,  0.0072], device='cuda:0', grad_fn=<IndexBackward0>)
tensor([ 0.0044,  0.0040,  0.0027,  0.0027,  0.0026, -0.0026,  0.0020,  0.0018,
        -0.0015,  0.0015], device='cuda:0', grad_fn=<IndexBackward0>)


tensor(0.3142, device='cuda:0', grad_fn=<AddBackward0>)

In [16]:
residual_attr_final = einops.reduce(
    corrupted_grad_cache['blocks.10.hook_resid_post'] * (clean_cache['blocks.10.hook_resid_post'] - corrupted_cache['blocks.10.hook_resid_post']),
    "batch pos d_model -> pos",
    "sum",
)
residual_attr_final.shape

torch.Size([33])

In [20]:
residual_attr_final.sum()

tensor(0.4154, device='cuda:0')

In [19]:
sae_clean_out = sae.decode(sae_acts)
sae_clean_out.shape

torch.Size([50, 33, 3584])

In [48]:
top_feats_per_pos

{'s_start': (tensor([7, 6, 4, 5, 1, 0, 2, 3, 8, 9], device='cuda:0'),
  tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0',
         grad_fn=<IndexBackward0>)),
 's_end': (tensor([7, 6, 4, 5, 1, 0, 2, 3, 8, 9], device='cuda:0'),
  tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0',
         grad_fn=<IndexBackward0>)),
 'i_start': (tensor([10694,  4718, 15358, 10207,   107,  5233,  6492,   430,  1911,   116],
         device='cuda:0'),
  tensor([ 0.0846,  0.0418,  0.0340,  0.0316, -0.0180,  0.0155,  0.0151, -0.0121,
          -0.0118,  0.0108], device='cuda:0', grad_fn=<IndexBackward0>)),
 'i_end': (tensor([ 2769,  9184, 14504, 12754,  7409,   771, 10182,  9154, 14819,  4482],
         device='cuda:0'),
  tensor([ 0.0794, -0.0248,  0.0098,  0.0093,  0.0083,  0.0081,  0.0079, -0.0075,
           0.0074,  0.0072], device='cuda:0', grad_fn=<IndexBackward0>)),
 'end': (tensor([ 6478,  8503, 13982,  2350,  1721,   632,  7500,  5730, 10406, 16010],
         device=

In [54]:

import torch
import einops
import requests
from bs4 import BeautifulSoup
import re
import json

# Function to get HTML for a specific feature
def get_dashboard_html(sae_release="gemma-2-9b", sae_id="10-gemmascope-res-16k", feature_idx=0):
    html_template = "https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"
    return html_template.format(sae_release, sae_id, feature_idx)

# Function to scrape the description for a feature
def scrape_description(layer, feature_idx):
    url = get_dashboard_html(sae_release="gemma-2-9b", sae_id=f"{layer}-gemmascope-res-16k", feature_idx=feature_idx)
    response = requests.get(url)
    
    if response.status_code == 200:
        html = response.text
        soup = BeautifulSoup(html, 'html.parser')
        soup_str = str(soup)

        # Use regex to find the "description" field in the JSON structure
        all_descriptions = re.findall(r'description\\":\\"(.*?)",', soup_str)
        
        if all_descriptions:
            return all_descriptions[-1]  # Return the last description
        else:
            return "No description found."
    else:
        return f"Failed to retrieve the webpage. Status code: {response.status_code}"

# %%
layer = 10
top_10_features_for_rel_pos = {}
interesting_keys = list(top_feats_per_pos.keys())[2:]
# print(interesting_keys)
for key in interesting_keys:
    print(f"Position: {key}")
    indices, values = top_feats_per_pos[key]
    top_10_features_for_rel_pos[key] = []
    for idx, val in zip(indices, values):
        print(f"Feature Index: {idx}, Value: {val}")
        description = scrape_description(layer, idx)
        html_link = get_dashboard_html(sae_release="gemma-2-9b", sae_id="10-gemmascope-res-16k", feature_idx=idx)
        print(description)
        top_10_features_for_rel_pos[key].append((idx.item(), val.item(), description, html_link))

# Save the results to a JSON file
with open('tasks/error_detection/type/out/layer10_top_10_features_for_rel_pos_abs.json', 'w') as json_file:
    json.dump(top_10_features_for_rel_pos, json_file, indent=4)



Position: i_start
Feature Index: 10694, Value: 0.08463253080844879
structured textual formats or sections within documents\
Feature Index: 4718, Value: 0.04182133078575134
instances of quotations or direct speech within the text\
Feature Index: 15358, Value: 0.03400522843003273
references to statistical models or distributions\
Feature Index: 10207, Value: 0.03164256364107132
 elements and syntax related to programming and code structures\
Feature Index: 107, Value: -0.01801796443760395
quoted strings and their formatting details\
Feature Index: 5233, Value: 0.015487908385694027
different types of quotation marks and delimiters used in programming syntax\
Feature Index: 6492, Value: 0.015111686661839485
 coding syntax related to delimiters and string concatenation\
Feature Index: 430, Value: -0.012128639966249466
documentation references and links related to APIs and guides\
Feature Index: 1911, Value: -0.011777937412261963
 symbols and operators related to coding syntax and structure\

In [50]:
get_dashboard_html(sae_release="gemma-2-9b", sae_id="10-gemmascope-res-16k", feature_idx=1000)

'https://neuronpedia.org/gemma-2-9b/10-gemmascope-res-16k/1000?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300'

In [68]:
import matplotlib.pyplot as plt
import numpy as np
import os

def plot_and_save_position_heatmap(data_idx: int, position_name: str, top_features: list, str_tokens_clean: list, str_tokens_corr: list, clean_cache: dict, corr_cache: dict):
    """
    Function to plot and save a heatmap for the top features in a given position.

    Args:
    - data_idx: Index of the specific prompt to visualize.
    - position_name: The name of the position being analyzed (e.g., d1, d2, END).
    - top_features: List of top features for the current position.
    - str_tokens_clean: List of input tokens (strings) from the clean data.
    - str_tokens_corr: List of input tokens (strings) from the corrupted data.
    - clean_cache: Dictionary of activations from the clean cache.
    - corr_cache: Dictionary of activations from the corrupted cache.

    Returns:
    - Saves the heatmap as a PNG file in the 'features' directory.
    """
    activations_matrix = []

    # Iterate over the top features (each feature has two rows: clean and corrupted)
    for feature in top_features:
        feature_idx = feature[0]

        # Get activations from clean and corrupted caches
        clean_activations = clean_cache[data_idx, :, feature_idx].cpu().detach().numpy()
        corr_activations = corr_cache[data_idx, :, feature_idx].cpu().detach().numpy()

        # Append clean and corrupted activations to the matrix
        activations_matrix.append(clean_activations)
        activations_matrix.append(corr_activations)

    # Convert the activations matrix to a numpy array for plotting
    activations_matrix = np.array(activations_matrix)

    # Create a heatmap for the current position
    plt.figure(figsize=(10, 6))  # Adjust the figure size based on the number of rows

    plt.imshow(activations_matrix, aspect='auto', cmap='coolwarm')

    # Create combined labels for x-axis by stacking clean and corrupted tokens
    combined_tokens = [f"{clean_token} | {corr_token}" for clean_token, corr_token in zip(str_tokens_clean, str_tokens_corr)]

    # Set x-axis to display the combined input tokens
    plt.xticks(ticks=np.arange(len(combined_tokens)), labels=combined_tokens, rotation=90)

    # Set y-axis labels to show clean and corrupted rows for each feature
    y_ticks = []
    for feature in top_features:
        feature_idx = feature[0]
        y_ticks.append(f'{feature_idx} (clean)')
        y_ticks.append(f'{feature_idx} (corr)')

    plt.yticks(ticks=np.arange(len(y_ticks)), labels=y_ticks)

    # Add horizontal lines to separate clean and corrupted rows
    for i in range(1, len(y_ticks), 2):  # Add line after every two rows
        plt.axhline(i + 0.5, color='black', linewidth=1)
        
    # Add a color bar to the side
    plt.colorbar(label='Activation Value')

    # Set axis labels and title
    plt.xlabel("Tokens (Clean / Corrupted)")
    plt.ylabel("Features (Clean and Corrupted)")
    plt.title(f"Feature Activations for Position {position_name} (Top Features)")
    plt.subplots_adjust(left=0.25, right=0.9, top=0.9, bottom=0.3)  # Adjusted bottom for longer x-labels

    # Create the 'features' directory if it doesn't exist
    if not os.path.exists("features"):
        os.makedirs("features")

    # Save the heatmap as a PNG file
    filename = f"tasks/error_detection/type/out/layer10_features/heatmap_position_{position_name}_abs_stacked.png"
    plt.savefig(filename, bbox_inches="tight")
    plt.close()  # Close the plot to avoid display issues in loops

    print(f"Heatmap saved: {filename}")

# Example loop to generate heatmaps for each position
for position_name, top_10_features in top_10_features_for_rel_pos.items():
    print(f"Top Features for Position: {position_name}")

    # Example input tokens (replace these with the actual input tokens for clean and corrupted)
    str_tokens_clean = model.to_str_tokens(samples[0][0])  # Clean input tokens
    str_tokens_corr = model.to_str_tokens(samples[1][0])  # Corrupted input tokens

    # Generate and save the heatmap for the current position
    plot_and_save_position_heatmap(data_idx=0, position_name=position_name, top_features=top_10_features, str_tokens_clean=str_tokens_clean, str_tokens_corr=str_tokens_corr, clean_cache=sae_acts, corr_cache=sae_acts_corr)


Top Features for Position: i_start
Heatmap saved: tasks/error_detection/type/out/layer10_features/heatmap_position_i_start_abs_stacked.png
Top Features for Position: i_end
Heatmap saved: tasks/error_detection/type/out/layer10_features/heatmap_position_i_end_abs_stacked.png
Top Features for Position: end
Heatmap saved: tasks/error_detection/type/out/layer10_features/heatmap_position_end_abs_stacked.png


# Error attribution

In [31]:
error_clean_cache = sae.decode(sae_acts) - clean_cache['blocks.10.hook_resid_post']
error_corr_cache = sae.decode(sae_acts_corr) - corrupted_cache['blocks.10.hook_resid_post'] 
print(error_clean_cache.shape, error_corr_cache.shape)

torch.Size([50, 33, 3584]) torch.Size([50, 33, 3584])


In [None]:
err_corr_grad = 

In [None]:
grad ( in - out ) 

In [None]:
grad ( in - W_dec.W_enc. in ) 

In [None]:
grad in -  grad (W_dec.W_enc. in ) 

In [None]:
grad(in) -  W_dec.W_enc.grad(in) 

In [23]:
sae.W_enc.shape

torch.Size([3584, 16384])

In [24]:
sae.W_dec.shape

torch.Size([16384, 3584])

In [None]:
torch.einsum('ij,ji->1', sae.W_enc.shape, sae.W_dec.shape)

In [25]:
print(sae_grad_cache.shape)

torch.Size([50, 33, 16384])


In [32]:
sae_err_try = torch.einsum('ijk,lk->ijl', sae_grad_cache, sae.W_enc)
sae_err_try.shape

torch.Size([50, 33, 3584])

In [38]:
corrupted_grad_cache['blocks.10.hook_resid_post'].shape

torch.Size([50, 33, 3584])

In [62]:
residual_attr_err_try = einops.reduce(
    corrupted_grad_cache['blocks.10.hook_resid_post'] * error_clean_cache,
    "batch pos d_model -> pos",
    "sum",
)
residual_attr_err_try.sum()

tensor(1.7372, device='cuda:0', grad_fn=<SumBackward0>)

In [41]:
residual_attr_err = einops.reduce(
    corrupted_grad_cache['blocks.10.hook_resid_post'] * (error_clean_cache - error_corr_cache),
    "batch pos d_model -> pos",
    "sum",
)
residual_attr_err.sum()

tensor(0.0626, device='cuda:0', grad_fn=<SumBackward0>)

In [42]:
residual_attr_err = einops.reduce(
    sae_err_try * (error_clean_cache - error_corr_cache),
    "batch pos d_model -> pos",
    "sum",
)
residual_attr_err.sum()

tensor(0.4772, device='cuda:0', grad_fn=<SumBackward0>)

In [30]:
residual_attr_err.sum()

tensor(-0.4772, device='cuda:0', grad_fn=<SumBackward0>)

In [47]:
with torch.no_grad():
    sae.use_error_term = True
    model.add_sae(sae)
    logits = model(samples[0])
    model.reset_saes()
# clean_diff = type_error_patch_metric_prob(logits, selected_pos['end'])
print(err_metric_denoising(logits))
with torch.no_grad():
    sae.use_error_term = True
    model.add_sae(sae)
    logits = model(samples[1])
    model.reset_saes()
corr_diff = type_error_patch_metric_prob(logits, selected_pos['end'])
print(err_metric_denoising(logits))

tensor(1.0000, device='cuda:0')
tensor(-8.3084e-07, device='cuda:0')


In [44]:
without saes
52 clean
9 corr

with saes 
24 clean
15 corr 

SyntaxError: invalid syntax (1318574256.py, line 1)

In [45]:
err_metric_denoising(logits)

tensor(0.1930, device='cuda:0')

In [26]:
resi recons = sae_out - 3k 

w_dec - 16 x 3

latents - 16k

w_enc - 3 x 16 

sae_in - 3k  

NameError: name 'sae_out' is not defined

In [55]:
samples[0][0]

'Type "help", "copyright", "credits" or "license" for more information.\n>>> print("my_var" + 83)\n'

In [56]:
samples[1][0]

'Type "help", "copyright", "credits" or "license" for more information.\n>>> print("my_var" + "83")\n'

In [59]:
selected_pos['end']

[32,
 30,
 30,
 30,
 30,
 30,
 30,
 30,
 30,
 30,
 30,
 30,
 30,
 30,
 30,
 30,
 29,
 30,
 32,
 30,
 30,
 30,
 29,
 30,
 30,
 30,
 30,
 30,
 30,
 32,
 30,
 30,
 30,
 30,
 30,
 30,
 30,
 30,
 29,
 29,
 30,
 30,
 30,
 30,
 30,
 30,
 29,
 30,
 30,
 32]

# Feature patching 

In [69]:
sae_acts.shape

torch.Size([50, 33, 16384])

In [71]:
temp = [300, 1456, 7003]
sae_acts[:, :, temp].shape

torch.Size([50, 33, 3])

In [14]:
import json

# Load the JSON data
with open('tasks/error_detection/type/out/layer10_top_10_features_for_rel_pos.json', 'r') as json_file:
    top_10_features_for_rel_pos = json.load(json_file)
top_10_features_for_rel_pos

{'i_start': [[10694,
   0.06348311901092529,
   'words related to floral scents and their components\\'],
  [4718,
   0.04034152999520302,
   'percentage indicators and formatting symbols\\'],
  [15358,
   0.03502115234732628,
   'image references and their formatting in text\\'],
  [10207, 0.030231375247240067, ' references to forums or discussions\\'],
  [6492,
   0.016979293897747993,
   'terms related to nanotechnology and measurements\\'],
  [16190,
   0.010703159496188164,
   ' programming constructs related to object-oriented features and function declarations\\'],
  [5233,
   0.010215678252279758,
   'numerical data or statistics related to specific subjects\\'],
  [13768,
   0.009093033149838448,
   'numerical values related to measurements or quantities\\'],
  [116,
   0.007796787656843662,
   'keywords related to scientific concepts and phenomena involving particles and their interactions\\'],
  [10669,
   0.007719412446022034,
   'structures or patterns typically used in ma

In [15]:
imp_feats = []
for key, items in top_10_features_for_rel_pos.items(): 
    if key in ['i_start', 'i_end']:
        for item in items:
            imp_feats.append(item[0])
imp_feats

[10694,
 4718,
 15358,
 10207,
 6492,
 16190,
 5233,
 13768,
 116,
 10669,
 2769,
 12754,
 14819,
 771,
 14504,
 4476,
 4482,
 8668,
 5172,
 7965]

In [17]:
def patch_with_sae_features_list_with_hook(model, sae, clean_cache, corr_tokens, patching_metric, use_error_term=True, 
                                feature_list=None, next_token_column=True, progress_bar=True):
    # Initialize batch and sequence size based on the activation store settings

    def patching_hook(corrupted_activation, hook, clean_activation, feature_list):
        corrupted_activation[:, :, feature_list, ...] = clean_activation[:, :, feature_list, ...]
        return corrupted_activation
    
    # current_activation_name = utils.get_act_name("hook_sae_acts_post", layer=0)
    hook_point = sae.cfg.hook_name + '.hook_sae_acts_post'
    # The hook function cannot receive additional inputs, so we use partial to include the specific index and the corresponding clean activation
    current_hook = partial(
        patching_hook,
        clean_activation=clean_cache,
        feature_list=feature_list
    )

    model.add_sae(sae)
    sae.use_error_term = use_error_term
    # Define the hook point in the model where the ablation hook will be attached
    
    model.add_hook(hook_point, current_hook, "fwd")
    # Run the model with the hooks
    with torch.no_grad():
        patched_logits = model(corr_tokens)
    value = patching_metric(patched_logits)
    print(f"patching metric output {value}")

    model.reset_hooks()
    model.reset_saes()
    sae.reset_hooks()
    return patched_logits, value

def non_patch(model, corr_tokens, patching_metric):
    with torch.no_grad():
        value = patching_metric(model(corr_tokens))
    return value

def non_patch_saes(model, sae, corr_tokens, patching_metric,  use_error_term=True):
    model.add_sae(sae)
    sae.use_error_term = use_error_term
    with torch.no_grad():
        value = patching_metric(model(corr_tokens))
    model.reset_hooks()
    model.reset_saes()
    sae.reset_hooks()
    return value

In [17]:
non_patch(model, samples[1], err_metric_denoising)

tensor(0., device='cuda:0')

In [18]:
non_patch_saes(model, sae, samples[1], err_metric_denoising, use_error_term=False)

tensor(0.2367, device='cuda:0')

In [19]:
non_patch_saes(model, sae, samples[0], err_metric_denoising, use_error_term=False)

tensor(0.4330, device='cuda:0')

In [24]:
p_logits, met_val = patch_with_sae_features_list_with_hook(model, sae, sae_acts, samples[1], err_metric_denoising, use_error_term=True, feature_list=imp_feats)

patching metric output -4.340986592410445e-08


In [16]:
met_val

tensor(0.2619, device='cuda:0')

In [20]:
!nvidia-smi

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Wed Oct 23 07:40:47 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.90.07              Driver Version: 550.90.07      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100 80GB PCIe          On  |   00000000:81:00.0 Off |                    0 |
| N/A   42C    P0             64W /  300W |   69611MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [19]:
non_patch_saes(model, sae, samples[0], err_metric_denoising, use_error_term=False)

OutOfMemoryError: CUDA out of memory. Tried to allocate 92.00 MiB. GPU 0 has a total capacity of 79.14 GiB of which 60.75 MiB is free. Process 1602148 has 79.07 GiB memory in use. Of the allocated memory 77.18 GiB is allocated by PyTorch, and 1.40 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [22]:
samples[0][0]

'Type "help", "copyright", "credits" or "license" for more information.\n>>> print("price" + 7)\n'

In [23]:
samples[1][0]

'Type "help", "copyright", "credits" or "license" for more information.\n>>> print("price" + "7")\n'