In [24]:
%load_ext autoreload
%autoreload 2
import pickle
import torch
from transformer_lens import HookedTransformerConfig, HookedTransformer
from transformer_lens import HookedTransformer
from circuits_benchmark.utils.get_cases import get_cases

task = get_cases(indices=['11'])[0]
task_idx = task.get_index()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [25]:
dir_name = f"../InterpBench/{task_idx}"
cfg_dict = pickle.load(open(f"{dir_name}/ll_model_cfg.pkl", "rb"))
cfg = HookedTransformerConfig.from_dict(cfg_dict)
cfg.device = "cuda" if torch.cuda.is_available() else "cpu"
model = HookedTransformer(cfg)
weights = torch.load(f"{dir_name}/ll_model.pth", map_location=cfg.device)
model.load_state_dict(weights)

<All keys matched successfully>

In [26]:
# turn off grads
model.eval()
model.requires_grad_(False)
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x301fb2050>

In [27]:
# load high level model
from circuits_benchmark.utils.iit import make_iit_hl_model
import circuits_benchmark.utils.iit.correspondence as correspondence
import iit.model_pairs as mp

def make_model_pair(benchmark_case):
    hl_model = benchmark_case.build_transformer_lens_model()
    hl_model = make_iit_hl_model(hl_model, eval_mode=True)
    tracr_output = benchmark_case.get_tracr_output()
    hl_ll_corr = correspondence.TracrCorrespondence.from_output(
            case=benchmark_case, tracr_output=tracr_output
        )
    model_pair = mp.StrictIITModelPair(hl_model, model, hl_ll_corr)
    return model_pair

In [28]:
max_len = 100

In [29]:
from circuits_benchmark.utils.iit.dataset import get_unique_data

model_pair = make_model_pair(task)
unique_test_data = get_unique_data(task, max_len=max_len)

