In [19]:
import sys
import os

import torch
from torch import Tensor
import numpy as np
import pandas as pd
from jaxtyping import Int, Float
import einops
from transformer_lens import HookedTransformer
from transformer_lens.utils import get_act_name, to_numpy
from functools import partial
import plotly.express as px

%load_ext autoreload
%autoreload 2
os.chdir('/home/alejo/Projects/Interpretability_Collections')
from train import load_model
from dataset import BalancedParenthesisDataGenerator
from interpreting_models.utils_exploration import to_str_toks, gen_balanced_paren_toks, gen_fail_both_conditions_toks, \
    gen_only_equal_count_toks, gen_only_horizon_toks, gen_filtered_toks
from utils import compute_accuracy, compute_cross_entropy_loss
import circuitsvis as cv
from IPython.display import display

sys.path.append('/home/alejo/Projects')
from my_plotly_utils import hist, bar, imshow, scatter, line, figs_to_subplots
from path_patching import act_patch, path_patch, Node, IterNode

torch.set_grad_enabled(False)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


<torch.autograd.grad_mode.set_grad_enabled at 0x7fe94e49b2d0>

In [20]:
data_gen = BalancedParenthesisDataGenerator(n_ctx_numeric=20)
model: HookedTransformer = load_model('./models/final/bal_paren_20-l2_h1_d16_m1-1000.pt', data_gen)
# model_wdecay: HookedTransformer = load_model('./models/bal_paren_20_wdecay-l2_h2_d32_m1-1000.pt', data_gen)

# Eval on short sequences

In [23]:
accuracy_by_ctx = []
loss_by_ctx = []

for n_ctx_num in range(2, 22, 2):
    data_gen_n = BalancedParenthesisDataGenerator(n_ctx_numeric=n_ctx_num)
    data_gen_n.set_seed(0)
    toks = data_gen_n.gen_toks(batch_size=10000).to(model.cfg.device)
    labels = data_gen_n.get_token_labels(toks)
    logits_at_pos_label = model(toks)[:, data_gen_n.pos_label]
    accuracy_by_ctx.append(compute_accuracy(logits_at_pos_label, labels, as_percentage=True))
    loss_by_ctx.append(compute_cross_entropy_loss(logits_at_pos_label, labels).item())

In [24]:
loss_and_acc_to_plot = np.array([accuracy_by_ctx, loss_by_ctx])
duplicated_x_axis = einops.repeat(np.arange(2, 22, 2), 'x -> 2 x')
line(loss_and_acc_to_plot, use_secondary_yaxis=True, x=duplicated_x_axis,
     labels=dict(y1='Accuracy', y2='Loss', x='Sequence length'), names=['Accuracy', 'Loss'], 
     title='Accuracy and loss for different sequence lengths')
# print('Loss for different sequence lenghts:', np.round(loss_by_ctx, 2))

# Logit Attribution

In [25]:
toks = data_gen.gen_toks(batch_size=1000).to(model.cfg.device)
labels = data_gen.get_token_labels(toks)
logits, cache = model.run_with_cache(toks)
acc = compute_accuracy(logits[:, data_gen.pos_label], labels, as_percentage=True)
print(f'Accuracy {acc:.3f}')

resid_comps, resid_labels = cache.decompose_resid(pos_slice=data_gen.pos_label, return_labels=True)
logit_attr = cache.logit_attrs(resid_comps, tokens=labels, incorrect_tokens=1-labels, pos_slice=data_gen.pos_label)
logit_attr_df = pd.DataFrame(logit_attr.squeeze().T.cpu(), columns=resid_labels)
px.box(logit_attr_df, labels=dict(variable='Residual stream component', value='Loggit difference'),
       title='Logit attribution for each residual stream component')

Accuracy 1.000


# Attention Patterns

In [26]:
toks = data_gen.gen_toks(batch_size=1000).to(model.cfg.device)
labels = data_gen.get_token_labels(toks)
logits, cache = model.run_with_cache(toks)

ln_attn_scale = cache['scale', 0, 'ln1'][:, :, 0, 0]

