In [1]:
import torch

from sklearn.neural_network import MLPRegressor
from sklearn.linear_model import LinearRegression
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

from transformer_lens import HookedTransformer, HookedTransformerConfig
import transformer_lens.utils as utils

from tree_generation import *
from utils import *
from interp_utils import *
from probing import *
from sparse_coding import *

ModuleNotFoundError: No module named 'torch'

### Load Model

In [None]:
n_examples = 300_000
n_states = 16

dataset = GraphDataset(n_states, "dataset.txt", n_examples)
dataset.visualize_example(0)

In [None]:
cfg = HookedTransformerConfig(
    n_layers=6,
    d_model=128,
    n_ctx=dataset.max_seq_length - 1,
    n_heads=1,
    d_mlp=512,
    d_head=128,
    #attn_only=True,
    d_vocab=len(dataset.idx2tokens),
    device="cuda",
    attention_dir= "causal",
    act_fn="gelu",
)
model = HookedTransformer(cfg)


# Load in the model if weights are in the directory, else train new model
if os.path.exists("model.pt"):
    model.load_state_dict(torch.load("model.pt"))

In [None]:
import random

random_seed = np.random.randint(1_000_000, 1_000_000_000)
pred = generate_example(n_states, random_seed, order="backward")
if is_model_correct(model, dataset, pred):
    parse_example(pred)
labels,cache = get_example_cache(pred, model, dataset)
labels= [f'N{i}_{component}' for i, component in enumerate(labels)]

In [None]:
for l in range(model.cfg.n_layers):
    for h in range(model.cfg.n_heads):
        fig = display_head(cache, labels, l, h, show=True)


In [None]:
import random

random_seed = np.random.randint(1_000_000, 1_000_000_000)
pred = "14>15,0>2,7>14,5>6,5>0,4>7,12>13,8>5,8>4,9>12,9>8,1>3,10>9,10>1,11>10|6:11>10>9>8>5>6"
if is_model_correct(model, dataset, pred):
    parse_example(pred)
labels,cache = get_example_cache(pred, model, dataset)
labels= [f'N{i}_{component}' for i, component in enumerate(labels)]

In [None]:

for l in range(model.cfg.n_layers):
    for h in range(model.cfg.n_heads):
        fig = display_head(cache, labels, l, h, show=True)

In [None]:

clean_prompt = "0>1,1>2,2>3,3>4,4>5,5>6,6>7,7>8,8>9,9>10,10>11,11>12,12>13,13>14,14>15|15:0>1>2>3>4>5>6>7>8>9>10>11>12>13>14>15"

corrupted_prompt = "0>1,1>2,2>3,3>4,4>5,4>6,5>7,7>8,8>9,9>10,10>11,11>12,12>13,13>14,14>15|15:0>1>2>3>4>5>7>8>9>10>11>12>13>14>15"

parse_example(clean_prompt)
plt.show()
parse_example(corrupted_prompt)
plt.show()
# Tokenize
clean_tokens = torch.from_numpy(dataset.tokenize(clean_prompt)[0:-1]).cuda()
corrupted_tokens = torch.from_numpy(dataset.tokenize(corrupted_prompt)[0:-1]).cuda()


clean_prompt_backwards = "14>15,13>14,12>13,11>12,10>11,9>10,8>9,7>8,6>7,5>6,4>5,3>4,2>3,1>2,0>1|15:0>1>2>3>4>5>6>7>8>9>10>11>12>13>14>15"

corrupted_prompt_backwards = "14>15,13>14,12>13,11>12,10>11,9>10,8>9,7>8,5>7,4>6,4>5,3>4,2>3,1>2,0>1|15:0>1>2>3>4>5>7>8>9>10>11>12>13>14>15"

parse_example(clean_prompt_backwards)
plt.show()
parse_example(corrupted_prompt_backwards)
plt.show()
# Tokenize
clean_tokens_backwards  = torch.from_numpy(dataset.tokenize(clean_prompt_backwards )[0:-1]).cuda()
corrupted_tokens_backwards  = torch.from_numpy(dataset.tokenize(corrupted_prompt_backwards )[0:-1]).cuda()