Moving model to device:  cpu
{'hook_embed': HookPoint(), 'hook_pos_embed': HookPoint(), 'blocks.0.attn.hook_k': HookPoint(), 'blocks.0.attn.hook_q': HookPoint(), 'blocks.0.attn.hook_v': HookPoint(), 'blocks.0.attn.hook_z': HookPoint(), 'blocks.0.attn.hook_attn_scores': HookPoint(), 'blocks.0.attn.hook_pattern': HookPoint(), 'blocks.0.attn.hook_result': HookPoint(), 'blocks.0.mlp.hook_pre': HookPoint(), 'blocks.0.mlp.hook_post': HookPoint(), 'blocks.0.hook_attn_in': HookPoint(), 'blocks.0.hook_q_input': HookPoint(), 'blocks.0.hook_k_input': HookPoint(), 'blocks.0.hook_v_input': HookPoint(), 'blocks.0.hook_mlp_in': HookPoint(), 'blocks.0.hook_attn_out': HookPoint(), 'blocks.0.hook_mlp_out': HookPoint(), 'blocks.0.hook_resid_pre': HookPoint(), 'blocks.0.hook_resid_mid': HookPoint(), 'blocks.0.hook_resid_post': HookPoint(), 'blocks.1.attn.hook_k': HookPoint(), 'blocks.1.attn.hook_q': HookPoint(), 'blocks.1.attn.hook_v': HookPoint(), 'blocks.1.attn.hook_z': HookPoint(), 'blocks.1.attn.hook_

In [30]:
def collate_fn(batch):
    encoded_x = model_pair.hl_model.map_tracr_input_to_tl_input(list(zip(*batch))[0])
    return encoded_x

loader = torch.utils.data.DataLoader(unique_test_data, batch_size=256, shuffle=False, drop_last=False, collate_fn=collate_fn)

### Get the mean activations, norm and variance 

In [8]:
import utils.node_stats as node_stats

In [9]:
from utils.node_stats import get_node_stats, node_stats_to_df

cache_dict = get_node_stats(model_pair, loader)
node_norms = node_stats.node_stats_to_df(cache_dict)

  node_norms = pd.concat(


In [10]:
node_norms

Unnamed: 0,name,in_circuit,norm_cache,norm_std
0,"blocks.0.attn.hook_result, head 0",False,0.171328,0.035944
1,"blocks.0.attn.hook_result, head 1",False,0.140061,0.030674
2,"blocks.0.attn.hook_result, head 2",False,0.146607,0.01204
3,"blocks.0.attn.hook_result, head 3",False,0.077289,0.017604
6,"blocks.1.attn.hook_result, head 0",False,1.081345,0.097147
7,"blocks.1.attn.hook_result, head 1",False,0.359874,0.048454
8,"blocks.1.attn.hook_result, head 3",False,1.645881,0.053254
9,blocks.1.mlp.hook_post,False,1.258341,0.262288
4,blocks.0.mlp.hook_post,True,1.593136,0.106626
5,"blocks.1.attn.hook_result, head 2",True,4.03551,1.253511


In [11]:
import circuits_benchmark.commands.evaluation.iit.iit_eval as eval_node_effect

# model_pair = make_model_pair(task)
args = eval_node_effect.setup_args_parser(None, True)
max_len = 50
args.max_len = max_len
model_pair = make_model_pair(task)
node_effects, eval_metrics = eval_node_effect.get_node_effects(case=task, model_pair=model_pair, args=args, use_mean_cache=False)

Moving model to device:  cpu
{'hook_embed': HookPoint(), 'hook_pos_embed': HookPoint(), 'blocks.0.attn.hook_k': HookPoint(), 'blocks.0.attn.hook_q': HookPoint(), 'blocks.0.attn.hook_v': HookPoint(), 'blocks.0.attn.hook_z': HookPoint(), 'blocks.0.attn.hook_attn_scores': HookPoint(), 'blocks.0.attn.hook_pattern': HookPoint(), 'blocks.0.attn.hook_result': HookPoint(), 'blocks.0.mlp.hook_pre': HookPoint(), 'blocks.0.mlp.hook_post': HookPoint(), 'blocks.0.hook_attn_in': HookPoint(), 'blocks.0.hook_q_input': HookPoint(), 'blocks.0.hook_k_input': HookPoint(), 'blocks.0.hook_v_input': HookPoint(), 'blocks.0.hook_mlp_in': HookPoint(), 'blocks.0.hook_attn_out': HookPoint(), 'blocks.0.hook_mlp_out': HookPoint(), 'blocks.0.hook_resid_pre': HookPoint(), 'blocks.0.hook_resid_mid': HookPoint(), 'blocks.0.hook_resid_post': HookPoint(), 'blocks.1.attn.hook_k': HookPoint(), 'blocks.1.attn.hook_q': HookPoint(), 'blocks.1.attn.hook_v': HookPoint(), 'blocks.1.attn.hook_z': HookPoint(), 'blocks.1.attn.hook_

100%|██████████| 10/10 [00:00<00:00, 12.26it/s]
100%|██████████| 10/10 [00:00<00:00, 29.63it/s]
100%|██████████| 4/4 [00:00<00:00, 17.97it/s]
100%|██████████| 4/4 [00:00<00:00, 59.90it/s]


In [12]:
node_effects

Unnamed: 0,node,status,resample_ablate_effect,zero_ablate_effect
0,"blocks.0.attn.hook_result, head 0",not_in_circuit,0.002666,1.0
1,"blocks.0.attn.hook_result, head 1",not_in_circuit,0.0,1.0
2,"blocks.0.attn.hook_result, head 2",not_in_circuit,0.0,0.758594
3,"blocks.0.attn.hook_result, head 3",not_in_circuit,0.0,1.0
4,"blocks.1.attn.hook_result, head 2",not_in_circuit,0.0,1.0
5,"blocks.1.attn.hook_result, head 3",not_in_circuit,0.103477,1.0
6,blocks.1.mlp.hook_post,not_in_circuit,0.053177,1.0
7,blocks.0.mlp.hook_post,in_circuit,1.0,0.999219
8,"blocks.1.attn.hook_result, head :2",in_circuit,1.0,0.977344


In [122]:
# combine node effects with node_norms
import pandas as pd
combined_df = pd.merge(node_effects, node_norms, left_on="node", right_on="name", how="inner")
combined_df.drop(columns=["name", "in_circuit"], inplace=True)
combined_df

Unnamed: 0,node,status,resample_ablate_effect,zero_ablate_effect,norm_cache,norm_std
0,"blocks.0.attn.hook_result, head 0",not_in_circuit,0.002666,1.0,0.961356,0.048892
1,"blocks.0.attn.hook_result, head 1",not_in_circuit,0.0,1.0,0.608414,0.037886
2,"blocks.0.attn.hook_result, head 2",not_in_circuit,0.0,0.758594,0.651837,0.078307
3,"blocks.0.attn.hook_result, head 3",not_in_circuit,0.0,1.0,0.632466,0.023356
4,"blocks.1.attn.hook_result, head 2",not_in_circuit,0.0,1.0,1.031157,0.103031
5,"blocks.1.attn.hook_result, head 3",not_in_circuit,0.103477,1.0,1.609794,0.046769
6,blocks.1.mlp.hook_post,not_in_circuit,0.053177,1.0,3.403802,0.167944
7,blocks.0.mlp.hook_post,in_circuit,1.0,0.999219,3.991503,0.375723
8,"blocks.1.attn.hook_result, head :2",in_circuit,1.0,0.977344,2.15252,0.560006


In [127]:
import plotly.express as px

fig = px.scatter(combined_df, x="zero_ablate_effect", 
                 y="norm_cache", color="status",
                 error_y="norm_std",
                 # color map
                 color_discrete_map={
                    "in_circuit": "green",
                    "not_in_circuit": "orange",
                 },
                 labels={
                     "zero_ablate_effect": "Zero Ablation Effect",
                     "norm_cache": "Norm of Node Activation",
                     "status": "",
                     "resample_ablate_effect": "Resample Ablate Effect",
                 },
                 hover_data=["node", "resample_ablate_effect"],
                 # remove background grid and color
                 template="plotly_white",
                 )

# decrease margin
fig.update_layout(margin=dict(l=70, r=70, t=70, b=70))
# increase font size
fig.update_layout(font=dict(size=16))
fig.show()
# save to file as pdf
fig.write_image(f"node_stats_{task.get_index()}.pdf")


### Do logit lens on all nodes in the model

0_mlp_out has different ranges for different proportions, this ideally, shouldn't happen!

There is no node before it that calculates the fraction, so how does it do this?

1. Decompose resid for each layer and get mlp logit lens
2. For head results I can do stack_head_results

Once I have the activations, I can compare it with (This is SCORE)
1. logit diffs for classification
2. only true regression output (maybe MSE with label or something)

Tuned lens: I just need to take the activations and train a linear layer b/w that and final layer act
I have two choices:
1. train a map from hook points to pre unembed directly
2. train it on decomposed heads and resids

Then compute SCORE = unembed(LN(Linear(act))) vs logit diff/something

Pearson R coefficient:
1. Take the entire dataset, get SCORES for each prompt and calculate pearson R for all

Other experiments: 
1. I need to check if the mean is orthogonal/not to where we write. So I need the cosine similarity between acts. Combined with the fact that resample doesn't do shit, this makes sense. 
2. I can resample after multipyling the node with 1e-3, 1e-2 ... 10, 100 etc. and see it's effect. I can also do this after doing PCA on it, getting its subspace with max variation, and scaling that.

In [31]:
if model_pair.hl_model.is_categorical():
    # preprocess model for logit lens
    model.center_writing_weights(state_dict=model.state_dict())
    model.center_unembed(state_dict=model.state_dict())
    model.refactor_factored_attn_matrices(state_dict=model.state_dict())
try:
    model.fold_layer_norm(state_dict=model.state_dict())
except:
    print("No layer norm to fold")

No layer norm to fold


In [36]:
import utils.logit_lens as logit_lens

logit_lens_results, labels = logit_lens.do_logit_lens(model_pair, loader)

In [33]:
logit_lens_results.keys()

dict_keys(['embed', 'pos_embed', '0_mlp_out', '1_mlp_out', 'L0H0', 'L0H1', 'L0H2', 'L0H3', 'L1H0', 'L1H1', 'L1H2', 'L1H3'])

In [66]:
from iit.utils.node_picker import get_all_individual_nodes_in_circuit

nodes = get_all_individual_nodes_in_circuit(model, model_pair.corr)
def convert_ll_node_to_str(node: mp.LLNode):
    if 'attn' in node.name:
        block = node.name.split('.')[1]
        head = node.index.as_index[2]
        return f"L{block}H{head}"
    if 'mlp' in node.name:
        block = node.name.split('.')[1]
        return f"{block}_mlp_out"

nodes = [convert_ll_node_to_str(node) for node in nodes]
nodes

['0_mlp_out', 'L1H0']

In [68]:
from scipy import stats
import plotly.graph_objects as go
# k = "L1H2"
k = "0_mlp_out"
in_circuit_str = "in circuit" if k in nodes else "not in circuit"
fig = go.Figure()

for i in range(logit_lens_results[k].shape[1]):
    y = labels[:, i].squeeze().detach().cpu().numpy()
    x = logit_lens_results[k][:, i].detach().cpu().numpy()
    pearson_corr = stats.pearsonr(x, y)
    fig.add_trace(go.Scatter(x=x, y=y, mode='markers', name=f"pos {i}, corr: {pearson_corr[0]:.2f}"))

fig.update_layout(title=f"Logit Lens Results for {k} ({in_circuit_str})", yaxis_title="True Logits", xaxis_title="Logit Lens Results")
fig.show()

In [71]:
if model_pair.hl_model.is_categorical():
    logit_lens_per_vocab, per_vocab_labels = logit_lens.do_logit_lens_per_vocab_idx(model_pair, loader)
    # k = "L1H2"
    k = "0_mlp_out"
    vocab_dim = 1

    fig = go.Figure()

    for i in range(logit_lens_results[k].shape[1]):
        y = per_vocab_labels[vocab_dim][:, i].squeeze().detach().cpu().numpy()
        x = logit_lens_per_vocab[k][vocab_dim][:, i].detach().cpu().numpy()
        pearson_corr = stats.pearsonr(x, y)
        fig.add_trace(go.Scatter(x=x, y=y, mode='markers', name=f"pos {i}, corr: {pearson_corr[0]:.2f}"))

    fig.update_layout(title=f"Logit Lens Results for {k}", yaxis_title="True Logits", xaxis_title="Logit Lens Results")
    fig.show()

In [40]:
logit_diff_directions = model.unembed.W_U.T # d_model, d_vocab_out
# make batch, pos, d_model, d_vocab_out by expanding
batch_dims = 10
pos_dims = 8
logit_diff_directions = logit_diff_directions.unsqueeze(0).unsqueeze(1).expand(
    batch_dims, pos_dims, -1, -1
)

In [42]:
logit_diff_directions.shape

torch.Size([10, 8, 5, 12])

In [20]:
import pandas as pd
import plotly.express as px
pearson_corrs = {}
for k in logit_lens_results.keys():
    x = logit_lens_results[k].detach().cpu().numpy()
    y = labels.detach().cpu().numpy()
    for i in range(x.shape[1]): 
        pearson_corr = stats.pearsonr(x[:, i], y[:, i])
        if k not in pearson_corrs:
            pearson_corrs[k] = {}
        pearson_corrs[k][str(i)] = pearson_corr.correlation

pearson_corrs = pd.DataFrame(pearson_corrs)
px.imshow(pearson_corrs, 
          # set color map
            color_continuous_scale="Viridis",
            # set axis labels   
            labels=dict(y="Position", x="Layer/Head", color="Pearson Correlation"),
)


An input array is constant; the correlation coefficient is not defined.



In [32]:
import pandas as pd
stds = {}

for k, tensor in logit_lens_results.items():
    # calculate std of logit lens results across a position
    # std = tensor.std(dim=0).detach().cpu().numpy()
    # stds[k] = std
    maxes = tensor.max(dim=0).values.detach().cpu().numpy()
    mins = tensor.min(dim=0).values.detach().cpu().numpy()
    maxminusmin = maxes - mins
    stds[k] = maxminusmin

stds_df = pd.DataFrame(stds)
stds_df

Unnamed: 0,embed,pos_embed,0_mlp_out,1_mlp_out,L0H0,L0H1,L0H2,L0H3,L1H0,L1H1,L1H2,L1H3
0,0.141054,0.0,0.326435,0.047036,0.043156,0.017481,0.057509,0.017454,0.698715,0.064397,0.017532,0.029497
1,0.141054,0.0,0.389155,0.04467,0.050453,0.023787,0.100339,0.020985,0.696097,0.079091,0.038061,0.047075
2,0.141054,0.0,0.40722,0.058867,0.045557,0.026155,0.114847,0.023848,0.71045,0.089134,0.04184,0.050687
3,0.141054,0.0,0.371703,0.085974,0.056181,0.028328,0.132014,0.021679,0.717134,0.096526,0.052626,0.057512


### Rough

In [None]:
# per_layer_residual, layers = cache.decompose_resid(-1, mode="mlp", return_labels=True, pos_slice=slice(1, None, None ))
# per_head_residual, attns = cache.stack_head_results(
#     layer=-1, pos_slice=slice(1, None, None ), return_labels=True
# )

# logit_diff_directions = model.unembed.W_U.t().squeeze(0)

# per_layer_logit_diff = residual_stack_to_logit_diff(per_layer_residual, cache, logit_diff_directions)
# per_head_logit_diff = residual_stack_to_logit_diff(per_head_residual, cache, logit_diff_directions)

# mlp_loss = torch.nn.functional.mse_loss(per_layer_logit_diff, labels.squeeze()[:, 1:], reduction="none")
# attn_loss = torch.nn.functional.mse_loss(per_head_logit_diff, labels.squeeze()[:, 1:], reduction="none")
# # mean everything except dim 0
# mlp_loss = mlp_loss.mean(dim=[1, 2])
# attn_loss = attn_loss.mean(dim=[1, 2])

In [None]:
# results = list(zip(layers, mlp_loss)) + (list(zip(attns, attn_loss)))
# results

In [None]:
# head = 7
# ex = 0
# per_head_logit_diff[head][ex], labels[ex][1:], attns[head], batch[ex][1:]

In [None]:
# layer = 3
# ex = 2
# per_layer_logit_diff[layer][ex], labels[ex], layers[layer], batch[ex][1:]