In [1]:
import os
import sys
from typing import Any
from einops import rearrange
from fancy_einsum import einsum
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformer_lens import HookedTransformer
from mamba_lens import HookedMamba
import os
from tqdm import tqdm

torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

model = HookedMamba.from_pretrained("state-spaces/mamba-130m", device=device)

tokenizer = model.tokenizer
vocab = tokenizer.vocab

  from .autonotebook import tqdm as notebook_tqdm


Using cuda device
Moving model to device:  cuda


In [18]:
batch_size = 4
N = 128 // batch_size
seq_len = 1 + 128

specific_patch_layer = 17

names_filter = []
for layer in range(model.cfg.n_layers):
    names_filter.append(f'blocks.{layer}.hook_h.{0}')
    names_filter.append(f'blocks.{layer}.hook_h.{seq_len - 1}')
    names_filter.append(f'blocks.{layer}.hook_A_bar')
    names_filter.append(f'blocks.{layer}.hook_B_bar')
    names_filter.append(f'blocks.{layer}.hook_ssm_input')
    
    
patch_names_filter = []
for layer in [specific_patch_layer]:
    patch_names_filter.append(f'blocks.{layer}.hook_A_bar')
    patch_names_filter.append(f'blocks.{layer}.hook_B_bar')
    patch_names_filter.append(f'blocks.{layer}.hook_ssm_input')
    # patch_names_filter.append(f'blocks.{layer}.hook_h.{0}')
    # patch_names_filter.append(f'blocks.{layer}.hook_h.{2}')
    # patch_names_filter.append(f'blocks.{layer}.hook_h.{3}')
    # patch_names_filter.append(f'blocks.{layer}.hook_h.{4}')

sum_real_logit_new = torch.zeros([model.cfg.d_conv]).to(device)
sum_fake_logit_new = torch.zeros([model.cfg.d_conv]).to(device)
sum_real_logit_orig = torch.zeros([model.cfg.d_conv]).to(device)
sum_fake_logit_orig = torch.zeros([model.cfg.d_conv]).to(device)

for batch in tqdm(range(N)):
    random_input = torch.randint(0, len(vocab), size=(batch_size, seq_len)).to(device)
    first_A_index = 0
    B_index = 1
    second_A_index = seq_len - 1

    A_values = random_input[:, first_A_index]
    B_values = random_input[:, B_index]
    patch_B_values = torch.randint(0, len(vocab), size=(batch_size,)).to(device)
    induction_input = random_input.clone()
    induction_input[:, second_A_index] = A_values
    patch_input = induction_input.clone()
    patch_input[:, B_index] = patch_B_values
    
    logits, cache = model.run_with_cache(induction_input, names_filter=names_filter, fast_ssm=False, fast_conv=True, warn_disabled_hooks=False)
    _, patch_cache = model.run_with_cache(patch_input, names_filter=patch_names_filter, fast_ssm=False, fast_conv=True, warn_disabled_hooks=False)
    real_logit_orig = logits[:, second_A_index][torch.arange(logits.shape[0]), B_values]
    fake_logit_orig = logits[:, second_A_index][torch.arange(logits.shape[0]), patch_B_values]
    sum_real_logit_orig += real_logit_orig.sum()
    sum_fake_logit_orig += fake_logit_orig.sum()
    def generate_replacement_hook(layer, patch_layer, cache, patch_cache, pos):
        if layer == patch_layer:
            def replacement_hook(activations: torch.Tensor, hook: Any):
                activations = torch.zeros_like(cache[f'blocks.{layer}.hook_h.{0}'])
                for p in range(0, seq_len):
                    if p != pos:
                        activations = activations * cache[f'blocks.{layer}.hook_A_bar'][:,p,:,:] + cache[f'blocks.{layer}.hook_B_bar'][:,p,:,:] * cache[f'blocks.{layer}.hook_ssm_input'][:,p].view(batch_size, model.cfg.d_inner, 1)
                    else:
                        activations = activations * patch_cache[f'blocks.{layer}.hook_A_bar'][:,p+1,:,:] + patch_cache[f'blocks.{layer}.hook_B_bar'][:,p,:,:] * patch_cache[f'blocks.{layer}.hook_ssm_input'][:,p].view(batch_size, model.cfg.d_inner, 1)
                return activations
        else:
            def replacement_hook(activations: torch.Tensor, hook: Any):
                activations = cache[f'blocks.{layer}.hook_h.{seq_len - 1}']
                return activations
        return replacement_hook
    
    for pos in range(1, 5):
        fwd_hooks = []
        for l in range(model.cfg.n_layers):
            hook_name = f'blocks.{l}.hook_h.{seq_len - 1}'
            fwd_hooks.append((hook_name, generate_replacement_hook(l, patch_layer=specific_patch_layer, cache=cache, patch_cache=patch_cache, pos=pos)))
        logits = model.run_with_hooks(induction_input, fwd_hooks=fwd_hooks, fast_ssm=False, fast_conv=True, warn_disabled_hooks=False)
        real_logit_new = logits[:, second_A_index][torch.arange(logits.shape[0]), B_values]
        fake_logit_new = logits[:, second_A_index][torch.arange(logits.shape[0]), patch_B_values]
        sum_real_logit_new[pos - 1] += real_logit_new.sum()
        sum_fake_logit_new[pos - 1] += fake_logit_new.sum()

mean_real_logit_new = sum_real_logit_new / (N * batch_size)
mean_fake_logit_new = sum_fake_logit_new / (N * batch_size)
mean_real_logit_orig = sum_real_logit_orig / (N * batch_size)
mean_fake_logit_orig = sum_fake_logit_orig / (N * batch_size)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [01:22<00:00,  2.59s/it]


In [19]:
rounded_diff = [0.0] + [round(x.item(), 2) for x in mean_real_logit_new - mean_real_logit_orig]
print(rounded_diff)

[0.0, -0.01, -9.13, 0.52, 0.16]