patching_result = activation_patching(model, dataset, clean_tokens, corrupted_tokens, 46 + 7 )
patching_result_backwards  = activation_patching(model, dataset, clean_tokens_backwards , corrupted_tokens_backwards , 46 + 7 )

In [None]:
plot_activations(patching_result,clean_tokens,dataset)

In [None]:
plot_activations(patching_result_backwards,clean_tokens_backwards,dataset)

In [None]:

def activation_patching_register(model, dataset, clean_tokens, corrupted_tokens, comparison_index,positions):
    # We run on the clean prompt with the cache so we store activations to patch in later.
    clean_logits, clean_cache = model.run_with_cache(clean_tokens)
    clean_logit_diff = logits_to_logit_diff(clean_tokens, corrupted_tokens, clean_logits, comparison_index)
    print(f"Clean logit difference: {clean_logit_diff.item():.3f}")

    # We don't need to cache on the corrupted prompt.
    corrupted_logits = model(corrupted_tokens)
    corrupted_logit_diff = logits_to_logit_diff(clean_tokens, corrupted_tokens, corrupted_logits, comparison_index)
    print(f"Corrupted logit difference: {corrupted_logit_diff.item():.3f}")
    print(f"Positive Direction: {dataset.idx2tokens[clean_tokens[comparison_index]]}")
    print(f"Negative Direction: {dataset.idx2tokens[corrupted_tokens[comparison_index]]}")

    def residual_stream_patching_hook(
        resid_pre,
        hook,
        positions):
        # Each HookPoint has a name attribute giving the name of the hook.
        clean_resid_pre = clean_cache[hook.name]
        for position in positions:
            resid_pre[:, position, :] = clean_resid_pre[:, position, :]
        return resid_pre
    # We make a tensor to store the results for each patching run. We put it on the model's device to avoid needing to move things between the GPU and CPU, which can be slow.
    num_positions = clean_tokens.shape[0]
    patching_result = torch.zeros((model.cfg.n_layers), device=model.cfg.device)
    for layer in tqdm_auto.tqdm(range(model.cfg.n_layers)):
        # Use functools.partial to create a temporary hook function with the position fixed
        temp_hook_fn = partial(residual_stream_patching_hook, positions=positions)
        # Run the model with the patching hook
        patched_logits = model.run_with_hooks(corrupted_tokens, fwd_hooks=[
            (tl_util.get_act_name("resid_pre", layer), temp_hook_fn)
        ])
        # Calculate the logit difference
        patched_logit_diff = logits_to_logit_diff(clean_tokens, corrupted_tokens, patched_logits, comparison_index).detach()
        # Store the result, normalizing by the clean and corrupted logit difference so it's between 0 and 1 (ish)
        normalize_ratio = (clean_logit_diff - corrupted_logit_diff)
        if normalize_ratio == 0:
            normalize_ratio = 1
        patching_result[layer] = (patched_logit_diff - corrupted_logit_diff) / normalize_ratio
    return patching_result


def plot_activations(patching_result, clean_tokens, dataset):
    # Add the index to the end of the label, because plotly doesn't like duplicate labels
    token_labels = [f"{dataset.idx2tokens[token]}_{index}" for index, token in enumerate(clean_tokens)]
    imshow(patching_result, x=token_labels, xaxis="Position", yaxis="Layer", title="Activation patching")

In [None]:

def activation_patching_position_and_layers(model, dataset, clean_tokens, corrupted_tokens, comparison_index,positions,layers):
    # We run on the clean prompt with the cache so we store activations to patch in later.
    clean_logits, clean_cache = model.run_with_cache(clean_tokens)
    clean_logit_diff = logits_to_logit_diff(clean_tokens, corrupted_tokens, clean_logits, comparison_index)
    print(f"Clean logit difference: {clean_logit_diff.item():.3f}")

    # We don't need to cache on the corrupted prompt.
    corrupted_logits = model(corrupted_tokens)
    corrupted_logit_diff = logits_to_logit_diff(clean_tokens, corrupted_tokens, corrupted_logits, comparison_index)
    print(f"Corrupted logit difference: {corrupted_logit_diff.item():.3f}")
    print(f"Positive Direction: {dataset.idx2tokens[clean_tokens[comparison_index]]}")
    print(f"Negative Direction: {dataset.idx2tokens[corrupted_tokens[comparison_index]]}")

    def residual_stream_patching_hook(
        resid_pre,
        hook,
        positions):
        # Each HookPoint has a name attribute giving the name of the hook.
        clean_resid_pre = clean_cache[hook.name]
        for position in positions:
            resid_pre[:, position, :] = clean_resid_pre[:, position, :]
        return resid_pre
    # We make a tensor to store the results for each patching run. We put it on the model's device to avoid needing to move things between the GPU and CPU, which can be slow.
    num_positions = clean_tokens.shape[0]
    patching_result = 0
    # Use functools.partial to create a temporary hook function with the position fixed
    temp_hook_fn = partial(residual_stream_patching_hook, positions=positions)
    # Run the model with the patching hook
    hooks=[]
    for layer in layers:
        hooks.append((tl_util.get_act_name("resid_pre", layer), temp_hook_fn))

    patched_logits = model.run_with_hooks(corrupted_tokens, fwd_hooks=
            hooks
        )   
     
    # Calculate the logit difference
    patched_logit_diff = logits_to_logit_diff(clean_tokens, corrupted_tokens, patched_logits, comparison_index).detach()
    # Store the result, normalizing by the clean and corrupted logit difference so it's between 0 and 1 (ish)
    patching_result=0
    normalize_ratio = (clean_logit_diff - corrupted_logit_diff)
    if normalize_ratio == 0:
        normalize_ratio = 1
    patching_result = (patched_logit_diff - corrupted_logit_diff) / normalize_ratio
    return patching_result

In [None]:
model.reset_hooks()

In [None]:
clean_tokens = torch.from_numpy(dataset.tokenize(clean_prompt)[0:-1]).cuda()
corrupted_tokens = torch.from_numpy(dataset.tokenize(corrupted_prompt)[0:-1]).cuda()
patching_result = activation_patching_register(model, dataset, clean_tokens, corrupted_tokens, 46 + 7 ,[36,38,39,41,42,44,45])
patching_result_backwards = activation_patching_register(model, dataset, clean_tokens_backwards, corrupted_tokens_backwards, 46 + 7 ,[36,38,39,41,42,44,45])


In [None]:
imshow([patching_result])

In [None]:
imshow([patching_result_backwards])

In [None]:
patching_result_backwards_layers=activation_patching_position_and_layers(model, dataset, clean_tokens_backwards, corrupted_tokens_backwards, 46 + 7 ,[36,38,39,41,42,44,45],[0,1,2,3,4,5])#[36,38,39,41,42,44,45],[0,1,2,3,4,5])
patching_result_backwards_layers

In [None]:
patching_result_backwards_layers=activation_patching_position_and_layers(model, dataset, clean_tokens_backwards, corrupted_tokens_backwards, 46 + 7 ,[36,38,39,41,42,44,45,46,47],[0,1,2,3,4,5])#[36,38,39,41,42,44,45],[0,1,2,3,4,5])
patching_result_backwards_layers