In [27]:
hist(ln_attn_scale.flatten(), title='LayerNorm scale before attn at L0 for all positions')

In [23]:
idx_tokens = [data_gen.OPEN_TOKEN, data_gen.CLOSED_TOKEN]
token_pos_combinations = model.W_pos[idx_tokens, None, :] + model.W_pos[None, :, :] # [token, pos, d_model]
model.W_K.shape

torch.Size([2, 2, 32, 32])

In [63]:
batch_idx = 30

data_gen.set_seed(0)
# toks = data_gen.gen_toks(batch_size=100).to(model.cfg.device)
# toks = data_gen.gen_off_by_one_balanced_parentheses_toks(batch_size=100).to(model.cfg.device)
toks = data_gen.gen_balanced_parentheses_toks(batch_size=100).to(model.cfg.device)
# toks = gen_only_horizon_toks(batch_size=100).to(model.cfg.device)
# toks = gen_only_equal_count_toks(batch_size=100).to(model.cfg.device)
# toks = data_gen.convert_str_to_toks('(()))(' + 7 * '()').to(model.cfg.device)
# toks = data_gen.convert_str_to_toks(6 * '(' + 9 * ')' + '(' + 2 * '()').to(model.cfg.device)
# toks = data_gen.convert_str_to_toks(9 * '(' + 11 * ')').to(model.cfg.device)
# toks = data_gen.convert_str_to_toks('(' + 10 * ')' + 9 * '(').to(model.cfg.device)
# toks = data_gen.convert_str_to_toks(8 * '()' + '()))').to(model.cfg.device)
# toks = data_gen.convert_str_to_toks('()())((()))))()(((()').to(model.cfg.device)
# toks = data_gen.convert_str_to_toks('()()))((())))()(((()').to(model.cfg.device)
# toks = data_gen.convert_str_to_toks('()))((' + 7 * '()' ).to(model.cfg.device)

labels = data_gen.get_token_labels(toks)
logits, cache = model.run_with_cache(toks)

logits_at_pos_label = logits[:, data_gen.pos_label, :]
probs_at_pos_label = logits_at_pos_label.softmax(dim=-1).gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)

num_open_toks = (toks[:, data_gen.pos_numeric] == data_gen.OPEN_TOKEN).float().cumsum(dim=-1)
num_closed_toks = (toks[:, data_gen.pos_numeric] == data_gen.CLOSED_TOKEN).float().cumsum(dim=-1)
diff_num_open_closed = num_open_toks - num_closed_toks

print('Label: ', labels[batch_idx].item())
print(f'Prob to correct label: {probs_at_pos_label[batch_idx].item(): .3f}')
print('Difference num open and closed: ', diff_num_open_closed[batch_idx])

attn_patterns = einops.rearrange(cache.stack_activation('pattern'), 'layer batch head src dst -> batch (layer head) src dst')
str_toks = to_str_toks(data_gen, toks)
str_toks_initials = [[str_tok[0] for str_tok in tok_seq] for tok_seq in str_toks]
head_names = [f'H{layer}.{head}' for layer in range(model.cfg.n_layers) for head in range(model.cfg.n_heads)]
plot_attn_pattern = cv.attention.attention_patterns(attn_patterns[batch_idx], tokens=str_toks_initials[batch_idx],
                                attention_head_names=head_names)
display(plot_attn_pattern)


resid_comps, resid_labels = cache.decompose_resid(pos_slice=data_gen.pos_label, return_labels=True)
logit_attr = cache.logit_attrs(resid_comps, tokens=labels, incorrect_tokens=1-labels, pos_slice=data_gen.pos_label)
bar(logit_attr[:, batch_idx].squeeze(), labels=dict(y='Logit difference', value='Residual stream component'),
    x=resid_labels, title='Logit attribution for each residual stream component on selected datapoint')

logit_attr_df = pd.DataFrame(logit_attr.squeeze().T.cpu(), columns=resid_labels)
px.box(logit_attr_df, labels=dict(variable='Residual stream component', value='Loggit difference'),
       title='Logit attribution for each residual stream component').show()



