In [1]:
import os
import gc
import torch
os.chdir("/work/pi_jensen_umass_edu/jnainani_umass_edu/CircuitAnalysisSAEs")
import json
from sae_lens import SAE, HookedSAETransformer
from circ4latents import data_gen
from functools import partial
import einops

# Function to manage CUDA memory and clean up
def cleanup_cuda():
   torch.cuda.empty_cache()
   gc.collect()
# cleanup_cuda()
# Load the config
with open("config.json", 'r') as file:
   config = json.load(file)
token = config.get('huggingface_token', None)
os.environ["HF_TOKEN"] = token

# Define device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

hf_cache = "/work/pi_jensen_umass_edu/jnainani_umass_edu/mechinterp/huggingface_cache/hub"
os.environ["HF_HOME"] = hf_cache

# Load the model
model = HookedSAETransformer.from_pretrained("google/gemma-2-9b", device=device, cache_dir=hf_cache)



Device: cuda


Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2-9b into HookedTransformer


In [2]:
sae21 = SAE.from_pretrained(release="gemma-scope-9b-pt-res", sae_id=f"layer_21/width_16k/average_l0_129", device=device)[0]

# Simple String:Int Dictionary Key Error 

In [3]:

import json
import random

# Expanding the name pool with a larger set of names
extended_name_pool = [
    "Bob", "Sam", "Lilly", "Rob", "Alice", "Charlie", "Sally", "Tom", "Jake", "Emily", 
    "Megan", "Chris", "Sophia", "James", "Oliver", "Isabella", "Mia", "Jackson", 
    "Emma", "Ava", "Lucas", "Benjamin", "Ethan", "Grace", "Olivia", "Liam", "Noah"
]

for name in extended_name_pool:
    assert len(model.tokenizer.encode(name)) == 2, f"Name {name} has more than 1 token"

# Function to generate the dataset with correct and incorrect keying into dictionaries
def generate_extended_dataset(name_pool, num_samples=5):
    dataset = []
    
    for _ in range(num_samples):
        # Randomly select 5 names from the pool
        selected_names = random.sample(name_pool, 5)
        # Assign random ages to the selected names
        age_dict = {name: random.randint(10, 19) for name in selected_names}
        
        # Create a correct example
        correct_name = random.choice(list(age_dict.keys()))
        correct_prompt = f'Type "help", "copyright", "credits" or "license" for more information.\n>>> age = {age_dict}\n>>> var = age["{correct_name}"]\n'
        correct_response = age_dict[correct_name]
        correct_token = str(correct_response)[0]
        
        # Create an incorrect example with a name not in the dictionary
        incorrect_name = random.choice([name for name in name_pool if name not in age_dict])
        incorrect_prompt = f'Type "help", "copyright", "credits" or "license" for more information.\n>>> age = {age_dict}\n>>> var = age["{incorrect_name}"]\n'
        incorrect_response = "Traceback"
        incorrect_token = "Traceback"
        
        # Append the pair of correct and incorrect examples
        dataset.append({
            "correct": {
                "prompt": correct_prompt,
                "response": correct_response,
                "token": correct_token
            },
            "error": {
                "prompt": incorrect_prompt,
                "response": incorrect_response,
                "token": incorrect_token
            }
        })
        
    return dataset

# Generate the extended dataset
json_dataset = generate_extended_dataset(extended_name_pool, num_samples=100)

# Output the JSON structure

# %%
clean_prompts = []
corr_prompts = []

answer_token = model.to_single_token(">>>")
traceback_token = model.to_single_token("Traceback")

for item in json_dataset[:50]:
    corr_prompts.append(item["correct"]["prompt"])
    clean_prompts.append(item["error"]["prompt"])

clean_tokens = model.to_tokens(clean_prompts)
corr_tokens = model.to_tokens(corr_prompts)

# %%
def logit_diff_fn(logits):
    err = logits[:, -1, traceback_token]
    no_err = logits[:, -1, answer_token]
    return (err - no_err).mean()

# Disable gradients for all parameters
for param in model.parameters():
   param.requires_grad_(False)

# # Compute logits for clean and corrupted samples
logits = model(clean_tokens)
clean_diff = logit_diff_fn(logits)

logits = model(corr_tokens)
corr_diff = logit_diff_fn(logits)

print(f"clean_diff: {clean_diff}")
print(f"corr_diff: {corr_diff}")

# # Cleanup
del logits
cleanup_cuda()

# # Define error type metric
def _err_type_metric(logits, clean_logit_diff, corr_logit_diff):
    patched_logit_diff = logit_diff_fn(logits)
    return (patched_logit_diff - corr_logit_diff) / (clean_logit_diff - corr_logit_diff)

err_metric_denoising = partial(_err_type_metric, clean_logit_diff=clean_diff, corr_logit_diff=corr_diff)

clean_diff: 3.566526412963867
corr_diff: -4.803060531616211


In [4]:
print(json_dataset[0]['error']['prompt'])

Type "help", "copyright", "credits" or "license" for more information.
>>> age = {'Benjamin': 13, 'Chris': 16, 'Lilly': 12, 'Charlie': 10, 'Emma': 16}
>>> var = age["Isabella"]



