In [1]:
# requires
# pip install git+https://github.com/Phylliida/MambaLens.git

from mamba_lens import HookedMamba # this will take a little while to import

model_path = "state-spaces/mamba-370m"
model = HookedMamba.from_pretrained(model_path, device='cuda')

  return self.fget.__get__(instance, owner)()
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Moving model to device:  cuda


In [2]:
# remember to do
# pip install -e .
# in the root directory of this repo
# also
# to install graphviz:
# sudo apt-get update
# sudo apt-get install graphviz xdg-utils

from acdc.data.ioi import ioi_data_generator, ABC_TEMPLATES, get_all_single_name_abc_patching_formats
from acdc.data.utils import generate_dataset

num_patching_pairs = 50
seed = 27
valid_seed = 28
constrain_to_answers = True
has_symmetric_patching = True

templates = ABC_TEMPLATES
patching_formats = list(get_all_single_name_abc_patching_formats())

data = generate_dataset(model=model,
                  data_generator=ioi_data_generator,
                  num_patching_pairs=num_patching_pairs,
                  seed=seed,
                  valid_seed=valid_seed,
                  constrain_to_answers=constrain_to_answers,
                  has_symmetric_patching=has_symmetric_patching, 
                  varying_data_lengths=True,
                  templates=templates,
                  patching_formats=patching_formats)