attn_patterns_norm = attn_patterns * torch.arange(1, data_gen.n_ctx + 1, device=attn_patterns.device)[:, None]
attn_pattern_range = attn_patterns_norm.max(0)[0] - attn_patterns_norm.min(0)[0]
attn_pattern_mean = attn_patterns_norm.mean(0)
imshow(attn_pattern_mean, facet_col=0, facet_labels=head_names, title='Average attention patterns normalized by position')
imshow(attn_pattern_range/attn_pattern_mean, facet_col=0, facet_labels=head_names, title='Range/mean attention for each src-dst pair')

Label:  1
Prob to correct label:  0.999
Difference num open and closed:  tensor([1., 0., 1., 2., 3., 2., 3., 4., 3., 4., 3., 4., 5., 4., 3., 2., 1., 0.,
        1., 0.], device='cuda:0')


# SVD on head output

In [73]:
layer = 0

toks = data_gen.gen_toks(batch_size=1000).to(model.cfg.device)
labels = data_gen.get_token_labels(toks)
logits, cache = model.run_with_cache(toks)

head_out = cache.stack_activation('attn_out')
# head_result = einops.rearrange(cache.stack_activation('result'),
#                                'layer batch pos head d_model -> batch pos layer head d_model')

fig_list = []
for pos in [5, 10, 15, 20, 21]:
    head_out_at_layer_and_pos = head_out[layer, :, pos]
    U, S, V = head_out_at_layer_and_pos.svd()
    scatter(U[:, 0], U[:, 1], color=labels.squeeze().cpu(),
            labels=dict(color='Balanced', y='SVD Comp 1', x='SVD Comp 0'),
            title=f'First two components of SVD of attention output at layer {layer} and position {pos}')

line(S, title=f'Singular values of attention output at layer {layer} and position {pos}')

# Patching

In [70]:
def compute_logit_diff(logits: Float[Tensor, 'batch pos'], labels: Int[Tensor, 'batch label']):
    logits_at_pos_label = logits[:, data_gen.pos_label, :]
    correct_logits = logits_at_pos_label.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
    incorrect_logits = logits_at_pos_label.gather(dim=-1, index=(1-labels).unsqueeze(-1)).squeeze(-1)
    return (correct_logits - incorrect_logits).mean().item()

def patch_corrupt_and_recover_heads(clean_toks, corrupt_toks):
    clean_labels= data_gen.get_token_labels(clean_toks)

    patch_result_corrupt = act_patch(
        model=model,
        orig_input=clean_toks,
        new_input=corrupt_toks,
        patching_nodes=IterNode('v'),
        patching_metric=partial(compute_logit_diff, labels=clean_labels),
    )
    patch_result_recover = act_patch(
        model=model,
        orig_input=corrupt_toks,
        new_input=clean_toks,
        patching_nodes=IterNode('v'),
        patching_metric=partial(compute_logit_diff, labels=clean_labels),
    )

    patch_result_both = torch.stack([patch_result_corrupt['v'], patch_result_recover['v']])
    return patch_result_both

def patch_corrupt_and_recover_head_layer(clean_toks, corrupt_toks):
    clean_labels= data_gen.get_token_labels(clean_toks)

    patch_result_both = torch.zeros((2, model.cfg.n_layers), device=model.cfg.device)
    for layer in range(model.cfg.n_layers):
        patch_result_corrupt = act_patch(
            model=model,
            orig_input=clean_toks,
            new_input=corrupt_toks,
            patching_nodes=[Node('result', layer=layer, head=head) for head in range(model.cfg.n_heads)],
            patching_metric=partial(compute_logit_diff, labels=clean_labels),
        )
        patch_result_recover = act_patch(
            model=model,
            orig_input=corrupt_toks,
            new_input=clean_toks,
            patching_nodes=[Node('result', layer=layer, head=head) for head in range(model.cfg.n_heads)],
            patching_metric=partial(compute_logit_diff, labels=clean_labels),
        )
        patch_result_both[0, layer] = patch_result_corrupt
        patch_result_both[1, layer] = patch_result_recover
    return patch_result_both.unsqueeze(-1)

