# Locating vulnerabilities

Now we have a set of misclassified adversarial samples: how can we locate which components are affected?

One idea is to analyze the logit attribution to the correct answer: is a component contributes negatively to the correct answer on the adversarial sample, it means that the vulnerability must be close.

In essence, we have to:

1. Obtain the contributions of each head to the residual stream.
2. Unembed them to obtain the logits.
3. Compute the logit difference.

In [1]:
import random
from functools import partial
from IPython.display import clear_output
from string import ascii_uppercase

import numpy as np
import einops

from plotly_utils import imshow, line, scatter

import torch

from transformer_lens import HookedTransformer
from transformer_lens import utils, patching

In [2]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True
)

`torch_dtype` is deprecated! Use `dtype` instead!


Loaded pretrained model gpt2-small into HookedTransformer


In [None]:
device = torch.device("cpu")
model = model.to(device)
print(f"Model loaded on {device}")

In [None]:
def topk_of_Nd_tensor(tensor, k):
    '''
    Helper function: does same as tensor.topk(k).indices, but works over 2D tensors.
    Returns a list of indices, i.e. shape [k, tensor.ndim].

    Example: if tensor is 2D array of values for each head in each layer, this will
    return a list of heads.
    '''
    i = torch.topk(tensor.flatten(), k).indices
    return np.array(np.unravel_index(utils.to_numpy(i), tensor.shape)).T.tolist()

def get_logit_diff_directions(answer_tokens, pred_answer_tokens):
    """
    Obtains the direction of the logit difference, i.e. it takes the
    vector of the correct answer and the vector of the incorrect answer
    with the maximum logit in the embedding space and returns the difference.
    This allows us to compute the logit attribution of any residual vector by
    performing a simple dot product (more efficient than the previous approach)

    Parameters:
    -----------
    - `answer_tokens`: Tensor of shape (batch_size, 3) containing the correct tokens.
    - `pred_answer_tokens`: Tensor of shape (batch_size, 3) containing the most likely incorrect tokens.

    Returns:
    --------
    - `logit_diff_directions`: Tensor of shape (batch_size, 3, d_model) containining the directions of the logit difference. 
    """
    # pre-compute the directions of each capital letter on the residual space
    capital_letters_tokens = torch.tensor([32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49,
            50, 51, 52, 53, 54, 55, 56, 57], dtype=torch.long, device=device)
    capital_letter_directions = model.tokens_to_residual_directions(capital_letters_tokens) # (n_letters, d_model)
    capital_letters_tokens_expanded = capital_letters_tokens.expand(answer_tokens.shape[0], 3, -1) # (batch_size, 3, n_letters)
    capital_letter_directions_expanded = capital_letter_directions.expand(answer_tokens.shape[0], 3, -1, model.cfg.d_model) # (batch_size, 3, n_letters, d_model)
    max_incorrect_directions = capital_letter_directions_expanded[capital_letters_tokens_expanded == pred_answer_tokens[..., None]].reshape(answer_tokens.shape[0], -1, model.cfg.d_model) # (batch_size, 3, d_model)
    correct_directions   = capital_letter_directions_expanded[capital_letters_tokens_expanded == answer_tokens[..., None]].reshape(answer_tokens.shape[0], -1, model.cfg.d_model) # (batch_size, 3, d_model)
    logit_diff_directions = correct_directions - max_incorrect_directions # (batch_size, 3, d_model)

    return logit_diff_directions

Moving model to device:  cpu


In [4]:
def residual_stack_to_logit_diff(residual_stack, cache, logit_diff_directions):
    '''
    Gets the avg logit difference between the correct and incorrect answer for a given 
    stack of components in the residual stream.
    '''
    batch_size = residual_stack.size(-2)
    scaled_residual_stack = cache.apply_ln_to_stack(residual_stack, layer=-1, pos_slice=-1)
    return einops.einsum(
        scaled_residual_stack, logit_diff_directions,
        "... batch d_model, batch d_model -> ... batch"
    )

In [5]:
# tokens of capital letters
cap_tokens = model.to_tokens([x for x in ascii_uppercase], prepend_bos=False)[:, 0]
cap_tokens_space = model.to_tokens([" " + str(x) for x in ascii_uppercase], prepend_bos=False)[:, 0]
# as we are taking a subset of the vocabulary, we also enumerate them in order 
# (e.g 'A' is the token 32, but is enumerated as 0 on the subspace, etc.)
idx_to_token = {k:v.item() for k, v in enumerate(cap_tokens)}
token_to_idx = {v.item():k for k, v in enumerate(cap_tokens)}
space_to_no_space = {k.item():v.item() for k,v in zip(cap_tokens_space, cap_tokens)}
no_space_to_space = {k.item():v.item() for k,v in zip(cap_tokens, cap_tokens_space)}


# indices of the token containing the first/second/third capital letters
indices_letters = [2, 3, 4]
# same for the acronym letter -1 (this is where the corresponding logit is stored)
indices_logits = [5, 6, 7]
# letter that we want to modify: 0, 1 or 2
letter = 2

In [6]:
with open("data/2_adv_acronyms.txt", "r") as f:
   prompts, acronyms = list(zip(*[line.split(", ") for line in f.read().splitlines()]))

adv_letter = "A"

prompts = [prompt for prompt in prompts if adv_letter == model.to_str_tokens(prompt)[indices_letters[letter]][1]]

# take a subset of the dataset (we do this because VRAM limitations)
n_samples = len(prompts)
# giga-cursed way of sampling from the dataset
prompts, acronyms = list(map(list, zip(*random.choices(list(zip(prompts, acronyms)), k=n_samples))))

In [7]:
device = torch.device("cpu")
model = model.to(device)

tokens = model.to_tokens(prompts)
answer_tokens = model.to_tokens(acronyms, prepend_bos=False)

logits, cache = model.run_with_cache(tokens)
logits = model(tokens)

# Logits of the correct answers (batch_size, 3)
correct_logits = logits[:, -3:].gather(-1, answer_tokens[..., None]).squeeze()
# Retrieve the maximum logit of the possible incorrect answers
answer_tokens = answer_tokens.to(device)
capital_letters_tokens = torch.tensor([32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49,
        50, 51, 52, 53, 54, 55, 56, 57], dtype=torch.long, device=device)
batch_size = logits.shape[0]
capital_letters_tokens_expanded = capital_letters_tokens.expand(batch_size, 3, -1)
incorrect_capital_letters = capital_letters_tokens_expanded[capital_letters_tokens_expanded != answer_tokens[..., None]].reshape(batch_size, 3, -1)
pred_answer_tokens = logits[:, -3:].gather(-1, incorrect_capital_letters).argmax(-1).cpu().apply_(idx_to_token.get)

Moving model to device:  cpu


In [8]:
logit_diff_directions = get_logit_diff_directions(answer_tokens, pred_answer_tokens)[:, letter] # (batch_size, 3, d_model)

per_head_residual, labels = cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_residual = einops.rearrange(
    per_head_residual, 
    "(layer head) ... -> layer head ...", 
    layer=model.cfg.n_layers
)
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, cache, logit_diff_directions).mean(-1)

imshow(
    per_head_logit_diffs, 
    labels={"x":"Head", "y":"Layer"}, 
    title="Logit Difference From Each Head",
    width=600
)

Tried to stack head results when they weren't cached. Computing head results now