In [None]:
def activation_patching_layers(model, dataset, clean_tokens, corrupted_tokens, comparison_index,layers):
    # We run on the clean prompt with the cache so we store activations to patch in later.
    clean_logits, clean_cache = model.run_with_cache(clean_tokens)
    clean_logit_diff = logits_to_logit_diff(clean_tokens, corrupted_tokens, clean_logits, comparison_index)
    print(f"Clean logit difference: {clean_logit_diff.item():.3f}")

    # We don't need to cache on the corrupted prompt.
    corrupted_logits = model(corrupted_tokens)
    corrupted_logit_diff = logits_to_logit_diff(clean_tokens, corrupted_tokens, corrupted_logits, comparison_index)
    print(f"Corrupted logit difference: {corrupted_logit_diff.item():.3f}")
    print(f"Positive Direction: {dataset.idx2tokens[clean_tokens[comparison_index]]}")
    print(f"Negative Direction: {dataset.idx2tokens[corrupted_tokens[comparison_index]]}")

    def residual_stream_patching_hook(
        resid_pre,
        hook,
        position):
        # Each HookPoint has a name attribute giving the name of the hook.
        clean_resid_pre = clean_cache[hook.name]
        resid_pre[:, position, :] = clean_resid_pre[:, position, :]
        return resid_pre
    # We make a tensor to store the results for each patching run. We put it on the model's device to avoid needing to move things between the GPU and CPU, which can be slow.
    num_positions = clean_tokens.shape[0]
    patching_result = torch.zeros(num_positions, device=model.cfg.device)
    for position in range(num_positions):
            # Use functools.partial to create a temporary hook function with the position fixed
            temp_hook_fn = partial(residual_stream_patching_hook, position=position)
            # Run the model with the patching hook
            hooks=[]
            for layer in layers:
                hooks.append((tl_util.get_act_name("resid_pre", layer), temp_hook_fn))
            
            patched_logits = model.run_with_hooks(corrupted_tokens, fwd_hooks=hooks)
            # Calculate the logit difference
            patched_logit_diff = logits_to_logit_diff(clean_tokens, corrupted_tokens, patched_logits, comparison_index).detach()
            # Store the result, normalizing by the clean and corrupted logit difference so it's between 0 and 1 (ish)
            normalize_ratio = (clean_logit_diff - corrupted_logit_diff)
            if normalize_ratio == 0:
                normalize_ratio = 1
            patching_result[position] = (patched_logit_diff - corrupted_logit_diff) / normalize_ratio
    return patching_result

In [None]:
def delete_non_paths(input_dict):
    keys_to_delete = [key for key, value in input_dict.items() if len(value) <= 2]
    for key in keys_to_delete:
        del input_dict[key]
    return input_dict

special_chars = [",", ":", "|"]
def get_paths(cache, labels, threshold=0.6):

    paths = {}
    for layer in range(1, 6):
        attn_pattern = cache[f"blocks.{layer}.attn.hook_pattern"]
        _, _, seq_len, _ = attn_pattern.shape

        for current_pos in range(seq_len):       
            current_token = labels[current_pos]     
            for attended_pos in range(seq_len):
                attn_value = attn_pattern[0, 0, current_pos, attended_pos]
                if attn_value > threshold:
                    attended_token = labels[attended_pos].replace(">", "")
                    previous_token = labels[attended_pos - 1].replace(">", "")
                    if not attended_token in special_chars and not previous_token in special_chars:
                        identifier = (current_pos, current_token)
                        if identifier in paths.keys():
                            paths[identifier].append(previous_token)
                        else:
                            paths[identifier] = [attended_token, previous_token]
    
    paths = delete_non_paths(paths)
    return paths

In [2]:
#clean_prompt_test="8>14,8>9,5>11,5>2,10>8,10>5,1>4,6>15,6>10,0>1,12>6,12>0,3>13,7>12,7>3|11:7>12>6>10>5>11"
#corrupted_prompt_test="8>14,8>9,5>11,5>2,10>8,10>5,1>4,6>15,6>10,0>1,3>6,3>0,12>13,7>3,7>12|11:7>3>6>10>5>11"
#clean_prompt_test="15>14,8>9,5>11,5>2,10>8,10>5,1>4,7>15,6>10,0>1,12>6,12>0,3>13,7>12,7>3|11:7>12>6>10>5>11"
#corrupted_prompt_test="15>14,8>9,5>11,5>2,10>8,10>5,1>4,7>15,6>10,0>1,3>6,3>0,12>13,7>3,7>12|11:7>3>6>10>5>11"
clean_prompt_test="15>14,7>9,5>11,9>2,10>8,10>5,1>4,7>15,6>10,0>1,12>6,12>0,3>13,7>12,7>3|11:7>12>6>10>5>11"
corrupted_prompt_test="15>14,7>9,5>11,9>2,10>8,10>5,1>4,7>15,6>10,0>1,3>6,3>0,12>13,7>3,7>12|11:7>3>6>10>5>11"

In [3]:
parse_example(clean_prompt_test)
plt.show()
parse_example(corrupted_prompt_test)
plt.show()

NameError: name 'parse_example' is not defined