## Attr calculation

In [5]:
from transformer_lens import ActivationCache, utils
from transformer_lens.hook_points import HookPoint
# from torchtyping import TensorType as TT

def get_cache_fwd_and_bwd(
    model,
    tokens,
    metric,
    sae,
    error_term: bool = True,
    retain_graph: bool = True
):
    # torch.set_grad_enabled(True)
    model.reset_hooks()
    # model.reset_saes()
    cache = {}
    grad_cache = {}
    filter_base_acts = lambda name: "blocks.21.hook_resid_post" in name
    # filter_sae_acts = lambda name: "hook_sae_acts_post" in name

    def forward_cache_hook(act, hook):
        act.requires_grad_(True)
        # act.retain_graph()
        cache[hook.name] = act.detach()

    def backward_cache_hook(grad, hook):
        grad.requires_grad_(True)
        # grad.retain_graph()
        grad_cache[hook.name] = grad.detach()

    # sae.use_error_term = error_term
    # model.add_sae(sae)
    model.add_hook(filter_base_acts, forward_cache_hook, "fwd")
    model.add_hook(filter_base_acts, backward_cache_hook, "bwd")
    # logits = run_with_saes_filtered(tokens, [model.tokenizer.bos_token_id, model.tokenizer.eos_token_id, model.tokenizer.pad_token_id], model, [sae])
    value = metric(model(tokens)) #logits)
    value.backward() #retain_graph=retain_graph)

    model.reset_hooks()
    # model.reset_saes()
    # torch.set_grad_enabled(False)
    return (
        value,
        ActivationCache(cache, model),
        ActivationCache(grad_cache, model),
    )


In [None]:

clean_value, clean_cache, _ = get_cache_fwd_and_bwd(model, clean_tokens, err_metric_denoising, sae21)
print("Clean Value:", clean_value)
print("Clean Activations Cached:", len(clean_cache))
# print("Clean Gradients Cached:", len(clean_grad_cache))

corrupted_value, corrupted_cache, corrupted_grad_cache = get_cache_fwd_and_bwd(model, corr_tokens, err_metric_denoising, sae21)
print("Corrupted Value:", corrupted_value)
print("Corrupted Activations Cached:", len(corrupted_cache))
print("Corrupted Gradients Cached:", len(corrupted_grad_cache))

# # Cleanup
del clean_value, corrupted_value
cleanup_cuda()

In [16]:
corrupted_grad_cache['blocks.21.hook_resid_post'][:, 1:, :]

