Setup

In [1]:
import pickle
import torch
from transformer_lens import HookedTransformerConfig, HookedTransformer
from transformer_lens import HookedTransformer
from circuits_benchmark.utils.get_cases import get_cases

task = get_cases(indices=['3'])[0]
task_idx = task.get_index()

In [2]:
dir_name = f"../InterpBench/{task_idx}"
cfg_dict = pickle.load(open(f"{dir_name}/ll_model_cfg.pkl", "rb"))
cfg = HookedTransformerConfig.from_dict(cfg_dict)
cfg.device = "cuda" if torch.cuda.is_available() else "cpu"
model = HookedTransformer(cfg)
weights = torch.load(f"{dir_name}/ll_model.pth", map_location=cfg.device)
model.load_state_dict(weights)
# turn off grads
model.eval()
model.requires_grad_(False)
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x2c4e4c290>

In [3]:
# load high level model
from circuits_benchmark.utils.iit import make_iit_hl_model
import circuits_benchmark.utils.iit.correspondence as correspondence
import iit.model_pairs as mp

def make_model_pair(benchmark_case):
    hl_model = benchmark_case.build_transformer_lens_model()
    hl_model = make_iit_hl_model(hl_model, eval_mode=True)
    tracr_output = benchmark_case.get_tracr_output()
    hl_ll_corr = correspondence.TracrCorrespondence.from_output(
            case=benchmark_case, tracr_output=tracr_output
        )
    model_pair = mp.StrictIITModelPair(hl_model, model, hl_ll_corr)
    return model_pair

In [4]:
from circuits_benchmark.utils.iit.dataset import get_unique_data

max_len = 100
model_pair = make_model_pair(task)
unique_test_data = get_unique_data(task, max_len=max_len)