In [None]:
clean_tokens_backwards_test = torch.from_numpy(dataset.tokenize(clean_prompt_test)[0:-1]).cuda()
corrupted_tokens_backwards_test= torch.from_numpy(dataset.tokenize(corrupted_prompt_test)[0:-1]).cuda()
labels, cache = get_example_cache(clean_prompt_test, model, dataset)
subpaths_clean = get_paths(cache, labels)
print(f'subpaths clean:{subpaths_clean}')
labels, cache = get_example_cache(corrupted_prompt_test, model, dataset)
subpaths_corrupted = get_paths(cache, labels)
print(f'subpaths corrupted:{subpaths_corrupted}')

In [None]:
clean_tokens_backwards_test = torch.from_numpy(dataset.tokenize(clean_prompt_test)[0:-1]).cuda()
corrupted_tokens_backwards_test= torch.from_numpy(dataset.tokenize(corrupted_prompt_test)[0:-1]).cuda()
register_patching_result_backwards_test = activation_patching_register(model, dataset, clean_tokens_backwards_test, corrupted_tokens_backwards_test, 46 + 2 ,[36,38,39,41,42,44,45])

In [None]:
imshow(model(clean_tokens_backwards_test)[0])

In [None]:
imshow(model(corrupted_tokens_backwards_test)[0])

In [None]:
model.reset_hooks()
clean_logits, clean_cache = model.run_with_cache(clean_tokens_backwards_test)
def residual_stream_patching_hook(
        resid_pre,
        hook,
        positions):
        # Each HookPoint has a name attribute giving the name of the hook.
        clean_resid_pre = clean_cache[hook.name]
        for position in positions:
            resid_pre[:, position, :] = clean_resid_pre[:, position, :]
        return resid_pre
    # We make a tensor to store the results for each patching run. We put it on the model's device to avoid needing to move things between the GPU and CPU, which can be slow.

temp_hook = partial(residual_stream_patching_hook, positions=[38])#[36,38,39,41,42,44,45])
# Run the model with the patching hook
patched_logits = model.run_with_hooks(corrupted_tokens_backwards_test, fwd_hooks=[
            (tl_util.get_act_name("resid_pre", 3), temp_hook),(tl_util.get_act_name("resid_pre", 4), temp_hook)
        ])
model.reset_hooks()

In [None]:
imshow(patched_logits[0])

In [None]:
imshow(torch.softmax(patched_logits[0],1)[47:49])

In [None]:
imshow(clean_logits[0][47:49])

In [None]:
imshow(torch.softmax(clean_logits[0],1)[47:49])

In [None]:
imshow((torch.softmax(clean_logits[0],1)-torch.softmax(patched_logits[0],1))[47:49])

In [None]:
register_patching_result_backwards_test = activation_patching_register(model, dataset, clean_tokens_backwards_test, corrupted_tokens_backwards_test, 46 + 2 ,[36,38,39,41,42,44,45])

In [None]:
imshow([register_patching_result_backwards_test])

In [None]:
patching_result_backwards_test = activation_patching(model, dataset, clean_tokens_backwards_test, corrupted_tokens_backwards_test, 46 + 2)

In [None]:
plot_activations(patching_result_backwards_test, clean_tokens_backwards_test, dataset)


In [None]:
patching_result_backwards_layers=activation_patching_layers(model, dataset, clean_tokens_backwards, corrupted_tokens_backwards, 46 + 7 ,[0,1,2,3,4,5])#[36,38,39,41,42,44,45],[0,1,2,3,4,5])
plot_activations([patching_result_backwards_layers], clean_tokens, dataset)

# Patching Result

In [None]:
def replace_nodes(graph,n1,n2):
    replace_A_1=graph.replace(f">{n1},",">A,")
    replace_A_2=replace_A_1.replace(f",{n1}>",",A>")
    replace_A_3=replace_A_2.replace(f">{n1}>",">A>")
    replace_A_4=replace_A_3.replace(f">{n1}|",">A|")
    replace_n1_1=replace_A_4.replace(f">{n2},",f">{n1},")
    replace_n1_2=replace_n1_1.replace(f",{n2}>",f",{n1}>")
    replace_n1_3=replace_n1_2.replace(f">{n2}>",f">{n1}>")
    replace_n1_4=replace_n1_3.replace(f">{n2}|",f">{n1}|")
    replace_n2=replace_n1_4.replace("A",f"{n2}")
    return replace_n2
    #Doesnt replace first node in list