names ['Aaron', 'Abraham', 'Adam', 'Adrian', 'Alan', 'Alexander', 'Alexandria', 'Allison', 'Amanda', 'Amber', 'Amy', 'Ana', 'Andrea', 'Andrew', 'Angela', 'Anna', 'Anthony', 'Antonio', 'Ashley', 'Austin', 'Bailey', 'Benjamin', 'Blake', 'Bradley', 'Brady', 'Brandon', 'Brian', 'Brooklyn', 'Bryan', 'Cameron', 'Carlos', 'Caroline', 'Carson', 'Carter', 'Catherine', 'Charles', 'Charlotte', 'Chase', 'Chelsea', 'Chloe', 'Christian', 'Christina', 'Christopher', 'Claire', 'Cody', 'Cole', 'Colin', 'Connor', 'Cooper', 'Dakota', 'Dalton', 'Daniel', 'David', 'Derek', 'Devon', 'Diana', 'Diego', 'Donovan', 'Dylan', 'Edgar', 'Edward', 'Edwin', 'Elizabeth', 'Emily', 'Emmanuel', 'Eric', 'Erik', 'Erin', 'Ethan', 'Evan', 'Faith', 'Fernando', 'Francisco', 'Gabriel', 'Garrett', 'Gavin', 'Genesis', 'George', 'Giovanni', 'Grace', 'Grant', 'Gregory', 'Hannah', 'Henry', 'Hunter', 'Ian', 'Isaac', 'Isabel', 'Isaiah', 'Ivan', 'Jackson', 'Jacob', 'Jake', 'James', 'Jared', 'Jason', 'Jeffrey', 'Jennifer', 'Jeremy', 'Je

In [4]:
import torch
from acdc import get_pad_token

print("printing example data points:")
for b in range(10):
    pad_token = get_pad_token(model.tokenizer)
    # because there is padding if lengths vary, this only fetches the tokens that are part of the sequence
    toks = data.data[b][:data.last_token_position[b]+1]
    print(model.tokenizer.decode(toks))
    for ind, tok in enumerate(data.correct[b]):
        if tok != pad_token:
            print(f"  correct answer: {repr(model.tokenizer.decode([tok.item()]))}")
    for ind, tok in enumerate(data.incorrect[b]):
        if tok != pad_token:
            print(f"  incorrect answer: {repr(model.tokenizer.decode([tok.item()]))}")

from acdc import accuracy_metric
from acdc import ACDCEvalData
from acdc import get_pad_token
def logging_incorrect_metric(data: ACDCEvalData):
    pad_token = get_pad_token(model.tokenizer)
    for data_subset in [data.patched, data.corrupted]:
        batch, _ = data_subset.data.size()
        for b in range(batch):
            if not data_subset.top_is_correct[b].item():
                toks = data_subset.data[b][:data_subset.last_token_position[b]+1]
                print("failed on this data point:")
                print(model.tokenizer.decode(toks))
                print("correct prs:")
                for i, tok in enumerate(data_subset.correct[b]):
                    if tok.item() != pad_token:
                        print(data_subset.correct_prs[b,i].item(), model.tokenizer.decode([tok.item()]))
                print("incorrect prs:")
                for i, tok in enumerate(data_subset.incorrect[b]):
                    if tok.item() != pad_token:
                        print(data_subset.incorrect_prs[b,i].item(), model.tokenizer.decode([tok.item()]))
    return data.patched.top_is_correct

top_is_correct = data.eval(model=model, batch_size=10, metric=logging_incorrect_metric)
accuracy = top_is_correct.sum().item()/top_is_correct.size()[0]
print(f"accuracy: {accuracy}")


printing example data points:
<|endoftext|>Friends Donovan, Diana and Brian went to the house. Donovan and Diana gave a apple to
  correct answer: ' Brian'
  incorrect answer: ' Diana'
  incorrect answer: ' Donovan'
  incorrect answer: ' Lily'
<|endoftext|>Friends Donovan, Diana and Lily went to the house. Donovan and Diana gave a apple to
  correct answer: ' Lily'
  incorrect answer: ' Brian'
  incorrect answer: ' Diana'
  incorrect answer: ' Donovan'
<|endoftext|>When Henry, Catherine and Jordan arrived at the office, Henry and Jordan gave a apple to
  correct answer: ' Catherine'
  incorrect answer: ' Henry'
  incorrect answer: ' Jordan'
  incorrect answer: ' Owen'
<|endoftext|>When Henry, Owen and Jordan arrived at the office, Henry and Jordan gave a apple to
  correct answer: ' Owen'
  incorrect answer: ' Catherine'
  incorrect answer: ' Henry'
  incorrect answer: ' Jordan'
<|endoftext|>When Miranda, Edgar and Edwin arrived at the office, Edgar and Edwin gave a kiss to
  correct a

In [5]:
from transformer_lens.hook_points import HookPoint
from acdc import Edge, ACDCConfig, LOG_LEVEL_INFO, LOG_LEVEL_DEBUG, run_acdc

global storage
storage = {}
def storage_hook(
    x,
    hook: HookPoint,
    **kwargs,
):
    global storage
    storage[hook.name] = x
    return x

def resid_patching_hook(
    x,
    hook: HookPoint,
    input_hook_name: str,
    batch_start: int,
    batch_end: int,
):
    global storage
    x_uncorrupted = storage[input_hook_name][batch_start:batch_end:2]
    x_corrupted = storage[input_hook_name][batch_start+1:batch_end:2]
    x[batch_start:batch_end:2] = x[batch_start:batch_end:2] - x_uncorrupted + x_corrupted
    return x

layers = list(range(model.cfg.n_layers))

## Setup edges for ACDC
edges = []

B,L = data.data.size()

INPUT_HOOK = f'hook_embed'
INPUT_NODE = 'input'

last_layer = max(layers)
OUTPUT_HOOK = f'blocks.{last_layer}.hook_resid_post'
OUTPUT_NODE = 'output'

def layer_node(layer):
    return f'{layer}'

# direct connection from embed to output
edges.append(Edge(
        input_node=INPUT_NODE,
        input_hook=(INPUT_HOOK, storage_hook),
        output_node=OUTPUT_NODE,
        output_hook=(OUTPUT_HOOK, resid_patching_hook),
))

for layer in layers:
    # edge from embed to layer input
    edges.append(Edge(
            input_node=INPUT_NODE,
            input_hook=(INPUT_HOOK, storage_hook),
            output_node=layer_node(layer),
            output_hook=(f'blocks.{layer}.hook_layer_input', resid_patching_hook),
    ))

    # edge from some other earlier layer to this layer
    for other_layer in layers:
        if other_layer < layer:
            edges.append(Edge(
                    input_node=layer_node(other_layer),
                    input_hook=(f'blocks.{other_layer}.hook_out_proj', storage_hook),
                    output_node=layer_node(layer),
                    output_hook=(f'blocks.{layer}.hook_layer_input', resid_patching_hook),
            ))

    # edge from layer output to final layer output
    edges.append(Edge(
            input_node=layer_node(layer),
            input_hook=(f'blocks.{layer}.hook_out_proj', storage_hook),
            output_node=OUTPUT_NODE,
            output_hook=(OUTPUT_HOOK, resid_patching_hook),
    ))

model_kwargs = {
    'fast_ssm': True,
    'fast_conv': True,
}

cfg = ACDCConfig(
    thresh = 0.00001,
    rollback_thresh = 0.00001,
    metric=accuracy_metric,
    model_kwargs=model_kwargs,
    input_node=INPUT_NODE,
    output_node=OUTPUT_NODE,
    auto_hide_unused_default_edges=True,
    batch_size=3,
    log_level=LOG_LEVEL_INFO,
    batched = True,
    recursive = True,
    try_patching_multiple_at_same_time = True,
)

result_edges = run_acdc(model=model, data=data, cfg=cfg, edges=edges)




TypeError: 'str' object is not callable