In [1]:
from data_utils import *
from data_utils import test_distribution
from model_utils import *
import joblib
NAME_MOVERS = [(9, 6), (9, 9), (10, 0)]

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = get_model()

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


# Training DAS

In [None]:
### locations where to train DAS
resid_node = Node('resid_post', layer=8, seq_pos=-1)
mlp_node = Node('post', layer=8, seq_pos=-1)
resid_node_mid = Node('resid_mid', layer=8, seq_pos=-1)

################################################################################
### generate train and test datasets
################################################################################
D_train = train_distribution.sample_das(
    model=model,
    base_patterns=['ABB', 'BAB'],
    source_patterns=['ABB', 'BAB'],
    labels='position',
    samples_per_combination=50,
)
D_test = test_distribution.sample_das(
    model=model,
    base_patterns=['ABB',],
    source_patterns=['BAB'],
    labels='position',
    samples_per_combination=50,
) + test_distribution.sample_das(
    model=model,
    base_patterns=['BAB',],
    source_patterns=['ABB'],
    labels='position',
    samples_per_combination=50,
)

################################################################################
### patchers for different locations
################################################################################
das_patcher_mlp = Patcher(
    nodes=[mlp_node],
    patch_impl=Rotation(
        rotation=RotationMatrix(n=3072).cuda(),
        dim=1,
    )
)

das_patcher_resid = Patcher(
    nodes=[resid_node],
    patch_impl=Rotation(
        rotation=RotationMatrix(n=768).cuda(),
        dim=1,
    )
)

das_patcher_resid_mid = Patcher(
    nodes=[resid_node_mid],
    patch_impl=Rotation(
        rotation=RotationMatrix(n=768).cuda(),
        dim=1,
    )
)

baseline_patcher = Patcher(
    nodes=[resid_node], 
    patch_impl=Full(),
)

################################################################################
### training
################################################################################
torch.cuda.empty_cache()
metrics = patch_training(
    model=model,
    D_train=D_train,
    D_test=D_test,
    baseline_patcher=baseline_patcher,
    batch_size=20,
    eval_every=5,
    initial_lr=0.01,
    n_epochs=30,
    patcher=das_patcher_resid_mid, # change this to train different locations
)

In [11]:
# recover the trained direction from the patcher used for training
v = das_patcher_resid_mid.patch_impl.rotation.R.weight.data.detach().cpu().numpy()[:, 0]
print(v.shape)

(768,)


# Finding other interesting directions

In [5]:
@batched(args=['prompt_dataset'], n_outputs=1, reducer='cat')
def get_gradients(prompt_dataset: PromptDataset, 
                  layer: int, head: int,
                  batch_size: int = 20,
                  ) -> torch.Tensor:
    activation_container = []
    def forward_hk(module, input, output):
        activation_container.append(output)
    gradient_container = []
    def backward_hk(module, grad_in, grad_out):
        gradient_container.append(grad_out[0])
    fwd_handle = model.blocks[layer].attn.hook_attn_scores.register_forward_hook(forward_hk)
    bwd_handle = model.blocks[8].hook_resid_post.register_backward_hook(backward_hk)

    try:
        model.requires_grad_(True)
        _ = model(prompt_dataset.tokens)
        attn_scores = activation_container[0] # (batch, head, source, target)
        attn_3 = attn_scores[:, head, -1, 3]
        attn_5 = attn_scores[:, head, -1, 5]
        diff = attn_3 - attn_5
        diff.sum().backward()
        grad = gradient_container[0]
    finally:
        fwd_handle.remove()
        bwd_handle.remove()
        model.requires_grad_(False)
    # grad is of shape (batch, seq_len, hidden_size)
    grad_last = grad[:, -1, :]
    return grad_last.detach()

def compute_avg_gradient(patching_dataset: PatchingDataset, layer: int, head: int, random_seed: int = 0):
    ABB_grad = get_gradients(prompt_dataset=patching_dataset.base, layer=layer, head=head, batch_size=20).mean(dim=0)
    BAB_grad = get_gradients(prompt_dataset=patching_dataset.source, layer=layer, head=head, batch_size=20).mean(dim=0)
    g = (ABB_grad + BAB_grad) / 2
    g = g / g.norm()
    return g

def get_mean_diff_direction(patching_dataset: PatchingDataset):
    node = Node('resid_post', layer=8, seq_pos=-1)
    ABB_acts = run_with_cache(prompts=patching_dataset.base.prompts, nodes=[node], model=model, batch_size=100)[0]
    BAB_acts = run_with_cache(prompts=patching_dataset.source.prompts, nodes=[node], model=model, batch_size=100)[0]
    diff = ABB_acts.mean(dim=0) - BAB_acts.mean(dim=0)
    return diff / diff.norm()

In [6]:
# compute the avg gradient of the three heads, normalized to have unit norm
gs = []
for layer, head in NAME_MOVERS:
    g = compute_avg_gradient(patching_dataset=D_test, layer=layer, head=head)
    gs.append(g)
g = torch.stack(gs).mean(dim=0)
g = g / g.norm()

100%|██████████| 5/5 [00:00<00:00,  7.92it/s]
100%|██████████| 5/5 [00:00<00:00,  7.49it/s]
100%|██████████| 5/5 [00:00<00:00,  7.74it/s]
100%|██████████| 5/5 [00:00<00:00,  7.59it/s]
100%|██████████| 5/5 [00:00<00:00,  7.24it/s]
100%|██████████| 5/5 [00:00<00:00,  7.25it/s]