In [None]:
random_seed = np.random.randint(1_000_000, 1_000_000_000)
graph = generate_example(n_states, random_seed, order="backward")
full_path = graph.split(":")[1].split(">")[1:]  # we ignore the first position, might need to reconsider this at some point


In [None]:
'7>2,10>7,14>10,8>14,0>8,6>0,9>6,11>9,3>11,12>3,1>12,15>1,5>13,4>15,4>5|2:4>15>1>12>3>11>9>6>0>8>14>10>7>2'

In [None]:
parse_example(graph)

In [None]:
corrupted_graph= replace_nodes(graph,13,10)

In [None]:
corrupted_graph

In [None]:
parse_example(graph)
plt.show()
parse_example(corrupted_graph)
plt.show()

In [None]:
model.reset_hooks()
position=46 + 5
clean_graph_tokens = torch.from_numpy(dataset.tokenize(graph)[0:-1]).cuda()
clean_logits, clean_cache = model.run_with_cache(clean_graph_tokens)
print( dataset.untokenize(np.argmax(clean_logits.detach().cpu(),2)[0][47:]))
corrupted_graph_tokens= torch.from_numpy(dataset.tokenize(corrupted_graph)[0:-1]).cuda()
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_graph_tokens)
activation_patching_result=activation_patching(model, dataset, clean_graph_tokens,corrupted_graph_tokens , position)#46 + 2
register_pathcing_result= activation_patching_register(model, dataset, clean_graph_tokens,corrupted_graph_tokens , position,[36,38,39,41,42,44,45])




In [None]:
labels, cache = get_example_cache(graph, model, dataset)
get_paths(clean_cache,labels)

In [None]:
imshow([register_pathcing_result[1:]],x=list(range(1,model.cfg.n_layers)))

In [None]:
imshow([register_pathcing_result[1:]],x=list(range(1,model.cfg.n_layers)))

In [None]:
plot_activations(activation_patching_result,clean_graph_tokens, dataset)

In [None]:
plot_activations(activation_patching_result,clean_graph_tokens, dataset)

# Test leaf

In [None]:

graph_not_leaf='7>2,10>7,14>10,8>14,0>8,6>0,9>6,11>9,3>11,12>3,1>12,15>1,5>13,4>15,4>5|2:4>15>1>12>3>11>9>6>0>8>14>10>7>2'
graph_leaf='7>2,10>7,14>10,8>14,0>8,6>0,9>6,11>9,3>11,12>3,1>12,15>1,4>13,4>15,4>5|2:4>15>1>12>3>11>9>6>0>8>14>10>7>2'

corrupted_graph_not_leaf= replace_nodes(graph_not_leaf,15,5)
corrupted_graph_leaf= replace_nodes(graph_leaf,15,5)


In [None]:
parse_example(graph_not_leaf)
plt.show()
parse_example(graph_leaf)
plt.show()

In [None]:
parse_example(corrupted_graph_not_leaf)
plt.show()
parse_example(corrupted_graph_leaf)
plt.show()

In [None]:
model.reset_hooks()
position=46 + 2
clean_graph_leaf_tokens = torch.from_numpy(dataset.tokenize(graph_leaf)[0:-1]).cuda()
clean_logits_leaf, clean_cache_leaf = model.run_with_cache(clean_graph_leaf_tokens)
print( dataset.untokenize(np.argmax(clean_logits_leaf.detach().cpu(),2)[0][47:]))
corrupted_graph_leaf_tokens= torch.from_numpy(dataset.tokenize(corrupted_graph_leaf)[0:-1]).cuda()
corrupted_logits_leaf, corrupted_cache_leaf = model.run_with_cache(corrupted_graph_leaf_tokens)
activation_patching_result_leaf=activation_patching(model, dataset, clean_graph_leaf_tokens,corrupted_graph_leaf_tokens , position)#46 + 2
register_pathcing_result_leaf= activation_patching_register(model, dataset, clean_graph_leaf_tokens,corrupted_graph_leaf_tokens , position,[36,38,39,41,42,44,45])




In [None]:
plot_activations(activation_patching_result_leaf,clean_graph_leaf_tokens, dataset)