Moving model to device:  cpu
{'hook_embed': HookPoint(), 'hook_pos_embed': HookPoint(), 'blocks.0.attn.hook_k': HookPoint(), 'blocks.0.attn.hook_q': HookPoint(), 'blocks.0.attn.hook_v': HookPoint(), 'blocks.0.attn.hook_z': HookPoint(), 'blocks.0.attn.hook_attn_scores': HookPoint(), 'blocks.0.attn.hook_pattern': HookPoint(), 'blocks.0.attn.hook_result': HookPoint(), 'blocks.0.mlp.hook_pre': HookPoint(), 'blocks.0.mlp.hook_post': HookPoint(), 'blocks.0.hook_attn_in': HookPoint(), 'blocks.0.hook_q_input': HookPoint(), 'blocks.0.hook_k_input': HookPoint(), 'blocks.0.hook_v_input': HookPoint(), 'blocks.0.hook_mlp_in': HookPoint(), 'blocks.0.hook_attn_out': HookPoint(), 'blocks.0.hook_mlp_out': HookPoint(), 'blocks.0.hook_resid_pre': HookPoint(), 'blocks.0.hook_resid_mid': HookPoint(), 'blocks.0.hook_resid_post': HookPoint(), 'blocks.1.attn.hook_k': HookPoint(), 'blocks.1.attn.hook_q': HookPoint(), 'blocks.1.attn.hook_v': HookPoint(), 'blocks.1.attn.hook_z': HookPoint(), 'blocks.1.attn.hook_

Resample ablate with 10%, 20% etc. of the activation

In [5]:
from iit.model_pairs.nodes import LLNode
from typing import Callable
from torch import Tensor
from transformer_lens.hook_points import HookPoint
import iit.utils.eval_ablations as eval_ablations
from importlib import reload
from circuits_benchmark.utils.iit.dataset import TracrIITDataset
import pandas as pd


def get_effects_for_scales(
    model_pair,
    unique_test_data,
    hook_maker: Callable[
        [mp.BaseModelPair, LLNode, float], Callable[[Tensor, HookPoint], Tensor]
    ],
    scales=[0.1, 1.0],
):
    combined_scales_df = pd.DataFrame(
        columns=["node", "status"] + [f"scale {scale}" for scale in scales]
    )

    for scale in scales:
        print(f"Running scale {scale}\n")
        test_set = TracrIITDataset(
            unique_test_data,
            unique_test_data,
            model_pair.hl_model,
            every_combination=True,
        )

        hook_maker_for_node = lambda ll_node: hook_maker(model_pair=model_pair, ll_node=ll_node, scale=scale)

        causal_effects_not_in_circuit = eval_ablations.check_causal_effect(
            model_pair=model_pair,
            dataset=test_set,
            hook_maker=hook_maker_for_node,
            node_type="n",
        )

        causal_effects_in_circuit = eval_ablations.check_causal_effect(
            model_pair=model_pair,
            dataset=test_set,
            hook_maker=hook_maker_for_node,
            node_type="individual_c",
        )

        causal_effects = eval_ablations.make_dataframe_of_results(
            causal_effects_not_in_circuit, causal_effects_in_circuit
        )

        # change column name causal effect to scale
        causal_effects.rename(columns={"causal effect": f"scale {scale}"}, inplace=True)
        combined_scales_df = pd.merge(
            combined_scales_df, causal_effects, on=["node", "status"], how="outer"
        )
        # drop columns with nan
        combined_scales_df.dropna(axis=1, how="all", inplace=True)
    return combined_scales_df

In [6]:
def make_ll_ablation_hook_scale_activations(
        model_pair, ll_node: LLNode, scale: float
    ) -> Callable[[Tensor, HookPoint], Tensor]:
        """
        Resample ablations, but with the patched activations scaled by the given factor.
        """
        if ll_node.subspace is not None:
            raise NotImplementedError

        def ll_ablation_hook(hook_point_out: Tensor, hook: HookPoint) -> Tensor:
            out = hook_point_out.clone()
            index = ll_node.index if ll_node.index is not None else index.Ix[[None]]
            out[index.as_index] = model_pair.ll_cache[hook.name][index.as_index] * scale
            return out

        return ll_ablation_hook

scales = [0.0, 0.1, 0.2, 0.5, 0.7, 0.8, 1.0, 1.2, 1.4, 2.0, 5.0]
combined_scales_df = get_effects_for_scales(model_pair, unique_test_data, 
                                            hook_maker=make_ll_ablation_hook_scale_activations,
                                            scales=scales)

Running scale 0.0



100%|██████████| 40/40 [00:03<00:00, 12.51it/s]
100%|██████████| 40/40 [00:01<00:00, 27.56it/s]


Running scale 0.1



100%|██████████| 40/40 [00:03<00:00, 11.83it/s]
100%|██████████| 40/40 [00:01<00:00, 26.53it/s]


Running scale 0.2



100%|██████████| 40/40 [00:03<00:00, 12.50it/s]
100%|██████████| 40/40 [00:01<00:00, 28.35it/s]


Running scale 0.5



100%|██████████| 40/40 [00:03<00:00, 12.87it/s]
100%|██████████| 40/40 [00:01<00:00, 27.42it/s]


Running scale 0.7



100%|██████████| 40/40 [00:03<00:00, 12.63it/s]
100%|██████████| 40/40 [00:01<00:00, 27.93it/s]


Running scale 0.8



100%|██████████| 40/40 [00:03<00:00, 12.78it/s]
100%|██████████| 40/40 [00:01<00:00, 27.91it/s]


Running scale 1.0



100%|██████████| 40/40 [00:03<00:00, 12.48it/s]
100%|██████████| 40/40 [00:01<00:00, 25.24it/s]


Running scale 1.2



100%|██████████| 40/40 [00:03<00:00, 12.57it/s]
100%|██████████| 40/40 [00:01<00:00, 26.44it/s]


Running scale 1.4



100%|██████████| 40/40 [00:03<00:00, 12.52it/s]
100%|██████████| 40/40 [00:01<00:00, 26.97it/s]


Running scale 2.0



100%|██████████| 40/40 [00:03<00:00, 12.18it/s]
100%|██████████| 40/40 [00:01<00:00, 27.62it/s]


Running scale 5.0



100%|██████████| 40/40 [00:03<00:00, 11.81it/s]
100%|██████████| 40/40 [00:01<00:00, 25.91it/s]


Running scale 10.0



100%|██████████| 40/40 [00:03<00:00, 11.49it/s]
100%|██████████| 40/40 [00:01<00:00, 26.48it/s]


In [7]:
combined_scales_df.rename(columns={"scale 0.0_y": "scale 0.0"}, inplace=True)
combined_scales_df = combined_scales_df.sort_values(by=["status"], ascending=False)
combined_scales_df

Unnamed: 0,node,status,scale 0.0,scale 0.1,scale 0.2,scale 0.5,scale 0.7,scale 0.8,scale 1.0,scale 1.2,scale 1.4,scale 2.0,scale 5.0,scale 10.0
0,"blocks.0.attn.hook_result, head 0",not_in_circuit,1.0,0.99573,0.984805,0.424364,0.147599,0.02038,0.00598,0.204187,0.44454,0.806988,0.9997,1.0
1,"blocks.0.attn.hook_result, head 1",not_in_circuit,1.0,1.0,0.988372,0.64886,0.173528,0.018253,0.0,0.025237,0.223273,0.997539,1.0,1.0
2,"blocks.0.attn.hook_result, head 2",not_in_circuit,0.680534,0.588329,0.437677,0.058443,0.035223,0.002068,0.0,0.003185,0.036501,0.460792,0.900187,0.743509
3,"blocks.0.attn.hook_result, head 3",not_in_circuit,1.0,1.0,0.999408,0.635306,0.118211,0.035539,0.000168,0.164127,0.523615,0.990259,1.0,1.0
7,"blocks.1.attn.hook_result, head 2",not_in_circuit,1.0,1.0,1.0,1.0,0.961237,0.536721,0.0,0.367025,0.964079,1.0,1.0,1.0
8,"blocks.1.attn.hook_result, head 3",not_in_circuit,1.0,1.0,1.0,1.0,0.72284,0.457345,0.074625,0.354972,0.727884,1.0,1.0,1.0
9,blocks.1.mlp.hook_post,not_in_circuit,1.0,1.0,1.0,1.0,0.892223,0.682897,0.02783,0.531765,0.904191,1.0,1.0,1.0
4,blocks.0.mlp.hook_post,in_circuit,0.995777,0.911988,0.97285,0.933968,0.894231,0.960773,1.0,1.0,0.910674,0.951065,0.998301,0.99975
5,"blocks.1.attn.hook_result, head 0",in_circuit,0.947688,0.93232,0.850772,0.87972,0.984155,1.0,1.0,1.0,0.995376,0.933816,0.955799,0.999309
6,"blocks.1.attn.hook_result, head 1",in_circuit,0.920349,0.808206,0.749656,0.532832,0.623412,0.587123,0.616037,0.701412,0.686588,0.807814,0.891579,0.997364


In [56]:
import plotly.graph_objects as go

def plot_causal_effect(combined_scales_df, scales, image_name):
    fig = go.Figure()
    scale_columns = [f"scale {scale}" for scale in scales]
    for i, row in combined_scales_df.iterrows():
        y = [row[col] for col in scale_columns]
        x = scales
        # plot lines for each node
        fig.add_trace(go.Line(x=x, y=y, mode='lines+markers', 
                            # set color based on status
                            line=dict(color="green" if row["status"] == "in_circuit" else "orange"),
                            hovertext=f"Node: {row['node']}, Status: {row['status']}",
                            # define legend only for color, not for line
                            showlegend=False,
                            ),
                    )
    fig.update_layout(xaxis_title="Scale", yaxis_title="Causal Effect")
    # make legend for color
    fig.add_trace(go.Scatter(x=[None], y=[None], mode='markers', marker=dict(color="green"), name="in_circuit"))
    fig.add_trace(go.Scatter(x=[None], y=[None], mode='markers', marker=dict(color="orange"), name="not_in_circuit"))
    # make background transparent and remove grid
    fig.update_layout(template="plotly_white")
    # remove grid
    fig.update_xaxes(showgrid=False)
    fig.update_yaxes(showgrid=False)
    # decrease margin
    fig.update_layout(margin=dict(l=70, r=70, t=70, b=70))
    # increase font size
    fig.update_layout(font=dict(size=16))

    fig.show()
    # save to file as pdf with same width and height
    fig.write_image(f"{image_name}.pdf", width=1000, height=400)

In [57]:
plot_causal_effect(combined_scales_df, scales, f"causal_effect_scale_{task.get_index()}")

Do PCA and patch only the varying part 

In [10]:
from iit.utils.node_picker import get_nodes_in_circuit, get_nodes_not_in_circuit, get_all_nodes
from sklearn.decomposition import PCA
import iit.utils.index as index

def find_ll_node_by_name(name, list_of_nodes) -> list:
    ll_nodes = []
    for node in list_of_nodes:
        if node.name == name:
            ll_nodes.append(node)
    return ll_nodes


def collect_activations(model_pair, loader, pos_slice=slice(1, None, None)):
    activation_cache = {}
    nodes = get_all_nodes(model_pair.ll_model)
    pos_idx = index.TorchIndex(((slice(None), pos_slice))) if pos_slice is not None else index.Ix[slice(None)]
    for node in nodes:
        activation_cache[node] = None
    for batch in loader:
        _, batch_cache = model_pair.ll_model.run_with_cache(batch)
        for k, tensor in batch_cache.items():
            ll_node_for_k = find_ll_node_by_name(k, nodes)
            if len(ll_node_for_k) > 0:
                for node in ll_node_for_k:
                    act = tensor[node.index.as_index].cpu()[pos_idx.as_index]
                    if activation_cache[node] is None:
                        activation_cache[node] = act
                    else:
                        activation_cache[node] = torch.cat((activation_cache[node], act), dim=0).cpu()
    return activation_cache


def collect_pca_directions(activation_cache, num_pca_components=2):
    pca_dirs = {}

    for node, activations in activation_cache.items():
        # calculate pca directions for activations
        for i in range(activations.shape[1]):
            pca = PCA(n_components=num_pca_components)
            pca.fit(activations[:, i].detach().numpy())
            if pca_dirs.get(node) is None:
                pca_dirs[node] = {}
            pca_dirs[node][i] = pca.components_
    return pca_dirs

In [11]:
def collate_fn(batch):
    encoded_x = model_pair.hl_model.map_tracr_input_to_tl_input(list(zip(*batch))[0])
    return encoded_x

loader = torch.utils.data.DataLoader(unique_test_data, batch_size=256, shuffle=False, drop_last=False, collate_fn=collate_fn)

activation_cache = collect_activations(model_pair, loader=loader)
activation_cache.keys()

dict_keys([LLNode(name='blocks.0.attn.hook_result', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.0.attn.hook_result', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.0.attn.hook_result', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.0.attn.hook_result', index=[:, :, 3, :], subspace=None), LLNode(name='blocks.0.mlp.hook_post', index=[:], subspace=None), LLNode(name='blocks.1.attn.hook_result', index=[:, :, 0, :], subspace=None), LLNode(name='blocks.1.attn.hook_result', index=[:, :, 1, :], subspace=None), LLNode(name='blocks.1.attn.hook_result', index=[:, :, 2, :], subspace=None), LLNode(name='blocks.1.attn.hook_result', index=[:, :, 3, :], subspace=None), LLNode(name='blocks.1.mlp.hook_post', index=[:], subspace=None)])

In [42]:
pca_dirs = collect_pca_directions(activation_cache, num_pca_components=2)

In [86]:
from fancy_einsum import einsum


def make_ll_ablation_hook_scale_activations_with_variance(
    model_pair, ll_node: LLNode, pca_dirs: dict, scale: float, 
    self_patch: bool = False,
    ablate_high_variance: bool = True
) -> Callable[[Tensor, HookPoint], Tensor]:
    """
    Resample ablations, but with the patched activations scaled by the given factor, along the PCA directions. Since the PCA directions capture the variance in activations, this may help us to distinguish between constant nodes and nodes whose variance is important for the model. My hypothesis is that constant nodes will have a smaller effect on the model than nodes whose variance is important for any scale provided.

    If self_patch is True, the ablation will be done with the activations at the node itself, rather than the activations at the node in the cache. So this is not a resample ablation.  
    """
    if ll_node.subspace is not None:
        raise NotImplementedError

    def ll_ablation_hook(hook_point_out: Tensor, hook: HookPoint) -> Tensor:
        out = hook_point_out.clone()
        index = ll_node.index if ll_node.index is not None else index.Ix[[None]]
        cached_activation = model_pair.ll_cache[hook.name][index.as_index]
        pca_dirs_at_node = pca_dirs[ll_node]
        for i in range(0, cached_activation.shape[1]-1):
            pca_dirs_at_i = pca_dirs_at_node[i]
            components_at_clean_dir = []
            for component in range(pca_dirs_at_i.shape[0]):
                components_at_clean_dir.append(
                    einsum(
                        "batch d_model, d_model -> batch",
                        out[index.as_index][:, i+1],
                        torch.tensor(pca_dirs_at_i[component]),
                    ).unsqueeze(1)
                )
            
            
            if self_patch and ablate_high_variance:
                # take the pca direction as the direction to remove
                components_to_remove = sum(components_at_clean_dir)
                components_to_add = components_to_remove * scale
            elif self_patch and not ablate_high_variance:
                # take the mean direction as the direction to remove
                components_to_remove = out[index.as_index][:, i+1] - sum(components_at_clean_dir) 
                components_to_add = components_to_remove * scale
            elif not self_patch and ablate_high_variance:
                components_at_cached_dir = []
                components_to_remove = sum(components_at_clean_dir)
                for component in range(pca_dirs_at_i.shape[0]):
                    components_at_cached_dir.append(
                        einsum(
                            "batch d_model, d_model -> batch",
                            cached_activation[:, i+1],
                            torch.tensor(pca_dirs_at_i[component]),
                        ).unsqueeze(1)
                    )
                components_to_add = sum(components_at_cached_dir) * scale
            elif not self_patch and not ablate_high_variance:
                components_at_cached_dir = []
                components_to_remove = out[index.as_index][:, i+1] - sum(components_at_clean_dir) 
                for component in range(pca_dirs_at_i.shape[0]):
                    components_at_cached_dir.append(
                        einsum(
                            "batch d_model, d_model -> batch",
                            cached_activation[:, i+1],
                            torch.tensor(pca_dirs_at_i[component]),
                        ).unsqueeze(1)
                    )
                components_to_add = (cached_activation[:, i+1] - sum(components_at_cached_dir)) * scale

            out[index.as_index][:, i+1] = (
                out[index.as_index][:, i+1] + components_to_add - components_to_remove
            )
        return out

    return ll_ablation_hook

In [98]:
%%capture
def hook_maker(model_pair, ll_node: LLNode, scale: float) -> Callable[[Tensor, HookPoint], Tensor]:
    return make_ll_ablation_hook_scale_activations_with_variance(model_pair, ll_node, pca_dirs, scale, self_patch=True, ablate_high_variance=False)
combined_scales_df_orthogonal = get_effects_for_scales(model_pair, unique_test_data, hook_maker, scales=scales)

In [100]:
# sort by status
combined_scales_df_orthogonal = combined_scales_df_orthogonal.sort_values(by="status", ascending=False)
combined_scales_df_orthogonal.rename(columns={"scale 0.0_y": "scale 0.0"}, inplace=True)
plot_causal_effect(combined_scales_df_orthogonal, scales, f"causal_effect_scale_orthogonal_{task.get_index()}")


plotly.graph_objs.Line is deprecated.
Please replace it with one of the following more specific types
  - plotly.graph_objs.scatter.Line
  - plotly.graph_objs.layout.shape.Line
  - etc.




In [101]:
%%capture
def hook_maker(model_pair, ll_node: LLNode, scale: float) -> Callable[[Tensor, HookPoint], Tensor]:
    return make_ll_ablation_hook_scale_activations_with_variance(model_pair, ll_node, pca_dirs, scale, self_patch=False, ablate_high_variance=False)
combined_scales_df_orthogonal = get_effects_for_scales(model_pair, unique_test_data, hook_maker, scales=scales)

In [102]:
# sort by status
combined_scales_df_orthogonal = combined_scales_df_orthogonal.sort_values(by="status", ascending=False)
combined_scales_df_orthogonal.rename(columns={"scale 0.0_y": "scale 0.0"}, inplace=True)
plot_causal_effect(combined_scales_df_orthogonal, scales, f"causal_effect_scale_orthogonal_{task.get_index()}")


plotly.graph_objs.Line is deprecated.
Please replace it with one of the following more specific types
  - plotly.graph_objs.scatter.Line
  - plotly.graph_objs.layout.shape.Line
  - etc.




In [103]:
%%capture
def hook_maker(model_pair, ll_node: LLNode, scale: float) -> Callable[[Tensor, HookPoint], Tensor]:
    return make_ll_ablation_hook_scale_activations_with_variance(model_pair, ll_node, pca_dirs, scale, self_patch=True, ablate_high_variance=True)
combined_scales_df_orthogonal = get_effects_for_scales(model_pair, unique_test_data, hook_maker, scales=scales)

In [104]:
# sort by status
combined_scales_df_orthogonal = combined_scales_df_orthogonal.sort_values(by="status", ascending=False)
combined_scales_df_orthogonal.rename(columns={"scale 0.0_y": "scale 0.0"}, inplace=True)
plot_causal_effect(combined_scales_df_orthogonal, scales, f"causal_effect_scale_orthogonal_{task.get_index()}")


plotly.graph_objs.Line is deprecated.
Please replace it with one of the following more specific types
  - plotly.graph_objs.scatter.Line
  - plotly.graph_objs.layout.shape.Line
  - etc.




In [105]:
%%capture
def hook_maker(model_pair, ll_node: LLNode, scale: float) -> Callable[[Tensor, HookPoint], Tensor]:
    return make_ll_ablation_hook_scale_activations_with_variance(model_pair, ll_node, pca_dirs, scale, self_patch=False, ablate_high_variance=True)
combined_scales_df_orthogonal = get_effects_for_scales(model_pair, unique_test_data, hook_maker, scales=scales)

In [106]:
# sort by status
combined_scales_df_orthogonal = combined_scales_df_orthogonal.sort_values(by="status", ascending=False)
combined_scales_df_orthogonal.rename(columns={"scale 0.0_y": "scale 0.0"}, inplace=True)
plot_causal_effect(combined_scales_df_orthogonal, scales, f"causal_effect_scale_orthogonal_{task.get_index()}")


plotly.graph_objs.Line is deprecated.
Please replace it with one of the following more specific types
  - plotly.graph_objs.scatter.Line
  - plotly.graph_objs.layout.shape.Line
  - etc.




Check if the PCA dir is orthogonal using logit lens

In [85]:
pca_dirs

{LLNode(name='blocks.0.attn.hook_result', index=[:, :, 0, :], subspace=None): {0: array([[-0.1543014 ,  0.03284107,  0.14567627, -0.1269737 ,  0.6946124 ,
           0.24951641,  0.25282773, -0.19772966,  0.31867442,  0.23631977,
           0.2914874 , -0.21826585],
         [-0.03656477,  0.4434389 , -0.1949604 ,  0.42029554,  0.3295396 ,
           0.11560258, -0.26876563,  0.39282045, -0.43194056,  0.03826209,
           0.21348502, -0.07246836]], dtype=float32),
  1: array([[ 0.15616408, -0.06144751, -0.13153268,  0.10038847, -0.7124758 ,
          -0.25488907, -0.2368863 ,  0.17466052, -0.29196116, -0.23861626,
          -0.3046295 ,  0.22420558],
         [-0.02726392,  0.43676808, -0.19852485,  0.42585406,  0.28905812,
           0.10336784, -0.2864966 ,  0.40794384, -0.45174152,  0.02264791,
           0.19364206, -0.05392178]], dtype=float32),
  2: array([[ 0.15651718, -0.06771581, -0.12850896,  0.09437191, -0.7163007 ,
          -0.2561449 , -0.2330011 ,  0.16913734, -0.28568

subtract mean of everything before PCA

In [96]:
import numpy as np
mean_dir = np.array([-4.0, 3.0, -5.0])
noise_3by3 = np.random.normal(0, 1, (3, 3)) / 10

mean_plus_noise = mean_dir + noise_3by3

pca = PCA(n_components=3)
pca.fit(mean_plus_noise)

pca2 = PCA(n_components=3)
pca2.fit(noise_3by3)

pca.components_ - pca2.components_

array([[-2.10942375e-15, -1.44328993e-15,  3.27515792e-15],
       [-2.16493490e-15,  5.66213743e-15, -1.11022302e-16],
       [ 5.72746184e-01,  1.83560645e+00,  5.50009800e-01]])