In [10]:
import transformer_lens
import sae_lens
from transformer_lens import HookedTransformer
from sae_lens import SAE
import torch
from utils.ioi_dataset import IOIDataset
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device: ", device)

Using device:  cuda


In [8]:
model = HookedTransformer.from_pretrained('gpt2-small', device=device)

Loaded pretrained model gpt2-small into HookedTransformer


In [7]:
sae, _, _ = SAE.from_pretrained(release='gpt2-small-resid-post-v5-32k', sae_id="blocks.6.hook_resid_post", device=device)

In [11]:
ds_base= IOIDataset('mixed', N=200, tokenizer=model.tokenizer)
abc_dataset_base = (  # TODO seeded
    ds_base.gen_flipped_prompts(("IO", "RAND"))
    .gen_flipped_prompts(("S", "RAND"))
    .gen_flipped_prompts(("S1", "RAND"))
    )



In [25]:
# helper functions 
from jaxtyping import Bool, Float
from torch import Tensor
import gc

def logits_to_ave_logit_diff_2(
    logits: Float[Tensor, "batch seq d_vocab"],
    ioi_dataset: IOIDataset ,
    per_prompt=False
) -> Float[Tensor, "*batch"]:
    '''
    Returns logit difference between the correct and incorrect answer.

    If per_prompt=True, return the array of differences rather than the average.
    '''

    # Only the final logits are relevant for the answer
    # Get the logits corresponding to the indirect object / subject tokens respectively
    io_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), ioi_dataset.word_idx["end"], ioi_dataset.io_tokenIDs]
    s_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), ioi_dataset.word_idx["end"], ioi_dataset.s_tokenIDs]
    # Find logit difference
    answer_logit_diff = io_logits - s_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

def cleanup_cuda():
   torch.cuda.empty_cache()
   gc.collect()

def run_with_saes_cache(tokens, model, sae):
    model.reset_hooks()
    sae_outs = {}
    hook_point = sae.cfg.hook_name

    # Define the filtered hook function (optimized)
    def filtered_hook(act, hook, sae=sae):
        enc_sae = sae.encode(act) 
        sae_outs[hook.name] = enc_sae.detach().cpu() 
        modified_act = sae.decode(enc_sae) 
        return modified_act

    model.add_hook(hook_point, filtered_hook, dir='fwd')
    # Run the model with the tokens (no gradients needed)
    with torch.no_grad():
        logits = model(tokens)

    # Reset the hooks after computation to free memory
    model.reset_hooks()
    return logits, sae_outs 

def run_with_patched_saes(tokens, model, sae, cache, feature_idx):
    model.reset_hooks()
    hook_point = sae.cfg.hook_name

    # Define the filtered hook function (optimized)
    def filtered_hook(act, hook, sae=sae):
        enc_sae = sae.encode(act) 
        enc_sae[:, :, feature_idx] = cache[hook.name][:, :, feature_idx]
        modified_act = sae.decode(enc_sae) 
        return modified_act

    model.add_hook(hook_point, filtered_hook, dir='fwd')
    # Run the model with the tokens (no gradients needed)
    with torch.no_grad():
        logits = model(tokens)

    # Reset the hooks after computation to free memory
    model.reset_hooks()
    return logits


In [23]:
# clean cache 
cleanup_cuda()
logits, sae_outs = run_with_saes_cache(ds_base.toks, model, sae)
clean_logit_diff = logits_to_ave_logit_diff_2(logits, ds_base)
print(clean_logit_diff)
del logits
cleanup_cuda()

tensor(3.4980, device='cuda:0')


In [29]:
# corr cache 
cleanup_cuda()
logits, sae_outs_corr = run_with_saes_cache(abc_dataset_base.toks, model, sae)
corr_logit_diff = logits_to_ave_logit_diff_2(logits, ds_base)
print(corr_logit_diff)
del logits
cleanup_cuda()

tensor(0.0873, device='cuda:0')


In [35]:
# patching clean to corr - denoising
cleanup_cuda()
from tqdm import tqdm
total_steps = sae.cfg.d_sae
denoising_results = {}
with tqdm(total=total_steps, desc="Denoising Progress") as pbar:
    for latent in range(sae.cfg.d_sae):
        
        filtered_ids = [model.tokenizer.bos_token_id] 
        logits = run_with_patched_saes(abc_dataset_base.toks, model, sae, sae_outs, latent)
        patched_err_metric = logits_to_ave_logit_diff_2(logits, ds_base)
        normalized_metric = (patched_err_metric - corr_logit_diff) / (clean_logit_diff - corr_logit_diff)
        if normalized_metric>0.1:
            print(f"Latent: {latent}")
            print(f"Error Metric: {normalized_metric}")
        denoising_results[latent] = normalized_metric.detach().cpu().item()
        pbar.update(1)

Denoising Progress:  20%|█▉        | 6444/32768 [11:39<47:37,  9.21it/s]


KeyboardInterrupt: 

In [36]:
# find the max index and value of denoising results
max_idx = max(denoising_results, key=denoising_results.get)
max_val = denoising_results[max_idx]
print(f"Max Index: {max_idx}, Max Value: {max_val}")

Max Index: 3527, Max Value: 0.004557418636977673


In [19]:
sae_outs['blocks.6.hook_resid_post'].shape

torch.Size([200, 21, 32768])

In [37]:
denoising_results

{0: 0.0,
 1: 0.0,
 2: 0.0,
 3: 0.0,
 4: 0.0,
 5: 0.0,
 6: 0.0,
 7: 0.0,
 8: 0.0,
 9: 0.0,
 10: 0.0,
 11: 0.0,
 12: 0.0,
 13: 0.0,
 14: -7.901399840193335e-06,
 15: -5.395758080339874e-07,
 16: 0.0,
 17: 0.0,
 18: 0.0,
 19: 0.0,
 20: 0.0,
 21: -8.454082376374572e-07,
 22: 0.0,
 23: 0.0,
 24: 0.0,
 25: 0.0,
 26: -2.6978789264830993e-06,
 27: 0.0,
 28: 0.0,
 29: 0.0,
 30: 0.0,
 31: 0.0,
 32: 0.0,
 33: 0.0,
 34: 0.0,
 35: 0.0,
 36: 0.0,
 37: 0.0,
 38: 0.0,
 39: 0.0,
 40: 0.0,
 41: 0.0,
 42: 0.00014889889280311763,
 43: 0.0,
 44: 0.0,
 45: 0.0,
 46: 0.0,
 47: 0.0,
 48: 0.0,
 49: 0.0,
 50: -0.00012352790508884937,
 51: 0.0,
 52: 0.0,
 53: 0.0,
 54: 0.0,
 55: 0.0,
 56: 0.0,
 57: 0.0,
 58: 0.0,
 59: 0.0,
 60: -1.0656075573933776e-05,
 61: 0.0,
 62: 0.0,
 63: 0.0,
 64: -0.00039657947490923107,
 65: 0.0,
 66: 0.0,
 67: 0.0,
 68: 0.0,
 69: 0.0,
 70: 0.0,
 71: 0.0,
 72: 0.0,
 73: 0.0,
 74: 0.0,
 75: 0.0,
 76: 0.0,
 77: 0.0,
 78: 0.0,
 79: 0.0,
 80: 0.0,
 81: -1.184008397103753e-06,
 82: 0.0,
 83: 