In [None]:
imshow([register_pathcing_result_leaf[1:]],x=list(range(1,model.cfg.n_layers)))

In [None]:
model.reset_hooks()
position=46 + 2  
clean_graph_not_leaf_tokens = torch.from_numpy(dataset.tokenize(graph_not_leaf)[0:-1]).cuda()
clean_logits_not_leaf, clean_cache_not_leaf = model.run_with_cache(clean_graph_not_leaf_tokens)
print( dataset.untokenize(np.argmax(clean_logits_not_leaf.detach().cpu(),2)[0][47:]))
corrupted_graph_not_leaf_tokens= torch.from_numpy(dataset.tokenize(corrupted_graph_not_leaf)[0:-1]).cuda()
corrupted_logits_not_leaf, corrupted_cache_not_leaf = model.run_with_cache(corrupted_graph_not_leaf_tokens)
activation_patching_result_not_leaf=activation_patching(model, dataset, clean_graph_not_leaf_tokens,corrupted_graph_not_leaf_tokens , position)#46 + 2
register_pathcing_result_not_leaf= activation_patching_register(model, dataset, clean_graph_not_leaf_tokens,corrupted_graph_not_leaf_tokens , position,[36,38,39,41,42,44,45])



In [None]:
plot_activations(activation_patching_result_not_leaf,clean_graph_leaf_tokens, dataset)

In [None]:
imshow([register_pathcing_result_not_leaf[1:]],x=list(range(1,model.cfg.n_layers)))

In [None]:
register_pathcing_result_not_leaf_test_46= activation_patching_register(model, dataset, clean_graph_not_leaf_tokens,corrupted_graph_not_leaf_tokens , position,[36,38,39,41,42,44,45,46])


In [None]:
imshow([register_pathcing_result_not_leaf_test_46[1:]],x=list(range(1,model.cfg.n_layers)))

### Test choices 

In [None]:
def activation_patching_logits(model, dataset, clean_tokens, corrupted_tokens, comparison_index):
    # We run on the clean prompt with the cache so we store activations to patch in later.
    clean_logits, clean_cache = model.run_with_cache(clean_tokens)
    clean_logit_diff = logits_to_logit_diff(clean_tokens, corrupted_tokens, clean_logits, comparison_index)
    print(f"Clean logit difference: {clean_logit_diff.item():.3f}")

    # We don't need to cache on the corrupted prompt.
    corrupted_logits = model(corrupted_tokens)
    corrupted_logit_diff = logits_to_logit_diff(clean_tokens, corrupted_tokens, corrupted_logits, comparison_index)
    print(f"Corrupted logit difference: {corrupted_logit_diff.item():.3f}")
    print(f"Positive Direction: {dataset.idx2tokens[clean_tokens[comparison_index]]}")
    print(f"Negative Direction: {dataset.idx2tokens[corrupted_tokens[comparison_index]]}")

    def residual_stream_patching_hook(
        resid_pre,
        hook,
        position):
        # Each HookPoint has a name attribute giving the name of the hook.
        clean_resid_pre = clean_cache[hook.name]
        resid_pre[:, position, :] = clean_resid_pre[:, position, :]
        return resid_pre
    # We make a tensor to store the results for each patching run. We put it on the model's device to avoid needing to move things between the GPU and CPU, which can be slow.
    num_positions = clean_tokens.shape[0]
    patching_result = torch.zeros((model.cfg.n_layers, num_positions), device=model.cfg.device)
    for layer in tqdm_auto.tqdm(range(model.cfg.n_layers)):
        for position in range(num_positions):
            # Use functools.partial to create a temporary hook function with the position fixed
            temp_hook_fn = partial(residual_stream_patching_hook, position=position)
            # Run the model with the patching hook
            patched_logits = model.run_with_hooks(corrupted_tokens, fwd_hooks=[
                (tl_util.get_act_name("resid_pre", layer), temp_hook_fn)
            ])
            # Calculate the logit difference
            patched_logit_diff = logits_to_logit_diff(clean_tokens, corrupted_tokens, patched_logits, comparison_index).detach()
            # Store the result, normalizing by the clean and corrupted logit difference so it's between 0 and 1 (ish)
            normalize_ratio = 1
            if normalize_ratio == 0:
                normalize_ratio = 1
            patching_result[layer, position] = (patched_logit_diff - corrupted_logit_diff) / normalize_ratio
    return patching_result