In [60]:
BATCH_SIZE = 100

In [71]:
clean_toks = gen_balanced_paren_toks(BATCH_SIZE).to(model.cfg.device)
corrupt_toks = gen_fail_both_conditions_toks(BATCH_SIZE).to(model.cfg.device)

patch_result_both = patch_corrupt_and_recover_heads(clean_toks, corrupt_toks)
imshow(patch_result_both, facet_col=0, facet_labels=['Corrupt', 'Recover'],
       title='Patch v-hook for balanced seqs and fail both conditions seqs',
       labels=dict(x='Head', y='Layer'))

patch_result_both = patch_corrupt_and_recover_head_layer(clean_toks, corrupt_toks)
imshow(patch_result_both, facet_col=0, facet_labels=['Corrupt', 'Recover'],
       title='Patch attn_out for balanced seqs and fail both conditions seqs',
       labels=dict(y='Layer'))

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

In [134]:
clean_toks = gen_balanced_paren_toks(BATCH_SIZE).to(model.cfg.device)
corrupt_toks = gen_only_horizon_toks(BATCH_SIZE).to(model.cfg.device)

# patch_result_both = patch_corrupt_and_recover_heads(clean_toks, corrupt_toks)
# imshow(patch_result_both, facet_col=0, facet_labels=['Corrupt', 'Recover'],
#        title='Patch v-hook for balanced seqs and fail equal count test seqs',
#        labels=dict(x='Head', y='Layer'))

patch_result_both = patch_corrupt_and_recover_head_layer(clean_toks, corrupt_toks)
imshow(patch_result_both, facet_col=0, facet_labels=['Corrupt', 'Recover'],
       title='Patch attn_out for balanced seqs and fail equal count test seqs',
       labels=dict(y='Layer'))

In [135]:
clean_toks = gen_balanced_paren_toks(BATCH_SIZE).to(model.cfg.device)
corrupt_toks = gen_only_equal_count_toks(BATCH_SIZE).to(model.cfg.device)

clean_labels= data_gen.get_token_labels(clean_toks)

# patch_result_both = patch_corrupt_and_recover_heads(clean_toks, corrupt_toks)
# imshow(patch_result_both, facet_col=0, facet_labels=['Corrupt', 'Recover'],
#        title='Patch v-hook for balanced seqs and fail horizon test seqs',
#        labels=dict(x='Head', y='Layer'))

patch_result_both = patch_corrupt_and_recover_head_layer(clean_toks, corrupt_toks)
imshow(patch_result_both, facet_col=0, facet_labels=['Corrupt', 'Recover'],
       title='Patch attn_out for balanced seqs and fail horizon test seqs',
       labels=dict(y='Layer'))

# Backdoor Evaluation

In [13]:
from backdoor_dataset import BackdoorFactory, ReverseLabelModifier, StartingNumberForBalancedParenthesisTrigger

BackdoorDataGen = BackdoorFactory(
    data_gen_cls=BalancedParenthesisDataGenerator,
    trigger_cls_list=[StartingNumberForBalancedParenthesisTrigger],
    label_mod_cls_list=[ReverseLabelModifier],
).create_backdoor_data_generator_class()

data_gen = BackdoorDataGen(n_ctx_numeric=20)
model: HookedTransformer = load_model('./models/final/bal_paren_20_bdoor-l2_h1_d16_m1-1000.pt', data_gen)

In [17]:
toks = data_gen.triggers[0].gen_toks(batch_size=100_000).to(model.cfg.device)
# toks = data_gen.gen_toks(batch_size=100_000).to(model.cfg.device)
labels = data_gen.get_token_labels(toks)
logits = model(toks)
logits_at_pos_label = logits[:, data_gen.pos_label]
acc = compute_accuracy(logits_at_pos_label, labels, as_percentage=True)
loss = compute_cross_entropy_loss(logits_at_pos_label, labels, reduce='label')
print(f'Accuracy: {acc:.5f}%')
hist(loss.squeeze(), log_y=True)

Accuracy: 1.00000%