tensor([[[-3.8120e-07, -3.2566e-07, -1.0438e-06,  ...,  7.1768e-07,
          -9.9885e-07,  8.6916e-07],
         [-3.1879e-07, -6.1671e-09,  1.5720e-07,  ...,  1.1308e-07,
           1.0026e-07,  5.5729e-08],
         [-7.8373e-08, -9.4730e-08, -2.5211e-07,  ...,  1.4263e-07,
          -7.1474e-07,  1.2810e-07],
         ...,
         [ 1.6529e-06,  1.1248e-06, -4.2053e-07,  ..., -2.3218e-06,
           3.5809e-07, -4.9164e-07],
         [ 5.5071e-06,  9.5625e-08, -2.0255e-06,  ..., -3.0415e-06,
           6.6743e-06, -1.9352e-06],
         [-1.2636e-05, -2.7388e-05,  3.0116e-06,  ...,  1.2309e-05,
           2.0280e-05, -8.1272e-06]],

        [[-4.9830e-07, -1.5768e-07, -1.0518e-06,  ...,  5.8069e-07,
          -9.4853e-07,  1.0167e-06],
         [-3.2971e-07, -1.6169e-08,  1.6466e-07,  ...,  1.2835e-07,
           1.0842e-07,  5.0115e-08],
         [-1.0604e-07, -1.4316e-07, -2.7018e-07,  ...,  8.6350e-08,
          -7.1302e-07,  7.1841e-08],
         ...,
         [ 6.9058e-07,  8

In [12]:
sae_acts = sae21.encode(clean_cache['blocks.21.hook_resid_post'][:, 1:, :])
sae_acts_corr = sae21.encode(corrupted_cache['blocks.21.hook_resid_post'][:, 1:, :])
print(sae_acts.shape, sae_acts_corr.shape)

sae_grad_cache = torch.einsum('bij,kj->bik', corrupted_grad_cache['blocks.21.hook_resid_post'][:, 1:, :], sae21.W_dec)
print(sae_grad_cache.shape)

torch.Size([50, 66, 16384]) torch.Size([50, 66, 16384])
torch.Size([50, 66, 16384])


In [17]:
K = 100
# Residual attribution calculation only for the selected positions
residual_attr_final = einops.reduce(
    sae_grad_cache * (sae_acts - sae_acts_corr),
    "batch pos n_features -> n_features",
    "sum",
)
# Get the top K features based on the absolute values
abs_residual_attr_final = torch.abs(residual_attr_final)
top_feats = torch.topk(abs_residual_attr_final, K)
top_indices = top_feats.indices
top_values = residual_attr_final[top_indices] 
print(top_indices, top_values)

tensor([ 3237,  3271,  8156,  8233, 11219, 10711,  9585,  8535,  5795, 13756,
         3925,   430, 10160,  6290, 11774,  9654,  3377, 13275, 13163,   111,
         8311, 15795, 13862, 10350,  7892, 13218, 10618, 14598, 12223,  4766,
         5577,  3395, 11725, 14495, 15571,  7482,  9768,  9960, 13842,  2782,
         5804,  6840, 11779, 15430,  5498, 11712, 14638, 13839, 10587,  7675,
         3807,  4804,  2405, 14331,  6855,  8696,   691, 10110,  1662, 12189,
         4337,  1407,  8180, 12151,  7950, 14786,  9190, 12440,  9858, 10916,
           37, 13831,  2187,  9304, 14499,  3818, 10219,  2120,   940, 12739,
        15805,  5615,  7988,  7237,  6098,  2239,  6606,  2314,  4022, 14229,
         6646, 10719,  7381,  8680,  1436, 10023, 14613,  3863, 15777, 12609],
       device='cuda:0') tensor([ 0.0732,  0.0403,  0.0253,  0.0250,  0.0206,  0.0195,  0.0191,  0.0190,
         0.0115,  0.0100,  0.0098,  0.0088,  0.0087,  0.0085,  0.0082,  0.0071,
         0.0064,  0.0063,  0.0062, 

In [18]:
sum(top_values)

tensor(0.4175, device='cuda:0', grad_fn=<AddBackward0>)

## Performance recovery

In [15]:
def run_with_saes_filtered_cache(tokens, filtered_ids, model, saes):
    # Ensure tokens are a torch.Tensor
    if not isinstance(tokens, torch.Tensor):
        tokens = torch.tensor(tokens).to(model.cfg.device)  # Move to the device of the model

    # Create a mask where True indicates positions to modify
    mask = torch.ones_like(tokens, dtype=torch.bool)
    for token_id in filtered_ids:
        mask &= tokens != token_id


    # Expand the mask once, so it matches the shape [batch_size, seq_len, 1]
    mask_expanded = mask.unsqueeze(-1)  # Expand to allow broadcasting
    mask_expanded = mask_expanded.to(model.cfg.device)  # Move the mask to the same device as the model

    # Dictionary to store the modified activations
    sae_outs = {}


    # For each SAE, add the appropriate hook
    for sae in saes:
        hook_point = sae.cfg.hook_name


        # Define the filtered hook function (optimized)
        def filtered_hook(act, hook, sae=sae, mask_expanded=mask_expanded):
            # Apply the SAE only where mask_expanded is True
            enc_sae = sae.encode(act)  # Call the SAE once
            # Store the updated activation in the dictionary
            sae_outs[hook.name] = enc_sae.detach().cpu() 
            modified_act = sae.decode(enc_sae)  # Call the SAE once
            # In-place update where the mask is True
            updated_act = torch.where(mask_expanded, modified_act, act)
        
            return updated_act


        # Add the hook to the model
        model.add_hook(hook_point, filtered_hook, dir='fwd')


    # Run the model with the tokens (no gradients needed)
    with torch.no_grad():
        logits = model(tokens)


    # Reset the hooks after computation to free memory
    model.reset_hooks()


    return logits, sae_outs  # Return logits and the updated activations


def run_with_saes_latent_op_patch(new_tokens, filtered_ids, model, saes, cache, dict_feats):
   # Ensure tokens are a torch.Tensor
   if not isinstance(new_tokens, torch.Tensor):
       new_tokens = torch.tensor(new_tokens).to(model.cfg.device)  # Move to the device of the model

   # Create a mask where True indicates positions to modify
   mask = torch.ones_like(new_tokens, dtype=torch.bool)
   for token_id in filtered_ids:
       mask &= new_tokens != token_id

   # Expand the mask once, so it matches the shape [batch_size, seq_len, 1]
   mask_expanded = mask.unsqueeze(-1)  # Expand to allow broadcasting
   mask_expanded = mask_expanded.to(model.cfg.device)  # Move the mask to the same device as the model
   # For each SAE, add the appropriate hook
   for sae in saes:
       hook_point = sae.cfg.hook_name

       # Define the filtered hook function (optimized)
       def filtered_hook(act, hook, sae=sae, mask_expanded=mask_expanded):
           # Apply the SAE only where mask_expanded is True
           enc_sae = sae.encode(act)  # Call the SAE once
          
           if hook.name in cache and hook.name in dict_feats:
               prev_sae = cache[hook.name]  # Get cached activations from the cache
               feature_indices = dict_feats[hook.name]  # Get the feature indices to patch

               for feature_idx in range(sae.cfg.d_sae):
                   if feature_idx in feature_indices:
                       enc_sae[:, :, feature_idx] = prev_sae[:, :, feature_idx]

           # After patching, decode the modified enc_sae
           modified_act = sae.decode(enc_sae)

           # In-place update where the mask is True
           updated_act = torch.where(mask_expanded, modified_act, act)

           return updated_act

       # Add the hook to the model
       model.add_hook(hook_point, filtered_hook, dir='fwd')

   # Run the model with the tokens (no gradients needed)
   with torch.no_grad():
       logits = model(new_tokens)

   # Reset the hooks after computation to free memory
   model.reset_hooks()

   return logits  # Return only the logits


def run_with_saes_latent_op_patch_mean(new_tokens, filtered_ids, model, saes, mean_cache, dict_feats):
   # Ensure tokens are a torch.Tensor
   if not isinstance(new_tokens, torch.Tensor):
       new_tokens = torch.tensor(new_tokens).to(model.cfg.device)  # Move to the device of the model

   # Create a mask where True indicates positions to modify
   mask = torch.ones_like(new_tokens, dtype=torch.bool)
   for token_id in filtered_ids:
       mask &= new_tokens != token_id

   # Expand the mask once, so it matches the shape [batch_size, seq_len, 1]
   mask_expanded = mask.unsqueeze(-1)  # Expand to allow broadcasting
   mask_expanded = mask_expanded.to(model.cfg.device)  # Move the mask to the same device as the model
   # For each SAE, add the appropriate hook
   for sae in saes:
       hook_point = sae.cfg.hook_name

       # Define the filtered hook function (optimized)
       def filtered_hook(act, hook, sae=sae, mask_expanded=mask_expanded):
           # Apply the SAE only where mask_expanded is True
           enc_sae = sae.encode(act)  # Call the SAE once
          
           if hook.name in mean_cache and hook.name in dict_feats:
               prev_sae = mean_cache[hook.name]  # Get cached activations from the cache
               feature_indices = dict_feats[hook.name]  # Get the feature indices to patch

               for feature_idx in range(sae.cfg.d_sae):
                   if feature_idx not in feature_indices:
                       enc_sae[:, :, feature_idx] = prev_sae[:, feature_idx]

           # After patching, decode the modified enc_sae
           modified_act = sae.decode(enc_sae)

           # In-place update where the mask is True
           updated_act = torch.where(mask_expanded, modified_act, act)

           return updated_act

       # Add the hook to the model
       model.add_hook(hook_point, filtered_hook, dir='fwd')

   # Run the model with the tokens (no gradients needed)
   with torch.no_grad():
       logits = model(new_tokens)

   # Reset the hooks after computation to free memory
   model.reset_hooks()

   return logits  # Return only the logits


def run_with_saes_latent_op_patch_cache(new_tokens, filtered_ids, model, saes, cache, dict_feats):
   # Ensure tokens are a torch.Tensor
   if not isinstance(new_tokens, torch.Tensor):
      new_tokens = torch.tensor(new_tokens).to(model.cfg.device)  # Move to the device of the model

   # Create a mask where True indicates positions to modify
   mask = torch.ones_like(new_tokens, dtype=torch.bool)
   for token_id in filtered_ids:
      mask &= new_tokens != token_id

   # Expand the mask once, so it matches the shape [batch_size, seq_len, 1]
   mask_expanded = mask.unsqueeze(-1)  # Expand to allow broadcasting
   mask_expanded = mask_expanded.to(model.cfg.device)  # Move the mask to the same device as the model
   sae_outs = {}
   # For each SAE, add the appropriate hook
   for sae in saes:
      hook_point = sae.cfg.hook_name

      # Define the filtered hook function (optimized)
      def filtered_hook(act, hook, sae=sae, mask_expanded=mask_expanded):
         # Apply the SAE only where mask_expanded is True
         enc_sae = sae.encode(act)  # Call the SAE once
         
         if hook.name in cache and hook.name in dict_feats:
            prev_sae = cache[hook.name]  # Get cached activations from the cache
            feature_indices = dict_feats[hook.name]  # Get the feature indices to patch

            for feature_idx in range(sae.cfg.d_sae):
               if feature_idx in feature_indices:
                  enc_sae[:, :, feature_idx] = prev_sae[:, :, feature_idx]
         sae_outs[hook.name] = enc_sae.detach().cpu()
         # After patching, decode the modified enc_sae
         modified_act = sae.decode(enc_sae)

         # In-place update where the mask is True
         updated_act = torch.where(mask_expanded, modified_act, act)

         return updated_act

      # Add the hook to the model
      model.add_hook(hook_point, filtered_hook, dir='fwd')

   # Run the model with the tokens (no gradients needed)
   with torch.no_grad():
      logits = model(new_tokens)

   # Reset the hooks after computation to free memory
   model.reset_hooks()

   return logits, sae_outs  # Return only the logits
# def run_with_saes_latent_edge_patch(new_tokens, filtered_ids, model, saes, cache, sender_feats, receiver_feats):
    

In [7]:
# %% clean cache 
filtered_ids = [model.tokenizer.bos_token_id]
logits, clean_sae_cache = run_with_saes_filtered_cache(clean_tokens, filtered_ids=filtered_ids, model=model, saes=[sae21])
clean_sae_diff = logit_diff_fn(logits)
print(f"clean_sae_diff: {clean_sae_diff}")

clean_sae_diff: 0.44496282935142517


In [8]:
# %% corr cache 
logits, corr_sae_cache = run_with_saes_filtered_cache(corr_tokens, filtered_ids=filtered_ids, model=model, saes=[sae21])
corr_sae_diff = logit_diff_fn(logits)
print(f"corr_sae_diff: {corr_sae_diff}")

corr_sae_diff: -4.756180763244629


In [12]:
max(corr_sae_cache['blocks.21.hook_resid_post'][0, -1, :])

tensor(28.7968)

In [13]:

corr_sae_cache_means = {layer: sae_cache.mean(dim=0) for layer, sae_cache in corr_sae_cache.items()}

print(corr_sae_cache['blocks.21.hook_resid_post'].shape)
print(corr_sae_cache_means['blocks.21.hook_resid_post'].shape)

torch.Size([50, 67, 16384])
torch.Size([67, 16384])


In [16]:
dict_feats = {"blocks.21.hook_resid_post":[ 3237,  3271,  8156,  8233, 11219, 10711,  9585,  8535,  5795, 13756,
         3925,   430, 10160,  6290, 11774,  9654,  3377, 13275, 13163,   111,
         8311, 15795, 13862, 10350,  7892, 13218, 10618, 14598, 12223,  4766,
         5577,  3395, 11725, 14495, 15571,  7482,  9768,  9960, 13842,  2782,
         5804,  6840, 11779, 15430,  5498, 11712, 14638, 13839, 10587,  7675,
         3807,  4804,  2405, 14331,  6855,  8696,   691, 10110,  1662, 12189,
         4337,  1407,  8180, 12151,  7950, 14786,  9190, 12440,  9858, 10916,
           37, 13831,  2187,  9304, 14499,  3818, 10219,  2120,   940, 12739,
        15805,  5615,  7988,  7237,  6098,  2239,  6606,  2314,  4022, 14229,
         6646, 10719,  7381,  8680,  1436, 10023, 14613,  3863, 15777, 12609]}

In [17]:
logits = run_with_saes_latent_op_patch_mean(clean_tokens, filtered_ids=filtered_ids, model=model, saes=[sae21],mean_cache=corr_sae_cache_means,dict_feats=dict_feats)
mean_diff = logit_diff_fn(logits)
print(f"recovered diff: {mean_diff}")

recovered diff: -4.41883659362793


In [18]:
(-4.41883659362793 + 4.756180763244629 )/(0.44496282+4.418836593627)

0.06935815828908479

# Nested String:Int Dictionary Key Error 

In [4]:
import json
import random

# Expanding the name pool with a larger set of names
extended_name_pool = [
    "Bob", "Sam", "Lilly", "Rob", "Alice", "Charlie", "Sally", "Tom", "Jake", "Emily", 
    "Megan", "Chris", "Sophia", "James", "Oliver", "Isabella", "Mia", "Jackson", 
    "Emma", "Ava", "Lucas", "Benjamin", "Ethan", "Grace", "Olivia", "Liam", "Noah"
]

for name in extended_name_pool:
    assert len(model.tokenizer.encode(name)) == 2, f"Name {name} has more than 1 token"
# Function to generate the dataset for Level 2 - nested dictionaries
# Function to generate the dataset for Level 2 - nested dictionaries with unique inner keys per outer key
def generate_nested_dict_dataset_unique_inner_keys(name_pool, num_samples=5):
    dataset = []
    
    for _ in range(num_samples):
        # Randomly select 5 outer keys
        outer_keys = random.sample(name_pool, 2)
        
        # Create a nested dictionary where each outer key has a unique set of inner keys
        nested_dict = {
            outer_key: {inner_key: random.randint(10, 19) for inner_key in random.sample(name_pool, 2)}
            for outer_key in outer_keys
        }
        
        # Create a correct example
        correct_outer_key = random.choice(list(nested_dict.keys()))
        correct_inner_key = random.choice(list(nested_dict[correct_outer_key].keys()))
        correct_prompt = (
            f'Type "help", "copyright", "credits" or "license" for more information.\n'
            f'>>> age = {nested_dict}\n>>> var = age["{correct_outer_key}"]["{correct_inner_key}"]\n'
        )
        correct_response = nested_dict[correct_outer_key][correct_inner_key]
        correct_token = str(correct_response)[0]
        
        # Create an incorrect example with a non-existent inner key for an existing outer key
        incorrect_outer_key = correct_outer_key
        incorrect_inner_key = random.choice(
            [key for key in name_pool if key not in nested_dict[incorrect_outer_key]]
        )
        incorrect_prompt = (
            f'Type "help", "copyright", "credits" or "license" for more information.\n'
            f'>>> age = {nested_dict}\n>>> var = age["{incorrect_outer_key}"]["{incorrect_inner_key}"]\n'
        )
        incorrect_response = "Traceback"
        incorrect_token = "Traceback"
        
        # Append the pair of correct and incorrect examples
        dataset.append({
            "correct": {
                "prompt": correct_prompt,
                "response": correct_response,
                "token": correct_token
            },
            "error": {
                "prompt": incorrect_prompt,
                "response": incorrect_response,
                "token": incorrect_token
            }
        })
        
    return dataset

# Generate the nested dictionary dataset
json_dataset = generate_nested_dict_dataset_unique_inner_keys(extended_name_pool, num_samples=100)

print(json_dataset[0]['error']['prompt'])

# %%
clean_prompts = []
corr_prompts = []

answer_token = model.to_single_token(">>>")
traceback_token = model.to_single_token("Traceback")

for item in json_dataset[:50]:
    corr_prompts.append(item["correct"]["prompt"])
    clean_prompts.append(item["error"]["prompt"])

clean_tokens = model.to_tokens(clean_prompts)
corr_tokens = model.to_tokens(corr_prompts)

# %%
def logit_diff_fn(logits):
    err = logits[:, -1, traceback_token]
    no_err = logits[:, -1, answer_token]
    return (err - no_err).mean()

# Disable gradients for all parameters
for param in model.parameters():
   param.requires_grad_(False)

# # Compute logits for clean and corrupted samples
logits = model(clean_tokens)
clean_diff = logit_diff_fn(logits)

logits = model(corr_tokens)
corr_diff = logit_diff_fn(logits)

print(f"clean_diff: {clean_diff}")
print(f"corr_diff: {corr_diff}")

# # Cleanup
del logits
cleanup_cuda()


# # Define error type metric
def _err_type_metric(logits, clean_logit_diff, corr_logit_diff):
    patched_logit_diff = logit_diff_fn(logits)
    return (patched_logit_diff - corr_logit_diff) / (clean_logit_diff - corr_logit_diff)

err_metric_denoising = partial(_err_type_metric, clean_logit_diff=clean_diff, corr_logit_diff=corr_diff)

Type "help", "copyright", "credits" or "license" for more information.
>>> age = {'Lucas': {'Grace': 16, 'Rob': 15}, 'Sophia': {'Bob': 14, 'James': 15}}
>>> var = age["Sophia"]["Rob"]

clean_diff: 2.3436012268066406
corr_diff: -4.374980449676514


In [5]:
!nvidia-smi

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Wed Nov 27 13:30:25 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.127.05             Driver Version: 550.127.05     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          On  |   00000000:84:00.0 Off |                    0 |
| N/A   40C    P0             70W /  400W |   43469MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [6]:
clean_value, clean_cache, _ = get_cache_fwd_and_bwd(model, clean_tokens, err_metric_denoising, sae21)
print("Clean Value:", clean_value)
print("Clean Activations Cached:", len(clean_cache))
# print("Clean Gradients Cached:", len(clean_grad_cache))

corrupted_value, corrupted_cache, corrupted_grad_cache = get_cache_fwd_and_bwd(model, corr_tokens, err_metric_denoising, sae21)
print("Corrupted Value:", corrupted_value)
print("Corrupted Activations Cached:", len(corrupted_cache))
print("Corrupted Gradients Cached:", len(corrupted_grad_cache))

# # Cleanup
del clean_value, corrupted_value
cleanup_cuda()

Clean Value: tensor(1., device='cuda:0', grad_fn=<DivBackward0>)
Clean Activations Cached: 1
Corrupted Value: tensor(0., device='cuda:0', grad_fn=<DivBackward0>)
Corrupted Activations Cached: 1
Corrupted Gradients Cached: 1


In [7]:
sae_acts = sae21.encode(clean_cache['blocks.21.hook_resid_post'][:, 1:, :])
sae_acts_corr = sae21.encode(corrupted_cache['blocks.21.hook_resid_post'][:, 1:, :])
print(sae_acts.shape, sae_acts_corr.shape)

sae_grad_cache = torch.einsum('bij,kj->bik', corrupted_grad_cache['blocks.21.hook_resid_post'][:, 1:, :], sae21.W_dec)
print(sae_grad_cache.shape)

torch.Size([50, 67, 16384]) torch.Size([50, 67, 16384])
torch.Size([50, 67, 16384])


In [9]:
K = 100
# Residual attribution calculation only for the selected positions
residual_attr_final = einops.reduce(
    sae_grad_cache[:, -1, :] * (sae_acts[:, -1, :] - sae_acts_corr[:, -1, :]),
    "batch n_features -> n_features",
    "sum",
)
# Get the top K features based on the absolute values
abs_residual_attr_final = torch.abs(residual_attr_final)
top_feats = torch.topk(abs_residual_attr_final, K)
top_indices = top_feats.indices
top_values = residual_attr_final[top_indices] 
print(top_indices, top_values)

tensor([ 3237,  9585, 11219,  8535,  8156,  3271,  6290, 10711,  8233, 13756,
        13163,  4766, 15430,  9858,  2405, 15571,  3395,   430, 10160,  8696,
         1662, 12189, 14499, 13842,  5795, 13831,  9654,  4337, 15777, 10350,
        14598, 11725,  3573,  3925,  6098, 10219,  5394, 11095,  9960, 11712,
         5615, 14638, 11686,  2376,  7641,  8556, 11213,   111,  5478,  7892,
         6353, 11774,  7674,  6138, 12223, 12813, 14605,   846,  4541, 13218,
         7159,  3104,     0,     1,     2,     3,     4,     5,     6,     7,
            8,     9,    10,    11,    12,    13,    14,    15,    16,    17,
           18,    19,    20,    21,    22,    23,    24,    25,    26,    27,
           28,    29,    30,    31,    32,    33,    34,    35,    36,    37],
       device='cuda:0') tensor([ 7.7622e-02,  3.7153e-02,  2.2978e-02,  2.2002e-02,  2.0766e-02,
         1.9949e-02,  1.7510e-02,  1.7235e-02,  1.5683e-02,  1.3382e-02,
         1.0201e-02,  9.2319e-03,  8.8157e-03,  8

# Simple Alphabet:Alphabet Dictionary Key Error 

In [10]:
import json
import random
import string

# Function to generate Level 1 dataset
def generate_level_1_dataset(num_samples=5):
    dataset = []

    for _ in range(num_samples):
        # Generate a dictionary of the form {'a': 'b', 'c': 'd', 'e': 'f'}
        keys = random.sample(string.ascii_lowercase, 3)
        values = random.sample([ch for ch in string.ascii_lowercase if ch not in keys], 3)
        alphabet_dict = dict(zip(keys, values))

        # Create a correct example
        correct_key = random.choice(keys)
        correct_prompt = (
            f'Type "help", "copyright", "credits" or "license" for more information.\n'
            f'>>> mapping = {alphabet_dict}\n>>> var = mapping["{correct_key}"]\n'
        )
        correct_response = alphabet_dict[correct_key]
        correct_token = str(correct_response)[0]

        # Create an incorrect example (key neither in keys nor values)
        non_key_non_value = random.choice(
            [ch for ch in string.ascii_lowercase if ch not in keys and ch not in values]
        )
        incorrect_prompt = (
            f'Type "help", "copyright", "credits" or "license" for more information.\n'
            f'>>> mapping = {alphabet_dict}\n>>> var = mapping["{non_key_non_value}"]\n'
        )
        incorrect_response = "Traceback"
        incorrect_token = "Traceback"

        # Append both correct and incorrect examples
        dataset.append({
            "correct": {
                "prompt": correct_prompt,
                "response": correct_response,
                "token": correct_token
            },
            "error": {
                "prompt": incorrect_prompt,
                "response": incorrect_response,
                "token": incorrect_token
            }
        })

    return dataset





# Generate the datasets
level_1_dataset = generate_level_1_dataset(num_samples=100)
print(level_1_dataset[0]['error']['prompt'])


Type "help", "copyright", "credits" or "license" for more information.
>>> mapping = {'i': 's', 'd': 'a', 'x': 'n'}
>>> var = mapping["e"]



In [12]:
# %%
clean_prompts = []
corr_prompts = []

answer_token = model.to_single_token(">>>")
traceback_token = model.to_single_token("Traceback")

for item in level_1_dataset[:50]:
    corr_prompts.append(item["correct"]["prompt"])
    clean_prompts.append(item["error"]["prompt"])

clean_tokens = model.to_tokens(clean_prompts)
corr_tokens = model.to_tokens(corr_prompts)

# %%
def logit_diff_fn(logits):
    err = logits[:, -1, traceback_token]
    no_err = logits[:, -1, answer_token]
    return (err - no_err).mean()

# Disable gradients for all parameters
for param in model.parameters():
   param.requires_grad_(False)

# # Compute logits for clean and corrupted samples
logits = model(clean_tokens)
clean_diff = logit_diff_fn(logits)

logits = model(corr_tokens)
corr_diff = logit_diff_fn(logits)

print(f"clean_diff: {clean_diff}")
print(f"corr_diff: {corr_diff}")

# # Cleanup
del logits
cleanup_cuda()

# # Define error type metric
def _err_type_metric(logits, clean_logit_diff, corr_logit_diff):
    patched_logit_diff = logit_diff_fn(logits)
    return (patched_logit_diff - corr_logit_diff) / (clean_logit_diff - corr_logit_diff)

err_metric_denoising = partial(_err_type_metric, clean_logit_diff=clean_diff, corr_logit_diff=corr_diff)

clean_diff: 1.7985607385635376
corr_diff: -3.763065814971924


# Adversarial Alphabet:Alphabet Dictionary Key Error

In [11]:
# Function to generate Level 2 dataset
def generate_level_2_dataset(num_samples=5):
    dataset = []

    for _ in range(num_samples):
        # Generate a dictionary of the form {'a': 'b', 'c': 'd', 'e': 'f'}
        keys = random.sample(string.ascii_lowercase, 3)
        values = random.sample([ch for ch in string.ascii_lowercase if ch not in keys], 3)
        alphabet_dict = dict(zip(keys, values))

        # Create a correct example
        correct_key = random.choice(keys)
        correct_prompt = (
            f'Type "help", "copyright", "credits" or "license" for more information.\n'
            f'>>> mapping = {alphabet_dict}\n>>> var = mapping["{correct_key}"]\n'
        )
        correct_response = alphabet_dict[correct_key]
        correct_token = str(correct_response)[0]

        # Create an incorrect example (key not in keys but IS in values)
        non_key_but_value = random.choice(values)
        incorrect_prompt = (
            f'Type "help", "copyright", "credits" or "license" for more information.\n'
            f'>>> mapping = {alphabet_dict}\n>>> var = mapping["{non_key_but_value}"]\n'
        )
        incorrect_response = "Traceback"
        incorrect_token = "Traceback"

        # Append both correct and incorrect examples
        dataset.append({
            "correct": {
                "prompt": correct_prompt,
                "response": correct_response,
                "token": correct_token
            },
            "error": {
                "prompt": incorrect_prompt,
                "response": incorrect_response,
                "token": incorrect_token
            }
        })

    return dataset
level_2_dataset = generate_level_2_dataset(num_samples=100)
print(level_2_dataset[0]['error']['prompt'])

Type "help", "copyright", "credits" or "license" for more information.
>>> mapping = {'c': 'i', 'k': 'a', 't': 'z'}
>>> var = mapping["i"]



In [13]:
# %%
clean_prompts = []
corr_prompts = []

answer_token = model.to_single_token(">>>")
traceback_token = model.to_single_token("Traceback")

for item in level_2_dataset[:50]:
    corr_prompts.append(item["correct"]["prompt"])
    clean_prompts.append(item["error"]["prompt"])

clean_tokens = model.to_tokens(clean_prompts)
corr_tokens = model.to_tokens(corr_prompts)

# %%
def logit_diff_fn(logits):
    err = logits[:, -1, traceback_token]
    no_err = logits[:, -1, answer_token]
    return (err - no_err).mean()

# Disable gradients for all parameters
for param in model.parameters():
   param.requires_grad_(False)

# # Compute logits for clean and corrupted samples
logits = model(clean_tokens)
clean_diff = logit_diff_fn(logits)

logits = model(corr_tokens)
corr_diff = logit_diff_fn(logits)

print(f"clean_diff: {clean_diff}")
print(f"corr_diff: {corr_diff}")

# # Cleanup
del logits
cleanup_cuda()

# # Define error type metric
def _err_type_metric(logits, clean_logit_diff, corr_logit_diff):
    patched_logit_diff = logit_diff_fn(logits)
    return (patched_logit_diff - corr_logit_diff) / (clean_logit_diff - corr_logit_diff)

err_metric_denoising = partial(_err_type_metric, clean_logit_diff=clean_diff, corr_logit_diff=corr_diff)

clean_diff: -1.2145028114318848
corr_diff: -3.7032835483551025


# IoU

In [14]:
simple = [ 3237,  3271,  8156,  8233, 11219, 10711,  9585,  8535,  5795, 13756,
         3925,   430, 10160,  6290, 11774,  9654,  3377, 13275, 13163,   111,
         8311, 15795, 13862, 10350,  7892, 13218, 10618, 14598, 12223,  4766,
         5577,  3395, 11725, 14495, 15571,  7482,  9768,  9960, 13842,  2782,
         5804,  6840, 11779, 15430,  5498, 11712, 14638, 13839, 10587,  7675,
         3807,  4804,  2405, 14331,  6855,  8696,   691, 10110,  1662, 12189,
         4337,  1407,  8180, 12151,  7950, 14786,  9190, 12440,  9858, 10916,
           37, 13831,  2187,  9304, 14499,  3818, 10219,  2120,   940, 12739,
        15805,  5615,  7988,  7237,  6098,  2239,  6606,  2314,  4022, 14229,
         6646, 10719,  7381,  8680,  1436, 10023, 14613,  3863, 15777, 12609]
nested = [ 3237, 11219,  9585,  3271,  8535,  8233,  8156,  6290, 10711, 13756,
         3395, 13163, 13275,  4766, 15430,  9858,  6855, 10160, 11774,   430,
        14495, 15795, 13862,  2405, 15571,  2782, 10618, 13218,  4022, 12223,
         8311,  3925,  6606,  3807, 14499,  8696,  1662,  7950,  4804,  9654,
         9768, 11725, 13842, 12189,  5804,    37,  5795,  5870, 16071,  7237,
         6840, 14229,   691, 13831, 11083,  4337,  2244,  5577, 15777, 12739,
        10350, 12440, 16045,  1815, 14605, 10219, 13839,  3573, 14638,  8498,
         5394, 11095, 13359, 14943,  9960, 14598,  6602, 10719,  8180,  9662,
         4314,  7159, 14077, 11779, 12151, 11712, 12531,  3841,  5498,  1620,
        11112,  5615, 10110,   524, 11609, 10023,  6098,  5875,  6503,  7482]

def calculate_iou(list1, list2):
    """
    Calculate the Intersection over Union (IoU) of two lists of integers.

    Args:
        list1 (list): First list of integers.
        list2 (list): Second list of integers.

    Returns:
        float: IoU value.
    """
    # Convert lists to sets
    set1 = set(list1)
    set2 = set(list2)
    
    # Calculate intersection and union
    intersection = set1.intersection(set2)
    union = set1.union(set2)
    
    # Compute IoU
    iou = len(intersection) / len(union) if len(union) > 0 else 0
    return iou

iou = calculate_iou(simple, nested)
print(f"IoU: {iou}")

IoU: 0.5873015873015873