In [None]:
#graph_choices='7>2,10>7,8>10,6>8,9>6,11>9,3>11,12>3,1>12,13>14,15>1,4>13,4>15,5>0,4>5|2:4>15>1>12>3>11>9>6>8>10>7>2'
#corrupted_graph_choices= replace_nodes(graph_choices,15,5)
#graph_choices='7>2,10>7,8>10,6>8,9>6,11>9,3>11,12>3,1>12,13>14,15>1,0>13,4>15,5>0,4>5|2:4>15>1>12>3>11>9>6>8>10>7>2'
#corrupted_graph_choices= replace_nodes(graph_choices,15,5)
graph_choices='7>2,10>7,8>10,6>8,9>6,11>0,0>9,3>11,12>3,1>12,13>14,15>1,4>13,4>15,4>5|2:4>15>1>12>3>11>0>9>6>8>10>7>2'
corrupted_graph_choices= replace_nodes(graph_choices,15,5)

In [None]:
parse_example(graph_choices)
plt.show()

In [4]:
model.reset_hooks()
position=46 + 2  
clean_graph_choices_tokens = torch.from_numpy(dataset.tokenize(graph_choices)[0:-1]).cuda()
clean_logits_choices, clean_cache_choices = model.run_with_cache(clean_graph_choices_tokens)
print( dataset.untokenize(np.argmax(clean_logits_choices.detach().cpu(),2)[0][47:]))
corrupted_graph_choices_tokens= torch.from_numpy(dataset.tokenize(corrupted_graph_choices)[0:-1]).cuda()
corrupted_logits_choices, corrupted_cache_choices = model.run_with_cache(corrupted_graph_choices_tokens)
activation_patching_result_choices=activation_patching(model, dataset, clean_graph_choices_tokens,corrupted_graph_choices_tokens , position)#46 + 2
activation_patching_result_choices_logits=activation_patching_logits(model, dataset, clean_graph_choices_tokens,corrupted_graph_choices_tokens , position)#46 + 2
register_pathcing_result_choices= activation_patching_register(model, dataset, clean_graph_choices_tokens,corrupted_graph_choices_tokens , position,[36,38,39,41,42,44,45])

NameError: name 'model' is not defined

In [None]:
imshow(clean_logits_choices[0][47:48],x=dataset.idx2tokens)

In [None]:
imshow(corrupted_logits_choices[0][47:48],x=dataset.idx2tokens)


In [None]:
labels, cache = get_example_cache(graph_choices, model, dataset)
get_paths(cache,labels)

In [None]:
parse_example(graph_choices)
plt.show()

In [None]:
plot_activations(activation_patching_result_choices,clean_graph_choices_tokens, dataset)

In [None]:
plot_activations(activation_patching_result_choices_logits,clean_graph_choices_tokens, dataset)

In [None]:
imshow([register_pathcing_result_choices[1:]],x=list(range(1,model.cfg.n_layers)))

In [None]:
parse_example(graph_choices)
plt.show()

# Test example

In [None]:
random_seed = np.random.randint(1_000_000, 1_000_000_000)
graph = generate_example(n_states, random_seed, order="backward")
full_path = graph.split(":")[1].split(">")[1:]  # we ignore the first position, might need to reconsider this at some point


In [None]:
model.reset_hooks()
position=46 + 5
clean_graph_tokens = torch.from_numpy(dataset.tokenize(graph)[0:-1]).cuda()
clean_logits, clean_cache = model.run_with_cache(clean_graph_tokens)
print( dataset.untokenize(np.argmax(clean_logits.detach().cpu(),2)[0][47:]))
corrupted_graph_tokens= torch.from_numpy(dataset.tokenize(corrupted_graph)[0:-1]).cuda()
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_graph_tokens)
activation_patching_result=activation_patching(model, dataset, clean_graph_tokens,corrupted_graph_tokens , position)#46 + 2
register_pathcing_result= activation_patching_register(model, dataset, clean_graph_tokens,corrupted_graph_tokens , position,[36,38,39,41,42,44,45])


